use anyhow::{Context, Result}; use async_session::{Session, SessionStore}; use axum::{ extract::{Query, State}, response::{IntoResponse, Redirect}, routing::get, Router, }; use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite}; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RefreshToken, TokenResponse, TokenUrl, }; use serde::{Deserialize, Serialize}; use crate::{ app_error::AppError, app_state::{AppState, ReqwestClient}, sessions::{AppSession, PgStore}, settings::Settings, }; const SESSION_KEY_AUTH_CSRF_TOKEN: &str = "oauth_csrf_token"; const SESSION_KEY_AUTH_REFRESH_TOKEN: &str = "oauth_refresh_token"; pub const SESSION_KEY_AUTH_INFO: &str = "auth"; pub const SESSION_KEY_AUTH_REDIRECT: &str = "post_auth_redirect"; /// Creates a new OAuth2 client to be stored in global application state. pub fn new_oauth_client(settings: &Settings) -> Result { Ok(BasicClient::new( ClientId::new(settings.auth.client_id.clone()), Some(ClientSecret::new(settings.auth.client_secret.clone())), AuthUrl::new(settings.auth.auth_url.clone()) .context("failed to create new authorization server URL")?, Some( TokenUrl::new(settings.auth.token_url.clone()) .context("failed to create new token endpoint URL")?, ), ) .set_redirect_uri( RedirectUrl::new(settings.auth.redirect_url.clone()) .context("failed to create new redirection URL")?, )) } /// Creates a router which can be nested within the higher level app router. pub fn new_router() -> Router { Router::new() .route("/login", get(start_login)) .route("/callback", get(callback)) .route("/logout", get(logout)) } /// HTTP get handler for /login async fn start_login( State(state): State, State(Settings { auth: auth_settings, base_path, .. }): State, State(session_store): State, AppSession(maybe_session): AppSession, jar: CookieJar, ) -> Result { let mut session = if let Some(value) = maybe_session { value } else { Session::new() }; if session.get::(SESSION_KEY_AUTH_INFO).is_some() { tracing::debug!("already logged in, redirecting..."); return Ok(Redirect::to(&format!("{}/", base_path)).into_response()); } assert!(session.get_raw(SESSION_KEY_AUTH_REFRESH_TOKEN).is_none()); let csrf_token = CsrfToken::new_random(); session.insert(SESSION_KEY_AUTH_CSRF_TOKEN, &csrf_token)?; let (auth_url, _csrf_token) = state.oauth_client.authorize_url(|| csrf_token).url(); let jar = if let Some(cookie_value) = session_store.store_session(session).await? { tracing::debug!("adding session cookie to jar"); jar.add( Cookie::build((auth_settings.cookie_name.clone(), cookie_value)) .same_site(SameSite::Lax) .http_only(true) .path("/"), ) } else { tracing::debug!("inferred that session cookie already in jar"); jar }; Ok((jar, Redirect::to(auth_url.as_ref())).into_response()) } /// HTTP get handler for /logout async fn logout( State(Settings { base_path, auth: auth_settings, .. }): State, State(ReqwestClient(reqwest_client)): State, State(session_store): State, AppSession(session): AppSession, jar: CookieJar, ) -> Result { if let Some(session) = session { tracing::debug!("Session {} loaded.", session.id()); if let Some(logout_url) = auth_settings.logout_url { tracing::debug!("attempting to send logout request to oauth provider"); let refresh_token: Option = session.get(SESSION_KEY_AUTH_REFRESH_TOKEN); if let Some(refresh_token) = refresh_token { tracing::debug!("Sending logout request to OAuth provider."); #[derive(Serialize)] struct LogoutRequestBody { refresh_token: String, } reqwest_client .post(logout_url) .json(&LogoutRequestBody { refresh_token: refresh_token.secret().to_owned(), }) .send() .await? .error_for_status()?; tracing::debug!("Sent logout request to OAuth provider successfully."); } } session_store.destroy_session(session).await?; } let jar = jar.remove(Cookie::from(auth_settings.cookie_name)); tracing::debug!("Removed session cookie from jar."); Ok((jar, Redirect::to(&format!("{}/", base_path)))) } #[derive(Debug, Deserialize)] struct AuthRequestQuery { code: String, /// CSRF token state: String, } /// HTTP get handler for /callback async fn callback( Query(query): Query, State(state): State, State(Settings { auth: auth_settings, base_path, .. }): State, State(ReqwestClient(reqwest_client)): State, AppSession(session): AppSession, ) -> Result { let mut session = session.ok_or_else(|| { tracing::debug!("unable to load session"); AppError::Forbidden( "our apologies: authentication session expired or lost, please try again".to_owned(), ) })?; let session_csrf_token: String = session.get(SESSION_KEY_AUTH_CSRF_TOKEN).ok_or_else(|| { tracing::debug!("oauth csrf token not found on session"); AppError::Forbidden( "our apologies: authentication session expired or lost, please try again".to_owned(), ) })?; if session_csrf_token != query.state { tracing::debug!("oauth csrf tokens did not match"); return Err(AppError::ForbiddenError( "OAuth CSRF tokens do not match.".to_string(), )); } tracing::debug!("exchanging authorization code"); let response = state .oauth_client .exchange_code(AuthorizationCode::new(query.code.clone())) .request_async(async_http_client) .await?; tracing::debug!("fetching user info"); let auth_info: AuthInfo = reqwest_client .get(auth_settings.userinfo_url.as_str()) .bearer_auth(response.access_token().secret()) .send() .await? .json() .await?; tracing::debug!("updating session"); let redirect_target: Option = session.get(SESSION_KEY_AUTH_REDIRECT); // Remove this since we don't need or want it sticking around, for both UX // and security hygiene reasons session.remove(SESSION_KEY_AUTH_REDIRECT); session.insert(SESSION_KEY_AUTH_INFO, &auth_info)?; session.insert(SESSION_KEY_AUTH_REFRESH_TOKEN, response.refresh_token())?; if state.session_store.store_session(session).await?.is_some() { return Err(anyhow::anyhow!( "expected cookie value returned by store_session() to be None for existing session" ) .into()); } tracing::debug!("successfully authenticated"); Ok(Redirect::to( &redirect_target.unwrap_or(format!("{}/", base_path)), )) } /// Data stored in the visitor's session upon successful authentication. #[derive(Debug, Deserialize, Serialize)] pub struct AuthInfo { pub sub: String, pub email: String, }