Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4ac0754ea7 | |||
| bac773ae4b | |||
| 564a9b68a5 | |||
| 7742673cc0 | |||
| c157c37cde | |||
| 3b8fa3e7a0 | |||
| 438399646f | |||
| 9ed4ca126a | |||
| f3ae76ce68 | |||
| 7ee3bdef84 | |||
| 500c44feb1 | |||
| 896501deb1 | |||
| c95800e65d | |||
| 90372a6a6d | |||
| c01764243f | |||
| dfaac4caf8 | |||
| 17c3a2f888 | |||
| fa0f68e3b4 | |||
| a6c698c6bd | |||
| 3671926430 |
@@ -19,11 +19,6 @@ BASE_URL=192.168.1.5:8000
|
||||
LLAMA_SERVER_URL=http://192.168.1.213:8080/v1
|
||||
LLAMA_MODEL_NAME=llama-3.1-8b-instruct
|
||||
|
||||
# ChromaDB Configuration
|
||||
# For Docker: This is automatically set to /app/data/chromadb
|
||||
# For local development: Set to a local directory path
|
||||
CHROMADB_PATH=./data/chromadb
|
||||
|
||||
# OpenAI Configuration
|
||||
OPENAI_API_KEY=your-openai-api-key
|
||||
|
||||
|
||||
@@ -13,9 +13,6 @@ wheels/
|
||||
.env
|
||||
|
||||
# Database files
|
||||
chromadb/
|
||||
chromadb_openai/
|
||||
chroma_db/
|
||||
database/
|
||||
*.db
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
## Project Overview
|
||||
|
||||
SimbaRAG is a RAG (Retrieval-Augmented Generation) conversational AI system for querying information about Simba (a cat). It ingests documents from Paperless-NGX, stores embeddings in ChromaDB, and uses LLMs (Ollama or OpenAI) to answer questions.
|
||||
SimbaRAG is a RAG (Retrieval-Augmented Generation) conversational AI system for querying information about Simba (a cat). It ingests documents from Paperless-NGX, stores embeddings in PostgreSQL via pgvector, and uses LLMs (Ollama or OpenAI) to answer questions.
|
||||
|
||||
## Commands
|
||||
|
||||
@@ -54,9 +54,8 @@ docker compose up -d
|
||||
│ Docker Compose │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ raggr (port 8080) │ postgres (port 5432) │
|
||||
│ ├── Quart backend │ PostgreSQL 16 │
|
||||
│ ├── React frontend (served) │ │
|
||||
│ └── ChromaDB (volume) │ │
|
||||
│ ├── Quart backend │ PostgreSQL 16 + pgvector│
|
||||
│ └── React frontend (served) │ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
|
||||
+2
-3
@@ -37,15 +37,14 @@ WORKDIR /app/raggr-frontend
|
||||
RUN yarn install && yarn build
|
||||
WORKDIR /app
|
||||
|
||||
# Create ChromaDB and database directories
|
||||
RUN mkdir -p /app/chromadb /app/database
|
||||
# Create database directory
|
||||
RUN mkdir -p /app/database
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8080
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONPATH=/app
|
||||
ENV CHROMADB_PATH=/app/chromadb
|
||||
|
||||
# Run the startup script
|
||||
CMD ["./startup.sh"]
|
||||
|
||||
+2
-3
@@ -34,16 +34,15 @@ COPY . .
|
||||
WORKDIR /app/raggr-frontend
|
||||
RUN yarn build
|
||||
|
||||
# Create ChromaDB and database directories
|
||||
# Create database directory
|
||||
WORKDIR /app
|
||||
RUN mkdir -p /app/chromadb /app/database
|
||||
RUN mkdir -p /app/database
|
||||
|
||||
# Make startup script executable
|
||||
RUN chmod +x /app/startup-dev.sh
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONPATH=/app
|
||||
ENV CHROMADB_PATH=/app/chromadb
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Expose port
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
.PHONY: deploy build up down restart logs migrate migrate-new frontend test
|
||||
.PHONY: deploy redeploy build up down restart logs migrate migrate-new frontend test
|
||||
|
||||
# Build and deploy
|
||||
deploy: build up
|
||||
|
||||
redeploy:
|
||||
git pull && $(MAKE) down && $(MAKE) up
|
||||
|
||||
build:
|
||||
docker compose build raggr
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
from datetime import timedelta
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from quart import Quart, jsonify, render_template, request, send_from_directory
|
||||
from quart import Quart, jsonify, render_template, send_from_directory
|
||||
from quart_jwt_extended import JWTManager, get_jwt_identity, jwt_refresh_token_required
|
||||
from tortoise import Tortoise
|
||||
|
||||
@@ -15,7 +15,6 @@ import blueprints.users
|
||||
import blueprints.whatsapp
|
||||
import blueprints.users.models
|
||||
from config.db import TORTOISE_CONFIG
|
||||
from main import consult_simba_oracle
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
@@ -78,39 +77,6 @@ async def serve_react_app(path):
|
||||
return await render_template("index.html")
|
||||
|
||||
|
||||
@app.route("/api/query", methods=["POST"])
|
||||
@jwt_refresh_token_required
|
||||
async def query():
|
||||
current_user_uuid = get_jwt_identity()
|
||||
user = await blueprints.users.models.User.get(id=current_user_uuid)
|
||||
data = await request.get_json()
|
||||
query = data.get("query")
|
||||
conversation_id = data.get("conversation_id")
|
||||
conversation = await blueprints.conversation.logic.get_conversation_by_id(
|
||||
conversation_id
|
||||
)
|
||||
await conversation.fetch_related("messages")
|
||||
await blueprints.conversation.logic.add_message_to_conversation(
|
||||
conversation=conversation,
|
||||
message=query,
|
||||
speaker="user",
|
||||
user=user,
|
||||
)
|
||||
|
||||
transcript = await blueprints.conversation.logic.get_conversation_transcript(
|
||||
user=user, conversation=conversation
|
||||
)
|
||||
|
||||
response = consult_simba_oracle(input=query, transcript=transcript)
|
||||
await blueprints.conversation.logic.add_message_to_conversation(
|
||||
conversation=conversation,
|
||||
message=response,
|
||||
speaker="simba",
|
||||
user=user,
|
||||
)
|
||||
return jsonify({"response": response})
|
||||
|
||||
|
||||
@app.route("/api/messages", methods=["GET"])
|
||||
@jwt_refresh_token_required
|
||||
async def get_messages():
|
||||
@@ -135,17 +101,10 @@ async def get_messages():
|
||||
}
|
||||
)
|
||||
|
||||
name = conversation.name
|
||||
if len(messages) > 8:
|
||||
name = await blueprints.conversation.logic.rename_conversation(
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"id": str(conversation.id),
|
||||
"name": name,
|
||||
"name": conversation.name,
|
||||
"messages": messages,
|
||||
"created_at": conversation.created_at.isoformat(),
|
||||
"updated_at": conversation.updated_at.isoformat(),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
@@ -20,8 +19,8 @@ from .agents import main_agent
|
||||
from .logic import (
|
||||
add_message_to_conversation,
|
||||
get_conversation_by_id,
|
||||
rename_conversation,
|
||||
)
|
||||
from .memory import get_memories_for_user
|
||||
from .models import (
|
||||
Conversation,
|
||||
PydConversation,
|
||||
@@ -36,15 +35,27 @@ conversation_blueprint = Blueprint(
|
||||
_SYSTEM_PROMPT = SIMBA_SYSTEM_PROMPT
|
||||
|
||||
|
||||
async def _build_system_prompt_with_memories(user_id: str) -> str:
|
||||
"""Append user memories to the base system prompt."""
|
||||
memories = await get_memories_for_user(user_id)
|
||||
if not memories:
|
||||
return _SYSTEM_PROMPT
|
||||
memory_block = "\n".join(f"- {m}" for m in memories)
|
||||
return f"{_SYSTEM_PROMPT}\n\nUSER MEMORIES (facts the user has asked you to remember):\n{memory_block}"
|
||||
|
||||
|
||||
def _build_messages_payload(
|
||||
conversation, query_text: str, image_description: str | None = None
|
||||
conversation,
|
||||
query_text: str,
|
||||
image_description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> list:
|
||||
recent_messages = (
|
||||
conversation.messages[-10:]
|
||||
if len(conversation.messages) > 10
|
||||
else conversation.messages
|
||||
)
|
||||
messages_payload = [{"role": "system", "content": _SYSTEM_PROMPT}]
|
||||
messages_payload = [{"role": "system", "content": system_prompt or _SYSTEM_PROMPT}]
|
||||
for msg in recent_messages[:-1]: # Exclude the message we just added
|
||||
role = "user" if msg.speaker == "user" else "assistant"
|
||||
text = msg.text
|
||||
@@ -80,10 +91,14 @@ async def query():
|
||||
user=user,
|
||||
)
|
||||
|
||||
messages_payload = _build_messages_payload(conversation, query)
|
||||
system_prompt = await _build_system_prompt_with_memories(str(user.id))
|
||||
messages_payload = _build_messages_payload(
|
||||
conversation, query, system_prompt=system_prompt
|
||||
)
|
||||
payload = {"messages": messages_payload}
|
||||
agent_config = {"configurable": {"user_id": str(user.id)}}
|
||||
|
||||
response = await main_agent.ainvoke(payload)
|
||||
response = await main_agent.ainvoke(payload, config=agent_config)
|
||||
message = response.get("messages", [])[-1].content
|
||||
await add_message_to_conversation(
|
||||
conversation=conversation,
|
||||
@@ -163,15 +178,19 @@ async def stream_query():
|
||||
logging.error(f"Failed to analyze image: {e}")
|
||||
image_description = "[Image could not be analyzed]"
|
||||
|
||||
system_prompt = await _build_system_prompt_with_memories(str(user.id))
|
||||
messages_payload = _build_messages_payload(
|
||||
conversation, query_text or "", image_description
|
||||
conversation, query_text or "", image_description, system_prompt=system_prompt
|
||||
)
|
||||
payload = {"messages": messages_payload}
|
||||
agent_config = {"configurable": {"user_id": str(user.id)}}
|
||||
|
||||
async def event_generator():
|
||||
final_message = None
|
||||
try:
|
||||
async for event in main_agent.astream_events(payload, version="v2"):
|
||||
async for event in main_agent.astream_events(
|
||||
payload, version="v2", config=agent_config
|
||||
):
|
||||
event_type = event.get("event")
|
||||
|
||||
if event_type == "on_tool_start":
|
||||
@@ -221,8 +240,6 @@ async def stream_query():
|
||||
@jwt_refresh_token_required
|
||||
async def get_conversation(conversation_id: str):
|
||||
conversation = await Conversation.get(id=conversation_id)
|
||||
current_user_uuid = get_jwt_identity()
|
||||
user = await blueprints.users.models.User.get(id=current_user_uuid)
|
||||
await conversation.fetch_related("messages")
|
||||
|
||||
# Manually serialize the conversation with messages
|
||||
@@ -237,18 +254,10 @@ async def get_conversation(conversation_id: str):
|
||||
"image_key": msg.image_key,
|
||||
}
|
||||
)
|
||||
name = conversation.name
|
||||
if len(messages) > 8 and "datetime" in name.lower():
|
||||
name = await rename_conversation(
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
)
|
||||
print(name)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"id": str(conversation.id),
|
||||
"name": name,
|
||||
"name": conversation.name,
|
||||
"messages": messages,
|
||||
"created_at": conversation.created_at.isoformat(),
|
||||
"updated_at": conversation.updated_at.isoformat(),
|
||||
@@ -262,7 +271,7 @@ async def create_conversation():
|
||||
user_uuid = get_jwt_identity()
|
||||
user = await blueprints.users.models.User.get(id=user_uuid)
|
||||
conversation = await Conversation.create(
|
||||
name=f"{user.username} {datetime.datetime.now().timestamp}",
|
||||
name="New Conversation",
|
||||
user=user,
|
||||
)
|
||||
|
||||
@@ -275,7 +284,7 @@ async def create_conversation():
|
||||
async def get_all_conversations():
|
||||
user_uuid = get_jwt_identity()
|
||||
user = await blueprints.users.models.User.get(id=user_uuid)
|
||||
conversations = Conversation.filter(user=user)
|
||||
conversations = Conversation.filter(user=user).order_by("-updated_at")
|
||||
serialized_conversations = await PydListConversation.from_queryset(conversations)
|
||||
|
||||
return jsonify(serialized_conversations.model_dump())
|
||||
|
||||
@@ -5,9 +5,11 @@ from dotenv import load_dotenv
|
||||
from langchain.agents import create_agent
|
||||
from langchain.chat_models import BaseChatModel
|
||||
from langchain.tools import tool
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_openai import ChatOpenAI
|
||||
from tavily import AsyncTavilyClient
|
||||
|
||||
from blueprints.conversation.memory import save_memory
|
||||
from blueprints.rag.logic import query_vector_store
|
||||
from utils.obsidian_service import ObsidianService
|
||||
from utils.ynab_service import YNABService
|
||||
@@ -326,7 +328,7 @@ async def obsidian_search_notes(query: str) -> str:
|
||||
return "Obsidian integration is not configured. Please set OBSIDIAN_VAULT_PATH environment variable."
|
||||
|
||||
try:
|
||||
# Query ChromaDB for obsidian documents
|
||||
# Query vector store for obsidian documents
|
||||
serialized, docs = await query_vector_store(query=query)
|
||||
return serialized
|
||||
|
||||
@@ -589,8 +591,35 @@ async def obsidian_create_task(
|
||||
return f"Error creating task: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
async def save_user_memory(content: str, config: RunnableConfig) -> str:
|
||||
"""Save a fact or preference about the user for future conversations.
|
||||
|
||||
Use this tool when the user:
|
||||
- Explicitly asks you to remember something ("remember that...", "keep in mind...")
|
||||
- Shares a personal preference that would be useful in future conversations
|
||||
(e.g., "I prefer metric units", "my cat's name is Luna")
|
||||
- Tells you a meaningful personal fact (e.g., "I'm allergic to peanuts")
|
||||
|
||||
Do NOT save:
|
||||
- Trivial or ephemeral info (e.g., "I'm tired today")
|
||||
- Information already in the system prompt or documents
|
||||
- Conversation-specific context that won't matter later
|
||||
|
||||
Args:
|
||||
content: A concise statement of the fact or preference to remember.
|
||||
Write it as a standalone sentence (e.g., "User prefers dark mode"
|
||||
rather than "likes dark mode").
|
||||
|
||||
Returns:
|
||||
Confirmation that the memory was saved.
|
||||
"""
|
||||
user_id = config["configurable"]["user_id"]
|
||||
return await save_memory(user_id=user_id, content=content)
|
||||
|
||||
|
||||
# Create tools list based on what's available
|
||||
tools = [get_current_date, simba_search, web_search]
|
||||
tools = [get_current_date, simba_search, web_search, save_user_memory]
|
||||
if ynab_enabled:
|
||||
tools.extend(
|
||||
[
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import tortoise.exceptions
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
import blueprints.users.models
|
||||
|
||||
from .models import Conversation, ConversationMessage, RenameConversationOutputSchema
|
||||
from .models import Conversation, ConversationMessage
|
||||
|
||||
|
||||
async def create_conversation(name: str = "") -> Conversation:
|
||||
@@ -19,6 +18,12 @@ async def add_message_to_conversation(
|
||||
image_key: str | None = None,
|
||||
) -> ConversationMessage:
|
||||
print(conversation, message, speaker)
|
||||
|
||||
# Name the conversation after the first user message
|
||||
if speaker == "user" and not await conversation.messages.all().exists():
|
||||
conversation.name = message[:100]
|
||||
await conversation.save()
|
||||
|
||||
message = await ConversationMessage.create(
|
||||
text=message,
|
||||
speaker=speaker,
|
||||
@@ -61,22 +66,3 @@ async def get_conversation_transcript(
|
||||
messages.append(f"{message.speaker} at {message.created_at}: {message.text}")
|
||||
|
||||
return "\n".join(messages)
|
||||
|
||||
|
||||
async def rename_conversation(
|
||||
user: blueprints.users.models.User,
|
||||
conversation: Conversation,
|
||||
) -> str:
|
||||
messages: str = await get_conversation_transcript(
|
||||
user=user, conversation=conversation
|
||||
)
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||
structured_llm = llm.with_structured_output(RenameConversationOutputSchema)
|
||||
|
||||
prompt = f"Summarize the following conversation into a sassy one-liner title:\n\n{messages}"
|
||||
response = structured_llm.invoke(prompt)
|
||||
new_name: str = response.get("title", "")
|
||||
conversation.name = new_name
|
||||
await conversation.save()
|
||||
return new_name
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
from .models import UserMemory
|
||||
|
||||
|
||||
async def get_memories_for_user(user_id: str) -> list[str]:
|
||||
"""Return all memory content strings for a user, ordered by most recently updated."""
|
||||
memories = await UserMemory.filter(user_id=user_id).order_by("-updated_at")
|
||||
return [m.content for m in memories]
|
||||
|
||||
|
||||
async def save_memory(user_id: str, content: str) -> str:
|
||||
"""Save a new memory or touch an existing one (exact-match dedup)."""
|
||||
existing = await UserMemory.filter(user_id=user_id, content=content).first()
|
||||
if existing:
|
||||
existing.updated_at = None # auto_now=True will refresh it on save
|
||||
await existing.save(update_fields=["updated_at"])
|
||||
return "Memory already exists (refreshed)."
|
||||
|
||||
await UserMemory.create(user_id=user_id, content=content)
|
||||
return "Memory saved."
|
||||
@@ -1,5 +1,4 @@
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
|
||||
from tortoise import fields
|
||||
from tortoise.contrib.pydantic import (
|
||||
@@ -9,12 +8,6 @@ from tortoise.contrib.pydantic import (
|
||||
from tortoise.models import Model
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenameConversationOutputSchema:
|
||||
title: str
|
||||
justification: str
|
||||
|
||||
|
||||
class Speaker(enum.Enum):
|
||||
USER = "user"
|
||||
SIMBA = "simba"
|
||||
@@ -47,6 +40,17 @@ class ConversationMessage(Model):
|
||||
table = "conversation_messages"
|
||||
|
||||
|
||||
class UserMemory(Model):
|
||||
id = fields.UUIDField(primary_key=True)
|
||||
user = fields.ForeignKeyField("models.User", related_name="memories")
|
||||
content = fields.TextField()
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
updated_at = fields.DatetimeField(auto_now=True)
|
||||
|
||||
class Meta:
|
||||
table = "user_memories"
|
||||
|
||||
|
||||
PydConversationMessage = pydantic_model_creator(ConversationMessage)
|
||||
PydConversation = pydantic_model_creator(
|
||||
Conversation, name="Conversation", allow_cycles=True, exclude=("user",)
|
||||
|
||||
@@ -54,4 +54,7 @@ You have access to Ryan's daily journal notes. Each note lives at journal/YYYY/Y
|
||||
- Use journal_get_tasks to list tasks (done/pending) for today or a specific date
|
||||
- Use journal_add_task to add a new task to today's (or a given date's) note
|
||||
- Use journal_complete_task to check off a task as done
|
||||
Use these tools when Ryan asks about today's tasks, wants to add something to his list, or wants to mark a task complete."""
|
||||
Use these tools when Ryan asks about today's tasks, wants to add something to his list, or wants to mark a task complete.
|
||||
|
||||
USER MEMORY:
|
||||
You can remember facts about the user across conversations using the save_user_memory tool. When a user explicitly asks you to remember something, or shares a meaningful preference or personal fact, save it. Saved memories will automatically appear at the end of this prompt in future conversations under "USER MEMORIES"."""
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from quart import Blueprint, jsonify
|
||||
from quart_jwt_extended import jwt_refresh_token_required
|
||||
|
||||
from .logic import fetch_obsidian_documents, get_vector_store_stats, index_documents, index_obsidian_documents, vector_store
|
||||
from .logic import (
|
||||
delete_all_documents,
|
||||
get_vector_store_stats,
|
||||
index_documents,
|
||||
index_obsidian_documents,
|
||||
)
|
||||
from blueprints.users.decorators import admin_required
|
||||
|
||||
rag_blueprint = Blueprint("rag_api", __name__, url_prefix="/api/rag")
|
||||
@@ -32,14 +37,7 @@ async def trigger_index():
|
||||
async def trigger_reindex():
|
||||
"""Clear and reindex all documents. Admin only."""
|
||||
try:
|
||||
# Clear existing documents
|
||||
collection = vector_store._collection
|
||||
all_docs = collection.get()
|
||||
|
||||
if all_docs["ids"]:
|
||||
collection.delete(ids=all_docs["ids"])
|
||||
|
||||
# Reindex
|
||||
delete_all_documents()
|
||||
await index_documents()
|
||||
stats = get_vector_store_stats()
|
||||
return jsonify({"status": "success", "stats": stats})
|
||||
|
||||
+125
-29
@@ -1,11 +1,13 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_core.documents import Document
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_postgres import PGVector
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
from .fetchers import PaperlessNGXService
|
||||
from utils.obsidian_service import ObsidianService
|
||||
@@ -13,13 +15,40 @@ from utils.obsidian_service import ObsidianService
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
|
||||
|
||||
vector_store = Chroma(
|
||||
collection_name="simba_docs",
|
||||
embedding_function=embeddings,
|
||||
persist_directory=os.getenv("CHROMADB_PATH", ""),
|
||||
# Convert Tortoise-style postgres:// URL to SQLAlchemy-style postgresql+psycopg://
|
||||
_db_url = os.getenv(
|
||||
"DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr"
|
||||
)
|
||||
_pgvector_url = _db_url.replace("postgres://", "postgresql+psycopg://")
|
||||
|
||||
# Lazy-initialized vector store (defers DB connection to first use)
|
||||
_vector_store = None
|
||||
|
||||
|
||||
def _get_vector_store() -> PGVector:
|
||||
global _vector_store
|
||||
if _vector_store is None:
|
||||
_vector_store = PGVector(
|
||||
embeddings=embeddings,
|
||||
collection_name="simba_docs",
|
||||
connection=_pgvector_url,
|
||||
use_jsonb=True,
|
||||
create_extension=False, # created by docker init script
|
||||
async_mode=True,
|
||||
)
|
||||
return _vector_store
|
||||
|
||||
|
||||
def _get_engine():
|
||||
"""Get a SQLAlchemy engine for direct queries."""
|
||||
if not hasattr(_get_engine, "_engine"):
|
||||
_get_engine._engine = create_engine(_pgvector_url)
|
||||
return _get_engine._engine
|
||||
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000, # chunk size (characters)
|
||||
@@ -28,6 +57,22 @@ text_splitter = RecursiveCharacterTextSplitter(
|
||||
)
|
||||
|
||||
|
||||
def _get_collection_id():
|
||||
"""Get the UUID of our collection from the langchain_pg_collection table."""
|
||||
engine = _get_engine()
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(
|
||||
text("SELECT uuid FROM langchain_pg_collection WHERE name = :name"),
|
||||
{"name": "simba_docs"},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row[0] if row else None
|
||||
except Exception:
|
||||
# Table doesn't exist yet (first run before any indexing)
|
||||
return None
|
||||
|
||||
|
||||
def date_to_epoch(date_str: str) -> float:
|
||||
split_date = date_str.split("-")
|
||||
date = datetime.datetime(
|
||||
@@ -63,6 +108,7 @@ async def index_documents():
|
||||
documents = await fetch_documents_from_paperless_ngx()
|
||||
|
||||
splits = text_splitter.split_documents(documents)
|
||||
vector_store = _get_vector_store()
|
||||
await vector_store.aadd_documents(documents=splits)
|
||||
|
||||
|
||||
@@ -92,13 +138,17 @@ async def fetch_obsidian_documents() -> list[Document]:
|
||||
"filepath": parsed["filepath"],
|
||||
"tags": parsed["tags"],
|
||||
"created_at": parsed["metadata"].get("created_at"),
|
||||
**{k: v for k, v in parsed["metadata"].items() if k not in ["created_at", "created_by"]},
|
||||
**{
|
||||
k: v
|
||||
for k, v in parsed["metadata"].items()
|
||||
if k not in ["created_at", "created_by"]
|
||||
},
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading {md_path}: {e}")
|
||||
logger.warning(f"Error reading {md_path}: {e}")
|
||||
continue
|
||||
|
||||
return documents
|
||||
@@ -109,26 +159,25 @@ async def index_obsidian_documents():
|
||||
|
||||
Deletes existing obsidian source chunks before re-indexing.
|
||||
"""
|
||||
obsidian_service = ObsidianService()
|
||||
documents = await fetch_obsidian_documents()
|
||||
|
||||
if not documents:
|
||||
print("No Obsidian documents found to index")
|
||||
logger.info("No Obsidian documents found to index")
|
||||
return {"indexed": 0}
|
||||
|
||||
# Delete existing obsidian chunks
|
||||
existing_results = vector_store.get(where={"source": "obsidian"})
|
||||
if existing_results.get("ids"):
|
||||
await vector_store.adelete(existing_results["ids"])
|
||||
delete_documents_by_metadata("source", "obsidian")
|
||||
|
||||
# Split and index documents
|
||||
splits = text_splitter.split_documents(documents)
|
||||
vector_store = _get_vector_store()
|
||||
await vector_store.aadd_documents(documents=splits)
|
||||
|
||||
return {"indexed": len(documents)}
|
||||
|
||||
|
||||
async def query_vector_store(query: str):
|
||||
vector_store = _get_vector_store()
|
||||
retrieved_docs = await vector_store.asimilarity_search(query, k=2)
|
||||
serialized = "\n\n".join(
|
||||
(f"Source: {doc.metadata}\nContent: {doc.page_content}")
|
||||
@@ -137,33 +186,80 @@ async def query_vector_store(query: str):
|
||||
return serialized, retrieved_docs
|
||||
|
||||
|
||||
def delete_all_documents():
|
||||
"""Delete all documents from the vector store collection."""
|
||||
collection_id = _get_collection_id()
|
||||
if not collection_id:
|
||||
return
|
||||
engine = _get_engine()
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
text("DELETE FROM langchain_pg_embedding WHERE collection_id = :cid"),
|
||||
{"cid": collection_id},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def delete_documents_by_metadata(key: str, value: str):
|
||||
"""Delete documents matching a metadata key/value pair."""
|
||||
collection_id = _get_collection_id()
|
||||
if not collection_id:
|
||||
return
|
||||
engine = _get_engine()
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"DELETE FROM langchain_pg_embedding "
|
||||
"WHERE collection_id = :cid AND cmetadata->>:key = :value"
|
||||
),
|
||||
{"cid": collection_id, "key": key, "value": value},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_vector_store_stats():
|
||||
"""Get statistics about the vector store."""
|
||||
collection = vector_store._collection
|
||||
count = collection.count()
|
||||
collection_id = _get_collection_id()
|
||||
count = 0
|
||||
if collection_id:
|
||||
engine = _get_engine()
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(
|
||||
text(
|
||||
"SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_id = :cid"
|
||||
),
|
||||
{"cid": collection_id},
|
||||
)
|
||||
count = result.scalar()
|
||||
return {
|
||||
"total_documents": count,
|
||||
"collection_name": collection.name,
|
||||
"collection_name": "simba_docs",
|
||||
}
|
||||
|
||||
|
||||
def list_all_documents(limit: int = 10):
|
||||
"""List documents in the vector store with their metadata."""
|
||||
collection = vector_store._collection
|
||||
results = collection.get(limit=limit, include=["metadatas", "documents"])
|
||||
collection_id = _get_collection_id()
|
||||
if not collection_id:
|
||||
return []
|
||||
|
||||
documents = []
|
||||
for i, doc_id in enumerate(results["ids"]):
|
||||
documents.append(
|
||||
{
|
||||
"id": doc_id,
|
||||
"metadata": results["metadatas"][i]
|
||||
if results.get("metadatas")
|
||||
else None,
|
||||
"content_preview": results["documents"][i][:200]
|
||||
if results.get("documents")
|
||||
else None,
|
||||
}
|
||||
engine = _get_engine()
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(
|
||||
text(
|
||||
"SELECT id, document, cmetadata FROM langchain_pg_embedding "
|
||||
"WHERE collection_id = :cid LIMIT :limit"
|
||||
),
|
||||
{"cid": collection_id, "limit": limit},
|
||||
)
|
||||
documents = []
|
||||
for row in result:
|
||||
documents.append(
|
||||
{
|
||||
"id": str(row[0]),
|
||||
"metadata": row[2],
|
||||
"content_preview": row[1][:200] if row[1] else None,
|
||||
}
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
@@ -35,7 +35,7 @@ class OIDCUserService:
|
||||
claims.get("preferred_username") or claims.get("name") or user.username
|
||||
)
|
||||
# Update LDAP groups from claims
|
||||
user.ldap_groups = claims.get("groups", [])
|
||||
user.ldap_groups = claims.get("groups") or []
|
||||
await user.save()
|
||||
return user
|
||||
|
||||
@@ -48,7 +48,7 @@ class OIDCUserService:
|
||||
user.oidc_subject = oidc_subject
|
||||
user.auth_provider = "oidc"
|
||||
user.password = None # Clear password
|
||||
user.ldap_groups = claims.get("groups", [])
|
||||
user.ldap_groups = claims.get("groups") or []
|
||||
await user.save()
|
||||
return user
|
||||
|
||||
@@ -61,7 +61,7 @@ class OIDCUserService:
|
||||
)
|
||||
|
||||
# Extract LDAP groups from claims
|
||||
groups = claims.get("groups", [])
|
||||
groups = claims.get("groups") or []
|
||||
|
||||
user = await User.create(
|
||||
id=uuid4(),
|
||||
|
||||
+2
-4
@@ -2,7 +2,7 @@ version: "3.8"
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
image: pgvector/pgvector:pg16
|
||||
ports:
|
||||
- "5432:5432"
|
||||
environment:
|
||||
@@ -11,6 +11,7 @@ services:
|
||||
- POSTGRES_DB=${POSTGRES_DB:-raggr}
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./docker/init-pgvector.sql:/docker-entrypoint-initdb.d/init-pgvector.sql
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-raggr}"]
|
||||
interval: 10s
|
||||
@@ -29,7 +30,6 @@ services:
|
||||
- PAPERLESS_TOKEN=${PAPERLESS_TOKEN}
|
||||
- BASE_URL=${BASE_URL}
|
||||
- OLLAMA_URL=${OLLAMA_URL:-http://localhost:11434}
|
||||
- CHROMADB_PATH=/app/data/chromadb
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
|
||||
- LLAMA_SERVER_URL=${LLAMA_SERVER_URL}
|
||||
@@ -66,10 +66,8 @@ services:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- chromadb_data:/app/data/chromadb
|
||||
- ./obvault:/app/data/obsidian
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
chromadb_data:
|
||||
postgres_data:
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
@@ -1,278 +0,0 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import chromadb
|
||||
from utils.chunker import Chunker
|
||||
from utils.cleaner import pdf_to_image, summarize_pdf_image
|
||||
from llm import LLMClient
|
||||
from scripts.query import QueryGenerator
|
||||
from utils.request import PaperlessNGXService
|
||||
|
||||
_dotenv_loaded = load_dotenv()
|
||||
|
||||
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
|
||||
simba_docs = client.get_or_create_collection(name="simba_docs2")
|
||||
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="An LLM tool to query information about Simba <3"
|
||||
)
|
||||
|
||||
parser.add_argument("query", type=str, help="questions about simba's health")
|
||||
parser.add_argument(
|
||||
"--reindex", action="store_true", help="re-index the simba documents"
|
||||
)
|
||||
parser.add_argument("--classify", action="store_true", help="test classification")
|
||||
parser.add_argument("--index", help="index a file")
|
||||
|
||||
ppngx = PaperlessNGXService()
|
||||
|
||||
llm_client = LLMClient()
|
||||
|
||||
|
||||
def index_using_pdf_llm(doctypes):
|
||||
logging.info("reindex data...")
|
||||
files = ppngx.get_data()
|
||||
for file in files:
|
||||
document_id: int = file["id"]
|
||||
pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
||||
image_paths = pdf_to_image(filepath=pdf_path)
|
||||
logging.info(f"summarizing {file}")
|
||||
generated_summary = summarize_pdf_image(filepaths=image_paths)
|
||||
file["content"] = generated_summary
|
||||
|
||||
chunk_data(files, simba_docs, doctypes=doctypes)
|
||||
|
||||
|
||||
def date_to_epoch(date_str: str) -> float:
|
||||
split_date = date_str.split("-")
|
||||
date = datetime.datetime(
|
||||
int(split_date[0]),
|
||||
int(split_date[1]),
|
||||
int(split_date[2]),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
return date.timestamp()
|
||||
|
||||
|
||||
def chunk_data(docs, collection, doctypes):
|
||||
# Step 2: Create chunks
|
||||
chunker = Chunker(collection)
|
||||
|
||||
logging.info(f"chunking {len(docs)} documents")
|
||||
texts: list[str] = [doc["content"] for doc in docs]
|
||||
with sqlite3.connect("database/visited.db") as conn:
|
||||
to_insert = []
|
||||
c = conn.cursor()
|
||||
for index, text in enumerate(texts):
|
||||
metadata = {
|
||||
"created_date": date_to_epoch(docs[index]["created_date"]),
|
||||
"filename": docs[index]["original_file_name"],
|
||||
"document_type": doctypes.get(docs[index]["document_type"], ""),
|
||||
}
|
||||
|
||||
if doctypes:
|
||||
metadata["type"] = doctypes.get(docs[index]["document_type"])
|
||||
|
||||
chunker.chunk_document(
|
||||
document=text,
|
||||
metadata=metadata,
|
||||
)
|
||||
to_insert.append((docs[index]["id"],))
|
||||
|
||||
c.executemany(
|
||||
"INSERT INTO indexed_documents (paperless_id) values (?)", to_insert
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def chunk_text(texts: list[str], collection):
|
||||
chunker = Chunker(collection)
|
||||
|
||||
for index, text in enumerate(texts):
|
||||
metadata = {}
|
||||
chunker.chunk_document(
|
||||
document=text,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def classify_query(query: str, transcript: str) -> bool:
|
||||
logging.info("Starting query generation")
|
||||
qg_start = time.time()
|
||||
qg = QueryGenerator()
|
||||
query_type = qg.get_query_type(input=query, transcript=transcript)
|
||||
logging.info(query_type)
|
||||
qg_end = time.time()
|
||||
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
||||
return query_type == "Simba"
|
||||
|
||||
|
||||
def consult_oracle(
|
||||
input: str,
|
||||
collection,
|
||||
transcript: str = "",
|
||||
):
|
||||
chunker = Chunker(collection)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Ask
|
||||
logging.info("Starting query generation")
|
||||
qg_start = time.time()
|
||||
qg = QueryGenerator()
|
||||
doctype_query = qg.get_doctype_query(input=input)
|
||||
# metadata_filter = qg.get_query(input)
|
||||
metadata_filter = {**doctype_query}
|
||||
logging.info(metadata_filter)
|
||||
qg_end = time.time()
|
||||
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
||||
|
||||
logging.info("Starting embedding generation")
|
||||
embedding_start = time.time()
|
||||
embeddings = chunker.embedding_fx(inputs=[input])
|
||||
embedding_end = time.time()
|
||||
logging.info(
|
||||
f"Embedding generation took {embedding_end - embedding_start:.2f} seconds"
|
||||
)
|
||||
|
||||
logging.info("Starting collection query")
|
||||
query_start = time.time()
|
||||
results = collection.query(
|
||||
query_texts=[input],
|
||||
query_embeddings=embeddings,
|
||||
where=metadata_filter,
|
||||
)
|
||||
query_end = time.time()
|
||||
logging.info(f"Collection query took {query_end - query_start:.2f} seconds")
|
||||
|
||||
# Generate
|
||||
logging.info("Starting LLM generation")
|
||||
llm_start = time.time()
|
||||
system_prompt = "You are a helpful assistant that understands veterinary terms."
|
||||
transcript_prompt = f"Here is the message transcript thus far {transcript}."
|
||||
prompt = f"""Using the following data, help answer the user's query by providing as many details as possible.
|
||||
Using this data: {results}. {transcript_prompt if len(transcript) > 0 else ""}
|
||||
Respond to this prompt: {input}"""
|
||||
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
|
||||
llm_end = time.time()
|
||||
logging.info(f"LLM generation took {llm_end - llm_start:.2f} seconds")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logging.info(f"Total consult_oracle execution took {total_time:.2f} seconds")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def llm_chat(input: str, transcript: str = "") -> str:
|
||||
system_prompt = "You are a helpful assistant that understands veterinary terms."
|
||||
transcript_prompt = f"Here is the message transcript thus far {transcript}."
|
||||
prompt = f"""Answer the user in as if you were a cat named Simba. Don't act too catlike. Be assertive.
|
||||
{transcript_prompt if len(transcript) > 0 else ""}
|
||||
Respond to this prompt: {input}"""
|
||||
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
|
||||
return output
|
||||
|
||||
|
||||
def paperless_workflow(input):
|
||||
# Step 1: Get the text
|
||||
ppngx = PaperlessNGXService()
|
||||
docs = ppngx.get_data()
|
||||
|
||||
chunk_data(docs, collection=simba_docs)
|
||||
consult_oracle(input, simba_docs)
|
||||
|
||||
|
||||
def consult_simba_oracle(input: str, transcript: str = ""):
|
||||
is_simba_related = classify_query(query=input, transcript=transcript)
|
||||
|
||||
if is_simba_related:
|
||||
logging.info("Query is related to simba")
|
||||
return consult_oracle(
|
||||
input=input,
|
||||
collection=simba_docs,
|
||||
transcript=transcript,
|
||||
)
|
||||
|
||||
logging.info("Query is NOT related to simba")
|
||||
|
||||
return llm_chat(input=input, transcript=transcript)
|
||||
|
||||
|
||||
def filter_indexed_files(docs):
|
||||
with sqlite3.connect("database/visited.db") as conn:
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
|
||||
)
|
||||
c.execute("SELECT paperless_id FROM indexed_documents")
|
||||
rows = c.fetchall()
|
||||
conn.commit()
|
||||
|
||||
visited = {row[0] for row in rows}
|
||||
return [doc for doc in docs if doc["id"] not in visited]
|
||||
|
||||
|
||||
def reindex():
|
||||
with sqlite3.connect("database/visited.db") as conn:
|
||||
c = conn.cursor()
|
||||
# Ensure the table exists before trying to delete from it
|
||||
c.execute(
|
||||
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
|
||||
)
|
||||
c.execute("DELETE FROM indexed_documents")
|
||||
conn.commit()
|
||||
|
||||
# Delete all documents from the collection
|
||||
all_docs = simba_docs.get()
|
||||
if all_docs["ids"]:
|
||||
simba_docs.delete(ids=all_docs["ids"])
|
||||
|
||||
logging.info("Fetching documents from Paperless-NGX")
|
||||
ppngx = PaperlessNGXService()
|
||||
docs = ppngx.get_data()
|
||||
docs = filter_indexed_files(docs)
|
||||
logging.info(f"Fetched {len(docs)} documents")
|
||||
|
||||
# Delete all chromadb data
|
||||
ids = simba_docs.get(ids=None, limit=None, offset=0)
|
||||
all_ids = ids["ids"]
|
||||
if len(all_ids) > 0:
|
||||
simba_docs.delete(ids=all_ids)
|
||||
|
||||
# Chunk documents
|
||||
logging.info("Chunking documents now ...")
|
||||
doctype_lookup = ppngx.get_doctypes()
|
||||
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
|
||||
logging.info("Done chunking documents")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.reindex:
|
||||
reindex()
|
||||
|
||||
if args.classify:
|
||||
consult_simba_oracle(input="yohohoho testing")
|
||||
consult_simba_oracle(input="write an email")
|
||||
consult_simba_oracle(input="how much does simba weigh")
|
||||
|
||||
if args.query:
|
||||
logging.info("Consulting oracle ...")
|
||||
print(
|
||||
consult_oracle(
|
||||
input=args.query,
|
||||
collection=simba_docs,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("please provide a query")
|
||||
@@ -0,0 +1,112 @@
|
||||
from tortoise import BaseDBAsyncClient
|
||||
|
||||
RUN_IN_TRANSACTION = True
|
||||
|
||||
|
||||
async def upgrade(db: BaseDBAsyncClient) -> str:
|
||||
return """
|
||||
CREATE TABLE IF NOT EXISTS "user_memories" (
|
||||
"id" UUID NOT NULL PRIMARY KEY,
|
||||
"content" TEXT NOT NULL,
|
||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"user_id" UUID NOT NULL REFERENCES "users" ("id") ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "email_accounts" (
|
||||
"id" UUID NOT NULL PRIMARY KEY,
|
||||
"email_address" VARCHAR(255) NOT NULL UNIQUE,
|
||||
"display_name" VARCHAR(255),
|
||||
"imap_host" VARCHAR(255) NOT NULL,
|
||||
"imap_port" INT NOT NULL DEFAULT 993,
|
||||
"imap_username" VARCHAR(255) NOT NULL,
|
||||
"imap_password" TEXT NOT NULL,
|
||||
"is_active" BOOL NOT NULL DEFAULT True,
|
||||
"last_error" TEXT,
|
||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"user_id" UUID NOT NULL REFERENCES "users" ("id") ON DELETE CASCADE
|
||||
);
|
||||
COMMENT ON TABLE "email_accounts" IS 'Email account configuration for IMAP connections.';
|
||||
CREATE TABLE IF NOT EXISTS "emails" (
|
||||
"id" UUID NOT NULL PRIMARY KEY,
|
||||
"message_id" VARCHAR(255) NOT NULL UNIQUE,
|
||||
"subject" VARCHAR(500) NOT NULL,
|
||||
"from_address" VARCHAR(255) NOT NULL,
|
||||
"to_address" TEXT NOT NULL,
|
||||
"date" TIMESTAMPTZ NOT NULL,
|
||||
"body_text" TEXT,
|
||||
"body_html" TEXT,
|
||||
"chromadb_doc_id" VARCHAR(255),
|
||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"expires_at" TIMESTAMPTZ NOT NULL,
|
||||
"account_id" UUID NOT NULL REFERENCES "email_accounts" ("id") ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS "idx_emails_message_981ddd" ON "emails" ("message_id");
|
||||
COMMENT ON TABLE "emails" IS 'Email message metadata and content.';
|
||||
CREATE TABLE IF NOT EXISTS "email_sync_status" (
|
||||
"id" UUID NOT NULL PRIMARY KEY,
|
||||
"last_sync_date" TIMESTAMPTZ,
|
||||
"last_message_uid" INT NOT NULL DEFAULT 0,
|
||||
"message_count" INT NOT NULL DEFAULT 0,
|
||||
"consecutive_failures" INT NOT NULL DEFAULT 0,
|
||||
"last_failure_date" TIMESTAMPTZ,
|
||||
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"account_id" UUID NOT NULL REFERENCES "email_accounts" ("id") ON DELETE CASCADE
|
||||
);
|
||||
COMMENT ON TABLE "email_sync_status" IS 'Tracks sync progress and state per email account.';"""
|
||||
|
||||
|
||||
async def downgrade(db: BaseDBAsyncClient) -> str:
|
||||
return """
|
||||
DROP TABLE IF EXISTS "user_memories";
|
||||
DROP TABLE IF EXISTS "email_accounts";
|
||||
DROP TABLE IF EXISTS "emails";
|
||||
DROP TABLE IF EXISTS "email_sync_status";"""
|
||||
|
||||
|
||||
MODELS_STATE = (
|
||||
"eJztXGtv2zYU/SuCPrVAFjTPbcUwwE7czVudDLGz9ZFCoCXa1ixRGkk1NYr+911Skq0HZV"
|
||||
"t+RUr1oU1C8lLU4SV57tGVvuquZ2GHHV955DOmDHHbI/pr7atOkIvhF2X9kaYj31/UigKO"
|
||||
"ho40MBMtZQ0aMk6RyaFyhByGocjCzKS2H12MBI4jCj0TGtpkvCgKiP1fgA3ujTGfYAoVHz"
|
||||
"9BsU0s/AWz+E9/aoxs7FipcduWuLYsN/jMl2X3993rN7KluNzQMD0ncMmitT/jE4/MmweB"
|
||||
"bR0LG1E3xgRTxLGVuA0xyui246JwxFDAaYDnQ7UWBRYeocARYOi/jAJiCgw0eSXx3/mveg"
|
||||
"l4AGoBrU24wOLrt/CuFvcsS3VxqavfW3cvzi5fyrv0GB9TWSkR0b9JQ8RRaCpxXQApf+ag"
|
||||
"vJogqoYybp8BEwa6CYxxwQLHhQ/FQMYAbYaa7qIvhoPJmE/gz9OLiyUw/t26k0hCKwmlB3"
|
||||
"4dev1NVHUa1glIFxCaFItbNhDPA3kNNdx2sRrMtGUGUisyPY5/qSjAcA/WLXFm0SJYgu+g"
|
||||
"2+v0B63eX+JOXMb+cyRErUFH1JzK0lmm9MVlZirmnWj/dAe/a+JP7cPtTSfr+/N2gw+6GB"
|
||||
"MKuGcQ79FAVmK9xqUxMKmJDXxrw4lNWzYT+6QTGw0+Ma8MU6PcCZIw2eIYicZ2wEnc/NAQ"
|
||||
"R+9oqjwzBBh58N54FNtj8ieeSQi7MA5ETNVhEZGO+6ibqoK2KF2MgqLHORtJOgXcHdwT5u"
|
||||
"Hp2epfta47usRwiMzpI6KWUQCmixlDY8zygLYjyzd/3mFnTs3UWCYJXC/ssZq7ShG2Eivv"
|
||||
"1EtglEIvX+WeutkSROC+reja4kpL0FnBghMgrkeGjeRENqS41qSY4y+KI38ApWoo4/Z1Ic"
|
||||
"XLjvLOu0HqFI+p74te693L1En+9vbmt7h5gipfvb1tNwz5ORKpPENmPkZTFRkQAWSHBG6O"
|
||||
"CqRmN2H+xEtHv+937l5r4kR/IP1ur916rTHbHSJ9vSlORZknr9YIMk9eFcaYoiq9gGwXTh"
|
||||
"ZjimdlQvWU0Ub4Hp56pYG8ODldA0loVQilrMtsRslDu9yRqTDd5flZ03DAzIiHW4YFWS2y"
|
||||
"siiujA8U7lI2TtgnKxbxVw+7Hp3pCjKcqF3KgWUQ5IqGdsN9nwH3hYtwTErR34RJw4AbBv"
|
||||
"xdMeBGI34WE1sdjbham2FdROIKs8AtVOJ9s78i3rea8TVMr/5MT8xj2cf/SZu6cL0DpAD4"
|
||||
"iLFHjyo8s20TRGdqMJNWGTCHMx5GU5VTaJaA1xa8N3m6A2Tt7k3r7r2aOsftk37bfj/otD"
|
||||
"LoYhfZThkvnRvsxkVXr/hdOujJq/Xkw2X6YU5AfJwgzmBLN0jgDosEWzWYCtOdiImHRfVs"
|
||||
"HVDPijE9y0EqnczARNyeauF7noMRWeKgSdvs8gfjfW2mZY/qEuv/9vZtav23u9nQ+L7X7o"
|
||||
"DzSpihkR1Soe7NQAnuxEUmcIQpVuiKK1Z/xraGHntyuc42kI2QErvAZdZjPdsyDRYM/8Wm"
|
||||
"IlotBjRrV0Mw93LqQ/w4MXzqfbatcltqzvBwVEp3PBM5W3DRzBOadbbVi+Jt9SK3rToW8o"
|
||||
"0x9QJfkRLzR//2Rg1pxiwD6D2Bu/xo2SY/0hyb8U97g/fjp/3wfHHny1XJrACZIVaig0aV"
|
||||
"fJbiVaNKPtOJnSfG5VShVVmFudc0dpNaWOWINJ9SmFwRySeUm2ORfihaPc9fC4qQHyPT9A"
|
||||
"JhthUgHdFXK+yqZpDsU1yVsOgKdbUTKxPF8qqcnvX0VV12p0WZp/CTI6H2aYhYWvRQ9ljP"
|
||||
"oLSOzQN5IH3uwbal+YgybGlyUJps+GjzCYTTP1hoplEs2sNgjrW3NpkyjXva1YR6LrpuP5"
|
||||
"CRR7XPEDPAD4YRNSeaiXw0tCHsg5UoR9bowDvih1vowJErKJ92Fccwaas6Cm17iQk3CK+3"
|
||||
"jayfXFK/WEuxvFiiWF7kFcsR7CKCGIEfK86oYjSzdvWEdC++CbyyENAlye1eHeE8dIKPiI"
|
||||
"jKxlqxTT2jrJpEVfFtL42Xh541M8q+9ZEyqkl+9aGXhcRowl3F47sVwMZGDbDqhELJsuGK"
|
||||
"MLSSzE1hWhOQm5f5G+VsU0kUf/Ft6G2DiU1b1nNiazKRax3WkXJVMjszbdUkaMaA5DEsna"
|
||||
"NZXxHwKJOrmXaSKqVrpjAuEhYTc7BCX0zJv+vqjJGNkAlH9jigUhvWhMrX7bX+EsUES7mL"
|
||||
"FamOJXpIaJBzK4otITfCGEMVEhOTznxwNC1OpTtK9KExzDlcnh09EKFuxt2AO/OAHWv9wP"
|
||||
"c9ypnmgovZvoPjFkzzMZXvgjYaZUU0yshpy8tBOcNGqYwVC5v5DpoZZVOAs3ZN7JB4S9s3"
|
||||
"JuDYZeBMGdVFXTsUmGJ/zoPZJQXCQcomg6W9P27y889nW0Apv0Xzw+nJ+Y/nP51dnv8ETe"
|
||||
"RQ5iU/LgE3nzkpMdgktT9n2DhjxhkLk/w7MQ8p1rRyPdQF3UMLWzYE2sBHPit8d2lKdcru"
|
||||
"gOnU84O/wtnUDmLcwJR6iizVYpdNW9XkmG9e7G70wiaFspnY5sXu5sXu6r7YnRE2dpGEWS"
|
||||
"8s0zlTM2IaoSq3AyD60Ft/3lmNINm7fJxApkhBToO3SkTOTNxqHXkA1VOmCTvNp57YdJjM"
|
||||
"PBWdYCm74qRQnNeRS/cgdOSegBn+MU1w2tBYnL1g4/pHYSF0ZkJf2JqnxsJOGCnHI+gwoF"
|
||||
"gTdzeFgYg0Vxaqx5oNsR92MfTvhB0LA8matQn86kDzRkWuiIosIxrptJuka+Wtd8DtqhUh"
|
||||
"VYjKrfUoWE5JnIkcqNZJoVaoMj2cZPhqC2q+Y8EwxqDgaXAhgDm77xI90T02AyE8GdExoS"
|
||||
"AxhSAWmX+XWMolGaGw+Q6d7aDZpJ94k26UlOeppDR5WM8sD2vfSQ31z8JqYWqbE10RPUc1"
|
||||
"R8uCZrRoU5kv5xU/S1+TEUcT+KTJMjvhIcVho3j9Xflt8+KH6QmTujzoPcQHc2BplAAxal"
|
||||
"5PAPfyGbfCj3MXfxin+OPcB/sozt4O3Z19FKfENzZ2f7x8+x8fHBMe"
|
||||
)
|
||||
+2
-2
@@ -5,7 +5,8 @@ description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"chromadb>=1.1.0",
|
||||
"langchain-postgres>=0.0.13",
|
||||
"psycopg[binary]>=3.1.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"flask>=3.1.2",
|
||||
"httpx>=0.28.1",
|
||||
@@ -30,7 +31,6 @@ dependencies = [
|
||||
"asyncpg>=0.30.0",
|
||||
"langchain-openai>=1.1.6",
|
||||
"langchain>=1.2.0",
|
||||
"langchain-chroma>=1.0.0",
|
||||
"langchain-community>=0.4.1",
|
||||
"jq>=1.10.0",
|
||||
"tavily-python>=0.7.17",
|
||||
|
||||
@@ -1,48 +1,13 @@
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
import "./App.css";
|
||||
import { AuthProvider } from "./contexts/AuthContext";
|
||||
import { ChatScreen } from "./components/ChatScreen";
|
||||
import { LoginScreen } from "./components/LoginScreen";
|
||||
import { conversationService } from "./api/conversationService";
|
||||
import { useAuthCheck } from "./hooks/useAuthCheck";
|
||||
import catIcon from "./assets/cat.png";
|
||||
|
||||
const AppContainer = () => {
|
||||
const [isAuthenticated, setAuthenticated] = useState<boolean>(false);
|
||||
const [isChecking, setIsChecking] = useState<boolean>(true);
|
||||
const { isAuthenticated, isChecking, isAdmin, setAuthenticated } = useAuthCheck();
|
||||
|
||||
useEffect(() => {
|
||||
const checkAuth = async () => {
|
||||
const accessToken = localStorage.getItem("access_token");
|
||||
const refreshToken = localStorage.getItem("refresh_token");
|
||||
|
||||
// No tokens at all, not authenticated
|
||||
if (!accessToken && !refreshToken) {
|
||||
setIsChecking(false);
|
||||
setAuthenticated(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to verify token by making a request
|
||||
try {
|
||||
await conversationService.getAllConversations();
|
||||
// If successful, user is authenticated
|
||||
setAuthenticated(true);
|
||||
} catch (error) {
|
||||
// Token is invalid or expired
|
||||
console.error("Authentication check failed:", error);
|
||||
localStorage.removeItem("access_token");
|
||||
localStorage.removeItem("refresh_token");
|
||||
setAuthenticated(false);
|
||||
} finally {
|
||||
setIsChecking(false);
|
||||
}
|
||||
};
|
||||
|
||||
checkAuth();
|
||||
}, []);
|
||||
|
||||
// Show loading state while checking authentication
|
||||
if (isChecking) {
|
||||
return (
|
||||
<div className="h-screen flex flex-col items-center justify-center bg-cream gap-4">
|
||||
@@ -61,7 +26,7 @@ const AppContainer = () => {
|
||||
return (
|
||||
<>
|
||||
{isAuthenticated ? (
|
||||
<ChatScreen setAuthenticated={setAuthenticated} />
|
||||
<ChatScreen setAuthenticated={setAuthenticated} isAdmin={isAdmin} />
|
||||
) : (
|
||||
<LoginScreen setAuthenticated={setAuthenticated} />
|
||||
)}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import { useState } from "react";
|
||||
import { X, Phone, PhoneOff, Pencil, Check, Mail, Copy } from "lucide-react";
|
||||
import { userService, type AdminUserRecord } from "../api/userService";
|
||||
import { cn } from "../lib/utils";
|
||||
@@ -12,27 +12,19 @@ import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "./ui/table";
|
||||
import { useAdminUsers } from "../hooks/useAdminUsers";
|
||||
|
||||
type Props = {
|
||||
onClose: () => void;
|
||||
};
|
||||
|
||||
export const AdminPanel = ({ onClose }: Props) => {
|
||||
const [users, setUsers] = useState<AdminUserRecord[]>([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const { users, loading, updateUser } = useAdminUsers();
|
||||
const [editingId, setEditingId] = useState<string | null>(null);
|
||||
const [editValue, setEditValue] = useState("");
|
||||
const [rowError, setRowError] = useState<Record<string, string>>({});
|
||||
const [rowSuccess, setRowSuccess] = useState<Record<string, string>>({});
|
||||
|
||||
useEffect(() => {
|
||||
userService
|
||||
.adminListUsers()
|
||||
.then(setUsers)
|
||||
.catch(() => {})
|
||||
.finally(() => setLoading(false));
|
||||
}, []);
|
||||
|
||||
const startEdit = (user: AdminUserRecord) => {
|
||||
setEditingId(user.id);
|
||||
setEditValue(user.whatsapp_number ?? "");
|
||||
@@ -49,8 +41,8 @@ export const AdminPanel = ({ onClose }: Props) => {
|
||||
setRowError((p) => ({ ...p, [userId]: "" }));
|
||||
try {
|
||||
const updated = await userService.adminSetWhatsapp(userId, editValue);
|
||||
setUsers((p) => p.map((u) => (u.id === userId ? updated : u)));
|
||||
setRowSuccess((p) => ({ ...p, [userId]: "Saved ✓" }));
|
||||
updateUser(userId, () => updated);
|
||||
setRowSuccess((p) => ({ ...p, [userId]: "Saved" }));
|
||||
setEditingId(null);
|
||||
setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000);
|
||||
} catch (err) {
|
||||
@@ -65,10 +57,8 @@ export const AdminPanel = ({ onClose }: Props) => {
|
||||
setRowError((p) => ({ ...p, [userId]: "" }));
|
||||
try {
|
||||
await userService.adminUnlinkWhatsapp(userId);
|
||||
setUsers((p) =>
|
||||
p.map((u) => (u.id === userId ? { ...u, whatsapp_number: null } : u)),
|
||||
);
|
||||
setRowSuccess((p) => ({ ...p, [userId]: "Unlinked ✓" }));
|
||||
updateUser(userId, (u) => ({ ...u, whatsapp_number: null }));
|
||||
setRowSuccess((p) => ({ ...p, [userId]: "Unlinked" }));
|
||||
setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000);
|
||||
} catch (err) {
|
||||
setRowError((p) => ({
|
||||
@@ -82,8 +72,8 @@ export const AdminPanel = ({ onClose }: Props) => {
|
||||
setRowError((p) => ({ ...p, [userId]: "" }));
|
||||
try {
|
||||
const updated = await userService.adminToggleEmail(userId);
|
||||
setUsers((p) => p.map((u) => (u.id === userId ? updated : u)));
|
||||
setRowSuccess((p) => ({ ...p, [userId]: "Email enabled ✓" }));
|
||||
updateUser(userId, () => updated);
|
||||
setRowSuccess((p) => ({ ...p, [userId]: "Email enabled" }));
|
||||
setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000);
|
||||
} catch (err) {
|
||||
setRowError((p) => ({
|
||||
@@ -97,10 +87,8 @@ export const AdminPanel = ({ onClose }: Props) => {
|
||||
setRowError((p) => ({ ...p, [userId]: "" }));
|
||||
try {
|
||||
await userService.adminDisableEmail(userId);
|
||||
setUsers((p) =>
|
||||
p.map((u) => (u.id === userId ? { ...u, email_enabled: false, email_address: null } : u)),
|
||||
);
|
||||
setRowSuccess((p) => ({ ...p, [userId]: "Email disabled ✓" }));
|
||||
updateUser(userId, (u) => ({ ...u, email_enabled: false, email_address: null }));
|
||||
setRowSuccess((p) => ({ ...p, [userId]: "Email disabled" }));
|
||||
setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000);
|
||||
} catch (err) {
|
||||
setRowError((p) => ({
|
||||
@@ -112,7 +100,7 @@ export const AdminPanel = ({ onClose }: Props) => {
|
||||
|
||||
const copyToClipboard = (text: string, userId: string) => {
|
||||
navigator.clipboard.writeText(text);
|
||||
setRowSuccess((p) => ({ ...p, [userId]: "Copied ✓" }));
|
||||
setRowSuccess((p) => ({ ...p, [userId]: "Copied" }));
|
||||
setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000);
|
||||
};
|
||||
|
||||
@@ -128,7 +116,6 @@ export const AdminPanel = ({ onClose }: Props) => {
|
||||
"border border-sand-light/60",
|
||||
)}
|
||||
>
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between px-6 py-4 border-b border-sand-light/60">
|
||||
<div className="flex items-center gap-2.5">
|
||||
<div className="w-8 h-8 rounded-xl bg-leaf-pale flex items-center justify-center">
|
||||
@@ -146,7 +133,6 @@ export const AdminPanel = ({ onClose }: Props) => {
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Body */}
|
||||
<div className="overflow-y-auto flex-1 rounded-b-3xl">
|
||||
{loading ? (
|
||||
<div className="px-6 py-12 text-center text-warm-gray text-sm">
|
||||
@@ -155,7 +141,7 @@ export const AdminPanel = ({ onClose }: Props) => {
|
||||
<span className="loading-dot w-2 h-2 rounded-full bg-amber-soft inline-block" />
|
||||
<span className="loading-dot w-2 h-2 rounded-full bg-amber-soft inline-block" />
|
||||
</div>
|
||||
Loading users…
|
||||
Loading users...
|
||||
</div>
|
||||
) : (
|
||||
<Table>
|
||||
@@ -204,7 +190,7 @@ export const AdminPanel = ({ onClose }: Props) => {
|
||||
: "text-warm-gray/40 italic",
|
||||
)}
|
||||
>
|
||||
{user.whatsapp_number ?? "—"}
|
||||
{user.whatsapp_number ?? "\u2014"}
|
||||
</span>
|
||||
{rowSuccess[user.id] && (
|
||||
<span className="text-xs text-leaf-dark">
|
||||
@@ -235,7 +221,7 @@ export const AdminPanel = ({ onClose }: Props) => {
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<span className="text-sm text-warm-gray/40 italic">—</span>
|
||||
<span className="text-sm text-warm-gray/40 italic">\u2014</span>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import React from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import { cn } from "../lib/utils";
|
||||
|
||||
@@ -6,7 +7,7 @@ type AnswerBubbleProps = {
|
||||
loading?: boolean;
|
||||
};
|
||||
|
||||
export const AnswerBubble = ({ text, loading }: AnswerBubbleProps) => {
|
||||
export const AnswerBubble = React.memo(({ text, loading }: AnswerBubbleProps) => {
|
||||
return (
|
||||
<div className="flex justify-start message-enter">
|
||||
<div
|
||||
@@ -17,7 +18,6 @@ export const AnswerBubble = ({ text, loading }: AnswerBubbleProps) => {
|
||||
"overflow-hidden",
|
||||
)}
|
||||
>
|
||||
{/* amber accent bar */}
|
||||
<div className="h-0.5 w-full bg-gradient-to-r from-amber-soft via-amber-glow/50 to-transparent" />
|
||||
|
||||
<div className="px-4 py-3">
|
||||
@@ -36,4 +36,4 @@ export const AnswerBubble = ({ text, loading }: AnswerBubbleProps) => {
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import { useCallback, useEffect, useState, useRef } from "react";
|
||||
import { useCallback, useState, useRef } from "react";
|
||||
import { LogOut, Shield, PanelLeftClose, PanelLeftOpen, Menu, X } from "lucide-react";
|
||||
import { conversationService } from "../api/conversationService";
|
||||
import { userService } from "../api/userService";
|
||||
import { QuestionBubble } from "./QuestionBubble";
|
||||
import { AnswerBubble } from "./AnswerBubble";
|
||||
import { ToolBubble } from "./ToolBubble";
|
||||
@@ -9,216 +7,79 @@ import { MessageInput } from "./MessageInput";
|
||||
import { ConversationList } from "./ConversationList";
|
||||
import { AdminPanel } from "./AdminPanel";
|
||||
import { cn } from "../lib/utils";
|
||||
import { useConversations } from "../hooks/useConversations";
|
||||
import { useChat } from "../hooks/useChat";
|
||||
import catIcon from "../assets/cat.png";
|
||||
|
||||
type Message = {
|
||||
text: string;
|
||||
speaker: "simba" | "user" | "tool";
|
||||
image_key?: string | null;
|
||||
};
|
||||
|
||||
type Conversation = {
|
||||
title: string;
|
||||
id: string;
|
||||
};
|
||||
|
||||
type ChatScreenProps = {
|
||||
setAuthenticated: (isAuth: boolean) => void;
|
||||
isAdmin: boolean;
|
||||
};
|
||||
|
||||
const TOOL_MESSAGES: Record<string, string> = {
|
||||
simba_search: "🔍 Searching Simba's records...",
|
||||
web_search: "🌐 Searching the web...",
|
||||
get_current_date: "📅 Checking today's date...",
|
||||
ynab_budget_summary: "💰 Checking budget summary...",
|
||||
ynab_search_transactions: "💳 Looking up transactions...",
|
||||
ynab_category_spending: "📊 Analyzing category spending...",
|
||||
ynab_insights: "📈 Generating budget insights...",
|
||||
obsidian_search_notes: "📝 Searching notes...",
|
||||
obsidian_read_note: "📖 Reading note...",
|
||||
obsidian_create_note: "✏️ Saving note...",
|
||||
obsidian_create_task: "✅ Creating task...",
|
||||
journal_get_today: "📔 Reading today's journal...",
|
||||
journal_get_tasks: "📋 Getting tasks...",
|
||||
journal_add_task: "➕ Adding task...",
|
||||
journal_complete_task: "✔️ Completing task...",
|
||||
};
|
||||
|
||||
export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
const [query, setQuery] = useState<string>("");
|
||||
const [simbaMode, setSimbaMode] = useState<boolean>(false);
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
const [conversations, setConversations] = useState<Conversation[]>([]);
|
||||
const [showConversations, setShowConversations] = useState<boolean>(false);
|
||||
const [selectedConversation, setSelectedConversation] =
|
||||
useState<Conversation | null>(null);
|
||||
const [sidebarCollapsed, setSidebarCollapsed] = useState<boolean>(false);
|
||||
const [isLoading, setIsLoading] = useState<boolean>(false);
|
||||
const [isAdmin, setIsAdmin] = useState<boolean>(false);
|
||||
const [showAdminPanel, setShowAdminPanel] = useState<boolean>(false);
|
||||
const [pendingImage, setPendingImage] = useState<File | null>(null);
|
||||
export const ChatScreen = ({ setAuthenticated, isAdmin }: ChatScreenProps) => {
|
||||
const [query, setQuery] = useState("");
|
||||
const [simbaMode, setSimbaMode] = useState(false);
|
||||
const [showConversations, setShowConversations] = useState(false);
|
||||
const [sidebarCollapsed, setSidebarCollapsed] = useState(false);
|
||||
const [showAdminPanel, setShowAdminPanel] = useState(false);
|
||||
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
const isMountedRef = useRef<boolean>(true);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"];
|
||||
const isLoadingRef = useRef(false);
|
||||
|
||||
const scrollToBottom = useCallback(() => {
|
||||
requestAnimationFrame(() => {
|
||||
messagesEndRef.current?.scrollIntoView({
|
||||
behavior: isLoading ? "instant" : "smooth",
|
||||
behavior: isLoadingRef.current ? "instant" : "smooth",
|
||||
});
|
||||
});
|
||||
}, [isLoading]);
|
||||
|
||||
useEffect(() => {
|
||||
isMountedRef.current = true;
|
||||
return () => {
|
||||
isMountedRef.current = false;
|
||||
abortControllerRef.current?.abort();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const handleSelectConversation = (conversation: Conversation) => {
|
||||
setShowConversations(false);
|
||||
setSelectedConversation(conversation);
|
||||
const load = async () => {
|
||||
try {
|
||||
const fetched = await conversationService.getConversation(conversation.id);
|
||||
setMessages(
|
||||
fetched.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })),
|
||||
);
|
||||
} catch (err) {
|
||||
console.error("Failed to load messages:", err);
|
||||
}
|
||||
};
|
||||
load();
|
||||
};
|
||||
const {
|
||||
conversations,
|
||||
selectedConversation,
|
||||
selectConversation,
|
||||
createConversation,
|
||||
refreshConversations,
|
||||
} = useConversations();
|
||||
|
||||
const loadConversations = async () => {
|
||||
try {
|
||||
const fetched = await conversationService.getAllConversations();
|
||||
const parsed = fetched.map((c) => ({ id: c.id, title: c.name }));
|
||||
setConversations(parsed);
|
||||
} catch (err) {
|
||||
console.error("Failed to load conversations:", err);
|
||||
}
|
||||
};
|
||||
const onSessionExpired = useCallback(() => setAuthenticated(false), [setAuthenticated]);
|
||||
|
||||
const handleCreateNewConversation = async () => {
|
||||
const newConv = await conversationService.createConversation();
|
||||
await loadConversations();
|
||||
setSelectedConversation({ title: newConv.name, id: newConv.id });
|
||||
};
|
||||
const {
|
||||
messages,
|
||||
setMessages,
|
||||
isLoading,
|
||||
pendingImage,
|
||||
setPendingImage,
|
||||
sendMessage,
|
||||
} = useChat({
|
||||
selectedConversation,
|
||||
createConversation,
|
||||
refreshConversations,
|
||||
onSessionExpired,
|
||||
scrollToBottom,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
loadConversations();
|
||||
userService.getMe().then((me) => setIsAdmin(me.is_admin)).catch(() => {});
|
||||
}, []);
|
||||
// Keep ref in sync for scrollToBottom behavior
|
||||
isLoadingRef.current = isLoading;
|
||||
|
||||
useEffect(() => {
|
||||
scrollToBottom();
|
||||
}, [messages]);
|
||||
const handleSelectConversation = useCallback(
|
||||
async (conversation: { title: string; id: string }) => {
|
||||
setShowConversations(false);
|
||||
const loaded = await selectConversation(conversation);
|
||||
setMessages(loaded);
|
||||
},
|
||||
[selectConversation, setMessages],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const load = async () => {
|
||||
if (!selectedConversation) return;
|
||||
try {
|
||||
const conv = await conversationService.getConversation(selectedConversation.id);
|
||||
setSelectedConversation({ id: conv.id, title: conv.name });
|
||||
setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })));
|
||||
} catch (err) {
|
||||
console.error("Failed to load messages:", err);
|
||||
}
|
||||
};
|
||||
load();
|
||||
}, [selectedConversation?.id]);
|
||||
const handleCreateNewConversation = useCallback(async () => {
|
||||
await createConversation();
|
||||
setMessages([]);
|
||||
}, [createConversation, setMessages]);
|
||||
|
||||
const handleQuestionSubmit = useCallback(async () => {
|
||||
if ((!query.trim() && !pendingImage) || isLoading) return;
|
||||
|
||||
let activeConversation = selectedConversation;
|
||||
if (!activeConversation) {
|
||||
const newConv = await conversationService.createConversation();
|
||||
activeConversation = { title: newConv.name, id: newConv.id };
|
||||
setSelectedConversation(activeConversation);
|
||||
setConversations((prev) => [activeConversation!, ...prev]);
|
||||
}
|
||||
|
||||
// Capture pending image before clearing state
|
||||
const imageFile = pendingImage;
|
||||
|
||||
const currMessages = messages.concat([{ text: query, speaker: "user" }]);
|
||||
setMessages(currMessages);
|
||||
const handleQuestionSubmit = useCallback(() => {
|
||||
sendMessage(query, simbaMode);
|
||||
setQuery("");
|
||||
setPendingImage(null);
|
||||
setIsLoading(true);
|
||||
|
||||
if (simbaMode) {
|
||||
const randomElement = simbaAnswers[Math.floor(Math.random() * simbaAnswers.length)];
|
||||
setMessages((prev) => prev.concat([{ text: randomElement, speaker: "simba" }]));
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
abortControllerRef.current = abortController;
|
||||
|
||||
try {
|
||||
// Upload image first if present
|
||||
let imageKey: string | undefined;
|
||||
if (imageFile) {
|
||||
const uploadResult = await conversationService.uploadImage(
|
||||
imageFile,
|
||||
activeConversation.id,
|
||||
);
|
||||
imageKey = uploadResult.image_key;
|
||||
|
||||
// Update the user message with the image key
|
||||
setMessages((prev) => {
|
||||
const updated = [...prev];
|
||||
// Find the last user message we just added
|
||||
for (let i = updated.length - 1; i >= 0; i--) {
|
||||
if (updated[i].speaker === "user") {
|
||||
updated[i] = { ...updated[i], image_key: imageKey };
|
||||
break;
|
||||
}
|
||||
}
|
||||
return updated;
|
||||
});
|
||||
}
|
||||
|
||||
await conversationService.streamQuery(
|
||||
query,
|
||||
activeConversation.id,
|
||||
(event) => {
|
||||
if (!isMountedRef.current) return;
|
||||
if (event.type === "tool_start") {
|
||||
const friendly = TOOL_MESSAGES[event.tool] ?? `🔧 Using ${event.tool}...`;
|
||||
setMessages((prev) => prev.concat([{ text: friendly, speaker: "tool" }]));
|
||||
} else if (event.type === "response") {
|
||||
setMessages((prev) => prev.concat([{ text: event.message, speaker: "simba" }]));
|
||||
} else if (event.type === "error") {
|
||||
console.error("Stream error:", event.message);
|
||||
}
|
||||
},
|
||||
abortController.signal,
|
||||
imageKey,
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
console.log("Request was aborted");
|
||||
} else {
|
||||
console.error("Failed to send query:", error);
|
||||
if (error instanceof Error && error.message.includes("Session expired")) {
|
||||
setAuthenticated(false);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
if (isMountedRef.current) setIsLoading(false);
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
}, [query, pendingImage, isLoading, selectedConversation, simbaMode, messages, setAuthenticated]);
|
||||
}, [query, simbaMode, sendMessage]);
|
||||
|
||||
const handleQueryChange = useCallback((event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
setQuery(event.target.value);
|
||||
@@ -232,8 +93,8 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
}
|
||||
}, [handleQuestionSubmit]);
|
||||
|
||||
const handleImageSelect = useCallback((file: File) => setPendingImage(file), []);
|
||||
const handleClearImage = useCallback(() => setPendingImage(null), []);
|
||||
const handleImageSelect = useCallback((file: File) => setPendingImage(file), [setPendingImage]);
|
||||
const handleClearImage = useCallback(() => setPendingImage(null), [setPendingImage]);
|
||||
|
||||
const handleLogout = () => {
|
||||
localStorage.removeItem("access_token");
|
||||
@@ -243,7 +104,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
|
||||
return (
|
||||
<div className="h-screen h-[100dvh] flex flex-row bg-cream overflow-hidden">
|
||||
{/* ── Desktop Sidebar ─────────────────────────────── */}
|
||||
{/* Desktop Sidebar */}
|
||||
<aside
|
||||
className={cn(
|
||||
"hidden md:flex md:flex-col",
|
||||
@@ -252,7 +113,6 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
)}
|
||||
>
|
||||
{sidebarCollapsed ? (
|
||||
/* Collapsed state */
|
||||
<div className="flex flex-col items-center py-4 gap-4 h-full">
|
||||
<button
|
||||
onClick={() => setSidebarCollapsed(false)}
|
||||
@@ -267,9 +127,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
/* Expanded state */
|
||||
<div className="flex flex-col h-full">
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between px-4 py-4 border-b border-white/8">
|
||||
<div className="flex items-center gap-2.5">
|
||||
<img src={catIcon} alt="Simba" className="w-12 h-12" />
|
||||
@@ -288,7 +146,6 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Conversations */}
|
||||
<div className="flex-1 overflow-y-auto px-2 py-3">
|
||||
<ConversationList
|
||||
conversations={conversations}
|
||||
@@ -298,7 +155,6 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Footer */}
|
||||
<div className="px-2 pb-3 pt-2 border-t border-white/8 flex flex-col gap-0.5">
|
||||
{isAdmin && (
|
||||
<button
|
||||
@@ -321,12 +177,9 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
)}
|
||||
</aside>
|
||||
|
||||
{/* Admin Panel modal */}
|
||||
{showAdminPanel && <AdminPanel onClose={() => setShowAdminPanel(false)} />}
|
||||
|
||||
{/* ── Main chat area ──────────────────────────────── */}
|
||||
<div className="flex-1 flex flex-col h-full overflow-hidden min-w-0">
|
||||
{/* Mobile header */}
|
||||
<header className="md:hidden flex items-center justify-between px-4 py-3 bg-warm-white border-b border-sand-light/60">
|
||||
<div className="flex items-center gap-2">
|
||||
<img src={catIcon} alt="Simba" className="w-12 h-12" />
|
||||
@@ -354,9 +207,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
</header>
|
||||
|
||||
{messages.length === 0 ? (
|
||||
/* ── Empty / homepage state ── */
|
||||
<div className="flex-1 flex flex-col items-center justify-center px-4 gap-6">
|
||||
{/* Mobile conversation drawer */}
|
||||
{showConversations && (
|
||||
<div className="md:hidden w-full max-w-2xl bg-warm-white rounded-2xl border border-sand-light p-3 shadow-sm">
|
||||
<ConversationList
|
||||
@@ -393,11 +244,9 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
/* ── Active chat state ── */
|
||||
<>
|
||||
<div className="flex-1 overflow-y-auto px-4 py-6">
|
||||
<div className="max-w-2xl mx-auto flex flex-col gap-3">
|
||||
{/* Mobile conversation drawer */}
|
||||
{showConversations && (
|
||||
<div className="md:hidden mb-3 bg-warm-white rounded-2xl border border-sand-light p-3 shadow-sm">
|
||||
<ConversationList
|
||||
@@ -433,8 +282,8 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
setSimbaMode={setSimbaMode}
|
||||
isLoading={isLoading}
|
||||
pendingImage={pendingImage}
|
||||
onImageSelect={(file) => setPendingImage(file)}
|
||||
onClearImage={() => setPendingImage(null)}
|
||||
onImageSelect={handleImageSelect}
|
||||
onClearImage={handleClearImage}
|
||||
/>
|
||||
</div>
|
||||
</footer>
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { Plus } from "lucide-react";
|
||||
import { cn } from "../lib/utils";
|
||||
import { conversationService } from "../api/conversationService";
|
||||
|
||||
type Conversation = {
|
||||
title: string;
|
||||
@@ -23,32 +21,8 @@ export const ConversationList = ({
|
||||
selectedId,
|
||||
variant = "dark",
|
||||
}: ConversationProps) => {
|
||||
const [items, setItems] = useState(conversations);
|
||||
|
||||
useEffect(() => {
|
||||
const load = async () => {
|
||||
try {
|
||||
let fetched = await conversationService.getAllConversations();
|
||||
if (fetched.length === 0) {
|
||||
await conversationService.createConversation();
|
||||
fetched = await conversationService.getAllConversations();
|
||||
}
|
||||
setItems(fetched.map((c) => ({ id: c.id, title: c.name })));
|
||||
} catch (err) {
|
||||
console.error("Failed to load conversations:", err);
|
||||
}
|
||||
};
|
||||
load();
|
||||
}, []);
|
||||
|
||||
// Keep in sync when parent updates conversations
|
||||
useEffect(() => {
|
||||
setItems(conversations);
|
||||
}, [conversations]);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-1">
|
||||
{/* New thread button */}
|
||||
<button
|
||||
onClick={onCreateNewConversation}
|
||||
className={cn(
|
||||
@@ -63,8 +37,7 @@ export const ConversationList = ({
|
||||
<span>New thread</span>
|
||||
</button>
|
||||
|
||||
{/* Conversation items */}
|
||||
{items.map((conv) => {
|
||||
{conversations.map((conv) => {
|
||||
const isActive = conv.id === selectedId;
|
||||
return (
|
||||
<button
|
||||
|
||||
@@ -1,66 +1,19 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { userService } from "../api/userService";
|
||||
import { oidcService } from "../api/oidcService";
|
||||
import catIcon from "../assets/cat.png";
|
||||
import { cn } from "../lib/utils";
|
||||
import { useOIDCAuth } from "../hooks/useOIDCAuth";
|
||||
|
||||
type LoginScreenProps = {
|
||||
setAuthenticated: (isAuth: boolean) => void;
|
||||
};
|
||||
|
||||
export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => {
|
||||
const [error, setError] = useState<string>("");
|
||||
const [isChecking, setIsChecking] = useState<boolean>(true);
|
||||
const [isLoggingIn, setIsLoggingIn] = useState<boolean>(false);
|
||||
|
||||
useEffect(() => {
|
||||
const initAuth = async () => {
|
||||
const callbackParams = oidcService.getCallbackParamsFromURL();
|
||||
if (callbackParams) {
|
||||
try {
|
||||
setIsLoggingIn(true);
|
||||
const result = await oidcService.handleCallback(
|
||||
callbackParams.code,
|
||||
callbackParams.state,
|
||||
);
|
||||
localStorage.setItem("access_token", result.access_token);
|
||||
localStorage.setItem("refresh_token", result.refresh_token);
|
||||
oidcService.clearCallbackParams();
|
||||
setAuthenticated(true);
|
||||
setIsChecking(false);
|
||||
return;
|
||||
} catch (err) {
|
||||
console.error("OIDC callback error:", err);
|
||||
setError("Login failed. Please try again.");
|
||||
oidcService.clearCallbackParams();
|
||||
setIsLoggingIn(false);
|
||||
setIsChecking(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
const isValid = await userService.validateToken();
|
||||
if (isValid) setAuthenticated(true);
|
||||
setIsChecking(false);
|
||||
};
|
||||
initAuth();
|
||||
}, [setAuthenticated]);
|
||||
|
||||
const handleOIDCLogin = async () => {
|
||||
try {
|
||||
setIsLoggingIn(true);
|
||||
setError("");
|
||||
const authUrl = await oidcService.initiateLogin();
|
||||
window.location.href = authUrl;
|
||||
} catch {
|
||||
setError("Failed to initiate login. Please try again.");
|
||||
setIsLoggingIn(false);
|
||||
}
|
||||
};
|
||||
const { isChecking, isLoggingIn, error, handleLogin } = useOIDCAuth({
|
||||
setAuthenticated,
|
||||
});
|
||||
|
||||
if (isChecking || isLoggingIn) {
|
||||
return (
|
||||
<div className="h-screen flex flex-col items-center justify-center bg-cream gap-4">
|
||||
{/* Subtle dot grid */}
|
||||
<div
|
||||
className="fixed inset-0 pointer-events-none opacity-[0.035]"
|
||||
style={{
|
||||
@@ -85,7 +38,6 @@ export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => {
|
||||
|
||||
return (
|
||||
<div className="h-screen bg-cream flex items-center justify-center p-4 relative overflow-hidden">
|
||||
{/* Background dot texture */}
|
||||
<div
|
||||
className="fixed inset-0 pointer-events-none opacity-[0.04]"
|
||||
style={{
|
||||
@@ -94,12 +46,10 @@ export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => {
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* Decorative background blobs */}
|
||||
<div className="absolute top-1/4 -left-20 w-72 h-72 rounded-full bg-leaf-pale/60 blur-3xl pointer-events-none" />
|
||||
<div className="absolute bottom-1/4 -right-20 w-64 h-64 rounded-full bg-amber-pale/70 blur-3xl pointer-events-none" />
|
||||
|
||||
<div className="relative w-full max-w-sm">
|
||||
{/* Branding */}
|
||||
<div className="flex flex-col items-center mb-8">
|
||||
<div className="relative mb-5">
|
||||
<div className="absolute -inset-5 bg-amber-soft/30 rounded-full blur-2xl" />
|
||||
@@ -120,7 +70,6 @@ export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => {
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Card */}
|
||||
<div
|
||||
className={cn(
|
||||
"bg-warm-white rounded-3xl border border-sand-light",
|
||||
@@ -138,7 +87,7 @@ export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => {
|
||||
</p>
|
||||
|
||||
<button
|
||||
onClick={handleOIDCLogin}
|
||||
onClick={handleLogin}
|
||||
disabled={isLoggingIn}
|
||||
className={cn(
|
||||
"w-full py-3.5 px-4 rounded-2xl text-sm font-semibold tracking-wide",
|
||||
@@ -154,7 +103,7 @@ export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => {
|
||||
</div>
|
||||
|
||||
<p className="text-center text-sand mt-5 text-xs tracking-widest select-none">
|
||||
✦ meow ✦
|
||||
* meow *
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,26 +1,14 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import React from "react";
|
||||
import { cn } from "../lib/utils";
|
||||
import { conversationService } from "../api/conversationService";
|
||||
import { usePresignedUrl } from "../hooks/usePresignedUrl";
|
||||
|
||||
type QuestionBubbleProps = {
|
||||
text: string;
|
||||
image_key?: string | null;
|
||||
};
|
||||
|
||||
export const QuestionBubble = ({ text, image_key }: QuestionBubbleProps) => {
|
||||
const [imageUrl, setImageUrl] = useState<string | null>(null);
|
||||
const [imageError, setImageError] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!image_key) return;
|
||||
conversationService
|
||||
.getPresignedImageUrl(image_key)
|
||||
.then(setImageUrl)
|
||||
.catch((err) => {
|
||||
console.error("Failed to load image:", err);
|
||||
setImageError(true);
|
||||
});
|
||||
}, [image_key]);
|
||||
export const QuestionBubble = React.memo(({ text, image_key }: QuestionBubbleProps) => {
|
||||
const { imageUrl, imageError } = usePresignedUrl(image_key);
|
||||
|
||||
return (
|
||||
<div className="flex justify-end message-enter">
|
||||
@@ -34,7 +22,6 @@ export const QuestionBubble = ({ text, image_key }: QuestionBubbleProps) => {
|
||||
>
|
||||
{imageError && (
|
||||
<div className="flex items-center gap-2 text-xs text-charcoal/50 bg-charcoal/5 rounded-xl px-3 py-2 mb-2">
|
||||
<span>🖼️</span>
|
||||
<span>Image failed to load</span>
|
||||
</div>
|
||||
)}
|
||||
@@ -49,4 +36,4 @@ export const QuestionBubble = ({ text, image_key }: QuestionBubbleProps) => {
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import React from "react";
|
||||
import { cn } from "../lib/utils";
|
||||
|
||||
export const ToolBubble = ({ text }: { text: string }) => (
|
||||
export const ToolBubble = React.memo(({ text }: { text: string }) => (
|
||||
<div className="flex justify-center message-enter">
|
||||
<div
|
||||
className={cn(
|
||||
@@ -12,4 +13,4 @@ export const ToolBubble = ({ text }: { text: string }) => (
|
||||
{text}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
));
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { userService, type AdminUserRecord } from "../api/userService";
|
||||
|
||||
export function useAdminUsers() {
|
||||
const [users, setUsers] = useState<AdminUserRecord[]>([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
userService
|
||||
.adminListUsers()
|
||||
.then(setUsers)
|
||||
.catch(() => {})
|
||||
.finally(() => setLoading(false));
|
||||
}, []);
|
||||
|
||||
const updateUser = (userId: string, updater: (u: AdminUserRecord) => AdminUserRecord) => {
|
||||
setUsers((prev) => prev.map((u) => (u.id === userId ? updater(u) : u)));
|
||||
};
|
||||
|
||||
return { users, loading, updateUser };
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { userService } from "../api/userService";
|
||||
|
||||
export function useAuthCheck() {
|
||||
const [isAuthenticated, setAuthenticated] = useState(false);
|
||||
const [isChecking, setIsChecking] = useState(true);
|
||||
const [isAdmin, setIsAdmin] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const checkAuth = async () => {
|
||||
const accessToken = localStorage.getItem("access_token");
|
||||
const refreshToken = localStorage.getItem("refresh_token");
|
||||
|
||||
if (!accessToken && !refreshToken) {
|
||||
setIsChecking(false);
|
||||
setAuthenticated(false);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const me = await userService.getMe();
|
||||
setAuthenticated(true);
|
||||
setIsAdmin(me.is_admin);
|
||||
} catch {
|
||||
localStorage.removeItem("access_token");
|
||||
localStorage.removeItem("refresh_token");
|
||||
setAuthenticated(false);
|
||||
} finally {
|
||||
setIsChecking(false);
|
||||
}
|
||||
};
|
||||
|
||||
checkAuth();
|
||||
}, []);
|
||||
|
||||
return { isAuthenticated, isChecking, isAdmin, setAuthenticated };
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
import { useState, useCallback, useEffect, useRef } from "react";
|
||||
import { conversationService } from "../api/conversationService";
|
||||
import type { Conversation } from "./useConversations";
|
||||
|
||||
type Message = {
|
||||
text: string;
|
||||
speaker: "simba" | "user" | "tool";
|
||||
image_key?: string | null;
|
||||
};
|
||||
|
||||
const TOOL_MESSAGES: Record<string, string> = {
|
||||
simba_search: "Searching Simba's records...",
|
||||
web_search: "Searching the web...",
|
||||
get_current_date: "Checking today's date...",
|
||||
ynab_budget_summary: "Checking budget summary...",
|
||||
ynab_search_transactions: "Looking up transactions...",
|
||||
ynab_category_spending: "Analyzing category spending...",
|
||||
ynab_insights: "Generating budget insights...",
|
||||
obsidian_search_notes: "Searching notes...",
|
||||
obsidian_read_note: "Reading note...",
|
||||
obsidian_create_note: "Saving note...",
|
||||
obsidian_create_task: "Creating task...",
|
||||
journal_get_today: "Reading today's journal...",
|
||||
journal_get_tasks: "Getting tasks...",
|
||||
journal_add_task: "Adding task...",
|
||||
journal_complete_task: "Completing task...",
|
||||
};
|
||||
|
||||
const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"];
|
||||
|
||||
type UseChatOptions = {
|
||||
selectedConversation: Conversation | null;
|
||||
createConversation: () => Promise<Conversation>;
|
||||
refreshConversations: () => Promise<void>;
|
||||
onSessionExpired: () => void;
|
||||
scrollToBottom: () => void;
|
||||
};
|
||||
|
||||
export function useChat({
|
||||
selectedConversation,
|
||||
createConversation,
|
||||
refreshConversations,
|
||||
onSessionExpired,
|
||||
scrollToBottom,
|
||||
}: UseChatOptions) {
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [pendingImage, setPendingImage] = useState<File | null>(null);
|
||||
|
||||
const isMountedRef = useRef(true);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
isMountedRef.current = true;
|
||||
return () => {
|
||||
isMountedRef.current = false;
|
||||
abortControllerRef.current?.abort();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const updateMessages = useCallback(
|
||||
(updater: Message[] | ((prev: Message[]) => Message[])) => {
|
||||
setMessages(updater);
|
||||
scrollToBottom();
|
||||
},
|
||||
[scrollToBottom],
|
||||
);
|
||||
|
||||
const sendMessage = useCallback(
|
||||
async (query: string, simbaMode: boolean) => {
|
||||
if ((!query.trim() && !pendingImage) || isLoading) return;
|
||||
|
||||
let activeConversation = selectedConversation;
|
||||
let createdNew = false;
|
||||
if (!activeConversation) {
|
||||
activeConversation = await createConversation();
|
||||
createdNew = true;
|
||||
}
|
||||
|
||||
const imageFile = pendingImage;
|
||||
|
||||
updateMessages((prev) => prev.concat([{ text: query, speaker: "user" }]));
|
||||
setPendingImage(null);
|
||||
setIsLoading(true);
|
||||
|
||||
if (simbaMode) {
|
||||
const randomElement =
|
||||
simbaAnswers[Math.floor(Math.random() * simbaAnswers.length)];
|
||||
updateMessages((prev) =>
|
||||
prev.concat([{ text: randomElement, speaker: "simba" }]),
|
||||
);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
abortControllerRef.current = abortController;
|
||||
|
||||
try {
|
||||
let imageKey: string | undefined;
|
||||
if (imageFile) {
|
||||
const uploadResult = await conversationService.uploadImage(
|
||||
imageFile,
|
||||
activeConversation.id,
|
||||
);
|
||||
imageKey = uploadResult.image_key;
|
||||
|
||||
updateMessages((prev) => {
|
||||
const updated = [...prev];
|
||||
for (let i = updated.length - 1; i >= 0; i--) {
|
||||
if (updated[i].speaker === "user") {
|
||||
updated[i] = { ...updated[i], image_key: imageKey };
|
||||
break;
|
||||
}
|
||||
}
|
||||
return updated;
|
||||
});
|
||||
}
|
||||
|
||||
await conversationService.streamQuery(
|
||||
query,
|
||||
activeConversation.id,
|
||||
(event) => {
|
||||
if (!isMountedRef.current) return;
|
||||
if (event.type === "tool_start") {
|
||||
const friendly =
|
||||
TOOL_MESSAGES[event.tool] ?? `Using ${event.tool}...`;
|
||||
updateMessages((prev) =>
|
||||
prev.concat([{ text: friendly, speaker: "tool" }]),
|
||||
);
|
||||
} else if (event.type === "response") {
|
||||
updateMessages((prev) =>
|
||||
prev.concat([{ text: event.message, speaker: "simba" }]),
|
||||
);
|
||||
} else if (event.type === "error") {
|
||||
console.error("Stream error:", event.message);
|
||||
}
|
||||
},
|
||||
abortController.signal,
|
||||
imageKey,
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
console.log("Request was aborted");
|
||||
} else {
|
||||
console.error("Failed to send query:", error);
|
||||
if (
|
||||
error instanceof Error &&
|
||||
error.message.includes("Session expired")
|
||||
) {
|
||||
onSessionExpired();
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
if (isMountedRef.current) {
|
||||
setIsLoading(false);
|
||||
if (createdNew) {
|
||||
refreshConversations();
|
||||
}
|
||||
}
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
},
|
||||
[
|
||||
pendingImage,
|
||||
isLoading,
|
||||
selectedConversation,
|
||||
createConversation,
|
||||
refreshConversations,
|
||||
onSessionExpired,
|
||||
updateMessages,
|
||||
],
|
||||
);
|
||||
|
||||
return {
|
||||
messages,
|
||||
setMessages: updateMessages,
|
||||
isLoading,
|
||||
pendingImage,
|
||||
setPendingImage,
|
||||
sendMessage,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
import { useState, useCallback, useEffect } from "react";
|
||||
import { conversationService } from "../api/conversationService";
|
||||
|
||||
export type Conversation = {
|
||||
title: string;
|
||||
id: string;
|
||||
};
|
||||
|
||||
type Message = {
|
||||
text: string;
|
||||
speaker: "simba" | "user" | "tool";
|
||||
image_key?: string | null;
|
||||
};
|
||||
|
||||
export function useConversations() {
|
||||
const [conversations, setConversations] = useState<Conversation[]>([]);
|
||||
const [selectedConversation, setSelectedConversation] =
|
||||
useState<Conversation | null>(null);
|
||||
|
||||
const refreshConversations = useCallback(async () => {
|
||||
try {
|
||||
const fetched = await conversationService.getAllConversations();
|
||||
setConversations(fetched.map((c) => ({ id: c.id, title: c.name })));
|
||||
} catch (err) {
|
||||
console.error("Failed to load conversations:", err);
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
refreshConversations();
|
||||
}, [refreshConversations]);
|
||||
|
||||
const selectConversation = useCallback(
|
||||
async (conversation: Conversation): Promise<Message[]> => {
|
||||
setSelectedConversation(conversation);
|
||||
try {
|
||||
const fetched = await conversationService.getConversation(
|
||||
conversation.id,
|
||||
);
|
||||
return fetched.messages.map((m) => ({
|
||||
text: m.text,
|
||||
speaker: m.speaker,
|
||||
image_key: m.image_key,
|
||||
}));
|
||||
} catch (err) {
|
||||
console.error("Failed to load messages:", err);
|
||||
return [];
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const createConversation = useCallback(async (): Promise<Conversation> => {
|
||||
const newConv = await conversationService.createConversation();
|
||||
const conversation = { title: newConv.name, id: newConv.id };
|
||||
setConversations((prev) => [conversation, ...prev]);
|
||||
setSelectedConversation(conversation);
|
||||
return conversation;
|
||||
}, []);
|
||||
|
||||
return {
|
||||
conversations,
|
||||
selectedConversation,
|
||||
setSelectedConversation,
|
||||
selectConversation,
|
||||
createConversation,
|
||||
refreshConversations,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { userService } from "../api/userService";
|
||||
import { oidcService } from "../api/oidcService";
|
||||
|
||||
type UseOIDCAuthOptions = {
|
||||
setAuthenticated: (isAuth: boolean) => void;
|
||||
};
|
||||
|
||||
export function useOIDCAuth({ setAuthenticated }: UseOIDCAuthOptions) {
|
||||
const [isChecking, setIsChecking] = useState(true);
|
||||
const [isLoggingIn, setIsLoggingIn] = useState(false);
|
||||
const [error, setError] = useState("");
|
||||
|
||||
useEffect(() => {
|
||||
const initAuth = async () => {
|
||||
const callbackParams = oidcService.getCallbackParamsFromURL();
|
||||
if (callbackParams) {
|
||||
try {
|
||||
setIsLoggingIn(true);
|
||||
const result = await oidcService.handleCallback(
|
||||
callbackParams.code,
|
||||
callbackParams.state,
|
||||
);
|
||||
localStorage.setItem("access_token", result.access_token);
|
||||
localStorage.setItem("refresh_token", result.refresh_token);
|
||||
oidcService.clearCallbackParams();
|
||||
setAuthenticated(true);
|
||||
setIsChecking(false);
|
||||
return;
|
||||
} catch (err) {
|
||||
console.error("OIDC callback error:", err);
|
||||
setError("Login failed. Please try again.");
|
||||
oidcService.clearCallbackParams();
|
||||
setIsLoggingIn(false);
|
||||
setIsChecking(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
const isValid = await userService.validateToken();
|
||||
if (isValid) setAuthenticated(true);
|
||||
setIsChecking(false);
|
||||
};
|
||||
initAuth();
|
||||
}, [setAuthenticated]);
|
||||
|
||||
const handleLogin = async () => {
|
||||
try {
|
||||
setIsLoggingIn(true);
|
||||
setError("");
|
||||
const authUrl = await oidcService.initiateLogin();
|
||||
window.location.href = authUrl;
|
||||
} catch {
|
||||
setError("Failed to initiate login. Please try again.");
|
||||
setIsLoggingIn(false);
|
||||
}
|
||||
};
|
||||
|
||||
return { isChecking, isLoggingIn, error, handleLogin };
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { conversationService } from "../api/conversationService";
|
||||
|
||||
const urlCache = new Map<string, string>();
|
||||
|
||||
export function usePresignedUrl(imageKey: string | null | undefined) {
|
||||
const [imageUrl, setImageUrl] = useState<string | null>(
|
||||
imageKey ? (urlCache.get(imageKey) ?? null) : null,
|
||||
);
|
||||
const [imageError, setImageError] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!imageKey) return;
|
||||
|
||||
const cached = urlCache.get(imageKey);
|
||||
if (cached) {
|
||||
setImageUrl(cached);
|
||||
return;
|
||||
}
|
||||
|
||||
conversationService
|
||||
.getPresignedImageUrl(imageKey)
|
||||
.then((url) => {
|
||||
urlCache.set(imageKey, url);
|
||||
setImageUrl(url);
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error("Failed to load image:", err);
|
||||
setImageError(true);
|
||||
});
|
||||
}, [imageKey]);
|
||||
|
||||
return { imageUrl, imageError };
|
||||
}
|
||||
@@ -6,19 +6,19 @@ import asyncio
|
||||
import sys
|
||||
|
||||
from blueprints.rag.logic import (
|
||||
delete_all_documents,
|
||||
get_vector_store_stats,
|
||||
index_documents,
|
||||
list_all_documents,
|
||||
vector_store,
|
||||
)
|
||||
|
||||
|
||||
def stats():
|
||||
"""Show vector store statistics."""
|
||||
stats = get_vector_store_stats()
|
||||
s = get_vector_store_stats()
|
||||
print("=== Vector Store Statistics ===")
|
||||
print(f"Collection: {stats['collection_name']}")
|
||||
print(f"Total Documents: {stats['total_documents']}")
|
||||
print(f"Collection: {s['collection_name']}")
|
||||
print(f"Total Documents: {s['total_documents']}")
|
||||
|
||||
|
||||
async def index():
|
||||
@@ -26,23 +26,15 @@ async def index():
|
||||
print("Starting indexing process...")
|
||||
print("Fetching documents from Paperless-NGX...")
|
||||
await index_documents()
|
||||
print("✓ Indexing complete!")
|
||||
print("Indexing complete!")
|
||||
stats()
|
||||
|
||||
|
||||
async def reindex():
|
||||
"""Clear and reindex all documents."""
|
||||
print("Clearing existing documents...")
|
||||
collection = vector_store._collection
|
||||
all_docs = collection.get()
|
||||
|
||||
if all_docs["ids"]:
|
||||
print(f"Deleting {len(all_docs['ids'])} existing documents...")
|
||||
collection.delete(ids=all_docs["ids"])
|
||||
print("✓ Cleared")
|
||||
else:
|
||||
print("Collection is already empty")
|
||||
|
||||
delete_all_documents()
|
||||
print("Cleared")
|
||||
await index()
|
||||
|
||||
|
||||
@@ -113,7 +105,7 @@ Examples:
|
||||
print("\n\nOperation cancelled by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}", file=sys.stderr)
|
||||
print(f"\nError: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
from bs4 import BeautifulSoup
|
||||
import chromadb
|
||||
import httpx
|
||||
|
||||
client = chromadb.PersistentClient(path="/Users/ryanchen/Programs/raggr/chromadb")
|
||||
|
||||
# Scrape
|
||||
BASE_URL = "https://www.vet.cornell.edu"
|
||||
LIST_URL = "/departments-centers-and-institutes/cornell-feline-health-center/health-information/feline-health-topics"
|
||||
|
||||
QUERY_URL = BASE_URL + LIST_URL
|
||||
r = httpx.get(QUERY_URL)
|
||||
soup = BeautifulSoup(r.text)
|
||||
|
||||
container = soup.find("div", class_="field-body")
|
||||
a_s = container.find_all("a", href=True)
|
||||
|
||||
new_texts = []
|
||||
|
||||
for link in a_s:
|
||||
endpoint = link["href"]
|
||||
query_url = BASE_URL + endpoint
|
||||
r2 = httpx.get(query_url)
|
||||
article_soup = BeautifulSoup(r2.text)
|
||||
@@ -1,9 +1,6 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "Initializing directories..."
|
||||
mkdir -p /app/data/chromadb
|
||||
|
||||
echo "Rebuilding frontend..."
|
||||
cd /app/raggr-frontend
|
||||
yarn build
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
"""Tests for text preprocessing functions in utils/chunker.py."""
|
||||
|
||||
from utils.chunker import (
|
||||
remove_headers_footers,
|
||||
remove_special_characters,
|
||||
remove_repeated_substrings,
|
||||
remove_extra_spaces,
|
||||
preprocess_text,
|
||||
)
|
||||
|
||||
|
||||
class TestRemoveHeadersFooters:
|
||||
def test_removes_default_header(self):
|
||||
text = "Header Line\nActual content here"
|
||||
result = remove_headers_footers(text)
|
||||
assert "Header" not in result
|
||||
assert "Actual content here" in result
|
||||
|
||||
def test_removes_default_footer(self):
|
||||
text = "Actual content\nFooter Line"
|
||||
result = remove_headers_footers(text)
|
||||
assert "Footer" not in result
|
||||
assert "Actual content" in result
|
||||
|
||||
def test_custom_patterns(self):
|
||||
text = "PAGE 1\nContent\nCopyright 2024"
|
||||
result = remove_headers_footers(
|
||||
text,
|
||||
header_patterns=[r"^PAGE \d+$"],
|
||||
footer_patterns=[r"^Copyright.*$"],
|
||||
)
|
||||
assert "PAGE 1" not in result
|
||||
assert "Copyright" not in result
|
||||
assert "Content" in result
|
||||
|
||||
def test_no_match_preserves_text(self):
|
||||
text = "Just normal content"
|
||||
result = remove_headers_footers(text)
|
||||
assert result == "Just normal content"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert remove_headers_footers("") == ""
|
||||
|
||||
|
||||
class TestRemoveSpecialCharacters:
|
||||
def test_removes_special_chars(self):
|
||||
text = "Hello @world #test $100"
|
||||
result = remove_special_characters(text)
|
||||
assert "@" not in result
|
||||
assert "#" not in result
|
||||
assert "$" not in result
|
||||
|
||||
def test_preserves_allowed_chars(self):
|
||||
text = "Hello, world! How's it going? Yes-no."
|
||||
result = remove_special_characters(text)
|
||||
assert "," in result
|
||||
assert "!" in result
|
||||
assert "'" in result
|
||||
assert "?" in result
|
||||
assert "-" in result
|
||||
assert "." in result
|
||||
|
||||
def test_custom_pattern(self):
|
||||
text = "keep @this but not #that"
|
||||
result = remove_special_characters(text, special_chars=r"[#]")
|
||||
assert "@this" in result
|
||||
assert "#" not in result
|
||||
|
||||
def test_empty_string(self):
|
||||
assert remove_special_characters("") == ""
|
||||
|
||||
|
||||
class TestRemoveRepeatedSubstrings:
|
||||
def test_collapses_dots(self):
|
||||
text = "Item.....Value"
|
||||
result = remove_repeated_substrings(text)
|
||||
assert result == "Item.Value"
|
||||
|
||||
def test_single_dot_preserved(self):
|
||||
text = "End of sentence."
|
||||
result = remove_repeated_substrings(text)
|
||||
assert result == "End of sentence."
|
||||
|
||||
def test_custom_pattern(self):
|
||||
text = "hello---world"
|
||||
result = remove_repeated_substrings(text, pattern=r"-{2,}")
|
||||
# Function always replaces matched pattern with "."
|
||||
assert result == "hello.world"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert remove_repeated_substrings("") == ""
|
||||
|
||||
|
||||
class TestRemoveExtraSpaces:
|
||||
def test_collapses_multiple_blank_lines(self):
|
||||
text = "Line 1\n\n\n\nLine 2"
|
||||
result = remove_extra_spaces(text)
|
||||
# After collapsing newlines to \n\n, then \s+ collapses everything to single spaces
|
||||
assert "\n\n\n" not in result
|
||||
|
||||
def test_collapses_multiple_spaces(self):
|
||||
text = "Hello world"
|
||||
result = remove_extra_spaces(text)
|
||||
assert result == "Hello world"
|
||||
|
||||
def test_strips_whitespace(self):
|
||||
text = " Hello world "
|
||||
result = remove_extra_spaces(text)
|
||||
assert result == "Hello world"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert remove_extra_spaces("") == ""
|
||||
|
||||
|
||||
class TestPreprocessText:
|
||||
def test_full_pipeline(self):
|
||||
text = "Header Info\nHello @world... with spaces\nFooter Info"
|
||||
result = preprocess_text(text)
|
||||
assert "Header" not in result
|
||||
assert "Footer" not in result
|
||||
assert "@" not in result
|
||||
assert "..." not in result
|
||||
assert " " not in result
|
||||
|
||||
def test_preserves_meaningful_content(self):
|
||||
text = "The cat weighs 10 pounds."
|
||||
result = preprocess_text(text)
|
||||
assert "cat" in result
|
||||
assert "10" in result
|
||||
assert "pounds" in result
|
||||
|
||||
def test_empty_string(self):
|
||||
assert preprocess_text("") == ""
|
||||
|
||||
def test_already_clean(self):
|
||||
text = "Simple clean text here."
|
||||
result = preprocess_text(text)
|
||||
assert "Simple" in result
|
||||
assert "clean" in result
|
||||
@@ -1,137 +0,0 @@
|
||||
import os
|
||||
from math import ceil
|
||||
import re
|
||||
from typing import Union
|
||||
from uuid import UUID, uuid4
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
from llm import LLMClient
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
|
||||
if header_patterns is None:
|
||||
header_patterns = [r"^.*Header.*$"]
|
||||
if footer_patterns is None:
|
||||
footer_patterns = [r"^.*Footer.*$"]
|
||||
|
||||
for pattern in header_patterns + footer_patterns:
|
||||
text = re.sub(pattern, "", text, flags=re.MULTILINE)
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
def remove_special_characters(text, special_chars=None):
|
||||
if special_chars is None:
|
||||
special_chars = r"[^A-Za-z0-9\s\.,;:\'\"\?\!\-]"
|
||||
|
||||
text = re.sub(special_chars, "", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def remove_repeated_substrings(text, pattern=r"\.{2,}"):
|
||||
text = re.sub(pattern, ".", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def remove_extra_spaces(text):
|
||||
text = re.sub(r"\n\s*\n", "\n\n", text)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
def preprocess_text(text):
|
||||
# Remove headers and footers
|
||||
text = remove_headers_footers(text)
|
||||
|
||||
# Remove special characters
|
||||
text = remove_special_characters(text)
|
||||
|
||||
# Remove repeated substrings like dots
|
||||
text = remove_repeated_substrings(text)
|
||||
|
||||
# Remove extra spaces between lines and within lines
|
||||
text = remove_extra_spaces(text)
|
||||
|
||||
# Additional cleaning steps can be added here
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
class Chunk:
|
||||
def __init__(
|
||||
self,
|
||||
text: str,
|
||||
size: int,
|
||||
document_id: UUID,
|
||||
chunk_id: int,
|
||||
embedding,
|
||||
):
|
||||
self.text = text
|
||||
self.size = size
|
||||
self.document_id = document_id
|
||||
self.chunk_id = chunk_id
|
||||
self.embedding = embedding
|
||||
|
||||
|
||||
class Chunker:
|
||||
def __init__(self, collection) -> None:
|
||||
self.collection = collection
|
||||
self.llm_client = LLMClient()
|
||||
|
||||
def embedding_fx(self, inputs):
|
||||
openai_embedding_fx = OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model_name="text-embedding-3-small",
|
||||
)
|
||||
return openai_embedding_fx(inputs)
|
||||
|
||||
def chunk_document(
|
||||
self,
|
||||
document: str,
|
||||
chunk_size: int = 1000,
|
||||
metadata: dict[str, Union[str, float]] = {},
|
||||
) -> list[Chunk]:
|
||||
doc_uuid = uuid4()
|
||||
|
||||
chunk_size = min(chunk_size, len(document)) or 1
|
||||
|
||||
chunks = []
|
||||
num_chunks = ceil(len(document) / chunk_size)
|
||||
document_length = len(document)
|
||||
|
||||
for i in range(num_chunks):
|
||||
curr_pos = i * num_chunks
|
||||
to_pos = (
|
||||
curr_pos + chunk_size
|
||||
if curr_pos + chunk_size < document_length
|
||||
else document_length
|
||||
)
|
||||
text_chunk = self.clean_document(document[curr_pos:to_pos])
|
||||
|
||||
embedding = self.embedding_fx([text_chunk])
|
||||
self.collection.add(
|
||||
ids=[str(doc_uuid) + ":" + str(i)],
|
||||
documents=[text_chunk],
|
||||
embeddings=embedding,
|
||||
metadatas=[metadata],
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
def clean_document(self, document: str) -> str:
|
||||
"""This function will remove information that is noise or already known.
|
||||
|
||||
Example: We already know all the things in here are Simba-related, so we don't need things like
|
||||
"Sumamry of simba's visit"
|
||||
"""
|
||||
|
||||
document = document.replace("\\n", "")
|
||||
document = document.strip()
|
||||
|
||||
return preprocess_text(document)
|
||||
Reference in New Issue
Block a user