This commit is contained in:
2025-10-05 20:31:46 -04:00
parent 0bb3e3172b
commit 910097d13b
7 changed files with 146 additions and 98 deletions

View File

@@ -3,15 +3,19 @@ from math import ceil
import re import re
from typing import Union from typing import Union
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from ollama import Client
from chromadb.utils.embedding_functions.openai_embedding_function import ( from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction, OpenAIEmbeddingFunction,
) )
from dotenv import load_dotenv from dotenv import load_dotenv
USE_OPENAI = os.getenv("OPENAI_API_KEY") != None
load_dotenv() load_dotenv()
ollama_client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
def remove_headers_footers(text, header_patterns=None, footer_patterns=None): def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
if header_patterns is None: if header_patterns is None:
@@ -88,6 +92,17 @@ class Chunker:
def __init__(self, collection) -> None: def __init__(self, collection) -> None:
self.collection = collection self.collection = collection
def embedding_fx(self, inputs):
if USE_OPENAI:
openai_embedding_fx = OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="text-embedding-3-small",
)
return openai_embedding_fx(inputs)
else:
response = ollama_client.embed(model="mxbai-embed-large", input=inputs[0])
return response["embeddings"]
def chunk_document( def chunk_document(
self, self,
document: str, document: str,

View File

@@ -1,4 +1,4 @@
version: '3.8' version: "3.8"
services: services:
raggr: raggr:

View File

@@ -25,22 +25,24 @@ parser.add_argument("filepath")
client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434")) client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
class SimbaImageDescription(BaseModel): class SimbaImageDescription(BaseModel):
image_date: str image_date: str
description: str description: str
def describe_simba_image(input): def describe_simba_image(input):
logging.info("Opening image of Simba ...") logging.info("Opening image of Simba ...")
if "heic" in input.lower() or "heif" in input.lower(): if "heic" in input.lower() or "heif" in input.lower():
new_filepath = input.split(".")[0] + ".jpg" new_filepath = input.split(".")[0] + ".jpg"
img = Image.open(input) img = Image.open(input)
img.save(new_filepath, 'JPEG') img.save(new_filepath, "JPEG")
logging.info("Extracting EXIF...") logging.info("Extracting EXIF...")
exif = { exif = {
ExifTags.TAGS[k]: v for k, v in img.getexif().items() if k in ExifTags.TAGS ExifTags.TAGS[k]: v for k, v in img.getexif().items() if k in ExifTags.TAGS
} }
img = Image.open(new_filepath) img = Image.open(new_filepath)
input=new_filepath input = new_filepath
else: else:
img = Image.open(input) img = Image.open(input)
@@ -66,7 +68,7 @@ def describe_simba_image(input):
}, },
{"role": "user", "content": prompt, "images": [input]}, {"role": "user", "content": prompt, "images": [input]},
], ],
format=SimbaImageDescription.model_json_schema() format=SimbaImageDescription.model_json_schema(),
) )
result = SimbaImageDescription.model_validate_json(response["message"]["content"]) result = SimbaImageDescription.model_validate_json(response["message"]["content"])

View File

