This commit is contained in:
2025-08-07 17:43:24 -04:00
parent fc504d3e9c
commit 679cfb08e4
5 changed files with 294 additions and 32 deletions

View File

@@ -1,6 +1,7 @@
import os
from math import ceil
import re
from typing import Union
from uuid import UUID, uuid4
from chromadb.utils.embedding_functions.ollama_embedding_function import (
@@ -87,7 +88,12 @@ class Chunker:
def __init__(self, collection) -> None:
self.collection = collection
def chunk_document(self, document: str, chunk_size: int = 1000) -> list[Chunk]:
def chunk_document(
self,
document: str,
chunk_size: int = 1000,
metadata: dict[str, Union[str, float]] = {},
) -> list[Chunk]:
doc_uuid = uuid4()
chunk_size = min(chunk_size, len(document))
@@ -110,6 +116,7 @@ class Chunker:
ids=[str(doc_uuid) + ":" + str(i)],
documents=[text_chunk],
embeddings=embedding,
metadatas=[metadata],
)
return chunks

162
cleaner.py Normal file
View File

@@ -0,0 +1,162 @@
import os
import sys
import tempfile
import argparse
from dotenv import load_dotenv
import ollama
from PIL import Image
import fitz
from request import PaperlessNGXService
load_dotenv()
parser = argparse.ArgumentParser(description="use llm to clean documents")
parser.add_argument("document_id", type=str, help="questions about simba's health")
def pdf_to_image(filepath: str, dpi=300) -> list[str]:
"""Returns the filepaths to the created images"""
image_temp_files = []
try:
pdf_document = fitz.open(filepath)
print(f"\nConverting '{os.path.basename(filepath)}' to temporary images...")
for page_num in range(len(pdf_document)):
page = pdf_document.load_page(page_num)
zoom = dpi / 72
mat = fitz.Matrix(zoom, zoom)
pix = page.get_pixmap(matrix=mat)
# Create a temporary file for the image. delete=False is crucial.
with tempfile.NamedTemporaryFile(
delete=False,
suffix=".png",
prefix=f"pdf_page_{page_num + 1}_",
) as temp_image_file:
temp_image_path = temp_image_file.name
# Save the pixel data to the temporary file
pix.save(temp_image_path)
image_temp_files.append(temp_image_path)
print(
f" -> Saved page {page_num + 1} to temporary file: '{temp_image_path}'"
)
print("\nConversion successful! ✨")
return image_temp_files
except Exception as e:
print(f"An error occurred during PDF conversion: {e}", file=sys.stderr)
# Clean up any image files that were created before the error
for path in image_temp_files:
os.remove(path)
return []
def merge_images_vertically_to_tempfile(image_paths):
"""
Merges a list of images vertically and saves the result to a temporary file.
Args:
image_paths (list): A list of strings, where each string is the
filepath to an image.
Returns:
str: The filepath of the temporary merged image file.
"""
if not image_paths:
print("Error: The list of image paths is empty.")
return None
# Open all images and check for consistency
try:
images = [Image.open(path) for path in image_paths]
except FileNotFoundError as e:
print(f"Error: Could not find image file: {e}")
return None
widths, heights = zip(*(img.size for img in images))
max_width = max(widths)
# All images must have the same width
if not all(width == max_width for width in widths):
print("Warning: Images have different widths. They will be resized.")
resized_images = []
for img in images:
if img.size[0] != max_width:
img = img.resize(
(max_width, int(img.size[1] * (max_width / img.size[0])))
)
resized_images.append(img)
images = resized_images
heights = [img.size[1] for img in images]
# Calculate the total height of the merged image
total_height = sum(heights)
# Create a new blank image with the combined dimensions
merged_image = Image.new("RGB", (max_width, total_height))
# Paste each image onto the new blank image
y_offset = 0
for img in images:
merged_image.paste(img, (0, y_offset))
y_offset += img.height
# Create a temporary file and save the image
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
temp_path = temp_file.name
merged_image.save(temp_path)
temp_file.close()
print(f"Successfully merged {len(images)} images into temporary file: {temp_path}")
return temp_path
OCR_PROMPT = """
You job is to extract text from the images I provide you. Extract every bit of the text in the image. Don't say anything just do your job. Text should be same as in the images. If there are multiple images, categorize the transcriptions by page.
Things to avoid:
- Don't miss anything to extract from the images
Things to include:
- Include everything, even anything inside [], (), {} or anything.
- Include any repetitive things like "..." or anything
- If you think there is any mistake in image just include it too
Someone will kill the innocent kittens if you don't extract the text exactly. So, make sure you extract every bit of the text. Only output the extracted text.
"""
def summarize_pdf_image(filepaths: list[str]):
res = ollama.chat(
model="gemma3:4b",
messages=[
{
"role": "user",
"content": OCR_PROMPT,
"images": filepaths,
}
],
)
return res["message"]["content"]
if __name__ == "__main__":
args = parser.parse_args()
ppngx = PaperlessNGXService()
if args.document_id:
doc_id = args.document_id
file = ppngx.get_doc_by_id(doc_id=doc_id)
pdf_path = ppngx.download_pdf_from_id(doc_id)
print(pdf_path)
image_paths = pdf_to_image(filepath=pdf_path)
summary = summarize_pdf_image(filepaths=image_paths)
print(summary)
file["content"] = summary
print(file)
ppngx.upload_cleaned_content(doc_id, file)

