This commit is contained in:
Jeremy Yin 2024-03-15 19:11:17 +08:00
parent cd33a36e28
commit bbb646ba92
7 changed files with 143 additions and 49 deletions

18
simplylab/database.py Normal file
View File

@ -0,0 +1,18 @@
import motor.motor_asyncio
class Database:
def __init__(self, client: motor.motor_asyncio.AsyncIOMotorClient):
self._client = client
@property
def user(self):
return self._client.get_database("simplylab").get_collection("user")
@property
def conversation(self):
return self._client.get_database("simplylab").get_collection("conversation")
@property
def message(self):
return self._client.get_database("simplylab").get_collection("message")

View File

@ -1,8 +1,9 @@
import datetime
from enum import Enum
from typing import Optional, Annotated
from bson import ObjectId
from pydantic import BaseModel
from pydantic import BaseModel, Field, BeforeValidator, ConfigDict
class GetAiChatResponseInput(BaseModel):
@ -37,52 +38,99 @@ class GetChatStatusTodayOutput(BaseModel):
# === mongodb documents start ===
PyObjectId = Annotated[str, BeforeValidator(ObjectId)]
class MessageRoleType(str, Enum):
User = "user"
Ai = "ai"
class User(BaseModel):
_id: ObjectId
name: str
created_at: datetime.datetime
updated_at: datetime.datetime
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
name: str = Field(min_length=3, max_length=100, description="user name")
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
updated_at: Optional[datetime.datetime] = Field(default=None)
model_config = ConfigDict(
populate_by_name=True,
json_schema_extra={
"example": {
"_id": "xxx",
"name": "jdoe",
"created_at": datetime.datetime.now(),
"updated_at": None,
}
},
)
class Message(BaseModel):
_id: ObjectId
conversation_id: ObjectId
user_id: ObjectId
type: MessageRoleType
text: str
created_at: datetime.datetime
created_by: ObjectId
updated_at: datetime.datetime
updated_by: ObjectId
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
conversation_id: PyObjectId = Field()
user_id: PyObjectId = Field()
type: MessageRoleType = Field()
text: str = Field()
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
created_by: PyObjectId = Field()
updated_at: Optional[datetime.datetime] = Field(default=None)
updated_by: Optional[PyObjectId] = Field(default=None)
model_config = ConfigDict(
populate_by_name=True,
json_schema_extra={
"example": {
"_id": "xxx",
"conversation_id": "xxx",
"user_id": "xxx",
"type": MessageRoleType.User,
"text": "xxx",
"created_at": datetime.datetime.now(),
"created_by": "xxx",
"updated_at": None,
"updated_by": None,
}
},
)
class Conversation(BaseModel):
_id: ObjectId
user_id: ObjectId
title: str
created_at: datetime.datetime
created_by: ObjectId
updated_at: datetime.datetime
updated_by: ObjectId
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
user_id: PyObjectId = Field()
title: str = Field()
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
created_by: PyObjectId = Field()
updated_at: Optional[datetime.datetime] = Field(default=None)
updated_by: Optional[PyObjectId] = Field(default=None)
model_config = ConfigDict(
populate_by_name=True,
json_schema_extra={
"example": {
"_id": "xxx",
"user_id": "xxx",
"title": "xx",
"created_at": datetime.datetime.now(),
"created_by": "xxx",
"updated_at": None,
"updated_by": None,
}
},
)
# === mongodb documents end ===
class UserConversationMessages(BaseModel):
user_id: ObjectId
user_name: str
conversation_id: ObjectId
title: str
created_at: datetime.datetime
created_by: ObjectId
updated_at: datetime.datetime
updated_by: ObjectId
user_id: PyObjectId = Field()
user_name: str = Field()
conversation_id: PyObjectId = Field()
title: str = Field()
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
created_by: PyObjectId = Field()
updated_at: Optional[datetime.datetime] = Field(default=None)
updated_by: Optional[PyObjectId] = Field(default=None)
messages: list[Message]

View File