@@ -6,6 +6,7 @@ import tempfile
from image_process import describe_simba_image from image_process import describe_simba_image
from request import PaperlessNGXService from request import PaperlessNGXService
import sqlite3
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@@ -23,8 +24,16 @@ DOWNLOAD_DIR = os.getenv("DOWNLOAD_DIR", "./simba_photos")
# Set up headers # Set up headers
headers = {"x-api-key": API_KEY, "Content-Type": "application/json"} headers = {"x-api-key": API_KEY, "Content-Type": "application/json"}
VISITED = {}
if __name__ == "__main__": if __name__ == "__main__":
conn = sqlite3.connect("./visited.db")
c = conn.cursor()
c.execute("select immich_id from visited")
rows = c.fetchall()
for row in rows:
VISITED.add(row[0])
ppngx = PaperlessNGXService() ppngx = PaperlessNGXService()
people_url = f"{IMMICH_URL}/api/search/person?name=Simba" people_url = f"{IMMICH_URL}/api/search/person?name=Simba"
people = httpx.get(people_url, headers=headers).json() people = httpx.get(people_url, headers=headers).json()
@@ -39,7 +48,7 @@ if __name__ == "__main__":
assets = results.json()["assets"] assets = results.json()["assets"]
for asset in assets["items"]: for asset in assets["items"]:
if asset["type"] == "IMAGE": if asset["type"] == "IMAGE" and asset["id"] not in VISITED:
ids[asset["id"]] = asset.get("originalFileName") ids[asset["id"]] = asset.get("originalFileName")
nextPage = assets.get("nextPage") nextPage = assets.get("nextPage")
@@ -58,13 +67,12 @@ if __name__ == "__main__":
asset_search = f"{IMMICH_URL}/api/search/smart" asset_search = f"{IMMICH_URL}/api/search/smart"
request_body = {"query": "simba"} request_body = {"query": "simba"}
results = httpx.post(asset_search, headers=headers, json=request_body) results = httpx.post(asset_search, headers=headers, json=request_body)
print(results.json()["assets"]["total"])
for asset in results.json()["assets"]["items"]: for asset in results.json()["assets"]["items"]:
if asset["type"] == "IMAGE": if asset["type"] == "IMAGE":
ids[asset["id"]] = asset.get("originalFileName") ids[asset["id"]] = asset.get("originalFileName")
immich_asset_id = list(ids.keys())[1] for immich_asset_id, immich_filename in ids.items():
immich_filename = ids.get(immich_asset_id) try:
response = httpx.get( response = httpx.get(
f"{IMMICH_URL}/api/assets/{immich_asset_id}/original", headers=headers f"{IMMICH_URL}/api/assets/{immich_asset_id}/original", headers=headers
) )
@@ -80,19 +88,28 @@ if __name__ == "__main__":
image_description = description.description image_description = description.description
image_date = description.image_date image_date = description.image_date
description_filepath = os.path.join("/Users/ryanchen/Programs/raggr", f"SIMBA_DESCRIBE_001.txt") description_filepath = os.path.join(
"/Users/ryanchen/Programs/raggr", f"SIMBA_DESCRIBE_001.txt"
)
file = open(description_filepath, "w+") file = open(description_filepath, "w+")
file.write(image_description) file.write(image_description)
file.close() file.close()
file = open(description_filepath, 'rb') file = open(description_filepath, "rb")
ppngx.upload_description(
ppngx.upload_description(description_filepath=description_filepath, file=file, title="SIMBA_DESCRIBE_001.txt", exif_date=image_date) description_filepath=description_filepath,
file=file,
title="SIMBA_DESCRIBE_001.txt",
exif_date=image_date,
)
file.close() file.close()
c.execute("INSERT INTO visited (immich_id) values (?)", (immich_asset_id,))
conn.commit()
logging.info("Processing complete. Deleting file.") logging.info("Processing complete. Deleting file.")
os.remove(file.name) os.remove(file.name)
except Exception as e:
logging.info(f"something went wrong for {immich_filename}")
logging.info(e)
conn.close()

27
main.py
View File

