This commit is contained in:
2025-10-05 20:31:46 -04:00
parent 0bb3e3172b
commit 910097d13b
7 changed files with 146 additions and 98 deletions

View File

@@ -3,15 +3,19 @@ from math import ceil
import re
from typing import Union
from uuid import UUID, uuid4
from ollama import Client
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from dotenv import load_dotenv
USE_OPENAI = os.getenv("OPENAI_API_KEY") != None
load_dotenv()
ollama_client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
if header_patterns is None:
@@ -88,6 +92,17 @@ class Chunker:
def __init__(self, collection) -> None:
self.collection = collection
def embedding_fx(self, inputs):
if USE_OPENAI:
openai_embedding_fx = OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="text-embedding-3-small",
)
return openai_embedding_fx(inputs)
else:
response = ollama_client.embed(model="mxbai-embed-large", input=inputs[0])
return response["embeddings"]
def chunk_document(
self,
document: str,

View File

@@ -1,4 +1,4 @@
version: '3.8'
version: "3.8"
services:
raggr:

View File

@@ -25,16 +25,18 @@ parser.add_argument("filepath")
client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
class SimbaImageDescription(BaseModel):
image_date: str
description: str
def describe_simba_image(input):
logging.info("Opening image of Simba ...")
if "heic" in input.lower() or "heif" in input.lower():
new_filepath = input.split(".")[0] + ".jpg"
img = Image.open(input)
img.save(new_filepath, 'JPEG')
img.save(new_filepath, "JPEG")
logging.info("Extracting EXIF...")
exif = {
ExifTags.TAGS[k]: v for k, v in img.getexif().items() if k in ExifTags.TAGS
@@ -66,7 +68,7 @@ def describe_simba_image(input):
},
{"role": "user", "content": prompt, "images": [input]},
],
format=SimbaImageDescription.model_json_schema()
format=SimbaImageDescription.model_json_schema(),
)
result = SimbaImageDescription.model_validate_json(response["message"]["content"])

View File

@@ -6,6 +6,7 @@ import tempfile
from image_process import describe_simba_image
from request import PaperlessNGXService
import sqlite3
logging.basicConfig(level=logging.INFO)
@@ -23,8 +24,16 @@ DOWNLOAD_DIR = os.getenv("DOWNLOAD_DIR", "./simba_photos")
# Set up headers
headers = {"x-api-key": API_KEY, "Content-Type": "application/json"}
VISITED = {}
if __name__ == "__main__":
conn = sqlite3.connect("./visited.db")
c = conn.cursor()
c.execute("select immich_id from visited")
rows = c.fetchall()
for row in rows:
VISITED.add(row[0])
ppngx = PaperlessNGXService()
people_url = f"{IMMICH_URL}/api/search/person?name=Simba"
people = httpx.get(people_url, headers=headers).json()
@@ -39,7 +48,7 @@ if __name__ == "__main__":
assets = results.json()["assets"]
for asset in assets["items"]:
if asset["type"] == "IMAGE":
if asset["type"] == "IMAGE" and asset["id"] not in VISITED:
ids[asset["id"]] = asset.get("originalFileName")
nextPage = assets.get("nextPage")
@@ -58,13 +67,12 @@ if __name__ == "__main__":
asset_search = f"{IMMICH_URL}/api/search/smart"
request_body = {"query": "simba"}
results = httpx.post(asset_search, headers=headers, json=request_body)
print(results.json()["assets"]["total"])
for asset in results.json()["assets"]["items"]:
if asset["type"] == "IMAGE":
ids[asset["id"]] = asset.get("originalFileName")
immich_asset_id = list(ids.keys())[1]
immich_filename = ids.get(immich_asset_id)
for immich_asset_id, immich_filename in ids.items():
try:
response = httpx.get(
f"{IMMICH_URL}/api/assets/{immich_asset_id}/original", headers=headers
)
@@ -80,19 +88,28 @@ if __name__ == "__main__":
image_description = description.description
image_date = description.image_date
description_filepath = os.path.join("/Users/ryanchen/Programs/raggr", f"SIMBA_DESCRIBE_001.txt")
description_filepath = os.path.join(
"/Users/ryanchen/Programs/raggr", f"SIMBA_DESCRIBE_001.txt"
)
file = open(description_filepath, "w+")
file.write(image_description)
file.close()
file = open(description_filepath, 'rb')
ppngx.upload_description(description_filepath=description_filepath, file=file, title="SIMBA_DESCRIBE_001.txt", exif_date=image_date)
file = open(description_filepath, "rb")
ppngx.upload_description(
description_filepath=description_filepath,
file=file,
title="SIMBA_DESCRIBE_001.txt",
exif_date=image_date,
)
file.close()
c.execute("INSERT INTO visited (immich_id) values (?)", (immich_asset_id,))
conn.commit()
logging.info("Processing complete. Deleting file.")
os.remove(file.name)
except Exception as e:
logging.info(f"something went wrong for {immich_filename}")
logging.info(e)
conn.close()

27
main.py
View File

