diff --git a/app.py b/app.py index 13ba978..0cdbb95 100644 --- a/app.py +++ b/app.py @@ -13,6 +13,7 @@ import blueprints.users import blueprints.whatsapp import blueprints.email import blueprints.users.models +from config.db import TORTOISE_CONFIG from main import consult_simba_oracle # Load environment variables @@ -28,6 +29,7 @@ app = Quart( ) app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY") +app.config["MAX_CONTENT_LENGTH"] = 10 * 1024 * 1024 # 10 MB upload limit jwt = JWTManager(app) # Register blueprints @@ -38,24 +40,6 @@ app.register_blueprint(blueprints.whatsapp.whatsapp_blueprint) app.register_blueprint(blueprints.email.email_blueprint) -# Database configuration with environment variable support -DATABASE_URL = os.getenv( - "DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr" -) - -TORTOISE_CONFIG = { - "connections": {"default": DATABASE_URL}, - "apps": { - "models": { - "models": [ - "blueprints.conversation.models", - "blueprints.users.models", - "aerich.models", - ] - }, - }, -} - # Initialize Tortoise ORM with lifecycle hooks @app.while_serving async def lifespan(): diff --git a/blueprints/conversation/__init__.py b/blueprints/conversation/__init__.py index 28b6765..d19ace2 100644 --- a/blueprints/conversation/__init__.py +++ b/blueprints/conversation/__init__.py @@ -1,13 +1,19 @@ import datetime import json +import logging +import uuid -from quart import Blueprint, jsonify, make_response, request +from quart import Blueprint, Response, jsonify, make_response, request from quart_jwt_extended import ( get_jwt_identity, jwt_refresh_token_required, ) import blueprints.users.models +from utils.image_process import analyze_user_image +from utils.image_upload import ImageValidationError, process_image +from utils.s3_client import get_image as s3_get_image +from utils.s3_client import upload_image as s3_upload_image from .agents import main_agent from .logic import ( @@ -29,7 +35,9 @@ conversation_blueprint = Blueprint( _SYSTEM_PROMPT = SIMBA_SYSTEM_PROMPT -def _build_messages_payload(conversation, query_text: str) -> list: +def _build_messages_payload( + conversation, query_text: str, image_description: str | None = None +) -> list: recent_messages = ( conversation.messages[-10:] if len(conversation.messages) > 10 @@ -38,8 +46,19 @@ def _build_messages_payload(conversation, query_text: str) -> list: messages_payload = [{"role": "system", "content": _SYSTEM_PROMPT}] for msg in recent_messages[:-1]: # Exclude the message we just added role = "user" if msg.speaker == "user" else "assistant" - messages_payload.append({"role": role, "content": msg.text}) - messages_payload.append({"role": "user", "content": query_text}) + text = msg.text + if msg.image_key and role == "user": + text = f"[User sent an image]\n{text}" + messages_payload.append({"role": role, "content": text}) + + # Build the current user message with optional image description + if image_description: + content = f"[Image analysis: {image_description}]" + if query_text: + content = f"{query_text}\n\n{content}" + else: + content = query_text + messages_payload.append({"role": "user", "content": content}) return messages_payload @@ -74,6 +93,58 @@ async def query(): return jsonify({"response": message}) +@conversation_blueprint.post("/upload-image") +@jwt_refresh_token_required +async def upload_image(): + current_user_uuid = get_jwt_identity() + await blueprints.users.models.User.get(id=current_user_uuid) + + files = await request.files + form = await request.form + file = files.get("file") + conversation_id = form.get("conversation_id") + + if not file or not conversation_id: + return jsonify({"error": "file and conversation_id are required"}), 400 + + file_bytes = file.read() + content_type = file.content_type or "image/jpeg" + + try: + processed_bytes, output_content_type = process_image(file_bytes, content_type) + except ImageValidationError as e: + return jsonify({"error": str(e)}), 400 + + ext = output_content_type.split("/")[-1] + if ext == "jpeg": + ext = "jpg" + key = f"conversations/{conversation_id}/{uuid.uuid4()}.{ext}" + + await s3_upload_image(processed_bytes, key, output_content_type) + + return jsonify( + { + "image_key": key, + "image_url": f"/api/conversation/image/{key}", + } + ) + + +@conversation_blueprint.get("/image/") +@jwt_refresh_token_required +async def serve_image(image_key: str): + try: + image_bytes, content_type = await s3_get_image(image_key) + except Exception: + return jsonify({"error": "Image not found"}), 404 + + return Response( + image_bytes, + content_type=content_type, + headers={"Cache-Control": "private, max-age=3600"}, + ) + + @conversation_blueprint.post("/stream-query") @jwt_refresh_token_required async def stream_query(): @@ -82,16 +153,31 @@ async def stream_query(): data = await request.get_json() query_text = data.get("query") conversation_id = data.get("conversation_id") + image_key = data.get("image_key") conversation = await get_conversation_by_id(conversation_id) await conversation.fetch_related("messages") await add_message_to_conversation( conversation=conversation, - message=query_text, + message=query_text or "", speaker="user", user=user, + image_key=image_key, ) - messages_payload = _build_messages_payload(conversation, query_text) + # If an image was uploaded, analyze it with the vision model + image_description = None + if image_key: + try: + image_bytes, _ = await s3_get_image(image_key) + image_description = await analyze_user_image(image_bytes) + logging.info(f"Image analysis complete for {image_key}") + except Exception as e: + logging.error(f"Failed to analyze image: {e}") + image_description = "[Image could not be analyzed]" + + messages_payload = _build_messages_payload( + conversation, query_text or "", image_description + ) payload = {"messages": messages_payload} async def event_generator(): @@ -160,6 +246,7 @@ async def get_conversation(conversation_id: str): "text": msg.text, "speaker": msg.speaker.value, "created_at": msg.created_at.isoformat(), + "image_key": msg.image_key, } ) name = conversation.name diff --git a/blueprints/conversation/logic.py b/blueprints/conversation/logic.py index 8129ffc..4586d19 100644 --- a/blueprints/conversation/logic.py +++ b/blueprints/conversation/logic.py @@ -16,12 +16,14 @@ async def add_message_to_conversation( message: str, speaker: str, user: blueprints.users.models.User, + image_key: str | None = None, ) -> ConversationMessage: print(conversation, message, speaker) message = await ConversationMessage.create( text=message, speaker=speaker, conversation=conversation, + image_key=image_key, ) return message diff --git a/blueprints/conversation/models.py b/blueprints/conversation/models.py index e0e5ad1..1a73b6b 100644 --- a/blueprints/conversation/models.py +++ b/blueprints/conversation/models.py @@ -41,6 +41,7 @@ class ConversationMessage(Model): ) created_at = fields.DatetimeField(auto_now_add=True) speaker = fields.CharEnumField(enum_type=Speaker, max_length=10) + image_key = fields.CharField(max_length=512, null=True, default=None) class Meta: table = "conversation_messages" diff --git a/aerich_config.py b/config/db.py similarity index 53% rename from aerich_config.py rename to config/db.py index bfacaa9..3e4aa6d 100644 --- a/aerich_config.py +++ b/config/db.py @@ -1,15 +1,14 @@ import os + from dotenv import load_dotenv -# Load environment variables load_dotenv() -# Database configuration with environment variable support -# Use DATABASE_PATH for relative paths or DATABASE_URL for full connection strings -DATABASE_PATH = os.getenv("DATABASE_PATH", "database/raggr.db") -DATABASE_URL = os.getenv("DATABASE_URL", f"sqlite://{DATABASE_PATH}") +DATABASE_URL = os.getenv( + "DATABASE_URL", "postgres://raggr:raggr_dev_password@localhost:5432/raggr" +) -TORTOISE_ORM = { +TORTOISE_CONFIG = { "connections": {"default": DATABASE_URL}, "apps": { "models": { diff --git a/docker-compose.yml b/docker-compose.yml index 5fa3914..88f91a5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -55,6 +55,12 @@ services: - OBSIDIAN_DEVICE_NAME=${OBSIDIAN_DEVICE_NAME} - OBSIDIAN_CONTINUOUS_SYNC=${OBSIDIAN_CONTINUOUS_SYNC:-false} - OBSIDIAN_VAULT_PATH=${OBSIDIAN_VAULT_PATH:-/app/data/obsidian} + - S3_ENDPOINT_URL=${S3_ENDPOINT_URL} + - S3_ACCESS_KEY_ID=${S3_ACCESS_KEY_ID} + - S3_SECRET_ACCESS_KEY=${S3_SECRET_ACCESS_KEY} + - S3_BUCKET_NAME=${S3_BUCKET_NAME:-asksimba-images} + - S3_REGION=${S3_REGION:-garage} + - OLLAMA_HOST=${OLLAMA_HOST:-http://localhost:11434} depends_on: postgres: condition: service_healthy diff --git a/migrations/models/3_20260313000000_add_email_fields.py b/migrations/models/3_20260313000000_add_email_fields.py index c4a6e11..b4e5651 100644 --- a/migrations/models/3_20260313000000_add_email_fields.py +++ b/migrations/models/3_20260313000000_add_email_fields.py @@ -15,3 +15,32 @@ async def downgrade(db: BaseDBAsyncClient) -> str: DROP INDEX IF EXISTS "idx_users_email_h_a1b2c3"; ALTER TABLE "users" DROP COLUMN "email_hmac_token"; ALTER TABLE "users" DROP COLUMN "email_enabled";""" + + +MODELS_STATE = ( + "eJztmm1v4jgQx78Kyquu1KtaKN1VdTopUHrLbYEThX3qVZFJXMg1sbOJsxRV/e5nm4Q4jg" + "OEAoU93rRl7CH2z2PP35M+ay62oBOc1DH6Cf0AEBsj7bL0rCHgQvqHsv24pAHPS1qZgYCB" + "wx1MoSdvAYOA+MAktPEBOAGkJgsGpm970cNQ6DjMiE3a0UbDxBQi+0cIDYKHkIygTxvu7q" + "nZRhZ8gkH80Xs0HmzoWKlx2xZ7NrcbZOJxW7/fvLrmPdnjBoaJndBFSW9vQkYYzbqHoW2d" + "MB/WNoQI+oBAS5gGG2U07dg0HTE1ED+Es6FaicGCDyB0GAzt94cQmYxBiT+J/Tj/QyuAh6" + "JmaG1EGIvnl+mskjlzq8YeVf+od48qF+/4LHFAhj5v5ES0F+4ICJi6cq4JSP47g7I+Ar4a" + "ZdxfgkkHugrG2JBwTGIoBhkDWo2a5oInw4FoSEb0Y7lanYPxs97lJGkvjhLTuJ5GfTtqKk" + "/bGNIEoelDNmUDkCzIK9pCbBeqYaY9JaRW5HoS/7GjgOkcrA5yJtEmmMO312w1bnt66282" + "EzcIfjgckd5rsJYyt04k69GFtBSzLyl9afY+ltjH0vdOuyHH/qxf77vGxgRCgg2ExwawhP" + "0aW2MwqYUNPWvFhU17Hhb2TRc2GrywrgH0jWIZRHB5RRqJxrbFRVw9abDU+/CozBkMRhbe" + "NfahPUSf4IQjbNJxAGSqkkUkOvrR1+wqtMSajMIH45kaEYOCzo7OCZJp9tRv6/pVQ+MMB8" + "B8HAPfMnJgujAIwBAGWaC1yPP6Uxc6M2mmZikKuNb0G3fzVMljy1nhMhYYpehlm9yyK1sA" + "ovO2omezJ82hs0AFCxCXE8OGuJAHUbzXopjAJ0XK71GrGmXcf19E8bxU3vjaS2XxWPoetf" + "Sv71KZ/KbT/jPuLkjl+k2ndlDIv6KQyirkwIPgUSUG2AWygUI3IwVSqyu4v/HW0fq3je5l" + "iWX0f9Bts1XTL0uB7Q6AttwSp26ZZ6dLXDLPTnPvmKxJ2kBioil2zCtc13nm76mENaWC1y" + "ulrFw/21mKCzWtIlyKattNKjl+Z1BIt/guka/V2NY+aLP912ZsHYsWLUWffdFoWyhceiAI" + "xthXRGbNRsCfqGGKXhLMwYRM7z+7eqVXwasxvSrKLYqs1mzr3W9qyRv3F+O29q3X0CW60A" + "W2UyRKZw7rCdHFO36dAXp2upzomad6MrJnPAIkoEe6QZXkIE9mqmEqXFfCKofqdqlWloFa" + "yWdaySDlQWZAxKan2vgYOxCgOQEq+srbnzpv6jAtmqoL7P9O5ya1/2tN+Urbb9UaNHg5Zt" + "rJnkqhZrunhDtygUk1wiNUKMsFu1/y3cOIPbtY5hiQr6zCKXAhRyy2LdMIwsG/0FSUD/KB" + "yn57CHMjWZ9e6EeG5+OftlXsSM04bk9KaQ42gfMKLZrmWl3mWK3mH6vVzLHqWMAzhj4OPU" + "Uh/6/bTluNVHKTgPYRneWdZZvkuOTYAbnfGN67+83ofDbz+dVEuXAoCSv2BYdq4v+kmnh4" + "3/5LLOzsdV6mKrToXWjmn8vW80J0l2+k230RqkPfNkeaooAWtRzPK6GBpM/O1NCaKOednL" + "KExjBLwRCt/JvepPnr6N/KZ+fvzz9ULs4/0C58JDPL+zmHQXwNyS+ZsY2grHPnaz3B5VAw" + "S6Qz3RpFBPO0+34C3EhBhz6RQKRI7/kSWXB5K3m8sdLj2uRxgWy7/vTy8h9Mf/k3" +) diff --git a/migrations/models/4_20260404080201_add_image_key.py b/migrations/models/4_20260404080201_add_image_key.py new file mode 100644 index 0000000..e15fdbc --- /dev/null +++ b/migrations/models/4_20260404080201_add_image_key.py @@ -0,0 +1,43 @@ +from tortoise import BaseDBAsyncClient + +RUN_IN_TRANSACTION = True + + +async def upgrade(db: BaseDBAsyncClient) -> str: + return """ + ALTER TABLE "conversation_messages" ADD "image_key" VARCHAR(512);""" + + +async def downgrade(db: BaseDBAsyncClient) -> str: + return """ + ALTER TABLE "conversation_messages" DROP COLUMN "image_key";""" + + +MODELS_STATE = ( + "eJztmmtv4jgUhv8KyqeO1K0KvcyoWq0UWrrDToFVC3PrVpFJXPCS2JnEGYqq/ve1TUIcx6" + "GkBQqzfGnLsQ+xH1/Oe076aHjEgW54cE7wTxiEgCKCjbPKo4GBB9kf2vb9igF8P23lBgr6" + "rnCwpZ6iBfRDGgCbssZ74IaQmRwY2gHy44fhyHW5kdisI8KD1BRh9COCFiUDSIcwYA23d8" + "yMsAMfYJh89EfWPYKukxk3cvizhd2iE1/Yer3mxaXoyR/Xt2ziRh5Oe/sTOiR41j2KkHPA" + "fXjbAGIYAAodaRp8lPG0E9N0xMxAgwjOhuqkBgfeg8jlMIzf7yNscwYV8ST+4/gPowQehp" + "qjRZhyFo9P01mlcxZWgz/q/KN5vXd0+k7MkoR0EIhGQcR4Eo6Agqmr4JqCFL9zKM+HINCj" + "TPorMNlAX4IxMaQc0z2UgEwAvYya4YEHy4V4QIfsY+3kZA7Gz+a1IMl6CZSE7evprm/HTb" + "VpG0eaIrQDyKdsAZoHecFaKPKgHmbWU0HqxK4HyR8bCpjNwelgdxIfgjl8u81W46Zrtv7m" + "M/HC8IcrEJndBm+pCetEse6dKksx+5LKl2b3Y4V/rHzvtBvq3p/16343+JhARImFydgCjn" + "ReE2sCJrOwke+8cGGznruFfdOFjQcvrWsIA6tcBJFcXhFG4rGtcRFfHjR46L0faWMGh5GH" + "d0kCiAb4E5wIhE02DoBtXbCIRUcv/ppNhZZa01EEYDxTI/KmYLNjc4J0Gj3Nm3PzomEIhn" + "1gj8YgcKwCmB4MQzCAYR5oPfa8/HQN3Zk007OUBVxr+o2beasUsRWsSI1IjDL08k1ezVMt" + "ALN5O/Gz+ZPm0HlGBUsQFxPDlryQO1G81aKYwgdNyO8yqx5l0n9bRPG8UN742s1E8UT67r" + "XMr+8ykfyq0/4z6S5J5fOrTn2nkH9FIZVXyKEPwUgnBngC2cCRl5MCmdWV3N/46Bi9m8b1" + "WYVH9H/wTbNVN88qIfL6wFhsiTNZZvVwgSSzeliYY/Km7AFCHoss1ghOyqTqGacX8V2/9M" + "qCPKnWFiDJehWiFG3KZSQH7XIhU+O6zPi5pemArRQPX5kWqLXIjaX4bH6g2S5l84RVqmKR" + "f2lkcJKXFetefk3udO7261y+jmULwLLPtujdNRSBfRCGYxJodmYdYRBM9DBlLwVmf0Knue" + "TGxeg58Opc+8vSlSGrN9vm9Td9+pD0l/dt/Vu3YSp0oQeQW2aXzhyWs0WfP/HL3KDVw8UE" + "5DwFmZOQ4yGgIbvSLabK+0WSXQ9T47oUObleqkeLQD0qZnqUQyo2mQUxn57u4BPiQoDnbF" + "DZVz3+zHlVl2nZUF3i/Hc6V5nzX2+q5YFeq95gm1dgZp3QVAo1210t3KEHbKYRRlCjLJ85" + "/YrvFu7Y6uki14Ca/ku3wKm6YwlybCuM+v9CW1OKKQaq+m0hzJVEfRDRoeUH5Cdyyl2pOc" + "f1SSnDJTZwX6FFlRx9kWv1pPhaPcldq64DfGsQkMjXvBT566bT1iNV3BSgPcxmeesgm+5X" + "XBTSu5Xhvb1bjc7nM59fmVWLsIqw4l+wq8z+Tyqzu/9d+CUWdvZqNFcVeu69cu4f9Zbzcn" + "mTM9L1vlQ2YYDsoaEpoMUt+/NKaCDtszE1tCYueL+pLaFxzMpmiFf+TTNp8Wr/t1r1+P3x" + "h6PT4w+sixjJzPJ+zmWQpCHFJTN+ELR17mKtJ7nsCmapdGZHo4xgnnbfToArKeiwJ1KINe" + "G9WCJLLm8lj1dWelyaPC4RbZcfXp7+AzcBYwM=" +) diff --git a/pyproject.toml b/pyproject.toml index 0827bbb..4e88e98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,9 +37,10 @@ dependencies = [ "ynab>=1.3.0", "ollama>=0.6.1", "twilio>=9.10.2", + "aioboto3>=13.0.0", ] [tool.aerich] -tortoise_orm = "app.TORTOISE_CONFIG" +tortoise_orm = "config.db.TORTOISE_CONFIG" location = "./migrations" src_folder = "./." diff --git a/raggr-frontend/src/api/conversationService.ts b/raggr-frontend/src/api/conversationService.ts index f80eb6a..788089e 100644 --- a/raggr-frontend/src/api/conversationService.ts +++ b/raggr-frontend/src/api/conversationService.ts @@ -13,6 +13,7 @@ interface Message { text: string; speaker: "user" | "simba"; created_at: string; + image_key?: string | null; } interface Conversation { @@ -121,17 +122,52 @@ class ConversationService { return await response.json(); } + async uploadImage( + file: File, + conversationId: string, + ): Promise<{ image_key: string; image_url: string }> { + const formData = new FormData(); + formData.append("file", file); + formData.append("conversation_id", conversationId); + + const response = await userService.fetchWithRefreshToken( + `${this.conversationBaseUrl}/upload-image`, + { + method: "POST", + body: formData, + }, + { skipContentType: true }, + ); + + if (!response.ok) { + const data = await response.json(); + throw new Error(data.error || "Failed to upload image"); + } + + return await response.json(); + } + + getImageUrl(imageKey: string): string { + return `/api/conversation/image/${imageKey}`; + } + async streamQuery( query: string, conversation_id: string, onEvent: SSEEventCallback, signal?: AbortSignal, + imageKey?: string, ): Promise { + const body: Record = { query, conversation_id }; + if (imageKey) { + body.image_key = imageKey; + } + const response = await userService.fetchWithRefreshToken( `${this.conversationBaseUrl}/stream-query`, { method: "POST", - body: JSON.stringify({ query, conversation_id }), + body: JSON.stringify(body), signal, }, ); diff --git a/raggr-frontend/src/api/userService.ts b/raggr-frontend/src/api/userService.ts index 8f816c5..52e4f4c 100644 --- a/raggr-frontend/src/api/userService.ts +++ b/raggr-frontend/src/api/userService.ts @@ -106,14 +106,15 @@ class UserService { async fetchWithRefreshToken( url: string, options: RequestInit = {}, + { skipContentType = false }: { skipContentType?: boolean } = {}, ): Promise { const refreshToken = localStorage.getItem("refresh_token"); // Add authorization header - const headers = { - "Content-Type": "application/json", - ...(options.headers || {}), - ...(refreshToken && { Authorization: `Bearer ${refreshToken}` }), + const headers: Record = { + ...(skipContentType ? {} : { "Content-Type": "application/json" }), + ...((options.headers as Record) || {}), + ...(refreshToken ? { Authorization: `Bearer ${refreshToken}` } : {}), }; let response = await fetch(url, { ...options, headers }); diff --git a/raggr-frontend/src/components/ChatScreen.tsx b/raggr-frontend/src/components/ChatScreen.tsx index 289a3a2..b1c6216 100644 --- a/raggr-frontend/src/components/ChatScreen.tsx +++ b/raggr-frontend/src/components/ChatScreen.tsx @@ -14,6 +14,7 @@ import catIcon from "../assets/cat.png"; type Message = { text: string; speaker: "simba" | "user" | "tool"; + image_key?: string | null; }; type Conversation = { @@ -55,6 +56,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { const [isLoading, setIsLoading] = useState(false); const [isAdmin, setIsAdmin] = useState(false); const [showAdminPanel, setShowAdminPanel] = useState(false); + const [pendingImage, setPendingImage] = useState(null); const messagesEndRef = useRef(null); const isMountedRef = useRef(true); @@ -80,7 +82,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { try { const fetched = await conversationService.getConversation(conversation.id); setMessages( - fetched.messages.map((m) => ({ text: m.text, speaker: m.speaker })), + fetched.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })), ); } catch (err) { console.error("Failed to load messages:", err); @@ -120,7 +122,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { try { const conv = await conversationService.getConversation(selectedConversation.id); setSelectedConversation({ id: conv.id, title: conv.name }); - setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker }))); + setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key }))); } catch (err) { console.error("Failed to load messages:", err); } @@ -129,7 +131,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { }, [selectedConversation?.id]); const handleQuestionSubmit = async () => { - if (!query.trim() || isLoading) return; + if ((!query.trim() && !pendingImage) || isLoading) return; let activeConversation = selectedConversation; if (!activeConversation) { @@ -139,9 +141,13 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { setConversations((prev) => [activeConversation!, ...prev]); } + // Capture pending image before clearing state + const imageFile = pendingImage; + const currMessages = messages.concat([{ text: query, speaker: "user" }]); setMessages(currMessages); setQuery(""); + setPendingImage(null); setIsLoading(true); if (simbaMode) { @@ -155,6 +161,29 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { abortControllerRef.current = abortController; try { + // Upload image first if present + let imageKey: string | undefined; + if (imageFile) { + const uploadResult = await conversationService.uploadImage( + imageFile, + activeConversation.id, + ); + imageKey = uploadResult.image_key; + + // Update the user message with the image key + setMessages((prev) => { + const updated = [...prev]; + // Find the last user message we just added + for (let i = updated.length - 1; i >= 0; i--) { + if (updated[i].speaker === "user") { + updated[i] = { ...updated[i], image_key: imageKey }; + break; + } + } + return updated; + }); + } + await conversationService.streamQuery( query, activeConversation.id, @@ -170,6 +199,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { } }, abortController.signal, + imageKey, ); } catch (error) { if (error instanceof Error && error.name === "AbortError") { @@ -349,6 +379,9 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { handleQuestionSubmit={handleQuestionSubmit} setSimbaMode={setSimbaMode} isLoading={isLoading} + pendingImage={pendingImage} + onImageSelect={(file) => setPendingImage(file)} + onClearImage={() => setPendingImage(null)} /> @@ -375,7 +408,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => { return ; if (msg.speaker === "simba") return ; - return ; + return ; })} {isLoading && } diff --git a/raggr-frontend/src/components/MessageInput.tsx b/raggr-frontend/src/components/MessageInput.tsx index 1e0afa4..bd9af95 100644 --- a/raggr-frontend/src/components/MessageInput.tsx +++ b/raggr-frontend/src/components/MessageInput.tsx @@ -1,5 +1,5 @@ -import { useState } from "react"; -import { ArrowUp } from "lucide-react"; +import { useRef, useState } from "react"; +import { ArrowUp, ImagePlus, X } from "lucide-react"; import { cn } from "../lib/utils"; import { Textarea } from "./ui/textarea"; @@ -10,6 +10,9 @@ type MessageInputProps = { setSimbaMode: (val: boolean) => void; query: string; isLoading: boolean; + pendingImage: File | null; + onImageSelect: (file: File) => void; + onClearImage: () => void; }; export const MessageInput = ({ @@ -19,8 +22,12 @@ export const MessageInput = ({ handleQuestionSubmit, setSimbaMode, isLoading, + pendingImage, + onImageSelect, + onClearImage, }: MessageInputProps) => { const [simbaMode, setLocalSimbaMode] = useState(false); + const fileInputRef = useRef(null); const toggleSimbaMode = () => { const next = !simbaMode; @@ -28,6 +35,17 @@ export const MessageInput = ({ setSimbaMode(next); }; + const handleFileChange = (e: React.ChangeEvent) => { + const file = e.target.files?.[0]; + if (file) { + onImageSelect(file); + } + // Reset so the same file can be re-selected + e.target.value = ""; + }; + + const canSend = !isLoading && (query.trim() || pendingImage); + return (
+ {/* Image preview */} + {pendingImage && ( +
+
+ Pending upload + +
+
+ )} + {/* Textarea */}