From da38946dbd3589f6d2a17d43ada96da7d66000e8 Mon Sep 17 00:00:00 2001 From: Brent Schroeter Date: Wed, 26 Feb 2025 13:10:47 -0800 Subject: [PATCH] re-implement guards using AppError returns instead of macros --- src/app_error.rs | 7 +++- src/guards.rs | 86 ++++++++++++++++++++++++------------------------ src/router.rs | 8 ++--- 3 files changed, 53 insertions(+), 48 deletions(-) diff --git a/src/app_error.rs b/src/app_error.rs index 3c054fe..88274a4 100644 --- a/src/app_error.rs +++ b/src/app_error.rs @@ -7,6 +7,7 @@ use axum::response::{IntoResponse, Response}; #[derive(Debug)] pub enum AppError { InternalServerError(Error), + ForbiddenError(String), } // Tell axum how to convert `AppError` into a response. @@ -14,9 +15,13 @@ impl IntoResponse for AppError { fn into_response(self) -> Response { match self { Self::InternalServerError(err) => { - tracing::error!("Application error: {:#}", err); + tracing::error!("Application error: {}", err); (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response() } + Self::ForbiddenError(client_message) => { + tracing::info!("Forbidden: {}", client_message); + (StatusCode::FORBIDDEN, client_message).into_response() + } } } } diff --git a/src/guards.rs b/src/guards.rs index 3252f88..0c3410e 100644 --- a/src/guards.rs +++ b/src/guards.rs @@ -1,45 +1,45 @@ -macro_rules! require_team_membership { - ($current_user:expr, $team_id:expr, $db_conn:expr) => {{ - let current_user_id = $current_user.id.clone(); - match $db_conn - .interact(move |conn| { - crate::team_memberships::TeamMembership::all() - .filter(crate::team_memberships::TeamMembership::with_user_id( - current_user_id, - )) - .filter(crate::team_memberships::TeamMembership::with_team_id( - $team_id, - )) - .first(conn) - .optional() - }) - .await - .unwrap()? - { - Some((team, _)) => team, - None => { - return Ok(( - axum::http::StatusCode::FORBIDDEN, - "not a member of requested team".to_string(), - ) - .into_response()); - } - } - }}; -} -pub(crate) use require_team_membership; +use deadpool_diesel::postgres::Connection; +use diesel::prelude::*; +use uuid::Uuid; -macro_rules! require_valid_csrf_token { - ($csrf_token:expr, $current_user:expr, $db_conn:expr) => {{ - if !crate::csrf::validate_csrf_token(&$db_conn, &$csrf_token, Some($current_user.id)) - .await? - { - return Ok(( - axum::http::StatusCode::FORBIDDEN, - "invalid CSRF token".to_string(), - ) - .into_response()); - } - }}; +use crate::{ + app_error::AppError, csrf::validate_csrf_token, team_memberships::TeamMembership, teams::Team, + users::User, +}; + +pub async fn require_team_membership( + current_user: &User, + team_id: &Uuid, + db_conn: &Connection, +) -> Result { + let current_user_id = current_user.id.clone(); + let team_id = team_id.clone(); + match db_conn + .interact(move |conn| { + TeamMembership::all() + .filter(TeamMembership::with_user_id(current_user_id)) + .filter(TeamMembership::with_team_id(team_id)) + .first(conn) + .optional() + }) + .await + .unwrap()? + { + Some((team, _)) => Ok(team), + None => Err(AppError::ForbiddenError( + "not a member of requested team".to_string(), + )), + } +} + +pub async fn require_valid_csrf_token( + csrf_token: &str, + current_user: &User, + db_conn: &Connection, +) -> Result<(), AppError> { + if validate_csrf_token(db_conn, csrf_token, Some(current_user.id.clone())).await? { + Ok(()) + } else { + Err(AppError::ForbiddenError("invalid CSRF token".to_string())) + } } -pub(crate) use require_valid_csrf_token; diff --git a/src/router.rs b/src/router.rs index 22099cf..6f7fbcb 100644 --- a/src/router.rs +++ b/src/router.rs @@ -111,8 +111,8 @@ async fn post_new_api_key( CurrentUser(current_user): CurrentUser, Form(form): Form, ) -> Result { - guards::require_valid_csrf_token!(form.csrf_token, current_user, db_conn); - let team = guards::require_team_membership!(current_user, team_id, db_conn); + guards::require_valid_csrf_token(&form.csrf_token, ¤t_user, &db_conn).await?; + let team = guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; ApiKey::generate_for_team(&db_conn, team.id.clone()).await?; Ok(Redirect::to(&format!( @@ -158,7 +158,7 @@ async fn post_new_team( CurrentUser(current_user): CurrentUser, Form(form): Form, ) -> Result { - guards::require_valid_csrf_token!(form.csrf_token, current_user, db_conn); + guards::require_valid_csrf_token(&form.csrf_token, ¤t_user, &db_conn).await?; let team_id = Uuid::now_v7(); let team = Team { @@ -195,7 +195,7 @@ async fn projects_page( Path(team_id): Path, CurrentUser(current_user): CurrentUser, ) -> Result { - let team = guards::require_team_membership!(current_user, team_id, db_conn); + let team = guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; let team_id = team.id.clone(); let api_keys = db_conn