add db
This commit is contained in:
parent
cd33a36e28
commit
bbb646ba92
|
@ -0,0 +1,18 @@
|
||||||
|
import motor.motor_asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
def __init__(self, client: motor.motor_asyncio.AsyncIOMotorClient):
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user(self):
|
||||||
|
return self._client.get_database("simplylab").get_collection("user")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def conversation(self):
|
||||||
|
return self._client.get_database("simplylab").get_collection("conversation")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message(self):
|
||||||
|
return self._client.get_database("simplylab").get_collection("message")
|
|
@ -1,8 +1,9 @@
|
||||||
import datetime
|
import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional, Annotated
|
||||||
|
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field, BeforeValidator, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
class GetAiChatResponseInput(BaseModel):
|
class GetAiChatResponseInput(BaseModel):
|
||||||
|
@ -37,52 +38,99 @@ class GetChatStatusTodayOutput(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
# === mongodb documents start ===
|
# === mongodb documents start ===
|
||||||
|
|
||||||
|
PyObjectId = Annotated[str, BeforeValidator(ObjectId)]
|
||||||
|
|
||||||
|
|
||||||
class MessageRoleType(str, Enum):
|
class MessageRoleType(str, Enum):
|
||||||
User = "user"
|
User = "user"
|
||||||
Ai = "ai"
|
Ai = "ai"
|
||||||
|
|
||||||
|
|
||||||
class User(BaseModel):
|
class User(BaseModel):
|
||||||
_id: ObjectId
|
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
||||||
name: str
|
name: str = Field(min_length=3, max_length=100, description="user name")
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
|
||||||
updated_at: datetime.datetime
|
updated_at: Optional[datetime.datetime] = Field(default=None)
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
populate_by_name=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"_id": "xxx",
|
||||||
|
"name": "jdoe",
|
||||||
|
"created_at": datetime.datetime.now(),
|
||||||
|
"updated_at": None,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
_id: ObjectId
|
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
||||||
conversation_id: ObjectId
|
conversation_id: PyObjectId = Field()
|
||||||
user_id: ObjectId
|
user_id: PyObjectId = Field()
|
||||||
type: MessageRoleType
|
type: MessageRoleType = Field()
|
||||||
text: str
|
text: str = Field()
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
|
||||||
created_by: ObjectId
|
created_by: PyObjectId = Field()
|
||||||
updated_at: datetime.datetime
|
updated_at: Optional[datetime.datetime] = Field(default=None)
|
||||||
updated_by: ObjectId
|
updated_by: Optional[PyObjectId] = Field(default=None)
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
populate_by_name=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"_id": "xxx",
|
||||||
|
"conversation_id": "xxx",
|
||||||
|
"user_id": "xxx",
|
||||||
|
"type": MessageRoleType.User,
|
||||||
|
"text": "xxx",
|
||||||
|
"created_at": datetime.datetime.now(),
|
||||||
|
"created_by": "xxx",
|
||||||
|
"updated_at": None,
|
||||||
|
"updated_by": None,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Conversation(BaseModel):
|
class Conversation(BaseModel):
|
||||||
_id: ObjectId
|
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
||||||
user_id: ObjectId
|
user_id: PyObjectId = Field()
|
||||||
title: str
|
title: str = Field()
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
|
||||||
created_by: ObjectId
|
created_by: PyObjectId = Field()
|
||||||
updated_at: datetime.datetime
|
updated_at: Optional[datetime.datetime] = Field(default=None)
|
||||||
updated_by: ObjectId
|
updated_by: Optional[PyObjectId] = Field(default=None)
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
populate_by_name=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"_id": "xxx",
|
||||||
|
"user_id": "xxx",
|
||||||
|
"title": "xx",
|
||||||
|
"created_at": datetime.datetime.now(),
|
||||||
|
"created_by": "xxx",
|
||||||
|
"updated_at": None,
|
||||||
|
"updated_by": None,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# === mongodb documents end ===
|
# === mongodb documents end ===
|
||||||
|
|
||||||
|
|
||||||
class UserConversationMessages(BaseModel):
|
class UserConversationMessages(BaseModel):
|
||||||
user_id: ObjectId
|
user_id: PyObjectId = Field()
|
||||||
user_name: str
|
user_name: str = Field()
|
||||||
conversation_id: ObjectId
|
conversation_id: PyObjectId = Field()
|
||||||
title: str
|
title: str = Field()
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
|
||||||
created_by: ObjectId
|
created_by: PyObjectId = Field()
|
||||||
updated_at: datetime.datetime
|
updated_at: Optional[datetime.datetime] = Field(default=None)
|
||||||
updated_by: ObjectId
|
updated_by: Optional[PyObjectId] = Field(default=None)
|
||||||
messages: list[Message]
|
messages: list[Message]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
import os
|
import os
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
|
import motor.motor_asyncio
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from loguru import logger
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
from simplylab.database import Database
|
||||||
from simplylab.entity import GetAiChatResponseInput, Context
|
from simplylab.entity import GetAiChatResponseInput, Context
|
||||||
from simplylab.entity import GetAiChatResponseOutput
|
from simplylab.entity import GetAiChatResponseOutput
|
||||||
from simplylab.entity import GetUserChatHistoryInput
|
from simplylab.entity import GetUserChatHistoryInput
|
||||||
|
@ -30,6 +32,22 @@ sentry_sdk.init(
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def startup_db_client():
|
||||||
|
mongo_username = os.getenv("MONGO_USERNAME")
|
||||||
|
mongo_password = os.getenv("MONGO_PASSWORD")
|
||||||
|
mongo_uri = f"mongodb://{mongo_username}:{mongo_password}@mongodb:27017/"
|
||||||
|
app.mongodb_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_uri)
|
||||||
|
app.db = Database(client=app.mongodb_client)
|
||||||
|
logger.info("Connected to the MongoDB database!")
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
def shutdown_db_client():
|
||||||
|
app.mongodb_client.close()
|
||||||
|
logger.info("Disconnected to the MongoDB database!")
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(Error)
|
@app.exception_handler(Error)
|
||||||
async def error_handler(request: Request, exc: Error):
|
async def error_handler(request: Request, exc: Error):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
@ -44,8 +62,8 @@ async def hi():
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/v1/get_ai_chat_response")
|
@app.post("/api/v1/get_ai_chat_response")
|
||||||
async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
|
async def get_ai_chat_response(request: Request, req: GetAiChatResponseInput) -> GetAiChatResponseOutput:
|
||||||
pvd = Providers()
|
pvd = Providers(db=request.app.db)
|
||||||
user = await pvd.user.get_user_by_name(req.user_name)
|
user = await pvd.user.get_user_by_name(req.user_name)
|
||||||
if not user:
|
if not user:
|
||||||
raise UserNotFoundError(req.user_name)
|
raise UserNotFoundError(req.user_name)
|
||||||
|
@ -56,8 +74,8 @@ async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponse
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/v1/get_user_chat_history")
|
@app.post("/api/v1/get_user_chat_history")
|
||||||
async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
|
async def get_user_chat_history(request: Request, req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput:
|
||||||
pvd = Providers()
|
pvd = Providers(db=request.app.db)
|
||||||
user = await pvd.user.get_user_by_name(req.user_name)
|
user = await pvd.user.get_user_by_name(req.user_name)
|
||||||
if not user:
|
if not user:
|
||||||
raise UserNotFoundError(req.user_name)
|
raise UserNotFoundError(req.user_name)
|
||||||
|
@ -68,8 +86,8 @@ async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHist
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/v1/get_chat_status_today")
|
@app.post("/api/v1/get_chat_status_today")
|
||||||
async def get_chat_status_today(req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
|
async def get_chat_status_today(request: Request, req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput:
|
||||||
pvd = Providers()
|
pvd = Providers(db=request.app.db)
|
||||||
user = await pvd.user.get_user_by_name(req.user_name)
|
user = await pvd.user.get_user_by_name(req.user_name)
|
||||||
if not user:
|
if not user:
|
||||||
raise UserNotFoundError(req.user_name)
|
raise UserNotFoundError(req.user_name)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from simplylab.database import Database
|
||||||
from simplylab.providers.chat import ChatProvider
|
from simplylab.providers.chat import ChatProvider
|
||||||
from simplylab.providers.openrouter import OpenRouterProvider
|
from simplylab.providers.openrouter import OpenRouterProvider
|
||||||
from simplylab.providers.user import UserProvider
|
from simplylab.providers.user import UserProvider
|
||||||
|
@ -5,17 +6,17 @@ from simplylab.providers.user import UserProvider
|
||||||
|
|
||||||
class Providers:
|
class Providers:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, db: Database) -> None:
|
||||||
...
|
self.db = db
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def openrouter(self):
|
def openrouter(self):
|
||||||
return OpenRouterProvider()
|
return OpenRouterProvider(db=self.db)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def user(self):
|
def user(self):
|
||||||
return UserProvider()
|
return UserProvider(db=self.db)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat(self):
|
def chat(self):
|
||||||
return ChatProvider()
|
return ChatProvider(db=self.db)
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from simplylab.database import Database
|
||||||
from simplylab.entity import UserConversationMessages
|
from simplylab.entity import UserConversationMessages
|
||||||
|
|
||||||
|
|
||||||
class ChatProvider:
|
class ChatProvider:
|
||||||
def __init__(self):
|
def __init__(self, db: Database) -> None:
|
||||||
...
|
self.db = db
|
||||||
|
|
||||||
async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool:
|
async def check_user_message_limited_in_30_seconds(self, user_id: str) -> bool:
|
||||||
...
|
...
|
||||||
|
|
|
@ -2,10 +2,12 @@ from openai import OpenAI
|
||||||
from os import getenv
|
from os import getenv
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from simplylab.database import Database
|
||||||
|
|
||||||
|
|
||||||
class OpenRouterProvider:
|
class OpenRouterProvider:
|
||||||
def __init__(self):
|
def __init__(self, db: Database) -> None:
|
||||||
...
|
self.db = db
|
||||||
|
|
||||||
async def chat(self, content: str) -> str:
|
async def chat(self, content: str) -> str:
|
||||||
# gets API Key from environment variable OPENAI_API_KEY
|
# gets API Key from environment variable OPENAI_API_KEY
|
||||||
|
|
|
@ -1,12 +1,18 @@
|
||||||
|
import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from simplylab.database import Database
|
||||||
from simplylab.entity import User
|
from simplylab.entity import User
|
||||||
|
|
||||||
|
|
||||||
class UserProvider:
|
class UserProvider:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, db: Database) -> None:
|
||||||
...
|
self.db = db
|
||||||
|
|
||||||
def get_user_by_name(self, user_name: str) -> Optional[User]:
|
async def get_user_by_name(self, user_name: str) -> Optional[User]:
|
||||||
...
|
user = await self.db.user.find_one({"name": user_name})
|
||||||
|
if not user:
|
||||||
|
user = User(name=user_name)
|
||||||
|
user = await self.db.user.insert_one(user.model_dump())
|
||||||
|
return user
|
||||||
|
|
Loading…
Reference in New Issue