refactor
This commit is contained in:
72
services/raggr/blueprints/conversation/__init__.py
Normal file
72
services/raggr/blueprints/conversation/__init__.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import datetime
|
||||
|
||||
from quart_jwt_extended import (
|
||||
jwt_refresh_token_required,
|
||||
get_jwt_identity,
|
||||
)
|
||||
|
||||
from quart import Blueprint, jsonify
|
||||
from .models import (
|
||||
Conversation,
|
||||
PydConversation,
|
||||
PydListConversation,
|
||||
)
|
||||
|
||||
import blueprints.users.models
|
||||
|
||||
conversation_blueprint = Blueprint(
|
||||
"conversation_api", __name__, url_prefix="/api/conversation"
|
||||
)
|
||||
|
||||
|
||||
@conversation_blueprint.route("/<conversation_id>")
|
||||
async def get_conversation(conversation_id: str):
|
||||
conversation = await Conversation.get(id=conversation_id)
|
||||
await conversation.fetch_related("messages")
|
||||
|
||||
# Manually serialize the conversation with messages
|
||||
messages = []
|
||||
for msg in conversation.messages:
|
||||
messages.append(
|
||||
{
|
||||
"id": str(msg.id),
|
||||
"text": msg.text,
|
||||
"speaker": msg.speaker.value,
|
||||
"created_at": msg.created_at.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"id": str(conversation.id),
|
||||
"name": conversation.name,
|
||||
"messages": messages,
|
||||
"created_at": conversation.created_at.isoformat(),
|
||||
"updated_at": conversation.updated_at.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@conversation_blueprint.post("/")
|
||||
@jwt_refresh_token_required
|
||||
async def create_conversation():
|
||||
user_uuid = get_jwt_identity()
|
||||
user = await blueprints.users.models.User.get(id=user_uuid)
|
||||
conversation = await Conversation.create(
|
||||
name=f"{user.username} {datetime.datetime.now().timestamp}",
|
||||
user=user,
|
||||
)
|
||||
|
||||
serialized_conversation = await PydConversation.from_tortoise_orm(conversation)
|
||||
return jsonify(serialized_conversation.model_dump())
|
||||
|
||||
|
||||
@conversation_blueprint.get("/")
|
||||
@jwt_refresh_token_required
|
||||
async def get_all_conversations():
|
||||
user_uuid = get_jwt_identity()
|
||||
user = await blueprints.users.models.User.get(id=user_uuid)
|
||||
conversations = Conversation.filter(user=user)
|
||||
serialized_conversations = await PydListConversation.from_queryset(conversations)
|
||||
|
||||
return jsonify(serialized_conversations.model_dump())
|
||||
60
services/raggr/blueprints/conversation/logic.py
Normal file
60
services/raggr/blueprints/conversation/logic.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import tortoise.exceptions
|
||||
|
||||
from .models import Conversation, ConversationMessage
|
||||
|
||||
import blueprints.users.models
|
||||
|
||||
|
||||
async def create_conversation(name: str = "") -> Conversation:
|
||||
conversation = await Conversation.create(name=name)
|
||||
return conversation
|
||||
|
||||
|
||||
async def add_message_to_conversation(
|
||||
conversation: Conversation,
|
||||
message: str,
|
||||
speaker: str,
|
||||
user: blueprints.users.models.User,
|
||||
) -> ConversationMessage:
|
||||
print(conversation, message, speaker)
|
||||
message = await ConversationMessage.create(
|
||||
text=message,
|
||||
speaker=speaker,
|
||||
conversation=conversation,
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
|
||||
async def get_the_only_conversation() -> Conversation:
|
||||
try:
|
||||
conversation = await Conversation.all().first()
|
||||
if conversation is None:
|
||||
conversation = await Conversation.create(name="simba_chat")
|
||||
except Exception as _e:
|
||||
conversation = await Conversation.create(name="simba_chat")
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
async def get_conversation_for_user(user: blueprints.users.models.User) -> Conversation:
|
||||
try:
|
||||
return await Conversation.get(user=user)
|
||||
except tortoise.exceptions.DoesNotExist:
|
||||
await Conversation.get_or_create(name=f"{user.username}'s chat", user=user)
|
||||
|
||||
return await Conversation.get(user=user)
|
||||
|
||||
|
||||
async def get_conversation_by_id(id: str) -> Conversation:
|
||||
return await Conversation.get(id=id)
|
||||
|
||||
|
||||
async def get_conversation_transcript(
|
||||
user: blueprints.users.models.User, conversation: Conversation
|
||||
) -> str:
|
||||
messages = []
|
||||
for message in conversation.messages:
|
||||
messages.append(f"{message.speaker} at {message.created_at}: {message.text}")
|
||||
|
||||
return "\n".join(messages)
|
||||
54
services/raggr/blueprints/conversation/models.py
Normal file
54
services/raggr/blueprints/conversation/models.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import enum
|
||||
|
||||
from tortoise.models import Model
|
||||
from tortoise import fields
|
||||
from tortoise.contrib.pydantic import (
|
||||
pydantic_queryset_creator,
|
||||
pydantic_model_creator,
|
||||
)
|
||||
|
||||
|
||||
class Speaker(enum.Enum):
|
||||
USER = "user"
|
||||
SIMBA = "simba"
|
||||
|
||||
|
||||
class Conversation(Model):
|
||||
id = fields.UUIDField(primary_key=True)
|
||||
name = fields.CharField(max_length=255)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
updated_at = fields.DatetimeField(auto_now=True)
|
||||
user: fields.ForeignKeyRelation = fields.ForeignKeyField(
|
||||
"models.User", related_name="conversations", null=True
|
||||
)
|
||||
|
||||
class Meta:
|
||||
table = "conversations"
|
||||
|
||||
|
||||
class ConversationMessage(Model):
|
||||
id = fields.UUIDField(primary_key=True)
|
||||
text = fields.TextField()
|
||||
conversation = fields.ForeignKeyField(
|
||||
"models.Conversation", related_name="messages"
|
||||
)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
speaker = fields.CharEnumField(enum_type=Speaker, max_length=10)
|
||||
|
||||
class Meta:
|
||||
table = "conversation_messages"
|
||||
|
||||
|
||||
PydConversationMessage = pydantic_model_creator(ConversationMessage)
|
||||
PydConversation = pydantic_model_creator(
|
||||
Conversation, name="Conversation", allow_cycles=True, exclude=("user",)
|
||||
)
|
||||
PydConversationWithMessages = pydantic_model_creator(
|
||||
Conversation,
|
||||
name="ConversationWithMessages",
|
||||
allow_cycles=True,
|
||||
exclude=("user",),
|
||||
include=("messages",),
|
||||
)
|
||||
PydListConversation = pydantic_queryset_creator(Conversation)
|
||||
PydListConversationMessage = pydantic_queryset_creator(ConversationMessage)
|
||||
Reference in New Issue
Block a user