diff --git a/main.py b/main.py index ff890b2..e911e69 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import ollama +import os from uuid import uuid4, UUID from request import PaperlessNGXService @@ -7,9 +8,17 @@ from math import ceil import chromadb +from chromadb.utils.embedding_functions.ollama_embedding_function import ( + OllamaEmbeddingFunction, +) + +from dotenv import load_dotenv + client = chromadb.EphemeralClient() collection = client.create_collection(name="docs") +load_dotenv() + class Chunk: def __init__( @@ -29,6 +38,11 @@ class Chunk: class Chunker: def __init__(self) -> None: + self.embedding_fx = OllamaEmbeddingFunction( + url=os.getenv("OLLAMA_URL", ""), + model_name="mxbai-embed-large", + ) + pass def chunk_document(self, document: str, chunk_size: int = 300) -> list[Chunk]: @@ -47,15 +61,20 @@ class Chunker: ) text_chunk = document[curr_pos:to_pos] + embedding = self.embedding_fx([text_chunk]) collection.add( ids=[str(doc_uuid) + ":" + str(i)], documents=[text_chunk], + embeddings=embedding, ) return chunks -# Setup +embedding_fx = OllamaEmbeddingFunction( + url=os.getenv("OLLAMA_URL", ""), + model_name="mxbai-embed-large", +) # Step 1: Get the text ppngx = PaperlessNGXService() @@ -70,9 +89,9 @@ for text in texts: chunker.chunk_document(document=text) # Ask -input = "How many teeth has Simba had removed?" -response = ollama.embed(model="mxbai-embed-large", input=input) -results = collection.query(query_texts=[input], n_results=1) +input = "How many teeth has Simba had removed? Who is his current vet?" +embeddings = embedding_fx(input=[input]) +results = collection.query(query_texts=[input], query_embeddings=embeddings) print(results) # Generate output = ollama.generate( diff --git a/request.py b/request.py index 7770fcf..229d619 100644 --- a/request.py +++ b/request.py @@ -17,3 +17,8 @@ class PaperlessNGXService: print(f"Getting data from: {self.url}") r = httpx.get(self.url, headers=self.headers) return r.json()["results"] + + +if __name__ == "__main__": + pp = PaperlessNGXService() + print(pp.get_data()[0].keys())