query classification

This commit is contained in:
2025-10-26 17:29:00 -04:00
parent 591788dfa4
commit e577cb335b
2 changed files with 129 additions and 45 deletions

117
main.py
View File

@@ -7,6 +7,8 @@ import argparse
import chromadb
import ollama
import time
from request import PaperlessNGXService
from chunker import Chunker
@@ -36,6 +38,7 @@ parser.add_argument("query", type=str, help="questions about simba's health")
parser.add_argument(
"--reindex", action="store_true", help="re-index the simba documents"
)
parser.add_argument("--classify", action="store_true", help="test classification")
parser.add_argument("--index", help="index a file")
ppngx = PaperlessNGXService()
@@ -113,13 +116,22 @@ def chunk_text(texts: list[str], collection):
)
def classify_query(query: str, transcript: str) -> bool:
logging.info("Starting query generation")
qg_start = time.time()
qg = QueryGenerator()
query_type = qg.get_query_type(input=query, transcript=transcript)
logging.info(query_type)
qg_end = time.time()
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
return query_type == "Simba"
def consult_oracle(
input: str,
collection,
transcript: str = "",
):
import time
chunker = Chunker(collection)
start_time = time.time()
@@ -171,6 +183,16 @@ def consult_oracle(
return output
def llm_chat(input: str, transcript: str = "") -> str:
system_prompt = "You are a helpful assistant that understands veterinary terms."
transcript_prompt = f"Here is the message transcript thus far {transcript}."
prompt = f"""Answer the user in a humorous way as if you were a cat named Simba. Be very coy.
{transcript_prompt if len(transcript) > 0 else ""}
Respond to this prompt: {input}"""
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
return output
def paperless_workflow(input):
# Step 1: Get the text
ppngx = PaperlessNGXService()
@@ -181,11 +203,19 @@ def paperless_workflow(input):
def consult_simba_oracle(input: str, transcript: str = ""):
return consult_oracle(
input=input,
collection=simba_docs,
transcript=transcript,
)
is_simba_related = classify_query(query=input, transcript=transcript)
if is_simba_related:
logging.info("Query is related to simba")
return consult_oracle(
input=input,
collection=simba_docs,
transcript=transcript,
)
logging.info("Query is NOT related to simba")
return llm_chat(input=input, transcript=transcript)
def filter_indexed_files(docs):
@@ -202,48 +232,45 @@ def filter_indexed_files(docs):
return [doc for doc in docs if doc["id"] not in visited]
def reindex():
with sqlite3.connect("database/visited.db") as conn:
c = conn.cursor()
c.execute("DELETE FROM indexed_documents")
conn.commit()
# Delete all documents from the collection
all_docs = simba_docs.get()
if all_docs["ids"]:
simba_docs.delete(ids=all_docs["ids"])
logging.info("Fetching documents from Paperless-NGX")
ppngx = PaperlessNGXService()
docs = ppngx.get_data()
docs = filter_indexed_files(docs)
logging.info(f"Fetched {len(docs)} documents")
# Delete all chromadb data
ids = simba_docs.get(ids=None, limit=None, offset=0)
all_ids = ids["ids"]
if len(all_ids) > 0:
simba_docs.delete(ids=all_ids)
# Chunk documents
logging.info("Chunking documents now ...")
doctype_lookup = ppngx.get_doctypes()
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
logging.info("Done chunking documents")
if __name__ == "__main__":
args = parser.parse_args()
if args.reindex:
with sqlite3.connect("database/visited.db") as conn:
c = conn.cursor()
c.execute("DELETE FROM indexed_documents")
conn.commit()
reindex()
# Delete all documents from the collection
all_docs = simba_docs.get()
if all_docs["ids"]:
simba_docs.delete(ids=all_docs["ids"])
logging.info("Fetching documents from Paperless-NGX")
ppngx = PaperlessNGXService()
docs = ppngx.get_data()
docs = filter_indexed_files(docs)
logging.info(f"Fetched {len(docs)} documents")
# Delete all chromadb data
ids = simba_docs.get(ids=None, limit=None, offset=0)
all_ids = ids["ids"]
if len(all_ids) > 0:
simba_docs.delete(ids=all_ids)
# Chunk documents
logging.info("Chunking documents now ...")
tag_lookup = ppngx.get_tags()
doctype_lookup = ppngx.get_doctypes()
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
logging.info("Done chunking documents")
# if args.index:
# with open(args.index) as file:
# extension = args.index.split(".")[-1]
# if extension == "pdf":
# pdf_path = ppngx.download_pdf_from_id(id=document_id)
# image_paths = pdf_to_image(filepath=pdf_path)
# print(f"summarizing {file}")
# generated_summary = summarize_pdf_image(filepaths=image_paths)
# elif extension in [".md", ".txt"]:
# chunk_text(texts=[file.readall()], collection=simba_docs)
if args.classify:
consult_simba_oracle(input="yohohoho testing")
consult_simba_oracle(input="write an email")
consult_simba_oracle(input="how much does simba weigh")
if args.query:
logging.info("Consulting oracle ...")