Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| db977270a3 | |||
| bac773ae4b | |||
| 564a9b68a5 | |||
| 7742673cc0 | |||
| c157c37cde | |||
| 3b8fa3e7a0 | |||
| 438399646f | |||
| 9ed4ca126a | |||
| f3ae76ce68 | |||
| 7ee3bdef84 | |||
| 500c44feb1 | |||
| 896501deb1 | |||
| c95800e65d | |||
| 90372a6a6d | |||
| c01764243f | |||
| dfaac4caf8 | |||
| 17c3a2f888 | |||
| fa0f68e3b4 | |||
| a6c698c6bd | |||
| 07c272c96a | |||
| 975a337af4 | |||
| e644def141 | |||
| 3671926430 | |||
| be600e78d6 | |||
| b6576fb2fd | |||
| bb3ef4fe95 | |||
| 30db71d134 | |||
| 167d014ca5 | |||
| fa9d5af1fb | |||
| a7726654ff | |||
| c8306e6702 |
@@ -19,11 +19,6 @@ BASE_URL=192.168.1.5:8000
|
|||||||
LLAMA_SERVER_URL=http://192.168.1.213:8080/v1
|
LLAMA_SERVER_URL=http://192.168.1.213:8080/v1
|
||||||
LLAMA_MODEL_NAME=llama-3.1-8b-instruct
|
LLAMA_MODEL_NAME=llama-3.1-8b-instruct
|
||||||
|
|
||||||
# ChromaDB Configuration
|
|
||||||
# For Docker: This is automatically set to /app/data/chromadb
|
|
||||||
# For local development: Set to a local directory path
|
|
||||||
CHROMADB_PATH=./data/chromadb
|
|
||||||
|
|
||||||
# OpenAI Configuration
|
# OpenAI Configuration
|
||||||
OPENAI_API_KEY=your-openai-api-key
|
OPENAI_API_KEY=your-openai-api-key
|
||||||
|
|
||||||
|
|||||||
@@ -13,9 +13,6 @@ wheels/
|
|||||||
.env
|
.env
|
||||||
|
|
||||||
# Database files
|
# Database files
|
||||||
chromadb/
|
|
||||||
chromadb_openai/
|
|
||||||
chroma_db/
|
|
||||||
database/
|
database/
|
||||||
*.db
|
*.db
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|||||||
|
|
||||||
## Project Overview
|
## Project Overview
|
||||||
|
|
||||||
SimbaRAG is a RAG (Retrieval-Augmented Generation) conversational AI system for querying information about Simba (a cat). It ingests documents from Paperless-NGX, stores embeddings in ChromaDB, and uses LLMs (Ollama or OpenAI) to answer questions.
|
SimbaRAG is a RAG (Retrieval-Augmented Generation) conversational AI system for querying information about Simba (a cat). It ingests documents from Paperless-NGX, stores embeddings in PostgreSQL via pgvector, and uses LLMs (Ollama or OpenAI) to answer questions.
|
||||||
|
|
||||||
## Commands
|
## Commands
|
||||||
|
|
||||||
@@ -54,9 +54,8 @@ docker compose up -d
|
|||||||
│ Docker Compose │
|
│ Docker Compose │
|
||||||
├─────────────────────────────────────────────────────────────┤
|
├─────────────────────────────────────────────────────────────┤
|
||||||
│ raggr (port 8080) │ postgres (port 5432) │
|
│ raggr (port 8080) │ postgres (port 5432) │
|
||||||
│ ├── Quart backend │ PostgreSQL 16 │
|
│ ├── Quart backend │ PostgreSQL 16 + pgvector│
|
||||||
│ ├── React frontend (served) │ │
|
│ └── React frontend (served) │ │
|
||||||
│ └── ChromaDB (volume) │ │
|
|
||||||
└─────────────────────────────────────────────────────────────┘
|
└─────────────────────────────────────────────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -91,6 +90,15 @@ docker compose up -d
|
|||||||
|
|
||||||
**Auth Flow**: LLDAP → Authelia (OIDC) → Backend JWT → Frontend localStorage
|
**Auth Flow**: LLDAP → Authelia (OIDC) → Backend JWT → Frontend localStorage
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Always run `make test` before pushing code to ensure all tests pass.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make test # Run tests
|
||||||
|
make test-cov # Run tests with coverage
|
||||||
|
```
|
||||||
|
|
||||||
## Key Patterns
|
## Key Patterns
|
||||||
|
|
||||||
- All endpoints are async (`async def`)
|
- All endpoints are async (`async def`)
|
||||||
|
|||||||
+2
-3
@@ -37,15 +37,14 @@ WORKDIR /app/raggr-frontend
|
|||||||
RUN yarn install && yarn build
|
RUN yarn install && yarn build
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Create ChromaDB and database directories
|
# Create database directory
|
||||||
RUN mkdir -p /app/chromadb /app/database
|
RUN mkdir -p /app/database
|
||||||
|
|
||||||
# Expose port
|
# Expose port
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV PYTHONPATH=/app
|
ENV PYTHONPATH=/app
|
||||||
ENV CHROMADB_PATH=/app/chromadb
|
|
||||||
|
|
||||||
# Run the startup script
|
# Run the startup script
|
||||||
CMD ["./startup.sh"]
|
CMD ["./startup.sh"]
|
||||||
|
|||||||
+2
-3
@@ -34,16 +34,15 @@ COPY . .
|
|||||||
WORKDIR /app/raggr-frontend
|
WORKDIR /app/raggr-frontend
|
||||||
RUN yarn build
|
RUN yarn build
|
||||||
|
|
||||||
# Create ChromaDB and database directories
|
# Create database directory
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
RUN mkdir -p /app/chromadb /app/database
|
RUN mkdir -p /app/database
|
||||||
|
|
||||||
# Make startup script executable
|
# Make startup script executable
|
||||||
RUN chmod +x /app/startup-dev.sh
|
RUN chmod +x /app/startup-dev.sh
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV PYTHONPATH=/app
|
ENV PYTHONPATH=/app
|
||||||
ENV CHROMADB_PATH=/app/chromadb
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
# Expose port
|
# Expose port
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
.PHONY: deploy build up down restart logs migrate migrate-new frontend
|
.PHONY: deploy redeploy build up down restart logs migrate migrate-new frontend test
|
||||||
|
|
||||||
# Build and deploy
|
# Build and deploy
|
||||||
deploy: build up
|
deploy: build up
|
||||||
|
|
||||||
|
redeploy:
|
||||||
|
git pull && $(MAKE) down && $(MAKE) up
|
||||||
|
|
||||||
build:
|
build:
|
||||||
docker compose build raggr
|
docker compose build raggr
|
||||||
|
|
||||||
@@ -29,6 +32,13 @@ migrate-new:
|
|||||||
migrate-history:
|
migrate-history:
|
||||||
docker compose exec raggr aerich history
|
docker compose exec raggr aerich history
|
||||||
|
|
||||||
|
# Tests
|
||||||
|
test:
|
||||||
|
pytest tests/ -v
|
||||||
|
|
||||||
|
test-cov:
|
||||||
|
pytest tests/ -v --cov
|
||||||
|
|
||||||
# Frontend
|
# Frontend
|
||||||
frontend:
|
frontend:
|
||||||
cd raggr-frontend && yarn install && yarn build
|
cd raggr-frontend && yarn install && yarn build
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from quart import Quart, jsonify, render_template, request, send_from_directory
|
from quart import Quart, jsonify, render_template, send_from_directory
|
||||||
from quart_jwt_extended import JWTManager, get_jwt_identity, jwt_refresh_token_required
|
from quart_jwt_extended import JWTManager, get_jwt_identity, jwt_refresh_token_required
|
||||||
from tortoise import Tortoise
|
from tortoise import Tortoise
|
||||||
|
|
||||||
@@ -14,7 +15,6 @@ import blueprints.users
|
|||||||
import blueprints.whatsapp
|
import blueprints.whatsapp
|
||||||
import blueprints.users.models
|
import blueprints.users.models
|
||||||
from config.db import TORTOISE_CONFIG
|
from config.db import TORTOISE_CONFIG
|
||||||
from main import consult_simba_oracle
|
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -38,6 +38,8 @@ app = Quart(
|
|||||||
)
|
)
|
||||||
|
|
||||||
app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY")
|
app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY")
|
||||||
|
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1)
|
||||||
|
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=30)
|
||||||
app.config["MAX_CONTENT_LENGTH"] = 10 * 1024 * 1024 # 10 MB upload limit
|
app.config["MAX_CONTENT_LENGTH"] = 10 * 1024 * 1024 # 10 MB upload limit
|
||||||
jwt = JWTManager(app)
|
jwt = JWTManager(app)
|
||||||
|
|
||||||
@@ -75,39 +77,6 @@ async def serve_react_app(path):
|
|||||||
return await render_template("index.html")
|
return await render_template("index.html")
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/query", methods=["POST"])
|
|
||||||
@jwt_refresh_token_required
|
|
||||||
async def query():
|
|
||||||
current_user_uuid = get_jwt_identity()
|
|
||||||
user = await blueprints.users.models.User.get(id=current_user_uuid)
|
|
||||||
data = await request.get_json()
|
|
||||||
query = data.get("query")
|
|
||||||
conversation_id = data.get("conversation_id")
|
|
||||||
conversation = await blueprints.conversation.logic.get_conversation_by_id(
|
|
||||||
conversation_id
|
|
||||||
)
|
|
||||||
await conversation.fetch_related("messages")
|
|
||||||
await blueprints.conversation.logic.add_message_to_conversation(
|
|
||||||
conversation=conversation,
|
|
||||||
message=query,
|
|
||||||
speaker="user",
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
transcript = await blueprints.conversation.logic.get_conversation_transcript(
|
|
||||||
user=user, conversation=conversation
|
|
||||||
)
|
|
||||||
|
|
||||||
response = consult_simba_oracle(input=query, transcript=transcript)
|
|
||||||
await blueprints.conversation.logic.add_message_to_conversation(
|
|
||||||
conversation=conversation,
|
|
||||||
message=response,
|
|
||||||
speaker="simba",
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
return jsonify({"response": response})
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/messages", methods=["GET"])
|
@app.route("/api/messages", methods=["GET"])
|
||||||
@jwt_refresh_token_required
|
@jwt_refresh_token_required
|
||||||
async def get_messages():
|
async def get_messages():
|
||||||
@@ -132,17 +101,10 @@ async def get_messages():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
name = conversation.name
|
|
||||||
if len(messages) > 8:
|
|
||||||
name = await blueprints.conversation.logic.rename_conversation(
|
|
||||||
user=user,
|
|
||||||
conversation=conversation,
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"id": str(conversation.id),
|
"id": str(conversation.id),
|
||||||
"name": name,
|
"name": conversation.name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"created_at": conversation.created_at.isoformat(),
|
"created_at": conversation.created_at.isoformat(),
|
||||||
"updated_at": conversation.updated_at.isoformat(),
|
"updated_at": conversation.updated_at.isoformat(),
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import datetime
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from quart import Blueprint, Response, jsonify, make_response, request
|
from quart import Blueprint, jsonify, make_response, request
|
||||||
from quart_jwt_extended import (
|
from quart_jwt_extended import (
|
||||||
get_jwt_identity,
|
get_jwt_identity,
|
||||||
jwt_refresh_token_required,
|
jwt_refresh_token_required,
|
||||||
@@ -12,6 +11,7 @@ from quart_jwt_extended import (
|
|||||||
import blueprints.users.models
|
import blueprints.users.models
|
||||||
from utils.image_process import analyze_user_image
|
from utils.image_process import analyze_user_image
|
||||||
from utils.image_upload import ImageValidationError, process_image
|
from utils.image_upload import ImageValidationError, process_image
|
||||||
|
from utils.s3_client import generate_presigned_url as s3_presigned_url
|
||||||
from utils.s3_client import get_image as s3_get_image
|
from utils.s3_client import get_image as s3_get_image
|
||||||
from utils.s3_client import upload_image as s3_upload_image
|
from utils.s3_client import upload_image as s3_upload_image
|
||||||
|
|
||||||
@@ -19,8 +19,8 @@ from .agents import main_agent
|
|||||||
from .logic import (
|
from .logic import (
|
||||||
add_message_to_conversation,
|
add_message_to_conversation,
|
||||||
get_conversation_by_id,
|
get_conversation_by_id,
|
||||||
rename_conversation,
|
|
||||||
)
|
)
|
||||||
|
from .memory import get_memories_for_user
|
||||||
from .models import (
|
from .models import (
|
||||||
Conversation,
|
Conversation,
|
||||||
PydConversation,
|
PydConversation,
|
||||||
@@ -35,15 +35,27 @@ conversation_blueprint = Blueprint(
|
|||||||
_SYSTEM_PROMPT = SIMBA_SYSTEM_PROMPT
|
_SYSTEM_PROMPT = SIMBA_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_system_prompt_with_memories(user_id: str) -> str:
|
||||||
|
"""Append user memories to the base system prompt."""
|
||||||
|
memories = await get_memories_for_user(user_id)
|
||||||
|
if not memories:
|
||||||
|
return _SYSTEM_PROMPT
|
||||||
|
memory_block = "\n".join(f"- {m}" for m in memories)
|
||||||
|
return f"{_SYSTEM_PROMPT}\n\nUSER MEMORIES (facts the user has asked you to remember):\n{memory_block}"
|
||||||
|
|
||||||
|
|
||||||
def _build_messages_payload(
|
def _build_messages_payload(
|
||||||
conversation, query_text: str, image_description: str | None = None
|
conversation,
|
||||||
|
query_text: str,
|
||||||
|
image_description: str | None = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
recent_messages = (
|
recent_messages = (
|
||||||
conversation.messages[-10:]
|
conversation.messages[-10:]
|
||||||
if len(conversation.messages) > 10
|
if len(conversation.messages) > 10
|
||||||
else conversation.messages
|
else conversation.messages
|
||||||
)
|
)
|
||||||
messages_payload = [{"role": "system", "content": _SYSTEM_PROMPT}]
|
messages_payload = [{"role": "system", "content": system_prompt or _SYSTEM_PROMPT}]
|
||||||
for msg in recent_messages[:-1]: # Exclude the message we just added
|
for msg in recent_messages[:-1]: # Exclude the message we just added
|
||||||
role = "user" if msg.speaker == "user" else "assistant"
|
role = "user" if msg.speaker == "user" else "assistant"
|
||||||
text = msg.text
|
text = msg.text
|
||||||
@@ -79,10 +91,14 @@ async def query():
|
|||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages_payload = _build_messages_payload(conversation, query)
|
system_prompt = await _build_system_prompt_with_memories(str(user.id))
|
||||||
|
messages_payload = _build_messages_payload(
|
||||||
|
conversation, query, system_prompt=system_prompt
|
||||||
|
)
|
||||||
payload = {"messages": messages_payload}
|
payload = {"messages": messages_payload}
|
||||||
|
agent_config = {"configurable": {"user_id": str(user.id)}}
|
||||||
|
|
||||||
response = await main_agent.ainvoke(payload)
|
response = await main_agent.ainvoke(payload, config=agent_config)
|
||||||
message = response.get("messages", [])[-1].content
|
message = response.get("messages", [])[-1].content
|
||||||
await add_message_to_conversation(
|
await add_message_to_conversation(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
@@ -122,27 +138,14 @@ async def upload_image():
|
|||||||
|
|
||||||
await s3_upload_image(processed_bytes, key, output_content_type)
|
await s3_upload_image(processed_bytes, key, output_content_type)
|
||||||
|
|
||||||
return jsonify(
|
return jsonify({"image_key": key})
|
||||||
{
|
|
||||||
"image_key": key,
|
|
||||||
"image_url": f"/api/conversation/image/{key}",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@conversation_blueprint.get("/image/<path:image_key>")
|
@conversation_blueprint.get("/image/<path:image_key>")
|
||||||
@jwt_refresh_token_required
|
@jwt_refresh_token_required
|
||||||
async def serve_image(image_key: str):
|
async def serve_image(image_key: str):
|
||||||
try:
|
url = await s3_presigned_url(image_key)
|
||||||
image_bytes, content_type = await s3_get_image(image_key)
|
return jsonify({"url": url})
|
||||||
except Exception:
|
|
||||||
return jsonify({"error": "Image not found"}), 404
|
|
||||||
|
|
||||||
return Response(
|
|
||||||
image_bytes,
|
|
||||||
content_type=content_type,
|
|
||||||
headers={"Cache-Control": "private, max-age=3600"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@conversation_blueprint.post("/stream-query")
|
@conversation_blueprint.post("/stream-query")
|
||||||
@@ -175,15 +178,19 @@ async def stream_query():
|
|||||||
logging.error(f"Failed to analyze image: {e}")
|
logging.error(f"Failed to analyze image: {e}")
|
||||||
image_description = "[Image could not be analyzed]"
|
image_description = "[Image could not be analyzed]"
|
||||||
|
|
||||||
|
system_prompt = await _build_system_prompt_with_memories(str(user.id))
|
||||||
messages_payload = _build_messages_payload(
|
messages_payload = _build_messages_payload(
|
||||||
conversation, query_text or "", image_description
|
conversation, query_text or "", image_description, system_prompt=system_prompt
|
||||||
)
|
)
|
||||||
payload = {"messages": messages_payload}
|
payload = {"messages": messages_payload}
|
||||||
|
agent_config = {"configurable": {"user_id": str(user.id)}}
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
final_message = None
|
final_message = None
|
||||||
try:
|
try:
|
||||||
async for event in main_agent.astream_events(payload, version="v2"):
|
async for event in main_agent.astream_events(
|
||||||
|
payload, version="v2", config=agent_config
|
||||||
|
):
|
||||||
event_type = event.get("event")
|
event_type = event.get("event")
|
||||||
|
|
||||||
if event_type == "on_tool_start":
|
if event_type == "on_tool_start":
|
||||||
@@ -233,8 +240,6 @@ async def stream_query():
|
|||||||
@jwt_refresh_token_required
|
@jwt_refresh_token_required
|
||||||
async def get_conversation(conversation_id: str):
|
async def get_conversation(conversation_id: str):
|
||||||
conversation = await Conversation.get(id=conversation_id)
|
conversation = await Conversation.get(id=conversation_id)
|
||||||
current_user_uuid = get_jwt_identity()
|
|
||||||
user = await blueprints.users.models.User.get(id=current_user_uuid)
|
|
||||||
await conversation.fetch_related("messages")
|
await conversation.fetch_related("messages")
|
||||||
|
|
||||||
# Manually serialize the conversation with messages
|
# Manually serialize the conversation with messages
|
||||||
@@ -249,18 +254,10 @@ async def get_conversation(conversation_id: str):
|
|||||||
"image_key": msg.image_key,
|
"image_key": msg.image_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
name = conversation.name
|
|
||||||
if len(messages) > 8 and "datetime" in name.lower():
|
|
||||||
name = await rename_conversation(
|
|
||||||
user=user,
|
|
||||||
conversation=conversation,
|
|
||||||
)
|
|
||||||
print(name)
|
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"id": str(conversation.id),
|
"id": str(conversation.id),
|
||||||
"name": name,
|
"name": conversation.name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"created_at": conversation.created_at.isoformat(),
|
"created_at": conversation.created_at.isoformat(),
|
||||||
"updated_at": conversation.updated_at.isoformat(),
|
"updated_at": conversation.updated_at.isoformat(),
|
||||||
@@ -274,7 +271,7 @@ async def create_conversation():
|
|||||||
user_uuid = get_jwt_identity()
|
user_uuid = get_jwt_identity()
|
||||||
user = await blueprints.users.models.User.get(id=user_uuid)
|
user = await blueprints.users.models.User.get(id=user_uuid)
|
||||||
conversation = await Conversation.create(
|
conversation = await Conversation.create(
|
||||||
name=f"{user.username} {datetime.datetime.now().timestamp}",
|
name="New Conversation",
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -287,7 +284,7 @@ async def create_conversation():
|
|||||||
async def get_all_conversations():
|
async def get_all_conversations():
|
||||||
user_uuid = get_jwt_identity()
|
user_uuid = get_jwt_identity()
|
||||||
user = await blueprints.users.models.User.get(id=user_uuid)
|
user = await blueprints.users.models.User.get(id=user_uuid)
|
||||||
conversations = Conversation.filter(user=user)
|
conversations = Conversation.filter(user=user).order_by("-updated_at")
|
||||||
serialized_conversations = await PydListConversation.from_queryset(conversations)
|
serialized_conversations = await PydListConversation.from_queryset(conversations)
|
||||||
|
|
||||||
return jsonify(serialized_conversations.model_dump())
|
return jsonify(serialized_conversations.model_dump())
|
||||||
|
|||||||
@@ -5,9 +5,11 @@ from dotenv import load_dotenv
|
|||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.chat_models import BaseChatModel
|
from langchain.chat_models import BaseChatModel
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from tavily import AsyncTavilyClient
|
from tavily import AsyncTavilyClient
|
||||||
|
|
||||||
|
from blueprints.conversation.memory import save_memory
|
||||||
from blueprints.rag.logic import query_vector_store
|
from blueprints.rag.logic import query_vector_store
|
||||||
from utils.obsidian_service import ObsidianService
|
from utils.obsidian_service import ObsidianService
|
||||||
from utils.ynab_service import YNABService
|
from utils.ynab_service import YNABService
|
||||||
@@ -326,7 +328,7 @@ async def obsidian_search_notes(query: str) -> str:
|
|||||||
return "Obsidian integration is not configured. Please set OBSIDIAN_VAULT_PATH environment variable."
|
return "Obsidian integration is not configured. Please set OBSIDIAN_VAULT_PATH environment variable."
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Query ChromaDB for obsidian documents
|
# Query vector store for obsidian documents
|
||||||
serialized, docs = await query_vector_store(query=query)
|
serialized, docs = await query_vector_store(query=query)
|
||||||
return serialized
|
return serialized
|
||||||
|
|
||||||
@@ -589,8 +591,35 @@ async def obsidian_create_task(
|
|||||||
return f"Error creating task: {str(e)}"
|
return f"Error creating task: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def save_user_memory(content: str, config: RunnableConfig) -> str:
|
||||||
|
"""Save a fact or preference about the user for future conversations.
|
||||||
|
|
||||||
|
Use this tool when the user:
|
||||||
|
- Explicitly asks you to remember something ("remember that...", "keep in mind...")
|
||||||
|
- Shares a personal preference that would be useful in future conversations
|
||||||
|
(e.g., "I prefer metric units", "my cat's name is Luna")
|
||||||
|
- Tells you a meaningful personal fact (e.g., "I'm allergic to peanuts")
|
||||||
|
|
||||||
|
Do NOT save:
|
||||||
|
- Trivial or ephemeral info (e.g., "I'm tired today")
|
||||||
|
- Information already in the system prompt or documents
|
||||||
|
- Conversation-specific context that won't matter later
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: A concise statement of the fact or preference to remember.
|
||||||
|
Write it as a standalone sentence (e.g., "User prefers dark mode"
|
||||||
|
rather than "likes dark mode").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Confirmation that the memory was saved.
|
||||||
|
"""
|
||||||
|
user_id = config["configurable"]["user_id"]
|
||||||
|
return await save_memory(user_id=user_id, content=content)
|
||||||
|
|
||||||
|
|
||||||
# Create tools list based on what's available
|
# Create tools list based on what's available
|
||||||
tools = [get_current_date, simba_search, web_search]
|
tools = [get_current_date, simba_search, web_search, save_user_memory]
|
||||||
if ynab_enabled:
|
if ynab_enabled:
|
||||||
tools.extend(
|
tools.extend(
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import tortoise.exceptions
|
import tortoise.exceptions
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
import blueprints.users.models
|
import blueprints.users.models
|
||||||
|
|
||||||
from .models import Conversation, ConversationMessage, RenameConversationOutputSchema
|
from .models import Conversation, ConversationMessage
|
||||||
|
|
||||||
|
|
||||||
async def create_conversation(name: str = "") -> Conversation:
|
async def create_conversation(name: str = "") -> Conversation:
|
||||||
@@ -19,6 +18,12 @@ async def add_message_to_conversation(
|
|||||||
image_key: str | None = None,
|
image_key: str | None = None,
|
||||||
) -> ConversationMessage:
|
) -> ConversationMessage:
|
||||||
print(conversation, message, speaker)
|
print(conversation, message, speaker)
|
||||||
|
|
||||||
|
# Name the conversation after the first user message
|
||||||
|
if speaker == "user" and not await conversation.messages.all().exists():
|
||||||
|
conversation.name = message[:100]
|
||||||
|
await conversation.save()
|
||||||
|
|
||||||
message = await ConversationMessage.create(
|
message = await ConversationMessage.create(
|
||||||
text=message,
|
text=message,
|
||||||
speaker=speaker,
|
speaker=speaker,
|
||||||
@@ -61,22 +66,3 @@ async def get_conversation_transcript(
|
|||||||
messages.append(f"{message.speaker} at {message.created_at}: {message.text}")
|
messages.append(f"{message.speaker} at {message.created_at}: {message.text}")
|
||||||
|
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
|
|
||||||
async def rename_conversation(
|
|
||||||
user: blueprints.users.models.User,
|
|
||||||
conversation: Conversation,
|
|
||||||
) -> str:
|
|
||||||
messages: str = await get_conversation_transcript(
|
|
||||||
user=user, conversation=conversation
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
|
||||||
structured_llm = llm.with_structured_output(RenameConversationOutputSchema)
|
|
||||||
|
|
||||||
prompt = f"Summarize the following conversation into a sassy one-liner title:\n\n{messages}"
|
|
||||||
response = structured_llm.invoke(prompt)
|
|
||||||
new_name: str = response.get("title", "")
|
|
||||||
conversation.name = new_name
|
|
||||||
await conversation.save()
|
|
||||||
return new_name
|
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
from .models import UserMemory
|
||||||
|
|
||||||
|
|
||||||
|
async def get_memories_for_user(user_id: str) -> list[str]:
|
||||||
|
"""Return all memory content strings for a user, ordered by most recently updated."""
|
||||||
|
memories = await UserMemory.filter(user_id=user_id).order_by("-updated_at")
|
||||||
|
return [m.content for m in memories]
|
||||||
|
|
||||||
|
|
||||||
|
async def save_memory(user_id: str, content: str) -> str:
|
||||||
|
"""Save a new memory or touch an existing one (exact-match dedup)."""
|
||||||
|
existing = await UserMemory.filter(user_id=user_id, content=content).first()
|
||||||
|
if existing:
|
||||||
|
existing.updated_at = None # auto_now=True will refresh it on save
|
||||||
|
await existing.save(update_fields=["updated_at"])
|
||||||
|
return "Memory already exists (refreshed)."
|
||||||
|
|
||||||
|
await UserMemory.create(user_id=user_id, content=content)
|
||||||
|
return "Memory saved."
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
import enum
|
import enum
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
from tortoise.contrib.pydantic import (
|
from tortoise.contrib.pydantic import (
|
||||||
@@ -9,12 +8,6 @@ from tortoise.contrib.pydantic import (
|
|||||||
from tortoise.models import Model
|
from tortoise.models import Model
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RenameConversationOutputSchema:
|
|
||||||
title: str
|
|
||||||
justification: str
|
|
||||||
|
|
||||||
|
|
||||||
class Speaker(enum.Enum):
|
class Speaker(enum.Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
SIMBA = "simba"
|
SIMBA = "simba"
|
||||||
@@ -47,6 +40,17 @@ class ConversationMessage(Model):
|
|||||||
table = "conversation_messages"
|
table = "conversation_messages"
|
||||||
|
|
||||||
|
|
||||||
|
class UserMemory(Model):
|
||||||
|
id = fields.UUIDField(primary_key=True)
|
||||||
|
user = fields.ForeignKeyField("models.User", related_name="memories")
|
||||||
|
content = fields.TextField()
|
||||||
|
created_at = fields.DatetimeField(auto_now_add=True)
|
||||||
|
updated_at = fields.DatetimeField(auto_now=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table = "user_memories"
|
||||||
|
|
||||||
|
|
||||||
PydConversationMessage = pydantic_model_creator(ConversationMessage)
|
PydConversationMessage = pydantic_model_creator(ConversationMessage)
|
||||||
PydConversation = pydantic_model_creator(
|
PydConversation = pydantic_model_creator(
|
||||||
Conversation, name="Conversation", allow_cycles=True, exclude=("user",)
|
Conversation, name="Conversation", allow_cycles=True, exclude=("user",)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
SIMBA_SYSTEM_PROMPT = """You are a helpful cat assistant named Simba that understands veterinary terms. When there are questions to you specifically, they are referring to Simba the cat. Answer the user in as if you were a cat named Simba. Don't act too catlike. Be assertive.
|
SIMBA_SYSTEM_PROMPT = """You are Simba, Ryan's helpful personal assistant. You're named after his orange cat. You have a warm, friendly personality with a light cat-themed touch, but your priority is always being genuinely useful — give thorough, detailed answers and think things through carefully. When asked about Simba the cat, you speak as him in first person. For everything else, you're just a great assistant who happens to have a cat's name.
|
||||||
|
|
||||||
SIMBA FACTS (as of January 2026):
|
SIMBA FACTS (as of January 2026):
|
||||||
- Name: Simba
|
- Name: Simba
|
||||||
@@ -54,4 +54,7 @@ You have access to Ryan's daily journal notes. Each note lives at journal/YYYY/Y
|
|||||||
- Use journal_get_tasks to list tasks (done/pending) for today or a specific date
|
- Use journal_get_tasks to list tasks (done/pending) for today or a specific date
|
||||||
- Use journal_add_task to add a new task to today's (or a given date's) note
|
- Use journal_add_task to add a new task to today's (or a given date's) note
|
||||||
- Use journal_complete_task to check off a task as done
|
- Use journal_complete_task to check off a task as done
|
||||||
Use these tools when Ryan asks about today's tasks, wants to add something to his list, or wants to mark a task complete."""
|
Use these tools when Ryan asks about today's tasks, wants to add something to his list, or wants to mark a task complete.
|
||||||
|
|
||||||
|
USER MEMORY:
|
||||||
|
You can remember facts about the user across conversations using the save_user_memory tool. When a user explicitly asks you to remember something, or shares a meaningful preference or personal fact, save it. Saved memories will automatically appear at the end of this prompt in future conversations under "USER MEMORIES"."""
|
||||||
|
|||||||
@@ -1,7 +1,12 @@
|
|||||||
from quart import Blueprint, jsonify
|
from quart import Blueprint, jsonify
|
||||||
from quart_jwt_extended import jwt_refresh_token_required
|
from quart_jwt_extended import jwt_refresh_token_required
|
||||||
|
|
||||||
from .logic import fetch_obsidian_documents, get_vector_store_stats, index_documents, index_obsidian_documents, vector_store
|
from .logic import (
|
||||||
|
delete_all_documents,
|
||||||
|
get_vector_store_stats,
|
||||||
|
index_documents,
|
||||||
|
index_obsidian_documents,
|
||||||
|
)
|
||||||
from blueprints.users.decorators import admin_required
|
from blueprints.users.decorators import admin_required
|
||||||
|
|
||||||
rag_blueprint = Blueprint("rag_api", __name__, url_prefix="/api/rag")
|
rag_blueprint = Blueprint("rag_api", __name__, url_prefix="/api/rag")
|
||||||
@@ -32,14 +37,7 @@ async def trigger_index():
|
|||||||
async def trigger_reindex():
|
async def trigger_reindex():
|
||||||
"""Clear and reindex all documents. Admin only."""
|
"""Clear and reindex all documents. Admin only."""
|
||||||
try:
|
try:
|
||||||
# Clear existing documents
|
delete_all_documents()
|
||||||
collection = vector_store._collection
|
|
||||||
all_docs = collection.get()
|
|
||||||
|
|
||||||
if all_docs["ids"]:
|
|
||||||
collection.delete(ids=all_docs["ids"])
|
|
||||||
|
|
||||||
# Reindex
|
|
||||||
await index_documents()
|
await index_documents()
|
||||||
stats = get_vector_store_stats()
|
stats = get_vector_store_stats()
|
||||||
return jsonify({"status": "success", "stats": stats})
|
return jsonify({"status": "success", "stats": stats})
|
||||||
|
|||||||
+125
-29
@@ -1,11 +1,13 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain_chroma import Chroma
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
from langchain_postgres import PGVector
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
|
||||||
from .fetchers import PaperlessNGXService
|
from .fetchers import PaperlessNGXService
|
||||||
from utils.obsidian_service import ObsidianService
|
from utils.obsidian_service import ObsidianService
|
||||||
@@ -13,13 +15,40 @@ from utils.obsidian_service import ObsidianService
|
|||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
|
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
|
||||||
|
|
||||||
vector_store = Chroma(
|
# Convert Tortoise-style postgres:// URL to SQLAlchemy-style postgresql+psycopg://
|
||||||
collection_name="simba_docs",
|
_db_url = os.getenv(
|
||||||
embedding_function=embeddings,
|
"DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr"
|
||||||
persist_directory=os.getenv("CHROMADB_PATH", ""),
|
|
||||||
)
|
)
|
||||||
|
_pgvector_url = _db_url.replace("postgres://", "postgresql+psycopg://")
|
||||||
|
|
||||||
|
# Lazy-initialized vector store (defers DB connection to first use)
|
||||||
|
_vector_store = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_vector_store() -> PGVector:
|
||||||
|
global _vector_store
|
||||||
|
if _vector_store is None:
|
||||||
|
_vector_store = PGVector(
|
||||||
|
embeddings=embeddings,
|
||||||
|
collection_name="simba_docs",
|
||||||
|
connection=_pgvector_url,
|
||||||
|
use_jsonb=True,
|
||||||
|
create_extension=False, # created by docker init script
|
||||||
|
async_mode=True,
|
||||||
|
)
|
||||||
|
return _vector_store
|
||||||
|
|
||||||
|
|
||||||
|
def _get_engine():
|
||||||
|
"""Get a SQLAlchemy engine for direct queries."""
|
||||||
|
if not hasattr(_get_engine, "_engine"):
|
||||||
|
_get_engine._engine = create_engine(_pgvector_url)
|
||||||
|
return _get_engine._engine
|
||||||
|
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=1000, # chunk size (characters)
|
chunk_size=1000, # chunk size (characters)
|
||||||
@@ -28,6 +57,22 @@ text_splitter = RecursiveCharacterTextSplitter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_collection_id():
|
||||||
|
"""Get the UUID of our collection from the langchain_pg_collection table."""
|
||||||
|
engine = _get_engine()
|
||||||
|
try:
|
||||||
|
with engine.connect() as conn:
|
||||||
|
result = conn.execute(
|
||||||
|
text("SELECT uuid FROM langchain_pg_collection WHERE name = :name"),
|
||||||
|
{"name": "simba_docs"},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row[0] if row else None
|
||||||
|
except Exception:
|
||||||
|
# Table doesn't exist yet (first run before any indexing)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def date_to_epoch(date_str: str) -> float:
|
def date_to_epoch(date_str: str) -> float:
|
||||||
split_date = date_str.split("-")
|
split_date = date_str.split("-")
|
||||||
date = datetime.datetime(
|
date = datetime.datetime(
|
||||||
@@ -63,6 +108,7 @@ async def index_documents():
|
|||||||
documents = await fetch_documents_from_paperless_ngx()
|
documents = await fetch_documents_from_paperless_ngx()
|
||||||
|
|
||||||
splits = text_splitter.split_documents(documents)
|
splits = text_splitter.split_documents(documents)
|
||||||
|
vector_store = _get_vector_store()
|
||||||
await vector_store.aadd_documents(documents=splits)
|
await vector_store.aadd_documents(documents=splits)
|
||||||
|
|
||||||
|
|
||||||
@@ -92,13 +138,17 @@ async def fetch_obsidian_documents() -> list[Document]:
|
|||||||
"filepath": parsed["filepath"],
|
"filepath": parsed["filepath"],
|
||||||
"tags": parsed["tags"],
|
"tags": parsed["tags"],
|
||||||
"created_at": parsed["metadata"].get("created_at"),
|
"created_at": parsed["metadata"].get("created_at"),
|
||||||
**{k: v for k, v in parsed["metadata"].items() if k not in ["created_at", "created_by"]},
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in parsed["metadata"].items()
|
||||||
|
if k not in ["created_at", "created_by"]
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading {md_path}: {e}")
|
logger.warning(f"Error reading {md_path}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
@@ -109,26 +159,25 @@ async def index_obsidian_documents():
|
|||||||
|
|
||||||
Deletes existing obsidian source chunks before re-indexing.
|
Deletes existing obsidian source chunks before re-indexing.
|
||||||
"""
|
"""
|
||||||
obsidian_service = ObsidianService()
|
|
||||||
documents = await fetch_obsidian_documents()
|
documents = await fetch_obsidian_documents()
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
print("No Obsidian documents found to index")
|
logger.info("No Obsidian documents found to index")
|
||||||
return {"indexed": 0}
|
return {"indexed": 0}
|
||||||
|
|
||||||
# Delete existing obsidian chunks
|
# Delete existing obsidian chunks
|
||||||
existing_results = vector_store.get(where={"source": "obsidian"})
|
delete_documents_by_metadata("source", "obsidian")
|
||||||
if existing_results.get("ids"):
|
|
||||||
await vector_store.adelete(existing_results["ids"])
|
|
||||||
|
|
||||||
# Split and index documents
|
# Split and index documents
|
||||||
splits = text_splitter.split_documents(documents)
|
splits = text_splitter.split_documents(documents)
|
||||||
|
vector_store = _get_vector_store()
|
||||||
await vector_store.aadd_documents(documents=splits)
|
await vector_store.aadd_documents(documents=splits)
|
||||||
|
|
||||||
return {"indexed": len(documents)}
|
return {"indexed": len(documents)}
|
||||||
|
|
||||||
|
|
||||||
async def query_vector_store(query: str):
|
async def query_vector_store(query: str):
|
||||||
|
vector_store = _get_vector_store()
|
||||||
retrieved_docs = await vector_store.asimilarity_search(query, k=2)
|
retrieved_docs = await vector_store.asimilarity_search(query, k=2)
|
||||||
serialized = "\n\n".join(
|
serialized = "\n\n".join(
|
||||||
(f"Source: {doc.metadata}\nContent: {doc.page_content}")
|
(f"Source: {doc.metadata}\nContent: {doc.page_content}")
|
||||||
@@ -137,33 +186,80 @@ async def query_vector_store(query: str):
|
|||||||
return serialized, retrieved_docs
|
return serialized, retrieved_docs
|
||||||
|
|
||||||
|
|
||||||
|
def delete_all_documents():
|
||||||
|
"""Delete all documents from the vector store collection."""
|
||||||
|
collection_id = _get_collection_id()
|
||||||
|
if not collection_id:
|
||||||
|
return
|
||||||
|
engine = _get_engine()
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text("DELETE FROM langchain_pg_embedding WHERE collection_id = :cid"),
|
||||||
|
{"cid": collection_id},
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def delete_documents_by_metadata(key: str, value: str):
|
||||||
|
"""Delete documents matching a metadata key/value pair."""
|
||||||
|
collection_id = _get_collection_id()
|
||||||
|
if not collection_id:
|
||||||
|
return
|
||||||
|
engine = _get_engine()
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM langchain_pg_embedding "
|
||||||
|
"WHERE collection_id = :cid AND cmetadata->>:key = :value"
|
||||||
|
),
|
||||||
|
{"cid": collection_id, "key": key, "value": value},
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
def get_vector_store_stats():
|
def get_vector_store_stats():
|
||||||
"""Get statistics about the vector store."""
|
"""Get statistics about the vector store."""
|
||||||
collection = vector_store._collection
|
collection_id = _get_collection_id()
|
||||||
count = collection.count()
|
count = 0
|
||||||
|
if collection_id:
|
||||||
|
engine = _get_engine()
|
||||||
|
with engine.connect() as conn:
|
||||||
|
result = conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_id = :cid"
|
||||||
|
),
|
||||||
|
{"cid": collection_id},
|
||||||
|
)
|
||||||
|
count = result.scalar()
|
||||||
return {
|
return {
|
||||||
"total_documents": count,
|
"total_documents": count,
|
||||||
"collection_name": collection.name,
|
"collection_name": "simba_docs",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def list_all_documents(limit: int = 10):
|
def list_all_documents(limit: int = 10):
|
||||||
"""List documents in the vector store with their metadata."""
|
"""List documents in the vector store with their metadata."""
|
||||||
collection = vector_store._collection
|
collection_id = _get_collection_id()
|
||||||
results = collection.get(limit=limit, include=["metadatas", "documents"])
|
if not collection_id:
|
||||||
|
return []
|
||||||
|
|
||||||
documents = []
|
engine = _get_engine()
|
||||||
for i, doc_id in enumerate(results["ids"]):
|
with engine.connect() as conn:
|
||||||
documents.append(
|
result = conn.execute(
|
||||||
{
|
text(
|
||||||
"id": doc_id,
|
"SELECT id, document, cmetadata FROM langchain_pg_embedding "
|
||||||
"metadata": results["metadatas"][i]
|
"WHERE collection_id = :cid LIMIT :limit"
|
||||||
if results.get("metadatas")
|
),
|
||||||
else None,
|
{"cid": collection_id, "limit": limit},
|
||||||
"content_preview": results["documents"][i][:200]
|
|
||||||
if results.get("documents")
|
|
||||||
else None,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
documents = []
|
||||||
|
for row in result:
|
||||||
|
documents.append(
|
||||||
|
{
|
||||||
|
"id": str(row[0]),
|
||||||
|
"metadata": row[2],
|
||||||
|
"content_preview": row[1][:200] if row[1] else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class OIDCUserService:
|
|||||||
claims.get("preferred_username") or claims.get("name") or user.username
|
claims.get("preferred_username") or claims.get("name") or user.username
|
||||||
)
|
)
|
||||||
# Update LDAP groups from claims
|
# Update LDAP groups from claims
|
||||||
user.ldap_groups = claims.get("groups", [])
|
user.ldap_groups = claims.get("groups") or []
|
||||||
await user.save()
|
await user.save()
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@@ -48,7 +48,7 @@ class OIDCUserService:
|
|||||||
user.oidc_subject = oidc_subject
|
user.oidc_subject = oidc_subject
|
||||||
user.auth_provider = "oidc"
|
user.auth_provider = "oidc"
|
||||||
user.password = None # Clear password
|
user.password = None # Clear password
|
||||||
user.ldap_groups = claims.get("groups", [])
|
user.ldap_groups = claims.get("groups") or []
|
||||||
await user.save()
|
await user.save()
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ class OIDCUserService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Extract LDAP groups from claims
|
# Extract LDAP groups from claims
|
||||||
groups = claims.get("groups", [])
|
groups = claims.get("groups") or []
|
||||||
|
|
||||||
user = await User.create(
|
user = await User.create(
|
||||||
id=uuid4(),
|
id=uuid4(),
|
||||||
|
|||||||
+2
-4
@@ -2,7 +2,7 @@ version: "3.8"
|
|||||||
|
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: postgres:16-alpine
|
image: pgvector/pgvector:pg16
|
||||||
ports:
|
ports:
|
||||||
- "5432:5432"
|
- "5432:5432"
|
||||||
environment:
|
environment:
|
||||||
@@ -11,6 +11,7 @@ services:
|
|||||||
- POSTGRES_DB=${POSTGRES_DB:-raggr}
|
- POSTGRES_DB=${POSTGRES_DB:-raggr}
|
||||||
volumes:
|
volumes:
|
||||||
- postgres_data:/var/lib/postgresql/data
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
- ./docker/init-pgvector.sql:/docker-entrypoint-initdb.d/init-pgvector.sql
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-raggr}"]
|
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-raggr}"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
@@ -29,7 +30,6 @@ services:
|
|||||||
- PAPERLESS_TOKEN=${PAPERLESS_TOKEN}
|
- PAPERLESS_TOKEN=${PAPERLESS_TOKEN}
|
||||||
- BASE_URL=${BASE_URL}
|
- BASE_URL=${BASE_URL}
|
||||||
- OLLAMA_URL=${OLLAMA_URL:-http://localhost:11434}
|
- OLLAMA_URL=${OLLAMA_URL:-http://localhost:11434}
|
||||||
- CHROMADB_PATH=/app/data/chromadb
|
|
||||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
|
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
|
||||||
- LLAMA_SERVER_URL=${LLAMA_SERVER_URL}
|
- LLAMA_SERVER_URL=${LLAMA_SERVER_URL}
|
||||||
@@ -66,10 +66,8 @@ services:
|
|||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
volumes:
|
volumes:
|
||||||
- chromadb_data:/app/data/chromadb
|
|
||||||
- ./obvault:/app/data/obsidian
|
- ./obvault:/app/data/obsidian
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
chromadb_data:
|
|
||||||
postgres_data:
|
postgres_data:
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
CREATE EXTENSION IF NOT EXISTS vector;
|
||||||
@@ -1,278 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sqlite3
|
|
||||||
import time
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
import chromadb
|
|
||||||
from utils.chunker import Chunker
|
|
||||||
from utils.cleaner import pdf_to_image, summarize_pdf_image
|
|
||||||
from llm import LLMClient
|
|
||||||
from scripts.query import QueryGenerator
|
|
||||||
from utils.request import PaperlessNGXService
|
|
||||||
|
|
||||||
_dotenv_loaded = load_dotenv()
|
|
||||||
|
|
||||||
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
|
|
||||||
simba_docs = client.get_or_create_collection(name="simba_docs2")
|
|
||||||
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="An LLM tool to query information about Simba <3"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("query", type=str, help="questions about simba's health")
|
|
||||||
parser.add_argument(
|
|
||||||
"--reindex", action="store_true", help="re-index the simba documents"
|
|
||||||
)
|
|
||||||
parser.add_argument("--classify", action="store_true", help="test classification")
|
|
||||||
parser.add_argument("--index", help="index a file")
|
|
||||||
|
|
||||||
ppngx = PaperlessNGXService()
|
|
||||||
|
|
||||||
llm_client = LLMClient()
|
|
||||||
|
|
||||||
|
|
||||||
def index_using_pdf_llm(doctypes):
|
|
||||||
logging.info("reindex data...")
|
|
||||||
files = ppngx.get_data()
|
|
||||||
for file in files:
|
|
||||||
document_id: int = file["id"]
|
|
||||||
pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
|
||||||
image_paths = pdf_to_image(filepath=pdf_path)
|
|
||||||
logging.info(f"summarizing {file}")
|
|
||||||
generated_summary = summarize_pdf_image(filepaths=image_paths)
|
|
||||||
file["content"] = generated_summary
|
|
||||||
|
|
||||||
chunk_data(files, simba_docs, doctypes=doctypes)
|
|
||||||
|
|
||||||
|
|
||||||
def date_to_epoch(date_str: str) -> float:
|
|
||||||
split_date = date_str.split("-")
|
|
||||||
date = datetime.datetime(
|
|
||||||
int(split_date[0]),
|
|
||||||
int(split_date[1]),
|
|
||||||
int(split_date[2]),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
|
|
||||||
return date.timestamp()
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_data(docs, collection, doctypes):
|
|
||||||
# Step 2: Create chunks
|
|
||||||
chunker = Chunker(collection)
|
|
||||||
|
|
||||||
logging.info(f"chunking {len(docs)} documents")
|
|
||||||
texts: list[str] = [doc["content"] for doc in docs]
|
|
||||||
with sqlite3.connect("database/visited.db") as conn:
|
|
||||||
to_insert = []
|
|
||||||
c = conn.cursor()
|
|
||||||
for index, text in enumerate(texts):
|
|
||||||
metadata = {
|
|
||||||
"created_date": date_to_epoch(docs[index]["created_date"]),
|
|
||||||
"filename": docs[index]["original_file_name"],
|
|
||||||
"document_type": doctypes.get(docs[index]["document_type"], ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
if doctypes:
|
|
||||||
metadata["type"] = doctypes.get(docs[index]["document_type"])
|
|
||||||
|
|
||||||
chunker.chunk_document(
|
|
||||||
document=text,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
to_insert.append((docs[index]["id"],))
|
|
||||||
|
|
||||||
c.executemany(
|
|
||||||
"INSERT INTO indexed_documents (paperless_id) values (?)", to_insert
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(texts: list[str], collection):
|
|
||||||
chunker = Chunker(collection)
|
|
||||||
|
|
||||||
for index, text in enumerate(texts):
|
|
||||||
metadata = {}
|
|
||||||
chunker.chunk_document(
|
|
||||||
document=text,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def classify_query(query: str, transcript: str) -> bool:
|
|
||||||
logging.info("Starting query generation")
|
|
||||||
qg_start = time.time()
|
|
||||||
qg = QueryGenerator()
|
|
||||||
query_type = qg.get_query_type(input=query, transcript=transcript)
|
|
||||||
logging.info(query_type)
|
|
||||||
qg_end = time.time()
|
|
||||||
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
|
||||||
return query_type == "Simba"
|
|
||||||
|
|
||||||
|
|
||||||
def consult_oracle(
|
|
||||||
input: str,
|
|
||||||
collection,
|
|
||||||
transcript: str = "",
|
|
||||||
):
|
|
||||||
chunker = Chunker(collection)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# Ask
|
|
||||||
logging.info("Starting query generation")
|
|
||||||
qg_start = time.time()
|
|
||||||
qg = QueryGenerator()
|
|
||||||
doctype_query = qg.get_doctype_query(input=input)
|
|
||||||
# metadata_filter = qg.get_query(input)
|
|
||||||
metadata_filter = {**doctype_query}
|
|
||||||
logging.info(metadata_filter)
|
|
||||||
qg_end = time.time()
|
|
||||||
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
|
||||||
|
|
||||||
logging.info("Starting embedding generation")
|
|
||||||
embedding_start = time.time()
|
|
||||||
embeddings = chunker.embedding_fx(inputs=[input])
|
|
||||||
embedding_end = time.time()
|
|
||||||
logging.info(
|
|
||||||
f"Embedding generation took {embedding_end - embedding_start:.2f} seconds"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Starting collection query")
|
|
||||||
query_start = time.time()
|
|
||||||
results = collection.query(
|
|
||||||
query_texts=[input],
|
|
||||||
query_embeddings=embeddings,
|
|
||||||
where=metadata_filter,
|
|
||||||
)
|
|
||||||
query_end = time.time()
|
|
||||||
logging.info(f"Collection query took {query_end - query_start:.2f} seconds")
|
|
||||||
|
|
||||||
# Generate
|
|
||||||
logging.info("Starting LLM generation")
|
|
||||||
llm_start = time.time()
|
|
||||||
system_prompt = "You are a helpful assistant that understands veterinary terms."
|
|
||||||
transcript_prompt = f"Here is the message transcript thus far {transcript}."
|
|
||||||
prompt = f"""Using the following data, help answer the user's query by providing as many details as possible.
|
|
||||||
Using this data: {results}. {transcript_prompt if len(transcript) > 0 else ""}
|
|
||||||
Respond to this prompt: {input}"""
|
|
||||||
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
|
|
||||||
llm_end = time.time()
|
|
||||||
logging.info(f"LLM generation took {llm_end - llm_start:.2f} seconds")
|
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
|
||||||
logging.info(f"Total consult_oracle execution took {total_time:.2f} seconds")
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def llm_chat(input: str, transcript: str = "") -> str:
|
|
||||||
system_prompt = "You are a helpful assistant that understands veterinary terms."
|
|
||||||
transcript_prompt = f"Here is the message transcript thus far {transcript}."
|
|
||||||
prompt = f"""Answer the user in as if you were a cat named Simba. Don't act too catlike. Be assertive.
|
|
||||||
{transcript_prompt if len(transcript) > 0 else ""}
|
|
||||||
Respond to this prompt: {input}"""
|
|
||||||
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def paperless_workflow(input):
|
|
||||||
# Step 1: Get the text
|
|
||||||
ppngx = PaperlessNGXService()
|
|
||||||
docs = ppngx.get_data()
|
|
||||||
|
|
||||||
chunk_data(docs, collection=simba_docs)
|
|
||||||
consult_oracle(input, simba_docs)
|
|
||||||
|
|
||||||
|
|
||||||
def consult_simba_oracle(input: str, transcript: str = ""):
|
|
||||||
is_simba_related = classify_query(query=input, transcript=transcript)
|
|
||||||
|
|
||||||
if is_simba_related:
|
|
||||||
logging.info("Query is related to simba")
|
|
||||||
return consult_oracle(
|
|
||||||
input=input,
|
|
||||||
collection=simba_docs,
|
|
||||||
transcript=transcript,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Query is NOT related to simba")
|
|
||||||
|
|
||||||
return llm_chat(input=input, transcript=transcript)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_indexed_files(docs):
|
|
||||||
with sqlite3.connect("database/visited.db") as conn:
|
|
||||||
c = conn.cursor()
|
|
||||||
c.execute(
|
|
||||||
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
|
|
||||||
)
|
|
||||||
c.execute("SELECT paperless_id FROM indexed_documents")
|
|
||||||
rows = c.fetchall()
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
visited = {row[0] for row in rows}
|
|
||||||
return [doc for doc in docs if doc["id"] not in visited]
|
|
||||||
|
|
||||||
|
|
||||||
def reindex():
|
|
||||||
with sqlite3.connect("database/visited.db") as conn:
|
|
||||||
c = conn.cursor()
|
|
||||||
# Ensure the table exists before trying to delete from it
|
|
||||||
c.execute(
|
|
||||||
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
|
|
||||||
)
|
|
||||||
c.execute("DELETE FROM indexed_documents")
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
# Delete all documents from the collection
|
|
||||||
all_docs = simba_docs.get()
|
|
||||||
if all_docs["ids"]:
|
|
||||||
simba_docs.delete(ids=all_docs["ids"])
|
|
||||||
|
|
||||||
logging.info("Fetching documents from Paperless-NGX")
|
|
||||||
ppngx = PaperlessNGXService()
|
|
||||||
docs = ppngx.get_data()
|
|
||||||
docs = filter_indexed_files(docs)
|
|
||||||
logging.info(f"Fetched {len(docs)} documents")
|
|
||||||
|
|
||||||
# Delete all chromadb data
|
|
||||||
ids = simba_docs.get(ids=None, limit=None, offset=0)
|
|
||||||
all_ids = ids["ids"]
|
|
||||||
if len(all_ids) > 0:
|
|
||||||
simba_docs.delete(ids=all_ids)
|
|
||||||
|
|
||||||
# Chunk documents
|
|
||||||
logging.info("Chunking documents now ...")
|
|
||||||
doctype_lookup = ppngx.get_doctypes()
|
|
||||||
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
|
|
||||||
logging.info("Done chunking documents")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = parser.parse_args()
|
|
||||||
if args.reindex:
|
|
||||||
reindex()
|
|
||||||
|
|
||||||
if args.classify:
|
|
||||||
consult_simba_oracle(input="yohohoho testing")
|
|
||||||
consult_simba_oracle(input="write an email")
|
|
||||||
consult_simba_oracle(input="how much does simba weigh")
|
|
||||||
|
|
||||||
if args.query:
|
|
||||||
logging.info("Consulting oracle ...")
|
|
||||||
print(
|
|
||||||
consult_oracle(
|
|
||||||
input=args.query,
|
|
||||||
collection=simba_docs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logging.info("please provide a query")
|
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
from tortoise import BaseDBAsyncClient
|
||||||
|
|
||||||
|
RUN_IN_TRANSACTION = True
|
||||||
|
|
||||||
|
|
||||||
|
async def upgrade(db: BaseDBAsyncClient) -> str:
|
||||||
|
return """
|
||||||
|
CREATE TABLE IF NOT EXISTS "user_memories" (
|
||||||
|
"id" UUID NOT NULL PRIMARY KEY,
|
||||||
|
"content" TEXT NOT NULL,
|
||||||
|
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"user_id" UUID NOT NULL REFERENCES "users" ("id") ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS "email_accounts" (
|
||||||
|
"id" UUID NOT NULL PRIMARY KEY,
|
||||||
|
"email_address" VARCHAR(255) NOT NULL UNIQUE,
|
||||||
|
"display_name" VARCHAR(255),
|
||||||
|
"imap_host" VARCHAR(255) NOT NULL,
|
||||||
|
"imap_port" INT NOT NULL DEFAULT 993,
|
||||||
|
"imap_username" VARCHAR(255) NOT NULL,
|
||||||
|
"imap_password" TEXT NOT NULL,
|
||||||
|
"is_active" BOOL NOT NULL DEFAULT True,
|
||||||
|
"last_error" TEXT,
|
||||||
|
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"user_id" UUID NOT NULL REFERENCES "users" ("id") ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
COMMENT ON TABLE "email_accounts" IS 'Email account configuration for IMAP connections.';
|
||||||
|
CREATE TABLE IF NOT EXISTS "emails" (
|
||||||
|
"id" UUID NOT NULL PRIMARY KEY,
|
||||||
|
"message_id" VARCHAR(255) NOT NULL UNIQUE,
|
||||||
|
"subject" VARCHAR(500) NOT NULL,
|
||||||
|
"from_address" VARCHAR(255) NOT NULL,
|
||||||
|
"to_address" TEXT NOT NULL,
|
||||||
|
"date" TIMESTAMPTZ NOT NULL,
|
||||||
|
"body_text" TEXT,
|
||||||
|
"body_html" TEXT,
|
||||||
|
"chromadb_doc_id" VARCHAR(255),
|
||||||
|
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"expires_at" TIMESTAMPTZ NOT NULL,
|
||||||
|
"account_id" UUID NOT NULL REFERENCES "email_accounts" ("id") ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS "idx_emails_message_981ddd" ON "emails" ("message_id");
|
||||||
|
COMMENT ON TABLE "emails" IS 'Email message metadata and content.';
|
||||||
|
CREATE TABLE IF NOT EXISTS "email_sync_status" (
|
||||||
|
"id" UUID NOT NULL PRIMARY KEY,
|
||||||
|
"last_sync_date" TIMESTAMPTZ,
|
||||||
|
"last_message_uid" INT NOT NULL DEFAULT 0,
|
||||||
|
"message_count" INT NOT NULL DEFAULT 0,
|
||||||
|
"consecutive_failures" INT NOT NULL DEFAULT 0,
|
||||||
|
"last_failure_date" TIMESTAMPTZ,
|
||||||
|
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"account_id" UUID NOT NULL REFERENCES "email_accounts" ("id") ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
COMMENT ON TABLE "email_sync_status" IS 'Tracks sync progress and state per email account.';"""
|
||||||
|
|
||||||
|
|
||||||
|
async def downgrade(db: BaseDBAsyncClient) -> str:
|
||||||
|
return """
|
||||||
|
DROP TABLE IF EXISTS "user_memories";
|
||||||
|
DROP TABLE IF EXISTS "email_accounts";
|
||||||
|
DROP TABLE IF EXISTS "emails";
|
||||||
|
DROP TABLE IF EXISTS "email_sync_status";"""
|
||||||
|
|
||||||
|
|
||||||
|
MODELS_STATE = (
|
||||||
|
"eJztXGtv2zYU/SuCPrVAFjTPbcUwwE7czVudDLGz9ZFCoCXa1ixRGkk1NYr+911Skq0HZV"
|
||||||
|
"t+RUr1oU1C8lLU4SV57tGVvuquZ2GHHV955DOmDHHbI/pr7atOkIvhF2X9kaYj31/UigKO"
|
||||||
|
"ho40MBMtZQ0aMk6RyaFyhByGocjCzKS2H12MBI4jCj0TGtpkvCgKiP1fgA3ujTGfYAoVHz"
|
||||||
|
"9BsU0s/AWz+E9/aoxs7FipcduWuLYsN/jMl2X3993rN7KluNzQMD0ncMmitT/jE4/MmweB"
|
||||||
|
"bR0LG1E3xgRTxLGVuA0xyui246JwxFDAaYDnQ7UWBRYeocARYOi/jAJiCgw0eSXx3/mveg"
|
||||||
|
"l4AGoBrU24wOLrt/CuFvcsS3VxqavfW3cvzi5fyrv0GB9TWSkR0b9JQ8RRaCpxXQApf+ag"
|
||||||
|
"vJogqoYybp8BEwa6CYxxwQLHhQ/FQMYAbYaa7qIvhoPJmE/gz9OLiyUw/t26k0hCKwmlB3"
|
||||||
|
"4dev1NVHUa1glIFxCaFItbNhDPA3kNNdx2sRrMtGUGUisyPY5/qSjAcA/WLXFm0SJYgu+g"
|
||||||
|
"2+v0B63eX+JOXMb+cyRErUFH1JzK0lmm9MVlZirmnWj/dAe/a+JP7cPtTSfr+/N2gw+6GB"
|
||||||
|
"MKuGcQ79FAVmK9xqUxMKmJDXxrw4lNWzYT+6QTGw0+Ma8MU6PcCZIw2eIYicZ2wEnc/NAQ"
|
||||||
|
"R+9oqjwzBBh58N54FNtj8ieeSQi7MA5ETNVhEZGO+6ibqoK2KF2MgqLHORtJOgXcHdwT5u"
|
||||||
|
"Hp2epfta47usRwiMzpI6KWUQCmixlDY8zygLYjyzd/3mFnTs3UWCYJXC/ssZq7ShG2Eivv"
|
||||||
|
"1EtglEIvX+WeutkSROC+reja4kpL0FnBghMgrkeGjeRENqS41qSY4y+KI38ApWoo4/Z1Ic"
|
||||||
|
"XLjvLOu0HqFI+p74te693L1En+9vbmt7h5gipfvb1tNwz5ORKpPENmPkZTFRkQAWSHBG6O"
|
||||||
|
"CqRmN2H+xEtHv+937l5r4kR/IP1ur916rTHbHSJ9vSlORZknr9YIMk9eFcaYoiq9gGwXTh"
|
||||||
|
"ZjimdlQvWU0Ub4Hp56pYG8ODldA0loVQilrMtsRslDu9yRqTDd5flZ03DAzIiHW4YFWS2y"
|
||||||
|
"siiujA8U7lI2TtgnKxbxVw+7Hp3pCjKcqF3KgWUQ5IqGdsN9nwH3hYtwTErR34RJw4AbBv"
|
||||||
|
"xdMeBGI34WE1sdjbham2FdROIKs8AtVOJ9s78i3rea8TVMr/5MT8xj2cf/SZu6cL0DpAD4"
|
||||||
|
"iLFHjyo8s20TRGdqMJNWGTCHMx5GU5VTaJaA1xa8N3m6A2Tt7k3r7r2aOsftk37bfj/otD"
|
||||||
|
"LoYhfZThkvnRvsxkVXr/hdOujJq/Xkw2X6YU5AfJwgzmBLN0jgDosEWzWYCtOdiImHRfVs"
|
||||||
|
"HVDPijE9y0EqnczARNyeauF7noMRWeKgSdvs8gfjfW2mZY/qEuv/9vZtav23u9nQ+L7X7o"
|
||||||
|
"DzSpihkR1Soe7NQAnuxEUmcIQpVuiKK1Z/xraGHntyuc42kI2QErvAZdZjPdsyDRYM/8Wm"
|
||||||
|
"IlotBjRrV0Mw93LqQ/w4MXzqfbatcltqzvBwVEp3PBM5W3DRzBOadbbVi+Jt9SK3rToW8o"
|
||||||
|
"0x9QJfkRLzR//2Rg1pxiwD6D2Bu/xo2SY/0hyb8U97g/fjp/3wfHHny1XJrACZIVaig0aV"
|
||||||
|
"fJbiVaNKPtOJnSfG5VShVVmFudc0dpNaWOWINJ9SmFwRySeUm2ORfihaPc9fC4qQHyPT9A"
|
||||||
|
"JhthUgHdFXK+yqZpDsU1yVsOgKdbUTKxPF8qqcnvX0VV12p0WZp/CTI6H2aYhYWvRQ9ljP"
|
||||||
|
"oLSOzQN5IH3uwbal+YgybGlyUJps+GjzCYTTP1hoplEs2sNgjrW3NpkyjXva1YR6LrpuP5"
|
||||||
|
"CRR7XPEDPAD4YRNSeaiXw0tCHsg5UoR9bowDvih1vowJErKJ92Fccwaas6Cm17iQk3CK+3"
|
||||||
|
"jayfXFK/WEuxvFiiWF7kFcsR7CKCGIEfK86oYjSzdvWEdC++CbyyENAlye1eHeE8dIKPiI"
|
||||||
|
"jKxlqxTT2jrJpEVfFtL42Xh541M8q+9ZEyqkl+9aGXhcRowl3F47sVwMZGDbDqhELJsuGK"
|
||||||
|
"MLSSzE1hWhOQm5f5G+VsU0kUf/Ft6G2DiU1b1nNiazKRax3WkXJVMjszbdUkaMaA5DEsna"
|
||||||
|
"NZXxHwKJOrmXaSKqVrpjAuEhYTc7BCX0zJv+vqjJGNkAlH9jigUhvWhMrX7bX+EsUES7mL"
|
||||||
|
"FamOJXpIaJBzK4otITfCGEMVEhOTznxwNC1OpTtK9KExzDlcnh09EKFuxt2AO/OAHWv9wP"
|
||||||
|
"c9ypnmgovZvoPjFkzzMZXvgjYaZUU0yshpy8tBOcNGqYwVC5v5DpoZZVOAs3ZN7JB4S9s3"
|
||||||
|
"JuDYZeBMGdVFXTsUmGJ/zoPZJQXCQcomg6W9P27y889nW0Apv0Xzw+nJ+Y/nP51dnv8ETe"
|
||||||
|
"RQ5iU/LgE3nzkpMdgktT9n2DhjxhkLk/w7MQ8p1rRyPdQF3UMLWzYE2sBHPit8d2lKdcru"
|
||||||
|
"gOnU84O/wtnUDmLcwJR6iizVYpdNW9XkmG9e7G70wiaFspnY5sXu5sXu6r7YnRE2dpGEWS"
|
||||||
|
"8s0zlTM2IaoSq3AyD60Ft/3lmNINm7fJxApkhBToO3SkTOTNxqHXkA1VOmCTvNp57YdJjM"
|
||||||
|
"PBWdYCm74qRQnNeRS/cgdOSegBn+MU1w2tBYnL1g4/pHYSF0ZkJf2JqnxsJOGCnHI+gwoF"
|
||||||
|
"gTdzeFgYg0Vxaqx5oNsR92MfTvhB0LA8matQn86kDzRkWuiIosIxrptJuka+Wtd8DtqhUh"
|
||||||
|
"VYjKrfUoWE5JnIkcqNZJoVaoMj2cZPhqC2q+Y8EwxqDgaXAhgDm77xI90T02AyE8GdExoS"
|
||||||
|
"AxhSAWmX+XWMolGaGw+Q6d7aDZpJ94k26UlOeppDR5WM8sD2vfSQ31z8JqYWqbE10RPUc1"
|
||||||
|
"R8uCZrRoU5kv5xU/S1+TEUcT+KTJMjvhIcVho3j9Xflt8+KH6QmTujzoPcQHc2BplAAxal"
|
||||||
|
"5PAPfyGbfCj3MXfxin+OPcB/sozt4O3Z19FKfENzZ2f7x8+x8fHBMe"
|
||||||
|
)
|
||||||
+13
-2
@@ -5,7 +5,8 @@ description = "Add your description here"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"chromadb>=1.1.0",
|
"langchain-postgres>=0.0.13",
|
||||||
|
"psycopg[binary]>=3.1.0",
|
||||||
"python-dotenv>=1.0.0",
|
"python-dotenv>=1.0.0",
|
||||||
"flask>=3.1.2",
|
"flask>=3.1.2",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
@@ -30,7 +31,6 @@ dependencies = [
|
|||||||
"asyncpg>=0.30.0",
|
"asyncpg>=0.30.0",
|
||||||
"langchain-openai>=1.1.6",
|
"langchain-openai>=1.1.6",
|
||||||
"langchain>=1.2.0",
|
"langchain>=1.2.0",
|
||||||
"langchain-chroma>=1.0.0",
|
|
||||||
"langchain-community>=0.4.1",
|
"langchain-community>=0.4.1",
|
||||||
"jq>=1.10.0",
|
"jq>=1.10.0",
|
||||||
"tavily-python>=0.7.17",
|
"tavily-python>=0.7.17",
|
||||||
@@ -42,6 +42,17 @@ dependencies = [
|
|||||||
"aioboto3>=13.0.0",
|
"aioboto3>=13.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
test = [
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
"pytest-asyncio>=0.25.0",
|
||||||
|
"pytest-cov>=6.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
|
||||||
[tool.aerich]
|
[tool.aerich]
|
||||||
tortoise_orm = "config.db.TORTOISE_CONFIG"
|
tortoise_orm = "config.db.TORTOISE_CONFIG"
|
||||||
location = "./migrations"
|
location = "./migrations"
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ class ConversationService {
|
|||||||
async uploadImage(
|
async uploadImage(
|
||||||
file: File,
|
file: File,
|
||||||
conversationId: string,
|
conversationId: string,
|
||||||
): Promise<{ image_key: string; image_url: string }> {
|
): Promise<{ image_key: string }> {
|
||||||
const formData = new FormData();
|
const formData = new FormData();
|
||||||
formData.append("file", file);
|
formData.append("file", file);
|
||||||
formData.append("conversation_id", conversationId);
|
formData.append("conversation_id", conversationId);
|
||||||
@@ -147,8 +147,15 @@ class ConversationService {
|
|||||||
return await response.json();
|
return await response.json();
|
||||||
}
|
}
|
||||||
|
|
||||||
getImageUrl(imageKey: string): string {
|
async getPresignedImageUrl(imageKey: string): Promise<string> {
|
||||||
return `/api/conversation/image/${imageKey}`;
|
const response = await userService.fetchWithRefreshToken(
|
||||||
|
`${this.conversationBaseUrl}/image/${imageKey}`,
|
||||||
|
);
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error("Failed to get image URL");
|
||||||
|
}
|
||||||
|
const data = await response.json();
|
||||||
|
return data.url;
|
||||||
}
|
}
|
||||||
|
|
||||||
async streamQuery(
|
async streamQuery(
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { useEffect, useState, useRef } from "react";
|
import { useCallback, useEffect, useState, useRef } from "react";
|
||||||
import { LogOut, Shield, PanelLeftClose, PanelLeftOpen, Menu, X } from "lucide-react";
|
import { LogOut, Shield, PanelLeftClose, PanelLeftOpen, Menu, X } from "lucide-react";
|
||||||
import { conversationService } from "../api/conversationService";
|
import { conversationService } from "../api/conversationService";
|
||||||
import { userService } from "../api/userService";
|
import { userService } from "../api/userService";
|
||||||
@@ -63,9 +63,13 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
const abortControllerRef = useRef<AbortController | null>(null);
|
const abortControllerRef = useRef<AbortController | null>(null);
|
||||||
const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"];
|
const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"];
|
||||||
|
|
||||||
const scrollToBottom = () => {
|
const scrollToBottom = useCallback(() => {
|
||||||
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
|
requestAnimationFrame(() => {
|
||||||
};
|
messagesEndRef.current?.scrollIntoView({
|
||||||
|
behavior: isLoading ? "instant" : "smooth",
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}, [isLoading]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
isMountedRef.current = true;
|
isMountedRef.current = true;
|
||||||
@@ -116,21 +120,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
}, [messages]);
|
}, [messages]);
|
||||||
|
|
||||||
useEffect(() => {
|
const handleQuestionSubmit = useCallback(async () => {
|
||||||
const load = async () => {
|
|
||||||
if (!selectedConversation) return;
|
|
||||||
try {
|
|
||||||
const conv = await conversationService.getConversation(selectedConversation.id);
|
|
||||||
setSelectedConversation({ id: conv.id, title: conv.name });
|
|
||||||
setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })));
|
|
||||||
} catch (err) {
|
|
||||||
console.error("Failed to load messages:", err);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
load();
|
|
||||||
}, [selectedConversation?.id]);
|
|
||||||
|
|
||||||
const handleQuestionSubmit = async () => {
|
|
||||||
if ((!query.trim() && !pendingImage) || isLoading) return;
|
if ((!query.trim() && !pendingImage) || isLoading) return;
|
||||||
|
|
||||||
let activeConversation = selectedConversation;
|
let activeConversation = selectedConversation;
|
||||||
@@ -211,22 +201,28 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
if (isMountedRef.current) setIsLoading(false);
|
if (isMountedRef.current) {
|
||||||
|
setIsLoading(false);
|
||||||
|
loadConversations();
|
||||||
|
}
|
||||||
abortControllerRef.current = null;
|
abortControllerRef.current = null;
|
||||||
}
|
}
|
||||||
};
|
}, [query, pendingImage, isLoading, selectedConversation, simbaMode, messages, setAuthenticated]);
|
||||||
|
|
||||||
const handleQueryChange = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
const handleQueryChange = useCallback((event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
setQuery(event.target.value);
|
setQuery(event.target.value);
|
||||||
};
|
}, []);
|
||||||
|
|
||||||
const handleKeyDown = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
const handleKeyDown = useCallback((event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
const kev = event as unknown as React.KeyboardEvent<HTMLTextAreaElement>;
|
const kev = event as unknown as React.KeyboardEvent<HTMLTextAreaElement>;
|
||||||
if (kev.key === "Enter" && !kev.shiftKey) {
|
if (kev.key === "Enter" && !kev.shiftKey) {
|
||||||
kev.preventDefault();
|
kev.preventDefault();
|
||||||
handleQuestionSubmit();
|
handleQuestionSubmit();
|
||||||
}
|
}
|
||||||
};
|
}, [handleQuestionSubmit]);
|
||||||
|
|
||||||
|
const handleImageSelect = useCallback((file: File) => setPendingImage(file), []);
|
||||||
|
const handleClearImage = useCallback(() => setPendingImage(null), []);
|
||||||
|
|
||||||
const handleLogout = () => {
|
const handleLogout = () => {
|
||||||
localStorage.removeItem("access_token");
|
localStorage.removeItem("access_token");
|
||||||
@@ -380,8 +376,8 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
setSimbaMode={setSimbaMode}
|
setSimbaMode={setSimbaMode}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
pendingImage={pendingImage}
|
pendingImage={pendingImage}
|
||||||
onImageSelect={(file) => setPendingImage(file)}
|
onImageSelect={handleImageSelect}
|
||||||
onClearImage={() => setPendingImage(null)}
|
onClearImage={handleClearImage}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -416,7 +412,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<footer className="border-t border-sand-light/40 bg-cream/80 backdrop-blur-sm">
|
<footer className="border-t border-sand-light/40 bg-cream">
|
||||||
<div className="max-w-2xl mx-auto px-4 py-3">
|
<div className="max-w-2xl mx-auto px-4 py-3">
|
||||||
<MessageInput
|
<MessageInput
|
||||||
query={query}
|
query={query}
|
||||||
@@ -425,6 +421,9 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
handleQuestionSubmit={handleQuestionSubmit}
|
handleQuestionSubmit={handleQuestionSubmit}
|
||||||
setSimbaMode={setSimbaMode}
|
setSimbaMode={setSimbaMode}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
|
pendingImage={pendingImage}
|
||||||
|
onImageSelect={(file) => setPendingImage(file)}
|
||||||
|
onClearImage={() => setPendingImage(null)}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</footer>
|
</footer>
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { useRef, useState } from "react";
|
import React, { useEffect, useMemo, useRef, useState } from "react";
|
||||||
import { ArrowUp, ImagePlus, X } from "lucide-react";
|
import { ArrowUp, ImagePlus, X } from "lucide-react";
|
||||||
import { cn } from "../lib/utils";
|
import { cn } from "../lib/utils";
|
||||||
import { Textarea } from "./ui/textarea";
|
import { Textarea } from "./ui/textarea";
|
||||||
@@ -15,7 +15,7 @@ type MessageInputProps = {
|
|||||||
onClearImage: () => void;
|
onClearImage: () => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const MessageInput = ({
|
export const MessageInput = React.memo(({
|
||||||
query,
|
query,
|
||||||
handleKeyDown,
|
handleKeyDown,
|
||||||
handleQueryChange,
|
handleQueryChange,
|
||||||
@@ -29,6 +29,18 @@ export const MessageInput = ({
|
|||||||
const [simbaMode, setLocalSimbaMode] = useState(false);
|
const [simbaMode, setLocalSimbaMode] = useState(false);
|
||||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
|
// Create blob URL once per file, revoke on cleanup
|
||||||
|
const previewUrl = useMemo(
|
||||||
|
() => (pendingImage ? URL.createObjectURL(pendingImage) : null),
|
||||||
|
[pendingImage],
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
return () => {
|
||||||
|
if (previewUrl) URL.revokeObjectURL(previewUrl);
|
||||||
|
};
|
||||||
|
}, [previewUrl]);
|
||||||
|
|
||||||
const toggleSimbaMode = () => {
|
const toggleSimbaMode = () => {
|
||||||
const next = !simbaMode;
|
const next = !simbaMode;
|
||||||
setLocalSimbaMode(next);
|
setLocalSimbaMode(next);
|
||||||
@@ -59,7 +71,7 @@ export const MessageInput = ({
|
|||||||
<div className="px-3 pt-3">
|
<div className="px-3 pt-3">
|
||||||
<div className="relative inline-block">
|
<div className="relative inline-block">
|
||||||
<img
|
<img
|
||||||
src={URL.createObjectURL(pendingImage)}
|
src={previewUrl!}
|
||||||
alt="Pending upload"
|
alt="Pending upload"
|
||||||
className="h-20 rounded-lg object-cover border border-sand"
|
className="h-20 rounded-lg object-cover border border-sand"
|
||||||
/>
|
/>
|
||||||
@@ -145,4 +157,4 @@ export const MessageInput = ({
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import { useEffect, useState } from "react";
|
||||||
import { cn } from "../lib/utils";
|
import { cn } from "../lib/utils";
|
||||||
import { conversationService } from "../api/conversationService";
|
import { conversationService } from "../api/conversationService";
|
||||||
|
|
||||||
@@ -7,6 +8,20 @@ type QuestionBubbleProps = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const QuestionBubble = ({ text, image_key }: QuestionBubbleProps) => {
|
export const QuestionBubble = ({ text, image_key }: QuestionBubbleProps) => {
|
||||||
|
const [imageUrl, setImageUrl] = useState<string | null>(null);
|
||||||
|
const [imageError, setImageError] = useState(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!image_key) return;
|
||||||
|
conversationService
|
||||||
|
.getPresignedImageUrl(image_key)
|
||||||
|
.then(setImageUrl)
|
||||||
|
.catch((err) => {
|
||||||
|
console.error("Failed to load image:", err);
|
||||||
|
setImageError(true);
|
||||||
|
});
|
||||||
|
}, [image_key]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex justify-end message-enter">
|
<div className="flex justify-end message-enter">
|
||||||
<div
|
<div
|
||||||
@@ -17,9 +32,15 @@ export const QuestionBubble = ({ text, image_key }: QuestionBubbleProps) => {
|
|||||||
"shadow-sm shadow-leaf/10",
|
"shadow-sm shadow-leaf/10",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{image_key && (
|
{imageError && (
|
||||||
|
<div className="flex items-center gap-2 text-xs text-charcoal/50 bg-charcoal/5 rounded-xl px-3 py-2 mb-2">
|
||||||
|
<span>🖼️</span>
|
||||||
|
<span>Image failed to load</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{imageUrl && (
|
||||||
<img
|
<img
|
||||||
src={conversationService.getImageUrl(image_key)}
|
src={imageUrl}
|
||||||
alt="Uploaded image"
|
alt="Uploaded image"
|
||||||
className="max-w-full rounded-xl mb-2"
|
className="max-w-full rounded-xl mb-2"
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -6,19 +6,19 @@ import asyncio
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from blueprints.rag.logic import (
|
from blueprints.rag.logic import (
|
||||||
|
delete_all_documents,
|
||||||
get_vector_store_stats,
|
get_vector_store_stats,
|
||||||
index_documents,
|
index_documents,
|
||||||
list_all_documents,
|
list_all_documents,
|
||||||
vector_store,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def stats():
|
def stats():
|
||||||
"""Show vector store statistics."""
|
"""Show vector store statistics."""
|
||||||
stats = get_vector_store_stats()
|
s = get_vector_store_stats()
|
||||||
print("=== Vector Store Statistics ===")
|
print("=== Vector Store Statistics ===")
|
||||||
print(f"Collection: {stats['collection_name']}")
|
print(f"Collection: {s['collection_name']}")
|
||||||
print(f"Total Documents: {stats['total_documents']}")
|
print(f"Total Documents: {s['total_documents']}")
|
||||||
|
|
||||||
|
|
||||||
async def index():
|
async def index():
|
||||||
@@ -26,23 +26,15 @@ async def index():
|
|||||||
print("Starting indexing process...")
|
print("Starting indexing process...")
|
||||||
print("Fetching documents from Paperless-NGX...")
|
print("Fetching documents from Paperless-NGX...")
|
||||||
await index_documents()
|
await index_documents()
|
||||||
print("✓ Indexing complete!")
|
print("Indexing complete!")
|
||||||
stats()
|
stats()
|
||||||
|
|
||||||
|
|
||||||
async def reindex():
|
async def reindex():
|
||||||
"""Clear and reindex all documents."""
|
"""Clear and reindex all documents."""
|
||||||
print("Clearing existing documents...")
|
print("Clearing existing documents...")
|
||||||
collection = vector_store._collection
|
delete_all_documents()
|
||||||
all_docs = collection.get()
|
print("Cleared")
|
||||||
|
|
||||||
if all_docs["ids"]:
|
|
||||||
print(f"Deleting {len(all_docs['ids'])} existing documents...")
|
|
||||||
collection.delete(ids=all_docs["ids"])
|
|
||||||
print("✓ Cleared")
|
|
||||||
else:
|
|
||||||
print("Collection is already empty")
|
|
||||||
|
|
||||||
await index()
|
await index()
|
||||||
|
|
||||||
|
|
||||||
@@ -113,7 +105,7 @@ Examples:
|
|||||||
print("\n\nOperation cancelled by user")
|
print("\n\nOperation cancelled by user")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n❌ Error: {e}", file=sys.stderr)
|
print(f"\nError: {e}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
from bs4 import BeautifulSoup
|
|
||||||
import chromadb
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
client = chromadb.PersistentClient(path="/Users/ryanchen/Programs/raggr/chromadb")
|
|
||||||
|
|
||||||
# Scrape
|
|
||||||
BASE_URL = "https://www.vet.cornell.edu"
|
|
||||||
LIST_URL = "/departments-centers-and-institutes/cornell-feline-health-center/health-information/feline-health-topics"
|
|
||||||
|
|
||||||
QUERY_URL = BASE_URL + LIST_URL
|
|
||||||
r = httpx.get(QUERY_URL)
|
|
||||||
soup = BeautifulSoup(r.text)
|
|
||||||
|
|
||||||
container = soup.find("div", class_="field-body")
|
|
||||||
a_s = container.find_all("a", href=True)
|
|
||||||
|
|
||||||
new_texts = []
|
|
||||||
|
|
||||||
for link in a_s:
|
|
||||||
endpoint = link["href"]
|
|
||||||
query_url = BASE_URL + endpoint
|
|
||||||
r2 = httpx.get(query_url)
|
|
||||||
article_soup = BeautifulSoup(r2.text)
|
|
||||||
@@ -1,9 +1,6 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
echo "Initializing directories..."
|
|
||||||
mkdir -p /app/data/chromadb
|
|
||||||
|
|
||||||
echo "Rebuilding frontend..."
|
echo "Rebuilding frontend..."
|
||||||
cd /app/raggr-frontend
|
cd /app/raggr-frontend
|
||||||
yarn build
|
yarn build
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Ensure project root is on the path so imports work
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
|
||||||
|
# Set FERNET_KEY for tests that import email models (EncryptedTextField needs it at import time)
|
||||||
|
if "FERNET_KEY" not in os.environ:
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
|
os.environ["FERNET_KEY"] = Fernet.generate_key().decode()
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
"""Tests for encryption/decryption in blueprints/email/crypto_service.py."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
|
|
||||||
|
# Generate a valid key for testing
|
||||||
|
TEST_FERNET_KEY = Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
|
||||||
|
class TestEncryptedTextField:
|
||||||
|
@pytest.fixture
|
||||||
|
def field(self):
|
||||||
|
with patch.dict(os.environ, {"FERNET_KEY": TEST_FERNET_KEY}):
|
||||||
|
from blueprints.email.crypto_service import EncryptedTextField
|
||||||
|
|
||||||
|
return EncryptedTextField()
|
||||||
|
|
||||||
|
def test_encrypt_decrypt_roundtrip(self, field):
|
||||||
|
original = "my secret password"
|
||||||
|
encrypted = field.to_db_value(original, None)
|
||||||
|
decrypted = field.to_python_value(encrypted)
|
||||||
|
assert decrypted == original
|
||||||
|
assert encrypted != original
|
||||||
|
|
||||||
|
def test_none_passthrough(self, field):
|
||||||
|
assert field.to_db_value(None, None) is None
|
||||||
|
assert field.to_python_value(None) is None
|
||||||
|
|
||||||
|
def test_unicode_roundtrip(self, field):
|
||||||
|
original = "Hello 世界 🐱"
|
||||||
|
encrypted = field.to_db_value(original, None)
|
||||||
|
decrypted = field.to_python_value(encrypted)
|
||||||
|
assert decrypted == original
|
||||||
|
|
||||||
|
def test_empty_string_roundtrip(self, field):
|
||||||
|
encrypted = field.to_db_value("", None)
|
||||||
|
decrypted = field.to_python_value(encrypted)
|
||||||
|
assert decrypted == ""
|
||||||
|
|
||||||
|
def test_long_text_roundtrip(self, field):
|
||||||
|
original = "x" * 10000
|
||||||
|
encrypted = field.to_db_value(original, None)
|
||||||
|
decrypted = field.to_python_value(encrypted)
|
||||||
|
assert decrypted == original
|
||||||
|
|
||||||
|
def test_different_encryptions_differ(self, field):
|
||||||
|
"""Fernet includes a timestamp, so two encryptions of the same value differ."""
|
||||||
|
e1 = field.to_db_value("same", None)
|
||||||
|
e2 = field.to_db_value("same", None)
|
||||||
|
assert e1 != e2 # Different ciphertexts
|
||||||
|
assert field.to_python_value(e1) == field.to_python_value(e2) == "same"
|
||||||
|
|
||||||
|
def test_wrong_key_fails(self, field):
|
||||||
|
encrypted = field.to_db_value("secret", None)
|
||||||
|
|
||||||
|
# Create a field with a different key
|
||||||
|
other_key = Fernet.generate_key().decode()
|
||||||
|
with patch.dict(os.environ, {"FERNET_KEY": other_key}):
|
||||||
|
from blueprints.email.crypto_service import EncryptedTextField
|
||||||
|
|
||||||
|
other_field = EncryptedTextField()
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
other_field.to_python_value(encrypted)
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateFernetKey:
|
||||||
|
def test_valid_key(self):
|
||||||
|
with patch.dict(os.environ, {"FERNET_KEY": TEST_FERNET_KEY}):
|
||||||
|
from blueprints.email.crypto_service import validate_fernet_key
|
||||||
|
|
||||||
|
validate_fernet_key() # Should not raise
|
||||||
|
|
||||||
|
def test_missing_key(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
os.environ.pop("FERNET_KEY", None)
|
||||||
|
from blueprints.email.crypto_service import validate_fernet_key
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not set"):
|
||||||
|
validate_fernet_key()
|
||||||
|
|
||||||
|
def test_invalid_key(self):
|
||||||
|
with patch.dict(os.environ, {"FERNET_KEY": "not-a-valid-key"}):
|
||||||
|
from blueprints.email.crypto_service import validate_fernet_key
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="validation failed"):
|
||||||
|
validate_fernet_key()
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
"""Tests for email helper functions in blueprints/email/helpers.py."""
|
||||||
|
|
||||||
|
from blueprints.email.helpers import generate_email_token, get_user_email_address
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateEmailToken:
|
||||||
|
def test_returns_16_char_hex(self):
|
||||||
|
token = generate_email_token("user-123", "my-secret")
|
||||||
|
assert len(token) == 16
|
||||||
|
assert all(c in "0123456789abcdef" for c in token)
|
||||||
|
|
||||||
|
def test_deterministic(self):
|
||||||
|
t1 = generate_email_token("user-123", "my-secret")
|
||||||
|
t2 = generate_email_token("user-123", "my-secret")
|
||||||
|
assert t1 == t2
|
||||||
|
|
||||||
|
def test_different_users_different_tokens(self):
|
||||||
|
t1 = generate_email_token("user-1", "secret")
|
||||||
|
t2 = generate_email_token("user-2", "secret")
|
||||||
|
assert t1 != t2
|
||||||
|
|
||||||
|
def test_different_secrets_different_tokens(self):
|
||||||
|
t1 = generate_email_token("user-1", "secret-a")
|
||||||
|
t2 = generate_email_token("user-1", "secret-b")
|
||||||
|
assert t1 != t2
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUserEmailAddress:
|
||||||
|
def test_formats_correctly(self):
|
||||||
|
addr = get_user_email_address("abc123", "example.com")
|
||||||
|
assert addr == "ask+abc123@example.com"
|
||||||
|
|
||||||
|
def test_preserves_token(self):
|
||||||
|
token = "deadbeef12345678"
|
||||||
|
addr = get_user_email_address(token, "mail.test.org")
|
||||||
|
assert token in addr
|
||||||
|
assert addr.startswith("ask+")
|
||||||
|
assert "@mail.test.org" in addr
|
||||||
@@ -0,0 +1,259 @@
|
|||||||
|
"""Tests for ObsidianService markdown parsing and file operations."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Set vault path before importing so __init__ validation passes
|
||||||
|
_test_vault_dir = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def vault_dir(tmp_path):
|
||||||
|
"""Create a temporary vault directory with a sample .md file."""
|
||||||
|
global _test_vault_dir
|
||||||
|
_test_vault_dir = tmp_path
|
||||||
|
|
||||||
|
# Create a sample markdown file so vault validation passes
|
||||||
|
sample = tmp_path / "sample.md"
|
||||||
|
sample.write_text("# Sample\nHello world")
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"OBSIDIAN_VAULT_PATH": str(tmp_path)}):
|
||||||
|
yield tmp_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service(vault_dir):
|
||||||
|
from utils.obsidian_service import ObsidianService
|
||||||
|
|
||||||
|
return ObsidianService()
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseMarkdown:
|
||||||
|
def test_extracts_frontmatter(self, service):
|
||||||
|
content = "---\ntitle: Test Note\ntags: [cat, vet]\n---\n\nBody content"
|
||||||
|
result = service.parse_markdown(content)
|
||||||
|
assert result["metadata"]["title"] == "Test Note"
|
||||||
|
assert result["metadata"]["tags"] == ["cat", "vet"]
|
||||||
|
|
||||||
|
def test_no_frontmatter(self, service):
|
||||||
|
content = "Just body content with no frontmatter"
|
||||||
|
result = service.parse_markdown(content)
|
||||||
|
assert result["metadata"] == {}
|
||||||
|
assert "Just body content" in result["content"]
|
||||||
|
|
||||||
|
def test_invalid_yaml_frontmatter(self, service):
|
||||||
|
content = "---\n: invalid: yaml: [[\n---\n\nBody"
|
||||||
|
result = service.parse_markdown(content)
|
||||||
|
assert result["metadata"] == {}
|
||||||
|
|
||||||
|
def test_extracts_tags(self, service):
|
||||||
|
content = "Some text with #tag1 and #tag2 here"
|
||||||
|
result = service.parse_markdown(content)
|
||||||
|
assert "tag1" in result["tags"]
|
||||||
|
assert "tag2" in result["tags"]
|
||||||
|
|
||||||
|
def test_extracts_wikilinks(self, service):
|
||||||
|
content = "Link to [[Other Note]] and [[Another Page]]"
|
||||||
|
result = service.parse_markdown(content)
|
||||||
|
assert "Other Note" in result["wikilinks"]
|
||||||
|
assert "Another Page" in result["wikilinks"]
|
||||||
|
|
||||||
|
def test_extracts_embeds(self, service):
|
||||||
|
content = "An embed [[!my_embed]] here"
|
||||||
|
result = service.parse_markdown(content)
|
||||||
|
assert "my_embed" in result["embeds"]
|
||||||
|
|
||||||
|
def test_cleans_wikilinks_from_content(self, service):
|
||||||
|
content = "Text with [[link]] included"
|
||||||
|
result = service.parse_markdown(content)
|
||||||
|
assert "[[" not in result["content"]
|
||||||
|
assert "]]" not in result["content"]
|
||||||
|
|
||||||
|
def test_filepath_passed_through(self, service):
|
||||||
|
result = service.parse_markdown("text", filepath=Path("/vault/note.md"))
|
||||||
|
assert result["filepath"] == "/vault/note.md"
|
||||||
|
|
||||||
|
def test_filepath_none_by_default(self, service):
|
||||||
|
result = service.parse_markdown("text")
|
||||||
|
assert result["filepath"] is None
|
||||||
|
|
||||||
|
def test_empty_content(self, service):
|
||||||
|
result = service.parse_markdown("")
|
||||||
|
assert result["metadata"] == {}
|
||||||
|
assert result["tags"] == []
|
||||||
|
assert result["wikilinks"] == []
|
||||||
|
assert result["embeds"] == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDailyNotePath:
|
||||||
|
def test_formats_path_correctly(self, service):
|
||||||
|
date = datetime(2026, 3, 15)
|
||||||
|
path = service.get_daily_note_path(date)
|
||||||
|
assert path == "journal/2026/2026-03-15.md"
|
||||||
|
|
||||||
|
def test_defaults_to_today(self, service):
|
||||||
|
path = service.get_daily_note_path()
|
||||||
|
today = datetime.now()
|
||||||
|
assert today.strftime("%Y-%m-%d") in path
|
||||||
|
assert path.startswith(f"journal/{today.strftime('%Y')}/")
|
||||||
|
|
||||||
|
|
||||||
|
class TestWalkVault:
|
||||||
|
def test_finds_markdown_files(self, service, vault_dir):
|
||||||
|
(vault_dir / "note1.md").write_text("# Note 1")
|
||||||
|
(vault_dir / "subdir").mkdir()
|
||||||
|
(vault_dir / "subdir" / "note2.md").write_text("# Note 2")
|
||||||
|
|
||||||
|
files = service.walk_vault()
|
||||||
|
filenames = [f.name for f in files]
|
||||||
|
assert "sample.md" in filenames
|
||||||
|
assert "note1.md" in filenames
|
||||||
|
assert "note2.md" in filenames
|
||||||
|
|
||||||
|
def test_excludes_obsidian_dir(self, service, vault_dir):
|
||||||
|
obsidian_dir = vault_dir / ".obsidian"
|
||||||
|
obsidian_dir.mkdir()
|
||||||
|
(obsidian_dir / "config.md").write_text("config")
|
||||||
|
|
||||||
|
files = service.walk_vault()
|
||||||
|
filenames = [f.name for f in files]
|
||||||
|
assert "config.md" not in filenames
|
||||||
|
|
||||||
|
def test_ignores_non_md_files(self, service, vault_dir):
|
||||||
|
(vault_dir / "image.png").write_bytes(b"\x89PNG")
|
||||||
|
|
||||||
|
files = service.walk_vault()
|
||||||
|
filenames = [f.name for f in files]
|
||||||
|
assert "image.png" not in filenames
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateNote:
|
||||||
|
def test_creates_file(self, service, vault_dir):
|
||||||
|
path = service.create_note("My Test Note", "Body content")
|
||||||
|
full_path = vault_dir / path
|
||||||
|
assert full_path.exists()
|
||||||
|
|
||||||
|
def test_sanitizes_title(self, service, vault_dir):
|
||||||
|
path = service.create_note("Hello World! @#$", "Body")
|
||||||
|
assert "hello-world" in path
|
||||||
|
assert "@" not in path
|
||||||
|
assert "#" not in path
|
||||||
|
|
||||||
|
def test_includes_frontmatter(self, service, vault_dir):
|
||||||
|
path = service.create_note("Test", "Body", tags=["cat", "vet"])
|
||||||
|
full_path = vault_dir / path
|
||||||
|
content = full_path.read_text()
|
||||||
|
assert "---" in content
|
||||||
|
assert "created_by: simbarag" in content
|
||||||
|
assert "cat" in content
|
||||||
|
assert "vet" in content
|
||||||
|
|
||||||
|
def test_custom_folder(self, service, vault_dir):
|
||||||
|
path = service.create_note("Test", "Body", folder="custom/subfolder")
|
||||||
|
assert path.startswith("custom/subfolder/")
|
||||||
|
assert (vault_dir / path).exists()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDailyNoteTasks:
|
||||||
|
def test_get_tasks_from_daily_note(self, service, vault_dir):
|
||||||
|
# Create a daily note with tasks
|
||||||
|
date = datetime(2026, 1, 15)
|
||||||
|
rel_path = service.get_daily_note_path(date)
|
||||||
|
note_path = vault_dir / rel_path
|
||||||
|
note_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
note_path.write_text(
|
||||||
|
"---\nmodified: 2026-01-15\n---\n"
|
||||||
|
"### tasks\n\n"
|
||||||
|
"- [ ] Feed the cat\n"
|
||||||
|
"- [x] Clean litter box\n"
|
||||||
|
"- [ ] Buy cat food\n\n"
|
||||||
|
"### log\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = service.get_daily_tasks(date)
|
||||||
|
assert result["found"] is True
|
||||||
|
assert len(result["tasks"]) == 3
|
||||||
|
assert result["tasks"][0] == {"text": "Feed the cat", "done": False}
|
||||||
|
assert result["tasks"][1] == {"text": "Clean litter box", "done": True}
|
||||||
|
assert result["tasks"][2] == {"text": "Buy cat food", "done": False}
|
||||||
|
|
||||||
|
def test_get_tasks_no_note(self, service):
|
||||||
|
date = datetime(2099, 12, 31)
|
||||||
|
result = service.get_daily_tasks(date)
|
||||||
|
assert result["found"] is False
|
||||||
|
assert result["tasks"] == []
|
||||||
|
|
||||||
|
def test_add_task_creates_note(self, service, vault_dir):
|
||||||
|
date = datetime(2026, 6, 1)
|
||||||
|
result = service.add_task_to_daily_note("Walk the cat", date)
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["created_note"] is True
|
||||||
|
|
||||||
|
# Verify file was created with the task
|
||||||
|
note_path = vault_dir / result["path"]
|
||||||
|
content = note_path.read_text()
|
||||||
|
assert "- [ ] Walk the cat" in content
|
||||||
|
|
||||||
|
def test_add_task_to_existing_note(self, service, vault_dir):
|
||||||
|
date = datetime(2026, 6, 2)
|
||||||
|
rel_path = service.get_daily_note_path(date)
|
||||||
|
note_path = vault_dir / rel_path
|
||||||
|
note_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
note_path.write_text(
|
||||||
|
"---\nmodified: 2026-06-02\n---\n"
|
||||||
|
"### tasks\n\n"
|
||||||
|
"- [ ] Existing task\n\n"
|
||||||
|
"### log\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = service.add_task_to_daily_note("New task", date)
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["created_note"] is False
|
||||||
|
|
||||||
|
content = note_path.read_text()
|
||||||
|
assert "- [ ] Existing task" in content
|
||||||
|
assert "- [ ] New task" in content
|
||||||
|
|
||||||
|
def test_complete_task_exact_match(self, service, vault_dir):
|
||||||
|
date = datetime(2026, 6, 3)
|
||||||
|
rel_path = service.get_daily_note_path(date)
|
||||||
|
note_path = vault_dir / rel_path
|
||||||
|
note_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
note_path.write_text("### tasks\n\n" "- [ ] Feed the cat\n" "- [ ] Buy food\n")
|
||||||
|
|
||||||
|
result = service.complete_task_in_daily_note("Feed the cat", date)
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
content = note_path.read_text()
|
||||||
|
assert "- [x] Feed the cat" in content
|
||||||
|
assert "- [ ] Buy food" in content # Other task unchanged
|
||||||
|
|
||||||
|
def test_complete_task_partial_match(self, service, vault_dir):
|
||||||
|
date = datetime(2026, 6, 4)
|
||||||
|
rel_path = service.get_daily_note_path(date)
|
||||||
|
note_path = vault_dir / rel_path
|
||||||
|
note_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
note_path.write_text("### tasks\n\n- [ ] Feed the cat at 5pm\n")
|
||||||
|
|
||||||
|
result = service.complete_task_in_daily_note("Feed the cat", date)
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
def test_complete_task_not_found(self, service, vault_dir):
|
||||||
|
date = datetime(2026, 6, 5)
|
||||||
|
rel_path = service.get_daily_note_path(date)
|
||||||
|
note_path = vault_dir / rel_path
|
||||||
|
note_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
note_path.write_text("### tasks\n\n- [ ] Feed the cat\n")
|
||||||
|
|
||||||
|
result = service.complete_task_in_daily_note("Walk the dog", date)
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "not found" in result["error"]
|
||||||
|
|
||||||
|
def test_complete_task_no_note(self, service):
|
||||||
|
date = datetime(2099, 12, 31)
|
||||||
|
result = service.complete_task_in_daily_note("Something", date)
|
||||||
|
assert result["success"] is False
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
"""Tests for rate limiting logic in email and WhatsApp blueprints."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmailRateLimit:
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset rate limit store before each test."""
|
||||||
|
from blueprints.email import _rate_limit_store
|
||||||
|
|
||||||
|
_rate_limit_store.clear()
|
||||||
|
|
||||||
|
def test_allows_under_limit(self):
|
||||||
|
from blueprints.email import _check_rate_limit, RATE_LIMIT_MAX
|
||||||
|
|
||||||
|
for _ in range(RATE_LIMIT_MAX):
|
||||||
|
assert _check_rate_limit("sender@test.com") is True
|
||||||
|
|
||||||
|
def test_blocks_at_limit(self):
|
||||||
|
from blueprints.email import _check_rate_limit, RATE_LIMIT_MAX
|
||||||
|
|
||||||
|
for _ in range(RATE_LIMIT_MAX):
|
||||||
|
_check_rate_limit("sender@test.com")
|
||||||
|
|
||||||
|
assert _check_rate_limit("sender@test.com") is False
|
||||||
|
|
||||||
|
def test_different_senders_independent(self):
|
||||||
|
from blueprints.email import _check_rate_limit, RATE_LIMIT_MAX
|
||||||
|
|
||||||
|
for _ in range(RATE_LIMIT_MAX):
|
||||||
|
_check_rate_limit("user1@test.com")
|
||||||
|
|
||||||
|
# user1 is at limit, but user2 should be fine
|
||||||
|
assert _check_rate_limit("user1@test.com") is False
|
||||||
|
assert _check_rate_limit("user2@test.com") is True
|
||||||
|
|
||||||
|
def test_window_expiry(self):
|
||||||
|
from blueprints.email import (
|
||||||
|
_check_rate_limit,
|
||||||
|
_rate_limit_store,
|
||||||
|
RATE_LIMIT_MAX,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fill up the rate limit with timestamps in the past
|
||||||
|
past = time.monotonic() - 999 # Well beyond any window
|
||||||
|
_rate_limit_store["old@test.com"] = [past] * RATE_LIMIT_MAX
|
||||||
|
|
||||||
|
# Should be allowed because all timestamps are expired
|
||||||
|
assert _check_rate_limit("old@test.com") is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestWhatsAppRateLimit:
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset rate limit store before each test."""
|
||||||
|
from blueprints.whatsapp import _rate_limit_store
|
||||||
|
|
||||||
|
_rate_limit_store.clear()
|
||||||
|
|
||||||
|
def test_allows_under_limit(self):
|
||||||
|
from blueprints.whatsapp import _check_rate_limit, RATE_LIMIT_MAX
|
||||||
|
|
||||||
|
for _ in range(RATE_LIMIT_MAX):
|
||||||
|
assert _check_rate_limit("whatsapp:+1234567890") is True
|
||||||
|
|
||||||
|
def test_blocks_at_limit(self):
|
||||||
|
from blueprints.whatsapp import _check_rate_limit, RATE_LIMIT_MAX
|
||||||
|
|
||||||
|
for _ in range(RATE_LIMIT_MAX):
|
||||||
|
_check_rate_limit("whatsapp:+1234567890")
|
||||||
|
|
||||||
|
assert _check_rate_limit("whatsapp:+1234567890") is False
|
||||||
|
|
||||||
|
def test_different_numbers_independent(self):
|
||||||
|
from blueprints.whatsapp import _check_rate_limit, RATE_LIMIT_MAX
|
||||||
|
|
||||||
|
for _ in range(RATE_LIMIT_MAX):
|
||||||
|
_check_rate_limit("whatsapp:+1111111111")
|
||||||
|
|
||||||
|
assert _check_rate_limit("whatsapp:+1111111111") is False
|
||||||
|
assert _check_rate_limit("whatsapp:+2222222222") is True
|
||||||
|
|
||||||
|
def test_window_expiry(self):
|
||||||
|
from blueprints.whatsapp import (
|
||||||
|
_check_rate_limit,
|
||||||
|
_rate_limit_store,
|
||||||
|
RATE_LIMIT_MAX,
|
||||||
|
)
|
||||||
|
|
||||||
|
past = time.monotonic() - 999
|
||||||
|
_rate_limit_store["whatsapp:+9999999999"] = [past] * RATE_LIMIT_MAX
|
||||||
|
|
||||||
|
assert _check_rate_limit("whatsapp:+9999999999") is True
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
"""Tests for User model methods in blueprints/users/models.py."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserModelMethods:
|
||||||
|
"""Test User model methods without requiring a database connection.
|
||||||
|
|
||||||
|
We instantiate a mock object with the same methods as User
|
||||||
|
to avoid Tortoise ORM initialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _make_user(self, ldap_groups=None, password=None):
|
||||||
|
"""Create a mock user with real method implementations."""
|
||||||
|
from blueprints.users.models import User
|
||||||
|
|
||||||
|
user = MagicMock(spec=User)
|
||||||
|
user.ldap_groups = ldap_groups
|
||||||
|
user.password = password
|
||||||
|
|
||||||
|
# Bind real methods
|
||||||
|
user.has_group = lambda group: group in (user.ldap_groups or [])
|
||||||
|
user.is_admin = lambda: user.has_group("lldap_admin")
|
||||||
|
|
||||||
|
def set_password(plain):
|
||||||
|
user.password = bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt())
|
||||||
|
|
||||||
|
user.set_password = set_password
|
||||||
|
|
||||||
|
def verify_password(plain):
|
||||||
|
if not user.password:
|
||||||
|
return False
|
||||||
|
return bcrypt.checkpw(plain.encode("utf-8"), user.password)
|
||||||
|
|
||||||
|
user.verify_password = verify_password
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
def test_has_group_true(self):
|
||||||
|
user = self._make_user(ldap_groups=["lldap_admin", "users"])
|
||||||
|
assert user.has_group("lldap_admin") is True
|
||||||
|
assert user.has_group("users") is True
|
||||||
|
|
||||||
|
def test_has_group_false(self):
|
||||||
|
user = self._make_user(ldap_groups=["users"])
|
||||||
|
assert user.has_group("lldap_admin") is False
|
||||||
|
|
||||||
|
def test_has_group_empty_list(self):
|
||||||
|
user = self._make_user(ldap_groups=[])
|
||||||
|
assert user.has_group("anything") is False
|
||||||
|
|
||||||
|
def test_has_group_none(self):
|
||||||
|
user = self._make_user(ldap_groups=None)
|
||||||
|
assert user.has_group("anything") is False
|
||||||
|
|
||||||
|
def test_is_admin_true(self):
|
||||||
|
user = self._make_user(ldap_groups=["lldap_admin"])
|
||||||
|
assert user.is_admin() is True
|
||||||
|
|
||||||
|
def test_is_admin_false(self):
|
||||||
|
user = self._make_user(ldap_groups=["users"])
|
||||||
|
assert user.is_admin() is False
|
||||||
|
|
||||||
|
def test_is_admin_empty(self):
|
||||||
|
user = self._make_user(ldap_groups=[])
|
||||||
|
assert user.is_admin() is False
|
||||||
|
|
||||||
|
def test_set_and_verify_password(self):
|
||||||
|
user = self._make_user()
|
||||||
|
user.set_password("hunter2")
|
||||||
|
assert user.password is not None
|
||||||
|
assert user.verify_password("hunter2") is True
|
||||||
|
assert user.verify_password("wrong") is False
|
||||||
|
|
||||||
|
def test_verify_password_no_password_set(self):
|
||||||
|
user = self._make_user(password=None)
|
||||||
|
assert user.verify_password("anything") is False
|
||||||
|
|
||||||
|
def test_password_is_hashed(self):
|
||||||
|
user = self._make_user()
|
||||||
|
user.set_password("mypassword")
|
||||||
|
# The stored password should not be the plaintext
|
||||||
|
assert user.password != b"mypassword"
|
||||||
|
assert user.password != "mypassword"
|
||||||
@@ -0,0 +1,254 @@
|
|||||||
|
"""Tests for YNAB service data formatting and filtering logic."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_category(
|
||||||
|
name, budgeted, activity, balance, deleted=False, hidden=False, goal_type=None
|
||||||
|
):
|
||||||
|
cat = MagicMock()
|
||||||
|
cat.name = name
|
||||||
|
cat.budgeted = budgeted
|
||||||
|
cat.activity = activity
|
||||||
|
cat.balance = balance
|
||||||
|
cat.deleted = deleted
|
||||||
|
cat.hidden = hidden
|
||||||
|
cat.goal_type = goal_type
|
||||||
|
return cat
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_transaction(
|
||||||
|
var_date, payee_name, category_name, amount, memo="", deleted=False, approved=True
|
||||||
|
):
|
||||||
|
txn = MagicMock()
|
||||||
|
txn.var_date = var_date
|
||||||
|
txn.payee_name = payee_name
|
||||||
|
txn.category_name = category_name
|
||||||
|
txn.amount = amount
|
||||||
|
txn.memo = memo
|
||||||
|
txn.deleted = deleted
|
||||||
|
txn.approved = approved
|
||||||
|
return txn
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ynab_service():
|
||||||
|
"""Create a YNABService with mocked API client."""
|
||||||
|
with patch.dict(
|
||||||
|
os.environ, {"YNAB_ACCESS_TOKEN": "fake-token", "YNAB_BUDGET_ID": "test-budget"}
|
||||||
|
):
|
||||||
|
with patch("utils.ynab_service.ynab") as mock_ynab:
|
||||||
|
# Mock the configuration and API client chain
|
||||||
|
mock_ynab.Configuration.return_value = MagicMock()
|
||||||
|
mock_ynab.ApiClient.return_value = MagicMock()
|
||||||
|
mock_ynab.PlansApi.return_value = MagicMock()
|
||||||
|
mock_ynab.TransactionsApi.return_value = MagicMock()
|
||||||
|
mock_ynab.MonthsApi.return_value = MagicMock()
|
||||||
|
mock_ynab.CategoriesApi.return_value = MagicMock()
|
||||||
|
|
||||||
|
from utils.ynab_service import YNABService
|
||||||
|
|
||||||
|
service = YNABService()
|
||||||
|
yield service
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetBudgetSummary:
|
||||||
|
def test_calculates_totals(self, ynab_service):
|
||||||
|
categories = [
|
||||||
|
_mock_category("Groceries", 500_000, -350_000, 150_000),
|
||||||
|
_mock_category("Rent", 1_500_000, -1_500_000, 0),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_month = MagicMock()
|
||||||
|
mock_month.to_be_budgeted = 200_000
|
||||||
|
|
||||||
|
mock_budget = MagicMock()
|
||||||
|
mock_budget.name = "My Budget"
|
||||||
|
mock_budget.months = [mock_month]
|
||||||
|
mock_budget.categories = categories
|
||||||
|
mock_budget.currency_format = MagicMock(iso_code="USD")
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data.budget = mock_budget
|
||||||
|
ynab_service.plans_api.get_plan_by_id.return_value = mock_response
|
||||||
|
|
||||||
|
result = ynab_service.get_budget_summary()
|
||||||
|
|
||||||
|
assert result["budget_name"] == "My Budget"
|
||||||
|
assert result["to_be_budgeted"] == 200.0
|
||||||
|
assert result["total_budgeted"] == 2000.0 # (500k + 1500k) / 1000
|
||||||
|
assert result["total_activity"] == -1850.0
|
||||||
|
assert result["currency_format"] == "USD"
|
||||||
|
|
||||||
|
def test_skips_deleted_and_hidden(self, ynab_service):
|
||||||
|
categories = [
|
||||||
|
_mock_category("Active", 100_000, -50_000, 50_000),
|
||||||
|
_mock_category("Deleted", 999_000, -999_000, 0, deleted=True),
|
||||||
|
_mock_category("Hidden", 999_000, -999_000, 0, hidden=True),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_month = MagicMock()
|
||||||
|
mock_month.to_be_budgeted = 0
|
||||||
|
mock_budget = MagicMock()
|
||||||
|
mock_budget.name = "Budget"
|
||||||
|
mock_budget.months = [mock_month]
|
||||||
|
mock_budget.categories = categories
|
||||||
|
mock_budget.currency_format = None
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data.budget = mock_budget
|
||||||
|
ynab_service.plans_api.get_plan_by_id.return_value = mock_response
|
||||||
|
|
||||||
|
result = ynab_service.get_budget_summary()
|
||||||
|
assert result["total_budgeted"] == 100.0
|
||||||
|
assert result["currency_format"] == "USD" # Default fallback
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetTransactions:
|
||||||
|
def test_filters_by_date_range(self, ynab_service):
|
||||||
|
transactions = [
|
||||||
|
_mock_transaction("2026-01-05", "Store", "Groceries", -25_000),
|
||||||
|
_mock_transaction("2026-01-15", "Gas", "Transport", -40_000),
|
||||||
|
_mock_transaction(
|
||||||
|
"2026-02-01", "Store", "Groceries", -30_000
|
||||||
|
), # Out of range
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data.transactions = transactions
|
||||||
|
ynab_service.transactions_api.get_transactions.return_value = mock_response
|
||||||
|
|
||||||
|
result = ynab_service.get_transactions(
|
||||||
|
start_date="2026-01-01", end_date="2026-01-31"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["count"] == 2
|
||||||
|
assert result["total_amount"] == -65.0
|
||||||
|
|
||||||
|
def test_filters_by_category(self, ynab_service):
|
||||||
|
transactions = [
|
||||||
|
_mock_transaction("2026-01-05", "Store", "Groceries", -25_000),
|
||||||
|
_mock_transaction("2026-01-06", "Gas", "Transport", -40_000),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data.transactions = transactions
|
||||||
|
ynab_service.transactions_api.get_transactions.return_value = mock_response
|
||||||
|
|
||||||
|
result = ynab_service.get_transactions(
|
||||||
|
start_date="2026-01-01",
|
||||||
|
end_date="2026-01-31",
|
||||||
|
category_name="groceries", # Case insensitive
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["count"] == 1
|
||||||
|
assert result["transactions"][0]["category"] == "Groceries"
|
||||||
|
|
||||||
|
def test_filters_by_payee(self, ynab_service):
|
||||||
|
transactions = [
|
||||||
|
_mock_transaction("2026-01-05", "Whole Foods", "Groceries", -25_000),
|
||||||
|
_mock_transaction("2026-01-06", "Shell Gas", "Transport", -40_000),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data.transactions = transactions
|
||||||
|
ynab_service.transactions_api.get_transactions.return_value = mock_response
|
||||||
|
|
||||||
|
result = ynab_service.get_transactions(
|
||||||
|
start_date="2026-01-01",
|
||||||
|
end_date="2026-01-31",
|
||||||
|
payee_name="whole",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["count"] == 1
|
||||||
|
assert result["transactions"][0]["payee"] == "Whole Foods"
|
||||||
|
|
||||||
|
def test_skips_deleted(self, ynab_service):
|
||||||
|
transactions = [
|
||||||
|
_mock_transaction("2026-01-05", "Store", "Groceries", -25_000),
|
||||||
|
_mock_transaction("2026-01-06", "Deleted", "Other", -10_000, deleted=True),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data.transactions = transactions
|
||||||
|
ynab_service.transactions_api.get_transactions.return_value = mock_response
|
||||||
|
|
||||||
|
result = ynab_service.get_transactions(
|
||||||
|
start_date="2026-01-01", end_date="2026-01-31"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["count"] == 1
|
||||||
|
|
||||||
|
def test_converts_milliunits(self, ynab_service):
|
||||||
|
transactions = [
|
||||||
|
_mock_transaction("2026-01-05", "Store", "Groceries", -12_340),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data.transactions = transactions
|
||||||
|
ynab_service.transactions_api.get_transactions.return_value = mock_response
|
||||||
|
|
||||||
|
result = ynab_service.get_transactions(
|
||||||
|
start_date="2026-01-01", end_date="2026-01-31"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["transactions"][0]["amount"] == -12.34
|
||||||
|
|
||||||
|
def test_sorts_by_date_descending(self, ynab_service):
|
||||||
|
transactions = [
|
||||||
|
_mock_transaction("2026-01-01", "A", "Cat", -10_000),
|
||||||
|
_mock_transaction("2026-01-15", "B", "Cat", -20_000),
|
||||||
|
_mock_transaction("2026-01-10", "C", "Cat", -30_000),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data.transactions = transactions
|
||||||
|
ynab_service.transactions_api.get_transactions.return_value = mock_response
|
||||||
|
|
||||||
|
result = ynab_service.get_transactions(
|
||||||
|
start_date="2026-01-01", end_date="2026-01-31"
|
||||||
|
)
|
||||||
|
|
||||||
|
dates = [t["date"] for t in result["transactions"]]
|
||||||
|
assert dates == sorted(dates, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCategorySpending:
|
||||||
|
def test_month_format_normalization(self, ynab_service):
|
||||||
|
"""Passing YYYY-MM should be normalized to YYYY-MM-01."""
|
||||||
|
categories = [_mock_category("Food", 100_000, -50_000, 50_000)]
|
||||||
|
|
||||||
|
mock_month = MagicMock()
|
||||||
|
mock_month.categories = categories
|
||||||
|
mock_month.to_be_budgeted = 0
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data.month = mock_month
|
||||||
|
ynab_service.months_api.get_plan_month.return_value = mock_response
|
||||||
|
|
||||||
|
result = ynab_service.get_category_spending("2026-03")
|
||||||
|
|
||||||
|
assert result["month"] == "2026-03"
|
||||||
|
|
||||||
|
def test_identifies_overspent(self, ynab_service):
|
||||||
|
categories = [
|
||||||
|
_mock_category("Dining", 200_000, -300_000, -100_000), # Overspent
|
||||||
|
_mock_category("Groceries", 500_000, -400_000, 100_000), # Fine
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_month = MagicMock()
|
||||||
|
mock_month.categories = categories
|
||||||
|
mock_month.to_be_budgeted = 0
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data.month = mock_month
|
||||||
|
ynab_service.months_api.get_plan_month.return_value = mock_response
|
||||||
|
|
||||||
|
result = ynab_service.get_category_spending("2026-03")
|
||||||
|
|
||||||
|
assert len(result["overspent_categories"]) == 1
|
||||||
|
assert result["overspent_categories"][0]["name"] == "Dining"
|
||||||
|
assert result["overspent_categories"][0]["overspent_by"] == 100.0
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
import os
|
|
||||||
from math import ceil
|
|
||||||
import re
|
|
||||||
from typing import Union
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
|
||||||
OpenAIEmbeddingFunction,
|
|
||||||
)
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from llm import LLMClient
|
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
|
|
||||||
if header_patterns is None:
|
|
||||||
header_patterns = [r"^.*Header.*$"]
|
|
||||||
if footer_patterns is None:
|
|
||||||
footer_patterns = [r"^.*Footer.*$"]
|
|
||||||
|
|
||||||
for pattern in header_patterns + footer_patterns:
|
|
||||||
text = re.sub(pattern, "", text, flags=re.MULTILINE)
|
|
||||||
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def remove_special_characters(text, special_chars=None):
|
|
||||||
if special_chars is None:
|
|
||||||
special_chars = r"[^A-Za-z0-9\s\.,;:\'\"\?\!\-]"
|
|
||||||
|
|
||||||
text = re.sub(special_chars, "", text)
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def remove_repeated_substrings(text, pattern=r"\.{2,}"):
|
|
||||||
text = re.sub(pattern, ".", text)
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def remove_extra_spaces(text):
|
|
||||||
text = re.sub(r"\n\s*\n", "\n\n", text)
|
|
||||||
text = re.sub(r"\s+", " ", text)
|
|
||||||
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_text(text):
|
|
||||||
# Remove headers and footers
|
|
||||||
text = remove_headers_footers(text)
|
|
||||||
|
|
||||||
# Remove special characters
|
|
||||||
text = remove_special_characters(text)
|
|
||||||
|
|
||||||
# Remove repeated substrings like dots
|
|
||||||
text = remove_repeated_substrings(text)
|
|
||||||
|
|
||||||
# Remove extra spaces between lines and within lines
|
|
||||||
text = remove_extra_spaces(text)
|
|
||||||
|
|
||||||
# Additional cleaning steps can be added here
|
|
||||||
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
class Chunk:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
size: int,
|
|
||||||
document_id: UUID,
|
|
||||||
chunk_id: int,
|
|
||||||
embedding,
|
|
||||||
):
|
|
||||||
self.text = text
|
|
||||||
self.size = size
|
|
||||||
self.document_id = document_id
|
|
||||||
self.chunk_id = chunk_id
|
|
||||||
self.embedding = embedding
|
|
||||||
|
|
||||||
|
|
||||||
class Chunker:
|
|
||||||
def __init__(self, collection) -> None:
|
|
||||||
self.collection = collection
|
|
||||||
self.llm_client = LLMClient()
|
|
||||||
|
|
||||||
def embedding_fx(self, inputs):
|
|
||||||
openai_embedding_fx = OpenAIEmbeddingFunction(
|
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
|
||||||
model_name="text-embedding-3-small",
|
|
||||||
)
|
|
||||||
return openai_embedding_fx(inputs)
|
|
||||||
|
|
||||||
def chunk_document(
|
|
||||||
self,
|
|
||||||
document: str,
|
|
||||||
chunk_size: int = 1000,
|
|
||||||
metadata: dict[str, Union[str, float]] = {},
|
|
||||||
) -> list[Chunk]:
|
|
||||||
doc_uuid = uuid4()
|
|
||||||
|
|
||||||
chunk_size = min(chunk_size, len(document)) or 1
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
num_chunks = ceil(len(document) / chunk_size)
|
|
||||||
document_length = len(document)
|
|
||||||
|
|
||||||
for i in range(num_chunks):
|
|
||||||
curr_pos = i * num_chunks
|
|
||||||
to_pos = (
|
|
||||||
curr_pos + chunk_size
|
|
||||||
if curr_pos + chunk_size < document_length
|
|
||||||
else document_length
|
|
||||||
)
|
|
||||||
text_chunk = self.clean_document(document[curr_pos:to_pos])
|
|
||||||
|
|
||||||
embedding = self.embedding_fx([text_chunk])
|
|
||||||
self.collection.add(
|
|
||||||
ids=[str(doc_uuid) + ":" + str(i)],
|
|
||||||
documents=[text_chunk],
|
|
||||||
embeddings=embedding,
|
|
||||||
metadatas=[metadata],
|
|
||||||
)
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
def clean_document(self, document: str) -> str:
|
|
||||||
"""This function will remove information that is noise or already known.
|
|
||||||
|
|
||||||
Example: We already know all the things in here are Simba-related, so we don't need things like
|
|
||||||
"Sumamry of simba's visit"
|
|
||||||
"""
|
|
||||||
|
|
||||||
document = document.replace("\\n", "")
|
|
||||||
document = document.strip()
|
|
||||||
|
|
||||||
return preprocess_text(document)
|
|
||||||
@@ -47,6 +47,16 @@ async def get_image(key: str) -> tuple[bytes, str]:
|
|||||||
return body, content_type
|
return body, content_type
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_presigned_url(key: str, expires_in: int = 3600) -> str:
|
||||||
|
async with _get_client() as client:
|
||||||
|
url = await client.generate_presigned_url(
|
||||||
|
"get_object",
|
||||||
|
Params={"Bucket": S3_BUCKET_NAME, "Key": key},
|
||||||
|
ExpiresIn=expires_in,
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
async def delete_image(key: str) -> None:
|
async def delete_image(key: str) -> None:
|
||||||
async with _get_client() as client:
|
async with _get_client() as client:
|
||||||
await client.delete_object(Bucket=S3_BUCKET_NAME, Key=key)
|
await client.delete_object(Bucket=S3_BUCKET_NAME, Key=key)
|
||||||
|
|||||||
Reference in New Issue
Block a user