Adding some funny stuff
This commit is contained in:
158
main.py
158
main.py
@@ -1,102 +1,84 @@
|
||||
import logging
|
||||
|
||||
import argparse
|
||||
import chromadb
|
||||
import ollama
|
||||
import os
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
|
||||
from request import PaperlessNGXService
|
||||
from chunker import Chunker
|
||||
|
||||
from math import ceil
|
||||
|
||||
import chromadb
|
||||
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
client = chromadb.EphemeralClient()
|
||||
collection = client.create_collection(name="docs")
|
||||
client = chromadb.PersistentClient(path="/Users/ryanchen/Programs/raggr/chromadb")
|
||||
simba_docs = client.get_or_create_collection(name="simba_docs")
|
||||
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="An LLM tool to query information about Simba <3"
|
||||
)
|
||||
|
||||
parser.add_argument("query", type=str, help="questions about simba's health")
|
||||
parser.add_argument(
|
||||
"--reindex", action="store_true", help="re-index the simba documents"
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Chunk:
|
||||
def __init__(
|
||||
self,
|
||||
text: str,
|
||||
size: int,
|
||||
document_id: UUID,
|
||||
chunk_id: int,
|
||||
embedding,
|
||||
):
|
||||
self.text = text
|
||||
self.size = size
|
||||
self.document_id = document_id
|
||||
self.chunk_id = chunk_id
|
||||
self.embedding = embedding
|
||||
def chunk_data(texts: list[str], collection):
|
||||
# Step 2: Create chunks
|
||||
chunker = Chunker(collection)
|
||||
|
||||
print(f"chunking {len(texts)} documents")
|
||||
for text in texts[: len(texts) // 2]:
|
||||
chunker.chunk_document(document=text)
|
||||
|
||||
|
||||
class Chunker:
|
||||
def __init__(self) -> None:
|
||||
self.embedding_fx = OllamaEmbeddingFunction(
|
||||
url=os.getenv("OLLAMA_URL", ""),
|
||||
model_name="mxbai-embed-large",
|
||||
def consult_oracle(input: str, collection):
|
||||
# Ask
|
||||
embeddings = Chunker.embedding_fx(input=[input])
|
||||
results = collection.query(query_texts=[input], query_embeddings=embeddings)
|
||||
print(results)
|
||||
|
||||
# Generate
|
||||
output = ollama.generate(
|
||||
model="gemma3n:e4b",
|
||||
prompt=f"You are a helpful assistant that understandings veterinary terms. Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}",
|
||||
)
|
||||
|
||||
print(output["response"])
|
||||
|
||||
|
||||
def paperless_workflow(input):
|
||||
# Step 1: Get the text
|
||||
ppngx = PaperlessNGXService()
|
||||
docs = ppngx.get_data()
|
||||
texts = [doc["content"] for doc in docs]
|
||||
|
||||
chunk_data(texts, collection=simba_docs)
|
||||
consult_oracle(input, simba_docs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.reindex:
|
||||
logging.info(msg="Fetching documents from Paperless-NGX")
|
||||
ppngx = PaperlessNGXService()
|
||||
docs = ppngx.get_data()
|
||||
texts = [doc["content"] for doc in docs]
|
||||
logging.info(msg=f"Fetched {len(texts)} documents")
|
||||
|
||||
logging.info(msg="Chunking documents now ...")
|
||||
chunk_data(texts, collection=simba_docs)
|
||||
logging.info(msg="Done chunking documents")
|
||||
|
||||
if args.query:
|
||||
logging.info("Consulting oracle ...")
|
||||
consult_oracle(
|
||||
input=args.query,
|
||||
collection=simba_docs,
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
def chunk_document(self, document: str, chunk_size: int = 300) -> list[Chunk]:
|
||||
doc_uuid = uuid4()
|
||||
|
||||
chunks = []
|
||||
num_chunks = ceil(len(document) / chunk_size)
|
||||
document_length = len(document)
|
||||
|
||||
for i in range(num_chunks):
|
||||
curr_pos = i * num_chunks
|
||||
to_pos = (
|
||||
curr_pos + num_chunks
|
||||
if curr_pos + num_chunks < document_length
|
||||
else document_length
|
||||
)
|
||||
text_chunk = document[curr_pos:to_pos]
|
||||
|
||||
embedding = self.embedding_fx([text_chunk])
|
||||
collection.add(
|
||||
ids=[str(doc_uuid) + ":" + str(i)],
|
||||
documents=[text_chunk],
|
||||
embeddings=embedding,
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
embedding_fx = OllamaEmbeddingFunction(
|
||||
url=os.getenv("OLLAMA_URL", ""),
|
||||
model_name="mxbai-embed-large",
|
||||
)
|
||||
|
||||
# Step 1: Get the text
|
||||
ppngx = PaperlessNGXService()
|
||||
docs = ppngx.get_data()
|
||||
texts = [doc["content"] for doc in docs]
|
||||
|
||||
# Step 2: Create chunks
|
||||
chunker = Chunker()
|
||||
|
||||
print(f"chunking {len(texts)} documents")
|
||||
for text in texts:
|
||||
chunker.chunk_document(document=text)
|
||||
|
||||
# Ask
|
||||
input = "How many teeth has Simba had removed? Who is his current vet?"
|
||||
embeddings = embedding_fx(input=[input])
|
||||
results = collection.query(query_texts=[input], query_embeddings=embeddings)
|
||||
print(results)
|
||||
# Generate
|
||||
output = ollama.generate(
|
||||
model="gemma3n:e4b",
|
||||
prompt=f"Using this data: {results}. Respond to this prompt: {input}",
|
||||
)
|
||||
|
||||
print(output["response"])
|
||||
else:
|
||||
print("please provide a query")
|
||||
|
||||
Reference in New Issue
Block a user