Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 64dab18428 | |||
| b62a8b6b3f |
@@ -19,6 +19,11 @@ BASE_URL=192.168.1.5:8000
|
|||||||
LLAMA_SERVER_URL=http://192.168.1.213:8080/v1
|
LLAMA_SERVER_URL=http://192.168.1.213:8080/v1
|
||||||
LLAMA_MODEL_NAME=llama-3.1-8b-instruct
|
LLAMA_MODEL_NAME=llama-3.1-8b-instruct
|
||||||
|
|
||||||
|
# ChromaDB Configuration
|
||||||
|
# For Docker: This is automatically set to /app/data/chromadb
|
||||||
|
# For local development: Set to a local directory path
|
||||||
|
CHROMADB_PATH=./data/chromadb
|
||||||
|
|
||||||
# OpenAI Configuration
|
# OpenAI Configuration
|
||||||
OPENAI_API_KEY=your-openai-api-key
|
OPENAI_API_KEY=your-openai-api-key
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ wheels/
|
|||||||
.env
|
.env
|
||||||
|
|
||||||
# Database files
|
# Database files
|
||||||
|
chromadb/
|
||||||
|
chromadb_openai/
|
||||||
|
chroma_db/
|
||||||
database/
|
database/
|
||||||
*.db
|
*.db
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|||||||
|
|
||||||
## Project Overview
|
## Project Overview
|
||||||
|
|
||||||
SimbaRAG is a RAG (Retrieval-Augmented Generation) conversational AI system for querying information about Simba (a cat). It ingests documents from Paperless-NGX, stores embeddings in PostgreSQL via pgvector, and uses LLMs (Ollama or OpenAI) to answer questions.
|
SimbaRAG is a RAG (Retrieval-Augmented Generation) conversational AI system for querying information about Simba (a cat). It ingests documents from Paperless-NGX, stores embeddings in ChromaDB, and uses LLMs (Ollama or OpenAI) to answer questions.
|
||||||
|
|
||||||
## Commands
|
## Commands
|
||||||
|
|
||||||
@@ -54,8 +54,9 @@ docker compose up -d
|
|||||||
│ Docker Compose │
|
│ Docker Compose │
|
||||||
├─────────────────────────────────────────────────────────────┤
|
├─────────────────────────────────────────────────────────────┤
|
||||||
│ raggr (port 8080) │ postgres (port 5432) │
|
│ raggr (port 8080) │ postgres (port 5432) │
|
||||||
│ ├── Quart backend │ PostgreSQL 16 + pgvector│
|
│ ├── Quart backend │ PostgreSQL 16 │
|
||||||
│ └── React frontend (served) │ │
|
│ ├── React frontend (served) │ │
|
||||||
|
│ └── ChromaDB (volume) │ │
|
||||||
└─────────────────────────────────────────────────────────────┘
|
└─────────────────────────────────────────────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -90,15 +91,6 @@ docker compose up -d
|
|||||||
|
|
||||||
**Auth Flow**: LLDAP → Authelia (OIDC) → Backend JWT → Frontend localStorage
|
**Auth Flow**: LLDAP → Authelia (OIDC) → Backend JWT → Frontend localStorage
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
Always run `make test` before pushing code to ensure all tests pass.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
make test # Run tests
|
|
||||||
make test-cov # Run tests with coverage
|
|
||||||
```
|
|
||||||
|
|
||||||
## Key Patterns
|
## Key Patterns
|
||||||
|
|
||||||
- All endpoints are async (`async def`)
|
- All endpoints are async (`async def`)
|
||||||
|
|||||||
+3
-2
@@ -37,14 +37,15 @@ WORKDIR /app/raggr-frontend
|
|||||||
RUN yarn install && yarn build
|
RUN yarn install && yarn build
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Create database directory
|
# Create ChromaDB and database directories
|
||||||
RUN mkdir -p /app/database
|
RUN mkdir -p /app/chromadb /app/database
|
||||||
|
|
||||||
# Expose port
|
# Expose port
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV PYTHONPATH=/app
|
ENV PYTHONPATH=/app
|
||||||
|
ENV CHROMADB_PATH=/app/chromadb
|
||||||
|
|
||||||
# Run the startup script
|
# Run the startup script
|
||||||
CMD ["./startup.sh"]
|
CMD ["./startup.sh"]
|
||||||
|
|||||||
+3
-2
@@ -34,15 +34,16 @@ COPY . .
|
|||||||
WORKDIR /app/raggr-frontend
|
WORKDIR /app/raggr-frontend
|
||||||
RUN yarn build
|
RUN yarn build
|
||||||
|
|
||||||
# Create database directory
|
# Create ChromaDB and database directories
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
RUN mkdir -p /app/database
|
RUN mkdir -p /app/chromadb /app/database
|
||||||
|
|
||||||
# Make startup script executable
|
# Make startup script executable
|
||||||
RUN chmod +x /app/startup-dev.sh
|
RUN chmod +x /app/startup-dev.sh
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV PYTHONPATH=/app
|
ENV PYTHONPATH=/app
|
||||||
|
ENV CHROMADB_PATH=/app/chromadb
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
# Expose port
|
# Expose port
|
||||||
|
|||||||
@@ -1,11 +1,8 @@
|
|||||||
.PHONY: deploy redeploy build up down restart logs migrate migrate-new frontend test
|
.PHONY: deploy build up down restart logs migrate migrate-new frontend
|
||||||
|
|
||||||
# Build and deploy
|
# Build and deploy
|
||||||
deploy: build up
|
deploy: build up
|
||||||
|
|
||||||
redeploy:
|
|
||||||
git pull && $(MAKE) down && $(MAKE) up
|
|
||||||
|
|
||||||
build:
|
build:
|
||||||
docker compose build raggr
|
docker compose build raggr
|
||||||
|
|
||||||
@@ -32,13 +29,6 @@ migrate-new:
|
|||||||
migrate-history:
|
migrate-history:
|
||||||
docker compose exec raggr aerich history
|
docker compose exec raggr aerich history
|
||||||
|
|
||||||
# Tests
|
|
||||||
test:
|
|
||||||
pytest tests/ -v
|
|
||||||
|
|
||||||
test-cov:
|
|
||||||
pytest tests/ -v --cov
|
|
||||||
|
|
||||||
# Frontend
|
# Frontend
|
||||||
frontend:
|
frontend:
|
||||||
cd raggr-frontend && yarn install && yarn build
|
cd raggr-frontend && yarn install && yarn build
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from quart import Quart, jsonify, render_template, send_from_directory
|
from quart import Quart, jsonify, render_template, request, send_from_directory
|
||||||
from quart_jwt_extended import JWTManager, get_jwt_identity, jwt_refresh_token_required
|
from quart_jwt_extended import JWTManager, get_jwt_identity, jwt_refresh_token_required
|
||||||
from tortoise import Tortoise
|
from tortoise import Tortoise
|
||||||
|
|
||||||
@@ -15,6 +14,7 @@ import blueprints.users
|
|||||||
import blueprints.whatsapp
|
import blueprints.whatsapp
|
||||||
import blueprints.users.models
|
import blueprints.users.models
|
||||||
from config.db import TORTOISE_CONFIG
|
from config.db import TORTOISE_CONFIG
|
||||||
|
from main import consult_simba_oracle
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -38,8 +38,6 @@ app = Quart(
|
|||||||
)
|
)
|
||||||
|
|
||||||
app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY")
|
app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY")
|
||||||
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1)
|
|
||||||
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=30)
|
|
||||||
app.config["MAX_CONTENT_LENGTH"] = 10 * 1024 * 1024 # 10 MB upload limit
|
app.config["MAX_CONTENT_LENGTH"] = 10 * 1024 * 1024 # 10 MB upload limit
|
||||||
jwt = JWTManager(app)
|
jwt = JWTManager(app)
|
||||||
|
|
||||||
@@ -77,6 +75,39 @@ async def serve_react_app(path):
|
|||||||
return await render_template("index.html")
|
return await render_template("index.html")
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/api/query", methods=["POST"])
|
||||||
|
@jwt_refresh_token_required
|
||||||
|
async def query():
|
||||||
|
current_user_uuid = get_jwt_identity()
|
||||||
|
user = await blueprints.users.models.User.get(id=current_user_uuid)
|
||||||
|
data = await request.get_json()
|
||||||
|
query = data.get("query")
|
||||||
|
conversation_id = data.get("conversation_id")
|
||||||
|
conversation = await blueprints.conversation.logic.get_conversation_by_id(
|
||||||
|
conversation_id
|
||||||
|
)
|
||||||
|
await conversation.fetch_related("messages")
|
||||||
|
await blueprints.conversation.logic.add_message_to_conversation(
|
||||||
|
conversation=conversation,
|
||||||
|
message=query,
|
||||||
|
speaker="user",
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript = await blueprints.conversation.logic.get_conversation_transcript(
|
||||||
|
user=user, conversation=conversation
|
||||||
|
)
|
||||||
|
|
||||||
|
response = consult_simba_oracle(input=query, transcript=transcript)
|
||||||
|
await blueprints.conversation.logic.add_message_to_conversation(
|
||||||
|
conversation=conversation,
|
||||||
|
message=response,
|
||||||
|
speaker="simba",
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
return jsonify({"response": response})
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/messages", methods=["GET"])
|
@app.route("/api/messages", methods=["GET"])
|
||||||
@jwt_refresh_token_required
|
@jwt_refresh_token_required
|
||||||
async def get_messages():
|
async def get_messages():
|
||||||
@@ -101,10 +132,17 @@ async def get_messages():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
name = conversation.name
|
||||||
|
if len(messages) > 8:
|
||||||
|
name = await blueprints.conversation.logic.rename_conversation(
|
||||||
|
user=user,
|
||||||
|
conversation=conversation,
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"id": str(conversation.id),
|
"id": str(conversation.id),
|
||||||
"name": conversation.name,
|
"name": name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"created_at": conversation.created_at.isoformat(),
|
"created_at": conversation.created_at.isoformat(),
|
||||||
"updated_at": conversation.updated_at.isoformat(),
|
"updated_at": conversation.updated_at.isoformat(),
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
@@ -19,8 +20,8 @@ from .agents import main_agent
|
|||||||
from .logic import (
|
from .logic import (
|
||||||
add_message_to_conversation,
|
add_message_to_conversation,
|
||||||
get_conversation_by_id,
|
get_conversation_by_id,
|
||||||
|
rename_conversation,
|
||||||
)
|
)
|
||||||
from .memory import get_memories_for_user
|
|
||||||
from .models import (
|
from .models import (
|
||||||
Conversation,
|
Conversation,
|
||||||
PydConversation,
|
PydConversation,
|
||||||
@@ -35,27 +36,15 @@ conversation_blueprint = Blueprint(
|
|||||||
_SYSTEM_PROMPT = SIMBA_SYSTEM_PROMPT
|
_SYSTEM_PROMPT = SIMBA_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
|
||||||
async def _build_system_prompt_with_memories(user_id: str) -> str:
|
|
||||||
"""Append user memories to the base system prompt."""
|
|
||||||
memories = await get_memories_for_user(user_id)
|
|
||||||
if not memories:
|
|
||||||
return _SYSTEM_PROMPT
|
|
||||||
memory_block = "\n".join(f"- {m}" for m in memories)
|
|
||||||
return f"{_SYSTEM_PROMPT}\n\nUSER MEMORIES (facts the user has asked you to remember):\n{memory_block}"
|
|
||||||
|
|
||||||
|
|
||||||
def _build_messages_payload(
|
def _build_messages_payload(
|
||||||
conversation,
|
conversation, query_text: str, image_description: str | None = None
|
||||||
query_text: str,
|
|
||||||
image_description: str | None = None,
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
) -> list:
|
) -> list:
|
||||||
recent_messages = (
|
recent_messages = (
|
||||||
conversation.messages[-10:]
|
conversation.messages[-10:]
|
||||||
if len(conversation.messages) > 10
|
if len(conversation.messages) > 10
|
||||||
else conversation.messages
|
else conversation.messages
|
||||||
)
|
)
|
||||||
messages_payload = [{"role": "system", "content": system_prompt or _SYSTEM_PROMPT}]
|
messages_payload = [{"role": "system", "content": _SYSTEM_PROMPT}]
|
||||||
for msg in recent_messages[:-1]: # Exclude the message we just added
|
for msg in recent_messages[:-1]: # Exclude the message we just added
|
||||||
role = "user" if msg.speaker == "user" else "assistant"
|
role = "user" if msg.speaker == "user" else "assistant"
|
||||||
text = msg.text
|
text = msg.text
|
||||||
@@ -91,14 +80,10 @@ async def query():
|
|||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
system_prompt = await _build_system_prompt_with_memories(str(user.id))
|
messages_payload = _build_messages_payload(conversation, query)
|
||||||
messages_payload = _build_messages_payload(
|
|
||||||
conversation, query, system_prompt=system_prompt
|
|
||||||
)
|
|
||||||
payload = {"messages": messages_payload}
|
payload = {"messages": messages_payload}
|
||||||
agent_config = {"configurable": {"user_id": str(user.id)}}
|
|
||||||
|
|
||||||
response = await main_agent.ainvoke(payload, config=agent_config)
|
response = await main_agent.ainvoke(payload)
|
||||||
message = response.get("messages", [])[-1].content
|
message = response.get("messages", [])[-1].content
|
||||||
await add_message_to_conversation(
|
await add_message_to_conversation(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
@@ -178,19 +163,15 @@ async def stream_query():
|
|||||||
logging.error(f"Failed to analyze image: {e}")
|
logging.error(f"Failed to analyze image: {e}")
|
||||||
image_description = "[Image could not be analyzed]"
|
image_description = "[Image could not be analyzed]"
|
||||||
|
|
||||||
system_prompt = await _build_system_prompt_with_memories(str(user.id))
|
|
||||||
messages_payload = _build_messages_payload(
|
messages_payload = _build_messages_payload(
|
||||||
conversation, query_text or "", image_description, system_prompt=system_prompt
|
conversation, query_text or "", image_description
|
||||||
)
|
)
|
||||||
payload = {"messages": messages_payload}
|
payload = {"messages": messages_payload}
|
||||||
agent_config = {"configurable": {"user_id": str(user.id)}}
|
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
final_message = None
|
final_message = None
|
||||||
try:
|
try:
|
||||||
async for event in main_agent.astream_events(
|
async for event in main_agent.astream_events(payload, version="v2"):
|
||||||
payload, version="v2", config=agent_config
|
|
||||||
):
|
|
||||||
event_type = event.get("event")
|
event_type = event.get("event")
|
||||||
|
|
||||||
if event_type == "on_tool_start":
|
if event_type == "on_tool_start":
|
||||||
@@ -240,6 +221,8 @@ async def stream_query():
|
|||||||
@jwt_refresh_token_required
|
@jwt_refresh_token_required
|
||||||
async def get_conversation(conversation_id: str):
|
async def get_conversation(conversation_id: str):
|
||||||
conversation = await Conversation.get(id=conversation_id)
|
conversation = await Conversation.get(id=conversation_id)
|
||||||
|
current_user_uuid = get_jwt_identity()
|
||||||
|
user = await blueprints.users.models.User.get(id=current_user_uuid)
|
||||||
await conversation.fetch_related("messages")
|
await conversation.fetch_related("messages")
|
||||||
|
|
||||||
# Manually serialize the conversation with messages
|
# Manually serialize the conversation with messages
|
||||||
@@ -254,10 +237,18 @@ async def get_conversation(conversation_id: str):
|
|||||||
"image_key": msg.image_key,
|
"image_key": msg.image_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
name = conversation.name
|
||||||
|
if len(messages) > 8 and "datetime" in name.lower():
|
||||||
|
name = await rename_conversation(
|
||||||
|
user=user,
|
||||||
|
conversation=conversation,
|
||||||
|
)
|
||||||
|
print(name)
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"id": str(conversation.id),
|
"id": str(conversation.id),
|
||||||
"name": conversation.name,
|
"name": name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"created_at": conversation.created_at.isoformat(),
|
"created_at": conversation.created_at.isoformat(),
|
||||||
"updated_at": conversation.updated_at.isoformat(),
|
"updated_at": conversation.updated_at.isoformat(),
|
||||||
@@ -271,7 +262,7 @@ async def create_conversation():
|
|||||||
user_uuid = get_jwt_identity()
|
user_uuid = get_jwt_identity()
|
||||||
user = await blueprints.users.models.User.get(id=user_uuid)
|
user = await blueprints.users.models.User.get(id=user_uuid)
|
||||||
conversation = await Conversation.create(
|
conversation = await Conversation.create(
|
||||||
name="New Conversation",
|
name=f"{user.username} {datetime.datetime.now().timestamp}",
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -284,7 +275,7 @@ async def create_conversation():
|
|||||||
async def get_all_conversations():
|
async def get_all_conversations():
|
||||||
user_uuid = get_jwt_identity()
|
user_uuid = get_jwt_identity()
|
||||||
user = await blueprints.users.models.User.get(id=user_uuid)
|
user = await blueprints.users.models.User.get(id=user_uuid)
|
||||||
conversations = Conversation.filter(user=user).order_by("-updated_at")
|
conversations = Conversation.filter(user=user)
|
||||||
serialized_conversations = await PydListConversation.from_queryset(conversations)
|
serialized_conversations = await PydListConversation.from_queryset(conversations)
|
||||||
|
|
||||||
return jsonify(serialized_conversations.model_dump())
|
return jsonify(serialized_conversations.model_dump())
|
||||||
|
|||||||
@@ -5,11 +5,9 @@ from dotenv import load_dotenv
|
|||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.chat_models import BaseChatModel
|
from langchain.chat_models import BaseChatModel
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from tavily import AsyncTavilyClient
|
from tavily import AsyncTavilyClient
|
||||||
|
|
||||||
from blueprints.conversation.memory import save_memory
|
|
||||||
from blueprints.rag.logic import query_vector_store
|
from blueprints.rag.logic import query_vector_store
|
||||||
from utils.obsidian_service import ObsidianService
|
from utils.obsidian_service import ObsidianService
|
||||||
from utils.ynab_service import YNABService
|
from utils.ynab_service import YNABService
|
||||||
@@ -328,7 +326,7 @@ async def obsidian_search_notes(query: str) -> str:
|
|||||||
return "Obsidian integration is not configured. Please set OBSIDIAN_VAULT_PATH environment variable."
|
return "Obsidian integration is not configured. Please set OBSIDIAN_VAULT_PATH environment variable."
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Query vector store for obsidian documents
|
# Query ChromaDB for obsidian documents
|
||||||
serialized, docs = await query_vector_store(query=query)
|
serialized, docs = await query_vector_store(query=query)
|
||||||
return serialized
|
return serialized
|
||||||
|
|
||||||
@@ -591,35 +589,8 @@ async def obsidian_create_task(
|
|||||||
return f"Error creating task: {str(e)}"
|
return f"Error creating task: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def save_user_memory(content: str, config: RunnableConfig) -> str:
|
|
||||||
"""Save a fact or preference about the user for future conversations.
|
|
||||||
|
|
||||||
Use this tool when the user:
|
|
||||||
- Explicitly asks you to remember something ("remember that...", "keep in mind...")
|
|
||||||
- Shares a personal preference that would be useful in future conversations
|
|
||||||
(e.g., "I prefer metric units", "my cat's name is Luna")
|
|
||||||
- Tells you a meaningful personal fact (e.g., "I'm allergic to peanuts")
|
|
||||||
|
|
||||||
Do NOT save:
|
|
||||||
- Trivial or ephemeral info (e.g., "I'm tired today")
|
|
||||||
- Information already in the system prompt or documents
|
|
||||||
- Conversation-specific context that won't matter later
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: A concise statement of the fact or preference to remember.
|
|
||||||
Write it as a standalone sentence (e.g., "User prefers dark mode"
|
|
||||||
rather than "likes dark mode").
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Confirmation that the memory was saved.
|
|
||||||
"""
|
|
||||||
user_id = config["configurable"]["user_id"]
|
|
||||||
return await save_memory(user_id=user_id, content=content)
|
|
||||||
|
|
||||||
|
|
||||||
# Create tools list based on what's available
|
# Create tools list based on what's available
|
||||||
tools = [get_current_date, simba_search, web_search, save_user_memory]
|
tools = [get_current_date, simba_search, web_search]
|
||||||
if ynab_enabled:
|
if ynab_enabled:
|
||||||
tools.extend(
|
tools.extend(
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import tortoise.exceptions
|
import tortoise.exceptions
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
import blueprints.users.models
|
import blueprints.users.models
|
||||||
|
|
||||||
from .models import Conversation, ConversationMessage
|
from .models import Conversation, ConversationMessage, RenameConversationOutputSchema
|
||||||
|
|
||||||
|
|
||||||
async def create_conversation(name: str = "") -> Conversation:
|
async def create_conversation(name: str = "") -> Conversation:
|
||||||
@@ -18,12 +19,6 @@ async def add_message_to_conversation(
|
|||||||
image_key: str | None = None,
|
image_key: str | None = None,
|
||||||
) -> ConversationMessage:
|
) -> ConversationMessage:
|
||||||
print(conversation, message, speaker)
|
print(conversation, message, speaker)
|
||||||
|
|
||||||
# Name the conversation after the first user message
|
|
||||||
if speaker == "user" and not await conversation.messages.all().exists():
|
|
||||||
conversation.name = message[:100]
|
|
||||||
await conversation.save()
|
|
||||||
|
|
||||||
message = await ConversationMessage.create(
|
message = await ConversationMessage.create(
|
||||||
text=message,
|
text=message,
|
||||||
speaker=speaker,
|
speaker=speaker,
|
||||||
@@ -66,3 +61,22 @@ async def get_conversation_transcript(
|
|||||||
messages.append(f"{message.speaker} at {message.created_at}: {message.text}")
|
messages.append(f"{message.speaker} at {message.created_at}: {message.text}")
|
||||||
|
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
|
|
||||||
|
async def rename_conversation(
|
||||||
|
user: blueprints.users.models.User,
|
||||||
|
conversation: Conversation,
|
||||||
|
) -> str:
|
||||||
|
messages: str = await get_conversation_transcript(
|
||||||
|
user=user, conversation=conversation
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||||
|
structured_llm = llm.with_structured_output(RenameConversationOutputSchema)
|
||||||
|
|
||||||
|
prompt = f"Summarize the following conversation into a sassy one-liner title:\n\n{messages}"
|
||||||
|
response = structured_llm.invoke(prompt)
|
||||||
|
new_name: str = response.get("title", "")
|
||||||
|
conversation.name = new_name
|
||||||
|
await conversation.save()
|
||||||
|
return new_name
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
from .models import UserMemory
|
|
||||||
|
|
||||||
|
|
||||||
async def get_memories_for_user(user_id: str) -> list[str]:
|
|
||||||
"""Return all memory content strings for a user, ordered by most recently updated."""
|
|
||||||
memories = await UserMemory.filter(user_id=user_id).order_by("-updated_at")
|
|
||||||
return [m.content for m in memories]
|
|
||||||
|
|
||||||
|
|
||||||
async def save_memory(user_id: str, content: str) -> str:
|
|
||||||
"""Save a new memory or touch an existing one (exact-match dedup)."""
|
|
||||||
existing = await UserMemory.filter(user_id=user_id, content=content).first()
|
|
||||||
if existing:
|
|
||||||
existing.updated_at = None # auto_now=True will refresh it on save
|
|
||||||
await existing.save(update_fields=["updated_at"])
|
|
||||||
return "Memory already exists (refreshed)."
|
|
||||||
|
|
||||||
await UserMemory.create(user_id=user_id, content=content)
|
|
||||||
return "Memory saved."
|
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import enum
|
import enum
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
from tortoise.contrib.pydantic import (
|
from tortoise.contrib.pydantic import (
|
||||||
@@ -8,6 +9,12 @@ from tortoise.contrib.pydantic import (
|
|||||||
from tortoise.models import Model
|
from tortoise.models import Model
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RenameConversationOutputSchema:
|
||||||
|
title: str
|
||||||
|
justification: str
|
||||||
|
|
||||||
|
|
||||||
class Speaker(enum.Enum):
|
class Speaker(enum.Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
SIMBA = "simba"
|
SIMBA = "simba"
|
||||||
@@ -40,17 +47,6 @@ class ConversationMessage(Model):
|
|||||||
table = "conversation_messages"
|
table = "conversation_messages"
|
||||||
|
|
||||||
|
|
||||||
class UserMemory(Model):
|
|
||||||
id = fields.UUIDField(primary_key=True)
|
|
||||||
user = fields.ForeignKeyField("models.User", related_name="memories")
|
|
||||||
content = fields.TextField()
|
|
||||||
created_at = fields.DatetimeField(auto_now_add=True)
|
|
||||||
updated_at = fields.DatetimeField(auto_now=True)
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
table = "user_memories"
|
|
||||||
|
|
||||||
|
|
||||||
PydConversationMessage = pydantic_model_creator(ConversationMessage)
|
PydConversationMessage = pydantic_model_creator(ConversationMessage)
|
||||||
PydConversation = pydantic_model_creator(
|
PydConversation = pydantic_model_creator(
|
||||||
Conversation, name="Conversation", allow_cycles=True, exclude=("user",)
|
Conversation, name="Conversation", allow_cycles=True, exclude=("user",)
|
||||||
|
|||||||
@@ -54,7 +54,4 @@ You have access to Ryan's daily journal notes. Each note lives at journal/YYYY/Y
|
|||||||
- Use journal_get_tasks to list tasks (done/pending) for today or a specific date
|
- Use journal_get_tasks to list tasks (done/pending) for today or a specific date
|
||||||
- Use journal_add_task to add a new task to today's (or a given date's) note
|
- Use journal_add_task to add a new task to today's (or a given date's) note
|
||||||
- Use journal_complete_task to check off a task as done
|
- Use journal_complete_task to check off a task as done
|
||||||
Use these tools when Ryan asks about today's tasks, wants to add something to his list, or wants to mark a task complete.
|
Use these tools when Ryan asks about today's tasks, wants to add something to his list, or wants to mark a task complete."""
|
||||||
|
|
||||||
USER MEMORY:
|
|
||||||
You can remember facts about the user across conversations using the save_user_memory tool. When a user explicitly asks you to remember something, or shares a meaningful preference or personal fact, save it. Saved memories will automatically appear at the end of this prompt in future conversations under "USER MEMORIES"."""
|
|
||||||
|
|||||||
@@ -1,12 +1,7 @@
|
|||||||
from quart import Blueprint, jsonify
|
from quart import Blueprint, jsonify
|
||||||
from quart_jwt_extended import jwt_refresh_token_required
|
from quart_jwt_extended import jwt_refresh_token_required
|
||||||
|
|
||||||
from .logic import (
|
from .logic import fetch_obsidian_documents, get_vector_store_stats, index_documents, index_obsidian_documents, vector_store
|
||||||
delete_all_documents,
|
|
||||||
get_vector_store_stats,
|
|
||||||
index_documents,
|
|
||||||
index_obsidian_documents,
|
|
||||||
)
|
|
||||||
from blueprints.users.decorators import admin_required
|
from blueprints.users.decorators import admin_required
|
||||||
|
|
||||||
rag_blueprint = Blueprint("rag_api", __name__, url_prefix="/api/rag")
|
rag_blueprint = Blueprint("rag_api", __name__, url_prefix="/api/rag")
|
||||||
@@ -37,7 +32,14 @@ async def trigger_index():
|
|||||||
async def trigger_reindex():
|
async def trigger_reindex():
|
||||||
"""Clear and reindex all documents. Admin only."""
|
"""Clear and reindex all documents. Admin only."""
|
||||||
try:
|
try:
|
||||||
delete_all_documents()
|
# Clear existing documents
|
||||||
|
collection = vector_store._collection
|
||||||
|
all_docs = collection.get()
|
||||||
|
|
||||||
|
if all_docs["ids"]:
|
||||||
|
collection.delete(ids=all_docs["ids"])
|
||||||
|
|
||||||
|
# Reindex
|
||||||
await index_documents()
|
await index_documents()
|
||||||
stats = get_vector_store_stats()
|
stats = get_vector_store_stats()
|
||||||
return jsonify({"status": "success", "stats": stats})
|
return jsonify({"status": "success", "stats": stats})
|
||||||
|
|||||||
+25
-121
@@ -1,13 +1,11 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from langchain_chroma import Chroma
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
from langchain_postgres import PGVector
|
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
from sqlalchemy import create_engine, text
|
|
||||||
|
|
||||||
from .fetchers import PaperlessNGXService
|
from .fetchers import PaperlessNGXService
|
||||||
from utils.obsidian_service import ObsidianService
|
from utils.obsidian_service import ObsidianService
|
||||||
@@ -15,40 +13,13 @@ from utils.obsidian_service import ObsidianService
|
|||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
|
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
|
||||||
|
|
||||||
# Convert Tortoise-style postgres:// URL to SQLAlchemy-style postgresql+psycopg://
|
vector_store = Chroma(
|
||||||
_db_url = os.getenv(
|
|
||||||
"DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr"
|
|
||||||
)
|
|
||||||
_pgvector_url = _db_url.replace("postgres://", "postgresql+psycopg://")
|
|
||||||
|
|
||||||
# Lazy-initialized vector store (defers DB connection to first use)
|
|
||||||
_vector_store = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_vector_store() -> PGVector:
|
|
||||||
global _vector_store
|
|
||||||
if _vector_store is None:
|
|
||||||
_vector_store = PGVector(
|
|
||||||
embeddings=embeddings,
|
|
||||||
collection_name="simba_docs",
|
collection_name="simba_docs",
|
||||||
connection=_pgvector_url,
|
embedding_function=embeddings,
|
||||||
use_jsonb=True,
|
persist_directory=os.getenv("CHROMADB_PATH", ""),
|
||||||
create_extension=False, # created by docker init script
|
)
|
||||||
async_mode=True,
|
|
||||||
)
|
|
||||||
return _vector_store
|
|
||||||
|
|
||||||
|
|
||||||
def _get_engine():
|
|
||||||
"""Get a SQLAlchemy engine for direct queries."""
|
|
||||||
if not hasattr(_get_engine, "_engine"):
|
|
||||||
_get_engine._engine = create_engine(_pgvector_url)
|
|
||||||
return _get_engine._engine
|
|
||||||
|
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=1000, # chunk size (characters)
|
chunk_size=1000, # chunk size (characters)
|
||||||
@@ -57,22 +28,6 @@ text_splitter = RecursiveCharacterTextSplitter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_collection_id():
|
|
||||||
"""Get the UUID of our collection from the langchain_pg_collection table."""
|
|
||||||
engine = _get_engine()
|
|
||||||
try:
|
|
||||||
with engine.connect() as conn:
|
|
||||||
result = conn.execute(
|
|
||||||
text("SELECT uuid FROM langchain_pg_collection WHERE name = :name"),
|
|
||||||
{"name": "simba_docs"},
|
|
||||||
)
|
|
||||||
row = result.fetchone()
|
|
||||||
return row[0] if row else None
|
|
||||||
except Exception:
|
|
||||||
# Table doesn't exist yet (first run before any indexing)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def date_to_epoch(date_str: str) -> float:
|
def date_to_epoch(date_str: str) -> float:
|
||||||
split_date = date_str.split("-")
|
split_date = date_str.split("-")
|
||||||
date = datetime.datetime(
|
date = datetime.datetime(
|
||||||
@@ -108,7 +63,6 @@ async def index_documents():
|
|||||||
documents = await fetch_documents_from_paperless_ngx()
|
documents = await fetch_documents_from_paperless_ngx()
|
||||||
|
|
||||||
splits = text_splitter.split_documents(documents)
|
splits = text_splitter.split_documents(documents)
|
||||||
vector_store = _get_vector_store()
|
|
||||||
await vector_store.aadd_documents(documents=splits)
|
await vector_store.aadd_documents(documents=splits)
|
||||||
|
|
||||||
|
|
||||||
@@ -138,17 +92,13 @@ async def fetch_obsidian_documents() -> list[Document]:
|
|||||||
"filepath": parsed["filepath"],
|
"filepath": parsed["filepath"],
|
||||||
"tags": parsed["tags"],
|
"tags": parsed["tags"],
|
||||||
"created_at": parsed["metadata"].get("created_at"),
|
"created_at": parsed["metadata"].get("created_at"),
|
||||||
**{
|
**{k: v for k, v in parsed["metadata"].items() if k not in ["created_at", "created_by"]},
|
||||||
k: v
|
|
||||||
for k, v in parsed["metadata"].items()
|
|
||||||
if k not in ["created_at", "created_by"]
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error reading {md_path}: {e}")
|
print(f"Error reading {md_path}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
@@ -159,25 +109,26 @@ async def index_obsidian_documents():
|
|||||||
|
|
||||||
Deletes existing obsidian source chunks before re-indexing.
|
Deletes existing obsidian source chunks before re-indexing.
|
||||||
"""
|
"""
|
||||||
|
obsidian_service = ObsidianService()
|
||||||
documents = await fetch_obsidian_documents()
|
documents = await fetch_obsidian_documents()
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
logger.info("No Obsidian documents found to index")
|
print("No Obsidian documents found to index")
|
||||||
return {"indexed": 0}
|
return {"indexed": 0}
|
||||||
|
|
||||||
# Delete existing obsidian chunks
|
# Delete existing obsidian chunks
|
||||||
delete_documents_by_metadata("source", "obsidian")
|
existing_results = vector_store.get(where={"source": "obsidian"})
|
||||||
|
if existing_results.get("ids"):
|
||||||
|
await vector_store.adelete(existing_results["ids"])
|
||||||
|
|
||||||
# Split and index documents
|
# Split and index documents
|
||||||
splits = text_splitter.split_documents(documents)
|
splits = text_splitter.split_documents(documents)
|
||||||
vector_store = _get_vector_store()
|
|
||||||
await vector_store.aadd_documents(documents=splits)
|
await vector_store.aadd_documents(documents=splits)
|
||||||
|
|
||||||
return {"indexed": len(documents)}
|
return {"indexed": len(documents)}
|
||||||
|
|
||||||
|
|
||||||
async def query_vector_store(query: str):
|
async def query_vector_store(query: str):
|
||||||
vector_store = _get_vector_store()
|
|
||||||
retrieved_docs = await vector_store.asimilarity_search(query, k=2)
|
retrieved_docs = await vector_store.asimilarity_search(query, k=2)
|
||||||
serialized = "\n\n".join(
|
serialized = "\n\n".join(
|
||||||
(f"Source: {doc.metadata}\nContent: {doc.page_content}")
|
(f"Source: {doc.metadata}\nContent: {doc.page_content}")
|
||||||
@@ -186,79 +137,32 @@ async def query_vector_store(query: str):
|
|||||||
return serialized, retrieved_docs
|
return serialized, retrieved_docs
|
||||||
|
|
||||||
|
|
||||||
def delete_all_documents():
|
|
||||||
"""Delete all documents from the vector store collection."""
|
|
||||||
collection_id = _get_collection_id()
|
|
||||||
if not collection_id:
|
|
||||||
return
|
|
||||||
engine = _get_engine()
|
|
||||||
with engine.connect() as conn:
|
|
||||||
conn.execute(
|
|
||||||
text("DELETE FROM langchain_pg_embedding WHERE collection_id = :cid"),
|
|
||||||
{"cid": collection_id},
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def delete_documents_by_metadata(key: str, value: str):
|
|
||||||
"""Delete documents matching a metadata key/value pair."""
|
|
||||||
collection_id = _get_collection_id()
|
|
||||||
if not collection_id:
|
|
||||||
return
|
|
||||||
engine = _get_engine()
|
|
||||||
with engine.connect() as conn:
|
|
||||||
conn.execute(
|
|
||||||
text(
|
|
||||||
"DELETE FROM langchain_pg_embedding "
|
|
||||||
"WHERE collection_id = :cid AND cmetadata->>:key = :value"
|
|
||||||
),
|
|
||||||
{"cid": collection_id, "key": key, "value": value},
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def get_vector_store_stats():
|
def get_vector_store_stats():
|
||||||
"""Get statistics about the vector store."""
|
"""Get statistics about the vector store."""
|
||||||
collection_id = _get_collection_id()
|
collection = vector_store._collection
|
||||||
count = 0
|
count = collection.count()
|
||||||
if collection_id:
|
|
||||||
engine = _get_engine()
|
|
||||||
with engine.connect() as conn:
|
|
||||||
result = conn.execute(
|
|
||||||
text(
|
|
||||||
"SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_id = :cid"
|
|
||||||
),
|
|
||||||
{"cid": collection_id},
|
|
||||||
)
|
|
||||||
count = result.scalar()
|
|
||||||
return {
|
return {
|
||||||
"total_documents": count,
|
"total_documents": count,
|
||||||
"collection_name": "simba_docs",
|
"collection_name": collection.name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def list_all_documents(limit: int = 10):
|
def list_all_documents(limit: int = 10):
|
||||||
"""List documents in the vector store with their metadata."""
|
"""List documents in the vector store with their metadata."""
|
||||||
collection_id = _get_collection_id()
|
collection = vector_store._collection
|
||||||
if not collection_id:
|
results = collection.get(limit=limit, include=["metadatas", "documents"])
|
||||||
return []
|
|
||||||
|
|
||||||
engine = _get_engine()
|
|
||||||
with engine.connect() as conn:
|
|
||||||
result = conn.execute(
|
|
||||||
text(
|
|
||||||
"SELECT id, document, cmetadata FROM langchain_pg_embedding "
|
|
||||||
"WHERE collection_id = :cid LIMIT :limit"
|
|
||||||
),
|
|
||||||
{"cid": collection_id, "limit": limit},
|
|
||||||
)
|
|
||||||
documents = []
|
documents = []
|
||||||
for row in result:
|
for i, doc_id in enumerate(results["ids"]):
|
||||||
documents.append(
|
documents.append(
|
||||||
{
|
{
|
||||||
"id": str(row[0]),
|
"id": doc_id,
|
||||||
"metadata": row[2],
|
"metadata": results["metadatas"][i]
|
||||||
"content_preview": row[1][:200] if row[1] else None,
|
if results.get("metadatas")
|
||||||
|
else None,
|
||||||
|
"content_preview": results["documents"][i][:200]
|
||||||
|
if results.get("documents")
|
||||||
|
else None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class OIDCUserService:
|
|||||||
claims.get("preferred_username") or claims.get("name") or user.username
|
claims.get("preferred_username") or claims.get("name") or user.username
|
||||||
)
|
)
|
||||||
# Update LDAP groups from claims
|
# Update LDAP groups from claims
|
||||||
user.ldap_groups = claims.get("groups") or []
|
user.ldap_groups = claims.get("groups", [])
|
||||||
await user.save()
|
await user.save()
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@@ -48,7 +48,7 @@ class OIDCUserService:
|
|||||||
user.oidc_subject = oidc_subject
|
user.oidc_subject = oidc_subject
|
||||||
user.auth_provider = "oidc"
|
user.auth_provider = "oidc"
|
||||||
user.password = None # Clear password
|
user.password = None # Clear password
|
||||||
user.ldap_groups = claims.get("groups") or []
|
user.ldap_groups = claims.get("groups", [])
|
||||||
await user.save()
|
await user.save()
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ class OIDCUserService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Extract LDAP groups from claims
|
# Extract LDAP groups from claims
|
||||||
groups = claims.get("groups") or []
|
groups = claims.get("groups", [])
|
||||||
|
|
||||||
user = await User.create(
|
user = await User.create(
|
||||||
id=uuid4(),
|
id=uuid4(),
|
||||||
|
|||||||
+4
-2
@@ -2,7 +2,7 @@ version: "3.8"
|
|||||||
|
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: pgvector/pgvector:pg16
|
image: postgres:16-alpine
|
||||||
ports:
|
ports:
|
||||||
- "5432:5432"
|
- "5432:5432"
|
||||||
environment:
|
environment:
|
||||||
@@ -11,7 +11,6 @@ services:
|
|||||||
- POSTGRES_DB=${POSTGRES_DB:-raggr}
|
- POSTGRES_DB=${POSTGRES_DB:-raggr}
|
||||||
volumes:
|
volumes:
|
||||||
- postgres_data:/var/lib/postgresql/data
|
- postgres_data:/var/lib/postgresql/data
|
||||||
- ./docker/init-pgvector.sql:/docker-entrypoint-initdb.d/init-pgvector.sql
|
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-raggr}"]
|
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-raggr}"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
@@ -30,6 +29,7 @@ services:
|
|||||||
- PAPERLESS_TOKEN=${PAPERLESS_TOKEN}
|
- PAPERLESS_TOKEN=${PAPERLESS_TOKEN}
|
||||||
- BASE_URL=${BASE_URL}
|
- BASE_URL=${BASE_URL}
|
||||||
- OLLAMA_URL=${OLLAMA_URL:-http://localhost:11434}
|
- OLLAMA_URL=${OLLAMA_URL:-http://localhost:11434}
|
||||||
|
- CHROMADB_PATH=/app/data/chromadb
|
||||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
|
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
|
||||||
- LLAMA_SERVER_URL=${LLAMA_SERVER_URL}
|
- LLAMA_SERVER_URL=${LLAMA_SERVER_URL}
|
||||||
@@ -66,8 +66,10 @@ services:
|
|||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
volumes:
|
volumes:
|
||||||
|
- chromadb_data:/app/data/chromadb
|
||||||
- ./obvault:/app/data/obsidian
|
- ./obvault:/app/data/obsidian
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
|
chromadb_data:
|
||||||
postgres_data:
|
postgres_data:
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
CREATE EXTENSION IF NOT EXISTS vector;
|
|
||||||
@@ -0,0 +1,278 @@
|
|||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
from utils.chunker import Chunker
|
||||||
|
from utils.cleaner import pdf_to_image, summarize_pdf_image
|
||||||
|
from llm import LLMClient
|
||||||
|
from scripts.query import QueryGenerator
|
||||||
|
from utils.request import PaperlessNGXService
|
||||||
|
|
||||||
|
_dotenv_loaded = load_dotenv()
|
||||||
|
|
||||||
|
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
|
||||||
|
simba_docs = client.get_or_create_collection(name="simba_docs2")
|
||||||
|
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="An LLM tool to query information about Simba <3"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("query", type=str, help="questions about simba's health")
|
||||||
|
parser.add_argument(
|
||||||
|
"--reindex", action="store_true", help="re-index the simba documents"
|
||||||
|
)
|
||||||
|
parser.add_argument("--classify", action="store_true", help="test classification")
|
||||||
|
parser.add_argument("--index", help="index a file")
|
||||||
|
|
||||||
|
ppngx = PaperlessNGXService()
|
||||||
|
|
||||||
|
llm_client = LLMClient()
|
||||||
|
|
||||||
|
|
||||||
|
def index_using_pdf_llm(doctypes):
|
||||||
|
logging.info("reindex data...")
|
||||||
|
files = ppngx.get_data()
|
||||||
|
for file in files:
|
||||||
|
document_id: int = file["id"]
|
||||||
|
pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
||||||
|
image_paths = pdf_to_image(filepath=pdf_path)
|
||||||
|
logging.info(f"summarizing {file}")
|
||||||
|
generated_summary = summarize_pdf_image(filepaths=image_paths)
|
||||||
|
file["content"] = generated_summary
|
||||||
|
|
||||||
|
chunk_data(files, simba_docs, doctypes=doctypes)
|
||||||
|
|
||||||
|
|
||||||
|
def date_to_epoch(date_str: str) -> float:
|
||||||
|
split_date = date_str.split("-")
|
||||||
|
date = datetime.datetime(
|
||||||
|
int(split_date[0]),
|
||||||
|
int(split_date[1]),
|
||||||
|
int(split_date[2]),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return date.timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_data(docs, collection, doctypes):
|
||||||
|
# Step 2: Create chunks
|
||||||
|
chunker = Chunker(collection)
|
||||||
|
|
||||||
|
logging.info(f"chunking {len(docs)} documents")
|
||||||
|
texts: list[str] = [doc["content"] for doc in docs]
|
||||||
|
with sqlite3.connect("database/visited.db") as conn:
|
||||||
|
to_insert = []
|
||||||
|
c = conn.cursor()
|
||||||
|
for index, text in enumerate(texts):
|
||||||
|
metadata = {
|
||||||
|
"created_date": date_to_epoch(docs[index]["created_date"]),
|
||||||
|
"filename": docs[index]["original_file_name"],
|
||||||
|
"document_type": doctypes.get(docs[index]["document_type"], ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
if doctypes:
|
||||||
|
metadata["type"] = doctypes.get(docs[index]["document_type"])
|
||||||
|
|
||||||
|
chunker.chunk_document(
|
||||||
|
document=text,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
to_insert.append((docs[index]["id"],))
|
||||||
|
|
||||||
|
c.executemany(
|
||||||
|
"INSERT INTO indexed_documents (paperless_id) values (?)", to_insert
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_text(texts: list[str], collection):
|
||||||
|
chunker = Chunker(collection)
|
||||||
|
|
||||||
|
for index, text in enumerate(texts):
|
||||||
|
metadata = {}
|
||||||
|
chunker.chunk_document(
|
||||||
|
document=text,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def classify_query(query: str, transcript: str) -> bool:
|
||||||
|
logging.info("Starting query generation")
|
||||||
|
qg_start = time.time()
|
||||||
|
qg = QueryGenerator()
|
||||||
|
query_type = qg.get_query_type(input=query, transcript=transcript)
|
||||||
|
logging.info(query_type)
|
||||||
|
qg_end = time.time()
|
||||||
|
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
||||||
|
return query_type == "Simba"
|
||||||
|
|
||||||
|
|
||||||
|
def consult_oracle(
|
||||||
|
input: str,
|
||||||
|
collection,
|
||||||
|
transcript: str = "",
|
||||||
|
):
|
||||||
|
chunker = Chunker(collection)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Ask
|
||||||
|
logging.info("Starting query generation")
|
||||||
|
qg_start = time.time()
|
||||||
|
qg = QueryGenerator()
|
||||||
|
doctype_query = qg.get_doctype_query(input=input)
|
||||||
|
# metadata_filter = qg.get_query(input)
|
||||||
|
metadata_filter = {**doctype_query}
|
||||||
|
logging.info(metadata_filter)
|
||||||
|
qg_end = time.time()
|
||||||
|
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
||||||
|
|
||||||
|
logging.info("Starting embedding generation")
|
||||||
|
embedding_start = time.time()
|
||||||
|
embeddings = chunker.embedding_fx(inputs=[input])
|
||||||
|
embedding_end = time.time()
|
||||||
|
logging.info(
|
||||||
|
f"Embedding generation took {embedding_end - embedding_start:.2f} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Starting collection query")
|
||||||
|
query_start = time.time()
|
||||||
|
results = collection.query(
|
||||||
|
query_texts=[input],
|
||||||
|
query_embeddings=embeddings,
|
||||||
|
where=metadata_filter,
|
||||||
|
)
|
||||||
|
query_end = time.time()
|
||||||
|
logging.info(f"Collection query took {query_end - query_start:.2f} seconds")
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
logging.info("Starting LLM generation")
|
||||||
|
llm_start = time.time()
|
||||||
|
system_prompt = "You are a helpful assistant that understands veterinary terms."
|
||||||
|
transcript_prompt = f"Here is the message transcript thus far {transcript}."
|
||||||
|
prompt = f"""Using the following data, help answer the user's query by providing as many details as possible.
|
||||||
|
Using this data: {results}. {transcript_prompt if len(transcript) > 0 else ""}
|
||||||
|
Respond to this prompt: {input}"""
|
||||||
|
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
|
||||||
|
llm_end = time.time()
|
||||||
|
logging.info(f"LLM generation took {llm_end - llm_start:.2f} seconds")
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
logging.info(f"Total consult_oracle execution took {total_time:.2f} seconds")
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def llm_chat(input: str, transcript: str = "") -> str:
|
||||||
|
system_prompt = "You are a helpful assistant that understands veterinary terms."
|
||||||
|
transcript_prompt = f"Here is the message transcript thus far {transcript}."
|
||||||
|
prompt = f"""Answer the user in as if you were a cat named Simba. Don't act too catlike. Be assertive.
|
||||||
|
{transcript_prompt if len(transcript) > 0 else ""}
|
||||||
|
Respond to this prompt: {input}"""
|
||||||
|
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def paperless_workflow(input):
|
||||||
|
# Step 1: Get the text
|
||||||
|
ppngx = PaperlessNGXService()
|
||||||
|
docs = ppngx.get_data()
|
||||||
|
|
||||||
|
chunk_data(docs, collection=simba_docs)
|
||||||
|
consult_oracle(input, simba_docs)
|
||||||
|
|
||||||
|
|
||||||
|
def consult_simba_oracle(input: str, transcript: str = ""):
|
||||||
|
is_simba_related = classify_query(query=input, transcript=transcript)
|
||||||
|
|
||||||
|
if is_simba_related:
|
||||||
|
logging.info("Query is related to simba")
|
||||||
|
return consult_oracle(
|
||||||
|
input=input,
|
||||||
|
collection=simba_docs,
|
||||||
|
transcript=transcript,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Query is NOT related to simba")
|
||||||
|
|
||||||
|
return llm_chat(input=input, transcript=transcript)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_indexed_files(docs):
|
||||||
|
with sqlite3.connect("database/visited.db") as conn:
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
|
||||||
|
)
|
||||||
|
c.execute("SELECT paperless_id FROM indexed_documents")
|
||||||
|
rows = c.fetchall()
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
visited = {row[0] for row in rows}
|
||||||
|
return [doc for doc in docs if doc["id"] not in visited]
|
||||||
|
|
||||||
|
|
||||||
|
def reindex():
|
||||||
|
with sqlite3.connect("database/visited.db") as conn:
|
||||||
|
c = conn.cursor()
|
||||||
|
# Ensure the table exists before trying to delete from it
|
||||||
|
c.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
|
||||||
|
)
|
||||||
|
c.execute("DELETE FROM indexed_documents")
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
# Delete all documents from the collection
|
||||||
|
all_docs = simba_docs.get()
|
||||||
|
if all_docs["ids"]:
|
||||||
|
simba_docs.delete(ids=all_docs["ids"])
|
||||||
|
|
||||||
|
logging.info("Fetching documents from Paperless-NGX")
|
||||||
|
ppngx = PaperlessNGXService()
|
||||||
|
docs = ppngx.get_data()
|
||||||
|
docs = filter_indexed_files(docs)
|
||||||
|
logging.info(f"Fetched {len(docs)} documents")
|
||||||
|
|
||||||
|
# Delete all chromadb data
|
||||||
|
ids = simba_docs.get(ids=None, limit=None, offset=0)
|
||||||
|
all_ids = ids["ids"]
|
||||||
|
if len(all_ids) > 0:
|
||||||
|
simba_docs.delete(ids=all_ids)
|
||||||
|
|
||||||
|
# Chunk documents
|
||||||
|
logging.info("Chunking documents now ...")
|
||||||
|
doctype_lookup = ppngx.get_doctypes()
|
||||||
|
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
|
||||||
|
logging.info("Done chunking documents")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.reindex:
|
||||||
|
reindex()
|
||||||
|
|
||||||
|
if args.classify:
|
||||||
|
consult_simba_oracle(input="yohohoho testing")
|
||||||
|
consult_simba_oracle(input="write an email")
|
||||||
|
consult_simba_oracle(input="how much does simba weigh")
|
||||||
|
|
||||||
|
if args.query:
|
||||||
|
logging.info("Consulting oracle ...")
|
||||||
|
print(
|
||||||
|
consult_oracle(
|
||||||
|
input=args.query,
|
||||||
|
collection=simba_docs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("please provide a query")
|
||||||
@@ -1,112 +0,0 @@
|
|||||||
from tortoise import BaseDBAsyncClient
|
|
||||||
|
|
||||||
RUN_IN_TRANSACTION = True
|
|
||||||
|
|
||||||
|
|
||||||
async def upgrade(db: BaseDBAsyncClient) -> str:
|
|
||||||
return """
|
|
||||||
CREATE TABLE IF NOT EXISTS "user_memories" (
|
|
||||||
"id" UUID NOT NULL PRIMARY KEY,
|
|
||||||
"content" TEXT NOT NULL,
|
|
||||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"user_id" UUID NOT NULL REFERENCES "users" ("id") ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
CREATE TABLE IF NOT EXISTS "email_accounts" (
|
|
||||||
"id" UUID NOT NULL PRIMARY KEY,
|
|
||||||
"email_address" VARCHAR(255) NOT NULL UNIQUE,
|
|
||||||
"display_name" VARCHAR(255),
|
|
||||||
"imap_host" VARCHAR(255) NOT NULL,
|
|
||||||
"imap_port" INT NOT NULL DEFAULT 993,
|
|
||||||
"imap_username" VARCHAR(255) NOT NULL,
|
|
||||||
"imap_password" TEXT NOT NULL,
|
|
||||||
"is_active" BOOL NOT NULL DEFAULT True,
|
|
||||||
"last_error" TEXT,
|
|
||||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"user_id" UUID NOT NULL REFERENCES "users" ("id") ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
COMMENT ON TABLE "email_accounts" IS 'Email account configuration for IMAP connections.';
|
|
||||||
CREATE TABLE IF NOT EXISTS "emails" (
|
|
||||||
"id" UUID NOT NULL PRIMARY KEY,
|
|
||||||
"message_id" VARCHAR(255) NOT NULL UNIQUE,
|
|
||||||
"subject" VARCHAR(500) NOT NULL,
|
|
||||||
"from_address" VARCHAR(255) NOT NULL,
|
|
||||||
"to_address" TEXT NOT NULL,
|
|
||||||
"date" TIMESTAMPTZ NOT NULL,
|
|
||||||
"body_text" TEXT,
|
|
||||||
"body_html" TEXT,
|
|
||||||
"chromadb_doc_id" VARCHAR(255),
|
|
||||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"expires_at" TIMESTAMPTZ NOT NULL,
|
|
||||||
"account_id" UUID NOT NULL REFERENCES "email_accounts" ("id") ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS "idx_emails_message_981ddd" ON "emails" ("message_id");
|
|
||||||
COMMENT ON TABLE "emails" IS 'Email message metadata and content.';
|
|
||||||
CREATE TABLE IF NOT EXISTS "email_sync_status" (
|
|
||||||
"id" UUID NOT NULL PRIMARY KEY,
|
|
||||||
"last_sync_date" TIMESTAMPTZ,
|
|
||||||
"last_message_uid" INT NOT NULL DEFAULT 0,
|
|
||||||
"message_count" INT NOT NULL DEFAULT 0,
|
|
||||||
"consecutive_failures" INT NOT NULL DEFAULT 0,
|
|
||||||
"last_failure_date" TIMESTAMPTZ,
|
|
||||||
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"account_id" UUID NOT NULL REFERENCES "email_accounts" ("id") ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
COMMENT ON TABLE "email_sync_status" IS 'Tracks sync progress and state per email account.';"""
|
|
||||||
|
|
||||||
|
|
||||||
async def downgrade(db: BaseDBAsyncClient) -> str:
|
|
||||||
return """
|
|
||||||
DROP TABLE IF EXISTS "user_memories";
|
|
||||||
DROP TABLE IF EXISTS "email_accounts";
|
|
||||||
DROP TABLE IF EXISTS "emails";
|
|
||||||
DROP TABLE IF EXISTS "email_sync_status";"""
|
|
||||||
|
|
||||||
|
|
||||||
MODELS_STATE = (
|
|
||||||
"eJztXGtv2zYU/SuCPrVAFjTPbcUwwE7czVudDLGz9ZFCoCXa1ixRGkk1NYr+911Skq0HZV"
|
|
||||||
"t+RUr1oU1C8lLU4SV57tGVvuquZ2GHHV955DOmDHHbI/pr7atOkIvhF2X9kaYj31/UigKO"
|
|
||||||
"ho40MBMtZQ0aMk6RyaFyhByGocjCzKS2H12MBI4jCj0TGtpkvCgKiP1fgA3ujTGfYAoVHz"
|
|
||||||
"9BsU0s/AWz+E9/aoxs7FipcduWuLYsN/jMl2X3993rN7KluNzQMD0ncMmitT/jE4/MmweB"
|
|
||||||
"bR0LG1E3xgRTxLGVuA0xyui246JwxFDAaYDnQ7UWBRYeocARYOi/jAJiCgw0eSXx3/mveg"
|
|
||||||
"l4AGoBrU24wOLrt/CuFvcsS3VxqavfW3cvzi5fyrv0GB9TWSkR0b9JQ8RRaCpxXQApf+ag"
|
|
||||||
"vJogqoYybp8BEwa6CYxxwQLHhQ/FQMYAbYaa7qIvhoPJmE/gz9OLiyUw/t26k0hCKwmlB3"
|
|
||||||
"4dev1NVHUa1glIFxCaFItbNhDPA3kNNdx2sRrMtGUGUisyPY5/qSjAcA/WLXFm0SJYgu+g"
|
|
||||||
"2+v0B63eX+JOXMb+cyRErUFH1JzK0lmm9MVlZirmnWj/dAe/a+JP7cPtTSfr+/N2gw+6GB"
|
|
||||||
"MKuGcQ79FAVmK9xqUxMKmJDXxrw4lNWzYT+6QTGw0+Ma8MU6PcCZIw2eIYicZ2wEnc/NAQ"
|
|
||||||
"R+9oqjwzBBh58N54FNtj8ieeSQi7MA5ETNVhEZGO+6ibqoK2KF2MgqLHORtJOgXcHdwT5u"
|
|
||||||
"Hp2epfta47usRwiMzpI6KWUQCmixlDY8zygLYjyzd/3mFnTs3UWCYJXC/ssZq7ShG2Eivv"
|
|
||||||
"1EtglEIvX+WeutkSROC+reja4kpL0FnBghMgrkeGjeRENqS41qSY4y+KI38ApWoo4/Z1Ic"
|
|
||||||
"XLjvLOu0HqFI+p74te693L1En+9vbmt7h5gipfvb1tNwz5ORKpPENmPkZTFRkQAWSHBG6O"
|
|
||||||
"CqRmN2H+xEtHv+937l5r4kR/IP1ur916rTHbHSJ9vSlORZknr9YIMk9eFcaYoiq9gGwXTh"
|
|
||||||
"ZjimdlQvWU0Ub4Hp56pYG8ODldA0loVQilrMtsRslDu9yRqTDd5flZ03DAzIiHW4YFWS2y"
|
|
||||||
"siiujA8U7lI2TtgnKxbxVw+7Hp3pCjKcqF3KgWUQ5IqGdsN9nwH3hYtwTErR34RJw4AbBv"
|
|
||||||
"xdMeBGI34WE1sdjbham2FdROIKs8AtVOJ9s78i3rea8TVMr/5MT8xj2cf/SZu6cL0DpAD4"
|
|
||||||
"iLFHjyo8s20TRGdqMJNWGTCHMx5GU5VTaJaA1xa8N3m6A2Tt7k3r7r2aOsftk37bfj/otD"
|
|
||||||
"LoYhfZThkvnRvsxkVXr/hdOujJq/Xkw2X6YU5AfJwgzmBLN0jgDosEWzWYCtOdiImHRfVs"
|
|
||||||
"HVDPijE9y0EqnczARNyeauF7noMRWeKgSdvs8gfjfW2mZY/qEuv/9vZtav23u9nQ+L7X7o"
|
|
||||||
"DzSpihkR1Soe7NQAnuxEUmcIQpVuiKK1Z/xraGHntyuc42kI2QErvAZdZjPdsyDRYM/8Wm"
|
|
||||||
"IlotBjRrV0Mw93LqQ/w4MXzqfbatcltqzvBwVEp3PBM5W3DRzBOadbbVi+Jt9SK3rToW8o"
|
|
||||||
"0x9QJfkRLzR//2Rg1pxiwD6D2Bu/xo2SY/0hyb8U97g/fjp/3wfHHny1XJrACZIVaig0aV"
|
|
||||||
"fJbiVaNKPtOJnSfG5VShVVmFudc0dpNaWOWINJ9SmFwRySeUm2ORfihaPc9fC4qQHyPT9A"
|
|
||||||
"JhthUgHdFXK+yqZpDsU1yVsOgKdbUTKxPF8qqcnvX0VV12p0WZp/CTI6H2aYhYWvRQ9ljP"
|
|
||||||
"oLSOzQN5IH3uwbal+YgybGlyUJps+GjzCYTTP1hoplEs2sNgjrW3NpkyjXva1YR6LrpuP5"
|
|
||||||
"CRR7XPEDPAD4YRNSeaiXw0tCHsg5UoR9bowDvih1vowJErKJ92Fccwaas6Cm17iQk3CK+3"
|
|
||||||
"jayfXFK/WEuxvFiiWF7kFcsR7CKCGIEfK86oYjSzdvWEdC++CbyyENAlye1eHeE8dIKPiI"
|
|
||||||
"jKxlqxTT2jrJpEVfFtL42Xh541M8q+9ZEyqkl+9aGXhcRowl3F47sVwMZGDbDqhELJsuGK"
|
|
||||||
"MLSSzE1hWhOQm5f5G+VsU0kUf/Ft6G2DiU1b1nNiazKRax3WkXJVMjszbdUkaMaA5DEsna"
|
|
||||||
"NZXxHwKJOrmXaSKqVrpjAuEhYTc7BCX0zJv+vqjJGNkAlH9jigUhvWhMrX7bX+EsUES7mL"
|
|
||||||
"FamOJXpIaJBzK4otITfCGEMVEhOTznxwNC1OpTtK9KExzDlcnh09EKFuxt2AO/OAHWv9wP"
|
|
||||||
"c9ypnmgovZvoPjFkzzMZXvgjYaZUU0yshpy8tBOcNGqYwVC5v5DpoZZVOAs3ZN7JB4S9s3"
|
|
||||||
"JuDYZeBMGdVFXTsUmGJ/zoPZJQXCQcomg6W9P27y889nW0Apv0Xzw+nJ+Y/nP51dnv8ETe"
|
|
||||||
"RQ5iU/LgE3nzkpMdgktT9n2DhjxhkLk/w7MQ8p1rRyPdQF3UMLWzYE2sBHPit8d2lKdcru"
|
|
||||||
"gOnU84O/wtnUDmLcwJR6iizVYpdNW9XkmG9e7G70wiaFspnY5sXu5sXu6r7YnRE2dpGEWS"
|
|
||||||
"8s0zlTM2IaoSq3AyD60Ft/3lmNINm7fJxApkhBToO3SkTOTNxqHXkA1VOmCTvNp57YdJjM"
|
|
||||||
"PBWdYCm74qRQnNeRS/cgdOSegBn+MU1w2tBYnL1g4/pHYSF0ZkJf2JqnxsJOGCnHI+gwoF"
|
|
||||||
"gTdzeFgYg0Vxaqx5oNsR92MfTvhB0LA8matQn86kDzRkWuiIosIxrptJuka+Wtd8DtqhUh"
|
|
||||||
"VYjKrfUoWE5JnIkcqNZJoVaoMj2cZPhqC2q+Y8EwxqDgaXAhgDm77xI90T02AyE8GdExoS"
|
|
||||||
"AxhSAWmX+XWMolGaGw+Q6d7aDZpJ94k26UlOeppDR5WM8sD2vfSQ31z8JqYWqbE10RPUc1"
|
|
||||||
"R8uCZrRoU5kv5xU/S1+TEUcT+KTJMjvhIcVho3j9Xflt8+KH6QmTujzoPcQHc2BplAAxal"
|
|
||||||
"5PAPfyGbfCj3MXfxin+OPcB/sozt4O3Z19FKfENzZ2f7x8+x8fHBMe"
|
|
||||||
)
|
|
||||||
+2
-13
@@ -5,8 +5,7 @@ description = "Add your description here"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"langchain-postgres>=0.0.13",
|
"chromadb>=1.1.0",
|
||||||
"psycopg[binary]>=3.1.0",
|
|
||||||
"python-dotenv>=1.0.0",
|
"python-dotenv>=1.0.0",
|
||||||
"flask>=3.1.2",
|
"flask>=3.1.2",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
@@ -31,6 +30,7 @@ dependencies = [
|
|||||||
"asyncpg>=0.30.0",
|
"asyncpg>=0.30.0",
|
||||||
"langchain-openai>=1.1.6",
|
"langchain-openai>=1.1.6",
|
||||||
"langchain>=1.2.0",
|
"langchain>=1.2.0",
|
||||||
|
"langchain-chroma>=1.0.0",
|
||||||
"langchain-community>=0.4.1",
|
"langchain-community>=0.4.1",
|
||||||
"jq>=1.10.0",
|
"jq>=1.10.0",
|
||||||
"tavily-python>=0.7.17",
|
"tavily-python>=0.7.17",
|
||||||
@@ -42,17 +42,6 @@ dependencies = [
|
|||||||
"aioboto3>=13.0.0",
|
"aioboto3>=13.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
test = [
|
|
||||||
"pytest>=8.0.0",
|
|
||||||
"pytest-asyncio>=0.25.0",
|
|
||||||
"pytest-cov>=6.0.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
testpaths = ["tests"]
|
|
||||||
asyncio_mode = "auto"
|
|
||||||
|
|
||||||
[tool.aerich]
|
[tool.aerich]
|
||||||
tortoise_orm = "config.db.TORTOISE_CONFIG"
|
tortoise_orm = "config.db.TORTOISE_CONFIG"
|
||||||
location = "./migrations"
|
location = "./migrations"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { useCallback, useEffect, useState, useRef } from "react";
|
import { useEffect, useState, useRef } from "react";
|
||||||
import { LogOut, Shield, PanelLeftClose, PanelLeftOpen, Menu, X } from "lucide-react";
|
import { LogOut, Shield, PanelLeftClose, PanelLeftOpen, Menu, X } from "lucide-react";
|
||||||
import { conversationService } from "../api/conversationService";
|
import { conversationService } from "../api/conversationService";
|
||||||
import { userService } from "../api/userService";
|
import { userService } from "../api/userService";
|
||||||
@@ -63,13 +63,9 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
const abortControllerRef = useRef<AbortController | null>(null);
|
const abortControllerRef = useRef<AbortController | null>(null);
|
||||||
const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"];
|
const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"];
|
||||||
|
|
||||||
const scrollToBottom = useCallback(() => {
|
const scrollToBottom = () => {
|
||||||
requestAnimationFrame(() => {
|
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||||
messagesEndRef.current?.scrollIntoView({
|
};
|
||||||
behavior: isLoading ? "instant" : "smooth",
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}, [isLoading]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
isMountedRef.current = true;
|
isMountedRef.current = true;
|
||||||
@@ -120,7 +116,21 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
}, [messages]);
|
}, [messages]);
|
||||||
|
|
||||||
const handleQuestionSubmit = useCallback(async () => {
|
useEffect(() => {
|
||||||
|
const load = async () => {
|
||||||
|
if (!selectedConversation) return;
|
||||||
|
try {
|
||||||
|
const conv = await conversationService.getConversation(selectedConversation.id);
|
||||||
|
setSelectedConversation({ id: conv.id, title: conv.name });
|
||||||
|
setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })));
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Failed to load messages:", err);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
load();
|
||||||
|
}, [selectedConversation?.id]);
|
||||||
|
|
||||||
|
const handleQuestionSubmit = async () => {
|
||||||
if ((!query.trim() && !pendingImage) || isLoading) return;
|
if ((!query.trim() && !pendingImage) || isLoading) return;
|
||||||
|
|
||||||
let activeConversation = selectedConversation;
|
let activeConversation = selectedConversation;
|
||||||
@@ -201,28 +211,22 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
if (isMountedRef.current) {
|
if (isMountedRef.current) setIsLoading(false);
|
||||||
setIsLoading(false);
|
|
||||||
loadConversations();
|
|
||||||
}
|
|
||||||
abortControllerRef.current = null;
|
abortControllerRef.current = null;
|
||||||
}
|
}
|
||||||
}, [query, pendingImage, isLoading, selectedConversation, simbaMode, messages, setAuthenticated]);
|
};
|
||||||
|
|
||||||
const handleQueryChange = useCallback((event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
const handleQueryChange = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
setQuery(event.target.value);
|
setQuery(event.target.value);
|
||||||
}, []);
|
};
|
||||||
|
|
||||||
const handleKeyDown = useCallback((event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
const handleKeyDown = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
const kev = event as unknown as React.KeyboardEvent<HTMLTextAreaElement>;
|
const kev = event as unknown as React.KeyboardEvent<HTMLTextAreaElement>;
|
||||||
if (kev.key === "Enter" && !kev.shiftKey) {
|
if (kev.key === "Enter" && !kev.shiftKey) {
|
||||||
kev.preventDefault();
|
kev.preventDefault();
|
||||||
handleQuestionSubmit();
|
handleQuestionSubmit();
|
||||||
}
|
}
|
||||||
}, [handleQuestionSubmit]);
|
};
|
||||||
|
|
||||||
const handleImageSelect = useCallback((file: File) => setPendingImage(file), []);
|
|
||||||
const handleClearImage = useCallback(() => setPendingImage(null), []);
|
|
||||||
|
|
||||||
const handleLogout = () => {
|
const handleLogout = () => {
|
||||||
localStorage.removeItem("access_token");
|
localStorage.removeItem("access_token");
|
||||||
@@ -376,8 +380,8 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
setSimbaMode={setSimbaMode}
|
setSimbaMode={setSimbaMode}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
pendingImage={pendingImage}
|
pendingImage={pendingImage}
|
||||||
onImageSelect={handleImageSelect}
|
onImageSelect={(file) => setPendingImage(file)}
|
||||||
onClearImage={handleClearImage}
|
onClearImage={() => setPendingImage(null)}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -412,7 +416,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<footer className="border-t border-sand-light/40 bg-cream">
|
<footer className="border-t border-sand-light/40 bg-cream/80 backdrop-blur-sm">
|
||||||
<div className="max-w-2xl mx-auto px-4 py-3">
|
<div className="max-w-2xl mx-auto px-4 py-3">
|
||||||
<MessageInput
|
<MessageInput
|
||||||
query={query}
|
query={query}
|
||||||
@@ -421,9 +425,6 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
handleQuestionSubmit={handleQuestionSubmit}
|
handleQuestionSubmit={handleQuestionSubmit}
|
||||||
setSimbaMode={setSimbaMode}
|
setSimbaMode={setSimbaMode}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
pendingImage={pendingImage}
|
|
||||||
onImageSelect={(file) => setPendingImage(file)}
|
|
||||||
onClearImage={() => setPendingImage(null)}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</footer>
|
</footer>
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import React, { useEffect, useMemo, useRef, useState } from "react";
|
import { useRef, useState } from "react";
|
||||||
import { ArrowUp, ImagePlus, X } from "lucide-react";
|
import { ArrowUp, ImagePlus, X } from "lucide-react";
|
||||||
import { cn } from "../lib/utils";
|
import { cn } from "../lib/utils";
|
||||||
import { Textarea } from "./ui/textarea";
|
import { Textarea } from "./ui/textarea";
|
||||||
@@ -15,7 +15,7 @@ type MessageInputProps = {
|
|||||||
onClearImage: () => void;
|
onClearImage: () => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const MessageInput = React.memo(({
|
export const MessageInput = ({
|
||||||
query,
|
query,
|
||||||
handleKeyDown,
|
handleKeyDown,
|
||||||
handleQueryChange,
|
handleQueryChange,
|
||||||
@@ -29,18 +29,6 @@ export const MessageInput = React.memo(({
|
|||||||
const [simbaMode, setLocalSimbaMode] = useState(false);
|
const [simbaMode, setLocalSimbaMode] = useState(false);
|
||||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
// Create blob URL once per file, revoke on cleanup
|
|
||||||
const previewUrl = useMemo(
|
|
||||||
() => (pendingImage ? URL.createObjectURL(pendingImage) : null),
|
|
||||||
[pendingImage],
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
return () => {
|
|
||||||
if (previewUrl) URL.revokeObjectURL(previewUrl);
|
|
||||||
};
|
|
||||||
}, [previewUrl]);
|
|
||||||
|
|
||||||
const toggleSimbaMode = () => {
|
const toggleSimbaMode = () => {
|
||||||
const next = !simbaMode;
|
const next = !simbaMode;
|
||||||
setLocalSimbaMode(next);
|
setLocalSimbaMode(next);
|
||||||
@@ -71,7 +59,7 @@ export const MessageInput = React.memo(({
|
|||||||
<div className="px-3 pt-3">
|
<div className="px-3 pt-3">
|
||||||
<div className="relative inline-block">
|
<div className="relative inline-block">
|
||||||
<img
|
<img
|
||||||
src={previewUrl!}
|
src={URL.createObjectURL(pendingImage)}
|
||||||
alt="Pending upload"
|
alt="Pending upload"
|
||||||
className="h-20 rounded-lg object-cover border border-sand"
|
className="h-20 rounded-lg object-cover border border-sand"
|
||||||
/>
|
/>
|
||||||
@@ -157,4 +145,4 @@ export const MessageInput = React.memo(({
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
});
|
};
|
||||||
|
|||||||
@@ -6,19 +6,19 @@ import asyncio
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from blueprints.rag.logic import (
|
from blueprints.rag.logic import (
|
||||||
delete_all_documents,
|
|
||||||
get_vector_store_stats,
|
get_vector_store_stats,
|
||||||
index_documents,
|
index_documents,
|
||||||
list_all_documents,
|
list_all_documents,
|
||||||
|
vector_store,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def stats():
|
def stats():
|
||||||
"""Show vector store statistics."""
|
"""Show vector store statistics."""
|
||||||
s = get_vector_store_stats()
|
stats = get_vector_store_stats()
|
||||||
print("=== Vector Store Statistics ===")
|
print("=== Vector Store Statistics ===")
|
||||||
print(f"Collection: {s['collection_name']}")
|
print(f"Collection: {stats['collection_name']}")
|
||||||
print(f"Total Documents: {s['total_documents']}")
|
print(f"Total Documents: {stats['total_documents']}")
|
||||||
|
|
||||||
|
|
||||||
async def index():
|
async def index():
|
||||||
@@ -26,15 +26,23 @@ async def index():
|
|||||||
print("Starting indexing process...")
|
print("Starting indexing process...")
|
||||||
print("Fetching documents from Paperless-NGX...")
|
print("Fetching documents from Paperless-NGX...")
|
||||||
await index_documents()
|
await index_documents()
|
||||||
print("Indexing complete!")
|
print("✓ Indexing complete!")
|
||||||
stats()
|
stats()
|
||||||
|
|
||||||
|
|
||||||
async def reindex():
|
async def reindex():
|
||||||
"""Clear and reindex all documents."""
|
"""Clear and reindex all documents."""
|
||||||
print("Clearing existing documents...")
|
print("Clearing existing documents...")
|
||||||
delete_all_documents()
|
collection = vector_store._collection
|
||||||
print("Cleared")
|
all_docs = collection.get()
|
||||||
|
|
||||||
|
if all_docs["ids"]:
|
||||||
|
print(f"Deleting {len(all_docs['ids'])} existing documents...")
|
||||||
|
collection.delete(ids=all_docs["ids"])
|
||||||
|
print("✓ Cleared")
|
||||||
|
else:
|
||||||
|
print("Collection is already empty")
|
||||||
|
|
||||||
await index()
|
await index()
|
||||||
|
|
||||||
|
|
||||||
@@ -105,7 +113,7 @@ Examples:
|
|||||||
print("\n\nOperation cancelled by user")
|
print("\n\nOperation cancelled by user")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nError: {e}", file=sys.stderr)
|
print(f"\n❌ Error: {e}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
from bs4 import BeautifulSoup
|
||||||
|
import chromadb
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
client = chromadb.PersistentClient(path="/Users/ryanchen/Programs/raggr/chromadb")
|
||||||
|
|
||||||
|
# Scrape
|
||||||
|
BASE_URL = "https://www.vet.cornell.edu"
|
||||||
|
LIST_URL = "/departments-centers-and-institutes/cornell-feline-health-center/health-information/feline-health-topics"
|
||||||
|
|
||||||
|
QUERY_URL = BASE_URL + LIST_URL
|
||||||
|
r = httpx.get(QUERY_URL)
|
||||||
|
soup = BeautifulSoup(r.text)
|
||||||
|
|
||||||
|
container = soup.find("div", class_="field-body")
|
||||||
|
a_s = container.find_all("a", href=True)
|
||||||
|
|
||||||
|
new_texts = []
|
||||||
|
|
||||||
|
for link in a_s:
|
||||||
|
endpoint = link["href"]
|
||||||
|
query_url = BASE_URL + endpoint
|
||||||
|
r2 = httpx.get(query_url)
|
||||||
|
article_soup = BeautifulSoup(r2.text)
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
echo "Initializing directories..."
|
||||||
|
mkdir -p /app/data/chromadb
|
||||||
|
|
||||||
echo "Rebuilding frontend..."
|
echo "Rebuilding frontend..."
|
||||||
cd /app/raggr-frontend
|
cd /app/raggr-frontend
|
||||||
yarn build
|
yarn build
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# Ensure project root is on the path so imports work
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
|
||||||
|
|
||||||
# Set FERNET_KEY for tests that import email models (EncryptedTextField needs it at import time)
|
|
||||||
if "FERNET_KEY" not in os.environ:
|
|
||||||
from cryptography.fernet import Fernet
|
|
||||||
|
|
||||||
os.environ["FERNET_KEY"] = Fernet.generate_key().decode()
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
"""Tests for encryption/decryption in blueprints/email/crypto_service.py."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from cryptography.fernet import Fernet
|
|
||||||
|
|
||||||
|
|
||||||
# Generate a valid key for testing
|
|
||||||
TEST_FERNET_KEY = Fernet.generate_key().decode()
|
|
||||||
|
|
||||||
|
|
||||||
class TestEncryptedTextField:
|
|
||||||
@pytest.fixture
|
|
||||||
def field(self):
|
|
||||||
with patch.dict(os.environ, {"FERNET_KEY": TEST_FERNET_KEY}):
|
|
||||||
from blueprints.email.crypto_service import EncryptedTextField
|
|
||||||
|
|
||||||
return EncryptedTextField()
|
|
||||||
|
|
||||||
def test_encrypt_decrypt_roundtrip(self, field):
|
|
||||||
original = "my secret password"
|
|
||||||
encrypted = field.to_db_value(original, None)
|
|
||||||
decrypted = field.to_python_value(encrypted)
|
|
||||||
assert decrypted == original
|
|
||||||
assert encrypted != original
|
|
||||||
|
|
||||||
def test_none_passthrough(self, field):
|
|
||||||
assert field.to_db_value(None, None) is None
|
|
||||||
assert field.to_python_value(None) is None
|
|
||||||
|
|
||||||
def test_unicode_roundtrip(self, field):
|
|
||||||
original = "Hello 世界 🐱"
|
|
||||||
encrypted = field.to_db_value(original, None)
|
|
||||||
decrypted = field.to_python_value(encrypted)
|
|
||||||
assert decrypted == original
|
|
||||||
|
|
||||||
def test_empty_string_roundtrip(self, field):
|
|
||||||
encrypted = field.to_db_value("", None)
|
|
||||||
decrypted = field.to_python_value(encrypted)
|
|
||||||
assert decrypted == ""
|
|
||||||
|
|
||||||
def test_long_text_roundtrip(self, field):
|
|
||||||
original = "x" * 10000
|
|
||||||
encrypted = field.to_db_value(original, None)
|
|
||||||
decrypted = field.to_python_value(encrypted)
|
|
||||||
assert decrypted == original
|
|
||||||
|
|
||||||
def test_different_encryptions_differ(self, field):
|
|
||||||
"""Fernet includes a timestamp, so two encryptions of the same value differ."""
|
|
||||||
e1 = field.to_db_value("same", None)
|
|
||||||
e2 = field.to_db_value("same", None)
|
|
||||||
assert e1 != e2 # Different ciphertexts
|
|
||||||
assert field.to_python_value(e1) == field.to_python_value(e2) == "same"
|
|
||||||
|
|
||||||
def test_wrong_key_fails(self, field):
|
|
||||||
encrypted = field.to_db_value("secret", None)
|
|
||||||
|
|
||||||
# Create a field with a different key
|
|
||||||
other_key = Fernet.generate_key().decode()
|
|
||||||
with patch.dict(os.environ, {"FERNET_KEY": other_key}):
|
|
||||||
from blueprints.email.crypto_service import EncryptedTextField
|
|
||||||
|
|
||||||
other_field = EncryptedTextField()
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
other_field.to_python_value(encrypted)
|
|
||||||
|
|
||||||
|
|
||||||
class TestValidateFernetKey:
|
|
||||||
def test_valid_key(self):
|
|
||||||
with patch.dict(os.environ, {"FERNET_KEY": TEST_FERNET_KEY}):
|
|
||||||
from blueprints.email.crypto_service import validate_fernet_key
|
|
||||||
|
|
||||||
validate_fernet_key() # Should not raise
|
|
||||||
|
|
||||||
def test_missing_key(self):
|
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
|
||||||
os.environ.pop("FERNET_KEY", None)
|
|
||||||
from blueprints.email.crypto_service import validate_fernet_key
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="not set"):
|
|
||||||
validate_fernet_key()
|
|
||||||
|
|
||||||
def test_invalid_key(self):
|
|
||||||
with patch.dict(os.environ, {"FERNET_KEY": "not-a-valid-key"}):
|
|
||||||
from blueprints.email.crypto_service import validate_fernet_key
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="validation failed"):
|
|
||||||
validate_fernet_key()
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
"""Tests for email helper functions in blueprints/email/helpers.py."""
|
|
||||||
|
|
||||||
from blueprints.email.helpers import generate_email_token, get_user_email_address
|
|
||||||
|
|
||||||
|
|
||||||
class TestGenerateEmailToken:
|
|
||||||
def test_returns_16_char_hex(self):
|
|
||||||
token = generate_email_token("user-123", "my-secret")
|
|
||||||
assert len(token) == 16
|
|
||||||
assert all(c in "0123456789abcdef" for c in token)
|
|
||||||
|
|
||||||
def test_deterministic(self):
|
|
||||||
t1 = generate_email_token("user-123", "my-secret")
|
|
||||||
t2 = generate_email_token("user-123", "my-secret")
|
|
||||||
assert t1 == t2
|
|
||||||
|
|
||||||
def test_different_users_different_tokens(self):
|
|
||||||
t1 = generate_email_token("user-1", "secret")
|
|
||||||
t2 = generate_email_token("user-2", "secret")
|
|
||||||
assert t1 != t2
|
|
||||||
|
|
||||||
def test_different_secrets_different_tokens(self):
|
|
||||||
t1 = generate_email_token("user-1", "secret-a")
|
|
||||||
t2 = generate_email_token("user-1", "secret-b")
|
|
||||||
assert t1 != t2
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetUserEmailAddress:
|
|
||||||
def test_formats_correctly(self):
|
|
||||||
addr = get_user_email_address("abc123", "example.com")
|
|
||||||
assert addr == "ask+abc123@example.com"
|
|
||||||
|
|
||||||
def test_preserves_token(self):
|
|
||||||
token = "deadbeef12345678"
|
|
||||||
addr = get_user_email_address(token, "mail.test.org")
|
|
||||||
assert token in addr
|
|
||||||
assert addr.startswith("ask+")
|
|
||||||
assert "@mail.test.org" in addr
|
|
||||||
@@ -1,259 +0,0 @@
|
|||||||
"""Tests for ObsidianService markdown parsing and file operations."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# Set vault path before importing so __init__ validation passes
|
|
||||||
_test_vault_dir = None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def vault_dir(tmp_path):
|
|
||||||
"""Create a temporary vault directory with a sample .md file."""
|
|
||||||
global _test_vault_dir
|
|
||||||
_test_vault_dir = tmp_path
|
|
||||||
|
|
||||||
# Create a sample markdown file so vault validation passes
|
|
||||||
sample = tmp_path / "sample.md"
|
|
||||||
sample.write_text("# Sample\nHello world")
|
|
||||||
|
|
||||||
with patch.dict(os.environ, {"OBSIDIAN_VAULT_PATH": str(tmp_path)}):
|
|
||||||
yield tmp_path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def service(vault_dir):
|
|
||||||
from utils.obsidian_service import ObsidianService
|
|
||||||
|
|
||||||
return ObsidianService()
|
|
||||||
|
|
||||||
|
|
||||||
class TestParseMarkdown:
|
|
||||||
def test_extracts_frontmatter(self, service):
|
|
||||||
content = "---\ntitle: Test Note\ntags: [cat, vet]\n---\n\nBody content"
|
|
||||||
result = service.parse_markdown(content)
|
|
||||||
assert result["metadata"]["title"] == "Test Note"
|
|
||||||
assert result["metadata"]["tags"] == ["cat", "vet"]
|
|
||||||
|
|
||||||
def test_no_frontmatter(self, service):
|
|
||||||
content = "Just body content with no frontmatter"
|
|
||||||
result = service.parse_markdown(content)
|
|
||||||
assert result["metadata"] == {}
|
|
||||||
assert "Just body content" in result["content"]
|
|
||||||
|
|
||||||
def test_invalid_yaml_frontmatter(self, service):
|
|
||||||
content = "---\n: invalid: yaml: [[\n---\n\nBody"
|
|
||||||
result = service.parse_markdown(content)
|
|
||||||
assert result["metadata"] == {}
|
|
||||||
|
|
||||||
def test_extracts_tags(self, service):
|
|
||||||
content = "Some text with #tag1 and #tag2 here"
|
|
||||||
result = service.parse_markdown(content)
|
|
||||||
assert "tag1" in result["tags"]
|
|
||||||
assert "tag2" in result["tags"]
|
|
||||||
|
|
||||||
def test_extracts_wikilinks(self, service):
|
|
||||||
content = "Link to [[Other Note]] and [[Another Page]]"
|
|
||||||
result = service.parse_markdown(content)
|
|
||||||
assert "Other Note" in result["wikilinks"]
|
|
||||||
assert "Another Page" in result["wikilinks"]
|
|
||||||
|
|
||||||
def test_extracts_embeds(self, service):
|
|
||||||
content = "An embed [[!my_embed]] here"
|
|
||||||
result = service.parse_markdown(content)
|
|
||||||
assert "my_embed" in result["embeds"]
|
|
||||||
|
|
||||||
def test_cleans_wikilinks_from_content(self, service):
|
|
||||||
content = "Text with [[link]] included"
|
|
||||||
result = service.parse_markdown(content)
|
|
||||||
assert "[[" not in result["content"]
|
|
||||||
assert "]]" not in result["content"]
|
|
||||||
|
|
||||||
def test_filepath_passed_through(self, service):
|
|
||||||
result = service.parse_markdown("text", filepath=Path("/vault/note.md"))
|
|
||||||
assert result["filepath"] == "/vault/note.md"
|
|
||||||
|
|
||||||
def test_filepath_none_by_default(self, service):
|
|
||||||
result = service.parse_markdown("text")
|
|
||||||
assert result["filepath"] is None
|
|
||||||
|
|
||||||
def test_empty_content(self, service):
|
|
||||||
result = service.parse_markdown("")
|
|
||||||
assert result["metadata"] == {}
|
|
||||||
assert result["tags"] == []
|
|
||||||
assert result["wikilinks"] == []
|
|
||||||
assert result["embeds"] == []
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetDailyNotePath:
|
|
||||||
def test_formats_path_correctly(self, service):
|
|
||||||
date = datetime(2026, 3, 15)
|
|
||||||
path = service.get_daily_note_path(date)
|
|
||||||
assert path == "journal/2026/2026-03-15.md"
|
|
||||||
|
|
||||||
def test_defaults_to_today(self, service):
|
|
||||||
path = service.get_daily_note_path()
|
|
||||||
today = datetime.now()
|
|
||||||
assert today.strftime("%Y-%m-%d") in path
|
|
||||||
assert path.startswith(f"journal/{today.strftime('%Y')}/")
|
|
||||||
|
|
||||||
|
|
||||||
class TestWalkVault:
|
|
||||||
def test_finds_markdown_files(self, service, vault_dir):
|
|
||||||
(vault_dir / "note1.md").write_text("# Note 1")
|
|
||||||
(vault_dir / "subdir").mkdir()
|
|
||||||
(vault_dir / "subdir" / "note2.md").write_text("# Note 2")
|
|
||||||
|
|
||||||
files = service.walk_vault()
|
|
||||||
filenames = [f.name for f in files]
|
|
||||||
assert "sample.md" in filenames
|
|
||||||
assert "note1.md" in filenames
|
|
||||||
assert "note2.md" in filenames
|
|
||||||
|
|
||||||
def test_excludes_obsidian_dir(self, service, vault_dir):
|
|
||||||
obsidian_dir = vault_dir / ".obsidian"
|
|
||||||
obsidian_dir.mkdir()
|
|
||||||
(obsidian_dir / "config.md").write_text("config")
|
|
||||||
|
|
||||||
files = service.walk_vault()
|
|
||||||
filenames = [f.name for f in files]
|
|
||||||
assert "config.md" not in filenames
|
|
||||||
|
|
||||||
def test_ignores_non_md_files(self, service, vault_dir):
|
|
||||||
(vault_dir / "image.png").write_bytes(b"\x89PNG")
|
|
||||||
|
|
||||||
files = service.walk_vault()
|
|
||||||
filenames = [f.name for f in files]
|
|
||||||
assert "image.png" not in filenames
|
|
||||||
|
|
||||||
|
|
||||||
class TestCreateNote:
|
|
||||||
def test_creates_file(self, service, vault_dir):
|
|
||||||
path = service.create_note("My Test Note", "Body content")
|
|
||||||
full_path = vault_dir / path
|
|
||||||
assert full_path.exists()
|
|
||||||
|
|
||||||
def test_sanitizes_title(self, service, vault_dir):
|
|
||||||
path = service.create_note("Hello World! @#$", "Body")
|
|
||||||
assert "hello-world" in path
|
|
||||||
assert "@" not in path
|
|
||||||
assert "#" not in path
|
|
||||||
|
|
||||||
def test_includes_frontmatter(self, service, vault_dir):
|
|
||||||
path = service.create_note("Test", "Body", tags=["cat", "vet"])
|
|
||||||
full_path = vault_dir / path
|
|
||||||
content = full_path.read_text()
|
|
||||||
assert "---" in content
|
|
||||||
assert "created_by: simbarag" in content
|
|
||||||
assert "cat" in content
|
|
||||||
assert "vet" in content
|
|
||||||
|
|
||||||
def test_custom_folder(self, service, vault_dir):
|
|
||||||
path = service.create_note("Test", "Body", folder="custom/subfolder")
|
|
||||||
assert path.startswith("custom/subfolder/")
|
|
||||||
assert (vault_dir / path).exists()
|
|
||||||
|
|
||||||
|
|
||||||
class TestDailyNoteTasks:
|
|
||||||
def test_get_tasks_from_daily_note(self, service, vault_dir):
|
|
||||||
# Create a daily note with tasks
|
|
||||||
date = datetime(2026, 1, 15)
|
|
||||||
rel_path = service.get_daily_note_path(date)
|
|
||||||
note_path = vault_dir / rel_path
|
|
||||||
note_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
note_path.write_text(
|
|
||||||
"---\nmodified: 2026-01-15\n---\n"
|
|
||||||
"### tasks\n\n"
|
|
||||||
"- [ ] Feed the cat\n"
|
|
||||||
"- [x] Clean litter box\n"
|
|
||||||
"- [ ] Buy cat food\n\n"
|
|
||||||
"### log\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = service.get_daily_tasks(date)
|
|
||||||
assert result["found"] is True
|
|
||||||
assert len(result["tasks"]) == 3
|
|
||||||
assert result["tasks"][0] == {"text": "Feed the cat", "done": False}
|
|
||||||
assert result["tasks"][1] == {"text": "Clean litter box", "done": True}
|
|
||||||
assert result["tasks"][2] == {"text": "Buy cat food", "done": False}
|
|
||||||
|
|
||||||
def test_get_tasks_no_note(self, service):
|
|
||||||
date = datetime(2099, 12, 31)
|
|
||||||
result = service.get_daily_tasks(date)
|
|
||||||
assert result["found"] is False
|
|
||||||
assert result["tasks"] == []
|
|
||||||
|
|
||||||
def test_add_task_creates_note(self, service, vault_dir):
|
|
||||||
date = datetime(2026, 6, 1)
|
|
||||||
result = service.add_task_to_daily_note("Walk the cat", date)
|
|
||||||
assert result["success"] is True
|
|
||||||
assert result["created_note"] is True
|
|
||||||
|
|
||||||
# Verify file was created with the task
|
|
||||||
note_path = vault_dir / result["path"]
|
|
||||||
content = note_path.read_text()
|
|
||||||
assert "- [ ] Walk the cat" in content
|
|
||||||
|
|
||||||
def test_add_task_to_existing_note(self, service, vault_dir):
|
|
||||||
date = datetime(2026, 6, 2)
|
|
||||||
rel_path = service.get_daily_note_path(date)
|
|
||||||
note_path = vault_dir / rel_path
|
|
||||||
note_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
note_path.write_text(
|
|
||||||
"---\nmodified: 2026-06-02\n---\n"
|
|
||||||
"### tasks\n\n"
|
|
||||||
"- [ ] Existing task\n\n"
|
|
||||||
"### log\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = service.add_task_to_daily_note("New task", date)
|
|
||||||
assert result["success"] is True
|
|
||||||
assert result["created_note"] is False
|
|
||||||
|
|
||||||
content = note_path.read_text()
|
|
||||||
assert "- [ ] Existing task" in content
|
|
||||||
assert "- [ ] New task" in content
|
|
||||||
|
|
||||||
def test_complete_task_exact_match(self, service, vault_dir):
|
|
||||||
date = datetime(2026, 6, 3)
|
|
||||||
rel_path = service.get_daily_note_path(date)
|
|
||||||
note_path = vault_dir / rel_path
|
|
||||||
note_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
note_path.write_text("### tasks\n\n" "- [ ] Feed the cat\n" "- [ ] Buy food\n")
|
|
||||||
|
|
||||||
result = service.complete_task_in_daily_note("Feed the cat", date)
|
|
||||||
assert result["success"] is True
|
|
||||||
|
|
||||||
content = note_path.read_text()
|
|
||||||
assert "- [x] Feed the cat" in content
|
|
||||||
assert "- [ ] Buy food" in content # Other task unchanged
|
|
||||||
|
|
||||||
def test_complete_task_partial_match(self, service, vault_dir):
|
|
||||||
date = datetime(2026, 6, 4)
|
|
||||||
rel_path = service.get_daily_note_path(date)
|
|
||||||
note_path = vault_dir / rel_path
|
|
||||||
note_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
note_path.write_text("### tasks\n\n- [ ] Feed the cat at 5pm\n")
|
|
||||||
|
|
||||||
result = service.complete_task_in_daily_note("Feed the cat", date)
|
|
||||||
assert result["success"] is True
|
|
||||||
|
|
||||||
def test_complete_task_not_found(self, service, vault_dir):
|
|
||||||
date = datetime(2026, 6, 5)
|
|
||||||
rel_path = service.get_daily_note_path(date)
|
|
||||||
note_path = vault_dir / rel_path
|
|
||||||
note_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
note_path.write_text("### tasks\n\n- [ ] Feed the cat\n")
|
|
||||||
|
|
||||||
result = service.complete_task_in_daily_note("Walk the dog", date)
|
|
||||||
assert result["success"] is False
|
|
||||||
assert "not found" in result["error"]
|
|
||||||
|
|
||||||
def test_complete_task_no_note(self, service):
|
|
||||||
date = datetime(2099, 12, 31)
|
|
||||||
result = service.complete_task_in_daily_note("Something", date)
|
|
||||||
assert result["success"] is False
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
"""Tests for rate limiting logic in email and WhatsApp blueprints."""
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
class TestEmailRateLimit:
|
|
||||||
def setup_method(self):
|
|
||||||
"""Reset rate limit store before each test."""
|
|
||||||
from blueprints.email import _rate_limit_store
|
|
||||||
|
|
||||||
_rate_limit_store.clear()
|
|
||||||
|
|
||||||
def test_allows_under_limit(self):
|
|
||||||
from blueprints.email import _check_rate_limit, RATE_LIMIT_MAX
|
|
||||||
|
|
||||||
for _ in range(RATE_LIMIT_MAX):
|
|
||||||
assert _check_rate_limit("sender@test.com") is True
|
|
||||||
|
|
||||||
def test_blocks_at_limit(self):
|
|
||||||
from blueprints.email import _check_rate_limit, RATE_LIMIT_MAX
|
|
||||||
|
|
||||||
for _ in range(RATE_LIMIT_MAX):
|
|
||||||
_check_rate_limit("sender@test.com")
|
|
||||||
|
|
||||||
assert _check_rate_limit("sender@test.com") is False
|
|
||||||
|
|
||||||
def test_different_senders_independent(self):
|
|
||||||
from blueprints.email import _check_rate_limit, RATE_LIMIT_MAX
|
|
||||||
|
|
||||||
for _ in range(RATE_LIMIT_MAX):
|
|
||||||
_check_rate_limit("user1@test.com")
|
|
||||||
|
|
||||||
# user1 is at limit, but user2 should be fine
|
|
||||||
assert _check_rate_limit("user1@test.com") is False
|
|
||||||
assert _check_rate_limit("user2@test.com") is True
|
|
||||||
|
|
||||||
def test_window_expiry(self):
|
|
||||||
from blueprints.email import (
|
|
||||||
_check_rate_limit,
|
|
||||||
_rate_limit_store,
|
|
||||||
RATE_LIMIT_MAX,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fill up the rate limit with timestamps in the past
|
|
||||||
past = time.monotonic() - 999 # Well beyond any window
|
|
||||||
_rate_limit_store["old@test.com"] = [past] * RATE_LIMIT_MAX
|
|
||||||
|
|
||||||
# Should be allowed because all timestamps are expired
|
|
||||||
assert _check_rate_limit("old@test.com") is True
|
|
||||||
|
|
||||||
|
|
||||||
class TestWhatsAppRateLimit:
|
|
||||||
def setup_method(self):
|
|
||||||
"""Reset rate limit store before each test."""
|
|
||||||
from blueprints.whatsapp import _rate_limit_store
|
|
||||||
|
|
||||||
_rate_limit_store.clear()
|
|
||||||
|
|
||||||
def test_allows_under_limit(self):
|
|
||||||
from blueprints.whatsapp import _check_rate_limit, RATE_LIMIT_MAX
|
|
||||||
|
|
||||||
for _ in range(RATE_LIMIT_MAX):
|
|
||||||
assert _check_rate_limit("whatsapp:+1234567890") is True
|
|
||||||
|
|
||||||
def test_blocks_at_limit(self):
|
|
||||||
from blueprints.whatsapp import _check_rate_limit, RATE_LIMIT_MAX
|
|
||||||
|
|
||||||
for _ in range(RATE_LIMIT_MAX):
|
|
||||||
_check_rate_limit("whatsapp:+1234567890")
|
|
||||||
|
|
||||||
assert _check_rate_limit("whatsapp:+1234567890") is False
|
|
||||||
|
|
||||||
def test_different_numbers_independent(self):
|
|
||||||
from blueprints.whatsapp import _check_rate_limit, RATE_LIMIT_MAX
|
|
||||||
|
|
||||||
for _ in range(RATE_LIMIT_MAX):
|
|
||||||
_check_rate_limit("whatsapp:+1111111111")
|
|
||||||
|
|
||||||
assert _check_rate_limit("whatsapp:+1111111111") is False
|
|
||||||
assert _check_rate_limit("whatsapp:+2222222222") is True
|
|
||||||
|
|
||||||
def test_window_expiry(self):
|
|
||||||
from blueprints.whatsapp import (
|
|
||||||
_check_rate_limit,
|
|
||||||
_rate_limit_store,
|
|
||||||
RATE_LIMIT_MAX,
|
|
||||||
)
|
|
||||||
|
|
||||||
past = time.monotonic() - 999
|
|
||||||
_rate_limit_store["whatsapp:+9999999999"] = [past] * RATE_LIMIT_MAX
|
|
||||||
|
|
||||||
assert _check_rate_limit("whatsapp:+9999999999") is True
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
"""Tests for User model methods in blueprints/users/models.py."""
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import bcrypt
|
|
||||||
|
|
||||||
|
|
||||||
class TestUserModelMethods:
|
|
||||||
"""Test User model methods without requiring a database connection.
|
|
||||||
|
|
||||||
We instantiate a mock object with the same methods as User
|
|
||||||
to avoid Tortoise ORM initialization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _make_user(self, ldap_groups=None, password=None):
|
|
||||||
"""Create a mock user with real method implementations."""
|
|
||||||
from blueprints.users.models import User
|
|
||||||
|
|
||||||
user = MagicMock(spec=User)
|
|
||||||
user.ldap_groups = ldap_groups
|
|
||||||
user.password = password
|
|
||||||
|
|
||||||
# Bind real methods
|
|
||||||
user.has_group = lambda group: group in (user.ldap_groups or [])
|
|
||||||
user.is_admin = lambda: user.has_group("lldap_admin")
|
|
||||||
|
|
||||||
def set_password(plain):
|
|
||||||
user.password = bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt())
|
|
||||||
|
|
||||||
user.set_password = set_password
|
|
||||||
|
|
||||||
def verify_password(plain):
|
|
||||||
if not user.password:
|
|
||||||
return False
|
|
||||||
return bcrypt.checkpw(plain.encode("utf-8"), user.password)
|
|
||||||
|
|
||||||
user.verify_password = verify_password
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
def test_has_group_true(self):
|
|
||||||
user = self._make_user(ldap_groups=["lldap_admin", "users"])
|
|
||||||
assert user.has_group("lldap_admin") is True
|
|
||||||
assert user.has_group("users") is True
|
|
||||||
|
|
||||||
def test_has_group_false(self):
|
|
||||||
user = self._make_user(ldap_groups=["users"])
|
|
||||||
assert user.has_group("lldap_admin") is False
|
|
||||||
|
|
||||||
def test_has_group_empty_list(self):
|
|
||||||
user = self._make_user(ldap_groups=[])
|
|
||||||
assert user.has_group("anything") is False
|
|
||||||
|
|
||||||
def test_has_group_none(self):
|
|
||||||
user = self._make_user(ldap_groups=None)
|
|
||||||
assert user.has_group("anything") is False
|
|
||||||
|
|
||||||
def test_is_admin_true(self):
|
|
||||||
user = self._make_user(ldap_groups=["lldap_admin"])
|
|
||||||
assert user.is_admin() is True
|
|
||||||
|
|
||||||
def test_is_admin_false(self):
|
|
||||||
user = self._make_user(ldap_groups=["users"])
|
|
||||||
assert user.is_admin() is False
|
|
||||||
|
|
||||||
def test_is_admin_empty(self):
|
|
||||||
user = self._make_user(ldap_groups=[])
|
|
||||||
assert user.is_admin() is False
|
|
||||||
|
|
||||||
def test_set_and_verify_password(self):
|
|
||||||
user = self._make_user()
|
|
||||||
user.set_password("hunter2")
|
|
||||||
assert user.password is not None
|
|
||||||
assert user.verify_password("hunter2") is True
|
|
||||||
assert user.verify_password("wrong") is False
|
|
||||||
|
|
||||||
def test_verify_password_no_password_set(self):
|
|
||||||
user = self._make_user(password=None)
|
|
||||||
assert user.verify_password("anything") is False
|
|
||||||
|
|
||||||
def test_password_is_hashed(self):
|
|
||||||
user = self._make_user()
|
|
||||||
user.set_password("mypassword")
|
|
||||||
# The stored password should not be the plaintext
|
|
||||||
assert user.password != b"mypassword"
|
|
||||||
assert user.password != "mypassword"
|
|
||||||
@@ -1,254 +0,0 @@
|
|||||||
"""Tests for YNAB service data formatting and filtering logic."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_category(
|
|
||||||
name, budgeted, activity, balance, deleted=False, hidden=False, goal_type=None
|
|
||||||
):
|
|
||||||
cat = MagicMock()
|
|
||||||
cat.name = name
|
|
||||||
cat.budgeted = budgeted
|
|
||||||
cat.activity = activity
|
|
||||||
cat.balance = balance
|
|
||||||
cat.deleted = deleted
|
|
||||||
cat.hidden = hidden
|
|
||||||
cat.goal_type = goal_type
|
|
||||||
return cat
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_transaction(
|
|
||||||
var_date, payee_name, category_name, amount, memo="", deleted=False, approved=True
|
|
||||||
):
|
|
||||||
txn = MagicMock()
|
|
||||||
txn.var_date = var_date
|
|
||||||
txn.payee_name = payee_name
|
|
||||||
txn.category_name = category_name
|
|
||||||
txn.amount = amount
|
|
||||||
txn.memo = memo
|
|
||||||
txn.deleted = deleted
|
|
||||||
txn.approved = approved
|
|
||||||
return txn
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def ynab_service():
|
|
||||||
"""Create a YNABService with mocked API client."""
|
|
||||||
with patch.dict(
|
|
||||||
os.environ, {"YNAB_ACCESS_TOKEN": "fake-token", "YNAB_BUDGET_ID": "test-budget"}
|
|
||||||
):
|
|
||||||
with patch("utils.ynab_service.ynab") as mock_ynab:
|
|
||||||
# Mock the configuration and API client chain
|
|
||||||
mock_ynab.Configuration.return_value = MagicMock()
|
|
||||||
mock_ynab.ApiClient.return_value = MagicMock()
|
|
||||||
mock_ynab.PlansApi.return_value = MagicMock()
|
|
||||||
mock_ynab.TransactionsApi.return_value = MagicMock()
|
|
||||||
mock_ynab.MonthsApi.return_value = MagicMock()
|
|
||||||
mock_ynab.CategoriesApi.return_value = MagicMock()
|
|
||||||
|
|
||||||
from utils.ynab_service import YNABService
|
|
||||||
|
|
||||||
service = YNABService()
|
|
||||||
yield service
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetBudgetSummary:
|
|
||||||
def test_calculates_totals(self, ynab_service):
|
|
||||||
categories = [
|
|
||||||
_mock_category("Groceries", 500_000, -350_000, 150_000),
|
|
||||||
_mock_category("Rent", 1_500_000, -1_500_000, 0),
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_month = MagicMock()
|
|
||||||
mock_month.to_be_budgeted = 200_000
|
|
||||||
|
|
||||||
mock_budget = MagicMock()
|
|
||||||
mock_budget.name = "My Budget"
|
|
||||||
mock_budget.months = [mock_month]
|
|
||||||
mock_budget.categories = categories
|
|
||||||
mock_budget.currency_format = MagicMock(iso_code="USD")
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.data.budget = mock_budget
|
|
||||||
ynab_service.plans_api.get_plan_by_id.return_value = mock_response
|
|
||||||
|
|
||||||
result = ynab_service.get_budget_summary()
|
|
||||||
|
|
||||||
assert result["budget_name"] == "My Budget"
|
|
||||||
assert result["to_be_budgeted"] == 200.0
|
|
||||||
assert result["total_budgeted"] == 2000.0 # (500k + 1500k) / 1000
|
|
||||||
assert result["total_activity"] == -1850.0
|
|
||||||
assert result["currency_format"] == "USD"
|
|
||||||
|
|
||||||
def test_skips_deleted_and_hidden(self, ynab_service):
|
|
||||||
categories = [
|
|
||||||
_mock_category("Active", 100_000, -50_000, 50_000),
|
|
||||||
_mock_category("Deleted", 999_000, -999_000, 0, deleted=True),
|
|
||||||
_mock_category("Hidden", 999_000, -999_000, 0, hidden=True),
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_month = MagicMock()
|
|
||||||
mock_month.to_be_budgeted = 0
|
|
||||||
mock_budget = MagicMock()
|
|
||||||
mock_budget.name = "Budget"
|
|
||||||
mock_budget.months = [mock_month]
|
|
||||||
mock_budget.categories = categories
|
|
||||||
mock_budget.currency_format = None
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.data.budget = mock_budget
|
|
||||||
ynab_service.plans_api.get_plan_by_id.return_value = mock_response
|
|
||||||
|
|
||||||
result = ynab_service.get_budget_summary()
|
|
||||||
assert result["total_budgeted"] == 100.0
|
|
||||||
assert result["currency_format"] == "USD" # Default fallback
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetTransactions:
|
|
||||||
def test_filters_by_date_range(self, ynab_service):
|
|
||||||
transactions = [
|
|
||||||
_mock_transaction("2026-01-05", "Store", "Groceries", -25_000),
|
|
||||||
_mock_transaction("2026-01-15", "Gas", "Transport", -40_000),
|
|
||||||
_mock_transaction(
|
|
||||||
"2026-02-01", "Store", "Groceries", -30_000
|
|
||||||
), # Out of range
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.data.transactions = transactions
|
|
||||||
ynab_service.transactions_api.get_transactions.return_value = mock_response
|
|
||||||
|
|
||||||
result = ynab_service.get_transactions(
|
|
||||||
start_date="2026-01-01", end_date="2026-01-31"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["count"] == 2
|
|
||||||
assert result["total_amount"] == -65.0
|
|
||||||
|
|
||||||
def test_filters_by_category(self, ynab_service):
|
|
||||||
transactions = [
|
|
||||||
_mock_transaction("2026-01-05", "Store", "Groceries", -25_000),
|
|
||||||
_mock_transaction("2026-01-06", "Gas", "Transport", -40_000),
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.data.transactions = transactions
|
|
||||||
ynab_service.transactions_api.get_transactions.return_value = mock_response
|
|
||||||
|
|
||||||
result = ynab_service.get_transactions(
|
|
||||||
start_date="2026-01-01",
|
|
||||||
end_date="2026-01-31",
|
|
||||||
category_name="groceries", # Case insensitive
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["count"] == 1
|
|
||||||
assert result["transactions"][0]["category"] == "Groceries"
|
|
||||||
|
|
||||||
def test_filters_by_payee(self, ynab_service):
|
|
||||||
transactions = [
|
|
||||||
_mock_transaction("2026-01-05", "Whole Foods", "Groceries", -25_000),
|
|
||||||
_mock_transaction("2026-01-06", "Shell Gas", "Transport", -40_000),
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.data.transactions = transactions
|
|
||||||
ynab_service.transactions_api.get_transactions.return_value = mock_response
|
|
||||||
|
|
||||||
result = ynab_service.get_transactions(
|
|
||||||
start_date="2026-01-01",
|
|
||||||
end_date="2026-01-31",
|
|
||||||
payee_name="whole",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["count"] == 1
|
|
||||||
assert result["transactions"][0]["payee"] == "Whole Foods"
|
|
||||||
|
|
||||||
def test_skips_deleted(self, ynab_service):
|
|
||||||
transactions = [
|
|
||||||
_mock_transaction("2026-01-05", "Store", "Groceries", -25_000),
|
|
||||||
_mock_transaction("2026-01-06", "Deleted", "Other", -10_000, deleted=True),
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.data.transactions = transactions
|
|
||||||
ynab_service.transactions_api.get_transactions.return_value = mock_response
|
|
||||||
|
|
||||||
result = ynab_service.get_transactions(
|
|
||||||
start_date="2026-01-01", end_date="2026-01-31"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["count"] == 1
|
|
||||||
|
|
||||||
def test_converts_milliunits(self, ynab_service):
|
|
||||||
transactions = [
|
|
||||||
_mock_transaction("2026-01-05", "Store", "Groceries", -12_340),
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.data.transactions = transactions
|
|
||||||
ynab_service.transactions_api.get_transactions.return_value = mock_response
|
|
||||||
|
|
||||||
result = ynab_service.get_transactions(
|
|
||||||
start_date="2026-01-01", end_date="2026-01-31"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["transactions"][0]["amount"] == -12.34
|
|
||||||
|
|
||||||
def test_sorts_by_date_descending(self, ynab_service):
|
|
||||||
transactions = [
|
|
||||||
_mock_transaction("2026-01-01", "A", "Cat", -10_000),
|
|
||||||
_mock_transaction("2026-01-15", "B", "Cat", -20_000),
|
|
||||||
_mock_transaction("2026-01-10", "C", "Cat", -30_000),
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.data.transactions = transactions
|
|
||||||
ynab_service.transactions_api.get_transactions.return_value = mock_response
|
|
||||||
|
|
||||||
result = ynab_service.get_transactions(
|
|
||||||
start_date="2026-01-01", end_date="2026-01-31"
|
|
||||||
)
|
|
||||||
|
|
||||||
dates = [t["date"] for t in result["transactions"]]
|
|
||||||
assert dates == sorted(dates, reverse=True)
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetCategorySpending:
|
|
||||||
def test_month_format_normalization(self, ynab_service):
|
|
||||||
"""Passing YYYY-MM should be normalized to YYYY-MM-01."""
|
|
||||||
categories = [_mock_category("Food", 100_000, -50_000, 50_000)]
|
|
||||||
|
|
||||||
mock_month = MagicMock()
|
|
||||||
mock_month.categories = categories
|
|
||||||
mock_month.to_be_budgeted = 0
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.data.month = mock_month
|
|
||||||
ynab_service.months_api.get_plan_month.return_value = mock_response
|
|
||||||
|
|
||||||
result = ynab_service.get_category_spending("2026-03")
|
|
||||||
|
|
||||||
assert result["month"] == "2026-03"
|
|
||||||
|
|
||||||
def test_identifies_overspent(self, ynab_service):
|
|
||||||
categories = [
|
|
||||||
_mock_category("Dining", 200_000, -300_000, -100_000), # Overspent
|
|
||||||
_mock_category("Groceries", 500_000, -400_000, 100_000), # Fine
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_month = MagicMock()
|
|
||||||
mock_month.categories = categories
|
|
||||||
mock_month.to_be_budgeted = 0
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.data.month = mock_month
|
|
||||||
ynab_service.months_api.get_plan_month.return_value = mock_response
|
|
||||||
|
|
||||||
result = ynab_service.get_category_spending("2026-03")
|
|
||||||
|
|
||||||
assert len(result["overspent_categories"]) == 1
|
|
||||||
assert result["overspent_categories"][0]["name"] == "Dining"
|
|
||||||
assert result["overspent_categories"][0]["overspent_by"] == 100.0
|
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
import os
|
||||||
|
from math import ceil
|
||||||
|
import re
|
||||||
|
from typing import Union
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
|
OpenAIEmbeddingFunction,
|
||||||
|
)
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from llm import LLMClient
|
||||||
|
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
|
||||||
|
if header_patterns is None:
|
||||||
|
header_patterns = [r"^.*Header.*$"]
|
||||||
|
if footer_patterns is None:
|
||||||
|
footer_patterns = [r"^.*Footer.*$"]
|
||||||
|
|
||||||
|
for pattern in header_patterns + footer_patterns:
|
||||||
|
text = re.sub(pattern, "", text, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def remove_special_characters(text, special_chars=None):
|
||||||
|
if special_chars is None:
|
||||||
|
special_chars = r"[^A-Za-z0-9\s\.,;:\'\"\?\!\-]"
|
||||||
|
|
||||||
|
text = re.sub(special_chars, "", text)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def remove_repeated_substrings(text, pattern=r"\.{2,}"):
|
||||||
|
text = re.sub(pattern, ".", text)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def remove_extra_spaces(text):
|
||||||
|
text = re.sub(r"\n\s*\n", "\n\n", text)
|
||||||
|
text = re.sub(r"\s+", " ", text)
|
||||||
|
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_text(text):
|
||||||
|
# Remove headers and footers
|
||||||
|
text = remove_headers_footers(text)
|
||||||
|
|
||||||
|
# Remove special characters
|
||||||
|
text = remove_special_characters(text)
|
||||||
|
|
||||||
|
# Remove repeated substrings like dots
|
||||||
|
text = remove_repeated_substrings(text)
|
||||||
|
|
||||||
|
# Remove extra spaces between lines and within lines
|
||||||
|
text = remove_extra_spaces(text)
|
||||||
|
|
||||||
|
# Additional cleaning steps can be added here
|
||||||
|
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
size: int,
|
||||||
|
document_id: UUID,
|
||||||
|
chunk_id: int,
|
||||||
|
embedding,
|
||||||
|
):
|
||||||
|
self.text = text
|
||||||
|
self.size = size
|
||||||
|
self.document_id = document_id
|
||||||
|
self.chunk_id = chunk_id
|
||||||
|
self.embedding = embedding
|
||||||
|
|
||||||
|
|
||||||
|
class Chunker:
|
||||||
|
def __init__(self, collection) -> None:
|
||||||
|
self.collection = collection
|
||||||
|
self.llm_client = LLMClient()
|
||||||
|
|
||||||
|
def embedding_fx(self, inputs):
|
||||||
|
openai_embedding_fx = OpenAIEmbeddingFunction(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
)
|
||||||
|
return openai_embedding_fx(inputs)
|
||||||
|
|
||||||
|
def chunk_document(
|
||||||
|
self,
|
||||||
|
document: str,
|
||||||
|
chunk_size: int = 1000,
|
||||||
|
metadata: dict[str, Union[str, float]] = {},
|
||||||
|
) -> list[Chunk]:
|
||||||
|
doc_uuid = uuid4()
|
||||||
|
|
||||||
|
chunk_size = min(chunk_size, len(document)) or 1
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
num_chunks = ceil(len(document) / chunk_size)
|
||||||
|
document_length = len(document)
|
||||||
|
|
||||||
|
for i in range(num_chunks):
|
||||||
|
curr_pos = i * num_chunks
|
||||||
|
to_pos = (
|
||||||
|
curr_pos + chunk_size
|
||||||
|
if curr_pos + chunk_size < document_length
|
||||||
|
else document_length
|
||||||
|
)
|
||||||
|
text_chunk = self.clean_document(document[curr_pos:to_pos])
|
||||||
|
|
||||||
|
embedding = self.embedding_fx([text_chunk])
|
||||||
|
self.collection.add(
|
||||||
|
ids=[str(doc_uuid) + ":" + str(i)],
|
||||||
|
documents=[text_chunk],
|
||||||
|
embeddings=embedding,
|
||||||
|
metadatas=[metadata],
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def clean_document(self, document: str) -> str:
|
||||||
|
"""This function will remove information that is noise or already known.
|
||||||
|
|
||||||
|
Example: We already know all the things in here are Simba-related, so we don't need things like
|
||||||
|
"Sumamry of simba's visit"
|
||||||
|
"""
|
||||||
|
|
||||||
|
document = document.replace("\\n", "")
|
||||||
|
document = document.strip()
|
||||||
|
|
||||||
|
return preprocess_text(document)
|
||||||
Reference in New Issue
Block a user