Starting attempt #2 at metadata filtering
This commit is contained in:
@@ -13,7 +13,9 @@ from llm import LLMClient
|
||||
|
||||
load_dotenv()
|
||||
|
||||
ollama_client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
|
||||
ollama_client = Client(
|
||||
host=os.getenv("OLLAMA_HOST", "http://localhost:11434"), timeout=10.0
|
||||
)
|
||||
|
||||
|
||||
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
|
||||
|
||||
4
llm.py
4
llm.py
@@ -3,8 +3,6 @@ import os
|
||||
from ollama import Client
|
||||
from openai import OpenAI
|
||||
|
||||
import typing
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -14,7 +12,7 @@ class LLMClient:
|
||||
def __init__(self):
|
||||
try:
|
||||
self.ollama_client = Client(
|
||||
host=os.getenv("OLLAMA_URL", "http://localhost:11434")
|
||||
host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0
|
||||
)
|
||||
self.ollama_client.chat(
|
||||
model="gemma3:4b", messages=[{"role": "system", "content": "test"}]
|
||||
|
||||
54
main.py
54
main.py
@@ -7,12 +7,10 @@ from typing import Any, Union
|
||||
import argparse
|
||||
import chromadb
|
||||
import ollama
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
from request import PaperlessNGXService
|
||||
from chunker import Chunker
|
||||
from query import QueryGenerator
|
||||
from cleaner import pdf_to_image, summarize_pdf_image
|
||||
from llm import LLMClient
|
||||
|
||||
@@ -21,13 +19,13 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
USE_OPENAI = os.getenv("OPENAI_API_KEY") != None
|
||||
|
||||
# Configure ollama client with URL from environment or default to localhost
|
||||
ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
||||
ollama_client = ollama.Client(
|
||||
host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0
|
||||
)
|
||||
|
||||
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
|
||||
simba_docs = client.get_or_create_collection(name="simba_docs3")
|
||||
simba_docs = client.get_or_create_collection(name="simba_docs2")
|
||||
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -46,6 +44,7 @@ llm_client = LLMClient()
|
||||
|
||||
|
||||
def index_using_pdf_llm():
|
||||
logging.info("reindex data...")
|
||||
files = ppngx.get_data()
|
||||
for file in files:
|
||||
document_id = file["id"]
|
||||
@@ -72,7 +71,7 @@ def date_to_epoch(date_str: str) -> float:
|
||||
return date.timestamp()
|
||||
|
||||
|
||||
def chunk_data(docs: list[dict[str, Union[str, Any]]], collection):
|
||||
def chunk_data(docs: list[dict[str, Union[str, Any]]], collection, doctypes):
|
||||
# Step 2: Create chunks
|
||||
chunker = Chunker(collection)
|
||||
|
||||
@@ -85,15 +84,22 @@ def chunk_data(docs: list[dict[str, Union[str, Any]]], collection):
|
||||
metadata = {
|
||||
"created_date": date_to_epoch(docs[index]["created_date"]),
|
||||
"filename": docs[index]["original_file_name"],
|
||||
"document_type": doctypes.get(docs[index]["document_type"], ""),
|
||||
}
|
||||
|
||||
if doctypes:
|
||||
metadata["type"] = doctypes.get(docs[index]["document_type"])
|
||||
|
||||
chunker.chunk_document(
|
||||
document=text,
|
||||
metadata=metadata,
|
||||
)
|
||||
to_insert.append((docs[index]["id"],))
|
||||
|
||||
c.executemany("INSERT INTO indexed_documents (paperless_id) values (?)", to_insert)
|
||||
|
||||
c.executemany(
|
||||
"INSERT INTO indexed_documents (paperless_id) values (?)", to_insert
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def chunk_text(texts: list[str], collection):
|
||||
@@ -169,10 +175,13 @@ def consult_simba_oracle(input: str):
|
||||
collection=simba_docs,
|
||||
)
|
||||
|
||||
|
||||
def filter_indexed_files(docs):
|
||||
with sqlite3.connect("visited.db") as conn:
|
||||
c = conn.cursor()
|
||||
c.execute("CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)")
|
||||
c.execute(
|
||||
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
|
||||
)
|
||||
c.execute("SELECT paperless_id FROM indexed_documents")
|
||||
rows = c.fetchall()
|
||||
conn.commit()
|
||||
@@ -181,7 +190,6 @@ def filter_indexed_files(docs):
|
||||
return [doc for doc in docs if doc["id"] not in visited]
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.reindex:
|
||||
@@ -192,20 +200,22 @@ if __name__ == "__main__":
|
||||
print(f"Fetched {len(docs)} documents")
|
||||
#
|
||||
print("Chunking documents now ...")
|
||||
chunk_data(docs, collection=simba_docs)
|
||||
tag_lookup = ppngx.get_tags()
|
||||
doctype_lookup = ppngx.get_doctypes()
|
||||
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
|
||||
print("Done chunking documents")
|
||||
# index_using_pdf_llm()
|
||||
|
||||
if args.index:
|
||||
with open(args.index) as file:
|
||||
extension = args.index.split(".")[-1]
|
||||
if extension == "pdf":
|
||||
pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
||||
image_paths = pdf_to_image(filepath=pdf_path)
|
||||
print(f"summarizing {file}")
|
||||
generated_summary = summarize_pdf_image(filepaths=image_paths)
|
||||
elif extension in [".md", ".txt"]:
|
||||
chunk_text(texts=[file.readall()], collection=simba_docs)
|
||||
# if args.index:
|
||||
# with open(args.index) as file:
|
||||
# extension = args.index.split(".")[-1]
|
||||
# if extension == "pdf":
|
||||
# pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
||||
# image_paths = pdf_to_image(filepath=pdf_path)
|
||||
# print(f"summarizing {file}")
|
||||
# generated_summary = summarize_pdf_image(filepaths=image_paths)
|
||||
# elif extension in [".md", ".txt"]:
|
||||
# chunk_text(texts=[file.readall()], collection=simba_docs)
|
||||
|
||||
if args.query:
|
||||
print("Consulting oracle ...")
|
||||
|
||||
105
query.py
105
query.py
@@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
from typing import Literal
|
||||
import datetime
|
||||
from ollama import chat, ChatResponse, Client
|
||||
from ollama import Client
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -38,6 +38,20 @@ class Time(BaseModel):
|
||||
time: int
|
||||
|
||||
|
||||
DOCTYPE_OPTIONS = [
|
||||
"Bill",
|
||||
"Image Description",
|
||||
"Insurance",
|
||||
"Medical Record",
|
||||
"Documentation",
|
||||
"Letter",
|
||||
]
|
||||
|
||||
|
||||
class DocumentType(BaseModel):
|
||||
type: list[str] = Field(description="type of document", enum=DOCTYPE_OPTIONS)
|
||||
|
||||
|
||||
PROMPT = """
|
||||
You are an information specialist that processes user queries. The current year is 2025. The user queries are all about
|
||||
a cat, Simba, and its records. The types of records are listed below. Using the query, extract the
|
||||
@@ -82,7 +96,8 @@ document_types:
|
||||
Only return the extracted metadata fields. Make sure the extracted metadata fields are valid JSON
|
||||
"""
|
||||
|
||||
USE_OPENAI = os.getenv("OPENAI_API_KEY", None) != None
|
||||
|
||||
DOCTYPE_PROMPT = f"You are an information specialist that processes user queries. A query can have two tags attached from the following options. Based on the query, determine which of the following options is most appropriate: {','.join(DOCTYPE_OPTIONS)}"
|
||||
|
||||
|
||||
class QueryGenerator:
|
||||
@@ -102,43 +117,67 @@ class QueryGenerator:
|
||||
|
||||
return date.timestamp()
|
||||
|
||||
def get_doctype_query(self, input: str):
|
||||
print(DOCTYPE_PROMPT)
|
||||
client = OpenAI()
|
||||
response = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an information specialist that is really good at deciding what tags a query should have",
|
||||
},
|
||||
{"role": "user", "content": DOCTYPE_PROMPT + " " + input},
|
||||
],
|
||||
model="gpt-4o",
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "document_type",
|
||||
"schema": DocumentType.model_json_schema(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
response_json_str = response.choices[0].message.content
|
||||
type_data = json.loads(response_json_str)
|
||||
print(type_data)
|
||||
return type_data
|
||||
|
||||
def get_query(self, input: str):
|
||||
if USE_OPENAI:
|
||||
client = OpenAI()
|
||||
response = client.responses.parse(
|
||||
model="gpt-4o",
|
||||
input=[
|
||||
{"role": "system", "content": PROMPT},
|
||||
{"role": "user", "content": input},
|
||||
],
|
||||
text_format=GeneratedQuery,
|
||||
)
|
||||
print(response.output)
|
||||
query = json.loads(response.output_parsed.extracted_metadata_fields)
|
||||
else:
|
||||
response: ChatResponse = ollama_client.chat(
|
||||
model="gemma3n:e4b",
|
||||
messages=[
|
||||
{"role": "system", "content": PROMPT},
|
||||
{"role": "user", "content": input},
|
||||
],
|
||||
format=GeneratedQuery.model_json_schema(),
|
||||
)
|
||||
client = OpenAI()
|
||||
response = client.responses.parse(
|
||||
model="gpt-4o",
|
||||
input=[
|
||||
{"role": "system", "content": PROMPT},
|
||||
{"role": "user", "content": input},
|
||||
],
|
||||
text_format=GeneratedQuery,
|
||||
)
|
||||
print(response.output)
|
||||
query = json.loads(response.output_parsed.extracted_metadata_fields)
|
||||
# response: ChatResponse = ollama_client.chat(
|
||||
# model="gemma3n:e4b",
|
||||
# messages=[
|
||||
# {"role": "system", "content": PROMPT},
|
||||
# {"role": "user", "content": input},
|
||||
# ],
|
||||
# format=GeneratedQuery.model_json_schema(),
|
||||
# )
|
||||
|
||||
query = json.loads(
|
||||
json.loads(response["message"]["content"])["extracted_metadata_fields"]
|
||||
)
|
||||
date_key = list(query["created_date"].keys())[0]
|
||||
query["created_date"][date_key] = self.date_to_epoch(
|
||||
query["created_date"][date_key]
|
||||
)
|
||||
# query = json.loads(
|
||||
# json.loads(response["message"]["content"])["extracted_metadata_fields"]
|
||||
# )
|
||||
# date_key = list(query["created_date"].keys())[0]
|
||||
# query["created_date"][date_key] = self.date_to_epoch(
|
||||
# query["created_date"][date_key]
|
||||
# )
|
||||
|
||||
if "$" not in date_key:
|
||||
query["created_date"]["$" + date_key] = query["created_date"][date_key]
|
||||
# if "$" not in date_key:
|
||||
# query["created_date"]["$" + date_key] = query["created_date"][date_key]
|
||||
|
||||
return query
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
qg = QueryGenerator()
|
||||
print(qg.get_query("How heavy is Simba?"))
|
||||
print(qg.get_doctype_query("How heavy is Simba?"))
|
||||
|
||||
14
request.py
14
request.py
@@ -14,7 +14,7 @@ class PaperlessNGXService:
|
||||
def __init__(self):
|
||||
self.base_url = os.getenv("BASE_URL")
|
||||
self.token = os.getenv("PAPERLESS_TOKEN")
|
||||
self.url = f"http://{os.getenv('BASE_URL')}/api/documents/?query=simba"
|
||||
self.url = f"http://{os.getenv('BASE_URL')}/api/documents/?tags__id=8"
|
||||
self.headers = {"Authorization": f"Token {os.getenv('PAPERLESS_TOKEN')}"}
|
||||
|
||||
def get_data(self):
|
||||
@@ -68,6 +68,18 @@ class PaperlessNGXService:
|
||||
r = httpx.post(POST_URL, headers=self.headers, data=data, files=files)
|
||||
r.raise_for_status()
|
||||
|
||||
def get_tags(self):
|
||||
GET_URL = f"http://{os.getenv('BASE_URL')}/api/tags/"
|
||||
r = httpx.get(GET_URL, headers=self.headers)
|
||||
data = r.json()
|
||||
return {tag["id"]: tag["name"] for tag in data["results"]}
|
||||
|
||||
def get_doctypes(self):
|
||||
GET_URL = f"http://{os.getenv('BASE_URL')}/api/document_types/"
|
||||
r = httpx.get(GET_URL, headers=self.headers)
|
||||
data = r.json()
|
||||
return {doctype["id"]: doctype["name"] for doctype in data["results"]}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pp = PaperlessNGXService()
|
||||
|
||||
Reference in New Issue
Block a user