1
0
Fork 0
forked from 2sys/shoutdotdev

restore originally requested uri path after login

This commit is contained in:
Brent Schroeter 2025-04-11 23:03:39 -07:00
parent 588bf33d6e
commit 1f08b5a590
3 changed files with 133 additions and 89 deletions

View file

@ -1,14 +1,9 @@
use std::fmt::{self, Display};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Redirect, Response};
use axum::response::{IntoResponse, Response};
use validator::ValidationErrors;
#[derive(Debug)]
pub struct AuthRedirectInfo {
base_path: String,
}
/// Custom error type that maps to appropriate HTTP responses.
#[derive(Debug)]
pub enum AppError {
@ -17,14 +12,9 @@ pub enum AppError {
NotFoundError(String),
BadRequestError(String),
TooManyRequestsError(String),
AuthRedirect(AuthRedirectInfo),
}
impl AppError {
pub fn auth_redirect_from_base_path(base_path: String) -> Self {
Self::AuthRedirect(AuthRedirectInfo { base_path })
}
pub fn from_validation_errors(errs: ValidationErrors) -> Self {
// TODO: customize validation errors formatting
Self::BadRequestError(
@ -36,10 +26,6 @@ impl AppError {
impl IntoResponse for AppError {
fn into_response(self) -> Response {
match self {
Self::AuthRedirect(AuthRedirectInfo { base_path }) => {
tracing::debug!("Handling AuthRedirect");
Redirect::to(&format!("{}/auth/login", base_path)).into_response()
}
Self::InternalServerError(err) => {
tracing::error!("Application error: {:?}", err);
(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response()
@ -79,7 +65,6 @@ where
impl Display for AppError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AppError::AuthRedirect(info) => write!(f, "AuthRedirect: {:?}", info),
AppError::InternalServerError(inner) => inner.fmt(f),
AppError::ForbiddenError(client_message) => {
write!(f, "ForbiddenError: {}", client_message)

View file

@ -1,11 +1,10 @@
use anyhow::{Context, Result};
use async_session::{Session, SessionStore as _};
use async_session::{Session, SessionStore};
use axum::{
extract::{FromRequestParts, Query, State},
http::request::Parts,
extract::{Query, State},
response::{IntoResponse, Redirect},
routing::get,
RequestPartsExt, Router,
Router,
};
use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
use oauth2::{
@ -13,7 +12,6 @@ use oauth2::{
ClientSecret, CsrfToken, RedirectUrl, RefreshToken, TokenResponse, TokenUrl,
};
use serde::{Deserialize, Serialize};
use tracing::{trace_span, Instrument};
use crate::{
app_error::AppError,
@ -24,7 +22,8 @@ use crate::{
const SESSION_KEY_AUTH_CSRF_TOKEN: &str = "oauth_csrf_token";
const SESSION_KEY_AUTH_REFRESH_TOKEN: &str = "oauth_refresh_token";
const SESSION_KEY_AUTH_INFO: &str = "auth";
pub const SESSION_KEY_AUTH_INFO: &str = "auth";
pub const SESSION_KEY_AUTH_REDIRECT: &str = "post_auth_redirect";
/// Creates a new OAuth2 client to be stored in global application state.
pub fn new_oauth_client(settings: &Settings) -> Result<BasicClient> {
@ -64,29 +63,33 @@ async fn start_login(
AppSession(maybe_session): AppSession,
jar: CookieJar,
) -> Result<impl IntoResponse, AppError> {
if let Some(session) = maybe_session {
if session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).is_some() {
tracing::debug!("already logged in, redirecting...");
return Ok(Redirect::to(&format!("{}/", base_path)).into_response());
}
let mut session = if let Some(value) = maybe_session {
value
} else {
Session::new()
};
if session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).is_some() {
tracing::debug!("already logged in, redirecting...");
return Ok(Redirect::to(&format!("{}/", base_path)).into_response());
}
assert!(session.get_raw(SESSION_KEY_AUTH_REFRESH_TOKEN).is_none());
let csrf_token = CsrfToken::new_random();
let (auth_url, _csrf_token) = state
.oauth_client
.authorize_url(|| csrf_token.clone())
.url();
let mut session = Session::new();
session.insert(SESSION_KEY_AUTH_CSRF_TOKEN, &csrf_token)?;
let cookie_value = session_store
.store_session(session)
.await?
.ok_or(anyhow::anyhow!("cookie value from store_session() is None"))?;
let jar = jar.add(
Cookie::build((auth_settings.cookie_name.clone(), cookie_value))
.same_site(SameSite::Lax)
.http_only(true)
.path("/"),
);
let (auth_url, _csrf_token) = state.oauth_client.authorize_url(|| csrf_token).url();
let jar = if let Some(cookie_value) = session_store.store_session(session).await? {
tracing::debug!("adding session cookie to jar");
jar.add(
Cookie::build((auth_settings.cookie_name.clone(), cookie_value))
.same_site(SameSite::Lax)
.http_only(true)
.path("/"),
)
} else {
tracing::debug!("inferred that session cookie already in jar");
jar
};
Ok((jar, Redirect::to(auth_url.as_ref())).into_response())
}
@ -150,14 +153,17 @@ async fn callback(
State(ReqwestClient(reqwest_client)): State<ReqwestClient>,
AppSession(session): AppSession,
) -> Result<impl IntoResponse, AppError> {
let mut session = if let Some(session) = session {
session
} else {
return Err(AppError::auth_redirect_from_base_path(base_path));
};
let mut session = session.ok_or_else(|| {
tracing::debug!("unable to load session");
AppError::Forbidden(
"our apologies: authentication session expired or lost, please try again".to_owned(),
)
})?;
let session_csrf_token: String = session.get(SESSION_KEY_AUTH_CSRF_TOKEN).ok_or_else(|| {
tracing::debug!("oauth csrf token not found on session");
AppError::auth_redirect_from_base_path(base_path.clone())
AppError::Forbidden(
"our apologies: authentication session expired or lost, please try again".to_owned(),
)
})?;
if session_csrf_token != query.state {
tracing::debug!("oauth csrf tokens did not match");
@ -180,6 +186,12 @@ async fn callback(
.json()
.await?;
tracing::debug!("updating session");
let redirect_target: Option<String> = session.get(SESSION_KEY_AUTH_REDIRECT);
// Remove this since we don't need or want it sticking around, for both UX
// and security hygiene reasons
session.remove(SESSION_KEY_AUTH_REDIRECT);
session.insert(SESSION_KEY_AUTH_INFO, &auth_info)?;
session.insert(SESSION_KEY_AUTH_REFRESH_TOKEN, response.refresh_token())?;
if state.session_store.store_session(session).await?.is_some() {
@ -189,7 +201,9 @@ async fn callback(
.into());
}
tracing::debug!("successfully authenticated");
Ok(Redirect::to(&format!("{}/", base_path)))
Ok(Redirect::to(
&redirect_target.unwrap_or(format!("{}/", base_path)),
))
}
/// Data stored in the visitor's session upon successful authentication.
@ -198,29 +212,3 @@ pub struct AuthInfo {
pub sub: String,
pub email: String,
}
impl FromRequestParts<AppState> for AuthInfo {
type Rejection = AppError;
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, <Self as FromRequestParts<AppState>>::Rejection> {
async move {
let session = parts
.extract_with_state::<AppSession, AppState>(state)
.await?
.0
.ok_or(AppError::auth_redirect_from_base_path(
state.settings.base_path.clone(),
))?;
let user = session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).ok_or(
AppError::auth_redirect_from_base_path(state.settings.base_path.clone()),
)?;
Ok(user)
}
// The Span.enter() guard pattern doesn't play nicely with async
.instrument(trace_span!("AuthInfo from_request_parts()"))
.await
}
}

View file

@ -1,5 +1,15 @@
use anyhow::Context;
use axum::{extract::FromRequestParts, http::request::Parts, RequestPartsExt};
use async_session::{Session, SessionStore as _};
use axum::{
extract::{FromRequestParts, OriginalUri},
http::{request::Parts, Method},
response::{IntoResponse, Redirect, Response},
RequestPartsExt,
};
use axum_extra::extract::{
cookie::{Cookie, SameSite},
CookieJar,
};
use diesel::{
associations::Identifiable,
deserialize::Queryable,
@ -13,8 +23,9 @@ use uuid::Uuid;
use crate::{
app_error::AppError,
app_state::AppState,
auth::AuthInfo,
auth::{AuthInfo, SESSION_KEY_AUTH_INFO, SESSION_KEY_AUTH_REDIRECT},
schema::{team_memberships, teams, users},
sessions::AppSession,
team_memberships::TeamMembership,
teams::Team,
};
@ -57,20 +68,54 @@ impl<S> FromRequestParts<S> for CurrentUser
where
S: Into<AppState> + Clone + Sync,
{
type Rejection = AppError;
type Rejection = CurrentUserRejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let state: AppState = state.clone().into();
let auth_info = parts
.extract_with_state::<AuthInfo, AppState>(&state)
.await
.map_err(|_| {
AppError::auth_redirect_from_base_path(state.settings.base_path.clone())
})?;
let current_user = state
.db_pool
.get()
.await?
let app_state: AppState = state.clone().into();
let mut session =
if let AppSession(Some(value)) = parts.extract_with_state(&app_state).await? {
value
} else {
Session::new()
};
let auth_info = if let Some(value) = session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO) {
value
} else {
let jar: CookieJar = parts.extract().await?;
let method: Method = parts.extract().await?;
let jar = if method == Method::GET {
let OriginalUri(uri) = parts.extract().await?;
session.insert(
SESSION_KEY_AUTH_REDIRECT,
uri.path_and_query()
.map(|value| value.to_string())
.unwrap_or(format!("{}/", app_state.settings.base_path)),
)?;
if let Some(cookie_value) = app_state.session_store.store_session(session).await? {
tracing::debug!("adding session cookie to jar");
jar.add(
Cookie::build((app_state.settings.auth.cookie_name.clone(), cookie_value))
.same_site(SameSite::Lax)
.http_only(true)
.path("/"),
)
} else {
tracing::debug!("inferred that session cookie already in jar");
jar
}
} else {
// If request method is not GET then do not attempt to infer the
// redirect target, as there may be no GET handler defined for
// it.
jar
};
return Err(Self::Rejection::SetCookiesAndRedirect(
jar,
format!("{}/auth/login", app_state.settings.base_path),
));
};
let db_conn = app_state.db_pool.get().await?;
let current_user = db_conn
.interact(move |conn| {
let maybe_current_user = User::all()
.filter(User::with_uid(&auth_info.sub))
@ -112,3 +157,29 @@ where
Ok(CurrentUser(current_user))
}
}
pub enum CurrentUserRejection {
AppError(AppError),
SetCookiesAndRedirect(CookieJar, String),
}
// Easily convert semi-arbitrary errors to InternalServerError
impl<E> From<E> for CurrentUserRejection
where
E: Into<AppError>,
{
fn from(err: E) -> Self {
Self::AppError(err.into())
}
}
impl IntoResponse for CurrentUserRejection {
fn into_response(self) -> Response {
match self {
Self::AppError(err) => err.into_response(),
Self::SetCookiesAndRedirect(jar, redirect_to) => {
(jar, Redirect::to(&redirect_to)).into_response()
}
}
}
}