data
This commit is contained in:
17
chunker.py
17
chunker.py
@@ -3,15 +3,19 @@ from math import ceil
|
|||||||
import re
|
import re
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
from ollama import Client
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
OpenAIEmbeddingFunction,
|
OpenAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
||||||
|
USE_OPENAI = os.getenv("OPENAI_API_KEY") != None
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
ollama_client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
|
||||||
|
|
||||||
|
|
||||||
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
|
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
|
||||||
if header_patterns is None:
|
if header_patterns is None:
|
||||||
@@ -88,6 +92,17 @@ class Chunker:
|
|||||||
def __init__(self, collection) -> None:
|
def __init__(self, collection) -> None:
|
||||||
self.collection = collection
|
self.collection = collection
|
||||||
|
|
||||||
|
def embedding_fx(self, inputs):
|
||||||
|
if USE_OPENAI:
|
||||||
|
openai_embedding_fx = OpenAIEmbeddingFunction(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
)
|
||||||
|
return openai_embedding_fx(inputs)
|
||||||
|
else:
|
||||||
|
response = ollama_client.embed(model="mxbai-embed-large", input=inputs[0])
|
||||||
|
return response["embeddings"]
|
||||||
|
|
||||||
def chunk_document(
|
def chunk_document(
|
||||||
self,
|
self,
|
||||||
document: str,
|
document: str,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
version: '3.8'
|
version: "3.8"
|
||||||
|
|
||||||
services:
|
services:
|
||||||
raggr:
|
raggr:
|
||||||
|
|||||||
@@ -25,16 +25,18 @@ parser.add_argument("filepath")
|
|||||||
|
|
||||||
client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
|
client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
|
||||||
|
|
||||||
|
|
||||||
class SimbaImageDescription(BaseModel):
|
class SimbaImageDescription(BaseModel):
|
||||||
image_date: str
|
image_date: str
|
||||||
description: str
|
description: str
|
||||||
|
|
||||||
|
|
||||||
def describe_simba_image(input):
|
def describe_simba_image(input):
|
||||||
logging.info("Opening image of Simba ...")
|
logging.info("Opening image of Simba ...")
|
||||||
if "heic" in input.lower() or "heif" in input.lower():
|
if "heic" in input.lower() or "heif" in input.lower():
|
||||||
new_filepath = input.split(".")[0] + ".jpg"
|
new_filepath = input.split(".")[0] + ".jpg"
|
||||||
img = Image.open(input)
|
img = Image.open(input)
|
||||||
img.save(new_filepath, 'JPEG')
|
img.save(new_filepath, "JPEG")
|
||||||
logging.info("Extracting EXIF...")
|
logging.info("Extracting EXIF...")
|
||||||
exif = {
|
exif = {
|
||||||
ExifTags.TAGS[k]: v for k, v in img.getexif().items() if k in ExifTags.TAGS
|
ExifTags.TAGS[k]: v for k, v in img.getexif().items() if k in ExifTags.TAGS
|
||||||
@@ -66,7 +68,7 @@ def describe_simba_image(input):
|
|||||||
},
|
},
|
||||||
{"role": "user", "content": prompt, "images": [input]},
|
{"role": "user", "content": prompt, "images": [input]},
|
||||||
],
|
],
|
||||||
format=SimbaImageDescription.model_json_schema()
|
format=SimbaImageDescription.model_json_schema(),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = SimbaImageDescription.model_validate_json(response["message"]["content"])
|
result = SimbaImageDescription.model_validate_json(response["message"]["content"])
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import tempfile
|
|||||||
|
|
||||||
from image_process import describe_simba_image
|
from image_process import describe_simba_image
|
||||||
from request import PaperlessNGXService
|
from request import PaperlessNGXService
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
@@ -23,8 +24,16 @@ DOWNLOAD_DIR = os.getenv("DOWNLOAD_DIR", "./simba_photos")
|
|||||||
# Set up headers
|
# Set up headers
|
||||||
headers = {"x-api-key": API_KEY, "Content-Type": "application/json"}
|
headers = {"x-api-key": API_KEY, "Content-Type": "application/json"}
|
||||||
|
|
||||||
|
VISITED = {}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
conn = sqlite3.connect("./visited.db")
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute("select immich_id from visited")
|
||||||
|
rows = c.fetchall()
|
||||||
|
for row in rows:
|
||||||
|
VISITED.add(row[0])
|
||||||
|
|
||||||
ppngx = PaperlessNGXService()
|
ppngx = PaperlessNGXService()
|
||||||
people_url = f"{IMMICH_URL}/api/search/person?name=Simba"
|
people_url = f"{IMMICH_URL}/api/search/person?name=Simba"
|
||||||
people = httpx.get(people_url, headers=headers).json()
|
people = httpx.get(people_url, headers=headers).json()
|
||||||
@@ -39,7 +48,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
assets = results.json()["assets"]
|
assets = results.json()["assets"]
|
||||||
for asset in assets["items"]:
|
for asset in assets["items"]:
|
||||||
if asset["type"] == "IMAGE":
|
if asset["type"] == "IMAGE" and asset["id"] not in VISITED:
|
||||||
ids[asset["id"]] = asset.get("originalFileName")
|
ids[asset["id"]] = asset.get("originalFileName")
|
||||||
nextPage = assets.get("nextPage")
|
nextPage = assets.get("nextPage")
|
||||||
|
|
||||||
@@ -58,13 +67,12 @@ if __name__ == "__main__":
|
|||||||
asset_search = f"{IMMICH_URL}/api/search/smart"
|
asset_search = f"{IMMICH_URL}/api/search/smart"
|
||||||
request_body = {"query": "simba"}
|
request_body = {"query": "simba"}
|
||||||
results = httpx.post(asset_search, headers=headers, json=request_body)
|
results = httpx.post(asset_search, headers=headers, json=request_body)
|
||||||
print(results.json()["assets"]["total"])
|
|
||||||
for asset in results.json()["assets"]["items"]:
|
for asset in results.json()["assets"]["items"]:
|
||||||
if asset["type"] == "IMAGE":
|
if asset["type"] == "IMAGE":
|
||||||
ids[asset["id"]] = asset.get("originalFileName")
|
ids[asset["id"]] = asset.get("originalFileName")
|
||||||
|
|
||||||
immich_asset_id = list(ids.keys())[1]
|
for immich_asset_id, immich_filename in ids.items():
|
||||||
immich_filename = ids.get(immich_asset_id)
|
try:
|
||||||
response = httpx.get(
|
response = httpx.get(
|
||||||
f"{IMMICH_URL}/api/assets/{immich_asset_id}/original", headers=headers
|
f"{IMMICH_URL}/api/assets/{immich_asset_id}/original", headers=headers
|
||||||
)
|
)
|
||||||
@@ -80,19 +88,28 @@ if __name__ == "__main__":
|
|||||||
image_description = description.description
|
image_description = description.description
|
||||||
image_date = description.image_date
|
image_date = description.image_date
|
||||||
|
|
||||||
description_filepath = os.path.join("/Users/ryanchen/Programs/raggr", f"SIMBA_DESCRIBE_001.txt")
|
description_filepath = os.path.join(
|
||||||
|
"/Users/ryanchen/Programs/raggr", f"SIMBA_DESCRIBE_001.txt"
|
||||||
|
)
|
||||||
file = open(description_filepath, "w+")
|
file = open(description_filepath, "w+")
|
||||||
file.write(image_description)
|
file.write(image_description)
|
||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
file = open(description_filepath, 'rb')
|
file = open(description_filepath, "rb")
|
||||||
|
ppngx.upload_description(
|
||||||
ppngx.upload_description(description_filepath=description_filepath, file=file, title="SIMBA_DESCRIBE_001.txt", exif_date=image_date)
|
description_filepath=description_filepath,
|
||||||
|
file=file,
|
||||||
|
title="SIMBA_DESCRIBE_001.txt",
|
||||||
|
exif_date=image_date,
|
||||||
|
)
|
||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
|
c.execute("INSERT INTO visited (immich_id) values (?)", (immich_asset_id,))
|
||||||
|
conn.commit()
|
||||||
logging.info("Processing complete. Deleting file.")
|
logging.info("Processing complete. Deleting file.")
|
||||||
os.remove(file.name)
|
os.remove(file.name)
|
||||||
|
except Exception as e:
|
||||||
|
logging.info(f"something went wrong for {immich_filename}")
|
||||||
|
logging.info(e)
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|||||||
27
main.py
27
main.py
@@ -18,11 +18,13 @@ from dotenv import load_dotenv
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
USE_OPENAI = os.getenv("OPENAI_API_KEY") != None
|
||||||
|
|
||||||
# Configure ollama client with URL from environment or default to localhost
|
# Configure ollama client with URL from environment or default to localhost
|
||||||
ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
||||||
|
|
||||||
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
|
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
|
||||||
simba_docs = client.get_or_create_collection(name="simba_docs")
|
simba_docs = client.get_or_create_collection(name="simba_docs2")
|
||||||
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
|
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@@ -55,7 +57,6 @@ def index_using_pdf_llm():
|
|||||||
|
|
||||||
def date_to_epoch(date_str: str) -> float:
|
def date_to_epoch(date_str: str) -> float:
|
||||||
split_date = date_str.split("-")
|
split_date = date_str.split("-")
|
||||||
print(split_date)
|
|
||||||
date = datetime.datetime(
|
date = datetime.datetime(
|
||||||
int(split_date[0]),
|
int(split_date[0]),
|
||||||
int(split_date[1]),
|
int(split_date[1]),
|
||||||
@@ -73,10 +74,8 @@ def chunk_data(docs: list[dict[str, Union[str, Any]]], collection):
|
|||||||
chunker = Chunker(collection)
|
chunker = Chunker(collection)
|
||||||
|
|
||||||
print(f"chunking {len(docs)} documents")
|
print(f"chunking {len(docs)} documents")
|
||||||
print(docs)
|
|
||||||
texts: list[str] = [doc["content"] for doc in docs]
|
texts: list[str] = [doc["content"] for doc in docs]
|
||||||
for index, text in enumerate(texts):
|
for index, text in enumerate(texts):
|
||||||
print(docs[index]["original_file_name"])
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"created_date": date_to_epoch(docs[index]["created_date"]),
|
"created_date": date_to_epoch(docs[index]["created_date"]),
|
||||||
"filename": docs[index]["original_file_name"],
|
"filename": docs[index]["original_file_name"],
|
||||||
@@ -101,6 +100,7 @@ def chunk_text(texts: list[str], collection):
|
|||||||
def consult_oracle(input: str, collection):
|
def consult_oracle(input: str, collection):
|
||||||
print(input)
|
print(input)
|
||||||
import time
|
import time
|
||||||
|
chunker = Chunker(collection)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@@ -115,7 +115,7 @@ def consult_oracle(input: str, collection):
|
|||||||
|
|
||||||
print("Starting embedding generation")
|
print("Starting embedding generation")
|
||||||
embedding_start = time.time()
|
embedding_start = time.time()
|
||||||
embeddings = Chunker.embedding_fx(input=[input])
|
embeddings = chunker.embedding_fx(inputs=[input])
|
||||||
embedding_end = time.time()
|
embedding_end = time.time()
|
||||||
print(f"Embedding generation took {embedding_end - embedding_start:.2f} seconds")
|
print(f"Embedding generation took {embedding_end - embedding_start:.2f} seconds")
|
||||||
|
|
||||||
@@ -126,17 +126,13 @@ def consult_oracle(input: str, collection):
|
|||||||
query_embeddings=embeddings,
|
query_embeddings=embeddings,
|
||||||
# where=metadata_filter,
|
# where=metadata_filter,
|
||||||
)
|
)
|
||||||
print(results)
|
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
print(f"Collection query took {query_end - query_start:.2f} seconds")
|
print(f"Collection query took {query_end - query_start:.2f} seconds")
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
print("Starting LLM generation")
|
print("Starting LLM generation")
|
||||||
llm_start = time.time()
|
llm_start = time.time()
|
||||||
# output = ollama_client.generate(
|
if USE_OPENAI:
|
||||||
# model="gemma3n:e4b",
|
|
||||||
# prompt=f"You are a helpful assistant that understandings veterinary terms. Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}",
|
|
||||||
# )
|
|
||||||
response = openai_client.chat.completions.create(
|
response = openai_client.chat.completions.create(
|
||||||
model="gpt-4o-mini",
|
model="gpt-4o-mini",
|
||||||
messages=[
|
messages=[
|
||||||
@@ -150,13 +146,20 @@ def consult_oracle(input: str, collection):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
output= response.choices[0].message.content
|
||||||
|
else:
|
||||||
|
response = ollama_client.generate(
|
||||||
|
model="gemma3:4b",
|
||||||
|
prompt=f"You are a helpful assistant that understandings veterinary terms. Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}",
|
||||||
|
)
|
||||||
|
output = response["response"]
|
||||||
llm_end = time.time()
|
llm_end = time.time()
|
||||||
print(f"LLM generation took {llm_end - llm_start:.2f} seconds")
|
print(f"LLM generation took {llm_end - llm_start:.2f} seconds")
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
print(f"Total consult_oracle execution took {total_time:.2f} seconds")
|
print(f"Total consult_oracle execution took {total_time:.2f} seconds")
|
||||||
|
|
||||||
return response.choices[0].message.content
|
return output
|
||||||
|
|
||||||
|
|
||||||
def paperless_workflow(input):
|
def paperless_workflow(input):
|
||||||
@@ -181,7 +184,6 @@ if __name__ == "__main__":
|
|||||||
print("Fetching documents from Paperless-NGX")
|
print("Fetching documents from Paperless-NGX")
|
||||||
ppngx = PaperlessNGXService()
|
ppngx = PaperlessNGXService()
|
||||||
docs = ppngx.get_data()
|
docs = ppngx.get_data()
|
||||||
print(docs)
|
|
||||||
print(f"Fetched {len(docs)} documents")
|
print(f"Fetched {len(docs)} documents")
|
||||||
#
|
#
|
||||||
print("Chunking documents now ...")
|
print("Chunking documents now ...")
|
||||||
@@ -192,7 +194,6 @@ if __name__ == "__main__":
|
|||||||
if args.index:
|
if args.index:
|
||||||
with open(args.index) as file:
|
with open(args.index) as file:
|
||||||
extension = args.index.split(".")[-1]
|
extension = args.index.split(".")[-1]
|
||||||
|
|
||||||
if extension == "pdf":
|
if extension == "pdf":
|
||||||
pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
||||||
image_paths = pdf_to_image(filepath=pdf_path)
|
image_paths = pdf_to_image(filepath=pdf_path)
|
||||||
|
|||||||
33
query.py
33
query.py
@@ -45,7 +45,6 @@ the date range the user is trying to query. You should return it as a JSON. The
|
|||||||
|
|
||||||
If the created_date cannot be ascertained, set it to epoch time start.
|
If the created_date cannot be ascertained, set it to epoch time start.
|
||||||
|
|
||||||
|
|
||||||
You have several operators at your disposal:
|
You have several operators at your disposal:
|
||||||
- $gt: greater than
|
- $gt: greater than
|
||||||
- $gte: greater than or equal
|
- $gte: greater than or equal
|
||||||
@@ -83,6 +82,8 @@ document_types:
|
|||||||
Only return the extracted metadata fields. Make sure the extracted metadata fields are valid JSON
|
Only return the extracted metadata fields. Make sure the extracted metadata fields are valid JSON
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
USE_OPENAI = os.getenv("OPENAI_API_KEY", None) != None
|
||||||
|
|
||||||
|
|
||||||
class QueryGenerator:
|
class QueryGenerator:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@@ -102,31 +103,31 @@ class QueryGenerator:
|
|||||||
return date.timestamp()
|
return date.timestamp()
|
||||||
|
|
||||||
def get_query(self, input: str):
|
def get_query(self, input: str):
|
||||||
|
if USE_OPENAI:
|
||||||
client = OpenAI()
|
client = OpenAI()
|
||||||
print(input)
|
|
||||||
response = client.responses.parse(
|
response = client.responses.parse(
|
||||||
model="gpt-4o",
|
model="gpt-4o",
|
||||||
input=[
|
input=[
|
||||||
{"role": "system", "content": PROMPT},
|
{"role": "system", "content": PROMPT},
|
||||||
{"role": "user", "content": input},
|
{"role": "user", "content": input},
|
||||||
],
|
],
|
||||||
text_format=Time,
|
text_format=GeneratedQuery,
|
||||||
)
|
)
|
||||||
print(response)
|
print(response.output)
|
||||||
query = json.loads(response.output_parsed.extracted_metadata_fields)
|
query = json.loads(response.output_parsed.extracted_metadata_fields)
|
||||||
|
else:
|
||||||
|
response: ChatResponse = ollama_client.chat(
|
||||||
|
model="gemma3n:e4b",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": PROMPT},
|
||||||
|
{"role": "user", "content": input},
|
||||||
|
],
|
||||||
|
format=GeneratedQuery.model_json_schema(),
|
||||||
|
)
|
||||||
|
|
||||||
# response: ChatResponse = ollama_client.chat(
|
query = json.loads(
|
||||||
# model="gemma3n:e4b",
|
json.loads(response["message"]["content"])["extracted_metadata_fields"]
|
||||||
# messages=[
|
)
|
||||||
# {"role": "system", "content": PROMPT},
|
|
||||||
# {"role": "user", "content": input},
|
|
||||||
# ],
|
|
||||||
# format=GeneratedQuery.model_json_schema(),
|
|
||||||
# )
|
|
||||||
|
|
||||||
# query = json.loads(
|
|
||||||
# json.loads(response["message"]["content"])["extracted_metadata_fields"]
|
|
||||||
# )
|
|
||||||
date_key = list(query["created_date"].keys())[0]
|
date_key = list(query["created_date"].keys())[0]
|
||||||
query["created_date"][date_key] = self.date_to_epoch(
|
query["created_date"][date_key] = self.date_to_epoch(
|
||||||
query["created_date"][date_key]
|
query["created_date"][date_key]
|
||||||
|
|||||||
20
request.py
20
request.py
@@ -1,11 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import httpx
|
import httpx
|
||||||
|
import logging
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class PaperlessNGXService:
|
class PaperlessNGXService:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -17,7 +20,16 @@ class PaperlessNGXService:
|
|||||||
def get_data(self):
|
def get_data(self):
|
||||||
print(f"Getting data from: {self.url}")
|
print(f"Getting data from: {self.url}")
|
||||||
r = httpx.get(self.url, headers=self.headers)
|
r = httpx.get(self.url, headers=self.headers)
|
||||||
return r.json()["results"]
|
results = r.json()["results"]
|
||||||
|
|
||||||
|
nextLink = r.json().get("next")
|
||||||
|
|
||||||
|
while nextLink:
|
||||||
|
r = httpx.get(nextLink, headers=self.headers)
|
||||||
|
results += r.json()["results"]
|
||||||
|
nextLink = r.json().get("next")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def get_doc_by_id(self, doc_id: int):
|
def get_doc_by_id(self, doc_id: int):
|
||||||
url = f"http://{os.getenv('BASE_URL')}/api/documents/{doc_id}/"
|
url = f"http://{os.getenv('BASE_URL')}/api/documents/{doc_id}/"
|
||||||
@@ -45,12 +57,12 @@ class PaperlessNGXService:
|
|||||||
|
|
||||||
def upload_description(self, description_filepath, file, title, exif_date: str):
|
def upload_description(self, description_filepath, file, title, exif_date: str):
|
||||||
POST_URL = f"http://{os.getenv('BASE_URL')}/api/documents/post_document/"
|
POST_URL = f"http://{os.getenv('BASE_URL')}/api/documents/post_document/"
|
||||||
files = {'document': ('description_filepath', file, 'application/txt')}
|
files = {"document": ("description_filepath", file, "application/txt")}
|
||||||
data = {
|
data = {
|
||||||
"title": title,
|
"title": title,
|
||||||
"create": exif_date,
|
"create": exif_date,
|
||||||
"document_type": 3
|
"document_type": 3,
|
||||||
"tags": [7]
|
"tags": [7],
|
||||||
}
|
}
|
||||||
|
|
||||||
r = httpx.post(POST_URL, headers=self.headers, data=data, files=files)
|
r = httpx.post(POST_URL, headers=self.headers, data=data, files=files)
|
||||||
|
|||||||
Reference in New Issue
Block a user