Reducing startup time/cost
This commit is contained in:
@@ -8,6 +8,7 @@ from chromadb.utils.embedding_functions.openai_embedding_function import (
|
|||||||
OpenAIEmbeddingFunction,
|
OpenAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from llm import LLMClient
|
||||||
|
|
||||||
|
|
||||||
USE_OPENAI = os.getenv("OPENAI_API_KEY") != None
|
USE_OPENAI = os.getenv("OPENAI_API_KEY") != None
|
||||||
@@ -91,9 +92,10 @@ class Chunker:
|
|||||||
|
|
||||||
def __init__(self, collection) -> None:
|
def __init__(self, collection) -> None:
|
||||||
self.collection = collection
|
self.collection = collection
|
||||||
|
self.llm_client = LLMClient()
|
||||||
|
|
||||||
def embedding_fx(self, inputs):
|
def embedding_fx(self, inputs):
|
||||||
if USE_OPENAI:
|
if self.llm_client.PROVIDER == "openai":
|
||||||
openai_embedding_fx = OpenAIEmbeddingFunction(
|
openai_embedding_fx = OpenAIEmbeddingFunction(
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
model_name="text-embedding-3-small",
|
model_name="text-embedding-3-small",
|
||||||
|
|||||||
15
llm.py
15
llm.py
@@ -16,7 +16,7 @@ class LLMClient:
|
|||||||
self.ollama_client = Client(
|
self.ollama_client = Client(
|
||||||
host=os.getenv("OLLAMA_URL", "http://localhost:11434")
|
host=os.getenv("OLLAMA_URL", "http://localhost:11434")
|
||||||
)
|
)
|
||||||
client.chat(
|
self.ollama_client.chat(
|
||||||
model="gemma3:4b", messages=[{"role": "system", "content": "test"}]
|
model="gemma3:4b", messages=[{"role": "system", "content": "test"}]
|
||||||
)
|
)
|
||||||
self.PROVIDER = "ollama"
|
self.PROVIDER = "ollama"
|
||||||
@@ -35,9 +35,16 @@ class LLMClient:
|
|||||||
if self.PROVIDER == "ollama":
|
if self.PROVIDER == "ollama":
|
||||||
response = self.ollama_client.chat(
|
response = self.ollama_client.chat(
|
||||||
model="gemma3:4b",
|
model="gemma3:4b",
|
||||||
prompt=prompt,
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": system_prompt,
|
||||||
|
},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
)
|
)
|
||||||
output = response["response"]
|
print(response)
|
||||||
|
output = response.message.content
|
||||||
elif self.PROVIDER == "openai":
|
elif self.PROVIDER == "openai":
|
||||||
response = self.openai_client.responses.create(
|
response = self.openai_client.responses.create(
|
||||||
model="gpt-4o-mini",
|
model="gpt-4o-mini",
|
||||||
@@ -51,6 +58,8 @@ class LLMClient:
|
|||||||
)
|
)
|
||||||
output = response.output_text
|
output = response.output_text
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
client = Client()
|
client = Client()
|
||||||
|
|||||||
42
main.py
42
main.py
@@ -1,6 +1,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sqlite3
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -15,6 +16,7 @@ from query import QueryGenerator
|
|||||||
from cleaner import pdf_to_image, summarize_pdf_image
|
from cleaner import pdf_to_image, summarize_pdf_image
|
||||||
from llm import LLMClient
|
from llm import LLMClient
|
||||||
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -25,7 +27,7 @@ USE_OPENAI = os.getenv("OPENAI_API_KEY") != None
|
|||||||
ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
||||||
|
|
||||||
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
|
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
|
||||||
simba_docs = client.get_or_create_collection(name="simba_docs2")
|
simba_docs = client.get_or_create_collection(name="simba_docs3")
|
||||||
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
|
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@@ -76,15 +78,22 @@ def chunk_data(docs: list[dict[str, Union[str, Any]]], collection):
|
|||||||
|
|
||||||
print(f"chunking {len(docs)} documents")
|
print(f"chunking {len(docs)} documents")
|
||||||
texts: list[str] = [doc["content"] for doc in docs]
|
texts: list[str] = [doc["content"] for doc in docs]
|
||||||
for index, text in enumerate(texts):
|
with sqlite3.connect("visited.db") as conn:
|
||||||
metadata = {
|
to_insert = []
|
||||||
"created_date": date_to_epoch(docs[index]["created_date"]),
|
c = conn.cursor()
|
||||||
"filename": docs[index]["original_file_name"],
|
for index, text in enumerate(texts):
|
||||||
}
|
metadata = {
|
||||||
chunker.chunk_document(
|
"created_date": date_to_epoch(docs[index]["created_date"]),
|
||||||
document=text,
|
"filename": docs[index]["original_file_name"],
|
||||||
metadata=metadata,
|
}
|
||||||
)
|
chunker.chunk_document(
|
||||||
|
document=text,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
to_insert.append((docs[index]["id"],))
|
||||||
|
|
||||||
|
c.executemany("INSERT INTO indexed_documents (paperless_id) values (?)", to_insert)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(texts: list[str], collection):
|
def chunk_text(texts: list[str], collection):
|
||||||
@@ -160,6 +169,18 @@ def consult_simba_oracle(input: str):
|
|||||||
collection=simba_docs,
|
collection=simba_docs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def filter_indexed_files(docs):
|
||||||
|
with sqlite3.connect("visited.db") as conn:
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute("CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)")
|
||||||
|
c.execute("SELECT paperless_id FROM indexed_documents")
|
||||||
|
rows = c.fetchall()
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
visited = {row[0] for row in rows}
|
||||||
|
return [doc for doc in docs if doc["id"] not in visited]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -167,6 +188,7 @@ if __name__ == "__main__":
|
|||||||
print("Fetching documents from Paperless-NGX")
|
print("Fetching documents from Paperless-NGX")
|
||||||
ppngx = PaperlessNGXService()
|
ppngx = PaperlessNGXService()
|
||||||
docs = ppngx.get_data()
|
docs = ppngx.get_data()
|
||||||
|
docs = filter_indexed_files(docs)
|
||||||
print(f"Fetched {len(docs)} documents")
|
print(f"Fetched {len(docs)} documents")
|
||||||
#
|
#
|
||||||
print("Chunking documents now ...")
|
print("Chunking documents now ...")
|
||||||
|
|||||||
Reference in New Issue
Block a user