query classification
This commit is contained in:
59
main.py
59
main.py
@@ -7,6 +7,8 @@ import argparse
|
||||
import chromadb
|
||||
import ollama
|
||||
|
||||
import time
|
||||
|
||||
|
||||
from request import PaperlessNGXService
|
||||
from chunker import Chunker
|
||||
@@ -36,6 +38,7 @@ parser.add_argument("query", type=str, help="questions about simba's health")
|
||||
parser.add_argument(
|
||||
"--reindex", action="store_true", help="re-index the simba documents"
|
||||
)
|
||||
parser.add_argument("--classify", action="store_true", help="test classification")
|
||||
parser.add_argument("--index", help="index a file")
|
||||
|
||||
ppngx = PaperlessNGXService()
|
||||
@@ -113,13 +116,22 @@ def chunk_text(texts: list[str], collection):
|
||||
)
|
||||
|
||||
|
||||
def classify_query(query: str, transcript: str) -> bool:
|
||||
logging.info("Starting query generation")
|
||||
qg_start = time.time()
|
||||
qg = QueryGenerator()
|
||||
query_type = qg.get_query_type(input=query, transcript=transcript)
|
||||
logging.info(query_type)
|
||||
qg_end = time.time()
|
||||
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
||||
return query_type == "Simba"
|
||||
|
||||
|
||||
def consult_oracle(
|
||||
input: str,
|
||||
collection,
|
||||
transcript: str = "",
|
||||
):
|
||||
import time
|
||||
|
||||
chunker = Chunker(collection)
|
||||
|
||||
start_time = time.time()
|
||||
@@ -171,6 +183,16 @@ def consult_oracle(
|
||||
return output
|
||||
|
||||
|
||||
def llm_chat(input: str, transcript: str = "") -> str:
|
||||
system_prompt = "You are a helpful assistant that understands veterinary terms."
|
||||
transcript_prompt = f"Here is the message transcript thus far {transcript}."
|
||||
prompt = f"""Answer the user in a humorous way as if you were a cat named Simba. Be very coy.
|
||||
{transcript_prompt if len(transcript) > 0 else ""}
|
||||
Respond to this prompt: {input}"""
|
||||
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
|
||||
return output
|
||||
|
||||
|
||||
def paperless_workflow(input):
|
||||
# Step 1: Get the text
|
||||
ppngx = PaperlessNGXService()
|
||||
@@ -181,12 +203,20 @@ def paperless_workflow(input):
|
||||
|
||||
|
||||
def consult_simba_oracle(input: str, transcript: str = ""):
|
||||
is_simba_related = classify_query(query=input, transcript=transcript)
|
||||
|
||||
if is_simba_related:
|
||||
logging.info("Query is related to simba")
|
||||
return consult_oracle(
|
||||
input=input,
|
||||
collection=simba_docs,
|
||||
transcript=transcript,
|
||||
)
|
||||
|
||||
logging.info("Query is NOT related to simba")
|
||||
|
||||
return llm_chat(input=input, transcript=transcript)
|
||||
|
||||
|
||||
def filter_indexed_files(docs):
|
||||
with sqlite3.connect("database/visited.db") as conn:
|
||||
@@ -202,9 +232,7 @@ def filter_indexed_files(docs):
|
||||
return [doc for doc in docs if doc["id"] not in visited]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.reindex:
|
||||
def reindex():
|
||||
with sqlite3.connect("database/visited.db") as conn:
|
||||
c = conn.cursor()
|
||||
c.execute("DELETE FROM indexed_documents")
|
||||
@@ -229,21 +257,20 @@ if __name__ == "__main__":
|
||||
|
||||
# Chunk documents
|
||||
logging.info("Chunking documents now ...")
|
||||
tag_lookup = ppngx.get_tags()
|
||||
doctype_lookup = ppngx.get_doctypes()
|
||||
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
|
||||
logging.info("Done chunking documents")
|
||||
|
||||
# if args.index:
|
||||
# with open(args.index) as file:
|
||||
# extension = args.index.split(".")[-1]
|
||||
# if extension == "pdf":
|
||||
# pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
||||
# image_paths = pdf_to_image(filepath=pdf_path)
|
||||
# print(f"summarizing {file}")
|
||||
# generated_summary = summarize_pdf_image(filepaths=image_paths)
|
||||
# elif extension in [".md", ".txt"]:
|
||||
# chunk_text(texts=[file.readall()], collection=simba_docs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.reindex:
|
||||
reindex()
|
||||
|
||||
if args.classify:
|
||||
consult_simba_oracle(input="yohohoho testing")
|
||||
consult_simba_oracle(input="write an email")
|
||||
consult_simba_oracle(input="how much does simba weigh")
|
||||
|
||||
if args.query:
|
||||
logging.info("Consulting oracle ...")
|
||||
|
||||
57
query.py
57
query.py
@@ -49,11 +49,20 @@ DOCTYPE_OPTIONS = [
|
||||
"Letter",
|
||||
]
|
||||
|
||||
QUERY_TYPE_OPTIONS = [
|
||||
"Simba",
|
||||
"Other",
|
||||
]
|
||||
|
||||
|
||||
class DocumentType(BaseModel):
|
||||
type: list[str] = Field(description="type of document", enum=DOCTYPE_OPTIONS)
|
||||
|
||||
|
||||
class QueryType(BaseModel):
|
||||
type: str = Field(desciption="type of query", enum=QUERY_TYPE_OPTIONS)
|
||||
|
||||
|
||||
PROMPT = """
|
||||
You are an information specialist that processes user queries. The current year is 2025. The user queries are all about
|
||||
a cat, Simba, and its records. The types of records are listed below. Using the query, extract the
|
||||
@@ -111,6 +120,27 @@ Query: "Who does Simba know?"
|
||||
Tags: ["Letter", "Documentation"]
|
||||
"""
|
||||
|
||||
QUERY_TYPE_PROMPT = f"""You are an information specialist that processes user queries.
|
||||
A query can have one tag attached from the following options. Based on the query and the transcript which is listed below, determine
|
||||
which of the following options is most appropriate: {",".join(QUERY_TYPE_OPTIONS)}
|
||||
|
||||
### Example 1
|
||||
Query: "Who is Simba's current vet?"
|
||||
Tags: ["Simba"]
|
||||
|
||||
|
||||
### Example 2
|
||||
Query: "What is the capital of Tokyo?"
|
||||
Tags: ["Other"]
|
||||
|
||||
|
||||
### Example 3
|
||||
Query: "Can you help me write an email?"
|
||||
Tags: ["Other"]
|
||||
|
||||
TRANSCRIPT:
|
||||
"""
|
||||
|
||||
|
||||
class QueryGenerator:
|
||||
def __init__(self) -> None:
|
||||
@@ -154,6 +184,33 @@ class QueryGenerator:
|
||||
metadata_query = {"document_type": {"$in": type_data["type"]}}
|
||||
return metadata_query
|
||||
|
||||
def get_query_type(self, input: str, transcript: str):
|
||||
client = OpenAI()
|
||||
response = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an information specialist that is really good at deciding what tags a query should have",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{QUERY_TYPE_PROMPT}\nTRANSCRIPT:\n{transcript}\nQUERY:{input}",
|
||||
},
|
||||
],
|
||||
model="gpt-4o",
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "query_type",
|
||||
"schema": QueryType.model_json_schema(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
response_json_str = response.choices[0].message.content
|
||||
type_data = json.loads(response_json_str)
|
||||
return type_data["type"]
|
||||
|
||||
def get_query(self, input: str):
|
||||
client = OpenAI()
|
||||
response = client.responses.parse(
|
||||
|
||||
Reference in New Issue
Block a user