diff --git a/simplylab/entity.py b/simplylab/entity.py index b99d800..bdfc43c 100644 --- a/simplylab/entity.py +++ b/simplylab/entity.py @@ -91,7 +91,7 @@ class User(BaseModel): class Message(BaseModel): id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id') - conversation_id: ObjectIdField = Field() + # conversation_id: ObjectIdField = Field() user_id: ObjectIdField = Field() type: MessageRoleType = Field() text: str = Field() @@ -118,44 +118,44 @@ class Message(BaseModel): ) -class Conversation(BaseModel): - id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id') - user_id: ObjectIdField = Field() - title: str = Field() - created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) - created_by: ObjectIdField = Field() - updated_at: Optional[datetime.datetime] = Field(default=None) - updated_by: Optional[ObjectIdField] = Field(default=None) - - model_config = ConfigDict( - populate_by_name=True, - json_schema_extra={ - "example": { - "_id": "xxx", - "user_id": "xxx", - "title": "xx", - "created_at": datetime.datetime.now(), - "created_by": "xxx", - "updated_at": None, - "updated_by": None, - } - }, - ) +# class Conversation(BaseModel): +# id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id') +# user_id: ObjectIdField = Field() +# title: str = Field() +# created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) +# created_by: ObjectIdField = Field() +# updated_at: Optional[datetime.datetime] = Field(default=None) +# updated_by: Optional[ObjectIdField] = Field(default=None) +# +# model_config = ConfigDict( +# populate_by_name=True, +# json_schema_extra={ +# "example": { +# "_id": "xxx", +# "user_id": "xxx", +# "title": "xx", +# "created_at": datetime.datetime.now(), +# "created_by": "xxx", +# "updated_at": None, +# "updated_by": None, +# } +# }, +# ) # === mongodb documents end === -class UserConversationMessages(BaseModel): - user_id: ObjectIdField = Field() - user_name: str = Field() - conversation_id: ObjectIdField = Field() - title: str = Field() - created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) - created_by: ObjectIdField = Field() - updated_at: Optional[datetime.datetime] = Field(default=None) - updated_by: Optional[ObjectIdField] = Field(default=None) - messages: list[Message] +# class UserConversationMessages(BaseModel): +# user_id: ObjectIdField = Field() +# user_name: str = Field() +# conversation_id: ObjectIdField = Field(default=None) +# title: str = Field() +# created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) +# created_by: ObjectIdField = Field() +# updated_at: Optional[datetime.datetime] = Field(default=None) +# updated_by: Optional[ObjectIdField] = Field(default=None) +# messages: list[Message] class Context(BaseModel): diff --git a/simplylab/providers/chat.py b/simplylab/providers/chat.py index 7949abe..3ea77a9 100644 --- a/simplylab/providers/chat.py +++ b/simplylab/providers/chat.py @@ -1,21 +1,62 @@ +import datetime from typing import Optional +import pymongo +from loguru import logger + from simplylab.database import Database -from simplylab.entity import UserConversationMessages +from simplylab.entity import ObjectIdField, Message class ChatProvider: def __init__(self, db: Database) -> None: self.db = db - async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool: - ... + async def check_user_message_limited_in_30_seconds(self, user_id: ObjectIdField) -> bool: + time_start = datetime.datetime.now() - datetime.timedelta(seconds=30) + count = await self.db.message.count_documents({ + "user_id": user_id, + "type": "user", + "created_at": {"$gte": time_start} + }) + if count > 3: + return True + return False - async def check_user_message_limited_in_daily(self, user_id: str) -> bool: - ... + async def check_user_message_limited_in_daily(self, user_id: ObjectIdField) -> bool: + today_start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) + count = await self.db.message.count_documents({ + "user_id": user_id, + "type": "user", + "created_at": {"$gte": today_start} + }) + if count > 20: + return True + return False - async def get_user_conversation_messages(self, user_id, conversation_id) -> Optional[UserConversationMessages]: - ... + async def add_chat_message(self, messages: list[Message]) -> int: + res = await self.db.message.insert_many(documents=[msg.model_dump(by_alias=True) for msg in messages]) + return len(res.inserted_ids) - async def get_user_chat_messages_count_today(self, user_id: str) -> int: - ... \ No newline at end of file + async def get_user_chat_messages(self, user_id: ObjectIdField, limit: int = 10) -> Optional[list[Message]]: + logger.debug(f"user_id={user_id}, limit={limit}") + messages = await (self.db.message.find({"user_id": user_id}) + .sort("created_at", pymongo.DESCENDING) + .limit(limit=limit).to_list(limit)) + logger.info(f"Found {len(messages)} messages") + if not messages: + return [] + msgs = [] + for message in messages: + msgs.append(Message(**message)) + return msgs + + async def get_user_chat_messages_count_today(self, user_id: ObjectIdField) -> int: + today_start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) + count = await self.db.message.count_documents({ + "user_id": user_id, + "type": "user", + "created_at": {"$gte": today_start} + }) + logger.info(f"count: {count}") + return count diff --git a/simplylab/providers/openrouter.py b/simplylab/providers/openrouter.py index 52c8e0f..53b16e7 100644 --- a/simplylab/providers/openrouter.py +++ b/simplylab/providers/openrouter.py @@ -30,5 +30,6 @@ class OpenRouterProvider: }, ], ) - print(completion.choices[0].message.content) + logger.debug(f"request content: {content}") + logger.debug(f"response content: {completion.choices[0].message.content}") return completion.choices[0].message.content diff --git a/simplylab/providers/user.py b/simplylab/providers/user.py index 87d976a..e230219 100644 --- a/simplylab/providers/user.py +++ b/simplylab/providers/user.py @@ -1,6 +1,8 @@ import datetime from typing import Optional +from loguru import logger + from simplylab.database import Database from simplylab.entity import User @@ -14,6 +16,7 @@ class UserProvider: user = await self.db.user.find_one({"name": user_name}) if not user: user = User(name=user_name) - res = await self.db.user.insert_one(user.model_dump()) + res = await self.db.user.insert_one(user.model_dump(by_alias=True)) user = await self.db.user.find_one({"_id": res.inserted_id}) + logger.debug(f"user: {user}") return user diff --git a/simplylab/services/chat.py b/simplylab/services/chat.py index 9d65197..2c417a8 100644 --- a/simplylab/services/chat.py +++ b/simplylab/services/chat.py @@ -1,7 +1,10 @@ from typing import Any +from loguru import logger + from simplylab.entity import GetAiChatResponseInput, GetUserChatHistoryInput, GetChatStatusTodayInput, UserChatMessage, \ - GetChatStatusTodayOutput, GetAiChatResponseOutput, GetUserChatHistoryOutput, Context + GetChatStatusTodayOutput, GetAiChatResponseOutput, GetUserChatHistoryOutput, Context, Message, MessageRoleType +from simplylab.error import MessageLimitedInDailyError from simplylab.providers import Providers @@ -11,15 +14,35 @@ class ChatService: self.pvd = provider async def get_ai_chat_response(self, req: GetAiChatResponseInput) -> GetAiChatResponseOutput: - message = req.message - response_content = await self.pvd.openrouter.chat(content=message) + request_content = req.message + # todo: request content middle out + response_content = await self.pvd.openrouter.chat(content=request_content) + user_message = Message( + user_id=self.ctx.user.id, + type=MessageRoleType.User, + text=request_content, + created_by=self.ctx.user.id, + ) + ai_message = Message( + user_id=self.ctx.user.id, + type=MessageRoleType.Ai, + text=response_content, + created_by=self.ctx.user.id, + ) + messages = [user_message, ai_message] + count = await self.pvd.chat.add_chat_message(messages=messages) + logger.debug(f"Added {count} chat messages") res = GetAiChatResponseOutput(response=response_content) return res async def get_user_chat_history(self, req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput: - res = [UserChatMessage(type="user", text="echo")] + messages = await self.pvd.chat.get_user_chat_messages(user_id=self.ctx.user.id, limit=req.last_n) + res = [] + for message in messages: + res.append(UserChatMessage(type=message.type.value, text=message.text)) return res async def get_chat_status_today(self, req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput: - res = GetChatStatusTodayOutput(user_name=req.user_name, chat_cnt=0) + count = await self.pvd.chat.get_user_chat_messages_count_today(user_id=self.ctx.user.id) + res = GetChatStatusTodayOutput(user_name=self.ctx.user.name, chat_cnt=count) return res