define user chat func and context be user
This commit is contained in:
parent
14c6b5933c
commit
8edf8f9fd6
|
@ -1,10 +1,10 @@
|
||||||
|
import datetime
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from bson import ObjectId
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class Context(BaseModel):
|
|
||||||
user_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class GetAiChatResponseInput(BaseModel):
|
class GetAiChatResponseInput(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
user_name: str
|
user_name: str
|
||||||
|
@ -34,3 +34,57 @@ class GetChatStatusTodayInput(BaseModel):
|
||||||
class GetChatStatusTodayOutput(BaseModel):
|
class GetChatStatusTodayOutput(BaseModel):
|
||||||
user_name: str
|
user_name: str
|
||||||
chat_cnt: int
|
chat_cnt: int
|
||||||
|
|
||||||
|
|
||||||
|
# === mongodb documents start ===
|
||||||
|
class MessageRoleType(str, Enum):
|
||||||
|
User = "user"
|
||||||
|
Ai = "ai"
|
||||||
|
|
||||||
|
|
||||||
|
class User(BaseModel):
|
||||||
|
_id: ObjectId
|
||||||
|
name: str
|
||||||
|
created_at: datetime.datetime
|
||||||
|
updated_at: datetime.datetime
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
_id: ObjectId
|
||||||
|
conversation_id: ObjectId
|
||||||
|
user_id: ObjectId
|
||||||
|
type: MessageRoleType
|
||||||
|
text: str
|
||||||
|
created_at: datetime.datetime
|
||||||
|
created_by: ObjectId
|
||||||
|
updated_at: datetime.datetime
|
||||||
|
updated_by: ObjectId
|
||||||
|
|
||||||
|
|
||||||
|
class Conversation(BaseModel):
|
||||||
|
_id: ObjectId
|
||||||
|
user_id: ObjectId
|
||||||
|
title: str
|
||||||
|
created_at: datetime.datetime
|
||||||
|
created_by: ObjectId
|
||||||
|
updated_at: datetime.datetime
|
||||||
|
updated_by: ObjectId
|
||||||
|
|
||||||
|
|
||||||
|
# === mongodb documents end ===
|
||||||
|
|
||||||
|
|
||||||
|
class UserConversationMessages(BaseModel):
|
||||||
|
user_id: ObjectId
|
||||||
|
user_name: str
|
||||||
|
conversation_id: ObjectId
|
||||||
|
title: str
|
||||||
|
created_at: datetime.datetime
|
||||||
|
created_by: ObjectId
|
||||||
|
updated_at: datetime.datetime
|
||||||
|
updated_by: ObjectId
|
||||||
|
messages: list[Message]
|
||||||
|
|
||||||
|
|
||||||
|
class Context(BaseModel):
|
||||||
|
user: User
|
||||||
|
|
|
@ -14,3 +14,10 @@ class MessageLimitedInDailyError(Error):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("20 messages limited in daily")
|
super().__init__("20 messages limited in daily")
|
||||||
self.status_code = 401
|
self.status_code = 401
|
||||||
|
|
||||||
|
|
||||||
|
class UserNotFoundError(Error):
|
||||||
|
|
||||||
|
def __init__(self, user_name: str):
|
||||||
|
super().__init__("user {} not found".format(user_name))
|
||||||
|
self.status_code = 401
|
||||||
|
|
|
@ -10,7 +10,7 @@ from simplylab.entity import GetUserChatHistoryInput
|
||||||
from simplylab.entity import GetUserChatHistoryOutput
|
from simplylab.entity import GetUserChatHistoryOutput
|
||||||
from simplylab.entity import GetChatStatusTodayInput
|
from simplylab.entity import GetChatStatusTodayInput
|
||||||
from simplylab.entity import GetChatStatusTodayOutput
|
from simplylab.entity import GetChatStatusTodayOutput
|
||||||
from simplylab.error import Error
|
from simplylab.error import Error, UserNotFoundError
|
||||||
from simplylab.providers import Providers
|
from simplylab.providers import Providers
|
||||||
from simplylab.services import Services
|
from simplylab.services import Services
|
||||||
|
|
||||||
|
@ -33,8 +33,11 @@ async def hi():
|
||||||
|
|
||||||
@app.post("/api/v1/get_ai_chat_response")
|
@app.post("/api/v1/get_ai_chat_response")
|
||||||
async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
|
async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
|
||||||
ctx = Context(user_name=req.user_name)
|
|
||||||
pvd = Providers()
|
pvd = Providers()
|
||||||
|
user = await pvd.user.get_user_by_name(req.user_name)
|
||||||
|
if not user:
|
||||||
|
raise UserNotFoundError(req.user_name)
|
||||||
|
ctx = Context(user=user)
|
||||||
svc = Services(ctx, pvd)
|
svc = Services(ctx, pvd)
|
||||||
res = await svc.chat.get_ai_chat_response(req)
|
res = await svc.chat.get_ai_chat_response(req)
|
||||||
return res
|
return res
|
||||||
|
@ -42,8 +45,11 @@ async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponse
|
||||||
|
|
||||||
@app.post("/api/v1/get_user_chat_history")
|
@app.post("/api/v1/get_user_chat_history")
|
||||||
async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
|
async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
|
||||||
ctx = Context(user_name=req.user_name)
|
|
||||||
pvd = Providers()
|
pvd = Providers()
|
||||||
|
user = await pvd.user.get_user_by_name(req.user_name)
|
||||||
|
if not user:
|
||||||
|
raise UserNotFoundError(req.user_name)
|
||||||
|
ctx = Context(user=user)
|
||||||
svc = Services(ctx, pvd)
|
svc = Services(ctx, pvd)
|
||||||
res = await svc.chat.get_user_chat_history(req)
|
res = await svc.chat.get_user_chat_history(req)
|
||||||
return res
|
return res
|
||||||
|
@ -51,8 +57,11 @@ async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHist
|
||||||
|
|
||||||
@app.post("/api/v1/get_chat_status_today")
|
@app.post("/api/v1/get_chat_status_today")
|
||||||
async def get_chat_status_today(req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
|
async def get_chat_status_today(req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
|
||||||
ctx = Context(user_name=req.user_name)
|
|
||||||
pvd = Providers()
|
pvd = Providers()
|
||||||
|
user = await pvd.user.get_user_by_name(req.user_name)
|
||||||
|
if not user:
|
||||||
|
raise UserNotFoundError(req.user_name)
|
||||||
|
ctx = Context(user=user)
|
||||||
svc = Services(ctx, pvd)
|
svc = Services(ctx, pvd)
|
||||||
res = await svc.chat.get_chat_status_today(req)
|
res = await svc.chat.get_chat_status_today(req)
|
||||||
return res
|
return res
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
from simplylab.providers.chat import ChatProvider
|
||||||
from simplylab.providers.openrouter import OpenRouterProvider
|
from simplylab.providers.openrouter import OpenRouterProvider
|
||||||
|
from simplylab.providers.user import UserProvider
|
||||||
|
|
||||||
|
|
||||||
class Providers:
|
class Providers:
|
||||||
|
@ -9,3 +11,11 @@ class Providers:
|
||||||
@property
|
@property
|
||||||
def openrouter(self):
|
def openrouter(self):
|
||||||
return OpenRouterProvider()
|
return OpenRouterProvider()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user(self):
|
||||||
|
return UserProvider()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat(self):
|
||||||
|
return ChatProvider()
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from simplylab.entity import UserConversationMessages
|
||||||
|
|
||||||
|
|
||||||
|
class ChatProvider:
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def check_user_message_limited_in_daily(self, user_id: str) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_user_conversation_messages(self, user_id, conversation_id) -> Optional[UserConversationMessages]:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_user_chat_messages_count_today(self, user_id: str) -> int:
|
||||||
|
...
|
|
@ -0,0 +1,12 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from simplylab.entity import User
|
||||||
|
|
||||||
|
|
||||||
|
class UserProvider:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_user_by_name(self, user_name: str) -> Optional[User]:
|
||||||
|
...
|
Loading…
Reference in New Issue