1
0
Fork 0
forked from 2sys/shoutdotdev

restore originally requested uri path after login

This commit is contained in:
Brent Schroeter 2025-04-11 23:03:39 -07:00
parent 588bf33d6e
commit 1f08b5a590
3 changed files with 133 additions and 89 deletions

View file

@ -1,14 +1,9 @@
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::response::{IntoResponse, Redirect, Response}; use axum::response::{IntoResponse, Response};
use validator::ValidationErrors; use validator::ValidationErrors;
#[derive(Debug)]
pub struct AuthRedirectInfo {
base_path: String,
}
/// Custom error type that maps to appropriate HTTP responses. /// Custom error type that maps to appropriate HTTP responses.
#[derive(Debug)] #[derive(Debug)]
pub enum AppError { pub enum AppError {
@ -17,14 +12,9 @@ pub enum AppError {
NotFoundError(String), NotFoundError(String),
BadRequestError(String), BadRequestError(String),
TooManyRequestsError(String), TooManyRequestsError(String),
AuthRedirect(AuthRedirectInfo),
} }
impl AppError { impl AppError {
pub fn auth_redirect_from_base_path(base_path: String) -> Self {
Self::AuthRedirect(AuthRedirectInfo { base_path })
}
pub fn from_validation_errors(errs: ValidationErrors) -> Self { pub fn from_validation_errors(errs: ValidationErrors) -> Self {
// TODO: customize validation errors formatting // TODO: customize validation errors formatting
Self::BadRequestError( Self::BadRequestError(
@ -36,10 +26,6 @@ impl AppError {
impl IntoResponse for AppError { impl IntoResponse for AppError {
fn into_response(self) -> Response { fn into_response(self) -> Response {
match self { match self {
Self::AuthRedirect(AuthRedirectInfo { base_path }) => {
tracing::debug!("Handling AuthRedirect");
Redirect::to(&format!("{}/auth/login", base_path)).into_response()
}
Self::InternalServerError(err) => { Self::InternalServerError(err) => {
tracing::error!("Application error: {:?}", err); tracing::error!("Application error: {:?}", err);
(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response() (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response()
@ -79,7 +65,6 @@ where
impl Display for AppError { impl Display for AppError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
AppError::AuthRedirect(info) => write!(f, "AuthRedirect: {:?}", info),
AppError::InternalServerError(inner) => inner.fmt(f), AppError::InternalServerError(inner) => inner.fmt(f),
AppError::ForbiddenError(client_message) => { AppError::ForbiddenError(client_message) => {
write!(f, "ForbiddenError: {}", client_message) write!(f, "ForbiddenError: {}", client_message)

View file

@ -1,11 +1,10 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use async_session::{Session, SessionStore as _}; use async_session::{Session, SessionStore};
use axum::{ use axum::{
extract::{FromRequestParts, Query, State}, extract::{Query, State},
http::request::Parts,
response::{IntoResponse, Redirect}, response::{IntoResponse, Redirect},
routing::get, routing::get,
RequestPartsExt, Router, Router,
}; };
use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite}; use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
use oauth2::{ use oauth2::{
@ -13,7 +12,6 @@ use oauth2::{
ClientSecret, CsrfToken, RedirectUrl, RefreshToken, TokenResponse, TokenUrl, ClientSecret, CsrfToken, RedirectUrl, RefreshToken, TokenResponse, TokenUrl,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::{trace_span, Instrument};
use crate::{ use crate::{
app_error::AppError, app_error::AppError,
@ -24,7 +22,8 @@ use crate::{
const SESSION_KEY_AUTH_CSRF_TOKEN: &str = "oauth_csrf_token"; const SESSION_KEY_AUTH_CSRF_TOKEN: &str = "oauth_csrf_token";
const SESSION_KEY_AUTH_REFRESH_TOKEN: &str = "oauth_refresh_token"; const SESSION_KEY_AUTH_REFRESH_TOKEN: &str = "oauth_refresh_token";
const SESSION_KEY_AUTH_INFO: &str = "auth"; pub const SESSION_KEY_AUTH_INFO: &str = "auth";
pub const SESSION_KEY_AUTH_REDIRECT: &str = "post_auth_redirect";
/// Creates a new OAuth2 client to be stored in global application state. /// Creates a new OAuth2 client to be stored in global application state.
pub fn new_oauth_client(settings: &Settings) -> Result<BasicClient> { pub fn new_oauth_client(settings: &Settings) -> Result<BasicClient> {
@ -64,29 +63,33 @@ async fn start_login(
AppSession(maybe_session): AppSession, AppSession(maybe_session): AppSession,
jar: CookieJar, jar: CookieJar,
) -> Result<impl IntoResponse, AppError> { ) -> Result<impl IntoResponse, AppError> {
if let Some(session) = maybe_session { let mut session = if let Some(value) = maybe_session {
if session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).is_some() { value
tracing::debug!("already logged in, redirecting..."); } else {
return Ok(Redirect::to(&format!("{}/", base_path)).into_response()); Session::new()
} };
if session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).is_some() {
tracing::debug!("already logged in, redirecting...");
return Ok(Redirect::to(&format!("{}/", base_path)).into_response());
} }
assert!(session.get_raw(SESSION_KEY_AUTH_REFRESH_TOKEN).is_none());
let csrf_token = CsrfToken::new_random(); let csrf_token = CsrfToken::new_random();
let (auth_url, _csrf_token) = state
.oauth_client
.authorize_url(|| csrf_token.clone())
.url();
let mut session = Session::new();
session.insert(SESSION_KEY_AUTH_CSRF_TOKEN, &csrf_token)?; session.insert(SESSION_KEY_AUTH_CSRF_TOKEN, &csrf_token)?;
let cookie_value = session_store let (auth_url, _csrf_token) = state.oauth_client.authorize_url(|| csrf_token).url();
.store_session(session) let jar = if let Some(cookie_value) = session_store.store_session(session).await? {
.await? tracing::debug!("adding session cookie to jar");
.ok_or(anyhow::anyhow!("cookie value from store_session() is None"))?; jar.add(
let jar = jar.add( Cookie::build((auth_settings.cookie_name.clone(), cookie_value))
Cookie::build((auth_settings.cookie_name.clone(), cookie_value)) .same_site(SameSite::Lax)
.same_site(SameSite::Lax) .http_only(true)
.http_only(true) .path("/"),
.path("/"), )
); } else {
tracing::debug!("inferred that session cookie already in jar");
jar
};
Ok((jar, Redirect::to(auth_url.as_ref())).into_response()) Ok((jar, Redirect::to(auth_url.as_ref())).into_response())
} }
@ -150,14 +153,17 @@ async fn callback(
State(ReqwestClient(reqwest_client)): State<ReqwestClient>, State(ReqwestClient(reqwest_client)): State<ReqwestClient>,
AppSession(session): AppSession, AppSession(session): AppSession,
) -> Result<impl IntoResponse, AppError> { ) -> Result<impl IntoResponse, AppError> {
let mut session = if let Some(session) = session { let mut session = session.ok_or_else(|| {
session tracing::debug!("unable to load session");
} else { AppError::Forbidden(
return Err(AppError::auth_redirect_from_base_path(base_path)); "our apologies: authentication session expired or lost, please try again".to_owned(),
}; )
})?;
let session_csrf_token: String = session.get(SESSION_KEY_AUTH_CSRF_TOKEN).ok_or_else(|| { let session_csrf_token: String = session.get(SESSION_KEY_AUTH_CSRF_TOKEN).ok_or_else(|| {
tracing::debug!("oauth csrf token not found on session"); tracing::debug!("oauth csrf token not found on session");
AppError::auth_redirect_from_base_path(base_path.clone()) AppError::Forbidden(
"our apologies: authentication session expired or lost, please try again".to_owned(),
)
})?; })?;
if session_csrf_token != query.state { if session_csrf_token != query.state {
tracing::debug!("oauth csrf tokens did not match"); tracing::debug!("oauth csrf tokens did not match");
@ -180,6 +186,12 @@ async fn callback(
.json() .json()
.await?; .await?;
tracing::debug!("updating session"); tracing::debug!("updating session");
let redirect_target: Option<String> = session.get(SESSION_KEY_AUTH_REDIRECT);
// Remove this since we don't need or want it sticking around, for both UX
// and security hygiene reasons
session.remove(SESSION_KEY_AUTH_REDIRECT);
session.insert(SESSION_KEY_AUTH_INFO, &auth_info)?; session.insert(SESSION_KEY_AUTH_INFO, &auth_info)?;
session.insert(SESSION_KEY_AUTH_REFRESH_TOKEN, response.refresh_token())?; session.insert(SESSION_KEY_AUTH_REFRESH_TOKEN, response.refresh_token())?;
if state.session_store.store_session(session).await?.is_some() { if state.session_store.store_session(session).await?.is_some() {
@ -189,7 +201,9 @@ async fn callback(
.into()); .into());
} }
tracing::debug!("successfully authenticated"); tracing::debug!("successfully authenticated");
Ok(Redirect::to(&format!("{}/", base_path))) Ok(Redirect::to(
&redirect_target.unwrap_or(format!("{}/", base_path)),
))
} }
/// Data stored in the visitor's session upon successful authentication. /// Data stored in the visitor's session upon successful authentication.
@ -198,29 +212,3 @@ pub struct AuthInfo {
pub sub: String, pub sub: String,
pub email: String, pub email: String,
} }
impl FromRequestParts<AppState> for AuthInfo {
type Rejection = AppError;
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, <Self as FromRequestParts<AppState>>::Rejection> {
async move {
let session = parts
.extract_with_state::<AppSession, AppState>(state)
.await?
.0
.ok_or(AppError::auth_redirect_from_base_path(
state.settings.base_path.clone(),
))?;
let user = session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).ok_or(
AppError::auth_redirect_from_base_path(state.settings.base_path.clone()),
)?;
Ok(user)
}
// The Span.enter() guard pattern doesn't play nicely with async
.instrument(trace_span!("AuthInfo from_request_parts()"))
.await
}
}

View file

@ -1,5 +1,15 @@
use anyhow::Context; use anyhow::Context;
use axum::{extract::FromRequestParts, http::request::Parts, RequestPartsExt}; use async_session::{Session, SessionStore as _};
use axum::{
extract::{FromRequestParts, OriginalUri},
http::{request::Parts, Method},
response::{IntoResponse, Redirect, Response},
RequestPartsExt,
};
use axum_extra::extract::{
cookie::{Cookie, SameSite},
CookieJar,
};
use diesel::{ use diesel::{
associations::Identifiable, associations::Identifiable,
deserialize::Queryable, deserialize::Queryable,
@ -13,8 +23,9 @@ use uuid::Uuid;
use crate::{ use crate::{
app_error::AppError, app_error::AppError,
app_state::AppState, app_state::AppState,
auth::AuthInfo, auth::{AuthInfo, SESSION_KEY_AUTH_INFO, SESSION_KEY_AUTH_REDIRECT},
schema::{team_memberships, teams, users}, schema::{team_memberships, teams, users},
sessions::AppSession,
team_memberships::TeamMembership, team_memberships::TeamMembership,
teams::Team, teams::Team,
}; };
@ -57,20 +68,54 @@ impl<S> FromRequestParts<S> for CurrentUser
where where
S: Into<AppState> + Clone + Sync, S: Into<AppState> + Clone + Sync,
{ {
type Rejection = AppError; type Rejection = CurrentUserRejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let state: AppState = state.clone().into(); let app_state: AppState = state.clone().into();
let auth_info = parts let mut session =
.extract_with_state::<AuthInfo, AppState>(&state) if let AppSession(Some(value)) = parts.extract_with_state(&app_state).await? {
.await value
.map_err(|_| { } else {
AppError::auth_redirect_from_base_path(state.settings.base_path.clone()) Session::new()
})?; };
let current_user = state let auth_info = if let Some(value) = session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO) {
.db_pool value
.get() } else {
.await? let jar: CookieJar = parts.extract().await?;
let method: Method = parts.extract().await?;
let jar = if method == Method::GET {
let OriginalUri(uri) = parts.extract().await?;
session.insert(
SESSION_KEY_AUTH_REDIRECT,
uri.path_and_query()
.map(|value| value.to_string())
.unwrap_or(format!("{}/", app_state.settings.base_path)),
)?;
if let Some(cookie_value) = app_state.session_store.store_session(session).await? {
tracing::debug!("adding session cookie to jar");
jar.add(
Cookie::build((app_state.settings.auth.cookie_name.clone(), cookie_value))
.same_site(SameSite::Lax)
.http_only(true)
.path("/"),
)
} else {
tracing::debug!("inferred that session cookie already in jar");
jar
}
} else {
// If request method is not GET then do not attempt to infer the
// redirect target, as there may be no GET handler defined for
// it.
jar
};
return Err(Self::Rejection::SetCookiesAndRedirect(
jar,
format!("{}/auth/login", app_state.settings.base_path),
));
};
let db_conn = app_state.db_pool.get().await?;
let current_user = db_conn
.interact(move |conn| { .interact(move |conn| {
let maybe_current_user = User::all() let maybe_current_user = User::all()
.filter(User::with_uid(&auth_info.sub)) .filter(User::with_uid(&auth_info.sub))
@ -112,3 +157,29 @@ where
Ok(CurrentUser(current_user)) Ok(CurrentUser(current_user))
} }
} }
pub enum CurrentUserRejection {
AppError(AppError),
SetCookiesAndRedirect(CookieJar, String),
}
// Easily convert semi-arbitrary errors to InternalServerError
impl<E> From<E> for CurrentUserRejection
where
E: Into<AppError>,
{
fn from(err: E) -> Self {
Self::AppError(err.into())
}
}
impl IntoResponse for CurrentUserRejection {
fn into_response(self) -> Response {
match self {
Self::AppError(err) => err.into_response(),
Self::SetCookiesAndRedirect(jar, redirect_to) => {
(jar, Redirect::to(&redirect_to)).into_response()
}
}
}
}