From 0415610d647e908ef85610899fa762a35757229b Mon Sep 17 00:00:00 2001 From: Ryan Chen Date: Sat, 4 Apr 2026 08:03:19 -0400 Subject: [PATCH] Add image upload and vision analysis to Ask Simba chat Users can now attach images in the web chat for Simba to analyze using Ollama's gemma3 vision model. Images are stored in Garage (S3-compatible) and displayed in chat history. Also fixes aerich migration config by extracting TORTOISE_CONFIG into a standalone config/db.py module, removing the stale aerich_config.py, and adding missing MODELS_STATE to migration 3. Co-Authored-By: Claude Opus 4.6 --- app.py | 20 +--- blueprints/conversation/__init__.py | 99 +++++++++++++++++-- blueprints/conversation/logic.py | 2 + blueprints/conversation/models.py | 1 + aerich_config.py => config/db.py | 11 +-- docker-compose.yml | 6 ++ .../3_20260313000000_add_email_fields.py | 29 ++++++ .../models/4_20260404080201_add_image_key.py | 43 ++++++++ pyproject.toml | 3 +- raggr-frontend/src/api/conversationService.ts | 38 ++++++- raggr-frontend/src/api/userService.ts | 9 +- raggr-frontend/src/components/ChatScreen.tsx | 41 +++++++- .../src/components/MessageInput.tsx | 98 ++++++++++++++---- .../src/components/QuestionBubble.tsx | 11 ++- utils/image_process.py | 33 +++++++ utils/image_upload.py | 62 ++++++++++++ utils/s3_client.py | 53 ++++++++++ 17 files changed, 501 insertions(+), 58 deletions(-) rename aerich_config.py => config/db.py (53%) create mode 100644 migrations/models/4_20260404080201_add_image_key.py create mode 100644 utils/image_upload.py create mode 100644 utils/s3_client.py diff --git a/app.py b/app.py index 13ba978..0cdbb95 100644 --- a/app.py +++ b/app.py @@ -13,6 +13,7 @@ import blueprints.users import blueprints.whatsapp import blueprints.email import blueprints.users.models +from config.db import TORTOISE_CONFIG from main import consult_simba_oracle # Load environment variables @@ -28,6 +29,7 @@ app = Quart( ) app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY") +app.config["MAX_CONTENT_LENGTH"] = 10 * 1024 * 1024 # 10 MB upload limit jwt = JWTManager(app) # Register blueprints @@ -38,24 +40,6 @@ app.register_blueprint(blueprints.whatsapp.whatsapp_blueprint) app.register_blueprint(blueprints.email.email_blueprint) -# Database configuration with environment variable support -DATABASE_URL = os.getenv( - "DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr" -) - -TORTOISE_CONFIG = { - "connections": {"default": DATABASE_URL}, - "apps": { - "models": { - "models": [ - "blueprints.conversation.models", - "blueprints.users.models", - "aerich.models", - ] - }, - }, -} - # Initialize Tortoise ORM with lifecycle hooks @app.while_serving async def lifespan(): diff --git a/blueprints/conversation/__init__.py b/blueprints/conversation/__init__.py index 28b6765..d19ace2 100644 --- a/blueprints/conversation/__init__.py +++ b/blueprints/conversation/__init__.py @@ -1,13 +1,19 @@ import datetime import json +import logging +import uuid -from quart import Blueprint, jsonify, make_response, request +from quart import Blueprint, Response, jsonify, make_response, request from quart_jwt_extended import ( get_jwt_identity, jwt_refresh_token_required, ) import blueprints.users.models +from utils.image_process import analyze_user_image +from utils.image_upload import ImageValidationError, process_image +from utils.s3_client import get_image as s3_get_image +from utils.s3_client import upload_image as s3_upload_image from .agents import main_agent from .logic import ( @@ -29,7 +35,9 @@ conversation_blueprint = Blueprint( _SYSTEM_PROMPT = SIMBA_SYSTEM_PROMPT -def _build_messages_payload(conversation, query_text: str) -> list: +def _build_messages_payload( + conversation, query_text: str, image_description: str | None = None +) -> list: recent_messages = ( conversation.messages[-10:] if len(conversation.messages) > 10 @@ -38,8 +46,19 @@ def _build_messages_payload(conversation, query_text: str) -> list: messages_payload = [{"role": "system", "content": _SYSTEM_PROMPT}] for msg in recent_messages[:-1]: # Exclude the message we just added role = "user" if msg.speaker == "user" else "assistant" - messages_payload.append({"role": role, "content": msg.text}) - messages_payload.append({"role": "user", "content": query_text}) + text = msg.text + if msg.image_key and role == "user": + text = f"[User sent an image]\n{text}" + messages_payload.append({"role": role, "content": text}) + + # Build the current user message with optional image description + if image_description: + content = f"[Image analysis: {image_description}]" + if query_text: + content = f"{query_text}\n\n{content}" + else: + content = query_text + messages_payload.append({"role": "user", "content": content}) return messages_payload @@ -74,6 +93,58 @@ async def query(): return jsonify({"response": message}) +@conversation_blueprint.post("/upload-image") +@jwt_refresh_token_required +async def upload_image(): + current_user_uuid = get_jwt_identity() + await blueprints.users.models.User.get(id=current_user_uuid) + + files = await request.files + form = await request.form + file = files.get("file") + conversation_id = form.get("conversation_id") + + if not file or not conversation_id: + return jsonify({"error": "file and conversation_id are required"}), 400 + + file_bytes = file.read() + content_type = file.content_type or "image/jpeg" + + try: + processed_bytes, output_content_type = process_image(file_bytes, content_type) + except ImageValidationError as e: + return jsonify({"error": str(e)}), 400 + + ext = output_content_type.split("/")[-1] + if ext == "jpeg": + ext = "jpg" + key = f"conversations/{conversation_id}/{uuid.uuid4()}.{ext}" + + await s3_upload_image(processed_bytes, key, output_content_type) + + return jsonify( + { + "image_key": key, + "image_url": f"/api/conversation/image/{key}", + } + ) + + +@conversation_blueprint.get("/image/") +@jwt_refresh_token_required +async def serve_image(image_key: str): + try: + image_bytes, content_type = await s3_get_image(image_key) + except Exception: + return jsonify({"error": "Image not found"}), 404 + + return Response( + image_bytes, + content_type=content_type, + headers={"Cache-Control": "private, max-age=3600"}, + ) + + @conversation_blueprint.post("/stream-query") @jwt_refresh_token_required async def stream_query(): @@ -82,16 +153,31 @@ async def stream_query(): data = await request.get_json() query_text = data.get("query") conversation_id = data.get("conversation_id") + image_key = data.get("image_key") conversation = await get_conversation_by_id(conversation_id) await conversation.fetch_related("messages") await add_message_to_conversation( conversation=conversation, - message=query_text, + message=query_text or "", speaker="user", user=user, + image_key=image_key, ) - messages_payload = _build_messages_payload(conversation, query_text) + # If an image was uploaded, analyze it with the vision model + image_description = None + if image_key: + try: + image_bytes, _ = await s3_get_image(image_key) + image_description = await analyze_user_image(image_bytes) + logging.info(f"Image analysis complete for {image_key}") + except Exception as e: + logging.error(f"Failed to analyze image: {e}") + image_description = "[Image could not be analyzed]" + + messages_payload = _build_messages_payload( + conversation, query_text or "", image_description + ) payload = {"messages": messages_payload} async def event_generator(): @@ -160,6 +246,7 @@ async def get_conversation(conversation_id: str): "text": msg.text, "speaker": msg.speaker.value, "created_at": msg.created_at.isoformat(), + "image_key": msg.image_key, } ) name = conversation.name diff --git a/blueprints/conversation/logic.py b/blueprints/conversation/logic.py index 8129ffc..4586d19 100644 --- a/blueprints/conversation/logic.py +++ b/blueprints/conversation/logic.py @@ -16,12 +16,14 @@ async def add_message_to_conversation( message: str, speaker: str, user: blueprints.users.models.User, + image_key: str | None = None, ) -> ConversationMessage: print(conversation, message, speaker) message = await ConversationMessage.create( text=message, speaker=speaker, conversation=conversation, + image_key=image_key, ) return message diff --git a/blueprints/conversation/models.py b/blueprints/conversation/models.py index e0e5ad1..1a73b6b 100644 --- a/blueprints/conversation/models.py +++ b/blueprints/conversation/models.py @@ -41,6 +41,7 @@ class ConversationMessage(Model): ) created_at = fields.DatetimeField(auto_now_add=True) speaker = fields.CharEnumField(enum_type=Speaker, max_length=10) + image_key = fields.CharField(max_length=512, null=True, default=None) class Meta: table = "conversation_messages" diff --git a/aerich_config.py b/config/db.py similarity index 53% rename from aerich_config.py rename to config/db.py index bfacaa9..3e4aa6d 100644 --- a/aerich_config.py +++ b/config/db.py @@ -1,15 +1,14 @@ import os + from dotenv import load_dotenv -# Load environment variables load_dotenv() -# Database configuration with environment variable support -# Use DATABASE_PATH for relative paths or DATABASE_URL for full connection strings -DATABASE_PATH = os.getenv("DATABASE_PATH", "database/raggr.db") -DATABASE_URL = os.getenv("DATABASE_URL", f"sqlite://{DATABASE_PATH}") +DATABASE_URL = os.getenv( + "DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr" +) -TORTOISE_ORM = { +TORTOISE_CONFIG = { "connections": {"default": DATABASE_URL}, "apps": { "models": { diff --git a/docker-compose.yml b/docker-compose.yml index 5fa3914..88f91a5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -55,6 +55,12 @@ services: - OBSIDIAN_DEVICE_NAME=${OBSIDIAN_DEVICE_NAME} - OBSIDIAN_CONTINUOUS_SYNC=${OBSIDIAN_CONTINUOUS_SYNC:-false} - OBSIDIAN_VAULT_PATH=${OBSIDIAN_VAULT_PATH:-/app/data/obsidian} + - S3_ENDPOINT_URL=${S3_ENDPOINT_URL} + - S3_ACCESS_KEY_ID=${S3_ACCESS_KEY_ID} + - S3_SECRET_ACCESS_KEY=${S3_SECRET_ACCESS_KEY} + - S3_BUCKET_NAME=${S3_BUCKET_NAME:-asksimba-images} + - S3_REGION=${S3_REGION:-garage} + - OLLAMA_HOST=${OLLAMA_HOST:-http://localhost:11434} depends_on: postgres: condition: service_healthy diff --git a/migrations/models/3_20260313000000_add_email_fields.py b/migrations/models/3_20260313000000_add_email_fields.py index c4a6e11..b4e5651 100644 --- a/migrations/models/3_20260313000000_add_email_fields.py +++ b/migrations/models/3_20260313000000_add_email_fields.py @@ -15,3 +15,32 @@ async def downgrade(db: BaseDBAsyncClient) -> str: DROP INDEX IF EXISTS "idx_users_email_h_a1b2c3"; ALTER TABLE "users" DROP COLUMN "email_hmac_token"; ALTER TABLE "users" DROP COLUMN "email_enabled";""" + + +MODELS_STATE = ( + "eJztmm1v4jgQx78Kyquu1KtaKN1VdTopUHrLbYEThX3qVZFJXMg1sbOJsxRV/e5nm4Q4jg" + "OEAoU93rRl7CH2z2PP35M+ay62oBOc1DH6Cf0AEBsj7bL0rCHgQvqHsv24pAHPS1qZgYCB" + "wx1MoSdvAYOA+MAktPEBOAGkJgsGpm970cNQ6DjMiE3a0UbDxBQi+0cIDYKHkIygTxvu7q" + "nZRhZ8gkH80Xs0HmzoWKlx2xZ7NrcbZOJxW7/fvLrmPdnjBoaJndBFSW9vQkYYzbqHoW2d" + "MB/WNoQI+oBAS5gGG2U07dg0HTE1ED+Es6FaicGCDyB0GAzt94cQmYxBiT+J/Tj/QyuAh6" + "JmaG1EGIvnl+mskjlzq8YeVf+od48qF+/4LHFAhj5v5ES0F+4ICJi6cq4JSP47g7I+Ar4a" + "ZdxfgkkHugrG2JBwTGIoBhkDWo2a5oInw4FoSEb0Y7lanYPxs97lJGkvjhLTuJ5GfTtqKk" + "/bGNIEoelDNmUDkCzIK9pCbBeqYaY9JaRW5HoS/7GjgOkcrA5yJtEmmMO312w1bnt66282" + "EzcIfjgckd5rsJYyt04k69GFtBSzLyl9afY+ltjH0vdOuyHH/qxf77vGxgRCgg2ExwawhP" + "0aW2MwqYUNPWvFhU17Hhb2TRc2GrywrgH0jWIZRHB5RRqJxrbFRVw9abDU+/CozBkMRhbe" + "NfahPUSf4IQjbNJxAGSqkkUkOvrR1+wqtMSajMIH45kaEYOCzo7OCZJp9tRv6/pVQ+MMB8" + "B8HAPfMnJgujAIwBAGWaC1yPP6Uxc6M2mmZikKuNb0G3fzVMljy1nhMhYYpehlm9yyK1sA" + "ovO2omezJ82hs0AFCxCXE8OGuJAHUbzXopjAJ0XK71GrGmXcf19E8bxU3vjaS2XxWPoetf" + "Sv71KZ/KbT/jPuLkjl+k2ndlDIv6KQyirkwIPgUSUG2AWygUI3IwVSqyu4v/HW0fq3je5l" + "iWX0f9Bts1XTL0uB7Q6AttwSp26ZZ6dLXDLPTnPvmKxJ2kBioil2zCtc13nm76mENaWC1y" + "ulrFw/21mKCzWtIlyKattNKjl+Z1BIt/guka/V2NY+aLP912ZsHYsWLUWffdFoWyhceiAI" + "xthXRGbNRsCfqGGKXhLMwYRM7z+7eqVXwasxvSrKLYqs1mzr3W9qyRv3F+O29q3X0CW60A" + "W2UyRKZw7rCdHFO36dAXp2upzomad6MrJnPAIkoEe6QZXkIE9mqmEqXFfCKofqdqlWloFa" + "yWdaySDlQWZAxKan2vgYOxCgOQEq+srbnzpv6jAtmqoL7P9O5ya1/2tN+Urbb9UaNHg5Zt" + "rJnkqhZrunhDtygUk1wiNUKMsFu1/y3cOIPbtY5hiQr6zCKXAhRyy2LdMIwsG/0FSUD/KB" + "yn57CHMjWZ9e6EeG5+OftlXsSM04bk9KaQ42gfMKLZrmWl3mWK3mH6vVzLHqWMAzhj4OPU" + "Uh/6/bTluNVHKTgPYRneWdZZvkuOTYAbnfGN67+83ofDbz+dVEuXAoCSv2BYdq4v+kmnh4" + "3/5LLOzsdV6mKrToXWjmn8vW80J0l2+k230RqkPfNkeaooAWtRzPK6GBpM/O1NCaKOednL" + "KExjBLwRCt/JvepPnr6N/KZ+fvzz9ULs4/0C58JDPL+zmHQXwNyS+ZsY2grHPnaz3B5VAw" + "S6Qz3RpFBPO0+34C3EhBhz6RQKRI7/kSWXB5K3m8sdLj2uRxgWy7/vTy8h9Mf/k3" +) diff --git a/migrations/models/4_20260404080201_add_image_key.py b/migrations/models/4_20260404080201_add_image_key.py new file mode 100644 index 0000000..e15fdbc --- /dev/null +++ b/migrations/models/4_20260404080201_add_image_key.py @@ -0,0 +1,43 @@ +from tortoise import BaseDBAsyncClient + +RUN_IN_TRANSACTION = True + + +async def upgrade(db: BaseDBAsyncClient) -> str: + return """ + ALTER TABLE "conversation_messages" ADD "image_key" VARCHAR(512);""" + + +async def downgrade(db: BaseDBAsyncClient) -> str: + return """ + ALTER TABLE "conversation_messages" DROP COLUMN "image_key";""" + + +MODELS_STATE = ( + "eJztmmtv4jgUhv8KyqeO1K0KvcyoWq0UWrrDToFVC3PrVpFJXPCS2JnEGYqq/ve1TUIcx6" + "GkBQqzfGnLsQ+xH1/Oe076aHjEgW54cE7wTxiEgCKCjbPKo4GBB9kf2vb9igF8P23lBgr6" + "rnCwpZ6iBfRDGgCbssZ74IaQmRwY2gHy44fhyHW5kdisI8KD1BRh9COCFiUDSIcwYA23d8" + "yMsAMfYJh89EfWPYKukxk3cvizhd2iE1/Yer3mxaXoyR/Xt2ziRh5Oe/sTOiR41j2KkHPA" + "fXjbAGIYAAodaRp8lPG0E9N0xMxAgwjOhuqkBgfeg8jlMIzf7yNscwYV8ST+4/gPowQehp" + "qjRZhyFo9P01mlcxZWgz/q/KN5vXd0+k7MkoR0EIhGQcR4Eo6Agqmr4JqCFL9zKM+HINCj" + "TPorMNlAX4IxMaQc0z2UgEwAvYya4YEHy4V4QIfsY+3kZA7Gz+a1IMl6CZSE7evprm/HTb" + "VpG0eaIrQDyKdsAZoHecFaKPKgHmbWU0HqxK4HyR8bCpjNwelgdxIfgjl8u81W46Zrtv7m" + "M/HC8IcrEJndBm+pCetEse6dKksx+5LKl2b3Y4V/rHzvtBvq3p/16343+JhARImFydgCjn" + "ReE2sCJrOwke+8cGGznruFfdOFjQcvrWsIA6tcBJFcXhFG4rGtcRFfHjR46L0faWMGh5GH" + "d0kCiAb4E5wIhE02DoBtXbCIRUcv/ppNhZZa01EEYDxTI/KmYLNjc4J0Gj3Nm3PzomEIhn" + "1gj8YgcKwCmB4MQzCAYR5oPfa8/HQN3Zk007OUBVxr+o2beasUsRWsSI1IjDL08k1ezVMt" + "ALN5O/Gz+ZPm0HlGBUsQFxPDlryQO1G81aKYwgdNyO8yqx5l0n9bRPG8UN742s1E8UT67r" + "XMr+8ykfyq0/4z6S5J5fOrTn2nkH9FIZVXyKEPwUgnBngC2cCRl5MCmdWV3N/46Bi9m8b1" + "WYVH9H/wTbNVN88qIfL6wFhsiTNZZvVwgSSzeliYY/Km7AFCHoss1ghOyqTqGacX8V2/9M" + "qCPKnWFiDJehWiFG3KZSQH7XIhU+O6zPi5pemArRQPX5kWqLXIjaX4bH6g2S5l84RVqmKR" + "f2lkcJKXFetefk3udO7261y+jmULwLLPtujdNRSBfRCGYxJodmYdYRBM9DBlLwVmf0Knue" + "TGxeg58Opc+8vSlSGrN9vm9Td9+pD0l/dt/Vu3YSp0oQeQW2aXzhyWs0WfP/HL3KDVw8UE" + "5DwFmZOQ4yGgIbvSLabK+0WSXQ9T47oUObleqkeLQD0qZnqUQyo2mQUxn57u4BPiQoDnbF" + "DZVz3+zHlVl2nZUF3i/Hc6V5nzX2+q5YFeq95gm1dgZp3QVAo1210t3KEHbKYRRlCjLJ85" + "/YrvFu7Y6uki14Ca/ku3wKm6YwlybCuM+v9CW1OKKQaq+m0hzJVEfRDRoeUH5Cdyyl2pOc" + "f1SSnDJTZwX6FFlRx9kWv1pPhaPcldq64DfGsQkMjXvBT566bT1iNV3BSgPcxmeesgm+5X" + "XBTSu5Xhvb1bjc7nM59fmVWLsIqw4l+wq8z+Tyqzu/9d+CUWdvZqNFcVeu69cu4f9Zbzcn" + "mTM9L1vlQ2YYDsoaEpoMUt+/NKaCDtszE1tCYueL+pLaFxzMpmiFf+TTNp8Wr/t1r1+P3x" + "h6PT4w+sixjJzPJ+zmWQpCHFJTN+ELR17mKtJ7nsCmapdGZHo4xgnnbfToArKeiwJ1KINe" + "G9WCJLLm8lj1dWelyaPC4RbZcfXp7+AzcBYwM=" +) diff --git a/pyproject.toml b/pyproject.toml index 0827bbb..4e88e98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,9 +37,10 @@ dependencies = [ "ynab>=1.3.0", "ollama>=0.6.1", "twilio>=9.10.2", + "aioboto3>=13.0.0", ] [tool.aerich] -tortoise_orm = "app.TORTOISE_CONFIG" +tortoise_orm = "config.db.TORTOISE_CONFIG" location = "./migrations" src_folder = "./." diff --git a/raggr-frontend/src/api/conversationService.ts b/raggr-frontend/src/api/conversationService.ts index f80eb6a..788089e 100644 --- a/raggr-frontend/src/api/conversationService.ts +++ b/raggr-frontend/src/api/conversationService.ts @@ -13,6 +13,7 @@ interface Message { text: string; speaker: "user" | "simba"; created_at: string; + image_key?: string | null; } interface Conversation { @@ -121,17 +122,52 @@ class ConversationService { return await response.json(); } + async uploadImage( + file: File, + conversationId: string, + ): Promise<{ image_key: string; image_url: string }> { + const formData = new FormData(); + formData.append("file", file); + formData.append("conversation_id", conversationId); + + const response = await userService.fetchWithRefreshToken( + `${this.conversationBaseUrl}/upload-image`, + { + method: "POST", + body: formData, + }, + { skipContentType: true }, + ); + + if (!response.ok) { + const data = await response.json(); + throw new Error(data.error || "Failed to upload image"); + } + + return await response.json(); + } + + getImageUrl(imageKey: string): string { + return `/api/conversation/image/${imageKey}`; + } + async streamQuery( query: string, conversation_id: string, onEvent: SSEEventCallback, signal?: AbortSignal, + imageKey?: string, ): Promise { + const body: Record = { query, conversation_id }; + if (imageKey) { + body.image_key = imageKey; + } + const response = await userService.fetchWithRefreshToken( `${this.conversationBaseUrl}/stream-query`, { method: "POST", - body: JSON.stringify({ query, conversation_id }), + body: JSON.stringify(body), signal, }, ); diff --git a/raggr-frontend/src/api/userService.ts b/raggr-frontend/src/api/userService.ts index 8f816c5..52e4f4c 100644 --- a/raggr-frontend/src/api/userService.ts +++ b/raggr-frontend/src/api/userService.ts @@ -106,14 +106,15 @@ class UserService { async fetchWithRefreshToken( url: string, options: RequestInit = {}, + { skipContentType = false }: { skipContentType?: boolean } = {}, ): Promise { const refreshToken = localStorage.getItem("refresh_token"); // Add authorization header - const headers = { - "Content-Type": "application/json", - ...(options.headers || {}), - ...(refreshToken && { Authorization: `Bearer ${refreshToken}` }), + const headers: Record = { + ...(skipContentType ? {} : { "Content-Type": "application/json" }), + ...((options.headers as Record) || {}), + ...(refreshToken ? { Authorization: `Bearer ${refreshToken}` } : {}), }; let response = await fetch(url, { ...options, headers }); diff --git a/raggr-frontend/src/components/ChatScreen.tsx b/raggr-frontend/src/components/ChatScreen.tsx index 289a3a2..b1c6216 100644 --- a/raggr-frontend/src/components/ChatScreen.tsx +++ b/raggr-frontend/src/components/ChatScreen.tsx @@ -14,6 +14,7 @@ import catIcon from "../assets/cat.png"; type Message = { text: string; speaker: "simba" | "user" | "tool"; + image_key?: string | null; }; type Conversation = { @@ -55,6 +56,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { const [isLoading, setIsLoading] = useState(false); const [isAdmin, setIsAdmin] = useState(false); const [showAdminPanel, setShowAdminPanel] = useState(false); + const [pendingImage, setPendingImage] = useState(null); const messagesEndRef = useRef(null); const isMountedRef = useRef(true); @@ -80,7 +82,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { try { const fetched = await conversationService.getConversation(conversation.id); setMessages( - fetched.messages.map((m) => ({ text: m.text, speaker: m.speaker })), + fetched.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })), ); } catch (err) { console.error("Failed to load messages:", err); @@ -120,7 +122,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { try { const conv = await conversationService.getConversation(selectedConversation.id); setSelectedConversation({ id: conv.id, title: conv.name }); - setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker }))); + setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key }))); } catch (err) { console.error("Failed to load messages:", err); } @@ -129,7 +131,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { }, [selectedConversation?.id]); const handleQuestionSubmit = async () => { - if (!query.trim() || isLoading) return; + if ((!query.trim() && !pendingImage) || isLoading) return; let activeConversation = selectedConversation; if (!activeConversation) { @@ -139,9 +141,13 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { setConversations((prev) => [activeConversation!, ...prev]); } + // Capture pending image before clearing state + const imageFile = pendingImage; + const currMessages = messages.concat([{ text: query, speaker: "user" }]); setMessages(currMessages); setQuery(""); + setPendingImage(null); setIsLoading(true); if (simbaMode) { @@ -155,6 +161,29 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { abortControllerRef.current = abortController; try { + // Upload image first if present + let imageKey: string | undefined; + if (imageFile) { + const uploadResult = await conversationService.uploadImage( + imageFile, + activeConversation.id, + ); + imageKey = uploadResult.image_key; + + // Update the user message with the image key + setMessages((prev) => { + const updated = [...prev]; + // Find the last user message we just added + for (let i = updated.length - 1; i >= 0; i--) { + if (updated[i].speaker === "user") { + updated[i] = { ...updated[i], image_key: imageKey }; + break; + } + } + return updated; + }); + } + await conversationService.streamQuery( query, activeConversation.id, @@ -170,6 +199,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { } }, abortController.signal, + imageKey, ); } catch (error) { if (error instanceof Error && error.name === "AbortError") { @@ -349,6 +379,9 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { handleQuestionSubmit={handleQuestionSubmit} setSimbaMode={setSimbaMode} isLoading={isLoading} + pendingImage={pendingImage} + onImageSelect={(file) => setPendingImage(file)} + onClearImage={() => setPendingImage(null)} /> @@ -375,7 +408,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { return ; if (msg.speaker === "simba") return ; - return ; + return ; })} {isLoading && } diff --git a/raggr-frontend/src/components/MessageInput.tsx b/raggr-frontend/src/components/MessageInput.tsx index 1e0afa4..bd9af95 100644 --- a/raggr-frontend/src/components/MessageInput.tsx +++ b/raggr-frontend/src/components/MessageInput.tsx @@ -1,5 +1,5 @@ -import { useState } from "react"; -import { ArrowUp } from "lucide-react"; +import { useRef, useState } from "react"; +import { ArrowUp, ImagePlus, X } from "lucide-react"; import { cn } from "../lib/utils"; import { Textarea } from "./ui/textarea"; @@ -10,6 +10,9 @@ type MessageInputProps = { setSimbaMode: (val: boolean) => void; query: string; isLoading: boolean; + pendingImage: File | null; + onImageSelect: (file: File) => void; + onClearImage: () => void; }; export const MessageInput = ({ @@ -19,8 +22,12 @@ export const MessageInput = ({ handleQuestionSubmit, setSimbaMode, isLoading, + pendingImage, + onImageSelect, + onClearImage, }: MessageInputProps) => { const [simbaMode, setLocalSimbaMode] = useState(false); + const fileInputRef = useRef(null); const toggleSimbaMode = () => { const next = !simbaMode; @@ -28,6 +35,17 @@ export const MessageInput = ({ setSimbaMode(next); }; + const handleFileChange = (e: React.ChangeEvent) => { + const file = e.target.files?.[0]; + if (file) { + onImageSelect(file); + } + // Reset so the same file can be re-selected + e.target.value = ""; + }; + + const canSend = !isLoading && (query.trim() || pendingImage); + return (
+ {/* Image preview */} + {pendingImage && ( +
+
+ Pending upload + +
+
+ )} + {/* Textarea */}