define user chat func and context be user
This commit is contained in:
parent
14c6b5933c
commit
8edf8f9fd6
|
@ -1,10 +1,10 @@
|
|||
import datetime
|
||||
from enum import Enum
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
user_name: str
|
||||
|
||||
|
||||
class GetAiChatResponseInput(BaseModel):
|
||||
message: str
|
||||
user_name: str
|
||||
|
@ -34,3 +34,57 @@ class GetChatStatusTodayInput(BaseModel):
|
|||
class GetChatStatusTodayOutput(BaseModel):
|
||||
user_name: str
|
||||
chat_cnt: int
|
||||
|
||||
|
||||
# === mongodb documents start ===
|
||||
class MessageRoleType(str, Enum):
|
||||
User = "user"
|
||||
Ai = "ai"
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
_id: ObjectId
|
||||
name: str
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
_id: ObjectId
|
||||
conversation_id: ObjectId
|
||||
user_id: ObjectId
|
||||
type: MessageRoleType
|
||||
text: str
|
||||
created_at: datetime.datetime
|
||||
created_by: ObjectId
|
||||
updated_at: datetime.datetime
|
||||
updated_by: ObjectId
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
_id: ObjectId
|
||||
user_id: ObjectId
|
||||
title: str
|
||||
created_at: datetime.datetime
|
||||
created_by: ObjectId
|
||||
updated_at: datetime.datetime
|
||||
updated_by: ObjectId
|
||||
|
||||
|
||||
# === mongodb documents end ===
|
||||
|
||||
|
||||
class UserConversationMessages(BaseModel):
|
||||
user_id: ObjectId
|
||||
user_name: str
|
||||
conversation_id: ObjectId
|
||||
title: str
|
||||
created_at: datetime.datetime
|
||||
created_by: ObjectId
|
||||
updated_at: datetime.datetime
|
||||
updated_by: ObjectId
|
||||
messages: list[Message]
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
user: User
|
||||
|
|
|
@ -14,3 +14,10 @@ class MessageLimitedInDailyError(Error):
|
|||
def __init__(self):
|
||||
super().__init__("20 messages limited in daily")
|
||||
self.status_code = 401
|
||||
|
||||
|
||||
class UserNotFoundError(Error):
|
||||
|
||||
def __init__(self, user_name: str):
|
||||
super().__init__("user {} not found".format(user_name))
|
||||
self.status_code = 401
|
||||
|
|
|
@ -10,7 +10,7 @@ from simplylab.entity import GetUserChatHistoryInput
|
|||
from simplylab.entity import GetUserChatHistoryOutput
|
||||
from simplylab.entity import GetChatStatusTodayInput
|
||||
from simplylab.entity import GetChatStatusTodayOutput
|
||||
from simplylab.error import Error
|
||||
from simplylab.error import Error, UserNotFoundError
|
||||
from simplylab.providers import Providers
|
||||
from simplylab.services import Services
|
||||
|
||||
|
@ -33,8 +33,11 @@ async def hi():
|
|||
|
||||
@app.post("/api/v1/get_ai_chat_response")
|
||||
async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
|
||||
ctx = Context(user_name=req.user_name)
|
||||
pvd = Providers()
|
||||
user = await pvd.user.get_user_by_name(req.user_name)
|
||||
if not user:
|
||||
raise UserNotFoundError(req.user_name)
|
||||
ctx = Context(user=user)
|
||||
svc = Services(ctx, pvd)
|
||||
res = await svc.chat.get_ai_chat_response(req)
|
||||
return res
|
||||
|
@ -42,8 +45,11 @@ async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponse
|
|||
|
||||
@app.post("/api/v1/get_user_chat_history")
|
||||
async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
|
||||
ctx = Context(user_name=req.user_name)
|
||||
pvd = Providers()
|
||||
user = await pvd.user.get_user_by_name(req.user_name)
|
||||
if not user:
|
||||
raise UserNotFoundError(req.user_name)
|
||||
ctx = Context(user=user)
|
||||
svc = Services(ctx, pvd)
|
||||
res = await svc.chat.get_user_chat_history(req)
|
||||
return res
|
||||
|
@ -51,8 +57,11 @@ async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHist
|
|||
|
||||
@app.post("/api/v1/get_chat_status_today")
|
||||
async def get_chat_status_today(req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
|
||||
ctx = Context(user_name=req.user_name)
|
||||
pvd = Providers()
|
||||
user = await pvd.user.get_user_by_name(req.user_name)
|
||||
if not user:
|
||||
raise UserNotFoundError(req.user_name)
|
||||
ctx = Context(user=user)
|
||||
svc = Services(ctx, pvd)
|
||||
res = await svc.chat.get_chat_status_today(req)
|
||||
return res
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from simplylab.providers.chat import ChatProvider
|
||||
from simplylab.providers.openrouter import OpenRouterProvider
|
||||
from simplylab.providers.user import UserProvider
|
||||
|
||||
|
||||
class Providers:
|
||||
|
@ -9,3 +11,11 @@ class Providers:
|
|||
@property
|
||||
def openrouter(self):
|
||||
return OpenRouterProvider()
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
return UserProvider()
|
||||
|
||||
@property
|
||||
def chat(self):
|
||||
return ChatProvider()
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
from typing import Optional
|
||||
|
||||
from simplylab.entity import UserConversationMessages
|
||||
|
||||
|
||||
class ChatProvider:
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def check_user_message_limited_in_daily(self, user_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def get_user_conversation_messages(self, user_id, conversation_id) -> Optional[UserConversationMessages]:
|
||||
...
|
||||
|
||||
async def get_user_chat_messages_count_today(self, user_id: str) -> int:
|
||||
...
|
|
@ -0,0 +1,12 @@
|
|||
from typing import Optional
|
||||
|
||||
from simplylab.entity import User
|
||||
|
||||
|
||||
class UserProvider:
|
||||
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
def get_user_by_name(self, user_name: str) -> Optional[User]:
|
||||
...
|
Loading…
Reference in New Issue