82
main.py
View File

@@ -1,5 +1,7 @@
import datetime
import logging
import os
from typing import Any, Union
import argparse
import chromadb
@@ -8,7 +10,8 @@ import ollama
from request import PaperlessNGXService
from chunker import Chunker
from query import QueryGenerator
from cleaner import pdf_to_image, summarize_pdf_image
from dotenv import load_dotenv
@@ -27,20 +30,66 @@ parser.add_argument(
"--reindex", action="store_true", help="re-index the simba documents"
)
ppngx = PaperlessNGXService()
def chunk_data(texts: list[str], collection):
def index_using_pdf_llm():
files = ppngx.get_data()
for file in files:
document_id = file["id"]
pdf_path = ppngx.download_pdf_from_id(id=document_id)
image_paths = pdf_to_image(filepath=pdf_path)
generated_summary = summarize_pdf_image(filepaths=image_paths)
file["content"] = generated_summary
chunk_data(files, simba_docs)
def date_to_epoch(date_str: str) -> float:
split_date = date_str.split("-")
print(split_date)
date = datetime.datetime(
int(split_date[0]),
int(split_date[1]),
int(split_date[2]),
0,
0,
0,
)
return date.timestamp()
def chunk_data(docs: list[dict[str, Union[str, Any]]], collection):
# Step 2: Create chunks
chunker = Chunker(collection)
print(f"chunking {len(texts)} documents")
for text in texts:
chunker.chunk_document(document=text)
print(f"chunking {len(docs)} documents")
print(docs)
texts: list[str] = [doc["content"] for doc in docs]
for index, text in enumerate(texts):
metadata = {
"created_date": date_to_epoch(docs[index]["created_date"]),
}
chunker.chunk_document(
document=text,
metadata=metadata,
)
def consult_oracle(input: str, collection):
# Ask
qg = QueryGenerator()
metadata_filter = qg.get_query("input")
print(metadata_filter)
embeddings = Chunker.embedding_fx(input=[input])
results = collection.query(query_texts=[input], query_embeddings=embeddings)
results = collection.query(
query_texts=[input],
query_embeddings=embeddings,
where=metadata_filter,
)
print(results)
# Generate
output = ollama.generate(
@@ -55,24 +104,23 @@ def paperless_workflow(input):
# Step 1: Get the text
ppngx = PaperlessNGXService()
docs = ppngx.get_data()
texts = [doc["content"] for doc in docs]
chunk_data(texts, collection=simba_docs)
chunk_data(docs, collection=simba_docs)
consult_oracle(input, simba_docs)
if __name__ == "__main__":
args = parser.parse_args()
if args.reindex:
logging.info(msg="Fetching documents from Paperless-NGX")
ppngx = PaperlessNGXService()
docs = ppngx.get_data()
texts = [doc["content"] for doc in docs]
logging.info(msg=f"Fetched {len(texts)} documents")
logging.info(msg="Chunking documents now ...")
chunk_data(texts, collection=simba_docs)
logging.info(msg="Done chunking documents")
# logging.info(msg="Fetching documents from Paperless-NGX")
# ppngx = PaperlessNGXService()
# docs = ppngx.get_data()
# logging.info(msg=f"Fetched {len(docs)} documents")
#
# logging.info(msg="Chunking documents now ...")
# chunk_data(docs, collection=simba_docs)
# logging.info(msg="Done chunking documents")
index_using_pdf_llm()
if args.query:
logging.info("Consulting oracle ...")

View File

@@ -1,6 +1,6 @@
import json
from typing import Literal
import datetime
from ollama import chat, ChatResponse
from pydantic import BaseModel, Field
@@ -29,9 +29,9 @@ class GeneratedQuery(BaseModel):
PROMPT = """
You are an information specialist that processes user queries. The user queries are all about
You are an information specialist that processes user queries. The current year is 2025. The user queries are all about
a cat, Simba, and its records. The types of records are listed below. Using the query, extract the
type of record the user is trying to query and the date range the user is trying to query.
the date range the user is trying to query. You should return the it as a JSON. The date tag is created_date. Return the date in epoch time
You have several operators at your disposal:
@@ -49,18 +49,18 @@ Logical operators:
### Example 1
Query: "Who is Simba's current vet?"
Metadata fields: "{"created_date, tags"}"
Extracted metadata fields: {"$and": [{"created_date: {"$gt": "2025-01-01"}, "tags": {"$in": ["bill", "medical records", "aftercare"]}}]}
Metadata fields: "{"created_date"}"
Extracted metadata fields: {"created_date: {"$gt": "2025-01-01"}}
### Example 2
Query: "How many teeth has Simba had removed?"
Metadata fields: {"tags"}
Extracted metadata fields: {"tags": "medical records"}
Metadata fields: {}
Extracted metadata fields: {}
### Example 3
Query: "How many times has Simba been to the vet this year?"
Metadata fields: {"tags", "created_date"}
Extracted metadata fields: {"$and": [{"created_date": {"gt": "2025-01-01"}, "tags": {"$in": ["bill"]}}]}
Metadata fields: {"created_date"}
Extracted metadata fields: {"created_date": {"gt": "2025-01-01"}}
document_types:
- aftercare
@@ -76,6 +76,19 @@ class QueryGenerator:
def __init__(self) -> None:
pass
def date_to_epoch(self, date_str: str) -> float:
split_date = date_str.split("-")
date = datetime.datetime(
int(split_date[0]),
int(split_date[1]),
int(split_date[2]),
0,
0,
0,
)
return date.timestamp()
def get_query(self, input: str):
response: ChatResponse = chat(
model="gemma3n:e4b",
@@ -86,13 +99,20 @@ class QueryGenerator:
format=GeneratedQuery.model_json_schema(),
)
print(
json.loads(
json.loads(response["message"]["content"])["extracted_metadata_fields"]
)
query = json.loads(
json.loads(response["message"]["content"])["extracted_metadata_fields"]
)
date_key = list(query["created_date"].keys())[0]
query["created_date"][date_key] = self.date_to_epoch(
query["created_date"][date_key]
)
if "$" not in date_key:
query["created_date"]["$" + date_key] = query["created_date"][date_key]
return query
if __name__ == "__main__":
qg = QueryGenerator()
qg.get_query("How old is Simba?")
print(qg.get_query("How heavy is Simba?"))

View File

@@ -1,4 +1,5 @@
import os
import tempfile
import httpx
from dotenv import load_dotenv
@@ -18,6 +19,30 @@ class PaperlessNGXService:
r = httpx.get(self.url, headers=self.headers)
return r.json()["results"]
def get_doc_by_id(self, doc_id: int):
url = f"http://{os.getenv("BASE_URL")}/api/documents/{doc_id}/"
r = httpx.get(url, headers=self.headers)
return r.json()
def download_pdf_from_id(self, id: int) -> str:
download_url = f"http://{os.getenv("BASE_URL")}/api/documents/{id}/download/"
response = httpx.get(
download_url, headers=self.headers, follow_redirects=True, timeout=30
)
response.raise_for_status()
# Use a temporary file for the downloaded PDF
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
temp_file.write(response.content)
temp_file.close()
temp_pdf_path = temp_file.name
pdf_to_process = temp_pdf_path
return pdf_to_process
def upload_cleaned_content(self, document_id, data):
PUTS_URL = f"http://{os.getenv("BASE_URL")}/api/documents/{document_id}/"
r = httpx.put(PUTS_URL, headers=self.headers, data=data)
r.raise_for_status()
if __name__ == "__main__":
pp = PaperlessNGXService()