define user chat func and context be user

This commit is contained in:
Jeremy Yin 2024-03-15 17:22:29 +08:00
parent 14c6b5933c
commit 8edf8f9fd6
6 changed files with 120 additions and 8 deletions

View File

@ -1,10 +1,10 @@
import datetime
from enum import Enum
from bson import ObjectId
from pydantic import BaseModel from pydantic import BaseModel
class Context(BaseModel):
user_name: str
class GetAiChatResponseInput(BaseModel): class GetAiChatResponseInput(BaseModel):
message: str message: str
user_name: str user_name: str
@ -34,3 +34,57 @@ class GetChatStatusTodayInput(BaseModel):
class GetChatStatusTodayOutput(BaseModel): class GetChatStatusTodayOutput(BaseModel):
user_name: str user_name: str
chat_cnt: int chat_cnt: int
# === mongodb documents start ===
class MessageRoleType(str, Enum):
User = "user"
Ai = "ai"
class User(BaseModel):
_id: ObjectId
name: str
created_at: datetime.datetime
updated_at: datetime.datetime
class Message(BaseModel):
_id: ObjectId
conversation_id: ObjectId
user_id: ObjectId
type: MessageRoleType
text: str
created_at: datetime.datetime
created_by: ObjectId
updated_at: datetime.datetime
updated_by: ObjectId
class Conversation(BaseModel):
_id: ObjectId
user_id: ObjectId
title: str
created_at: datetime.datetime
created_by: ObjectId
updated_at: datetime.datetime
updated_by: ObjectId
# === mongodb documents end ===
class UserConversationMessages(BaseModel):
user_id: ObjectId
user_name: str
conversation_id: ObjectId
title: str
created_at: datetime.datetime
created_by: ObjectId
updated_at: datetime.datetime
updated_by: ObjectId
messages: list[Message]
class Context(BaseModel):
user: User

View File

@ -14,3 +14,10 @@ class MessageLimitedInDailyError(Error):
def __init__(self): def __init__(self):
super().__init__("20 messages limited in daily") super().__init__("20 messages limited in daily")
self.status_code = 401 self.status_code = 401
class UserNotFoundError(Error):
def __init__(self, user_name: str):
super().__init__("user {} not found".format(user_name))
self.status_code = 401

View File

@ -10,7 +10,7 @@ from simplylab.entity import GetUserChatHistoryInput
from simplylab.entity import GetUserChatHistoryOutput from simplylab.entity import GetUserChatHistoryOutput
from simplylab.entity import GetChatStatusTodayInput from simplylab.entity import GetChatStatusTodayInput
from simplylab.entity import GetChatStatusTodayOutput from simplylab.entity import GetChatStatusTodayOutput
from simplylab.error import Error from simplylab.error import Error, UserNotFoundError
from simplylab.providers import Providers from simplylab.providers import Providers
from simplylab.services import Services from simplylab.services import Services
@ -33,8 +33,11 @@ async def hi():
@app.post("/api/v1/get_ai_chat_response") @app.post("/api/v1/get_ai_chat_response")
async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponseOutput: async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
ctx = Context(user_name=req.user_name)
pvd = Providers() pvd = Providers()
user = await pvd.user.get_user_by_name(req.user_name)
if not user:
raise UserNotFoundError(req.user_name)
ctx = Context(user=user)
svc = Services(ctx, pvd) svc = Services(ctx, pvd)
res = await svc.chat.get_ai_chat_response(req) res = await svc.chat.get_ai_chat_response(req)
return res return res
@ -42,8 +45,11 @@ async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponse
@app.post("/api/v1/get_user_chat_history") @app.post("/api/v1/get_user_chat_history")
async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput: async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
ctx = Context(user_name=req.user_name)
pvd = Providers() pvd = Providers()
user = await pvd.user.get_user_by_name(req.user_name)
if not user:
raise UserNotFoundError(req.user_name)
ctx = Context(user=user)
svc = Services(ctx, pvd) svc = Services(ctx, pvd)
res = await svc.chat.get_user_chat_history(req) res = await svc.chat.get_user_chat_history(req)
return res return res
@ -51,8 +57,11 @@ async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHist
@app.post("/api/v1/get_chat_status_today") @app.post("/api/v1/get_chat_status_today")
async def get_chat_status_today(req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput: async def get_chat_status_today(req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
ctx = Context(user_name=req.user_name)
pvd = Providers() pvd = Providers()
user = await pvd.user.get_user_by_name(req.user_name)
if not user:
raise UserNotFoundError(req.user_name)
ctx = Context(user=user)
svc = Services(ctx, pvd) svc = Services(ctx, pvd)
res = await svc.chat.get_chat_status_today(req) res = await svc.chat.get_chat_status_today(req)
return res return res

View File

@ -1,4 +1,6 @@
from simplylab.providers.chat import ChatProvider
from simplylab.providers.openrouter import OpenRouterProvider from simplylab.providers.openrouter import OpenRouterProvider
from simplylab.providers.user import UserProvider
class Providers: class Providers:
@ -9,3 +11,11 @@ class Providers:
@property @property
def openrouter(self): def openrouter(self):
return OpenRouterProvider() return OpenRouterProvider()
@property
def user(self):
return UserProvider()
@property
def chat(self):
return ChatProvider()

View File

@ -0,0 +1,20 @@
from typing import Optional
from simplylab.entity import UserConversationMessages
class ChatProvider:
def __init__(self):
...
async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool:
...
async def check_user_message_limited_in_daily(self, user_id: str) -> bool:
...
async def get_user_conversation_messages(self, user_id, conversation_id) -> Optional[UserConversationMessages]:
...
async def get_user_chat_messages_count_today(self, user_id: str) -> int:
...

View File

@ -0,0 +1,12 @@
from typing import Optional
from simplylab.entity import User
class UserProvider:
def __init__(self):
...
def get_user_by_name(self, user_name: str) -> Optional[User]:
...