@@ -18,11 +18,13 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
USE_OPENAI = os.getenv("OPENAI_API_KEY") != None
# Configure ollama client with URL from environment or default to localhost # Configure ollama client with URL from environment or default to localhost
ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434")) ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434"))
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", "")) client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
simba_docs = client.get_or_create_collection(name="simba_docs") simba_docs = client.get_or_create_collection(name="simba_docs2")
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup") feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@@ -55,7 +57,6 @@ def index_using_pdf_llm():
def date_to_epoch(date_str: str) -> float: def date_to_epoch(date_str: str) -> float:
split_date = date_str.split("-") split_date = date_str.split("-")
print(split_date)
date = datetime.datetime( date = datetime.datetime(
int(split_date[0]), int(split_date[0]),
int(split_date[1]), int(split_date[1]),
@@ -73,10 +74,8 @@ def chunk_data(docs: list[dict[str, Union[str, Any]]], collection):
chunker = Chunker(collection) chunker = Chunker(collection)
print(f"chunking {len(docs)} documents") print(f"chunking {len(docs)} documents")
print(docs)
texts: list[str] = [doc["content"] for doc in docs] texts: list[str] = [doc["content"] for doc in docs]
for index, text in enumerate(texts): for index, text in enumerate(texts):
print(docs[index]["original_file_name"])
metadata = { metadata = {
"created_date": date_to_epoch(docs[index]["created_date"]), "created_date": date_to_epoch(docs[index]["created_date"]),
"filename": docs[index]["original_file_name"], "filename": docs[index]["original_file_name"],
@@ -101,6 +100,7 @@ def chunk_text(texts: list[str], collection):
def consult_oracle(input: str, collection): def consult_oracle(input: str, collection):
print(input) print(input)
import time import time
chunker = Chunker(collection)
start_time = time.time() start_time = time.time()
@@ -115,7 +115,7 @@ def consult_oracle(input: str, collection):
print("Starting embedding generation") print("Starting embedding generation")
embedding_start = time.time() embedding_start = time.time()
embeddings = Chunker.embedding_fx(input=[input]) embeddings = chunker.embedding_fx(inputs=[input])
embedding_end = time.time() embedding_end = time.time()
print(f"Embedding generation took {embedding_end - embedding_start:.2f} seconds") print(f"Embedding generation took {embedding_end - embedding_start:.2f} seconds")
@@ -126,17 +126,13 @@ def consult_oracle(input: str, collection):
query_embeddings=embeddings, query_embeddings=embeddings,
# where=metadata_filter, # where=metadata_filter,
) )
print(results)
query_end = time.time() query_end = time.time()
print(f"Collection query took {query_end - query_start:.2f} seconds") print(f"Collection query took {query_end - query_start:.2f} seconds")
# Generate # Generate
print("Starting LLM generation") print("Starting LLM generation")
llm_start = time.time() llm_start = time.time()
# output = ollama_client.generate( if USE_OPENAI:
# model="gemma3n:e4b",
# prompt=f"You are a helpful assistant that understandings veterinary terms. Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}",
# )
response = openai_client.chat.completions.create( response = openai_client.chat.completions.create(
model="gpt-4o-mini", model="gpt-4o-mini",
messages=[ messages=[
@@ -150,13 +146,20 @@ def consult_oracle(input: str, collection):
}, },
], ],
) )
output= response.choices[0].message.content
else:
response = ollama_client.generate(
model="gemma3:4b",
prompt=f"You are a helpful assistant that understandings veterinary terms. Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}",
)
output = response["response"]
llm_end = time.time() llm_end = time.time()
print(f"LLM generation took {llm_end - llm_start:.2f} seconds") print(f"LLM generation took {llm_end - llm_start:.2f} seconds")
total_time = time.time() - start_time total_time = time.time() - start_time
print(f"Total consult_oracle execution took {total_time:.2f} seconds") print(f"Total consult_oracle execution took {total_time:.2f} seconds")
return response.choices[0].message.content return output
def paperless_workflow(input): def paperless_workflow(input):
@@ -181,7 +184,6 @@ if __name__ == "__main__":
print("Fetching documents from Paperless-NGX") print("Fetching documents from Paperless-NGX")
ppngx = PaperlessNGXService() ppngx = PaperlessNGXService()
docs = ppngx.get_data() docs = ppngx.get_data()
print(docs)
print(f"Fetched {len(docs)} documents") print(f"Fetched {len(docs)} documents")
# #
print("Chunking documents now ...") print("Chunking documents now ...")
@@ -192,7 +194,6 @@ if __name__ == "__main__":
if args.index: if args.index:
with open(args.index) as file: with open(args.index) as file:
extension = args.index.split(".")[-1] extension = args.index.split(".")[-1]
if extension == "pdf": if extension == "pdf":
pdf_path = ppngx.download_pdf_from_id(id=document_id) pdf_path = ppngx.download_pdf_from_id(id=document_id)
image_paths = pdf_to_image(filepath=pdf_path) image_paths = pdf_to_image(filepath=pdf_path)

View File

