use anyhow::{Context, Result}; use async_session::{Session, SessionStore as _}; use axum::{ extract::{FromRequestParts, Query, State}, http::request::Parts, response::{IntoResponse, Redirect}, routing::get, RequestPartsExt, Router, }; use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite}; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RefreshToken, TokenResponse, TokenUrl, }; use serde::{Deserialize, Serialize}; use tracing::trace_span; use crate::{ app_error::AppError, app_state::{AppState, ReqwestClient}, sessions::{AppSession, PgStore}, settings::Settings, }; const SESSION_KEY_AUTH_CSRF_TOKEN: &'static str = "oauth_csrf_token"; const SESSION_KEY_AUTH_REFRESH_TOKEN: &'static str = "oauth_refresh_token"; const SESSION_KEY_AUTH_INFO: &'static str = "auth"; pub fn new_oauth_client(settings: &Settings) -> Result { Ok(BasicClient::new( ClientId::new(settings.auth.client_id.clone()), Some(ClientSecret::new(settings.auth.client_secret.clone())), AuthUrl::new(settings.auth.auth_url.clone()) .context("failed to create new authorization server URL")?, Some( TokenUrl::new(settings.auth.token_url.clone()) .context("failed to create new token endpoint URL")?, ), ) .set_redirect_uri( RedirectUrl::new(settings.auth.redirect_url.clone()) .context("failed to create new redirection URL")?, )) } pub fn new_router() -> Router { Router::new() .route("/login", get(propel_auth)) .route("/callback", get(login_authorized)) .route("/logout", get(logout)) } pub async fn propel_auth( State(state): State, State(Settings { auth: auth_settings, base_path, .. }): State, State(session_store): State, AppSession(maybe_session): AppSession, jar: CookieJar, ) -> Result { if let Some(session) = maybe_session { if session.get::(SESSION_KEY_AUTH_INFO).is_some() { tracing::debug!("already logged in, redirecting..."); return Ok(Redirect::to(&base_path).into_response()); } } let csrf_token = CsrfToken::new_random(); let (auth_url, _csrf_token) = state .oauth_client .authorize_url(|| csrf_token.clone()) .url(); let mut session = Session::new(); session.insert(SESSION_KEY_AUTH_CSRF_TOKEN, &csrf_token)?; let cookie_value = session_store .store_session(session) .await? .ok_or(anyhow::anyhow!("cookie value from store_session() is None"))?; let jar = jar.add( Cookie::build((auth_settings.cookie_name.clone(), cookie_value)) .same_site(SameSite::Lax) .http_only(true) .path("/"), ); Ok((jar, Redirect::to(&auth_url.to_string())).into_response()) } pub async fn logout( State(Settings { base_path, auth: auth_settings, .. }): State, State(ReqwestClient(reqwest_client)): State, State(session_store): State, AppSession(session): AppSession, jar: CookieJar, ) -> Result { if let Some(session) = session { tracing::debug!("Session {} loaded.", session.id()); if let Some(logout_url) = auth_settings.logout_url { tracing::debug!("attempting to send logout request to oauth provider"); let refresh_token: Option = session.get(SESSION_KEY_AUTH_REFRESH_TOKEN); if let Some(refresh_token) = refresh_token { tracing::debug!("Sending logout request to OAuth provider."); #[derive(Serialize)] struct LogoutRequestBody { refresh_token: String, } reqwest_client .post(logout_url) .json(&LogoutRequestBody { refresh_token: refresh_token.secret().to_owned(), }) .send() .await? .error_for_status()?; tracing::debug!("Sent logout request to OAuth provider successfully."); } } session_store.destroy_session(session).await?; } let jar = jar.remove(Cookie::from(auth_settings.cookie_name)); tracing::debug!("Removed session cookie from jar."); Ok((jar, Redirect::to(&base_path))) } #[derive(Debug, Deserialize)] pub struct AuthRequestQuery { code: String, state: String, // CSRF token } #[derive(Debug, Deserialize, Serialize)] pub struct AuthInfo { pub sub: String, pub email: String, } pub async fn login_authorized( Query(query): Query, State(state): State, State(Settings { auth: auth_settings, base_path, .. }): State, State(ReqwestClient(reqwest_client)): State, AppSession(session): AppSession, ) -> Result { let mut session = if let Some(session) = session { session } else { return Err(AppError::auth_redirect_from_base_path( state.settings.base_path, )); }; let session_csrf_token: String = session.get(SESSION_KEY_AUTH_CSRF_TOKEN).ok_or_else(|| { tracing::debug!("oauth csrf token not found on session"); AppError::auth_redirect_from_base_path(base_path.clone()) })?; if session_csrf_token != query.state { tracing::debug!("oauth csrf tokens did not match"); return Err(AppError::ForbiddenError( "OAuth CSRF tokens do not match.".to_string(), )); } let response = state .oauth_client .exchange_code(AuthorizationCode::new(query.code.clone())) .request_async(async_http_client) .await?; let auth_info: AuthInfo = reqwest_client .get(auth_settings.userinfo_url.as_str()) .bearer_auth(response.access_token().secret()) .send() .await? .json() .await?; session.insert(SESSION_KEY_AUTH_INFO, &auth_info)?; session.insert(SESSION_KEY_AUTH_REFRESH_TOKEN, response.refresh_token())?; if state.session_store.store_session(session).await?.is_some() { return Err(anyhow::anyhow!( "expected cookie value returned by store_session() to be None for existing session" ) .into()); } Ok(Redirect::to(&base_path)) } impl FromRequestParts for AuthInfo { type Rejection = AppError; async fn from_request_parts( parts: &mut Parts, state: &AppState, ) -> Result>::Rejection> { let _ = trace_span!("AuthInfo from_request_parts()").enter(); let session = parts .extract_with_state::(state) .await? .0 .ok_or(AppError::auth_redirect_from_base_path( state.settings.base_path.clone(), ))?; let user = session.get::(SESSION_KEY_AUTH_INFO).ok_or( AppError::auth_redirect_from_base_path(state.settings.base_path.clone()), )?; Ok(user) } }