9 Commits

Author SHA1 Message Date
Ryan Chen
e577cb335b query classification 2025-10-26 17:29:00 -04:00
Ryan Chen
591788dfa4 reindex pls 2025-10-26 11:06:32 -04:00
Ryan Chen
561b5bddce reindex pls 2025-10-26 11:04:33 -04:00
Ryan Chen
ddd455a4c6 reindex pls 2025-10-26 11:02:51 -04:00
ryan
07424e77e0 Merge pull request 'favicon' (#7) from update-favicon-and-title into main
Reviewed-on: #7
2025-10-26 10:49:27 -04:00
Ryan Chen
a56f752917 favicon 2025-10-26 10:48:59 -04:00
Ryan Chen
e8264e80ce Changing DB thing 2025-10-26 09:36:33 -04:00
ryan
04350045d3 Merge pull request 'Adding support for conversations and multiple threads' (#6) from conversation-uplift into main
Reviewed-on: #6
2025-10-26 09:25:52 -04:00
Ryan Chen
f16e13fccc big uplift 2025-10-26 09:25:17 -04:00
21 changed files with 3577 additions and 432 deletions

View File

@@ -24,7 +24,6 @@ RUN uv pip install --system -e .
# Copy application code # Copy application code
COPY *.py ./ COPY *.py ./
COPY blueprints ./blueprints COPY blueprints ./blueprints
COPY aerich.toml ./
COPY migrations ./migrations COPY migrations ./migrations
COPY startup.sh ./ COPY startup.sh ./
RUN chmod +x startup.sh RUN chmod +x startup.sh
@@ -35,8 +34,8 @@ WORKDIR /app/raggr-frontend
RUN yarn install && yarn build RUN yarn install && yarn build
WORKDIR /app WORKDIR /app
# Create ChromaDB directory # Create ChromaDB and database directories
RUN mkdir -p /app/chromadb RUN mkdir -p /app/chromadb /app/database
# Expose port # Expose port
EXPOSE 8080 EXPOSE 8080

View File

@@ -10,7 +10,7 @@ from blueprints.users.models import User
async def add_user(username: str, email: str, password: str): async def add_user(username: str, email: str, password: str):
"""Add a new user to the database""" """Add a new user to the database"""
await Tortoise.init( await Tortoise.init(
db_url="sqlite://raggr.db", db_url="sqlite://database/raggr.db",
modules={ modules={
"models": [ "models": [
"blueprints.users.models", "blueprints.users.models",
@@ -56,7 +56,7 @@ async def add_user(username: str, email: str, password: str):
async def list_users(): async def list_users():
"""List all users in the database""" """List all users in the database"""
await Tortoise.init( await Tortoise.init(
db_url="sqlite://raggr.db", db_url="sqlite://database/raggr.db",
modules={ modules={
"models": [ "models": [
"blueprints.users.models", "blueprints.users.models",

View File

@@ -1,7 +1,7 @@
import os import os
TORTOISE_ORM = { TORTOISE_ORM = {
"connections": {"default": os.getenv("DATABASE_URL", "sqlite:///app/raggr.db")}, "connections": {"default": os.getenv("DATABASE_URL", "sqlite:///app/database/raggr.db")},
"apps": { "apps": {
"models": { "models": {
"models": [ "models": [

14
app.py
View File

@@ -27,7 +27,7 @@ app.register_blueprint(blueprints.conversation.conversation_blueprint)
TORTOISE_CONFIG = { TORTOISE_CONFIG = {
"connections": {"default": "sqlite://raggr.db"}, "connections": {"default": "sqlite://database/raggr.db"},
"apps": { "apps": {
"models": { "models": {
"models": [ "models": [
@@ -69,9 +69,11 @@ async def query():
user = await blueprints.users.models.User.get(id=current_user_uuid) user = await blueprints.users.models.User.get(id=current_user_uuid)
data = await request.get_json() data = await request.get_json()
query = data.get("query") query = data.get("query")
conversation = await blueprints.conversation.logic.get_conversation_for_user( conversation_id = data.get("conversation_id")
user=user conversation = await blueprints.conversation.logic.get_conversation_by_id(
conversation_id
) )
await conversation.fetch_related("messages")
await blueprints.conversation.logic.add_message_to_conversation( await blueprints.conversation.logic.add_message_to_conversation(
conversation=conversation, conversation=conversation,
message=query, message=query,
@@ -79,7 +81,11 @@ async def query():
user=user, user=user,
) )
response = consult_simba_oracle(query) transcript = await blueprints.conversation.logic.get_conversation_transcript(
user=user, conversation=conversation
)
response = consult_simba_oracle(input=query, transcript=transcript)
await blueprints.conversation.logic.add_message_to_conversation( await blueprints.conversation.logic.add_message_to_conversation(
conversation=conversation, conversation=conversation,
message=response, message=response,

View File

@@ -1,9 +1,19 @@
import datetime
from quart_jwt_extended import (
jwt_refresh_token_required,
get_jwt_identity,
)
from quart import Blueprint, jsonify from quart import Blueprint, jsonify
from .models import ( from .models import (
Conversation, Conversation,
PydConversation, PydConversation,
PydListConversation,
) )
import blueprints.users.models
conversation_blueprint = Blueprint( conversation_blueprint = Blueprint(
"conversation_api", __name__, url_prefix="/api/conversation" "conversation_api", __name__, url_prefix="/api/conversation"
) )
@@ -12,6 +22,51 @@ conversation_blueprint = Blueprint(
@conversation_blueprint.route("/<conversation_id>") @conversation_blueprint.route("/<conversation_id>")
async def get_conversation(conversation_id: str): async def get_conversation(conversation_id: str):
conversation = await Conversation.get(id=conversation_id) conversation = await Conversation.get(id=conversation_id)
serialized_conversation = await PydConversation.from_tortoise_orm(conversation) await conversation.fetch_related("messages")
return jsonify(serialized_conversation.model_dump_json()) # Manually serialize the conversation with messages
messages = []
for msg in conversation.messages:
messages.append(
{
"id": str(msg.id),
"text": msg.text,
"speaker": msg.speaker.value,
"created_at": msg.created_at.isoformat(),
}
)
return jsonify(
{
"id": str(conversation.id),
"name": conversation.name,
"messages": messages,
"created_at": conversation.created_at.isoformat(),
"updated_at": conversation.updated_at.isoformat(),
}
)
@conversation_blueprint.post("/")
@jwt_refresh_token_required
async def create_conversation():
user_uuid = get_jwt_identity()
user = await blueprints.users.models.User.get(id=user_uuid)
conversation = await Conversation.create(
name=f"{user.username} {datetime.datetime.now().timestamp}",
user=user,
)
serialized_conversation = await PydConversation.from_tortoise_orm(conversation)
return jsonify(serialized_conversation.model_dump())
@conversation_blueprint.get("/")
@jwt_refresh_token_required
async def get_all_conversations():
user_uuid = get_jwt_identity()
user = await blueprints.users.models.User.get(id=user_uuid)
conversations = Conversation.filter(user=user)
serialized_conversations = await PydListConversation.from_queryset(conversations)
return jsonify(serialized_conversations.model_dump())

View File

@@ -44,3 +44,17 @@ async def get_conversation_for_user(user: blueprints.users.models.User) -> Conve
await Conversation.get_or_create(name=f"{user.username}'s chat", user=user) await Conversation.get_or_create(name=f"{user.username}'s chat", user=user)
return await Conversation.get(user=user) return await Conversation.get(user=user)
async def get_conversation_by_id(id: str) -> Conversation:
return await Conversation.get(id=id)
async def get_conversation_transcript(
user: blueprints.users.models.User, conversation: Conversation
) -> str:
messages = []
for message in conversation.messages:
messages.append(f"{message.speaker} at {message.created_at}: {message.text}")
return "\n".join(messages)

View File

@@ -40,5 +40,15 @@ class ConversationMessage(Model):
PydConversationMessage = pydantic_model_creator(ConversationMessage) PydConversationMessage = pydantic_model_creator(ConversationMessage)
PydConversation = pydantic_model_creator(Conversation, name="Conversation") PydConversation = pydantic_model_creator(
Conversation, name="Conversation", allow_cycles=True, exclude=("user",)
)
PydConversationWithMessages = pydantic_model_creator(
Conversation,
name="ConversationWithMessages",
allow_cycles=True,
exclude=("user",),
include=("messages",),
)
PydListConversation = pydantic_queryset_creator(Conversation)
PydListConversationMessage = pydantic_queryset_creator(ConversationMessage) PydListConversationMessage = pydantic_queryset_creator(ConversationMessage)

View File

@@ -14,7 +14,7 @@ from llm import LLMClient
load_dotenv() load_dotenv()
ollama_client = Client( ollama_client = Client(
host=os.getenv("OLLAMA_HOST", "http://localhost:11434"), timeout=10.0 host=os.getenv("OLLAMA_HOST", "http://localhost:11434"), timeout=1.0
) )

View File

@@ -12,6 +12,8 @@ services:
- OPENAI_API_KEY=${OPENAI_API_KEY} - OPENAI_API_KEY=${OPENAI_API_KEY}
volumes: volumes:
- chromadb_data:/app/chromadb - chromadb_data:/app/chromadb
- database_data:/app/database
volumes: volumes:
chromadb_data: chromadb_data:
database_data:

View File

@@ -27,7 +27,7 @@ headers = {"x-api-key": API_KEY, "Content-Type": "application/json"}
VISITED = {} VISITED = {}
if __name__ == "__main__": if __name__ == "__main__":
conn = sqlite3.connect("./visited.db") conn = sqlite3.connect("./database/visited.db")
c = conn.cursor() c = conn.cursor()
c.execute("select immich_id from visited") c.execute("select immich_id from visited")
rows = c.fetchall() rows = c.fetchall()

2
llm.py
View File

@@ -17,7 +17,7 @@ class LLMClient:
def __init__(self): def __init__(self):
try: try:
self.ollama_client = Client( self.ollama_client = Client(
host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=10.0 host=os.getenv("OLLAMA_URL", "http://localhost:11434"), timeout=1.0
) )
self.ollama_client.chat( self.ollama_client.chat(
model="gemma3:4b", messages=[{"role": "system", "content": "test"}] model="gemma3:4b", messages=[{"role": "system", "content": "test"}]

121
main.py
View File

@@ -7,6 +7,8 @@ import argparse
import chromadb import chromadb
import ollama import ollama
import time
from request import PaperlessNGXService from request import PaperlessNGXService
from chunker import Chunker from chunker import Chunker
@@ -36,6 +38,7 @@ parser.add_argument("query", type=str, help="questions about simba's health")
parser.add_argument( parser.add_argument(
"--reindex", action="store_true", help="re-index the simba documents" "--reindex", action="store_true", help="re-index the simba documents"
) )
parser.add_argument("--classify", action="store_true", help="test classification")
parser.add_argument("--index", help="index a file") parser.add_argument("--index", help="index a file")
ppngx = PaperlessNGXService() ppngx = PaperlessNGXService()
@@ -77,7 +80,7 @@ def chunk_data(docs, collection, doctypes):
logging.info(f"chunking {len(docs)} documents") logging.info(f"chunking {len(docs)} documents")
texts: list[str] = [doc["content"] for doc in docs] texts: list[str] = [doc["content"] for doc in docs]
with sqlite3.connect("visited.db") as conn: with sqlite3.connect("database/visited.db") as conn:
to_insert = [] to_insert = []
c = conn.cursor() c = conn.cursor()
for index, text in enumerate(texts): for index, text in enumerate(texts):
@@ -113,9 +116,22 @@ def chunk_text(texts: list[str], collection):
) )
def consult_oracle(input: str, collection): def classify_query(query: str, transcript: str) -> bool:
import time logging.info("Starting query generation")
qg_start = time.time()
qg = QueryGenerator()
query_type = qg.get_query_type(input=query, transcript=transcript)
logging.info(query_type)
qg_end = time.time()
logging.info(f"Query generation took {qg_end - qg_start:.2f} seconds")
return query_type == "Simba"
def consult_oracle(
input: str,
collection,
transcript: str = "",
):
chunker = Chunker(collection) chunker = Chunker(collection)
start_time = time.time() start_time = time.time()
@@ -153,7 +169,10 @@ def consult_oracle(input: str, collection):
logging.info("Starting LLM generation") logging.info("Starting LLM generation")
llm_start = time.time() llm_start = time.time()
system_prompt = "You are a helpful assistant that understands veterinary terms." system_prompt = "You are a helpful assistant that understands veterinary terms."
prompt = f"Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}" transcript_prompt = f"Here is the message transcript thus far {transcript}."
prompt = f"""Using the following data, help answer the user's query by providing as many details as possible.
Using this data: {results}. {transcript_prompt if len(transcript) > 0 else ""}
Respond to this prompt: {input}"""
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt) output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
llm_end = time.time() llm_end = time.time()
logging.info(f"LLM generation took {llm_end - llm_start:.2f} seconds") logging.info(f"LLM generation took {llm_end - llm_start:.2f} seconds")
@@ -164,6 +183,16 @@ def consult_oracle(input: str, collection):
return output return output
def llm_chat(input: str, transcript: str = "") -> str:
system_prompt = "You are a helpful assistant that understands veterinary terms."
transcript_prompt = f"Here is the message transcript thus far {transcript}."
prompt = f"""Answer the user in a humorous way as if you were a cat named Simba. Be very coy.
{transcript_prompt if len(transcript) > 0 else ""}
Respond to this prompt: {input}"""
output = llm_client.chat(prompt=prompt, system_prompt=system_prompt)
return output
def paperless_workflow(input): def paperless_workflow(input):
# Step 1: Get the text # Step 1: Get the text
ppngx = PaperlessNGXService() ppngx = PaperlessNGXService()
@@ -173,15 +202,24 @@ def paperless_workflow(input):
consult_oracle(input, simba_docs) consult_oracle(input, simba_docs)
def consult_simba_oracle(input: str): def consult_simba_oracle(input: str, transcript: str = ""):
return consult_oracle( is_simba_related = classify_query(query=input, transcript=transcript)
input=input,
collection=simba_docs, if is_simba_related:
) logging.info("Query is related to simba")
return consult_oracle(
input=input,
collection=simba_docs,
transcript=transcript,
)
logging.info("Query is NOT related to simba")
return llm_chat(input=input, transcript=transcript)
def filter_indexed_files(docs): def filter_indexed_files(docs):
with sqlite3.connect("visited.db") as conn: with sqlite3.connect("database/visited.db") as conn:
c = conn.cursor() c = conn.cursor()
c.execute( c.execute(
"CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)" "CREATE TABLE IF NOT EXISTS indexed_documents (id INTEGER PRIMARY KEY AUTOINCREMENT, paperless_id INTEGER)"
@@ -194,38 +232,45 @@ def filter_indexed_files(docs):
return [doc for doc in docs if doc["id"] not in visited] return [doc for doc in docs if doc["id"] not in visited]
def reindex():
with sqlite3.connect("database/visited.db") as conn:
c = conn.cursor()
c.execute("DELETE FROM indexed_documents")
conn.commit()
# Delete all documents from the collection
all_docs = simba_docs.get()
if all_docs["ids"]:
simba_docs.delete(ids=all_docs["ids"])
logging.info("Fetching documents from Paperless-NGX")
ppngx = PaperlessNGXService()
docs = ppngx.get_data()
docs = filter_indexed_files(docs)
logging.info(f"Fetched {len(docs)} documents")
# Delete all chromadb data
ids = simba_docs.get(ids=None, limit=None, offset=0)
all_ids = ids["ids"]
if len(all_ids) > 0:
simba_docs.delete(ids=all_ids)
# Chunk documents
logging.info("Chunking documents now ...")
doctype_lookup = ppngx.get_doctypes()
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
logging.info("Done chunking documents")
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if args.reindex: if args.reindex:
logging.info("Fetching documents from Paperless-NGX") reindex()
ppngx = PaperlessNGXService()
docs = ppngx.get_data()
docs = filter_indexed_files(docs)
logging.info(f"Fetched {len(docs)} documents")
# Delete all chromadb data if args.classify:
ids = simba_docs.get(ids=None, limit=None, offset=0) consult_simba_oracle(input="yohohoho testing")
all_ids = ids["ids"] consult_simba_oracle(input="write an email")
if len(all_ids) > 0: consult_simba_oracle(input="how much does simba weigh")
simba_docs.delete(ids=all_ids)
# Chunk documents
logging.info("Chunking documents now ...")
tag_lookup = ppngx.get_tags()
doctype_lookup = ppngx.get_doctypes()
chunk_data(docs, collection=simba_docs, doctypes=doctype_lookup)
logging.info("Done chunking documents")
# if args.index:
# with open(args.index) as file:
# extension = args.index.split(".")[-1]
# if extension == "pdf":
# pdf_path = ppngx.download_pdf_from_id(id=document_id)
# image_paths = pdf_to_image(filepath=pdf_path)
# print(f"summarizing {file}")
# generated_summary = summarize_pdf_image(filepaths=image_paths)
# elif extension in [".md", ".txt"]:
# chunk_text(texts=[file.readall()], collection=simba_docs)
if args.query: if args.query:
logging.info("Consulting oracle ...") logging.info("Consulting oracle ...")

View File

@@ -49,11 +49,20 @@ DOCTYPE_OPTIONS = [
"Letter", "Letter",
] ]
QUERY_TYPE_OPTIONS = [
"Simba",
"Other",
]
class DocumentType(BaseModel): class DocumentType(BaseModel):
type: list[str] = Field(description="type of document", enum=DOCTYPE_OPTIONS) type: list[str] = Field(description="type of document", enum=DOCTYPE_OPTIONS)
class QueryType(BaseModel):
type: str = Field(desciption="type of query", enum=QUERY_TYPE_OPTIONS)
PROMPT = """ PROMPT = """
You are an information specialist that processes user queries. The current year is 2025. The user queries are all about You are an information specialist that processes user queries. The current year is 2025. The user queries are all about
a cat, Simba, and its records. The types of records are listed below. Using the query, extract the a cat, Simba, and its records. The types of records are listed below. Using the query, extract the
@@ -111,6 +120,27 @@ Query: "Who does Simba know?"
Tags: ["Letter", "Documentation"] Tags: ["Letter", "Documentation"]
""" """
QUERY_TYPE_PROMPT = f"""You are an information specialist that processes user queries.
A query can have one tag attached from the following options. Based on the query and the transcript which is listed below, determine
which of the following options is most appropriate: {",".join(QUERY_TYPE_OPTIONS)}
### Example 1
Query: "Who is Simba's current vet?"
Tags: ["Simba"]
### Example 2
Query: "What is the capital of Tokyo?"
Tags: ["Other"]
### Example 3
Query: "Can you help me write an email?"
Tags: ["Other"]
TRANSCRIPT:
"""
class QueryGenerator: class QueryGenerator:
def __init__(self) -> None: def __init__(self) -> None:
@@ -154,6 +184,33 @@ class QueryGenerator:
metadata_query = {"document_type": {"$in": type_data["type"]}} metadata_query = {"document_type": {"$in": type_data["type"]}}
return metadata_query return metadata_query
def get_query_type(self, input: str, transcript: str):
client = OpenAI()
response = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are an information specialist that is really good at deciding what tags a query should have",
},
{
"role": "user",
"content": f"{QUERY_TYPE_PROMPT}\nTRANSCRIPT:\n{transcript}\nQUERY:{input}",
},
],
model="gpt-4o",
response_format={
"type": "json_schema",
"json_schema": {
"name": "query_type",
"schema": QueryType.model_json_schema(),
},
},
)
response_json_str = response.choices[0].message.content
type_data = json.loads(response_json_str)
return type_data["type"]
def get_query(self, input: str): def get_query(self, input: str):
client = OpenAI() client = OpenAI()
response = client.responses.parse( response = client.responses.parse(

2677
raggr-frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -6,14 +6,18 @@
"scripts": { "scripts": {
"build": "rsbuild build", "build": "rsbuild build",
"dev": "rsbuild dev --open", "dev": "rsbuild dev --open",
"preview": "rsbuild preview" "preview": "rsbuild preview",
"watch": "npm-watch build",
"watch:build": "rsbuild build --watch"
}, },
"dependencies": { "dependencies": {
"axios": "^1.12.2", "axios": "^1.12.2",
"marked": "^16.3.0", "marked": "^16.3.0",
"npm-watch": "^0.13.0",
"react": "^19.1.1", "react": "^19.1.1",
"react-dom": "^19.1.1", "react-dom": "^19.1.1",
"react-markdown": "^10.1.0" "react-markdown": "^10.1.0",
"watch": "^1.0.2"
}, },
"devDependencies": { "devDependencies": {
"@rsbuild/core": "^1.5.6", "@rsbuild/core": "^1.5.6",
@@ -22,5 +26,16 @@
"@types/react": "^19.1.13", "@types/react": "^19.1.13",
"@types/react-dom": "^19.1.9", "@types/react-dom": "^19.1.9",
"typescript": "^5.9.2" "typescript": "^5.9.2"
},
"watch": {
"build": {
"patterns": [
"src"
],
"extensions": "ts,tsx,css,js,jsx",
"delay": 1000,
"quiet": false,
"inherit": true
}
} }
} }

View File

@@ -3,4 +3,8 @@ import { pluginReact } from '@rsbuild/plugin-react';
export default defineConfig({ export default defineConfig({
plugins: [pluginReact()], plugins: [pluginReact()],
html: {
title: 'Raggr',
favicon: './src/assets/favicon.svg',
},
}); });

View File

@@ -10,9 +10,10 @@ interface Message {
interface Conversation { interface Conversation {
id: string; id: string;
name: string; name: string;
messages: Message[]; messages?: Message[];
created_at: string; created_at: string;
updated_at: string; updated_at: string;
user_id?: string;
} }
interface QueryRequest { interface QueryRequest {
@@ -23,15 +24,23 @@ interface QueryResponse {
response: string; response: string;
} }
interface CreateConversationRequest {
user_id: string;
}
class ConversationService { class ConversationService {
private baseUrl = "/api"; private baseUrl = "/api";
private conversationBaseUrl = "/api/conversation";
async sendQuery(query: string): Promise<QueryResponse> { async sendQuery(
query: string,
conversation_id: string,
): Promise<QueryResponse> {
const response = await userService.fetchWithRefreshToken( const response = await userService.fetchWithRefreshToken(
`${this.baseUrl}/query`, `${this.baseUrl}/query`,
{ {
method: "POST", method: "POST",
body: JSON.stringify({ query }), body: JSON.stringify({ query, conversation_id }),
}, },
); );
@@ -56,6 +65,51 @@ class ConversationService {
return await response.json(); return await response.json();
} }
async getConversation(conversationId: string): Promise<Conversation> {
const response = await userService.fetchWithRefreshToken(
`${this.conversationBaseUrl}/${conversationId}`,
{
method: "GET",
},
);
if (!response.ok) {
throw new Error("Failed to fetch conversation");
}
return await response.json();
}
async createConversation(): Promise<Conversation> {
const response = await userService.fetchWithRefreshToken(
`${this.conversationBaseUrl}/`,
{
method: "POST",
},
);
if (!response.ok) {
throw new Error("Failed to create conversation");
}
return await response.json();
}
async getAllConversations(): Promise<Conversation[]> {
const response = await userService.fetchWithRefreshToken(
`${this.conversationBaseUrl}/`,
{
method: "GET",
},
);
if (!response.ok) {
throw new Error("Failed to fetch conversations");
}
return await response.json();
}
} }
export const conversationService = new ConversationService(); export const conversationService = new ConversationService();

View File

@@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100">
<text y="80" font-size="80" font-family="system-ui, -apple-system, sans-serif">🐱</text>
</svg>

After

Width:  |  Height:  |  Size: 163 B

View File

@@ -2,6 +2,8 @@ import { useEffect, useState } from "react";
import { conversationService } from "../api/conversationService"; import { conversationService } from "../api/conversationService";
import { QuestionBubble } from "./QuestionBubble"; import { QuestionBubble } from "./QuestionBubble";
import { AnswerBubble } from "./AnswerBubble"; import { AnswerBubble } from "./AnswerBubble";
import { ConversationList } from "./ConversationList";
import { parse } from "node:path/win32";
type Message = { type Message = {
text: string; text: string;
@@ -33,13 +35,69 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
const [conversations, setConversations] = useState<Conversation[]>([ const [conversations, setConversations] = useState<Conversation[]>([
{ title: "simba meow meow", id: "uuid" }, { title: "simba meow meow", id: "uuid" },
]); ]);
const [showConversations, setShowConversations] = useState<boolean>(false);
const [selectedConversation, setSelectedConversation] =
useState<Conversation | null>(null);
const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"]; const simbaAnswers = ["meow.", "hiss...", "purrrrrr", "yowOWROWWowowr"];
useEffect(() => { const handleSelectConversation = (conversation: Conversation) => {
setShowConversations(false);
setSelectedConversation(conversation);
const loadMessages = async () => { const loadMessages = async () => {
try { try {
const conversation = await conversationService.getMessages(); const fetchedConversation = await conversationService.getConversation(
conversation.id,
);
setMessages(
fetchedConversation.messages.map((message) => ({
text: message.text,
speaker: message.speaker,
})),
);
} catch (error) {
console.error("Failed to load messages:", error);
}
};
loadMessages();
};
const loadConversations = async () => {
try {
const fetchedConversations =
await conversationService.getAllConversations();
const parsedConversations = fetchedConversations.map((conversation) => ({
id: conversation.id,
title: conversation.name,
}));
setConversations(parsedConversations);
setSelectedConversation(parsedConversations[0]);
console.log(parsedConversations);
} catch (error) {
console.error("Failed to load messages:", error);
}
};
const handleCreateNewConversation = async () => {
const newConversation = await conversationService.createConversation();
await loadConversations();
setSelectedConversation({
title: newConversation.name,
id: newConversation.id,
});
};
useEffect(() => {
loadConversations();
}, []);
useEffect(() => {
const loadMessages = async () => {
if (selectedConversation == null) return;
try {
const conversation = await conversationService.getConversation(
selectedConversation.id,
);
setMessages( setMessages(
conversation.messages.map((message) => ({ conversation.messages.map((message) => ({
text: message.text, text: message.text,
@@ -51,7 +109,7 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
} }
}; };
loadMessages(); loadMessages();
}, []); }, [selectedConversation]);
const handleQuestionSubmit = async () => { const handleQuestionSubmit = async () => {
const currMessages = messages.concat([{ text: query, speaker: "user" }]); const currMessages = messages.concat([{ text: query, speaker: "user" }]);
@@ -74,7 +132,10 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
} }
try { try {
const result = await conversationService.sendQuery(query); const result = await conversationService.sendQuery(
query,
selectedConversation.id,
);
setQuestionsAnswers( setQuestionsAnswers(
questionsAnswers.concat([{ question: query, answer: result.response }]), questionsAnswers.concat([{ question: query, answer: result.response }]),
); );
@@ -101,16 +162,33 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
<div className="flex flex-row justify-center py-4"> <div className="flex flex-row justify-center py-4">
<div className="flex flex-col gap-4 min-w-xl max-w-xl"> <div className="flex flex-col gap-4 min-w-xl max-w-xl">
<div className="flex flex-row justify-between"> <div className="flex flex-row justify-between">
<header className="flex flex-row justify-center gap-2 grow sticky top-0 z-10 bg-white"> <header className="flex flex-row justify-center gap-2 sticky top-0 z-10 bg-white">
<h1 className="text-3xl">ask simba!</h1> <h1 className="text-3xl">ask simba!</h1>
</header> </header>
<button <div className="flex flex-row gap-2">
className="p-4 border border-red-400 bg-red-200 hover:bg-red-400 cursor-pointer rounded-md" <button
onClick={() => setAuthenticated(false)} className="p-2 border border-green-400 bg-green-200 hover:bg-green-400 cursor-pointer rounded-md"
> onClick={() => setShowConversations(!showConversations)}
logout >
</button> {showConversations
? "hide conversations"
: "show conversations"}
</button>
<button
className="p-2 border border-red-400 bg-red-200 hover:bg-red-400 cursor-pointer rounded-md"
onClick={() => setAuthenticated(false)}
>
logout
</button>
</div>
</div> </div>
{showConversations && (
<ConversationList
conversations={conversations}
onCreateNewConversation={handleCreateNewConversation}
onSelectConversation={handleSelectConversation}
/>
)}
{messages.map((msg, index) => { {messages.map((msg, index) => {
if (msg.speaker === "simba") { if (msg.speaker === "simba") {
return <AnswerBubble key={index} text={msg.text} />; return <AnswerBubble key={index} text={msg.text} />;

View File

@@ -0,0 +1,60 @@
import { useState, useEffect } from "react";
import { conversationService } from "../api/conversationService";
type Conversation = {
title: string;
id: string;
};
type ConversationProps = {
conversations: Conversation[];
onSelectConversation: (conversation: Conversation) => void;
onCreateNewConversation: () => void;
};
export const ConversationList = ({
conversations,
onSelectConversation,
onCreateNewConversation,
}: ConversationProps) => {
const [conservations, setConversations] = useState(conversations);
useEffect(() => {
const loadConversations = async () => {
try {
const fetchedConversations =
await conversationService.getAllConversations();
setConversations(
fetchedConversations.map((conversation) => ({
id: conversation.id,
title: conversation.name,
})),
);
} catch (error) {
console.error("Failed to load messages:", error);
}
};
loadConversations();
}, []);
return (
<div className="bg-indigo-300 rounded-md p-3 flex flex-col">
{conservations.map((conversation) => {
return (
<div
className="border-blue-400 bg-indigo-300 hover:bg-indigo-200 cursor-pointer rounded-md p-2"
onClick={() => onSelectConversation(conversation)}
>
<p>{conversation.title}</p>
</div>
);
})}
<div
className="border-blue-400 bg-indigo-300 hover:bg-indigo-200 cursor-pointer rounded-md p-2"
onClick={() => onCreateNewConversation()}
>
<p> + Start a new thread</p>
</div>
</div>
);
};

File diff suppressed because it is too large Load Diff