feat: rust version of simplylab
This commit is contained in:
parent
38e837a7b8
commit
92437ec9e9
|
@ -0,0 +1,4 @@
|
|||
/target/debug/.fingerprint
|
||||
/target/release/.fingerprint
|
||||
.idea
|
||||
*.DS_Store
|
|
@ -0,0 +1,10 @@
|
|||
MONGO_HOST=localhost
|
||||
MONGO_PORT=27017
|
||||
MONGO_USERNAME=root
|
||||
MONGO_PASSWORD=root123456
|
||||
OPENROUTER_API_KEY=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
|
||||
SENTRY_DSN=https://3f355956f8c5f65ae5a29a5fc665c0c0@o4506630800539648.ingest.us.sentry.io/4506914597568512
|
||||
VIRTUAL_HOST=simplylab-rs.isyin.cn
|
||||
VIRTUAL_PORT=8002
|
||||
LETSENCRYPT_HOST=simplylab-rs.isyin.cn
|
||||
LETSENCRYPT_EMAIL=jeremyyin2012@gmail.com
|
|
@ -0,0 +1,10 @@
|
|||
MONGO_HOST=
|
||||
MONGO_PORT=
|
||||
MONGO_USERNAME=
|
||||
MONGO_PASSWORD=
|
||||
OPENROUTER_API_KEY=
|
||||
SENTRY_DSN=
|
||||
VIRTUAL_HOST=
|
||||
VIRTUAL_PORT=
|
||||
LETSENCRYPT_HOST=
|
||||
LETSENCRYPT_EMAIL=
|
|
@ -5,7 +5,7 @@ target/
|
|||
|
||||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
#Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,53 @@
|
|||
[package]
|
||||
name = "simplylab"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.35.1", features = ["full"] }
|
||||
chrono = { version = "0.4.31", features = ["serde"] }
|
||||
serde = { version = "1.0.195", features = ["derive"] }
|
||||
serde_json = "1.0.111"
|
||||
schemars = { version = "0.8.16", features = ["uuid1", "chrono", "bigdecimal03"] }
|
||||
dotenv = "0.15.0"
|
||||
dotenvy = "0.15.7"
|
||||
futures = "0.3.30"
|
||||
thiserror = "1.0.56"
|
||||
anyhow = "1.0.79"
|
||||
lazy_static = "1.4.0"
|
||||
rocket = { version = "0.5.0", features = ["json", "uuid"] }
|
||||
okapi = "0.7.0"
|
||||
rocket_okapi = { version = "0.8.0", features = ["uuid", "rocket_db_pools", "rocket_ws", "swagger", "rapidoc"] }
|
||||
redis = { version = "0.24.0", features = ["aio", "tokio-comp", "connection-manager"] }
|
||||
reqwest = { version = "0.11.23", features = ["json"] }
|
||||
url = { version = "2.5.0", features = ["serde"] }
|
||||
sea-query = { version = "0.30.7", features = ["derive", "attr", "thread-safe", "backend-mysql", "with-chrono", "with-time", "with-json", "with-rust_decimal", "with-bigdecimal", "with-uuid"] }
|
||||
uuid7 = { version = "0.7.2", features = ["serde", "uuid"] }
|
||||
bigdecimal = { version = "0.3.1", features = ["serde"] }
|
||||
sqlx = { version = "0.7.3", features = ["postgres", "sqlx-postgres", "uuid", "chrono", "bigdecimal", "runtime-tokio-rustls"] }
|
||||
sqlx-postgres = { version = "0.7.3", features = ["bigdecimal"] }
|
||||
sentry = { version = "0.32.1", default-features = false, features = ["reqwest", "rustls", "backtrace", "contexts", "panic", "transport"] }
|
||||
regex = "1.10.2"
|
||||
clap = { version = "4.4.16", features = ["derive"] }
|
||||
calamine = "0.23.1"
|
||||
csv = "1.3.0"
|
||||
proc-macro2 = "1.0.76"
|
||||
mongodb = {version = "2.8.2", features = ["bson-chrono-0_4", "bson-serde_with"]}
|
||||
log = "0.4.17"
|
||||
async-openai = "0.19.1"
|
||||
|
||||
[dependencies.rocket_db_pools]
|
||||
version = "0.1.0"
|
||||
features = ["sqlx_mysql"]
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1.6.1"
|
||||
features = [
|
||||
"v4", # Lets you generate random UUIDs
|
||||
"serde",
|
||||
"fast-rng", # Use a faster (but still sufficiently random) RNG
|
||||
"macro-diagnostics", # Enable better diagnostics for compile-time UUIDs
|
||||
]
|
|
@ -0,0 +1,7 @@
|
|||
FROM debian:bullseye
|
||||
ADD ./target/release/simplylab /app/
|
||||
ADD ./Rocket.toml /app/
|
||||
|
||||
WORKDIR /app/
|
||||
|
||||
CMD ["./simplylab"]
|
|
@ -0,0 +1,11 @@
|
|||
[default]
|
||||
address = "0.0.0.0"
|
||||
|
||||
[debug]
|
||||
port = 8002
|
||||
workers = 2
|
||||
log_level = "debug"
|
||||
|
||||
[release]
|
||||
port = 8002
|
||||
workers = 4
|
|
@ -0,0 +1,47 @@
|
|||
version: "3"
|
||||
|
||||
services:
|
||||
mongodb:
|
||||
image: mongo:7.0.6
|
||||
restart: always
|
||||
volumes:
|
||||
- "./data/mongo/configdb:/data/configdb"
|
||||
- "./data/mongo/db:/data/db"
|
||||
environment:
|
||||
TZ: "Asia/Shanghai"
|
||||
MONGO_INITDB_ROOT_USERNAME: ${MONGO_USERNAME}
|
||||
MONGO_INITDB_ROOT_PASSWORD: ${MONGO_PASSWORD}
|
||||
ports:
|
||||
- 27017:27017
|
||||
tty: true
|
||||
stdin_open: true
|
||||
networks:
|
||||
- simplylab
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "2m"
|
||||
max-file: "10"
|
||||
|
||||
mongo-express:
|
||||
image: mongo-express:1.0.2
|
||||
restart: always
|
||||
ports:
|
||||
- 8081:8081
|
||||
environment:
|
||||
TZ: "Asia/Shanghai"
|
||||
ME_CONFIG_MONGODB_ADMINUSERNAME: ${MONGO_USERNAME}
|
||||
ME_CONFIG_MONGODB_ADMINPASSWORD: ${MONGO_PASSWORD}
|
||||
ME_CONFIG_MONGODB_URL: mongodb://${MONGO_USERNAME}:${MONGO_PASSWORD}@mongodb:27017/
|
||||
tty: true
|
||||
stdin_open: true
|
||||
networks:
|
||||
- simplylab
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "2m"
|
||||
max-file: "10"
|
||||
|
||||
networks:
|
||||
simplylab:
|
|
@ -0,0 +1,32 @@
|
|||
version: "3"
|
||||
|
||||
services:
|
||||
simplylab-rs:
|
||||
build:
|
||||
dockerfile: Dockerfile
|
||||
context: .
|
||||
restart: always
|
||||
environment:
|
||||
TZ: "Asia/Shanghai"
|
||||
VIRTUAL_HOST: ${VIRTUAL_HOST}
|
||||
VIRTUAL_PORT: ${VIRTUAL_PORT}
|
||||
LETSENCRYPT_HOST: ${LETSENCRYPT_HOST}
|
||||
LETSENCRYPT_EMAIL: ${LETSENCRYPT_EMAIL}
|
||||
expose:
|
||||
- 8002
|
||||
tty: true
|
||||
stdin_open: true
|
||||
networks:
|
||||
- simplylab
|
||||
- nginx-proxy
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "2m"
|
||||
max-file: "10"
|
||||
|
||||
networks:
|
||||
simplylab:
|
||||
nginx-proxy:
|
||||
external:
|
||||
name: nginx-proxy_default
|
|
@ -0,0 +1,2 @@
|
|||
[toolchain]
|
||||
channel = "stable"
|
|
@ -0,0 +1,67 @@
|
|||
use std::env;
|
||||
|
||||
use dotenvy::dotenv;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub app_env: String,
|
||||
pub debug: bool,
|
||||
pub sentry_dsn: String,
|
||||
pub database_url: String,
|
||||
pub max_size: u32,
|
||||
pub redis_url: String,
|
||||
pub openrouter_api_key: String,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
app_env: "dev".to_string(),
|
||||
debug: true,
|
||||
sentry_dsn: "".to_string(),
|
||||
database_url: "".to_string(),
|
||||
max_size: 10,
|
||||
redis_url: "".to_string(),
|
||||
openrouter_api_key: "".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub async fn new() -> Self {
|
||||
println!("Config init");
|
||||
let config = connect().await;
|
||||
println!("{:?}", config);
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect() -> Config {
|
||||
dotenv().ok();
|
||||
let app_env = env::var("APP_ENV").unwrap_or("dev".to_string());
|
||||
|
||||
let debug = env::var("DEBUG").unwrap_or("true".to_string())
|
||||
.parse::<bool>()
|
||||
.unwrap();
|
||||
|
||||
let sentry_dsn = env::var("SENTRY_DSN").unwrap_or("".to_string());
|
||||
|
||||
let database_url = env::var("DATABASE_URL").unwrap_or("".to_string());
|
||||
let max_size = env::var("DATABASE_POOL_MAX_SIZE").unwrap_or("10".to_string())
|
||||
.parse::<u32>()
|
||||
.unwrap();
|
||||
|
||||
let redis_url = env::var("REDIS_URL").unwrap_or("".to_string());
|
||||
let openrouter_api_key = env::var("OPENROUTER_API_KEY").unwrap_or("".to_string());
|
||||
|
||||
Config {
|
||||
app_env,
|
||||
debug,
|
||||
sentry_dsn,
|
||||
database_url,
|
||||
max_size,
|
||||
redis_url,
|
||||
openrouter_api_key,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,200 @@
|
|||
use rocket::response::Responder;
|
||||
use rocket::{Request, response, Response};
|
||||
use serde_json::{json, Value};
|
||||
use rocket::http::{ContentType, Status};
|
||||
use thiserror::Error;
|
||||
use std::io;
|
||||
use okapi::openapi3::{MediaType, RefOr, Responses};
|
||||
use reqwest::header::InvalidHeaderValue;
|
||||
use rocket_okapi::gen::OpenApiGenerator;
|
||||
use rocket_okapi::response::OpenApiResponderInner;
|
||||
use schemars::JsonSchema;
|
||||
use schemars::schema::SchemaObject;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::task;
|
||||
|
||||
#[derive(Error, Debug, serde::Serialize, schemars::JsonSchema, Copy, Clone)]
|
||||
pub enum Code {
|
||||
#[error("示例")]
|
||||
Example = 10000,
|
||||
#[error("未找到记录")]
|
||||
RecordNotFound,
|
||||
#[error("未找到用户")]
|
||||
UserNotFound,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("此功能暂未实现")]
|
||||
NotImplemented,
|
||||
#[error("未认证")]
|
||||
Unauthorized,
|
||||
#[error("无权限")]
|
||||
Forbidden,
|
||||
#[error("参数错误: {0}")]
|
||||
ParamsError(String),
|
||||
#[error("服务报错: {0}")]
|
||||
ServerError(String),
|
||||
#[error("数据库连接报错: 无法获取数据库连接: {0}")]
|
||||
DatabaseConnectionError(String),
|
||||
#[error("数据库Ping报错: {0}")]
|
||||
DatabasePingError(String),
|
||||
#[error("上游服务报错: {0}")]
|
||||
UpstreamError(String),
|
||||
#[error("服务反馈: {0}")]
|
||||
Feedback(Code),
|
||||
// 以下是由 thiserror 提供的自动错误转换
|
||||
#[error("EnvVarError: {0}")]
|
||||
EnvVarError(#[from] std::env::VarError),
|
||||
#[error("UuidError: {0}")]
|
||||
UuidError(#[from] uuid::Error),
|
||||
#[error("ChronoParseError: {0}")]
|
||||
ChronoParseError(#[from] chrono::ParseError),
|
||||
#[error("RedisError: {0}")]
|
||||
RedisError(#[from] redis::RedisError),
|
||||
#[error("SerdeJsonError: {0}")]
|
||||
SerdeJsonError(#[from] serde_json::Error),
|
||||
#[error("UrlError: {0}")]
|
||||
UrlError(#[from] url::ParseError),
|
||||
#[error("ReqwestError: {0}")]
|
||||
ReqwestError(#[from] reqwest::Error),
|
||||
#[error("SqlxError: {0}")]
|
||||
SqlxError(#[from] sqlx::Error),
|
||||
#[error("SeaQueryError: {0}")]
|
||||
SeaQueryError(#[from] sea_query::error::Error),
|
||||
#[error("IOError: {0}")]
|
||||
IOError(#[from] io::Error),
|
||||
#[error("TaskJoinError: {0}")]
|
||||
TaskJoinError(#[from] task::JoinError),
|
||||
#[error("mongodb::bson::datetime::Error: {0}")]
|
||||
BsonDatetimeError(#[from] mongodb::bson::datetime::Error),
|
||||
#[error("mongodb::error::Error: {0}")]
|
||||
MongodbError(#[from] mongodb::error::Error),
|
||||
#[error("mongodb::bson::oid::Error: {0}")]
|
||||
MongodbObjectIdError(#[from] mongodb::bson::oid::Error),
|
||||
#[error("InvalidHeaderValue: {0}")]
|
||||
InvalidHeaderValue(#[from] InvalidHeaderValue),
|
||||
// 其他任何错误
|
||||
#[error("AnyhowError: {0}")]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl Error {
|
||||
fn get_http_status(&self) -> Status {
|
||||
match self {
|
||||
Error::Unauthorized => Status::Unauthorized,
|
||||
Error::Forbidden => Status::Forbidden,
|
||||
Error::ParamsError(_) => Status::BadRequest,
|
||||
Error::ServerError(_) => Status::InternalServerError,
|
||||
Error::DatabaseConnectionError(_) => Status::ServiceUnavailable,
|
||||
Error::DatabasePingError(_) => Status::ServiceUnavailable,
|
||||
Error::UpstreamError(_) => Status::InternalServerError,
|
||||
Error::Feedback(_) => Status::Ok,
|
||||
_ => Status::InternalServerError,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
struct ErrorMessage {
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl<'r> Responder<'r, 'static> for Error {
|
||||
fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> {
|
||||
let errormsg = self.to_string();
|
||||
println!(
|
||||
"Error: Request: {} Response: {:#}",
|
||||
req, errormsg
|
||||
);
|
||||
// 发往sentry的如果有Backtrace则会包含Backtrace,返回给用户的则一定不会有Backtrace
|
||||
sentry::capture_error(&self);
|
||||
let resp = ErrorMessage {
|
||||
message: errormsg
|
||||
.split(", Backtrace")
|
||||
.next()
|
||||
.unwrap_or_default()
|
||||
.parse()
|
||||
.unwrap()
|
||||
};
|
||||
let err_response = serde_json::to_string(&resp).unwrap();
|
||||
Response::build()
|
||||
.status(self.get_http_status())
|
||||
.header(ContentType::JSON)
|
||||
.sized_body(err_response.len(), std::io::Cursor::new(err_response))
|
||||
.ok()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn response_err(
|
||||
_gen: &mut OpenApiGenerator,
|
||||
schema: SchemaObject,
|
||||
desc: String,
|
||||
example: Option<Value>,
|
||||
) -> okapi::openapi3::Response {
|
||||
okapi::openapi3::Response {
|
||||
description: desc.to_owned(),
|
||||
content: okapi::map! {
|
||||
"application/json".to_owned() => MediaType{
|
||||
schema: Some(schema),
|
||||
example: example,
|
||||
..Default::default()
|
||||
}
|
||||
},
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenApiResponderInner for Error {
|
||||
fn responses(gen: &mut OpenApiGenerator) -> rocket_okapi::Result<Responses> {
|
||||
let schema = gen.json_schema::<ErrorMessage>();
|
||||
Ok(Responses {
|
||||
responses: okapi::map! {
|
||||
Error::Feedback(Code::Example).get_http_status().to_string() + ": Feedback" => RefOr::Object(
|
||||
response_err(gen, schema.clone(), Error::Feedback(Code::Example).to_string(),
|
||||
Some(json!({
|
||||
"message": Error::Feedback(Code::Example).to_string(),
|
||||
})))),
|
||||
|
||||
Error::Unauthorized.get_http_status().to_string() => RefOr::Object(
|
||||
response_err(gen, schema.clone(), Error::Unauthorized.to_string(),
|
||||
Some(json!({
|
||||
"message": Error::Unauthorized.to_string(),
|
||||
})))),
|
||||
|
||||
Error::Forbidden.get_http_status().to_string() => RefOr::Object(
|
||||
response_err(gen, schema.clone(), Error::Forbidden.to_string(),
|
||||
Some(json!({
|
||||
"message": Error::Forbidden.to_string(),
|
||||
})))),
|
||||
|
||||
Error::ParamsError("".to_string()).get_http_status().to_string() + ": ParamsError" => RefOr::Object(
|
||||
response_err(gen, schema.clone(), Error::ParamsError("".to_string()).to_string(),
|
||||
Some(json!({
|
||||
"message": Error::ParamsError("".to_string()).to_string(),
|
||||
})))),
|
||||
|
||||
Error::ServerError("".to_string()).get_http_status().to_string() + ": ServerError" => RefOr::Object(
|
||||
response_err(gen, schema.clone(), Error::ServerError("".to_string()).to_string(),
|
||||
Some(json!({
|
||||
"message": Error::ServerError("".to_string()).to_string(),
|
||||
})))),
|
||||
|
||||
Error::UpstreamError("".to_string()).get_http_status().to_string() + ": UpstreamError" => RefOr::Object(
|
||||
response_err(gen, schema.clone(), Error::UpstreamError("".to_string()).to_string(),
|
||||
Some(json!({
|
||||
"message": Error::UpstreamError("".to_string()).to_string(),
|
||||
})))),
|
||||
|
||||
Error::NotImplemented.get_http_status().to_string() => RefOr::Object(
|
||||
response_err(gen, schema.clone(), Error::NotImplemented.to_string(),
|
||||
Some(json!({
|
||||
"message": Error::NotImplemented.to_string(),
|
||||
})))),
|
||||
|
||||
},
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
#![allow(dead_code)]
|
||||
#![allow(unused_variables)]
|
||||
#![allow(unused_imports)]
|
||||
|
||||
extern crate core;
|
||||
extern crate dotenv;
|
||||
#[macro_use]
|
||||
extern crate rocket;
|
||||
|
||||
use std::env;
|
||||
use dotenv::dotenv;
|
||||
use rocket_okapi::openapi_get_routes;
|
||||
use rocket_okapi::rapidoc::GeneralConfig;
|
||||
use rocket_okapi::rapidoc::make_rapidoc;
|
||||
use rocket_okapi::rapidoc::RapiDocConfig;
|
||||
use rocket_okapi::settings::UrlObject;
|
||||
use rocket_okapi::swagger_ui::make_swagger_ui;
|
||||
use rocket_okapi::swagger_ui::SwaggerUIConfig;
|
||||
|
||||
use crate::services::Services;
|
||||
use crate::store::Store;
|
||||
|
||||
mod services;
|
||||
mod conf;
|
||||
mod model;
|
||||
mod route;
|
||||
mod providers;
|
||||
mod store;
|
||||
mod error;
|
||||
|
||||
|
||||
pub fn get_docs() -> SwaggerUIConfig {
|
||||
SwaggerUIConfig {
|
||||
url: "/openapi.json".to_string(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_rapidoc() -> RapiDocConfig {
|
||||
RapiDocConfig {
|
||||
general: GeneralConfig {
|
||||
spec_urls: vec![UrlObject::new("General", "/openapi.json")],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[rocket::main]
|
||||
async fn main() -> Result<(), rocket::Error> {
|
||||
env::set_var("RUST_BACKTRACE", "1");
|
||||
dotenv().ok();
|
||||
let routes = openapi_get_routes![
|
||||
route::index,
|
||||
route::favicon,
|
||||
route::get_ai_chat_response,
|
||||
route::get_user_chat_history,
|
||||
route::get_chat_status_today,
|
||||
];
|
||||
let store = Store::new().await;
|
||||
let sentry_dsn = store.config.sentry_dsn.clone();
|
||||
let app_env = store.config.app_env.clone();
|
||||
let _guard = sentry::init((
|
||||
sentry_dsn,
|
||||
sentry::ClientOptions {
|
||||
release: sentry::release_name!(),
|
||||
environment: Some(app_env.into()),
|
||||
send_default_pii: true,
|
||||
..Default::default()
|
||||
},
|
||||
));
|
||||
let _rocket = rocket::build()
|
||||
.manage(store)
|
||||
.mount("/", routes)
|
||||
.mount("/docs", make_swagger_ui(&get_docs()))
|
||||
.mount("/rapidoc", make_rapidoc(&get_rapidoc()))
|
||||
.launch()
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
use crate::error::Error;
|
||||
use serde::{Serialize, Serializer};
|
||||
use std::str::FromStr;
|
||||
use uuid::Uuid;
|
||||
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
use std::collections::HashMap;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::str::FromStr;
|
||||
use bigdecimal::BigDecimal;
|
||||
use chrono::NaiveDateTime;
|
||||
use mongodb::bson;
|
||||
use mongodb::bson::oid::ObjectId;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use uuid::Uuid;
|
||||
use crate::error::Error;
|
||||
use bson::serde_helpers::hex_string_as_object_id;
|
||||
|
||||
|
||||
pub type UserId = String;
|
||||
pub type UserName = String;
|
||||
pub type CreatedAt = NaiveDateTime;
|
||||
pub type CreatedBy = UserId;
|
||||
pub type UpdatedAt = Option<NaiveDateTime>;
|
||||
pub type UpdatedBy = Option<UserId>;
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
|
||||
pub struct User {
|
||||
pub id: UserId,
|
||||
pub name: UserName,
|
||||
pub created_at: CreatedAt,
|
||||
pub updated_at: UpdatedAt,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
|
||||
pub enum MessageRoleType {
|
||||
#[serde(rename="user")]
|
||||
User,
|
||||
#[serde(rename="ai")]
|
||||
AI,
|
||||
}
|
||||
|
||||
impl FromStr for MessageRoleType {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"user" => Ok(Self::User),
|
||||
"ai" => Ok(Self::AI),
|
||||
_ => Err(Error::ParamsError("ai/user pls".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for MessageRoleType {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::User => f.write_str("user"),
|
||||
Self::AI => f.write_str("ai"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
pub type MessageId = String;
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
|
||||
pub struct Message {
|
||||
pub id: MessageId,
|
||||
pub user_id: UserId,
|
||||
#[serde(rename="type")]
|
||||
pub type_: MessageRoleType,
|
||||
pub text: String,
|
||||
pub created_at: CreatedAt,
|
||||
pub created_by: CreatedBy,
|
||||
pub updated_at: UpdatedAt,
|
||||
pub updated_by: UpdatedBy,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
|
||||
pub struct Context {
|
||||
pub user: User,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub fn new(user: User) -> Self {
|
||||
Self {
|
||||
user
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
use std::collections::HashMap;
|
||||
use std::fmt::{Display, Formatter, Write};
|
||||
use std::ops::Deref;
|
||||
use std::str::FromStr;
|
||||
|
||||
use rocket::serde::{Deserialize, Serialize};
|
||||
use schemars::JsonSchema;
|
||||
use schemars::_private::NoSerialize;
|
||||
use serde::{Deserializer, Serializer};
|
||||
use serde_json::Value;
|
||||
use sqlx::types::Uuid;
|
||||
|
||||
pub use req::*;
|
||||
pub use res::*;
|
||||
pub use table::*;
|
||||
pub use entity::*;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::error::Error::ParamsError;
|
||||
|
||||
mod deser;
|
||||
mod entity;
|
||||
mod req;
|
||||
mod res;
|
||||
mod table;
|
|
@ -0,0 +1,32 @@
|
|||
use chrono::NaiveDateTime;
|
||||
use mongodb::bson::oid::ObjectId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use schemars::JsonSchema;
|
||||
use rocket::form::FromForm;
|
||||
use crate::model::MessageRoleType;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, FromForm)]
|
||||
pub struct GetAiChatResponseInput {
|
||||
pub message: String,
|
||||
pub user_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct GetUserChatHistoryInput {
|
||||
pub user_name: String,
|
||||
pub last_n: i8,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct GetChatStatusTodayInput {
|
||||
pub user_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct NewMessage {
|
||||
pub user_id: String,
|
||||
#[serde(rename="type")]
|
||||
pub type_: MessageRoleType,
|
||||
pub text: String,
|
||||
}
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use schemars::JsonSchema;
|
||||
use crate::model::MessageRoleType;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct GetAiChatResponseOutput {
|
||||
pub response: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct GetChatStatusTodayOutput {
|
||||
pub user_name: String,
|
||||
pub chat_cnt: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct UserChatMessage {
|
||||
#[serde(rename="type")]
|
||||
pub type_: MessageRoleType,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
pub type GetUserChatHistoryOutput = Vec<UserChatMessage>;
|
|
@ -0,0 +1,62 @@
|
|||
use mongodb::bson::DateTime;
|
||||
use mongodb::bson::oid::ObjectId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::error::Error;
|
||||
|
||||
use crate::model::{Message, User};
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UserDoc {
|
||||
pub _id: ObjectId,
|
||||
pub name: String,
|
||||
pub created_at: DateTime,
|
||||
pub updated_at: Option<DateTime>,
|
||||
}
|
||||
|
||||
impl UserDoc {
|
||||
pub fn to_entity(self) -> Result<User, Error> {
|
||||
let user = User {
|
||||
id: self._id.to_hex(),
|
||||
name: self.name,
|
||||
created_at: self.created_at.to_chrono().naive_utc(),
|
||||
updated_at: if let Some(updated_at) = self.updated_at {
|
||||
Some(updated_at.to_chrono().naive_utc())
|
||||
} else { None },
|
||||
};
|
||||
Ok(user)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MessageDoc {
|
||||
pub _id: ObjectId,
|
||||
pub user_id: ObjectId,
|
||||
#[serde(rename="type")]
|
||||
pub type_: String,
|
||||
pub text: String,
|
||||
pub created_at: DateTime,
|
||||
pub created_by: ObjectId,
|
||||
pub updated_at: Option<DateTime>,
|
||||
pub updated_by: Option<ObjectId>,
|
||||
}
|
||||
|
||||
impl MessageDoc {
|
||||
pub fn to_entity(self) -> Result<Message, Error> {
|
||||
let msg = Message {
|
||||
id: self._id.to_hex(),
|
||||
user_id: self.user_id.to_hex(),
|
||||
type_: self.type_.parse()?,
|
||||
text: self.text,
|
||||
created_at: self.created_at.to_chrono().naive_utc(),
|
||||
created_by: self.created_by.to_hex(),
|
||||
updated_at: if let Some(updated_at) = self.updated_at {
|
||||
Some(updated_at.to_chrono().naive_utc())
|
||||
} else { None },
|
||||
updated_by: if let Some(updated_by) = self.updated_by {
|
||||
Some(updated_by.to_hex())
|
||||
} else { None },
|
||||
};
|
||||
Ok(msg)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
use std::cmp::max;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use anyhow::Context;
|
||||
use chrono::{DateTime, NaiveDateTime, NaiveTime, Utc};
|
||||
use futures::TryStreamExt;
|
||||
use mongodb::bson::{DateTime as BsonDateTime, doc};
|
||||
use mongodb::bson::oid::ObjectId;
|
||||
use mongodb::options::{FindOneOptions, FindOptions};
|
||||
use crate::error::Error;
|
||||
|
||||
use crate::model::{Message, MessageDoc, MessageRoleType, NewMessage, User};
|
||||
use crate::store::api_client::ApiClients;
|
||||
use crate::store::cache::Caches;
|
||||
use crate::store::database::Databases;
|
||||
use crate::store::Store;
|
||||
|
||||
pub struct ChatProvider {
|
||||
store: Store,
|
||||
db: Databases,
|
||||
cache: Caches,
|
||||
api: ApiClients,
|
||||
}
|
||||
|
||||
|
||||
impl ChatProvider {
|
||||
pub fn new(store: Store) -> Self {
|
||||
Self {
|
||||
store: store.clone(),
|
||||
db: store.databases.clone(),
|
||||
cache: store.caches.clone(),
|
||||
api: store.api_clients.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ChatProvider {
|
||||
pub async fn check_user_message_limited_in_30_seconds(&self, user: User) -> Result<bool, Error> {
|
||||
let now = Utc::now();
|
||||
let dt_start = now - Duration::from_secs(30);
|
||||
let start = BsonDateTime::from_chrono(now);
|
||||
let filter = doc! {
|
||||
"user_id": ObjectId::from_str(user.id.as_str()).with_context(||format!("parse oid error: {}", user.id))?,
|
||||
"type": MessageRoleType::User.to_string(),
|
||||
"created_at": {"$gte": start}
|
||||
};
|
||||
debug!("filter: {}", filter);
|
||||
let count = self.db.message().count_documents(filter, None).await
|
||||
.with_context(|| "count_documents".to_string())?;
|
||||
debug!("count: {}", count);
|
||||
if count > 3 {
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check_user_message_limited_in_daily(&self, user: User) -> Result<bool, Error> {
|
||||
let now = Utc::now();
|
||||
let dt_start = NaiveDateTime::new(now.date_naive(), NaiveTime::default()).and_utc();
|
||||
let start = BsonDateTime::from_chrono(dt_start);
|
||||
let filter = doc! {
|
||||
"user_id": ObjectId::from_str(user.id.as_str()).with_context(||format!("parse oid error: {}", user.id))?,
|
||||
"type": MessageRoleType::User.to_string(),
|
||||
"created_at": {"$gte": start}
|
||||
};
|
||||
debug!("filter: {}", filter);
|
||||
let count = self.db.message().count_documents(filter, None).await
|
||||
.with_context(|| "count_documents".to_string())?;
|
||||
debug!("count: {}", count);
|
||||
if count > 20 {
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_chat_message(&self, messages: Vec<NewMessage>) -> Result<usize, Error> {
|
||||
let mut docs = vec![];
|
||||
for message in messages.iter() {
|
||||
let doc = MessageDoc {
|
||||
_id: ObjectId::new(),
|
||||
user_id: ObjectId::from_str(message.user_id.as_str()).with_context(||format!("parse oid error: {}", message.user_id))?,
|
||||
type_: message.type_.to_string(),
|
||||
text: message.text.to_owned(),
|
||||
created_at: BsonDateTime::now(),
|
||||
created_by: ObjectId::from_str(message.user_id.as_str()).with_context(||format!("parse oid error: {}", message.user_id))?,
|
||||
updated_at: None,
|
||||
updated_by: None,
|
||||
};
|
||||
docs.push(doc);
|
||||
}
|
||||
let res = self.db.message().insert_many(docs, None).await
|
||||
.with_context(|| "insert_many".to_string())?;
|
||||
debug!("inserted: {:?}", res);
|
||||
Ok(res.inserted_ids.len())
|
||||
}
|
||||
|
||||
pub async fn get_user_chat_messages(&self, user: User, limit: i64) -> Result<Vec<Message>, Error> {
|
||||
let limit = max(limit, 10);
|
||||
let opts = FindOptions::builder().sort(doc! {"created_at": -1}).limit(limit).build();
|
||||
let filter = doc! {
|
||||
"user_id": ObjectId::from_str(user.id.as_str())?,
|
||||
};
|
||||
debug!("filter: {}", filter);
|
||||
let mut cursor = self.db.message().find(filter, opts).await
|
||||
.with_context(|| "find".to_string())?;
|
||||
let mut res = vec![];
|
||||
while let Some(doc) = cursor.try_next().await
|
||||
.with_context(|| "try_next".to_string())? {
|
||||
res.push(doc.to_entity()?)
|
||||
}
|
||||
debug!("messages: {:?}", res);
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub async fn get_user_chat_messages_count_today(&self, user: User) -> Result<u64, Error> {
|
||||
let now = Utc::now();
|
||||
let dt_start = NaiveDateTime::new(now.date_naive(), NaiveTime::default()).and_utc();
|
||||
let start = BsonDateTime::from_chrono(dt_start);
|
||||
let filter = doc! {
|
||||
"user_id": ObjectId::from_str(user.id.as_str()).with_context(||format!("parse oid error: {}", user.id))?,
|
||||
"type": MessageRoleType::User.to_string(),
|
||||
"created_at": {"$gte": start}
|
||||
};
|
||||
debug!("filter: {}", filter);
|
||||
let count = self.db.message().count_documents(filter, None).await
|
||||
.with_context(|| "count_documents".to_string())?;
|
||||
debug!("count: {}", count);
|
||||
Ok(count)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
use crate::providers::ping::PingProvider;
|
||||
use crate::providers::chat::ChatProvider;
|
||||
use crate::providers::openrouter::OpenRouterProvider;
|
||||
use crate::providers::user::UserProvider;
|
||||
use crate::store::Store;
|
||||
|
||||
mod ping;
|
||||
mod chat;
|
||||
mod openrouter;
|
||||
mod user;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Providers {
|
||||
store: Store,
|
||||
}
|
||||
|
||||
impl Providers {
|
||||
pub fn new(store: &Store) -> Self {
|
||||
Self {
|
||||
store: store.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ping(&self) -> PingProvider {
|
||||
PingProvider::new(self.store.clone())
|
||||
}
|
||||
|
||||
pub fn openrouter(&self) -> OpenRouterProvider {
|
||||
OpenRouterProvider::new(self.store.clone())
|
||||
}
|
||||
|
||||
pub fn user(&self) -> UserProvider {
|
||||
UserProvider::new(self.store.clone())
|
||||
}
|
||||
|
||||
pub fn chat(&self) -> ChatProvider {
|
||||
ChatProvider::new(self.store.clone())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
use anyhow::Context;
|
||||
use async_openai::Client;
|
||||
use async_openai::config::OpenAIConfig;
|
||||
use async_openai::types::{ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, CompletionUsage, CreateChatCompletionRequestArgs, CreateChatCompletionResponse, CreateCompletionRequestArgs, Role};
|
||||
use reqwest::header::HeaderMap;
|
||||
use rocket_okapi::hash_map;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::error::Error;
|
||||
use crate::store::api_client::ApiClients;
|
||||
use crate::store::cache::Caches;
|
||||
use crate::store::database::Databases;
|
||||
use crate::store::Store;
|
||||
|
||||
pub struct OpenRouterProvider {
|
||||
store: Store,
|
||||
db: Databases,
|
||||
cache: Caches,
|
||||
api: ApiClients,
|
||||
}
|
||||
|
||||
|
||||
impl OpenRouterProvider {
|
||||
pub fn new(store: Store) -> Self {
|
||||
Self {
|
||||
store: store.clone(),
|
||||
db: store.databases.clone(),
|
||||
cache: store.caches.clone(),
|
||||
api: store.api_clients.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenRouterProvider {
|
||||
pub async fn chat(self, content: String) -> Result<String, Error> {
|
||||
// let config = OpenAIConfig::default()
|
||||
// .with_api_base("https://openrouter.ai/api/v1")
|
||||
// .with_api_key(self.store.config.openrouter_api_key);
|
||||
// let client = Client::with_config(config);
|
||||
// let request = CreateChatCompletionRequestArgs::default()
|
||||
// .model("mistralai/mistral-7b-instruct:free")
|
||||
// .messages(vec![
|
||||
// ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
|
||||
// content: ChatCompletionRequestUserMessageContent::Text(content),
|
||||
// role: Role::User,
|
||||
// name: None,
|
||||
// })
|
||||
// ]).build().with_context(|| "build CreateChatCompletionRequestArgs".to_string())?;
|
||||
//
|
||||
// println!("{}", serde_json::to_string(&request).unwrap());
|
||||
// let response = client.chat().create(request).await
|
||||
// .with_context(|| "chat create".to_string())?;
|
||||
let url = "https://openrouter.ai/api/v1/chat/completions";
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("Authorization", format!("Bearer {}", self.store.config.openrouter_api_key).parse()?);
|
||||
headers.insert("Content-Type", "application/json".parse()?);
|
||||
let body = OpenRouterCreateChatCompletionRequestArgs {
|
||||
model: "mistralai/mistral-7b-instruct:free".to_string(),
|
||||
messages: vec![OpenRouterCreateChatCompletionRequestArgsMessage {
|
||||
role: Role::User,
|
||||
content: content,
|
||||
}],
|
||||
};
|
||||
let client = reqwest::Client::new();
|
||||
let response: OpenRouterCreateChatCompletionResponse = client.post(url).headers(headers).json(&body)
|
||||
.send().await.with_context(|| "send request to openrouter".to_string())?
|
||||
.json().await.with_context(|| "deserialize from openrouter".to_string())?;
|
||||
debug!("response: {:?}", response);
|
||||
let choice = response.choices[0].clone();
|
||||
if let Some(response_content) = choice.message.content {
|
||||
Ok(response_content)
|
||||
} else {
|
||||
Ok("todo".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct OpenRouterCreateChatCompletionRequestArgsMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct OpenRouterCreateChatCompletionRequestArgs {
|
||||
pub model: String,
|
||||
pub messages: Vec<OpenRouterCreateChatCompletionRequestArgsMessage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct OpenRouterChatChoiceMessage {
|
||||
pub role: Role,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct OpenRouterChatChoice {
|
||||
pub message: OpenRouterChatChoiceMessage,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct OpenRouterCompletionUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
pub total_cost: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct OpenRouterCreateChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub model: String,
|
||||
pub created: u32,
|
||||
pub object: String,
|
||||
pub choices: Vec<OpenRouterChatChoice>,
|
||||
pub usage: Option<OpenRouterCompletionUsage>,
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
use sqlx::Connection;
|
||||
use crate::error::Error;
|
||||
|
||||
use crate::store::api_client::ApiClients;
|
||||
use crate::store::cache::Caches;
|
||||
use crate::store::database::Databases;
|
||||
use crate::store::Store;
|
||||
|
||||
pub struct PingProvider {
|
||||
store: Store,
|
||||
db: Databases,
|
||||
cache: Caches,
|
||||
api: ApiClients,
|
||||
}
|
||||
|
||||
impl PingProvider {
|
||||
pub fn new(store: Store) -> Self {
|
||||
Self {
|
||||
store: store.clone(),
|
||||
db: store.databases.clone(),
|
||||
cache: store.caches.clone(),
|
||||
api: store.api_clients.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PingProvider {
|
||||
pub async fn ping_mysql(&self) -> Result<String, Error> {
|
||||
// let mut conn = self.db.user();
|
||||
// conn.ping().await?;
|
||||
Ok("pong".to_string())
|
||||
}
|
||||
|
||||
pub async fn ping_pgsql(&self) -> Result<String, Error> {
|
||||
// let mut conn = self.db.default.acquire().await?;
|
||||
// conn.ping().await?;
|
||||
Ok("pong".to_string())
|
||||
}
|
||||
|
||||
pub async fn ping_redis(&self) -> Result<String, Error> {
|
||||
// let mut conn = self.cache.default.clone();
|
||||
// let reply: RedisResult<String> = redis::cmd("PING").query_async(&mut conn).await;
|
||||
// assert_eq!("PONG", reply.unwrap());
|
||||
Ok("pong".to_string())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
use anyhow::Context;
|
||||
use chrono::{NaiveDateTime, Utc};
|
||||
use mongodb::bson::{DateTime, doc};
|
||||
use mongodb::bson::oid::ObjectId;
|
||||
use crate::error::Error;
|
||||
use crate::model::{User, UserDoc};
|
||||
use crate::store::api_client::ApiClients;
|
||||
use crate::store::cache::Caches;
|
||||
use crate::store::database::Databases;
|
||||
use crate::store::Store;
|
||||
|
||||
pub struct UserProvider {
|
||||
store: Store,
|
||||
db: Databases,
|
||||
cache: Caches,
|
||||
api: ApiClients,
|
||||
}
|
||||
|
||||
|
||||
impl UserProvider {
|
||||
pub fn new(store: Store) -> Self {
|
||||
Self {
|
||||
store: store.clone(),
|
||||
db: store.databases.clone(),
|
||||
cache: store.caches.clone(),
|
||||
api: store.api_clients.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserProvider {
|
||||
pub async fn get_user_by_name(self, user_name: String) -> Result<Option<User>, Error>{
|
||||
let user = self.db.user().find_one(doc! {"name": user_name.clone()}, None).await
|
||||
.with_context(|| format!("find_one by name: {}", user_name))?;
|
||||
if let Some(user) = user {
|
||||
Ok(Some(user.clone().to_entity().with_context(||format!("found user to_entity: {:?}", user))?))
|
||||
} else {
|
||||
let user = UserDoc {
|
||||
_id: ObjectId::new(),
|
||||
name: user_name,
|
||||
created_at: DateTime::now(),
|
||||
updated_at: None,
|
||||
};
|
||||
let res = self.db.user().insert_one(user, None).await
|
||||
.with_context(|| "insert_one".to_string())?;
|
||||
let user = self.db.user().find_one(doc! {"_id": res.inserted_id.clone()}, None).await
|
||||
.with_context(|| format!("find_one by _id {}", res.inserted_id))?;
|
||||
if let Some(user) = user {
|
||||
Ok(Some(user.clone().to_entity().with_context(||format!("new user to_entity: {:?}", user))?))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
use std::ops::Deref;
|
||||
use rocket::form::Form;
|
||||
use rocket::serde::json::Json;
|
||||
|
||||
use rocket::State;
|
||||
use rocket_okapi::openapi;
|
||||
use crate::error::{Code, Error};
|
||||
use crate::model::{Context, GetAiChatResponseInput, GetAiChatResponseOutput, GetChatStatusTodayOutput, GetUserChatHistoryOutput};
|
||||
|
||||
use crate::services::Services;
|
||||
use crate::error::Error::ParamsError;
|
||||
use crate::providers::Providers;
|
||||
use crate::store::Store;
|
||||
|
||||
|
||||
#[openapi(tag = "Hello World")]
|
||||
#[get("/")]
|
||||
pub async fn index() -> &'static str {
|
||||
"Hello, world!"
|
||||
}
|
||||
|
||||
#[openapi(tag = "favicon.ico")]
|
||||
#[get("/favicon.ico")]
|
||||
pub async fn favicon() -> &'static str {
|
||||
"favicon.ico"
|
||||
}
|
||||
|
||||
|
||||
/// # Get AI Chat Response
|
||||
#[openapi(tag = "Chat")]
|
||||
#[post("/api/v1/get_ai_chat_response", data="<req>")]
|
||||
pub async fn get_ai_chat_response(store: &State<Store>, req: Json<GetAiChatResponseInput>) -> Result<Json<GetAiChatResponseOutput>, Error> {
|
||||
let req = req.into_inner();
|
||||
let pvd = Providers::new(store);
|
||||
let user = pvd.user().get_user_by_name(req.user_name.clone()).await?;
|
||||
if let Some(user) = user {
|
||||
let ctx = Context::new(user);
|
||||
let svc = Services::new(ctx, pvd);
|
||||
let res = svc.chat().get_ai_chat_response(req).await?;
|
||||
Ok(Json(res))
|
||||
} else {
|
||||
Err(Error::Feedback(Code::UserNotFound))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// # Get User Chat History
|
||||
#[openapi(tag = "Chat")]
|
||||
#[get("/api/v1/get_user_chat_history?<user_name>&<last_n>")]
|
||||
pub async fn get_user_chat_history(store: &State<Store>, user_name: String, last_n: i64) -> Result<Json<GetUserChatHistoryOutput>, Error> {
|
||||
let pvd = Providers::new(store);
|
||||
let user = pvd.user().get_user_by_name(user_name).await?;
|
||||
if let Some(user) = user {
|
||||
let ctx = Context::new(user);
|
||||
let svc = Services::new(ctx, pvd);
|
||||
let res = svc.chat().get_user_chat_history(last_n).await?;
|
||||
Ok(Json(res))
|
||||
} else {
|
||||
Err(Error::Feedback(Code::UserNotFound))
|
||||
}
|
||||
}
|
||||
|
||||
/// # Get Chat Status Today
|
||||
#[openapi(tag = "Chat")]
|
||||
#[get("/api/v1/get_chat_status_today?<user_name>")]
|
||||
pub async fn get_chat_status_today(store: &State<Store>, user_name: String) -> Result<Json<GetChatStatusTodayOutput>, Error> {
|
||||
let pvd = Providers::new(store);
|
||||
let user = pvd.user().get_user_by_name(user_name).await?;
|
||||
if let Some(user) = user {
|
||||
let ctx = Context::new(user);
|
||||
let svc = Services::new(ctx, pvd);
|
||||
let res = svc.chat().get_chat_status_today().await?;
|
||||
Ok(Json(res))
|
||||
} else {
|
||||
Err(Error::Feedback(Code::UserNotFound))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
use std::fmt::format;
|
||||
use anyhow::Context as AnyhowContext;
|
||||
use chrono::{NaiveDateTime, Utc};
|
||||
use log::debug;
|
||||
use mongodb::bson;
|
||||
use mongodb::bson::oid::ObjectId;
|
||||
use redis::ToRedisArgs;
|
||||
use crate::error::{Code, Error};
|
||||
use crate::model::{Context, GetAiChatResponseInput, GetAiChatResponseOutput, GetChatStatusTodayOutput, GetUserChatHistoryOutput, Message, MessageRoleType, NewMessage, UserChatMessage};
|
||||
use crate::providers::Providers;
|
||||
|
||||
pub struct ChatService {
|
||||
ctx: Context,
|
||||
pvd: Providers,
|
||||
}
|
||||
|
||||
impl ChatService {
|
||||
pub fn new(context: Context, providers: Providers) -> Self {
|
||||
Self {
|
||||
ctx: context,
|
||||
pvd: providers,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl ChatService {
|
||||
pub async fn get_ai_chat_response(&self, req: GetAiChatResponseInput) -> Result<GetAiChatResponseOutput, Error> {
|
||||
let limited = self.pvd.chat().check_user_message_limited_in_30_seconds(self.ctx.user.clone()).await
|
||||
.with_context(||format!("check_user_message_limited_in_30_seconds: {:?}", self.ctx.user.clone()))?;
|
||||
if limited {
|
||||
return Err(Error::Unauthorized);
|
||||
}
|
||||
let limited = self.pvd.chat().check_user_message_limited_in_daily(self.ctx.user.clone()).await
|
||||
.with_context(||format!("check_user_message_limited_in_daily: {:?}", self.ctx.user.clone()))?;
|
||||
if limited {
|
||||
return Err(Error::Unauthorized);
|
||||
}
|
||||
|
||||
let request_content = req.message;
|
||||
// todo: request conent middle out
|
||||
let response_content = self.pvd.openrouter().chat(request_content.clone()).await
|
||||
.with_context(|| format!("chat: {}", request_content.clone()))?;
|
||||
let now = Utc::now();
|
||||
let created_at = NaiveDateTime::new(now.date_naive(), now.time());
|
||||
let user_message = NewMessage {
|
||||
user_id: self.ctx.user.id.to_string(),
|
||||
type_: MessageRoleType::User,
|
||||
text: request_content.to_string(),
|
||||
};
|
||||
let ai_message = NewMessage {
|
||||
user_id: self.ctx.user.id.to_string(),
|
||||
type_: MessageRoleType::AI,
|
||||
text: response_content.to_string(),
|
||||
};
|
||||
let messages = vec![user_message, ai_message];
|
||||
let count = self.pvd.chat().add_chat_message(messages).await
|
||||
.with_context(|| "add_chat_message".to_string())?;
|
||||
debug!("Added {count} chat messages");
|
||||
let res = GetAiChatResponseOutput {
|
||||
response: response_content,
|
||||
};
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub async fn get_user_chat_history(&self, last_n: i64) -> Result<GetUserChatHistoryOutput, Error> {
|
||||
let messages = self.pvd.chat().get_user_chat_messages(self.ctx.user.clone(), last_n).await
|
||||
.with_context(||format!("get_user_chat_messages: {:?}", self.ctx.user.clone()))?;
|
||||
let mut res = vec![];
|
||||
for msg in messages.iter() {
|
||||
res.push(UserChatMessage {
|
||||
type_: msg.type_.clone(),
|
||||
text: msg.text.clone(),
|
||||
});
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub async fn get_chat_status_today(&self) -> Result<GetChatStatusTodayOutput, Error> {
|
||||
let count = self.pvd.chat().get_user_chat_messages_count_today(self.ctx.user.clone()).await
|
||||
.with_context(||format!("get_user_chat_messages_count_today: {:?}", self.ctx.user.clone()))?;
|
||||
let res = GetChatStatusTodayOutput {
|
||||
user_name: self.ctx.user.name.clone(),
|
||||
chat_cnt: count,
|
||||
};
|
||||
Ok(res)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
use crate::model::Context;
|
||||
use crate::providers::Providers;
|
||||
use crate::services::ping::PingService;
|
||||
use crate::services::chat::ChatService;
|
||||
use crate::store::Store;
|
||||
|
||||
mod ping;
|
||||
mod chat;
|
||||
|
||||
pub struct Services {
|
||||
ctx: Context,
|
||||
pvd: Providers,
|
||||
}
|
||||
|
||||
impl Services {
|
||||
pub fn new(context: Context, providers: Providers) -> Self {
|
||||
Self {
|
||||
ctx: context,
|
||||
pvd: providers,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ping(&self) -> PingService {
|
||||
PingService::new(self.ctx.clone(), self.pvd.clone())
|
||||
}
|
||||
|
||||
pub fn chat(&self) -> ChatService {
|
||||
ChatService::new(self.ctx.clone(), self.pvd.clone())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
use crate::error::Error;
|
||||
use crate::model::Context;
|
||||
use crate::providers::Providers;
|
||||
|
||||
pub struct PingService {
|
||||
ctx: Context,
|
||||
pvd: Providers,
|
||||
}
|
||||
|
||||
impl PingService {
|
||||
pub fn new(context: Context, providers: Providers) -> Self {
|
||||
Self {
|
||||
ctx: context,
|
||||
pvd: providers,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PingService {
|
||||
pub async fn do_ping(&self) -> Result<String, Error> {
|
||||
self.pvd.ping().ping_pgsql().await?;
|
||||
self.pvd.ping().ping_redis().await?;
|
||||
Ok("pong".to_string())
|
||||
}
|
||||
pub async fn do_ping_mysql(&self) -> Result<String, Error> {
|
||||
self.pvd.ping().ping_mysql().await?;
|
||||
Ok("pong".to_string())
|
||||
}
|
||||
pub async fn do_ping_pgsql(&self) -> Result<String, Error> {
|
||||
self.pvd.ping().ping_pgsql().await?;
|
||||
Ok("pong".to_string())
|
||||
}
|
||||
pub async fn do_ping_redis(&self) -> Result<String, Error> {
|
||||
self.pvd.ping().ping_redis().await?;
|
||||
Ok("pong".to_string())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use reqwest;
|
||||
use url::Url;
|
||||
use crate::conf::Config;
|
||||
use crate::error::Error;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiClients {
|
||||
}
|
||||
|
||||
impl ApiClients {
|
||||
pub fn new(config: Config) -> Self {
|
||||
ApiClients {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct HttpClient {
|
||||
host: Url,
|
||||
location: String,
|
||||
pub client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl HttpClient {
|
||||
pub fn new(host: String, location: String) -> Self {
|
||||
let host = Url::parse(host.as_str()).expect("Invalid host");
|
||||
println!("{:?}", host);
|
||||
HttpClient {
|
||||
host,
|
||||
location,
|
||||
client: reqwest::Client::builder()
|
||||
.tcp_keepalive(Some(Duration::from_secs(3600)))
|
||||
.build()
|
||||
.expect("Build reqwest::Client failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient {
|
||||
pub fn make_url(&self, path: String) -> Result<Url, Error> {
|
||||
let url = format!(
|
||||
"{}{}{}",
|
||||
self.host
|
||||
.as_str()
|
||||
.strip_suffix("/")
|
||||
.unwrap_or(self.host.as_str()),
|
||||
self.location.as_str(),
|
||||
path.as_str()
|
||||
);
|
||||
let url = Url::parse(url.as_str())?;
|
||||
println!("{:?}", url.to_string());
|
||||
Ok(url)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
use std::env;
|
||||
|
||||
use redis::aio::ConnectionManager;
|
||||
use redis::Client;
|
||||
use crate::conf::Config;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Caches {
|
||||
// pub default: ConnectionManager,
|
||||
}
|
||||
|
||||
impl Caches {
|
||||
pub async fn new(config: Config) -> Self {
|
||||
println!("Caches init");
|
||||
Caches {
|
||||
// default: connect(config).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect(config: Config) -> ConnectionManager {
|
||||
let client = Client::open(config.redis_url).unwrap();
|
||||
ConnectionManager::new(client).await.unwrap()
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
use std::env;
|
||||
use rocket::http::Status;
|
||||
use rocket::request::FromRequest;
|
||||
use rocket::{request, Request};
|
||||
use mongodb::options::ClientOptions;
|
||||
use mongodb::{Client, Collection, Database};
|
||||
use crate::conf::Config;
|
||||
use crate::error::Error;
|
||||
use crate::model::{MessageDoc, UserDoc};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Databases {
|
||||
pub default: Database,
|
||||
}
|
||||
|
||||
impl Databases {
|
||||
pub async fn new(config: Config) -> Self {
|
||||
println!("Databases init");
|
||||
let db = Databases {
|
||||
default: connect(config).await.expect("can not connect to mongodb."),
|
||||
};
|
||||
println!("{db:?}");
|
||||
db
|
||||
}
|
||||
|
||||
pub fn user(&self) -> Collection<UserDoc> {
|
||||
return self.default.collection::<UserDoc>("user")
|
||||
}
|
||||
|
||||
pub fn message(&self) -> Collection<MessageDoc> {
|
||||
return self.default.collection::<MessageDoc>("message")
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect(config: Config) -> mongodb::error::Result<Database> {
|
||||
let mongo_host = env::var("MONGO_HOST").unwrap_or("localhost".to_string());
|
||||
let mongo_port = env::var("MONGO_PORT").unwrap_or("27017".to_string());
|
||||
let mongo_username = env::var("MONGO_USERNAME").expect("MONGO_USERNAME is not set in .env");
|
||||
let mongo_password = env::var("MONGO_PASSWORD").expect("MONGO_PASSWORD is not set in .env");
|
||||
// mongodb://{MONGO_USERNAME}:{MONGO_PASSWORD}@mongodb:27017/
|
||||
let mongo_uri = format!("mongodb://{mongo_username}:{mongo_password}@{mongo_host}:{mongo_port}/");
|
||||
println!("mongo_uri: {mongo_uri}");
|
||||
let mongo_db_name = env::var("MONGO_DB_NAME").unwrap_or("simplylab".to_string());
|
||||
|
||||
let client_options = ClientOptions::parse(mongo_uri).await?;
|
||||
let client = Client::with_options(client_options)?;
|
||||
let dbs = client.list_databases(None, None).await?;
|
||||
println!("databases: {dbs:?}");
|
||||
let database = client.database(mongo_db_name.as_str());
|
||||
Ok(database)
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for &'r Databases {
|
||||
type Error = Error;
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Error> {
|
||||
let db = request.rocket().state::<Databases>();
|
||||
match db {
|
||||
Some(db) => request::Outcome::Success(db),
|
||||
None => request::Outcome::Error((
|
||||
Status::ServiceUnavailable,
|
||||
Error::DatabaseConnectionError("从 State 获取连接池失败".to_string()),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
use crate::conf::Config;
|
||||
use crate::store::api_client::ApiClients;
|
||||
use crate::store::cache::Caches;
|
||||
use crate::store::database::Databases;
|
||||
|
||||
pub mod api_client;
|
||||
pub mod cache;
|
||||
pub mod database;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Store {
|
||||
pub config: Config,
|
||||
pub databases: Databases,
|
||||
pub caches: Caches,
|
||||
pub api_clients: ApiClients,
|
||||
}
|
||||
|
||||
impl Store {
|
||||
pub async fn new() -> Self {
|
||||
let config = Config::new().await;
|
||||
Store {
|
||||
config: config.clone(),
|
||||
databases: Databases::new(config.clone()).await,
|
||||
caches: Caches::new(config.clone()).await,
|
||||
api_clients: ApiClients::new(config.clone()),
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue