This commit is contained in:
Jeremy Yin 2024-03-15 19:11:17 +08:00
parent cd33a36e28
commit bbb646ba92
7 changed files with 143 additions and 49 deletions

18
simplylab/database.py Normal file
View File

@ -0,0 +1,18 @@
import motor.motor_asyncio
class Database:
def __init__(self, client: motor.motor_asyncio.AsyncIOMotorClient):
self._client = client
@property
def user(self):
return self._client.get_database("simplylab").get_collection("user")
@property
def conversation(self):
return self._client.get_database("simplylab").get_collection("conversation")
@property
def message(self):
return self._client.get_database("simplylab").get_collection("message")

View File

@ -1,8 +1,9 @@
import datetime import datetime
from enum import Enum from enum import Enum
from typing import Optional, Annotated
from bson import ObjectId from bson import ObjectId
from pydantic import BaseModel from pydantic import BaseModel, Field, BeforeValidator, ConfigDict
class GetAiChatResponseInput(BaseModel): class GetAiChatResponseInput(BaseModel):
@ -37,52 +38,99 @@ class GetChatStatusTodayOutput(BaseModel):
# === mongodb documents start === # === mongodb documents start ===
PyObjectId = Annotated[str, BeforeValidator(ObjectId)]
class MessageRoleType(str, Enum): class MessageRoleType(str, Enum):
User = "user" User = "user"
Ai = "ai" Ai = "ai"
class User(BaseModel): class User(BaseModel):
_id: ObjectId id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
name: str name: str = Field(min_length=3, max_length=100, description="user name")
created_at: datetime.datetime created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
updated_at: datetime.datetime updated_at: Optional[datetime.datetime] = Field(default=None)
model_config = ConfigDict(
populate_by_name=True,
json_schema_extra={
"example": {
"_id": "xxx",
"name": "jdoe",
"created_at": datetime.datetime.now(),
"updated_at": None,
}
},
)
class Message(BaseModel): class Message(BaseModel):
_id: ObjectId id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
conversation_id: ObjectId conversation_id: PyObjectId = Field()
user_id: ObjectId user_id: PyObjectId = Field()
type: MessageRoleType type: MessageRoleType = Field()
text: str text: str = Field()
created_at: datetime.datetime created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
created_by: ObjectId created_by: PyObjectId = Field()
updated_at: datetime.datetime updated_at: Optional[datetime.datetime] = Field(default=None)
updated_by: ObjectId updated_by: Optional[PyObjectId] = Field(default=None)
model_config = ConfigDict(
populate_by_name=True,
json_schema_extra={
"example": {
"_id": "xxx",
"conversation_id": "xxx",
"user_id": "xxx",
"type": MessageRoleType.User,
"text": "xxx",
"created_at": datetime.datetime.now(),
"created_by": "xxx",
"updated_at": None,
"updated_by": None,
}
},
)
class Conversation(BaseModel): class Conversation(BaseModel):
_id: ObjectId id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
user_id: ObjectId user_id: PyObjectId = Field()
title: str title: str = Field()
created_at: datetime.datetime created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
created_by: ObjectId created_by: PyObjectId = Field()
updated_at: datetime.datetime updated_at: Optional[datetime.datetime] = Field(default=None)
updated_by: ObjectId updated_by: Optional[PyObjectId] = Field(default=None)
model_config = ConfigDict(
populate_by_name=True,
json_schema_extra={
"example": {
"_id": "xxx",
"user_id": "xxx",
"title": "xx",
"created_at": datetime.datetime.now(),
"created_by": "xxx",
"updated_at": None,
"updated_by": None,
}
},
)
# === mongodb documents end === # === mongodb documents end ===
class UserConversationMessages(BaseModel): class UserConversationMessages(BaseModel):
user_id: ObjectId user_id: PyObjectId = Field()
user_name: str user_name: str = Field()
conversation_id: ObjectId conversation_id: PyObjectId = Field()
title: str title: str = Field()
created_at: datetime.datetime created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
created_by: ObjectId created_by: PyObjectId = Field()
updated_at: datetime.datetime updated_at: Optional[datetime.datetime] = Field(default=None)
updated_by: ObjectId updated_by: Optional[PyObjectId] = Field(default=None)
messages: list[Message] messages: list[Message]

View File