@@ -45,7 +45,6 @@ the date range the user is trying to query. You should return it as a JSON. The
If the created_date cannot be ascertained, set it to epoch time start. If the created_date cannot be ascertained, set it to epoch time start.
You have several operators at your disposal: You have several operators at your disposal:
- $gt: greater than - $gt: greater than
- $gte: greater than or equal - $gte: greater than or equal
@@ -83,6 +82,8 @@ document_types:
Only return the extracted metadata fields. Make sure the extracted metadata fields are valid JSON Only return the extracted metadata fields. Make sure the extracted metadata fields are valid JSON
""" """
USE_OPENAI = os.getenv("OPENAI_API_KEY", None) != None
class QueryGenerator: class QueryGenerator:
def __init__(self) -> None: def __init__(self) -> None:
@@ -102,31 +103,31 @@ class QueryGenerator:
return date.timestamp() return date.timestamp()
def get_query(self, input: str): def get_query(self, input: str):
if USE_OPENAI:
client = OpenAI() client = OpenAI()
print(input)
response = client.responses.parse( response = client.responses.parse(
model="gpt-4o", model="gpt-4o",
input=[ input=[
{"role": "system", "content": PROMPT}, {"role": "system", "content": PROMPT},
{"role": "user", "content": input}, {"role": "user", "content": input},
], ],
text_format=Time, text_format=GeneratedQuery,
) )
print(response) print(response.output)
query = json.loads(response.output_parsed.extracted_metadata_fields) query = json.loads(response.output_parsed.extracted_metadata_fields)
else:
response: ChatResponse = ollama_client.chat(
model="gemma3n:e4b",
messages=[
{"role": "system", "content": PROMPT},
{"role": "user", "content": input},
],
format=GeneratedQuery.model_json_schema(),
)
# response: ChatResponse = ollama_client.chat( query = json.loads(
# model="gemma3n:e4b", json.loads(response["message"]["content"])["extracted_metadata_fields"]
# messages=[ )
# {"role": "system", "content": PROMPT},
# {"role": "user", "content": input},
# ],
# format=GeneratedQuery.model_json_schema(),
# )
# query = json.loads(
# json.loads(response["message"]["content"])["extracted_metadata_fields"]
# )
date_key = list(query["created_date"].keys())[0] date_key = list(query["created_date"].keys())[0]
query["created_date"][date_key] = self.date_to_epoch( query["created_date"][date_key] = self.date_to_epoch(
query["created_date"][date_key] query["created_date"][date_key]

View File

@@ -1,11 +1,14 @@
import os import os
import tempfile import tempfile
import httpx import httpx
import logging
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO)
class PaperlessNGXService: class PaperlessNGXService:
def __init__(self): def __init__(self):
@@ -17,7 +20,16 @@ class PaperlessNGXService:
def get_data(self): def get_data(self):
print(f"Getting data from: {self.url}") print(f"Getting data from: {self.url}")
r = httpx.get(self.url, headers=self.headers) r = httpx.get(self.url, headers=self.headers)
return r.json()["results"] results = r.json()["results"]
nextLink = r.json().get("next")
while nextLink:
r = httpx.get(nextLink, headers=self.headers)
results += r.json()["results"]
nextLink = r.json().get("next")
return results
def get_doc_by_id(self, doc_id: int): def get_doc_by_id(self, doc_id: int):
url = f"http://{os.getenv('BASE_URL')}/api/documents/{doc_id}/" url = f"http://{os.getenv('BASE_URL')}/api/documents/{doc_id}/"
@@ -45,15 +57,15 @@ class PaperlessNGXService:
def upload_description(self, description_filepath, file, title, exif_date: str): def upload_description(self, description_filepath, file, title, exif_date: str):
POST_URL = f"http://{os.getenv('BASE_URL')}/api/documents/post_document/" POST_URL = f"http://{os.getenv('BASE_URL')}/api/documents/post_document/"
files = {'document': ('description_filepath', file, 'application/txt')} files = {"document": ("description_filepath", file, "application/txt")}
data = { data = {
"title": title, "title": title,
"create": exif_date, "create": exif_date,
"document_type": 3 "document_type": 3,
"tags": [7] "tags": [7],
} }
r= httpx.post(POST_URL, headers=self.headers, data=data, files=files) r = httpx.post(POST_URL, headers=self.headers, data=data, files=files)
r.raise_for_status() r.raise_for_status()