From 8edf8f9fd6f13db94ab705aff39bbb4c45519380 Mon Sep 17 00:00:00 2001 From: Jeremy Yin Date: Fri, 15 Mar 2024 17:22:29 +0800 Subject: [PATCH] define user chat func and context be user --- simplylab/entity.py | 62 ++++++++++++++++++++++++++++++--- simplylab/error.py | 7 ++++ simplylab/main.py | 17 ++++++--- simplylab/providers/__init__.py | 10 ++++++ simplylab/providers/chat.py | 20 +++++++++++ simplylab/providers/user.py | 12 +++++++ 6 files changed, 120 insertions(+), 8 deletions(-) create mode 100644 simplylab/providers/chat.py create mode 100644 simplylab/providers/user.py diff --git a/simplylab/entity.py b/simplylab/entity.py index 2adfe80..12085fe 100644 --- a/simplylab/entity.py +++ b/simplylab/entity.py @@ -1,10 +1,10 @@ +import datetime +from enum import Enum + +from bson import ObjectId from pydantic import BaseModel -class Context(BaseModel): - user_name: str - - class GetAiChatResponseInput(BaseModel): message: str user_name: str @@ -34,3 +34,57 @@ class GetChatStatusTodayInput(BaseModel): class GetChatStatusTodayOutput(BaseModel): user_name: str chat_cnt: int + + +# === mongodb documents start === +class MessageRoleType(str, Enum): + User = "user" + Ai = "ai" + + +class User(BaseModel): + _id: ObjectId + name: str + created_at: datetime.datetime + updated_at: datetime.datetime + + +class Message(BaseModel): + _id: ObjectId + conversation_id: ObjectId + user_id: ObjectId + type: MessageRoleType + text: str + created_at: datetime.datetime + created_by: ObjectId + updated_at: datetime.datetime + updated_by: ObjectId + + +class Conversation(BaseModel): + _id: ObjectId + user_id: ObjectId + title: str + created_at: datetime.datetime + created_by: ObjectId + updated_at: datetime.datetime + updated_by: ObjectId + + +# === mongodb documents end === + + +class UserConversationMessages(BaseModel): + user_id: ObjectId + user_name: str + conversation_id: ObjectId + title: str + created_at: datetime.datetime + created_by: ObjectId + updated_at: datetime.datetime + updated_by: ObjectId + messages: list[Message] + + +class Context(BaseModel): + user: User diff --git a/simplylab/error.py b/simplylab/error.py index 0344d70..ffe4560 100644 --- a/simplylab/error.py +++ b/simplylab/error.py @@ -14,3 +14,10 @@ class MessageLimitedInDailyError(Error): def __init__(self): super().__init__("20 messages limited in daily") self.status_code = 401 + + +class UserNotFoundError(Error): + + def __init__(self, user_name: str): + super().__init__("user {} not found".format(user_name)) + self.status_code = 401 diff --git a/simplylab/main.py b/simplylab/main.py index db8c431..c9d6758 100644 --- a/simplylab/main.py +++ b/simplylab/main.py @@ -10,7 +10,7 @@ from simplylab.entity import GetUserChatHistoryInput from simplylab.entity import GetUserChatHistoryOutput from simplylab.entity import GetChatStatusTodayInput from simplylab.entity import GetChatStatusTodayOutput -from simplylab.error import Error +from simplylab.error import Error, UserNotFoundError from simplylab.providers import Providers from simplylab.services import Services @@ -33,8 +33,11 @@ async def hi(): @app.post("/api/v1/get_ai_chat_response") async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponseOutput: - ctx = Context(user_name=req.user_name) pvd = Providers() + user = await pvd.user.get_user_by_name(req.user_name) + if not user: + raise UserNotFoundError(req.user_name) + ctx = Context(user=user) svc = Services(ctx, pvd) res = await svc.chat.get_ai_chat_response(req) return res @@ -42,8 +45,11 @@ async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponse @app.post("/api/v1/get_user_chat_history") async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput: - ctx = Context(user_name=req.user_name) pvd = Providers() + user = await pvd.user.get_user_by_name(req.user_name) + if not user: + raise UserNotFoundError(req.user_name) + ctx = Context(user=user) svc = Services(ctx, pvd) res = await svc.chat.get_user_chat_history(req) return res @@ -51,8 +57,11 @@ async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHist @app.post("/api/v1/get_chat_status_today") async def get_chat_status_today(req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput: - ctx = Context(user_name=req.user_name) pvd = Providers() + user = await pvd.user.get_user_by_name(req.user_name) + if not user: + raise UserNotFoundError(req.user_name) + ctx = Context(user=user) svc = Services(ctx, pvd) res = await svc.chat.get_chat_status_today(req) return res diff --git a/simplylab/providers/__init__.py b/simplylab/providers/__init__.py index f053819..4d1e2ca 100644 --- a/simplylab/providers/__init__.py +++ b/simplylab/providers/__init__.py @@ -1,4 +1,6 @@ +from simplylab.providers.chat import ChatProvider from simplylab.providers.openrouter import OpenRouterProvider +from simplylab.providers.user import UserProvider class Providers: @@ -9,3 +11,11 @@ class Providers: @property def openrouter(self): return OpenRouterProvider() + + @property + def user(self): + return UserProvider() + + @property + def chat(self): + return ChatProvider() diff --git a/simplylab/providers/chat.py b/simplylab/providers/chat.py new file mode 100644 index 0000000..f5a2a1f --- /dev/null +++ b/simplylab/providers/chat.py @@ -0,0 +1,20 @@ +from typing import Optional + +from simplylab.entity import UserConversationMessages + + +class ChatProvider: + def __init__(self): + ... + + async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool: + ... + + async def check_user_message_limited_in_daily(self, user_id: str) -> bool: + ... + + async def get_user_conversation_messages(self, user_id, conversation_id) -> Optional[UserConversationMessages]: + ... + + async def get_user_chat_messages_count_today(self, user_id: str) -> int: + ... \ No newline at end of file diff --git a/simplylab/providers/user.py b/simplylab/providers/user.py new file mode 100644 index 0000000..0ab1337 --- /dev/null +++ b/simplylab/providers/user.py @@ -0,0 +1,12 @@ +from typing import Optional + +from simplylab.entity import User + + +class UserProvider: + + def __init__(self): + ... + + def get_user_by_name(self, user_name: str) -> Optional[User]: + ...