Adding more embeddings
This commit is contained in:
20
main.py
20
main.py
@@ -7,6 +7,10 @@ from math import ceil
|
|||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
|
|
||||||
|
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||||
|
OllamaEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
client = chromadb.EphemeralClient()
|
client = chromadb.EphemeralClient()
|
||||||
collection = client.create_collection(name="docs")
|
collection = client.create_collection(name="docs")
|
||||||
|
|
||||||
@@ -29,6 +33,11 @@ class Chunk:
|
|||||||
|
|
||||||
class Chunker:
|
class Chunker:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
self.embedding_fx = OllamaEmbeddingFunction(
|
||||||
|
url="http://localhost:11434",
|
||||||
|
model_name="mxbai-embed-large",
|
||||||
|
)
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def chunk_document(self, document: str, chunk_size: int = 300) -> list[Chunk]:
|
def chunk_document(self, document: str, chunk_size: int = 300) -> list[Chunk]:
|
||||||
@@ -47,15 +56,20 @@ class Chunker:
|
|||||||
)
|
)
|
||||||
text_chunk = document[curr_pos:to_pos]
|
text_chunk = document[curr_pos:to_pos]
|
||||||
|
|
||||||
|
embedding = self.embedding_fx([text_chunk])
|
||||||
collection.add(
|
collection.add(
|
||||||
ids=[str(doc_uuid) + ":" + str(i)],
|
ids=[str(doc_uuid) + ":" + str(i)],
|
||||||
documents=[text_chunk],
|
documents=[text_chunk],
|
||||||
|
embeddings=embedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
# Setup
|
embedding_fx = OllamaEmbeddingFunction(
|
||||||
|
url="http://localhost:11434",
|
||||||
|
model_name="mxbai-embed-large",
|
||||||
|
)
|
||||||
|
|
||||||
# Step 1: Get the text
|
# Step 1: Get the text
|
||||||
ppngx = PaperlessNGXService()
|
ppngx = PaperlessNGXService()
|
||||||
@@ -71,8 +85,8 @@ for text in texts:
|
|||||||
|
|
||||||
# Ask
|
# Ask
|
||||||
input = "How many teeth has Simba had removed?"
|
input = "How many teeth has Simba had removed?"
|
||||||
response = ollama.embed(model="mxbai-embed-large", input=input)
|
embeddings = embedding_fx(input=[input])
|
||||||
results = collection.query(query_texts=[input], n_results=1)
|
results = collection.query(query_texts=[input], query_embeddings=embeddings)
|
||||||
print(results)
|
print(results)
|
||||||
# Generate
|
# Generate
|
||||||
output = ollama.generate(
|
output = ollama.generate(
|
||||||
|
|||||||
Reference in New Issue
Block a user