Compare commits
6 Commits
feat/user-
...
feat/renam
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ffbe992f64 | ||
|
|
9ed4ca126a | ||
|
|
f3ae76ce68 | ||
|
|
7ee3bdef84 | ||
|
|
500c44feb1 | ||
|
|
896501deb1 |
9
app.py
9
app.py
@@ -135,17 +135,10 @@ async def get_messages():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
name = conversation.name
|
|
||||||
if len(messages) > 8:
|
|
||||||
name = await blueprints.conversation.logic.rename_conversation(
|
|
||||||
user=user,
|
|
||||||
conversation=conversation,
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"id": str(conversation.id),
|
"id": str(conversation.id),
|
||||||
"name": name,
|
"name": conversation.name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"created_at": conversation.created_at.isoformat(),
|
"created_at": conversation.created_at.isoformat(),
|
||||||
"updated_at": conversation.updated_at.isoformat(),
|
"updated_at": conversation.updated_at.isoformat(),
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import datetime
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
@@ -20,7 +19,6 @@ from .agents import main_agent
|
|||||||
from .logic import (
|
from .logic import (
|
||||||
add_message_to_conversation,
|
add_message_to_conversation,
|
||||||
get_conversation_by_id,
|
get_conversation_by_id,
|
||||||
rename_conversation,
|
|
||||||
)
|
)
|
||||||
from .memory import get_memories_for_user
|
from .memory import get_memories_for_user
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -242,8 +240,6 @@ async def stream_query():
|
|||||||
@jwt_refresh_token_required
|
@jwt_refresh_token_required
|
||||||
async def get_conversation(conversation_id: str):
|
async def get_conversation(conversation_id: str):
|
||||||
conversation = await Conversation.get(id=conversation_id)
|
conversation = await Conversation.get(id=conversation_id)
|
||||||
current_user_uuid = get_jwt_identity()
|
|
||||||
user = await blueprints.users.models.User.get(id=current_user_uuid)
|
|
||||||
await conversation.fetch_related("messages")
|
await conversation.fetch_related("messages")
|
||||||
|
|
||||||
# Manually serialize the conversation with messages
|
# Manually serialize the conversation with messages
|
||||||
@@ -258,18 +254,10 @@ async def get_conversation(conversation_id: str):
|
|||||||
"image_key": msg.image_key,
|
"image_key": msg.image_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
name = conversation.name
|
|
||||||
if len(messages) > 8 and "datetime" in name.lower():
|
|
||||||
name = await rename_conversation(
|
|
||||||
user=user,
|
|
||||||
conversation=conversation,
|
|
||||||
)
|
|
||||||
print(name)
|
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"id": str(conversation.id),
|
"id": str(conversation.id),
|
||||||
"name": name,
|
"name": conversation.name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"created_at": conversation.created_at.isoformat(),
|
"created_at": conversation.created_at.isoformat(),
|
||||||
"updated_at": conversation.updated_at.isoformat(),
|
"updated_at": conversation.updated_at.isoformat(),
|
||||||
@@ -283,7 +271,7 @@ async def create_conversation():
|
|||||||
user_uuid = get_jwt_identity()
|
user_uuid = get_jwt_identity()
|
||||||
user = await blueprints.users.models.User.get(id=user_uuid)
|
user = await blueprints.users.models.User.get(id=user_uuid)
|
||||||
conversation = await Conversation.create(
|
conversation = await Conversation.create(
|
||||||
name=f"{user.username} {datetime.datetime.now().timestamp}",
|
name="New Conversation",
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import tortoise.exceptions
|
import tortoise.exceptions
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
import blueprints.users.models
|
import blueprints.users.models
|
||||||
|
|
||||||
from .models import Conversation, ConversationMessage, RenameConversationOutputSchema
|
from .models import Conversation, ConversationMessage
|
||||||
|
|
||||||
|
|
||||||
async def create_conversation(name: str = "") -> Conversation:
|
async def create_conversation(name: str = "") -> Conversation:
|
||||||
@@ -67,22 +66,3 @@ async def get_conversation_transcript(
|
|||||||
messages.append(f"{message.speaker} at {message.created_at}: {message.text}")
|
messages.append(f"{message.speaker} at {message.created_at}: {message.text}")
|
||||||
|
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
|
|
||||||
async def rename_conversation(
|
|
||||||
user: blueprints.users.models.User,
|
|
||||||
conversation: Conversation,
|
|
||||||
) -> str:
|
|
||||||
messages: str = await get_conversation_transcript(
|
|
||||||
user=user, conversation=conversation
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
|
||||||
structured_llm = llm.with_structured_output(RenameConversationOutputSchema)
|
|
||||||
|
|
||||||
prompt = f"Summarize the following conversation into a sassy one-liner title:\n\n{messages}"
|
|
||||||
response = structured_llm.invoke(prompt)
|
|
||||||
new_name: str = response.get("title", "")
|
|
||||||
conversation.name = new_name
|
|
||||||
await conversation.save()
|
|
||||||
return new_name
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import enum
|
import enum
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
from tortoise.contrib.pydantic import (
|
from tortoise.contrib.pydantic import (
|
||||||
@@ -9,12 +8,6 @@ from tortoise.contrib.pydantic import (
|
|||||||
from tortoise.models import Model
|
from tortoise.models import Model
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RenameConversationOutputSchema:
|
|
||||||
title: str
|
|
||||||
justification: str
|
|
||||||
|
|
||||||
|
|
||||||
class Speaker(enum.Enum):
|
class Speaker(enum.Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
SIMBA = "simba"
|
SIMBA = "simba"
|
||||||
|
|||||||
@@ -120,20 +120,6 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
}, [messages]);
|
}, [messages]);
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const load = async () => {
|
|
||||||
if (!selectedConversation) return;
|
|
||||||
try {
|
|
||||||
const conv = await conversationService.getConversation(selectedConversation.id);
|
|
||||||
setSelectedConversation({ id: conv.id, title: conv.name });
|
|
||||||
setMessages(conv.messages.map((m) => ({ text: m.text, speaker: m.speaker, image_key: m.image_key })));
|
|
||||||
} catch (err) {
|
|
||||||
console.error("Failed to load messages:", err);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
load();
|
|
||||||
}, [selectedConversation?.id]);
|
|
||||||
|
|
||||||
const handleQuestionSubmit = useCallback(async () => {
|
const handleQuestionSubmit = useCallback(async () => {
|
||||||
if ((!query.trim() && !pendingImage) || isLoading) return;
|
if ((!query.trim() && !pendingImage) || isLoading) return;
|
||||||
|
|
||||||
@@ -215,7 +201,10 @@ export const ChatScreen = ({ setAuthenticated }: ChatScreenProps) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
if (isMountedRef.current) setIsLoading(false);
|
if (isMountedRef.current) {
|
||||||
|
setIsLoading(false);
|
||||||
|
loadConversations();
|
||||||
|
}
|
||||||
abortControllerRef.current = null;
|
abortControllerRef.current = null;
|
||||||
}
|
}
|
||||||
}, [query, pendingImage, isLoading, selectedConversation, simbaMode, messages, setAuthenticated]);
|
}, [query, pendingImage, isLoading, selectedConversation, simbaMode, messages, setAuthenticated]);
|
||||||
|
|||||||
97
scripts/rename_conversations.py
Normal file
97
scripts/rename_conversations.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Management command to rename all conversations.
|
||||||
|
|
||||||
|
- Conversations with >10 messages: renamed to an LLM-generated summary
|
||||||
|
- Conversations with <=10 messages: renamed to a truncation of the first user message
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tortoise import Tortoise
|
||||||
|
|
||||||
|
from blueprints.conversation.models import Conversation, Speaker
|
||||||
|
from llm import LLMClient
|
||||||
|
|
||||||
|
|
||||||
|
async def rename_conversations(dry_run: bool = False):
|
||||||
|
"""Rename all conversations based on message count."""
|
||||||
|
|
||||||
|
database_url = os.getenv("DATABASE_URL", "sqlite://raggr.db")
|
||||||
|
await Tortoise.init(
|
||||||
|
db_url=database_url,
|
||||||
|
modules={
|
||||||
|
"models": [
|
||||||
|
"blueprints.users.models",
|
||||||
|
"blueprints.conversation.models",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = LLMClient()
|
||||||
|
conversations = await Conversation.all().prefetch_related("messages")
|
||||||
|
|
||||||
|
renamed = 0
|
||||||
|
skipped = 0
|
||||||
|
|
||||||
|
for conversation in conversations:
|
||||||
|
messages = sorted(conversation.messages, key=lambda m: m.created_at)
|
||||||
|
user_messages = [m for m in messages if m.speaker == Speaker.USER]
|
||||||
|
|
||||||
|
if not user_messages:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(messages) > 10:
|
||||||
|
# Summarize via LLM
|
||||||
|
message_text = "\n".join(
|
||||||
|
f"{m.speaker.value}: {m.text}" for m in messages[:30]
|
||||||
|
)
|
||||||
|
new_name = llm.chat(
|
||||||
|
prompt=message_text,
|
||||||
|
system_prompt=(
|
||||||
|
"You are naming a conversation. Given the messages below, "
|
||||||
|
"produce a short, descriptive title (max 8 words). "
|
||||||
|
"Reply with ONLY the title, nothing else."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
new_name = new_name.strip().strip('"').strip("'")[:100]
|
||||||
|
else:
|
||||||
|
# Truncate first user message
|
||||||
|
new_name = user_messages[0].text[:100]
|
||||||
|
|
||||||
|
old_name = conversation.name
|
||||||
|
if old_name == new_name:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
print(f" [dry-run] '{old_name}' -> '{new_name}'")
|
||||||
|
else:
|
||||||
|
conversation.name = new_name
|
||||||
|
await conversation.save()
|
||||||
|
print(f" '{old_name}' -> '{new_name}'")
|
||||||
|
|
||||||
|
renamed += 1
|
||||||
|
|
||||||
|
print(f"\nRenamed: {renamed} Skipped: {skipped}")
|
||||||
|
if dry_run:
|
||||||
|
print("(dry run — no changes were saved)")
|
||||||
|
finally:
|
||||||
|
await Tortoise.close_connections()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Rename conversations based on message count"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dry-run",
|
||||||
|
action="store_true",
|
||||||
|
help="Preview renames without saving",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
asyncio.run(rename_conversations(dry_run=args.dry_run))
|
||||||
Reference in New Issue
Block a user