use std::sync::LazyLock; use anyhow::Context; use axum::{ extract::Query, response::{IntoResponse, Json}, routing::get, Router, }; use chrono::TimeDelta; use diesel::{dsl::insert_into, prelude::*, update}; use regex::Regex; use serde::Deserialize; use serde_json::json; use uuid::Uuid; use validator::Validate; use crate::{ api_keys::{try_parse_as_uuid, ApiKey}, app_error::AppError, app_state::{AppState, DbConn}, channels::Channel, governors::Governor, projects::{Project, DEFAULT_PROJECT_NAME}, schema::{api_keys, messages}, }; const TEAM_GOVERNOR_DEFAULT_WINDOW_SIZE_SEC: i64 = 300; const TEAM_GOVERNOR_DEFAULT_MAX_COUNT: i32 = 50; static RE_PROJECT_NAME: LazyLock = LazyLock::new(|| Regex::new(r"^[a-z0-9_-]{1,100}$").unwrap()); pub fn new_router() -> Router { Router::new().route("/say", get(say_get)) } #[derive(Deserialize, Validate)] struct SayQuery { #[serde(alias = "k")] key: String, #[serde(alias = "p")] #[serde(default = "default_project")] #[validate(regex( path = *RE_PROJECT_NAME, message = "may be no more than 100 characters and contain only alphanumerics, -, and _", ))] project: String, #[serde(alias = "m")] #[validate(length( min = 1, max = 2048, message = "message must be non-empty and no larger than 2KiB" ))] message: String, } fn default_project() -> String { DEFAULT_PROJECT_NAME.to_string() } async fn say_get( DbConn(db_conn): DbConn, Query(mut query): Query, ) -> Result { query.project = query.project.to_lowercase().replace(" ", "_"); query.validate().map_err(AppError::from_validation_errors)?; let api_key = { let query_key = try_parse_as_uuid(&query.key) .or(Err(AppError::Forbidden("key not accepted".to_string())))?; db_conn .interact::<_, Result>(move |conn| { update(api_keys::table.filter(ApiKey::with_id(&query_key))) .set(api_keys::last_used_at.eq(diesel::dsl::now)) .returning(ApiKey::as_returning()) .get_result(conn) .optional() .context("failed to get API key")? .ok_or(AppError::Forbidden("key not accepted.".to_string())) }) .await .unwrap()? }; let project = { let project_name = query.project.clone(); db_conn .interact::<_, Result>(move |conn| { conn.transaction(move |conn| { Ok( match Project::all() .filter(Project::with_team(&api_key.team_id)) .filter(Project::with_name(&project_name)) .first(conn) .optional() .context("failed to load project")? { Some(project) => project, None => Project::insert_new(conn, &api_key.team_id, &project_name) .context("failed to insert project")?, }, ) }) }) .await .unwrap()? }; let team_governor = { let team_id = project.team_id; db_conn .interact::<_, Result>(move |conn| { // TODO: extract this logic to a method in crate::governors, // and create governor proactively on team creation match Governor::all() .filter(Governor::with_team(&team_id)) .filter(Governor::with_project(&None)) .first(conn) { diesel::QueryResult::Ok(governor) => Ok(governor), diesel::QueryResult::Err(diesel::result::Error::NotFound) => { // Lazily initialize governor Governor::insert_new( conn, &team_id, None, &TimeDelta::seconds(TEAM_GOVERNOR_DEFAULT_WINDOW_SIZE_SEC), TEAM_GOVERNOR_DEFAULT_MAX_COUNT, ) .map_err(Into::into) } diesel::QueryResult::Err(err) => Err(err.into()), } }) .await .unwrap()? }; if db_conn .interact::<_, Result, anyhow::Error>>(move |conn| { team_governor.create_entry(conn) }) .await .unwrap()? .is_none() { return Err(AppError::TooManyRequests( "team rate limit exceeded".to_string(), )); } let selected_channels = { let project = project.clone(); db_conn .interact::<_, Result, AppError>>(move |conn| { Ok(project .selected_channels() .load(conn) .context("failed to load selected channels")?) }) .await .unwrap()? }; { let selected_channels = selected_channels.clone(); db_conn .interact::<_, Result<_, AppError>>(move |conn| { for channel in selected_channels { insert_into(messages::table) .values(( messages::id.eq(Uuid::now_v7()), messages::channel_id.eq(&channel.id), messages::project_id.eq(&project.id), messages::message.eq(&query.message), )) .execute(conn)?; } Ok(()) }) .await .unwrap()?; } tracing::debug!("queued {} messages", selected_channels.len()); Ok(Json(json!({ "ok": true }))) }