can chat and history and count now

This commit is contained in:
Jeremy Yin 2024-03-15 22:34:25 +08:00
parent 23a83fcc74
commit 0e8f130149
5 changed files with 118 additions and 50 deletions

View File

@ -91,7 +91,7 @@ class User(BaseModel):
class Message(BaseModel): class Message(BaseModel):
id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id') id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id')
conversation_id: ObjectIdField = Field() # conversation_id: ObjectIdField = Field()
user_id: ObjectIdField = Field() user_id: ObjectIdField = Field()
type: MessageRoleType = Field() type: MessageRoleType = Field()
text: str = Field() text: str = Field()
@ -118,44 +118,44 @@ class Message(BaseModel):
) )
class Conversation(BaseModel): # class Conversation(BaseModel):
id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id') # id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id')
user_id: ObjectIdField = Field() # user_id: ObjectIdField = Field()
title: str = Field() # title: str = Field()
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) # created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
created_by: ObjectIdField = Field() # created_by: ObjectIdField = Field()
updated_at: Optional[datetime.datetime] = Field(default=None) # updated_at: Optional[datetime.datetime] = Field(default=None)
updated_by: Optional[ObjectIdField] = Field(default=None) # updated_by: Optional[ObjectIdField] = Field(default=None)
#
model_config = ConfigDict( # model_config = ConfigDict(
populate_by_name=True, # populate_by_name=True,
json_schema_extra={ # json_schema_extra={
"example": { # "example": {
"_id": "xxx", # "_id": "xxx",
"user_id": "xxx", # "user_id": "xxx",
"title": "xx", # "title": "xx",
"created_at": datetime.datetime.now(), # "created_at": datetime.datetime.now(),
"created_by": "xxx", # "created_by": "xxx",
"updated_at": None, # "updated_at": None,
"updated_by": None, # "updated_by": None,
} # }
}, # },
) # )
# === mongodb documents end === # === mongodb documents end ===
class UserConversationMessages(BaseModel): # class UserConversationMessages(BaseModel):
user_id: ObjectIdField = Field() # user_id: ObjectIdField = Field()
user_name: str = Field() # user_name: str = Field()
conversation_id: ObjectIdField = Field() # conversation_id: ObjectIdField = Field(default=None)
title: str = Field() # title: str = Field()
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) # created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
created_by: ObjectIdField = Field() # created_by: ObjectIdField = Field()
updated_at: Optional[datetime.datetime] = Field(default=None) # updated_at: Optional[datetime.datetime] = Field(default=None)
updated_by: Optional[ObjectIdField] = Field(default=None) # updated_by: Optional[ObjectIdField] = Field(default=None)
messages: list[Message] # messages: list[Message]
class Context(BaseModel): class Context(BaseModel):

View File

@ -1,21 +1,62 @@
import datetime
from typing import Optional from typing import Optional
import pymongo
from loguru import logger
from simplylab.database import Database from simplylab.database import Database
from simplylab.entity import UserConversationMessages from simplylab.entity import ObjectIdField, Message
class ChatProvider: class ChatProvider:
def __init__(self, db: Database) -> None: def __init__(self, db: Database) -> None:
self.db = db self.db = db
async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool: async def check_user_message_limited_in_30_seconds(self, user_id: ObjectIdField) -> bool:
... time_start = datetime.datetime.now() - datetime.timedelta(seconds=30)
count = await self.db.message.count_documents({
"user_id": user_id,
"type": "user",
"created_at": {"$gte": time_start}
})
if count > 3:
return True
return False
async def check_user_message_limited_in_daily(self, user_id: str) -> bool: async def check_user_message_limited_in_daily(self, user_id: ObjectIdField) -> bool:
... today_start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
count = await self.db.message.count_documents({
"user_id": user_id,
"type": "user",
"created_at": {"$gte": today_start}
})
if count > 20:
return True
return False
async def get_user_conversation_messages(self, user_id, conversation_id) -> Optional[UserConversationMessages]: async def add_chat_message(self, messages: list[Message]) -> int:
... res = await self.db.message.insert_many(documents=[msg.model_dump(by_alias=True) for msg in messages])
return len(res.inserted_ids)
async def get_user_chat_messages_count_today(self, user_id: str) -> int: async def get_user_chat_messages(self, user_id: ObjectIdField, limit: int = 10) -> Optional[list[Message]]:
... logger.debug(f"user_id={user_id}, limit={limit}")
messages = await (self.db.message.find({"user_id": user_id})
.sort("created_at", pymongo.DESCENDING)
.limit(limit=limit).to_list(limit))
logger.info(f"Found {len(messages)} messages")
if not messages:
return []
msgs = []
for message in messages:
msgs.append(Message(**message))
return msgs
async def get_user_chat_messages_count_today(self, user_id: ObjectIdField) -> int:
today_start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
count = await self.db.message.count_documents({
"user_id": user_id,
"type": "user",
"created_at": {"$gte": today_start}
})
logger.info(f"count: {count}")
return count

