Replace ChromaDB with pgvector for vector storage
Consolidate onto PostgreSQL by using pgvector instead of a separate ChromaDB instance. This removes a Docker volume, a large dependency, and simplifies the stack without meaningful performance impact at our document scale. - Swap langchain-chroma for langchain-postgres (PGVector) - Use pgvector/pgvector:pg16 Docker image with init script - Lazy-initialize vector store to avoid eager DB connections - Add SQL helpers for stats/delete/list (replacing _collection access) - Remove legacy main.py, chunker, petmd scraper, and /api/query endpoint Re-index required after deploy (POST /api/rag/index + /index-obsidian). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -19,11 +19,6 @@ BASE_URL=192.168.1.5:8000
|
|||||||
LLAMA_SERVER_URL=http://192.168.1.213:8080/v1
|
LLAMA_SERVER_URL=http://192.168.1.213:8080/v1
|
||||||
LLAMA_MODEL_NAME=llama-3.1-8b-instruct
|
LLAMA_MODEL_NAME=llama-3.1-8b-instruct
|
||||||
|
|
||||||
# ChromaDB Configuration
|
|
||||||
# For Docker: This is automatically set to /app/data/chromadb
|
|
||||||
# For local development: Set to a local directory path
|
|
||||||
CHROMADB_PATH=./data/chromadb
|
|
||||||
|
|
||||||
# OpenAI Configuration
|
# OpenAI Configuration
|
||||||
OPENAI_API_KEY=your-openai-api-key
|
OPENAI_API_KEY=your-openai-api-key
|
||||||
|
|
||||||
|
|||||||
@@ -13,9 +13,6 @@ wheels/
|
|||||||
.env
|
.env
|
||||||
|
|
||||||
# Database files
|
# Database files
|
||||||
chromadb/
|
|
||||||
chromadb_openai/
|
|
||||||
chroma_db/
|
|
||||||
database/
|
database/
|
||||||
*.db
|
*.db
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|||||||
|
|
||||||
## Project Overview
|
## Project Overview
|
||||||
|
|
||||||
SimbaRAG is a RAG (Retrieval-Augmented Generation) conversational AI system for querying information about Simba (a cat). It ingests documents from Paperless-NGX, stores embeddings in ChromaDB, and uses LLMs (Ollama or OpenAI) to answer questions.
|
SimbaRAG is a RAG (Retrieval-Augmented Generation) conversational AI system for querying information about Simba (a cat). It ingests documents from Paperless-NGX, stores embeddings in PostgreSQL via pgvector, and uses LLMs (Ollama or OpenAI) to answer questions.
|
||||||
|
|
||||||
## Commands
|
## Commands
|
||||||
|
|
||||||
@@ -54,9 +54,8 @@ docker compose up -d
|
|||||||
│ Docker Compose │
|
│ Docker Compose │
|
||||||
├─────────────────────────────────────────────────────────────┤
|
├─────────────────────────────────────────────────────────────┤
|
||||||
│ raggr (port 8080) │ postgres (port 5432) │
|
│ raggr (port 8080) │ postgres (port 5432) │
|
||||||
│ ├── Quart backend │ PostgreSQL 16 │
|
│ ├── Quart backend │ PostgreSQL 16 + pgvector│
|
||||||
│ ├── React frontend (served) │ │
|
│ └── React frontend (served) │ │
|
||||||
│ └── ChromaDB (volume) │ │
|
|
||||||
└─────────────────────────────────────────────────────────────┘
|
└─────────────────────────────────────────────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
+2
-3
@@ -37,15 +37,14 @@ WORKDIR /app/raggr-frontend
|
|||||||
RUN yarn install && yarn build
|
RUN yarn install && yarn build
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Create ChromaDB and database directories
|
# Create database directory
|
||||||
RUN mkdir -p /app/chromadb /app/database
|
RUN mkdir -p /app/database
|
||||||
|
|
||||||
# Expose port
|
# Expose port
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV PYTHONPATH=/app
|
ENV PYTHONPATH=/app
|
||||||
ENV CHROMADB_PATH=/app/chromadb
|
|
||||||
|
|
||||||
# Run the startup script
|
# Run the startup script
|
||||||
CMD ["./startup.sh"]
|
CMD ["./startup.sh"]
|
||||||
|
|||||||
+2
-3
@@ -34,16 +34,15 @@ COPY . .
|
|||||||
WORKDIR /app/raggr-frontend
|
WORKDIR /app/raggr-frontend
|
||||||
RUN yarn build
|
RUN yarn build
|
||||||
|
|
||||||
# Create ChromaDB and database directories
|
# Create database directory
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
RUN mkdir -p /app/chromadb /app/database
|
RUN mkdir -p /app/database
|
||||||
|
|
||||||
# Make startup script executable
|
# Make startup script executable
|
||||||
RUN chmod +x /app/startup-dev.sh
|
RUN chmod +x /app/startup-dev.sh
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV PYTHONPATH=/app
|
ENV PYTHONPATH=/app
|
||||||
ENV CHROMADB_PATH=/app/chromadb
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
# Expose port
|
# Expose port
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from quart import Quart, jsonify, render_template, request, send_from_directory
|
from quart import Quart, jsonify, render_template, send_from_directory
|
||||||
from quart_jwt_extended import JWTManager, get_jwt_identity, jwt_refresh_token_required
|
from quart_jwt_extended import JWTManager, get_jwt_identity, jwt_refresh_token_required
|
||||||
from tortoise import Tortoise
|
from tortoise import Tortoise
|
||||||
|
|
||||||
@@ -15,7 +15,6 @@ import blueprints.users
|
|||||||
import blueprints.whatsapp
|
import blueprints.whatsapp
|
||||||
import blueprints.users.models
|
import blueprints.users.models
|
||||||
from config.db import TORTOISE_CONFIG
|
from config.db import TORTOISE_CONFIG
|
||||||
from main import consult_simba_oracle
|
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -78,39 +77,6 @@ async def serve_react_app(path):
|
|||||||
return await render_template("index.html")
|
return await render_template("index.html")
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/query", methods=["POST"])
|
|
||||||
@jwt_refresh_token_required
|
|
||||||
async def query():
|
|
||||||
current_user_uuid = get_jwt_identity()
|
|
||||||
user = await blueprints.users.models.User.get(id=current_user_uuid)
|
|
||||||
data = await request.get_json()
|
|
||||||
query = data.get("query")
|
|
||||||
conversation_id = data.get("conversation_id")
|
|
||||||
conversation = await blueprints.conversation.logic.get_conversation_by_id(
|
|
||||||
conversation_id
|
|
||||||
)
|
|
||||||
await conversation.fetch_related("messages")
|
|
||||||
await blueprints.conversation.logic.add_message_to_conversation(
|
|
||||||
conversation=conversation,
|
|
||||||
message=query,
|
|
||||||
speaker="user",
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
transcript = await blueprints.conversation.logic.get_conversation_transcript(
|
|
||||||
user=user, conversation=conversation
|
|
||||||
)
|
|
||||||
|
|
||||||
response = consult_simba_oracle(input=query, transcript=transcript)
|
|
||||||
await blueprints.conversation.logic.add_message_to_conversation(
|
|
||||||
conversation=conversation,
|
|
||||||
message=response,
|
|
||||||
speaker="simba",
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
return jsonify({"response": response})
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/messages", methods=["GET"])
|
@app.route("/api/messages", methods=["GET"])
|
||||||
@jwt_refresh_token_required
|
@jwt_refresh_token_required
|
||||||
async def get_messages():
|
async def get_messages():
|
||||||
|
|||||||
@@ -328,7 +328,7 @@ async def obsidian_search_notes(query: str) -> str:
|
|||||||
return "Obsidian integration is not configured. Please set OBSIDIAN_VAULT_PATH environment variable."
|
return "Obsidian integration is not configured. Please set OBSIDIAN_VAULT_PATH environment variable."
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Query ChromaDB for obsidian documents
|
# Query vector store for obsidian documents
|
||||||
serialized, docs = await query_vector_store(query=query)
|
serialized, docs = await query_vector_store(query=query)
|
||||||
return serialized
|
return serialized
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,12 @@
|
|||||||
from quart import Blueprint, jsonify
|
from quart import Blueprint, jsonify
|
||||||
from quart_jwt_extended import jwt_refresh_token_required
|
from quart_jwt_extended import jwt_refresh_token_required
|
||||||
|
|
||||||
from .logic import fetch_obsidian_documents, get_vector_store_stats, index_documents, index_obsidian_documents, vector_store
|
from .logic import (
|
||||||
|
delete_all_documents,
|
||||||
|
get_vector_store_stats,
|
||||||
|
index_documents,
|
||||||
|
index_obsidian_documents,
|
||||||
|
)
|
||||||
from blueprints.users.decorators import admin_required
|
from blueprints.users.decorators import admin_required
|
||||||
|
|
||||||
rag_blueprint = Blueprint("rag_api", __name__, url_prefix="/api/rag")
|
rag_blueprint = Blueprint("rag_api", __name__, url_prefix="/api/rag")
|
||||||
@@ -32,14 +37,7 @@ async def trigger_index():
|
|||||||
async def trigger_reindex():
|
async def trigger_reindex():
|
||||||
"""Clear and reindex all documents. Admin only."""
|
"""Clear and reindex all documents. Admin only."""
|
||||||
try:
|
try:
|
||||||
# Clear existing documents
|
delete_all_documents()
|
||||||
collection = vector_store._collection
|
|
||||||
all_docs = collection.get()
|
|
||||||
|
|
||||||
if all_docs["ids"]:
|
|
||||||
collection.delete(ids=all_docs["ids"])
|
|
||||||
|
|
||||||
# Reindex
|
|
||||||
await index_documents()
|
await index_documents()
|
||||||
stats = get_vector_store_stats()
|
stats = get_vector_store_stats()
|
||||||
return jsonify({"status": "success", "stats": stats})
|
return jsonify({"status": "success", "stats": stats})
|
||||||
|
|||||||
+116
-25
@@ -1,11 +1,13 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain_chroma import Chroma
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
from langchain_postgres import PGVector
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
|
||||||
from .fetchers import PaperlessNGXService
|
from .fetchers import PaperlessNGXService
|
||||||
from utils.obsidian_service import ObsidianService
|
from utils.obsidian_service import ObsidianService
|
||||||
@@ -13,13 +15,39 @@ from utils.obsidian_service import ObsidianService
|
|||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
|
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
|
||||||
|
|
||||||
vector_store = Chroma(
|
# Convert Tortoise-style postgres:// URL to SQLAlchemy-style postgresql+psycopg://
|
||||||
collection_name="simba_docs",
|
_db_url = os.getenv(
|
||||||
embedding_function=embeddings,
|
"DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr"
|
||||||
persist_directory=os.getenv("CHROMADB_PATH", ""),
|
|
||||||
)
|
)
|
||||||
|
_pgvector_url = _db_url.replace("postgres://", "postgresql+psycopg://")
|
||||||
|
|
||||||
|
# Lazy-initialized vector store (defers DB connection to first use)
|
||||||
|
_vector_store = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_vector_store() -> PGVector:
|
||||||
|
global _vector_store
|
||||||
|
if _vector_store is None:
|
||||||
|
_vector_store = PGVector(
|
||||||
|
embeddings=embeddings,
|
||||||
|
collection_name="simba_docs",
|
||||||
|
connection=_pgvector_url,
|
||||||
|
use_jsonb=True,
|
||||||
|
create_extension=False, # created by docker init script
|
||||||
|
)
|
||||||
|
return _vector_store
|
||||||
|
|
||||||
|
|
||||||
|
def _get_engine():
|
||||||
|
"""Get a SQLAlchemy engine for direct queries."""
|
||||||
|
if not hasattr(_get_engine, "_engine"):
|
||||||
|
_get_engine._engine = create_engine(_pgvector_url)
|
||||||
|
return _get_engine._engine
|
||||||
|
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=1000, # chunk size (characters)
|
chunk_size=1000, # chunk size (characters)
|
||||||
@@ -28,6 +56,18 @@ text_splitter = RecursiveCharacterTextSplitter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_collection_id():
|
||||||
|
"""Get the UUID of our collection from the langchain_pg_collection table."""
|
||||||
|
engine = _get_engine()
|
||||||
|
with engine.connect() as conn:
|
||||||
|
result = conn.execute(
|
||||||
|
text("SELECT uuid FROM langchain_pg_collection WHERE name = :name"),
|
||||||
|
{"name": "simba_docs"},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row[0] if row else None
|
||||||
|
|
||||||
|
|
||||||
def date_to_epoch(date_str: str) -> float:
|
def date_to_epoch(date_str: str) -> float:
|
||||||
split_date = date_str.split("-")
|
split_date = date_str.split("-")
|
||||||
date = datetime.datetime(
|
date = datetime.datetime(
|
||||||
@@ -63,6 +103,7 @@ async def index_documents():
|
|||||||
documents = await fetch_documents_from_paperless_ngx()
|
documents = await fetch_documents_from_paperless_ngx()
|
||||||
|
|
||||||
splits = text_splitter.split_documents(documents)
|
splits = text_splitter.split_documents(documents)
|
||||||
|
vector_store = _get_vector_store()
|
||||||
await vector_store.aadd_documents(documents=splits)
|
await vector_store.aadd_documents(documents=splits)
|
||||||
|
|
||||||
|
|
||||||
@@ -92,13 +133,17 @@ async def fetch_obsidian_documents() -> list[Document]:
|
|||||||
"filepath": parsed["filepath"],
|
"filepath": parsed["filepath"],
|
||||||
"tags": parsed["tags"],
|
"tags": parsed["tags"],
|
||||||
"created_at": parsed["metadata"].get("created_at"),
|
"created_at": parsed["metadata"].get("created_at"),
|
||||||
**{k: v for k, v in parsed["metadata"].items() if k not in ["created_at", "created_by"]},
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in parsed["metadata"].items()
|
||||||
|
if k not in ["created_at", "created_by"]
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading {md_path}: {e}")
|
logger.warning(f"Error reading {md_path}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
@@ -109,26 +154,25 @@ async def index_obsidian_documents():
|
|||||||
|
|
||||||
Deletes existing obsidian source chunks before re-indexing.
|
Deletes existing obsidian source chunks before re-indexing.
|
||||||
"""
|
"""
|
||||||
obsidian_service = ObsidianService()
|
|
||||||
documents = await fetch_obsidian_documents()
|
documents = await fetch_obsidian_documents()
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
print("No Obsidian documents found to index")
|
logger.info("No Obsidian documents found to index")
|
||||||
return {"indexed": 0}
|
return {"indexed": 0}
|
||||||
|
|
||||||
# Delete existing obsidian chunks
|
# Delete existing obsidian chunks
|
||||||
existing_results = vector_store.get(where={"source": "obsidian"})
|
delete_documents_by_metadata("source", "obsidian")
|
||||||
if existing_results.get("ids"):
|
|
||||||
await vector_store.adelete(existing_results["ids"])
|
|
||||||
|
|
||||||
# Split and index documents
|
# Split and index documents
|
||||||
splits = text_splitter.split_documents(documents)
|
splits = text_splitter.split_documents(documents)
|
||||||
|
vector_store = _get_vector_store()
|
||||||
await vector_store.aadd_documents(documents=splits)
|
await vector_store.aadd_documents(documents=splits)
|
||||||
|
|
||||||
return {"indexed": len(documents)}
|
return {"indexed": len(documents)}
|
||||||
|
|
||||||
|
|
||||||
async def query_vector_store(query: str):
|
async def query_vector_store(query: str):
|
||||||
|
vector_store = _get_vector_store()
|
||||||
retrieved_docs = await vector_store.asimilarity_search(query, k=2)
|
retrieved_docs = await vector_store.asimilarity_search(query, k=2)
|
||||||
serialized = "\n\n".join(
|
serialized = "\n\n".join(
|
||||||
(f"Source: {doc.metadata}\nContent: {doc.page_content}")
|
(f"Source: {doc.metadata}\nContent: {doc.page_content}")
|
||||||
@@ -137,32 +181,79 @@ async def query_vector_store(query: str):
|
|||||||
return serialized, retrieved_docs
|
return serialized, retrieved_docs
|
||||||
|
|
||||||
|
|
||||||
|
def delete_all_documents():
|
||||||
|
"""Delete all documents from the vector store collection."""
|
||||||
|
collection_id = _get_collection_id()
|
||||||
|
if not collection_id:
|
||||||
|
return
|
||||||
|
engine = _get_engine()
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text("DELETE FROM langchain_pg_embedding WHERE collection_id = :cid"),
|
||||||
|
{"cid": collection_id},
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def delete_documents_by_metadata(key: str, value: str):
|
||||||
|
"""Delete documents matching a metadata key/value pair."""
|
||||||
|
collection_id = _get_collection_id()
|
||||||
|
if not collection_id:
|
||||||
|
return
|
||||||
|
engine = _get_engine()
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM langchain_pg_embedding "
|
||||||
|
"WHERE collection_id = :cid AND cmetadata->>:key = :value"
|
||||||
|
),
|
||||||
|
{"cid": collection_id, "key": key, "value": value},
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
def get_vector_store_stats():
|
def get_vector_store_stats():
|
||||||
"""Get statistics about the vector store."""
|
"""Get statistics about the vector store."""
|
||||||
collection = vector_store._collection
|
collection_id = _get_collection_id()
|
||||||
count = collection.count()
|
count = 0
|
||||||
|
if collection_id:
|
||||||
|
engine = _get_engine()
|
||||||
|
with engine.connect() as conn:
|
||||||
|
result = conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_id = :cid"
|
||||||
|
),
|
||||||
|
{"cid": collection_id},
|
||||||
|
)
|
||||||
|
count = result.scalar()
|
||||||
return {
|
return {
|
||||||
"total_documents": count,
|
"total_documents": count,
|
||||||
"collection_name": collection.name,
|
"collection_name": "simba_docs",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def list_all_documents(limit: int = 10):
|
def list_all_documents(limit: int = 10):
|
||||||
"""List documents in the vector store with their metadata."""
|
"""List documents in the vector store with their metadata."""
|
||||||
collection = vector_store._collection
|
collection_id = _get_collection_id()
|
||||||
results = collection.get(limit=limit, include=["metadatas", "documents"])
|
if not collection_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
engine = _get_engine()
|
||||||
|
with engine.connect() as conn:
|
||||||
|
result = conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT id, document, cmetadata FROM langchain_pg_embedding "
|
||||||
|
"WHERE collection_id = :cid LIMIT :limit"
|
||||||
|
),
|
||||||
|
{"cid": collection_id, "limit": limit},
|
||||||
|
)
|
||||||
documents = []
|
documents = []
|
||||||
for i, doc_id in enumerate(results["ids"]):
|
for row in result:
|
||||||
documents.append(
|
documents.append(
|
||||||
{
|
{
|
||||||
"id": doc_id,
|
"id": str(row[0]),
|
||||||
"metadata": results["metadatas"][i]
|
"metadata": row[2],
|
||||||
if results.get("metadatas")
|
"content_preview": row[1][:200] if row[1] else None,
|
||||||
else None,
|
|
||||||
"content_preview": results["documents"][i][:200]
|
|
||||||
if results.get("documents")
|
|
||||||
else None,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+2
-4
@@ -2,7 +2,7 @@ version: "3.8"
|
|||||||
|
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: postgres:16-alpine
|
image: pgvector/pgvector:pg16
|
||||||
ports:
|
ports:
|
||||||
- "5432:5432"
|
- "5432:5432"
|
||||||
environment:
|
environment:
|
||||||
@@ -11,6 +11,7 @@ services:
|
|||||||
- POSTGRES_DB=${POSTGRES_DB:-raggr}
|
- POSTGRES_DB=${POSTGRES_DB:-raggr}
|
||||||
volumes:
|
volumes:
|
||||||
- postgres_data:/var/lib/postgresql/data
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
- ./docker/init-pgvector.sql:/docker-entrypoint-initdb.d/init-pgvector.sql
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-raggr}"]
|
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-raggr}"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
@@ -29,7 +30,6 @@ services:
|
|||||||
- PAPERLESS_TOKEN=${PAPERLESS_TOKEN}
|
- PAPERLESS_TOKEN=${PAPERLESS_TOKEN}
|
||||||
- BASE_URL=${BASE_URL}
|
- BASE_URL=${BASE_URL}
|
||||||
- OLLAMA_URL=${OLLAMA_URL:-http://localhost:11434}
|
- OLLAMA_URL=${OLLAMA_URL:-http://localhost:11434}
|
||||||
- CHROMADB_PATH=/app/data/chromadb
|
|
||||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
|
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
|
||||||
- LLAMA_SERVER_URL=${LLAMA_SERVER_URL}
|
- LLAMA_SERVER_URL=${LLAMA_SERVER_URL}
|
||||||
@@ -66,10 +66,8 @@ services:
|
|||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
volumes:
|
volumes:
|
||||||
- chromadb_data:/app/data/chromadb
|
|
||||||
- ./obvault:/app/data/obsidian
|
- ./obvault:/app/data/obsidian
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
chromadb_data:
|
|
||||||
postgres_data:
|
postgres_data:
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
CREATE EXTENSION IF NOT EXISTS vector;
|
||||||
@@ -1,278 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sqlite3
|
|
||||||
import time
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
import chromadb
|
|
||||||
from utils.chunker import Chunker
|
|
||||||
from utils.cleaner import pdf_to_image, summarize_pdf_image
|
|
||||||
from llm import LLMClient
|
|
||||||
from scripts.query import QueryGenerator
|
|
||||||
from utils.request import PaperlessNGXService
|
|
||||||
|
|
||||||
_dotenv_loaded = load_dotenv()
|
|
||||||
|
|
||||||
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
|
|
||||||
simba_docs = client.get_or_create_collection(name="simba_docs2")
|
|
||||||
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="An LLM tool to query information about Simba <3"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("query", type=str, help="questions about simba's health")
|
|
||||||
parser.add_argument(
|
|
||||||
"--reindex", action="store_true", help="re-index the simba documents"
|
|
||||||
)
|
|
||||||
parser.add_argument("--classify", action="store_true", help="test classification")
|
|
||||||
parser.add_argument("--index", help="index a file")
|
|
||||||
|
|
||||||
ppngx = PaperlessNGXService()
|
|
||||||
|
|
||||||
llm_client = LLMClient()
|
|
||||||
|
|
||||||
|
|
||||||
def index_using_pdf_llm(doctypes):
|
|
||||||
logging.info("reindex data...")
|
|
||||||
files = ppngx.get_data()
|
|
||||||
for file in files:
|
|
||||||
document_id: int = file["id"]
|
|
||||||
pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
|
||||||
image_paths = pdf_to_image(filepath=pdf_path)
|
|
||||||
logging.info(f"summarizing {file}")
|
|
||||||
generated_summary = summarize_pdf_image(filepaths=image_paths)
|
|
||||||
file["content"] = generated_summary
|
|
||||||
|
|
||||||
chunk_data(files, simba_docs, doctypes=doctypes)
|
|
||||||
|
|
||||||
|
|
||||||
def date_to_epoch(date_str: str) -> float:
|
|
||||||
split_date = date_str.split("-")
|
|
||||||
date = datetime.datetime(
|
|
||||||
int(split_date[0]),
|
|
||||||
int(split_date[1]),
|
|
||||||
int(split_date[2]),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
|
|
||||||
return date.timestamp()
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_data(docs, collection, doctypes):
|
|
||||||
# Step 2: Create chunks
|
|
||||||
chunker = Chunker(collection)
|
|
||||||
|
|
||||||
logging.info(f"chunking {len(docs)} documents")
|
|
||||||
texts: list[str] = [doc["content"] for doc in docs]
|
|
||||||
with sqlite3.connect("database/visited.db") as conn:
|
|
||||||
to_insert = []
|
|
||||||
c = conn.cursor()
|
|
||||||
for index, text in enumerate(texts):
|
|
||||||
metadata = {
|
|
||||||
"created_date": date_to_epoch(docs[index]["created_date"]),
|
|
||||||
"filename": docs[index]["original_file_name"],
|
|
||||||
"document_type": doctypes.get(docs[index]["document_type"], ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
if doctypes:
|
|
||||||
metadata["type"] = doctypes.get(docs[index]["document_type"])
|
|
||||||
|
|
||||||
chunker.chunk_document(
|
|
||||||
document=text,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
to_insert.append((docs[index]["id"],))
|
|
||||||
|
|
||||||
c.executemany(
|
|
||||||
"INSERT INTO indexed_documents (paperless_id) values (?)", to_insert
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(texts: list[str], collection):
|
|
||||||
chunker = Chunker(collection)
|
|
||||||
|
|
||||||
for index, text in enumerate(texts):
|
|
||||||
metadata = {}
|
|
||||||
chunker.chunk_document(
|
|
||||||
document=text,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def classify_query(query: str, transcript: str) -> bool:
|
|
||||||
logging.info("Starting query generation")
|
|
||||||
qg_start = time.time()
|
|
||||||
qg = QueryGenerator()
|
|
||||||
query_type = qg.get_query_type(input=query, transcript=transcript)
|
|
||||||
logging.info(query_type)
|
|
||||||
qg_end = time.time()
|
|
||||||
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
|
||||||
return query_type == "Simba"
|
|
||||||
|
|
||||||
|
|
||||||
def consult_oracle(
|
|
||||||
input: str,
|
|
||||||
collection,
|
|
||||||
transcript: str = "",
|
|
||||||
):
|
|
||||||
chunker = Chunker(collection)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# Ask
|
|
||||||
logging.info("Starting query generation")
|
|
||||||
qg_start = time.time()
|
|
||||||
qg = QueryGenerator()
|
|
||||||
doctype_query = qg.get_doctype_query(input=input)
|
|
||||||
# metadata_filter = qg.get_query(input)
|
|
||||||
metadata_filter = {**doctype_query}
|
|
||||||
logging.info(metadata_filter)
|
|
||||||
qg_end = time.time()
|
|
||||||
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
|
||||||
|
|
||||||
logging.info("Starting embedding generation")
|
|
||||||
embedding_start = time.time()
|
|
||||||
embeddings = chunker.embedding_fx(inputs=[input])
|
|
||||||
embedding_end = time.time()
|
|
||||||
logging.info(
|
|
||||||
f"Embedding generation took {embedding_end - embedding_start:.2f} seconds"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Starting collection query")
|
|
||||||
query_start = time.time()
|
|
||||||
results = collection.query(
|
|
||||||
query_texts=[input],
|
|
||||||
query_embeddings=embeddings,
|
|
||||||
where=metadata_filter,
|
|
||||||
)
|
|
||||||
query_end = time.time()
|
|
||||||
logging.info(f"Collection query took {query_end - query_start:.2f} seconds")
|
|
||||||
|
|
||||||
# Generate
|
|
||||||
logging.info("Starting LLM generation")
|
|
||||||
llm_start = time.time()
|
|
||||||
system_prompt = "You are a helpful assistant that understands veterinary terms."
|
|
||||||
transcript_prompt = f"Here is the message transcript thus far {transcript}."
|
|
||||||
prompt = f"""Using the following data, help answer the user's query by providing as many details as possible.
|
|
||||||
Using this data: {results}. {transcript_prompt if len(transcript) > 0 else ""}
|
|
||||||
Respond to this prompt: {input}"""
|
|
||||||
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
|
|
||||||
llm_end = time.time()
|
|
||||||
logging.info(f"LLM generation took {llm_end - llm_start:.2f} seconds")
|
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
|
||||||
logging.info(f"Total consult_oracle execution took {total_time:.2f} seconds")
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def llm_chat(input: str, transcript: str = "") -> str:
|
|
||||||
system_prompt = "You are a helpful assistant that understands veterinary terms."
|
|
||||||
transcript_prompt = f"Here is the message transcript thus far {transcript}."
|
|
||||||
prompt = f"""Answer the user in as if you were a cat named Simba. Don't act too catlike. Be assertive.
|
|
||||||
{transcript_prompt if len(transcript) > 0 else ""}
|
|
||||||
Respond to this prompt: {input}"""
|
|
||||||
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def paperless_workflow(input):
|
|
||||||
# Step 1: Get the text
|
|
||||||
ppngx = PaperlessNGXService()
|
|
||||||
docs = ppngx.get_data()
|
|
||||||
|
|
||||||
chunk_data(docs, collection=simba_docs)
|
|
||||||
consult_oracle(input, simba_docs)
|
|
||||||
|
|
||||||
|
|
||||||
def consult_simba_oracle(input: str, transcript: str = ""):
|
|
||||||
is_simba_related = classify_query(query=input, transcript=transcript)
|
|
||||||
|
|
||||||
if is_simba_related:
|
|
||||||
logging.info("Query is related to simba")
|
|
||||||
return consult_oracle(
|
|
||||||
input=input,
|
|
||||||
collection=simba_docs,
|
|
||||||
transcript=transcript,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Query is NOT related to simba")
|
|
||||||
|
|
||||||
return llm_chat(input=input, transcript=transcript)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_indexed_files(docs):
|
|
||||||
with sqlite3.connect("database/visited.db") as conn:
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute(
|
|
||||||
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
|
|
||||||
)
|
|
||||||
c.execute("SELECT paperless_id FROM indexed_documents")
|
|
||||||
rows = c.fetchall()
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
visited = {row[0] for row in rows}
|
|
||||||
return [doc for doc in docs if doc["id"] not in visited]
|
|
||||||
|
|
||||||
|
|
||||||
def reindex():
|
|
||||||
with sqlite3.connect("database/visited.db") as conn:
|
|
||||||
c = conn.cursor()
|
|
||||||
# Ensure the table exists before trying to delete from it
|
|
||||||
c.execute(
|
|
||||||
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
|
|
||||||
)
|
|
||||||
c.execute("DELETE FROM indexed_documents")
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
# Delete all documents from the collection
|
|
||||||
all_docs = simba_docs.get()
|
|
||||||
if all_docs["ids"]:
|
|
||||||
simba_docs.delete(ids=all_docs["ids"])
|
|
||||||
|
|
||||||
logging.info("Fetching documents from Paperless-NGX")
|
|
||||||
ppngx = PaperlessNGXService()
|
|
||||||
docs = ppngx.get_data()
|
|
||||||
docs = filter_indexed_files(docs)
|
|
||||||
logging.info(f"Fetched {len(docs)} documents")
|
|
||||||
|
|
||||||
# Delete all chromadb data
|
|
||||||
ids = simba_docs.get(ids=None, limit=None, offset=0)
|
|
||||||
all_ids = ids["ids"]
|
|
||||||
if len(all_ids) > 0:
|
|
||||||
simba_docs.delete(ids=all_ids)
|
|
||||||
|
|
||||||
# Chunk documents
|
|
||||||
logging.info("Chunking documents now ...")
|
|
||||||
doctype_lookup = ppngx.get_doctypes()
|
|
||||||
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
|
|
||||||
logging.info("Done chunking documents")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = parser.parse_args()
|
|
||||||
if args.reindex:
|
|
||||||
reindex()
|
|
||||||
|
|
||||||
if args.classify:
|
|
||||||
consult_simba_oracle(input="yohohoho testing")
|
|
||||||
consult_simba_oracle(input="write an email")
|
|
||||||
consult_simba_oracle(input="how much does simba weigh")
|
|
||||||
|
|
||||||
if args.query:
|
|
||||||
logging.info("Consulting oracle ...")
|
|
||||||
print(
|
|
||||||
consult_oracle(
|
|
||||||
input=args.query,
|
|
||||||
collection=simba_docs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logging.info("please provide a query")
|
|
||||||
+2
-2
@@ -5,7 +5,8 @@ description = "Add your description here"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"chromadb>=1.1.0",
|
"langchain-postgres>=0.0.13",
|
||||||
|
"psycopg[binary]>=3.1.0",
|
||||||
"python-dotenv>=1.0.0",
|
"python-dotenv>=1.0.0",
|
||||||
"flask>=3.1.2",
|
"flask>=3.1.2",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
@@ -30,7 +31,6 @@ dependencies = [
|
|||||||
"asyncpg>=0.30.0",
|
"asyncpg>=0.30.0",
|
||||||
"langchain-openai>=1.1.6",
|
"langchain-openai>=1.1.6",
|
||||||
"langchain>=1.2.0",
|
"langchain>=1.2.0",
|
||||||
"langchain-chroma>=1.0.0",
|
|
||||||
"langchain-community>=0.4.1",
|
"langchain-community>=0.4.1",
|
||||||
"jq>=1.10.0",
|
"jq>=1.10.0",
|
||||||
"tavily-python>=0.7.17",
|
"tavily-python>=0.7.17",
|
||||||
|
|||||||
@@ -6,19 +6,19 @@ import asyncio
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from blueprints.rag.logic import (
|
from blueprints.rag.logic import (
|
||||||
|
delete_all_documents,
|
||||||
get_vector_store_stats,
|
get_vector_store_stats,
|
||||||
index_documents,
|
index_documents,
|
||||||
list_all_documents,
|
list_all_documents,
|
||||||
vector_store,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def stats():
|
def stats():
|
||||||
"""Show vector store statistics."""
|
"""Show vector store statistics."""
|
||||||
stats = get_vector_store_stats()
|
s = get_vector_store_stats()
|
||||||
print("=== Vector Store Statistics ===")
|
print("=== Vector Store Statistics ===")
|
||||||
print(f"Collection: {stats['collection_name']}")
|
print(f"Collection: {s['collection_name']}")
|
||||||
print(f"Total Documents: {stats['total_documents']}")
|
print(f"Total Documents: {s['total_documents']}")
|
||||||
|
|
||||||
|
|
||||||
async def index():
|
async def index():
|
||||||
@@ -26,23 +26,15 @@ async def index():
|
|||||||
print("Starting indexing process...")
|
print("Starting indexing process...")
|
||||||
print("Fetching documents from Paperless-NGX...")
|
print("Fetching documents from Paperless-NGX...")
|
||||||
await index_documents()
|
await index_documents()
|
||||||
print("✓ Indexing complete!")
|
print("Indexing complete!")
|
||||||
stats()
|
stats()
|
||||||
|
|
||||||
|
|
||||||
async def reindex():
|
async def reindex():
|
||||||
"""Clear and reindex all documents."""
|
"""Clear and reindex all documents."""
|
||||||
print("Clearing existing documents...")
|
print("Clearing existing documents...")
|
||||||
collection = vector_store._collection
|
delete_all_documents()
|
||||||
all_docs = collection.get()
|
print("Cleared")
|
||||||
|
|
||||||
if all_docs["ids"]:
|
|
||||||
print(f"Deleting {len(all_docs['ids'])} existing documents...")
|
|
||||||
collection.delete(ids=all_docs["ids"])
|
|
||||||
print("✓ Cleared")
|
|
||||||
else:
|
|
||||||
print("Collection is already empty")
|
|
||||||
|
|
||||||
await index()
|
await index()
|
||||||
|
|
||||||
|
|
||||||
@@ -113,7 +105,7 @@ Examples:
|
|||||||
print("\n\nOperation cancelled by user")
|
print("\n\nOperation cancelled by user")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n❌ Error: {e}", file=sys.stderr)
|
print(f"\nError: {e}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
from bs4 import BeautifulSoup
|
|
||||||
import chromadb
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
client = chromadb.PersistentClient(path="/Users/ryanchen/Programs/raggr/chromadb")
|
|
||||||
|
|
||||||
# Scrape
|
|
||||||
BASE_URL = "https://www.vet.cornell.edu"
|
|
||||||
LIST_URL = "/departments-centers-and-institutes/cornell-feline-health-center/health-information/feline-health-topics"
|
|
||||||
|
|
||||||
QUERY_URL = BASE_URL + LIST_URL
|
|
||||||
r = httpx.get(QUERY_URL)
|
|
||||||
soup = BeautifulSoup(r.text)
|
|
||||||
|
|
||||||
container = soup.find("div", class_="field-body")
|
|
||||||
a_s = container.find_all("a", href=True)
|
|
||||||
|
|
||||||
new_texts = []
|
|
||||||
|
|
||||||
for link in a_s:
|
|
||||||
endpoint = link["href"]
|
|
||||||
query_url = BASE_URL + endpoint
|
|
||||||
r2 = httpx.get(query_url)
|
|
||||||
article_soup = BeautifulSoup(r2.text)
|
|
||||||
@@ -1,9 +1,6 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
echo "Initializing directories..."
|
|
||||||
mkdir -p /app/data/chromadb
|
|
||||||
|
|
||||||
echo "Rebuilding frontend..."
|
echo "Rebuilding frontend..."
|
||||||
cd /app/raggr-frontend
|
cd /app/raggr-frontend
|
||||||
yarn build
|
yarn build
|
||||||
|
|||||||
@@ -1,139 +0,0 @@
|
|||||||
"""Tests for text preprocessing functions in utils/chunker.py."""
|
|
||||||
|
|
||||||
from utils.chunker import (
|
|
||||||
remove_headers_footers,
|
|
||||||
remove_special_characters,
|
|
||||||
remove_repeated_substrings,
|
|
||||||
remove_extra_spaces,
|
|
||||||
preprocess_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRemoveHeadersFooters:
|
|
||||||
def test_removes_default_header(self):
|
|
||||||
text = "Header Line\nActual content here"
|
|
||||||
result = remove_headers_footers(text)
|
|
||||||
assert "Header" not in result
|
|
||||||
assert "Actual content here" in result
|
|
||||||
|
|
||||||
def test_removes_default_footer(self):
|
|
||||||
text = "Actual content\nFooter Line"
|
|
||||||
result = remove_headers_footers(text)
|
|
||||||
assert "Footer" not in result
|
|
||||||
assert "Actual content" in result
|
|
||||||
|
|
||||||
def test_custom_patterns(self):
|
|
||||||
text = "PAGE 1\nContent\nCopyright 2024"
|
|
||||||
result = remove_headers_footers(
|
|
||||||
text,
|
|
||||||
header_patterns=[r"^PAGE \d+$"],
|
|
||||||
footer_patterns=[r"^Copyright.*$"],
|
|
||||||
)
|
|
||||||
assert "PAGE 1" not in result
|
|
||||||
assert "Copyright" not in result
|
|
||||||
assert "Content" in result
|
|
||||||
|
|
||||||
def test_no_match_preserves_text(self):
|
|
||||||
text = "Just normal content"
|
|
||||||
result = remove_headers_footers(text)
|
|
||||||
assert result == "Just normal content"
|
|
||||||
|
|
||||||
def test_empty_string(self):
|
|
||||||
assert remove_headers_footers("") == ""
|
|
||||||
|
|
||||||
|
|
||||||
class TestRemoveSpecialCharacters:
|
|
||||||
def test_removes_special_chars(self):
|
|
||||||
text = "Hello @world #test $100"
|
|
||||||
result = remove_special_characters(text)
|
|
||||||
assert "@" not in result
|
|
||||||
assert "#" not in result
|
|
||||||
assert "$" not in result
|
|
||||||
|
|
||||||
def test_preserves_allowed_chars(self):
|
|
||||||
text = "Hello, world! How's it going? Yes-no."
|
|
||||||
result = remove_special_characters(text)
|
|
||||||
assert "," in result
|
|
||||||
assert "!" in result
|
|
||||||
assert "'" in result
|
|
||||||
assert "?" in result
|
|
||||||
assert "-" in result
|
|
||||||
assert "." in result
|
|
||||||
|
|
||||||
def test_custom_pattern(self):
|
|
||||||
text = "keep @this but not #that"
|
|
||||||
result = remove_special_characters(text, special_chars=r"[#]")
|
|
||||||
assert "@this" in result
|
|
||||||
assert "#" not in result
|
|
||||||
|
|
||||||
def test_empty_string(self):
|
|
||||||
assert remove_special_characters("") == ""
|
|
||||||
|
|
||||||
|
|
||||||
class TestRemoveRepeatedSubstrings:
|
|
||||||
def test_collapses_dots(self):
|
|
||||||
text = "Item.....Value"
|
|
||||||
result = remove_repeated_substrings(text)
|
|
||||||
assert result == "Item.Value"
|
|
||||||
|
|
||||||
def test_single_dot_preserved(self):
|
|
||||||
text = "End of sentence."
|
|
||||||
result = remove_repeated_substrings(text)
|
|
||||||
assert result == "End of sentence."
|
|
||||||
|
|
||||||
def test_custom_pattern(self):
|
|
||||||
text = "hello---world"
|
|
||||||
result = remove_repeated_substrings(text, pattern=r"-{2,}")
|
|
||||||
# Function always replaces matched pattern with "."
|
|
||||||
assert result == "hello.world"
|
|
||||||
|
|
||||||
def test_empty_string(self):
|
|
||||||
assert remove_repeated_substrings("") == ""
|
|
||||||
|
|
||||||
|
|
||||||
class TestRemoveExtraSpaces:
|
|
||||||
def test_collapses_multiple_blank_lines(self):
|
|
||||||
text = "Line 1\n\n\n\nLine 2"
|
|
||||||
result = remove_extra_spaces(text)
|
|
||||||
# After collapsing newlines to \n\n, then \s+ collapses everything to single spaces
|
|
||||||
assert "\n\n\n" not in result
|
|
||||||
|
|
||||||
def test_collapses_multiple_spaces(self):
|
|
||||||
text = "Hello world"
|
|
||||||
result = remove_extra_spaces(text)
|
|
||||||
assert result == "Hello world"
|
|
||||||
|
|
||||||
def test_strips_whitespace(self):
|
|
||||||
text = " Hello world "
|
|
||||||
result = remove_extra_spaces(text)
|
|
||||||
assert result == "Hello world"
|
|
||||||
|
|
||||||
def test_empty_string(self):
|
|
||||||
assert remove_extra_spaces("") == ""
|
|
||||||
|
|
||||||
|
|
||||||
class TestPreprocessText:
|
|
||||||
def test_full_pipeline(self):
|
|
||||||
text = "Header Info\nHello @world... with spaces\nFooter Info"
|
|
||||||
result = preprocess_text(text)
|
|
||||||
assert "Header" not in result
|
|
||||||
assert "Footer" not in result
|
|
||||||
assert "@" not in result
|
|
||||||
assert "..." not in result
|
|
||||||
assert " " not in result
|
|
||||||
|
|
||||||
def test_preserves_meaningful_content(self):
|
|
||||||
text = "The cat weighs 10 pounds."
|
|
||||||
result = preprocess_text(text)
|
|
||||||
assert "cat" in result
|
|
||||||
assert "10" in result
|
|
||||||
assert "pounds" in result
|
|
||||||
|
|
||||||
def test_empty_string(self):
|
|
||||||
assert preprocess_text("") == ""
|
|
||||||
|
|
||||||
def test_already_clean(self):
|
|
||||||
text = "Simple clean text here."
|
|
||||||
result = preprocess_text(text)
|
|
||||||
assert "Simple" in result
|
|
||||||
assert "clean" in result
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
import os
|
|
||||||
from math import ceil
|
|
||||||
import re
|
|
||||||
from typing import Union
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
|
||||||
OpenAIEmbeddingFunction,
|
|
||||||
)
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from llm import LLMClient
|
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
|
|
||||||
if header_patterns is None:
|
|
||||||
header_patterns = [r"^.*Header.*$"]
|
|
||||||
if footer_patterns is None:
|
|
||||||
footer_patterns = [r"^.*Footer.*$"]
|
|
||||||
|
|
||||||
for pattern in header_patterns + footer_patterns:
|
|
||||||
text = re.sub(pattern, "", text, flags=re.MULTILINE)
|
|
||||||
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def remove_special_characters(text, special_chars=None):
|
|
||||||
if special_chars is None:
|
|
||||||
special_chars = r"[^A-Za-z0-9\s\.,;:\'\"\?\!\-]"
|
|
||||||
|
|
||||||
text = re.sub(special_chars, "", text)
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def remove_repeated_substrings(text, pattern=r"\.{2,}"):
|
|
||||||
text = re.sub(pattern, ".", text)
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def remove_extra_spaces(text):
|
|
||||||
text = re.sub(r"\n\s*\n", "\n\n", text)
|
|
||||||
text = re.sub(r"\s+", " ", text)
|
|
||||||
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_text(text):
|
|
||||||
# Remove headers and footers
|
|
||||||
text = remove_headers_footers(text)
|
|
||||||
|
|
||||||
# Remove special characters
|
|
||||||
text = remove_special_characters(text)
|
|
||||||
|
|
||||||
# Remove repeated substrings like dots
|
|
||||||
text = remove_repeated_substrings(text)
|
|
||||||
|
|
||||||
# Remove extra spaces between lines and within lines
|
|
||||||
text = remove_extra_spaces(text)
|
|
||||||
|
|
||||||
# Additional cleaning steps can be added here
|
|
||||||
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
class Chunk:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
size: int,
|
|
||||||
document_id: UUID,
|
|
||||||
chunk_id: int,
|
|
||||||
embedding,
|
|
||||||
):
|
|
||||||
self.text = text
|
|
||||||
self.size = size
|
|
||||||
self.document_id = document_id
|
|
||||||
self.chunk_id = chunk_id
|
|
||||||
self.embedding = embedding
|
|
||||||
|
|
||||||
|
|
||||||
class Chunker:
|
|
||||||
def __init__(self, collection) -> None:
|
|
||||||
self.collection = collection
|
|
||||||
self.llm_client = LLMClient()
|
|
||||||
|
|
||||||
def embedding_fx(self, inputs):
|
|
||||||
openai_embedding_fx = OpenAIEmbeddingFunction(
|
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
|
||||||
model_name="text-embedding-3-small",
|
|
||||||
)
|
|
||||||
return openai_embedding_fx(inputs)
|
|
||||||
|
|
||||||
def chunk_document(
|
|
||||||
self,
|
|
||||||
document: str,
|
|
||||||
chunk_size: int = 1000,
|
|
||||||
metadata: dict[str, Union[str, float]] = {},
|
|
||||||
) -> list[Chunk]:
|
|
||||||
doc_uuid = uuid4()
|
|
||||||
|
|
||||||
chunk_size = min(chunk_size, len(document)) or 1
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
num_chunks = ceil(len(document) / chunk_size)
|
|
||||||
document_length = len(document)
|
|
||||||
|
|
||||||
for i in range(num_chunks):
|
|
||||||
curr_pos = i * num_chunks
|
|
||||||
to_pos = (
|
|
||||||
curr_pos + chunk_size
|
|
||||||
if curr_pos + chunk_size < document_length
|
|
||||||
else document_length
|
|
||||||
)
|
|
||||||
text_chunk = self.clean_document(document[curr_pos:to_pos])
|
|
||||||
|
|
||||||
embedding = self.embedding_fx([text_chunk])
|
|
||||||
self.collection.add(
|
|
||||||
ids=[str(doc_uuid) + ":" + str(i)],
|
|
||||||
documents=[text_chunk],
|
|
||||||
embeddings=embedding,
|
|
||||||
metadatas=[metadata],
|
|
||||||
)
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
def clean_document(self, document: str) -> str:
|
|
||||||
"""This function will remove information that is noise or already known.
|
|
||||||
|
|
||||||
Example: We already know all the things in here are Simba-related, so we don't need things like
|
|
||||||
"Sumamry of simba's visit"
|
|
||||||
"""
|
|
||||||
|
|
||||||
document = document.replace("\\n", "")
|
|
||||||
document = document.strip()
|
|
||||||
|
|
||||||
return preprocess_text(document)
|
|
||||||
Reference in New Issue
Block a user