can chat and history and count now
This commit is contained in:
parent
23a83fcc74
commit
0e8f130149
|
@ -91,7 +91,7 @@ class User(BaseModel):
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id')
|
id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id')
|
||||||
conversation_id: ObjectIdField = Field()
|
# conversation_id: ObjectIdField = Field()
|
||||||
user_id: ObjectIdField = Field()
|
user_id: ObjectIdField = Field()
|
||||||
type: MessageRoleType = Field()
|
type: MessageRoleType = Field()
|
||||||
text: str = Field()
|
text: str = Field()
|
||||||
|
@ -118,44 +118,44 @@ class Message(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Conversation(BaseModel):
|
# class Conversation(BaseModel):
|
||||||
id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id')
|
# id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id')
|
||||||
user_id: ObjectIdField = Field()
|
# user_id: ObjectIdField = Field()
|
||||||
title: str = Field()
|
# title: str = Field()
|
||||||
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
|
# created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
|
||||||
created_by: ObjectIdField = Field()
|
# created_by: ObjectIdField = Field()
|
||||||
updated_at: Optional[datetime.datetime] = Field(default=None)
|
# updated_at: Optional[datetime.datetime] = Field(default=None)
|
||||||
updated_by: Optional[ObjectIdField] = Field(default=None)
|
# updated_by: Optional[ObjectIdField] = Field(default=None)
|
||||||
|
#
|
||||||
model_config = ConfigDict(
|
# model_config = ConfigDict(
|
||||||
populate_by_name=True,
|
# populate_by_name=True,
|
||||||
json_schema_extra={
|
# json_schema_extra={
|
||||||
"example": {
|
# "example": {
|
||||||
"_id": "xxx",
|
# "_id": "xxx",
|
||||||
"user_id": "xxx",
|
# "user_id": "xxx",
|
||||||
"title": "xx",
|
# "title": "xx",
|
||||||
"created_at": datetime.datetime.now(),
|
# "created_at": datetime.datetime.now(),
|
||||||
"created_by": "xxx",
|
# "created_by": "xxx",
|
||||||
"updated_at": None,
|
# "updated_at": None,
|
||||||
"updated_by": None,
|
# "updated_by": None,
|
||||||
}
|
# }
|
||||||
},
|
# },
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
# === mongodb documents end ===
|
# === mongodb documents end ===
|
||||||
|
|
||||||
|
|
||||||
class UserConversationMessages(BaseModel):
|
# class UserConversationMessages(BaseModel):
|
||||||
user_id: ObjectIdField = Field()
|
# user_id: ObjectIdField = Field()
|
||||||
user_name: str = Field()
|
# user_name: str = Field()
|
||||||
conversation_id: ObjectIdField = Field()
|
# conversation_id: ObjectIdField = Field(default=None)
|
||||||
title: str = Field()
|
# title: str = Field()
|
||||||
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
|
# created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
|
||||||
created_by: ObjectIdField = Field()
|
# created_by: ObjectIdField = Field()
|
||||||
updated_at: Optional[datetime.datetime] = Field(default=None)
|
# updated_at: Optional[datetime.datetime] = Field(default=None)
|
||||||
updated_by: Optional[ObjectIdField] = Field(default=None)
|
# updated_by: Optional[ObjectIdField] = Field(default=None)
|
||||||
messages: list[Message]
|
# messages: list[Message]
|
||||||
|
|
||||||
|
|
||||||
class Context(BaseModel):
|
class Context(BaseModel):
|
||||||
|
|
|
@ -1,21 +1,62 @@
|
||||||
|
import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import pymongo
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from simplylab.database import Database
|
from simplylab.database import Database
|
||||||
from simplylab.entity import UserConversationMessages
|
from simplylab.entity import ObjectIdField, Message
|
||||||
|
|
||||||
|
|
||||||
class ChatProvider:
|
class ChatProvider:
|
||||||
def __init__(self, db: Database) -> None:
|
def __init__(self, db: Database) -> None:
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool:
|
async def check_user_message_limited_in_30_seconds(self, user_id: ObjectIdField) -> bool:
|
||||||
...
|
time_start = datetime.datetime.now() - datetime.timedelta(seconds=30)
|
||||||
|
count = await self.db.message.count_documents({
|
||||||
|
"user_id": user_id,
|
||||||
|
"type": "user",
|
||||||
|
"created_at": {"$gte": time_start}
|
||||||
|
})
|
||||||
|
if count > 3:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
async def check_user_message_limited_in_daily(self, user_id: str) -> bool:
|
async def check_user_message_limited_in_daily(self, user_id: ObjectIdField) -> bool:
|
||||||
...
|
today_start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
count = await self.db.message.count_documents({
|
||||||
|
"user_id": user_id,
|
||||||
|
"type": "user",
|
||||||
|
"created_at": {"$gte": today_start}
|
||||||
|
})
|
||||||
|
if count > 20:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
async def get_user_conversation_messages(self, user_id, conversation_id) -> Optional[UserConversationMessages]:
|
async def add_chat_message(self, messages: list[Message]) -> int:
|
||||||
...
|
res = await self.db.message.insert_many(documents=[msg.model_dump(by_alias=True) for msg in messages])
|
||||||
|
return len(res.inserted_ids)
|
||||||
|
|
||||||
async def get_user_chat_messages_count_today(self, user_id: str) -> int:
|
async def get_user_chat_messages(self, user_id: ObjectIdField, limit: int = 10) -> Optional[list[Message]]:
|
||||||
...
|
logger.debug(f"user_id={user_id}, limit={limit}")
|
||||||
|
messages = await (self.db.message.find({"user_id": user_id})
|
||||||
|
.sort("created_at", pymongo.DESCENDING)
|
||||||
|
.limit(limit=limit).to_list(limit))
|
||||||
|
logger.info(f"Found {len(messages)} messages")
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
msgs = []
|
||||||
|
for message in messages:
|
||||||
|
msgs.append(Message(**message))
|
||||||
|
return msgs
|
||||||
|
|
||||||
|
async def get_user_chat_messages_count_today(self, user_id: ObjectIdField) -> int:
|
||||||
|
today_start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
count = await self.db.message.count_documents({
|
||||||
|
"user_id": user_id,
|
||||||
|
"type": "user",
|
||||||
|
"created_at": {"$gte": today_start}
|
||||||
|
})
|
||||||
|
logger.info(f"count: {count}")
|
||||||
|
return count
|
||||||
|
|
|
@ -30,5 +30,6 @@ class OpenRouterProvider:
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
print(completion.choices[0].message.content)
|
logger.debug(f"request content: {content}")
|
||||||
|
logger.debug(f"response content: {completion.choices[0].message.content}")
|
||||||
return completion.choices[0].message.content
|
return completion.choices[0].message.content
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from simplylab.database import Database
|
from simplylab.database import Database
|
||||||
from simplylab.entity import User
|
from simplylab.entity import User
|
||||||
|
|
||||||
|
@ -14,6 +16,7 @@ class UserProvider:
|
||||||
user = await self.db.user.find_one({"name": user_name})
|
user = await self.db.user.find_one({"name": user_name})
|
||||||
if not user:
|
if not user:
|
||||||
user = User(name=user_name)
|
user = User(name=user_name)
|
||||||
res = await self.db.user.insert_one(user.model_dump())
|
res = await self.db.user.insert_one(user.model_dump(by_alias=True))
|
||||||
user = await self.db.user.find_one({"_id": res.inserted_id})
|
user = await self.db.user.find_one({"_id": res.inserted_id})
|
||||||
|
logger.debug(f"user: {user}")
|
||||||
return user
|
return user
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from simplylab.entity import GetAiChatResponseInput, GetUserChatHistoryInput, GetChatStatusTodayInput, UserChatMessage, \
|
from simplylab.entity import GetAiChatResponseInput, GetUserChatHistoryInput, GetChatStatusTodayInput, UserChatMessage, \
|
||||||
GetChatStatusTodayOutput, GetAiChatResponseOutput, GetUserChatHistoryOutput, Context
|
GetChatStatusTodayOutput, GetAiChatResponseOutput, GetUserChatHistoryOutput, Context, Message, MessageRoleType
|
||||||
|
from simplylab.error import MessageLimitedInDailyError
|
||||||
from simplylab.providers import Providers
|
from simplylab.providers import Providers
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,15 +14,35 @@ class ChatService:
|
||||||
self.pvd = provider
|
self.pvd = provider
|
||||||
|
|
||||||
async def get_ai_chat_response(self, req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
|
async def get_ai_chat_response(self, req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
|
||||||
message = req.message
|
request_content = req.message
|
||||||
response_content = await self.pvd.openrouter.chat(content=message)
|
# todo: request content middle out
|
||||||
|
response_content = await self.pvd.openrouter.chat(content=request_content)
|
||||||
|
user_message = Message(
|
||||||
|
user_id=self.ctx.user.id,
|
||||||
|
type=MessageRoleType.User,
|
||||||
|
text=request_content,
|
||||||
|
created_by=self.ctx.user.id,
|
||||||
|
)
|
||||||
|
ai_message = Message(
|
||||||
|
user_id=self.ctx.user.id,
|
||||||
|
type=MessageRoleType.Ai,
|
||||||
|
text=response_content,
|
||||||
|
created_by=self.ctx.user.id,
|
||||||
|
)
|
||||||
|
messages = [user_message, ai_message]
|
||||||
|
count = await self.pvd.chat.add_chat_message(messages=messages)
|
||||||
|
logger.debug(f"Added {count} chat messages")
|
||||||
res = GetAiChatResponseOutput(response=response_content)
|
res = GetAiChatResponseOutput(response=response_content)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def get_user_chat_history(self, req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
|
async def get_user_chat_history(self, req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
|
||||||
res = [UserChatMessage(type="user", text="echo")]
|
messages = await self.pvd.chat.get_user_chat_messages(user_id=self.ctx.user.id, limit=req.last_n)
|
||||||
|
res = []
|
||||||
|
for message in messages:
|
||||||
|
res.append(UserChatMessage(type=message.type.value, text=message.text))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def get_chat_status_today(self, req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
|
async def get_chat_status_today(self, req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
|
||||||
res = GetChatStatusTodayOutput(user_name=req.user_name, chat_cnt=0)
|
count = await self.pvd.chat.get_user_chat_messages_count_today(user_id=self.ctx.user.id)
|
||||||
|
res = GetChatStatusTodayOutput(user_name=self.ctx.user.name, chat_cnt=count)
|
||||||
return res
|
return res
|
||||||
|
|
Loading…
Reference in New Issue