Metadata filtering
This commit is contained in:
43
main.py
43
main.py
@@ -2,7 +2,6 @@ import datetime
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
from typing import Any, Union
|
||||
|
||||
import argparse
|
||||
import chromadb
|
||||
@@ -13,11 +12,12 @@ from request import PaperlessNGXService
|
||||
from chunker import Chunker
|
||||
from cleaner import pdf_to_image, summarize_pdf_image
|
||||
from llm import LLMClient
|
||||
from query import QueryGenerator
|
||||
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
_dotenv_loaded = load_dotenv()
|
||||
|
||||
# Configure ollama client with URL from environment or default to localhost
|
||||
ollama_client = ollama.Client(
|
||||
@@ -43,18 +43,18 @@ ppngx = PaperlessNGXService()
|
||||
llm_client = LLMClient()
|
||||
|
||||
|
||||
def index_using_pdf_llm():
|
||||
def index_using_pdf_llm(doctypes):
|
||||
logging.info("reindex data...")
|
||||
files = ppngx.get_data()
|
||||
for file in files:
|
||||
document_id = file["id"]
|
||||
document_id: int = file["id"]
|
||||
pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
||||
image_paths = pdf_to_image(filepath=pdf_path)
|
||||
print(f"summarizing {file}")
|
||||
generated_summary = summarize_pdf_image(filepaths=image_paths)
|
||||
file["content"] = generated_summary
|
||||
|
||||
chunk_data(files, simba_docs)
|
||||
chunk_data(files, simba_docs, doctypes=doctypes)
|
||||
|
||||
|
||||
def date_to_epoch(date_str: str) -> float:
|
||||
@@ -71,7 +71,7 @@ def date_to_epoch(date_str: str) -> float:
|
||||
return date.timestamp()
|
||||
|
||||
|
||||
def chunk_data(docs: list[dict[str, Union[str, Any]]], collection, doctypes):
|
||||
def chunk_data(docs, collection, doctypes):
|
||||
# Step 2: Create chunks
|
||||
chunker = Chunker(collection)
|
||||
|
||||
@@ -121,13 +121,15 @@ def consult_oracle(input: str, collection):
|
||||
start_time = time.time()
|
||||
|
||||
# Ask
|
||||
# print("Starting query generation")
|
||||
# qg_start = time.time()
|
||||
# qg = QueryGenerator()
|
||||
print("Starting query generation")
|
||||
qg_start = time.time()
|
||||
qg = QueryGenerator()
|
||||
doctype_query = qg.get_doctype_query(input=input)
|
||||
# metadata_filter = qg.get_query(input)
|
||||
# qg_end = time.time()
|
||||
# print(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
||||
# print(metadata_filter)
|
||||
metadata_filter = {**doctype_query}
|
||||
print(metadata_filter)
|
||||
qg_end = time.time()
|
||||
print(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
||||
|
||||
print("Starting embedding generation")
|
||||
embedding_start = time.time()
|
||||
@@ -140,8 +142,9 @@ def consult_oracle(input: str, collection):
|
||||
results = collection.query(
|
||||
query_texts=[input],
|
||||
query_embeddings=embeddings,
|
||||
# where=metadata_filter,
|
||||
where=metadata_filter,
|
||||
)
|
||||
print(results)
|
||||
query_end = time.time()
|
||||
print(f"Collection query took {query_end - query_start:.2f} seconds")
|
||||
|
||||
@@ -193,18 +196,28 @@ def filter_indexed_files(docs):
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.reindex:
|
||||
with sqlite3.connect("./visited.db") as conn:
|
||||
c = conn.cursor()
|
||||
c.execute("DELETE FROM indexed_documents")
|
||||
|
||||
print("Fetching documents from Paperless-NGX")
|
||||
ppngx = PaperlessNGXService()
|
||||
docs = ppngx.get_data()
|
||||
docs = filter_indexed_files(docs)
|
||||
print(f"Fetched {len(docs)} documents")
|
||||
#
|
||||
|
||||
# Delete all chromadb data
|
||||
ids = simba_docs.get(ids=None, limit=None, offset=0)
|
||||
all_ids = ids["ids"]
|
||||
if len(all_ids) > 0:
|
||||
simba_docs.delete(ids=all_ids)
|
||||
|
||||
# Chunk documents
|
||||
print("Chunking documents now ...")
|
||||
tag_lookup = ppngx.get_tags()
|
||||
doctype_lookup = ppngx.get_doctypes()
|
||||
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
|
||||
print("Done chunking documents")
|
||||
# index_using_pdf_llm()
|
||||
|
||||
# if args.index:
|
||||
# with open(args.index) as file:
|
||||
|
||||
21
query.py
21
query.py
@@ -9,7 +9,9 @@ from openai import OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Configure ollama client with URL from environment or default to localhost
|
||||
ollama_client = Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
||||
ollama_client = Client(
|
||||
host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0
|
||||
)
|
||||
|
||||
# This uses inferred filters — which means using LLM to create the metadata filters
|
||||
|
||||
@@ -97,7 +99,17 @@ Only return the extracted metadata fields. Make sure the extracted metadata fiel
|
||||
"""
|
||||
|
||||
|
||||
DOCTYPE_PROMPT = f"You are an information specialist that processes user queries. A query can have two tags attached from the following options. Based on the query, determine which of the following options is most appropriate: {','.join(DOCTYPE_OPTIONS)}"
|
||||
DOCTYPE_PROMPT = f"""You are an information specialist that processes user queries. A query can have two tags attached from the following options. Based on the query, determine which of the following options is most appropriate: {",".join(DOCTYPE_OPTIONS)}
|
||||
|
||||
### Example 1
|
||||
Query: "Who is Simba's current vet?"
|
||||
Tags: ["Bill", "Medical Record"]
|
||||
|
||||
|
||||
### Example 2
|
||||
Query: "Who does Simba know?"
|
||||
Tags: ["Letter", "Documentation"]
|
||||
"""
|
||||
|
||||
|
||||
class QueryGenerator:
|
||||
@@ -118,7 +130,6 @@ class QueryGenerator:
|
||||
return date.timestamp()
|
||||
|
||||
def get_doctype_query(self, input: str):
|
||||
print(DOCTYPE_PROMPT)
|
||||
client = OpenAI()
|
||||
response = client.chat.completions.create(
|
||||
messages=[
|
||||
@@ -140,8 +151,8 @@ class QueryGenerator:
|
||||
|
||||
response_json_str = response.choices[0].message.content
|
||||
type_data = json.loads(response_json_str)
|
||||
print(type_data)
|
||||
return type_data
|
||||
metadata_query = {"document_type": {"$in": type_data["type"]}}
|
||||
return metadata_query
|
||||
|
||||
def get_query(self, input: str):
|
||||
client = OpenAI()
|
||||
|
||||
Reference in New Issue
Block a user