diff --git a/main.py b/main.py index 16d0313..6c88ded 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,8 @@ import argparse import chromadb import ollama +import time + from request import PaperlessNGXService from chunker import Chunker @@ -36,6 +38,7 @@ parser.add_argument("query", type=str, help="questions about simba's health") parser.add_argument( "--reindex", action="store_true", help="re-index the simba documents" ) +parser.add_argument("--classify", action="store_true", help="test classification") parser.add_argument("--index", help="index a file") ppngx = PaperlessNGXService() @@ -113,13 +116,22 @@ def chunk_text(texts: list[str], collection): ) +def classify_query(query: str, transcript: str) -> bool: + logging.info("Starting query generation") + qg_start = time.time() + qg = QueryGenerator() + query_type = qg.get_query_type(input=query, transcript=transcript) + logging.info(query_type) + qg_end = time.time() + logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds") + return query_type == "Simba" + + def consult_oracle( input: str, collection, transcript: str = "", ): - import time - chunker = Chunker(collection) start_time = time.time() @@ -171,6 +183,16 @@ def consult_oracle( return output +def llm_chat(input: str, transcript: str = "") -> str: + system_prompt = "You are a helpful assistant that understands veterinary terms." + transcript_prompt = f"Here is the message transcript thus far {transcript}." + prompt = f"""Answer the user in a humorous way as if you were a cat named Simba. Be very coy. + {transcript_prompt if len(transcript) > 0 else ""} + Respond to this prompt: {input}""" + output = llm_client.chat(prompt=prompt, system_prompt=system_prompt) + return output + + def paperless_workflow(input): # Step 1: Get the text ppngx = PaperlessNGXService() @@ -181,11 +203,19 @@ def paperless_workflow(input): def consult_simba_oracle(input: str, transcript: str = ""): - return consult_oracle( - input=input, - collection=simba_docs, - transcript=transcript, - ) + is_simba_related = classify_query(query=input, transcript=transcript) + + if is_simba_related: + logging.info("Query is related to simba") + return consult_oracle( + input=input, + collection=simba_docs, + transcript=transcript, + ) + + logging.info("Query is NOT related to simba") + + return llm_chat(input=input, transcript=transcript) def filter_indexed_files(docs): @@ -202,48 +232,45 @@ def filter_indexed_files(docs): return [doc for doc in docs if doc["id"] not in visited] +def reindex(): + with sqlite3.connect("database/visited.db") as conn: + c = conn.cursor() + c.execute("DELETE FROM indexed_documents") + conn.commit() + + # Delete all documents from the collection + all_docs = simba_docs.get() + if all_docs["ids"]: + simba_docs.delete(ids=all_docs["ids"]) + + logging.info("Fetching documents from Paperless-NGX") + ppngx = PaperlessNGXService() + docs = ppngx.get_data() + docs = filter_indexed_files(docs) + logging.info(f"Fetched {len(docs)} documents") + + # Delete all chromadb data + ids = simba_docs.get(ids=None, limit=None, offset=0) + all_ids = ids["ids"] + if len(all_ids) > 0: + simba_docs.delete(ids=all_ids) + + # Chunk documents + logging.info("Chunking documents now ...") + doctype_lookup = ppngx.get_doctypes() + chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup) + logging.info("Done chunking documents") + + if __name__ == "__main__": args = parser.parse_args() if args.reindex: - with sqlite3.connect("database/visited.db") as conn: - c = conn.cursor() - c.execute("DELETE FROM indexed_documents") - conn.commit() + reindex() - # Delete all documents from the collection - all_docs = simba_docs.get() - if all_docs["ids"]: - simba_docs.delete(ids=all_docs["ids"]) - - logging.info("Fetching documents from Paperless-NGX") - ppngx = PaperlessNGXService() - docs = ppngx.get_data() - docs = filter_indexed_files(docs) - logging.info(f"Fetched {len(docs)} documents") - - # Delete all chromadb data - ids = simba_docs.get(ids=None, limit=None, offset=0) - all_ids = ids["ids"] - if len(all_ids) > 0: - simba_docs.delete(ids=all_ids) - - # Chunk documents - logging.info("Chunking documents now ...") - tag_lookup = ppngx.get_tags() - doctype_lookup = ppngx.get_doctypes() - chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup) - logging.info("Done chunking documents") - - # if args.index: - # with open(args.index) as file: - # extension = args.index.split(".")[-1] - # if extension == "pdf": - # pdf_path = ppngx.download_pdf_from_id(id=document_id) - # image_paths = pdf_to_image(filepath=pdf_path) - # print(f"summarizing {file}") - # generated_summary = summarize_pdf_image(filepaths=image_paths) - # elif extension in [".md", ".txt"]: - # chunk_text(texts=[file.readall()], collection=simba_docs) + if args.classify: + consult_simba_oracle(input="yohohoho testing") + consult_simba_oracle(input="write an email") + consult_simba_oracle(input="how much does simba weigh") if args.query: logging.info("Consulting oracle ...") diff --git a/query.py b/query.py index 0435a62..974b824 100644 --- a/query.py +++ b/query.py @@ -49,11 +49,20 @@ DOCTYPE_OPTIONS = [ "Letter", ] +QUERY_TYPE_OPTIONS = [ + "Simba", + "Other", +] + class DocumentType(BaseModel): type: list[str] = Field(description="type of document", enum=DOCTYPE_OPTIONS) +class QueryType(BaseModel): + type: str = Field(desciption="type of query", enum=QUERY_TYPE_OPTIONS) + + PROMPT = """ You are an information specialist that processes user queries. The current year is 2025. The user queries are all about a cat, Simba, and its records. The types of records are listed below. Using the query, extract the @@ -111,6 +120,27 @@ Query: "Who does Simba know?" Tags: ["Letter", "Documentation"] """ +QUERY_TYPE_PROMPT = f"""You are an information specialist that processes user queries. +A query can have one tag attached from the following options. Based on the query and the transcript which is listed below, determine + which of the following options is most appropriate: {",".join(QUERY_TYPE_OPTIONS)} + +### Example 1 +Query: "Who is Simba's current vet?" +Tags: ["Simba"] + + +### Example 2 +Query: "What is the capital of Tokyo?" +Tags: ["Other"] + + +### Example 3 +Query: "Can you help me write an email?" +Tags: ["Other"] + +TRANSCRIPT: +""" + class QueryGenerator: def __init__(self) -> None: @@ -154,6 +184,33 @@ class QueryGenerator: metadata_query = {"document_type": {"$in": type_data["type"]}} return metadata_query + def get_query_type(self, input: str, transcript: str): + client = OpenAI() + response = client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are an information specialist that is really good at deciding what tags a query should have", + }, + { + "role": "user", + "content": f"{QUERY_TYPE_PROMPT}\nTRANSCRIPT:\n{transcript}\nQUERY:{input}", + }, + ], + model="gpt-4o", + response_format={ + "type": "json_schema", + "json_schema": { + "name": "query_type", + "schema": QueryType.model_json_schema(), + }, + }, + ) + + response_json_str = response.choices[0].message.content + type_data = json.loads(response_json_str) + return type_data["type"] + def get_query(self, input: str): client = OpenAI() response = client.responses.parse(