@ -1,11 +1,13 @@
import os
from typing import Union
import sentry_sdk
import motor.motor_asyncio
from fastapi import FastAPI, Request
from dotenv import load_dotenv
from loguru import logger
from starlette.responses import JSONResponse
from simplylab.database import Database
from simplylab.entity import GetAiChatResponseInput, Context
from simplylab.entity import GetAiChatResponseOutput
from simplylab.entity import GetUserChatHistoryInput
@ -30,6 +32,22 @@ sentry_sdk.init(
app = FastAPI()
@app.on_event("startup")
def startup_db_client():
mongo_username = os.getenv("MONGO_USERNAME")
mongo_password = os.getenv("MONGO_PASSWORD")
mongo_uri = f"mongodb://{mongo_username}:{mongo_password}@mongodb:27017/"
app.mongodb_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_uri)
app.db = Database(client=app.mongodb_client)
logger.info("Connected to the MongoDB database!")
@app.on_event("shutdown")
def shutdown_db_client():
app.mongodb_client.close()
logger.info("Disconnected to the MongoDB database!")
@app.exception_handler(Error)
async def error_handler(request: Request, exc: Error):
return JSONResponse(
@ -44,8 +62,8 @@ async def hi():
@app.post("/api/v1/get_ai_chat_response")
async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
pvd = Providers()
async def get_ai_chat_response(request: Request, req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
pvd = Providers(db=request.app.db)
user = await pvd.user.get_user_by_name(req.user_name)
if not user:
raise UserNotFoundError(req.user_name)
@ -56,8 +74,8 @@ async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponse
@app.post("/api/v1/get_user_chat_history")
async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
pvd = Providers()
async def get_user_chat_history(request: Request, req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
pvd = Providers(db=request.app.db)
user = await pvd.user.get_user_by_name(req.user_name)
if not user:
raise UserNotFoundError(req.user_name)
@ -68,8 +86,8 @@ async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHist
@app.post("/api/v1/get_chat_status_today")
async def get_chat_status_today(req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
pvd = Providers()
async def get_chat_status_today(request: Request, req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
pvd = Providers(db=request.app.db)
user = await pvd.user.get_user_by_name(req.user_name)
if not user:
raise UserNotFoundError(req.user_name)

View File

@ -1,3 +1,4 @@
from simplylab.database import Database
from simplylab.providers.chat import ChatProvider
from simplylab.providers.openrouter import OpenRouterProvider
from simplylab.providers.user import UserProvider
@ -5,17 +6,17 @@ from simplylab.providers.user import UserProvider
class Providers:
def __init__(self):
...
def __init__(self, db: Database) -> None:
self.db = db
@property
def openrouter(self):
return OpenRouterProvider()
return OpenRouterProvider(db=self.db)
@property
def user(self):
return UserProvider()
return UserProvider(db=self.db)
@property
def chat(self):
return ChatProvider()
return ChatProvider(db=self.db)

View File

@ -1,11 +1,12 @@
from typing import Optional
from simplylab.database import Database
from simplylab.entity import UserConversationMessages
class ChatProvider:
def __init__(self):
...
def __init__(self, db: Database) -> None:
self.db = db
async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool:
...

View File

@ -2,10 +2,12 @@ from openai import OpenAI
from os import getenv
from loguru import logger
from simplylab.database import Database
class OpenRouterProvider:
def __init__(self):
...
def __init__(self, db: Database) -> None:
self.db = db
async def chat(self, content: str) -> str:
# gets API Key from environment variable OPENAI_API_KEY

View File

@ -1,12 +1,18 @@
import datetime
from typing import Optional
from simplylab.database import Database
from simplylab.entity import User
class UserProvider:
def __init__(self):
...
def __init__(self, db: Database) -> None:
self.db = db
def get_user_by_name(self, user_name: str) -> Optional[User]:
...
async def get_user_by_name(self, user_name: str) -> Optional[User]:
user = await self.db.user.find_one({"name": user_name})
if not user:
user = User(name=user_name)
user = await self.db.user.insert_one(user.model_dump())
return user