Metadata filtering
This commit is contained in:
43
main.py
43
main.py
@@ -2,7 +2,6 @@ import datetime
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from typing import Any, Union
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import chromadb
|
import chromadb
|
||||||
@@ -13,11 +12,12 @@ from request import PaperlessNGXService
|
|||||||
from chunker import Chunker
|
from chunker import Chunker
|
||||||
from cleaner import pdf_to_image, summarize_pdf_image
|
from cleaner import pdf_to_image, summarize_pdf_image
|
||||||
from llm import LLMClient
|
from llm import LLMClient
|
||||||
|
from query import QueryGenerator
|
||||||
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
_dotenv_loaded = load_dotenv()
|
||||||
|
|
||||||
# Configure ollama client with URL from environment or default to localhost
|
# Configure ollama client with URL from environment or default to localhost
|
||||||
ollama_client = ollama.Client(
|
ollama_client = ollama.Client(
|
||||||
@@ -43,18 +43,18 @@ ppngx = PaperlessNGXService()
|
|||||||
llm_client = LLMClient()
|
llm_client = LLMClient()
|
||||||
|
|
||||||
|
|
||||||
def index_using_pdf_llm():
|
def index_using_pdf_llm(doctypes):
|
||||||
logging.info("reindex data...")
|
logging.info("reindex data...")
|
||||||
files = ppngx.get_data()
|
files = ppngx.get_data()
|
||||||
for file in files:
|
for file in files:
|
||||||
document_id = file["id"]
|
document_id: int = file["id"]
|
||||||
pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
||||||
image_paths = pdf_to_image(filepath=pdf_path)
|
image_paths = pdf_to_image(filepath=pdf_path)
|
||||||
print(f"summarizing {file}")
|
print(f"summarizing {file}")
|
||||||
generated_summary = summarize_pdf_image(filepaths=image_paths)
|
generated_summary = summarize_pdf_image(filepaths=image_paths)
|
||||||
file["content"] = generated_summary
|
file["content"] = generated_summary
|
||||||
|
|
||||||
chunk_data(files, simba_docs)
|
chunk_data(files, simba_docs, doctypes=doctypes)
|
||||||
|
|
||||||
|
|
||||||
def date_to_epoch(date_str: str) -> float:
|
def date_to_epoch(date_str: str) -> float:
|
||||||
@@ -71,7 +71,7 @@ def date_to_epoch(date_str: str) -> float:
|
|||||||
return date.timestamp()
|
return date.timestamp()
|
||||||
|
|
||||||
|
|
||||||
def chunk_data(docs: list[dict[str, Union[str, Any]]], collection, doctypes):
|
def chunk_data(docs, collection, doctypes):
|
||||||
# Step 2: Create chunks
|
# Step 2: Create chunks
|
||||||
chunker = Chunker(collection)
|
chunker = Chunker(collection)
|
||||||
|
|
||||||
@@ -121,13 +121,15 @@ def consult_oracle(input: str, collection):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Ask
|
# Ask
|
||||||
# print("Starting query generation")
|
print("Starting query generation")
|
||||||
# qg_start = time.time()
|
qg_start = time.time()
|
||||||
# qg = QueryGenerator()
|
qg = QueryGenerator()
|
||||||
|
doctype_query = qg.get_doctype_query(input=input)
|
||||||
# metadata_filter = qg.get_query(input)
|
# metadata_filter = qg.get_query(input)
|
||||||
# qg_end = time.time()
|
metadata_filter = {**doctype_query}
|
||||||
# print(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
print(metadata_filter)
|
||||||
# print(metadata_filter)
|
qg_end = time.time()
|
||||||
|
print(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
||||||
|
|
||||||
print("Starting embedding generation")
|
print("Starting embedding generation")
|
||||||
embedding_start = time.time()
|
embedding_start = time.time()
|
||||||
@@ -140,8 +142,9 @@ def consult_oracle(input: str, collection):
|
|||||||
results = collection.query(
|
results = collection.query(
|
||||||
query_texts=[input],
|
query_texts=[input],
|
||||||
query_embeddings=embeddings,
|
query_embeddings=embeddings,
|
||||||
# where=metadata_filter,
|
where=metadata_filter,
|
||||||
)
|
)
|
||||||
|
print(results)
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
print(f"Collection query took {query_end - query_start:.2f} seconds")
|
print(f"Collection query took {query_end - query_start:.2f} seconds")
|
||||||
|
|
||||||
@@ -193,18 +196,28 @@ def filter_indexed_files(docs):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.reindex:
|
if args.reindex:
|
||||||
|
with sqlite3.connect("./visited.db") as conn:
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute("DELETE FROM indexed_documents")
|
||||||
|
|
||||||
print("Fetching documents from Paperless-NGX")
|
print("Fetching documents from Paperless-NGX")
|
||||||
ppngx = PaperlessNGXService()
|
ppngx = PaperlessNGXService()
|
||||||
docs = ppngx.get_data()
|
docs = ppngx.get_data()
|
||||||
docs = filter_indexed_files(docs)
|
docs = filter_indexed_files(docs)
|
||||||
print(f"Fetched {len(docs)} documents")
|
print(f"Fetched {len(docs)} documents")
|
||||||
#
|
|
||||||
|
# Delete all chromadb data
|
||||||
|
ids = simba_docs.get(ids=None, limit=None, offset=0)
|
||||||
|
all_ids = ids["ids"]
|
||||||
|
if len(all_ids) > 0:
|
||||||
|
simba_docs.delete(ids=all_ids)
|
||||||
|
|
||||||
|
# Chunk documents
|
||||||
print("Chunking documents now ...")
|
print("Chunking documents now ...")
|
||||||
tag_lookup = ppngx.get_tags()
|
tag_lookup = ppngx.get_tags()
|
||||||
doctype_lookup = ppngx.get_doctypes()
|
doctype_lookup = ppngx.get_doctypes()
|
||||||
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
|
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
|
||||||
print("Done chunking documents")
|
print("Done chunking documents")
|
||||||
# index_using_pdf_llm()
|
|
||||||
|
|
||||||
# if args.index:
|
# if args.index:
|
||||||
# with open(args.index) as file:
|
# with open(args.index) as file:
|
||||||
|
|||||||
21
query.py
21
query.py
@@ -9,7 +9,9 @@ from openai import OpenAI
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
# Configure ollama client with URL from environment or default to localhost
|
# Configure ollama client with URL from environment or default to localhost
|
||||||
ollama_client = Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
ollama_client = Client(
|
||||||
|
host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0
|
||||||
|
)
|
||||||
|
|
||||||
# This uses inferred filters — which means using LLM to create the metadata filters
|
# This uses inferred filters — which means using LLM to create the metadata filters
|
||||||
|
|
||||||
@@ -97,7 +99,17 @@ Only return the extracted metadata fields. Make sure the extracted metadata fiel
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
DOCTYPE_PROMPT = f"You are an information specialist that processes user queries. A query can have two tags attached from the following options. Based on the query, determine which of the following options is most appropriate: {','.join(DOCTYPE_OPTIONS)}"
|
DOCTYPE_PROMPT = f"""You are an information specialist that processes user queries. A query can have two tags attached from the following options. Based on the query, determine which of the following options is most appropriate: {",".join(DOCTYPE_OPTIONS)}
|
||||||
|
|
||||||
|
### Example 1
|
||||||
|
Query: "Who is Simba's current vet?"
|
||||||
|
Tags: ["Bill", "Medical Record"]
|
||||||
|
|
||||||
|
|
||||||
|
### Example 2
|
||||||
|
Query: "Who does Simba know?"
|
||||||
|
Tags: ["Letter", "Documentation"]
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class QueryGenerator:
|
class QueryGenerator:
|
||||||
@@ -118,7 +130,6 @@ class QueryGenerator:
|
|||||||
return date.timestamp()
|
return date.timestamp()
|
||||||
|
|
||||||
def get_doctype_query(self, input: str):
|
def get_doctype_query(self, input: str):
|
||||||
print(DOCTYPE_PROMPT)
|
|
||||||
client = OpenAI()
|
client = OpenAI()
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
messages=[
|
messages=[
|
||||||
@@ -140,8 +151,8 @@ class QueryGenerator:
|
|||||||
|
|
||||||
response_json_str = response.choices[0].message.content
|
response_json_str = response.choices[0].message.content
|
||||||
type_data = json.loads(response_json_str)
|
type_data = json.loads(response_json_str)
|
||||||
print(type_data)
|
metadata_query = {"document_type": {"$in": type_data["type"]}}
|
||||||
return type_data
|
return metadata_query
|
||||||
|
|
||||||
def get_query(self, input: str):
|
def get_query(self, input: str):
|
||||||
client = OpenAI()
|
client = OpenAI()
|
||||||
|
|||||||
Reference in New Issue
Block a user