Starting attempt #2 at metadata filtering
This commit is contained in:
@@ -13,7 +13,9 @@ from llm import LLMClient
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
ollama_client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
|
ollama_client = Client(
|
||||||
|
host=os.getenv("OLLAMA_HOST", "http://localhost:11434"), timeout=10.0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
|
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
|
||||||
|
|||||||
4
llm.py
4
llm.py
@@ -3,8 +3,6 @@ import os
|
|||||||
from ollama import Client
|
from ollama import Client
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -14,7 +12,7 @@ class LLMClient:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
try:
|
try:
|
||||||
self.ollama_client = Client(
|
self.ollama_client = Client(
|
||||||
host=os.getenv("OLLAMA_URL", "http://localhost:11434")
|
host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0
|
||||||
)
|
)
|
||||||
self.ollama_client.chat(
|
self.ollama_client.chat(
|
||||||
model="gemma3:4b", messages=[{"role": "system", "content": "test"}]
|
model="gemma3:4b", messages=[{"role": "system", "content": "test"}]
|
||||||
|
|||||||
56
main.py
56
main.py
@@ -7,12 +7,10 @@ from typing import Any, Union
|
|||||||
import argparse
|
import argparse
|
||||||
import chromadb
|
import chromadb
|
||||||
import ollama
|
import ollama
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
|
|
||||||
from request import PaperlessNGXService
|
from request import PaperlessNGXService
|
||||||
from chunker import Chunker
|
from chunker import Chunker
|
||||||
from query import QueryGenerator
|
|
||||||
from cleaner import pdf_to_image, summarize_pdf_image
|
from cleaner import pdf_to_image, summarize_pdf_image
|
||||||
from llm import LLMClient
|
from llm import LLMClient
|
||||||
|
|
||||||
@@ -21,13 +19,13 @@ from dotenv import load_dotenv
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
USE_OPENAI = os.getenv("OPENAI_API_KEY") != None
|
|
||||||
|
|
||||||
# Configure ollama client with URL from environment or default to localhost
|
# Configure ollama client with URL from environment or default to localhost
|
||||||
ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
ollama_client = ollama.Client(
|
||||||
|
host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0
|
||||||
|
)
|
||||||
|
|
||||||
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
|
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
|
||||||
simba_docs = client.get_or_create_collection(name="simba_docs3")
|
simba_docs = client.get_or_create_collection(name="simba_docs2")
|
||||||
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
|
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@@ -46,6 +44,7 @@ llm_client = LLMClient()
|
|||||||
|
|
||||||
|
|
||||||
def index_using_pdf_llm():
|
def index_using_pdf_llm():
|
||||||
|
logging.info("reindex data...")
|
||||||
files = ppngx.get_data()
|
files = ppngx.get_data()
|
||||||
for file in files:
|
for file in files:
|
||||||
document_id = file["id"]
|
document_id = file["id"]
|
||||||
@@ -72,28 +71,35 @@ def date_to_epoch(date_str: str) -> float:
|
|||||||
return date.timestamp()
|
return date.timestamp()
|
||||||
|
|
||||||
|
|
||||||
def chunk_data(docs: list[dict[str, Union[str, Any]]], collection):
|
def chunk_data(docs: list[dict[str, Union[str, Any]]], collection, doctypes):
|
||||||
# Step 2: Create chunks
|
# Step 2: Create chunks
|
||||||
chunker = Chunker(collection)
|
chunker = Chunker(collection)
|
||||||
|
|
||||||
print(f"chunking {len(docs)} documents")
|
print(f"chunking {len(docs)} documents")
|
||||||
texts: list[str] = [doc["content"] for doc in docs]
|
texts: list[str] = [doc["content"] for doc in docs]
|
||||||
with sqlite3.connect("visited.db") as conn:
|
with sqlite3.connect("visited.db") as conn:
|
||||||
to_insert = []
|
to_insert = []
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
for index, text in enumerate(texts):
|
for index, text in enumerate(texts):
|
||||||
metadata = {
|
metadata = {
|
||||||
"created_date": date_to_epoch(docs[index]["created_date"]),
|
"created_date": date_to_epoch(docs[index]["created_date"]),
|
||||||
"filename": docs[index]["original_file_name"],
|
"filename": docs[index]["original_file_name"],
|
||||||
|
"document_type": doctypes.get(docs[index]["document_type"], ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if doctypes:
|
||||||
|
metadata["type"] = doctypes.get(docs[index]["document_type"])
|
||||||
|
|
||||||
chunker.chunk_document(
|
chunker.chunk_document(
|
||||||
document=text,
|
document=text,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
to_insert.append((docs[index]["id"],))
|
to_insert.append((docs[index]["id"],))
|
||||||
|
|
||||||
c.executemany("INSERT INTO indexed_documents (paperless_id) values (?)", to_insert)
|
c.executemany(
|
||||||
|
"INSERT INTO indexed_documents (paperless_id) values (?)", to_insert
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(texts: list[str], collection):
|
def chunk_text(texts: list[str], collection):
|
||||||
@@ -169,10 +175,13 @@ def consult_simba_oracle(input: str):
|
|||||||
collection=simba_docs,
|
collection=simba_docs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def filter_indexed_files(docs):
|
def filter_indexed_files(docs):
|
||||||
with sqlite3.connect("visited.db") as conn:
|
with sqlite3.connect("visited.db") as conn:
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute("CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)")
|
c.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
|
||||||
|
)
|
||||||
c.execute("SELECT paperless_id FROM indexed_documents")
|
c.execute("SELECT paperless_id FROM indexed_documents")
|
||||||
rows = c.fetchall()
|
rows = c.fetchall()
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@@ -181,7 +190,6 @@ def filter_indexed_files(docs):
|
|||||||
return [doc for doc in docs if doc["id"] not in visited]
|
return [doc for doc in docs if doc["id"] not in visited]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.reindex:
|
if args.reindex:
|
||||||
@@ -192,20 +200,22 @@ if __name__ == "__main__":
|
|||||||
print(f"Fetched {len(docs)} documents")
|
print(f"Fetched {len(docs)} documents")
|
||||||
#
|
#
|
||||||
print("Chunking documents now ...")
|
print("Chunking documents now ...")
|
||||||
chunk_data(docs, collection=simba_docs)
|
tag_lookup = ppngx.get_tags()
|
||||||
|
doctype_lookup = ppngx.get_doctypes()
|
||||||
|
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
|
||||||
print("Done chunking documents")
|
print("Done chunking documents")
|
||||||
# index_using_pdf_llm()
|
# index_using_pdf_llm()
|
||||||
|
|
||||||
if args.index:
|
# if args.index:
|
||||||
with open(args.index) as file:
|
# with open(args.index) as file:
|
||||||
extension = args.index.split(".")[-1]
|
# extension = args.index.split(".")[-1]
|
||||||
if extension == "pdf":
|
# if extension == "pdf":
|
||||||
pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
# pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
||||||
image_paths = pdf_to_image(filepath=pdf_path)
|
# image_paths = pdf_to_image(filepath=pdf_path)
|
||||||
print(f"summarizing {file}")
|
# print(f"summarizing {file}")
|
||||||
generated_summary = summarize_pdf_image(filepaths=image_paths)
|
# generated_summary = summarize_pdf_image(filepaths=image_paths)
|
||||||
elif extension in [".md", ".txt"]:
|
# elif extension in [".md", ".txt"]:
|
||||||
chunk_text(texts=[file.readall()], collection=simba_docs)
|
# chunk_text(texts=[file.readall()], collection=simba_docs)
|
||||||
|
|
||||||
if args.query:
|
if args.query:
|
||||||
print("Consulting oracle ...")
|
print("Consulting oracle ...")
|
||||||
|
|||||||
105
query.py
105
query.py
@@ -2,7 +2,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
import datetime
|
import datetime
|
||||||
from ollama import chat, ChatResponse, Client
|
from ollama import Client
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
@@ -38,6 +38,20 @@ class Time(BaseModel):
|
|||||||
time: int
|
time: int
|
||||||
|
|
||||||
|
|
||||||
|
DOCTYPE_OPTIONS = [
|
||||||
|
"Bill",
|
||||||
|
"Image Description",
|
||||||
|
"Insurance",
|
||||||
|
"Medical Record",
|
||||||
|
"Documentation",
|
||||||
|
"Letter",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentType(BaseModel):
|
||||||
|
type: list[str] = Field(description="type of document", enum=DOCTYPE_OPTIONS)
|
||||||
|
|
||||||
|
|
||||||
PROMPT = """
|
PROMPT = """
|
||||||
You are an information specialist that processes user queries. The current year is 2025. The user queries are all about
|
You are an information specialist that processes user queries. The current year is 2025. The user queries are all about
|
||||||
a cat, Simba, and its records. The types of records are listed below. Using the query, extract the
|
a cat, Simba, and its records. The types of records are listed below. Using the query, extract the
|
||||||
@@ -82,7 +96,8 @@ document_types:
|
|||||||
Only return the extracted metadata fields. Make sure the extracted metadata fields are valid JSON
|
Only return the extracted metadata fields. Make sure the extracted metadata fields are valid JSON
|
||||||
"""
|
"""
|
||||||
|
|
||||||
USE_OPENAI = os.getenv("OPENAI_API_KEY", None) != None
|
|
||||||
|
DOCTYPE_PROMPT = f"You are an information specialist that processes user queries. A query can have two tags attached from the following options. Based on the query, determine which of the following options is most appropriate: {','.join(DOCTYPE_OPTIONS)}"
|
||||||
|
|
||||||
|
|
||||||
class QueryGenerator:
|
class QueryGenerator:
|
||||||
@@ -102,43 +117,67 @@ class QueryGenerator:
|
|||||||
|
|
||||||
return date.timestamp()
|
return date.timestamp()
|
||||||
|
|
||||||
|
def get_doctype_query(self, input: str):
|
||||||
|
print(DOCTYPE_PROMPT)
|
||||||
|
client = OpenAI()
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are an information specialist that is really good at deciding what tags a query should have",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": DOCTYPE_PROMPT + " " + input},
|
||||||
|
],
|
||||||
|
model="gpt-4o",
|
||||||
|
response_format={
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": "document_type",
|
||||||
|
"schema": DocumentType.model_json_schema(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response_json_str = response.choices[0].message.content
|
||||||
|
type_data = json.loads(response_json_str)
|
||||||
|
print(type_data)
|
||||||
|
return type_data
|
||||||
|
|
||||||
def get_query(self, input: str):
|
def get_query(self, input: str):
|
||||||
if USE_OPENAI:
|
client = OpenAI()
|
||||||
client = OpenAI()
|
response = client.responses.parse(
|
||||||
response = client.responses.parse(
|
model="gpt-4o",
|
||||||
model="gpt-4o",
|
input=[
|
||||||
input=[
|
{"role": "system", "content": PROMPT},
|
||||||
{"role": "system", "content": PROMPT},
|
{"role": "user", "content": input},
|
||||||
{"role": "user", "content": input},
|
],
|
||||||
],
|
text_format=GeneratedQuery,
|
||||||
text_format=GeneratedQuery,
|
)
|
||||||
)
|
print(response.output)
|
||||||
print(response.output)
|
query = json.loads(response.output_parsed.extracted_metadata_fields)
|
||||||
query = json.loads(response.output_parsed.extracted_metadata_fields)
|
# response: ChatResponse = ollama_client.chat(
|
||||||
else:
|
# model="gemma3n:e4b",
|
||||||
response: ChatResponse = ollama_client.chat(
|
# messages=[
|
||||||
model="gemma3n:e4b",
|
# {"role": "system", "content": PROMPT},
|
||||||
messages=[
|
# {"role": "user", "content": input},
|
||||||
{"role": "system", "content": PROMPT},
|
# ],
|
||||||
{"role": "user", "content": input},
|
# format=GeneratedQuery.model_json_schema(),
|
||||||
],
|
# )
|
||||||
format=GeneratedQuery.model_json_schema(),
|
|
||||||
)
|
|
||||||
|
|
||||||
query = json.loads(
|
# query = json.loads(
|
||||||
json.loads(response["message"]["content"])["extracted_metadata_fields"]
|
# json.loads(response["message"]["content"])["extracted_metadata_fields"]
|
||||||
)
|
# )
|
||||||
date_key = list(query["created_date"].keys())[0]
|
# date_key = list(query["created_date"].keys())[0]
|
||||||
query["created_date"][date_key] = self.date_to_epoch(
|
# query["created_date"][date_key] = self.date_to_epoch(
|
||||||
query["created_date"][date_key]
|
# query["created_date"][date_key]
|
||||||
)
|
# )
|
||||||
|
|
||||||
if "$" not in date_key:
|
# if "$" not in date_key:
|
||||||
query["created_date"]["$" + date_key] = query["created_date"][date_key]
|
# query["created_date"]["$" + date_key] = query["created_date"][date_key]
|
||||||
|
|
||||||
return query
|
return query
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
qg = QueryGenerator()
|
qg = QueryGenerator()
|
||||||
print(qg.get_query("How heavy is Simba?"))
|
print(qg.get_doctype_query("How heavy is Simba?"))
|
||||||
|
|||||||
14
request.py
14
request.py
@@ -14,7 +14,7 @@ class PaperlessNGXService:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.base_url = os.getenv("BASE_URL")
|
self.base_url = os.getenv("BASE_URL")
|
||||||
self.token = os.getenv("PAPERLESS_TOKEN")
|
self.token = os.getenv("PAPERLESS_TOKEN")
|
||||||
self.url = f"http://{os.getenv('BASE_URL')}/api/documents/?query=simba"
|
self.url = f"http://{os.getenv('BASE_URL')}/api/documents/?tags__id=8"
|
||||||
self.headers = {"Authorization": f"Token {os.getenv('PAPERLESS_TOKEN')}"}
|
self.headers = {"Authorization": f"Token {os.getenv('PAPERLESS_TOKEN')}"}
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
@@ -68,6 +68,18 @@ class PaperlessNGXService:
|
|||||||
r = httpx.post(POST_URL, headers=self.headers, data=data, files=files)
|
r = httpx.post(POST_URL, headers=self.headers, data=data, files=files)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
|
def get_tags(self):
|
||||||
|
GET_URL = f"http://{os.getenv('BASE_URL')}/api/tags/"
|
||||||
|
r = httpx.get(GET_URL, headers=self.headers)
|
||||||
|
data = r.json()
|
||||||
|
return {tag["id"]: tag["name"] for tag in data["results"]}
|
||||||
|
|
||||||
|
def get_doctypes(self):
|
||||||
|
GET_URL = f"http://{os.getenv('BASE_URL')}/api/document_types/"
|
||||||
|
r = httpx.get(GET_URL, headers=self.headers)
|
||||||
|
data = r.json()
|
||||||
|
return {doctype["id"]: doctype["name"] for doctype in data["results"]}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pp = PaperlessNGXService()
|
pp = PaperlessNGXService()
|
||||||
|
|||||||
Reference in New Issue
Block a user