shoutdotdev/src/auth.rs

214 lines
7.3 KiB
Rust
Raw Normal View History

2025-02-26 13:10:50 -08:00
use anyhow::{Context, Result};
use async_session::{Session, SessionStore as _};
use axum::{
extract::{FromRequestParts, Query, State},
http::request::Parts,
2025-02-26 13:10:45 -08:00
response::{IntoResponse, Redirect},
2025-02-26 13:10:50 -08:00
routing::get,
RequestPartsExt, Router,
};
2025-02-26 13:10:45 -08:00
use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
2025-02-26 13:10:50 -08:00
use oauth2::{
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
2025-02-26 13:10:45 -08:00
ClientSecret, CsrfToken, RedirectUrl, RefreshToken, TokenResponse, TokenUrl,
2025-02-26 13:10:50 -08:00
};
use serde::{Deserialize, Serialize};
2025-02-26 13:10:45 -08:00
use tracing::trace_span;
use crate::{
app_error::AppError,
app_state::{AppState, ReqwestClient},
sessions::{AppSession, PgStore},
settings::Settings,
};
2025-02-26 13:10:50 -08:00
2025-02-26 13:10:45 -08:00
const SESSION_KEY_AUTH_CSRF_TOKEN: &'static str = "oauth_csrf_token";
const SESSION_KEY_AUTH_REFRESH_TOKEN: &'static str = "oauth_refresh_token";
const SESSION_KEY_AUTH_INFO: &'static str = "auth";
2025-02-26 13:10:50 -08:00
pub fn new_oauth_client(settings: &Settings) -> Result<BasicClient, AppError> {
Ok(BasicClient::new(
ClientId::new(settings.auth.client_id.clone()),
Some(ClientSecret::new(settings.auth.client_secret.clone())),
AuthUrl::new(settings.auth.auth_url.clone())
.context("failed to create new authorization server URL")?,
Some(
TokenUrl::new(settings.auth.token_url.clone())
.context("failed to create new token endpoint URL")?,
),
)
.set_redirect_uri(
RedirectUrl::new(settings.auth.redirect_url.clone())
.context("failed to create new redirection URL")?,
))
}
pub fn new_router() -> Router<AppState> {
Router::new()
.route("/login", get(propel_auth))
.route("/callback", get(login_authorized))
.route("/logout", get(logout))
}
2025-02-26 13:10:45 -08:00
pub async fn propel_auth(
State(state): State<AppState>,
State(Settings {
auth: auth_settings,
base_path,
..
}): State<Settings>,
State(session_store): State<PgStore>,
AppSession(maybe_session): AppSession,
jar: CookieJar,
) -> Result<impl IntoResponse, AppError> {
if let Some(session) = maybe_session {
if session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).is_some() {
tracing::debug!("already logged in, redirecting...");
return Ok(Redirect::to(&base_path).into_response());
}
}
2025-02-26 13:10:50 -08:00
let csrf_token = CsrfToken::new_random();
let (auth_url, _csrf_token) = state
.oauth_client
2025-02-26 13:10:45 -08:00
.authorize_url(|| csrf_token.clone())
2025-02-26 13:10:50 -08:00
.url();
2025-02-26 13:10:45 -08:00
let mut session = Session::new();
session.insert(SESSION_KEY_AUTH_CSRF_TOKEN, &csrf_token)?;
let cookie_value = session_store
.store_session(session)
.await?
.ok_or(anyhow::anyhow!("cookie value from store_session() is None"))?;
let jar = jar.add(
Cookie::build((auth_settings.cookie_name.clone(), cookie_value))
.same_site(SameSite::Lax)
.http_only(true)
.path("/"),
);
Ok((jar, Redirect::to(&auth_url.to_string())).into_response())
2025-02-26 13:10:50 -08:00
}
pub async fn logout(
2025-02-26 13:10:45 -08:00
State(Settings {
base_path,
auth: auth_settings,
..
}): State<Settings>,
State(ReqwestClient(reqwest_client)): State<ReqwestClient>,
State(session_store): State<PgStore>,
AppSession(session): AppSession,
jar: CookieJar,
2025-02-26 13:10:50 -08:00
) -> Result<impl IntoResponse, AppError> {
2025-02-26 13:10:45 -08:00
if let Some(session) = session {
tracing::debug!("Session {} loaded.", session.id());
if let Some(logout_url) = auth_settings.logout_url {
tracing::debug!("attempting to send logout request to oauth provider");
let refresh_token: Option<RefreshToken> = session.get(SESSION_KEY_AUTH_REFRESH_TOKEN);
if let Some(refresh_token) = refresh_token {
tracing::debug!("Sending logout request to OAuth provider.");
#[derive(Serialize)]
struct LogoutRequestBody {
refresh_token: String,
}
reqwest_client
.post(logout_url)
.json(&LogoutRequestBody {
refresh_token: refresh_token.secret().to_owned(),
})
.send()
.await?
.error_for_status()?;
tracing::debug!("Sent logout request to OAuth provider successfully.");
}
}
session_store.destroy_session(session).await?;
}
let jar = jar.remove(Cookie::from(auth_settings.cookie_name));
tracing::debug!("Removed session cookie from jar.");
Ok((jar, Redirect::to(&base_path)))
2025-02-26 13:10:50 -08:00
}
#[derive(Debug, Deserialize)]
pub struct AuthRequestQuery {
code: String,
state: String, // CSRF token
}
#[derive(Debug, Deserialize, Serialize)]
pub struct AuthInfo {
pub sub: String,
pub email: String,
}
pub async fn login_authorized(
Query(query): Query<AuthRequestQuery>,
State(state): State<AppState>,
2025-02-26 13:10:45 -08:00
State(Settings {
auth: auth_settings,
base_path,
..
}): State<Settings>,
State(ReqwestClient(reqwest_client)): State<ReqwestClient>,
AppSession(session): AppSession,
2025-02-26 13:10:50 -08:00
) -> Result<impl IntoResponse, AppError> {
2025-02-26 13:10:45 -08:00
let mut session = if let Some(session) = session {
session
} else {
return Err(AppError::auth_redirect_from_base_path(
state.settings.base_path,
));
};
let session_csrf_token: String = session.get(SESSION_KEY_AUTH_CSRF_TOKEN).ok_or_else(|| {
tracing::debug!("oauth csrf token not found on session");
AppError::auth_redirect_from_base_path(base_path.clone())
})?;
if session_csrf_token != query.state {
tracing::debug!("oauth csrf tokens did not match");
return Err(AppError::ForbiddenError(
"OAuth CSRF tokens do not match.".to_string(),
));
}
2025-02-26 13:10:50 -08:00
let response = state
.oauth_client
.exchange_code(AuthorizationCode::new(query.code.clone()))
.request_async(async_http_client)
.await?;
2025-02-26 13:10:45 -08:00
let auth_info: AuthInfo = reqwest_client
.get(auth_settings.userinfo_url.as_str())
.bearer_auth(response.access_token().secret())
.send()
2025-02-26 13:10:50 -08:00
.await?
2025-02-26 13:10:45 -08:00
.json()
.await?;
session.insert(SESSION_KEY_AUTH_INFO, &auth_info)?;
session.insert(SESSION_KEY_AUTH_REFRESH_TOKEN, response.refresh_token())?;
if state.session_store.store_session(session).await?.is_some() {
return Err(anyhow::anyhow!(
"expected cookie value returned by store_session() to be None for existing session"
)
.into());
2025-02-26 13:10:50 -08:00
}
2025-02-26 13:10:45 -08:00
Ok(Redirect::to(&base_path))
2025-02-26 13:10:50 -08:00
}
impl FromRequestParts<AppState> for AuthInfo {
2025-02-26 13:10:45 -08:00
type Rejection = AppError;
2025-02-26 13:10:50 -08:00
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, <Self as FromRequestParts<AppState>>::Rejection> {
let _ = trace_span!("AuthInfo from_request_parts()").enter();
2025-02-26 13:10:45 -08:00
let session = parts
.extract_with_state::<AppSession, AppState>(state)
.await?
.0
.ok_or(AppError::auth_redirect_from_base_path(
state.settings.base_path.clone(),
))?;
let user = session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).ok_or(
AppError::auth_redirect_from_base_path(state.settings.base_path.clone()),
)?;
2025-02-26 13:10:50 -08:00
Ok(user)
}
}