refactor: model

This commit is contained in:
Jeremy Yin 2024-03-19 20:29:22 +08:00
parent 683b276c5b
commit aae5851178
10 changed files with 71 additions and 75 deletions

View File

@ -8,12 +8,10 @@ from loguru import logger
from starlette.responses import JSONResponse
from simplylab.database import Database
from simplylab.entity import GetAiChatResponseInput, Context
from simplylab.entity import GetAiChatResponseOutput
from simplylab.entity import GetUserChatHistoryInput
from simplylab.entity import GetUserChatHistoryOutput
from simplylab.entity import GetChatStatusTodayInput
from simplylab.entity import GetChatStatusTodayOutput
from simplylab.model.entity import Context
from simplylab.model.req import GetAiChatResponseInput, GetUserChatHistoryInput, GetChatStatusTodayInput
from simplylab.model.res import GetAiChatResponseOutput, GetChatStatusTodayOutput
from simplylab.model.res import GetUserChatHistoryOutput
from simplylab.error import Error, UserNotFoundError
from simplylab.providers import Providers
from simplylab.services import Services

View File

View File

@ -0,0 +1,7 @@
from pydantic import BaseModel
from simplylab.model.table import User
class Context(BaseModel):
user: User

15
simplylab/model/req.py Normal file
View File

@ -0,0 +1,15 @@
from pydantic import BaseModel
class GetAiChatResponseInput(BaseModel):
message: str
user_name: str
class GetUserChatHistoryInput(BaseModel):
user_name: str
last_n: int
class GetChatStatusTodayInput(BaseModel):
user_name: str

30
simplylab/model/res.py Normal file
View File

@ -0,0 +1,30 @@
from pydantic import BaseModel
class GetAiChatResponseOutput(BaseModel):
response: str
class GetChatStatusTodayOutput(BaseModel):
user_name: str
chat_cnt: int
class UserChatMessage(BaseModel):
type: str
text: str
type GetUserChatHistoryOutput = list[UserChatMessage]
# class UserConversationMessages(BaseModel):
# user_id: ObjectIdField = Field()
# user_name: str = Field()
# conversation_id: ObjectIdField = Field(default=None)
# title: str = Field()
# created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
# created_by: ObjectIdField = Field()
# updated_at: Optional[datetime.datetime] = Field(default=None)
# updated_by: Optional[ObjectIdField] = Field(default=None)
# messages: list[Message]

View File

@ -1,49 +1,13 @@
import datetime
from enum import Enum
from typing import Optional, Annotated, Any
from typing import Annotated, Any, Optional
from bson import ObjectId
from pydantic import BaseModel, Field, BeforeValidator, ConfigDict
class GetAiChatResponseInput(BaseModel):
message: str
user_name: str
class GetAiChatResponseOutput(BaseModel):
response: str
class GetUserChatHistoryInput(BaseModel):
user_name: str
last_n: int
class UserChatMessage(BaseModel):
type: str
text: str
type GetUserChatHistoryOutput = list[UserChatMessage]
class GetChatStatusTodayInput(BaseModel):
user_name: str
class GetChatStatusTodayOutput(BaseModel):
user_name: str
chat_cnt: int
# === mongodb documents start ===
from pydantic import (AfterValidator, GetPydanticSchema,
PlainSerializer, WithJsonSchema)
from pydantic import AfterValidator, PlainSerializer, WithJsonSchema, GetPydanticSchema, BaseModel, Field, ConfigDict
from pydantic_mongo import ObjectIdField as _objectIdField
# ObjectIdField = Annotated[str, BeforeValidator(ObjectId)]
ObjectIdField = Annotated[
_objectIdField,
AfterValidator(lambda id: _objectIdField(id)),
@ -58,9 +22,6 @@ class Test(BaseModel):
id: ObjectIdField = Field(default_factory=ObjectIdField, alias='_id', title='_id')
# ObjectIdField = Annotated[str, BeforeValidator(ObjectId)]
class MessageRoleType(str, Enum):
User = "user"
Ai = "ai"
@ -117,7 +78,6 @@ class Message(BaseModel):
},
)
# class Conversation(BaseModel):
# id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id')
# user_id: ObjectIdField = Field()
@ -141,22 +101,3 @@ class Message(BaseModel):
# }
# },
# )
# === mongodb documents end ===
# class UserConversationMessages(BaseModel):
# user_id: ObjectIdField = Field()
# user_name: str = Field()
# conversation_id: ObjectIdField = Field(default=None)
# title: str = Field()
# created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
# created_by: ObjectIdField = Field()
# updated_at: Optional[datetime.datetime] = Field(default=None)
# updated_by: Optional[ObjectIdField] = Field(default=None)
# messages: list[Message]
class Context(BaseModel):
user: User

View File

@ -5,7 +5,8 @@ import pymongo
from loguru import logger
from simplylab.database import Database
from simplylab.entity import ObjectIdField, Message
from simplylab.model.table import ObjectIdField
from simplylab.model.table import Message
class ChatProvider:

View File

@ -4,7 +4,7 @@ from typing import Optional
from loguru import logger
from simplylab.database import Database
from simplylab.entity import User
from simplylab.model.table import User
class UserProvider:

View File

@ -1,6 +1,6 @@
from typing import Any
from simplylab.entity import Context
from simplylab.model.entity import Context
from simplylab.providers import Providers
from simplylab.services.chat import ChatService

View File

@ -2,8 +2,12 @@ from typing import Any
from loguru import logger
from simplylab.entity import GetAiChatResponseInput, GetUserChatHistoryInput, GetChatStatusTodayInput, UserChatMessage, \
GetChatStatusTodayOutput, GetAiChatResponseOutput, GetUserChatHistoryOutput, Context, Message, MessageRoleType
from simplylab.model.res import UserChatMessage, \
GetUserChatHistoryOutput
from simplylab.model.entity import Context
from simplylab.model.res import GetAiChatResponseOutput, GetChatStatusTodayOutput
from simplylab.model.req import GetAiChatResponseInput, GetUserChatHistoryInput, GetChatStatusTodayInput
from simplylab.model.table import MessageRoleType, Message
from simplylab.error import MessageLimitedInDailyError, MessageLimitedIn30SecondsError
from simplylab.providers import Providers