-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
161 lines (133 loc) · 5.12 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import pdfplumber
from langchain_ollama import OllamaEmbeddings, OllamaLLM
import chromadb
# Define the LLM model to be used
llm_model = "llama3.2"
# Configure ChromaDB
# Initialize the ChromaDB client with persistent storage in the current directory
chroma_client = chromadb.PersistentClient(path=os.path.join(os.getcwd(), "chroma_db"))
# Define a custom embedding function for ChromaDB using Ollama
class ChromaDBEmbeddingFunction:
"""
Custom embedding function for ChromaDB using embeddings from Ollama.
"""
def __init__(self, langchain_embeddings):
self.langchain_embeddings = langchain_embeddings
def __call__(self, input):
# Ensure the input is in a list format for processing
if isinstance(input, str):
input = [input]
return self.langchain_embeddings.embed_documents(input)
# Initialize the embedding function with Ollama embeddings
embedding = ChromaDBEmbeddingFunction(
OllamaEmbeddings(
model=llm_model,
base_url="http://localhost:11434" # Adjust the base URL as per your Ollama server configuration
)
)
# Define a collection for the RAG workflow
collection_name = "rag_collection_demo_1"
collection = chroma_client.get_or_create_collection(
name=collection_name,
metadata={"description": "A collection for RAG with Ollama - Demo1"},
embedding_function=embedding # Use the custom embedding function
)
# Function to add documents to the ChromaDB collection
def add_documents_to_collection(documents, ids):
"""
Add documents to the ChromaDB collection.
Args:
documents (list of str): The documents to add.
ids (list of str): Unique IDs for the documents.
"""
collection.add(
documents=documents,
ids=ids
)
# Function to extract text from a PDF file
def extract_text_from_pdf(pdf_path):
"""
Extract text from a PDF file.
Args:
pdf_path (str): The path to the PDF file.
Returns:
str: The extracted text.
"""
with pdfplumber.open(pdf_path) as pdf:
text = ""
for page in pdf.pages:
text += page.extract_text()
return text
# Function to process all PDF files in the data directory
def process_pdfs_in_directory(directory_path):
"""
Process all PDF files in the specified directory.
Args:
directory_path (str): The path to the directory containing the PDF files.
"""
pdf_files = [f for f in os.listdir(directory_path) if f.endswith('.pdf')]
for pdf_file in pdf_files:
pdf_path = os.path.join(directory_path, pdf_file)
pdf_text = extract_text_from_pdf(pdf_path)
if pdf_text:
# Add the extracted text to ChromaDB
doc_id = f"pdf_doc_{pdf_file}"
documents = [pdf_text]
add_documents_to_collection(documents, [doc_id])
print(f"Text from {pdf_file} has been added to ChromaDB.")
else:
print(f"Failed to extract text from {pdf_file}.")
# Define the PDF directory path (assuming PDFs are located in the 'data' directory)
pdf_directory = os.path.join(os.getcwd(), "data")
# Process all PDFs in the data directory
process_pdfs_in_directory(pdf_directory)
# Function to query the ChromaDB collection
def query_chromadb(query_text, n_results=1):
"""
Query the ChromaDB collection for relevant documents.
Args:
query_text (str): The input query.
n_results (int): The number of top results to return.
Returns:
list of dict: The top matching documents and their metadata.
"""
results = collection.query(
query_texts=[query_text],
n_results=n_results
)
return results["documents"], results["metadatas"]
# Function to interact with the Ollama LLM
def query_ollama(prompt):
"""
Send a query to Ollama and retrieve the response.
Args:
prompt (str): The input prompt for Ollama.
Returns:
str: The response from Ollama.
"""
llm = OllamaLLM(model=llm_model)
return llm.invoke(prompt)
# RAG pipeline: Combine ChromaDB and Ollama for Retrieval-Augmented Generation
def rag_pipeline(query_text):
"""
Perform Retrieval-Augmented Generation (RAG) by combining ChromaDB and Ollama.
Args:
query_text (str): The input query.
Returns:
str: The generated response from Ollama augmented with retrieved context.
"""
# Step 1: Retrieve relevant documents from ChromaDB
retrieved_docs, metadata = query_chromadb(query_text)
context = " ".join(retrieved_docs[0]) if retrieved_docs else "No relevant documents found."
# Step 2: Send the query along with the context to Ollama
augmented_prompt = f"Context: {context}\n\nQuestion: {query_text}\nAnswer:"
print("######## Augmented Prompt ########")
print(augmented_prompt)
response = query_ollama(augmented_prompt)
return response
# Example usage
# Define a query to test the RAG pipeline
query = "I want to build a rop chain in ARM assembly that intiates a call back to the download server, how do I go about this?" # Change the query as needed
response = rag_pipeline(query)
print("######## Response from LLM ########\n", response)