diff --git a/chunker.py b/chunker.py index 589e348..7fef7a6 100644 --- a/chunker.py +++ b/chunker.py @@ -8,6 +8,7 @@ from chromadb.utils.embedding_functions.openai_embedding_function import ( OpenAIEmbeddingFunction, ) from dotenv import load_dotenv +from llm import LLMClient USE_OPENAI = os.getenv("OPENAI_API_KEY") != None @@ -91,9 +92,10 @@ class Chunker: def __init__(self, collection) -> None: self.collection = collection + self.llm_client = LLMClient() def embedding_fx(self, inputs): - if USE_OPENAI: + if self.llm_client.PROVIDER == "openai": openai_embedding_fx = OpenAIEmbeddingFunction( api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small", diff --git a/llm.py b/llm.py index 003b667..bc4167b 100644 --- a/llm.py +++ b/llm.py @@ -16,7 +16,7 @@ class LLMClient: self.ollama_client = Client( host=os.getenv("OLLAMA_URL", "http://localhost:11434") ) - client.chat( + self.ollama_client.chat( model="gemma3:4b", messages=[{"role": "system", "content": "test"}] ) self.PROVIDER = "ollama" @@ -35,9 +35,16 @@ class LLMClient: if self.PROVIDER == "ollama": response = self.ollama_client.chat( model="gemma3:4b", - prompt=prompt, + messages=[ + { + "role": "system", + "content": system_prompt, + }, + {"role": "user", "content": prompt}, + ], ) - output = response["response"] + print(response) + output = response.message.content elif self.PROVIDER == "openai": response = self.openai_client.responses.create( model="gpt-4o-mini", @@ -51,6 +58,8 @@ class LLMClient: ) output = response.output_text + return output + if __name__ == "__main__": client = Client() diff --git a/main.py b/main.py index 8ec4f34..b93fa95 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ import datetime import logging import os +import sqlite3 from typing import Any, Union import argparse @@ -15,6 +16,7 @@ from query import QueryGenerator from cleaner import pdf_to_image, summarize_pdf_image from llm import LLMClient + from dotenv import load_dotenv load_dotenv() @@ -25,7 +27,7 @@ USE_OPENAI = os.getenv("OPENAI_API_KEY") != None ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434")) client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", "")) -simba_docs = client.get_or_create_collection(name="simba_docs2") +simba_docs = client.get_or_create_collection(name="simba_docs3") feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup") parser = argparse.ArgumentParser( @@ -76,15 +78,22 @@ def chunk_data(docs: list[dict[str, Union[str, Any]]], collection): print(f"chunking {len(docs)} documents") texts: list[str] = [doc["content"] for doc in docs] - for index, text in enumerate(texts): - metadata = { - "created_date": date_to_epoch(docs[index]["created_date"]), - "filename": docs[index]["original_file_name"], - } - chunker.chunk_document( - document=text, - metadata=metadata, - ) + with sqlite3.connect("visited.db") as conn: + to_insert = [] + c = conn.cursor() + for index, text in enumerate(texts): + metadata = { + "created_date": date_to_epoch(docs[index]["created_date"]), + "filename": docs[index]["original_file_name"], + } + chunker.chunk_document( + document=text, + metadata=metadata, + ) + to_insert.append((docs[index]["id"],)) + + c.executemany("INSERT INTO indexed_documents (paperless_id) values (?)", to_insert) + def chunk_text(texts: list[str], collection): @@ -160,6 +169,18 @@ def consult_simba_oracle(input: str): collection=simba_docs, ) +def filter_indexed_files(docs): + with sqlite3.connect("visited.db") as conn: + c = conn.cursor() + c.execute("CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)") + c.execute("SELECT paperless_id FROM indexed_documents") + rows = c.fetchall() + conn.commit() + + visited = {row[0] for row in rows} + return [doc for doc in docs if doc["id"] not in visited] + + if __name__ == "__main__": args = parser.parse_args() @@ -167,6 +188,7 @@ if __name__ == "__main__": print("Fetching documents from Paperless-NGX") ppngx = PaperlessNGXService() docs = ppngx.get_data() + docs = filter_indexed_files(docs) print(f"Fetched {len(docs)} documents") # print("Chunking documents now ...")