diff --git a/main.py b/main.py index ff890b2..9069559 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,10 @@ from math import ceil import chromadb +from chromadb.utils.embedding_functions.ollama_embedding_function import ( + OllamaEmbeddingFunction, +) + client = chromadb.EphemeralClient() collection = client.create_collection(name="docs") @@ -29,6 +33,11 @@ class Chunk: class Chunker: def __init__(self) -> None: + self.embedding_fx = OllamaEmbeddingFunction( + url="http://localhost:11434", + model_name="mxbai-embed-large", + ) + pass def chunk_document(self, document: str, chunk_size: int = 300) -> list[Chunk]: @@ -47,15 +56,20 @@ class Chunker: ) text_chunk = document[curr_pos:to_pos] + embedding = self.embedding_fx([text_chunk]) collection.add( ids=[str(doc_uuid) + ":" + str(i)], documents=[text_chunk], + embeddings=embedding, ) return chunks -# Setup +embedding_fx = OllamaEmbeddingFunction( + url="http://localhost:11434", + model_name="mxbai-embed-large", +) # Step 1: Get the text ppngx = PaperlessNGXService() @@ -71,8 +85,8 @@ for text in texts: # Ask input = "How many teeth has Simba had removed?" -response = ollama.embed(model="mxbai-embed-large", input=input) -results = collection.query(query_texts=[input], n_results=1) +embeddings = embedding_fx(input=[input]) +results = collection.query(query_texts=[input], query_embeddings=embeddings) print(results) # Generate output = ollama.generate(