shoutdotdev/src/auth.rs

214 lines
7.5 KiB
Rust

use anyhow::{Context, Result};
use async_session::{Session, SessionStore};
use axum::{
extract::{Query, State},
response::{IntoResponse, Redirect},
routing::get,
Router,
};
use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
use oauth2::{
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
ClientSecret, CsrfToken, RedirectUrl, RefreshToken, TokenResponse, TokenUrl,
};
use serde::{Deserialize, Serialize};
use crate::{
app_error::AppError,
app_state::{AppState, ReqwestClient},
sessions::{AppSession, PgStore},
settings::Settings,
};
const SESSION_KEY_AUTH_CSRF_TOKEN: &str = "oauth_csrf_token";
const SESSION_KEY_AUTH_REFRESH_TOKEN: &str = "oauth_refresh_token";
pub const SESSION_KEY_AUTH_INFO: &str = "auth";
pub const SESSION_KEY_AUTH_REDIRECT: &str = "post_auth_redirect";
/// Creates a new OAuth2 client to be stored in global application state.
pub fn new_oauth_client(settings: &Settings) -> Result<BasicClient> {
Ok(BasicClient::new(
ClientId::new(settings.auth.client_id.clone()),
Some(ClientSecret::new(settings.auth.client_secret.clone())),
AuthUrl::new(settings.auth.auth_url.clone())
.context("failed to create new authorization server URL")?,
Some(
TokenUrl::new(settings.auth.token_url.clone())
.context("failed to create new token endpoint URL")?,
),
)
.set_redirect_uri(
RedirectUrl::new(settings.auth.redirect_url.clone())
.context("failed to create new redirection URL")?,
))
}
/// Creates a router which can be nested within the higher level app router.
pub fn new_router() -> Router<AppState> {
Router::new()
.route("/login", get(start_login))
.route("/callback", get(callback))
.route("/logout", get(logout))
}
/// HTTP get handler for /login
async fn start_login(
State(state): State<AppState>,
State(Settings {
auth: auth_settings,
base_path,
..
}): State<Settings>,
State(session_store): State<PgStore>,
AppSession(maybe_session): AppSession,
jar: CookieJar,
) -> Result<impl IntoResponse, AppError> {
let mut session = if let Some(value) = maybe_session {
value
} else {
Session::new()
};
if session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).is_some() {
tracing::debug!("already logged in, redirecting...");
return Ok(Redirect::to(&format!("{}/", base_path)).into_response());
}
assert!(session.get_raw(SESSION_KEY_AUTH_REFRESH_TOKEN).is_none());
let csrf_token = CsrfToken::new_random();
session.insert(SESSION_KEY_AUTH_CSRF_TOKEN, &csrf_token)?;
let (auth_url, _csrf_token) = state.oauth_client.authorize_url(|| csrf_token).url();
let jar = if let Some(cookie_value) = session_store.store_session(session).await? {
tracing::debug!("adding session cookie to jar");
jar.add(
Cookie::build((auth_settings.cookie_name.clone(), cookie_value))
.same_site(SameSite::Lax)
.http_only(true)
.path("/"),
)
} else {
tracing::debug!("inferred that session cookie already in jar");
jar
};
Ok((jar, Redirect::to(auth_url.as_ref())).into_response())
}
/// HTTP get handler for /logout
async fn logout(
State(Settings {
base_path,
auth: auth_settings,
..
}): State<Settings>,
State(ReqwestClient(reqwest_client)): State<ReqwestClient>,
State(session_store): State<PgStore>,
AppSession(session): AppSession,
jar: CookieJar,
) -> Result<impl IntoResponse, AppError> {
if let Some(session) = session {
tracing::debug!("Session {} loaded.", session.id());
if let Some(logout_url) = auth_settings.logout_url {
tracing::debug!("attempting to send logout request to oauth provider");
let refresh_token: Option<RefreshToken> = session.get(SESSION_KEY_AUTH_REFRESH_TOKEN);
if let Some(refresh_token) = refresh_token {
tracing::debug!("Sending logout request to OAuth provider.");
#[derive(Serialize)]
struct LogoutRequestBody {
refresh_token: String,
}
reqwest_client
.post(logout_url)
.json(&LogoutRequestBody {
refresh_token: refresh_token.secret().to_owned(),
})
.send()
.await?
.error_for_status()?;
tracing::debug!("Sent logout request to OAuth provider successfully.");
}
}
session_store.destroy_session(session).await?;
}
let jar = jar.remove(Cookie::from(auth_settings.cookie_name));
tracing::debug!("Removed session cookie from jar.");
Ok((jar, Redirect::to(&format!("{}/", base_path))))
}
#[derive(Debug, Deserialize)]
struct AuthRequestQuery {
code: String,
/// CSRF token
state: String,
}
/// HTTP get handler for /callback
async fn callback(
Query(query): Query<AuthRequestQuery>,
State(state): State<AppState>,
State(Settings {
auth: auth_settings,
base_path,
..
}): State<Settings>,
State(ReqwestClient(reqwest_client)): State<ReqwestClient>,
AppSession(session): AppSession,
) -> Result<impl IntoResponse, AppError> {
let mut session = session.ok_or_else(|| {
tracing::debug!("unable to load session");
AppError::Forbidden(
"our apologies: authentication session expired or lost, please try again".to_owned(),
)
})?;
let session_csrf_token: String = session.get(SESSION_KEY_AUTH_CSRF_TOKEN).ok_or_else(|| {
tracing::debug!("oauth csrf token not found on session");
AppError::Forbidden(
"our apologies: authentication session expired or lost, please try again".to_owned(),
)
})?;
if session_csrf_token != query.state {
tracing::debug!("oauth csrf tokens did not match");
return Err(AppError::ForbiddenError(
"OAuth CSRF tokens do not match.".to_string(),
));
}
tracing::debug!("exchanging authorization code");
let response = state
.oauth_client
.exchange_code(AuthorizationCode::new(query.code.clone()))
.request_async(async_http_client)
.await?;
tracing::debug!("fetching user info");
let auth_info: AuthInfo = reqwest_client
.get(auth_settings.userinfo_url.as_str())
.bearer_auth(response.access_token().secret())
.send()
.await?
.json()
.await?;
tracing::debug!("updating session");
let redirect_target: Option<String> = session.get(SESSION_KEY_AUTH_REDIRECT);
// Remove this since we don't need or want it sticking around, for both UX
// and security hygiene reasons
session.remove(SESSION_KEY_AUTH_REDIRECT);
session.insert(SESSION_KEY_AUTH_INFO, &auth_info)?;
session.insert(SESSION_KEY_AUTH_REFRESH_TOKEN, response.refresh_token())?;
if state.session_store.store_session(session).await?.is_some() {
return Err(anyhow::anyhow!(
"expected cookie value returned by store_session() to be None for existing session"
)
.into());
}
tracing::debug!("successfully authenticated");
Ok(Redirect::to(
&redirect_target.unwrap_or(format!("{}/", base_path)),
))
}
/// Data stored in the visitor's session upon successful authentication.
#[derive(Debug, Deserialize, Serialize)]
pub struct AuthInfo {
pub sub: String,
pub email: String,
}