can chat and history and count now

This commit is contained in:
Jeremy Yin 2024-03-15 22:34:25 +08:00
parent 23a83fcc74
commit 0e8f130149
5 changed files with 118 additions and 50 deletions

View File

@ -91,7 +91,7 @@ class User(BaseModel):
class Message(BaseModel):
id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id')
conversation_id: ObjectIdField = Field()
# conversation_id: ObjectIdField = Field()
user_id: ObjectIdField = Field()
type: MessageRoleType = Field()
text: str = Field()
@ -118,44 +118,44 @@ class Message(BaseModel):
)
class Conversation(BaseModel):
id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id')
user_id: ObjectIdField = Field()
title: str = Field()
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
created_by: ObjectIdField = Field()
updated_at: Optional[datetime.datetime] = Field(default=None)
updated_by: Optional[ObjectIdField] = Field(default=None)
model_config = ConfigDict(
populate_by_name=True,
json_schema_extra={
"example": {
"_id": "xxx",
"user_id": "xxx",
"title": "xx",
"created_at": datetime.datetime.now(),
"created_by": "xxx",
"updated_at": None,
"updated_by": None,
}
},
)
# class Conversation(BaseModel):
# id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id')
# user_id: ObjectIdField = Field()
# title: str = Field()
# created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
# created_by: ObjectIdField = Field()
# updated_at: Optional[datetime.datetime] = Field(default=None)
# updated_by: Optional[ObjectIdField] = Field(default=None)
#
# model_config = ConfigDict(
# populate_by_name=True,
# json_schema_extra={
# "example": {
# "_id": "xxx",
# "user_id": "xxx",
# "title": "xx",
# "created_at": datetime.datetime.now(),
# "created_by": "xxx",
# "updated_at": None,
# "updated_by": None,
# }
# },
# )
# === mongodb documents end ===
class UserConversationMessages(BaseModel):
user_id: ObjectIdField = Field()
user_name: str = Field()
conversation_id: ObjectIdField = Field()
title: str = Field()
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
created_by: ObjectIdField = Field()
updated_at: Optional[datetime.datetime] = Field(default=None)
updated_by: Optional[ObjectIdField] = Field(default=None)
messages: list[Message]
# class UserConversationMessages(BaseModel):
# user_id: ObjectIdField = Field()
# user_name: str = Field()
# conversation_id: ObjectIdField = Field(default=None)
# title: str = Field()
# created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
# created_by: ObjectIdField = Field()
# updated_at: Optional[datetime.datetime] = Field(default=None)
# updated_by: Optional[ObjectIdField] = Field(default=None)
# messages: list[Message]
class Context(BaseModel):

View File

@ -1,21 +1,62 @@
import datetime
from typing import Optional
import pymongo
from loguru import logger
from simplylab.database import Database
from simplylab.entity import UserConversationMessages
from simplylab.entity import ObjectIdField, Message
class ChatProvider:
def __init__(self, db: Database) -> None:
self.db = db
async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool:
...
async def check_user_message_limited_in_30_seconds(self, user_id: ObjectIdField) -> bool:
time_start = datetime.datetime.now() - datetime.timedelta(seconds=30)
count = await self.db.message.count_documents({
"user_id": user_id,
"type": "user",
"created_at": {"$gte": time_start}
})
if count > 3:
return True
return False
async def check_user_message_limited_in_daily(self, user_id: str) -> bool:
...
async def check_user_message_limited_in_daily(self, user_id: ObjectIdField) -> bool:
today_start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
count = await self.db.message.count_documents({
"user_id": user_id,
"type": "user",
"created_at": {"$gte": today_start}
})
if count > 20:
return True
return False
async def get_user_conversation_messages(self, user_id, conversation_id) -> Optional[UserConversationMessages]:
...
async def add_chat_message(self, messages: list[Message]) -> int:
res = await self.db.message.insert_many(documents=[msg.model_dump(by_alias=True) for msg in messages])
return len(res.inserted_ids)
async def get_user_chat_messages_count_today(self, user_id: str) -> int:
...
async def get_user_chat_messages(self, user_id: ObjectIdField, limit: int = 10) -> Optional[list[Message]]:
logger.debug(f"user_id={user_id}, limit={limit}")
messages = await (self.db.message.find({"user_id": user_id})
.sort("created_at", pymongo.DESCENDING)
.limit(limit=limit).to_list(limit))
logger.info(f"Found {len(messages)} messages")
if not messages:
return []
msgs = []
for message in messages:
msgs.append(Message(**message))
return msgs
async def get_user_chat_messages_count_today(self, user_id: ObjectIdField) -> int:
today_start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
count = await self.db.message.count_documents({
"user_id": user_id,
"type": "user",
"created_at": {"$gte": today_start}
})
logger.info(f"count: {count}")
return count

View File

@ -30,5 +30,6 @@ class OpenRouterProvider:
},
],
)
print(completion.choices[0].message.content)
logger.debug(f"request content: {content}")
logger.debug(f"response content: {completion.choices[0].message.content}")
return completion.choices[0].message.content

View File

@ -1,6 +1,8 @@
import datetime
from typing import Optional
from loguru import logger
from simplylab.database import Database
from simplylab.entity import User
@ -14,6 +16,7 @@ class UserProvider:
user = await self.db.user.find_one({"name": user_name})
if not user:
user = User(name=user_name)
res = await self.db.user.insert_one(user.model_dump())
res = await self.db.user.insert_one(user.model_dump(by_alias=True))
user = await self.db.user.find_one({"_id": res.inserted_id})
logger.debug(f"user: {user}")
return user

View File

@ -1,7 +1,10 @@
from typing import Any
from loguru import logger
from simplylab.entity import GetAiChatResponseInput, GetUserChatHistoryInput, GetChatStatusTodayInput, UserChatMessage, \
GetChatStatusTodayOutput, GetAiChatResponseOutput, GetUserChatHistoryOutput, Context
GetChatStatusTodayOutput, GetAiChatResponseOutput, GetUserChatHistoryOutput, Context, Message, MessageRoleType
from simplylab.error import MessageLimitedInDailyError
from simplylab.providers import Providers
@ -11,15 +14,35 @@ class ChatService:
self.pvd = provider
async def get_ai_chat_response(self, req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
message = req.message
response_content = await self.pvd.openrouter.chat(content=message)
request_content = req.message
# todo: request content middle out
response_content = await self.pvd.openrouter.chat(content=request_content)
user_message = Message(
user_id=self.ctx.user.id,
type=MessageRoleType.User,
text=request_content,
created_by=self.ctx.user.id,
)
ai_message = Message(
user_id=self.ctx.user.id,
type=MessageRoleType.Ai,
text=response_content,
created_by=self.ctx.user.id,
)
messages = [user_message, ai_message]
count = await self.pvd.chat.add_chat_message(messages=messages)
logger.debug(f"Added {count} chat messages")
res = GetAiChatResponseOutput(response=response_content)
return res
async def get_user_chat_history(self, req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
res = [UserChatMessage(type="user", text="echo")]
messages = await self.pvd.chat.get_user_chat_messages(user_id=self.ctx.user.id, limit=req.last_n)
res = []
for message in messages:
res.append(UserChatMessage(type=message.type.value, text=message.text))
return res
async def get_chat_status_today(self, req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
res = GetChatStatusTodayOutput(user_name=req.user_name, chat_cnt=0)
count = await self.pvd.chat.get_user_chat_messages_count_today(user_id=self.ctx.user.id)
res = GetChatStatusTodayOutput(user_name=self.ctx.user.name, chat_cnt=count)
return res