From 401fcde4ce2422871590fc7bb6e395101bb95421 Mon Sep 17 00:00:00 2001 From: Brent Schroeter Date: Wed, 26 Feb 2025 13:10:45 -0800 Subject: [PATCH] tighten oauth login/logout flows --- migrations/2025-01-08-211839_sessions/up.sql | 4 +- src/app_error.rs | 19 +- src/app_state.rs | 8 +- src/auth.rs | 227 +++++++++++-------- src/main.rs | 17 +- src/schema.rs | 2 +- src/sessions.rs | 107 +++++++-- src/settings.rs | 3 +- templates/nav.html | 2 +- 9 files changed, 260 insertions(+), 129 deletions(-) diff --git a/migrations/2025-01-08-211839_sessions/up.sql b/migrations/2025-01-08-211839_sessions/up.sql index f3a2db6..8cba36c 100644 --- a/migrations/2025-01-08-211839_sessions/up.sql +++ b/migrations/2025-01-08-211839_sessions/up.sql @@ -2,7 +2,7 @@ CREATE TABLE browser_sessions ( id TEXT NOT NULL PRIMARY KEY, serialized TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - last_seen_at TIMESTAMPTZ NOT NULL + expiry TIMESTAMPTZ ); -CREATE INDEX ON browser_sessions (last_seen_at); +CREATE INDEX ON browser_sessions (expiry); CREATE INDEX ON browser_sessions (created_at); diff --git a/src/app_error.rs b/src/app_error.rs index ca02d73..7205ca0 100644 --- a/src/app_error.rs +++ b/src/app_error.rs @@ -1,7 +1,12 @@ use std::fmt::{self, Display}; use axum::http::StatusCode; -use axum::response::{IntoResponse, Response}; +use axum::response::{IntoResponse, Redirect, Response}; + +#[derive(Debug)] +pub struct AuthRedirectInfo { + base_path: String, +} // Use anyhow, define error and enable '?' // For a simplified example of using anyhow in axum check /examples/anyhow-error-response @@ -11,12 +16,23 @@ pub enum AppError { ForbiddenError(String), NotFoundError(String), BadRequestError(String), + AuthRedirect(AuthRedirectInfo), +} + +impl AppError { + pub fn auth_redirect_from_base_path(base_path: String) -> Self { + Self::AuthRedirect(AuthRedirectInfo { base_path }) + } } // Tell axum how to convert `AppError` into a response. impl IntoResponse for AppError { fn into_response(self) -> Response { match self { + Self::AuthRedirect(AuthRedirectInfo { base_path }) => { + tracing::debug!("Handling AuthRedirect"); + Redirect::to(&format!("{}/auth/login", base_path)).into_response() + } Self::InternalServerError(err) => { tracing::error!("Application error: {:?}", err); (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response() @@ -51,6 +67,7 @@ where impl Display for AppError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { + AppError::AuthRedirect(info) => write!(f, "AuthRedirect: {:?}", info), AppError::InternalServerError(inner) => inner.fmt(f), AppError::ForbiddenError(client_message) => { write!(f, "ForbiddenError: {}", client_message) diff --git a/src/app_state.rs b/src/app_state.rs index 8558f98..bd95a1b 100644 --- a/src/app_state.rs +++ b/src/app_state.rs @@ -12,6 +12,7 @@ use crate::{app_error::AppError, sessions::PgStore, settings::Settings}; pub struct AppState { pub db_pool: Pool, pub mailer: Mailer, + pub reqwest_client: reqwest::Client, pub oauth_client: BasicClient, pub session_store: PgStore, pub settings: Settings, @@ -26,9 +27,12 @@ impl FromRef for Mailer { } } -impl FromRef for PgStore { +#[derive(Clone)] +pub struct ReqwestClient(pub reqwest::Client); + +impl FromRef for ReqwestClient { fn from_ref(state: &AppState) -> Self { - state.session_store.clone() + ReqwestClient(state.reqwest_client.clone()) } } diff --git a/src/auth.rs b/src/auth.rs index 3a03387..d3ba448 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -3,23 +3,28 @@ use async_session::{Session, SessionStore as _}; use axum::{ extract::{FromRequestParts, Query, State}, http::request::Parts, - response::{IntoResponse, Redirect, Response}, + response::{IntoResponse, Redirect}, routing::get, RequestPartsExt, Router, }; -use axum_extra::{ - extract::cookie::{Cookie, CookieJar, SameSite}, - headers, TypedHeader, -}; -use diesel::prelude::*; +use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite}; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, - ClientSecret, CsrfToken, RedirectUrl, TokenResponse, TokenUrl, + ClientSecret, CsrfToken, RedirectUrl, RefreshToken, TokenResponse, TokenUrl, }; use serde::{Deserialize, Serialize}; -use tracing::{debug, trace_span}; +use tracing::trace_span; -use crate::{app_error::AppError, app_state::AppState, schema, settings::Settings}; +use crate::{ + app_error::AppError, + app_state::{AppState, ReqwestClient}, + sessions::{AppSession, PgStore}, + settings::Settings, +}; + +const SESSION_KEY_AUTH_CSRF_TOKEN: &'static str = "oauth_csrf_token"; +const SESSION_KEY_AUTH_REFRESH_TOKEN: &'static str = "oauth_refresh_token"; +const SESSION_KEY_AUTH_INFO: &'static str = "auth"; pub fn new_oauth_client(settings: &Settings) -> Result { Ok(BasicClient::new( @@ -45,38 +50,81 @@ pub fn new_router() -> Router { .route("/logout", get(logout)) } -pub async fn propel_auth(State(state): State) -> impl IntoResponse { +pub async fn propel_auth( + State(state): State, + State(Settings { + auth: auth_settings, + base_path, + .. + }): State, + State(session_store): State, + AppSession(maybe_session): AppSession, + jar: CookieJar, +) -> Result { + if let Some(session) = maybe_session { + if session.get::(SESSION_KEY_AUTH_INFO).is_some() { + tracing::debug!("already logged in, redirecting..."); + return Ok(Redirect::to(&base_path).into_response()); + } + } let csrf_token = CsrfToken::new_random(); let (auth_url, _csrf_token) = state .oauth_client - .authorize_url(|| csrf_token) - // .add_scopes(vec![Scope::new("openid".to_string())]) + .authorize_url(|| csrf_token.clone()) .url(); - // FIXME: check CSRF token - Redirect::to(auth_url.as_ref()) + let mut session = Session::new(); + session.insert(SESSION_KEY_AUTH_CSRF_TOKEN, &csrf_token)?; + let cookie_value = session_store + .store_session(session) + .await? + .ok_or(anyhow::anyhow!("cookie value from store_session() is None"))?; + let jar = jar.add( + Cookie::build((auth_settings.cookie_name.clone(), cookie_value)) + .same_site(SameSite::Lax) + .http_only(true) + .path("/"), + ); + Ok((jar, Redirect::to(&auth_url.to_string())).into_response()) } pub async fn logout( - State(state): State, - TypedHeader(cookies): TypedHeader, + State(Settings { + base_path, + auth: auth_settings, + .. + }): State, + State(ReqwestClient(reqwest_client)): State, + State(session_store): State, + AppSession(session): AppSession, + jar: CookieJar, ) -> Result { - let cookie = cookies - .get(state.settings.auth.cookie_name.as_str()) - .context("couldn't get session cookie")?; - let session_id = Session::id_from_cookie_value(cookie)?; - state - .db_pool - .get() - .await? - .interact(move |conn| { - diesel::delete(schema::browser_sessions::table) - .filter(schema::browser_sessions::id.eq(session_id)) - .execute(conn) - }) - .await - .unwrap()?; - // FIXME: call logout endpoint of OIDC provider - Ok(Redirect::to(&state.settings.base_path)) + if let Some(session) = session { + tracing::debug!("Session {} loaded.", session.id()); + if let Some(logout_url) = auth_settings.logout_url { + tracing::debug!("attempting to send logout request to oauth provider"); + let refresh_token: Option = session.get(SESSION_KEY_AUTH_REFRESH_TOKEN); + if let Some(refresh_token) = refresh_token { + tracing::debug!("Sending logout request to OAuth provider."); + #[derive(Serialize)] + struct LogoutRequestBody { + refresh_token: String, + } + reqwest_client + .post(logout_url) + .json(&LogoutRequestBody { + refresh_token: refresh_token.secret().to_owned(), + }) + .send() + .await? + .error_for_status()?; + tracing::debug!("Sent logout request to OAuth provider successfully."); + } + } + session_store.destroy_session(session).await?; + } + let jar = jar.remove(Cookie::from(auth_settings.cookie_name)); + tracing::debug!("Removed session cookie from jar."); + Ok((jar, Redirect::to(&base_path))) } #[derive(Debug, Deserialize)] @@ -85,96 +133,81 @@ pub struct AuthRequestQuery { state: String, // CSRF token } -pub const AUTH_INFO_SESSION_KEY: &'static str = "user"; - #[derive(Debug, Deserialize, Serialize)] pub struct AuthInfo { pub sub: String, pub email: String, } -async fn get_user_info( - settings: &Settings, - access_token: &oauth2::AccessToken, -) -> Result { - let client = reqwest::Client::new(); - Ok(client - .get(settings.auth.userinfo_url.as_str()) - .bearer_auth(access_token.secret()) - .send() - .await? - .json() - .await?) -} - pub async fn login_authorized( Query(query): Query, State(state): State, - jar: CookieJar, + State(Settings { + auth: auth_settings, + base_path, + .. + }): State, + State(ReqwestClient(reqwest_client)): State, + AppSession(session): AppSession, ) -> Result { + let mut session = if let Some(session) = session { + session + } else { + return Err(AppError::auth_redirect_from_base_path( + state.settings.base_path, + )); + }; + let session_csrf_token: String = session.get(SESSION_KEY_AUTH_CSRF_TOKEN).ok_or_else(|| { + tracing::debug!("oauth csrf token not found on session"); + AppError::auth_redirect_from_base_path(base_path.clone()) + })?; + if session_csrf_token != query.state { + tracing::debug!("oauth csrf tokens did not match"); + return Err(AppError::ForbiddenError( + "OAuth CSRF tokens do not match.".to_string(), + )); + } let response = state .oauth_client .exchange_code(AuthorizationCode::new(query.code.clone())) .request_async(async_http_client) .await?; - let user_info = get_user_info(&state.settings, response.access_token()).await?; - let mut session = Session::new(); - session.insert(AUTH_INFO_SESSION_KEY, &user_info)?; - let cookie_value = state - .session_store - .store_session(session) + let auth_info: AuthInfo = reqwest_client + .get(auth_settings.userinfo_url.as_str()) + .bearer_auth(response.access_token().secret()) + .send() .await? - .context("cookie value from store_session() is None")?; - let jar = jar.add( - Cookie::build((state.settings.auth.cookie_name.clone(), cookie_value)) - .same_site(SameSite::Lax) - .http_only(true) - .path("/"), - ); - Ok((jar, Redirect::to(&format!("{}/", state.settings.base_path)))) -} - -pub struct AuthRedirect { - base_path: String, -} - -impl AuthRedirect { - pub fn new(base_path: &str) -> Self { - Self { - base_path: base_path.to_string(), - } - } -} - -impl IntoResponse for AuthRedirect { - fn into_response(self) -> Response { - Redirect::to(&format!("{}/auth/login", self.base_path)).into_response() + .json() + .await?; + session.insert(SESSION_KEY_AUTH_INFO, &auth_info)?; + session.insert(SESSION_KEY_AUTH_REFRESH_TOKEN, response.refresh_token())?; + if state.session_store.store_session(session).await?.is_some() { + return Err(anyhow::anyhow!( + "expected cookie value returned by store_session() to be None for existing session" + ) + .into()); } + Ok(Redirect::to(&base_path)) } impl FromRequestParts for AuthInfo { - type Rejection = AuthRedirect; + type Rejection = AppError; async fn from_request_parts( parts: &mut Parts, state: &AppState, ) -> Result>::Rejection> { let _ = trace_span!("AuthInfo from_request_parts()").enter(); - let jar = parts.extract::().await.unwrap(); - let session_cookie = jar - .get(&state.settings.auth.cookie_name) - .ok_or(AuthRedirect::new(&state.settings.base_path))?; - debug!("session cookie loaded"); - let session = state - .session_store - .load_session(session_cookie.value().to_string()) - .await - .unwrap() - .ok_or(AuthRedirect::new(&state.settings.base_path))?; - debug!("session loaded"); - let user = session - .get::(AUTH_INFO_SESSION_KEY) - .ok_or(AuthRedirect::new(&state.settings.base_path))?; + let session = parts + .extract_with_state::(state) + .await? + .0 + .ok_or(AppError::auth_redirect_from_base_path( + state.settings.base_path.clone(), + ))?; + let user = session.get::(SESSION_KEY_AUTH_INFO).ok_or( + AppError::auth_redirect_from_base_path(state.settings.base_path.clone()), + )?; Ok(user) } } diff --git a/src/main.rs b/src/main.rs index 0721627..4433825 100644 --- a/src/main.rs +++ b/src/main.rs @@ -41,9 +41,7 @@ async fn main() { let db_pool = deadpool_diesel::postgres::Pool::builder(manager) .build() .unwrap(); - let session_store = PgStore::new(db_pool.clone()); - let mailer_creds = lettre::transport::smtp::authentication::Credentials::new( settings.email.smtp_username.clone(), settings.email.smtp_password.clone(), @@ -54,20 +52,31 @@ async fn main() { .credentials(mailer_creds) .build(), ); - + let reqwest_client = reqwest::ClientBuilder::new() + .https_only(true) + .build() + .unwrap(); let oauth_client = auth::new_oauth_client(&settings).unwrap(); + let app_state = AppState { db_pool, mailer, oauth_client, + reqwest_client, session_store, settings: settings.clone(), }; let router = new_router(app_state); - let listener = tokio::net::TcpListener::bind((settings.host, settings.port)) + let listener = tokio::net::TcpListener::bind((settings.host.clone(), settings.port.clone())) .await .unwrap(); + tracing::info!( + "App running at http://{}:{}{}", + settings.host, + settings.port, + settings.base_path + ); axum::serve(listener, router).await.unwrap(); } diff --git a/src/schema.rs b/src/schema.rs index bf093ae..f77b142 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -14,7 +14,7 @@ diesel::table! { id -> Text, serialized -> Text, created_at -> Timestamptz, - last_seen_at -> Timestamptz, + expiry -> Nullable, } } diff --git a/src/sessions.rs b/src/sessions.rs index aa19379..a02cb7d 100644 --- a/src/sessions.rs +++ b/src/sessions.rs @@ -1,17 +1,26 @@ use anyhow::Result; use async_session::{async_trait, Session, SessionStore}; +use axum::{ + extract::{FromRef, FromRequestParts}, + http::request::Parts, + RequestPartsExt as _, +}; +use axum_extra::extract::CookieJar; use chrono::{DateTime, TimeDelta, Utc}; use diesel::{pg::Pg, prelude::*, upsert::excluded}; +use tracing::trace_span; -use crate::schema::browser_sessions::dsl::*; +use crate::{app_error::AppError, app_state::AppState, schema::browser_sessions}; + +const EXPIRY_DAYS: i64 = 7; #[derive(Clone, Debug, Identifiable, Queryable, Selectable)] -#[diesel(table_name = crate::schema::browser_sessions)] +#[diesel(table_name = browser_sessions)] #[diesel(check_for_backend(Pg))] pub struct BrowserSession { pub id: String, pub serialized: String, - pub last_seen_at: DateTime, + pub expiry: Option>, } #[derive(Clone)] @@ -32,21 +41,28 @@ impl std::fmt::Debug for PgStore { } } +impl FromRef for PgStore { + fn from_ref(state: &AppState) -> Self { + state.session_store.clone() + } +} + #[async_trait] impl SessionStore for PgStore { async fn load_session(&self, cookie_value: String) -> Result> { let session_id = Session::id_from_cookie_value(&cookie_value)?; - let timestamp_stale = Utc::now() - TimeDelta::days(7); let conn = self.pool.get().await?; let row = conn .interact(move |conn| { // Drop all sessions without recent activity - diesel::delete(browser_sessions.filter(last_seen_at.lt(timestamp_stale))) - .execute(conn)?; - diesel::update(browser_sessions.filter(id.eq(session_id))) - .set(last_seen_at.eq(diesel::dsl::now)) - .returning(BrowserSession::as_returning()) - .get_result(conn) + diesel::delete( + browser_sessions::table.filter(browser_sessions::expiry.lt(diesel::dsl::now)), + ) + .execute(conn)?; + browser_sessions::table + .filter(browser_sessions::id.eq(session_id)) + .select(BrowserSession::as_select()) + .first(conn) .optional() }) .await @@ -62,43 +78,94 @@ impl SessionStore for PgStore { async fn store_session(&self, session: Session) -> Result> { let serialized_data = serde_json::to_string(&session)?; let session_id = session.id().to_string(); + let expiry = session.expiry().map(|exp| exp.clone()); let conn = self.pool.get().await?; conn.interact(move |conn| { - diesel::insert_into(browser_sessions) + diesel::insert_into(browser_sessions::table) .values(( - id.eq(session_id), - serialized.eq(serialized_data), - last_seen_at.eq(diesel::dsl::now), + browser_sessions::id.eq(session_id), + browser_sessions::serialized.eq(serialized_data), + browser_sessions::expiry.eq(expiry), )) - .on_conflict(id) + .on_conflict(browser_sessions::id) .do_update() .set(( - serialized.eq(excluded(serialized)), - last_seen_at.eq(excluded(last_seen_at)), + browser_sessions::serialized.eq(excluded(browser_sessions::serialized)), + browser_sessions::expiry.eq(excluded(browser_sessions::expiry)), )) .execute(conn) }) .await .unwrap()?; - session.reset_data_changed(); Ok(session.into_cookie_value()) } async fn destroy_session(&self, session: Session) -> Result<()> { + let session_id = session.id().to_owned(); let conn = self.pool.get().await?; conn.interact(move |conn| { - diesel::delete(browser_sessions.filter(id.eq(session.id().to_string()))).execute(conn) + diesel::delete( + browser_sessions::table.filter(browser_sessions::id.eq(session.id().to_string())), + ) + .execute(conn) }) .await .unwrap()?; + tracing::debug!("destroyed session {}", session_id); Ok(()) } async fn clear_store(&self) -> Result<()> { let conn = self.pool.get().await?; - conn.interact(move |conn| diesel::delete(browser_sessions).execute(conn)) + conn.interact(move |conn| diesel::delete(browser_sessions::table).execute(conn)) .await .unwrap()?; Ok(()) } } + +#[derive(Clone)] +pub struct AppSession(pub Option); + +impl FromRequestParts for AppSession { + type Rejection = AppError; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result>::Rejection> { + let _ = trace_span!("AppSession::from_request_parts()").enter(); + let jar = parts.extract::().await.unwrap(); + let session_cookie = match jar.get(&state.settings.auth.cookie_name) { + Some(cookie) => cookie, + None => { + tracing::debug!("no session cookie present"); + return Ok(AppSession(None)); + } + }; + tracing::debug!("session cookie loaded"); + let maybe_session = state + .session_store + .load_session(session_cookie.value().to_string()) + .await?; + if let Some(mut session) = maybe_session { + tracing::debug!("session {} loaded", session.id()); + session.expire_in(TimeDelta::days(EXPIRY_DAYS).to_std()?); + if state + .session_store + .store_session(session.clone()) + .await? + .is_some() + { + return Err(anyhow::anyhow!( + "expected cookie value returned by store_session() to be None for existing session" + ) + .into()); + } + Ok(AppSession(Some(session))) + } else { + tracing::debug!("no matching session found in database"); + Ok(AppSession(None)) + } + } +} diff --git a/src/settings.rs b/src/settings.rs index fa42e00..37cfe36 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -39,6 +39,7 @@ pub struct AuthSettings { pub auth_url: String, pub token_url: String, pub userinfo_url: String, + pub logout_url: Option, #[serde(default = "default_cookie_name")] pub cookie_name: String, @@ -68,7 +69,7 @@ pub struct SlackSettings { impl Settings { pub fn load() -> Result { if let Err(err) = dotenv() { - println!("Couldn't load .env file: {:?}", err); + tracing::warn!("Couldn't load .env file: {:?}", err); } let s = Config::builder() .add_source(Environment::default()) diff --git a/templates/nav.html b/templates/nav.html index 069d8a0..df19e34 100644 --- a/templates/nav.html +++ b/templates/nav.html @@ -52,7 +52,7 @@