From 1f08b5a590c54dee64aee42b6bd341d8b3bdecf0 Mon Sep 17 00:00:00 2001 From: Brent Schroeter Date: Fri, 11 Apr 2025 23:03:39 -0700 Subject: [PATCH] restore originally requested uri path after login --- src/app_error.rs | 17 +------- src/auth.rs | 106 +++++++++++++++++++++-------------------------- src/users.rs | 99 ++++++++++++++++++++++++++++++++++++------- 3 files changed, 133 insertions(+), 89 deletions(-) diff --git a/src/app_error.rs b/src/app_error.rs index 46770cd..4d1f9d7 100644 --- a/src/app_error.rs +++ b/src/app_error.rs @@ -1,14 +1,9 @@ use std::fmt::{self, Display}; use axum::http::StatusCode; -use axum::response::{IntoResponse, Redirect, Response}; +use axum::response::{IntoResponse, Response}; use validator::ValidationErrors; -#[derive(Debug)] -pub struct AuthRedirectInfo { - base_path: String, -} - /// Custom error type that maps to appropriate HTTP responses. #[derive(Debug)] pub enum AppError { @@ -17,14 +12,9 @@ pub enum AppError { NotFoundError(String), BadRequestError(String), TooManyRequestsError(String), - AuthRedirect(AuthRedirectInfo), } impl AppError { - pub fn auth_redirect_from_base_path(base_path: String) -> Self { - Self::AuthRedirect(AuthRedirectInfo { base_path }) - } - pub fn from_validation_errors(errs: ValidationErrors) -> Self { // TODO: customize validation errors formatting Self::BadRequestError( @@ -36,10 +26,6 @@ impl AppError { impl IntoResponse for AppError { fn into_response(self) -> Response { match self { - Self::AuthRedirect(AuthRedirectInfo { base_path }) => { - tracing::debug!("Handling AuthRedirect"); - Redirect::to(&format!("{}/auth/login", base_path)).into_response() - } Self::InternalServerError(err) => { tracing::error!("Application error: {:?}", err); (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response() @@ -79,7 +65,6 @@ where impl Display for AppError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - AppError::AuthRedirect(info) => write!(f, "AuthRedirect: {:?}", info), AppError::InternalServerError(inner) => inner.fmt(f), AppError::ForbiddenError(client_message) => { write!(f, "ForbiddenError: {}", client_message) diff --git a/src/auth.rs b/src/auth.rs index 83ccab7..453a80e 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,11 +1,10 @@ use anyhow::{Context, Result}; -use async_session::{Session, SessionStore as _}; +use async_session::{Session, SessionStore}; use axum::{ - extract::{FromRequestParts, Query, State}, - http::request::Parts, + extract::{Query, State}, response::{IntoResponse, Redirect}, routing::get, - RequestPartsExt, Router, + Router, }; use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite}; use oauth2::{ @@ -13,7 +12,6 @@ use oauth2::{ ClientSecret, CsrfToken, RedirectUrl, RefreshToken, TokenResponse, TokenUrl, }; use serde::{Deserialize, Serialize}; -use tracing::{trace_span, Instrument}; use crate::{ app_error::AppError, @@ -24,7 +22,8 @@ use crate::{ const SESSION_KEY_AUTH_CSRF_TOKEN: &str = "oauth_csrf_token"; const SESSION_KEY_AUTH_REFRESH_TOKEN: &str = "oauth_refresh_token"; -const SESSION_KEY_AUTH_INFO: &str = "auth"; +pub const SESSION_KEY_AUTH_INFO: &str = "auth"; +pub const SESSION_KEY_AUTH_REDIRECT: &str = "post_auth_redirect"; /// Creates a new OAuth2 client to be stored in global application state. pub fn new_oauth_client(settings: &Settings) -> Result { @@ -64,29 +63,33 @@ async fn start_login( AppSession(maybe_session): AppSession, jar: CookieJar, ) -> Result { - if let Some(session) = maybe_session { - if session.get::(SESSION_KEY_AUTH_INFO).is_some() { - tracing::debug!("already logged in, redirecting..."); - return Ok(Redirect::to(&format!("{}/", base_path)).into_response()); - } + let mut session = if let Some(value) = maybe_session { + value + } else { + Session::new() + }; + + if session.get::(SESSION_KEY_AUTH_INFO).is_some() { + tracing::debug!("already logged in, redirecting..."); + return Ok(Redirect::to(&format!("{}/", base_path)).into_response()); } + assert!(session.get_raw(SESSION_KEY_AUTH_REFRESH_TOKEN).is_none()); + let csrf_token = CsrfToken::new_random(); - let (auth_url, _csrf_token) = state - .oauth_client - .authorize_url(|| csrf_token.clone()) - .url(); - let mut session = Session::new(); session.insert(SESSION_KEY_AUTH_CSRF_TOKEN, &csrf_token)?; - let cookie_value = session_store - .store_session(session) - .await? - .ok_or(anyhow::anyhow!("cookie value from store_session() is None"))?; - let jar = jar.add( - Cookie::build((auth_settings.cookie_name.clone(), cookie_value)) - .same_site(SameSite::Lax) - .http_only(true) - .path("/"), - ); + let (auth_url, _csrf_token) = state.oauth_client.authorize_url(|| csrf_token).url(); + let jar = if let Some(cookie_value) = session_store.store_session(session).await? { + tracing::debug!("adding session cookie to jar"); + jar.add( + Cookie::build((auth_settings.cookie_name.clone(), cookie_value)) + .same_site(SameSite::Lax) + .http_only(true) + .path("/"), + ) + } else { + tracing::debug!("inferred that session cookie already in jar"); + jar + }; Ok((jar, Redirect::to(auth_url.as_ref())).into_response()) } @@ -150,14 +153,17 @@ async fn callback( State(ReqwestClient(reqwest_client)): State, AppSession(session): AppSession, ) -> Result { - let mut session = if let Some(session) = session { - session - } else { - return Err(AppError::auth_redirect_from_base_path(base_path)); - }; + let mut session = session.ok_or_else(|| { + tracing::debug!("unable to load session"); + AppError::Forbidden( + "our apologies: authentication session expired or lost, please try again".to_owned(), + ) + })?; let session_csrf_token: String = session.get(SESSION_KEY_AUTH_CSRF_TOKEN).ok_or_else(|| { tracing::debug!("oauth csrf token not found on session"); - AppError::auth_redirect_from_base_path(base_path.clone()) + AppError::Forbidden( + "our apologies: authentication session expired or lost, please try again".to_owned(), + ) })?; if session_csrf_token != query.state { tracing::debug!("oauth csrf tokens did not match"); @@ -180,6 +186,12 @@ async fn callback( .json() .await?; tracing::debug!("updating session"); + + let redirect_target: Option = session.get(SESSION_KEY_AUTH_REDIRECT); + // Remove this since we don't need or want it sticking around, for both UX + // and security hygiene reasons + session.remove(SESSION_KEY_AUTH_REDIRECT); + session.insert(SESSION_KEY_AUTH_INFO, &auth_info)?; session.insert(SESSION_KEY_AUTH_REFRESH_TOKEN, response.refresh_token())?; if state.session_store.store_session(session).await?.is_some() { @@ -189,7 +201,9 @@ async fn callback( .into()); } tracing::debug!("successfully authenticated"); - Ok(Redirect::to(&format!("{}/", base_path))) + Ok(Redirect::to( + &redirect_target.unwrap_or(format!("{}/", base_path)), + )) } /// Data stored in the visitor's session upon successful authentication. @@ -198,29 +212,3 @@ pub struct AuthInfo { pub sub: String, pub email: String, } - -impl FromRequestParts for AuthInfo { - type Rejection = AppError; - - async fn from_request_parts( - parts: &mut Parts, - state: &AppState, - ) -> Result>::Rejection> { - async move { - let session = parts - .extract_with_state::(state) - .await? - .0 - .ok_or(AppError::auth_redirect_from_base_path( - state.settings.base_path.clone(), - ))?; - let user = session.get::(SESSION_KEY_AUTH_INFO).ok_or( - AppError::auth_redirect_from_base_path(state.settings.base_path.clone()), - )?; - Ok(user) - } - // The Span.enter() guard pattern doesn't play nicely with async - .instrument(trace_span!("AuthInfo from_request_parts()")) - .await - } -} diff --git a/src/users.rs b/src/users.rs index 79ea8f8..91756ae 100644 --- a/src/users.rs +++ b/src/users.rs @@ -1,5 +1,15 @@ use anyhow::Context; -use axum::{extract::FromRequestParts, http::request::Parts, RequestPartsExt}; +use async_session::{Session, SessionStore as _}; +use axum::{ + extract::{FromRequestParts, OriginalUri}, + http::{request::Parts, Method}, + response::{IntoResponse, Redirect, Response}, + RequestPartsExt, +}; +use axum_extra::extract::{ + cookie::{Cookie, SameSite}, + CookieJar, +}; use diesel::{ associations::Identifiable, deserialize::Queryable, @@ -13,8 +23,9 @@ use uuid::Uuid; use crate::{ app_error::AppError, app_state::AppState, - auth::AuthInfo, + auth::{AuthInfo, SESSION_KEY_AUTH_INFO, SESSION_KEY_AUTH_REDIRECT}, schema::{team_memberships, teams, users}, + sessions::AppSession, team_memberships::TeamMembership, teams::Team, }; @@ -57,20 +68,54 @@ impl FromRequestParts for CurrentUser where S: Into + Clone + Sync, { - type Rejection = AppError; + type Rejection = CurrentUserRejection; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let state: AppState = state.clone().into(); - let auth_info = parts - .extract_with_state::(&state) - .await - .map_err(|_| { - AppError::auth_redirect_from_base_path(state.settings.base_path.clone()) - })?; - let current_user = state - .db_pool - .get() - .await? + let app_state: AppState = state.clone().into(); + let mut session = + if let AppSession(Some(value)) = parts.extract_with_state(&app_state).await? { + value + } else { + Session::new() + }; + let auth_info = if let Some(value) = session.get::(SESSION_KEY_AUTH_INFO) { + value + } else { + let jar: CookieJar = parts.extract().await?; + let method: Method = parts.extract().await?; + let jar = if method == Method::GET { + let OriginalUri(uri) = parts.extract().await?; + session.insert( + SESSION_KEY_AUTH_REDIRECT, + uri.path_and_query() + .map(|value| value.to_string()) + .unwrap_or(format!("{}/", app_state.settings.base_path)), + )?; + if let Some(cookie_value) = app_state.session_store.store_session(session).await? { + tracing::debug!("adding session cookie to jar"); + jar.add( + Cookie::build((app_state.settings.auth.cookie_name.clone(), cookie_value)) + .same_site(SameSite::Lax) + .http_only(true) + .path("/"), + ) + } else { + tracing::debug!("inferred that session cookie already in jar"); + jar + } + } else { + // If request method is not GET then do not attempt to infer the + // redirect target, as there may be no GET handler defined for + // it. + jar + }; + return Err(Self::Rejection::SetCookiesAndRedirect( + jar, + format!("{}/auth/login", app_state.settings.base_path), + )); + }; + let db_conn = app_state.db_pool.get().await?; + let current_user = db_conn .interact(move |conn| { let maybe_current_user = User::all() .filter(User::with_uid(&auth_info.sub)) @@ -112,3 +157,29 @@ where Ok(CurrentUser(current_user)) } } + +pub enum CurrentUserRejection { + AppError(AppError), + SetCookiesAndRedirect(CookieJar, String), +} + +// Easily convert semi-arbitrary errors to InternalServerError +impl From for CurrentUserRejection +where + E: Into, +{ + fn from(err: E) -> Self { + Self::AppError(err.into()) + } +} + +impl IntoResponse for CurrentUserRejection { + fn into_response(self) -> Response { + match self { + Self::AppError(err) => err.into_response(), + Self::SetCookiesAndRedirect(jar, redirect_to) => { + (jar, Redirect::to(&redirect_to)).into_response() + } + } + } +}