diff --git a/src/api_keys.rs b/src/api_keys.rs index a08e451..5e6bfeb 100644 --- a/src/api_keys.rs +++ b/src/api_keys.rs @@ -11,6 +11,8 @@ use uuid::Uuid; use crate::{app_error::AppError, schema::api_keys, teams::Team}; +/// A team-scoped application key for authenticating API calls to /say, etc. +/// Does not authorize any administrative functions besides creating projects. #[derive(Associations, Clone, Debug, Identifiable, Insertable, Queryable, Selectable)] #[diesel(table_name = api_keys)] #[diesel(belongs_to(Team))] @@ -46,27 +48,23 @@ impl ApiKey { } #[auto_type(no_type_alias)] - pub fn with_id(id: Uuid) -> _ { + pub fn with_id<'a>(id: &'a Uuid) -> _ { api_keys::id.eq(id) } #[auto_type(no_type_alias)] - pub fn with_team(team_id: Uuid) -> _ { + pub fn with_team<'a>(team_id: &'a Uuid) -> _ { api_keys::team_id.eq(team_id) } } -/** - * Encode big-endian bytes of a UUID as URL-safe base64. - */ +/// Encode big-endian bytes of a UUID as URL-safe base64. pub fn compact_uuid(id: &Uuid) -> String { URL_SAFE_NO_PAD.encode(id.as_bytes()) } -/** - * Attempt to parse a string as either a standard formatted UUID or a big-endian - * base64 encoding of one. - */ +/// Attempt to parse a string as either a standard formatted UUID or a +/// big-endian base64 encoding of one. pub fn try_parse_as_uuid(value: &str) -> Result { if value.len() < 32 { let bytes: Vec = URL_SAFE_NO_PAD diff --git a/src/app_error.rs b/src/app_error.rs index bed2090..46770cd 100644 --- a/src/app_error.rs +++ b/src/app_error.rs @@ -9,8 +9,7 @@ pub struct AuthRedirectInfo { base_path: String, } -// Use anyhow, define error and enable '?' -// For a simplified example of using anyhow in axum check /examples/anyhow-error-response +/// Custom error type that maps to appropriate HTTP responses. #[derive(Debug)] pub enum AppError { InternalServerError(anyhow::Error), @@ -27,13 +26,13 @@ impl AppError { } pub fn from_validation_errors(errs: ValidationErrors) -> Self { + // TODO: customize validation errors formatting Self::BadRequestError( serde_json::to_string(&errs).unwrap_or("validation error".to_string()), ) } } -// Tell axum how to convert `AppError` into a response. impl IntoResponse for AppError { fn into_response(self) -> Response { match self { @@ -67,8 +66,7 @@ impl IntoResponse for AppError { } } -// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into -// `Result<_, AppError>`. That way you don't need to do that manually. +// Easily convert semi-arbitrary errors to InternalServerError impl From for AppError where E: Into, diff --git a/src/app_state.rs b/src/app_state.rs index 258e9ce..bc41e83 100644 --- a/src/app_state.rs +++ b/src/app_state.rs @@ -1,3 +1,6 @@ +use std::sync::Arc; + +use anyhow::Result; use axum::{ extract::{FromRef, FromRequestParts}, http::request::Parts, @@ -5,10 +8,15 @@ use axum::{ use deadpool_diesel::postgres::{Connection, Pool}; use oauth2::basic::BasicClient; -use crate::{app_error::AppError, email::Mailer, sessions::PgStore, settings::Settings}; +use crate::{ + app_error::AppError, + email::{Mailer, SmtpOptions}, + sessions::PgStore, + settings::Settings, +}; -#[derive(Clone)] -pub struct AppState { +/// Global app configuration +pub struct App { pub db_pool: Pool, pub mailer: Mailer, pub reqwest_client: reqwest::Client, @@ -17,6 +25,46 @@ pub struct AppState { pub settings: Settings, } +impl App { + /// Initialize global application functions based on config values + pub async fn from_settings(settings: Settings) -> Result { + let database_url = settings.database_url.clone(); + let manager = + deadpool_diesel::postgres::Manager::new(database_url, deadpool_diesel::Runtime::Tokio1); + let db_pool = deadpool_diesel::postgres::Pool::builder(manager).build()?; + + let session_store = PgStore::new(db_pool.clone()); + let reqwest_client = reqwest::ClientBuilder::new().https_only(true).build()?; + let oauth_client = crate::auth::new_oauth_client(&settings)?; + + let mailer = if let Some(smtp_settings) = settings.email.smtp.clone() { + Mailer::new_smtp(SmtpOptions { + server: smtp_settings.server, + username: smtp_settings.username, + password: smtp_settings.password, + })? + } else if let Some(postmark_settings) = settings.email.postmark.clone() { + Mailer::new_postmark(postmark_settings.server_token)? + .with_reqwest_client(reqwest_client.clone()) + } else { + return Err(anyhow::anyhow!("no email backend settings configured")); + }; + + Ok(Self { + db_pool, + mailer, + oauth_client, + reqwest_client, + session_store, + settings, + }) + } +} + +/// Global app configuration, arced for relatively inexpensive clones +pub type AppState = Arc; + +/// State extractor for shared reqwest client #[derive(Clone)] pub struct ReqwestClient(pub reqwest::Client); @@ -26,6 +74,7 @@ impl FromRef for ReqwestClient { } } +/// Extractor to automatically obtain a Deadpool database connection pub struct DbConn(pub Connection); impl FromRequestParts for DbConn { diff --git a/src/auth.rs b/src/auth.rs index f6c35f1..83ccab7 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -22,11 +22,12 @@ use crate::{ settings::Settings, }; -const SESSION_KEY_AUTH_CSRF_TOKEN: &'static str = "oauth_csrf_token"; -const SESSION_KEY_AUTH_REFRESH_TOKEN: &'static str = "oauth_refresh_token"; -const SESSION_KEY_AUTH_INFO: &'static str = "auth"; +const SESSION_KEY_AUTH_CSRF_TOKEN: &str = "oauth_csrf_token"; +const SESSION_KEY_AUTH_REFRESH_TOKEN: &str = "oauth_refresh_token"; +const SESSION_KEY_AUTH_INFO: &str = "auth"; -pub fn new_oauth_client(settings: &Settings) -> Result { +/// Creates a new OAuth2 client to be stored in global application state. +pub fn new_oauth_client(settings: &Settings) -> Result { Ok(BasicClient::new( ClientId::new(settings.auth.client_id.clone()), Some(ClientSecret::new(settings.auth.client_secret.clone())), @@ -43,14 +44,16 @@ pub fn new_oauth_client(settings: &Settings) -> Result { )) } +/// Creates a router which can be nested within the higher level app router. pub fn new_router() -> Router { Router::new() .route("/login", get(start_login)) - .route("/callback", get(login_authorized)) + .route("/callback", get(callback)) .route("/logout", get(logout)) } -pub async fn start_login( +/// HTTP get handler for /login +async fn start_login( State(state): State, State(Settings { auth: auth_settings, @@ -84,10 +87,11 @@ pub async fn start_login( .http_only(true) .path("/"), ); - Ok((jar, Redirect::to(&auth_url.to_string())).into_response()) + Ok((jar, Redirect::to(auth_url.as_ref())).into_response()) } -pub async fn logout( +/// HTTP get handler for /logout +async fn logout( State(Settings { base_path, auth: auth_settings, @@ -128,18 +132,14 @@ pub async fn logout( } #[derive(Debug, Deserialize)] -pub struct AuthRequestQuery { +struct AuthRequestQuery { code: String, - state: String, // CSRF token + /// CSRF token + state: String, } -#[derive(Debug, Deserialize, Serialize)] -pub struct AuthInfo { - pub sub: String, - pub email: String, -} - -pub async fn login_authorized( +/// HTTP get handler for /callback +async fn callback( Query(query): Query, State(state): State, State(Settings { @@ -153,9 +153,7 @@ pub async fn login_authorized( let mut session = if let Some(session) = session { session } else { - return Err(AppError::auth_redirect_from_base_path( - state.settings.base_path, - )); + return Err(AppError::auth_redirect_from_base_path(base_path)); }; let session_csrf_token: String = session.get(SESSION_KEY_AUTH_CSRF_TOKEN).ok_or_else(|| { tracing::debug!("oauth csrf token not found on session"); @@ -194,6 +192,13 @@ pub async fn login_authorized( Ok(Redirect::to(&format!("{}/", base_path))) } +/// Data stored in the visitor's session upon successful authentication. +#[derive(Debug, Deserialize, Serialize)] +pub struct AuthInfo { + pub sub: String, + pub email: String, +} + impl FromRequestParts for AuthInfo { type Rejection = AppError; @@ -214,7 +219,7 @@ impl FromRequestParts for AuthInfo { )?; Ok(user) } - // The Span.enter() guard pattern doesn't play nicely async + // The Span.enter() guard pattern doesn't play nicely with async .instrument(trace_span!("AuthInfo from_request_parts()")) .await } diff --git a/src/channel_selections.rs b/src/channel_selections.rs index 78f00ae..c88a1f6 100644 --- a/src/channel_selections.rs +++ b/src/channel_selections.rs @@ -25,12 +25,12 @@ impl ChannelSelection { } #[auto_type(no_type_alias)] - pub fn with_channel(channel_id: Uuid) -> _ { + pub fn with_channel<'a>(channel_id: &'a Uuid) -> _ { channel_selections::channel_id.eq(channel_id) } #[auto_type(no_type_alias)] - pub fn with_project(project_id: Uuid) -> _ { + pub fn with_project<'a>(project_id: &'a Uuid) -> _ { channel_selections::project_id.eq(project_id) } } diff --git a/src/channels.rs b/src/channels.rs index 07d672e..33e682e 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -14,15 +14,13 @@ use uuid::Uuid; use crate::{schema::channels, teams::Team}; -pub const CHANNEL_BACKEND_EMAIL: &'static str = "email"; -pub const CHANNEL_BACKEND_SLACK: &'static str = "slack"; +pub const CHANNEL_BACKEND_EMAIL: &str = "email"; +pub const CHANNEL_BACKEND_SLACK: &str = "slack"; -/** - * Represents a target/destination for messages, with the sender configuration - * defined in the backend_config field. A single channel may be attached to - * (in other words, "enabled" or "selected" for) any number of projects within - * the same team. - */ +/// Represents a target/destination for messages, with the sender configuration +/// defined in the backend_config field. A single channel may be attached to +/// (in other words, "enabled" or "selected" for) any number of projects within +/// the same team. #[derive(Associations, Clone, Debug, Identifiable, Queryable, Selectable)] #[diesel(belongs_to(Team))] #[diesel(check_for_backend(Pg))] @@ -57,20 +55,18 @@ impl Channel { } } -/** - * Encapsulates any information that needs to be persisted for setting up or - * using a channel's backend (that is, email sender, Slack app, etc.). This - * configuration is encoded to a jsonb column in the database, which determines - * the channel type along with configuration details. - * - * Note: In a previous implementation, channel configuration was handled by - * creating a dedicated table for each channel type and joining them to the - * `channels` table in order to access configuration fields. The jsonb approach - * simplifies database management and lends itself to a cleaner Rust - * implementation in which this enum can be treated as a column type with - * enforcement of data structure invariants handled entirely in the to_sql() - * and from_sql() serialization/deserialization logic. - */ +// Note: In a previous implementation, channel configuration was handled by +// creating a dedicated table for each channel type and joining them to the +// `channels` table in order to access configuration fields. The jsonb approach +// simplifies database management and lends itself to a cleaner Rust +// implementation in which this enum can be treated as a column type with +// enforcement of data structure invariants handled entirely in the to_sql() +// and from_sql() serialization/deserialization logic. + +/// Encapsulates any information that needs to be persisted for setting up or +/// using a channel's backend (that is, email sender, Slack app, etc.). This +/// configuration is encoded to a jsonb column in the database, which determines +/// the channel type along with configuration details. #[derive(AsExpression, Clone, Debug, FromSqlRow, Deserialize, Serialize)] #[diesel(sql_type = Jsonb)] pub enum BackendConfig { @@ -79,7 +75,7 @@ pub enum BackendConfig { } impl ToSql for BackendConfig { - fn to_sql<'a>(&self, out: &mut Output<'a, '_, Pg>) -> diesel::serialize::Result { + fn to_sql(&self, out: &mut Output<'_, '_, Pg>) -> diesel::serialize::Result { match self.clone() { BackendConfig::Email(config) => ToSql::::to_sql( &json!({ @@ -142,9 +138,9 @@ impl TryFrom for EmailBackendConfig { } } -impl Into for EmailBackendConfig { - fn into(self) -> BackendConfig { - BackendConfig::Email(self) +impl From for BackendConfig { + fn from(value: EmailBackendConfig) -> Self { + Self::Email(value) } } @@ -170,8 +166,8 @@ impl TryFrom for SlackBackendConfig { } } -impl Into for SlackBackendConfig { - fn into(self) -> BackendConfig { - BackendConfig::Slack(self) +impl From for BackendConfig { + fn from(value: SlackBackendConfig) -> Self { + Self::Slack(value) } } diff --git a/src/channels_router.rs b/src/channels_router.rs new file mode 100644 index 0000000..b3e7871 --- /dev/null +++ b/src/channels_router.rs @@ -0,0 +1,455 @@ +use anyhow::Context as _; +use askama::Template; +use axum::{ + extract::{Path, State}, + response::{Html, IntoResponse, Redirect}, + routing::{get, post}, + Router, +}; +use axum_extra::extract::Form; +use diesel::prelude::*; +use rand::Rng as _; +use regex::Regex; +use serde::Deserialize; +use uuid::Uuid; + +use crate::{ + app_error::AppError, + app_state::{AppState, DbConn}, + channels::{BackendConfig, Channel, EmailBackendConfig, CHANNEL_BACKEND_EMAIL}, + csrf::generate_csrf_token, + email::{MailSender as _, Mailer}, + guards, + nav_state::{Breadcrumb, NavState}, + schema::channels, + settings::Settings, + users::CurrentUser, +}; + +const VERIFICATION_CODE_LEN: usize = 6; + +/// Helper function to query a channel from the database by ID and team, and +/// return an appropriate error if no such channel exists. +fn get_channel_by_params<'a>( + conn: &mut PgConnection, + team_id: &'a Uuid, + channel_id: &'a Uuid, +) -> Result { + match Channel::all() + .filter(Channel::with_id(channel_id)) + .filter(Channel::with_team(team_id)) + .first(conn) + { + diesel::QueryResult::Err(diesel::result::Error::NotFound) => Err(AppError::NotFoundError( + "Channel with that team and ID not found.".to_string(), + )), + diesel::QueryResult::Err(err) => Err(err.into()), + diesel::QueryResult::Ok(channel) => Ok(channel), + } +} + +pub fn new_router() -> Router { + Router::new() + .route("/teams/{team_id}/channels", get(channels_page)) + .route("/teams/{team_id}/channels/{channel_id}", get(channel_page)) + .route( + "/teams/{team_id}/channels/{channel_id}/update-channel", + post(update_channel), + ) + .route( + "/teams/{team_id}/channels/{channel_id}/update-email-recipient", + post(update_channel_email_recipient), + ) + .route( + "/teams/{team_id}/channels/{channel_id}/verify-email", + post(verify_email), + ) + .route("/teams/{team_id}/new-channel", post(post_new_channel)) +} + +async fn channels_page( + State(Settings { base_path, .. }): State, + DbConn(db_conn): DbConn, + Path(team_id): Path, + CurrentUser(current_user): CurrentUser, +) -> Result { + let team = guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; + + let channels = { + db_conn + .interact(move |conn| { + Channel::all() + .filter(Channel::with_team(&team_id)) + .load(conn) + }) + .await + .unwrap() + .context("Failed to load channels list.")? + }; + + let csrf_token = generate_csrf_token(&db_conn, Some(current_user.id)).await?; + let nav_state = NavState::new() + .set_base_path(&base_path) + .push_team(&team) + .push_slug(Breadcrumb { + href: "channels".to_string(), + label: "Channels".to_string(), + }) + .set_navbar_active_item("channels"); + #[derive(Template)] + #[template(path = "channels.html")] + struct ResponseTemplate { + base_path: String, + channels: Vec, + csrf_token: String, + nav_state: NavState, + } + Ok(Html( + ResponseTemplate { + base_path, + channels, + csrf_token, + nav_state, + } + .render()?, + ) + .into_response()) +} + +#[derive(Deserialize)] +struct NewChannelPostFormBody { + csrf_token: String, + channel_type: String, +} + +async fn post_new_channel( + State(Settings { base_path, .. }): State, + DbConn(db_conn): DbConn, + Path(team_id): Path, + CurrentUser(current_user): CurrentUser, + Form(form_body): Form, +) -> Result { + let team = guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; + guards::require_valid_csrf_token(&form_body.csrf_token, ¤t_user, &db_conn).await?; + + let channel_id = Uuid::now_v7(); + let channel = match form_body.channel_type.as_str() { + CHANNEL_BACKEND_EMAIL => db_conn + .interact::<_, Result>(move |conn| { + Ok(diesel::insert_into(channels::table) + .values(( + channels::id.eq(channel_id), + channels::team_id.eq(team_id), + channels::name.eq("Untitled Email Channel"), + channels::backend_config + .eq(Into::::into(EmailBackendConfig::default())), + )) + .returning(Channel::as_returning()) + .get_result(conn) + .context("Failed to insert new EmailChannel.")?) + }) + .await + .unwrap()?, + _ => { + return Err(AppError::BadRequestError( + "Channel type not recognized.".to_string(), + )); + } + }; + + Ok(Redirect::to(&format!( + "{}/teams/{}/channels/{}", + base_path, + team.id.simple(), + channel.id.simple() + ))) +} + +async fn channel_page( + State(Settings { base_path, .. }): State, + DbConn(db_conn): DbConn, + Path((team_id, channel_id)): Path<(Uuid, Uuid)>, + CurrentUser(current_user): CurrentUser, +) -> Result { + let team = guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; + + let channel = { + match db_conn + .interact(move |conn| { + Channel::all() + .filter(Channel::with_id(&channel_id)) + .filter(Channel::with_team(&team_id)) + .first(conn) + .optional() + }) + .await + .unwrap()? + { + None => { + return Err(AppError::NotFoundError( + "Channel with that team and ID not found".to_string(), + )); + } + Some(channel) => channel, + } + }; + + let csrf_token = generate_csrf_token(&db_conn, Some(current_user.id)).await?; + let nav_state = NavState::new() + .set_base_path(&base_path) + .push_team(&team) + .push_slug(Breadcrumb { + href: "channels".to_string(), + label: "Channels".to_string(), + }) + .push_slug(Breadcrumb { + href: channel.id.simple().to_string(), + label: channel.name.clone(), + }) + .set_navbar_active_item("channels"); + + match channel.backend_config { + BackendConfig::Email(_) => { + #[derive(Template)] + #[template(path = "channel-email.html")] + struct ResponseTemplate { + base_path: String, + channel: Channel, + csrf_token: String, + nav_state: NavState, + } + Ok(Html( + ResponseTemplate { + base_path, + channel, + csrf_token, + nav_state, + } + .render()?, + )) + } + BackendConfig::Slack(_) => { + Err(anyhow::anyhow!("Slack channel config page is not yet implemented.").into()) + } + } +} + +#[derive(Deserialize)] +struct UpdateChannelFormBody { + csrf_token: String, + name: String, + enable_by_default: Option, +} + +async fn update_channel( + State(Settings { base_path, .. }): State, + DbConn(db_conn): DbConn, + Path((team_id, channel_id)): Path<(Uuid, Uuid)>, + CurrentUser(current_user): CurrentUser, + Form(form_body): Form, +) -> Result { + guards::require_valid_csrf_token(&form_body.csrf_token, ¤t_user, &db_conn).await?; + guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; + + let updated_rows = { + db_conn + .interact(move |conn| { + diesel::update( + channels::table + .filter(Channel::with_id(&channel_id)) + .filter(Channel::with_team(&team_id)), + ) + .set(( + channels::name.eq(form_body.name), + channels::enable_by_default + .eq(form_body.enable_by_default.unwrap_or("false".to_string()) == "true"), + )) + .execute(conn) + }) + .await + .unwrap() + .context("Failed to load Channel while updating.")? + }; + if updated_rows != 1 { + return Err(AppError::NotFoundError( + "Channel with that team and ID not found".to_string(), + )); + } + Ok(Redirect::to(&format!( + "{}/teams/{}/channels/{}", + base_path, + team_id.simple(), + channel_id.simple() + )) + .into_response()) +} + +#[derive(Deserialize)] +struct UpdateChannelEmailRecipientFormBody { + // Yes it's a mouthful, but it's only used twice + csrf_token: String, + recipient: String, +} + +async fn update_channel_email_recipient( + State(Settings { + base_path, + email: email_settings, + .. + }): State, + DbConn(db_conn): DbConn, + State(mailer): State, + Path((team_id, channel_id)): Path<(Uuid, Uuid)>, + CurrentUser(current_user): CurrentUser, + Form(form_body): Form, +) -> Result { + guards::require_valid_csrf_token(&form_body.csrf_token, ¤t_user, &db_conn).await?; + guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; + + if !is_permissible_email(&form_body.recipient) { + return Err(AppError::BadRequestError( + "Unable to validate email address format.".to_string(), + )); + } + let verification_code: String = rand::thread_rng() + .sample_iter(&rand::distributions::Uniform::from(0..9)) + .take(VERIFICATION_CODE_LEN) + .map(|n| n.to_string()) + .collect(); + + { + let verification_code = verification_code.clone(); + let recipient = form_body.recipient.clone(); + db_conn + .interact(move |conn| { + // TODO: transaction retries + conn.transaction::<_, AppError, _>(move |conn| { + let channel = get_channel_by_params(conn, &team_id, &channel_id)?; + let new_config = BackendConfig::Email(EmailBackendConfig { + recipient, + verification_code, + verification_code_guesses: 0, + ..channel.backend_config.try_into()? + }); + let num_rows = diesel::update(channels::table.filter(Channel::with_id(&channel.id))) + .set(channels::backend_config.eq(new_config)) + .execute(conn)?; + if num_rows != 1 { + return Err(anyhow::anyhow!( + "Updating EmailChannel recipient, the channel was found but {} rows were updated.", + num_rows + ) + .into()); + } + Ok(()) + }) + }) + .await + .unwrap()?; + } + + tracing::debug!( + "Email verification code for {} is: {}", + form_body.recipient, + verification_code + ); + tracing::info!( + "Sending email verification code to: {}", + form_body.recipient + ); + let email = crate::email::Message { + from: email_settings.verification_from, + to: form_body.recipient.parse()?, + subject: "Verify Your Email".to_string(), + text_body: format!("Your email verification code is: {}", verification_code), + }; + mailer.send_batch(vec![email]).await.remove(0)?; + + Ok(Redirect::to(&format!( + "{}/teams/{}/channels/{}", + base_path, + team_id.simple(), + channel_id.simple() + ))) +} + +/// Returns true if the email address matches a format recognized as "valid". +/// Not all "legal" email addresses will be accepted, but addresses that are +/// "illegal" and/or could result in unexpected behavior should be rejected. +fn is_permissible_email(address: &str) -> bool { + let re = Regex::new(r"^[a-zA-Z0-9._+-]+@([a-zA-Z0-9_-]+.)+[a-zA-Z]+$") + .expect("email validation regex should parse"); + re.is_match(address) +} + +#[derive(Deserialize)] +struct VerifyEmailFormBody { + csrf_token: String, + code: String, +} + +async fn verify_email( + State(Settings { base_path, .. }): State, + DbConn(db_conn): DbConn, + Path((team_id, channel_id)): Path<(Uuid, Uuid)>, + CurrentUser(current_user): CurrentUser, + Form(form_body): Form, +) -> Result { + guards::require_valid_csrf_token(&form_body.csrf_token, ¤t_user, &db_conn).await?; + guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; + + if form_body.code.len() != VERIFICATION_CODE_LEN { + return Err(AppError::BadRequestError(format!( + "Verification code must be {} characters long.", + VERIFICATION_CODE_LEN + ))); + } + + { + let verification_code = form_body.code; + db_conn + .interact(move |conn| { + conn.transaction::<(), AppError, _>(move |conn| { + let channel = get_channel_by_params(conn, &team_id, &channel_id)?; + let config: EmailBackendConfig = channel.backend_config.try_into()?; + if config.verified { + return Err(AppError::BadRequestError( + "Channel's email address is already verified.".to_string(), + )); + } + const MAX_VERIFICATION_GUESSES: u32 = 100; + if config.verification_code_guesses > MAX_VERIFICATION_GUESSES { + return Err(AppError::BadRequestError( + "Verification expired.".to_string(), + )); + } + let new_config = if config.verification_code == verification_code { + EmailBackendConfig { + verified: true, + verification_code: "".to_string(), + verification_code_guesses: 0, + ..config + } + } else { + EmailBackendConfig { + verification_code_guesses: config.verification_code_guesses + 1, + ..config + } + }; + diesel::update(channels::table.filter(Channel::with_id(&channel_id))) + .set(channels::backend_config.eq(Into::::into(new_config))) + .execute(conn)?; + Ok(()) + }) + }) + .await + .unwrap()?; + }; + + Ok(Redirect::to(&format!( + "{}/teams/{}/channels/{}", + base_path, + team_id.simple(), + channel_id.simple() + ))) +} diff --git a/src/csrf.rs b/src/csrf.rs index 1a1277b..d5270be 100644 --- a/src/csrf.rs +++ b/src/csrf.rs @@ -1,7 +1,7 @@ use chrono::{DateTime, TimeDelta, Utc}; use deadpool_diesel::postgres::Connection; use diesel::{ - dsl::{AsSelect, Eq, Gt, IsNotDistinctFrom, Select}, + dsl::{auto_type, AsSelect, Gt, Select}, pg::Pg, prelude::*, }; @@ -9,7 +9,7 @@ use uuid::Uuid; use crate::{app_error::AppError, schema::csrf_tokens::dsl::*}; -const TOKEN_PREFIX: &'static str = "csrf-"; +const TOKEN_PREFIX: &str = "csrf-"; #[derive(Clone, Debug, Identifiable, Queryable, Selectable)] #[diesel(table_name = crate::schema::csrf_tokens)] @@ -31,15 +31,18 @@ impl CsrfToken { created_at.gt(min_created_at) } - pub fn with_user_id(token_user_id: Option) -> IsNotDistinctFrom> { + #[auto_type(no_type_alias)] + pub fn with_user_id<'a>(token_user_id: &'a Option) -> _ { user_id.is_not_distinct_from(token_user_id) } - pub fn with_token_id(token_id: Uuid) -> Eq { + #[auto_type(no_type_alias)] + pub fn with_token_id<'a>(token_id: &'a Uuid) -> _ { id.eq(token_id) } } +/// Convenience function for creating new CSRF token rows in the database. pub async fn generate_csrf_token( db_conn: &Connection, with_user_id: Option, @@ -57,9 +60,10 @@ pub async fn generate_csrf_token( }) .await .unwrap()?; - Ok(format!("{}{}", TOKEN_PREFIX, token_id.simple().to_string())) + Ok(format!("{}{}", TOKEN_PREFIX, token_id.simple())) } +/// Convenience function for validating CSRF tokens against the database. pub async fn validate_csrf_token( db_conn: &Connection, token: &str, @@ -72,8 +76,8 @@ pub async fn validate_csrf_token( Ok(db_conn .interact(move |conn| { CsrfToken::all() - .filter(CsrfToken::with_token_id(token_id)) - .filter(CsrfToken::with_user_id(with_user_id)) + .filter(CsrfToken::with_token_id(&token_id)) + .filter(CsrfToken::with_user_id(&with_user_id)) .filter(CsrfToken::is_not_expired()) .first(conn) .optional() diff --git a/src/email.rs b/src/email.rs index 870ec61..98e2570 100644 --- a/src/email.rs +++ b/src/email.rs @@ -1,11 +1,12 @@ use anyhow::{Context, Result}; use axum::extract::FromRef; +use futures::Future; use lettre::{AsyncSmtpTransport, AsyncTransport, Tokio1Executor}; use serde::{Serialize, Serializer}; use crate::app_state::AppState; -const POSTMARK_EMAIL_BATCH_URL: &'static str = "https://api.postmarkapp.com/email/batch"; +const POSTMARK_EMAIL_BATCH_URL: &str = "https://api.postmarkapp.com/email/batch"; #[derive(Clone, Serialize)] pub struct Message { @@ -21,11 +22,9 @@ pub struct Message { } pub trait MailSender: Clone + Sync { - /** - * Attempt to send all messages defined by the input Vec. Send as many as - * possible, returning exactly one Result<()> for each message. - */ - async fn send_batch(&self, emails: Vec) -> Vec>; + /// Attempt to send all messages defined by the input Vec. Send as many as + /// possible, returning exactly one Result<()> for each message. + fn send_batch(&self, emails: Vec) -> impl Future>>; } #[derive(Clone, Debug)] @@ -61,7 +60,7 @@ impl MailSender for Mailer { } #[derive(Clone, Debug)] -struct SmtpSender { +pub struct SmtpSender { transport: AsyncSmtpTransport, } @@ -104,7 +103,7 @@ fn serialize_mailboxes(t: &lettre::message::Mailboxes, s: S) -> Result) -> Vec> { - /** - * Constructs a Vec with Ok(()) repeated n times. - */ + /// Constructs a Vec with Ok(()) repeated n times. macro_rules! all_ok { () => {{ let mut collection: Vec> = Vec::with_capacity(emails.len()); @@ -168,10 +163,8 @@ impl MailSender for PostmarkSender { }}; } - /** - * Constructs a Vec with a single specific error, followed by n-1 - * generic errors referring back to it. - */ + /// Constructs a Vec with a single specific error, followed by n-1 + /// generic errors referring back to it. macro_rules! cascade_err { ($err:expr) => {{ let mut collection: Vec> = Vec::with_capacity(emails.len()); @@ -183,15 +176,12 @@ impl MailSender for PostmarkSender { }}; } - /** - * Recursively splits the email batch in half and tries to send each - * half independently, allowing both to run to completion and then - * returning the first error of the two results, if present. - * - * This is implemented as a macro in order to avoid unstable async - * closures. - */ + /// Recursively splits the email batch in half and tries to send each + /// half independently, allowing both to run to completion and then + /// returning the first error of the two results, if present. macro_rules! split_and_retry { + // This is implemented as a macro in order to avoid unstable async + // closures. () => { if emails.len() < 2 { tracing::warn!("Postmark send batch cannot be split any further"); @@ -213,7 +203,7 @@ impl MailSender for PostmarkSender { const POSTMARK_MAX_REQUEST_BYTES: usize = 50 * 1000 * 1000; // TODO: Check email subject and body size against Postmark limits - if emails.len() == 0 { + if emails.is_empty() { tracing::debug!("no Postmark messages to send"); vec![Ok(())] } else if emails.len() > POSTMARK_MAX_BATCH_ENTRIES { @@ -248,13 +238,11 @@ impl MailSender for PostmarkSender { }; if resp.status().is_client_error() && emails.len() > 1 { split_and_retry!() + } else if let Err(err) = resp.error_for_status() { + cascade_err!(err.into()) } else { - if let Err(err) = resp.error_for_status() { - cascade_err!(err.into()) - } else { - tracing::debug!("sent Postmark batch of {} messages", emails.len()); - all_ok!() - } + tracing::debug!("sent Postmark batch of {} messages", emails.len()); + all_ok!() } } } diff --git a/src/governors.rs b/src/governors.rs index d3fc6d1..e4a2808 100644 --- a/src/governors.rs +++ b/src/governors.rs @@ -12,6 +12,7 @@ use uuid::Uuid; use crate::schema::{governor_entries, governors}; +// Expose built-in Postgres GREATEST() function to Diesel define_sql_function! { fn greatest(a: diesel::sql_types::Integer, b: diesel::sql_types::Integer) -> Integer } @@ -54,27 +55,26 @@ impl Governor { } #[auto_type(no_type_alias)] - pub fn with_id(governor_id: Uuid) -> _ { + pub fn with_id<'a>(governor_id: &'a Uuid) -> _ { governors::id.eq(governor_id) } #[auto_type(no_type_alias)] - pub fn with_team(team_id: Uuid) -> _ { + pub fn with_team<'a>(team_id: &'a Uuid) -> _ { governors::team_id.eq(team_id) } #[auto_type(no_type_alias)] - pub fn with_project(project_id: Option) -> _ { + pub fn with_project<'a>(project_id: &'a Option) -> _ { governors::project_id.is_not_distinct_from(project_id) } // TODO: return a custom result enum instead of a Result