forked from 2sys/shoutdotdev
restore originally requested uri path after login
This commit is contained in:
parent
588bf33d6e
commit
1f08b5a590
3 changed files with 133 additions and 89 deletions
|
@ -1,14 +1,9 @@
|
|||
use std::fmt::{self, Display};
|
||||
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Redirect, Response};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use validator::ValidationErrors;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRedirectInfo {
|
||||
base_path: String,
|
||||
}
|
||||
|
||||
/// Custom error type that maps to appropriate HTTP responses.
|
||||
#[derive(Debug)]
|
||||
pub enum AppError {
|
||||
|
@ -17,14 +12,9 @@ pub enum AppError {
|
|||
NotFoundError(String),
|
||||
BadRequestError(String),
|
||||
TooManyRequestsError(String),
|
||||
AuthRedirect(AuthRedirectInfo),
|
||||
}
|
||||
|
||||
impl AppError {
|
||||
pub fn auth_redirect_from_base_path(base_path: String) -> Self {
|
||||
Self::AuthRedirect(AuthRedirectInfo { base_path })
|
||||
}
|
||||
|
||||
pub fn from_validation_errors(errs: ValidationErrors) -> Self {
|
||||
// TODO: customize validation errors formatting
|
||||
Self::BadRequestError(
|
||||
|
@ -36,10 +26,6 @@ impl AppError {
|
|||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
match self {
|
||||
Self::AuthRedirect(AuthRedirectInfo { base_path }) => {
|
||||
tracing::debug!("Handling AuthRedirect");
|
||||
Redirect::to(&format!("{}/auth/login", base_path)).into_response()
|
||||
}
|
||||
Self::InternalServerError(err) => {
|
||||
tracing::error!("Application error: {:?}", err);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response()
|
||||
|
@ -79,7 +65,6 @@ where
|
|||
impl Display for AppError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
AppError::AuthRedirect(info) => write!(f, "AuthRedirect: {:?}", info),
|
||||
AppError::InternalServerError(inner) => inner.fmt(f),
|
||||
AppError::ForbiddenError(client_message) => {
|
||||
write!(f, "ForbiddenError: {}", client_message)
|
||||
|
|
106
src/auth.rs
106
src/auth.rs
|
@ -1,11 +1,10 @@
|
|||
use anyhow::{Context, Result};
|
||||
use async_session::{Session, SessionStore as _};
|
||||
use async_session::{Session, SessionStore};
|
||||
use axum::{
|
||||
extract::{FromRequestParts, Query, State},
|
||||
http::request::Parts,
|
||||
extract::{Query, State},
|
||||
response::{IntoResponse, Redirect},
|
||||
routing::get,
|
||||
RequestPartsExt, Router,
|
||||
Router,
|
||||
};
|
||||
use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
|
||||
use oauth2::{
|
||||
|
@ -13,7 +12,6 @@ use oauth2::{
|
|||
ClientSecret, CsrfToken, RedirectUrl, RefreshToken, TokenResponse, TokenUrl,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{trace_span, Instrument};
|
||||
|
||||
use crate::{
|
||||
app_error::AppError,
|
||||
|
@ -24,7 +22,8 @@ use crate::{
|
|||
|
||||
const SESSION_KEY_AUTH_CSRF_TOKEN: &str = "oauth_csrf_token";
|
||||
const SESSION_KEY_AUTH_REFRESH_TOKEN: &str = "oauth_refresh_token";
|
||||
const SESSION_KEY_AUTH_INFO: &str = "auth";
|
||||
pub const SESSION_KEY_AUTH_INFO: &str = "auth";
|
||||
pub const SESSION_KEY_AUTH_REDIRECT: &str = "post_auth_redirect";
|
||||
|
||||
/// Creates a new OAuth2 client to be stored in global application state.
|
||||
pub fn new_oauth_client(settings: &Settings) -> Result<BasicClient> {
|
||||
|
@ -64,29 +63,33 @@ async fn start_login(
|
|||
AppSession(maybe_session): AppSession,
|
||||
jar: CookieJar,
|
||||
) -> Result<impl IntoResponse, AppError> {
|
||||
if let Some(session) = maybe_session {
|
||||
if session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).is_some() {
|
||||
tracing::debug!("already logged in, redirecting...");
|
||||
return Ok(Redirect::to(&format!("{}/", base_path)).into_response());
|
||||
}
|
||||
let mut session = if let Some(value) = maybe_session {
|
||||
value
|
||||
} else {
|
||||
Session::new()
|
||||
};
|
||||
|
||||
if session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).is_some() {
|
||||
tracing::debug!("already logged in, redirecting...");
|
||||
return Ok(Redirect::to(&format!("{}/", base_path)).into_response());
|
||||
}
|
||||
assert!(session.get_raw(SESSION_KEY_AUTH_REFRESH_TOKEN).is_none());
|
||||
|
||||
let csrf_token = CsrfToken::new_random();
|
||||
let (auth_url, _csrf_token) = state
|
||||
.oauth_client
|
||||
.authorize_url(|| csrf_token.clone())
|
||||
.url();
|
||||
let mut session = Session::new();
|
||||
session.insert(SESSION_KEY_AUTH_CSRF_TOKEN, &csrf_token)?;
|
||||
let cookie_value = session_store
|
||||
.store_session(session)
|
||||
.await?
|
||||
.ok_or(anyhow::anyhow!("cookie value from store_session() is None"))?;
|
||||
let jar = jar.add(
|
||||
Cookie::build((auth_settings.cookie_name.clone(), cookie_value))
|
||||
.same_site(SameSite::Lax)
|
||||
.http_only(true)
|
||||
.path("/"),
|
||||
);
|
||||
let (auth_url, _csrf_token) = state.oauth_client.authorize_url(|| csrf_token).url();
|
||||
let jar = if let Some(cookie_value) = session_store.store_session(session).await? {
|
||||
tracing::debug!("adding session cookie to jar");
|
||||
jar.add(
|
||||
Cookie::build((auth_settings.cookie_name.clone(), cookie_value))
|
||||
.same_site(SameSite::Lax)
|
||||
.http_only(true)
|
||||
.path("/"),
|
||||
)
|
||||
} else {
|
||||
tracing::debug!("inferred that session cookie already in jar");
|
||||
jar
|
||||
};
|
||||
Ok((jar, Redirect::to(auth_url.as_ref())).into_response())
|
||||
}
|
||||
|
||||
|
@ -150,14 +153,17 @@ async fn callback(
|
|||
State(ReqwestClient(reqwest_client)): State<ReqwestClient>,
|
||||
AppSession(session): AppSession,
|
||||
) -> Result<impl IntoResponse, AppError> {
|
||||
let mut session = if let Some(session) = session {
|
||||
session
|
||||
} else {
|
||||
return Err(AppError::auth_redirect_from_base_path(base_path));
|
||||
};
|
||||
let mut session = session.ok_or_else(|| {
|
||||
tracing::debug!("unable to load session");
|
||||
AppError::Forbidden(
|
||||
"our apologies: authentication session expired or lost, please try again".to_owned(),
|
||||
)
|
||||
})?;
|
||||
let session_csrf_token: String = session.get(SESSION_KEY_AUTH_CSRF_TOKEN).ok_or_else(|| {
|
||||
tracing::debug!("oauth csrf token not found on session");
|
||||
AppError::auth_redirect_from_base_path(base_path.clone())
|
||||
AppError::Forbidden(
|
||||
"our apologies: authentication session expired or lost, please try again".to_owned(),
|
||||
)
|
||||
})?;
|
||||
if session_csrf_token != query.state {
|
||||
tracing::debug!("oauth csrf tokens did not match");
|
||||
|
@ -180,6 +186,12 @@ async fn callback(
|
|||
.json()
|
||||
.await?;
|
||||
tracing::debug!("updating session");
|
||||
|
||||
let redirect_target: Option<String> = session.get(SESSION_KEY_AUTH_REDIRECT);
|
||||
// Remove this since we don't need or want it sticking around, for both UX
|
||||
// and security hygiene reasons
|
||||
session.remove(SESSION_KEY_AUTH_REDIRECT);
|
||||
|
||||
session.insert(SESSION_KEY_AUTH_INFO, &auth_info)?;
|
||||
session.insert(SESSION_KEY_AUTH_REFRESH_TOKEN, response.refresh_token())?;
|
||||
if state.session_store.store_session(session).await?.is_some() {
|
||||
|
@ -189,7 +201,9 @@ async fn callback(
|
|||
.into());
|
||||
}
|
||||
tracing::debug!("successfully authenticated");
|
||||
Ok(Redirect::to(&format!("{}/", base_path)))
|
||||
Ok(Redirect::to(
|
||||
&redirect_target.unwrap_or(format!("{}/", base_path)),
|
||||
))
|
||||
}
|
||||
|
||||
/// Data stored in the visitor's session upon successful authentication.
|
||||
|
@ -198,29 +212,3 @@ pub struct AuthInfo {
|
|||
pub sub: String,
|
||||
pub email: String,
|
||||
}
|
||||
|
||||
impl FromRequestParts<AppState> for AuthInfo {
|
||||
type Rejection = AppError;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &AppState,
|
||||
) -> Result<Self, <Self as FromRequestParts<AppState>>::Rejection> {
|
||||
async move {
|
||||
let session = parts
|
||||
.extract_with_state::<AppSession, AppState>(state)
|
||||
.await?
|
||||
.0
|
||||
.ok_or(AppError::auth_redirect_from_base_path(
|
||||
state.settings.base_path.clone(),
|
||||
))?;
|
||||
let user = session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).ok_or(
|
||||
AppError::auth_redirect_from_base_path(state.settings.base_path.clone()),
|
||||
)?;
|
||||
Ok(user)
|
||||
}
|
||||
// The Span.enter() guard pattern doesn't play nicely with async
|
||||
.instrument(trace_span!("AuthInfo from_request_parts()"))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
|
99
src/users.rs
99
src/users.rs
|
@ -1,5 +1,15 @@
|
|||
use anyhow::Context;
|
||||
use axum::{extract::FromRequestParts, http::request::Parts, RequestPartsExt};
|
||||
use async_session::{Session, SessionStore as _};
|
||||
use axum::{
|
||||
extract::{FromRequestParts, OriginalUri},
|
||||
http::{request::Parts, Method},
|
||||
response::{IntoResponse, Redirect, Response},
|
||||
RequestPartsExt,
|
||||
};
|
||||
use axum_extra::extract::{
|
||||
cookie::{Cookie, SameSite},
|
||||
CookieJar,
|
||||
};
|
||||
use diesel::{
|
||||
associations::Identifiable,
|
||||
deserialize::Queryable,
|
||||
|
@ -13,8 +23,9 @@ use uuid::Uuid;
|
|||
use crate::{
|
||||
app_error::AppError,
|
||||
app_state::AppState,
|
||||
auth::AuthInfo,
|
||||
auth::{AuthInfo, SESSION_KEY_AUTH_INFO, SESSION_KEY_AUTH_REDIRECT},
|
||||
schema::{team_memberships, teams, users},
|
||||
sessions::AppSession,
|
||||
team_memberships::TeamMembership,
|
||||
teams::Team,
|
||||
};
|
||||
|
@ -57,20 +68,54 @@ impl<S> FromRequestParts<S> for CurrentUser
|
|||
where
|
||||
S: Into<AppState> + Clone + Sync,
|
||||
{
|
||||
type Rejection = AppError;
|
||||
type Rejection = CurrentUserRejection;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let state: AppState = state.clone().into();
|
||||
let auth_info = parts
|
||||
.extract_with_state::<AuthInfo, AppState>(&state)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
AppError::auth_redirect_from_base_path(state.settings.base_path.clone())
|
||||
})?;
|
||||
let current_user = state
|
||||
.db_pool
|
||||
.get()
|
||||
.await?
|
||||
let app_state: AppState = state.clone().into();
|
||||
let mut session =
|
||||
if let AppSession(Some(value)) = parts.extract_with_state(&app_state).await? {
|
||||
value
|
||||
} else {
|
||||
Session::new()
|
||||
};
|
||||
let auth_info = if let Some(value) = session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO) {
|
||||
value
|
||||
} else {
|
||||
let jar: CookieJar = parts.extract().await?;
|
||||
let method: Method = parts.extract().await?;
|
||||
let jar = if method == Method::GET {
|
||||
let OriginalUri(uri) = parts.extract().await?;
|
||||
session.insert(
|
||||
SESSION_KEY_AUTH_REDIRECT,
|
||||
uri.path_and_query()
|
||||
.map(|value| value.to_string())
|
||||
.unwrap_or(format!("{}/", app_state.settings.base_path)),
|
||||
)?;
|
||||
if let Some(cookie_value) = app_state.session_store.store_session(session).await? {
|
||||
tracing::debug!("adding session cookie to jar");
|
||||
jar.add(
|
||||
Cookie::build((app_state.settings.auth.cookie_name.clone(), cookie_value))
|
||||
.same_site(SameSite::Lax)
|
||||
.http_only(true)
|
||||
.path("/"),
|
||||
)
|
||||
} else {
|
||||
tracing::debug!("inferred that session cookie already in jar");
|
||||
jar
|
||||
}
|
||||
} else {
|
||||
// If request method is not GET then do not attempt to infer the
|
||||
// redirect target, as there may be no GET handler defined for
|
||||
// it.
|
||||
jar
|
||||
};
|
||||
return Err(Self::Rejection::SetCookiesAndRedirect(
|
||||
jar,
|
||||
format!("{}/auth/login", app_state.settings.base_path),
|
||||
));
|
||||
};
|
||||
let db_conn = app_state.db_pool.get().await?;
|
||||
let current_user = db_conn
|
||||
.interact(move |conn| {
|
||||
let maybe_current_user = User::all()
|
||||
.filter(User::with_uid(&auth_info.sub))
|
||||
|
@ -112,3 +157,29 @@ where
|
|||
Ok(CurrentUser(current_user))
|
||||
}
|
||||
}
|
||||
|
||||
pub enum CurrentUserRejection {
|
||||
AppError(AppError),
|
||||
SetCookiesAndRedirect(CookieJar, String),
|
||||
}
|
||||
|
||||
// Easily convert semi-arbitrary errors to InternalServerError
|
||||
impl<E> From<E> for CurrentUserRejection
|
||||
where
|
||||
E: Into<AppError>,
|
||||
{
|
||||
fn from(err: E) -> Self {
|
||||
Self::AppError(err.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for CurrentUserRejection {
|
||||
fn into_response(self) -> Response {
|
||||
match self {
|
||||
Self::AppError(err) => err.into_response(),
|
||||
Self::SetCookiesAndRedirect(jar, redirect_to) => {
|
||||
(jar, Redirect::to(&redirect_to)).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue