diff --git a/chunker.py b/chunker.py index 0ec6b21..589e348 100644 --- a/chunker.py +++ b/chunker.py @@ -3,15 +3,19 @@ from math import ceil import re from typing import Union from uuid import UUID, uuid4 - +from ollama import Client from chromadb.utils.embedding_functions.openai_embedding_function import ( OpenAIEmbeddingFunction, ) from dotenv import load_dotenv +USE_OPENAI = os.getenv("OPENAI_API_KEY") != None + load_dotenv() +ollama_client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434")) + def remove_headers_footers(text, header_patterns=None, footer_patterns=None): if header_patterns is None: @@ -88,6 +92,17 @@ class Chunker: def __init__(self, collection) -> None: self.collection = collection + def embedding_fx(self, inputs): + if USE_OPENAI: + openai_embedding_fx = OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), + model_name="text-embedding-3-small", + ) + return openai_embedding_fx(inputs) + else: + response = ollama_client.embed(model="mxbai-embed-large", input=inputs[0]) + return response["embeddings"] + def chunk_document( self, document: str, diff --git a/docker-compose.yml b/docker-compose.yml index 64c109f..49980d0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,4 @@ -version: '3.8' +version: "3.8" services: raggr: @@ -14,4 +14,4 @@ services: - chromadb_data:/app/chromadb volumes: - chromadb_data: \ No newline at end of file + chromadb_data: diff --git a/image_process.py b/image_process.py index 7cef2e6..41ed868 100644 --- a/image_process.py +++ b/image_process.py @@ -25,22 +25,24 @@ parser.add_argument("filepath") client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434")) + class SimbaImageDescription(BaseModel): image_date: str description: str + def describe_simba_image(input): logging.info("Opening image of Simba ...") if "heic" in input.lower() or "heif" in input.lower(): new_filepath = input.split(".")[0] + ".jpg" img = Image.open(input) - img.save(new_filepath, 'JPEG') + img.save(new_filepath, "JPEG") logging.info("Extracting EXIF...") exif = { ExifTags.TAGS[k]: v for k, v in img.getexif().items() if k in ExifTags.TAGS } img = Image.open(new_filepath) - input=new_filepath + input = new_filepath else: img = Image.open(input) @@ -66,7 +68,7 @@ def describe_simba_image(input): }, {"role": "user", "content": prompt, "images": [input]}, ], - format=SimbaImageDescription.model_json_schema() + format=SimbaImageDescription.model_json_schema(), ) result = SimbaImageDescription.model_validate_json(response["message"]["content"]) diff --git a/index_immich.py b/index_immich.py index ad12e2f..76a67f4 100644 --- a/index_immich.py +++ b/index_immich.py @@ -6,6 +6,7 @@ import tempfile from image_process import describe_simba_image from request import PaperlessNGXService +import sqlite3 logging.basicConfig(level=logging.INFO) @@ -23,8 +24,16 @@ DOWNLOAD_DIR = os.getenv("DOWNLOAD_DIR", "./simba_photos") # Set up headers headers = {"x-api-key": API_KEY, "Content-Type": "application/json"} +VISITED = {} if __name__ == "__main__": + conn = sqlite3.connect("./visited.db") + c = conn.cursor() + c.execute("select immich_id from visited") + rows = c.fetchall() + for row in rows: + VISITED.add(row[0]) + ppngx = PaperlessNGXService() people_url = f"{IMMICH_URL}/api/search/person?name=Simba" people = httpx.get(people_url, headers=headers).json() @@ -39,7 +48,7 @@ if __name__ == "__main__": assets = results.json()["assets"] for asset in assets["items"]: - if asset["type"] == "IMAGE": + if asset["type"] == "IMAGE" and asset["id"] not in VISITED: ids[asset["id"]] = asset.get("originalFileName") nextPage = assets.get("nextPage") @@ -58,41 +67,49 @@ if __name__ == "__main__": asset_search = f"{IMMICH_URL}/api/search/smart" request_body = {"query": "simba"} results = httpx.post(asset_search, headers=headers, json=request_body) - print(results.json()["assets"]["total"]) for asset in results.json()["assets"]["items"]: if asset["type"] == "IMAGE": ids[asset["id"]] = asset.get("originalFileName") - immich_asset_id = list(ids.keys())[1] - immich_filename = ids.get(immich_asset_id) - response = httpx.get( - f"{IMMICH_URL}/api/assets/{immich_asset_id}/original", headers=headers - ) + for immich_asset_id, immich_filename in ids.items(): + try: + response = httpx.get( + f"{IMMICH_URL}/api/assets/{immich_asset_id}/original", headers=headers + ) - path = os.path.join("/Users/ryanchen/Programs/raggr", immich_filename) - file = open(path, "wb+") - for chunk in response.iter_bytes(chunk_size=8192): - file.write(chunk) + path = os.path.join("/Users/ryanchen/Programs/raggr", immich_filename) + file = open(path, "wb+") + for chunk in response.iter_bytes(chunk_size=8192): + file.write(chunk) - logging.info("Processing image ...") - description = describe_simba_image(path) + logging.info("Processing image ...") + description = describe_simba_image(path) - image_description = description.description - image_date = description.image_date + image_description = description.description + image_date = description.image_date - description_filepath = os.path.join("/Users/ryanchen/Programs/raggr", f"SIMBA_DESCRIBE_001.txt") - file = open(description_filepath, "w+") - file.write(image_description) - file.close() + description_filepath = os.path.join( + "/Users/ryanchen/Programs/raggr", f"SIMBA_DESCRIBE_001.txt" + ) + file = open(description_filepath, "w+") + file.write(image_description) + file.close() - file = open(description_filepath, 'rb') + file = open(description_filepath, "rb") + ppngx.upload_description( + description_filepath=description_filepath, + file=file, + title="SIMBA_DESCRIBE_001.txt", + exif_date=image_date, + ) + file.close() - ppngx.upload_description(description_filepath=description_filepath, file=file, title="SIMBA_DESCRIBE_001.txt", exif_date=image_date) - + c.execute("INSERT INTO visited (immich_id) values (?)", (immich_asset_id,)) + conn.commit() + logging.info("Processing complete. Deleting file.") + os.remove(file.name) + except Exception as e: + logging.info(f"something went wrong for {immich_filename}") + logging.info(e) - file.close() - - - - logging.info("Processing complete. Deleting file.") - os.remove(file.name) + conn.close() diff --git a/main.py b/main.py index 4e58b33..b116fcd 100644 --- a/main.py +++ b/main.py @@ -18,11 +18,13 @@ from dotenv import load_dotenv load_dotenv() +USE_OPENAI = os.getenv("OPENAI_API_KEY") != None + # Configure ollama client with URL from environment or default to localhost ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434")) client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", "")) -simba_docs = client.get_or_create_collection(name="simba_docs") +simba_docs = client.get_or_create_collection(name="simba_docs2") feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup") parser = argparse.ArgumentParser( @@ -55,7 +57,6 @@ def index_using_pdf_llm(): def date_to_epoch(date_str: str) -> float: split_date = date_str.split("-") - print(split_date) date = datetime.datetime( int(split_date[0]), int(split_date[1]), @@ -73,10 +74,8 @@ def chunk_data(docs: list[dict[str, Union[str, Any]]], collection): chunker = Chunker(collection) print(f"chunking {len(docs)} documents") - print(docs) texts: list[str] = [doc["content"] for doc in docs] for index, text in enumerate(texts): - print(docs[index]["original_file_name"]) metadata = { "created_date": date_to_epoch(docs[index]["created_date"]), "filename": docs[index]["original_file_name"], @@ -101,6 +100,7 @@ def chunk_text(texts: list[str], collection): def consult_oracle(input: str, collection): print(input) import time + chunker = Chunker(collection) start_time = time.time() @@ -115,7 +115,7 @@ def consult_oracle(input: str, collection): print("Starting embedding generation") embedding_start = time.time() - embeddings = Chunker.embedding_fx(input=[input]) + embeddings = chunker.embedding_fx(inputs=[input]) embedding_end = time.time() print(f"Embedding generation took {embedding_end - embedding_start:.2f} seconds") @@ -126,37 +126,40 @@ def consult_oracle(input: str, collection): query_embeddings=embeddings, # where=metadata_filter, ) - print(results) query_end = time.time() print(f"Collection query took {query_end - query_start:.2f} seconds") # Generate print("Starting LLM generation") llm_start = time.time() - # output = ollama_client.generate( - # model="gemma3n:e4b", - # prompt=f"You are a helpful assistant that understandings veterinary terms. Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}", - # ) - response = openai_client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - { - "role": "system", - "content": "You are a helpful assistant that understands veterinary terms.", - }, - { - "role": "user", - "content": f"Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}", - }, - ], - ) + if USE_OPENAI: + response = openai_client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that understands veterinary terms.", + }, + { + "role": "user", + "content": f"Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}", + }, + ], + ) + output= response.choices[0].message.content + else: + response = ollama_client.generate( + model="gemma3:4b", + prompt=f"You are a helpful assistant that understandings veterinary terms. Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}", + ) + output = response["response"] llm_end = time.time() print(f"LLM generation took {llm_end - llm_start:.2f} seconds") total_time = time.time() - start_time print(f"Total consult_oracle execution took {total_time:.2f} seconds") - return response.choices[0].message.content + return output def paperless_workflow(input): @@ -181,7 +184,6 @@ if __name__ == "__main__": print("Fetching documents from Paperless-NGX") ppngx = PaperlessNGXService() docs = ppngx.get_data() - print(docs) print(f"Fetched {len(docs)} documents") # print("Chunking documents now ...") @@ -192,7 +194,6 @@ if __name__ == "__main__": if args.index: with open(args.index) as file: extension = args.index.split(".")[-1] - if extension == "pdf": pdf_path = ppngx.download_pdf_from_id(id=document_id) image_paths = pdf_to_image(filepath=pdf_path) diff --git a/query.py b/query.py index 927bbee..defa02a 100644 --- a/query.py +++ b/query.py @@ -45,7 +45,6 @@ the date range the user is trying to query. You should return it as a JSON. The If the created_date cannot be ascertained, set it to epoch time start. - You have several operators at your disposal: - $gt: greater than - $gte: greater than or equal @@ -83,6 +82,8 @@ document_types: Only return the extracted metadata fields. Make sure the extracted metadata fields are valid JSON """ +USE_OPENAI = os.getenv("OPENAI_API_KEY", None) != None + class QueryGenerator: def __init__(self) -> None: @@ -102,38 +103,38 @@ class QueryGenerator: return date.timestamp() def get_query(self, input: str): - client = OpenAI() - print(input) - response = client.responses.parse( - model="gpt-4o", - input=[ - {"role": "system", "content": PROMPT}, - {"role": "user", "content": input}, - ], - text_format=Time, - ) - print(response) - query = json.loads(response.output_parsed.extracted_metadata_fields) + if USE_OPENAI: + client = OpenAI() + response = client.responses.parse( + model="gpt-4o", + input=[ + {"role": "system", "content": PROMPT}, + {"role": "user", "content": input}, + ], + text_format=GeneratedQuery, + ) + print(response.output) + query = json.loads(response.output_parsed.extracted_metadata_fields) + else: + response: ChatResponse = ollama_client.chat( + model="gemma3n:e4b", + messages=[ + {"role": "system", "content": PROMPT}, + {"role": "user", "content": input}, + ], + format=GeneratedQuery.model_json_schema(), + ) - # response: ChatResponse = ollama_client.chat( - # model="gemma3n:e4b", - # messages=[ - # {"role": "system", "content": PROMPT}, - # {"role": "user", "content": input}, - # ], - # format=GeneratedQuery.model_json_schema(), - # ) + query = json.loads( + json.loads(response["message"]["content"])["extracted_metadata_fields"] + ) + date_key = list(query["created_date"].keys())[0] + query["created_date"][date_key] = self.date_to_epoch( + query["created_date"][date_key] + ) - # query = json.loads( - # json.loads(response["message"]["content"])["extracted_metadata_fields"] - # ) - date_key = list(query["created_date"].keys())[0] - query["created_date"][date_key] = self.date_to_epoch( - query["created_date"][date_key] - ) - - if "$" not in date_key: - query["created_date"]["$" + date_key] = query["created_date"][date_key] + if "$" not in date_key: + query["created_date"]["$" + date_key] = query["created_date"][date_key] return query diff --git a/request.py b/request.py index d6c0759..1eac623 100644 --- a/request.py +++ b/request.py @@ -1,11 +1,14 @@ import os import tempfile import httpx +import logging from dotenv import load_dotenv load_dotenv() +logging.basicConfig(level=logging.INFO) + class PaperlessNGXService: def __init__(self): @@ -17,7 +20,16 @@ class PaperlessNGXService: def get_data(self): print(f"Getting data from: {self.url}") r = httpx.get(self.url, headers=self.headers) - return r.json()["results"] + results = r.json()["results"] + + nextLink = r.json().get("next") + + while nextLink: + r = httpx.get(nextLink, headers=self.headers) + results += r.json()["results"] + nextLink = r.json().get("next") + + return results def get_doc_by_id(self, doc_id: int): url = f"http://{os.getenv('BASE_URL')}/api/documents/{doc_id}/" @@ -45,15 +57,15 @@ class PaperlessNGXService: def upload_description(self, description_filepath, file, title, exif_date: str): POST_URL = f"http://{os.getenv('BASE_URL')}/api/documents/post_document/" - files = {'document': ('description_filepath', file, 'application/txt')} + files = {"document": ("description_filepath", file, "application/txt")} data = { - "title": title, - "create": exif_date, - "document_type": 3 - "tags": [7] + "title": title, + "create": exif_date, + "document_type": 3, + "tags": [7], } - r= httpx.post(POST_URL, headers=self.headers, data=data, files=files) + r = httpx.post(POST_URL, headers=self.headers, data=data, files=files) r.raise_for_status()