diff --git a/pdm.lock b/pdm.lock index 1cb7186..bc0434c 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:e5b52083ef7f859b75eea32132ef301095ac77b239a08c49229d4a37474b980b" +content_hash = "sha256:68aec295b130841c79040d9aae40daa7c15c204a7b5756850923308c9460d5b4" [[package]] name = "annotated-types" @@ -270,6 +270,21 @@ files = [ {file = "pydantic_core-2.16.3.tar.gz", hash = "sha256:1cac689f80a3abab2d3c0048b29eea5751114054f032a941a32de4c852c59cad"}, ] +[[package]] +name = "pydantic-mongo" +version = "2.1.2" +requires_python = ">=3.8" +summary = "Document object mapper for pydantic and pymongo" +groups = ["default"] +dependencies = [ + "pydantic<3.0.0,>=2.0.2", + "pymongo<5.0,>=4.3", +] +files = [ + {file = "pydantic-mongo-2.1.2.tar.gz", hash = "sha256:3eb4db8f2eb5abb1e4af4d005c92a2ca0c386c3c8eff125a9174edfe4ffe6633"}, + {file = "pydantic_mongo-2.1.2-py2.py3-none-any.whl", hash = "sha256:44cc02484eb87e812064cf85130e87b8e5e0fa660cd7188e46f153b7eaf58617"}, +] + [[package]] name = "pymongo" version = "4.6.2" diff --git a/pyproject.toml b/pyproject.toml index 7ff6aa1..b617b11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "python-dotenv>=1.0.1", "motor>=3.3.2", "sentry-sdk>=1.42.0", + "pydantic-mongo>=2.1.2", ] requires-python = "==3.12.*" readme = "README.md" diff --git a/simplylab/entity.py b/simplylab/entity.py index 8eab3cd..b99d800 100644 --- a/simplylab/entity.py +++ b/simplylab/entity.py @@ -1,6 +1,6 @@ import datetime from enum import Enum -from typing import Optional, Annotated +from typing import Optional, Annotated, Any from bson import ObjectId from pydantic import BaseModel, Field, BeforeValidator, ConfigDict @@ -39,7 +39,26 @@ class GetChatStatusTodayOutput(BaseModel): # === mongodb documents start === -PyObjectId = Annotated[str, BeforeValidator(ObjectId)] + +from pydantic import (AfterValidator, GetPydanticSchema, + PlainSerializer, WithJsonSchema) +from pydantic_mongo import ObjectIdField as _objectIdField + +ObjectIdField = Annotated[ + _objectIdField, + AfterValidator(lambda id: _objectIdField(id)), + PlainSerializer(lambda id: str(id), return_type=str, when_used='json-unless-none'), + WithJsonSchema({'type': 'string'}, mode='serialization'), + WithJsonSchema({'type': 'string'}, mode='validation'), + GetPydanticSchema(lambda _s, h: h(Any)) +] + + +class Test(BaseModel): + id: ObjectIdField = Field(default_factory=ObjectIdField, alias='_id', title='_id') + + +# ObjectIdField = Annotated[str, BeforeValidator(ObjectId)] class MessageRoleType(str, Enum): @@ -48,13 +67,17 @@ class MessageRoleType(str, Enum): class User(BaseModel): - id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") + id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id') name: str = Field(min_length=3, max_length=100, description="user name") created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) updated_at: Optional[datetime.datetime] = Field(default=None) model_config = ConfigDict( populate_by_name=True, + # json_encoders={ + # ObjectId: str, + # datetime: lambda dt: dt.isoformat() + # }, json_schema_extra={ "example": { "_id": "xxx", @@ -67,15 +90,15 @@ class User(BaseModel): class Message(BaseModel): - id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") - conversation_id: PyObjectId = Field() - user_id: PyObjectId = Field() + id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id') + conversation_id: ObjectIdField = Field() + user_id: ObjectIdField = Field() type: MessageRoleType = Field() text: str = Field() created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) - created_by: PyObjectId = Field() + created_by: ObjectIdField = Field() updated_at: Optional[datetime.datetime] = Field(default=None) - updated_by: Optional[PyObjectId] = Field(default=None) + updated_by: Optional[ObjectIdField] = Field(default=None) model_config = ConfigDict( populate_by_name=True, @@ -96,13 +119,13 @@ class Message(BaseModel): class Conversation(BaseModel): - id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") - user_id: PyObjectId = Field() + id: ObjectIdField = Field(default_factory=ObjectIdField, alias="_id", title='_id') + user_id: ObjectIdField = Field() title: str = Field() created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) - created_by: PyObjectId = Field() + created_by: ObjectIdField = Field() updated_at: Optional[datetime.datetime] = Field(default=None) - updated_by: Optional[PyObjectId] = Field(default=None) + updated_by: Optional[ObjectIdField] = Field(default=None) model_config = ConfigDict( populate_by_name=True, @@ -119,18 +142,19 @@ class Conversation(BaseModel): }, ) + # === mongodb documents end === class UserConversationMessages(BaseModel): - user_id: PyObjectId = Field() + user_id: ObjectIdField = Field() user_name: str = Field() - conversation_id: PyObjectId = Field() + conversation_id: ObjectIdField = Field() title: str = Field() created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) - created_by: PyObjectId = Field() + created_by: ObjectIdField = Field() updated_at: Optional[datetime.datetime] = Field(default=None) - updated_by: Optional[PyObjectId] = Field(default=None) + updated_by: Optional[ObjectIdField] = Field(default=None) messages: list[Message] diff --git a/simplylab/main.py b/simplylab/main.py index 675bf19..f4adb31 100644 --- a/simplylab/main.py +++ b/simplylab/main.py @@ -48,6 +48,14 @@ def shutdown_db_client(): logger.info("Disconnected to the MongoDB database!") +@app.exception_handler(Exception) +async def exception_handler(request: Request, exc: Error): + return JSONResponse( + status_code=500, + content={"message": str(exc)} + ) + + @app.exception_handler(Error) async def error_handler(request: Request, exc: Error): return JSONResponse( diff --git a/simplylab/providers/user.py b/simplylab/providers/user.py index 9cf3d01..87d976a 100644 --- a/simplylab/providers/user.py +++ b/simplylab/providers/user.py @@ -14,5 +14,6 @@ class UserProvider: user = await self.db.user.find_one({"name": user_name}) if not user: user = User(name=user_name) - user = await self.db.user.insert_one(user.model_dump()) + res = await self.db.user.insert_one(user.model_dump()) + user = await self.db.user.find_one({"_id": res.inserted_id}) return user