Add image upload and vision analysis to Ask Simba chat
Users can now attach images in the web chat for Simba to analyze using Ollama's gemma3 vision model. Images are stored in Garage (S3-compatible) and displayed in chat history. Also fixes aerich migration config by extracting TORTOISE_CONFIG into a standalone config/db.py module, removing the stale aerich_config.py, and adding missing MODELS_STATE to migration 3. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
20
app.py
20
app.py
@@ -13,6 +13,7 @@ import blueprints.users
|
||||
import blueprints.whatsapp
|
||||
import blueprints.email
|
||||
import blueprints.users.models
|
||||
from config.db import TORTOISE_CONFIG
|
||||
from main import consult_simba_oracle
|
||||
|
||||
# Load environment variables
|
||||
@@ -28,6 +29,7 @@ app = Quart(
|
||||
)
|
||||
|
||||
app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY")
|
||||
app.config["MAX_CONTENT_LENGTH"] = 10 * 1024 * 1024 # 10 MB upload limit
|
||||
jwt = JWTManager(app)
|
||||
|
||||
# Register blueprints
|
||||
@@ -38,24 +40,6 @@ app.register_blueprint(blueprints.whatsapp.whatsapp_blueprint)
|
||||
app.register_blueprint(blueprints.email.email_blueprint)
|
||||
|
||||
|
||||
# Database configuration with environment variable support
|
||||
DATABASE_URL = os.getenv(
|
||||
"DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr"
|
||||
)
|
||||
|
||||
TORTOISE_CONFIG = {
|
||||
"connections": {"default": DATABASE_URL},
|
||||
"apps": {
|
||||
"models": {
|
||||
"models": [
|
||||
"blueprints.conversation.models",
|
||||
"blueprints.users.models",
|
||||
"aerich.models",
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Initialize Tortoise ORM with lifecycle hooks
|
||||
@app.while_serving
|
||||
async def lifespan():
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from quart import Blueprint, jsonify, make_response, request
|
||||
from quart import Blueprint, Response, jsonify, make_response, request
|
||||
from quart_jwt_extended import (
|
||||
get_jwt_identity,
|
||||
jwt_refresh_token_required,
|
||||
)
|
||||
|
||||
import blueprints.users.models
|
||||
from utils.image_process import analyze_user_image
|
||||
from utils.image_upload import ImageValidationError, process_image
|
||||
from utils.s3_client import get_image as s3_get_image
|
||||
from utils.s3_client import upload_image as s3_upload_image
|
||||
|
||||
from .agents import main_agent
|
||||
from .logic import (
|
||||
@@ -29,7 +35,9 @@ conversation_blueprint = Blueprint(
|
||||
_SYSTEM_PROMPT = SIMBA_SYSTEM_PROMPT
|
||||
|
||||
|
||||
def _build_messages_payload(conversation, query_text: str) -> list:
|
||||
def _build_messages_payload(
|
||||
conversation, query_text: str, image_description: str | None = None
|
||||
) -> list:
|
||||
recent_messages = (
|
||||
conversation.messages[-10:]
|
||||
if len(conversation.messages) > 10
|
||||
@@ -38,8 +46,19 @@ def _build_messages_payload(conversation, query_text: str) -> list:
|
||||
messages_payload = [{"role": "system", "content": _SYSTEM_PROMPT}]
|
||||
for msg in recent_messages[:-1]: # Exclude the message we just added
|
||||
role = "user" if msg.speaker == "user" else "assistant"
|
||||
messages_payload.append({"role": role, "content": msg.text})
|
||||
messages_payload.append({"role": "user", "content": query_text})
|
||||
text = msg.text
|
||||
if msg.image_key and role == "user":
|
||||
text = f"[User sent an image]\n{text}"
|
||||
messages_payload.append({"role": role, "content": text})
|
||||
|
||||
# Build the current user message with optional image description
|
||||
if image_description:
|
||||
content = f"[Image analysis: {image_description}]"
|
||||
if query_text:
|
||||
content = f"{query_text}\n\n{content}"
|
||||
else:
|
||||
content = query_text
|
||||
messages_payload.append({"role": "user", "content": content})
|
||||
return messages_payload
|
||||
|
||||
|
||||
@@ -74,6 +93,58 @@ async def query():
|
||||
return jsonify({"response": message})
|
||||
|
||||
|
||||
@conversation_blueprint.post("/upload-image")
|
||||
@jwt_refresh_token_required
|
||||
async def upload_image():
|
||||
current_user_uuid = get_jwt_identity()
|
||||
await blueprints.users.models.User.get(id=current_user_uuid)
|
||||
|
||||
files = await request.files
|
||||
form = await request.form
|
||||
file = files.get("file")
|
||||
conversation_id = form.get("conversation_id")
|
||||
|
||||
if not file or not conversation_id:
|
||||
return jsonify({"error": "file and conversation_id are required"}), 400
|
||||
|
||||
file_bytes = file.read()
|
||||
content_type = file.content_type or "image/jpeg"
|
||||
|
||||
try:
|
||||
processed_bytes, output_content_type = process_image(file_bytes, content_type)
|
||||
except ImageValidationError as e:
|
||||
return jsonify({"error": str(e)}), 400
|
||||
|
||||
ext = output_content_type.split("/")[-1]
|
||||
if ext == "jpeg":
|
||||
ext = "jpg"
|
||||
key = f"conversations/{conversation_id}/{uuid.uuid4()}.{ext}"
|
||||
|
||||
await s3_upload_image(processed_bytes, key, output_content_type)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"image_key": key,
|
||||
"image_url": f"/api/conversation/image/{key}",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@conversation_blueprint.get("/image/<path:image_key>")
|
||||
@jwt_refresh_token_required
|
||||
async def serve_image(image_key: str):
|
||||
try:
|
||||
image_bytes, content_type = await s3_get_image(image_key)
|
||||
except Exception:
|
||||
return jsonify({"error": "Image not found"}), 404
|
||||
|
||||
return Response(
|
||||
image_bytes,
|
||||
content_type=content_type,
|
||||
headers={"Cache-Control": "private, max-age=3600"},
|
||||
)
|
||||
|
||||
|
||||
@conversation_blueprint.post("/stream-query")
|
||||
@jwt_refresh_token_required
|
||||
async def stream_query():
|
||||
@@ -82,16 +153,31 @@ async def stream_query():
|
||||
data = await request.get_json()
|
||||
query_text = data.get("query")
|
||||
conversation_id = data.get("conversation_id")
|
||||
image_key = data.get("image_key")
|
||||
conversation = await get_conversation_by_id(conversation_id)
|
||||
await conversation.fetch_related("messages")
|
||||
await add_message_to_conversation(
|
||||
conversation=conversation,
|
||||
message=query_text,
|
||||
message=query_text or "",
|
||||
speaker="user",
|
||||
user=user,
|
||||
image_key=image_key,
|
||||
)
|
||||
|
||||
messages_payload = _build_messages_payload(conversation, query_text)
|
||||
# If an image was uploaded, analyze it with the vision model
|
||||
image_description = None
|
||||
if image_key:
|
||||
try:
|
||||
image_bytes, _ = await s3_get_image(image_key)
|
||||
image_description = await analyze_user_image(image_bytes)
|
||||
logging.info(f"Image analysis complete for {image_key}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to analyze image: {e}")
|
||||
image_description = "[Image could not be analyzed]"
|
||||
|
||||
messages_payload = _build_messages_payload(
|
||||
conversation, query_text or "", image_description
|
||||
)
|
||||
payload = {"messages": messages_payload}
|
||||
|
||||
async def event_generator():
|
||||
@@ -160,6 +246,7 @@ async def get_conversation(conversation_id: str):
|
||||
"text": msg.text,
|
||||
"speaker": msg.speaker.value,
|
||||
"created_at": msg.created_at.isoformat(),
|
||||
"image_key": msg.image_key,
|
||||
}
|
||||
)
|
||||
name = conversation.name
|
||||
|
||||
@@ -16,12 +16,14 @@ async def add_message_to_conversation(
|
||||
message: str,
|
||||
speaker: str,
|
||||
user: blueprints.users.models.User,
|
||||
image_key: str | None = None,
|
||||
) -> ConversationMessage:
|
||||
print(conversation, message, speaker)
|
||||
message = await ConversationMessage.create(
|
||||
text=message,
|
||||
speaker=speaker,
|
||||
conversation=conversation,
|
||||
image_key=image_key,
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
@@ -41,6 +41,7 @@ class ConversationMessage(Model):
|
||||
)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
speaker = fields.CharEnumField(enum_type=Speaker, max_length=10)
|
||||
image_key = fields.CharField(max_length=512, null=True, default=None)
|
||||
|
||||
class Meta:
|
||||
table = "conversation_messages"
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Database configuration with environment variable support
|
||||
# Use DATABASE_PATH for relative paths or DATABASE_URL for full connection strings
|
||||
DATABASE_PATH = os.getenv("DATABASE_PATH", "database/raggr.db")
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", f"sqlite://{DATABASE_PATH}")
|
||||
DATABASE_URL = os.getenv(
|
||||
"DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr"
|
||||
)
|
||||
|
||||
TORTOISE_ORM = {
|
||||
TORTOISE_CONFIG = {
|
||||
"connections": {"default": DATABASE_URL},
|
||||
"apps": {
|
||||
"models": {
|
||||
@@ -55,6 +55,12 @@ services:
|
||||
- OBSIDIAN_DEVICE_NAME=${OBSIDIAN_DEVICE_NAME}
|
||||
- OBSIDIAN_CONTINUOUS_SYNC=${OBSIDIAN_CONTINUOUS_SYNC:-false}
|
||||
- OBSIDIAN_VAULT_PATH=${OBSIDIAN_VAULT_PATH:-/app/data/obsidian}
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL}
|
||||
- S3_ACCESS_KEY_ID=${S3_ACCESS_KEY_ID}
|
||||
- S3_SECRET_ACCESS_KEY=${S3_SECRET_ACCESS_KEY}
|
||||
- S3_BUCKET_NAME=${S3_BUCKET_NAME:-asksimba-images}
|
||||
- S3_REGION=${S3_REGION:-garage}
|
||||
- OLLAMA_HOST=${OLLAMA_HOST:-http://localhost:11434}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -15,3 +15,32 @@ async def downgrade(db: BaseDBAsyncClient) -> str:
|
||||
DROP INDEX IF EXISTS "idx_users_email_h_a1b2c3";
|
||||
ALTER TABLE "users" DROP COLUMN "email_hmac_token";
|
||||
ALTER TABLE "users" DROP COLUMN "email_enabled";"""
|
||||
|
||||
|
||||
MODELS_STATE = (
|
||||
"eJztmm1v4jgQx78Kyquu1KtaKN1VdTopUHrLbYEThX3qVZFJXMg1sbOJsxRV/e5nm4Q4jg"
|
||||
"OEAoU93rRl7CH2z2PP35M+ay62oBOc1DH6Cf0AEBsj7bL0rCHgQvqHsv24pAHPS1qZgYCB"
|
||||
"wx1MoSdvAYOA+MAktPEBOAGkJgsGpm970cNQ6DjMiE3a0UbDxBQi+0cIDYKHkIygTxvu7q"
|
||||
"nZRhZ8gkH80Xs0HmzoWKlx2xZ7NrcbZOJxW7/fvLrmPdnjBoaJndBFSW9vQkYYzbqHoW2d"
|
||||
"MB/WNoQI+oBAS5gGG2U07dg0HTE1ED+Es6FaicGCDyB0GAzt94cQmYxBiT+J/Tj/QyuAh6"
|
||||
"JmaG1EGIvnl+mskjlzq8YeVf+od48qF+/4LHFAhj5v5ES0F+4ICJi6cq4JSP47g7I+Ar4a"
|
||||
"ZdxfgkkHugrG2JBwTGIoBhkDWo2a5oInw4FoSEb0Y7lanYPxs97lJGkvjhLTuJ5GfTtqKk"
|
||||
"/bGNIEoelDNmUDkCzIK9pCbBeqYaY9JaRW5HoS/7GjgOkcrA5yJtEmmMO312w1bnt66282"
|
||||
"EzcIfjgckd5rsJYyt04k69GFtBSzLyl9afY+ltjH0vdOuyHH/qxf77vGxgRCgg2ExwawhP"
|
||||
"0aW2MwqYUNPWvFhU17Hhb2TRc2GrywrgH0jWIZRHB5RRqJxrbFRVw9abDU+/CozBkMRhbe"
|
||||
"NfahPUSf4IQjbNJxAGSqkkUkOvrR1+wqtMSajMIH45kaEYOCzo7OCZJp9tRv6/pVQ+MMB8"
|
||||
"B8HAPfMnJgujAIwBAGWaC1yPP6Uxc6M2mmZikKuNb0G3fzVMljy1nhMhYYpehlm9yyK1sA"
|
||||
"ovO2omezJ82hs0AFCxCXE8OGuJAHUbzXopjAJ0XK71GrGmXcf19E8bxU3vjaS2XxWPoetf"
|
||||
"Sv71KZ/KbT/jPuLkjl+k2ndlDIv6KQyirkwIPgUSUG2AWygUI3IwVSqyu4v/HW0fq3je5l"
|
||||
"iWX0f9Bts1XTL0uB7Q6AttwSp26ZZ6dLXDLPTnPvmKxJ2kBioil2zCtc13nm76mENaWC1y"
|
||||
"ulrFw/21mKCzWtIlyKattNKjl+Z1BIt/guka/V2NY+aLP912ZsHYsWLUWffdFoWyhceiAI"
|
||||
"xthXRGbNRsCfqGGKXhLMwYRM7z+7eqVXwasxvSrKLYqs1mzr3W9qyRv3F+O29q3X0CW60A"
|
||||
"W2UyRKZw7rCdHFO36dAXp2upzomad6MrJnPAIkoEe6QZXkIE9mqmEqXFfCKofqdqlWloFa"
|
||||
"yWdaySDlQWZAxKan2vgYOxCgOQEq+srbnzpv6jAtmqoL7P9O5ya1/2tN+Urbb9UaNHg5Zt"
|
||||
"rJnkqhZrunhDtygUk1wiNUKMsFu1/y3cOIPbtY5hiQr6zCKXAhRyy2LdMIwsG/0FSUD/KB"
|
||||
"yn57CHMjWZ9e6EeG5+OftlXsSM04bk9KaQ42gfMKLZrmWl3mWK3mH6vVzLHqWMAzhj4OPU"
|
||||
"Uh/6/bTluNVHKTgPYRneWdZZvkuOTYAbnfGN67+83ofDbz+dVEuXAoCSv2BYdq4v+kmnh4"
|
||||
"3/5LLOzsdV6mKrToXWjmn8vW80J0l2+k230RqkPfNkeaooAWtRzPK6GBpM/O1NCaKOednL"
|
||||
"KExjBLwRCt/JvepPnr6N/KZ+fvzz9ULs4/0C58JDPL+zmHQXwNyS+ZsY2grHPnaz3B5VAw"
|
||||
"S6Qz3RpFBPO0+34C3EhBhz6RQKRI7/kSWXB5K3m8sdLj2uRxgWy7/vTy8h9Mf/k3"
|
||||
)
|
||||
|
||||
43
migrations/models/4_20260404080201_add_image_key.py
Normal file
43
migrations/models/4_20260404080201_add_image_key.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from tortoise import BaseDBAsyncClient
|
||||
|
||||
RUN_IN_TRANSACTION = True
|
||||
|
||||
|
||||
async def upgrade(db: BaseDBAsyncClient) -> str:
|
||||
return """
|
||||
ALTER TABLE "conversation_messages" ADD "image_key" VARCHAR(512);"""
|
||||
|
||||
|
||||
async def downgrade(db: BaseDBAsyncClient) -> str:
|
||||
return """
|
||||
ALTER TABLE "conversation_messages" DROP COLUMN "image_key";"""
|
||||
|
||||
|
||||
MODELS_STATE = (
|
||||
"eJztmmtv4jgUhv8KyqeO1K0KvcyoWq0UWrrDToFVC3PrVpFJXPCS2JnEGYqq/ve1TUIcx6"
|
||||
"GkBQqzfGnLsQ+xH1/Oe076aHjEgW54cE7wTxiEgCKCjbPKo4GBB9kf2vb9igF8P23lBgr6"
|
||||
"rnCwpZ6iBfRDGgCbssZ74IaQmRwY2gHy44fhyHW5kdisI8KD1BRh9COCFiUDSIcwYA23d8"
|
||||
"yMsAMfYJh89EfWPYKukxk3cvizhd2iE1/Yer3mxaXoyR/Xt2ziRh5Oe/sTOiR41j2KkHPA"
|
||||
"fXjbAGIYAAodaRp8lPG0E9N0xMxAgwjOhuqkBgfeg8jlMIzf7yNscwYV8ST+4/gPowQehp"
|
||||
"qjRZhyFo9P01mlcxZWgz/q/KN5vXd0+k7MkoR0EIhGQcR4Eo6Agqmr4JqCFL9zKM+HINCj"
|
||||
"TPorMNlAX4IxMaQc0z2UgEwAvYya4YEHy4V4QIfsY+3kZA7Gz+a1IMl6CZSE7evprm/HTb"
|
||||
"VpG0eaIrQDyKdsAZoHecFaKPKgHmbWU0HqxK4HyR8bCpjNwelgdxIfgjl8u81W46Zrtv7m"
|
||||
"M/HC8IcrEJndBm+pCetEse6dKksx+5LKl2b3Y4V/rHzvtBvq3p/16343+JhARImFydgCjn"
|
||||
"ReE2sCJrOwke+8cGGznruFfdOFjQcvrWsIA6tcBJFcXhFG4rGtcRFfHjR46L0faWMGh5GH"
|
||||
"d0kCiAb4E5wIhE02DoBtXbCIRUcv/ppNhZZa01EEYDxTI/KmYLNjc4J0Gj3Nm3PzomEIhn"
|
||||
"1gj8YgcKwCmB4MQzCAYR5oPfa8/HQN3Zk007OUBVxr+o2beasUsRWsSI1IjDL08k1ezVMt"
|
||||
"ALN5O/Gz+ZPm0HlGBUsQFxPDlryQO1G81aKYwgdNyO8yqx5l0n9bRPG8UN742s1E8UT67r"
|
||||
"XMr+8ykfyq0/4z6S5J5fOrTn2nkH9FIZVXyKEPwUgnBngC2cCRl5MCmdWV3N/46Bi9m8b1"
|
||||
"WYVH9H/wTbNVN88qIfL6wFhsiTNZZvVwgSSzeliYY/Km7AFCHoss1ghOyqTqGacX8V2/9M"
|
||||
"qCPKnWFiDJehWiFG3KZSQH7XIhU+O6zPi5pemArRQPX5kWqLXIjaX4bH6g2S5l84RVqmKR"
|
||||
"f2lkcJKXFetefk3udO7261y+jmULwLLPtujdNRSBfRCGYxJodmYdYRBM9DBlLwVmf0Knue"
|
||||
"TGxeg58Opc+8vSlSGrN9vm9Td9+pD0l/dt/Vu3YSp0oQeQW2aXzhyWs0WfP/HL3KDVw8UE"
|
||||
"5DwFmZOQ4yGgIbvSLabK+0WSXQ9T47oUObleqkeLQD0qZnqUQyo2mQUxn57u4BPiQoDnbF"
|
||||
"DZVz3+zHlVl2nZUF3i/Hc6V5nzX2+q5YFeq95gm1dgZp3QVAo1210t3KEHbKYRRlCjLJ85"
|
||||
"/YrvFu7Y6uki14Ca/ku3wKm6YwlybCuM+v9CW1OKKQaq+m0hzJVEfRDRoeUH5Cdyyl2pOc"
|
||||
"f1SSnDJTZwX6FFlRx9kWv1pPhaPcldq64DfGsQkMjXvBT566bT1iNV3BSgPcxmeesgm+5X"
|
||||
"XBTSu5Xhvb1bjc7nM59fmVWLsIqw4l+wq8z+Tyqzu/9d+CUWdvZqNFcVeu69cu4f9Zbzcn"
|
||||
"mTM9L1vlQ2YYDsoaEpoMUt+/NKaCDtszE1tCYueL+pLaFxzMpmiFf+TTNp8Wr/t1r1+P3x"
|
||||
"h6PT4w+sixjJzPJ+zmWQpCHFJTN+ELR17mKtJ7nsCmapdGZHo4xgnnbfToArKeiwJ1KINe"
|
||||
"G9WCJLLm8lj1dWelyaPC4RbZcfXp7+AzcBYwM="
|
||||
)
|
||||
@@ -37,9 +37,10 @@ dependencies = [
|
||||
"ynab>=1.3.0",
|
||||
"ollama>=0.6.1",
|
||||
"twilio>=9.10.2",
|
||||
"aioboto3>=13.0.0",
|
||||
]
|
||||
|
||||
[tool.aerich]
|
||||
tortoise_orm = "app.TORTOISE_CONFIG"
|
||||
tortoise_orm = "config.db.TORTOISE_CONFIG"
|
||||
location = "./migrations"
|
||||
src_folder = "./."
|
||||
|
||||
@@ -13,6 +13,7 @@ interface Message {
|
||||
text: string;
|
||||
speaker: "user" | "simba";
|
||||
created_at: string;
|
||||
image_key?: string | null;
|
||||
}
|
||||
|
||||
interface Conversation {
|
||||
@@ -121,17 +122,52 @@ class ConversationService {
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
async uploadImage(
|
||||
file: File,
|
||||
conversationId: string,
|
||||
): Promise<{ image_key: string; image_url: string }> {
|
||||
const formData = new FormData();
|
||||
formData.append("file", file);
|
||||
formData.append("conversation_id", conversationId);
|
||||
|
||||
const response = await userService.fetchWithRefreshToken(
|
||||
`${this.conversationBaseUrl}/upload-image`,
|
||||
{
|
||||
method: "POST",
|
||||
body: formData,
|
||||
},
|
||||
{ skipContentType: true },
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.error || "Failed to upload image");
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
getImageUrl(imageKey: string): string {
|
||||
return `/api/conversation/image/${imageKey}`;
|
||||
}
|
||||
|
||||
async streamQuery(
|
||||
query: string,
|
||||
conversation_id: string,
|
||||
onEvent: SSEEventCallback,
|
||||
signal?: AbortSignal,
|
||||
imageKey?: string,
|
||||
): Promise<void> {
|
||||
const body: Record<string, string> = { query, conversation_id };
|
||||
if (imageKey) {
|
||||
body.image_key = imageKey;
|
||||
}
|
||||
|
||||
const response = await userService.fetchWithRefreshToken(
|
||||
`${this.conversationBaseUrl}/stream-query`,
|
||||
{
|
||||
method: "POST",
|
||||
body: JSON.stringify({ query, conversation_id }),
|
||||
body: JSON.stringify(body),
|
||||
signal,
|
||||
},
|
||||
);
|
||||
|
||||
@@ -106,14 +106,15 @@ class UserService {
|
||||
async fetchWithRefreshToken(
|
||||
url: string,
|
||||
options: RequestInit = {},
|
||||
{ skipContentType = false }: { skipContentType?: boolean } = {},
|
||||
): Promise<Response> {
|
||||
const refreshToken = localStorage.getItem("refresh_token");
|
||||
|
||||
// Add authorization header
|
||||
const headers = {
|
||||
"Content-Type": "application/json",
|
||||
...(options.headers || {}),
|
||||
...(refreshToken && { Authorization: `Bearer ${refreshToken}` }),
|
||||
const headers: Record<string, string> = {
|
||||
...(skipContentType ? {} : { "Content-Type": "application/json" }),
|
||||
...((options.headers as Record<string, string>) || {}),
|
||||
...(refreshToken ? { Authorization: `Bearer ${refreshToken}` } : {}),
|
||||
};
|
||||
|
||||
let response = await fetch(url, { ...options, headers });
|
||||
|
||||
@@ -14,6 +14,7 @@ import catIcon from "../assets/cat.png";
|
||||
type Message = {
|
||||
text: string;
|
||||
speaker: "simba" | "user" | "tool";
|
||||
image_key?: string | null;
|
||||
};
|
||||
|
||||
type Conversation = {
|
||||
@@ -55,6 +56,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
const [isLoading, setIsLoading] = useState<boolean>(false);
|
||||
const [isAdmin, setIsAdmin] = useState<boolean>(false);
|
||||
const [showAdminPanel, setShowAdminPanel] = useState<boolean>(false);
|
||||
const [pendingImage, setPendingImage] = useState<File | null>(null);
|
||||
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
const isMountedRef = useRef<boolean>(true);
|
||||
@@ -80,7 +82,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
try {
|
||||
const fetched = await conversationService.getConversation(conversation.id);
|
||||
setMessages(
|
||||
fetched.messages.map((m) => ({ text: m.text, speaker: m.speaker })),
|
||||
fetched.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })),
|
||||
);
|
||||
} catch (err) {
|
||||
console.error("Failed to load messages:", err);
|
||||
@@ -120,7 +122,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
try {
|
||||
const conv = await conversationService.getConversation(selectedConversation.id);
|
||||
setSelectedConversation({ id: conv.id, title: conv.name });
|
||||
setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker })));
|
||||
setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })));
|
||||
} catch (err) {
|
||||
console.error("Failed to load messages:", err);
|
||||
}
|
||||
@@ -129,7 +131,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
}, [selectedConversation?.id]);
|
||||
|
||||
const handleQuestionSubmit = async () => {
|
||||
if (!query.trim() || isLoading) return;
|
||||
if ((!query.trim() && !pendingImage) || isLoading) return;
|
||||
|
||||
let activeConversation = selectedConversation;
|
||||
if (!activeConversation) {
|
||||
@@ -139,9 +141,13 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
setConversations((prev) => [activeConversation!, ...prev]);
|
||||
}
|
||||
|
||||
// Capture pending image before clearing state
|
||||
const imageFile = pendingImage;
|
||||
|
||||
const currMessages = messages.concat([{ text: query, speaker: "user" }]);
|
||||
setMessages(currMessages);
|
||||
setQuery("");
|
||||
setPendingImage(null);
|
||||
setIsLoading(true);
|
||||
|
||||
if (simbaMode) {
|
||||
@@ -155,6 +161,29 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
abortControllerRef.current = abortController;
|
||||
|
||||
try {
|
||||
// Upload image first if present
|
||||
let imageKey: string | undefined;
|
||||
if (imageFile) {
|
||||
const uploadResult = await conversationService.uploadImage(
|
||||
imageFile,
|
||||
activeConversation.id,
|
||||
);
|
||||
imageKey = uploadResult.image_key;
|
||||
|
||||
// Update the user message with the image key
|
||||
setMessages((prev) => {
|
||||
const updated = [...prev];
|
||||
// Find the last user message we just added
|
||||
for (let i = updated.length - 1; i >= 0; i--) {
|
||||
if (updated[i].speaker === "user") {
|
||||
updated[i] = { ...updated[i], image_key: imageKey };
|
||||
break;
|
||||
}
|
||||
}
|
||||
return updated;
|
||||
});
|
||||
}
|
||||
|
||||
await conversationService.streamQuery(
|
||||
query,
|
||||
activeConversation.id,
|
||||
@@ -170,6 +199,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
}
|
||||
},
|
||||
abortController.signal,
|
||||
imageKey,
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
@@ -349,6 +379,9 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
handleQuestionSubmit={handleQuestionSubmit}
|
||||
setSimbaMode={setSimbaMode}
|
||||
isLoading={isLoading}
|
||||
pendingImage={pendingImage}
|
||||
onImageSelect={(file) => setPendingImage(file)}
|
||||
onClearImage={() => setPendingImage(null)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -375,7 +408,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
return <ToolBubble key={index} text={msg.text} />;
|
||||
if (msg.speaker === "simba")
|
||||
return <AnswerBubble key={index} text={msg.text} />;
|
||||
return <QuestionBubble key={index} text={msg.text} />;
|
||||
return <QuestionBubble key={index} text={msg.text} image_key={msg.image_key} />;
|
||||
})}
|
||||
|
||||
{isLoading && <AnswerBubble text="" loading={true} />}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useState } from "react";
|
||||
import { ArrowUp } from "lucide-react";
|
||||
import { useRef, useState } from "react";
|
||||
import { ArrowUp, ImagePlus, X } from "lucide-react";
|
||||
import { cn } from "../lib/utils";
|
||||
import { Textarea } from "./ui/textarea";
|
||||
|
||||
@@ -10,6 +10,9 @@ type MessageInputProps = {
|
||||
setSimbaMode: (val: boolean) => void;
|
||||
query: string;
|
||||
isLoading: boolean;
|
||||
pendingImage: File | null;
|
||||
onImageSelect: (file: File) => void;
|
||||
onClearImage: () => void;
|
||||
};
|
||||
|
||||
export const MessageInput = ({
|
||||
@@ -19,8 +22,12 @@ export const MessageInput = ({
|
||||
handleQuestionSubmit,
|
||||
setSimbaMode,
|
||||
isLoading,
|
||||
pendingImage,
|
||||
onImageSelect,
|
||||
onClearImage,
|
||||
}: MessageInputProps) => {
|
||||
const [simbaMode, setLocalSimbaMode] = useState(false);
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const toggleSimbaMode = () => {
|
||||
const next = !simbaMode;
|
||||
@@ -28,6 +35,17 @@ export const MessageInput = ({
|
||||
setSimbaMode(next);
|
||||
};
|
||||
|
||||
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = e.target.files?.[0];
|
||||
if (file) {
|
||||
onImageSelect(file);
|
||||
}
|
||||
// Reset so the same file can be re-selected
|
||||
e.target.value = "";
|
||||
};
|
||||
|
||||
const canSend = !isLoading && (query.trim() || pendingImage);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
@@ -36,6 +54,26 @@ export const MessageInput = ({
|
||||
"focus-within:border-amber-soft/60",
|
||||
)}
|
||||
>
|
||||
{/* Image preview */}
|
||||
{pendingImage && (
|
||||
<div className="px-3 pt-3">
|
||||
<div className="relative inline-block">
|
||||
<img
|
||||
src={URL.createObjectURL(pendingImage)}
|
||||
alt="Pending upload"
|
||||
className="h-20 rounded-lg object-cover border border-sand"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={onClearImage}
|
||||
className="absolute -top-1.5 -right-1.5 w-5 h-5 rounded-full bg-charcoal text-white flex items-center justify-center hover:bg-charcoal/80 transition-colors cursor-pointer"
|
||||
>
|
||||
<X size={12} />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Textarea */}
|
||||
<Textarea
|
||||
onChange={handleQueryChange}
|
||||
@@ -46,32 +84,58 @@ export const MessageInput = ({
|
||||
className="min-h-[60px] max-h-40"
|
||||
/>
|
||||
|
||||
{/* Hidden file input */}
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept="image/*"
|
||||
onChange={handleFileChange}
|
||||
className="hidden"
|
||||
/>
|
||||
|
||||
{/* Bottom toolbar */}
|
||||
<div className="flex items-center justify-between px-3 pb-2.5 pt-1">
|
||||
{/* Simba mode toggle */}
|
||||
<button
|
||||
type="button"
|
||||
onClick={toggleSimbaMode}
|
||||
className="flex items-center gap-2 group cursor-pointer select-none"
|
||||
>
|
||||
<div className={cn("toggle-track", simbaMode && "checked")}>
|
||||
<div className="toggle-thumb" />
|
||||
</div>
|
||||
<span className="text-xs text-warm-gray group-hover:text-charcoal transition-colors">
|
||||
simba mode
|
||||
</span>
|
||||
</button>
|
||||
<div className="flex items-center gap-3">
|
||||
{/* Simba mode toggle */}
|
||||
<button
|
||||
type="button"
|
||||
onClick={toggleSimbaMode}
|
||||
className="flex items-center gap-2 group cursor-pointer select-none"
|
||||
>
|
||||
<div className={cn("toggle-track", simbaMode && "checked")}>
|
||||
<div className="toggle-thumb" />
|
||||
</div>
|
||||
<span className="text-xs text-warm-gray group-hover:text-charcoal transition-colors">
|
||||
simba mode
|
||||
</span>
|
||||
</button>
|
||||
|
||||
{/* Image attach button */}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => fileInputRef.current?.click()}
|
||||
disabled={isLoading}
|
||||
className={cn(
|
||||
"w-7 h-7 rounded-lg flex items-center justify-center transition-all cursor-pointer",
|
||||
isLoading
|
||||
? "text-warm-gray/40 cursor-not-allowed"
|
||||
: "text-warm-gray hover:text-charcoal hover:bg-cream-dark",
|
||||
)}
|
||||
>
|
||||
<ImagePlus size={16} />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Send button */}
|
||||
<button
|
||||
type="submit"
|
||||
onClick={handleQuestionSubmit}
|
||||
disabled={isLoading || !query.trim()}
|
||||
disabled={!canSend}
|
||||
className={cn(
|
||||
"w-8 h-8 rounded-full flex items-center justify-center",
|
||||
"transition-all duration-200 cursor-pointer",
|
||||
"shadow-sm",
|
||||
isLoading || !query.trim()
|
||||
!canSend
|
||||
? "bg-sand text-warm-gray/50 cursor-not-allowed shadow-none"
|
||||
: "bg-amber-glow text-white hover:bg-amber-dark hover:shadow-md hover:shadow-amber-glow/30 active:scale-95",
|
||||
)}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import { cn } from "../lib/utils";
|
||||
import { conversationService } from "../api/conversationService";
|
||||
|
||||
type QuestionBubbleProps = {
|
||||
text: string;
|
||||
image_key?: string | null;
|
||||
};
|
||||
|
||||
export const QuestionBubble = ({ text }: QuestionBubbleProps) => {
|
||||
export const QuestionBubble = ({ text, image_key }: QuestionBubbleProps) => {
|
||||
return (
|
||||
<div className="flex justify-end message-enter">
|
||||
<div
|
||||
@@ -15,6 +17,13 @@ export const QuestionBubble = ({ text }: QuestionBubbleProps) => {
|
||||
"shadow-sm shadow-leaf/10",
|
||||
)}
|
||||
>
|
||||
{image_key && (
|
||||
<img
|
||||
src={conversationService.getImageUrl(image_key)}
|
||||
alt="Uploaded image"
|
||||
className="max-w-full rounded-xl mb-2"
|
||||
/>
|
||||
)}
|
||||
{text}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -76,6 +76,39 @@ def describe_simba_image(input):
|
||||
return result
|
||||
|
||||
|
||||
async def analyze_user_image(file_bytes: bytes) -> str:
|
||||
"""Analyze an image uploaded by a user and return a text description.
|
||||
|
||||
Uses Ollama vision model to describe the image contents.
|
||||
Works with JPEG, PNG, WebP bytes (HEIC should be converted before calling).
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
# Write to temp file since ollama client expects a file path
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(file_bytes)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
response = client.chat(
|
||||
model="gemma3:4b",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful image analyst. Describe what you see in the image in detail. Be thorough but concise.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please describe this image in detail.",
|
||||
"images": [temp_path],
|
||||
},
|
||||
],
|
||||
)
|
||||
return response["message"]["content"]
|
||||
finally:
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.filepath:
|
||||
|
||||
62
utils/image_upload.py
Normal file
62
utils/image_upload.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
from PIL import Image
|
||||
from pillow_heif import register_heif_opener
|
||||
|
||||
register_heif_opener()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
ALLOWED_TYPES = {"image/jpeg", "image/png", "image/webp", "image/heic", "image/heif"}
|
||||
MAX_DIMENSION = 1920
|
||||
|
||||
|
||||
class ImageValidationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def process_image(file_bytes: bytes, content_type: str) -> tuple[bytes, str]:
|
||||
"""Validate, resize, and strip EXIF from an uploaded image.
|
||||
|
||||
Returns processed bytes and the output content type (always image/jpeg or image/png or image/webp).
|
||||
"""
|
||||
if content_type not in ALLOWED_TYPES:
|
||||
raise ImageValidationError(
|
||||
f"Unsupported image type: {content_type}. "
|
||||
f"Allowed: JPEG, PNG, WebP, HEIC"
|
||||
)
|
||||
|
||||
img = Image.open(io.BytesIO(file_bytes))
|
||||
|
||||
# Resize if too large
|
||||
width, height = img.size
|
||||
if max(width, height) > MAX_DIMENSION:
|
||||
ratio = MAX_DIMENSION / max(width, height)
|
||||
new_size = (int(width * ratio), int(height * ratio))
|
||||
img = img.resize(new_size, Image.LANCZOS)
|
||||
logging.info(
|
||||
f"Resized image from {width}x{height} to {new_size[0]}x{new_size[1]}"
|
||||
)
|
||||
|
||||
# Strip EXIF by copying pixel data to a new image
|
||||
clean_img = Image.new(img.mode, img.size)
|
||||
clean_img.putdata(list(img.getdata()))
|
||||
|
||||
# Convert HEIC/HEIF to JPEG; otherwise keep original format
|
||||
if content_type in {"image/heic", "image/heif"}:
|
||||
output_format = "JPEG"
|
||||
output_content_type = "image/jpeg"
|
||||
elif content_type == "image/png":
|
||||
output_format = "PNG"
|
||||
output_content_type = "image/png"
|
||||
elif content_type == "image/webp":
|
||||
output_format = "WEBP"
|
||||
output_content_type = "image/webp"
|
||||
else:
|
||||
output_format = "JPEG"
|
||||
output_content_type = "image/jpeg"
|
||||
|
||||
buf = io.BytesIO()
|
||||
clean_img.save(buf, format=output_format, quality=85)
|
||||
return buf.getvalue(), output_content_type
|
||||
53
utils/s3_client.py
Normal file
53
utils/s3_client.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
import aioboto3
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL")
|
||||
S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID")
|
||||
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
|
||||
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "asksimba-images")
|
||||
S3_REGION = os.getenv("S3_REGION", "garage")
|
||||
|
||||
session = aioboto3.Session()
|
||||
|
||||
|
||||
def _get_client():
|
||||
return session.client(
|
||||
"s3",
|
||||
endpoint_url=S3_ENDPOINT_URL,
|
||||
aws_access_key_id=S3_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||
region_name=S3_REGION,
|
||||
)
|
||||
|
||||
|
||||
async def upload_image(file_bytes: bytes, key: str, content_type: str) -> str:
|
||||
async with _get_client() as client:
|
||||
await client.put_object(
|
||||
Bucket=S3_BUCKET_NAME,
|
||||
Key=key,
|
||||
Body=file_bytes,
|
||||
ContentType=content_type,
|
||||
)
|
||||
logging.info(f"Uploaded image to S3: {key}")
|
||||
return key
|
||||
|
||||
|
||||
async def get_image(key: str) -> tuple[bytes, str]:
|
||||
async with _get_client() as client:
|
||||
response = await client.get_object(Bucket=S3_BUCKET_NAME, Key=key)
|
||||
body = await response["Body"].read()
|
||||
content_type = response.get("ContentType", "image/jpeg")
|
||||
return body, content_type
|
||||
|
||||
|
||||
async def delete_image(key: str) -> None:
|
||||
async with _get_client() as client:
|
||||
await client.delete_object(Bucket=S3_BUCKET_NAME, Key=key)
|
||||
logging.info(f"Deleted image from S3: {key}")
|
||||
Reference in New Issue
Block a user