From 14c6b5933c029279ac1ba9dfbfe974d6ab3d32b3 Mon Sep 17 00:00:00 2001 From: Jeremy Yin Date: Fri, 15 Mar 2024 15:01:45 +0800 Subject: [PATCH] pvd as deps arg --- simplylab/entity.py | 4 ++++ simplylab/main.py | 15 +++++++++++---- simplylab/services/__init__.py | 7 +++++-- simplylab/services/chat.py | 6 +++--- 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/simplylab/entity.py b/simplylab/entity.py index ed402ac..2adfe80 100644 --- a/simplylab/entity.py +++ b/simplylab/entity.py @@ -1,6 +1,10 @@ from pydantic import BaseModel +class Context(BaseModel): + user_name: str + + class GetAiChatResponseInput(BaseModel): message: str user_name: str diff --git a/simplylab/main.py b/simplylab/main.py index a3d75de..db8c431 100644 --- a/simplylab/main.py +++ b/simplylab/main.py @@ -4,13 +4,14 @@ from fastapi import FastAPI, Request from dotenv import load_dotenv from starlette.responses import JSONResponse -from simplylab.entity import GetAiChatResponseInput +from simplylab.entity import GetAiChatResponseInput, Context from simplylab.entity import GetAiChatResponseOutput from simplylab.entity import GetUserChatHistoryInput from simplylab.entity import GetUserChatHistoryOutput from simplylab.entity import GetChatStatusTodayInput from simplylab.entity import GetChatStatusTodayOutput from simplylab.error import Error +from simplylab.providers import Providers from simplylab.services import Services load_dotenv() @@ -32,20 +33,26 @@ async def hi(): @app.post("/api/v1/get_ai_chat_response") async def get_ai_chat_response(req: GetAiChatResponseInput) -> GetAiChatResponseOutput: - svc = Services(req) + ctx = Context(user_name=req.user_name) + pvd = Providers() + svc = Services(ctx, pvd) res = await svc.chat.get_ai_chat_response(req) return res @app.post("/api/v1/get_user_chat_history") async def get_user_chat_history(req: GetUserChatHistoryInput) -> GetUserChatHistoryOutput: - svc = Services(req) + ctx = Context(user_name=req.user_name) + pvd = Providers() + svc = Services(ctx, pvd) res = await svc.chat.get_user_chat_history(req) return res @app.post("/api/v1/get_chat_status_today") async def get_chat_status_today(req: GetChatStatusTodayInput) -> GetChatStatusTodayOutput: - svc = Services(req) + ctx = Context(user_name=req.user_name) + pvd = Providers() + svc = Services(ctx, pvd) res = await svc.chat.get_chat_status_today(req) return res diff --git a/simplylab/services/__init__.py b/simplylab/services/__init__.py index b95a51f..e8ce1ec 100644 --- a/simplylab/services/__init__.py +++ b/simplylab/services/__init__.py @@ -1,12 +1,15 @@ from typing import Any +from simplylab.entity import Context +from simplylab.providers import Providers from simplylab.services.chat import ChatService class Services: - def __init__(self, ctx: Any): + def __init__(self, ctx: Context, providers: Providers): self.ctx = ctx + self.pvd = providers @property def chat(self): - return ChatService(self.ctx) + return ChatService(self.ctx, self.pvd) diff --git a/simplylab/services/chat.py b/simplylab/services/chat.py index 084e909..9d65197 100644 --- a/simplylab/services/chat.py +++ b/simplylab/services/chat.py @@ -1,14 +1,14 @@ from typing import Any from simplylab.entity import GetAiChatResponseInput, GetUserChatHistoryInput, GetChatStatusTodayInput, UserChatMessage, \ - GetChatStatusTodayOutput, GetAiChatResponseOutput, GetUserChatHistoryOutput + GetChatStatusTodayOutput, GetAiChatResponseOutput, GetUserChatHistoryOutput, Context from simplylab.providers import Providers class ChatService: - def __init__(self, ctx: Any): + def __init__(self, ctx: Context, provider: Providers): self.ctx = ctx - self.pvd = Providers() + self.pvd = provider async def get_ai_chat_response(self, req: GetAiChatResponseInput) -> GetAiChatResponseOutput: message = req.message