diff --git a/src/channels.rs b/src/channels.rs index 5bfa340..07d672e 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -42,14 +42,19 @@ impl Channel { } #[auto_type(no_type_alias)] - pub fn with_id(channel_id: Uuid) -> _ { + pub fn with_id<'a>(channel_id: &'a Uuid) -> _ { channels::id.eq(channel_id) } #[auto_type(no_type_alias)] - pub fn with_team(team_id: Uuid) -> _ { + pub fn with_team<'a>(team_id: &'a Uuid) -> _ { channels::team_id.eq(team_id) } + + #[auto_type(no_type_alias)] + pub fn where_enabled_by_default() -> _ { + channels::enable_by_default.eq(true) + } } /** diff --git a/src/governors.rs b/src/governors.rs index c1b7569..d3fc6d1 100644 --- a/src/governors.rs +++ b/src/governors.rs @@ -3,7 +3,7 @@ use anyhow::Result; use chrono::{DateTime, TimeDelta, Utc}; use diesel::{ - dsl::{auto_type, AsSelect}, + dsl::{auto_type, insert_into, AsSelect}, pg::Pg, prelude::*, sql_types::Timestamptz, @@ -28,6 +28,25 @@ pub struct Governor { } impl Governor { + pub fn insert_new<'a>( + db_conn: &mut diesel::PgConnection, + team_id: &'a Uuid, + project_id: Option<&'a Uuid>, + window_size: &'a TimeDelta, + max_count: i32, + ) -> Result { + let id: Uuid = Uuid::now_v7(); + Ok(insert_into(governors::table) + .values(( + governors::team_id.eq(team_id), + governors::id.eq(id), + governors::project_id.eq(project_id), + governors::window_size.eq(window_size), + governors::max_count.eq(max_count), + )) + .get_result(db_conn)?) + } + #[auto_type(no_type_alias)] pub fn all() -> _ { let select: AsSelect = Governor::as_select(); diff --git a/src/projects.rs b/src/projects.rs index cf06f49..18857d3 100644 --- a/src/projects.rs +++ b/src/projects.rs @@ -1,5 +1,6 @@ +use anyhow::Result; use diesel::{ - dsl::{auto_type, AsSelect, Eq}, + dsl::{auto_type, insert_into, AsSelect, Eq}, pg::Pg, prelude::*, }; @@ -21,6 +22,34 @@ pub struct Project { } impl Project { + pub fn insert_new<'a>( + db_conn: &mut diesel::PgConnection, + team_id: &'a Uuid, + name: &'a str, + ) -> Result { + let default_channels = Channel::all() + .filter(Channel::with_team(team_id)) + .filter(Channel::where_enabled_by_default()) + .load(db_conn)?; + let id: Uuid = Uuid::now_v7(); + let project: Self = insert_into(projects::table) + .values(( + projects::id.eq(id), + projects::team_id.eq(team_id), + projects::name.eq(name), + )) + .get_result(db_conn)?; + for channel in default_channels { + insert_into(channel_selections::table) + .values(( + channel_selections::project_id.eq(&project.id), + channel_selections::channel_id.eq(&channel.id), + )) + .execute(db_conn)?; + } + Ok(project) + } + #[auto_type(no_type_alias)] pub fn all() -> _ { let select: AsSelect = Project::as_select(); diff --git a/src/router.rs b/src/router.rs index be42f21..c6a0fe0 100644 --- a/src/router.rs +++ b/src/router.rs @@ -27,7 +27,6 @@ use crate::{ email::{MailSender as _, Mailer}, guards, nav_state::{Breadcrumb, NavState}, - projects::Project, schema::{self, channel_selections, channels}, settings::Settings, team_memberships::TeamMembership, @@ -213,15 +212,15 @@ async fn post_new_team( user_id: current_user.id, }; db_conn - .interact(move |conn| { - conn.transaction(move |conn| { + .interact::<_, Result<(), AppError>>(move |conn| { + conn.transaction::<(), AppError, _>(move |conn| { insert_into(schema::teams::table) - .values(team) + .values(&team) .execute(conn)?; insert_into(schema::team_memberships::table) - .values(team_membership) + .values(&team_membership) .execute(conn)?; - diesel::QueryResult::Ok(()) + Ok(()) }) }) .await @@ -286,12 +285,18 @@ async fn channels_page( ) -> Result { let team = guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; - let team_filter = Channel::with_team(team_id); - let channels = db_conn - .interact(move |conn| Channel::all().filter(team_filter).load(conn)) - .await - .unwrap() - .context("Failed to load channels list.")?; + let channels = { + let team_id = team_id.clone(); + db_conn + .interact(move |conn| { + Channel::all() + .filter(Channel::with_team(&team_id)) + .load(conn) + }) + .await + .unwrap() + .context("Failed to load channels list.")? + }; let csrf_token = generate_csrf_token(&db_conn, Some(current_user.id.clone())).await?; let nav_state = NavState::new() @@ -379,25 +384,27 @@ async fn channel_page( ) -> Result { let team = guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; - let id_filter = Channel::with_id(channel_id); - let team_filter = Channel::with_team(team_id.clone()); - let channel = match db_conn - .interact(move |conn| { - Channel::all() - .filter(id_filter) - .filter(team_filter) - .first(conn) - .optional() - }) - .await - .unwrap()? - { - None => { - return Err(AppError::NotFoundError( - "Channel with that team and ID not found".to_string(), - )); + let channel = { + let channel_id = channel_id.clone(); + let team_id = team_id.clone(); + match db_conn + .interact(move |conn| { + Channel::all() + .filter(Channel::with_id(&channel_id)) + .filter(Channel::with_team(&team_id)) + .first(conn) + .optional() + }) + .await + .unwrap()? + { + None => { + return Err(AppError::NotFoundError( + "Channel with that team and ID not found".to_string(), + )); + } + Some(channel) => channel, } - Some(channel) => channel, }; let csrf_token = generate_csrf_token(&db_conn, Some(current_user.id.clone())).await?; @@ -457,21 +464,27 @@ async fn update_channel( guards::require_valid_csrf_token(&form_body.csrf_token, ¤t_user, &db_conn).await?; guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; - let id_filter = Channel::with_id(channel_id.clone()); - let team_filter = Channel::with_team(team_id.clone()); - let updated_rows = db_conn - .interact(move |conn| { - update(channels::table.filter(id_filter).filter(team_filter)) + let updated_rows = { + let channel_id = channel_id.clone(); + let team_id = team_id.clone(); + db_conn + .interact(move |conn| { + update( + channels::table + .filter(Channel::with_id(&channel_id)) + .filter(Channel::with_team(&team_id)), + ) .set(( channels::name.eq(form_body.name), channels::enable_by_default .eq(form_body.enable_by_default.unwrap_or("false".to_string()) == "true"), )) .execute(conn) - }) - .await - .unwrap() - .context("Failed to load Channel while updating.")?; + }) + .await + .unwrap() + .context("Failed to load Channel while updating.")? + }; if updated_rows != 1 { return Err(AppError::NotFoundError( "Channel with that team and ID not found".to_string(), @@ -490,10 +503,10 @@ async fn update_channel( * Helper function to query a channel from the database by ID and team, and * return an appropriate error if no such channel exists. */ -fn get_channel_by_params( +fn get_channel_by_params<'a>( conn: &mut PgConnection, - team_id: Uuid, - channel_id: Uuid, + team_id: &'a Uuid, + channel_id: &'a Uuid, ) -> Result { match Channel::all() .filter(Channel::with_id(channel_id)) @@ -551,14 +564,14 @@ async fn update_channel_email_recipient( .interact(move |conn| { // TODO: transaction retries conn.transaction::<_, AppError, _>(move |conn| { - let channel = get_channel_by_params(conn, team_id, channel_id)?; + let channel = get_channel_by_params(conn, &team_id, &channel_id)?; let new_config = BackendConfig::Email(EmailBackendConfig { recipient, verification_code, verification_code_guesses: 0, ..channel.backend_config.try_into()? }); - let num_rows = update(channels::table.filter(Channel::with_id(channel.id))) + let num_rows = update(channels::table.filter(Channel::with_id(&channel.id))) .set(channels::backend_config.eq(new_config)) .execute(conn)?; if num_rows != 1 { @@ -641,7 +654,7 @@ async fn verify_email( db_conn .interact(move |conn| { conn.transaction::<(), AppError, _>(move |conn| { - let channel = get_channel_by_params(conn, team_id, channel_id)?; + let channel = get_channel_by_params(conn, &team_id, &channel_id)?; let config: EmailBackendConfig = channel.backend_config.try_into()?; if config.verified { return Err(AppError::BadRequestError( @@ -666,7 +679,7 @@ async fn verify_email( ..config } }; - update(channels::table.filter(Channel::with_id(channel_id))) + update(channels::table.filter(Channel::with_id(&channel_id))) .set(channels::backend_config.eq(Into::::into(new_config))) .execute(conn)?; Ok(()) @@ -722,12 +735,18 @@ async fn project_page( .map(|channel| channel.id) .collect(); - let team_filter = Channel::with_team(team.id.clone()); - let team_channels = db_conn - .interact(move |conn| Channel::all().filter(team_filter).load(conn)) - .await - .unwrap() - .context("failed to load team channels")?; + let team_channels = { + let team_id = team.id.clone(); + db_conn + .interact(move |conn| { + Channel::all() + .filter(Channel::with_team(&team_id)) + .load(conn) + }) + .await + .unwrap() + .context("failed to load team channels")? + }; let csrf_token = generate_csrf_token(&db_conn, Some(current_user.id)).await?; let nav_state = NavState::new() diff --git a/src/settings.rs b/src/settings.rs index fc72aa2..b69979b 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result}; use axum::extract::FromRef; -use config::{Config, ConfigError, Environment}; +use config::{Config, Environment}; use dotenvy::dotenv; use serde::Deserialize; diff --git a/src/v0_router.rs b/src/v0_router.rs index 940d525..8895462 100644 --- a/src/v0_router.rs +++ b/src/v0_router.rs @@ -21,7 +21,7 @@ use crate::{ channels::Channel, governors::Governor, projects::Project, - schema::{api_keys, governors, messages, projects}, + schema::{api_keys, messages}, }; const TEAM_GOVERNOR_DEFAULT_WINDOW_SIZE_SEC: i64 = 300; @@ -90,13 +90,7 @@ async fn say_get( .context("failed to load project")? { Some(project) => project, - None => insert_into(projects::table) - .values(( - projects::id.eq(Uuid::now_v7()), - projects::team_id.eq(api_key.team_id), - projects::name.eq(project_name), - )) - .get_result(conn) + None => Project::insert_new(conn, &api_key.team_id, &project_name) .context("failed to insert project")?, }, ) @@ -120,16 +114,14 @@ async fn say_get( diesel::QueryResult::Ok(governor) => Ok(governor), diesel::QueryResult::Err(diesel::result::Error::NotFound) => { // Lazily initialize governor - Ok(diesel::insert_into(governors::table) - .values(( - governors::team_id.eq(team_id), - governors::id.eq(Uuid::now_v7()), - governors::project_id.eq(None as Option), - governors::window_size - .eq(TimeDelta::seconds(TEAM_GOVERNOR_DEFAULT_WINDOW_SIZE_SEC)), - governors::max_count.eq(TEAM_GOVERNOR_DEFAULT_MAX_COUNT), - )) - .get_result(conn)?) + Governor::insert_new( + conn, + &team_id, + None, + &TimeDelta::seconds(TEAM_GOVERNOR_DEFAULT_WINDOW_SIZE_SEC), + TEAM_GOVERNOR_DEFAULT_MAX_COUNT, + ) + .map_err(Into::into) } diesel::QueryResult::Err(err) => Err(err.into()), }