Compare commits
3 Commits
5c5125e662
...
b698109183
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b698109183 | ||
|
|
943a22401b | ||
|
|
994b3fdf1f |
27
main.py
27
main.py
@@ -1,4 +1,5 @@
|
||||
import ollama
|
||||
import os
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from request import PaperlessNGXService
|
||||
@@ -7,9 +8,17 @@ from math import ceil
|
||||
|
||||
import chromadb
|
||||
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
client = chromadb.EphemeralClient()
|
||||
collection = client.create_collection(name="docs")
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Chunk:
|
||||
def __init__(
|
||||
@@ -29,6 +38,11 @@ class Chunk:
|
||||
|
||||
class Chunker:
|
||||
def __init__(self) -> None:
|
||||
self.embedding_fx = OllamaEmbeddingFunction(
|
||||
url=os.getenv("OLLAMA_URL", ""),
|
||||
model_name="mxbai-embed-large",
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
def chunk_document(self, document: str, chunk_size: int = 300) -> list[Chunk]:
|
||||
@@ -47,15 +61,20 @@ class Chunker:
|
||||
)
|
||||
text_chunk = document[curr_pos:to_pos]
|
||||
|
||||
embedding = self.embedding_fx([text_chunk])
|
||||
collection.add(
|
||||
ids=[str(doc_uuid) + ":" + str(i)],
|
||||
documents=[text_chunk],
|
||||
embeddings=embedding,
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
# Setup
|
||||
embedding_fx = OllamaEmbeddingFunction(
|
||||
url=os.getenv("OLLAMA_URL", ""),
|
||||
model_name="mxbai-embed-large",
|
||||
)
|
||||
|
||||
# Step 1: Get the text
|
||||
ppngx = PaperlessNGXService()
|
||||
@@ -70,9 +89,9 @@ for text in texts:
|
||||
chunker.chunk_document(document=text)
|
||||
|
||||
# Ask
|
||||
input = "How many teeth has Simba had removed?"
|
||||
response = ollama.embed(model="mxbai-embed-large", input=input)
|
||||
results = collection.query(query_texts=[input], n_results=1)
|
||||
input = "How many teeth has Simba had removed? Who is his current vet?"
|
||||
embeddings = embedding_fx(input=[input])
|
||||
results = collection.query(query_texts=[input], query_embeddings=embeddings)
|
||||
print(results)
|
||||
# Generate
|
||||
output = ollama.generate(
|
||||
|
||||
@@ -17,3 +17,8 @@ class PaperlessNGXService:
|
||||
print(f"Getting data from: {self.url}")
|
||||
r = httpx.get(self.url, headers=self.headers)
|
||||
return r.json()["results"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pp = PaperlessNGXService()
|
||||
print(pp.get_data()[0].keys())
|
||||
|
||||
Reference in New Issue
Block a user