Compare commits

..

2 Commits

Author SHA1 Message Date
Ryan Chen 64dab18428 Clean up presigned URL implementation: remove dead fields, fix error handling
- Remove unused image_url from upload response and TS type
- Remove bare except in serve_image that masked config errors as 404s
- Add error state and broken-image placeholder in QuestionBubble

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-04 08:49:01 -04:00
Ryan Chen b62a8b6b3f Use presigned S3 URLs for serving images instead of proxying bytes
Browser <img> tags can't attach JWT headers, causing 401s. The image
endpoint now returns a time-limited presigned S3 URL via authenticated
API call, which the frontend fetches and uses directly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-04 08:45:35 -04:00
50 changed files with 1963 additions and 2160 deletions
+4 -5
View File
@@ -19,11 +19,10 @@ BASE_URL=192.168.1.5:8000
LLAMA_SERVER_URL=http://192.168.1.213:8080/v1 LLAMA_SERVER_URL=http://192.168.1.213:8080/v1
LLAMA_MODEL_NAME=llama-3.1-8b-instruct LLAMA_MODEL_NAME=llama-3.1-8b-instruct
# Embedding Server Configuration # ChromaDB Configuration
# If set, uses a custom OpenAI-compatible embedding server (e.g. llama-server) # For Docker: This is automatically set to /app/data/chromadb
# Falls back to OpenAI embeddings if not set # For local development: Set to a local directory path
EMBEDDING_SERVER_URL=http://192.168.1.7:8086/v1 CHROMADB_PATH=./data/chromadb
EMBEDDING_MODEL_NAME=all-minilm
# OpenAI Configuration # OpenAI Configuration
OPENAI_API_KEY=your-openai-api-key OPENAI_API_KEY=your-openai-api-key
+3
View File
@@ -13,6 +13,9 @@ wheels/
.env .env
# Database files # Database files
chromadb/
chromadb_openai/
chroma_db/
database/ database/
*.db *.db
+4 -12
View File
@@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
## Project Overview ## Project Overview
SimbaRAG is a RAG (Retrieval-Augmented Generation) conversational AI system for querying information about Simba (a cat). It ingests documents from Paperless-NGX, stores embeddings in PostgreSQL via pgvector, and uses LLMs (Ollama or OpenAI) to answer questions. SimbaRAG is a RAG (Retrieval-Augmented Generation) conversational AI system for querying information about Simba (a cat). It ingests documents from Paperless-NGX, stores embeddings in ChromaDB, and uses LLMs (Ollama or OpenAI) to answer questions.
## Commands ## Commands
@@ -54,8 +54,9 @@ docker compose up -d
│ Docker Compose │ │ Docker Compose │
├─────────────────────────────────────────────────────────────┤ ├─────────────────────────────────────────────────────────────┤
│ raggr (port 8080) │ postgres (port 5432) │ │ raggr (port 8080) │ postgres (port 5432) │
│ ├── Quart backend │ PostgreSQL 16 + pgvector │ ├── Quart backend │ PostgreSQL 16
── React frontend (served) │ │ ── React frontend (served) │ │
│ └── ChromaDB (volume) │ │
└─────────────────────────────────────────────────────────────┘ └─────────────────────────────────────────────────────────────┘
``` ```
@@ -90,15 +91,6 @@ docker compose up -d
**Auth Flow**: LLDAP → Authelia (OIDC) → Backend JWT → Frontend localStorage **Auth Flow**: LLDAP → Authelia (OIDC) → Backend JWT → Frontend localStorage
## Testing
Always run `make test` before pushing code to ensure all tests pass.
```bash
make test # Run tests
make test-cov # Run tests with coverage
```
## Key Patterns ## Key Patterns
- All endpoints are async (`async def`) - All endpoints are async (`async def`)
+3 -2
View File
@@ -37,14 +37,15 @@ WORKDIR /app/raggr-frontend
RUN yarn install && yarn build RUN yarn install && yarn build
WORKDIR /app WORKDIR /app
# Create database directory # Create ChromaDB and database directories
RUN mkdir -p /app/database RUN mkdir -p /app/chromadb /app/database
# Expose port # Expose port
EXPOSE 8080 EXPOSE 8080
# Set environment variables # Set environment variables
ENV PYTHONPATH=/app ENV PYTHONPATH=/app
ENV CHROMADB_PATH=/app/chromadb
# Run the startup script # Run the startup script
CMD ["./startup.sh"] CMD ["./startup.sh"]
+3 -2
View File
@@ -34,15 +34,16 @@ COPY . .
WORKDIR /app/raggr-frontend WORKDIR /app/raggr-frontend
RUN yarn build RUN yarn build
# Create database directory # Create ChromaDB and database directories
WORKDIR /app WORKDIR /app
RUN mkdir -p /app/database RUN mkdir -p /app/chromadb /app/database
# Make startup script executable # Make startup script executable
RUN chmod +x /app/startup-dev.sh RUN chmod +x /app/startup-dev.sh
# Set environment variables # Set environment variables
ENV PYTHONPATH=/app ENV PYTHONPATH=/app
ENV CHROMADB_PATH=/app/chromadb
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
# Expose port # Expose port
+1 -11
View File
@@ -1,11 +1,8 @@
.PHONY: deploy redeploy build up down restart logs migrate migrate-new frontend test .PHONY: deploy build up down restart logs migrate migrate-new frontend
# Build and deploy # Build and deploy
deploy: build up deploy: build up
redeploy:
git pull && $(MAKE) down && $(MAKE) up
build: build:
docker compose build raggr docker compose build raggr
@@ -32,13 +29,6 @@ migrate-new:
migrate-history: migrate-history:
docker compose exec raggr aerich history docker compose exec raggr aerich history
# Tests
test:
pytest tests/ -v
test-cov:
pytest tests/ -v --cov
# Frontend # Frontend
frontend: frontend:
cd raggr-frontend && yarn install && yarn build cd raggr-frontend && yarn install && yarn build
+43 -5
View File
@@ -1,9 +1,8 @@
import logging import logging
import os import os
from datetime import timedelta
from dotenv import load_dotenv from dotenv import load_dotenv
from quart import Quart, jsonify, render_template, send_from_directory from quart import Quart, jsonify, render_template, request, send_from_directory
from quart_jwt_extended import JWTManager, get_jwt_identity, jwt_refresh_token_required from quart_jwt_extended import JWTManager, get_jwt_identity, jwt_refresh_token_required
from tortoise import Tortoise from tortoise import Tortoise
@@ -15,6 +14,7 @@ import blueprints.users
import blueprints.whatsapp import blueprints.whatsapp
import blueprints.users.models import blueprints.users.models
from config.db import TORTOISE_CONFIG from config.db import TORTOISE_CONFIG
from main import consult_simba_oracle
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
@@ -38,8 +38,6 @@ app = Quart(
) )
app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY") app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY")
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1)
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=30)
app.config["MAX_CONTENT_LENGTH"] = 10 * 1024 * 1024 # 10 MB upload limit app.config["MAX_CONTENT_LENGTH"] = 10 * 1024 * 1024 # 10 MB upload limit
jwt = JWTManager(app) jwt = JWTManager(app)
@@ -77,6 +75,39 @@ async def serve_react_app(path):
return await render_template("index.html") return await render_template("index.html")
@app.route("/api/query", methods=["POST"])
@jwt_refresh_token_required
async def query():
current_user_uuid = get_jwt_identity()
user = await blueprints.users.models.User.get(id=current_user_uuid)
data = await request.get_json()
query = data.get("query")
conversation_id = data.get("conversation_id")
conversation = await blueprints.conversation.logic.get_conversation_by_id(
conversation_id
)
await conversation.fetch_related("messages")
await blueprints.conversation.logic.add_message_to_conversation(
conversation=conversation,
message=query,
speaker="user",
user=user,
)
transcript = await blueprints.conversation.logic.get_conversation_transcript(
user=user, conversation=conversation
)
response = consult_simba_oracle(input=query, transcript=transcript)
await blueprints.conversation.logic.add_message_to_conversation(
conversation=conversation,
message=response,
speaker="simba",
user=user,
)
return jsonify({"response": response})
@app.route("/api/messages", methods=["GET"]) @app.route("/api/messages", methods=["GET"])
@jwt_refresh_token_required @jwt_refresh_token_required
async def get_messages(): async def get_messages():
@@ -101,10 +132,17 @@ async def get_messages():
} }
) )
name = conversation.name
if len(messages) > 8:
name = await blueprints.conversation.logic.rename_conversation(
user=user,
conversation=conversation,
)
return jsonify( return jsonify(
{ {
"id": str(conversation.id), "id": str(conversation.id),
"name": conversation.name, "name": name,
"messages": messages, "messages": messages,
"created_at": conversation.created_at.isoformat(), "created_at": conversation.created_at.isoformat(),
"updated_at": conversation.updated_at.isoformat(), "updated_at": conversation.updated_at.isoformat(),
+21 -30
View File
@@ -1,3 +1,4 @@
import datetime
import json import json
import logging import logging
import uuid import uuid
@@ -19,8 +20,8 @@ from .agents import main_agent
from .logic import ( from .logic import (
add_message_to_conversation, add_message_to_conversation,
get_conversation_by_id, get_conversation_by_id,
rename_conversation,
) )
from .memory import get_memories_for_user
from .models import ( from .models import (
Conversation, Conversation,
PydConversation, PydConversation,
@@ -35,27 +36,15 @@ conversation_blueprint = Blueprint(
_SYSTEM_PROMPT = SIMBA_SYSTEM_PROMPT _SYSTEM_PROMPT = SIMBA_SYSTEM_PROMPT
async def _build_system_prompt_with_memories(user_id: str) -> str:
"""Append user memories to the base system prompt."""
memories = await get_memories_for_user(user_id)
if not memories:
return _SYSTEM_PROMPT
memory_block = "\n".join(f"- {m}" for m in memories)
return f"{_SYSTEM_PROMPT}\n\nUSER MEMORIES (facts the user has asked you to remember):\n{memory_block}"
def _build_messages_payload( def _build_messages_payload(
conversation, conversation, query_text: str, image_description: str | None = None
query_text: str,
image_description: str | None = None,
system_prompt: str | None = None,
) -> list: ) -> list:
recent_messages = ( recent_messages = (
conversation.messages[-10:] conversation.messages[-10:]
if len(conversation.messages) > 10 if len(conversation.messages) > 10
else conversation.messages else conversation.messages
) )
messages_payload = [{"role": "system", "content": system_prompt or _SYSTEM_PROMPT}] messages_payload = [{"role": "system", "content": _SYSTEM_PROMPT}]
for msg in recent_messages[:-1]: # Exclude the message we just added for msg in recent_messages[:-1]: # Exclude the message we just added
role = "user" if msg.speaker == "user" else "assistant" role = "user" if msg.speaker == "user" else "assistant"
text = msg.text text = msg.text
@@ -91,14 +80,10 @@ async def query():
user=user, user=user,
) )
system_prompt = await _build_system_prompt_with_memories(str(user.id)) messages_payload = _build_messages_payload(conversation, query)
messages_payload = _build_messages_payload(
conversation, query, system_prompt=system_prompt
)
payload = {"messages": messages_payload} payload = {"messages": messages_payload}
agent_config = {"configurable": {"user_id": str(user.id)}}
response = await main_agent.ainvoke(payload, config=agent_config) response = await main_agent.ainvoke(payload)
message = response.get("messages", [])[-1].content message = response.get("messages", [])[-1].content
await add_message_to_conversation( await add_message_to_conversation(
conversation=conversation, conversation=conversation,
@@ -178,19 +163,15 @@ async def stream_query():
logging.error(f"Failed to analyze image: {e}") logging.error(f"Failed to analyze image: {e}")
image_description = "[Image could not be analyzed]" image_description = "[Image could not be analyzed]"
system_prompt = await _build_system_prompt_with_memories(str(user.id))
messages_payload = _build_messages_payload( messages_payload = _build_messages_payload(
conversation, query_text or "", image_description, system_prompt=system_prompt conversation, query_text or "", image_description
) )
payload = {"messages": messages_payload} payload = {"messages": messages_payload}
agent_config = {"configurable": {"user_id": str(user.id)}}
async def event_generator(): async def event_generator():
final_message = None final_message = None
try: try:
async for event in main_agent.astream_events( async for event in main_agent.astream_events(payload, version="v2"):
payload, version="v2", config=agent_config
):
event_type = event.get("event") event_type = event.get("event")
if event_type == "on_tool_start": if event_type == "on_tool_start":
@@ -240,6 +221,8 @@ async def stream_query():
@jwt_refresh_token_required @jwt_refresh_token_required
async def get_conversation(conversation_id: str): async def get_conversation(conversation_id: str):
conversation = await Conversation.get(id=conversation_id) conversation = await Conversation.get(id=conversation_id)
current_user_uuid = get_jwt_identity()
user = await blueprints.users.models.User.get(id=current_user_uuid)
await conversation.fetch_related("messages") await conversation.fetch_related("messages")
# Manually serialize the conversation with messages # Manually serialize the conversation with messages
@@ -254,10 +237,18 @@ async def get_conversation(conversation_id: str):
"image_key": msg.image_key, "image_key": msg.image_key,
} }
) )
name = conversation.name
if len(messages) > 8 and "datetime" in name.lower():
name = await rename_conversation(
user=user,
conversation=conversation,
)
print(name)
return jsonify( return jsonify(
{ {
"id": str(conversation.id), "id": str(conversation.id),
"name": conversation.name, "name": name,
"messages": messages, "messages": messages,
"created_at": conversation.created_at.isoformat(), "created_at": conversation.created_at.isoformat(),
"updated_at": conversation.updated_at.isoformat(), "updated_at": conversation.updated_at.isoformat(),
@@ -271,7 +262,7 @@ async def create_conversation():
user_uuid = get_jwt_identity() user_uuid = get_jwt_identity()
user = await blueprints.users.models.User.get(id=user_uuid) user = await blueprints.users.models.User.get(id=user_uuid)
conversation = await Conversation.create( conversation = await Conversation.create(
name="New Conversation", name=f"{user.username} {datetime.datetime.now().timestamp}",
user=user, user=user,
) )
@@ -284,7 +275,7 @@ async def create_conversation():
async def get_all_conversations(): async def get_all_conversations():
user_uuid = get_jwt_identity() user_uuid = get_jwt_identity()
user = await blueprints.users.models.User.get(id=user_uuid) user = await blueprints.users.models.User.get(id=user_uuid)
conversations = Conversation.filter(user=user).order_by("-updated_at") conversations = Conversation.filter(user=user)
serialized_conversations = await PydListConversation.from_queryset(conversations) serialized_conversations = await PydListConversation.from_queryset(conversations)
return jsonify(serialized_conversations.model_dump()) return jsonify(serialized_conversations.model_dump())
+2 -31
View File
@@ -5,11 +5,9 @@ from dotenv import load_dotenv
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain.chat_models import BaseChatModel from langchain.chat_models import BaseChatModel
from langchain.tools import tool from langchain.tools import tool
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from tavily import AsyncTavilyClient from tavily import AsyncTavilyClient
from blueprints.conversation.memory import save_memory
from blueprints.rag.logic import query_vector_store from blueprints.rag.logic import query_vector_store
from utils.obsidian_service import ObsidianService from utils.obsidian_service import ObsidianService
from utils.ynab_service import YNABService from utils.ynab_service import YNABService
@@ -328,7 +326,7 @@ async def obsidian_search_notes(query: str) -> str:
return "Obsidian integration is not configured. Please set OBSIDIAN_VAULT_PATH environment variable." return "Obsidian integration is not configured. Please set OBSIDIAN_VAULT_PATH environment variable."
try: try:
# Query vector store for obsidian documents # Query ChromaDB for obsidian documents
serialized, docs = await query_vector_store(query=query) serialized, docs = await query_vector_store(query=query)
return serialized return serialized
@@ -591,35 +589,8 @@ async def obsidian_create_task(
return f"Error creating task: {str(e)}" return f"Error creating task: {str(e)}"
@tool
async def save_user_memory(content: str, config: RunnableConfig) -> str:
"""Save a fact or preference about the user for future conversations.
Use this tool when the user:
- Explicitly asks you to remember something ("remember that...", "keep in mind...")
- Shares a personal preference that would be useful in future conversations
(e.g., "I prefer metric units", "my cat's name is Luna")
- Tells you a meaningful personal fact (e.g., "I'm allergic to peanuts")
Do NOT save:
- Trivial or ephemeral info (e.g., "I'm tired today")
- Information already in the system prompt or documents
- Conversation-specific context that won't matter later
Args:
content: A concise statement of the fact or preference to remember.
Write it as a standalone sentence (e.g., "User prefers dark mode"
rather than "likes dark mode").
Returns:
Confirmation that the memory was saved.
"""
user_id = config["configurable"]["user_id"]
return await save_memory(user_id=user_id, content=content)
# Create tools list based on what's available # Create tools list based on what's available
tools = [get_current_date, simba_search, web_search, save_user_memory] tools = [get_current_date, simba_search, web_search]
if ynab_enabled: if ynab_enabled:
tools.extend( tools.extend(
[ [
+21 -7
View File
@@ -1,8 +1,9 @@
import tortoise.exceptions import tortoise.exceptions
from langchain_openai import ChatOpenAI
import blueprints.users.models import blueprints.users.models
from .models import Conversation, ConversationMessage from .models import Conversation, ConversationMessage, RenameConversationOutputSchema
async def create_conversation(name: str = "") -> Conversation: async def create_conversation(name: str = "") -> Conversation:
@@ -18,12 +19,6 @@ async def add_message_to_conversation(
image_key: str | None = None, image_key: str | None = None,
) -> ConversationMessage: ) -> ConversationMessage:
print(conversation, message, speaker) print(conversation, message, speaker)
# Name the conversation after the first user message
if speaker == "user" and not await conversation.messages.all().exists():
conversation.name = message[:100]
await conversation.save()
message = await ConversationMessage.create( message = await ConversationMessage.create(
text=message, text=message,
speaker=speaker, speaker=speaker,
@@ -66,3 +61,22 @@ async def get_conversation_transcript(
messages.append(f"{message.speaker} at {message.created_at}: {message.text}") messages.append(f"{message.speaker} at {message.created_at}: {message.text}")
return "\n".join(messages) return "\n".join(messages)
async def rename_conversation(
user: blueprints.users.models.User,
conversation: Conversation,
) -> str:
messages: str = await get_conversation_transcript(
user=user, conversation=conversation
)
llm = ChatOpenAI(model="gpt-4o-mini")
structured_llm = llm.with_structured_output(RenameConversationOutputSchema)
prompt = f"Summarize the following conversation into a sassy one-liner title:\n\n{messages}"
response = structured_llm.invoke(prompt)
new_name: str = response.get("title", "")
conversation.name = new_name
await conversation.save()
return new_name
-19
View File
@@ -1,19 +0,0 @@
from .models import UserMemory
async def get_memories_for_user(user_id: str) -> list[str]:
"""Return all memory content strings for a user, ordered by most recently updated."""
memories = await UserMemory.filter(user_id=user_id).order_by("-updated_at")
return [m.content for m in memories]
async def save_memory(user_id: str, content: str) -> str:
"""Save a new memory or touch an existing one (exact-match dedup)."""
existing = await UserMemory.filter(user_id=user_id, content=content).first()
if existing:
existing.updated_at = None # auto_now=True will refresh it on save
await existing.save(update_fields=["updated_at"])
return "Memory already exists (refreshed)."
await UserMemory.create(user_id=user_id, content=content)
return "Memory saved."
+7 -11
View File
@@ -1,4 +1,5 @@
import enum import enum
from dataclasses import dataclass
from tortoise import fields from tortoise import fields
from tortoise.contrib.pydantic import ( from tortoise.contrib.pydantic import (
@@ -8,6 +9,12 @@ from tortoise.contrib.pydantic import (
from tortoise.models import Model from tortoise.models import Model
@dataclass
class RenameConversationOutputSchema:
title: str
justification: str
class Speaker(enum.Enum): class Speaker(enum.Enum):
USER = "user" USER = "user"
SIMBA = "simba" SIMBA = "simba"
@@ -40,17 +47,6 @@ class ConversationMessage(Model):
table = "conversation_messages" table = "conversation_messages"
class UserMemory(Model):
id = fields.UUIDField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="memories")
content = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True)
class Meta:
table = "user_memories"
PydConversationMessage = pydantic_model_creator(ConversationMessage) PydConversationMessage = pydantic_model_creator(ConversationMessage)
PydConversation = pydantic_model_creator( PydConversation = pydantic_model_creator(
Conversation, name="Conversation", allow_cycles=True, exclude=("user",) Conversation, name="Conversation", allow_cycles=True, exclude=("user",)
+2 -5
View File
@@ -1,4 +1,4 @@
SIMBA_SYSTEM_PROMPT = """You are Simba, Ryan's helpful personal assistant. You're named after his orange cat. You have a warm, friendly personality with a light cat-themed touch, but your priority is always being genuinely useful — give thorough, detailed answers and think things through carefully. When asked about Simba the cat, you speak as him in first person. For everything else, you're just a great assistant who happens to have a cat's name. SIMBA_SYSTEM_PROMPT = """You are a helpful cat assistant named Simba that understands veterinary terms. When there are questions to you specifically, they are referring to Simba the cat. Answer the user in as if you were a cat named Simba. Don't act too catlike. Be assertive.
SIMBA FACTS (as of January 2026): SIMBA FACTS (as of January 2026):
- Name: Simba - Name: Simba
@@ -54,7 +54,4 @@ You have access to Ryan's daily journal notes. Each note lives at journal/YYYY/Y
- Use journal_get_tasks to list tasks (done/pending) for today or a specific date - Use journal_get_tasks to list tasks (done/pending) for today or a specific date
- Use journal_add_task to add a new task to today's (or a given date's) note - Use journal_add_task to add a new task to today's (or a given date's) note
- Use journal_complete_task to check off a task as done - Use journal_complete_task to check off a task as done
Use these tools when Ryan asks about today's tasks, wants to add something to his list, or wants to mark a task complete. Use these tools when Ryan asks about today's tasks, wants to add something to his list, or wants to mark a task complete."""
USER MEMORY:
You can remember facts about the user across conversations using the save_user_memory tool. When a user explicitly asks you to remember something, or shares a meaningful preference or personal fact, save it. Saved memories will automatically appear at the end of this prompt in future conversations under "USER MEMORIES"."""
+9 -7
View File
@@ -1,12 +1,7 @@
from quart import Blueprint, jsonify from quart import Blueprint, jsonify
from quart_jwt_extended import jwt_refresh_token_required from quart_jwt_extended import jwt_refresh_token_required
from .logic import ( from .logic import fetch_obsidian_documents, get_vector_store_stats, index_documents, index_obsidian_documents, vector_store
delete_all_documents,
get_vector_store_stats,
index_documents,
index_obsidian_documents,
)
from blueprints.users.decorators import admin_required from blueprints.users.decorators import admin_required
rag_blueprint = Blueprint("rag_api", __name__, url_prefix="/api/rag") rag_blueprint = Blueprint("rag_api", __name__, url_prefix="/api/rag")
@@ -37,7 +32,14 @@ async def trigger_index():
async def trigger_reindex(): async def trigger_reindex():
"""Clear and reindex all documents. Admin only.""" """Clear and reindex all documents. Admin only."""
try: try:
delete_all_documents() # Clear existing documents
collection = vector_store._collection
all_docs = collection.get()
if all_docs["ids"]:
collection.delete(ids=all_docs["ids"])
# Reindex
await index_documents() await index_documents()
stats = get_vector_store_stats() stats = get_vector_store_stats()
return jsonify({"status": "success", "stats": stats}) return jsonify({"status": "success", "stats": stats})
+28 -167
View File
@@ -1,14 +1,11 @@
import datetime import datetime
import logging
import os import os
import re
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain_chroma import Chroma
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAIEmbeddings
from langchain_postgres import PGVector
from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter
from sqlalchemy import create_engine, text
from .fetchers import PaperlessNGXService from .fetchers import PaperlessNGXService
from utils.obsidian_service import ObsidianService from utils.obsidian_service import ObsidianService
@@ -16,51 +13,13 @@ from utils.obsidian_service import ObsidianService
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
logger = logging.getLogger(__name__) embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
_embedding_server_url = os.getenv("EMBEDDING_SERVER_URL") vector_store = Chroma(
_embedding_model = os.getenv("EMBEDDING_MODEL_NAME", "text-embedding-3-small")
if _embedding_server_url:
embeddings = OpenAIEmbeddings(
model=_embedding_model,
base_url=_embedding_server_url,
api_key="not-needed",
check_embedding_ctx_length=False,
)
else:
embeddings = OpenAIEmbeddings(model=_embedding_model)
# Convert Tortoise-style postgres:// URL to SQLAlchemy-style postgresql+psycopg://
_db_url = os.getenv(
"DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr"
)
_pgvector_url = _db_url.replace("postgres://", "postgresql+psycopg://")
# Lazy-initialized vector store (defers DB connection to first use)
_vector_store = None
def _get_vector_store() -> PGVector:
global _vector_store
if _vector_store is None:
_vector_store = PGVector(
embeddings=embeddings,
collection_name="simba_docs", collection_name="simba_docs",
connection=_pgvector_url, embedding_function=embeddings,
use_jsonb=True, persist_directory=os.getenv("CHROMADB_PATH", ""),
create_extension=False, # created by docker init script )
async_mode=True,
)
return _vector_store
def _get_engine():
"""Get a SQLAlchemy engine for direct queries."""
if not hasattr(_get_engine, "_engine"):
_get_engine._engine = create_engine(_pgvector_url)
return _get_engine._engine
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, # chunk size (characters) chunk_size=1000, # chunk size (characters)
@@ -69,22 +28,6 @@ text_splitter = RecursiveCharacterTextSplitter(
) )
def _get_collection_id():
"""Get the UUID of our collection from the langchain_pg_collection table."""
engine = _get_engine()
try:
with engine.connect() as conn:
result = conn.execute(
text("SELECT uuid FROM langchain_pg_collection WHERE name = :name"),
{"name": "simba_docs"},
)
row = result.fetchone()
return row[0] if row else None
except Exception:
# Table doesn't exist yet (first run before any indexing)
return None
def date_to_epoch(date_str: str) -> float: def date_to_epoch(date_str: str) -> float:
split_date = date_str.split("-") split_date = date_str.split("-")
date = datetime.datetime( date = datetime.datetime(
@@ -115,43 +58,12 @@ async def fetch_documents_from_paperless_ngx() -> list[Document]:
return documents return documents
def _sanitize_text(text_content: str) -> str:
"""Strip non-printable and invalid characters that break embedding tokenizers."""
# Remove null bytes and control characters (keep newlines and tabs)
text_content = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f]", "", text_content)
# Remove Unicode surrogates and other problematic Unicode
text_content = re.sub(r"[\ud800-\udfff\ufffe\uffff]", "", text_content)
# Remove replacement character clusters
text_content = text_content.replace("\ufffd", "")
# Collapse excessive whitespace
text_content = re.sub(r" {3,}", " ", text_content)
return text_content.strip()
def _sanitize_documents(documents: list[Document]) -> list[Document]:
"""Sanitize page_content of all documents for embedding compatibility."""
for doc in documents:
doc.page_content = _sanitize_text(doc.page_content)
return [doc for doc in documents if doc.page_content]
async def index_documents(): async def index_documents():
"""Index Paperless-NGX documents into vector store.""" """Index Paperless-NGX documents into vector store."""
documents = await fetch_documents_from_paperless_ngx() documents = await fetch_documents_from_paperless_ngx()
splits = text_splitter.split_documents(documents) splits = text_splitter.split_documents(documents)
splits = _sanitize_documents(splits) await vector_store.aadd_documents(documents=splits)
logger.info(f"Indexing {len(splits)} chunks from {len(documents)} documents")
vector_store = _get_vector_store()
for i, split in enumerate(splits):
try:
await vector_store.aadd_documents(documents=[split])
except Exception as e:
logger.error(
f"Failed to embed chunk {i} from {split.metadata.get('filename', 'unknown')}: {e}"
)
logger.debug(f"Chunk content preview: {split.page_content[:200]!r}")
raise
async def fetch_obsidian_documents() -> list[Document]: async def fetch_obsidian_documents() -> list[Document]:
@@ -180,17 +92,13 @@ async def fetch_obsidian_documents() -> list[Document]:
"filepath": parsed["filepath"], "filepath": parsed["filepath"],
"tags": parsed["tags"], "tags": parsed["tags"],
"created_at": parsed["metadata"].get("created_at"), "created_at": parsed["metadata"].get("created_at"),
**{ **{k: v for k, v in parsed["metadata"].items() if k not in ["created_at", "created_by"]},
k: v
for k, v in parsed["metadata"].items()
if k not in ["created_at", "created_by"]
},
}, },
) )
documents.append(document) documents.append(document)
except Exception as e: except Exception as e:
logger.warning(f"Error reading {md_path}: {e}") print(f"Error reading {md_path}: {e}")
continue continue
return documents return documents
@@ -201,26 +109,26 @@ async def index_obsidian_documents():
Deletes existing obsidian source chunks before re-indexing. Deletes existing obsidian source chunks before re-indexing.
""" """
obsidian_service = ObsidianService()
documents = await fetch_obsidian_documents() documents = await fetch_obsidian_documents()
if not documents: if not documents:
logger.info("No Obsidian documents found to index") print("No Obsidian documents found to index")
return {"indexed": 0} return {"indexed": 0}
# Delete existing obsidian chunks # Delete existing obsidian chunks
delete_documents_by_metadata("source", "obsidian") existing_results = vector_store.get(where={"source": "obsidian"})
if existing_results.get("ids"):
await vector_store.adelete(existing_results["ids"])
# Split, sanitize, and index documents # Split and index documents
splits = text_splitter.split_documents(documents) splits = text_splitter.split_documents(documents)
splits = _sanitize_documents(splits)
vector_store = _get_vector_store()
await vector_store.aadd_documents(documents=splits) await vector_store.aadd_documents(documents=splits)
return {"indexed": len(documents)} return {"indexed": len(documents)}
async def query_vector_store(query: str): async def query_vector_store(query: str):
vector_store = _get_vector_store()
retrieved_docs = await vector_store.asimilarity_search(query, k=2) retrieved_docs = await vector_store.asimilarity_search(query, k=2)
serialized = "\n\n".join( serialized = "\n\n".join(
(f"Source: {doc.metadata}\nContent: {doc.page_content}") (f"Source: {doc.metadata}\nContent: {doc.page_content}")
@@ -229,79 +137,32 @@ async def query_vector_store(query: str):
return serialized, retrieved_docs return serialized, retrieved_docs
def delete_all_documents():
"""Delete all documents from the vector store collection."""
collection_id = _get_collection_id()
if not collection_id:
return
engine = _get_engine()
with engine.connect() as conn:
conn.execute(
text("DELETE FROM langchain_pg_embedding WHERE collection_id = :cid"),
{"cid": collection_id},
)
conn.commit()
def delete_documents_by_metadata(key: str, value: str):
"""Delete documents matching a metadata key/value pair."""
collection_id = _get_collection_id()
if not collection_id:
return
engine = _get_engine()
with engine.connect() as conn:
conn.execute(
text(
"DELETE FROM langchain_pg_embedding "
"WHERE collection_id = :cid AND cmetadata->>:key = :value"
),
{"cid": collection_id, "key": key, "value": value},
)
conn.commit()
def get_vector_store_stats(): def get_vector_store_stats():
"""Get statistics about the vector store.""" """Get statistics about the vector store."""
collection_id = _get_collection_id() collection = vector_store._collection
count = 0 count = collection.count()
if collection_id:
engine = _get_engine()
with engine.connect() as conn:
result = conn.execute(
text(
"SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_id = :cid"
),
{"cid": collection_id},
)
count = result.scalar()
return { return {
"total_documents": count, "total_documents": count,
"collection_name": "simba_docs", "collection_name": collection.name,
} }
def list_all_documents(limit: int = 10): def list_all_documents(limit: int = 10):
"""List documents in the vector store with their metadata.""" """List documents in the vector store with their metadata."""
collection_id = _get_collection_id() collection = vector_store._collection
if not collection_id: results = collection.get(limit=limit, include=["metadatas", "documents"])
return []
engine = _get_engine()
with engine.connect() as conn:
result = conn.execute(
text(
"SELECT id, document, cmetadata FROM langchain_pg_embedding "
"WHERE collection_id = :cid LIMIT :limit"
),
{"cid": collection_id, "limit": limit},
)
documents = [] documents = []
for row in result: for i, doc_id in enumerate(results["ids"]):
documents.append( documents.append(
{ {
"id": str(row[0]), "id": doc_id,
"metadata": row[2], "metadata": results["metadatas"][i]
"content_preview": row[1][:200] if row[1] else None, if results.get("metadatas")
else None,
"content_preview": results["documents"][i][:200]
if results.get("documents")
else None,
} }
) )
+3 -3
View File
@@ -35,7 +35,7 @@ class OIDCUserService:
claims.get("preferred_username") or claims.get("name") or user.username claims.get("preferred_username") or claims.get("name") or user.username
) )
# Update LDAP groups from claims # Update LDAP groups from claims
user.ldap_groups = claims.get("groups") or [] user.ldap_groups = claims.get("groups", [])
await user.save() await user.save()
return user return user
@@ -48,7 +48,7 @@ class OIDCUserService:
user.oidc_subject = oidc_subject user.oidc_subject = oidc_subject
user.auth_provider = "oidc" user.auth_provider = "oidc"
user.password = None # Clear password user.password = None # Clear password
user.ldap_groups = claims.get("groups") or [] user.ldap_groups = claims.get("groups", [])
await user.save() await user.save()
return user return user
@@ -61,7 +61,7 @@ class OIDCUserService:
) )
# Extract LDAP groups from claims # Extract LDAP groups from claims
groups = claims.get("groups") or [] groups = claims.get("groups", [])
user = await User.create( user = await User.create(
id=uuid4(), id=uuid4(),
+4 -4
View File
@@ -2,7 +2,7 @@ version: "3.8"
services: services:
postgres: postgres:
image: pgvector/pgvector:pg16 image: postgres:16-alpine
ports: ports:
- "5432:5432" - "5432:5432"
environment: environment:
@@ -11,7 +11,6 @@ services:
- POSTGRES_DB=${POSTGRES_DB:-raggr} - POSTGRES_DB=${POSTGRES_DB:-raggr}
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
- ./docker/init-pgvector.sql:/docker-entrypoint-initdb.d/init-pgvector.sql
healthcheck: healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-raggr}"] test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-raggr}"]
interval: 10s interval: 10s
@@ -30,9 +29,8 @@ services:
- PAPERLESS_TOKEN=${PAPERLESS_TOKEN} - PAPERLESS_TOKEN=${PAPERLESS_TOKEN}
- BASE_URL=${BASE_URL} - BASE_URL=${BASE_URL}
- OLLAMA_URL=${OLLAMA_URL:-http://localhost:11434} - OLLAMA_URL=${OLLAMA_URL:-http://localhost:11434}
- CHROMADB_PATH=/app/data/chromadb
- OPENAI_API_KEY=${OPENAI_API_KEY} - OPENAI_API_KEY=${OPENAI_API_KEY}
- EMBEDDING_SERVER_URL=${EMBEDDING_SERVER_URL}
- EMBEDDING_MODEL_NAME=${EMBEDDING_MODEL_NAME}
- JWT_SECRET_KEY=${JWT_SECRET_KEY} - JWT_SECRET_KEY=${JWT_SECRET_KEY}
- LLAMA_SERVER_URL=${LLAMA_SERVER_URL} - LLAMA_SERVER_URL=${LLAMA_SERVER_URL}
- LLAMA_MODEL_NAME=${LLAMA_MODEL_NAME} - LLAMA_MODEL_NAME=${LLAMA_MODEL_NAME}
@@ -68,8 +66,10 @@ services:
postgres: postgres:
condition: service_healthy condition: service_healthy
volumes: volumes:
- chromadb_data:/app/data/chromadb
- ./obvault:/app/data/obsidian - ./obvault:/app/data/obsidian
restart: unless-stopped restart: unless-stopped
volumes: volumes:
chromadb_data:
postgres_data: postgres_data:
-1
View File
@@ -1 +0,0 @@
CREATE EXTENSION IF NOT EXISTS vector;
+278
View File
@@ -0,0 +1,278 @@
import argparse
import datetime
import logging
import os
import sqlite3
import time
from dotenv import load_dotenv
import chromadb
from utils.chunker import Chunker
from utils.cleaner import pdf_to_image, summarize_pdf_image
from llm import LLMClient
from scripts.query import QueryGenerator
from utils.request import PaperlessNGXService
_dotenv_loaded = load_dotenv()
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
simba_docs = client.get_or_create_collection(name="simba_docs2")
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
parser = argparse.ArgumentParser(
description="An LLM tool to query information about Simba <3"
)
parser.add_argument("query", type=str, help="questions about simba's health")
parser.add_argument(
"--reindex", action="store_true", help="re-index the simba documents"
)
parser.add_argument("--classify", action="store_true", help="test classification")
parser.add_argument("--index", help="index a file")
ppngx = PaperlessNGXService()
llm_client = LLMClient()
def index_using_pdf_llm(doctypes):
logging.info("reindex data...")
files = ppngx.get_data()
for file in files:
document_id: int = file["id"]
pdf_path = ppngx.download_pdf_from_id(id=document_id)
image_paths = pdf_to_image(filepath=pdf_path)
logging.info(f"summarizing {file}")
generated_summary = summarize_pdf_image(filepaths=image_paths)
file["content"] = generated_summary
chunk_data(files, simba_docs, doctypes=doctypes)
def date_to_epoch(date_str: str) -> float:
split_date = date_str.split("-")
date = datetime.datetime(
int(split_date[0]),
int(split_date[1]),
int(split_date[2]),
0,
0,
0,
)
return date.timestamp()
def chunk_data(docs, collection, doctypes):
# Step 2: Create chunks
chunker = Chunker(collection)
logging.info(f"chunking {len(docs)} documents")
texts: list[str] = [doc["content"] for doc in docs]
with sqlite3.connect("database/visited.db") as conn:
to_insert = []
c = conn.cursor()
for index, text in enumerate(texts):
metadata = {
"created_date": date_to_epoch(docs[index]["created_date"]),
"filename": docs[index]["original_file_name"],
"document_type": doctypes.get(docs[index]["document_type"], ""),
}
if doctypes:
metadata["type"] = doctypes.get(docs[index]["document_type"])
chunker.chunk_document(
document=text,
metadata=metadata,
)
to_insert.append((docs[index]["id"],))
c.executemany(
"INSERT INTO indexed_documents (paperless_id) values (?)", to_insert
)
conn.commit()
def chunk_text(texts: list[str], collection):
chunker = Chunker(collection)
for index, text in enumerate(texts):
metadata = {}
chunker.chunk_document(
document=text,
metadata=metadata,
)
def classify_query(query: str, transcript: str) -> bool:
logging.info("Starting query generation")
qg_start = time.time()
qg = QueryGenerator()
query_type = qg.get_query_type(input=query, transcript=transcript)
logging.info(query_type)
qg_end = time.time()
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
return query_type == "Simba"
def consult_oracle(
input: str,
collection,
transcript: str = "",
):
chunker = Chunker(collection)
start_time = time.time()
# Ask
logging.info("Starting query generation")
qg_start = time.time()
qg = QueryGenerator()
doctype_query = qg.get_doctype_query(input=input)
# metadata_filter = qg.get_query(input)
metadata_filter = {**doctype_query}
logging.info(metadata_filter)
qg_end = time.time()
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
logging.info("Starting embedding generation")
embedding_start = time.time()
embeddings = chunker.embedding_fx(inputs=[input])
embedding_end = time.time()
logging.info(
f"Embedding generation took {embedding_end - embedding_start:.2f} seconds"
)
logging.info("Starting collection query")
query_start = time.time()
results = collection.query(
query_texts=[input],
query_embeddings=embeddings,
where=metadata_filter,
)
query_end = time.time()
logging.info(f"Collection query took {query_end - query_start:.2f} seconds")
# Generate
logging.info("Starting LLM generation")
llm_start = time.time()
system_prompt = "You are a helpful assistant that understands veterinary terms."
transcript_prompt = f"Here is the message transcript thus far {transcript}."
prompt = f"""Using the following data, help answer the user's query by providing as many details as possible.
Using this data: {results}. {transcript_prompt if len(transcript) > 0 else ""}
Respond to this prompt: {input}"""
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
llm_end = time.time()
logging.info(f"LLM generation took {llm_end - llm_start:.2f} seconds")
total_time = time.time() - start_time
logging.info(f"Total consult_oracle execution took {total_time:.2f} seconds")
return output
def llm_chat(input: str, transcript: str = "") -> str:
system_prompt = "You are a helpful assistant that understands veterinary terms."
transcript_prompt = f"Here is the message transcript thus far {transcript}."
prompt = f"""Answer the user in as if you were a cat named Simba. Don't act too catlike. Be assertive.
{transcript_prompt if len(transcript) > 0 else ""}
Respond to this prompt: {input}"""
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
return output
def paperless_workflow(input):
# Step 1: Get the text
ppngx = PaperlessNGXService()
docs = ppngx.get_data()
chunk_data(docs, collection=simba_docs)
consult_oracle(input, simba_docs)
def consult_simba_oracle(input: str, transcript: str = ""):
is_simba_related = classify_query(query=input, transcript=transcript)
if is_simba_related:
logging.info("Query is related to simba")
return consult_oracle(
input=input,
collection=simba_docs,
transcript=transcript,
)
logging.info("Query is NOT related to simba")
return llm_chat(input=input, transcript=transcript)
def filter_indexed_files(docs):
with sqlite3.connect("database/visited.db") as conn:
c = conn.cursor()
c.execute(
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
)
c.execute("SELECT paperless_id FROM indexed_documents")
rows = c.fetchall()
conn.commit()
visited = {row[0] for row in rows}
return [doc for doc in docs if doc["id"] not in visited]
def reindex():
with sqlite3.connect("database/visited.db") as conn:
c = conn.cursor()
# Ensure the table exists before trying to delete from it
c.execute(
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
)
c.execute("DELETE FROM indexed_documents")
conn.commit()
# Delete all documents from the collection
all_docs = simba_docs.get()
if all_docs["ids"]:
simba_docs.delete(ids=all_docs["ids"])
logging.info("Fetching documents from Paperless-NGX")
ppngx = PaperlessNGXService()
docs = ppngx.get_data()
docs = filter_indexed_files(docs)
logging.info(f"Fetched {len(docs)} documents")
# Delete all chromadb data
ids = simba_docs.get(ids=None, limit=None, offset=0)
all_ids = ids["ids"]
if len(all_ids) > 0:
simba_docs.delete(ids=all_ids)
# Chunk documents
logging.info("Chunking documents now ...")
doctype_lookup = ppngx.get_doctypes()
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
logging.info("Done chunking documents")
if __name__ == "__main__":
args = parser.parse_args()
if args.reindex:
reindex()
if args.classify:
consult_simba_oracle(input="yohohoho testing")
consult_simba_oracle(input="write an email")
consult_simba_oracle(input="how much does simba weigh")
if args.query:
logging.info("Consulting oracle ...")
print(
consult_oracle(
input=args.query,
collection=simba_docs,
)
)
else:
logging.info("please provide a query")
@@ -1,112 +0,0 @@
from tortoise import BaseDBAsyncClient
RUN_IN_TRANSACTION = True
async def upgrade(db: BaseDBAsyncClient) -> str:
return """
CREATE TABLE IF NOT EXISTS "user_memories" (
"id" UUID NOT NULL PRIMARY KEY,
"content" TEXT NOT NULL,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" UUID NOT NULL REFERENCES "users" ("id") ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS "email_accounts" (
"id" UUID NOT NULL PRIMARY KEY,
"email_address" VARCHAR(255) NOT NULL UNIQUE,
"display_name" VARCHAR(255),
"imap_host" VARCHAR(255) NOT NULL,
"imap_port" INT NOT NULL DEFAULT 993,
"imap_username" VARCHAR(255) NOT NULL,
"imap_password" TEXT NOT NULL,
"is_active" BOOL NOT NULL DEFAULT True,
"last_error" TEXT,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" UUID NOT NULL REFERENCES "users" ("id") ON DELETE CASCADE
);
COMMENT ON TABLE "email_accounts" IS 'Email account configuration for IMAP connections.';
CREATE TABLE IF NOT EXISTS "emails" (
"id" UUID NOT NULL PRIMARY KEY,
"message_id" VARCHAR(255) NOT NULL UNIQUE,
"subject" VARCHAR(500) NOT NULL,
"from_address" VARCHAR(255) NOT NULL,
"to_address" TEXT NOT NULL,
"date" TIMESTAMPTZ NOT NULL,
"body_text" TEXT,
"body_html" TEXT,
"chromadb_doc_id" VARCHAR(255),
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"expires_at" TIMESTAMPTZ NOT NULL,
"account_id" UUID NOT NULL REFERENCES "email_accounts" ("id") ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS "idx_emails_message_981ddd" ON "emails" ("message_id");
COMMENT ON TABLE "emails" IS 'Email message metadata and content.';
CREATE TABLE IF NOT EXISTS "email_sync_status" (
"id" UUID NOT NULL PRIMARY KEY,
"last_sync_date" TIMESTAMPTZ,
"last_message_uid" INT NOT NULL DEFAULT 0,
"message_count" INT NOT NULL DEFAULT 0,
"consecutive_failures" INT NOT NULL DEFAULT 0,
"last_failure_date" TIMESTAMPTZ,
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"account_id" UUID NOT NULL REFERENCES "email_accounts" ("id") ON DELETE CASCADE
);
COMMENT ON TABLE "email_sync_status" IS 'Tracks sync progress and state per email account.';"""
async def downgrade(db: BaseDBAsyncClient) -> str:
return """
DROP TABLE IF EXISTS "user_memories";
DROP TABLE IF EXISTS "email_accounts";
DROP TABLE IF EXISTS "emails";
DROP TABLE IF EXISTS "email_sync_status";"""
MODELS_STATE = (
"eJztXGtv2zYU/SuCPrVAFjTPbcUwwE7czVudDLGz9ZFCoCXa1ixRGkk1NYr+911Skq0HZV"
"t+RUr1oU1C8lLU4SV57tGVvuquZ2GHHV955DOmDHHbI/pr7atOkIvhF2X9kaYj31/UigKO"
"ho40MBMtZQ0aMk6RyaFyhByGocjCzKS2H12MBI4jCj0TGtpkvCgKiP1fgA3ujTGfYAoVHz"
"9BsU0s/AWz+E9/aoxs7FipcduWuLYsN/jMl2X3993rN7KluNzQMD0ncMmitT/jE4/MmweB"
"bR0LG1E3xgRTxLGVuA0xyui246JwxFDAaYDnQ7UWBRYeocARYOi/jAJiCgw0eSXx3/mveg"
"l4AGoBrU24wOLrt/CuFvcsS3VxqavfW3cvzi5fyrv0GB9TWSkR0b9JQ8RRaCpxXQApf+ag"
"vJogqoYybp8BEwa6CYxxwQLHhQ/FQMYAbYaa7qIvhoPJmE/gz9OLiyUw/t26k0hCKwmlB3"
"4dev1NVHUa1glIFxCaFItbNhDPA3kNNdx2sRrMtGUGUisyPY5/qSjAcA/WLXFm0SJYgu+g"
"2+v0B63eX+JOXMb+cyRErUFH1JzK0lmm9MVlZirmnWj/dAe/a+JP7cPtTSfr+/N2gw+6GB"
"MKuGcQ79FAVmK9xqUxMKmJDXxrw4lNWzYT+6QTGw0+Ma8MU6PcCZIw2eIYicZ2wEnc/NAQ"
"R+9oqjwzBBh58N54FNtj8ieeSQi7MA5ETNVhEZGO+6ibqoK2KF2MgqLHORtJOgXcHdwT5u"
"Hp2epfta47usRwiMzpI6KWUQCmixlDY8zygLYjyzd/3mFnTs3UWCYJXC/ssZq7ShG2Eivv"
"1EtglEIvX+WeutkSROC+reja4kpL0FnBghMgrkeGjeRENqS41qSY4y+KI38ApWoo4/Z1Ic"
"XLjvLOu0HqFI+p74te693L1En+9vbmt7h5gipfvb1tNwz5ORKpPENmPkZTFRkQAWSHBG6O"
"CqRmN2H+xEtHv+937l5r4kR/IP1ur916rTHbHSJ9vSlORZknr9YIMk9eFcaYoiq9gGwXTh"
"ZjimdlQvWU0Ub4Hp56pYG8ODldA0loVQilrMtsRslDu9yRqTDd5flZ03DAzIiHW4YFWS2y"
"siiujA8U7lI2TtgnKxbxVw+7Hp3pCjKcqF3KgWUQ5IqGdsN9nwH3hYtwTErR34RJw4AbBv"
"xdMeBGI34WE1sdjbham2FdROIKs8AtVOJ9s78i3rea8TVMr/5MT8xj2cf/SZu6cL0DpAD4"
"iLFHjyo8s20TRGdqMJNWGTCHMx5GU5VTaJaA1xa8N3m6A2Tt7k3r7r2aOsftk37bfj/otD"
"LoYhfZThkvnRvsxkVXr/hdOujJq/Xkw2X6YU5AfJwgzmBLN0jgDosEWzWYCtOdiImHRfVs"
"HVDPijE9y0EqnczARNyeauF7noMRWeKgSdvs8gfjfW2mZY/qEuv/9vZtav23u9nQ+L7X7o"
"DzSpihkR1Soe7NQAnuxEUmcIQpVuiKK1Z/xraGHntyuc42kI2QErvAZdZjPdsyDRYM/8Wm"
"IlotBjRrV0Mw93LqQ/w4MXzqfbatcltqzvBwVEp3PBM5W3DRzBOadbbVi+Jt9SK3rToW8o"
"0x9QJfkRLzR//2Rg1pxiwD6D2Bu/xo2SY/0hyb8U97g/fjp/3wfHHny1XJrACZIVaig0aV"
"fJbiVaNKPtOJnSfG5VShVVmFudc0dpNaWOWINJ9SmFwRySeUm2ORfihaPc9fC4qQHyPT9A"
"JhthUgHdFXK+yqZpDsU1yVsOgKdbUTKxPF8qqcnvX0VV12p0WZp/CTI6H2aYhYWvRQ9ljP"
"oLSOzQN5IH3uwbal+YgybGlyUJps+GjzCYTTP1hoplEs2sNgjrW3NpkyjXva1YR6LrpuP5"
"CRR7XPEDPAD4YRNSeaiXw0tCHsg5UoR9bowDvih1vowJErKJ92Fccwaas6Cm17iQk3CK+3"
"jayfXFK/WEuxvFiiWF7kFcsR7CKCGIEfK86oYjSzdvWEdC++CbyyENAlye1eHeE8dIKPiI"
"jKxlqxTT2jrJpEVfFtL42Xh541M8q+9ZEyqkl+9aGXhcRowl3F47sVwMZGDbDqhELJsuGK"
"MLSSzE1hWhOQm5f5G+VsU0kUf/Ft6G2DiU1b1nNiazKRax3WkXJVMjszbdUkaMaA5DEsna"
"NZXxHwKJOrmXaSKqVrpjAuEhYTc7BCX0zJv+vqjJGNkAlH9jigUhvWhMrX7bX+EsUES7mL"
"FamOJXpIaJBzK4otITfCGEMVEhOTznxwNC1OpTtK9KExzDlcnh09EKFuxt2AO/OAHWv9wP"
"c9ypnmgovZvoPjFkzzMZXvgjYaZUU0yshpy8tBOcNGqYwVC5v5DpoZZVOAs3ZN7JB4S9s3"
"JuDYZeBMGdVFXTsUmGJ/zoPZJQXCQcomg6W9P27y889nW0Apv0Xzw+nJ+Y/nP51dnv8ETe"
"RQ5iU/LgE3nzkpMdgktT9n2DhjxhkLk/w7MQ8p1rRyPdQF3UMLWzYE2sBHPit8d2lKdcru"
"gOnU84O/wtnUDmLcwJR6iizVYpdNW9XkmG9e7G70wiaFspnY5sXu5sXu6r7YnRE2dpGEWS"
"8s0zlTM2IaoSq3AyD60Ft/3lmNINm7fJxApkhBToO3SkTOTNxqHXkA1VOmCTvNp57YdJjM"
"PBWdYCm74qRQnNeRS/cgdOSegBn+MU1w2tBYnL1g4/pHYSF0ZkJf2JqnxsJOGCnHI+gwoF"
"gTdzeFgYg0Vxaqx5oNsR92MfTvhB0LA8matQn86kDzRkWuiIosIxrptJuka+Wtd8DtqhUh"
"VYjKrfUoWE5JnIkcqNZJoVaoMj2cZPhqC2q+Y8EwxqDgaXAhgDm77xI90T02AyE8GdExoS"
"AxhSAWmX+XWMolGaGw+Q6d7aDZpJ94k26UlOeppDR5WM8sD2vfSQ31z8JqYWqbE10RPUc1"
"R8uCZrRoU5kv5xU/S1+TEUcT+KTJMjvhIcVho3j9Xflt8+KH6QmTujzoPcQHc2BplAAxal"
"5PAPfyGbfCj3MXfxin+OPcB/sozt4O3Z19FKfENzZ2f7x8+x8fHBMe"
)
+2 -13
View File
@@ -5,8 +5,7 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [ dependencies = [
"langchain-postgres>=0.0.13", "chromadb>=1.1.0",
"psycopg[binary]>=3.1.0",
"python-dotenv>=1.0.0", "python-dotenv>=1.0.0",
"flask>=3.1.2", "flask>=3.1.2",
"httpx>=0.28.1", "httpx>=0.28.1",
@@ -31,6 +30,7 @@ dependencies = [
"asyncpg>=0.30.0", "asyncpg>=0.30.0",
"langchain-openai>=1.1.6", "langchain-openai>=1.1.6",
"langchain>=1.2.0", "langchain>=1.2.0",
"langchain-chroma>=1.0.0",
"langchain-community>=0.4.1", "langchain-community>=0.4.1",
"jq>=1.10.0", "jq>=1.10.0",
"tavily-python>=0.7.17", "tavily-python>=0.7.17",
@@ -42,17 +42,6 @@ dependencies = [
"aioboto3>=13.0.0", "aioboto3>=13.0.0",
] ]
[project.optional-dependencies]
test = [
"pytest>=8.0.0",
"pytest-asyncio>=0.25.0",
"pytest-cov>=6.0.0",
]
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "auto"
[tool.aerich] [tool.aerich]
tortoise_orm = "config.db.TORTOISE_CONFIG" tortoise_orm = "config.db.TORTOISE_CONFIG"
location = "./migrations" location = "./migrations"
+38 -3
View File
@@ -1,13 +1,48 @@
import { useState, useEffect } from "react";
import "./App.css"; import "./App.css";
import { AuthProvider } from "./contexts/AuthContext"; import { AuthProvider } from "./contexts/AuthContext";
import { ChatScreen } from "./components/ChatScreen"; import { ChatScreen } from "./components/ChatScreen";
import { LoginScreen } from "./components/LoginScreen"; import { LoginScreen } from "./components/LoginScreen";
import { useAuthCheck } from "./hooks/useAuthCheck"; import { conversationService } from "./api/conversationService";
import catIcon from "./assets/cat.png"; import catIcon from "./assets/cat.png";
const AppContainer = () => { const AppContainer = () => {
const { isAuthenticated, isChecking, isAdmin, setAuthenticated } = useAuthCheck(); const [isAuthenticated, setAuthenticated] = useState<boolean>(false);
const [isChecking, setIsChecking] = useState<boolean>(true);
useEffect(() => {
const checkAuth = async () => {
const accessToken = localStorage.getItem("access_token");
const refreshToken = localStorage.getItem("refresh_token");
// No tokens at all, not authenticated
if (!accessToken && !refreshToken) {
setIsChecking(false);
setAuthenticated(false);
return;
}
// Try to verify token by making a request
try {
await conversationService.getAllConversations();
// If successful, user is authenticated
setAuthenticated(true);
} catch (error) {
// Token is invalid or expired
console.error("Authentication check failed:", error);
localStorage.removeItem("access_token");
localStorage.removeItem("refresh_token");
setAuthenticated(false);
} finally {
setIsChecking(false);
}
};
checkAuth();
}, []);
// Show loading state while checking authentication
if (isChecking) { if (isChecking) {
return ( return (
<div className="h-screen flex flex-col items-center justify-center bg-cream gap-4"> <div className="h-screen flex flex-col items-center justify-center bg-cream gap-4">
@@ -26,7 +61,7 @@ const AppContainer = () => {
return ( return (
<> <>
{isAuthenticated ? ( {isAuthenticated ? (
<ChatScreen setAuthenticated={setAuthenticated} isAdmin={isAdmin} /> <ChatScreen setAuthenticated={setAuthenticated} />
) : ( ) : (
<LoginScreen setAuthenticated={setAuthenticated} /> <LoginScreen setAuthenticated={setAuthenticated} />
)} )}
+29 -15
View File
@@ -1,4 +1,4 @@
import { useState } from "react"; import { useEffect, useState } from "react";
import { X, Phone, PhoneOff, Pencil, Check, Mail, Copy } from "lucide-react"; import { X, Phone, PhoneOff, Pencil, Check, Mail, Copy } from "lucide-react";
import { userService, type AdminUserRecord } from "../api/userService"; import { userService, type AdminUserRecord } from "../api/userService";
import { cn } from "../lib/utils"; import { cn } from "../lib/utils";
@@ -12,19 +12,27 @@ import {
TableHeader, TableHeader,
TableRow, TableRow,
} from "./ui/table"; } from "./ui/table";
import { useAdminUsers } from "../hooks/useAdminUsers";
type Props = { type Props = {
onClose: () => void; onClose: () => void;
}; };
export const AdminPanel = ({ onClose }: Props) => { export const AdminPanel = ({ onClose }: Props) => {
const { users, loading, updateUser } = useAdminUsers(); const [users, setUsers] = useState<AdminUserRecord[]>([]);
const [loading, setLoading] = useState(true);
const [editingId, setEditingId] = useState<string | null>(null); const [editingId, setEditingId] = useState<string | null>(null);
const [editValue, setEditValue] = useState(""); const [editValue, setEditValue] = useState("");
const [rowError, setRowError] = useState<Record<string, string>>({}); const [rowError, setRowError] = useState<Record<string, string>>({});
const [rowSuccess, setRowSuccess] = useState<Record<string, string>>({}); const [rowSuccess, setRowSuccess] = useState<Record<string, string>>({});
useEffect(() => {
userService
.adminListUsers()
.then(setUsers)
.catch(() => {})
.finally(() => setLoading(false));
}, []);
const startEdit = (user: AdminUserRecord) => { const startEdit = (user: AdminUserRecord) => {
setEditingId(user.id); setEditingId(user.id);
setEditValue(user.whatsapp_number ?? ""); setEditValue(user.whatsapp_number ?? "");
@@ -41,8 +49,8 @@ export const AdminPanel = ({ onClose }: Props) => {
setRowError((p) => ({ ...p, [userId]: "" })); setRowError((p) => ({ ...p, [userId]: "" }));
try { try {
const updated = await userService.adminSetWhatsapp(userId, editValue); const updated = await userService.adminSetWhatsapp(userId, editValue);
updateUser(userId, () => updated); setUsers((p) => p.map((u) => (u.id === userId ? updated : u)));
setRowSuccess((p) => ({ ...p, [userId]: "Saved" })); setRowSuccess((p) => ({ ...p, [userId]: "Saved" }));
setEditingId(null); setEditingId(null);
setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000); setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000);
} catch (err) { } catch (err) {
@@ -57,8 +65,10 @@ export const AdminPanel = ({ onClose }: Props) => {
setRowError((p) => ({ ...p, [userId]: "" })); setRowError((p) => ({ ...p, [userId]: "" }));
try { try {
await userService.adminUnlinkWhatsapp(userId); await userService.adminUnlinkWhatsapp(userId);
updateUser(userId, (u) => ({ ...u, whatsapp_number: null })); setUsers((p) =>
setRowSuccess((p) => ({ ...p, [userId]: "Unlinked" })); p.map((u) => (u.id === userId ? { ...u, whatsapp_number: null } : u)),
);
setRowSuccess((p) => ({ ...p, [userId]: "Unlinked ✓" }));
setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000); setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000);
} catch (err) { } catch (err) {
setRowError((p) => ({ setRowError((p) => ({
@@ -72,8 +82,8 @@ export const AdminPanel = ({ onClose }: Props) => {
setRowError((p) => ({ ...p, [userId]: "" })); setRowError((p) => ({ ...p, [userId]: "" }));
try { try {
const updated = await userService.adminToggleEmail(userId); const updated = await userService.adminToggleEmail(userId);
updateUser(userId, () => updated); setUsers((p) => p.map((u) => (u.id === userId ? updated : u)));
setRowSuccess((p) => ({ ...p, [userId]: "Email enabled" })); setRowSuccess((p) => ({ ...p, [userId]: "Email enabled" }));
setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000); setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000);
} catch (err) { } catch (err) {
setRowError((p) => ({ setRowError((p) => ({
@@ -87,8 +97,10 @@ export const AdminPanel = ({ onClose }: Props) => {
setRowError((p) => ({ ...p, [userId]: "" })); setRowError((p) => ({ ...p, [userId]: "" }));
try { try {
await userService.adminDisableEmail(userId); await userService.adminDisableEmail(userId);
updateUser(userId, (u) => ({ ...u, email_enabled: false, email_address: null })); setUsers((p) =>
setRowSuccess((p) => ({ ...p, [userId]: "Email disabled" })); p.map((u) => (u.id === userId ? { ...u, email_enabled: false, email_address: null } : u)),
);
setRowSuccess((p) => ({ ...p, [userId]: "Email disabled ✓" }));
setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000); setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000);
} catch (err) { } catch (err) {
setRowError((p) => ({ setRowError((p) => ({
@@ -100,7 +112,7 @@ export const AdminPanel = ({ onClose }: Props) => {
const copyToClipboard = (text: string, userId: string) => { const copyToClipboard = (text: string, userId: string) => {
navigator.clipboard.writeText(text); navigator.clipboard.writeText(text);
setRowSuccess((p) => ({ ...p, [userId]: "Copied" })); setRowSuccess((p) => ({ ...p, [userId]: "Copied" }));
setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000); setTimeout(() => setRowSuccess((p) => ({ ...p, [userId]: "" })), 2000);
}; };
@@ -116,6 +128,7 @@ export const AdminPanel = ({ onClose }: Props) => {
"border border-sand-light/60", "border border-sand-light/60",
)} )}
> >
{/* Header */}
<div className="flex items-center justify-between px-6 py-4 border-b border-sand-light/60"> <div className="flex items-center justify-between px-6 py-4 border-b border-sand-light/60">
<div className="flex items-center gap-2.5"> <div className="flex items-center gap-2.5">
<div className="w-8 h-8 rounded-xl bg-leaf-pale flex items-center justify-center"> <div className="w-8 h-8 rounded-xl bg-leaf-pale flex items-center justify-center">
@@ -133,6 +146,7 @@ export const AdminPanel = ({ onClose }: Props) => {
</button> </button>
</div> </div>
{/* Body */}
<div className="overflow-y-auto flex-1 rounded-b-3xl"> <div className="overflow-y-auto flex-1 rounded-b-3xl">
{loading ? ( {loading ? (
<div className="px-6 py-12 text-center text-warm-gray text-sm"> <div className="px-6 py-12 text-center text-warm-gray text-sm">
@@ -141,7 +155,7 @@ export const AdminPanel = ({ onClose }: Props) => {
<span className="loading-dot w-2 h-2 rounded-full bg-amber-soft inline-block" /> <span className="loading-dot w-2 h-2 rounded-full bg-amber-soft inline-block" />
<span className="loading-dot w-2 h-2 rounded-full bg-amber-soft inline-block" /> <span className="loading-dot w-2 h-2 rounded-full bg-amber-soft inline-block" />
</div> </div>
Loading users... Loading users
</div> </div>
) : ( ) : (
<Table> <Table>
@@ -190,7 +204,7 @@ export const AdminPanel = ({ onClose }: Props) => {
: "text-warm-gray/40 italic", : "text-warm-gray/40 italic",
)} )}
> >
{user.whatsapp_number ?? "\u2014"} {user.whatsapp_number ?? ""}
</span> </span>
{rowSuccess[user.id] && ( {rowSuccess[user.id] && (
<span className="text-xs text-leaf-dark"> <span className="text-xs text-leaf-dark">
@@ -221,7 +235,7 @@ export const AdminPanel = ({ onClose }: Props) => {
</button> </button>
</div> </div>
) : ( ) : (
<span className="text-sm text-warm-gray/40 italic">\u2014</span> <span className="text-sm text-warm-gray/40 italic"></span>
)} )}
</div> </div>
</TableCell> </TableCell>
@@ -1,4 +1,3 @@
import React from "react";
import ReactMarkdown from "react-markdown"; import ReactMarkdown from "react-markdown";
import { cn } from "../lib/utils"; import { cn } from "../lib/utils";
@@ -7,7 +6,7 @@ type AnswerBubbleProps = {
loading?: boolean; loading?: boolean;
}; };
export const AnswerBubble = React.memo(({ text, loading }: AnswerBubbleProps) => { export const AnswerBubble = ({ text, loading }: AnswerBubbleProps) => {
return ( return (
<div className="flex justify-start message-enter"> <div className="flex justify-start message-enter">
<div <div
@@ -18,6 +17,7 @@ export const AnswerBubble = React.memo(({ text, loading }: AnswerBubbleProps) =>
"overflow-hidden", "overflow-hidden",
)} )}
> >
{/* amber accent bar */}
<div className="h-0.5 w-full bg-gradient-to-r from-amber-soft via-amber-glow/50 to-transparent" /> <div className="h-0.5 w-full bg-gradient-to-r from-amber-soft via-amber-glow/50 to-transparent" />
<div className="px-4 py-3"> <div className="px-4 py-3">
@@ -36,4 +36,4 @@ export const AnswerBubble = React.memo(({ text, loading }: AnswerBubbleProps) =>
</div> </div>
</div> </div>
); );
}); };
+214 -73
View File
@@ -1,5 +1,7 @@
import { useCallback, useState, useRef } from "react"; import { useEffect, useState, useRef } from "react";
import { LogOut, Shield, PanelLeftClose, PanelLeftOpen, Menu, X } from "lucide-react"; import { LogOut, Shield, PanelLeftClose, PanelLeftOpen, Menu, X } from "lucide-react";
import { conversationService } from "../api/conversationService";
import { userService } from "../api/userService";
import { QuestionBubble } from "./QuestionBubble"; import { QuestionBubble } from "./QuestionBubble";
import { AnswerBubble } from "./AnswerBubble"; import { AnswerBubble } from "./AnswerBubble";
import { ToolBubble } from "./ToolBubble"; import { ToolBubble } from "./ToolBubble";
@@ -7,94 +9,224 @@ import { MessageInput } from "./MessageInput";
import { ConversationList } from "./ConversationList"; import { ConversationList } from "./ConversationList";
import { AdminPanel } from "./AdminPanel"; import { AdminPanel } from "./AdminPanel";
import { cn } from "../lib/utils"; import { cn } from "../lib/utils";
import { useConversations } from "../hooks/useConversations";
import { useChat } from "../hooks/useChat";
import catIcon from "../assets/cat.png"; import catIcon from "../assets/cat.png";
type Message = {
text: string;
speaker: "simba" | "user" | "tool";
image_key?: string | null;
};
type Conversation = {
title: string;
id: string;
};
type ChatScreenProps = { type ChatScreenProps = {
setAuthenticated: (isAuth: boolean) => void; setAuthenticated: (isAuth: boolean) => void;
isAdmin: boolean;
}; };
export const ChatScreen = ({ setAuthenticated, isAdmin }: ChatScreenProps) => { const TOOL_MESSAGES: Record<string, string> = {
const [query, setQuery] = useState(""); simba_search: "🔍 Searching Simba's records...",
const [simbaMode, setSimbaMode] = useState(false); web_search: "🌐 Searching the web...",
const [showConversations, setShowConversations] = useState(false); get_current_date: "📅 Checking today's date...",
const [sidebarCollapsed, setSidebarCollapsed] = useState(false); ynab_budget_summary: "💰 Checking budget summary...",
const [showAdminPanel, setShowAdminPanel] = useState(false); ynab_search_transactions: "💳 Looking up transactions...",
ynab_category_spending: "📊 Analyzing category spending...",
ynab_insights: "📈 Generating budget insights...",
obsidian_search_notes: "📝 Searching notes...",
obsidian_read_note: "📖 Reading note...",
obsidian_create_note: "✏️ Saving note...",
obsidian_create_task: "✅ Creating task...",
journal_get_today: "📔 Reading today's journal...",
journal_get_tasks: "📋 Getting tasks...",
journal_add_task: " Adding task...",
journal_complete_task: "✔️ Completing task...",
};
export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
const [query, setQuery] = useState<string>("");
const [simbaMode, setSimbaMode] = useState<boolean>(false);
const [messages, setMessages] = useState<Message[]>([]);
const [conversations, setConversations] = useState<Conversation[]>([]);
const [showConversations, setShowConversations] = useState<boolean>(false);
const [selectedConversation, setSelectedConversation] =
useState<Conversation | null>(null);
const [sidebarCollapsed, setSidebarCollapsed] = useState<boolean>(false);
const [isLoading, setIsLoading] = useState<boolean>(false);
const [isAdmin, setIsAdmin] = useState<boolean>(false);
const [showAdminPanel, setShowAdminPanel] = useState<boolean>(false);
const [pendingImage, setPendingImage] = useState<File | null>(null);
const messagesEndRef = useRef<HTMLDivElement>(null); const messagesEndRef = useRef<HTMLDivElement>(null);
const isLoadingRef = useRef(false); const isMountedRef = useRef<boolean>(true);
const abortControllerRef = useRef<AbortController | null>(null);
const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"];
const scrollToBottom = useCallback(() => { const scrollToBottom = () => {
requestAnimationFrame(() => { messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
messagesEndRef.current?.scrollIntoView({ };
behavior: isLoadingRef.current ? "instant" : "smooth",
}); useEffect(() => {
}); isMountedRef.current = true;
return () => {
isMountedRef.current = false;
abortControllerRef.current?.abort();
};
}, []); }, []);
const { const handleSelectConversation = (conversation: Conversation) => {
conversations,
selectedConversation,
selectConversation,
createConversation,
refreshConversations,
} = useConversations();
const onSessionExpired = useCallback(() => setAuthenticated(false), [setAuthenticated]);
const {
messages,
setMessages,
isLoading,
pendingImage,
setPendingImage,
sendMessage,
} = useChat({
selectedConversation,
createConversation,
refreshConversations,
onSessionExpired,
scrollToBottom,
});
// Keep ref in sync for scrollToBottom behavior
isLoadingRef.current = isLoading;
const handleSelectConversation = useCallback(
async (conversation: { title: string; id: string }) => {
setShowConversations(false); setShowConversations(false);
const loaded = await selectConversation(conversation); setSelectedConversation(conversation);
setMessages(loaded); const load = async () => {
}, try {
[selectConversation, setMessages], const fetched = await conversationService.getConversation(conversation.id);
setMessages(
fetched.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })),
); );
} catch (err) {
console.error("Failed to load messages:", err);
}
};
load();
};
const handleCreateNewConversation = useCallback(async () => { const loadConversations = async () => {
await createConversation(); try {
setMessages([]); const fetched = await conversationService.getAllConversations();
}, [createConversation, setMessages]); const parsed = fetched.map((c) => ({ id: c.id, title: c.name }));
setConversations(parsed);
} catch (err) {
console.error("Failed to load conversations:", err);
}
};
const handleQuestionSubmit = useCallback(() => { const handleCreateNewConversation = async () => {
sendMessage(query, simbaMode); const newConv = await conversationService.createConversation();
setQuery(""); await loadConversations();
}, [query, simbaMode, sendMessage]); setSelectedConversation({ title: newConv.name, id: newConv.id });
};
const handleQueryChange = useCallback((event: React.ChangeEvent<HTMLTextAreaElement>) => { useEffect(() => {
setQuery(event.target.value); loadConversations();
userService.getMe().then((me) => setIsAdmin(me.is_admin)).catch(() => {});
}, []); }, []);
const handleKeyDown = useCallback((event: React.ChangeEvent<HTMLTextAreaElement>) => { useEffect(() => {
scrollToBottom();
}, [messages]);
useEffect(() => {
const load = async () => {
if (!selectedConversation) return;
try {
const conv = await conversationService.getConversation(selectedConversation.id);
setSelectedConversation({ id: conv.id, title: conv.name });
setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })));
} catch (err) {
console.error("Failed to load messages:", err);
}
};
load();
}, [selectedConversation?.id]);
const handleQuestionSubmit = async () => {
if ((!query.trim() && !pendingImage) || isLoading) return;
let activeConversation = selectedConversation;
if (!activeConversation) {
const newConv = await conversationService.createConversation();
activeConversation = { title: newConv.name, id: newConv.id };
setSelectedConversation(activeConversation);
setConversations((prev) => [activeConversation!, ...prev]);
}
// Capture pending image before clearing state
const imageFile = pendingImage;
const currMessages = messages.concat([{ text: query, speaker: "user" }]);
setMessages(currMessages);
setQuery("");
setPendingImage(null);
setIsLoading(true);
if (simbaMode) {
const randomElement = simbaAnswers[Math.floor(Math.random() * simbaAnswers.length)];
setMessages((prev) => prev.concat([{ text: randomElement, speaker: "simba" }]));
setIsLoading(false);
return;
}
const abortController = new AbortController();
abortControllerRef.current = abortController;
try {
// Upload image first if present
let imageKey: string | undefined;
if (imageFile) {
const uploadResult = await conversationService.uploadImage(
imageFile,
activeConversation.id,
);
imageKey = uploadResult.image_key;
// Update the user message with the image key
setMessages((prev) => {
const updated = [...prev];
// Find the last user message we just added
for (let i = updated.length - 1; i >= 0; i--) {
if (updated[i].speaker === "user") {
updated[i] = { ...updated[i], image_key: imageKey };
break;
}
}
return updated;
});
}
await conversationService.streamQuery(
query,
activeConversation.id,
(event) => {
if (!isMountedRef.current) return;
if (event.type === "tool_start") {
const friendly = TOOL_MESSAGES[event.tool] ?? `🔧 Using ${event.tool}...`;
setMessages((prev) => prev.concat([{ text: friendly, speaker: "tool" }]));
} else if (event.type === "response") {
setMessages((prev) => prev.concat([{ text: event.message, speaker: "simba" }]));
} else if (event.type === "error") {
console.error("Stream error:", event.message);
}
},
abortController.signal,
imageKey,
);
} catch (error) {
if (error instanceof Error && error.name === "AbortError") {
console.log("Request was aborted");
} else {
console.error("Failed to send query:", error);
if (error instanceof Error && error.message.includes("Session expired")) {
setAuthenticated(false);
}
}
} finally {
if (isMountedRef.current) setIsLoading(false);
abortControllerRef.current = null;
}
};
const handleQueryChange = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
setQuery(event.target.value);
};
const handleKeyDown = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
const kev = event as unknown as React.KeyboardEvent<HTMLTextAreaElement>; const kev = event as unknown as React.KeyboardEvent<HTMLTextAreaElement>;
if (kev.key === "Enter" && !kev.shiftKey) { if (kev.key === "Enter" && !kev.shiftKey) {
kev.preventDefault(); kev.preventDefault();
handleQuestionSubmit(); handleQuestionSubmit();
} }
}, [handleQuestionSubmit]); };
const handleImageSelect = useCallback((file: File) => setPendingImage(file), [setPendingImage]);
const handleClearImage = useCallback(() => setPendingImage(null), [setPendingImage]);
const handleLogout = () => { const handleLogout = () => {
localStorage.removeItem("access_token"); localStorage.removeItem("access_token");
@@ -104,7 +236,7 @@ export const ChatScreen = ({ setAuthenticated, isAdmin }: ChatScreenProps) => {
return ( return (
<div className="h-screen h-[100dvh] flex flex-row bg-cream overflow-hidden"> <div className="h-screen h-[100dvh] flex flex-row bg-cream overflow-hidden">
{/* Desktop Sidebar */} {/* ── Desktop Sidebar ─────────────────────────────── */}
<aside <aside
className={cn( className={cn(
"hidden md:flex md:flex-col", "hidden md:flex md:flex-col",
@@ -113,6 +245,7 @@ export const ChatScreen = ({ setAuthenticated, isAdmin }: ChatScreenProps) => {
)} )}
> >
{sidebarCollapsed ? ( {sidebarCollapsed ? (
/* Collapsed state */
<div className="flex flex-col items-center py-4 gap-4 h-full"> <div className="flex flex-col items-center py-4 gap-4 h-full">
<button <button
onClick={() => setSidebarCollapsed(false)} onClick={() => setSidebarCollapsed(false)}
@@ -127,7 +260,9 @@ export const ChatScreen = ({ setAuthenticated, isAdmin }: ChatScreenProps) => {
/> />
</div> </div>
) : ( ) : (
/* Expanded state */
<div className="flex flex-col h-full"> <div className="flex flex-col h-full">
{/* Header */}
<div className="flex items-center justify-between px-4 py-4 border-b border-white/8"> <div className="flex items-center justify-between px-4 py-4 border-b border-white/8">
<div className="flex items-center gap-2.5"> <div className="flex items-center gap-2.5">
<img src={catIcon} alt="Simba" className="w-12 h-12" /> <img src={catIcon} alt="Simba" className="w-12 h-12" />
@@ -146,6 +281,7 @@ export const ChatScreen = ({ setAuthenticated, isAdmin }: ChatScreenProps) => {
</button> </button>
</div> </div>
{/* Conversations */}
<div className="flex-1 overflow-y-auto px-2 py-3"> <div className="flex-1 overflow-y-auto px-2 py-3">
<ConversationList <ConversationList
conversations={conversations} conversations={conversations}
@@ -155,6 +291,7 @@ export const ChatScreen = ({ setAuthenticated, isAdmin }: ChatScreenProps) => {
/> />
</div> </div>
{/* Footer */}
<div className="px-2 pb-3 pt-2 border-t border-white/8 flex flex-col gap-0.5"> <div className="px-2 pb-3 pt-2 border-t border-white/8 flex flex-col gap-0.5">
{isAdmin && ( {isAdmin && (
<button <button
@@ -177,9 +314,12 @@ export const ChatScreen = ({ setAuthenticated, isAdmin }: ChatScreenProps) => {
)} )}
</aside> </aside>
{/* Admin Panel modal */}
{showAdminPanel && <AdminPanel onClose={() => setShowAdminPanel(false)} />} {showAdminPanel && <AdminPanel onClose={() => setShowAdminPanel(false)} />}
{/* ── Main chat area ──────────────────────────────── */}
<div className="flex-1 flex flex-col h-full overflow-hidden min-w-0"> <div className="flex-1 flex flex-col h-full overflow-hidden min-w-0">
{/* Mobile header */}
<header className="md:hidden flex items-center justify-between px-4 py-3 bg-warm-white border-b border-sand-light/60"> <header className="md:hidden flex items-center justify-between px-4 py-3 bg-warm-white border-b border-sand-light/60">
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<img src={catIcon} alt="Simba" className="w-12 h-12" /> <img src={catIcon} alt="Simba" className="w-12 h-12" />
@@ -207,7 +347,9 @@ export const ChatScreen = ({ setAuthenticated, isAdmin }: ChatScreenProps) => {
</header> </header>
{messages.length === 0 ? ( {messages.length === 0 ? (
/* ── Empty / homepage state ── */
<div className="flex-1 flex flex-col items-center justify-center px-4 gap-6"> <div className="flex-1 flex flex-col items-center justify-center px-4 gap-6">
{/* Mobile conversation drawer */}
{showConversations && ( {showConversations && (
<div className="md:hidden w-full max-w-2xl bg-warm-white rounded-2xl border border-sand-light p-3 shadow-sm"> <div className="md:hidden w-full max-w-2xl bg-warm-white rounded-2xl border border-sand-light p-3 shadow-sm">
<ConversationList <ConversationList
@@ -238,15 +380,17 @@ export const ChatScreen = ({ setAuthenticated, isAdmin }: ChatScreenProps) => {
setSimbaMode={setSimbaMode} setSimbaMode={setSimbaMode}
isLoading={isLoading} isLoading={isLoading}
pendingImage={pendingImage} pendingImage={pendingImage}
onImageSelect={handleImageSelect} onImageSelect={(file) => setPendingImage(file)}
onClearImage={handleClearImage} onClearImage={() => setPendingImage(null)}
/> />
</div> </div>
</div> </div>
) : ( ) : (
/* ── Active chat state ── */
<> <>
<div className="flex-1 overflow-y-auto px-4 py-6"> <div className="flex-1 overflow-y-auto px-4 py-6">
<div className="max-w-2xl mx-auto flex flex-col gap-3"> <div className="max-w-2xl mx-auto flex flex-col gap-3">
{/* Mobile conversation drawer */}
{showConversations && ( {showConversations && (
<div className="md:hidden mb-3 bg-warm-white rounded-2xl border border-sand-light p-3 shadow-sm"> <div className="md:hidden mb-3 bg-warm-white rounded-2xl border border-sand-light p-3 shadow-sm">
<ConversationList <ConversationList
@@ -272,7 +416,7 @@ export const ChatScreen = ({ setAuthenticated, isAdmin }: ChatScreenProps) => {
</div> </div>
</div> </div>
<footer className="border-t border-sand-light/40 bg-cream"> <footer className="border-t border-sand-light/40 bg-cream/80 backdrop-blur-sm">
<div className="max-w-2xl mx-auto px-4 py-3"> <div className="max-w-2xl mx-auto px-4 py-3">
<MessageInput <MessageInput
query={query} query={query}
@@ -281,9 +425,6 @@ export const ChatScreen = ({ setAuthenticated, isAdmin }: ChatScreenProps) => {
handleQuestionSubmit={handleQuestionSubmit} handleQuestionSubmit={handleQuestionSubmit}
setSimbaMode={setSimbaMode} setSimbaMode={setSimbaMode}
isLoading={isLoading} isLoading={isLoading}
pendingImage={pendingImage}
onImageSelect={handleImageSelect}
onClearImage={handleClearImage}
/> />
</div> </div>
</footer> </footer>
@@ -1,5 +1,7 @@
import { useState, useEffect } from "react";
import { Plus } from "lucide-react"; import { Plus } from "lucide-react";
import { cn } from "../lib/utils"; import { cn } from "../lib/utils";
import { conversationService } from "../api/conversationService";
type Conversation = { type Conversation = {
title: string; title: string;
@@ -21,8 +23,32 @@ export const ConversationList = ({
selectedId, selectedId,
variant = "dark", variant = "dark",
}: ConversationProps) => { }: ConversationProps) => {
const [items, setItems] = useState(conversations);
useEffect(() => {
const load = async () => {
try {
let fetched = await conversationService.getAllConversations();
if (fetched.length === 0) {
await conversationService.createConversation();
fetched = await conversationService.getAllConversations();
}
setItems(fetched.map((c) => ({ id: c.id, title: c.name })));
} catch (err) {
console.error("Failed to load conversations:", err);
}
};
load();
}, []);
// Keep in sync when parent updates conversations
useEffect(() => {
setItems(conversations);
}, [conversations]);
return ( return (
<div className="flex flex-col gap-1"> <div className="flex flex-col gap-1">
{/* New thread button */}
<button <button
onClick={onCreateNewConversation} onClick={onCreateNewConversation}
className={cn( className={cn(
@@ -37,7 +63,8 @@ export const ConversationList = ({
<span>New thread</span> <span>New thread</span>
</button> </button>
{conversations.map((conv) => { {/* Conversation items */}
{items.map((conv) => {
const isActive = conv.id === selectedId; const isActive = conv.id === selectedId;
return ( return (
<button <button
+57 -6
View File
@@ -1,19 +1,66 @@
import { useState, useEffect } from "react";
import { userService } from "../api/userService";
import { oidcService } from "../api/oidcService";
import catIcon from "../assets/cat.png"; import catIcon from "../assets/cat.png";
import { cn } from "../lib/utils"; import { cn } from "../lib/utils";
import { useOIDCAuth } from "../hooks/useOIDCAuth";
type LoginScreenProps = { type LoginScreenProps = {
setAuthenticated: (isAuth: boolean) => void; setAuthenticated: (isAuth: boolean) => void;
}; };
export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => { export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => {
const { isChecking, isLoggingIn, error, handleLogin } = useOIDCAuth({ const [error, setError] = useState<string>("");
setAuthenticated, const [isChecking, setIsChecking] = useState<boolean>(true);
}); const [isLoggingIn, setIsLoggingIn] = useState<boolean>(false);
useEffect(() => {
const initAuth = async () => {
const callbackParams = oidcService.getCallbackParamsFromURL();
if (callbackParams) {
try {
setIsLoggingIn(true);
const result = await oidcService.handleCallback(
callbackParams.code,
callbackParams.state,
);
localStorage.setItem("access_token", result.access_token);
localStorage.setItem("refresh_token", result.refresh_token);
oidcService.clearCallbackParams();
setAuthenticated(true);
setIsChecking(false);
return;
} catch (err) {
console.error("OIDC callback error:", err);
setError("Login failed. Please try again.");
oidcService.clearCallbackParams();
setIsLoggingIn(false);
setIsChecking(false);
return;
}
}
const isValid = await userService.validateToken();
if (isValid) setAuthenticated(true);
setIsChecking(false);
};
initAuth();
}, [setAuthenticated]);
const handleOIDCLogin = async () => {
try {
setIsLoggingIn(true);
setError("");
const authUrl = await oidcService.initiateLogin();
window.location.href = authUrl;
} catch {
setError("Failed to initiate login. Please try again.");
setIsLoggingIn(false);
}
};
if (isChecking || isLoggingIn) { if (isChecking || isLoggingIn) {
return ( return (
<div className="h-screen flex flex-col items-center justify-center bg-cream gap-4"> <div className="h-screen flex flex-col items-center justify-center bg-cream gap-4">
{/* Subtle dot grid */}
<div <div
className="fixed inset-0 pointer-events-none opacity-[0.035]" className="fixed inset-0 pointer-events-none opacity-[0.035]"
style={{ style={{
@@ -38,6 +85,7 @@ export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => {
return ( return (
<div className="h-screen bg-cream flex items-center justify-center p-4 relative overflow-hidden"> <div className="h-screen bg-cream flex items-center justify-center p-4 relative overflow-hidden">
{/* Background dot texture */}
<div <div
className="fixed inset-0 pointer-events-none opacity-[0.04]" className="fixed inset-0 pointer-events-none opacity-[0.04]"
style={{ style={{
@@ -46,10 +94,12 @@ export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => {
}} }}
/> />
{/* Decorative background blobs */}
<div className="absolute top-1/4 -left-20 w-72 h-72 rounded-full bg-leaf-pale/60 blur-3xl pointer-events-none" /> <div className="absolute top-1/4 -left-20 w-72 h-72 rounded-full bg-leaf-pale/60 blur-3xl pointer-events-none" />
<div className="absolute bottom-1/4 -right-20 w-64 h-64 rounded-full bg-amber-pale/70 blur-3xl pointer-events-none" /> <div className="absolute bottom-1/4 -right-20 w-64 h-64 rounded-full bg-amber-pale/70 blur-3xl pointer-events-none" />
<div className="relative w-full max-w-sm"> <div className="relative w-full max-w-sm">
{/* Branding */}
<div className="flex flex-col items-center mb-8"> <div className="flex flex-col items-center mb-8">
<div className="relative mb-5"> <div className="relative mb-5">
<div className="absolute -inset-5 bg-amber-soft/30 rounded-full blur-2xl" /> <div className="absolute -inset-5 bg-amber-soft/30 rounded-full blur-2xl" />
@@ -70,6 +120,7 @@ export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => {
</p> </p>
</div> </div>
{/* Card */}
<div <div
className={cn( className={cn(
"bg-warm-white rounded-3xl border border-sand-light", "bg-warm-white rounded-3xl border border-sand-light",
@@ -87,7 +138,7 @@ export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => {
</p> </p>
<button <button
onClick={handleLogin} onClick={handleOIDCLogin}
disabled={isLoggingIn} disabled={isLoggingIn}
className={cn( className={cn(
"w-full py-3.5 px-4 rounded-2xl text-sm font-semibold tracking-wide", "w-full py-3.5 px-4 rounded-2xl text-sm font-semibold tracking-wide",
@@ -103,7 +154,7 @@ export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => {
</div> </div>
<p className="text-center text-sand mt-5 text-xs tracking-widest select-none"> <p className="text-center text-sand mt-5 text-xs tracking-widest select-none">
* meow * meow
</p> </p>
</div> </div>
</div> </div>
+4 -16
View File
@@ -1,4 +1,4 @@
import React, { useEffect, useMemo, useRef, useState } from "react"; import { useRef, useState } from "react";
import { ArrowUp, ImagePlus, X } from "lucide-react"; import { ArrowUp, ImagePlus, X } from "lucide-react";
import { cn } from "../lib/utils"; import { cn } from "../lib/utils";
import { Textarea } from "./ui/textarea"; import { Textarea } from "./ui/textarea";
@@ -15,7 +15,7 @@ type MessageInputProps = {
onClearImage: () => void; onClearImage: () => void;
}; };
export const MessageInput = React.memo(({ export const MessageInput = ({
query, query,
handleKeyDown, handleKeyDown,
handleQueryChange, handleQueryChange,
@@ -29,18 +29,6 @@ export const MessageInput = React.memo(({
const [simbaMode, setLocalSimbaMode] = useState(false); const [simbaMode, setLocalSimbaMode] = useState(false);
const fileInputRef = useRef<HTMLInputElement>(null); const fileInputRef = useRef<HTMLInputElement>(null);
// Create blob URL once per file, revoke on cleanup
const previewUrl = useMemo(
() => (pendingImage ? URL.createObjectURL(pendingImage) : null),
[pendingImage],
);
useEffect(() => {
return () => {
if (previewUrl) URL.revokeObjectURL(previewUrl);
};
}, [previewUrl]);
const toggleSimbaMode = () => { const toggleSimbaMode = () => {
const next = !simbaMode; const next = !simbaMode;
setLocalSimbaMode(next); setLocalSimbaMode(next);
@@ -71,7 +59,7 @@ export const MessageInput = React.memo(({
<div className="px-3 pt-3"> <div className="px-3 pt-3">
<div className="relative inline-block"> <div className="relative inline-block">
<img <img
src={previewUrl!} src={URL.createObjectURL(pendingImage)}
alt="Pending upload" alt="Pending upload"
className="h-20 rounded-lg object-cover border border-sand" className="h-20 rounded-lg object-cover border border-sand"
/> />
@@ -157,4 +145,4 @@ export const MessageInput = React.memo(({
</div> </div>
</div> </div>
); );
}); };
@@ -1,14 +1,26 @@
import React from "react"; import { useEffect, useState } from "react";
import { cn } from "../lib/utils"; import { cn } from "../lib/utils";
import { usePresignedUrl } from "../hooks/usePresignedUrl"; import { conversationService } from "../api/conversationService";
type QuestionBubbleProps = { type QuestionBubbleProps = {
text: string; text: string;
image_key?: string | null; image_key?: string | null;
}; };
export const QuestionBubble = React.memo(({ text, image_key }: QuestionBubbleProps) => { export const QuestionBubble = ({ text, image_key }: QuestionBubbleProps) => {
const { imageUrl, imageError } = usePresignedUrl(image_key); const [imageUrl, setImageUrl] = useState<string | null>(null);
const [imageError, setImageError] = useState(false);
useEffect(() => {
if (!image_key) return;
conversationService
.getPresignedImageUrl(image_key)
.then(setImageUrl)
.catch((err) => {
console.error("Failed to load image:", err);
setImageError(true);
});
}, [image_key]);
return ( return (
<div className="flex justify-end message-enter"> <div className="flex justify-end message-enter">
@@ -22,6 +34,7 @@ export const QuestionBubble = React.memo(({ text, image_key }: QuestionBubblePro
> >
{imageError && ( {imageError && (
<div className="flex items-center gap-2 text-xs text-charcoal/50 bg-charcoal/5 rounded-xl px-3 py-2 mb-2"> <div className="flex items-center gap-2 text-xs text-charcoal/50 bg-charcoal/5 rounded-xl px-3 py-2 mb-2">
<span>🖼</span>
<span>Image failed to load</span> <span>Image failed to load</span>
</div> </div>
)} )}
@@ -36,4 +49,4 @@ export const QuestionBubble = React.memo(({ text, image_key }: QuestionBubblePro
</div> </div>
</div> </div>
); );
}); };
+2 -3
View File
@@ -1,7 +1,6 @@
import React from "react";
import { cn } from "../lib/utils"; import { cn } from "../lib/utils";
export const ToolBubble = React.memo(({ text }: { text: string }) => ( export const ToolBubble = ({ text }: { text: string }) => (
<div className="flex justify-center message-enter"> <div className="flex justify-center message-enter">
<div <div
className={cn( className={cn(
@@ -13,4 +12,4 @@ export const ToolBubble = React.memo(({ text }: { text: string }) => (
{text} {text}
</div> </div>
</div> </div>
)); );
-21
View File
@@ -1,21 +0,0 @@
import { useState, useEffect } from "react";
import { userService, type AdminUserRecord } from "../api/userService";
export function useAdminUsers() {
const [users, setUsers] = useState<AdminUserRecord[]>([]);
const [loading, setLoading] = useState(true);
useEffect(() => {
userService
.adminListUsers()
.then(setUsers)
.catch(() => {})
.finally(() => setLoading(false));
}, []);
const updateUser = (userId: string, updater: (u: AdminUserRecord) => AdminUserRecord) => {
setUsers((prev) => prev.map((u) => (u.id === userId ? updater(u) : u)));
};
return { users, loading, updateUser };
}
-37
View File
@@ -1,37 +0,0 @@
import { useState, useEffect } from "react";
import { userService } from "../api/userService";
export function useAuthCheck() {
const [isAuthenticated, setAuthenticated] = useState(false);
const [isChecking, setIsChecking] = useState(true);
const [isAdmin, setIsAdmin] = useState(false);
useEffect(() => {
const checkAuth = async () => {
const accessToken = localStorage.getItem("access_token");
const refreshToken = localStorage.getItem("refresh_token");
if (!accessToken && !refreshToken) {
setIsChecking(false);
setAuthenticated(false);
return;
}
try {
const me = await userService.getMe();
setAuthenticated(true);
setIsAdmin(me.is_admin);
} catch {
localStorage.removeItem("access_token");
localStorage.removeItem("refresh_token");
setAuthenticated(false);
} finally {
setIsChecking(false);
}
};
checkAuth();
}, []);
return { isAuthenticated, isChecking, isAdmin, setAuthenticated };
}
-183
View File
@@ -1,183 +0,0 @@
import { useState, useCallback, useEffect, useRef } from "react";
import { conversationService } from "../api/conversationService";
import type { Conversation } from "./useConversations";
type Message = {
text: string;
speaker: "simba" | "user" | "tool";
image_key?: string | null;
};
const TOOL_MESSAGES: Record<string, string> = {
simba_search: "Searching Simba's records...",
web_search: "Searching the web...",
get_current_date: "Checking today's date...",
ynab_budget_summary: "Checking budget summary...",
ynab_search_transactions: "Looking up transactions...",
ynab_category_spending: "Analyzing category spending...",
ynab_insights: "Generating budget insights...",
obsidian_search_notes: "Searching notes...",
obsidian_read_note: "Reading note...",
obsidian_create_note: "Saving note...",
obsidian_create_task: "Creating task...",
journal_get_today: "Reading today's journal...",
journal_get_tasks: "Getting tasks...",
journal_add_task: "Adding task...",
journal_complete_task: "Completing task...",
};
const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"];
type UseChatOptions = {
selectedConversation: Conversation | null;
createConversation: () => Promise<Conversation>;
refreshConversations: () => Promise<void>;
onSessionExpired: () => void;
scrollToBottom: () => void;
};
export function useChat({
selectedConversation,
createConversation,
refreshConversations,
onSessionExpired,
scrollToBottom,
}: UseChatOptions) {
const [messages, setMessages] = useState<Message[]>([]);
const [isLoading, setIsLoading] = useState(false);
const [pendingImage, setPendingImage] = useState<File | null>(null);
const isMountedRef = useRef(true);
const abortControllerRef = useRef<AbortController | null>(null);
useEffect(() => {
isMountedRef.current = true;
return () => {
isMountedRef.current = false;
abortControllerRef.current?.abort();
};
}, []);
const updateMessages = useCallback(
(updater: Message[] | ((prev: Message[]) => Message[])) => {
setMessages(updater);
scrollToBottom();
},
[scrollToBottom],
);
const sendMessage = useCallback(
async (query: string, simbaMode: boolean) => {
if ((!query.trim() && !pendingImage) || isLoading) return;
let activeConversation = selectedConversation;
let createdNew = false;
if (!activeConversation) {
activeConversation = await createConversation();
createdNew = true;
}
const imageFile = pendingImage;
updateMessages((prev) => prev.concat([{ text: query, speaker: "user" }]));
setPendingImage(null);
setIsLoading(true);
if (simbaMode) {
const randomElement =
simbaAnswers[Math.floor(Math.random() * simbaAnswers.length)];
updateMessages((prev) =>
prev.concat([{ text: randomElement, speaker: "simba" }]),
);
setIsLoading(false);
return;
}
const abortController = new AbortController();
abortControllerRef.current = abortController;
try {
let imageKey: string | undefined;
if (imageFile) {
const uploadResult = await conversationService.uploadImage(
imageFile,
activeConversation.id,
);
imageKey = uploadResult.image_key;
updateMessages((prev) => {
const updated = [...prev];
for (let i = updated.length - 1; i >= 0; i--) {
if (updated[i].speaker === "user") {
updated[i] = { ...updated[i], image_key: imageKey };
break;
}
}
return updated;
});
}
await conversationService.streamQuery(
query,
activeConversation.id,
(event) => {
if (!isMountedRef.current) return;
if (event.type === "tool_start") {
const friendly =
TOOL_MESSAGES[event.tool] ?? `Using ${event.tool}...`;
updateMessages((prev) =>
prev.concat([{ text: friendly, speaker: "tool" }]),
);
} else if (event.type === "response") {
updateMessages((prev) =>
prev.concat([{ text: event.message, speaker: "simba" }]),
);
} else if (event.type === "error") {
console.error("Stream error:", event.message);
}
},
abortController.signal,
imageKey,
);
} catch (error) {
if (error instanceof Error && error.name === "AbortError") {
console.log("Request was aborted");
} else {
console.error("Failed to send query:", error);
if (
error instanceof Error &&
error.message.includes("Session expired")
) {
onSessionExpired();
}
}
} finally {
if (isMountedRef.current) {
setIsLoading(false);
if (createdNew) {
refreshConversations();
}
}
abortControllerRef.current = null;
}
},
[
pendingImage,
isLoading,
selectedConversation,
createConversation,
refreshConversations,
onSessionExpired,
updateMessages,
],
);
return {
messages,
setMessages: updateMessages,
isLoading,
pendingImage,
setPendingImage,
sendMessage,
};
}
@@ -1,69 +0,0 @@
import { useState, useCallback, useEffect } from "react";
import { conversationService } from "../api/conversationService";
export type Conversation = {
title: string;
id: string;
};
type Message = {
text: string;
speaker: "simba" | "user" | "tool";
image_key?: string | null;
};
export function useConversations() {
const [conversations, setConversations] = useState<Conversation[]>([]);
const [selectedConversation, setSelectedConversation] =
useState<Conversation | null>(null);
const refreshConversations = useCallback(async () => {
try {
const fetched = await conversationService.getAllConversations();
setConversations(fetched.map((c) => ({ id: c.id, title: c.name })));
} catch (err) {
console.error("Failed to load conversations:", err);
}
}, []);
useEffect(() => {
refreshConversations();
}, [refreshConversations]);
const selectConversation = useCallback(
async (conversation: Conversation): Promise<Message[]> => {
setSelectedConversation(conversation);
try {
const fetched = await conversationService.getConversation(
conversation.id,
);
return fetched.messages.map((m) => ({
text: m.text,
speaker: m.speaker,
image_key: m.image_key,
}));
} catch (err) {
console.error("Failed to load messages:", err);
return [];
}
},
[],
);
const createConversation = useCallback(async (): Promise<Conversation> => {
const newConv = await conversationService.createConversation();
const conversation = { title: newConv.name, id: newConv.id };
setConversations((prev) => [conversation, ...prev]);
setSelectedConversation(conversation);
return conversation;
}, []);
return {
conversations,
selectedConversation,
setSelectedConversation,
selectConversation,
createConversation,
refreshConversations,
};
}
-59
View File
@@ -1,59 +0,0 @@
import { useState, useEffect } from "react";
import { userService } from "../api/userService";
import { oidcService } from "../api/oidcService";
type UseOIDCAuthOptions = {
setAuthenticated: (isAuth: boolean) => void;
};
export function useOIDCAuth({ setAuthenticated }: UseOIDCAuthOptions) {
const [isChecking, setIsChecking] = useState(true);
const [isLoggingIn, setIsLoggingIn] = useState(false);
const [error, setError] = useState("");
useEffect(() => {
const initAuth = async () => {
const callbackParams = oidcService.getCallbackParamsFromURL();
if (callbackParams) {
try {
setIsLoggingIn(true);
const result = await oidcService.handleCallback(
callbackParams.code,
callbackParams.state,
);
localStorage.setItem("access_token", result.access_token);
localStorage.setItem("refresh_token", result.refresh_token);
oidcService.clearCallbackParams();
setAuthenticated(true);
setIsChecking(false);
return;
} catch (err) {
console.error("OIDC callback error:", err);
setError("Login failed. Please try again.");
oidcService.clearCallbackParams();
setIsLoggingIn(false);
setIsChecking(false);
return;
}
}
const isValid = await userService.validateToken();
if (isValid) setAuthenticated(true);
setIsChecking(false);
};
initAuth();
}, [setAuthenticated]);
const handleLogin = async () => {
try {
setIsLoggingIn(true);
setError("");
const authUrl = await oidcService.initiateLogin();
window.location.href = authUrl;
} catch {
setError("Failed to initiate login. Please try again.");
setIsLoggingIn(false);
}
};
return { isChecking, isLoggingIn, error, handleLogin };
}
@@ -1,34 +0,0 @@
import { useState, useEffect } from "react";
import { conversationService } from "../api/conversationService";
const urlCache = new Map<string, string>();
export function usePresignedUrl(imageKey: string | null | undefined) {
const [imageUrl, setImageUrl] = useState<string | null>(
imageKey ? (urlCache.get(imageKey) ?? null) : null,
);
const [imageError, setImageError] = useState(false);
useEffect(() => {
if (!imageKey) return;
const cached = urlCache.get(imageKey);
if (cached) {
setImageUrl(cached);
return;
}
conversationService
.getPresignedImageUrl(imageKey)
.then((url) => {
urlCache.set(imageKey, url);
setImageUrl(url);
})
.catch((err) => {
console.error("Failed to load image:", err);
setImageError(true);
});
}, [imageKey]);
return { imageUrl, imageError };
}
+16 -8
View File
@@ -6,19 +6,19 @@ import asyncio
import sys import sys
from blueprints.rag.logic import ( from blueprints.rag.logic import (
delete_all_documents,
get_vector_store_stats, get_vector_store_stats,
index_documents, index_documents,
list_all_documents, list_all_documents,
vector_store,
) )
def stats(): def stats():
"""Show vector store statistics.""" """Show vector store statistics."""
s = get_vector_store_stats() stats = get_vector_store_stats()
print("=== Vector Store Statistics ===") print("=== Vector Store Statistics ===")
print(f"Collection: {s['collection_name']}") print(f"Collection: {stats['collection_name']}")
print(f"Total Documents: {s['total_documents']}") print(f"Total Documents: {stats['total_documents']}")
async def index(): async def index():
@@ -26,15 +26,23 @@ async def index():
print("Starting indexing process...") print("Starting indexing process...")
print("Fetching documents from Paperless-NGX...") print("Fetching documents from Paperless-NGX...")
await index_documents() await index_documents()
print("Indexing complete!") print("Indexing complete!")
stats() stats()
async def reindex(): async def reindex():
"""Clear and reindex all documents.""" """Clear and reindex all documents."""
print("Clearing existing documents...") print("Clearing existing documents...")
delete_all_documents() collection = vector_store._collection
print("Cleared") all_docs = collection.get()
if all_docs["ids"]:
print(f"Deleting {len(all_docs['ids'])} existing documents...")
collection.delete(ids=all_docs["ids"])
print("✓ Cleared")
else:
print("Collection is already empty")
await index() await index()
@@ -105,7 +113,7 @@ Examples:
print("\n\nOperation cancelled by user") print("\n\nOperation cancelled by user")
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
print(f"\nError: {e}", file=sys.stderr) print(f"\nError: {e}", file=sys.stderr)
sys.exit(1) sys.exit(1)
+24
View File
@@ -0,0 +1,24 @@
from bs4 import BeautifulSoup
import chromadb
import httpx
client = chromadb.PersistentClient(path="/Users/ryanchen/Programs/raggr/chromadb")
# Scrape
BASE_URL = "https://www.vet.cornell.edu"
LIST_URL = "/departments-centers-and-institutes/cornell-feline-health-center/health-information/feline-health-topics"
QUERY_URL = BASE_URL + LIST_URL
r = httpx.get(QUERY_URL)
soup = BeautifulSoup(r.text)
container = soup.find("div", class_="field-body")
a_s = container.find_all("a", href=True)
new_texts = []
for link in a_s:
endpoint = link["href"]
query_url = BASE_URL + endpoint
r2 = httpx.get(query_url)
article_soup = BeautifulSoup(r2.text)
+3
View File
@@ -1,6 +1,9 @@
#!/bin/bash #!/bin/bash
set -e set -e
echo "Initializing directories..."
mkdir -p /app/data/chromadb
echo "Rebuilding frontend..." echo "Rebuilding frontend..."
cd /app/raggr-frontend cd /app/raggr-frontend
yarn build yarn build
View File
-11
View File
@@ -1,11 +0,0 @@
import os
import sys
# Ensure project root is on the path so imports work
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
# Set FERNET_KEY for tests that import email models (EncryptedTextField needs it at import time)
if "FERNET_KEY" not in os.environ:
from cryptography.fernet import Fernet
os.environ["FERNET_KEY"] = Fernet.generate_key().decode()
View File
-91
View File
@@ -1,91 +0,0 @@
"""Tests for encryption/decryption in blueprints/email/crypto_service.py."""
import os
from unittest.mock import patch
import pytest
from cryptography.fernet import Fernet
# Generate a valid key for testing
TEST_FERNET_KEY = Fernet.generate_key().decode()
class TestEncryptedTextField:
@pytest.fixture
def field(self):
with patch.dict(os.environ, {"FERNET_KEY": TEST_FERNET_KEY}):
from blueprints.email.crypto_service import EncryptedTextField
return EncryptedTextField()
def test_encrypt_decrypt_roundtrip(self, field):
original = "my secret password"
encrypted = field.to_db_value(original, None)
decrypted = field.to_python_value(encrypted)
assert decrypted == original
assert encrypted != original
def test_none_passthrough(self, field):
assert field.to_db_value(None, None) is None
assert field.to_python_value(None) is None
def test_unicode_roundtrip(self, field):
original = "Hello 世界 🐱"
encrypted = field.to_db_value(original, None)
decrypted = field.to_python_value(encrypted)
assert decrypted == original
def test_empty_string_roundtrip(self, field):
encrypted = field.to_db_value("", None)
decrypted = field.to_python_value(encrypted)
assert decrypted == ""
def test_long_text_roundtrip(self, field):
original = "x" * 10000
encrypted = field.to_db_value(original, None)
decrypted = field.to_python_value(encrypted)
assert decrypted == original
def test_different_encryptions_differ(self, field):
"""Fernet includes a timestamp, so two encryptions of the same value differ."""
e1 = field.to_db_value("same", None)
e2 = field.to_db_value("same", None)
assert e1 != e2 # Different ciphertexts
assert field.to_python_value(e1) == field.to_python_value(e2) == "same"
def test_wrong_key_fails(self, field):
encrypted = field.to_db_value("secret", None)
# Create a field with a different key
other_key = Fernet.generate_key().decode()
with patch.dict(os.environ, {"FERNET_KEY": other_key}):
from blueprints.email.crypto_service import EncryptedTextField
other_field = EncryptedTextField()
with pytest.raises(Exception):
other_field.to_python_value(encrypted)
class TestValidateFernetKey:
def test_valid_key(self):
with patch.dict(os.environ, {"FERNET_KEY": TEST_FERNET_KEY}):
from blueprints.email.crypto_service import validate_fernet_key
validate_fernet_key() # Should not raise
def test_missing_key(self):
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("FERNET_KEY", None)
from blueprints.email.crypto_service import validate_fernet_key
with pytest.raises(ValueError, match="not set"):
validate_fernet_key()
def test_invalid_key(self):
with patch.dict(os.environ, {"FERNET_KEY": "not-a-valid-key"}):
from blueprints.email.crypto_service import validate_fernet_key
with pytest.raises(ValueError, match="validation failed"):
validate_fernet_key()
-38
View File
@@ -1,38 +0,0 @@
"""Tests for email helper functions in blueprints/email/helpers.py."""
from blueprints.email.helpers import generate_email_token, get_user_email_address
class TestGenerateEmailToken:
def test_returns_16_char_hex(self):
token = generate_email_token("user-123", "my-secret")
assert len(token) == 16
assert all(c in "0123456789abcdef" for c in token)
def test_deterministic(self):
t1 = generate_email_token("user-123", "my-secret")
t2 = generate_email_token("user-123", "my-secret")
assert t1 == t2
def test_different_users_different_tokens(self):
t1 = generate_email_token("user-1", "secret")
t2 = generate_email_token("user-2", "secret")
assert t1 != t2
def test_different_secrets_different_tokens(self):
t1 = generate_email_token("user-1", "secret-a")
t2 = generate_email_token("user-1", "secret-b")
assert t1 != t2
class TestGetUserEmailAddress:
def test_formats_correctly(self):
addr = get_user_email_address("abc123", "example.com")
assert addr == "ask+abc123@example.com"
def test_preserves_token(self):
token = "deadbeef12345678"
addr = get_user_email_address(token, "mail.test.org")
assert token in addr
assert addr.startswith("ask+")
assert "@mail.test.org" in addr
-259
View File
@@ -1,259 +0,0 @@
"""Tests for ObsidianService markdown parsing and file operations."""
import os
from datetime import datetime
from pathlib import Path
from unittest.mock import patch
import pytest
# Set vault path before importing so __init__ validation passes
_test_vault_dir = None
@pytest.fixture(autouse=True)
def vault_dir(tmp_path):
"""Create a temporary vault directory with a sample .md file."""
global _test_vault_dir
_test_vault_dir = tmp_path
# Create a sample markdown file so vault validation passes
sample = tmp_path / "sample.md"
sample.write_text("# Sample\nHello world")
with patch.dict(os.environ, {"OBSIDIAN_VAULT_PATH": str(tmp_path)}):
yield tmp_path
@pytest.fixture
def service(vault_dir):
from utils.obsidian_service import ObsidianService
return ObsidianService()
class TestParseMarkdown:
def test_extracts_frontmatter(self, service):
content = "---\ntitle: Test Note\ntags: [cat, vet]\n---\n\nBody content"
result = service.parse_markdown(content)
assert result["metadata"]["title"] == "Test Note"
assert result["metadata"]["tags"] == ["cat", "vet"]
def test_no_frontmatter(self, service):
content = "Just body content with no frontmatter"
result = service.parse_markdown(content)
assert result["metadata"] == {}
assert "Just body content" in result["content"]
def test_invalid_yaml_frontmatter(self, service):
content = "---\n: invalid: yaml: [[\n---\n\nBody"
result = service.parse_markdown(content)
assert result["metadata"] == {}
def test_extracts_tags(self, service):
content = "Some text with #tag1 and #tag2 here"
result = service.parse_markdown(content)
assert "tag1" in result["tags"]
assert "tag2" in result["tags"]
def test_extracts_wikilinks(self, service):
content = "Link to [[Other Note]] and [[Another Page]]"
result = service.parse_markdown(content)
assert "Other Note" in result["wikilinks"]
assert "Another Page" in result["wikilinks"]
def test_extracts_embeds(self, service):
content = "An embed [[!my_embed]] here"
result = service.parse_markdown(content)
assert "my_embed" in result["embeds"]
def test_cleans_wikilinks_from_content(self, service):
content = "Text with [[link]] included"
result = service.parse_markdown(content)
assert "[[" not in result["content"]
assert "]]" not in result["content"]
def test_filepath_passed_through(self, service):
result = service.parse_markdown("text", filepath=Path("/vault/note.md"))
assert result["filepath"] == "/vault/note.md"
def test_filepath_none_by_default(self, service):
result = service.parse_markdown("text")
assert result["filepath"] is None
def test_empty_content(self, service):
result = service.parse_markdown("")
assert result["metadata"] == {}
assert result["tags"] == []
assert result["wikilinks"] == []
assert result["embeds"] == []
class TestGetDailyNotePath:
def test_formats_path_correctly(self, service):
date = datetime(2026, 3, 15)
path = service.get_daily_note_path(date)
assert path == "journal/2026/2026-03-15.md"
def test_defaults_to_today(self, service):
path = service.get_daily_note_path()
today = datetime.now()
assert today.strftime("%Y-%m-%d") in path
assert path.startswith(f"journal/{today.strftime('%Y')}/")
class TestWalkVault:
def test_finds_markdown_files(self, service, vault_dir):
(vault_dir / "note1.md").write_text("# Note 1")
(vault_dir / "subdir").mkdir()
(vault_dir / "subdir" / "note2.md").write_text("# Note 2")
files = service.walk_vault()
filenames = [f.name for f in files]
assert "sample.md" in filenames
assert "note1.md" in filenames
assert "note2.md" in filenames
def test_excludes_obsidian_dir(self, service, vault_dir):
obsidian_dir = vault_dir / ".obsidian"
obsidian_dir.mkdir()
(obsidian_dir / "config.md").write_text("config")
files = service.walk_vault()
filenames = [f.name for f in files]
assert "config.md" not in filenames
def test_ignores_non_md_files(self, service, vault_dir):
(vault_dir / "image.png").write_bytes(b"\x89PNG")
files = service.walk_vault()
filenames = [f.name for f in files]
assert "image.png" not in filenames
class TestCreateNote:
def test_creates_file(self, service, vault_dir):
path = service.create_note("My Test Note", "Body content")
full_path = vault_dir / path
assert full_path.exists()
def test_sanitizes_title(self, service, vault_dir):
path = service.create_note("Hello World! @#$", "Body")
assert "hello-world" in path
assert "@" not in path
assert "#" not in path
def test_includes_frontmatter(self, service, vault_dir):
path = service.create_note("Test", "Body", tags=["cat", "vet"])
full_path = vault_dir / path
content = full_path.read_text()
assert "---" in content
assert "created_by: simbarag" in content
assert "cat" in content
assert "vet" in content
def test_custom_folder(self, service, vault_dir):
path = service.create_note("Test", "Body", folder="custom/subfolder")
assert path.startswith("custom/subfolder/")
assert (vault_dir / path).exists()
class TestDailyNoteTasks:
def test_get_tasks_from_daily_note(self, service, vault_dir):
# Create a daily note with tasks
date = datetime(2026, 1, 15)
rel_path = service.get_daily_note_path(date)
note_path = vault_dir / rel_path
note_path.parent.mkdir(parents=True, exist_ok=True)
note_path.write_text(
"---\nmodified: 2026-01-15\n---\n"
"### tasks\n\n"
"- [ ] Feed the cat\n"
"- [x] Clean litter box\n"
"- [ ] Buy cat food\n\n"
"### log\n"
)
result = service.get_daily_tasks(date)
assert result["found"] is True
assert len(result["tasks"]) == 3
assert result["tasks"][0] == {"text": "Feed the cat", "done": False}
assert result["tasks"][1] == {"text": "Clean litter box", "done": True}
assert result["tasks"][2] == {"text": "Buy cat food", "done": False}
def test_get_tasks_no_note(self, service):
date = datetime(2099, 12, 31)
result = service.get_daily_tasks(date)
assert result["found"] is False
assert result["tasks"] == []
def test_add_task_creates_note(self, service, vault_dir):
date = datetime(2026, 6, 1)
result = service.add_task_to_daily_note("Walk the cat", date)
assert result["success"] is True
assert result["created_note"] is True
# Verify file was created with the task
note_path = vault_dir / result["path"]
content = note_path.read_text()
assert "- [ ] Walk the cat" in content
def test_add_task_to_existing_note(self, service, vault_dir):
date = datetime(2026, 6, 2)
rel_path = service.get_daily_note_path(date)
note_path = vault_dir / rel_path
note_path.parent.mkdir(parents=True, exist_ok=True)
note_path.write_text(
"---\nmodified: 2026-06-02\n---\n"
"### tasks\n\n"
"- [ ] Existing task\n\n"
"### log\n"
)
result = service.add_task_to_daily_note("New task", date)
assert result["success"] is True
assert result["created_note"] is False
content = note_path.read_text()
assert "- [ ] Existing task" in content
assert "- [ ] New task" in content
def test_complete_task_exact_match(self, service, vault_dir):
date = datetime(2026, 6, 3)
rel_path = service.get_daily_note_path(date)
note_path = vault_dir / rel_path
note_path.parent.mkdir(parents=True, exist_ok=True)
note_path.write_text("### tasks\n\n" "- [ ] Feed the cat\n" "- [ ] Buy food\n")
result = service.complete_task_in_daily_note("Feed the cat", date)
assert result["success"] is True
content = note_path.read_text()
assert "- [x] Feed the cat" in content
assert "- [ ] Buy food" in content # Other task unchanged
def test_complete_task_partial_match(self, service, vault_dir):
date = datetime(2026, 6, 4)
rel_path = service.get_daily_note_path(date)
note_path = vault_dir / rel_path
note_path.parent.mkdir(parents=True, exist_ok=True)
note_path.write_text("### tasks\n\n- [ ] Feed the cat at 5pm\n")
result = service.complete_task_in_daily_note("Feed the cat", date)
assert result["success"] is True
def test_complete_task_not_found(self, service, vault_dir):
date = datetime(2026, 6, 5)
rel_path = service.get_daily_note_path(date)
note_path = vault_dir / rel_path
note_path.parent.mkdir(parents=True, exist_ok=True)
note_path.write_text("### tasks\n\n- [ ] Feed the cat\n")
result = service.complete_task_in_daily_note("Walk the dog", date)
assert result["success"] is False
assert "not found" in result["error"]
def test_complete_task_no_note(self, service):
date = datetime(2099, 12, 31)
result = service.complete_task_in_daily_note("Something", date)
assert result["success"] is False
-92
View File
@@ -1,92 +0,0 @@
"""Tests for rate limiting logic in email and WhatsApp blueprints."""
import time
class TestEmailRateLimit:
def setup_method(self):
"""Reset rate limit store before each test."""
from blueprints.email import _rate_limit_store
_rate_limit_store.clear()
def test_allows_under_limit(self):
from blueprints.email import _check_rate_limit, RATE_LIMIT_MAX
for _ in range(RATE_LIMIT_MAX):
assert _check_rate_limit("sender@test.com") is True
def test_blocks_at_limit(self):
from blueprints.email import _check_rate_limit, RATE_LIMIT_MAX
for _ in range(RATE_LIMIT_MAX):
_check_rate_limit("sender@test.com")
assert _check_rate_limit("sender@test.com") is False
def test_different_senders_independent(self):
from blueprints.email import _check_rate_limit, RATE_LIMIT_MAX
for _ in range(RATE_LIMIT_MAX):
_check_rate_limit("user1@test.com")
# user1 is at limit, but user2 should be fine
assert _check_rate_limit("user1@test.com") is False
assert _check_rate_limit("user2@test.com") is True
def test_window_expiry(self):
from blueprints.email import (
_check_rate_limit,
_rate_limit_store,
RATE_LIMIT_MAX,
)
# Fill up the rate limit with timestamps in the past
past = time.monotonic() - 999 # Well beyond any window
_rate_limit_store["old@test.com"] = [past] * RATE_LIMIT_MAX
# Should be allowed because all timestamps are expired
assert _check_rate_limit("old@test.com") is True
class TestWhatsAppRateLimit:
def setup_method(self):
"""Reset rate limit store before each test."""
from blueprints.whatsapp import _rate_limit_store
_rate_limit_store.clear()
def test_allows_under_limit(self):
from blueprints.whatsapp import _check_rate_limit, RATE_LIMIT_MAX
for _ in range(RATE_LIMIT_MAX):
assert _check_rate_limit("whatsapp:+1234567890") is True
def test_blocks_at_limit(self):
from blueprints.whatsapp import _check_rate_limit, RATE_LIMIT_MAX
for _ in range(RATE_LIMIT_MAX):
_check_rate_limit("whatsapp:+1234567890")
assert _check_rate_limit("whatsapp:+1234567890") is False
def test_different_numbers_independent(self):
from blueprints.whatsapp import _check_rate_limit, RATE_LIMIT_MAX
for _ in range(RATE_LIMIT_MAX):
_check_rate_limit("whatsapp:+1111111111")
assert _check_rate_limit("whatsapp:+1111111111") is False
assert _check_rate_limit("whatsapp:+2222222222") is True
def test_window_expiry(self):
from blueprints.whatsapp import (
_check_rate_limit,
_rate_limit_store,
RATE_LIMIT_MAX,
)
past = time.monotonic() - 999
_rate_limit_store["whatsapp:+9999999999"] = [past] * RATE_LIMIT_MAX
assert _check_rate_limit("whatsapp:+9999999999") is True
-86
View File
@@ -1,86 +0,0 @@
"""Tests for User model methods in blueprints/users/models.py."""
from unittest.mock import MagicMock
import bcrypt
class TestUserModelMethods:
"""Test User model methods without requiring a database connection.
We instantiate a mock object with the same methods as User
to avoid Tortoise ORM initialization.
"""
def _make_user(self, ldap_groups=None, password=None):
"""Create a mock user with real method implementations."""
from blueprints.users.models import User
user = MagicMock(spec=User)
user.ldap_groups = ldap_groups
user.password = password
# Bind real methods
user.has_group = lambda group: group in (user.ldap_groups or [])
user.is_admin = lambda: user.has_group("lldap_admin")
def set_password(plain):
user.password = bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt())
user.set_password = set_password
def verify_password(plain):
if not user.password:
return False
return bcrypt.checkpw(plain.encode("utf-8"), user.password)
user.verify_password = verify_password
return user
def test_has_group_true(self):
user = self._make_user(ldap_groups=["lldap_admin", "users"])
assert user.has_group("lldap_admin") is True
assert user.has_group("users") is True
def test_has_group_false(self):
user = self._make_user(ldap_groups=["users"])
assert user.has_group("lldap_admin") is False
def test_has_group_empty_list(self):
user = self._make_user(ldap_groups=[])
assert user.has_group("anything") is False
def test_has_group_none(self):
user = self._make_user(ldap_groups=None)
assert user.has_group("anything") is False
def test_is_admin_true(self):
user = self._make_user(ldap_groups=["lldap_admin"])
assert user.is_admin() is True
def test_is_admin_false(self):
user = self._make_user(ldap_groups=["users"])
assert user.is_admin() is False
def test_is_admin_empty(self):
user = self._make_user(ldap_groups=[])
assert user.is_admin() is False
def test_set_and_verify_password(self):
user = self._make_user()
user.set_password("hunter2")
assert user.password is not None
assert user.verify_password("hunter2") is True
assert user.verify_password("wrong") is False
def test_verify_password_no_password_set(self):
user = self._make_user(password=None)
assert user.verify_password("anything") is False
def test_password_is_hashed(self):
user = self._make_user()
user.set_password("mypassword")
# The stored password should not be the plaintext
assert user.password != b"mypassword"
assert user.password != "mypassword"
-254
View File
@@ -1,254 +0,0 @@
"""Tests for YNAB service data formatting and filtering logic."""
import os
from unittest.mock import MagicMock, patch
import pytest
def _mock_category(
name, budgeted, activity, balance, deleted=False, hidden=False, goal_type=None
):
cat = MagicMock()
cat.name = name
cat.budgeted = budgeted
cat.activity = activity
cat.balance = balance
cat.deleted = deleted
cat.hidden = hidden
cat.goal_type = goal_type
return cat
def _mock_transaction(
var_date, payee_name, category_name, amount, memo="", deleted=False, approved=True
):
txn = MagicMock()
txn.var_date = var_date
txn.payee_name = payee_name
txn.category_name = category_name
txn.amount = amount
txn.memo = memo
txn.deleted = deleted
txn.approved = approved
return txn
@pytest.fixture
def ynab_service():
"""Create a YNABService with mocked API client."""
with patch.dict(
os.environ, {"YNAB_ACCESS_TOKEN": "fake-token", "YNAB_BUDGET_ID": "test-budget"}
):
with patch("utils.ynab_service.ynab") as mock_ynab:
# Mock the configuration and API client chain
mock_ynab.Configuration.return_value = MagicMock()
mock_ynab.ApiClient.return_value = MagicMock()
mock_ynab.PlansApi.return_value = MagicMock()
mock_ynab.TransactionsApi.return_value = MagicMock()
mock_ynab.MonthsApi.return_value = MagicMock()
mock_ynab.CategoriesApi.return_value = MagicMock()
from utils.ynab_service import YNABService
service = YNABService()
yield service
class TestGetBudgetSummary:
def test_calculates_totals(self, ynab_service):
categories = [
_mock_category("Groceries", 500_000, -350_000, 150_000),
_mock_category("Rent", 1_500_000, -1_500_000, 0),
]
mock_month = MagicMock()
mock_month.to_be_budgeted = 200_000
mock_budget = MagicMock()
mock_budget.name = "My Budget"
mock_budget.months = [mock_month]
mock_budget.categories = categories
mock_budget.currency_format = MagicMock(iso_code="USD")
mock_response = MagicMock()
mock_response.data.budget = mock_budget
ynab_service.plans_api.get_plan_by_id.return_value = mock_response
result = ynab_service.get_budget_summary()
assert result["budget_name"] == "My Budget"
assert result["to_be_budgeted"] == 200.0
assert result["total_budgeted"] == 2000.0 # (500k + 1500k) / 1000
assert result["total_activity"] == -1850.0
assert result["currency_format"] == "USD"
def test_skips_deleted_and_hidden(self, ynab_service):
categories = [
_mock_category("Active", 100_000, -50_000, 50_000),
_mock_category("Deleted", 999_000, -999_000, 0, deleted=True),
_mock_category("Hidden", 999_000, -999_000, 0, hidden=True),
]
mock_month = MagicMock()
mock_month.to_be_budgeted = 0
mock_budget = MagicMock()
mock_budget.name = "Budget"
mock_budget.months = [mock_month]
mock_budget.categories = categories
mock_budget.currency_format = None
mock_response = MagicMock()
mock_response.data.budget = mock_budget
ynab_service.plans_api.get_plan_by_id.return_value = mock_response
result = ynab_service.get_budget_summary()
assert result["total_budgeted"] == 100.0
assert result["currency_format"] == "USD" # Default fallback
class TestGetTransactions:
def test_filters_by_date_range(self, ynab_service):
transactions = [
_mock_transaction("2026-01-05", "Store", "Groceries", -25_000),
_mock_transaction("2026-01-15", "Gas", "Transport", -40_000),
_mock_transaction(
"2026-02-01", "Store", "Groceries", -30_000
), # Out of range
]
mock_response = MagicMock()
mock_response.data.transactions = transactions
ynab_service.transactions_api.get_transactions.return_value = mock_response
result = ynab_service.get_transactions(
start_date="2026-01-01", end_date="2026-01-31"
)
assert result["count"] == 2
assert result["total_amount"] == -65.0
def test_filters_by_category(self, ynab_service):
transactions = [
_mock_transaction("2026-01-05", "Store", "Groceries", -25_000),
_mock_transaction("2026-01-06", "Gas", "Transport", -40_000),
]
mock_response = MagicMock()
mock_response.data.transactions = transactions
ynab_service.transactions_api.get_transactions.return_value = mock_response
result = ynab_service.get_transactions(
start_date="2026-01-01",
end_date="2026-01-31",
category_name="groceries", # Case insensitive
)
assert result["count"] == 1
assert result["transactions"][0]["category"] == "Groceries"
def test_filters_by_payee(self, ynab_service):
transactions = [
_mock_transaction("2026-01-05", "Whole Foods", "Groceries", -25_000),
_mock_transaction("2026-01-06", "Shell Gas", "Transport", -40_000),
]
mock_response = MagicMock()
mock_response.data.transactions = transactions
ynab_service.transactions_api.get_transactions.return_value = mock_response
result = ynab_service.get_transactions(
start_date="2026-01-01",
end_date="2026-01-31",
payee_name="whole",
)
assert result["count"] == 1
assert result["transactions"][0]["payee"] == "Whole Foods"
def test_skips_deleted(self, ynab_service):
transactions = [
_mock_transaction("2026-01-05", "Store", "Groceries", -25_000),
_mock_transaction("2026-01-06", "Deleted", "Other", -10_000, deleted=True),
]
mock_response = MagicMock()
mock_response.data.transactions = transactions
ynab_service.transactions_api.get_transactions.return_value = mock_response
result = ynab_service.get_transactions(
start_date="2026-01-01", end_date="2026-01-31"
)
assert result["count"] == 1
def test_converts_milliunits(self, ynab_service):
transactions = [
_mock_transaction("2026-01-05", "Store", "Groceries", -12_340),
]
mock_response = MagicMock()
mock_response.data.transactions = transactions
ynab_service.transactions_api.get_transactions.return_value = mock_response
result = ynab_service.get_transactions(
start_date="2026-01-01", end_date="2026-01-31"
)
assert result["transactions"][0]["amount"] == -12.34
def test_sorts_by_date_descending(self, ynab_service):
transactions = [
_mock_transaction("2026-01-01", "A", "Cat", -10_000),
_mock_transaction("2026-01-15", "B", "Cat", -20_000),
_mock_transaction("2026-01-10", "C", "Cat", -30_000),
]
mock_response = MagicMock()
mock_response.data.transactions = transactions
ynab_service.transactions_api.get_transactions.return_value = mock_response
result = ynab_service.get_transactions(
start_date="2026-01-01", end_date="2026-01-31"
)
dates = [t["date"] for t in result["transactions"]]
assert dates == sorted(dates, reverse=True)
class TestGetCategorySpending:
def test_month_format_normalization(self, ynab_service):
"""Passing YYYY-MM should be normalized to YYYY-MM-01."""
categories = [_mock_category("Food", 100_000, -50_000, 50_000)]
mock_month = MagicMock()
mock_month.categories = categories
mock_month.to_be_budgeted = 0
mock_response = MagicMock()
mock_response.data.month = mock_month
ynab_service.months_api.get_plan_month.return_value = mock_response
result = ynab_service.get_category_spending("2026-03")
assert result["month"] == "2026-03"
def test_identifies_overspent(self, ynab_service):
categories = [
_mock_category("Dining", 200_000, -300_000, -100_000), # Overspent
_mock_category("Groceries", 500_000, -400_000, 100_000), # Fine
]
mock_month = MagicMock()
mock_month.categories = categories
mock_month.to_be_budgeted = 0
mock_response = MagicMock()
mock_response.data.month = mock_month
ynab_service.months_api.get_plan_month.return_value = mock_response
result = ynab_service.get_category_spending("2026-03")
assert len(result["overspent_categories"]) == 1
assert result["overspent_categories"][0]["name"] == "Dining"
assert result["overspent_categories"][0]["overspent_by"] == 100.0
+137
View File
@@ -0,0 +1,137 @@
import os
from math import ceil
import re
from typing import Union
from uuid import UUID, uuid4
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from dotenv import load_dotenv
from llm import LLMClient
load_dotenv()
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
if header_patterns is None:
header_patterns = [r"^.*Header.*$"]
if footer_patterns is None:
footer_patterns = [r"^.*Footer.*$"]
for pattern in header_patterns + footer_patterns:
text = re.sub(pattern, "", text, flags=re.MULTILINE)
return text.strip()
def remove_special_characters(text, special_chars=None):
if special_chars is None:
special_chars = r"[^A-Za-z0-9\s\.,;:\'\"\?\!\-]"
text = re.sub(special_chars, "", text)
return text.strip()
def remove_repeated_substrings(text, pattern=r"\.{2,}"):
text = re.sub(pattern, ".", text)
return text.strip()
def remove_extra_spaces(text):
text = re.sub(r"\n\s*\n", "\n\n", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
def preprocess_text(text):
# Remove headers and footers
text = remove_headers_footers(text)
# Remove special characters
text = remove_special_characters(text)
# Remove repeated substrings like dots
text = remove_repeated_substrings(text)
# Remove extra spaces between lines and within lines
text = remove_extra_spaces(text)
# Additional cleaning steps can be added here
return text.strip()
class Chunk:
def __init__(
self,
text: str,
size: int,
document_id: UUID,
chunk_id: int,
embedding,
):
self.text = text
self.size = size
self.document_id = document_id
self.chunk_id = chunk_id
self.embedding = embedding
class Chunker:
def __init__(self, collection) -> None:
self.collection = collection
self.llm_client = LLMClient()
def embedding_fx(self, inputs):
openai_embedding_fx = OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="text-embedding-3-small",
)
return openai_embedding_fx(inputs)
def chunk_document(
self,
document: str,
chunk_size: int = 1000,
metadata: dict[str, Union[str, float]] = {},
) -> list[Chunk]:
doc_uuid = uuid4()
chunk_size = min(chunk_size, len(document)) or 1
chunks = []
num_chunks = ceil(len(document) / chunk_size)
document_length = len(document)
for i in range(num_chunks):
curr_pos = i * num_chunks
to_pos = (
curr_pos + chunk_size
if curr_pos + chunk_size < document_length
else document_length
)
text_chunk = self.clean_document(document[curr_pos:to_pos])
embedding = self.embedding_fx([text_chunk])
self.collection.add(
ids=[str(doc_uuid) + ":" + str(i)],
documents=[text_chunk],
embeddings=embedding,
metadatas=[metadata],
)
return chunks
def clean_document(self, document: str) -> str:
"""This function will remove information that is noise or already known.
Example: We already know all the things in here are Simba-related, so we don't need things like
"Sumamry of simba's visit"
"""
document = document.replace("\\n", "")
document = document.strip()
return preprocess_text(document)
Generated
+951 -345
View File
File diff suppressed because it is too large Load Diff