Metadata filtering

This commit is contained in:
2025-10-16 22:36:21 -04:00
parent 2bbe33fedc
commit acaf681927
2 changed files with 46 additions and 22 deletions

43
main.py
View File

@@ -2,7 +2,6 @@ import datetime
import logging import logging
import os import os
import sqlite3 import sqlite3
from typing import Any, Union
import argparse import argparse
import chromadb import chromadb
@@ -13,11 +12,12 @@ from request import PaperlessNGXService
from chunker import Chunker from chunker import Chunker
from cleaner import pdf_to_image, summarize_pdf_image from cleaner import pdf_to_image, summarize_pdf_image
from llm import LLMClient from llm import LLMClient
from query import QueryGenerator
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() _dotenv_loaded = load_dotenv()
# Configure ollama client with URL from environment or default to localhost # Configure ollama client with URL from environment or default to localhost
ollama_client = ollama.Client( ollama_client = ollama.Client(
@@ -43,18 +43,18 @@ ppngx = PaperlessNGXService()
llm_client = LLMClient() llm_client = LLMClient()
def index_using_pdf_llm(): def index_using_pdf_llm(doctypes):
logging.info("reindex data...") logging.info("reindex data...")
files = ppngx.get_data() files = ppngx.get_data()
for file in files: for file in files:
document_id = file["id"] document_id: int = file["id"]
pdf_path = ppngx.download_pdf_from_id(id=document_id) pdf_path = ppngx.download_pdf_from_id(id=document_id)
image_paths = pdf_to_image(filepath=pdf_path) image_paths = pdf_to_image(filepath=pdf_path)
print(f"summarizing {file}") print(f"summarizing {file}")
generated_summary = summarize_pdf_image(filepaths=image_paths) generated_summary = summarize_pdf_image(filepaths=image_paths)
file["content"] = generated_summary file["content"] = generated_summary
chunk_data(files, simba_docs) chunk_data(files, simba_docs, doctypes=doctypes)
def date_to_epoch(date_str: str) -> float: def date_to_epoch(date_str: str) -> float:
@@ -71,7 +71,7 @@ def date_to_epoch(date_str: str) -> float:
return date.timestamp() return date.timestamp()
def chunk_data(docs: list[dict[str, Union[str, Any]]], collection, doctypes): def chunk_data(docs, collection, doctypes):
# Step 2: Create chunks # Step 2: Create chunks
chunker = Chunker(collection) chunker = Chunker(collection)
@@ -121,13 +121,15 @@ def consult_oracle(input: str, collection):
start_time = time.time() start_time = time.time()
# Ask # Ask
# print("Starting query generation") print("Starting query generation")
# qg_start = time.time() qg_start = time.time()
# qg = QueryGenerator() qg = QueryGenerator()
doctype_query = qg.get_doctype_query(input=input)
# metadata_filter = qg.get_query(input) # metadata_filter = qg.get_query(input)
# qg_end = time.time() metadata_filter = {**doctype_query}
# print(f"Query generation took {qg_end - qg_start:.2f} seconds") print(metadata_filter)
# print(metadata_filter) qg_end = time.time()
print(f"Query generation took {qg_end - qg_start:.2f} seconds")
print("Starting embedding generation") print("Starting embedding generation")
embedding_start = time.time() embedding_start = time.time()
@@ -140,8 +142,9 @@ def consult_oracle(input: str, collection):
results = collection.query( results = collection.query(
query_texts=[input], query_texts=[input],
query_embeddings=embeddings, query_embeddings=embeddings,
# where=metadata_filter, where=metadata_filter,
) )
print(results)
query_end = time.time() query_end = time.time()
print(f"Collection query took {query_end - query_start:.2f} seconds") print(f"Collection query took {query_end - query_start:.2f} seconds")
@@ -193,18 +196,28 @@ def filter_indexed_files(docs):
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if args.reindex: if args.reindex:
with sqlite3.connect("./visited.db") as conn:
c = conn.cursor()
c.execute("DELETE FROM indexed_documents")
print("Fetching documents from Paperless-NGX") print("Fetching documents from Paperless-NGX")
ppngx = PaperlessNGXService() ppngx = PaperlessNGXService()
docs = ppngx.get_data() docs = ppngx.get_data()
docs = filter_indexed_files(docs) docs = filter_indexed_files(docs)
print(f"Fetched {len(docs)} documents") print(f"Fetched {len(docs)} documents")
#
# Delete all chromadb data
ids = simba_docs.get(ids=None, limit=None, offset=0)
all_ids = ids["ids"]
if len(all_ids) > 0:
simba_docs.delete(ids=all_ids)
# Chunk documents
print("Chunking documents now ...") print("Chunking documents now ...")
tag_lookup = ppngx.get_tags() tag_lookup = ppngx.get_tags()
doctype_lookup = ppngx.get_doctypes() doctype_lookup = ppngx.get_doctypes()
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup) chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
print("Done chunking documents") print("Done chunking documents")
# index_using_pdf_llm()
# if args.index: # if args.index:
# with open(args.index) as file: # with open(args.index) as file:

View File

@@ -9,7 +9,9 @@ from openai import OpenAI
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# Configure ollama client with URL from environment or default to localhost # Configure ollama client with URL from environment or default to localhost
ollama_client = Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434")) ollama_client = Client(
host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0
)
# This uses inferred filters — which means using LLM to create the metadata filters # This uses inferred filters — which means using LLM to create the metadata filters
@@ -53,8 +55,8 @@ class DocumentType(BaseModel):
PROMPT = """ PROMPT = """
You are an information specialist that processes user queries. The current year is 2025. The user queries are all about You are an information specialist that processes user queries. The current year is 2025. The user queries are all about
a cat, Simba, and its records. The types of records are listed below. Using the query, extract the a cat, Simba, and its records. The types of records are listed below. Using the query, extract the
the date range the user is trying to query. You should return it as a JSON. The date tag is created_date. Return the date in epoch time. the date range the user is trying to query. You should return it as a JSON. The date tag is created_date. Return the date in epoch time.
If the created_date cannot be ascertained, set it to epoch time start. If the created_date cannot be ascertained, set it to epoch time start.
@@ -97,7 +99,17 @@ Only return the extracted metadata fields. Make sure the extracted metadata fiel
""" """
DOCTYPE_PROMPT = f"You are an information specialist that processes user queries. A query can have two tags attached from the following options. Based on the query, determine which of the following options is most appropriate: {','.join(DOCTYPE_OPTIONS)}" DOCTYPE_PROMPT = f"""You are an information specialist that processes user queries. A query can have two tags attached from the following options. Based on the query, determine which of the following options is most appropriate: {",".join(DOCTYPE_OPTIONS)}
### Example 1
Query: "Who is Simba's current vet?"
Tags: ["Bill", "Medical Record"]
### Example 2
Query: "Who does Simba know?"
Tags: ["Letter", "Documentation"]
"""
class QueryGenerator: class QueryGenerator:
@@ -118,7 +130,6 @@ class QueryGenerator:
return date.timestamp() return date.timestamp()
def get_doctype_query(self, input: str): def get_doctype_query(self, input: str):
print(DOCTYPE_PROMPT)
client = OpenAI() client = OpenAI()
response = client.chat.completions.create( response = client.chat.completions.create(
messages=[ messages=[
@@ -140,8 +151,8 @@ class QueryGenerator:
response_json_str = response.choices[0].message.content response_json_str = response.choices[0].message.content
type_data = json.loads(response_json_str) type_data = json.loads(response_json_str)
print(type_data) metadata_query = {"document_type": {"$in": type_data["type"]}}
return type_data return metadata_query
def get_query(self, input: str): def get_query(self, input: str):
client = OpenAI() client = OpenAI()