Compare commits
38 Commits
data-prepr
...
245db92524
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
245db92524 | ||
|
|
29ac724d50 | ||
|
|
7161c09a4e | ||
|
|
68d73b62e8 | ||
|
|
6b616137d3 | ||
|
|
841b6ebd4f | ||
|
|
45a5e92aee | ||
|
|
8479898cc4 | ||
|
|
acaf681927 | ||
|
|
2bbe33fedc | ||
|
|
b872750444 | ||
|
|
376baccadb | ||
|
|
c978b1a255 | ||
|
|
51b9932389 | ||
|
|
ebf39480b6 | ||
|
|
e4a04331cb | ||
|
|
166ffb4c09 | ||
|
|
64e286e623 | ||
|
|
c6c14729dd | ||
|
|
910097d13b | ||
|
|
0bb3e3172b | ||
|
|
24b30bc8a3 | ||
|
|
3ffc95a1b0 | ||
|
|
c5091dc07a | ||
|
|
c140758560 | ||
|
|
ab3a0eb442 | ||
|
|
c619d78922 | ||
|
|
c20ae0a4b9 | ||
|
|
26cc01b58b | ||
|
|
746b60e070 | ||
|
|
577c9144ac | ||
|
|
2b2891bd79 | ||
|
|
03b033e9a4 | ||
|
|
a640ae5fed | ||
|
|
99c98b7e42 | ||
|
|
a69f7864f3 | ||
|
|
679cfb08e4 | ||
|
|
fc504d3e9c |
16
.dockerignore
Normal file
16
.dockerignore
Normal file
@@ -0,0 +1,16 @@
|
||||
.git
|
||||
.gitignore
|
||||
README.md
|
||||
.env
|
||||
.DS_Store
|
||||
chromadb/
|
||||
chroma_db/
|
||||
raggr-frontend/node_modules/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
.venv/
|
||||
venv/
|
||||
.pytest_cache/
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.13
|
||||
49
Dockerfile
Normal file
49
Dockerfile
Normal file
@@ -0,0 +1,49 @@
|
||||
FROM python:3.13-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies, Node.js, Yarn, and uv
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
curl \
|
||||
&& curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
|
||||
&& apt-get install -y nodejs \
|
||||
&& npm install -g yarn \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Add uv to PATH
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml ./
|
||||
|
||||
# Install Python dependencies using uv
|
||||
RUN uv pip install --system -e .
|
||||
|
||||
# Copy application code
|
||||
COPY *.py ./
|
||||
COPY blueprints ./blueprints
|
||||
COPY aerich.toml ./
|
||||
COPY migrations ./migrations
|
||||
COPY startup.sh ./
|
||||
RUN chmod +x startup.sh
|
||||
|
||||
# Copy frontend code and build
|
||||
COPY raggr-frontend ./raggr-frontend
|
||||
WORKDIR /app/raggr-frontend
|
||||
RUN yarn install && yarn build
|
||||
WORKDIR /app
|
||||
|
||||
# Create ChromaDB directory
|
||||
RUN mkdir -p /app/chromadb
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8080
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONPATH=/app
|
||||
ENV CHROMADB_PATH=/app/chromadb
|
||||
|
||||
# Run the startup script
|
||||
CMD ["./startup.sh"]
|
||||
54
MIGRATIONS.md
Normal file
54
MIGRATIONS.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# Database Migrations with Aerich
|
||||
|
||||
## Initial Setup (Run Once)
|
||||
|
||||
1. Install dependencies:
|
||||
```bash
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
2. Initialize Aerich:
|
||||
```bash
|
||||
aerich init-db
|
||||
```
|
||||
|
||||
This will:
|
||||
- Create a `migrations/` directory
|
||||
- Generate the initial migration based on your models
|
||||
- Create all tables in the database
|
||||
|
||||
## When You Add/Change Models
|
||||
|
||||
1. Generate a new migration:
|
||||
```bash
|
||||
aerich migrate --name "describe_your_changes"
|
||||
```
|
||||
|
||||
Example:
|
||||
```bash
|
||||
aerich migrate --name "add_user_profile_model"
|
||||
```
|
||||
|
||||
2. Apply the migration:
|
||||
```bash
|
||||
aerich upgrade
|
||||
```
|
||||
|
||||
## Common Commands
|
||||
|
||||
- `aerich init-db` - Initialize database (first time only)
|
||||
- `aerich migrate --name "description"` - Generate new migration
|
||||
- `aerich upgrade` - Apply pending migrations
|
||||
- `aerich downgrade` - Rollback last migration
|
||||
- `aerich history` - Show migration history
|
||||
- `aerich heads` - Show current migration heads
|
||||
|
||||
## Docker Setup
|
||||
|
||||
In Docker, migrations run automatically on container startup via the startup script.
|
||||
|
||||
## Notes
|
||||
|
||||
- Migration files are stored in `migrations/models/`
|
||||
- Always commit migration files to version control
|
||||
- Don't modify migration files manually after they're created
|
||||
130
add_user.py
Normal file
130
add_user.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# GENERATED BY CLAUDE
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
import asyncio
|
||||
from tortoise import Tortoise
|
||||
from blueprints.users.models import User
|
||||
|
||||
|
||||
async def add_user(username: str, email: str, password: str):
|
||||
"""Add a new user to the database"""
|
||||
await Tortoise.init(
|
||||
db_url="sqlite://raggr.db",
|
||||
modules={
|
||||
"models": [
|
||||
"blueprints.users.models",
|
||||
"blueprints.conversation.models",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if user already exists
|
||||
existing_user = await User.filter(email=email).first()
|
||||
if existing_user:
|
||||
print(f"Error: User with email '{email}' already exists!")
|
||||
return False
|
||||
|
||||
existing_username = await User.filter(username=username).first()
|
||||
if existing_username:
|
||||
print(f"Error: Username '{username}' is already taken!")
|
||||
return False
|
||||
|
||||
# Create new user
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
username=username,
|
||||
email=email,
|
||||
)
|
||||
user.set_password(password)
|
||||
await user.save()
|
||||
|
||||
print("✓ User created successfully!")
|
||||
print(f" Username: {username}")
|
||||
print(f" Email: {email}")
|
||||
print(f" ID: {user.id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating user: {e}")
|
||||
return False
|
||||
finally:
|
||||
await Tortoise.close_connections()
|
||||
|
||||
|
||||
async def list_users():
|
||||
"""List all users in the database"""
|
||||
await Tortoise.init(
|
||||
db_url="sqlite://raggr.db",
|
||||
modules={
|
||||
"models": [
|
||||
"blueprints.users.models",
|
||||
"blueprints.conversation.models",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
users = await User.all()
|
||||
if not users:
|
||||
print("No users found in database.")
|
||||
return
|
||||
|
||||
print(f"\nFound {len(users)} user(s):")
|
||||
print("-" * 60)
|
||||
for user in users:
|
||||
print(f"Username: {user.username}")
|
||||
print(f"Email: {user.email}")
|
||||
print(f"ID: {user.id}")
|
||||
print(f"Created: {user.created_at}")
|
||||
print("-" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error listing users: {e}")
|
||||
finally:
|
||||
await Tortoise.close_connections()
|
||||
|
||||
|
||||
def print_usage():
|
||||
"""Print usage instructions"""
|
||||
print("Usage:")
|
||||
print(" python add_user.py add <username> <email> <password>")
|
||||
print(" python add_user.py list")
|
||||
print("\nExamples:")
|
||||
print(" python add_user.py add ryan ryan@example.com mypassword123")
|
||||
print(" python add_user.py list")
|
||||
|
||||
|
||||
async def main():
|
||||
if len(sys.argv) < 2:
|
||||
print_usage()
|
||||
sys.exit(1)
|
||||
|
||||
command = sys.argv[1].lower()
|
||||
|
||||
if command == "add":
|
||||
if len(sys.argv) != 5:
|
||||
print("Error: Missing arguments for 'add' command")
|
||||
print_usage()
|
||||
sys.exit(1)
|
||||
|
||||
username = sys.argv[2]
|
||||
email = sys.argv[3]
|
||||
password = sys.argv[4]
|
||||
|
||||
success = await add_user(username, email, password)
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
elif command == "list":
|
||||
await list_users()
|
||||
sys.exit(0)
|
||||
|
||||
else:
|
||||
print(f"Error: Unknown command '{command}'")
|
||||
print_usage()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
15
aerich_config.py
Normal file
15
aerich_config.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
|
||||
TORTOISE_ORM = {
|
||||
"connections": {"default": os.getenv("DATABASE_URL", "sqlite:///app/raggr.db")},
|
||||
"apps": {
|
||||
"models": {
|
||||
"models": [
|
||||
"blueprints.conversation.models",
|
||||
"blueprints.users.models",
|
||||
"aerich.models",
|
||||
],
|
||||
"default_connection": "default",
|
||||
},
|
||||
},
|
||||
}
|
||||
128
app.py
Normal file
128
app.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
|
||||
from quart import Quart, request, jsonify, render_template, send_from_directory
|
||||
from tortoise.contrib.quart import register_tortoise
|
||||
|
||||
from quart_jwt_extended import JWTManager, jwt_refresh_token_required, get_jwt_identity
|
||||
|
||||
from main import consult_simba_oracle
|
||||
|
||||
import blueprints.users
|
||||
import blueprints.conversation
|
||||
import blueprints.conversation.logic
|
||||
import blueprints.users.models
|
||||
|
||||
app = Quart(
|
||||
__name__,
|
||||
static_folder="raggr-frontend/dist/static",
|
||||
template_folder="raggr-frontend/dist",
|
||||
)
|
||||
|
||||
app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY", "SECRET_KEY")
|
||||
jwt = JWTManager(app)
|
||||
|
||||
# Register blueprints
|
||||
app.register_blueprint(blueprints.users.user_blueprint)
|
||||
app.register_blueprint(blueprints.conversation.conversation_blueprint)
|
||||
|
||||
|
||||
TORTOISE_CONFIG = {
|
||||
"connections": {"default": "sqlite://raggr.db"},
|
||||
"apps": {
|
||||
"models": {
|
||||
"models": [
|
||||
"blueprints.conversation.models",
|
||||
"blueprints.users.models",
|
||||
"aerich.models",
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Initialize Tortoise ORM
|
||||
register_tortoise(
|
||||
app,
|
||||
config=TORTOISE_CONFIG,
|
||||
generate_schemas=False, # Disabled - using Aerich for migrations
|
||||
)
|
||||
|
||||
|
||||
# Serve React static files
|
||||
@app.route("/static/<path:filename>")
|
||||
async def static_files(filename):
|
||||
return await send_from_directory(app.static_folder, filename)
|
||||
|
||||
|
||||
# Serve the React app for all routes (catch-all)
|
||||
@app.route("/", defaults={"path": ""})
|
||||
@app.route("/<path:path>")
|
||||
async def serve_react_app(path):
|
||||
if path and os.path.exists(os.path.join(app.template_folder, path)):
|
||||
return await send_from_directory(app.template_folder, path)
|
||||
return await render_template("index.html")
|
||||
|
||||
|
||||
@app.route("/api/query", methods=["POST"])
|
||||
@jwt_refresh_token_required
|
||||
async def query():
|
||||
current_user_uuid = get_jwt_identity()
|
||||
user = await blueprints.users.models.User.get(id=current_user_uuid)
|
||||
data = await request.get_json()
|
||||
query = data.get("query")
|
||||
conversation = await blueprints.conversation.logic.get_conversation_for_user(
|
||||
user=user
|
||||
)
|
||||
await blueprints.conversation.logic.add_message_to_conversation(
|
||||
conversation=conversation,
|
||||
message=query,
|
||||
speaker="user",
|
||||
user=user,
|
||||
)
|
||||
|
||||
response = consult_simba_oracle(query)
|
||||
await blueprints.conversation.logic.add_message_to_conversation(
|
||||
conversation=conversation,
|
||||
message=response,
|
||||
speaker="simba",
|
||||
user=user,
|
||||
)
|
||||
return jsonify({"response": response})
|
||||
|
||||
|
||||
@app.route("/api/messages", methods=["GET"])
|
||||
@jwt_refresh_token_required
|
||||
async def get_messages():
|
||||
current_user_uuid = get_jwt_identity()
|
||||
user = await blueprints.users.models.User.get(id=current_user_uuid)
|
||||
|
||||
conversation = await blueprints.conversation.logic.get_conversation_for_user(
|
||||
user=user
|
||||
)
|
||||
# Prefetch related messages
|
||||
await conversation.fetch_related("messages")
|
||||
|
||||
# Manually serialize the conversation with messages
|
||||
messages = []
|
||||
for msg in conversation.messages:
|
||||
messages.append(
|
||||
{
|
||||
"id": str(msg.id),
|
||||
"text": msg.text,
|
||||
"speaker": msg.speaker.value,
|
||||
"created_at": msg.created_at.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"id": str(conversation.id),
|
||||
"name": conversation.name,
|
||||
"messages": messages,
|
||||
"created_at": conversation.created_at.isoformat(),
|
||||
"updated_at": conversation.updated_at.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=8080, debug=True)
|
||||
1
blueprints/__init__.py
Normal file
1
blueprints/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Blueprints package
|
||||
17
blueprints/conversation/__init__.py
Normal file
17
blueprints/conversation/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from quart import Blueprint, jsonify
|
||||
from .models import (
|
||||
Conversation,
|
||||
PydConversation,
|
||||
)
|
||||
|
||||
conversation_blueprint = Blueprint(
|
||||
"conversation_api", __name__, url_prefix="/api/conversation"
|
||||
)
|
||||
|
||||
|
||||
@conversation_blueprint.route("/<conversation_id>")
|
||||
async def get_conversation(conversation_id: str):
|
||||
conversation = await Conversation.get(id=conversation_id)
|
||||
serialized_conversation = await PydConversation.from_tortoise_orm(conversation)
|
||||
|
||||
return jsonify(serialized_conversation.model_dump_json())
|
||||
46
blueprints/conversation/logic.py
Normal file
46
blueprints/conversation/logic.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import tortoise.exceptions
|
||||
|
||||
from .models import Conversation, ConversationMessage
|
||||
|
||||
import blueprints.users.models
|
||||
|
||||
|
||||
async def create_conversation(name: str = "") -> Conversation:
|
||||
conversation = await Conversation.create(name=name)
|
||||
return conversation
|
||||
|
||||
|
||||
async def add_message_to_conversation(
|
||||
conversation: Conversation,
|
||||
message: str,
|
||||
speaker: str,
|
||||
user: blueprints.users.models.User,
|
||||
) -> ConversationMessage:
|
||||
print(conversation, message, speaker)
|
||||
message = await ConversationMessage.create(
|
||||
text=message,
|
||||
speaker=speaker,
|
||||
conversation=conversation,
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
|
||||
async def get_the_only_conversation() -> Conversation:
|
||||
try:
|
||||
conversation = await Conversation.all().first()
|
||||
if conversation is None:
|
||||
conversation = await Conversation.create(name="simba_chat")
|
||||
except Exception as _e:
|
||||
conversation = await Conversation.create(name="simba_chat")
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
async def get_conversation_for_user(user: blueprints.users.models.User) -> Conversation:
|
||||
try:
|
||||
return await Conversation.get(user=user)
|
||||
except tortoise.exceptions.DoesNotExist:
|
||||
await Conversation.get_or_create(name=f"{user.username}'s chat", user=user)
|
||||
|
||||
return await Conversation.get(user=user)
|
||||
44
blueprints/conversation/models.py
Normal file
44
blueprints/conversation/models.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import enum
|
||||
|
||||
from tortoise.models import Model
|
||||
from tortoise import fields
|
||||
from tortoise.contrib.pydantic import (
|
||||
pydantic_queryset_creator,
|
||||
pydantic_model_creator,
|
||||
)
|
||||
|
||||
|
||||
class Speaker(enum.Enum):
|
||||
USER = "user"
|
||||
SIMBA = "simba"
|
||||
|
||||
|
||||
class Conversation(Model):
|
||||
id = fields.UUIDField(primary_key=True)
|
||||
name = fields.CharField(max_length=255)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
updated_at = fields.DatetimeField(auto_now=True)
|
||||
user: fields.ForeignKeyRelation = fields.ForeignKeyField(
|
||||
"models.User", related_name="conversations", null=True
|
||||
)
|
||||
|
||||
class Meta:
|
||||
table = "conversations"
|
||||
|
||||
|
||||
class ConversationMessage(Model):
|
||||
id = fields.UUIDField(primary_key=True)
|
||||
text = fields.TextField()
|
||||
conversation = fields.ForeignKeyField(
|
||||
"models.Conversation", related_name="messages"
|
||||
)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
speaker = fields.CharEnumField(enum_type=Speaker, max_length=10)
|
||||
|
||||
class Meta:
|
||||
table = "conversation_messages"
|
||||
|
||||
|
||||
PydConversationMessage = pydantic_model_creator(ConversationMessage)
|
||||
PydConversation = pydantic_model_creator(Conversation, name="Conversation")
|
||||
PydListConversationMessage = pydantic_queryset_creator(ConversationMessage)
|
||||
40
blueprints/users/__init__.py
Normal file
40
blueprints/users/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from quart import Blueprint, jsonify, request
|
||||
from quart_jwt_extended import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
jwt_refresh_token_required,
|
||||
get_jwt_identity,
|
||||
)
|
||||
from .models import User
|
||||
|
||||
|
||||
user_blueprint = Blueprint("user_api", __name__, url_prefix="/api/user")
|
||||
|
||||
|
||||
@user_blueprint.route("/login", methods=["POST"])
|
||||
async def login():
|
||||
data = await request.get_json()
|
||||
username = data.get("username")
|
||||
password = data.get("password")
|
||||
|
||||
user = await User.filter(username=username).first()
|
||||
|
||||
if not user or not user.verify_password(password):
|
||||
return jsonify({"msg": "Invalid credentials"}), 401
|
||||
|
||||
access_token = create_access_token(identity=str(user.id))
|
||||
refresh_token = create_refresh_token(identity=str(user.id))
|
||||
|
||||
return jsonify(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user={"id": user.id, "username": user.username},
|
||||
)
|
||||
|
||||
|
||||
@user_blueprint.route("/refresh", methods=["POST"])
|
||||
@jwt_refresh_token_required
|
||||
async def refresh():
|
||||
user_id = get_jwt_identity()
|
||||
new_token = create_access_token(identity=user_id)
|
||||
return jsonify(access_token=new_token)
|
||||
26
blueprints/users/models.py
Normal file
26
blueprints/users/models.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from tortoise.models import Model
|
||||
from tortoise import fields
|
||||
|
||||
|
||||
import bcrypt
|
||||
|
||||
|
||||
class User(Model):
|
||||
id = fields.UUIDField(primary_key=True)
|
||||
username = fields.CharField(max_length=255)
|
||||
password = fields.BinaryField() # Hashed
|
||||
email = fields.CharField(max_length=100, unique=True)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
updated_at = fields.DatetimeField(auto_now=True)
|
||||
|
||||
class Meta:
|
||||
table = "users"
|
||||
|
||||
def set_password(self, plain_password: str):
|
||||
self.password = bcrypt.hashpw(
|
||||
plain_password.encode("utf-8"),
|
||||
bcrypt.gensalt(),
|
||||
)
|
||||
|
||||
def verify_password(self, plain_password: str):
|
||||
return bcrypt.checkpw(plain_password.encode("utf-8"), self.password)
|
||||
35
chunker.py
35
chunker.py
@@ -1,16 +1,22 @@
|
||||
import os
|
||||
from math import ceil
|
||||
import re
|
||||
from typing import Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
from ollama import Client
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
from llm import LLMClient
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
ollama_client = Client(
|
||||
host=os.getenv("OLLAMA_HOST", "http://localhost:11434"), timeout=10.0
|
||||
)
|
||||
|
||||
|
||||
def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
|
||||
if header_patterns is None:
|
||||
@@ -79,18 +85,26 @@ class Chunk:
|
||||
|
||||
|
||||
class Chunker:
|
||||
embedding_fx = OllamaEmbeddingFunction(
|
||||
url=os.getenv("OLLAMA_URL", ""),
|
||||
model_name="mxbai-embed-large",
|
||||
)
|
||||
|
||||
def __init__(self, collection) -> None:
|
||||
self.collection = collection
|
||||
self.llm_client = LLMClient()
|
||||
|
||||
def chunk_document(self, document: str, chunk_size: int = 1000) -> list[Chunk]:
|
||||
def embedding_fx(self, inputs):
|
||||
openai_embedding_fx = OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model_name="text-embedding-3-small",
|
||||
)
|
||||
return openai_embedding_fx(inputs)
|
||||
|
||||
def chunk_document(
|
||||
self,
|
||||
document: str,
|
||||
chunk_size: int = 1000,
|
||||
metadata: dict[str, Union[str, float]] = {},
|
||||
) -> list[Chunk]:
|
||||
doc_uuid = uuid4()
|
||||
|
||||
chunk_size = min(chunk_size, len(document))
|
||||
chunk_size = min(chunk_size, len(document)) or 1
|
||||
|
||||
chunks = []
|
||||
num_chunks = ceil(len(document) / chunk_size)
|
||||
@@ -110,6 +124,7 @@ class Chunker:
|
||||
ids=[str(doc_uuid) + ":" + str(i)],
|
||||
documents=[text_chunk],
|
||||
embeddings=embedding,
|
||||
metadatas=[metadata],
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
165
cleaner.py
Normal file
165
cleaner.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import argparse
|
||||
from dotenv import load_dotenv
|
||||
import ollama
|
||||
from PIL import Image
|
||||
import fitz
|
||||
|
||||
from request import PaperlessNGXService
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Configure ollama client with URL from environment or default to localhost
|
||||
ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
||||
|
||||
parser = argparse.ArgumentParser(description="use llm to clean documents")
|
||||
parser.add_argument("document_id", type=str, help="questions about simba's health")
|
||||
|
||||
|
||||
def pdf_to_image(filepath: str, dpi=300) -> list[str]:
|
||||
"""Returns the filepaths to the created images"""
|
||||
image_temp_files = []
|
||||
try:
|
||||
pdf_document = fitz.open(filepath)
|
||||
print(f"\nConverting '{os.path.basename(filepath)}' to temporary images...")
|
||||
|
||||
for page_num in range(len(pdf_document)):
|
||||
page = pdf_document.load_page(page_num)
|
||||
zoom = dpi / 72
|
||||
mat = fitz.Matrix(zoom, zoom)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
# Create a temporary file for the image. delete=False is crucial.
|
||||
with tempfile.NamedTemporaryFile(
|
||||
delete=False,
|
||||
suffix=".png",
|
||||
prefix=f"pdf_page_{page_num + 1}_",
|
||||
) as temp_image_file:
|
||||
temp_image_path = temp_image_file.name
|
||||
|
||||
# Save the pixel data to the temporary file
|
||||
pix.save(temp_image_path)
|
||||
image_temp_files.append(temp_image_path)
|
||||
print(
|
||||
f" -> Saved page {page_num + 1} to temporary file: '{temp_image_path}'"
|
||||
)
|
||||
|
||||
print("\nConversion successful! ✨")
|
||||
return image_temp_files
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred during PDF conversion: {e}", file=sys.stderr)
|
||||
# Clean up any image files that were created before the error
|
||||
for path in image_temp_files:
|
||||
os.remove(path)
|
||||
return []
|
||||
|
||||
|
||||
def merge_images_vertically_to_tempfile(image_paths):
|
||||
"""
|
||||
Merges a list of images vertically and saves the result to a temporary file.
|
||||
|
||||
Args:
|
||||
image_paths (list): A list of strings, where each string is the
|
||||
filepath to an image.
|
||||
|
||||
Returns:
|
||||
str: The filepath of the temporary merged image file.
|
||||
"""
|
||||
if not image_paths:
|
||||
print("Error: The list of image paths is empty.")
|
||||
return None
|
||||
|
||||
# Open all images and check for consistency
|
||||
try:
|
||||
images = [Image.open(path) for path in image_paths]
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: Could not find image file: {e}")
|
||||
return None
|
||||
|
||||
widths, heights = zip(*(img.size for img in images))
|
||||
max_width = max(widths)
|
||||
|
||||
# All images must have the same width
|
||||
if not all(width == max_width for width in widths):
|
||||
print("Warning: Images have different widths. They will be resized.")
|
||||
resized_images = []
|
||||
for img in images:
|
||||
if img.size[0] != max_width:
|
||||
img = img.resize(
|
||||
(max_width, int(img.size[1] * (max_width / img.size[0])))
|
||||
)
|
||||
resized_images.append(img)
|
||||
images = resized_images
|
||||
heights = [img.size[1] for img in images]
|
||||
|
||||
# Calculate the total height of the merged image
|
||||
total_height = sum(heights)
|
||||
|
||||
# Create a new blank image with the combined dimensions
|
||||
merged_image = Image.new("RGB", (max_width, total_height))
|
||||
|
||||
# Paste each image onto the new blank image
|
||||
y_offset = 0
|
||||
for img in images:
|
||||
merged_image.paste(img, (0, y_offset))
|
||||
y_offset += img.height
|
||||
|
||||
# Create a temporary file and save the image
|
||||
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
||||
temp_path = temp_file.name
|
||||
merged_image.save(temp_path)
|
||||
temp_file.close()
|
||||
|
||||
print(f"Successfully merged {len(images)} images into temporary file: {temp_path}")
|
||||
return temp_path
|
||||
|
||||
|
||||
OCR_PROMPT = """
|
||||
You job is to extract text from the images I provide you. Extract every bit of the text in the image. Don't say anything just do your job. Text should be same as in the images. If there are multiple images, categorize the transcriptions by page.
|
||||
|
||||
Things to avoid:
|
||||
- Don't miss anything to extract from the images
|
||||
|
||||
Things to include:
|
||||
- Include everything, even anything inside [], (), {} or anything.
|
||||
- Include any repetitive things like "..." or anything
|
||||
- If you think there is any mistake in image just include it too
|
||||
|
||||
Someone will kill the innocent kittens if you don't extract the text exactly. So, make sure you extract every bit of the text. Only output the extracted text.
|
||||
"""
|
||||
|
||||
|
||||
def summarize_pdf_image(filepaths: list[str]):
|
||||
res = ollama_client.chat(
|
||||
model="gemma3:4b",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": OCR_PROMPT,
|
||||
"images": filepaths,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
return res["message"]["content"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ppngx = PaperlessNGXService()
|
||||
|
||||
if args.document_id:
|
||||
doc_id = args.document_id
|
||||
file = ppngx.get_doc_by_id(doc_id=doc_id)
|
||||
pdf_path = ppngx.download_pdf_from_id(doc_id)
|
||||
print(pdf_path)
|
||||
image_paths = pdf_to_image(filepath=pdf_path)
|
||||
summary = summarize_pdf_image(filepaths=image_paths)
|
||||
print(summary)
|
||||
file["content"] = summary
|
||||
print(file)
|
||||
ppngx.upload_cleaned_content(doc_id, file)
|
||||
17
docker-compose.yml
Normal file
17
docker-compose.yml
Normal file
@@ -0,0 +1,17 @@
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
raggr:
|
||||
image: torrtle/simbarag:latest
|
||||
network_mode: host
|
||||
environment:
|
||||
- PAPERLESS_TOKEN=${PAPERLESS_TOKEN}
|
||||
- BASE_URL=${BASE_URL}
|
||||
- OLLAMA_URL=${OLLAMA_URL:-http://localhost:11434}
|
||||
- CHROMADB_PATH=/app/chromadb
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
volumes:
|
||||
- chromadb_data:/app/chromadb
|
||||
|
||||
volumes:
|
||||
chromadb_data:
|
||||
83
image_process.py
Normal file
83
image_process.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from ollama import Client
|
||||
import argparse
|
||||
import os
|
||||
import logging
|
||||
from PIL import Image, ExifTags
|
||||
from pillow_heif import register_heif_opener
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
register_heif_opener()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="SimbaImageProcessor",
|
||||
description="What the program does",
|
||||
epilog="Text at the bottom of help",
|
||||
)
|
||||
|
||||
parser.add_argument("filepath")
|
||||
|
||||
client = Client(host=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
|
||||
|
||||
|
||||
class SimbaImageDescription(BaseModel):
|
||||
image_date: str
|
||||
description: str
|
||||
|
||||
|
||||
def describe_simba_image(input):
|
||||
logging.info("Opening image of Simba ...")
|
||||
if "heic" in input.lower() or "heif" in input.lower():
|
||||
new_filepath = input.split(".")[0] + ".jpg"
|
||||
img = Image.open(input)
|
||||
img.save(new_filepath, "JPEG")
|
||||
logging.info("Extracting EXIF...")
|
||||
exif = {
|
||||
ExifTags.TAGS[k]: v for k, v in img.getexif().items() if k in ExifTags.TAGS
|
||||
}
|
||||
img = Image.open(new_filepath)
|
||||
input = new_filepath
|
||||
else:
|
||||
img = Image.open(input)
|
||||
|
||||
logging.info("Extracting EXIF...")
|
||||
exif = {
|
||||
ExifTags.TAGS[k]: v for k, v in img.getexif().items() if k in ExifTags.TAGS
|
||||
}
|
||||
|
||||
if "MakerNote" in exif:
|
||||
exif.pop("MakerNote")
|
||||
|
||||
logging.info(exif)
|
||||
|
||||
prompt = f"Simba is an orange cat belonging to Ryan Chen. In 2025, they lived in New York. In 2024, they lived in California. Analyze the following image and tell me what Simba seems to be doing. Be extremely descriptive about Simba, things in the background, and the setting of the image. I will also include the EXIF data of the image, please use it to help you determine information about Simba. EXIF: {exif}. Put the notes in the description field and the date in the image_date field."
|
||||
|
||||
logging.info("Sending info to Ollama ...")
|
||||
response = client.chat(
|
||||
model="gemma3:4b",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "you are a very shrewd and descriptive note taker. all of your responses will be formatted like notes in bullet points. be very descriptive. do not leave a single thing out.",
|
||||
},
|
||||
{"role": "user", "content": prompt, "images": [input]},
|
||||
],
|
||||
format=SimbaImageDescription.model_json_schema(),
|
||||
)
|
||||
|
||||
result = SimbaImageDescription.model_validate_json(response["message"]["content"])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.filepath:
|
||||
logging.info
|
||||
describe_simba_image(input=args.filepath)
|
||||
115
index_immich.py
Normal file
115
index_immich.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import httpx
|
||||
import os
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import tempfile
|
||||
|
||||
from image_process import describe_simba_image
|
||||
from request import PaperlessNGXService
|
||||
import sqlite3
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Configuration from environment variables
|
||||
IMMICH_URL = os.getenv("IMMICH_URL", "http://localhost:2283")
|
||||
API_KEY = os.getenv("IMMICH_API_KEY")
|
||||
PERSON_NAME = os.getenv("PERSON_NAME", "Simba") # Name of the tagged person/pet
|
||||
DOWNLOAD_DIR = os.getenv("DOWNLOAD_DIR", "./simba_photos")
|
||||
|
||||
# Set up headers
|
||||
headers = {"x-api-key": API_KEY, "Content-Type": "application/json"}
|
||||
|
||||
VISITED = {}
|
||||
|
||||
if __name__ == "__main__":
|
||||
conn = sqlite3.connect("./visited.db")
|
||||
c = conn.cursor()
|
||||
c.execute("select immich_id from visited")
|
||||
rows = c.fetchall()
|
||||
for row in rows:
|
||||
VISITED.add(row[0])
|
||||
|
||||
ppngx = PaperlessNGXService()
|
||||
people_url = f"{IMMICH_URL}/api/search/person?name=Simba"
|
||||
people = httpx.get(people_url, headers=headers).json()
|
||||
|
||||
simba_id = people[0]["id"]
|
||||
|
||||
ids = {}
|
||||
|
||||
asset_search = f"{IMMICH_URL}/api/search/smart"
|
||||
request_body = {"query": "orange cat"}
|
||||
results = httpx.post(asset_search, headers=headers, json=request_body)
|
||||
|
||||
assets = results.json()["assets"]
|
||||
for asset in assets["items"]:
|
||||
if asset["type"] == "IMAGE" and asset["id"] not in VISITED:
|
||||
ids[asset["id"]] = asset.get("originalFileName")
|
||||
nextPage = assets.get("nextPage")
|
||||
|
||||
# while nextPage != None:
|
||||
# logging.info(f"next page: {nextPage}")
|
||||
# request_body["page"] = nextPage
|
||||
# results = httpx.post(asset_search, headers=headers, json=request_body)
|
||||
# assets = results.json()["assets"]
|
||||
|
||||
# for asset in assets["items"]:
|
||||
# if asset["type"] == "IMAGE":
|
||||
# ids.add(asset['id'])
|
||||
|
||||
# nextPage = assets.get("nextPage")
|
||||
|
||||
asset_search = f"{IMMICH_URL}/api/search/smart"
|
||||
request_body = {"query": "simba"}
|
||||
results = httpx.post(asset_search, headers=headers, json=request_body)
|
||||
for asset in results.json()["assets"]["items"]:
|
||||
if asset["type"] == "IMAGE":
|
||||
ids[asset["id"]] = asset.get("originalFileName")
|
||||
|
||||
for immich_asset_id, immich_filename in ids.items():
|
||||
try:
|
||||
response = httpx.get(
|
||||
f"{IMMICH_URL}/api/assets/{immich_asset_id}/original", headers=headers
|
||||
)
|
||||
|
||||
path = os.path.join("/Users/ryanchen/Programs/raggr", immich_filename)
|
||||
file = open(path, "wb+")
|
||||
for chunk in response.iter_bytes(chunk_size=8192):
|
||||
file.write(chunk)
|
||||
|
||||
logging.info("Processing image ...")
|
||||
description = describe_simba_image(path)
|
||||
|
||||
image_description = description.description
|
||||
image_date = description.image_date
|
||||
|
||||
description_filepath = os.path.join(
|
||||
"/Users/ryanchen/Programs/raggr", f"SIMBA_DESCRIBE_001.txt"
|
||||
)
|
||||
file = open(description_filepath, "w+")
|
||||
file.write(image_description)
|
||||
file.close()
|
||||
|
||||
file = open(description_filepath, "rb")
|
||||
ppngx.upload_description(
|
||||
description_filepath=description_filepath,
|
||||
file=file,
|
||||
title="SIMBA_DESCRIBE_001.txt",
|
||||
exif_date=image_date,
|
||||
)
|
||||
file.close()
|
||||
|
||||
c.execute("INSERT INTO visited (immich_id) values (?)", (immich_asset_id,))
|
||||
conn.commit()
|
||||
logging.info("Processing complete. Deleting file.")
|
||||
os.remove(file.name)
|
||||
except Exception as e:
|
||||
logging.info(f"something went wrong for {immich_filename}")
|
||||
logging.info(e)
|
||||
|
||||
conn.close()
|
||||
73
llm.py
Normal file
73
llm.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import os
|
||||
|
||||
from ollama import Client
|
||||
from openai import OpenAI
|
||||
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
TRY_OLLAMA = os.getenv("TRY_OLLAMA", False)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self):
|
||||
try:
|
||||
self.ollama_client = Client(
|
||||
host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0
|
||||
)
|
||||
self.ollama_client.chat(
|
||||
model="gemma3:4b", messages=[{"role": "system", "content": "test"}]
|
||||
)
|
||||
self.PROVIDER = "ollama"
|
||||
logging.info("Using Ollama as LLM backend")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
self.openai_client = OpenAI()
|
||||
self.PROVIDER = "openai"
|
||||
logging.info("Using OpenAI as LLM backend")
|
||||
|
||||
def chat(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
):
|
||||
# Instituting a fallback if my gaming PC is not on
|
||||
if self.PROVIDER == "ollama":
|
||||
try:
|
||||
response = self.ollama_client.chat(
|
||||
model="gemma3:4b",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
)
|
||||
output = response.message.content
|
||||
return output
|
||||
except Exception as e:
|
||||
logging.error(f"Could not connect to OLLAMA: {str(e)}")
|
||||
|
||||
response = self.openai_client.responses.create(
|
||||
model="gpt-4o-mini",
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
)
|
||||
output = response.output_text
|
||||
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
client = Client()
|
||||
client.chat(model="gemma3:4b", messages=[{"role": "system", "promp": "hack"}])
|
||||
201
main.py
201
main.py
@@ -1,5 +1,7 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
|
||||
import argparse
|
||||
import chromadb
|
||||
@@ -8,14 +10,22 @@ import ollama
|
||||
|
||||
from request import PaperlessNGXService
|
||||
from chunker import Chunker
|
||||
from cleaner import pdf_to_image, summarize_pdf_image
|
||||
from llm import LLMClient
|
||||
from query import QueryGenerator
|
||||
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
_dotenv_loaded = load_dotenv()
|
||||
|
||||
# Configure ollama client with URL from environment or default to localhost
|
||||
ollama_client = ollama.Client(
|
||||
host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0
|
||||
)
|
||||
|
||||
client = chromadb.PersistentClient(path=os.getenv("CHROMADB_PATH", ""))
|
||||
simba_docs = client.get_or_create_collection(name="simba_docs")
|
||||
simba_docs = client.get_or_create_collection(name="simba_docs2")
|
||||
feline_vet_lookup = client.get_or_create_collection(name="feline_vet_lookup")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -26,59 +36,204 @@ parser.add_argument("query", type=str, help="questions about simba's health")
|
||||
parser.add_argument(
|
||||
"--reindex", action="store_true", help="re-index the simba documents"
|
||||
)
|
||||
parser.add_argument("--index", help="index a file")
|
||||
|
||||
ppngx = PaperlessNGXService()
|
||||
|
||||
llm_client = LLMClient()
|
||||
|
||||
|
||||
def chunk_data(texts: list[str], collection):
|
||||
def index_using_pdf_llm(doctypes):
|
||||
logging.info("reindex data...")
|
||||
files = ppngx.get_data()
|
||||
for file in files:
|
||||
document_id: int = file["id"]
|
||||
pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
||||
image_paths = pdf_to_image(filepath=pdf_path)
|
||||
logging.info(f"summarizing {file}")
|
||||
generated_summary = summarize_pdf_image(filepaths=image_paths)
|
||||
file["content"] = generated_summary
|
||||
|
||||
chunk_data(files, simba_docs, doctypes=doctypes)
|
||||
|
||||
|
||||
def date_to_epoch(date_str: str) -> float:
|
||||
split_date = date_str.split("-")
|
||||
date = datetime.datetime(
|
||||
int(split_date[0]),
|
||||
int(split_date[1]),
|
||||
int(split_date[2]),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
return date.timestamp()
|
||||
|
||||
|
||||
def chunk_data(docs, collection, doctypes):
|
||||
# Step 2: Create chunks
|
||||
chunker = Chunker(collection)
|
||||
|
||||
print(f"chunking {len(texts)} documents")
|
||||
for text in texts:
|
||||
chunker.chunk_document(document=text)
|
||||
logging.info(f"chunking {len(docs)} documents")
|
||||
texts: list[str] = [doc["content"] for doc in docs]
|
||||
with sqlite3.connect("visited.db") as conn:
|
||||
to_insert = []
|
||||
c = conn.cursor()
|
||||
for index, text in enumerate(texts):
|
||||
metadata = {
|
||||
"created_date": date_to_epoch(docs[index]["created_date"]),
|
||||
"filename": docs[index]["original_file_name"],
|
||||
"document_type": doctypes.get(docs[index]["document_type"], ""),
|
||||
}
|
||||
|
||||
if doctypes:
|
||||
metadata["type"] = doctypes.get(docs[index]["document_type"])
|
||||
|
||||
chunker.chunk_document(
|
||||
document=text,
|
||||
metadata=metadata,
|
||||
)
|
||||
to_insert.append((docs[index]["id"],))
|
||||
|
||||
c.executemany(
|
||||
"INSERT INTO indexed_documents (paperless_id) values (?)", to_insert
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def chunk_text(texts: list[str], collection):
|
||||
chunker = Chunker(collection)
|
||||
|
||||
for index, text in enumerate(texts):
|
||||
metadata = {}
|
||||
chunker.chunk_document(
|
||||
document=text,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def consult_oracle(input: str, collection):
|
||||
# Ask
|
||||
embeddings = Chunker.embedding_fx(input=[input])
|
||||
results = collection.query(query_texts=[input], query_embeddings=embeddings)
|
||||
import time
|
||||
|
||||
# Generate
|
||||
output = ollama.generate(
|
||||
model="gemma3n:e4b",
|
||||
prompt=f"You are a helpful assistant that understandings veterinary terms. Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}",
|
||||
chunker = Chunker(collection)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Ask
|
||||
logging.info("Starting query generation")
|
||||
qg_start = time.time()
|
||||
qg = QueryGenerator()
|
||||
doctype_query = qg.get_doctype_query(input=input)
|
||||
# metadata_filter = qg.get_query(input)
|
||||
metadata_filter = {**doctype_query}
|
||||
logging.info(metadata_filter)
|
||||
qg_end = time.time()
|
||||
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
|
||||
|
||||
logging.info("Starting embedding generation")
|
||||
embedding_start = time.time()
|
||||
embeddings = chunker.embedding_fx(inputs=[input])
|
||||
embedding_end = time.time()
|
||||
logging.info(
|
||||
f"Embedding generation took {embedding_end - embedding_start:.2f} seconds"
|
||||
)
|
||||
|
||||
print(output["response"])
|
||||
logging.info("Starting collection query")
|
||||
query_start = time.time()
|
||||
results = collection.query(
|
||||
query_texts=[input],
|
||||
query_embeddings=embeddings,
|
||||
where=metadata_filter,
|
||||
)
|
||||
query_end = time.time()
|
||||
logging.info(f"Collection query took {query_end - query_start:.2f} seconds")
|
||||
|
||||
# Generate
|
||||
logging.info("Starting LLM generation")
|
||||
llm_start = time.time()
|
||||
system_prompt = "You are a helpful assistant that understands veterinary terms."
|
||||
prompt = f"Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}"
|
||||
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
|
||||
llm_end = time.time()
|
||||
logging.info(f"LLM generation took {llm_end - llm_start:.2f} seconds")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logging.info(f"Total consult_oracle execution took {total_time:.2f} seconds")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def paperless_workflow(input):
|
||||
# Step 1: Get the text
|
||||
ppngx = PaperlessNGXService()
|
||||
docs = ppngx.get_data()
|
||||
texts = [doc["content"] for doc in docs]
|
||||
|
||||
chunk_data(texts, collection=simba_docs)
|
||||
chunk_data(docs, collection=simba_docs)
|
||||
consult_oracle(input, simba_docs)
|
||||
|
||||
|
||||
def consult_simba_oracle(input: str):
|
||||
return consult_oracle(
|
||||
input=input,
|
||||
collection=simba_docs,
|
||||
)
|
||||
|
||||
|
||||
def filter_indexed_files(docs):
|
||||
with sqlite3.connect("visited.db") as conn:
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
|
||||
)
|
||||
c.execute("SELECT paperless_id FROM indexed_documents")
|
||||
rows = c.fetchall()
|
||||
conn.commit()
|
||||
|
||||
visited = {row[0] for row in rows}
|
||||
return [doc for doc in docs if doc["id"] not in visited]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.reindex:
|
||||
logging.info(msg="Fetching documents from Paperless-NGX")
|
||||
logging.info("Fetching documents from Paperless-NGX")
|
||||
ppngx = PaperlessNGXService()
|
||||
docs = ppngx.get_data()
|
||||
texts = [doc["content"] for doc in docs]
|
||||
logging.info(msg=f"Fetched {len(texts)} documents")
|
||||
docs = filter_indexed_files(docs)
|
||||
logging.info(f"Fetched {len(docs)} documents")
|
||||
|
||||
logging.info(msg="Chunking documents now ...")
|
||||
chunk_data(texts, collection=simba_docs)
|
||||
logging.info(msg="Done chunking documents")
|
||||
# Delete all chromadb data
|
||||
ids = simba_docs.get(ids=None, limit=None, offset=0)
|
||||
all_ids = ids["ids"]
|
||||
if len(all_ids) > 0:
|
||||
simba_docs.delete(ids=all_ids)
|
||||
|
||||
# Chunk documents
|
||||
logging.info("Chunking documents now ...")
|
||||
tag_lookup = ppngx.get_tags()
|
||||
doctype_lookup = ppngx.get_doctypes()
|
||||
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
|
||||
logging.info("Done chunking documents")
|
||||
|
||||
# if args.index:
|
||||
# with open(args.index) as file:
|
||||
# extension = args.index.split(".")[-1]
|
||||
# if extension == "pdf":
|
||||
# pdf_path = ppngx.download_pdf_from_id(id=document_id)
|
||||
# image_paths = pdf_to_image(filepath=pdf_path)
|
||||
# print(f"summarizing {file}")
|
||||
# generated_summary = summarize_pdf_image(filepaths=image_paths)
|
||||
# elif extension in [".md", ".txt"]:
|
||||
# chunk_text(texts=[file.readall()], collection=simba_docs)
|
||||
|
||||
if args.query:
|
||||
logging.info("Consulting oracle ...")
|
||||
print(
|
||||
consult_oracle(
|
||||
input=args.query,
|
||||
collection=simba_docs,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print("please provide a query")
|
||||
logging.info("please provide a query")
|
||||
|
||||
63
migrations/models/0_20251025081744_init.py
Normal file
63
migrations/models/0_20251025081744_init.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from tortoise import BaseDBAsyncClient
|
||||
|
||||
RUN_IN_TRANSACTION = True
|
||||
|
||||
|
||||
async def upgrade(db: BaseDBAsyncClient) -> str:
|
||||
return """
|
||||
CREATE TABLE IF NOT EXISTS "conversations" (
|
||||
"id" CHAR(36) NOT NULL PRIMARY KEY,
|
||||
"name" VARCHAR(255) NOT NULL,
|
||||
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "conversation_messages" (
|
||||
"id" CHAR(36) NOT NULL PRIMARY KEY,
|
||||
"text" TEXT NOT NULL,
|
||||
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"speaker" VARCHAR(10) NOT NULL /* USER: user\nSIMBA: simba */,
|
||||
"conversation_id" CHAR(36) NOT NULL REFERENCES "conversations" ("id") ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "users" (
|
||||
"id" CHAR(36) NOT NULL PRIMARY KEY,
|
||||
"username" VARCHAR(255) NOT NULL,
|
||||
"password" BLOB NOT NULL,
|
||||
"email" VARCHAR(100) NOT NULL UNIQUE,
|
||||
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "aerich" (
|
||||
"id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
||||
"version" VARCHAR(255) NOT NULL,
|
||||
"app" VARCHAR(100) NOT NULL,
|
||||
"content" JSON NOT NULL
|
||||
);"""
|
||||
|
||||
|
||||
async def downgrade(db: BaseDBAsyncClient) -> str:
|
||||
return """
|
||||
"""
|
||||
|
||||
|
||||
MODELS_STATE = (
|
||||
"eJztmG1v4jgQx79KlFddaa9q2W53VZ1OCpTecrvACcLdPtwqMskAVhMnazvboorvfrbJE4"
|
||||
"kJpWq3UPGmhRkPtn8ztv/2nRmEHvjsuBWSn0AZ4jgk5oVxZxIUgPig9b82TBRFuVcaOBr7"
|
||||
"KsAttFQeNGacIpcL5wT5DITJA+ZSHCWdkdj3pTF0RUNMprkpJvhHDA4Pp8BnQIXj23dhxs"
|
||||
"SDW2Dp1+jamWDwvZVxY0/2rewOn0fKNhp1Lq9US9nd2HFDPw5I3jqa81lIsuZxjL1jGSN9"
|
||||
"UyBAEQevMA05ymTaqWk5YmHgNIZsqF5u8GCCYl/CMH+fxMSVDAzVk/xz9oe5BR6BWqLFhE"
|
||||
"sWd4vlrPI5K6spu2p9sAZHb85fqVmGjE+pcioi5kIFIo6WoYprDlL9r6BszRDVo0zbl2CK"
|
||||
"gT4EY2rIOeY1lIJMAT2MmhmgW8cHMuUz8bXx9m0Nxn+sgSIpWimUoajrZdX3Eldj6ZNIc4"
|
||||
"QuBTllB/EqyEvh4TgAPczVyBJSLwk9Tj/sKGAxB69P/HmyCGr42p1ue2hb3b/lTALGfvgK"
|
||||
"kWW3paehrPOS9ei8lIrsR4x/O/YHQ341vvZ77XLtZ+3sr6YcE4p56JDwxkFeYb2m1hTMSm"
|
||||
"LjyHtgYlcjD4l91sSqwcuTZHJd2AKlYYzc6xtEPWfFUzgdgTE0BVZNfzOJvPo4AD87NkuJ"
|
||||
"1hyu3eUv7mbGF2kZp9YivLARrqNXdQWNoGxBRMzbS/qWPdXQ2aBQChDvJ1ScYiIPgmWvBQ"
|
||||
"uHW812bAurHmXafl8ES9022/5sr+ywqSw56lqfX63ssp/6vT/T5gUZ0/rUbx7Uy0s85Krq"
|
||||
"hUWAroHqxX2bxIHKakfgQMSFSnYL4c+8dMzRsD24MGIG9D8y7HSb1oXBcDBG5gNuAKcn97"
|
||||
"gAnJ6s1f/SVVpAxYNmu21eE/qYe/6zblYbtviKHtMDrdK8CingKfkI80r9bpZfO02xoruE"
|
||||
"maKbTEzoykV8EJMEvlzY1rBlXbbNxXpt+5RKbsSUJKpIN2Wv1WpyaR+02f5rM5nHbR+Uij"
|
||||
"H7otF+waNShBi7CammMpuYIDrXwyxGlWCO53x5/9k9nDX0mlKwFvWWYNbs9KzBF73mTdsX"
|
||||
"C7f5xW5bJbwQIOxvU6ZZwOPU6OYl/5gVenpyP9VTJ3uquudwcXiZF4fDs+eLSOy2z55PKQ"
|
||||
"0toNid6cRh4qmVhyhvszP6sEPWvDdp5aHU9KVqTxL2rIeEemr9rXF69u7s/Zvzs/eiiRpJ"
|
||||
"ZnlXU/2dnr1BDsrLivYOt/6YLYQcxGAGUi6NLSAmzfcT4NNolZBwIJrz7K9hv7f2bSYNKY"
|
||||
"EcETHBbx52+WvDx4x/302sNRTlrOsfkstvxqXDSP5AU/eK8yuPl8X/Etg7Fw=="
|
||||
)
|
||||
60
migrations/models/1_20251025091926_update.py
Normal file
60
migrations/models/1_20251025091926_update.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from tortoise import BaseDBAsyncClient
|
||||
|
||||
RUN_IN_TRANSACTION = True
|
||||
|
||||
|
||||
async def upgrade(db: BaseDBAsyncClient) -> str:
|
||||
return """
|
||||
-- SQLite doesn't support ADD CONSTRAINT, so we need to recreate the table
|
||||
CREATE TABLE "conversations_new" (
|
||||
"id" CHAR(36) NOT NULL PRIMARY KEY,
|
||||
"name" VARCHAR(255) NOT NULL,
|
||||
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"user_id" CHAR(36),
|
||||
FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON DELETE CASCADE
|
||||
);
|
||||
INSERT INTO "conversations_new" ("id", "name", "created_at", "updated_at")
|
||||
SELECT "id", "name", "created_at", "updated_at" FROM "conversations";
|
||||
DROP TABLE "conversations";
|
||||
ALTER TABLE "conversations_new" RENAME TO "conversations";"""
|
||||
|
||||
|
||||
async def downgrade(db: BaseDBAsyncClient) -> str:
|
||||
return """
|
||||
-- Recreate table without user_id column
|
||||
CREATE TABLE "conversations_new" (
|
||||
"id" CHAR(36) NOT NULL PRIMARY KEY,
|
||||
"name" VARCHAR(255) NOT NULL,
|
||||
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
INSERT INTO "conversations_new" ("id", "name", "created_at", "updated_at")
|
||||
SELECT "id", "name", "created_at", "updated_at" FROM "conversations";
|
||||
DROP TABLE "conversations";
|
||||
ALTER TABLE "conversations_new" RENAME TO "conversations";"""
|
||||
|
||||
|
||||
MODELS_STATE = (
|
||||
"eJztmWtP2zAUhv9KlE8gbQg6xhCaJqWlbB20ndp0F9gUuYnbWiROiJ1Bhfjvs91cnMRNKe"
|
||||
"PSon6B9vic2H5s57w+vdU934Eu2Wn4+C8MCaDIx/qRdqtj4EH2Qdn+RtNBEGSt3EDB0BUB"
|
||||
"tuQpWsCQ0BDYlDWOgEsgMzmQ2CEK4s5w5Lrc6NvMEeFxZoowuoqgRf0xpBMYsoaLP8yMsA"
|
||||
"NvIEm+BpfWCEHXyY0bObxvYbfoNBC2waB1fCI8eXdDy/bdyMOZdzClEx+n7lGEnB0ew9vG"
|
||||
"EMMQUOhI0+CjjKedmGYjZgYaRjAdqpMZHDgCkcth6B9HEbY5A030xP/sf9KXwMNQc7QIU8"
|
||||
"7i9m42q2zOwqrzrhpfjN7Wu4NtMUuf0HEoGgUR/U4EAgpmoYJrBlL8L6FsTECoRpn4F2Cy"
|
||||
"gT4EY2LIOGZ7KAGZAHoYNd0DN5YL8ZhO2Nfa+/cVGL8bPUGSeQmUPtvXs13fiZtqszaONE"
|
||||
"Noh5BP2QK0DPKYtVDkQTXMfGQBqROH7iQfVhQwm4PTxe40PgQVfM1Wu9k3jfY3PhOPkCtX"
|
||||
"IDLMJm+pCeu0YN06KCxF+hDtR8v8ovGv2nm30yzu/dTPPNf5mEBEfQv71xZwpPOaWBMwuY"
|
||||
"WNAueBC5uP3Czsiy5sPHhpXQkMreUyiBTyH2kkHtszLuLDkwZPvaNLZc7gMMrwTvwQojE+"
|
||||
"hVOBsMXGAbCtShax6BjEj1lVaJk1G0UIrlM1Im8KNjs2J0hn2dPoN4zjpi4YDoF9eQ1Cx5"
|
||||
"oD04OEgDEkZaD1OPLktAfdVJqpWcoCrj174mq+VeaxFaz8mi8xytErN3k1r2gBmM3bifvm"
|
||||
"PVXQWaCCJYj3E8OWvJAbUbzWopjCG0XKN5lVjTLxXxdRXJXKmz/NXBZPpO9W2/i5ncvkZ9"
|
||||
"3O58RdksqNs259o5Bfo5AqK2QSQHCpEgP8AtnEkVeSArnVlcJf+Ojog36zd6TxjP4b91vt"
|
||||
"unGkEeQNgX6/Jc7dMvd273HJ3Nude8fkTYUDJCea5V7zitDHfOevqYS1CwWv/5SyxfrZyl"
|
||||
"JcqGkV22VZbfuUSk7cGRTSLblLzNdq/GhvtNn6azO+jssWLeWYddFoz1C4DAAh136o2Jl1"
|
||||
"hEE4VcOUowowh1M6u/+sHs4KenUuWGW9xZjVWx2j90uteRN/eePWf5lNo4AXegC5y2zTNO"
|
||||
"Bx9ujiI/+YO3Rv936qp0r2lHXP5uLwOi8Om9L6q1jYtHJXEoCLyp6l35Efp/a5VvXkJ615"
|
||||
"GjBE9kRXaOW4pVItg8xnZeRyC88pvynVMsdc2Azxyr9ozhSV57e1vf0P+4fvDvYPmYsYSW"
|
||||
"r5UPEyaHXMBeqYHwTllXa+6pBCNto4BcmPxhIQY/f1BPg00s3HFGJFev/a73bmlqqSkALI"
|
||||
"AWYTvHCQTd9oLiL0z2piraDIZ11dVy+W0Au5mT+gripqPWch5u4f/FVgYA=="
|
||||
)
|
||||
@@ -4,4 +4,9 @@ version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = []
|
||||
dependencies = ["chromadb>=1.1.0", "python-dotenv>=1.0.0", "flask>=3.1.2", "httpx>=0.28.1", "ollama>=0.6.0", "openai>=2.0.1", "pydantic>=2.11.9", "pillow>=10.0.0", "pymupdf>=1.24.0", "black>=25.9.0", "pillow-heif>=1.1.1", "flask-jwt-extended>=4.7.1", "bcrypt>=5.0.0", "pony>=0.7.19", "flask-login>=0.6.3", "quart>=0.20.0", "tortoise-orm>=0.25.1", "quart-jwt-extended>=0.1.0", "pre-commit>=4.3.0", "tortoise-orm-stubs>=1.0.2", "aerich>=0.8.0", "tomlkit>=0.13.3"]
|
||||
|
||||
[tool.aerich]
|
||||
tortoise_orm = "app.TORTOISE_CONFIG"
|
||||
location = "./migrations"
|
||||
src_folder = "./."
|
||||
|
||||
138
query.py
138
query.py
@@ -1,10 +1,18 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Literal
|
||||
import datetime
|
||||
from ollama import Client
|
||||
|
||||
from ollama import chat, ChatResponse
|
||||
from openai import OpenAI
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Configure ollama client with URL from environment or default to localhost
|
||||
ollama_client = Client(
|
||||
host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0
|
||||
)
|
||||
|
||||
# This uses inferred filters — which means using LLM to create the metadata filters
|
||||
|
||||
|
||||
@@ -28,12 +36,31 @@ class GeneratedQuery(BaseModel):
|
||||
extracted_metadata_fields: str
|
||||
|
||||
|
||||
PROMPT = """
|
||||
You are an information specialist that processes user queries. The user queries are all about
|
||||
a cat, Simba, and its records. The types of records are listed below. Using the query, extract the
|
||||
type of record the user is trying to query and the date range the user is trying to query.
|
||||
class Time(BaseModel):
|
||||
time: int
|
||||
|
||||
|
||||
DOCTYPE_OPTIONS = [
|
||||
"Bill",
|
||||
"Image Description",
|
||||
"Insurance",
|
||||
"Medical Record",
|
||||
"Documentation",
|
||||
"Letter",
|
||||
]
|
||||
|
||||
|
||||
class DocumentType(BaseModel):
|
||||
type: list[str] = Field(description="type of document", enum=DOCTYPE_OPTIONS)
|
||||
|
||||
|
||||
PROMPT = """
|
||||
You are an information specialist that processes user queries. The current year is 2025. The user queries are all about
|
||||
a cat, Simba, and its records. The types of records are listed below. Using the query, extract the
|
||||
the date range the user is trying to query. You should return it as a JSON. The date tag is created_date. Return the date in epoch time.
|
||||
|
||||
If the created_date cannot be ascertained, set it to epoch time start.
|
||||
|
||||
You have several operators at your disposal:
|
||||
- $gt: greater than
|
||||
- $gte: greater than or equal
|
||||
@@ -49,18 +76,18 @@ Logical operators:
|
||||
|
||||
### Example 1
|
||||
Query: "Who is Simba's current vet?"
|
||||
Metadata fields: "{"created_date, tags"}"
|
||||
Extracted metadata fields: {"$and": [{"created_date: {"$gt": "2025-01-01"}, "tags": {"$in": ["bill", "medical records", "aftercare"]}}]}
|
||||
Metadata fields: "{"created_date"}"
|
||||
Extracted metadata fields: {"created_date: {"$gt": "2025-01-01"}}
|
||||
|
||||
### Example 2
|
||||
Query: "How many teeth has Simba had removed?"
|
||||
Metadata fields: {"tags"}
|
||||
Extracted metadata fields: {"tags": "medical records"}
|
||||
Metadata fields: {}
|
||||
Extracted metadata fields: {}
|
||||
|
||||
### Example 3
|
||||
Query: "How many times has Simba been to the vet this year?"
|
||||
Metadata fields: {"tags", "created_date"}
|
||||
Extracted metadata fields: {"$and": [{"created_date": {"gt": "2025-01-01"}, "tags": {"$in": ["bill"]}}]}
|
||||
Metadata fields: {"created_date"}
|
||||
Extracted metadata fields: {"created_date": {"gt": "2025-01-01"}}
|
||||
|
||||
document_types:
|
||||
- aftercare
|
||||
@@ -72,27 +99,96 @@ Only return the extracted metadata fields. Make sure the extracted metadata fiel
|
||||
"""
|
||||
|
||||
|
||||
DOCTYPE_PROMPT = f"""You are an information specialist that processes user queries. A query can have two tags attached from the following options. Based on the query, determine which of the following options is most appropriate: {",".join(DOCTYPE_OPTIONS)}
|
||||
|
||||
### Example 1
|
||||
Query: "Who is Simba's current vet?"
|
||||
Tags: ["Bill", "Medical Record"]
|
||||
|
||||
|
||||
### Example 2
|
||||
Query: "Who does Simba know?"
|
||||
Tags: ["Letter", "Documentation"]
|
||||
"""
|
||||
|
||||
|
||||
class QueryGenerator:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_query(self, input: str):
|
||||
response: ChatResponse = chat(
|
||||
model="gemma3n:e4b",
|
||||
def date_to_epoch(self, date_str: str) -> float:
|
||||
split_date = date_str.split("-")
|
||||
date = datetime.datetime(
|
||||
int(split_date[0]),
|
||||
int(split_date[1]),
|
||||
int(split_date[2]),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
return date.timestamp()
|
||||
|
||||
def get_doctype_query(self, input: str):
|
||||
client = OpenAI()
|
||||
response = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an information specialist that is really good at deciding what tags a query should have",
|
||||
},
|
||||
{"role": "user", "content": DOCTYPE_PROMPT + " " + input},
|
||||
],
|
||||
model="gpt-4o",
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "document_type",
|
||||
"schema": DocumentType.model_json_schema(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
response_json_str = response.choices[0].message.content
|
||||
type_data = json.loads(response_json_str)
|
||||
metadata_query = {"document_type": {"$in": type_data["type"]}}
|
||||
return metadata_query
|
||||
|
||||
def get_query(self, input: str):
|
||||
client = OpenAI()
|
||||
response = client.responses.parse(
|
||||
model="gpt-4o",
|
||||
input=[
|
||||
{"role": "system", "content": PROMPT},
|
||||
{"role": "user", "content": input},
|
||||
],
|
||||
format=GeneratedQuery.model_json_schema(),
|
||||
text_format=GeneratedQuery,
|
||||
)
|
||||
print(response.output)
|
||||
query = json.loads(response.output_parsed.extracted_metadata_fields)
|
||||
# response: ChatResponse = ollama_client.chat(
|
||||
# model="gemma3n:e4b",
|
||||
# messages=[
|
||||
# {"role": "system", "content": PROMPT},
|
||||
# {"role": "user", "content": input},
|
||||
# ],
|
||||
# format=GeneratedQuery.model_json_schema(),
|
||||
# )
|
||||
|
||||
print(
|
||||
json.loads(
|
||||
json.loads(response["message"]["content"])["extracted_metadata_fields"]
|
||||
)
|
||||
)
|
||||
# query = json.loads(
|
||||
# json.loads(response["message"]["content"])["extracted_metadata_fields"]
|
||||
# )
|
||||
# date_key = list(query["created_date"].keys())[0]
|
||||
# query["created_date"][date_key] = self.date_to_epoch(
|
||||
# query["created_date"][date_key]
|
||||
# )
|
||||
|
||||
# if "$" not in date_key:
|
||||
# query["created_date"]["$" + date_key] = query["created_date"][date_key]
|
||||
|
||||
return query
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
qg = QueryGenerator()
|
||||
qg.get_query("How old is Simba?")
|
||||
print(qg.get_doctype_query("How heavy is Simba?"))
|
||||
|
||||
16
raggr-frontend/.gitignore
vendored
Normal file
16
raggr-frontend/.gitignore
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
# Local
|
||||
.DS_Store
|
||||
*.local
|
||||
*.log*
|
||||
|
||||
# Dist
|
||||
node_modules
|
||||
dist/
|
||||
|
||||
# Profile
|
||||
.rspack-profile-*/
|
||||
|
||||
# IDE
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
36
raggr-frontend/README.md
Normal file
36
raggr-frontend/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Rsbuild project
|
||||
|
||||
## Setup
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```bash
|
||||
pnpm install
|
||||
```
|
||||
|
||||
## Get started
|
||||
|
||||
Start the dev server, and the app will be available at [http://localhost:3000](http://localhost:3000).
|
||||
|
||||
```bash
|
||||
pnpm dev
|
||||
```
|
||||
|
||||
Build the app for production:
|
||||
|
||||
```bash
|
||||
pnpm build
|
||||
```
|
||||
|
||||
Preview the production build locally:
|
||||
|
||||
```bash
|
||||
pnpm preview
|
||||
```
|
||||
|
||||
## Learn more
|
||||
|
||||
To learn more about Rsbuild, check out the following resources:
|
||||
|
||||
- [Rsbuild documentation](https://rsbuild.rs) - explore Rsbuild features and APIs.
|
||||
- [Rsbuild GitHub repository](https://github.com/web-infra-dev/rsbuild) - your feedback and contributions are welcome!
|
||||
63
raggr-frontend/TOKEN_REFRESH_IMPLEMENTATION.md
Normal file
63
raggr-frontend/TOKEN_REFRESH_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,63 @@
|
||||
# Token Refresh Implementation
|
||||
|
||||
## Overview
|
||||
|
||||
The API services now automatically handle token refresh when access tokens expire. This provides a seamless user experience without requiring manual re-authentication.
|
||||
|
||||
## How It Works
|
||||
|
||||
### 1. **userService.ts**
|
||||
|
||||
The `userService` now includes:
|
||||
|
||||
- **`refreshToken()`**: Automatically gets the refresh token from localStorage, calls the `/api/user/refresh` endpoint, and updates the access token
|
||||
- **`fetchWithAuth()`**: A wrapper around `fetch()` that:
|
||||
1. Automatically adds the Authorization header with the access token
|
||||
2. Detects 401 (Unauthorized) responses
|
||||
3. Automatically refreshes the token using the refresh token
|
||||
4. Retries the original request with the new access token
|
||||
5. Throws an error if refresh fails (e.g., refresh token expired)
|
||||
|
||||
### 2. **conversationService.ts**
|
||||
|
||||
Now uses `userService.fetchWithAuth()` for all API calls:
|
||||
- `sendQuery()` - No longer needs token parameter
|
||||
- `getMessages()` - No longer needs token parameter
|
||||
|
||||
### 3. **Components Updated**
|
||||
|
||||
**ChatScreen.tsx**:
|
||||
- Removed manual token handling
|
||||
- Now simply calls `conversationService.sendQuery(query)` and `conversationService.getMessages()`
|
||||
|
||||
## Benefits
|
||||
|
||||
✅ **Automatic token refresh** - Users stay logged in longer
|
||||
✅ **Transparent retry logic** - Failed requests due to expired tokens are automatically retried
|
||||
✅ **Cleaner code** - Components don't need to manage tokens
|
||||
✅ **Better UX** - No interruptions when access token expires
|
||||
✅ **Centralized auth logic** - All auth handling in one place
|
||||
|
||||
## Error Handling
|
||||
|
||||
- If refresh token is missing or invalid, the error is thrown
|
||||
- Components can catch these errors and redirect to login
|
||||
- LocalStorage is automatically cleared when refresh fails
|
||||
|
||||
## Usage Example
|
||||
|
||||
```typescript
|
||||
// Old way (manual token management)
|
||||
const token = localStorage.getItem("access_token");
|
||||
const result = await conversationService.sendQuery(query, token);
|
||||
|
||||
// New way (automatic token refresh)
|
||||
const result = await conversationService.sendQuery(query);
|
||||
```
|
||||
|
||||
## Token Storage
|
||||
|
||||
- **Access Token**: `localStorage.getItem("access_token")`
|
||||
- **Refresh Token**: `localStorage.getItem("refresh_token")`
|
||||
|
||||
Both are automatically managed by the services.
|
||||
26
raggr-frontend/package.json
Normal file
26
raggr-frontend/package.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"name": "raggr-frontend",
|
||||
"version": "1.0.0",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"build": "rsbuild build",
|
||||
"dev": "rsbuild dev --open",
|
||||
"preview": "rsbuild preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.12.2",
|
||||
"marked": "^16.3.0",
|
||||
"react": "^19.1.1",
|
||||
"react-dom": "^19.1.1",
|
||||
"react-markdown": "^10.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@rsbuild/core": "^1.5.6",
|
||||
"@rsbuild/plugin-react": "^1.4.0",
|
||||
"@tailwindcss/postcss": "^4.0.0",
|
||||
"@types/react": "^19.1.13",
|
||||
"@types/react-dom": "^19.1.9",
|
||||
"typescript": "^5.9.2"
|
||||
}
|
||||
}
|
||||
5
raggr-frontend/postcss.config.mjs
Normal file
5
raggr-frontend/postcss.config.mjs
Normal file
@@ -0,0 +1,5 @@
|
||||
export default {
|
||||
plugins: {
|
||||
"@tailwindcss/postcss": {},
|
||||
},
|
||||
};
|
||||
6
raggr-frontend/rsbuild.config.ts
Normal file
6
raggr-frontend/rsbuild.config.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
import { defineConfig } from '@rsbuild/core';
|
||||
import { pluginReact } from '@rsbuild/plugin-react';
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [pluginReact()],
|
||||
});
|
||||
6
raggr-frontend/src/App.css
Normal file
6
raggr-frontend/src/App.css
Normal file
@@ -0,0 +1,6 @@
|
||||
@import "tailwindcss";
|
||||
|
||||
body {
|
||||
margin: 0;
|
||||
font-family: Inter, Avenir, Helvetica, Arial, sans-serif;
|
||||
}
|
||||
72
raggr-frontend/src/App.tsx
Normal file
72
raggr-frontend/src/App.tsx
Normal file
@@ -0,0 +1,72 @@
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
import "./App.css";
|
||||
import { AuthProvider } from "./contexts/AuthContext";
|
||||
import { ChatScreen } from "./components/ChatScreen";
|
||||
import { LoginScreen } from "./components/LoginScreen";
|
||||
import { conversationService } from "./api/conversationService";
|
||||
|
||||
const AppContainer = () => {
|
||||
const [isAuthenticated, setAuthenticated] = useState<boolean>(false);
|
||||
const [isChecking, setIsChecking] = useState<boolean>(true);
|
||||
|
||||
useEffect(() => {
|
||||
const checkAuth = async () => {
|
||||
const accessToken = localStorage.getItem("access_token");
|
||||
const refreshToken = localStorage.getItem("refresh_token");
|
||||
|
||||
// No tokens at all, not authenticated
|
||||
if (!accessToken && !refreshToken) {
|
||||
setIsChecking(false);
|
||||
setAuthenticated(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to verify token by making a request
|
||||
try {
|
||||
await conversationService.getMessages();
|
||||
// If successful, user is authenticated
|
||||
setAuthenticated(true);
|
||||
} catch (error) {
|
||||
// Token is invalid or expired
|
||||
console.error("Authentication check failed:", error);
|
||||
localStorage.removeItem("access_token");
|
||||
localStorage.removeItem("refresh_token");
|
||||
setAuthenticated(false);
|
||||
} finally {
|
||||
setIsChecking(false);
|
||||
}
|
||||
};
|
||||
|
||||
checkAuth();
|
||||
}, []);
|
||||
|
||||
// Show loading state while checking authentication
|
||||
if (isChecking) {
|
||||
return (
|
||||
<div className="h-screen flex items-center justify-center bg-white/85">
|
||||
<div className="text-xl">Loading...</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{isAuthenticated ? (
|
||||
<ChatScreen setAuthenticated={setAuthenticated} />
|
||||
) : (
|
||||
<LoginScreen setAuthenticated={setAuthenticated} />
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const App = () => {
|
||||
return (
|
||||
<AuthProvider>
|
||||
<AppContainer />
|
||||
</AuthProvider>
|
||||
);
|
||||
};
|
||||
|
||||
export default App;
|
||||
61
raggr-frontend/src/api/conversationService.ts
Normal file
61
raggr-frontend/src/api/conversationService.ts
Normal file
@@ -0,0 +1,61 @@
|
||||
import { userService } from "./userService";
|
||||
|
||||
interface Message {
|
||||
id: string;
|
||||
text: string;
|
||||
speaker: "user" | "simba";
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
interface Conversation {
|
||||
id: string;
|
||||
name: string;
|
||||
messages: Message[];
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
interface QueryRequest {
|
||||
query: string;
|
||||
}
|
||||
|
||||
interface QueryResponse {
|
||||
response: string;
|
||||
}
|
||||
|
||||
class ConversationService {
|
||||
private baseUrl = "/api";
|
||||
|
||||
async sendQuery(query: string): Promise<QueryResponse> {
|
||||
const response = await userService.fetchWithRefreshToken(
|
||||
`${this.baseUrl}/query`,
|
||||
{
|
||||
method: "POST",
|
||||
body: JSON.stringify({ query }),
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to send query");
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
async getMessages(): Promise<Conversation> {
|
||||
const response = await userService.fetchWithRefreshToken(
|
||||
`${this.baseUrl}/messages`,
|
||||
{
|
||||
method: "GET",
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to fetch messages");
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
}
|
||||
}
|
||||
|
||||
export const conversationService = new ConversationService();
|
||||
123
raggr-frontend/src/api/userService.ts
Normal file
123
raggr-frontend/src/api/userService.ts
Normal file
@@ -0,0 +1,123 @@
|
||||
interface LoginResponse {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
user: {
|
||||
id: string;
|
||||
username: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface RefreshResponse {
|
||||
access_token: string;
|
||||
}
|
||||
|
||||
class UserService {
|
||||
private baseUrl = "/api/user";
|
||||
|
||||
async login(username: string, password: string): Promise<LoginResponse> {
|
||||
const response = await fetch(`${this.baseUrl}/login`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ username, password }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error("Invalid credentials");
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
async refreshToken(): Promise<string> {
|
||||
const refreshToken = localStorage.getItem("refresh_token");
|
||||
|
||||
if (!refreshToken) {
|
||||
throw new Error("No refresh token available");
|
||||
}
|
||||
|
||||
const response = await fetch(`${this.baseUrl}/refresh`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${refreshToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
// Refresh token is invalid or expired, clear storage
|
||||
localStorage.removeItem("access_token");
|
||||
localStorage.removeItem("refresh_token");
|
||||
throw new Error("Failed to refresh token");
|
||||
}
|
||||
|
||||
const data: RefreshResponse = await response.json();
|
||||
localStorage.setItem("access_token", data.access_token);
|
||||
return data.access_token;
|
||||
}
|
||||
|
||||
async fetchWithAuth(
|
||||
url: string,
|
||||
options: RequestInit = {},
|
||||
): Promise<Response> {
|
||||
const accessToken = localStorage.getItem("access_token");
|
||||
|
||||
// Add authorization header
|
||||
const headers = {
|
||||
"Content-Type": "application/json",
|
||||
...(options.headers || {}),
|
||||
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
|
||||
};
|
||||
|
||||
let response = await fetch(url, { ...options, headers });
|
||||
|
||||
// If unauthorized, try refreshing the token
|
||||
if (response.status === 401) {
|
||||
try {
|
||||
const newAccessToken = await this.refreshToken();
|
||||
|
||||
// Retry the request with new token
|
||||
headers.Authorization = `Bearer ${newAccessToken}`;
|
||||
response = await fetch(url, { ...options, headers });
|
||||
} catch (error) {
|
||||
// Refresh failed, redirect to login or throw error
|
||||
throw new Error("Session expired. Please log in again.");
|
||||
}
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
async fetchWithRefreshToken(
|
||||
url: string,
|
||||
options: RequestInit = {},
|
||||
): Promise<Response> {
|
||||
const refreshToken = localStorage.getItem("refresh_token");
|
||||
|
||||
// Add authorization header
|
||||
const headers = {
|
||||
"Content-Type": "application/json",
|
||||
...(options.headers || {}),
|
||||
...(refreshToken && { Authorization: `Bearer ${refreshToken}` }),
|
||||
};
|
||||
|
||||
let response = await fetch(url, { ...options, headers });
|
||||
|
||||
// If unauthorized, try refreshing the token
|
||||
if (response.status === 401) {
|
||||
try {
|
||||
const newAccessToken = await this.refreshToken();
|
||||
|
||||
// Retry the request with new token
|
||||
headers.Authorization = `Bearer ${newAccessToken}`;
|
||||
response = await fetch(url, { ...options, headers });
|
||||
} catch (error) {
|
||||
// Refresh failed, redirect to login or throw error
|
||||
throw new Error("Session expired. Please log in again.");
|
||||
}
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
export const userService = new UserService();
|
||||
29
raggr-frontend/src/components/AnswerBubble.tsx
Normal file
29
raggr-frontend/src/components/AnswerBubble.tsx
Normal file
@@ -0,0 +1,29 @@
|
||||
import ReactMarkdown from "react-markdown";
|
||||
|
||||
type AnswerBubbleProps = {
|
||||
text: string;
|
||||
loading?: boolean;
|
||||
};
|
||||
|
||||
export const AnswerBubble = ({ text, loading }: AnswerBubbleProps) => {
|
||||
return (
|
||||
<div className="rounded-md bg-orange-100 p-3">
|
||||
{loading ? (
|
||||
<div className="flex flex-col w-full animate-pulse gap-2">
|
||||
<div className="flex flex-row gap-2 w-full">
|
||||
<div className="bg-gray-400 w-1/2 p-3 rounded-lg" />
|
||||
<div className="bg-gray-400 w-1/2 p-3 rounded-lg" />
|
||||
</div>
|
||||
<div className="flex flex-row gap-2 w-full">
|
||||
<div className="bg-gray-400 w-1/3 p-3 rounded-lg" />
|
||||
<div className="bg-gray-400 w-2/3 p-3 rounded-lg" />
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex flex-col">
|
||||
<ReactMarkdown>{"🐈: " + text}</ReactMarkdown>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
150
raggr-frontend/src/components/ChatScreen.tsx
Normal file
150
raggr-frontend/src/components/ChatScreen.tsx
Normal file
@@ -0,0 +1,150 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import { conversationService } from "../api/conversationService";
|
||||
import { QuestionBubble } from "./QuestionBubble";
|
||||
import { AnswerBubble } from "./AnswerBubble";
|
||||
|
||||
type Message = {
|
||||
text: string;
|
||||
speaker: "simba" | "user";
|
||||
};
|
||||
|
||||
type QuestionAnswer = {
|
||||
question: string;
|
||||
answer: string;
|
||||
};
|
||||
|
||||
type Conversation = {
|
||||
title: string;
|
||||
id: string;
|
||||
};
|
||||
|
||||
type ChatScreenProps = {
|
||||
setAuthenticated: (isAuth: boolean) => void;
|
||||
};
|
||||
|
||||
export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
||||
const [query, setQuery] = useState<string>("");
|
||||
const [answer, setAnswer] = useState<string>("");
|
||||
const [simbaMode, setSimbaMode] = useState<boolean>(false);
|
||||
const [questionsAnswers, setQuestionsAnswers] = useState<QuestionAnswer[]>(
|
||||
[],
|
||||
);
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
const [conversations, setConversations] = useState<Conversation[]>([
|
||||
{ title: "simba meow meow", id: "uuid" },
|
||||
]);
|
||||
|
||||
const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"];
|
||||
|
||||
useEffect(() => {
|
||||
const loadMessages = async () => {
|
||||
try {
|
||||
const conversation = await conversationService.getMessages();
|
||||
setMessages(
|
||||
conversation.messages.map((message) => ({
|
||||
text: message.text,
|
||||
speaker: message.speaker,
|
||||
})),
|
||||
);
|
||||
} catch (error) {
|
||||
console.error("Failed to load messages:", error);
|
||||
}
|
||||
};
|
||||
loadMessages();
|
||||
}, []);
|
||||
|
||||
const handleQuestionSubmit = async () => {
|
||||
const currMessages = messages.concat([{ text: query, speaker: "user" }]);
|
||||
setMessages(currMessages);
|
||||
|
||||
if (simbaMode) {
|
||||
console.log("simba mode activated");
|
||||
const randomIndex = Math.floor(Math.random() * simbaAnswers.length);
|
||||
const randomElement = simbaAnswers[randomIndex];
|
||||
setAnswer(randomElement);
|
||||
setQuestionsAnswers(
|
||||
questionsAnswers.concat([
|
||||
{
|
||||
question: query,
|
||||
answer: randomElement,
|
||||
},
|
||||
]),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await conversationService.sendQuery(query);
|
||||
setQuestionsAnswers(
|
||||
questionsAnswers.concat([{ question: query, answer: result.response }]),
|
||||
);
|
||||
setMessages(
|
||||
currMessages.concat([{ text: result.response, speaker: "simba" }]),
|
||||
);
|
||||
setQuery(""); // Clear input after successful send
|
||||
} catch (error) {
|
||||
console.error("Failed to send query:", error);
|
||||
// If session expired, redirect to login
|
||||
if (error instanceof Error && error.message.includes("Session expired")) {
|
||||
setAuthenticated(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleQueryChange = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
setQuery(event.target.value);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="h-screen bg-opacity-20">
|
||||
<div className="bg-white/85 h-screen">
|
||||
<div className="flex flex-row justify-center py-4">
|
||||
<div className="flex flex-col gap-4 min-w-xl max-w-xl">
|
||||
<div className="flex flex-row justify-between">
|
||||
<header className="flex flex-row justify-center gap-2 grow sticky top-0 z-10 bg-white">
|
||||
<h1 className="text-3xl">ask simba!</h1>
|
||||
</header>
|
||||
<button
|
||||
className="p-4 border border-red-400 bg-red-200 hover:bg-red-400 cursor-pointer rounded-md"
|
||||
onClick={() => setAuthenticated(false)}
|
||||
>
|
||||
logout
|
||||
</button>
|
||||
</div>
|
||||
{messages.map((msg, index) => {
|
||||
if (msg.speaker === "simba") {
|
||||
return <AnswerBubble key={index} text={msg.text} />;
|
||||
}
|
||||
return <QuestionBubble key={index} text={msg.text} />;
|
||||
})}
|
||||
<footer className="flex flex-col gap-2 sticky bottom-0">
|
||||
<div className="flex flex-row justify-between gap-2 grow">
|
||||
<textarea
|
||||
className="p-4 border border-blue-200 rounded-md grow bg-white"
|
||||
onChange={handleQueryChange}
|
||||
value={query}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-row justify-between gap-2 grow">
|
||||
<button
|
||||
className="p-4 border border-blue-400 bg-blue-200 hover:bg-blue-400 cursor-pointer rounded-md flex-grow"
|
||||
onClick={() => handleQuestionSubmit()}
|
||||
type="submit"
|
||||
>
|
||||
Submit
|
||||
</button>
|
||||
</div>
|
||||
<div className="flex flex-row justify-center gap-2 grow">
|
||||
<input
|
||||
type="checkbox"
|
||||
onChange={(event) => setSimbaMode(event.target.checked)}
|
||||
/>
|
||||
<p>simba mode?</p>
|
||||
</div>
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
24
raggr-frontend/src/components/ConversationMenu.tsx
Normal file
24
raggr-frontend/src/components/ConversationMenu.tsx
Normal file
@@ -0,0 +1,24 @@
|
||||
type Conversation = {
|
||||
title: string;
|
||||
id: string;
|
||||
};
|
||||
|
||||
type ConversationMenuProps = {
|
||||
conversations: Conversation[];
|
||||
};
|
||||
|
||||
export const ConversationMenu = ({ conversations }: ConversationMenuProps) => {
|
||||
return (
|
||||
<div className="absolute bg-white w-md rounded-md shadow-xl m-4 p-4">
|
||||
<p className="py-2 px-4 rounded-md w-full text-xl font-bold">askSimba!</p>
|
||||
{conversations.map((conversation) => (
|
||||
<p
|
||||
key={conversation.id}
|
||||
className="py-2 px-4 rounded-md hover:bg-stone-200 w-full text-xl font-bold cursor-pointer"
|
||||
>
|
||||
{conversation.title}
|
||||
</p>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
80
raggr-frontend/src/components/LoginScreen.tsx
Normal file
80
raggr-frontend/src/components/LoginScreen.tsx
Normal file
@@ -0,0 +1,80 @@
|
||||
import { useState } from "react";
|
||||
import { userService } from "../api/userService";
|
||||
|
||||
type LoginScreenProps = {
|
||||
setAuthenticated: (isAuth: boolean) => void;
|
||||
};
|
||||
|
||||
export const LoginScreen = ({ setAuthenticated }: LoginScreenProps) => {
|
||||
const [username, setUsername] = useState<string>("");
|
||||
const [password, setPassword] = useState<string>("");
|
||||
const [error, setError] = useState<string>("");
|
||||
|
||||
const handleLogin = async () => {
|
||||
if (!username || !password) {
|
||||
setError("Please enter username and password");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await userService.login(username, password);
|
||||
localStorage.setItem("access_token", result.access_token);
|
||||
localStorage.setItem("refresh_token", result.refresh_token);
|
||||
setAuthenticated(true);
|
||||
setError("");
|
||||
} catch (err) {
|
||||
setError("Login failed. Please check your credentials.");
|
||||
console.error("Login error:", err);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="h-screen bg-opacity-20">
|
||||
<div className="bg-white/85 h-screen">
|
||||
<div className="flex flex-row justify-center py-4">
|
||||
<div className="flex flex-col gap-4 min-w-xl max-w-xl">
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex flex-grow justify-center w-full bg-amber-400">
|
||||
<h1 className="text-xl font-bold">
|
||||
I AM LOOKING FOR A DESIGNER. THIS APP WILL REMAIN UGLY UNTIL A
|
||||
DESIGNER COMES.
|
||||
</h1>
|
||||
</div>
|
||||
<header className="flex flex-row justify-center gap-2 grow sticky top-0 z-10 bg-white">
|
||||
<h1 className="text-3xl">ask simba!</h1>
|
||||
</header>
|
||||
<label htmlFor="username">username</label>
|
||||
<input
|
||||
type="text"
|
||||
id="username"
|
||||
name="username"
|
||||
value={username}
|
||||
onChange={(e) => setUsername(e.target.value)}
|
||||
className="border border-s-slate-950 p-3 rounded-md"
|
||||
/>
|
||||
<label htmlFor="password">password</label>
|
||||
<input
|
||||
type="password"
|
||||
id="password"
|
||||
name="password"
|
||||
value={password}
|
||||
onChange={(e) => setPassword(e.target.value)}
|
||||
className="border border-s-slate-950 p-3 rounded-md"
|
||||
/>
|
||||
{error && (
|
||||
<div className="text-red-600 font-semibold">{error}</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<button
|
||||
className="p-4 border border-blue-400 bg-blue-200 hover:bg-blue-400 cursor-pointer rounded-md flex-grow"
|
||||
onClick={handleLogin}
|
||||
>
|
||||
login
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
7
raggr-frontend/src/components/QuestionBubble.tsx
Normal file
7
raggr-frontend/src/components/QuestionBubble.tsx
Normal file
@@ -0,0 +1,7 @@
|
||||
type QuestionBubbleProps = {
|
||||
text: string;
|
||||
};
|
||||
|
||||
export const QuestionBubble = ({ text }: QuestionBubbleProps) => {
|
||||
return <div className="rounded-md bg-stone-200 p-3">🤦: {text}</div>;
|
||||
};
|
||||
56
raggr-frontend/src/contexts/AuthContext.tsx
Normal file
56
raggr-frontend/src/contexts/AuthContext.tsx
Normal file
@@ -0,0 +1,56 @@
|
||||
import { createContext, useContext, useState, ReactNode } from "react";
|
||||
import { userService } from "../api/userService";
|
||||
|
||||
interface AuthContextType {
|
||||
token: string | null;
|
||||
login: (username: string, password: string) => Promise<any>;
|
||||
logout: () => void;
|
||||
isAuthenticated: () => boolean;
|
||||
}
|
||||
|
||||
const AuthContext = createContext<AuthContextType | undefined>(undefined);
|
||||
|
||||
interface AuthProviderProps {
|
||||
children: ReactNode;
|
||||
}
|
||||
|
||||
export const AuthProvider = ({ children }: AuthProviderProps) => {
|
||||
const [token, setToken] = useState(localStorage.getItem("access_token"));
|
||||
|
||||
const login = async (username: string, password: string) => {
|
||||
try {
|
||||
const data = await userService.login(username, password);
|
||||
setToken(data.access_token);
|
||||
localStorage.setItem("access_token", data.access_token);
|
||||
localStorage.setItem("refresh_token", data.refresh_token);
|
||||
return data;
|
||||
} catch (error) {
|
||||
console.error("Login failed:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
const logout = () => {
|
||||
setToken(null);
|
||||
localStorage.removeItem("access_token");
|
||||
localStorage.removeItem("refresh_token");
|
||||
};
|
||||
|
||||
const isAuthenticated = () => {
|
||||
return token !== null && token !== undefined && token !== "";
|
||||
};
|
||||
|
||||
return (
|
||||
<AuthContext.Provider value={{ token, login, logout, isAuthenticated }}>
|
||||
{children}
|
||||
</AuthContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
export const useAuth = () => {
|
||||
const context = useContext(AuthContext);
|
||||
if (context === undefined) {
|
||||
throw new Error("useAuth must be used within an AuthProvider");
|
||||
}
|
||||
return context;
|
||||
};
|
||||
11
raggr-frontend/src/env.d.ts
vendored
Normal file
11
raggr-frontend/src/env.d.ts
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
/// <reference types="@rsbuild/core/types" />
|
||||
|
||||
/**
|
||||
* Imports the SVG file as a React component.
|
||||
* @requires [@rsbuild/plugin-svgr](https://npmjs.com/package/@rsbuild/plugin-svgr)
|
||||
*/
|
||||
declare module '*.svg?react' {
|
||||
import type React from 'react';
|
||||
const ReactComponent: React.FunctionComponent<React.SVGProps<SVGSVGElement>>;
|
||||
export default ReactComponent;
|
||||
}
|
||||
13
raggr-frontend/src/index.tsx
Normal file
13
raggr-frontend/src/index.tsx
Normal file
@@ -0,0 +1,13 @@
|
||||
import React from 'react';
|
||||
import ReactDOM from 'react-dom/client';
|
||||
import App from './App';
|
||||
|
||||
const rootEl = document.getElementById('root');
|
||||
if (rootEl) {
|
||||
const root = ReactDOM.createRoot(rootEl);
|
||||
root.render(
|
||||
<React.StrictMode>
|
||||
<App />
|
||||
</React.StrictMode>,
|
||||
);
|
||||
}
|
||||
BIN
raggr-frontend/src/simba_cute.jpeg
Normal file
BIN
raggr-frontend/src/simba_cute.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.4 MiB |
BIN
raggr-frontend/src/simba_troll.jpeg
Normal file
BIN
raggr-frontend/src/simba_troll.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.1 MiB |
25
raggr-frontend/tsconfig.json
Normal file
25
raggr-frontend/tsconfig.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"lib": ["DOM", "ES2020"],
|
||||
"jsx": "react-jsx",
|
||||
"target": "ES2020",
|
||||
"noEmit": true,
|
||||
"skipLibCheck": true,
|
||||
"useDefineForClassFields": true,
|
||||
|
||||
/* modules */
|
||||
"module": "ESNext",
|
||||
"moduleDetection": "force",
|
||||
"moduleResolution": "bundler",
|
||||
"verbatimModuleSyntax": true,
|
||||
"resolveJsonModule": true,
|
||||
"allowImportingTsExtensions": true,
|
||||
"noUncheckedSideEffectImports": true,
|
||||
|
||||
/* type checking */
|
||||
"strict": true,
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true
|
||||
},
|
||||
"include": ["src"]
|
||||
}
|
||||
1424
raggr-frontend/yarn.lock
Normal file
1424
raggr-frontend/yarn.lock
Normal file
File diff suppressed because it is too large
Load Diff
68
request.py
68
request.py
@@ -1,22 +1,84 @@
|
||||
import os
|
||||
import tempfile
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class PaperlessNGXService:
|
||||
def __init__(self):
|
||||
self.base_url = os.getenv("BASE_URL")
|
||||
self.token = os.getenv("PAPERLESS_TOKEN")
|
||||
self.url = f"http://{os.getenv("BASE_URL")}/api/documents/?query=simba"
|
||||
self.headers = {"Authorization": f"Token {os.getenv("PAPERLESS_TOKEN")}"}
|
||||
self.url = f"http://{os.getenv('BASE_URL')}/api/documents/?tags__id=8"
|
||||
self.headers = {"Authorization": f"Token {os.getenv('PAPERLESS_TOKEN')}"}
|
||||
|
||||
def get_data(self):
|
||||
print(f"Getting data from: {self.url}")
|
||||
r = httpx.get(self.url, headers=self.headers)
|
||||
return r.json()["results"]
|
||||
results = r.json()["results"]
|
||||
|
||||
nextLink = r.json().get("next")
|
||||
|
||||
while nextLink:
|
||||
r = httpx.get(nextLink, headers=self.headers)
|
||||
results += r.json()["results"]
|
||||
nextLink = r.json().get("next")
|
||||
|
||||
return results
|
||||
|
||||
def get_doc_by_id(self, doc_id: int):
|
||||
url = f"http://{os.getenv('BASE_URL')}/api/documents/{doc_id}/"
|
||||
r = httpx.get(url, headers=self.headers)
|
||||
return r.json()
|
||||
|
||||
def download_pdf_from_id(self, id: int) -> str:
|
||||
download_url = f"http://{os.getenv('BASE_URL')}/api/documents/{id}/download/"
|
||||
response = httpx.get(
|
||||
download_url, headers=self.headers, follow_redirects=True, timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Use a temporary file for the downloaded PDF
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
|
||||
temp_file.write(response.content)
|
||||
temp_file.close()
|
||||
temp_pdf_path = temp_file.name
|
||||
pdf_to_process = temp_pdf_path
|
||||
return pdf_to_process
|
||||
|
||||
def upload_cleaned_content(self, document_id, data):
|
||||
PUTS_URL = f"http://{os.getenv('BASE_URL')}/api/documents/{document_id}/"
|
||||
r = httpx.put(PUTS_URL, headers=self.headers, data=data)
|
||||
r.raise_for_status()
|
||||
|
||||
def upload_description(self, description_filepath, file, title, exif_date: str):
|
||||
POST_URL = f"http://{os.getenv('BASE_URL')}/api/documents/post_document/"
|
||||
files = {"document": ("description_filepath", file, "application/txt")}
|
||||
data = {
|
||||
"title": title,
|
||||
"create": exif_date,
|
||||
"document_type": 3,
|
||||
"tags": [7],
|
||||
}
|
||||
|
||||
r = httpx.post(POST_URL, headers=self.headers, data=data, files=files)
|
||||
r.raise_for_status()
|
||||
|
||||
def get_tags(self):
|
||||
GET_URL = f"http://{os.getenv('BASE_URL')}/api/tags/"
|
||||
r = httpx.get(GET_URL, headers=self.headers)
|
||||
data = r.json()
|
||||
return {tag["id"]: tag["name"] for tag in data["results"]}
|
||||
|
||||
def get_doctypes(self):
|
||||
GET_URL = f"http://{os.getenv('BASE_URL')}/api/document_types/"
|
||||
r = httpx.get(GET_URL, headers=self.headers)
|
||||
data = r.json()
|
||||
return {doctype["id"]: doctype["name"] for doctype in data["results"]}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
10
startup.sh
Normal file
10
startup.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "Running database migrations..."
|
||||
aerich upgrade
|
||||
|
||||
echo "Starting reindex process..."
|
||||
python main.py "" --reindex
|
||||
|
||||
echo "Starting Flask application..."
|
||||
python app.py
|
||||
Reference in New Issue
Block a user