From bbb646ba922feb550e6f00e18b04b10c932df937 Mon Sep 17 00:00:00 2001 From: Jeremy Yin Date: Fri, 15 Mar 2024 19:11:17 +0800 Subject: [PATCH] add db --- simplylab/database.py | 18 +++++ simplylab/entity.py | 106 ++++++++++++++++++++++-------- simplylab/main.py | 32 +++++++-- simplylab/providers/__init__.py | 11 ++-- simplylab/providers/chat.py | 5 +- simplylab/providers/openrouter.py | 6 +- simplylab/providers/user.py | 14 ++-- 7 files changed, 143 insertions(+), 49 deletions(-) create mode 100644 simplylab/database.py diff --git a/simplylab/database.py b/simplylab/database.py new file mode 100644 index 0000000..216bcd9 --- /dev/null +++ b/simplylab/database.py @@ -0,0 +1,18 @@ +import motor.motor_asyncio + + +class Database: + def __init__(self, client: motor.motor_asyncio.AsyncIOMotorClient): + self._client = client + + @property + def user(self): + return self._client.get_database("simplylab").get_collection("user") + + @property + def conversation(self): + return self._client.get_database("simplylab").get_collection("conversation") + + @property + def message(self): + return self._client.get_database("simplylab").get_collection("message") diff --git a/simplylab/entity.py b/simplylab/entity.py index 12085fe..8eab3cd 100644 --- a/simplylab/entity.py +++ b/simplylab/entity.py @@ -1,8 +1,9 @@ import datetime from enum import Enum +from typing import Optional, Annotated from bson import ObjectId -from pydantic import BaseModel +from pydantic import BaseModel, Field, BeforeValidator, ConfigDict class GetAiChatResponseInput(BaseModel): @@ -37,52 +38,99 @@ class GetChatStatusTodayOutput(BaseModel): # === mongodb documents start === + +PyObjectId = Annotated[str, BeforeValidator(ObjectId)] + + class MessageRoleType(str, Enum): User = "user" Ai = "ai" class User(BaseModel): - _id: ObjectId - name: str - created_at: datetime.datetime - updated_at: datetime.datetime + id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") + name: str = Field(min_length=3, max_length=100, description="user name") + created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) + updated_at: Optional[datetime.datetime] = Field(default=None) + + model_config = ConfigDict( + populate_by_name=True, + json_schema_extra={ + "example": { + "_id": "xxx", + "name": "jdoe", + "created_at": datetime.datetime.now(), + "updated_at": None, + } + }, + ) class Message(BaseModel): - _id: ObjectId - conversation_id: ObjectId - user_id: ObjectId - type: MessageRoleType - text: str - created_at: datetime.datetime - created_by: ObjectId - updated_at: datetime.datetime - updated_by: ObjectId + id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") + conversation_id: PyObjectId = Field() + user_id: PyObjectId = Field() + type: MessageRoleType = Field() + text: str = Field() + created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) + created_by: PyObjectId = Field() + updated_at: Optional[datetime.datetime] = Field(default=None) + updated_by: Optional[PyObjectId] = Field(default=None) + + model_config = ConfigDict( + populate_by_name=True, + json_schema_extra={ + "example": { + "_id": "xxx", + "conversation_id": "xxx", + "user_id": "xxx", + "type": MessageRoleType.User, + "text": "xxx", + "created_at": datetime.datetime.now(), + "created_by": "xxx", + "updated_at": None, + "updated_by": None, + } + }, + ) class Conversation(BaseModel): - _id: ObjectId - user_id: ObjectId - title: str - created_at: datetime.datetime - created_by: ObjectId - updated_at: datetime.datetime - updated_by: ObjectId + id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") + user_id: PyObjectId = Field() + title: str = Field() + created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) + created_by: PyObjectId = Field() + updated_at: Optional[datetime.datetime] = Field(default=None) + updated_by: Optional[PyObjectId] = Field(default=None) + model_config = ConfigDict( + populate_by_name=True, + json_schema_extra={ + "example": { + "_id": "xxx", + "user_id": "xxx", + "title": "xx", + "created_at": datetime.datetime.now(), + "created_by": "xxx", + "updated_at": None, + "updated_by": None, + } + }, + ) # === mongodb documents end === class UserConversationMessages(BaseModel): - user_id: ObjectId - user_name: str - conversation_id: ObjectId - title: str - created_at: datetime.datetime - created_by: ObjectId - updated_at: datetime.datetime - updated_by: ObjectId + user_id: PyObjectId = Field() + user_name: str = Field() + conversation_id: PyObjectId = Field() + title: str = Field() + created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) + created_by: PyObjectId = Field() + updated_at: Optional[datetime.datetime] = Field(default=None) + updated_by: Optional[PyObjectId] = Field(default=None) messages: list[Message] diff --git a/simplylab/main.py b/simplylab/main.py index b9e1c9c..675bf19 100644 --- a/simplylab/main.py +++ b/simplylab/main.py @@ -1,11 +1,13 @@ import os from typing import Union - import sentry_sdk +import motor.motor_asyncio from fastapi import FastAPI, Request from dotenv import load_dotenv +from loguru import logger from starlette.responses import JSONResponse +from simplylab.database import Database from simplylab.entity import GetAiChatResponseInput, Context from simplylab.entity import GetAiChatResponseOutput from simplylab.entity import GetUserChatHistoryInput @@ -30,6 +32,22 @@ sentry_sdk.init( app = FastAPI() +@app.on_event("startup") +def startup_db_client(): + mongo_username = os.getenv("MONGO_USERNAME") + mongo_password = os.getenv("MONGO_PASSWORD") + mongo_uri = f"mongodb://{mongo_username}:{mongo_password}@mongodb:27017/" + app.mongodb_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_uri) + app.db = Database(client=app.mongodb_client) + logger.info("Connected to the MongoDB database!") + + +@app.on_event("shutdown") +def shutdown_db_client(): + app.mongodb_client.close() + logger.info("Disconnected to the MongoDB database!") + + @app.exception_handler(Error) async def error_handler(request: Request, exc: Error): return JSONResponse( @@ -44,8 +62,8 @@ async def hi(): @app.post("/api/v1/get_ai_chat_response") -async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponseOutput: - pvd = Providers() +async def get_ai_chat_response(request: Request, req: GetAiChatResponseInput) -> GetAiChatResponseOutput: + pvd = Providers(db=request.app.db) user = await pvd.user.get_user_by_name(req.user_name) if not user: raise UserNotFoundError(req.user_name) @@ -56,8 +74,8 @@ async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponse @app.post("/api/v1/get_user_chat_history") -async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput: - pvd = Providers() +async def get_user_chat_history(request: Request, req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput: + pvd = Providers(db=request.app.db) user = await pvd.user.get_user_by_name(req.user_name) if not user: raise UserNotFoundError(req.user_name) @@ -68,8 +86,8 @@ async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHist @app.post("/api/v1/get_chat_status_today") -async def get_chat_status_today(req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput: - pvd = Providers() +async def get_chat_status_today(request: Request, req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput: + pvd = Providers(db=request.app.db) user = await pvd.user.get_user_by_name(req.user_name) if not user: raise UserNotFoundError(req.user_name) diff --git a/simplylab/providers/__init__.py b/simplylab/providers/__init__.py index 4d1e2ca..9061650 100644 --- a/simplylab/providers/__init__.py +++ b/simplylab/providers/__init__.py @@ -1,3 +1,4 @@ +from simplylab.database import Database from simplylab.providers.chat import ChatProvider from simplylab.providers.openrouter import OpenRouterProvider from simplylab.providers.user import UserProvider @@ -5,17 +6,17 @@ from simplylab.providers.user import UserProvider class Providers: - def __init__(self): - ... + def __init__(self, db: Database) -> None: + self.db = db @property def openrouter(self): - return OpenRouterProvider() + return OpenRouterProvider(db=self.db) @property def user(self): - return UserProvider() + return UserProvider(db=self.db) @property def chat(self): - return ChatProvider() + return ChatProvider(db=self.db) diff --git a/simplylab/providers/chat.py b/simplylab/providers/chat.py index f5a2a1f..7949abe 100644 --- a/simplylab/providers/chat.py +++ b/simplylab/providers/chat.py @@ -1,11 +1,12 @@ from typing import Optional +from simplylab.database import Database from simplylab.entity import UserConversationMessages class ChatProvider: - def __init__(self): - ... + def __init__(self, db: Database) -> None: + self.db = db async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool: ... diff --git a/simplylab/providers/openrouter.py b/simplylab/providers/openrouter.py index 3d2cd7a..52c8e0f 100644 --- a/simplylab/providers/openrouter.py +++ b/simplylab/providers/openrouter.py @@ -2,10 +2,12 @@ from openai import OpenAI from os import getenv from loguru import logger +from simplylab.database import Database + class OpenRouterProvider: - def __init__(self): - ... + def __init__(self, db: Database) -> None: + self.db = db async def chat(self, content: str) -> str: # gets API Key from environment variable OPENAI_API_KEY diff --git a/simplylab/providers/user.py b/simplylab/providers/user.py index 0ab1337..9cf3d01 100644 --- a/simplylab/providers/user.py +++ b/simplylab/providers/user.py @@ -1,12 +1,18 @@ +import datetime from typing import Optional +from simplylab.database import Database from simplylab.entity import User class UserProvider: - def __init__(self): - ... + def __init__(self, db: Database) -> None: + self.db = db - def get_user_by_name(self, user_name: str) -> Optional[User]: - ... + async def get_user_by_name(self, user_name: str) -> Optional[User]: + user = await self.db.user.find_one({"name": user_name}) + if not user: + user = User(name=user_name) + user = await self.db.user.insert_one(user.model_dump()) + return user