diff --git a/Dockerfile b/Dockerfile index 3d64261..eb5a966 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,6 +24,8 @@ RUN uv pip install --system -e . # Copy application code COPY *.py ./ COPY blueprints ./blueprints +COPY aerich.toml ./ +COPY migrations ./migrations COPY startup.sh ./ RUN chmod +x startup.sh diff --git a/MIGRATIONS.md b/MIGRATIONS.md new file mode 100644 index 0000000..b125020 --- /dev/null +++ b/MIGRATIONS.md @@ -0,0 +1,54 @@ +# Database Migrations with Aerich + +## Initial Setup (Run Once) + +1. Install dependencies: +```bash +uv pip install -e . +``` + +2. Initialize Aerich: +```bash +aerich init-db +``` + +This will: +- Create a `migrations/` directory +- Generate the initial migration based on your models +- Create all tables in the database + +## When You Add/Change Models + +1. Generate a new migration: +```bash +aerich migrate --name "describe_your_changes" +``` + +Example: +```bash +aerich migrate --name "add_user_profile_model" +``` + +2. Apply the migration: +```bash +aerich upgrade +``` + +## Common Commands + +- `aerich init-db` - Initialize database (first time only) +- `aerich migrate --name "description"` - Generate new migration +- `aerich upgrade` - Apply pending migrations +- `aerich downgrade` - Rollback last migration +- `aerich history` - Show migration history +- `aerich heads` - Show current migration heads + +## Docker Setup + +In Docker, migrations run automatically on container startup via the startup script. + +## Notes + +- Migration files are stored in `migrations/models/` +- Always commit migration files to version control +- Don't modify migration files manually after they're created diff --git a/add_user.py b/add_user.py new file mode 100644 index 0000000..1961003 --- /dev/null +++ b/add_user.py @@ -0,0 +1,130 @@ +# GENERATED BY CLAUDE + +import sys +import uuid +import asyncio +from tortoise import Tortoise +from blueprints.users.models import User + + +async def add_user(username: str, email: str, password: str): + """Add a new user to the database""" + await Tortoise.init( + db_url="sqlite://raggr.db", + modules={ + "models": [ + "blueprints.users.models", + "blueprints.conversation.models", + ] + }, + ) + + try: + # Check if user already exists + existing_user = await User.filter(email=email).first() + if existing_user: + print(f"Error: User with email '{email}' already exists!") + return False + + existing_username = await User.filter(username=username).first() + if existing_username: + print(f"Error: Username '{username}' is already taken!") + return False + + # Create new user + user = User( + id=uuid.uuid4(), + username=username, + email=email, + ) + user.set_password(password) + await user.save() + + print("✓ User created successfully!") + print(f" Username: {username}") + print(f" Email: {email}") + print(f" ID: {user.id}") + return True + + except Exception as e: + print(f"Error creating user: {e}") + return False + finally: + await Tortoise.close_connections() + + +async def list_users(): + """List all users in the database""" + await Tortoise.init( + db_url="sqlite://raggr.db", + modules={ + "models": [ + "blueprints.users.models", + "blueprints.conversation.models", + ] + }, + ) + + try: + users = await User.all() + if not users: + print("No users found in database.") + return + + print(f"\nFound {len(users)} user(s):") + print("-" * 60) + for user in users: + print(f"Username: {user.username}") + print(f"Email: {user.email}") + print(f"ID: {user.id}") + print(f"Created: {user.created_at}") + print("-" * 60) + + except Exception as e: + print(f"Error listing users: {e}") + finally: + await Tortoise.close_connections() + + +def print_usage(): + """Print usage instructions""" + print("Usage:") + print(" python add_user.py add ") + print(" python add_user.py list") + print("\nExamples:") + print(" python add_user.py add ryan ryan@example.com mypassword123") + print(" python add_user.py list") + + +async def main(): + if len(sys.argv) < 2: + print_usage() + sys.exit(1) + + command = sys.argv[1].lower() + + if command == "add": + if len(sys.argv) != 5: + print("Error: Missing arguments for 'add' command") + print_usage() + sys.exit(1) + + username = sys.argv[2] + email = sys.argv[3] + password = sys.argv[4] + + success = await add_user(username, email, password) + sys.exit(0 if success else 1) + + elif command == "list": + await list_users() + sys.exit(0) + + else: + print(f"Error: Unknown command '{command}'") + print_usage() + sys.exit(1) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/aerich_config.py b/aerich_config.py new file mode 100644 index 0000000..d23a0ba --- /dev/null +++ b/aerich_config.py @@ -0,0 +1,15 @@ +import os + +TORTOISE_ORM = { + "connections": {"default": os.getenv("DATABASE_URL", "sqlite:///app/raggr.db")}, + "apps": { + "models": { + "models": [ + "blueprints.conversation.models", + "blueprints.users.models", + "aerich.models", + ], + "default_connection": "default", + }, + }, +} diff --git a/app.py b/app.py index 9997b1f..bbc6ad7 100644 --- a/app.py +++ b/app.py @@ -3,13 +3,14 @@ import os from quart import Quart, request, jsonify, render_template, send_from_directory from tortoise.contrib.quart import register_tortoise -from quart_jwt_extended import JWTManager +from quart_jwt_extended import JWTManager, jwt_refresh_token_required, get_jwt_identity from main import consult_simba_oracle -from blueprints.conversation.logic import ( - get_the_only_conversation, - add_message_to_conversation, -) + +import blueprints.users +import blueprints.conversation +import blueprints.conversation.logic +import blueprints.users.models app = Quart( __name__, @@ -20,12 +21,29 @@ app = Quart( app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY") jwt = JWTManager(app) +# Register blueprints +app.register_blueprint(blueprints.users.user_blueprint) +app.register_blueprint(blueprints.conversation.conversation_blueprint) + + +TORTOISE_CONFIG = { + "connections": {"default": "sqlite://raggr.db"}, + "apps": { + "models": { + "models": [ + "blueprints.conversation.models", + "blueprints.users.models", + "aerich.models", + ] + }, + }, +} + # Initialize Tortoise ORM register_tortoise( app, - db_url=os.getenv("DATABASE_URL", "sqlite://raggr.db"), - modules={"models": ["blueprints.conversation.models"]}, - generate_schemas=True, + config=TORTOISE_CONFIG, + generate_schemas=False, # Disabled - using Aerich for migrations ) @@ -45,26 +63,41 @@ async def serve_react_app(path): @app.route("/api/query", methods=["POST"]) +@jwt_refresh_token_required async def query(): + current_user_uuid = get_jwt_identity() + user = await blueprints.users.models.User.get(id=current_user_uuid) data = await request.get_json() query = data.get("query") - # add message to database - conversation = await get_the_only_conversation() - print(conversation) - await add_message_to_conversation( - conversation=conversation, message=query, speaker="user" + conversation = await blueprints.conversation.logic.get_conversation_for_user( + user=user + ) + await blueprints.conversation.logic.add_message_to_conversation( + conversation=conversation, + message=query, + speaker="user", + user=user, ) response = consult_simba_oracle(query) - await add_message_to_conversation( - conversation=conversation, message=response, speaker="simba" + await blueprints.conversation.logic.add_message_to_conversation( + conversation=conversation, + message=response, + speaker="simba", + user=user, ) return jsonify({"response": response}) @app.route("/api/messages", methods=["GET"]) +@jwt_refresh_token_required async def get_messages(): - conversation = await get_the_only_conversation() + current_user_uuid = get_jwt_identity() + user = await blueprints.users.models.User.get(id=current_user_uuid) + + conversation = await blueprints.conversation.logic.get_conversation_for_user( + user=user + ) # Prefetch related messages await conversation.fetch_related("messages") @@ -91,12 +124,5 @@ async def get_messages(): ) -# @app.route("/api/ingest", methods=["POST"]) -# def webhook(): -# data = request.get_json() -# print(data) -# return jsonify({"status": "received"}) - - if __name__ == "__main__": app.run(host="0.0.0.0", port=8080, debug=True) diff --git a/blueprints/conversation/logic.py b/blueprints/conversation/logic.py index cad2815..601bef1 100644 --- a/blueprints/conversation/logic.py +++ b/blueprints/conversation/logic.py @@ -1,5 +1,9 @@ +import tortoise.exceptions + from .models import Conversation, ConversationMessage +import blueprints.users.models + async def create_conversation(name: str = "") -> Conversation: conversation = await Conversation.create(name=name) @@ -10,6 +14,7 @@ async def add_message_to_conversation( conversation: Conversation, message: str, speaker: str, + user: blueprints.users.models.User, ) -> ConversationMessage: print(conversation, message, speaker) message = await ConversationMessage.create( @@ -30,3 +35,12 @@ async def get_the_only_conversation() -> Conversation: conversation = await Conversation.create(name="simba_chat") return conversation + + +async def get_conversation_for_user(user: blueprints.users.models.User) -> Conversation: + try: + return await Conversation.get(user=user) + except tortoise.exceptions.DoesNotExist: + await Conversation.get_or_create(name=f"{user.username}'s chat", user=user) + + return await Conversation.get(user=user) diff --git a/blueprints/conversation/models.py b/blueprints/conversation/models.py index 79fcc93..aadaa4c 100644 --- a/blueprints/conversation/models.py +++ b/blueprints/conversation/models.py @@ -18,6 +18,9 @@ class Conversation(Model): name = fields.CharField(max_length=255) created_at = fields.DatetimeField(auto_now_add=True) updated_at = fields.DatetimeField(auto_now=True) + user: fields.ForeignKeyRelation = fields.ForeignKeyField( + "models.User", related_name="conversations", null=True + ) class Meta: table = "conversations" diff --git a/blueprints/users/__init__.py b/blueprints/users/__init__.py new file mode 100644 index 0000000..12944f7 --- /dev/null +++ b/blueprints/users/__init__.py @@ -0,0 +1,40 @@ +from quart import Blueprint, jsonify, request +from quart_jwt_extended import ( + create_access_token, + create_refresh_token, + jwt_refresh_token_required, + get_jwt_identity, +) +from .models import User + + +user_blueprint = Blueprint("user_api", __name__, url_prefix="/api/user") + + +@user_blueprint.route("/login", methods=["POST"]) +async def login(): + data = await request.get_json() + username = data.get("username") + password = data.get("password") + + user = await User.filter(username=username).first() + + if not user or not user.verify_password(password): + return jsonify({"msg": "Invalid credentials"}), 401 + + access_token = create_access_token(identity=str(user.id)) + refresh_token = create_refresh_token(identity=str(user.id)) + + return jsonify( + access_token=access_token, + refresh_token=refresh_token, + user={"id": user.id, "username": user.username}, + ) + + +@user_blueprint.route("/refresh", methods=["POST"]) +@jwt_refresh_token_required +async def refresh(): + user_id = get_jwt_identity() + new_token = create_access_token(identity=user_id) + return jsonify(access_token=new_token) diff --git a/blueprints/users/models.py b/blueprints/users/models.py new file mode 100644 index 0000000..43930f0 --- /dev/null +++ b/blueprints/users/models.py @@ -0,0 +1,26 @@ +from tortoise.models import Model +from tortoise import fields + + +import bcrypt + + +class User(Model): + id = fields.UUIDField(primary_key=True) + username = fields.CharField(max_length=255) + password = fields.BinaryField() # Hashed + email = fields.CharField(max_length=100, unique=True) + created_at = fields.DatetimeField(auto_now_add=True) + updated_at = fields.DatetimeField(auto_now=True) + + class Meta: + table = "users" + + def set_password(self, plain_password: str): + self.password = bcrypt.hashpw( + plain_password.encode("utf-8"), + bcrypt.gensalt(), + ) + + def verify_password(self, plain_password: str): + return bcrypt.checkpw(plain_password.encode("utf-8"), self.password) diff --git a/migrations/models/0_20251025081744_init.py b/migrations/models/0_20251025081744_init.py new file mode 100644 index 0000000..2c09b45 --- /dev/null +++ b/migrations/models/0_20251025081744_init.py @@ -0,0 +1,63 @@ +from tortoise import BaseDBAsyncClient + +RUN_IN_TRANSACTION = True + + +async def upgrade(db: BaseDBAsyncClient) -> str: + return """ + CREATE TABLE IF NOT EXISTS "conversations" ( + "id" CHAR(36) NOT NULL PRIMARY KEY, + "name" VARCHAR(255) NOT NULL, + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE TABLE IF NOT EXISTS "conversation_messages" ( + "id" CHAR(36) NOT NULL PRIMARY KEY, + "text" TEXT NOT NULL, + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "speaker" VARCHAR(10) NOT NULL /* USER: user\nSIMBA: simba */, + "conversation_id" CHAR(36) NOT NULL REFERENCES "conversations" ("id") ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS "users" ( + "id" CHAR(36) NOT NULL PRIMARY KEY, + "username" VARCHAR(255) NOT NULL, + "password" BLOB NOT NULL, + "email" VARCHAR(100) NOT NULL UNIQUE, + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE TABLE IF NOT EXISTS "aerich" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + "version" VARCHAR(255) NOT NULL, + "app" VARCHAR(100) NOT NULL, + "content" JSON NOT NULL +);""" + + +async def downgrade(db: BaseDBAsyncClient) -> str: + return """ + """ + + +MODELS_STATE = ( + "eJztmG1v4jgQx79KlFddaa9q2W53VZ1OCpTecrvACcLdPtwqMskAVhMnazvboorvfrbJE4" + "kJpWq3UPGmhRkPtn8ztv/2nRmEHvjsuBWSn0AZ4jgk5oVxZxIUgPig9b82TBRFuVcaOBr7" + "KsAttFQeNGacIpcL5wT5DITJA+ZSHCWdkdj3pTF0RUNMprkpJvhHDA4Pp8BnQIXj23dhxs" + "SDW2Dp1+jamWDwvZVxY0/2rewOn0fKNhp1Lq9US9nd2HFDPw5I3jqa81lIsuZxjL1jGSN9" + "UyBAEQevMA05ymTaqWk5YmHgNIZsqF5u8GCCYl/CMH+fxMSVDAzVk/xz9oe5BR6BWqLFhE" + "sWd4vlrPI5K6spu2p9sAZHb85fqVmGjE+pcioi5kIFIo6WoYprDlL9r6BszRDVo0zbl2CK" + "gT4EY2rIOeY1lIJMAT2MmhmgW8cHMuUz8bXx9m0Nxn+sgSIpWimUoajrZdX3Eldj6ZNIc4" + "QuBTllB/EqyEvh4TgAPczVyBJSLwk9Tj/sKGAxB69P/HmyCGr42p1ue2hb3b/lTALGfvgK" + "kWW3paehrPOS9ei8lIrsR4x/O/YHQ341vvZ77XLtZ+3sr6YcE4p56JDwxkFeYb2m1hTMSm" + "LjyHtgYlcjD4l91sSqwcuTZHJd2AKlYYzc6xtEPWfFUzgdgTE0BVZNfzOJvPo4AD87NkuJ" + "1hyu3eUv7mbGF2kZp9YivLARrqNXdQWNoGxBRMzbS/qWPdXQ2aBQChDvJ1ScYiIPgmWvBQ" + "uHW812bAurHmXafl8ES9022/5sr+ywqSw56lqfX63ssp/6vT/T5gUZ0/rUbx7Uy0s85Krq" + "hUWAroHqxX2bxIHKakfgQMSFSnYL4c+8dMzRsD24MGIG9D8y7HSb1oXBcDBG5gNuAKcn97" + "gAnJ6s1f/SVVpAxYNmu21eE/qYe/6zblYbtviKHtMDrdK8CingKfkI80r9bpZfO02xoruE" + "maKbTEzoykV8EJMEvlzY1rBlXbbNxXpt+5RKbsSUJKpIN2Wv1WpyaR+02f5rM5nHbR+Uij" + "H7otF+waNShBi7CammMpuYIDrXwyxGlWCO53x5/9k9nDX0mlKwFvWWYNbs9KzBF73mTdsX" + "C7f5xW5bJbwQIOxvU6ZZwOPU6OYl/5gVenpyP9VTJ3uquudwcXiZF4fDs+eLSOy2z55PKQ" + "0toNid6cRh4qmVhyhvszP6sEPWvDdp5aHU9KVqTxL2rIeEemr9rXF69u7s/Zvzs/eiiRpJ" + "ZnlXU/2dnr1BDsrLivYOt/6YLYQcxGAGUi6NLSAmzfcT4NNolZBwIJrz7K9hv7f2bSYNKY" + "EcETHBbx52+WvDx4x/302sNRTlrOsfkstvxqXDSP5AU/eK8yuPl8X/Etg7Fw==" +) diff --git a/migrations/models/1_20251025091926_update.py b/migrations/models/1_20251025091926_update.py new file mode 100644 index 0000000..3194dc7 --- /dev/null +++ b/migrations/models/1_20251025091926_update.py @@ -0,0 +1,60 @@ +from tortoise import BaseDBAsyncClient + +RUN_IN_TRANSACTION = True + + +async def upgrade(db: BaseDBAsyncClient) -> str: + return """ + -- SQLite doesn't support ADD CONSTRAINT, so we need to recreate the table + CREATE TABLE "conversations_new" ( + "id" CHAR(36) NOT NULL PRIMARY KEY, + "name" VARCHAR(255) NOT NULL, + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "user_id" CHAR(36), + FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON DELETE CASCADE + ); + INSERT INTO "conversations_new" ("id", "name", "created_at", "updated_at") + SELECT "id", "name", "created_at", "updated_at" FROM "conversations"; + DROP TABLE "conversations"; + ALTER TABLE "conversations_new" RENAME TO "conversations";""" + + +async def downgrade(db: BaseDBAsyncClient) -> str: + return """ + -- Recreate table without user_id column + CREATE TABLE "conversations_new" ( + "id" CHAR(36) NOT NULL PRIMARY KEY, + "name" VARCHAR(255) NOT NULL, + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + INSERT INTO "conversations_new" ("id", "name", "created_at", "updated_at") + SELECT "id", "name", "created_at", "updated_at" FROM "conversations"; + DROP TABLE "conversations"; + ALTER TABLE "conversations_new" RENAME TO "conversations";""" + + +MODELS_STATE = ( + "eJztmWtP2zAUhv9KlE8gbQg6xhCaJqWlbB20ndp0F9gUuYnbWiROiJ1Bhfjvs91cnMRNKe" + "PSon6B9vic2H5s57w+vdU934Eu2Wn4+C8MCaDIx/qRdqtj4EH2Qdn+RtNBEGSt3EDB0BUB" + "tuQpWsCQ0BDYlDWOgEsgMzmQ2CEK4s5w5Lrc6NvMEeFxZoowuoqgRf0xpBMYsoaLP8yMsA" + "NvIEm+BpfWCEHXyY0bObxvYbfoNBC2waB1fCI8eXdDy/bdyMOZdzClEx+n7lGEnB0ew9vG" + "EMMQUOhI0+CjjKedmGYjZgYaRjAdqpMZHDgCkcth6B9HEbY5A030xP/sf9KXwMNQc7QIU8" + "7i9m42q2zOwqrzrhpfjN7Wu4NtMUuf0HEoGgUR/U4EAgpmoYJrBlL8L6FsTECoRpn4F2Cy" + "gT4EY2LIOGZ7KAGZAHoYNd0DN5YL8ZhO2Nfa+/cVGL8bPUGSeQmUPtvXs13fiZtqszaONE" + "Noh5BP2QK0DPKYtVDkQTXMfGQBqROH7iQfVhQwm4PTxe40PgQVfM1Wu9k3jfY3PhOPkCtX" + "IDLMJm+pCeu0YN06KCxF+hDtR8v8ovGv2nm30yzu/dTPPNf5mEBEfQv71xZwpPOaWBMwuY" + "WNAueBC5uP3Czsiy5sPHhpXQkMreUyiBTyH2kkHtszLuLDkwZPvaNLZc7gMMrwTvwQojE+" + "hVOBsMXGAbCtShax6BjEj1lVaJk1G0UIrlM1Im8KNjs2J0hn2dPoN4zjpi4YDoF9eQ1Cx5" + "oD04OEgDEkZaD1OPLktAfdVJqpWcoCrj174mq+VeaxFaz8mi8xytErN3k1r2gBmM3bifvm" + "PVXQWaCCJYj3E8OWvJAbUbzWopjCG0XKN5lVjTLxXxdRXJXKmz/NXBZPpO9W2/i5ncvkZ9" + "3O58RdksqNs259o5Bfo5AqK2QSQHCpEgP8AtnEkVeSArnVlcJf+Ojog36zd6TxjP4b91vt" + "unGkEeQNgX6/Jc7dMvd273HJ3Nude8fkTYUDJCea5V7zitDHfOevqYS1CwWv/5SyxfrZyl" + "JcqGkV22VZbfuUSk7cGRTSLblLzNdq/GhvtNn6azO+jssWLeWYddFoz1C4DAAh136o2Jl1" + "hEE4VcOUowowh1M6u/+sHs4KenUuWGW9xZjVWx2j90uteRN/eePWf5lNo4AXegC5y2zTNO" + "Bx9ujiI/+YO3Rv936qp0r2lHXP5uLwOi8Om9L6q1jYtHJXEoCLyp6l35Efp/a5VvXkJ615" + "GjBE9kRXaOW4pVItg8xnZeRyC88pvynVMsdc2Azxyr9ozhSV57e1vf0P+4fvDvYPmYsYSW" + "r5UPEyaHXMBeqYHwTllXa+6pBCNto4BcmPxhIQY/f1BPg00s3HFGJFev/a73bmlqqSkALI" + "AWYTvHCQTd9oLiL0z2piraDIZ11dVy+W0Au5mT+gripqPWch5u4f/FVgYA==" +) diff --git a/pyproject.toml b/pyproject.toml index d6ab609..02c15fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,24 +4,9 @@ version = "0.1.0" description = "Add your description here" readme = "README.md" requires-python = ">=3.13" -dependencies = [ - "chromadb>=1.1.0", - "python-dotenv>=1.0.0", - "flask>=3.1.2", - "httpx>=0.28.1", - "ollama>=0.6.0", - "openai>=2.0.1", - "pydantic>=2.11.9", - "pillow>=10.0.0", - "pymupdf>=1.24.0", - "black>=25.9.0", - "pillow-heif>=1.1.1", - "flask-jwt-extended>=4.7.1", - "bcrypt>=5.0.0", - "pony>=0.7.19", - "flask-login>=0.6.3", - "quart>=0.20.0", - "tortoise-orm>=0.25.1", - "quart-jwt-extended>=0.1.0", - "pre-commit>=4.3.0", -] +dependencies = ["chromadb>=1.1.0", "python-dotenv>=1.0.0", "flask>=3.1.2", "httpx>=0.28.1", "ollama>=0.6.0", "openai>=2.0.1", "pydantic>=2.11.9", "pillow>=10.0.0", "pymupdf>=1.24.0", "black>=25.9.0", "pillow-heif>=1.1.1", "flask-jwt-extended>=4.7.1", "bcrypt>=5.0.0", "pony>=0.7.19", "flask-login>=0.6.3", "quart>=0.20.0", "tortoise-orm>=0.25.1", "quart-jwt-extended>=0.1.0", "pre-commit>=4.3.0", "tortoise-orm-stubs>=1.0.2", "aerich>=0.8.0", "tomlkit>=0.13.3"] + +[tool.aerich] +tortoise_orm = "app.TORTOISE_CONFIG" +location = "./migrations" +src_folder = "./." diff --git a/raggr-frontend/TOKEN_REFRESH_IMPLEMENTATION.md b/raggr-frontend/TOKEN_REFRESH_IMPLEMENTATION.md new file mode 100644 index 0000000..ed2dfa1 --- /dev/null +++ b/raggr-frontend/TOKEN_REFRESH_IMPLEMENTATION.md @@ -0,0 +1,63 @@ +# Token Refresh Implementation + +## Overview + +The API services now automatically handle token refresh when access tokens expire. This provides a seamless user experience without requiring manual re-authentication. + +## How It Works + +### 1. **userService.ts** + +The `userService` now includes: + +- **`refreshToken()`**: Automatically gets the refresh token from localStorage, calls the `/api/user/refresh` endpoint, and updates the access token +- **`fetchWithAuth()`**: A wrapper around `fetch()` that: + 1. Automatically adds the Authorization header with the access token + 2. Detects 401 (Unauthorized) responses + 3. Automatically refreshes the token using the refresh token + 4. Retries the original request with the new access token + 5. Throws an error if refresh fails (e.g., refresh token expired) + +### 2. **conversationService.ts** + +Now uses `userService.fetchWithAuth()` for all API calls: +- `sendQuery()` - No longer needs token parameter +- `getMessages()` - No longer needs token parameter + +### 3. **Components Updated** + +**ChatScreen.tsx**: +- Removed manual token handling +- Now simply calls `conversationService.sendQuery(query)` and `conversationService.getMessages()` + +## Benefits + +✅ **Automatic token refresh** - Users stay logged in longer +✅ **Transparent retry logic** - Failed requests due to expired tokens are automatically retried +✅ **Cleaner code** - Components don't need to manage tokens +✅ **Better UX** - No interruptions when access token expires +✅ **Centralized auth logic** - All auth handling in one place + +## Error Handling + +- If refresh token is missing or invalid, the error is thrown +- Components can catch these errors and redirect to login +- LocalStorage is automatically cleared when refresh fails + +## Usage Example + +```typescript +// Old way (manual token management) +const token = localStorage.getItem("access_token"); +const result = await conversationService.sendQuery(query, token); + +// New way (automatic token refresh) +const result = await conversationService.sendQuery(query); +``` + +## Token Storage + +- **Access Token**: `localStorage.getItem("access_token")` +- **Refresh Token**: `localStorage.getItem("refresh_token")` + +Both are automatically managed by the services. diff --git a/raggr-frontend/src/.App.tsx.swp b/raggr-frontend/src/.App.tsx.swp deleted file mode 100644 index 421db95..0000000 Binary files a/raggr-frontend/src/.App.tsx.swp and /dev/null differ diff --git a/raggr-frontend/src/App.tsx b/raggr-frontend/src/App.tsx index a660f12..2e0f206 100644 --- a/raggr-frontend/src/App.tsx +++ b/raggr-frontend/src/App.tsx @@ -1,203 +1,71 @@ -import { useEffect, useState } from "react"; -import axios from "axios"; -import ReactMarkdown from "react-markdown"; +import { useState, useEffect } from "react"; import "./App.css"; +import { AuthProvider } from "./contexts/AuthContext"; +import { ChatScreen } from "./components/ChatScreen"; +import { LoginScreen } from "./components/LoginScreen"; +import { conversationService } from "./api/conversationService"; -type QuestionAnswer = { - question: string; - answer: string; -}; +const AppContainer = () => { + const [isAuthenticated, setAuthenticated] = useState(false); + const [isChecking, setIsChecking] = useState(true); -type QuestionBubbleProps = { - text: string; -}; + useEffect(() => { + const checkAuth = async () => { + const accessToken = localStorage.getItem("access_token"); + const refreshToken = localStorage.getItem("refresh_token"); -type AnswerBubbleProps = { - text: string; - loading: string; -}; + // No tokens at all, not authenticated + if (!accessToken && !refreshToken) { + setIsChecking(false); + setAuthenticated(false); + return; + } -type QuestionAnswerPairProps = { - question: string; - answer: string; - loading: boolean; -}; + // Try to verify token by making a request + try { + await conversationService.getMessages(); + // If successful, user is authenticated + setAuthenticated(true); + } catch (error) { + // Token is invalid or expired + console.error("Authentication check failed:", error); + localStorage.removeItem("access_token"); + localStorage.removeItem("refresh_token"); + setAuthenticated(false); + } finally { + setIsChecking(false); + } + }; -type Conversation = { - title: string; - id: string; -}; + checkAuth(); + }, []); -type Message = { - text: string; - speaker: "simba" | "user"; -}; + // Show loading state while checking authentication + if (isChecking) { + return ( +
+
Loading...
+
+ ); + } -type ConversationMenuProps = { - conversations: Conversation[]; -}; - -const ConversationMenu = ({ conversations }: ConversationMenuProps) => { return ( -
-

askSimba!

- {conversations.map((conversation) => ( -

- {conversation.title} -

- ))} -
- ); -}; - -const QuestionBubble = ({ text }: QuestionBubbleProps) => { - return
🤦: {text}
; -}; - -const AnswerBubble = ({ text, loading }: AnswerBubbleProps) => { - return ( -
- {loading ? ( -
-
-
-
-
-
-
-
-
-
+ <> + {isAuthenticated ? ( + ) : ( -
- {"🐈: " + text} -
+ )} -
- ); -}; - -const QuestionAnswerPair = ({ - question, - answer, - loading, -}: QuestionAnswerPairProps) => { - return ( -
- - -
+ ); }; const App = () => { - const [query, setQuery] = useState(""); - const [answer, setAnswer] = useState(""); - const [simbaMode, setSimbaMode] = useState(false); - const [questionsAnswers, setQuestionsAnswers] = useState( - [], - ); - const [messages, setMessages] = useState([]); - const [conversations, setConversations] = useState([ - { title: "simba meow meow", id: "uuid" }, - ]); - - const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"]; - - useEffect(() => { - axios.get("/api/messages").then((result) => { - setMessages( - result.data.messages.map((message) => { - return { - text: message.text, - speaker: message.speaker, - }; - }), - ); - }); - }, []); - - const handleQuestionSubmit = () => { - let currMessages = messages.concat([{ text: query, speaker: "user" }]); - setMessages(currMessages); - if (simbaMode) { - console.log("simba mode activated"); - const randomIndex = Math.floor(Math.random() * simbaAnswers.length); - const randomElement = simbaAnswers[randomIndex]; - setAnswer(randomElement); - setQuestionsAnswers( - questionsAnswers.concat([ - { - question: query, - answer: randomElement, - }, - ]), - ); - return; - } - const payload = { query: query }; - axios.post("/api/query", payload).then((result) => { - setQuestionsAnswers( - questionsAnswers.concat([ - { question: query, answer: result.data.response }, - ]), - ); - setMessages( - currMessages.concat([{ text: result.data.response, speaker: "simba" }]), - ); - }); - }; - const handleQueryChange = (event) => { - setQuery(event.target.value); - }; return ( -
-
-
-
-
-

ask simba!

-
- {/*{questionsAnswers.map((qa) => ( - - ))}*/} - {messages.map((msg) => { - if (msg.speaker == "simba") { - return ; - } - - return ; - })} -