Starting attempt #2 at metadata filtering

This commit is contained in:
2025-10-14 22:13:01 -04:00
parent b872750444
commit 2bbe33fedc
5 changed files with 122 additions and 61 deletions

View File

@@ -13,7 +13,9 @@ from llm import LLMClient
load_dotenv()
ollama_client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
ollama_client = Client(
host=os.getenv("OLLAMA_HOST", "http://localhost:11434"), timeout=10.0
)
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):

4
llm.py
View File

@@ -3,8 +3,6 @@ import os
from ollama import Client
from openai import OpenAI
import typing
import logging
logging.basicConfig(level=logging.INFO)
@@ -14,7 +12,7 @@ class LLMClient:
def __init__(self):
try:
self.ollama_client = Client(
host=os.getenv("OLLAMA_URL", "http://localhost:11434")
host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0
)
self.ollama_client.chat(
model="gemma3:4b", messages=[{"role": "system", "content": "test"}]

56
main.py
View File

@@ -7,12 +7,10 @@ from typing import Any, Union
import argparse
import chromadb
import ollama
from openai import OpenAI
from request import PaperlessNGXService
from chunker import Chunker
from query import QueryGenerator
from cleaner import pdf_to_image, summarize_pdf_image
from llm import LLMClient
@@ -21,13 +19,13 @@ from dotenv import load_dotenv
load_dotenv()
USE_OPENAI = os.getenv("OPENAI_API_KEY") != None
# Configure ollama client with URL from environment or default to localhost
ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434"))
ollama_client = ollama.Client(
host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0
)
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
simba_docs = client.get_or_create_collection(name="simba_docs3")
simba_docs = client.get_or_create_collection(name="simba_docs2")
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
parser = argparse.ArgumentParser(
@@ -46,6 +44,7 @@ llm_client = LLMClient()
def index_using_pdf_llm():
logging.info("reindex data...")
files = ppngx.get_data()
for file in files:
document_id = file["id"]
@@ -72,28 +71,35 @@ def date_to_epoch(date_str: str) -> float:
return date.timestamp()
def chunk_data(docs: list[dict[str, Union[str, Any]]], collection):
def chunk_data(docs: list[dict[str, Union[str, Any]]], collection, doctypes):
# Step 2: Create chunks
chunker = Chunker(collection)
print(f"chunking {len(docs)} documents")
texts: list[str] = [doc["content"] for doc in docs]
with sqlite3.connect("visited.db") as conn:
with sqlite3.connect("visited.db") as conn:
to_insert = []
c = conn.cursor()
for index, text in enumerate(texts):
metadata = {
"created_date": date_to_epoch(docs[index]["created_date"]),
"filename": docs[index]["original_file_name"],
"document_type": doctypes.get(docs[index]["document_type"], ""),
}
if doctypes:
metadata["type"] = doctypes.get(docs[index]["document_type"])
chunker.chunk_document(
document=text,
metadata=metadata,
)
to_insert.append((docs[index]["id"],))
c.executemany("INSERT INTO indexed_documents (paperless_id) values (?)", to_insert)
c.executemany(
"INSERT INTO indexed_documents (paperless_id) values (?)", to_insert
)
conn.commit()
def chunk_text(texts: list[str], collection):
@@ -169,10 +175,13 @@ def consult_simba_oracle(input: str):
collection=simba_docs,
)
def filter_indexed_files(docs):
with sqlite3.connect("visited.db") as conn:
c = conn.cursor()
c.execute("CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)")
c.execute(
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
)
c.execute("SELECT paperless_id FROM indexed_documents")
rows = c.fetchall()
conn.commit()
@@ -181,7 +190,6 @@ def filter_indexed_files(docs):
return [doc for doc in docs if doc["id"] not in visited]
if __name__ == "__main__":
args = parser.parse_args()
if args.reindex:
@@ -192,20 +200,22 @@ if __name__ == "__main__":
print(f"Fetched {len(docs)} documents")
#
print("Chunking documents now ...")
chunk_data(docs, collection=simba_docs)
tag_lookup = ppngx.get_tags()
doctype_lookup = ppngx.get_doctypes()
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
print("Done chunking documents")
# index_using_pdf_llm()
if args.index:
with open(args.index) as file:
extension = args.index.split(".")[-1]
if extension == "pdf":
pdf_path = ppngx.download_pdf_from_id(id=document_id)
image_paths = pdf_to_image(filepath=pdf_path)
print(f"summarizing {file}")
generated_summary = summarize_pdf_image(filepaths=image_paths)
elif extension in [".md", ".txt"]:
chunk_text(texts=[file.readall()], collection=simba_docs)
# if args.index:
# with open(args.index) as file:
# extension = args.index.split(".")[-1]
# if extension == "pdf":
# pdf_path = ppngx.download_pdf_from_id(id=document_id)
# image_paths = pdf_to_image(filepath=pdf_path)
# print(f"summarizing {file}")
# generated_summary = summarize_pdf_image(filepaths=image_paths)
# elif extension in [".md", ".txt"]:
# chunk_text(texts=[file.readall()], collection=simba_docs)
if args.query:
print("Consulting oracle ...")

105
query.py
View File

@@ -2,7 +2,7 @@ import json
import os
from typing import Literal
import datetime
from ollama import chat, ChatResponse, Client
from ollama import Client
from openai import OpenAI
@@ -38,6 +38,20 @@ class Time(BaseModel):
time: int
DOCTYPE_OPTIONS = [
"Bill",
"Image Description",
"Insurance",
"Medical Record",
"Documentation",
"Letter",
]
class DocumentType(BaseModel):
type: list[str] = Field(description="type of document", enum=DOCTYPE_OPTIONS)
PROMPT = """
You are an information specialist that processes user queries. The current year is 2025. The user queries are all about
a cat, Simba, and its records. The types of records are listed below. Using the query, extract the
@@ -82,7 +96,8 @@ document_types:
Only return the extracted metadata fields. Make sure the extracted metadata fields are valid JSON
"""
USE_OPENAI = os.getenv("OPENAI_API_KEY", None) != None
DOCTYPE_PROMPT = f"You are an information specialist that processes user queries. A query can have two tags attached from the following options. Based on the query, determine which of the following options is most appropriate: {','.join(DOCTYPE_OPTIONS)}"
class QueryGenerator:
@@ -102,43 +117,67 @@ class QueryGenerator:
return date.timestamp()
def get_doctype_query(self, input: str):
print(DOCTYPE_PROMPT)
client = OpenAI()
response = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are an information specialist that is really good at deciding what tags a query should have",
},
{"role": "user", "content": DOCTYPE_PROMPT + " " + input},
],
model="gpt-4o",
response_format={
"type": "json_schema",
"json_schema": {
"name": "document_type",
"schema": DocumentType.model_json_schema(),
},
},
)
response_json_str = response.choices[0].message.content
type_data = json.loads(response_json_str)
print(type_data)
return type_data
def get_query(self, input: str):
if USE_OPENAI:
client = OpenAI()
response = client.responses.parse(
model="gpt-4o",
input=[
{"role": "system", "content": PROMPT},
{"role": "user", "content": input},
],
text_format=GeneratedQuery,
)
print(response.output)
query = json.loads(response.output_parsed.extracted_metadata_fields)
else:
response: ChatResponse = ollama_client.chat(
model="gemma3n:e4b",
messages=[
{"role": "system", "content": PROMPT},
{"role": "user", "content": input},
],
format=GeneratedQuery.model_json_schema(),
)
client = OpenAI()
response = client.responses.parse(
model="gpt-4o",
input=[
{"role": "system", "content": PROMPT},
{"role": "user", "content": input},
],
text_format=GeneratedQuery,
)
print(response.output)
query = json.loads(response.output_parsed.extracted_metadata_fields)
# response: ChatResponse = ollama_client.chat(
# model="gemma3n:e4b",
# messages=[
# {"role": "system", "content": PROMPT},
# {"role": "user", "content": input},
# ],
# format=GeneratedQuery.model_json_schema(),
# )
query = json.loads(
json.loads(response["message"]["content"])["extracted_metadata_fields"]
)
date_key = list(query["created_date"].keys())[0]
query["created_date"][date_key] = self.date_to_epoch(
query["created_date"][date_key]
)
# query = json.loads(
# json.loads(response["message"]["content"])["extracted_metadata_fields"]
# )
# date_key = list(query["created_date"].keys())[0]
# query["created_date"][date_key] = self.date_to_epoch(
# query["created_date"][date_key]
# )
if "$" not in date_key:
query["created_date"]["$" + date_key] = query["created_date"][date_key]
# if "$" not in date_key:
# query["created_date"]["$" + date_key] = query["created_date"][date_key]
return query
if __name__ == "__main__":
qg = QueryGenerator()
print(qg.get_query("How heavy is Simba?"))
print(qg.get_doctype_query("How heavy is Simba?"))

View File

@@ -14,7 +14,7 @@ class PaperlessNGXService:
def __init__(self):
self.base_url = os.getenv("BASE_URL")
self.token = os.getenv("PAPERLESS_TOKEN")
self.url = f"http://{os.getenv('BASE_URL')}/api/documents/?query=simba"
self.url = f"http://{os.getenv('BASE_URL')}/api/documents/?tags__id=8"
self.headers = {"Authorization": f"Token {os.getenv('PAPERLESS_TOKEN')}"}
def get_data(self):
@@ -68,6 +68,18 @@ class PaperlessNGXService:
r = httpx.post(POST_URL, headers=self.headers, data=data, files=files)
r.raise_for_status()
def get_tags(self):
GET_URL = f"http://{os.getenv('BASE_URL')}/api/tags/"
r = httpx.get(GET_URL, headers=self.headers)
data = r.json()
return {tag["id"]: tag["name"] for tag in data["results"]}
def get_doctypes(self):
GET_URL = f"http://{os.getenv('BASE_URL')}/api/document_types/"
r = httpx.get(GET_URL, headers=self.headers)
data = r.json()
return {doctype["id"]: doctype["name"] for doctype in data["results"]}
if __name__ == "__main__":
pp = PaperlessNGXService()