View File

@ -30,5 +30,6 @@ class OpenRouterProvider:
}, },
], ],
) )
print(completion.choices[0].message.content) logger.debug(f"request content: {content}")
logger.debug(f"response content: {completion.choices[0].message.content}")
return completion.choices[0].message.content return completion.choices[0].message.content

View File

@ -1,6 +1,8 @@
import datetime import datetime
from typing import Optional from typing import Optional
from loguru import logger
from simplylab.database import Database from simplylab.database import Database
from simplylab.entity import User from simplylab.entity import User
@ -14,6 +16,7 @@ class UserProvider:
user = await self.db.user.find_one({"name": user_name}) user = await self.db.user.find_one({"name": user_name})
if not user: if not user:
user = User(name=user_name) user = User(name=user_name)
res = await self.db.user.insert_one(user.model_dump()) res = await self.db.user.insert_one(user.model_dump(by_alias=True))
user = await self.db.user.find_one({"_id": res.inserted_id}) user = await self.db.user.find_one({"_id": res.inserted_id})
logger.debug(f"user: {user}")
return user return user

View File

@ -1,7 +1,10 @@
from typing import Any from typing import Any
from loguru import logger
from simplylab.entity import GetAiChatResponseInput, GetUserChatHistoryInput, GetChatStatusTodayInput, UserChatMessage, \ from simplylab.entity import GetAiChatResponseInput, GetUserChatHistoryInput, GetChatStatusTodayInput, UserChatMessage, \
GetChatStatusTodayOutput, GetAiChatResponseOutput, GetUserChatHistoryOutput, Context GetChatStatusTodayOutput, GetAiChatResponseOutput, GetUserChatHistoryOutput, Context, Message, MessageRoleType
from simplylab.error import MessageLimitedInDailyError
from simplylab.providers import Providers from simplylab.providers import Providers
@ -11,15 +14,35 @@ class ChatService:
self.pvd = provider self.pvd = provider
async def get_ai_chat_response(self, req: GetAiChatResponseInput) -> GetAiChatResponseOutput: async def get_ai_chat_response(self, req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
message = req.message request_content = req.message
response_content = await self.pvd.openrouter.chat(content=message) # todo: request content middle out
response_content = await self.pvd.openrouter.chat(content=request_content)
user_message = Message(
user_id=self.ctx.user.id,
type=MessageRoleType.User,
text=request_content,
created_by=self.ctx.user.id,
)
ai_message = Message(
user_id=self.ctx.user.id,
type=MessageRoleType.Ai,
text=response_content,
created_by=self.ctx.user.id,
)
messages = [user_message, ai_message]
count = await self.pvd.chat.add_chat_message(messages=messages)
logger.debug(f"Added {count} chat messages")
res = GetAiChatResponseOutput(response=response_content) res = GetAiChatResponseOutput(response=response_content)
return res return res
async def get_user_chat_history(self, req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput: async def get_user_chat_history(self, req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
res = [UserChatMessage(type="user", text="echo")] messages = await self.pvd.chat.get_user_chat_messages(user_id=self.ctx.user.id, limit=req.last_n)
res = []
for message in messages:
res.append(UserChatMessage(type=message.type.value, text=message.text))
return res return res
async def get_chat_status_today(self, req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput: async def get_chat_status_today(self, req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
res = GetChatStatusTodayOutput(user_name=req.user_name, chat_cnt=0) count = await self.pvd.chat.get_user_chat_messages_count_today(user_id=self.ctx.user.id)
res = GetChatStatusTodayOutput(user_name=self.ctx.user.name, chat_cnt=count)
return res return res