diff --git a/main.py b/main.py index 20b7035..6f1b095 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,6 @@ import datetime import logging import os import sqlite3 -from typing import Any, Union import argparse import chromadb @@ -13,11 +12,12 @@ from request import PaperlessNGXService from chunker import Chunker from cleaner import pdf_to_image, summarize_pdf_image from llm import LLMClient +from query import QueryGenerator from dotenv import load_dotenv -load_dotenv() +_dotenv_loaded = load_dotenv() # Configure ollama client with URL from environment or default to localhost ollama_client = ollama.Client( @@ -43,18 +43,18 @@ ppngx = PaperlessNGXService() llm_client = LLMClient() -def index_using_pdf_llm(): +def index_using_pdf_llm(doctypes): logging.info("reindex data...") files = ppngx.get_data() for file in files: - document_id = file["id"] + document_id: int = file["id"] pdf_path = ppngx.download_pdf_from_id(id=document_id) image_paths = pdf_to_image(filepath=pdf_path) print(f"summarizing {file}") generated_summary = summarize_pdf_image(filepaths=image_paths) file["content"] = generated_summary - chunk_data(files, simba_docs) + chunk_data(files, simba_docs, doctypes=doctypes) def date_to_epoch(date_str: str) -> float: @@ -71,7 +71,7 @@ def date_to_epoch(date_str: str) -> float: return date.timestamp() -def chunk_data(docs: list[dict[str, Union[str, Any]]], collection, doctypes): +def chunk_data(docs, collection, doctypes): # Step 2: Create chunks chunker = Chunker(collection) @@ -121,13 +121,15 @@ def consult_oracle(input: str, collection): start_time = time.time() # Ask - # print("Starting query generation") - # qg_start = time.time() - # qg = QueryGenerator() + print("Starting query generation") + qg_start = time.time() + qg = QueryGenerator() + doctype_query = qg.get_doctype_query(input=input) # metadata_filter = qg.get_query(input) - # qg_end = time.time() - # print(f"Query generation took {qg_end - qg_start:.2f} seconds") - # print(metadata_filter) + metadata_filter = {**doctype_query} + print(metadata_filter) + qg_end = time.time() + print(f"Query generation took {qg_end - qg_start:.2f} seconds") print("Starting embedding generation") embedding_start = time.time() @@ -140,8 +142,9 @@ def consult_oracle(input: str, collection): results = collection.query( query_texts=[input], query_embeddings=embeddings, - # where=metadata_filter, + where=metadata_filter, ) + print(results) query_end = time.time() print(f"Collection query took {query_end - query_start:.2f} seconds") @@ -193,18 +196,28 @@ def filter_indexed_files(docs): if __name__ == "__main__": args = parser.parse_args() if args.reindex: + with sqlite3.connect("./visited.db") as conn: + c = conn.cursor() + c.execute("DELETE FROM indexed_documents") + print("Fetching documents from Paperless-NGX") ppngx = PaperlessNGXService() docs = ppngx.get_data() docs = filter_indexed_files(docs) print(f"Fetched {len(docs)} documents") - # + + # Delete all chromadb data + ids = simba_docs.get(ids=None, limit=None, offset=0) + all_ids = ids["ids"] + if len(all_ids) > 0: + simba_docs.delete(ids=all_ids) + + # Chunk documents print("Chunking documents now ...") tag_lookup = ppngx.get_tags() doctype_lookup = ppngx.get_doctypes() chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup) print("Done chunking documents") - # index_using_pdf_llm() # if args.index: # with open(args.index) as file: diff --git a/query.py b/query.py index 6dd53c6..0435a62 100644 --- a/query.py +++ b/query.py @@ -9,7 +9,9 @@ from openai import OpenAI from pydantic import BaseModel, Field # Configure ollama client with URL from environment or default to localhost -ollama_client = Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434")) +ollama_client = Client( + host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0 +) # This uses inferred filters — which means using LLM to create the metadata filters @@ -53,8 +55,8 @@ class DocumentType(BaseModel): PROMPT = """ -You are an information specialist that processes user queries. The current year is 2025. The user queries are all about -a cat, Simba, and its records. The types of records are listed below. Using the query, extract the +You are an information specialist that processes user queries. The current year is 2025. The user queries are all about +a cat, Simba, and its records. The types of records are listed below. Using the query, extract the the date range the user is trying to query. You should return it as a JSON. The date tag is created_date. Return the date in epoch time. If the created_date cannot be ascertained, set it to epoch time start. @@ -97,7 +99,17 @@ Only return the extracted metadata fields. Make sure the extracted metadata fiel """ -DOCTYPE_PROMPT = f"You are an information specialist that processes user queries. A query can have two tags attached from the following options. Based on the query, determine which of the following options is most appropriate: {','.join(DOCTYPE_OPTIONS)}" +DOCTYPE_PROMPT = f"""You are an information specialist that processes user queries. A query can have two tags attached from the following options. Based on the query, determine which of the following options is most appropriate: {",".join(DOCTYPE_OPTIONS)} + +### Example 1 +Query: "Who is Simba's current vet?" +Tags: ["Bill", "Medical Record"] + + +### Example 2 +Query: "Who does Simba know?" +Tags: ["Letter", "Documentation"] +""" class QueryGenerator: @@ -118,7 +130,6 @@ class QueryGenerator: return date.timestamp() def get_doctype_query(self, input: str): - print(DOCTYPE_PROMPT) client = OpenAI() response = client.chat.completions.create( messages=[ @@ -140,8 +151,8 @@ class QueryGenerator: response_json_str = response.choices[0].message.content type_data = json.loads(response_json_str) - print(type_data) - return type_data + metadata_query = {"document_type": {"$in": type_data["type"]}} + return metadata_query def get_query(self, input: str): client = OpenAI()