add db
This commit is contained in:
parent
cd33a36e28
commit
bbb646ba92
|
@ -0,0 +1,18 @@
|
|||
import motor.motor_asyncio
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, client: motor.motor_asyncio.AsyncIOMotorClient):
|
||||
self._client = client
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
return self._client.get_database("simplylab").get_collection("user")
|
||||
|
||||
@property
|
||||
def conversation(self):
|
||||
return self._client.get_database("simplylab").get_collection("conversation")
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return self._client.get_database("simplylab").get_collection("message")
|
|
@ -1,8 +1,9 @@
|
|||
import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional, Annotated
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, BeforeValidator, ConfigDict
|
||||
|
||||
|
||||
class GetAiChatResponseInput(BaseModel):
|
||||
|
@ -37,52 +38,99 @@ class GetChatStatusTodayOutput(BaseModel):
|
|||
|
||||
|
||||
# === mongodb documents start ===
|
||||
|
||||
PyObjectId = Annotated[str, BeforeValidator(ObjectId)]
|
||||
|
||||
|
||||
class MessageRoleType(str, Enum):
|
||||
User = "user"
|
||||
Ai = "ai"
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
_id: ObjectId
|
||||
name: str
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
||||
name: str = Field(min_length=3, max_length=100, description="user name")
|
||||
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
|
||||
updated_at: Optional[datetime.datetime] = Field(default=None)
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"_id": "xxx",
|
||||
"name": "jdoe",
|
||||
"created_at": datetime.datetime.now(),
|
||||
"updated_at": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
_id: ObjectId
|
||||
conversation_id: ObjectId
|
||||
user_id: ObjectId
|
||||
type: MessageRoleType
|
||||
text: str
|
||||
created_at: datetime.datetime
|
||||
created_by: ObjectId
|
||||
updated_at: datetime.datetime
|
||||
updated_by: ObjectId
|
||||
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
||||
conversation_id: PyObjectId = Field()
|
||||
user_id: PyObjectId = Field()
|
||||
type: MessageRoleType = Field()
|
||||
text: str = Field()
|
||||
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
|
||||
created_by: PyObjectId = Field()
|
||||
updated_at: Optional[datetime.datetime] = Field(default=None)
|
||||
updated_by: Optional[PyObjectId] = Field(default=None)
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"_id": "xxx",
|
||||
"conversation_id": "xxx",
|
||||
"user_id": "xxx",
|
||||
"type": MessageRoleType.User,
|
||||
"text": "xxx",
|
||||
"created_at": datetime.datetime.now(),
|
||||
"created_by": "xxx",
|
||||
"updated_at": None,
|
||||
"updated_by": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
_id: ObjectId
|
||||
user_id: ObjectId
|
||||
title: str
|
||||
created_at: datetime.datetime
|
||||
created_by: ObjectId
|
||||
updated_at: datetime.datetime
|
||||
updated_by: ObjectId
|
||||
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
||||
user_id: PyObjectId = Field()
|
||||
title: str = Field()
|
||||
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
|
||||
created_by: PyObjectId = Field()
|
||||
updated_at: Optional[datetime.datetime] = Field(default=None)
|
||||
updated_by: Optional[PyObjectId] = Field(default=None)
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"_id": "xxx",
|
||||
"user_id": "xxx",
|
||||
"title": "xx",
|
||||
"created_at": datetime.datetime.now(),
|
||||
"created_by": "xxx",
|
||||
"updated_at": None,
|
||||
"updated_by": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# === mongodb documents end ===
|
||||
|
||||
|
||||
class UserConversationMessages(BaseModel):
|
||||
user_id: ObjectId
|
||||
user_name: str
|
||||
conversation_id: ObjectId
|
||||
title: str
|
||||
created_at: datetime.datetime
|
||||
created_by: ObjectId
|
||||
updated_at: datetime.datetime
|
||||
updated_by: ObjectId
|
||||
user_id: PyObjectId = Field()
|
||||
user_name: str = Field()
|
||||
conversation_id: PyObjectId = Field()
|
||||
title: str = Field()
|
||||
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
|
||||
created_by: PyObjectId = Field()
|
||||
updated_at: Optional[datetime.datetime] = Field(default=None)
|
||||
updated_by: Optional[PyObjectId] = Field(default=None)
|
||||
messages: list[Message]
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import os
|
||||
from typing import Union
|
||||
|
||||
import sentry_sdk
|
||||
import motor.motor_asyncio
|
||||
from fastapi import FastAPI, Request
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from simplylab.database import Database
|
||||
from simplylab.entity import GetAiChatResponseInput, Context
|
||||
from simplylab.entity import GetAiChatResponseOutput
|
||||
from simplylab.entity import GetUserChatHistoryInput
|
||||
|
@ -30,6 +32,22 @@ sentry_sdk.init(
|
|||
app = FastAPI()
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup_db_client():
|
||||
mongo_username = os.getenv("MONGO_USERNAME")
|
||||
mongo_password = os.getenv("MONGO_PASSWORD")
|
||||
mongo_uri = f"mongodb://{mongo_username}:{mongo_password}@mongodb:27017/"
|
||||
app.mongodb_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_uri)
|
||||
app.db = Database(client=app.mongodb_client)
|
||||
logger.info("Connected to the MongoDB database!")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_db_client():
|
||||
app.mongodb_client.close()
|
||||
logger.info("Disconnected to the MongoDB database!")
|
||||
|
||||
|
||||
@app.exception_handler(Error)
|
||||
async def error_handler(request: Request, exc: Error):
|
||||
return JSONResponse(
|
||||
|
@ -44,8 +62,8 @@ async def hi():
|
|||
|
||||
|
||||
@app.post("/api/v1/get_ai_chat_response")
|
||||
async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
|
||||
pvd = Providers()
|
||||
async def get_ai_chat_response(request: Request, req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
|
||||
pvd = Providers(db=request.app.db)
|
||||
user = await pvd.user.get_user_by_name(req.user_name)
|
||||
if not user:
|
||||
raise UserNotFoundError(req.user_name)
|
||||
|
@ -56,8 +74,8 @@ async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponse
|
|||
|
||||
|
||||
@app.post("/api/v1/get_user_chat_history")
|
||||
async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
|
||||
pvd = Providers()
|
||||
async def get_user_chat_history(request: Request, req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
|
||||
pvd = Providers(db=request.app.db)
|
||||
user = await pvd.user.get_user_by_name(req.user_name)
|
||||
if not user:
|
||||
raise UserNotFoundError(req.user_name)
|
||||
|
@ -68,8 +86,8 @@ async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHist
|
|||
|
||||
|
||||
@app.post("/api/v1/get_chat_status_today")
|
||||
async def get_chat_status_today(req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
|
||||
pvd = Providers()
|
||||
async def get_chat_status_today(request: Request, req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
|
||||
pvd = Providers(db=request.app.db)
|
||||
user = await pvd.user.get_user_by_name(req.user_name)
|
||||
if not user:
|
||||
raise UserNotFoundError(req.user_name)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from simplylab.database import Database
|
||||
from simplylab.providers.chat import ChatProvider
|
||||
from simplylab.providers.openrouter import OpenRouterProvider
|
||||
from simplylab.providers.user import UserProvider
|
||||
|
@ -5,17 +6,17 @@ from simplylab.providers.user import UserProvider
|
|||
|
||||
class Providers:
|
||||
|
||||
def __init__(self):
|
||||
...
|
||||
def __init__(self, db: Database) -> None:
|
||||
self.db = db
|
||||
|
||||
@property
|
||||
def openrouter(self):
|
||||
return OpenRouterProvider()
|
||||
return OpenRouterProvider(db=self.db)
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
return UserProvider()
|
||||
return UserProvider(db=self.db)
|
||||
|
||||
@property
|
||||
def chat(self):
|
||||
return ChatProvider()
|
||||
return ChatProvider(db=self.db)
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from typing import Optional
|
||||
|
||||
from simplylab.database import Database
|
||||
from simplylab.entity import UserConversationMessages
|
||||
|
||||
|
||||
class ChatProvider:
|
||||
def __init__(self):
|
||||
...
|
||||
def __init__(self, db: Database) -> None:
|
||||
self.db = db
|
||||
|
||||
async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool:
|
||||
...
|
||||
|
|
|
@ -2,10 +2,12 @@ from openai import OpenAI
|
|||
from os import getenv
|
||||
from loguru import logger
|
||||
|
||||
from simplylab.database import Database
|
||||
|
||||
|
||||
class OpenRouterProvider:
|
||||
def __init__(self):
|
||||
...
|
||||
def __init__(self, db: Database) -> None:
|
||||
self.db = db
|
||||
|
||||
async def chat(self, content: str) -> str:
|
||||
# gets API Key from environment variable OPENAI_API_KEY
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from simplylab.database import Database
|
||||
from simplylab.entity import User
|
||||
|
||||
|
||||
class UserProvider:
|
||||
|
||||
def __init__(self):
|
||||
...
|
||||
def __init__(self, db: Database) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_user_by_name(self, user_name: str) -> Optional[User]:
|
||||
...
|
||||
async def get_user_by_name(self, user_name: str) -> Optional[User]:
|
||||
user = await self.db.user.find_one({"name": user_name})
|
||||
if not user:
|
||||
user = User(name=user_name)
|
||||
user = await self.db.user.insert_one(user.model_dump())
|
||||
return user
|
||||
|
|
Loading…
Reference in New Issue