Compare commits
4 Commits
0f88d211de
...
2fcf84f5d2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2fcf84f5d2 | ||
|
|
142fac3a84 | ||
|
|
0415610d64 | ||
|
|
ac9c821ec7 |
20
app.py
20
app.py
@@ -13,6 +13,7 @@ import blueprints.users
|
||||
import blueprints.whatsapp
|
||||
import blueprints.email
|
||||
import blueprints.users.models
|
||||
from config.db import TORTOISE_CONFIG
|
||||
from main import consult_simba_oracle
|
||||
|
||||
# Load environment variables
|
||||
@@ -28,6 +29,7 @@ app = Quart(
|
||||
)
|
||||
|
||||
app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY")
|
||||
app.config["MAX_CONTENT_LENGTH"] = 10 * 1024 * 1024 # 10 MB upload limit
|
||||
jwt = JWTManager(app)
|
||||
|
||||
# Register blueprints
|
||||
@@ -38,24 +40,6 @@ app.register_blueprint(blueprints.whatsapp.whatsapp_blueprint)
|
||||
app.register_blueprint(blueprints.email.email_blueprint)
|
||||
|
||||
|
||||
# Database configuration with environment variable support
|
||||
DATABASE_URL = os.getenv(
|
||||
"DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr"
|
||||
)
|
||||
|
||||
TORTOISE_CONFIG = {
|
||||
"connections": {"default": DATABASE_URL},
|
||||
"apps": {
|
||||
"models": {
|
||||
"models": [
|
||||
"blueprints.conversation.models",
|
||||
"blueprints.users.models",
|
||||
"aerich.models",
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Initialize Tortoise ORM with lifecycle hooks
|
||||
@app.while_serving
|
||||
async def lifespan():
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from quart import Blueprint, jsonify, make_response, request
|
||||
from quart import Blueprint, Response, jsonify, make_response, request
|
||||
from quart_jwt_extended import (
|
||||
get_jwt_identity,
|
||||
jwt_refresh_token_required,
|
||||
)
|
||||
|
||||
import blueprints.users.models
|
||||
from utils.image_process import analyze_user_image
|
||||
from utils.image_upload import ImageValidationError, process_image
|
||||
from utils.s3_client import get_image as s3_get_image
|
||||
from utils.s3_client import upload_image as s3_upload_image
|
||||
|
||||
from .agents import main_agent
|
||||
from .logic import (
|
||||
@@ -29,7 +35,9 @@ conversation_blueprint = Blueprint(
|
||||
_SYSTEM_PROMPT = SIMBA_SYSTEM_PROMPT
|
||||
|
||||
|
||||
def _build_messages_payload(conversation, query_text: str) -> list:
|
||||
def _build_messages_payload(
|
||||
conversation, query_text: str, image_description: str | None = None
|
||||
) -> list:
|
||||
recent_messages = (
|
||||
conversation.messages[-10:]
|
||||
if len(conversation.messages) > 10
|
||||
@@ -38,8 +46,19 @@ def _build_messages_payload(conversation, query_text: str) -> list:
|
||||
messages_payload = [{"role": "system", "content": _SYSTEM_PROMPT}]
|
||||
for msg in recent_messages[:-1]: # Exclude the message we just added
|
||||
role = "user" if msg.speaker == "user" else "assistant"
|
||||
messages_payload.append({"role": role, "content": msg.text})
|
||||
messages_payload.append({"role": "user", "content": query_text})
|
||||
text = msg.text
|
||||
if msg.image_key and role == "user":
|
||||
text = f"[User sent an image]\n{text}"
|
||||
messages_payload.append({"role": role, "content": text})
|
||||
|
||||
# Build the current user message with optional image description
|
||||
if image_description:
|
||||
content = f"[Image analysis: {image_description}]"
|
||||
if query_text:
|
||||
content = f"{query_text}\n\n{content}"
|
||||
else:
|
||||
content = query_text
|
||||
messages_payload.append({"role": "user", "content": content})
|
||||
return messages_payload
|
||||
|
||||
|
||||
@@ -74,6 +93,58 @@ async def query():
|
||||
return jsonify({"response": message})
|
||||
|
||||
|
||||
@conversation_blueprint.post("/upload-image")
|
||||
@jwt_refresh_token_required
|
||||
async def upload_image():
|
||||
current_user_uuid = get_jwt_identity()
|
||||
await blueprints.users.models.User.get(id=current_user_uuid)
|
||||
|
||||
files = await request.files
|
||||
form = await request.form
|
||||
file = files.get("file")
|
||||
conversation_id = form.get("conversation_id")
|
||||
|
||||
if not file or not conversation_id:
|
||||
return jsonify({"error": "file and conversation_id are required"}), 400
|
||||
|
||||
file_bytes = file.read()
|
||||
content_type = file.content_type or "image/jpeg"
|
||||
|
||||
try:
|
||||
processed_bytes, output_content_type = process_image(file_bytes, content_type)
|
||||
except ImageValidationError as e:
|
||||
return jsonify({"error": str(e)}), 400
|
||||
|
||||
ext = output_content_type.split("/")[-1]
|
||||
if ext == "jpeg":
|
||||
ext = "jpg"
|
||||
key = f"conversations/{conversation_id}/{uuid.uuid4()}.{ext}"
|
||||
|
||||
await s3_upload_image(processed_bytes, key, output_content_type)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"image_key": key,
|
||||
"image_url": f"/api/conversation/image/{key}",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@conversation_blueprint.get("/image/<path:image_key>")
|
||||
@jwt_refresh_token_required
|
||||
async def serve_image(image_key: str):
|
||||
try:
|
||||
image_bytes, content_type = await s3_get_image(image_key)
|
||||
except Exception:
|
||||
return jsonify({"error": "Image not found"}), 404
|
||||
|
||||
return Response(
|
||||
image_bytes,
|
||||
content_type=content_type,
|
||||
headers={"Cache-Control": "private, max-age=3600"},
|
||||
)
|
||||
|
||||
|
||||
@conversation_blueprint.post("/stream-query")
|
||||
@jwt_refresh_token_required
|
||||
async def stream_query():
|
||||
@@ -82,16 +153,31 @@ async def stream_query():
|
||||
data = await request.get_json()
|
||||
query_text = data.get("query")
|
||||
conversation_id = data.get("conversation_id")
|
||||
image_key = data.get("image_key")
|
||||
conversation = await get_conversation_by_id(conversation_id)
|
||||
await conversation.fetch_related("messages")
|
||||
await add_message_to_conversation(
|
||||
conversation=conversation,
|
||||
message=query_text,
|
||||
message=query_text or "",
|
||||
speaker="user",
|
||||
user=user,
|
||||
image_key=image_key,
|
||||
)
|
||||
|
||||
messages_payload = _build_messages_payload(conversation, query_text)
|
||||
# If an image was uploaded, analyze it with the vision model
|
||||
image_description = None
|
||||
if image_key:
|
||||
try:
|
||||
image_bytes, _ = await s3_get_image(image_key)
|
||||
image_description = await analyze_user_image(image_bytes)
|
||||
logging.info(f"Image analysis complete for {image_key}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to analyze image: {e}")
|
||||
image_description = "[Image could not be analyzed]"
|
||||
|
||||
messages_payload = _build_messages_payload(
|
||||
conversation, query_text or "", image_description
|
||||
)
|
||||
payload = {"messages": messages_payload}
|
||||
|
||||
async def event_generator():
|
||||
@@ -160,6 +246,7 @@ async def get_conversation(conversation_id: str):
|
||||
"text": msg.text,
|
||||
"speaker": msg.speaker.value,
|
||||
"created_at": msg.created_at.isoformat(),
|
||||
"image_key": msg.image_key,
|
||||
}
|
||||
)
|
||||
name = conversation.name
|
||||
|
||||
@@ -16,12 +16,14 @@ async def add_message_to_conversation(
|
||||
message: str,
|
||||
speaker: str,
|
||||
user: blueprints.users.models.User,
|
||||
image_key: str | None = None,
|
||||
) -> ConversationMessage:
|
||||
print(conversation, message, speaker)
|
||||
message = await ConversationMessage.create(
|
||||
text=message,
|
||||
speaker=speaker,
|
||||
conversation=conversation,
|
||||
image_key=image_key,
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
@@ -41,6 +41,7 @@ class ConversationMessage(Model):
|
||||
)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
speaker = fields.CharEnumField(enum_type=Speaker, max_length=10)
|
||||
image_key = fields.CharField(max_length=512, null=True, default=None)
|
||||
|
||||
class Meta:
|
||||
table = "conversation_messages"
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Database configuration with environment variable support
|
||||
# Use DATABASE_PATH for relative paths or DATABASE_URL for full connection strings
|
||||
DATABASE_PATH = os.getenv("DATABASE_PATH", "database/raggr.db")
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", f"sqlite://{DATABASE_PATH}")
|
||||
DATABASE_URL = os.getenv(
|
||||
"DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr"
|
||||
)
|
||||
|
||||
TORTOISE_ORM = {
|
||||
TORTOISE_CONFIG = {
|
||||
"connections": {"default": DATABASE_URL},
|
||||
"apps": {
|
||||
"models": {
|
||||
@@ -55,6 +55,12 @@ services:
|
||||
- OBSIDIAN_DEVICE_NAME=${OBSIDIAN_DEVICE_NAME}
|
||||
- OBSIDIAN_CONTINUOUS_SYNC=${OBSIDIAN_CONTINUOUS_SYNC:-false}
|
||||
- OBSIDIAN_VAULT_PATH=${OBSIDIAN_VAULT_PATH:-/app/data/obsidian}
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL}
|
||||
- S3_ACCESS_KEY_ID=${S3_ACCESS_KEY_ID}
|
||||
- S3_SECRET_ACCESS_KEY=${S3_SECRET_ACCESS_KEY}
|
||||
- S3_BUCKET_NAME=${S3_BUCKET_NAME:-asksimba-images}
|
||||
- S3_REGION=${S3_REGION:-garage}
|
||||
- OLLAMA_HOST=${OLLAMA_HOST:-http://localhost:11434}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -15,3 +15,32 @@ async def downgrade(db: BaseDBAsyncClient) -> str:
|
||||
DROP INDEX IF EXISTS "idx_users_email_h_a1b2c3";
|
||||
ALTER TABLE "users" DROP COLUMN "email_hmac_token";
|
||||
ALTER TABLE "users" DROP COLUMN "email_enabled";"""
|
||||
|
||||
|
||||
MODELS_STATE = (
|
||||
"eJztmm1v4jgQx78Kyquu1KtaKN1VdTopUHrLbYEThX3qVZFJXMg1sbOJsxRV/e5nm4Q4jg"
|
||||
"OEAoU93rRl7CH2z2PP35M+ay62oBOc1DH6Cf0AEBsj7bL0rCHgQvqHsv24pAHPS1qZgYCB"
|
||||
"wx1MoSdvAYOA+MAktPEBOAGkJgsGpm970cNQ6DjMiE3a0UbDxBQi+0cIDYKHkIygTxvu7q"
|
||||
"nZRhZ8gkH80Xs0HmzoWKlx2xZ7NrcbZOJxW7/fvLrmPdnjBoaJndBFSW9vQkYYzbqHoW2d"
|
||||
"MB/WNoQI+oBAS5gGG2U07dg0HTE1ED+Es6FaicGCDyB0GAzt94cQmYxBiT+J/Tj/QyuAh6"
|
||||
"JmaG1EGIvnl+mskjlzq8YeVf+od48qF+/4LHFAhj5v5ES0F+4ICJi6cq4JSP47g7I+Ar4a"
|
||||
"ZdxfgkkHugrG2JBwTGIoBhkDWo2a5oInw4FoSEb0Y7lanYPxs97lJGkvjhLTuJ5GfTtqKk"
|
||||
"/bGNIEoelDNmUDkCzIK9pCbBeqYaY9JaRW5HoS/7GjgOkcrA5yJtEmmMO312w1bnt66282"
|
||||
"EzcIfjgckd5rsJYyt04k69GFtBSzLyl9afY+ltjH0vdOuyHH/qxf77vGxgRCgg2ExwawhP"
|
||||
"0aW2MwqYUNPWvFhU17Hhb2TRc2GrywrgH0jWIZRHB5RRqJxrbFRVw9abDU+/CozBkMRhbe"
|
||||
"NfahPUSf4IQjbNJxAGSqkkUkOvrR1+wqtMSajMIH45kaEYOCzo7OCZJp9tRv6/pVQ+MMB8"
|
||||
"B8HAPfMnJgujAIwBAGWaC1yPP6Uxc6M2mmZikKuNb0G3fzVMljy1nhMhYYpehlm9yyK1sA"
|
||||
"ovO2omezJ82hs0AFCxCXE8OGuJAHUbzXopjAJ0XK71GrGmXcf19E8bxU3vjaS2XxWPoetf"
|
||||
"Sv71KZ/KbT/jPuLkjl+k2ndlDIv6KQyirkwIPgUSUG2AWygUI3IwVSqyu4v/HW0fq3je5l"
|
||||
"iWX0f9Bts1XTL0uB7Q6AttwSp26ZZ6dLXDLPTnPvmKxJ2kBioil2zCtc13nm76mENaWC1y"
|
||||
"ulrFw/21mKCzWtIlyKattNKjl+Z1BIt/guka/V2NY+aLP912ZsHYsWLUWffdFoWyhceiAI"
|
||||
"xthXRGbNRsCfqGGKXhLMwYRM7z+7eqVXwasxvSrKLYqs1mzr3W9qyRv3F+O29q3X0CW60A"
|
||||
"W2UyRKZw7rCdHFO36dAXp2upzomad6MrJnPAIkoEe6QZXkIE9mqmEqXFfCKofqdqlWloFa"
|
||||
"yWdaySDlQWZAxKan2vgYOxCgOQEq+srbnzpv6jAtmqoL7P9O5ya1/2tN+Urbb9UaNHg5Zt"
|
||||
"rJnkqhZrunhDtygUk1wiNUKMsFu1/y3cOIPbtY5hiQr6zCKXAhRyy2LdMIwsG/0FSUD/KB"
|
||||
"yn57CHMjWZ9e6EeG5+OftlXsSM04bk9KaQ42gfMKLZrmWl3mWK3mH6vVzLHqWMAzhj4OPU"
|
||||
"Uh/6/bTluNVHKTgPYRneWdZZvkuOTYAbnfGN67+83ofDbz+dVEuXAoCSv2BYdq4v+kmnh4"
|
||||
"3/5LLOzsdV6mKrToXWjmn8vW80J0l2+k230RqkPfNkeaooAWtRzPK6GBpM/O1NCaKOednL"
|
||||
"KExjBLwRCt/JvepPnr6N/KZ+fvzz9ULs4/0C58JDPL+zmHQXwNyS+ZsY2grHPnaz3B5VAw"
|
||||
"S6Qz3RpFBPO0+34C3EhBhz6RQKRI7/kSWXB5K3m8sdLj2uRxgWy7/vTy8h9Mf/k3"
|
||||
)
|
||||
|
||||
43
migrations/models/4_20260404080201_add_image_key.py
Normal file
43
migrations/models/4_20260404080201_add_image_key.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from tortoise import BaseDBAsyncClient
|
||||
|
||||
RUN_IN_TRANSACTION = True
|
||||
|
||||
|
||||
async def upgrade(db: BaseDBAsyncClient) -> str:
|
||||
return """
|
||||
ALTER TABLE "conversation_messages" ADD "image_key" VARCHAR(512);"""
|
||||
|
||||
|
||||
async def downgrade(db: BaseDBAsyncClient) -> str:
|
||||
return """
|
||||
ALTER TABLE "conversation_messages" DROP COLUMN "image_key";"""
|
||||
|
||||
|
||||
MODELS_STATE = (
|
||||
"eJztmmtv4jgUhv8KyqeO1K0KvcyoWq0UWrrDToFVC3PrVpFJXPCS2JnEGYqq/ve1TUIcx6"
|
||||
"GkBQqzfGnLsQ+xH1/Oe076aHjEgW54cE7wTxiEgCKCjbPKo4GBB9kf2vb9igF8P23lBgr6"
|
||||
"rnCwpZ6iBfRDGgCbssZ74IaQmRwY2gHy44fhyHW5kdisI8KD1BRh9COCFiUDSIcwYA23d8"
|
||||
"yMsAMfYJh89EfWPYKukxk3cvizhd2iE1/Yer3mxaXoyR/Xt2ziRh5Oe/sTOiR41j2KkHPA"
|
||||
"fXjbAGIYAAodaRp8lPG0E9N0xMxAgwjOhuqkBgfeg8jlMIzf7yNscwYV8ST+4/gPowQehp"
|
||||
"qjRZhyFo9P01mlcxZWgz/q/KN5vXd0+k7MkoR0EIhGQcR4Eo6Agqmr4JqCFL9zKM+HINCj"
|
||||
"TPorMNlAX4IxMaQc0z2UgEwAvYya4YEHy4V4QIfsY+3kZA7Gz+a1IMl6CZSE7evprm/HTb"
|
||||
"VpG0eaIrQDyKdsAZoHecFaKPKgHmbWU0HqxK4HyR8bCpjNwelgdxIfgjl8u81W46Zrtv7m"
|
||||
"M/HC8IcrEJndBm+pCetEse6dKksx+5LKl2b3Y4V/rHzvtBvq3p/16343+JhARImFydgCjn"
|
||||
"ReE2sCJrOwke+8cGGznruFfdOFjQcvrWsIA6tcBJFcXhFG4rGtcRFfHjR46L0faWMGh5GH"
|
||||
"d0kCiAb4E5wIhE02DoBtXbCIRUcv/ppNhZZa01EEYDxTI/KmYLNjc4J0Gj3Nm3PzomEIhn"
|
||||
"1gj8YgcKwCmB4MQzCAYR5oPfa8/HQN3Zk007OUBVxr+o2beasUsRWsSI1IjDL08k1ezVMt"
|
||||
"ALN5O/Gz+ZPm0HlGBUsQFxPDlryQO1G81aKYwgdNyO8yqx5l0n9bRPG8UN742s1E8UT67r"
|
||||
"XMr+8ykfyq0/4z6S5J5fOrTn2nkH9FIZVXyKEPwUgnBngC2cCRl5MCmdWV3N/46Bi9m8b1"
|
||||
"WYVH9H/wTbNVN88qIfL6wFhsiTNZZvVwgSSzeliYY/Km7AFCHoss1ghOyqTqGacX8V2/9M"
|
||||
"qCPKnWFiDJehWiFG3KZSQH7XIhU+O6zPi5pemArRQPX5kWqLXIjaX4bH6g2S5l84RVqmKR"
|
||||
"f2lkcJKXFetefk3udO7261y+jmULwLLPtujdNRSBfRCGYxJodmYdYRBM9DBlLwVmf0Knue"
|
||||
"TGxeg58Opc+8vSlSGrN9vm9Td9+pD0l/dt/Vu3YSp0oQeQW2aXzhyWs0WfP/HL3KDVw8UE"
|
||||
"5DwFmZOQ4yGgIbvSLabK+0WSXQ9T47oUObleqkeLQD0qZnqUQyo2mQUxn57u4BPiQoDnbF"
|
||||
"DZVz3+zHlVl2nZUF3i/Hc6V5nzX2+q5YFeq95gm1dgZp3QVAo1210t3KEHbKYRRlCjLJ85"
|
||||
"/YrvFu7Y6uki14Ca/ku3wKm6YwlybCuM+v9CW1OKKQaq+m0hzJVEfRDRoeUH5Cdyyl2pOc"
|
||||
"f1SSnDJTZwX6FFlRx9kWv1pPhaPcldq64DfGsQkMjXvBT566bT1iNV3BSgPcxmeesgm+5X"
|
||||
"XBTSu5Xhvb1bjc7nM59fmVWLsIqw4l+wq8z+Tyqzu/9d+CUWdvZqNFcVeu69cu4f9Zbzcn"
|
||||
"mTM9L1vlQ2YYDsoaEpoMUt+/NKaCDtszE1tCYueL+pLaFxzMpmiFf+TTNp8Wr/t1r1+P3x"
|
||||
"h6PT4w+sixjJzPJ+zmWQpCHFJTN+ELR17mKtJ7nsCmapdGZHo4xgnnbfToArKeiwJ1KINe"
|
||||
"G9WCJLLm8lj1dWelyaPC4RbZcfXp7+AzcBYwM="
|
||||
)
|
||||
@@ -37,9 +37,10 @@ dependencies = [
|
||||
"ynab>=1.3.0",
|
||||
"ollama>=0.6.1",
|
||||
"twilio>=9.10.2",
|
||||
"aioboto3>=13.0.0",
|
||||
]
|
||||
|
||||
[tool.aerich]
|
||||
tortoise_orm = "app.TORTOISE_CONFIG"
|
||||
tortoise_orm = "config.db.TORTOISE_CONFIG"
|
||||
location = "./migrations"
|
||||
src_folder = "./."
|
||||
|
||||
@@ -13,6 +13,7 @@ interface Message {
|
||||
text: string;
|
||||
speaker: "user" | "simba";
|
||||
created_at: string;
|
||||
image_key?: string | null;
|
||||
}
|
||||
|
||||
interface Conversation {
|
||||
@@ -121,17 +122,52 @@ class ConversationService {
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
async uploadImage(
|
||||
file: File,
|
||||
conversationId: string,
|
||||
): Promise<{ image_key: string; image_url: string }> {
|
||||
const formData = new FormData();
|
||||
formData.append("file", file);
|
||||
formData.append("conversation_id", conversationId);
|
||||
|
||||
const response = await userService.fetchWithRefreshToken(
|
||||
`${this.conversationBaseUrl}/upload-image`,
|
||||
{
|
||||
method: "POST",
|
||||
body: formData,
|
||||
},
|
||||
{ skipContentType: true },
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.error || "Failed to upload image");
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
getImageUrl(imageKey: string): string {
|
||||
return `/api/conversation/image/${imageKey}`;
|
||||
}
|
||||
|
||||
async streamQuery(
|
||||
query: string,
|
||||
conversation_id: string,
|
||||
onEvent: SSEEventCallback,
|
||||
signal?: AbortSignal,
|
||||
imageKey?: string,
|
||||
): Promise<void> {
|
||||
const body: Record<string, string> = { query, conversation_id };
|
||||
if (imageKey) {
|
||||
body.image_key = imageKey;
|
||||
}
|
||||
|
||||
const response = await userService.fetchWithRefreshToken(
|
||||
`${this.conversationBaseUrl}/stream-query`,
|
||||
{
|
||||
method: "POST",
|
||||
body: JSON.stringify({ query, conversation_id }),
|
||||
body: JSON.stringify(body),
|
||||
signal,
|
||||
},
|
||||
);
|
||||
|
||||
@@ -106,14 +106,15 @@ class UserService {
|
||||
async fetchWithRefreshToken(
|
||||
url: string,
|
||||
options: RequestInit = {},
|
||||
{ skipContentType = false }: { skipContentType?: boolean } = {},
|
||||
): Promise<Response> {
|
||||
const refreshToken = localStorage.getItem("refresh_token");
|
||||
|
||||
// Add authorization header
|
||||
const headers = {
|
||||
"Content-Type": "application/json",
|
||||
...(options.headers || {}),
|
||||
...(refreshToken && { Authorization: `Bearer ${refreshToken}` }),
|
||||
const headers: Record<string, string> = {
|
||||
...(skipContentType ? {} : { "Content-Type": "application/json" }),
|
||||
...((options.headers as Record<string, string>) || {}),
|
||||
...(refreshToken ? { Authorization: `Bearer ${refreshToken}` } : {}),
|
||||
};
|
||||
|
||||
let response = await fetch(url, { ...options, headers });
|
||||
|
||||
@@ -14,6 +14,7 @@ import catIcon from "../assets/cat.png";
|
||||
type Message = {
|
||||
text: string;
|
||||
speaker: "simba" | "user" | "tool";
|
||||
image_key?: string | null;
|
||||
};
|
||||
|
||||
type Conversation = {
|
||||
@@ -55,6 +56,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
const [isLoading, setIsLoading] = useState<boolean>(false);
|
||||
const [isAdmin, setIsAdmin] = useState<boolean>(false);
|
||||
const [showAdminPanel, setShowAdminPanel] = useState<boolean>(false);
|
||||
const [pendingImage, setPendingImage] = useState<File | null>(null);
|
||||
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
const isMountedRef = useRef<boolean>(true);
|
||||
@@ -80,7 +82,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
try {
|
||||
const fetched = await conversationService.getConversation(conversation.id);
|
||||
setMessages(
|
||||
fetched.messages.map((m) => ({ text: m.text, speaker: m.speaker })),
|
||||
fetched.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })),
|
||||
);
|
||||
} catch (err) {
|
||||
console.error("Failed to load messages:", err);
|
||||
@@ -120,7 +122,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
try {
|
||||
const conv = await conversationService.getConversation(selectedConversation.id);
|
||||
setSelectedConversation({ id: conv.id, title: conv.name });
|
||||
setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker })));
|
||||
setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })));
|
||||
} catch (err) {
|
||||
console.error("Failed to load messages:", err);
|
||||
}
|
||||
@@ -129,7 +131,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
}, [selectedConversation?.id]);
|
||||
|
||||
const handleQuestionSubmit = async () => {
|
||||
if (!query.trim() || isLoading) return;
|
||||
if ((!query.trim() && !pendingImage) || isLoading) return;
|
||||
|
||||
let activeConversation = selectedConversation;
|
||||
if (!activeConversation) {
|
||||
@@ -139,9 +141,13 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
setConversations((prev) => [activeConversation!, ...prev]);
|
||||
}
|
||||
|
||||
// Capture pending image before clearing state
|
||||
const imageFile = pendingImage;
|
||||
|
||||
const currMessages = messages.concat([{ text: query, speaker: "user" }]);
|
||||
setMessages(currMessages);
|
||||
setQuery("");
|
||||
setPendingImage(null);
|
||||
setIsLoading(true);
|
||||
|
||||
if (simbaMode) {
|
||||
@@ -155,6 +161,29 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
abortControllerRef.current = abortController;
|
||||
|
||||
try {
|
||||
// Upload image first if present
|
||||
let imageKey: string | undefined;
|
||||
if (imageFile) {
|
||||
const uploadResult = await conversationService.uploadImage(
|
||||
imageFile,
|
||||
activeConversation.id,
|
||||
);
|
||||
imageKey = uploadResult.image_key;
|
||||
|
||||
// Update the user message with the image key
|
||||
setMessages((prev) => {
|
||||
const updated = [...prev];
|
||||
// Find the last user message we just added
|
||||
for (let i = updated.length - 1; i >= 0; i--) {
|
||||
if (updated[i].speaker === "user") {
|
||||
updated[i] = { ...updated[i], image_key: imageKey };
|
||||
break;
|
||||
}
|
||||
}
|
||||
return updated;
|
||||
});
|
||||
}
|
||||
|
||||
await conversationService.streamQuery(
|
||||
query,
|
||||
activeConversation.id,
|
||||
@@ -170,6 +199,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
}
|
||||
},
|
||||
abortController.signal,
|
||||
imageKey,
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
@@ -349,6 +379,9 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
handleQuestionSubmit={handleQuestionSubmit}
|
||||
setSimbaMode={setSimbaMode}
|
||||
isLoading={isLoading}
|
||||
pendingImage={pendingImage}
|
||||
onImageSelect={(file) => setPendingImage(file)}
|
||||
onClearImage={() => setPendingImage(null)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -375,7 +408,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
return <ToolBubble key={index} text={msg.text} />;
|
||||
if (msg.speaker === "simba")
|
||||
return <AnswerBubble key={index} text={msg.text} />;
|
||||
return <QuestionBubble key={index} text={msg.text} />;
|
||||
return <QuestionBubble key={index} text={msg.text} image_key={msg.image_key} />;
|
||||
})}
|
||||
|
||||
{isLoading && <AnswerBubble text="" loading={true} />}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useState } from "react";
|
||||
import { ArrowUp } from "lucide-react";
|
||||
import { useRef, useState } from "react";
|
||||
import { ArrowUp, ImagePlus, X } from "lucide-react";
|
||||
import { cn } from "../lib/utils";
|
||||
import { Textarea } from "./ui/textarea";
|
||||
|
||||
@@ -10,6 +10,9 @@ type MessageInputProps = {
|
||||
setSimbaMode: (val: boolean) => void;
|
||||
query: string;
|
||||
isLoading: boolean;
|
||||
pendingImage: File | null;
|
||||
onImageSelect: (file: File) => void;
|
||||
onClearImage: () => void;
|
||||
};
|
||||
|
||||
export const MessageInput = ({
|
||||
@@ -19,8 +22,12 @@ export const MessageInput = ({
|
||||
handleQuestionSubmit,
|
||||
setSimbaMode,
|
||||
isLoading,
|
||||
pendingImage,
|
||||
onImageSelect,
|
||||
onClearImage,
|
||||
}: MessageInputProps) => {
|
||||
const [simbaMode, setLocalSimbaMode] = useState(false);
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const toggleSimbaMode = () => {
|
||||
const next = !simbaMode;
|
||||
@@ -28,6 +35,17 @@ export const MessageInput = ({
|
||||
setSimbaMode(next);
|
||||
};
|
||||
|
||||
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = e.target.files?.[0];
|
||||
if (file) {
|
||||
onImageSelect(file);
|
||||
}
|
||||
// Reset so the same file can be re-selected
|
||||
e.target.value = "";
|
||||
};
|
||||
|
||||
const canSend = !isLoading && (query.trim() || pendingImage);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
@@ -36,6 +54,26 @@ export const MessageInput = ({
|
||||
"focus-within:border-amber-soft/60",
|
||||
)}
|
||||
>
|
||||
{/* Image preview */}
|
||||
{pendingImage && (
|
||||
<div className="px-3 pt-3">
|
||||
<div className="relative inline-block">
|
||||
<img
|
||||
src={URL.createObjectURL(pendingImage)}
|
||||
alt="Pending upload"
|
||||
className="h-20 rounded-lg object-cover border border-sand"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={onClearImage}
|
||||
className="absolute -top-1.5 -right-1.5 w-5 h-5 rounded-full bg-charcoal text-white flex items-center justify-center hover:bg-charcoal/80 transition-colors cursor-pointer"
|
||||
>
|
||||
<X size={12} />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Textarea */}
|
||||
<Textarea
|
||||
onChange={handleQueryChange}
|
||||
@@ -46,8 +84,18 @@ export const MessageInput = ({
|
||||
className="min-h-[60px] max-h-40"
|
||||
/>
|
||||
|
||||
{/* Hidden file input */}
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept="image/*"
|
||||
onChange={handleFileChange}
|
||||
className="hidden"
|
||||
/>
|
||||
|
||||
{/* Bottom toolbar */}
|
||||
<div className="flex items-center justify-between px-3 pb-2.5 pt-1">
|
||||
<div className="flex items-center gap-3">
|
||||
{/* Simba mode toggle */}
|
||||
<button
|
||||
type="button"
|
||||
@@ -62,16 +110,32 @@ export const MessageInput = ({
|
||||
</span>
|
||||
</button>
|
||||
|
||||
{/* Image attach button */}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => fileInputRef.current?.click()}
|
||||
disabled={isLoading}
|
||||
className={cn(
|
||||
"w-7 h-7 rounded-lg flex items-center justify-center transition-all cursor-pointer",
|
||||
isLoading
|
||||
? "text-warm-gray/40 cursor-not-allowed"
|
||||
: "text-warm-gray hover:text-charcoal hover:bg-cream-dark",
|
||||
)}
|
||||
>
|
||||
<ImagePlus size={16} />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Send button */}
|
||||
<button
|
||||
type="submit"
|
||||
onClick={handleQuestionSubmit}
|
||||
disabled={isLoading || !query.trim()}
|
||||
disabled={!canSend}
|
||||
className={cn(
|
||||
"w-8 h-8 rounded-full flex items-center justify-center",
|
||||
"transition-all duration-200 cursor-pointer",
|
||||
"shadow-sm",
|
||||
isLoading || !query.trim()
|
||||
!canSend
|
||||
? "bg-sand text-warm-gray/50 cursor-not-allowed shadow-none"
|
||||
: "bg-amber-glow text-white hover:bg-amber-dark hover:shadow-md hover:shadow-amber-glow/30 active:scale-95",
|
||||
)}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import { cn } from "../lib/utils";
|
||||
import { conversationService } from "../api/conversationService";
|
||||
|
||||
type QuestionBubbleProps = {
|
||||
text: string;
|
||||
image_key?: string | null;
|
||||
};
|
||||
|
||||
export const QuestionBubble = ({ text }: QuestionBubbleProps) => {
|
||||
export const QuestionBubble = ({ text, image_key }: QuestionBubbleProps) => {
|
||||
return (
|
||||
<div className="flex justify-end message-enter">
|
||||
<div
|
||||
@@ -15,6 +17,13 @@ export const QuestionBubble = ({ text }: QuestionBubbleProps) => {
|
||||
"shadow-sm shadow-leaf/10",
|
||||
)}
|
||||
>
|
||||
{image_key && (
|
||||
<img
|
||||
src={conversationService.getImageUrl(image_key)}
|
||||
alt="Uploaded image"
|
||||
className="max-w-full rounded-xl mb-2"
|
||||
/>
|
||||
)}
|
||||
{text}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -76,6 +76,50 @@ def describe_simba_image(input):
|
||||
return result
|
||||
|
||||
|
||||
async def analyze_user_image(file_bytes: bytes) -> str:
|
||||
"""Analyze an image uploaded by a user and return a text description.
|
||||
|
||||
Uses llama-server (OpenAI-compatible API) with vision support.
|
||||
Falls back to OpenAI if llama-server is not configured.
|
||||
"""
|
||||
import base64
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
llama_url = os.getenv("LLAMA_SERVER_URL")
|
||||
if llama_url:
|
||||
aclient = AsyncOpenAI(base_url=llama_url, api_key="not-needed")
|
||||
model = os.getenv("LLAMA_MODEL_NAME", "llama-3.1-8b-instruct")
|
||||
else:
|
||||
aclient = AsyncOpenAI()
|
||||
model = "gpt-4o-mini"
|
||||
|
||||
b64 = base64.b64encode(file_bytes).decode("utf-8")
|
||||
|
||||
response = await aclient.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful image analyst. Describe what you see in the image in detail. Be thorough but concise.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Please describe this image in detail."},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{b64}",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.filepath:
|
||||
|
||||
62
utils/image_upload.py
Normal file
62
utils/image_upload.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
from PIL import Image
|
||||
from pillow_heif import register_heif_opener
|
||||
|
||||
register_heif_opener()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
ALLOWED_TYPES = {"image/jpeg", "image/png", "image/webp", "image/heic", "image/heif"}
|
||||
MAX_DIMENSION = 1920
|
||||
|
||||
|
||||
class ImageValidationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def process_image(file_bytes: bytes, content_type: str) -> tuple[bytes, str]:
|
||||
"""Validate, resize, and strip EXIF from an uploaded image.
|
||||
|
||||
Returns processed bytes and the output content type (always image/jpeg or image/png or image/webp).
|
||||
"""
|
||||
if content_type not in ALLOWED_TYPES:
|
||||
raise ImageValidationError(
|
||||
f"Unsupported image type: {content_type}. "
|
||||
f"Allowed: JPEG, PNG, WebP, HEIC"
|
||||
)
|
||||
|
||||
img = Image.open(io.BytesIO(file_bytes))
|
||||
|
||||
# Resize if too large
|
||||
width, height = img.size
|
||||
if max(width, height) > MAX_DIMENSION:
|
||||
ratio = MAX_DIMENSION / max(width, height)
|
||||
new_size = (int(width * ratio), int(height * ratio))
|
||||
img = img.resize(new_size, Image.LANCZOS)
|
||||
logging.info(
|
||||
f"Resized image from {width}x{height} to {new_size[0]}x{new_size[1]}"
|
||||
)
|
||||
|
||||
# Strip EXIF by copying pixel data to a new image
|
||||
clean_img = Image.new(img.mode, img.size)
|
||||
clean_img.putdata(list(img.getdata()))
|
||||
|
||||
# Convert HEIC/HEIF to JPEG; otherwise keep original format
|
||||
if content_type in {"image/heic", "image/heif"}:
|
||||
output_format = "JPEG"
|
||||
output_content_type = "image/jpeg"
|
||||
elif content_type == "image/png":
|
||||
output_format = "PNG"
|
||||
output_content_type = "image/png"
|
||||
elif content_type == "image/webp":
|
||||
output_format = "WEBP"
|
||||
output_content_type = "image/webp"
|
||||
else:
|
||||
output_format = "JPEG"
|
||||
output_content_type = "image/jpeg"
|
||||
|
||||
buf = io.BytesIO()
|
||||
clean_img.save(buf, format=output_format, quality=85)
|
||||
return buf.getvalue(), output_content_type
|
||||
53
utils/s3_client.py
Normal file
53
utils/s3_client.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
import aioboto3
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL")
|
||||
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
|
||||
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
|
||||
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "asksimba-images")
|
||||
S3_REGION = os.getenv("S3_REGION", "garage")
|
||||
|
||||
session = aioboto3.Session()
|
||||
|
||||
|
||||
def _get_client():
|
||||
return session.client(
|
||||
"s3",
|
||||
endpoint_url=S3_ENDPOINT_URL,
|
||||
aws_access_key_id=S3_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||
region_name=S3_REGION,
|
||||
)
|
||||
|
||||
|
||||
async def upload_image(file_bytes: bytes, key: str, content_type: str) -> str:
|
||||
async with _get_client() as client:
|
||||
await client.put_object(
|
||||
Bucket=S3_BUCKET_NAME,
|
||||
Key=key,
|
||||
Body=file_bytes,
|
||||
ContentType=content_type,
|
||||
)
|
||||
logging.info(f"Uploaded image to S3: {key}")
|
||||
return key
|
||||
|
||||
|
||||
async def get_image(key: str) -> tuple[bytes, str]:
|
||||
async with _get_client() as client:
|
||||
response = await client.get_object(Bucket=S3_BUCKET_NAME, Key=key)
|
||||
body = await response["Body"].read()
|
||||
content_type = response.get("ContentType", "image/jpeg")
|
||||
return body, content_type
|
||||
|
||||
|
||||
async def delete_image(key: str) -> None:
|
||||
async with _get_client() as client:
|
||||
await client.delete_object(Bucket=S3_BUCKET_NAME, Key=key)
|
||||
logging.info(f"Deleted image from S3: {key}")
|
||||
@@ -82,7 +82,6 @@ class YNABService:
|
||||
end_date: Optional[str] = None,
|
||||
category_name: Optional[str] = None,
|
||||
payee_name: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
) -> dict[str, Any]:
|
||||
"""Get transactions filtered by date range, category, or payee.
|
||||
|
||||
@@ -91,7 +90,6 @@ class YNABService:
|
||||
end_date: End date in YYYY-MM-DD format (defaults to today)
|
||||
category_name: Filter by category name (case-insensitive partial match)
|
||||
payee_name: Filter by payee name (case-insensitive partial match)
|
||||
limit: Maximum number of transactions to return (default 50)
|
||||
|
||||
Returns:
|
||||
Dictionary containing matching transactions and summary statistics.
|
||||
@@ -145,9 +143,8 @@ class YNABService:
|
||||
)
|
||||
total_amount += amount
|
||||
|
||||
# Sort by date (most recent first) and limit
|
||||
# Sort by date (most recent first)
|
||||
filtered_transactions.sort(key=lambda x: x["date"], reverse=True)
|
||||
filtered_transactions = filtered_transactions[:limit]
|
||||
|
||||
return {
|
||||
"transactions": filtered_transactions,
|
||||
|
||||
Reference in New Issue
Block a user