From 2bbe33fedcee7d500fc6e8cec482654ad26efe65 Mon Sep 17 00:00:00 2001 From: Ryan Chen Date: Tue, 14 Oct 2025 22:13:01 -0400 Subject: [PATCH] Starting attempt #2 at metadata filtering --- chunker.py | 4 +- llm.py | 4 +- main.py | 56 ++++++++++++++++------------ query.py | 105 ++++++++++++++++++++++++++++++++++++----------------- request.py | 14 ++++++- 5 files changed, 122 insertions(+), 61 deletions(-) diff --git a/chunker.py b/chunker.py index 7eeb73a..8185af0 100644 --- a/chunker.py +++ b/chunker.py @@ -13,7 +13,9 @@ from llm import LLMClient load_dotenv() -ollama_client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434")) +ollama_client = Client( + host=os.getenv("OLLAMA_HOST", "http://localhost:11434"), timeout=10.0 +) def remove_headers_footers(text, header_patterns=None, footer_patterns=None): diff --git a/llm.py b/llm.py index bc4167b..e700fce 100644 --- a/llm.py +++ b/llm.py @@ -3,8 +3,6 @@ import os from ollama import Client from openai import OpenAI -import typing - import logging logging.basicConfig(level=logging.INFO) @@ -14,7 +12,7 @@ class LLMClient: def __init__(self): try: self.ollama_client = Client( - host=os.getenv("OLLAMA_URL", "http://localhost:11434") + host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0 ) self.ollama_client.chat( model="gemma3:4b", messages=[{"role": "system", "content": "test"}] diff --git a/main.py b/main.py index b93fa95..20b7035 100644 --- a/main.py +++ b/main.py @@ -7,12 +7,10 @@ from typing import Any, Union import argparse import chromadb import ollama -from openai import OpenAI from request import PaperlessNGXService from chunker import Chunker -from query import QueryGenerator from cleaner import pdf_to_image, summarize_pdf_image from llm import LLMClient @@ -21,13 +19,13 @@ from dotenv import load_dotenv load_dotenv() -USE_OPENAI = os.getenv("OPENAI_API_KEY") != None - # Configure ollama client with URL from environment or default to localhost -ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434")) +ollama_client = ollama.Client( + host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0 +) client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", "")) -simba_docs = client.get_or_create_collection(name="simba_docs3") +simba_docs = client.get_or_create_collection(name="simba_docs2") feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup") parser = argparse.ArgumentParser( @@ -46,6 +44,7 @@ llm_client = LLMClient() def index_using_pdf_llm(): + logging.info("reindex data...") files = ppngx.get_data() for file in files: document_id = file["id"] @@ -72,28 +71,35 @@ def date_to_epoch(date_str: str) -> float: return date.timestamp() -def chunk_data(docs: list[dict[str, Union[str, Any]]], collection): +def chunk_data(docs: list[dict[str, Union[str, Any]]], collection, doctypes): # Step 2: Create chunks chunker = Chunker(collection) print(f"chunking {len(docs)} documents") texts: list[str] = [doc["content"] for doc in docs] - with sqlite3.connect("visited.db") as conn: + with sqlite3.connect("visited.db") as conn: to_insert = [] c = conn.cursor() for index, text in enumerate(texts): metadata = { "created_date": date_to_epoch(docs[index]["created_date"]), "filename": docs[index]["original_file_name"], + "document_type": doctypes.get(docs[index]["document_type"], ""), } + + if doctypes: + metadata["type"] = doctypes.get(docs[index]["document_type"]) + chunker.chunk_document( document=text, metadata=metadata, ) to_insert.append((docs[index]["id"],)) - c.executemany("INSERT INTO indexed_documents (paperless_id) values (?)", to_insert) - + c.executemany( + "INSERT INTO indexed_documents (paperless_id) values (?)", to_insert + ) + conn.commit() def chunk_text(texts: list[str], collection): @@ -169,10 +175,13 @@ def consult_simba_oracle(input: str): collection=simba_docs, ) + def filter_indexed_files(docs): with sqlite3.connect("visited.db") as conn: c = conn.cursor() - c.execute("CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)") + c.execute( + "CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)" + ) c.execute("SELECT paperless_id FROM indexed_documents") rows = c.fetchall() conn.commit() @@ -181,7 +190,6 @@ def filter_indexed_files(docs): return [doc for doc in docs if doc["id"] not in visited] - if __name__ == "__main__": args = parser.parse_args() if args.reindex: @@ -192,20 +200,22 @@ if __name__ == "__main__": print(f"Fetched {len(docs)} documents") # print("Chunking documents now ...") - chunk_data(docs, collection=simba_docs) + tag_lookup = ppngx.get_tags() + doctype_lookup = ppngx.get_doctypes() + chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup) print("Done chunking documents") # index_using_pdf_llm() - if args.index: - with open(args.index) as file: - extension = args.index.split(".")[-1] - if extension == "pdf": - pdf_path = ppngx.download_pdf_from_id(id=document_id) - image_paths = pdf_to_image(filepath=pdf_path) - print(f"summarizing {file}") - generated_summary = summarize_pdf_image(filepaths=image_paths) - elif extension in [".md", ".txt"]: - chunk_text(texts=[file.readall()], collection=simba_docs) + # if args.index: + # with open(args.index) as file: + # extension = args.index.split(".")[-1] + # if extension == "pdf": + # pdf_path = ppngx.download_pdf_from_id(id=document_id) + # image_paths = pdf_to_image(filepath=pdf_path) + # print(f"summarizing {file}") + # generated_summary = summarize_pdf_image(filepaths=image_paths) + # elif extension in [".md", ".txt"]: + # chunk_text(texts=[file.readall()], collection=simba_docs) if args.query: print("Consulting oracle ...") diff --git a/query.py b/query.py index defa02a..6dd53c6 100644 --- a/query.py +++ b/query.py @@ -2,7 +2,7 @@ import json import os from typing import Literal import datetime -from ollama import chat, ChatResponse, Client +from ollama import Client from openai import OpenAI @@ -38,6 +38,20 @@ class Time(BaseModel): time: int +DOCTYPE_OPTIONS = [ + "Bill", + "Image Description", + "Insurance", + "Medical Record", + "Documentation", + "Letter", +] + + +class DocumentType(BaseModel): + type: list[str] = Field(description="type of document", enum=DOCTYPE_OPTIONS) + + PROMPT = """ You are an information specialist that processes user queries. The current year is 2025. The user queries are all about a cat, Simba, and its records. The types of records are listed below. Using the query, extract the @@ -82,7 +96,8 @@ document_types: Only return the extracted metadata fields. Make sure the extracted metadata fields are valid JSON """ -USE_OPENAI = os.getenv("OPENAI_API_KEY", None) != None + +DOCTYPE_PROMPT = f"You are an information specialist that processes user queries. A query can have two tags attached from the following options. Based on the query, determine which of the following options is most appropriate: {','.join(DOCTYPE_OPTIONS)}" class QueryGenerator: @@ -102,43 +117,67 @@ class QueryGenerator: return date.timestamp() + def get_doctype_query(self, input: str): + print(DOCTYPE_PROMPT) + client = OpenAI() + response = client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are an information specialist that is really good at deciding what tags a query should have", + }, + {"role": "user", "content": DOCTYPE_PROMPT + " " + input}, + ], + model="gpt-4o", + response_format={ + "type": "json_schema", + "json_schema": { + "name": "document_type", + "schema": DocumentType.model_json_schema(), + }, + }, + ) + + response_json_str = response.choices[0].message.content + type_data = json.loads(response_json_str) + print(type_data) + return type_data + def get_query(self, input: str): - if USE_OPENAI: - client = OpenAI() - response = client.responses.parse( - model="gpt-4o", - input=[ - {"role": "system", "content": PROMPT}, - {"role": "user", "content": input}, - ], - text_format=GeneratedQuery, - ) - print(response.output) - query = json.loads(response.output_parsed.extracted_metadata_fields) - else: - response: ChatResponse = ollama_client.chat( - model="gemma3n:e4b", - messages=[ - {"role": "system", "content": PROMPT}, - {"role": "user", "content": input}, - ], - format=GeneratedQuery.model_json_schema(), - ) + client = OpenAI() + response = client.responses.parse( + model="gpt-4o", + input=[ + {"role": "system", "content": PROMPT}, + {"role": "user", "content": input}, + ], + text_format=GeneratedQuery, + ) + print(response.output) + query = json.loads(response.output_parsed.extracted_metadata_fields) + # response: ChatResponse = ollama_client.chat( + # model="gemma3n:e4b", + # messages=[ + # {"role": "system", "content": PROMPT}, + # {"role": "user", "content": input}, + # ], + # format=GeneratedQuery.model_json_schema(), + # ) - query = json.loads( - json.loads(response["message"]["content"])["extracted_metadata_fields"] - ) - date_key = list(query["created_date"].keys())[0] - query["created_date"][date_key] = self.date_to_epoch( - query["created_date"][date_key] - ) + # query = json.loads( + # json.loads(response["message"]["content"])["extracted_metadata_fields"] + # ) + # date_key = list(query["created_date"].keys())[0] + # query["created_date"][date_key] = self.date_to_epoch( + # query["created_date"][date_key] + # ) - if "$" not in date_key: - query["created_date"]["$" + date_key] = query["created_date"][date_key] + # if "$" not in date_key: + # query["created_date"]["$" + date_key] = query["created_date"][date_key] return query if __name__ == "__main__": qg = QueryGenerator() - print(qg.get_query("How heavy is Simba?")) + print(qg.get_doctype_query("How heavy is Simba?")) diff --git a/request.py b/request.py index 1eac623..f379480 100644 --- a/request.py +++ b/request.py @@ -14,7 +14,7 @@ class PaperlessNGXService: def __init__(self): self.base_url = os.getenv("BASE_URL") self.token = os.getenv("PAPERLESS_TOKEN") - self.url = f"http://{os.getenv('BASE_URL')}/api/documents/?query=simba" + self.url = f"http://{os.getenv('BASE_URL')}/api/documents/?tags__id=8" self.headers = {"Authorization": f"Token {os.getenv('PAPERLESS_TOKEN')}"} def get_data(self): @@ -68,6 +68,18 @@ class PaperlessNGXService: r = httpx.post(POST_URL, headers=self.headers, data=data, files=files) r.raise_for_status() + def get_tags(self): + GET_URL = f"http://{os.getenv('BASE_URL')}/api/tags/" + r = httpx.get(GET_URL, headers=self.headers) + data = r.json() + return {tag["id"]: tag["name"] for tag in data["results"]} + + def get_doctypes(self): + GET_URL = f"http://{os.getenv('BASE_URL')}/api/document_types/" + r = httpx.get(GET_URL, headers=self.headers) + data = r.json() + return {doctype["id"]: doctype["name"] for doctype in data["results"]} + if __name__ == "__main__": pp = PaperlessNGXService()