@@ -18,11 +18,13 @@ from dotenv import load_dotenv
load_dotenv()
USE_OPENAI = os.getenv("OPENAI_API_KEY") != None
# Configure ollama client with URL from environment or default to localhost
ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434"))
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
simba_docs = client.get_or_create_collection(name="simba_docs")
simba_docs = client.get_or_create_collection(name="simba_docs2")
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
parser = argparse.ArgumentParser(
@@ -55,7 +57,6 @@ def index_using_pdf_llm():
def date_to_epoch(date_str: str) -> float:
split_date = date_str.split("-")
print(split_date)
date = datetime.datetime(
int(split_date[0]),
int(split_date[1]),
@@ -73,10 +74,8 @@ def chunk_data(docs: list[dict[str, Union[str, Any]]], collection):
chunker = Chunker(collection)
print(f"chunking {len(docs)} documents")
print(docs)
texts: list[str] = [doc["content"] for doc in docs]
for index, text in enumerate(texts):
print(docs[index]["original_file_name"])
metadata = {
"created_date": date_to_epoch(docs[index]["created_date"]),
"filename": docs[index]["original_file_name"],
@@ -101,6 +100,7 @@ def chunk_text(texts: list[str], collection):
def consult_oracle(input: str, collection):
print(input)
import time
chunker = Chunker(collection)
start_time = time.time()
@@ -115,7 +115,7 @@ def consult_oracle(input: str, collection):
print("Starting embedding generation")
embedding_start = time.time()
embeddings = Chunker.embedding_fx(input=[input])
embeddings = chunker.embedding_fx(inputs=[input])
embedding_end = time.time()
print(f"Embedding generation took {embedding_end - embedding_start:.2f} seconds")
@@ -126,17 +126,13 @@ def consult_oracle(input: str, collection):
query_embeddings=embeddings,
# where=metadata_filter,
)
print(results)
query_end = time.time()
print(f"Collection query took {query_end - query_start:.2f} seconds")
# Generate
print("Starting LLM generation")
llm_start = time.time()
# output = ollama_client.generate(
# model="gemma3n:e4b",
# prompt=f"You are a helpful assistant that understandings veterinary terms. Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}",
# )
if USE_OPENAI:
response = openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[
@@ -150,13 +146,20 @@ def consult_oracle(input: str, collection):
},
],
)
output= response.choices[0].message.content
else:
response = ollama_client.generate(
model="gemma3:4b",
prompt=f"You are a helpful assistant that understandings veterinary terms. Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}",
)
output = response["response"]
llm_end = time.time()
print(f"LLM generation took {llm_end - llm_start:.2f} seconds")
total_time = time.time() - start_time
print(f"Total consult_oracle execution took {total_time:.2f} seconds")
return response.choices[0].message.content
return output
def paperless_workflow(input):
@@ -181,7 +184,6 @@ if __name__ == "__main__":
print("Fetching documents from Paperless-NGX")
ppngx = PaperlessNGXService()
docs = ppngx.get_data()
print(docs)
print(f"Fetched {len(docs)} documents")
#
print("Chunking documents now ...")
@@ -192,7 +194,6 @@ if __name__ == "__main__":
if args.index:
with open(args.index) as file:
extension = args.index.split(".")[-1]
if extension == "pdf":
pdf_path = ppngx.download_pdf_from_id(id=document_id)
image_paths = pdf_to_image(filepath=pdf_path)

View File

@@ -45,7 +45,6 @@ the date range the user is trying to query. You should return it as a JSON. The
If the created_date cannot be ascertained, set it to epoch time start.
You have several operators at your disposal:
- $gt: greater than
- $gte: greater than or equal
@@ -83,6 +82,8 @@ document_types:
Only return the extracted metadata fields. Make sure the extracted metadata fields are valid JSON
"""
USE_OPENAI = os.getenv("OPENAI_API_KEY", None) != None
class QueryGenerator:
def __init__(self) -> None:
@@ -102,31 +103,31 @@ class QueryGenerator:
return date.timestamp()
def get_query(self, input: str):
if USE_OPENAI:
client = OpenAI()
print(input)
response = client.responses.parse(
model="gpt-4o",
input=[
{"role": "system", "content": PROMPT},
{"role": "user", "content": input},
],
text_format=Time,
text_format=GeneratedQuery,
)
print(response)
print(response.output)
query = json.loads(response.output_parsed.extracted_metadata_fields)
else:
response: ChatResponse = ollama_client.chat(
model="gemma3n:e4b",
messages=[
{"role": "system", "content": PROMPT},
{"role": "user", "content": input},
],
format=GeneratedQuery.model_json_schema(),
)
# response: ChatResponse = ollama_client.chat(
# model="gemma3n:e4b",
# messages=[
# {"role": "system", "content": PROMPT},
# {"role": "user", "content": input},
# ],
# format=GeneratedQuery.model_json_schema(),
# )
# query = json.loads(
# json.loads(response["message"]["content"])["extracted_metadata_fields"]
# )
query = json.loads(
json.loads(response["message"]["content"])["extracted_metadata_fields"]
)
date_key = list(query["created_date"].keys())[0]
query["created_date"][date_key] = self.date_to_epoch(
query["created_date"][date_key]

View File

@@ -1,11 +1,14 @@
import os
import tempfile
import httpx
import logging
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(level=logging.INFO)
class PaperlessNGXService:
def __init__(self):
@@ -17,7 +20,16 @@ class PaperlessNGXService:
def get_data(self):
print(f"Getting data from: {self.url}")
r = httpx.get(self.url, headers=self.headers)
return r.json()["results"]
results = r.json()["results"]
nextLink = r.json().get("next")
while nextLink:
r = httpx.get(nextLink, headers=self.headers)
results += r.json()["results"]
nextLink = r.json().get("next")
return results
def get_doc_by_id(self, doc_id: int):
url = f"http://{os.getenv('BASE_URL')}/api/documents/{doc_id}/"
@@ -45,12 +57,12 @@ class PaperlessNGXService:
def upload_description(self, description_filepath, file, title, exif_date: str):
POST_URL = f"http://{os.getenv('BASE_URL')}/api/documents/post_document/"
files = {'document': ('description_filepath', file, 'application/txt')}
files = {"document": ("description_filepath", file, "application/txt")}
data = {
"title": title,
"create": exif_date,
"document_type": 3
"tags": [7]
"document_type": 3,
"tags": [7],
}
r = httpx.post(POST_URL, headers=self.headers, data=data, files=files)