Compare commits

..

2 Commits

Author SHA1 Message Date
Ryan Chen 64dab18428 Clean up presigned URL implementation: remove dead fields, fix error handling
- Remove unused image_url from upload response and TS type
- Remove bare except in serve_image that masked config errors as 404s
- Add error state and broken-image placeholder in QuestionBubble

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-04 08:49:01 -04:00
Ryan Chen b62a8b6b3f Use presigned S3 URLs for serving images instead of proxying bytes
Browser <img> tags can't attach JWT headers, causing 401s. The image
endpoint now returns a time-limited presigned S3 URL via authenticated
API call, which the frontend fetches and uses directly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-04 08:45:35 -04:00
38 changed files with 1604 additions and 1643 deletions
+5
View File
@@ -19,6 +19,11 @@ BASE_URL=192.168.1.5:8000
LLAMA_SERVER_URL=http://192.168.1.213:8080/v1 LLAMA_SERVER_URL=http://192.168.1.213:8080/v1
LLAMA_MODEL_NAME=llama-3.1-8b-instruct LLAMA_MODEL_NAME=llama-3.1-8b-instruct
# ChromaDB Configuration
# For Docker: This is automatically set to /app/data/chromadb
# For local development: Set to a local directory path
CHROMADB_PATH=./data/chromadb
# OpenAI Configuration # OpenAI Configuration
OPENAI_API_KEY=your-openai-api-key OPENAI_API_KEY=your-openai-api-key
+3
View File
@@ -13,6 +13,9 @@ wheels/
.env .env
# Database files # Database files
chromadb/
chromadb_openai/
chroma_db/
database/ database/
*.db *.db
+4 -12
View File
@@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
## Project Overview ## Project Overview
SimbaRAG is a RAG (Retrieval-Augmented Generation) conversational AI system for querying information about Simba (a cat). It ingests documents from Paperless-NGX, stores embeddings in PostgreSQL via pgvector, and uses LLMs (Ollama or OpenAI) to answer questions. SimbaRAG is a RAG (Retrieval-Augmented Generation) conversational AI system for querying information about Simba (a cat). It ingests documents from Paperless-NGX, stores embeddings in ChromaDB, and uses LLMs (Ollama or OpenAI) to answer questions.
## Commands ## Commands
@@ -54,8 +54,9 @@ docker compose up -d
│ Docker Compose │ │ Docker Compose │
├─────────────────────────────────────────────────────────────┤ ├─────────────────────────────────────────────────────────────┤
│ raggr (port 8080) │ postgres (port 5432) │ │ raggr (port 8080) │ postgres (port 5432) │
│ ├── Quart backend │ PostgreSQL 16 + pgvector │ ├── Quart backend │ PostgreSQL 16
── React frontend (served) │ │ ── React frontend (served) │ │
│ └── ChromaDB (volume) │ │
└─────────────────────────────────────────────────────────────┘ └─────────────────────────────────────────────────────────────┘
``` ```
@@ -90,15 +91,6 @@ docker compose up -d
**Auth Flow**: LLDAP → Authelia (OIDC) → Backend JWT → Frontend localStorage **Auth Flow**: LLDAP → Authelia (OIDC) → Backend JWT → Frontend localStorage
## Testing
Always run `make test` before pushing code to ensure all tests pass.
```bash
make test # Run tests
make test-cov # Run tests with coverage
```
## Key Patterns ## Key Patterns
- All endpoints are async (`async def`) - All endpoints are async (`async def`)
+3 -2
View File
@@ -37,14 +37,15 @@ WORKDIR /app/raggr-frontend
RUN yarn install && yarn build RUN yarn install && yarn build
WORKDIR /app WORKDIR /app
# Create database directory # Create ChromaDB and database directories
RUN mkdir -p /app/database RUN mkdir -p /app/chromadb /app/database
# Expose port # Expose port
EXPOSE 8080 EXPOSE 8080
# Set environment variables # Set environment variables
ENV PYTHONPATH=/app ENV PYTHONPATH=/app
ENV CHROMADB_PATH=/app/chromadb
# Run the startup script # Run the startup script
CMD ["./startup.sh"] CMD ["./startup.sh"]
+3 -2
View File
@@ -34,15 +34,16 @@ COPY . .
WORKDIR /app/raggr-frontend WORKDIR /app/raggr-frontend
RUN yarn build RUN yarn build
# Create database directory # Create ChromaDB and database directories
WORKDIR /app WORKDIR /app
RUN mkdir -p /app/database RUN mkdir -p /app/chromadb /app/database
# Make startup script executable # Make startup script executable
RUN chmod +x /app/startup-dev.sh RUN chmod +x /app/startup-dev.sh
# Set environment variables # Set environment variables
ENV PYTHONPATH=/app ENV PYTHONPATH=/app
ENV CHROMADB_PATH=/app/chromadb
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
# Expose port # Expose port
+1 -11
View File
@@ -1,11 +1,8 @@
.PHONY: deploy redeploy build up down restart logs migrate migrate-new frontend test .PHONY: deploy build up down restart logs migrate migrate-new frontend
# Build and deploy # Build and deploy
deploy: build up deploy: build up
redeploy:
git pull && $(MAKE) down && $(MAKE) up
build: build:
docker compose build raggr docker compose build raggr
@@ -32,13 +29,6 @@ migrate-new:
migrate-history: migrate-history:
docker compose exec raggr aerich history docker compose exec raggr aerich history
# Tests
test:
pytest tests/ -v
test-cov:
pytest tests/ -v --cov
# Frontend # Frontend
frontend: frontend:
cd raggr-frontend && yarn install && yarn build cd raggr-frontend && yarn install && yarn build
+43 -5
View File
@@ -1,9 +1,8 @@
import logging import logging
import os import os
from datetime import timedelta
from dotenv import load_dotenv from dotenv import load_dotenv
from quart import Quart, jsonify, render_template, send_from_directory from quart import Quart, jsonify, render_template, request, send_from_directory
from quart_jwt_extended import JWTManager, get_jwt_identity, jwt_refresh_token_required from quart_jwt_extended import JWTManager, get_jwt_identity, jwt_refresh_token_required
from tortoise import Tortoise from tortoise import Tortoise
@@ -15,6 +14,7 @@ import blueprints.users
import blueprints.whatsapp import blueprints.whatsapp
import blueprints.users.models import blueprints.users.models
from config.db import TORTOISE_CONFIG from config.db import TORTOISE_CONFIG
from main import consult_simba_oracle
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
@@ -38,8 +38,6 @@ app = Quart(
) )
app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY") app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY")
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1)
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=30)
app.config["MAX_CONTENT_LENGTH"] = 10 * 1024 * 1024 # 10 MB upload limit app.config["MAX_CONTENT_LENGTH"] = 10 * 1024 * 1024 # 10 MB upload limit
jwt = JWTManager(app) jwt = JWTManager(app)
@@ -77,6 +75,39 @@ async def serve_react_app(path):
return await render_template("index.html") return await render_template("index.html")
@app.route("/api/query", methods=["POST"])
@jwt_refresh_token_required
async def query():
current_user_uuid = get_jwt_identity()
user = await blueprints.users.models.User.get(id=current_user_uuid)
data = await request.get_json()
query = data.get("query")
conversation_id = data.get("conversation_id")
conversation = await blueprints.conversation.logic.get_conversation_by_id(
conversation_id
)
await conversation.fetch_related("messages")
await blueprints.conversation.logic.add_message_to_conversation(
conversation=conversation,
message=query,
speaker="user",
user=user,
)
transcript = await blueprints.conversation.logic.get_conversation_transcript(
user=user, conversation=conversation
)
response = consult_simba_oracle(input=query, transcript=transcript)
await blueprints.conversation.logic.add_message_to_conversation(
conversation=conversation,
message=response,
speaker="simba",
user=user,
)
return jsonify({"response": response})
@app.route("/api/messages", methods=["GET"]) @app.route("/api/messages", methods=["GET"])
@jwt_refresh_token_required @jwt_refresh_token_required
async def get_messages(): async def get_messages():
@@ -101,10 +132,17 @@ async def get_messages():
} }
) )
name = conversation.name
if len(messages) > 8:
name = await blueprints.conversation.logic.rename_conversation(
user=user,
conversation=conversation,
)
return jsonify( return jsonify(
{ {
"id": str(conversation.id), "id": str(conversation.id),
"name": conversation.name, "name": name,
"messages": messages, "messages": messages,
"created_at": conversation.created_at.isoformat(), "created_at": conversation.created_at.isoformat(),
"updated_at": conversation.updated_at.isoformat(), "updated_at": conversation.updated_at.isoformat(),
+21 -30
View File
@@ -1,3 +1,4 @@
import datetime
import json import json
import logging import logging
import uuid import uuid
@@ -19,8 +20,8 @@ from .agents import main_agent
from .logic import ( from .logic import (
add_message_to_conversation, add_message_to_conversation,
get_conversation_by_id, get_conversation_by_id,
rename_conversation,
) )
from .memory import get_memories_for_user
from .models import ( from .models import (
Conversation, Conversation,
PydConversation, PydConversation,
@@ -35,27 +36,15 @@ conversation_blueprint = Blueprint(
_SYSTEM_PROMPT = SIMBA_SYSTEM_PROMPT _SYSTEM_PROMPT = SIMBA_SYSTEM_PROMPT
async def _build_system_prompt_with_memories(user_id: str) -> str:
"""Append user memories to the base system prompt."""
memories = await get_memories_for_user(user_id)
if not memories:
return _SYSTEM_PROMPT
memory_block = "\n".join(f"- {m}" for m in memories)
return f"{_SYSTEM_PROMPT}\n\nUSER MEMORIES (facts the user has asked you to remember):\n{memory_block}"
def _build_messages_payload( def _build_messages_payload(
conversation, conversation, query_text: str, image_description: str | None = None
query_text: str,
image_description: str | None = None,
system_prompt: str | None = None,
) -> list: ) -> list:
recent_messages = ( recent_messages = (
conversation.messages[-10:] conversation.messages[-10:]
if len(conversation.messages) > 10 if len(conversation.messages) > 10
else conversation.messages else conversation.messages
) )
messages_payload = [{"role": "system", "content": system_prompt or _SYSTEM_PROMPT}] messages_payload = [{"role": "system", "content": _SYSTEM_PROMPT}]
for msg in recent_messages[:-1]: # Exclude the message we just added for msg in recent_messages[:-1]: # Exclude the message we just added
role = "user" if msg.speaker == "user" else "assistant" role = "user" if msg.speaker == "user" else "assistant"
text = msg.text text = msg.text
@@ -91,14 +80,10 @@ async def query():
user=user, user=user,
) )
system_prompt = await _build_system_prompt_with_memories(str(user.id)) messages_payload = _build_messages_payload(conversation, query)
messages_payload = _build_messages_payload(
conversation, query, system_prompt=system_prompt
)
payload = {"messages": messages_payload} payload = {"messages": messages_payload}
agent_config = {"configurable": {"user_id": str(user.id)}}
response = await main_agent.ainvoke(payload, config=agent_config) response = await main_agent.ainvoke(payload)
message = response.get("messages", [])[-1].content message = response.get("messages", [])[-1].content
await add_message_to_conversation( await add_message_to_conversation(
conversation=conversation, conversation=conversation,
@@ -178,19 +163,15 @@ async def stream_query():
logging.error(f"Failed to analyze image: {e}") logging.error(f"Failed to analyze image: {e}")
image_description = "[Image could not be analyzed]" image_description = "[Image could not be analyzed]"
system_prompt = await _build_system_prompt_with_memories(str(user.id))
messages_payload = _build_messages_payload( messages_payload = _build_messages_payload(
conversation, query_text or "", image_description, system_prompt=system_prompt conversation, query_text or "", image_description
) )
payload = {"messages": messages_payload} payload = {"messages": messages_payload}
agent_config = {"configurable": {"user_id": str(user.id)}}
async def event_generator(): async def event_generator():
final_message = None final_message = None
try: try:
async for event in main_agent.astream_events( async for event in main_agent.astream_events(payload, version="v2"):
payload, version="v2", config=agent_config
):
event_type = event.get("event") event_type = event.get("event")
if event_type == "on_tool_start": if event_type == "on_tool_start":
@@ -240,6 +221,8 @@ async def stream_query():
@jwt_refresh_token_required @jwt_refresh_token_required
async def get_conversation(conversation_id: str): async def get_conversation(conversation_id: str):
conversation = await Conversation.get(id=conversation_id) conversation = await Conversation.get(id=conversation_id)
current_user_uuid = get_jwt_identity()
user = await blueprints.users.models.User.get(id=current_user_uuid)
await conversation.fetch_related("messages") await conversation.fetch_related("messages")
# Manually serialize the conversation with messages # Manually serialize the conversation with messages
@@ -254,10 +237,18 @@ async def get_conversation(conversation_id: str):
"image_key": msg.image_key, "image_key": msg.image_key,
} }
) )
name = conversation.name
if len(messages) > 8 and "datetime" in name.lower():
name = await rename_conversation(
user=user,
conversation=conversation,
)
print(name)
return jsonify( return jsonify(
{ {
"id": str(conversation.id), "id": str(conversation.id),
"name": conversation.name, "name": name,
"messages": messages, "messages": messages,
"created_at": conversation.created_at.isoformat(), "created_at": conversation.created_at.isoformat(),
"updated_at": conversation.updated_at.isoformat(), "updated_at": conversation.updated_at.isoformat(),
@@ -271,7 +262,7 @@ async def create_conversation():
user_uuid = get_jwt_identity() user_uuid = get_jwt_identity()
user = await blueprints.users.models.User.get(id=user_uuid) user = await blueprints.users.models.User.get(id=user_uuid)
conversation = await Conversation.create( conversation = await Conversation.create(
name="New Conversation", name=f"{user.username} {datetime.datetime.now().timestamp}",
user=user, user=user,
) )
@@ -284,7 +275,7 @@ async def create_conversation():
async def get_all_conversations(): async def get_all_conversations():
user_uuid = get_jwt_identity() user_uuid = get_jwt_identity()
user = await blueprints.users.models.User.get(id=user_uuid) user = await blueprints.users.models.User.get(id=user_uuid)
conversations = Conversation.filter(user=user).order_by("-updated_at") conversations = Conversation.filter(user=user)
serialized_conversations = await PydListConversation.from_queryset(conversations) serialized_conversations = await PydListConversation.from_queryset(conversations)
return jsonify(serialized_conversations.model_dump()) return jsonify(serialized_conversations.model_dump())
+2 -31
View File
@@ -5,11 +5,9 @@ from dotenv import load_dotenv
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain.chat_models import BaseChatModel from langchain.chat_models import BaseChatModel
from langchain.tools import tool from langchain.tools import tool
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from tavily import AsyncTavilyClient from tavily import AsyncTavilyClient
from blueprints.conversation.memory import save_memory
from blueprints.rag.logic import query_vector_store from blueprints.rag.logic import query_vector_store
from utils.obsidian_service import ObsidianService from utils.obsidian_service import ObsidianService
from utils.ynab_service import YNABService from utils.ynab_service import YNABService
@@ -328,7 +326,7 @@ async def obsidian_search_notes(query: str) -> str:
return "Obsidian integration is not configured. Please set OBSIDIAN_VAULT_PATH environment variable." return "Obsidian integration is not configured. Please set OBSIDIAN_VAULT_PATH environment variable."
try: try:
# Query vector store for obsidian documents # Query ChromaDB for obsidian documents
serialized, docs = await query_vector_store(query=query) serialized, docs = await query_vector_store(query=query)
return serialized return serialized
@@ -591,35 +589,8 @@ async def obsidian_create_task(
return f"Error creating task: {str(e)}" return f"Error creating task: {str(e)}"
@tool
async def save_user_memory(content: str, config: RunnableConfig) -> str:
"""Save a fact or preference about the user for future conversations.
Use this tool when the user:
- Explicitly asks you to remember something ("remember that...", "keep in mind...")
- Shares a personal preference that would be useful in future conversations
(e.g., "I prefer metric units", "my cat's name is Luna")
- Tells you a meaningful personal fact (e.g., "I'm allergic to peanuts")
Do NOT save:
- Trivial or ephemeral info (e.g., "I'm tired today")
- Information already in the system prompt or documents
- Conversation-specific context that won't matter later
Args:
content: A concise statement of the fact or preference to remember.
Write it as a standalone sentence (e.g., "User prefers dark mode"
rather than "likes dark mode").
Returns:
Confirmation that the memory was saved.
"""
user_id = config["configurable"]["user_id"]
return await save_memory(user_id=user_id, content=content)
# Create tools list based on what's available # Create tools list based on what's available
tools = [get_current_date, simba_search, web_search, save_user_memory] tools = [get_current_date, simba_search, web_search]
if ynab_enabled: if ynab_enabled:
tools.extend( tools.extend(
[ [
+21 -7
View File
@@ -1,8 +1,9 @@
import tortoise.exceptions import tortoise.exceptions
from langchain_openai import ChatOpenAI
import blueprints.users.models import blueprints.users.models
from .models import Conversation, ConversationMessage from .models import Conversation, ConversationMessage, RenameConversationOutputSchema
async def create_conversation(name: str = "") -> Conversation: async def create_conversation(name: str = "") -> Conversation:
@@ -18,12 +19,6 @@ async def add_message_to_conversation(
image_key: str | None = None, image_key: str | None = None,
) -> ConversationMessage: ) -> ConversationMessage:
print(conversation, message, speaker) print(conversation, message, speaker)
# Name the conversation after the first user message
if speaker == "user" and not await conversation.messages.all().exists():
conversation.name = message[:100]
await conversation.save()
message = await ConversationMessage.create( message = await ConversationMessage.create(
text=message, text=message,
speaker=speaker, speaker=speaker,
@@ -66,3 +61,22 @@ async def get_conversation_transcript(
messages.append(f"{message.speaker} at {message.created_at}: {message.text}") messages.append(f"{message.speaker} at {message.created_at}: {message.text}")
return "\n".join(messages) return "\n".join(messages)
async def rename_conversation(
user: blueprints.users.models.User,
conversation: Conversation,
) -> str:
messages: str = await get_conversation_transcript(
user=user, conversation=conversation
)
llm = ChatOpenAI(model="gpt-4o-mini")
structured_llm = llm.with_structured_output(RenameConversationOutputSchema)
prompt = f"Summarize the following conversation into a sassy one-liner title:\n\n{messages}"
response = structured_llm.invoke(prompt)
new_name: str = response.get("title", "")
conversation.name = new_name
await conversation.save()
return new_name
-19
View File
@@ -1,19 +0,0 @@
from .models import UserMemory
async def get_memories_for_user(user_id: str) -> list[str]:
"""Return all memory content strings for a user, ordered by most recently updated."""
memories = await UserMemory.filter(user_id=user_id).order_by("-updated_at")
return [m.content for m in memories]
async def save_memory(user_id: str, content: str) -> str:
"""Save a new memory or touch an existing one (exact-match dedup)."""
existing = await UserMemory.filter(user_id=user_id, content=content).first()
if existing:
existing.updated_at = None # auto_now=True will refresh it on save
await existing.save(update_fields=["updated_at"])
return "Memory already exists (refreshed)."
await UserMemory.create(user_id=user_id, content=content)
return "Memory saved."
+7 -11
View File
@@ -1,4 +1,5 @@
import enum import enum
from dataclasses import dataclass
from tortoise import fields from tortoise import fields
from tortoise.contrib.pydantic import ( from tortoise.contrib.pydantic import (
@@ -8,6 +9,12 @@ from tortoise.contrib.pydantic import (
from tortoise.models import Model from tortoise.models import Model
@dataclass
class RenameConversationOutputSchema:
title: str
justification: str
class Speaker(enum.Enum): class Speaker(enum.Enum):
USER = "user" USER = "user"
SIMBA = "simba" SIMBA = "simba"
@@ -40,17 +47,6 @@ class ConversationMessage(Model):
table = "conversation_messages" table = "conversation_messages"
class UserMemory(Model):
id = fields.UUIDField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="memories")
content = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True)
class Meta:
table = "user_memories"
PydConversationMessage = pydantic_model_creator(ConversationMessage) PydConversationMessage = pydantic_model_creator(ConversationMessage)
PydConversation = pydantic_model_creator( PydConversation = pydantic_model_creator(
Conversation, name="Conversation", allow_cycles=True, exclude=("user",) Conversation, name="Conversation", allow_cycles=True, exclude=("user",)
+1 -4
View File
@@ -54,7 +54,4 @@ You have access to Ryan's daily journal notes. Each note lives at journal/YYYY/Y
- Use journal_get_tasks to list tasks (done/pending) for today or a specific date - Use journal_get_tasks to list tasks (done/pending) for today or a specific date
- Use journal_add_task to add a new task to today's (or a given date's) note - Use journal_add_task to add a new task to today's (or a given date's) note
- Use journal_complete_task to check off a task as done - Use journal_complete_task to check off a task as done
Use these tools when Ryan asks about today's tasks, wants to add something to his list, or wants to mark a task complete. Use these tools when Ryan asks about today's tasks, wants to add something to his list, or wants to mark a task complete."""
USER MEMORY:
You can remember facts about the user across conversations using the save_user_memory tool. When a user explicitly asks you to remember something, or shares a meaningful preference or personal fact, save it. Saved memories will automatically appear at the end of this prompt in future conversations under "USER MEMORIES"."""
+9 -7
View File
@@ -1,12 +1,7 @@
from quart import Blueprint, jsonify from quart import Blueprint, jsonify
from quart_jwt_extended import jwt_refresh_token_required from quart_jwt_extended import jwt_refresh_token_required
from .logic import ( from .logic import fetch_obsidian_documents, get_vector_store_stats, index_documents, index_obsidian_documents, vector_store
delete_all_documents,
get_vector_store_stats,
index_documents,
index_obsidian_documents,
)
from blueprints.users.decorators import admin_required from blueprints.users.decorators import admin_required
rag_blueprint = Blueprint("rag_api", __name__, url_prefix="/api/rag") rag_blueprint = Blueprint("rag_api", __name__, url_prefix="/api/rag")
@@ -37,7 +32,14 @@ async def trigger_index():
async def trigger_reindex(): async def trigger_reindex():
"""Clear and reindex all documents. Admin only.""" """Clear and reindex all documents. Admin only."""
try: try:
delete_all_documents() # Clear existing documents
collection = vector_store._collection
all_docs = collection.get()
if all_docs["ids"]:
collection.delete(ids=all_docs["ids"])
# Reindex
await index_documents() await index_documents()
stats = get_vector_store_stats() stats = get_vector_store_stats()
return jsonify({"status": "success", "stats": stats}) return jsonify({"status": "success", "stats": stats})
+25 -121
View File
@@ -1,13 +1,11 @@
import datetime import datetime
import logging
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain_chroma import Chroma
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAIEmbeddings
from langchain_postgres import PGVector
from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter
from sqlalchemy import create_engine, text
from .fetchers import PaperlessNGXService from .fetchers import PaperlessNGXService
from utils.obsidian_service import ObsidianService from utils.obsidian_service import ObsidianService
@@ -15,40 +13,13 @@ from utils.obsidian_service import ObsidianService
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
logger = logging.getLogger(__name__)
embeddings = OpenAIEmbeddings(model="text-embedding-3-small") embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
# Convert Tortoise-style postgres:// URL to SQLAlchemy-style postgresql+psycopg:// vector_store = Chroma(
_db_url = os.getenv(
"DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr"
)
_pgvector_url = _db_url.replace("postgres://", "postgresql+psycopg://")
# Lazy-initialized vector store (defers DB connection to first use)
_vector_store = None
def _get_vector_store() -> PGVector:
global _vector_store
if _vector_store is None:
_vector_store = PGVector(
embeddings=embeddings,
collection_name="simba_docs", collection_name="simba_docs",
connection=_pgvector_url, embedding_function=embeddings,
use_jsonb=True, persist_directory=os.getenv("CHROMADB_PATH", ""),
create_extension=False, # created by docker init script )
async_mode=True,
)
return _vector_store
def _get_engine():
"""Get a SQLAlchemy engine for direct queries."""
if not hasattr(_get_engine, "_engine"):
_get_engine._engine = create_engine(_pgvector_url)
return _get_engine._engine
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, # chunk size (characters) chunk_size=1000, # chunk size (characters)
@@ -57,22 +28,6 @@ text_splitter = RecursiveCharacterTextSplitter(
) )
def _get_collection_id():
"""Get the UUID of our collection from the langchain_pg_collection table."""
engine = _get_engine()
try:
with engine.connect() as conn:
result = conn.execute(
text("SELECT uuid FROM langchain_pg_collection WHERE name = :name"),
{"name": "simba_docs"},
)
row = result.fetchone()
return row[0] if row else None
except Exception:
# Table doesn't exist yet (first run before any indexing)
return None
def date_to_epoch(date_str: str) -> float: def date_to_epoch(date_str: str) -> float:
split_date = date_str.split("-") split_date = date_str.split("-")
date = datetime.datetime( date = datetime.datetime(
@@ -108,7 +63,6 @@ async def index_documents():
documents = await fetch_documents_from_paperless_ngx() documents = await fetch_documents_from_paperless_ngx()
splits = text_splitter.split_documents(documents) splits = text_splitter.split_documents(documents)
vector_store = _get_vector_store()
await vector_store.aadd_documents(documents=splits) await vector_store.aadd_documents(documents=splits)
@@ -138,17 +92,13 @@ async def fetch_obsidian_documents() -> list[Document]:
"filepath": parsed["filepath"], "filepath": parsed["filepath"],
"tags": parsed["tags"], "tags": parsed["tags"],
"created_at": parsed["metadata"].get("created_at"), "created_at": parsed["metadata"].get("created_at"),
**{ **{k: v for k, v in parsed["metadata"].items() if k not in ["created_at", "created_by"]},
k: v
for k, v in parsed["metadata"].items()
if k not in ["created_at", "created_by"]
},
}, },
) )
documents.append(document) documents.append(document)
except Exception as e: except Exception as e:
logger.warning(f"Error reading {md_path}: {e}") print(f"Error reading {md_path}: {e}")
continue continue
return documents return documents
@@ -159,25 +109,26 @@ async def index_obsidian_documents():
Deletes existing obsidian source chunks before re-indexing. Deletes existing obsidian source chunks before re-indexing.
""" """
obsidian_service = ObsidianService()
documents = await fetch_obsidian_documents() documents = await fetch_obsidian_documents()
if not documents: if not documents:
logger.info("No Obsidian documents found to index") print("No Obsidian documents found to index")
return {"indexed": 0} return {"indexed": 0}
# Delete existing obsidian chunks # Delete existing obsidian chunks
delete_documents_by_metadata("source", "obsidian") existing_results = vector_store.get(where={"source": "obsidian"})
if existing_results.get("ids"):
await vector_store.adelete(existing_results["ids"])
# Split and index documents # Split and index documents
splits = text_splitter.split_documents(documents) splits = text_splitter.split_documents(documents)
vector_store = _get_vector_store()
await vector_store.aadd_documents(documents=splits) await vector_store.aadd_documents(documents=splits)
return {"indexed": len(documents)} return {"indexed": len(documents)}
async def query_vector_store(query: str): async def query_vector_store(query: str):
vector_store = _get_vector_store()
retrieved_docs = await vector_store.asimilarity_search(query, k=2) retrieved_docs = await vector_store.asimilarity_search(query, k=2)
serialized = "\n\n".join( serialized = "\n\n".join(
(f"Source: {doc.metadata}\nContent: {doc.page_content}") (f"Source: {doc.metadata}\nContent: {doc.page_content}")
@@ -186,79 +137,32 @@ async def query_vector_store(query: str):
return serialized, retrieved_docs return serialized, retrieved_docs
def delete_all_documents():
"""Delete all documents from the vector store collection."""
collection_id = _get_collection_id()
if not collection_id:
return
engine = _get_engine()
with engine.connect() as conn:
conn.execute(
text("DELETE FROM langchain_pg_embedding WHERE collection_id = :cid"),
{"cid": collection_id},
)
conn.commit()
def delete_documents_by_metadata(key: str, value: str):
"""Delete documents matching a metadata key/value pair."""
collection_id = _get_collection_id()
if not collection_id:
return
engine = _get_engine()
with engine.connect() as conn:
conn.execute(
text(
"DELETE FROM langchain_pg_embedding "
"WHERE collection_id = :cid AND cmetadata->>:key = :value"
),
{"cid": collection_id, "key": key, "value": value},
)
conn.commit()
def get_vector_store_stats(): def get_vector_store_stats():
"""Get statistics about the vector store.""" """Get statistics about the vector store."""
collection_id = _get_collection_id() collection = vector_store._collection
count = 0 count = collection.count()
if collection_id:
engine = _get_engine()
with engine.connect() as conn:
result = conn.execute(
text(
"SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_id = :cid"
),
{"cid": collection_id},
)
count = result.scalar()
return { return {
"total_documents": count, "total_documents": count,
"collection_name": "simba_docs", "collection_name": collection.name,
} }
def list_all_documents(limit: int = 10): def list_all_documents(limit: int = 10):
"""List documents in the vector store with their metadata.""" """List documents in the vector store with their metadata."""
collection_id = _get_collection_id() collection = vector_store._collection
if not collection_id: results = collection.get(limit=limit, include=["metadatas", "documents"])
return []
engine = _get_engine()
with engine.connect() as conn:
result = conn.execute(
text(
"SELECT id, document, cmetadata FROM langchain_pg_embedding "
"WHERE collection_id = :cid LIMIT :limit"
),
{"cid": collection_id, "limit": limit},
)
documents = [] documents = []
for row in result: for i, doc_id in enumerate(results["ids"]):
documents.append( documents.append(
{ {
"id": str(row[0]), "id": doc_id,
"metadata": row[2], "metadata": results["metadatas"][i]
"content_preview": row[1][:200] if row[1] else None, if results.get("metadatas")
else None,
"content_preview": results["documents"][i][:200]
if results.get("documents")
else None,
} }
) )
+3 -3
View File
@@ -35,7 +35,7 @@ class OIDCUserService:
claims.get("preferred_username") or claims.get("name") or user.username claims.get("preferred_username") or claims.get("name") or user.username
) )
# Update LDAP groups from claims # Update LDAP groups from claims
user.ldap_groups = claims.get("groups") or [] user.ldap_groups = claims.get("groups", [])
await user.save() await user.save()
return user return user
@@ -48,7 +48,7 @@ class OIDCUserService:
user.oidc_subject = oidc_subject user.oidc_subject = oidc_subject
user.auth_provider = "oidc" user.auth_provider = "oidc"
user.password = None # Clear password user.password = None # Clear password
user.ldap_groups = claims.get("groups") or [] user.ldap_groups = claims.get("groups", [])
await user.save() await user.save()
return user return user
@@ -61,7 +61,7 @@ class OIDCUserService:
) )
# Extract LDAP groups from claims # Extract LDAP groups from claims
groups = claims.get("groups") or [] groups = claims.get("groups", [])
user = await User.create( user = await User.create(
id=uuid4(), id=uuid4(),
+4 -2
View File
@@ -2,7 +2,7 @@ version: "3.8"
services: services:
postgres: postgres:
image: pgvector/pgvector:pg16 image: postgres:16-alpine
ports: ports:
- "5432:5432" - "5432:5432"
environment: environment:
@@ -11,7 +11,6 @@ services:
- POSTGRES_DB=${POSTGRES_DB:-raggr} - POSTGRES_DB=${POSTGRES_DB:-raggr}
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
- ./docker/init-pgvector.sql:/docker-entrypoint-initdb.d/init-pgvector.sql
healthcheck: healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-raggr}"] test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-raggr}"]
interval: 10s interval: 10s
@@ -30,6 +29,7 @@ services:
- PAPERLESS_TOKEN=${PAPERLESS_TOKEN} - PAPERLESS_TOKEN=${PAPERLESS_TOKEN}
- BASE_URL=${BASE_URL} - BASE_URL=${BASE_URL}
- OLLAMA_URL=${OLLAMA_URL:-http://localhost:11434} - OLLAMA_URL=${OLLAMA_URL:-http://localhost:11434}
- CHROMADB_PATH=/app/data/chromadb
- OPENAI_API_KEY=${OPENAI_API_KEY} - OPENAI_API_KEY=${OPENAI_API_KEY}
- JWT_SECRET_KEY=${JWT_SECRET_KEY} - JWT_SECRET_KEY=${JWT_SECRET_KEY}
- LLAMA_SERVER_URL=${LLAMA_SERVER_URL} - LLAMA_SERVER_URL=${LLAMA_SERVER_URL}
@@ -66,8 +66,10 @@ services:
postgres: postgres:
condition: service_healthy condition: service_healthy
volumes: volumes:
- chromadb_data:/app/data/chromadb
- ./obvault:/app/data/obsidian - ./obvault:/app/data/obsidian
restart: unless-stopped restart: unless-stopped
volumes: volumes:
chromadb_data:
postgres_data: postgres_data:
-1
View File
@@ -1 +0,0 @@
CREATE EXTENSION IF NOT EXISTS vector;
+278
View File
@@ -0,0 +1,278 @@
import argparse
import datetime
import logging
import os
import sqlite3
import time
from dotenv import load_dotenv
import chromadb
from utils.chunker import Chunker
from utils.cleaner import pdf_to_image, summarize_pdf_image
from llm import LLMClient
from scripts.query import QueryGenerator
from utils.request import PaperlessNGXService
_dotenv_loaded = load_dotenv()
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
simba_docs = client.get_or_create_collection(name="simba_docs2")
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
parser = argparse.ArgumentParser(
description="An LLM tool to query information about Simba <3"
)
parser.add_argument("query", type=str, help="questions about simba's health")
parser.add_argument(
"--reindex", action="store_true", help="re-index the simba documents"
)
parser.add_argument("--classify", action="store_true", help="test classification")
parser.add_argument("--index", help="index a file")
ppngx = PaperlessNGXService()
llm_client = LLMClient()
def index_using_pdf_llm(doctypes):
logging.info("reindex data...")
files = ppngx.get_data()
for file in files:
document_id: int = file["id"]
pdf_path = ppngx.download_pdf_from_id(id=document_id)
image_paths = pdf_to_image(filepath=pdf_path)
logging.info(f"summarizing {file}")
generated_summary = summarize_pdf_image(filepaths=image_paths)
file["content"] = generated_summary
chunk_data(files, simba_docs, doctypes=doctypes)
def date_to_epoch(date_str: str) -> float:
split_date = date_str.split("-")
date = datetime.datetime(
int(split_date[0]),
int(split_date[1]),
int(split_date[2]),
0,
0,
0,
)
return date.timestamp()
def chunk_data(docs, collection, doctypes):
# Step 2: Create chunks
chunker = Chunker(collection)
logging.info(f"chunking {len(docs)} documents")
texts: list[str] = [doc["content"] for doc in docs]
with sqlite3.connect("database/visited.db") as conn:
to_insert = []
c = conn.cursor()
for index, text in enumerate(texts):
metadata = {
"created_date": date_to_epoch(docs[index]["created_date"]),
"filename": docs[index]["original_file_name"],
"document_type": doctypes.get(docs[index]["document_type"], ""),
}
if doctypes:
metadata["type"] = doctypes.get(docs[index]["document_type"])
chunker.chunk_document(
document=text,
metadata=metadata,
)
to_insert.append((docs[index]["id"],))
c.executemany(
"INSERT INTO indexed_documents (paperless_id) values (?)", to_insert
)
conn.commit()
def chunk_text(texts: list[str], collection):
chunker = Chunker(collection)
for index, text in enumerate(texts):
metadata = {}
chunker.chunk_document(
document=text,
metadata=metadata,
)
def classify_query(query: str, transcript: str) -> bool:
logging.info("Starting query generation")
qg_start = time.time()
qg = QueryGenerator()
query_type = qg.get_query_type(input=query, transcript=transcript)
logging.info(query_type)
qg_end = time.time()
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
return query_type == "Simba"
def consult_oracle(
input: str,
collection,
transcript: str = "",
):
chunker = Chunker(collection)
start_time = time.time()
# Ask
logging.info("Starting query generation")
qg_start = time.time()
qg = QueryGenerator()
doctype_query = qg.get_doctype_query(input=input)
# metadata_filter = qg.get_query(input)
metadata_filter = {**doctype_query}
logging.info(metadata_filter)
qg_end = time.time()
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
logging.info("Starting embedding generation")
embedding_start = time.time()
embeddings = chunker.embedding_fx(inputs=[input])
embedding_end = time.time()
logging.info(
f"Embedding generation took {embedding_end - embedding_start:.2f} seconds"
)
logging.info("Starting collection query")
query_start = time.time()
results = collection.query(
query_texts=[input],
query_embeddings=embeddings,
where=metadata_filter,
)
query_end = time.time()
logging.info(f"Collection query took {query_end - query_start:.2f} seconds")
# Generate
logging.info("Starting LLM generation")
llm_start = time.time()
system_prompt = "You are a helpful assistant that understands veterinary terms."
transcript_prompt = f"Here is the message transcript thus far {transcript}."
prompt = f"""Using the following data, help answer the user's query by providing as many details as possible.
Using this data: {results}. {transcript_prompt if len(transcript) > 0 else ""}
Respond to this prompt: {input}"""
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
llm_end = time.time()
logging.info(f"LLM generation took {llm_end - llm_start:.2f} seconds")
total_time = time.time() - start_time
logging.info(f"Total consult_oracle execution took {total_time:.2f} seconds")
return output
def llm_chat(input: str, transcript: str = "") -> str:
system_prompt = "You are a helpful assistant that understands veterinary terms."
transcript_prompt = f"Here is the message transcript thus far {transcript}."
prompt = f"""Answer the user in as if you were a cat named Simba. Don't act too catlike. Be assertive.
{transcript_prompt if len(transcript) > 0 else ""}
Respond to this prompt: {input}"""
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
return output
def paperless_workflow(input):
# Step 1: Get the text
ppngx = PaperlessNGXService()
docs = ppngx.get_data()
chunk_data(docs, collection=simba_docs)
consult_oracle(input, simba_docs)
def consult_simba_oracle(input: str, transcript: str = ""):
is_simba_related = classify_query(query=input, transcript=transcript)
if is_simba_related:
logging.info("Query is related to simba")
return consult_oracle(
input=input,
collection=simba_docs,
transcript=transcript,
)
logging.info("Query is NOT related to simba")
return llm_chat(input=input, transcript=transcript)
def filter_indexed_files(docs):
with sqlite3.connect("database/visited.db") as conn:
c = conn.cursor()
c.execute(
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
)
c.execute("SELECT paperless_id FROM indexed_documents")
rows = c.fetchall()
conn.commit()
visited = {row[0] for row in rows}
return [doc for doc in docs if doc["id"] not in visited]
def reindex():
with sqlite3.connect("database/visited.db") as conn:
c = conn.cursor()
# Ensure the table exists before trying to delete from it
c.execute(
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
)
c.execute("DELETE FROM indexed_documents")
conn.commit()
# Delete all documents from the collection
all_docs = simba_docs.get()
if all_docs["ids"]:
simba_docs.delete(ids=all_docs["ids"])
logging.info("Fetching documents from Paperless-NGX")
ppngx = PaperlessNGXService()
docs = ppngx.get_data()
docs = filter_indexed_files(docs)
logging.info(f"Fetched {len(docs)} documents")
# Delete all chromadb data
ids = simba_docs.get(ids=None, limit=None, offset=0)
all_ids = ids["ids"]
if len(all_ids) > 0:
simba_docs.delete(ids=all_ids)
# Chunk documents
logging.info("Chunking documents now ...")
doctype_lookup = ppngx.get_doctypes()
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
logging.info("Done chunking documents")
if __name__ == "__main__":
args = parser.parse_args()
if args.reindex:
reindex()
if args.classify:
consult_simba_oracle(input="yohohoho testing")
consult_simba_oracle(input="write an email")
consult_simba_oracle(input="how much does simba weigh")
if args.query:
logging.info("Consulting oracle ...")
print(
consult_oracle(
input=args.query,
collection=simba_docs,
)
)
else:
logging.info("please provide a query")
@@ -1,112 +0,0 @@
from tortoise import BaseDBAsyncClient
RUN_IN_TRANSACTION = True
async def upgrade(db: BaseDBAsyncClient) -> str:
return """
CREATE TABLE IF NOT EXISTS "user_memories" (
"id" UUID NOT NULL PRIMARY KEY,
"content" TEXT NOT NULL,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" UUID NOT NULL REFERENCES "users" ("id") ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS "email_accounts" (
"id" UUID NOT NULL PRIMARY KEY,
"email_address" VARCHAR(255) NOT NULL UNIQUE,
"display_name" VARCHAR(255),
"imap_host" VARCHAR(255) NOT NULL,
"imap_port" INT NOT NULL DEFAULT 993,
"imap_username" VARCHAR(255) NOT NULL,
"imap_password" TEXT NOT NULL,
"is_active" BOOL NOT NULL DEFAULT True,
"last_error" TEXT,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" UUID NOT NULL REFERENCES "users" ("id") ON DELETE CASCADE
);
COMMENT ON TABLE "email_accounts" IS 'Email account configuration for IMAP connections.';
CREATE TABLE IF NOT EXISTS "emails" (
"id" UUID NOT NULL PRIMARY KEY,
"message_id" VARCHAR(255) NOT NULL UNIQUE,
"subject" VARCHAR(500) NOT NULL,
"from_address" VARCHAR(255) NOT NULL,
"to_address" TEXT NOT NULL,
"date" TIMESTAMPTZ NOT NULL,
"body_text" TEXT,
"body_html" TEXT,
"chromadb_doc_id" VARCHAR(255),
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"expires_at" TIMESTAMPTZ NOT NULL,
"account_id" UUID NOT NULL REFERENCES "email_accounts" ("id") ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS "idx_emails_message_981ddd" ON "emails" ("message_id");
COMMENT ON TABLE "emails" IS 'Email message metadata and content.';
CREATE TABLE IF NOT EXISTS "email_sync_status" (
"id" UUID NOT NULL PRIMARY KEY,
"last_sync_date" TIMESTAMPTZ,
"last_message_uid" INT NOT NULL DEFAULT 0,
"message_count" INT NOT NULL DEFAULT 0,
"consecutive_failures" INT NOT NULL DEFAULT 0,
"last_failure_date" TIMESTAMPTZ,
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"account_id" UUID NOT NULL REFERENCES "email_accounts" ("id") ON DELETE CASCADE
);
COMMENT ON TABLE "email_sync_status" IS 'Tracks sync progress and state per email account.';"""
async def downgrade(db: BaseDBAsyncClient) -> str:
return """
DROP TABLE IF EXISTS "user_memories";
DROP TABLE IF EXISTS "email_accounts";
DROP TABLE IF EXISTS "emails";
DROP TABLE IF EXISTS "email_sync_status";"""
MODELS_STATE = (
"eJztXGtv2zYU/SuCPrVAFjTPbcUwwE7czVudDLGz9ZFCoCXa1ixRGkk1NYr+911Skq0HZV"
"t+RUr1oU1C8lLU4SV57tGVvuquZ2GHHV955DOmDHHbI/pr7atOkIvhF2X9kaYj31/UigKO"
"ho40MBMtZQ0aMk6RyaFyhByGocjCzKS2H12MBI4jCj0TGtpkvCgKiP1fgA3ujTGfYAoVHz"
"9BsU0s/AWz+E9/aoxs7FipcduWuLYsN/jMl2X3993rN7KluNzQMD0ncMmitT/jE4/MmweB"
"bR0LG1E3xgRTxLGVuA0xyui246JwxFDAaYDnQ7UWBRYeocARYOi/jAJiCgw0eSXx3/mveg"
"l4AGoBrU24wOLrt/CuFvcsS3VxqavfW3cvzi5fyrv0GB9TWSkR0b9JQ8RRaCpxXQApf+ag"
"vJogqoYybp8BEwa6CYxxwQLHhQ/FQMYAbYaa7qIvhoPJmE/gz9OLiyUw/t26k0hCKwmlB3"
"4dev1NVHUa1glIFxCaFItbNhDPA3kNNdx2sRrMtGUGUisyPY5/qSjAcA/WLXFm0SJYgu+g"
"2+v0B63eX+JOXMb+cyRErUFH1JzK0lmm9MVlZirmnWj/dAe/a+JP7cPtTSfr+/N2gw+6GB"
"MKuGcQ79FAVmK9xqUxMKmJDXxrw4lNWzYT+6QTGw0+Ma8MU6PcCZIw2eIYicZ2wEnc/NAQ"
"R+9oqjwzBBh58N54FNtj8ieeSQi7MA5ETNVhEZGO+6ibqoK2KF2MgqLHORtJOgXcHdwT5u"
"Hp2epfta47usRwiMzpI6KWUQCmixlDY8zygLYjyzd/3mFnTs3UWCYJXC/ssZq7ShG2Eivv"
"1EtglEIvX+WeutkSROC+reja4kpL0FnBghMgrkeGjeRENqS41qSY4y+KI38ApWoo4/Z1Ic"
"XLjvLOu0HqFI+p74te693L1En+9vbmt7h5gipfvb1tNwz5ORKpPENmPkZTFRkQAWSHBG6O"
"CqRmN2H+xEtHv+937l5r4kR/IP1ur916rTHbHSJ9vSlORZknr9YIMk9eFcaYoiq9gGwXTh"
"ZjimdlQvWU0Ub4Hp56pYG8ODldA0loVQilrMtsRslDu9yRqTDd5flZ03DAzIiHW4YFWS2y"
"siiujA8U7lI2TtgnKxbxVw+7Hp3pCjKcqF3KgWUQ5IqGdsN9nwH3hYtwTErR34RJw4AbBv"
"xdMeBGI34WE1sdjbham2FdROIKs8AtVOJ9s78i3rea8TVMr/5MT8xj2cf/SZu6cL0DpAD4"
"iLFHjyo8s20TRGdqMJNWGTCHMx5GU5VTaJaA1xa8N3m6A2Tt7k3r7r2aOsftk37bfj/otD"
"LoYhfZThkvnRvsxkVXr/hdOujJq/Xkw2X6YU5AfJwgzmBLN0jgDosEWzWYCtOdiImHRfVs"
"HVDPijE9y0EqnczARNyeauF7noMRWeKgSdvs8gfjfW2mZY/qEuv/9vZtav23u9nQ+L7X7o"
"DzSpihkR1Soe7NQAnuxEUmcIQpVuiKK1Z/xraGHntyuc42kI2QErvAZdZjPdsyDRYM/8Wm"
"IlotBjRrV0Mw93LqQ/w4MXzqfbatcltqzvBwVEp3PBM5W3DRzBOadbbVi+Jt9SK3rToW8o"
"0x9QJfkRLzR//2Rg1pxiwD6D2Bu/xo2SY/0hyb8U97g/fjp/3wfHHny1XJrACZIVaig0aV"
"fJbiVaNKPtOJnSfG5VShVVmFudc0dpNaWOWINJ9SmFwRySeUm2ORfihaPc9fC4qQHyPT9A"
"JhthUgHdFXK+yqZpDsU1yVsOgKdbUTKxPF8qqcnvX0VV12p0WZp/CTI6H2aYhYWvRQ9ljP"
"oLSOzQN5IH3uwbal+YgybGlyUJps+GjzCYTTP1hoplEs2sNgjrW3NpkyjXva1YR6LrpuP5"
"CRR7XPEDPAD4YRNSeaiXw0tCHsg5UoR9bowDvih1vowJErKJ92Fccwaas6Cm17iQk3CK+3"
"jayfXFK/WEuxvFiiWF7kFcsR7CKCGIEfK86oYjSzdvWEdC++CbyyENAlye1eHeE8dIKPiI"
"jKxlqxTT2jrJpEVfFtL42Xh541M8q+9ZEyqkl+9aGXhcRowl3F47sVwMZGDbDqhELJsuGK"
"MLSSzE1hWhOQm5f5G+VsU0kUf/Ft6G2DiU1b1nNiazKRax3WkXJVMjszbdUkaMaA5DEsna"
"NZXxHwKJOrmXaSKqVrpjAuEhYTc7BCX0zJv+vqjJGNkAlH9jigUhvWhMrX7bX+EsUES7mL"
"FamOJXpIaJBzK4otITfCGEMVEhOTznxwNC1OpTtK9KExzDlcnh09EKFuxt2AO/OAHWv9wP"
"c9ypnmgovZvoPjFkzzMZXvgjYaZUU0yshpy8tBOcNGqYwVC5v5DpoZZVOAs3ZN7JB4S9s3"
"JuDYZeBMGdVFXTsUmGJ/zoPZJQXCQcomg6W9P27y889nW0Apv0Xzw+nJ+Y/nP51dnv8ETe"
"RQ5iU/LgE3nzkpMdgktT9n2DhjxhkLk/w7MQ8p1rRyPdQF3UMLWzYE2sBHPit8d2lKdcru"
"gOnU84O/wtnUDmLcwJR6iizVYpdNW9XkmG9e7G70wiaFspnY5sXu5sXu6r7YnRE2dpGEWS"
"8s0zlTM2IaoSq3AyD60Ft/3lmNINm7fJxApkhBToO3SkTOTNxqHXkA1VOmCTvNp57YdJjM"
"PBWdYCm74qRQnNeRS/cgdOSegBn+MU1w2tBYnL1g4/pHYSF0ZkJf2JqnxsJOGCnHI+gwoF"
"gTdzeFgYg0Vxaqx5oNsR92MfTvhB0LA8matQn86kDzRkWuiIosIxrptJuka+Wtd8DtqhUh"
"VYjKrfUoWE5JnIkcqNZJoVaoMj2cZPhqC2q+Y8EwxqDgaXAhgDm77xI90T02AyE8GdExoS"
"AxhSAWmX+XWMolGaGw+Q6d7aDZpJ94k26UlOeppDR5WM8sD2vfSQ31z8JqYWqbE10RPUc1"
"R8uCZrRoU5kv5xU/S1+TEUcT+KTJMjvhIcVho3j9Xflt8+KH6QmTujzoPcQHc2BplAAxal"
"5PAPfyGbfCj3MXfxin+OPcB/sozt4O3Z19FKfENzZ2f7x8+x8fHBMe"
)
+2 -13
View File
@@ -5,8 +5,7 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [ dependencies = [
"langchain-postgres>=0.0.13", "chromadb>=1.1.0",
"psycopg[binary]>=3.1.0",
"python-dotenv>=1.0.0", "python-dotenv>=1.0.0",
"flask>=3.1.2", "flask>=3.1.2",
"httpx>=0.28.1", "httpx>=0.28.1",
@@ -31,6 +30,7 @@ dependencies = [
"asyncpg>=0.30.0", "asyncpg>=0.30.0",
"langchain-openai>=1.1.6", "langchain-openai>=1.1.6",
"langchain>=1.2.0", "langchain>=1.2.0",
"langchain-chroma>=1.0.0",
"langchain-community>=0.4.1", "langchain-community>=0.4.1",
"jq>=1.10.0", "jq>=1.10.0",
"tavily-python>=0.7.17", "tavily-python>=0.7.17",
@@ -42,17 +42,6 @@ dependencies = [
"aioboto3>=13.0.0", "aioboto3>=13.0.0",
] ]
[project.optional-dependencies]
test = [
"pytest>=8.0.0",
"pytest-asyncio>=0.25.0",
"pytest-cov>=6.0.0",
]
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "auto"
[tool.aerich] [tool.aerich]
tortoise_orm = "config.db.TORTOISE_CONFIG" tortoise_orm = "config.db.TORTOISE_CONFIG"
location = "./migrations" location = "./migrations"
+28 -27
View File
@@ -1,4 +1,4 @@
import { useCallback, useEffect, useState, useRef } from "react"; import { useEffect, useState, useRef } from "react";
import { LogOut, Shield, PanelLeftClose, PanelLeftOpen, Menu, X } from "lucide-react"; import { LogOut, Shield, PanelLeftClose, PanelLeftOpen, Menu, X } from "lucide-react";
import { conversationService } from "../api/conversationService"; import { conversationService } from "../api/conversationService";
import { userService } from "../api/userService"; import { userService } from "../api/userService";
@@ -63,13 +63,9 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
const abortControllerRef = useRef<AbortController | null>(null); const abortControllerRef = useRef<AbortController | null>(null);
const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"]; const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"];
const scrollToBottom = useCallback(() => { const scrollToBottom = () => {
requestAnimationFrame(() => { messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
messagesEndRef.current?.scrollIntoView({ };
behavior: isLoading ? "instant" : "smooth",
});
});
}, [isLoading]);
useEffect(() => { useEffect(() => {
isMountedRef.current = true; isMountedRef.current = true;
@@ -120,7 +116,21 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
scrollToBottom(); scrollToBottom();
}, [messages]); }, [messages]);
const handleQuestionSubmit = useCallback(async () => { useEffect(() => {
const load = async () => {
if (!selectedConversation) return;
try {
const conv = await conversationService.getConversation(selectedConversation.id);
setSelectedConversation({ id: conv.id, title: conv.name });
setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })));
} catch (err) {
console.error("Failed to load messages:", err);
}
};
load();
}, [selectedConversation?.id]);
const handleQuestionSubmit = async () => {
if ((!query.trim() && !pendingImage) || isLoading) return; if ((!query.trim() && !pendingImage) || isLoading) return;
let activeConversation = selectedConversation; let activeConversation = selectedConversation;
@@ -201,28 +211,22 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
} }
} }
} finally { } finally {
if (isMountedRef.current) { if (isMountedRef.current) setIsLoading(false);
setIsLoading(false);
loadConversations();
}
abortControllerRef.current = null; abortControllerRef.current = null;
} }
}, [query, pendingImage, isLoading, selectedConversation, simbaMode, messages, setAuthenticated]); };
const handleQueryChange = useCallback((event: React.ChangeEvent<HTMLTextAreaElement>) => { const handleQueryChange = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
setQuery(event.target.value); setQuery(event.target.value);
}, []); };
const handleKeyDown = useCallback((event: React.ChangeEvent<HTMLTextAreaElement>) => { const handleKeyDown = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
const kev = event as unknown as React.KeyboardEvent<HTMLTextAreaElement>; const kev = event as unknown as React.KeyboardEvent<HTMLTextAreaElement>;
if (kev.key === "Enter" && !kev.shiftKey) { if (kev.key === "Enter" && !kev.shiftKey) {
kev.preventDefault(); kev.preventDefault();
handleQuestionSubmit(); handleQuestionSubmit();
} }
}, [handleQuestionSubmit]); };
const handleImageSelect = useCallback((file: File) => setPendingImage(file), []);
const handleClearImage = useCallback(() => setPendingImage(null), []);
const handleLogout = () => { const handleLogout = () => {
localStorage.removeItem("access_token"); localStorage.removeItem("access_token");
@@ -376,8 +380,8 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
setSimbaMode={setSimbaMode} setSimbaMode={setSimbaMode}
isLoading={isLoading} isLoading={isLoading}
pendingImage={pendingImage} pendingImage={pendingImage}
onImageSelect={handleImageSelect} onImageSelect={(file) => setPendingImage(file)}
onClearImage={handleClearImage} onClearImage={() => setPendingImage(null)}
/> />
</div> </div>
</div> </div>
@@ -412,7 +416,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
</div> </div>
</div> </div>
<footer className="border-t border-sand-light/40 bg-cream"> <footer className="border-t border-sand-light/40 bg-cream/80 backdrop-blur-sm">
<div className="max-w-2xl mx-auto px-4 py-3"> <div className="max-w-2xl mx-auto px-4 py-3">
<MessageInput <MessageInput
query={query} query={query}
@@ -421,9 +425,6 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
handleQuestionSubmit={handleQuestionSubmit} handleQuestionSubmit={handleQuestionSubmit}
setSimbaMode={setSimbaMode} setSimbaMode={setSimbaMode}
isLoading={isLoading} isLoading={isLoading}
pendingImage={pendingImage}
onImageSelect={(file) => setPendingImage(file)}
onClearImage={() => setPendingImage(null)}
/> />
</div> </div>
</footer> </footer>
+4 -16
View File
@@ -1,4 +1,4 @@
import React, { useEffect, useMemo, useRef, useState } from "react"; import { useRef, useState } from "react";
import { ArrowUp, ImagePlus, X } from "lucide-react"; import { ArrowUp, ImagePlus, X } from "lucide-react";
import { cn } from "../lib/utils"; import { cn } from "../lib/utils";
import { Textarea } from "./ui/textarea"; import { Textarea } from "./ui/textarea";
@@ -15,7 +15,7 @@ type MessageInputProps = {
onClearImage: () => void; onClearImage: () => void;
}; };
export const MessageInput = React.memo(({ export const MessageInput = ({
query, query,
handleKeyDown, handleKeyDown,
handleQueryChange, handleQueryChange,
@@ -29,18 +29,6 @@ export const MessageInput = React.memo(({
const [simbaMode, setLocalSimbaMode] = useState(false); const [simbaMode, setLocalSimbaMode] = useState(false);
const fileInputRef = useRef<HTMLInputElement>(null); const fileInputRef = useRef<HTMLInputElement>(null);
// Create blob URL once per file, revoke on cleanup
const previewUrl = useMemo(
() => (pendingImage ? URL.createObjectURL(pendingImage) : null),
[pendingImage],
);
useEffect(() => {
return () => {
if (previewUrl) URL.revokeObjectURL(previewUrl);
};
}, [previewUrl]);
const toggleSimbaMode = () => { const toggleSimbaMode = () => {
const next = !simbaMode; const next = !simbaMode;
setLocalSimbaMode(next); setLocalSimbaMode(next);
@@ -71,7 +59,7 @@ export const MessageInput = React.memo(({
<div className="px-3 pt-3"> <div className="px-3 pt-3">
<div className="relative inline-block"> <div className="relative inline-block">
<img <img
src={previewUrl!} src={URL.createObjectURL(pendingImage)}
alt="Pending upload" alt="Pending upload"
className="h-20 rounded-lg object-cover border border-sand" className="h-20 rounded-lg object-cover border border-sand"
/> />
@@ -157,4 +145,4 @@ export const MessageInput = React.memo(({
</div> </div>
</div> </div>
); );
}); };
+16 -8
View File
@@ -6,19 +6,19 @@ import asyncio
import sys import sys
from blueprints.rag.logic import ( from blueprints.rag.logic import (
delete_all_documents,
get_vector_store_stats, get_vector_store_stats,
index_documents, index_documents,
list_all_documents, list_all_documents,
vector_store,
) )
def stats(): def stats():
"""Show vector store statistics.""" """Show vector store statistics."""
s = get_vector_store_stats() stats = get_vector_store_stats()
print("=== Vector Store Statistics ===") print("=== Vector Store Statistics ===")
print(f"Collection: {s['collection_name']}") print(f"Collection: {stats['collection_name']}")
print(f"Total Documents: {s['total_documents']}") print(f"Total Documents: {stats['total_documents']}")
async def index(): async def index():
@@ -26,15 +26,23 @@ async def index():
print("Starting indexing process...") print("Starting indexing process...")
print("Fetching documents from Paperless-NGX...") print("Fetching documents from Paperless-NGX...")
await index_documents() await index_documents()
print("Indexing complete!") print("Indexing complete!")
stats() stats()
async def reindex(): async def reindex():
"""Clear and reindex all documents.""" """Clear and reindex all documents."""
print("Clearing existing documents...") print("Clearing existing documents...")
delete_all_documents() collection = vector_store._collection
print("Cleared") all_docs = collection.get()
if all_docs["ids"]:
print(f"Deleting {len(all_docs['ids'])} existing documents...")
collection.delete(ids=all_docs["ids"])
print("✓ Cleared")
else:
print("Collection is already empty")
await index() await index()
@@ -105,7 +113,7 @@ Examples:
print("\n\nOperation cancelled by user") print("\n\nOperation cancelled by user")
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
print(f"\nError: {e}", file=sys.stderr) print(f"\nError: {e}", file=sys.stderr)
sys.exit(1) sys.exit(1)
+24
View File
@@ -0,0 +1,24 @@
from bs4 import BeautifulSoup
import chromadb
import httpx
client = chromadb.PersistentClient(path="/Users/ryanchen/Programs/raggr/chromadb")
# Scrape
BASE_URL = "https://www.vet.cornell.edu"
LIST_URL = "/departments-centers-and-institutes/cornell-feline-health-center/health-information/feline-health-topics"
QUERY_URL = BASE_URL + LIST_URL
r = httpx.get(QUERY_URL)
soup = BeautifulSoup(r.text)
container = soup.find("div", class_="field-body")
a_s = container.find_all("a", href=True)
new_texts = []
for link in a_s:
endpoint = link["href"]
query_url = BASE_URL + endpoint
r2 = httpx.get(query_url)
article_soup = BeautifulSoup(r2.text)
+3
View File
@@ -1,6 +1,9 @@
#!/bin/bash #!/bin/bash
set -e set -e
echo "Initializing directories..."
mkdir -p /app/data/chromadb
echo "Rebuilding frontend..." echo "Rebuilding frontend..."
cd /app/raggr-frontend cd /app/raggr-frontend
yarn build yarn build
+1 -18
View File
@@ -8,25 +8,8 @@ mkdir -p /app/data/obsidian
# Start continuous Obsidian sync if enabled # Start continuous Obsidian sync if enabled
if [ "${OBSIDIAN_CONTINUOUS_SYNC}" = "true" ]; then if [ "${OBSIDIAN_CONTINUOUS_SYNC}" = "true" ]; then
VAULT_PATH="${OBSIDIAN_VAULT_PATH:-/app/data/obsidian}"
# Setup sync if not already configured
if ! ob sync-status --path "$VAULT_PATH" > /dev/null 2>&1; then
echo "Configuring Obsidian sync..."
ob sync-setup \
--vault "${OBSIDIAN_VAULT_ID}" \
--path "$VAULT_PATH" \
--password "${OBSIDIAN_E2E_PASSWORD}" \
--device-name "${OBSIDIAN_DEVICE_NAME:-simbarag}" || {
echo "Failed to configure Obsidian sync. Skipping."
OBSIDIAN_CONTINUOUS_SYNC=false
}
fi
if [ "${OBSIDIAN_CONTINUOUS_SYNC}" = "true" ]; then
echo "Starting Obsidian continuous sync in background..." echo "Starting Obsidian continuous sync in background..."
ob sync --path "$VAULT_PATH" --continuous & ob sync --continuous &
fi
fi fi
echo "Starting application..." echo "Starting application..."
View File
-11
View File
@@ -1,11 +0,0 @@
import os
import sys
# Ensure project root is on the path so imports work
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
# Set FERNET_KEY for tests that import email models (EncryptedTextField needs it at import time)
if "FERNET_KEY" not in os.environ:
from cryptography.fernet import Fernet
os.environ["FERNET_KEY"] = Fernet.generate_key().decode()
View File
-91
View File
@@ -1,91 +0,0 @@
"""Tests for encryption/decryption in blueprints/email/crypto_service.py."""
import os
from unittest.mock import patch
import pytest
from cryptography.fernet import Fernet
# Generate a valid key for testing
TEST_FERNET_KEY = Fernet.generate_key().decode()
class TestEncryptedTextField:
@pytest.fixture
def field(self):
with patch.dict(os.environ, {"FERNET_KEY": TEST_FERNET_KEY}):
from blueprints.email.crypto_service import EncryptedTextField
return EncryptedTextField()
def test_encrypt_decrypt_roundtrip(self, field):
original = "my secret password"
encrypted = field.to_db_value(original, None)
decrypted = field.to_python_value(encrypted)
assert decrypted == original
assert encrypted != original
def test_none_passthrough(self, field):
assert field.to_db_value(None, None) is None
assert field.to_python_value(None) is None
def test_unicode_roundtrip(self, field):
original = "Hello 世界 🐱"
encrypted = field.to_db_value(original, None)
decrypted = field.to_python_value(encrypted)
assert decrypted == original
def test_empty_string_roundtrip(self, field):
encrypted = field.to_db_value("", None)
decrypted = field.to_python_value(encrypted)
assert decrypted == ""
def test_long_text_roundtrip(self, field):
original = "x" * 10000
encrypted = field.to_db_value(original, None)
decrypted = field.to_python_value(encrypted)
assert decrypted == original
def test_different_encryptions_differ(self, field):
"""Fernet includes a timestamp, so two encryptions of the same value differ."""
e1 = field.to_db_value("same", None)
e2 = field.to_db_value("same", None)
assert e1 != e2 # Different ciphertexts
assert field.to_python_value(e1) == field.to_python_value(e2) == "same"
def test_wrong_key_fails(self, field):
encrypted = field.to_db_value("secret", None)
# Create a field with a different key
other_key = Fernet.generate_key().decode()
with patch.dict(os.environ, {"FERNET_KEY": other_key}):
from blueprints.email.crypto_service import EncryptedTextField
other_field = EncryptedTextField()
with pytest.raises(Exception):
other_field.to_python_value(encrypted)
class TestValidateFernetKey:
def test_valid_key(self):
with patch.dict(os.environ, {"FERNET_KEY": TEST_FERNET_KEY}):
from blueprints.email.crypto_service import validate_fernet_key
validate_fernet_key() # Should not raise
def test_missing_key(self):
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("FERNET_KEY", None)
from blueprints.email.crypto_service import validate_fernet_key
with pytest.raises(ValueError, match="not set"):
validate_fernet_key()
def test_invalid_key(self):
with patch.dict(os.environ, {"FERNET_KEY": "not-a-valid-key"}):
from blueprints.email.crypto_service import validate_fernet_key
with pytest.raises(ValueError, match="validation failed"):
validate_fernet_key()
-38
View File
@@ -1,38 +0,0 @@
"""Tests for email helper functions in blueprints/email/helpers.py."""
from blueprints.email.helpers import generate_email_token, get_user_email_address
class TestGenerateEmailToken:
def test_returns_16_char_hex(self):
token = generate_email_token("user-123", "my-secret")
assert len(token) == 16
assert all(c in "0123456789abcdef" for c in token)
def test_deterministic(self):
t1 = generate_email_token("user-123", "my-secret")
t2 = generate_email_token("user-123", "my-secret")
assert t1 == t2
def test_different_users_different_tokens(self):
t1 = generate_email_token("user-1", "secret")
t2 = generate_email_token("user-2", "secret")
assert t1 != t2
def test_different_secrets_different_tokens(self):
t1 = generate_email_token("user-1", "secret-a")
t2 = generate_email_token("user-1", "secret-b")
assert t1 != t2
class TestGetUserEmailAddress:
def test_formats_correctly(self):
addr = get_user_email_address("abc123", "example.com")
assert addr == "ask+abc123@example.com"
def test_preserves_token(self):
token = "deadbeef12345678"
addr = get_user_email_address(token, "mail.test.org")
assert token in addr
assert addr.startswith("ask+")
assert "@mail.test.org" in addr
-259
View File
@@ -1,259 +0,0 @@
"""Tests for ObsidianService markdown parsing and file operations."""
import os
from datetime import datetime
from pathlib import Path
from unittest.mock import patch
import pytest
# Set vault path before importing so __init__ validation passes
_test_vault_dir = None
@pytest.fixture(autouse=True)
def vault_dir(tmp_path):
"""Create a temporary vault directory with a sample .md file."""
global _test_vault_dir
_test_vault_dir = tmp_path
# Create a sample markdown file so vault validation passes
sample = tmp_path / "sample.md"
sample.write_text("# Sample\nHello world")
with patch.dict(os.environ, {"OBSIDIAN_VAULT_PATH": str(tmp_path)}):
yield tmp_path
@pytest.fixture
def service(vault_dir):
from utils.obsidian_service import ObsidianService
return ObsidianService()
class TestParseMarkdown:
def test_extracts_frontmatter(self, service):
content = "---\ntitle: Test Note\ntags: [cat, vet]\n---\n\nBody content"
result = service.parse_markdown(content)
assert result["metadata"]["title"] == "Test Note"
assert result["metadata"]["tags"] == ["cat", "vet"]
def test_no_frontmatter(self, service):
content = "Just body content with no frontmatter"
result = service.parse_markdown(content)
assert result["metadata"] == {}
assert "Just body content" in result["content"]
def test_invalid_yaml_frontmatter(self, service):
content = "---\n: invalid: yaml: [[\n---\n\nBody"
result = service.parse_markdown(content)
assert result["metadata"] == {}
def test_extracts_tags(self, service):
content = "Some text with #tag1 and #tag2 here"
result = service.parse_markdown(content)
assert "tag1" in result["tags"]
assert "tag2" in result["tags"]
def test_extracts_wikilinks(self, service):
content = "Link to [[Other Note]] and [[Another Page]]"
result = service.parse_markdown(content)
assert "Other Note" in result["wikilinks"]
assert "Another Page" in result["wikilinks"]
def test_extracts_embeds(self, service):
content = "An embed [[!my_embed]] here"
result = service.parse_markdown(content)
assert "my_embed" in result["embeds"]
def test_cleans_wikilinks_from_content(self, service):
content = "Text with [[link]] included"
result = service.parse_markdown(content)
assert "[[" not in result["content"]
assert "]]" not in result["content"]
def test_filepath_passed_through(self, service):
result = service.parse_markdown("text", filepath=Path("/vault/note.md"))
assert result["filepath"] == "/vault/note.md"
def test_filepath_none_by_default(self, service):
result = service.parse_markdown("text")
assert result["filepath"] is None
def test_empty_content(self, service):
result = service.parse_markdown("")
assert result["metadata"] == {}
assert result["tags"] == []
assert result["wikilinks"] == []
assert result["embeds"] == []
class TestGetDailyNotePath:
def test_formats_path_correctly(self, service):
date = datetime(2026, 3, 15)
path = service.get_daily_note_path(date)
assert path == "journal/2026/2026-03-15.md"
def test_defaults_to_today(self, service):
path = service.get_daily_note_path()
today = datetime.now()
assert today.strftime("%Y-%m-%d") in path
assert path.startswith(f"journal/{today.strftime('%Y')}/")
class TestWalkVault:
def test_finds_markdown_files(self, service, vault_dir):
(vault_dir / "note1.md").write_text("# Note 1")
(vault_dir / "subdir").mkdir()
(vault_dir / "subdir" / "note2.md").write_text("# Note 2")
files = service.walk_vault()
filenames = [f.name for f in files]
assert "sample.md" in filenames
assert "note1.md" in filenames
assert "note2.md" in filenames
def test_excludes_obsidian_dir(self, service, vault_dir):
obsidian_dir = vault_dir / ".obsidian"
obsidian_dir.mkdir()
(obsidian_dir / "config.md").write_text("config")
files = service.walk_vault()
filenames = [f.name for f in files]
assert "config.md" not in filenames
def test_ignores_non_md_files(self, service, vault_dir):
(vault_dir / "image.png").write_bytes(b"\x89PNG")
files = service.walk_vault()
filenames = [f.name for f in files]
assert "image.png" not in filenames
class TestCreateNote:
def test_creates_file(self, service, vault_dir):
path = service.create_note("My Test Note", "Body content")
full_path = vault_dir / path
assert full_path.exists()
def test_sanitizes_title(self, service, vault_dir):
path = service.create_note("Hello World! @#$", "Body")
assert "hello-world" in path
assert "@" not in path
assert "#" not in path
def test_includes_frontmatter(self, service, vault_dir):
path = service.create_note("Test", "Body", tags=["cat", "vet"])
full_path = vault_dir / path
content = full_path.read_text()
assert "---" in content
assert "created_by: simbarag" in content
assert "cat" in content
assert "vet" in content
def test_custom_folder(self, service, vault_dir):
path = service.create_note("Test", "Body", folder="custom/subfolder")
assert path.startswith("custom/subfolder/")
assert (vault_dir / path).exists()
class TestDailyNoteTasks:
def test_get_tasks_from_daily_note(self, service, vault_dir):
# Create a daily note with tasks
date = datetime(2026, 1, 15)
rel_path = service.get_daily_note_path(date)
note_path = vault_dir / rel_path
note_path.parent.mkdir(parents=True, exist_ok=True)
note_path.write_text(
"---\nmodified: 2026-01-15\n---\n"
"### tasks\n\n"
"- [ ] Feed the cat\n"
"- [x] Clean litter box\n"
"- [ ] Buy cat food\n\n"
"### log\n"
)
result = service.get_daily_tasks(date)
assert result["found"] is True
assert len(result["tasks"]) == 3
assert result["tasks"][0] == {"text": "Feed the cat", "done": False}
assert result["tasks"][1] == {"text": "Clean litter box", "done": True}
assert result["tasks"][2] == {"text": "Buy cat food", "done": False}
def test_get_tasks_no_note(self, service):
date = datetime(2099, 12, 31)
result = service.get_daily_tasks(date)
assert result["found"] is False
assert result["tasks"] == []
def test_add_task_creates_note(self, service, vault_dir):
date = datetime(2026, 6, 1)
result = service.add_task_to_daily_note("Walk the cat", date)
assert result["success"] is True
assert result["created_note"] is True
# Verify file was created with the task
note_path = vault_dir / result["path"]
content = note_path.read_text()
assert "- [ ] Walk the cat" in content
def test_add_task_to_existing_note(self, service, vault_dir):
date = datetime(2026, 6, 2)
rel_path = service.get_daily_note_path(date)
note_path = vault_dir / rel_path
note_path.parent.mkdir(parents=True, exist_ok=True)
note_path.write_text(
"---\nmodified: 2026-06-02\n---\n"
"### tasks\n\n"
"- [ ] Existing task\n\n"
"### log\n"
)
result = service.add_task_to_daily_note("New task", date)
assert result["success"] is True
assert result["created_note"] is False
content = note_path.read_text()
assert "- [ ] Existing task" in content
assert "- [ ] New task" in content
def test_complete_task_exact_match(self, service, vault_dir):
date = datetime(2026, 6, 3)
rel_path = service.get_daily_note_path(date)
note_path = vault_dir / rel_path
note_path.parent.mkdir(parents=True, exist_ok=True)
note_path.write_text("### tasks\n\n" "- [ ] Feed the cat\n" "- [ ] Buy food\n")
result = service.complete_task_in_daily_note("Feed the cat", date)
assert result["success"] is True
content = note_path.read_text()
assert "- [x] Feed the cat" in content
assert "- [ ] Buy food" in content # Other task unchanged
def test_complete_task_partial_match(self, service, vault_dir):
date = datetime(2026, 6, 4)
rel_path = service.get_daily_note_path(date)
note_path = vault_dir / rel_path
note_path.parent.mkdir(parents=True, exist_ok=True)
note_path.write_text("### tasks\n\n- [ ] Feed the cat at 5pm\n")
result = service.complete_task_in_daily_note("Feed the cat", date)
assert result["success"] is True
def test_complete_task_not_found(self, service, vault_dir):
date = datetime(2026, 6, 5)
rel_path = service.get_daily_note_path(date)
note_path = vault_dir / rel_path
note_path.parent.mkdir(parents=True, exist_ok=True)
note_path.write_text("### tasks\n\n- [ ] Feed the cat\n")
result = service.complete_task_in_daily_note("Walk the dog", date)
assert result["success"] is False
assert "not found" in result["error"]
def test_complete_task_no_note(self, service):
date = datetime(2099, 12, 31)
result = service.complete_task_in_daily_note("Something", date)
assert result["success"] is False
-92
View File
@@ -1,92 +0,0 @@
"""Tests for rate limiting logic in email and WhatsApp blueprints."""
import time
class TestEmailRateLimit:
def setup_method(self):
"""Reset rate limit store before each test."""
from blueprints.email import _rate_limit_store
_rate_limit_store.clear()
def test_allows_under_limit(self):
from blueprints.email import _check_rate_limit, RATE_LIMIT_MAX
for _ in range(RATE_LIMIT_MAX):
assert _check_rate_limit("sender@test.com") is True
def test_blocks_at_limit(self):
from blueprints.email import _check_rate_limit, RATE_LIMIT_MAX
for _ in range(RATE_LIMIT_MAX):
_check_rate_limit("sender@test.com")
assert _check_rate_limit("sender@test.com") is False
def test_different_senders_independent(self):
from blueprints.email import _check_rate_limit, RATE_LIMIT_MAX
for _ in range(RATE_LIMIT_MAX):
_check_rate_limit("user1@test.com")
# user1 is at limit, but user2 should be fine
assert _check_rate_limit("user1@test.com") is False
assert _check_rate_limit("user2@test.com") is True
def test_window_expiry(self):
from blueprints.email import (
_check_rate_limit,
_rate_limit_store,
RATE_LIMIT_MAX,
)
# Fill up the rate limit with timestamps in the past
past = time.monotonic() - 999 # Well beyond any window
_rate_limit_store["old@test.com"] = [past] * RATE_LIMIT_MAX
# Should be allowed because all timestamps are expired
assert _check_rate_limit("old@test.com") is True
class TestWhatsAppRateLimit:
def setup_method(self):
"""Reset rate limit store before each test."""
from blueprints.whatsapp import _rate_limit_store
_rate_limit_store.clear()
def test_allows_under_limit(self):
from blueprints.whatsapp import _check_rate_limit, RATE_LIMIT_MAX
for _ in range(RATE_LIMIT_MAX):
assert _check_rate_limit("whatsapp:+1234567890") is True
def test_blocks_at_limit(self):
from blueprints.whatsapp import _check_rate_limit, RATE_LIMIT_MAX
for _ in range(RATE_LIMIT_MAX):
_check_rate_limit("whatsapp:+1234567890")
assert _check_rate_limit("whatsapp:+1234567890") is False
def test_different_numbers_independent(self):
from blueprints.whatsapp import _check_rate_limit, RATE_LIMIT_MAX
for _ in range(RATE_LIMIT_MAX):
_check_rate_limit("whatsapp:+1111111111")
assert _check_rate_limit("whatsapp:+1111111111") is False
assert _check_rate_limit("whatsapp:+2222222222") is True
def test_window_expiry(self):
from blueprints.whatsapp import (
_check_rate_limit,
_rate_limit_store,
RATE_LIMIT_MAX,
)
past = time.monotonic() - 999
_rate_limit_store["whatsapp:+9999999999"] = [past] * RATE_LIMIT_MAX
assert _check_rate_limit("whatsapp:+9999999999") is True
-86
View File
@@ -1,86 +0,0 @@
"""Tests for User model methods in blueprints/users/models.py."""
from unittest.mock import MagicMock
import bcrypt
class TestUserModelMethods:
"""Test User model methods without requiring a database connection.
We instantiate a mock object with the same methods as User
to avoid Tortoise ORM initialization.
"""
def _make_user(self, ldap_groups=None, password=None):
"""Create a mock user with real method implementations."""
from blueprints.users.models import User
user = MagicMock(spec=User)
user.ldap_groups = ldap_groups
user.password = password
# Bind real methods
user.has_group = lambda group: group in (user.ldap_groups or [])
user.is_admin = lambda: user.has_group("lldap_admin")
def set_password(plain):
user.password = bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt())
user.set_password = set_password
def verify_password(plain):
if not user.password:
return False
return bcrypt.checkpw(plain.encode("utf-8"), user.password)
user.verify_password = verify_password
return user
def test_has_group_true(self):
user = self._make_user(ldap_groups=["lldap_admin", "users"])
assert user.has_group("lldap_admin") is True
assert user.has_group("users") is True
def test_has_group_false(self):
user = self._make_user(ldap_groups=["users"])
assert user.has_group("lldap_admin") is False
def test_has_group_empty_list(self):
user = self._make_user(ldap_groups=[])
assert user.has_group("anything") is False
def test_has_group_none(self):
user = self._make_user(ldap_groups=None)
assert user.has_group("anything") is False
def test_is_admin_true(self):
user = self._make_user(ldap_groups=["lldap_admin"])
assert user.is_admin() is True
def test_is_admin_false(self):
user = self._make_user(ldap_groups=["users"])
assert user.is_admin() is False
def test_is_admin_empty(self):
user = self._make_user(ldap_groups=[])
assert user.is_admin() is False
def test_set_and_verify_password(self):
user = self._make_user()
user.set_password("hunter2")
assert user.password is not None
assert user.verify_password("hunter2") is True
assert user.verify_password("wrong") is False
def test_verify_password_no_password_set(self):
user = self._make_user(password=None)
assert user.verify_password("anything") is False
def test_password_is_hashed(self):
user = self._make_user()
user.set_password("mypassword")
# The stored password should not be the plaintext
assert user.password != b"mypassword"
assert user.password != "mypassword"
-254
View File
@@ -1,254 +0,0 @@
"""Tests for YNAB service data formatting and filtering logic."""
import os
from unittest.mock import MagicMock, patch
import pytest
def _mock_category(
name, budgeted, activity, balance, deleted=False, hidden=False, goal_type=None
):
cat = MagicMock()
cat.name = name
cat.budgeted = budgeted
cat.activity = activity
cat.balance = balance
cat.deleted = deleted
cat.hidden = hidden
cat.goal_type = goal_type
return cat
def _mock_transaction(
var_date, payee_name, category_name, amount, memo="", deleted=False, approved=True
):
txn = MagicMock()
txn.var_date = var_date
txn.payee_name = payee_name
txn.category_name = category_name
txn.amount = amount
txn.memo = memo
txn.deleted = deleted
txn.approved = approved
return txn
@pytest.fixture
def ynab_service():
"""Create a YNABService with mocked API client."""
with patch.dict(
os.environ, {"YNAB_ACCESS_TOKEN": "fake-token", "YNAB_BUDGET_ID": "test-budget"}
):
with patch("utils.ynab_service.ynab") as mock_ynab:
# Mock the configuration and API client chain
mock_ynab.Configuration.return_value = MagicMock()
mock_ynab.ApiClient.return_value = MagicMock()
mock_ynab.PlansApi.return_value = MagicMock()
mock_ynab.TransactionsApi.return_value = MagicMock()
mock_ynab.MonthsApi.return_value = MagicMock()
mock_ynab.CategoriesApi.return_value = MagicMock()
from utils.ynab_service import YNABService
service = YNABService()
yield service
class TestGetBudgetSummary:
def test_calculates_totals(self, ynab_service):
categories = [
_mock_category("Groceries", 500_000, -350_000, 150_000),
_mock_category("Rent", 1_500_000, -1_500_000, 0),
]
mock_month = MagicMock()
mock_month.to_be_budgeted = 200_000
mock_budget = MagicMock()
mock_budget.name = "My Budget"
mock_budget.months = [mock_month]
mock_budget.categories = categories
mock_budget.currency_format = MagicMock(iso_code="USD")
mock_response = MagicMock()
mock_response.data.budget = mock_budget
ynab_service.plans_api.get_plan_by_id.return_value = mock_response
result = ynab_service.get_budget_summary()
assert result["budget_name"] == "My Budget"
assert result["to_be_budgeted"] == 200.0
assert result["total_budgeted"] == 2000.0 # (500k + 1500k) / 1000
assert result["total_activity"] == -1850.0
assert result["currency_format"] == "USD"
def test_skips_deleted_and_hidden(self, ynab_service):
categories = [
_mock_category("Active", 100_000, -50_000, 50_000),
_mock_category("Deleted", 999_000, -999_000, 0, deleted=True),
_mock_category("Hidden", 999_000, -999_000, 0, hidden=True),
]
mock_month = MagicMock()
mock_month.to_be_budgeted = 0
mock_budget = MagicMock()
mock_budget.name = "Budget"
mock_budget.months = [mock_month]
mock_budget.categories = categories
mock_budget.currency_format = None
mock_response = MagicMock()
mock_response.data.budget = mock_budget
ynab_service.plans_api.get_plan_by_id.return_value = mock_response
result = ynab_service.get_budget_summary()
assert result["total_budgeted"] == 100.0
assert result["currency_format"] == "USD" # Default fallback
class TestGetTransactions:
def test_filters_by_date_range(self, ynab_service):
transactions = [
_mock_transaction("2026-01-05", "Store", "Groceries", -25_000),
_mock_transaction("2026-01-15", "Gas", "Transport", -40_000),
_mock_transaction(
"2026-02-01", "Store", "Groceries", -30_000
), # Out of range
]
mock_response = MagicMock()
mock_response.data.transactions = transactions
ynab_service.transactions_api.get_transactions.return_value = mock_response
result = ynab_service.get_transactions(
start_date="2026-01-01", end_date="2026-01-31"
)
assert result["count"] == 2
assert result["total_amount"] == -65.0
def test_filters_by_category(self, ynab_service):
transactions = [
_mock_transaction("2026-01-05", "Store", "Groceries", -25_000),
_mock_transaction("2026-01-06", "Gas", "Transport", -40_000),
]
mock_response = MagicMock()
mock_response.data.transactions = transactions
ynab_service.transactions_api.get_transactions.return_value = mock_response
result = ynab_service.get_transactions(
start_date="2026-01-01",
end_date="2026-01-31",
category_name="groceries", # Case insensitive
)
assert result["count"] == 1
assert result["transactions"][0]["category"] == "Groceries"
def test_filters_by_payee(self, ynab_service):
transactions = [
_mock_transaction("2026-01-05", "Whole Foods", "Groceries", -25_000),
_mock_transaction("2026-01-06", "Shell Gas", "Transport", -40_000),
]
mock_response = MagicMock()
mock_response.data.transactions = transactions
ynab_service.transactions_api.get_transactions.return_value = mock_response
result = ynab_service.get_transactions(
start_date="2026-01-01",
end_date="2026-01-31",
payee_name="whole",
)
assert result["count"] == 1
assert result["transactions"][0]["payee"] == "Whole Foods"
def test_skips_deleted(self, ynab_service):
transactions = [
_mock_transaction("2026-01-05", "Store", "Groceries", -25_000),
_mock_transaction("2026-01-06", "Deleted", "Other", -10_000, deleted=True),
]
mock_response = MagicMock()
mock_response.data.transactions = transactions
ynab_service.transactions_api.get_transactions.return_value = mock_response
result = ynab_service.get_transactions(
start_date="2026-01-01", end_date="2026-01-31"
)
assert result["count"] == 1
def test_converts_milliunits(self, ynab_service):
transactions = [
_mock_transaction("2026-01-05", "Store", "Groceries", -12_340),
]
mock_response = MagicMock()
mock_response.data.transactions = transactions
ynab_service.transactions_api.get_transactions.return_value = mock_response
result = ynab_service.get_transactions(
start_date="2026-01-01", end_date="2026-01-31"
)
assert result["transactions"][0]["amount"] == -12.34
def test_sorts_by_date_descending(self, ynab_service):
transactions = [
_mock_transaction("2026-01-01", "A", "Cat", -10_000),
_mock_transaction("2026-01-15", "B", "Cat", -20_000),
_mock_transaction("2026-01-10", "C", "Cat", -30_000),
]
mock_response = MagicMock()
mock_response.data.transactions = transactions
ynab_service.transactions_api.get_transactions.return_value = mock_response
result = ynab_service.get_transactions(
start_date="2026-01-01", end_date="2026-01-31"
)
dates = [t["date"] for t in result["transactions"]]
assert dates == sorted(dates, reverse=True)
class TestGetCategorySpending:
def test_month_format_normalization(self, ynab_service):
"""Passing YYYY-MM should be normalized to YYYY-MM-01."""
categories = [_mock_category("Food", 100_000, -50_000, 50_000)]
mock_month = MagicMock()
mock_month.categories = categories
mock_month.to_be_budgeted = 0
mock_response = MagicMock()
mock_response.data.month = mock_month
ynab_service.months_api.get_plan_month.return_value = mock_response
result = ynab_service.get_category_spending("2026-03")
assert result["month"] == "2026-03"
def test_identifies_overspent(self, ynab_service):
categories = [
_mock_category("Dining", 200_000, -300_000, -100_000), # Overspent
_mock_category("Groceries", 500_000, -400_000, 100_000), # Fine
]
mock_month = MagicMock()
mock_month.categories = categories
mock_month.to_be_budgeted = 0
mock_response = MagicMock()
mock_response.data.month = mock_month
ynab_service.months_api.get_plan_month.return_value = mock_response
result = ynab_service.get_category_spending("2026-03")
assert len(result["overspent_categories"]) == 1
assert result["overspent_categories"][0]["name"] == "Dining"
assert result["overspent_categories"][0]["overspent_by"] == 100.0
+137
View File
@@ -0,0 +1,137 @@
import os
from math import ceil
import re
from typing import Union
from uuid import UUID, uuid4
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from dotenv import load_dotenv
from llm import LLMClient
load_dotenv()
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
if header_patterns is None:
header_patterns = [r"^.*Header.*$"]
if footer_patterns is None:
footer_patterns = [r"^.*Footer.*$"]
for pattern in header_patterns + footer_patterns:
text = re.sub(pattern, "", text, flags=re.MULTILINE)
return text.strip()
def remove_special_characters(text, special_chars=None):
if special_chars is None:
special_chars = r"[^A-Za-z0-9\s\.,;:\'\"\?\!\-]"
text = re.sub(special_chars, "", text)
return text.strip()
def remove_repeated_substrings(text, pattern=r"\.{2,}"):
text = re.sub(pattern, ".", text)
return text.strip()
def remove_extra_spaces(text):
text = re.sub(r"\n\s*\n", "\n\n", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
def preprocess_text(text):
# Remove headers and footers
text = remove_headers_footers(text)
# Remove special characters
text = remove_special_characters(text)
# Remove repeated substrings like dots
text = remove_repeated_substrings(text)
# Remove extra spaces between lines and within lines
text = remove_extra_spaces(text)
# Additional cleaning steps can be added here
return text.strip()
class Chunk:
def __init__(
self,
text: str,
size: int,
document_id: UUID,
chunk_id: int,
embedding,
):
self.text = text
self.size = size
self.document_id = document_id
self.chunk_id = chunk_id
self.embedding = embedding
class Chunker:
def __init__(self, collection) -> None:
self.collection = collection
self.llm_client = LLMClient()
def embedding_fx(self, inputs):
openai_embedding_fx = OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="text-embedding-3-small",
)
return openai_embedding_fx(inputs)
def chunk_document(
self,
document: str,
chunk_size: int = 1000,
metadata: dict[str, Union[str, float]] = {},
) -> list[Chunk]:
doc_uuid = uuid4()
chunk_size = min(chunk_size, len(document)) or 1
chunks = []
num_chunks = ceil(len(document) / chunk_size)
document_length = len(document)
for i in range(num_chunks):
curr_pos = i * num_chunks
to_pos = (
curr_pos + chunk_size
if curr_pos + chunk_size < document_length
else document_length
)
text_chunk = self.clean_document(document[curr_pos:to_pos])
embedding = self.embedding_fx([text_chunk])
self.collection.add(
ids=[str(doc_uuid) + ":" + str(i)],
documents=[text_chunk],
embeddings=embedding,
metadatas=[metadata],
)
return chunks
def clean_document(self, document: str) -> str:
"""This function will remove information that is noise or already known.
Example: We already know all the things in here are Simba-related, so we don't need things like
"Sumamry of simba's visit"
"""
document = document.replace("\\n", "")
document = document.strip()
return preprocess_text(document)
Generated
+951 -345
View File
File diff suppressed because it is too large Load Diff