@ -1,11 +1,13 @@
import os import os
from typing import Union from typing import Union
import sentry_sdk import sentry_sdk
import motor.motor_asyncio
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from dotenv import load_dotenv from dotenv import load_dotenv
from loguru import logger
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from simplylab.database import Database
from simplylab.entity import GetAiChatResponseInput, Context from simplylab.entity import GetAiChatResponseInput, Context
from simplylab.entity import GetAiChatResponseOutput from simplylab.entity import GetAiChatResponseOutput
from simplylab.entity import GetUserChatHistoryInput from simplylab.entity import GetUserChatHistoryInput
@ -30,6 +32,22 @@ sentry_sdk.init(
app = FastAPI() app = FastAPI()
@app.on_event("startup")
def startup_db_client():
mongo_username = os.getenv("MONGO_USERNAME")
mongo_password = os.getenv("MONGO_PASSWORD")
mongo_uri = f"mongodb://{mongo_username}:{mongo_password}@mongodb:27017/"
app.mongodb_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_uri)
app.db = Database(client=app.mongodb_client)
logger.info("Connected to the MongoDB database!")
@app.on_event("shutdown")
def shutdown_db_client():
app.mongodb_client.close()
logger.info("Disconnected to the MongoDB database!")
@app.exception_handler(Error) @app.exception_handler(Error)
async def error_handler(request: Request, exc: Error): async def error_handler(request: Request, exc: Error):
return JSONResponse( return JSONResponse(
@ -44,8 +62,8 @@ async def hi():
@app.post("/api/v1/get_ai_chat_response") @app.post("/api/v1/get_ai_chat_response")
async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponseOutput: async def get_ai_chat_response(request: Request, req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
pvd = Providers() pvd = Providers(db=request.app.db)
user = await pvd.user.get_user_by_name(req.user_name) user = await pvd.user.get_user_by_name(req.user_name)
if not user: if not user:
raise UserNotFoundError(req.user_name) raise UserNotFoundError(req.user_name)
@ -56,8 +74,8 @@ async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponse
@app.post("/api/v1/get_user_chat_history") @app.post("/api/v1/get_user_chat_history")
async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput: async def get_user_chat_history(request: Request, req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
pvd = Providers() pvd = Providers(db=request.app.db)
user = await pvd.user.get_user_by_name(req.user_name) user = await pvd.user.get_user_by_name(req.user_name)
if not user: if not user:
raise UserNotFoundError(req.user_name) raise UserNotFoundError(req.user_name)
@ -68,8 +86,8 @@ async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHist
@app.post("/api/v1/get_chat_status_today") @app.post("/api/v1/get_chat_status_today")
async def get_chat_status_today(req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput: async def get_chat_status_today(request: Request, req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
pvd = Providers() pvd = Providers(db=request.app.db)
user = await pvd.user.get_user_by_name(req.user_name) user = await pvd.user.get_user_by_name(req.user_name)
if not user: if not user:
raise UserNotFoundError(req.user_name) raise UserNotFoundError(req.user_name)

View File

@ -1,3 +1,4 @@
from simplylab.database import Database
from simplylab.providers.chat import ChatProvider from simplylab.providers.chat import ChatProvider
from simplylab.providers.openrouter import OpenRouterProvider from simplylab.providers.openrouter import OpenRouterProvider
from simplylab.providers.user import UserProvider from simplylab.providers.user import UserProvider
@ -5,17 +6,17 @@ from simplylab.providers.user import UserProvider
class Providers: class Providers:
def __init__(self): def __init__(self, db: Database) -> None:
... self.db = db
@property @property
def openrouter(self): def openrouter(self):
return OpenRouterProvider() return OpenRouterProvider(db=self.db)
@property @property
def user(self): def user(self):
return UserProvider() return UserProvider(db=self.db)
@property @property
def chat(self): def chat(self):
return ChatProvider() return ChatProvider(db=self.db)

View File

@ -1,11 +1,12 @@
from typing import Optional from typing import Optional
from simplylab.database import Database
from simplylab.entity import UserConversationMessages from simplylab.entity import UserConversationMessages
class ChatProvider: class ChatProvider:
def __init__(self): def __init__(self, db: Database) -> None:
... self.db = db
async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool: async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool:
... ...

View File

@ -2,10 +2,12 @@ from openai import OpenAI
from os import getenv from os import getenv
from loguru import logger from loguru import logger
from simplylab.database import Database
class OpenRouterProvider: class OpenRouterProvider:
def __init__(self): def __init__(self, db: Database) -> None:
... self.db = db
async def chat(self, content: str) -> str: async def chat(self, content: str) -> str:
# gets API Key from environment variable OPENAI_API_KEY # gets API Key from environment variable OPENAI_API_KEY

View File

@ -1,12 +1,18 @@
import datetime
from typing import Optional from typing import Optional
from simplylab.database import Database
from simplylab.entity import User from simplylab.entity import User
class UserProvider: class UserProvider:
def __init__(self): def __init__(self, db: Database) -> None:
... self.db = db
def get_user_by_name(self, user_name: str) -> Optional[User]: async def get_user_by_name(self, user_name: str) -> Optional[User]:
... user = await self.db.user.find_one({"name": user_name})
if not user:
user = User(name=user_name)
user = await self.db.user.insert_one(user.model_dump())
return user