1
0
Fork 0
forked from 2sys/shoutdotdev

tighten oauth login/logout flows

This commit is contained in:
Brent Schroeter 2025-02-26 13:10:45 -08:00
parent d051b97810
commit 401fcde4ce
9 changed files with 260 additions and 129 deletions

View file

@ -2,7 +2,7 @@ CREATE TABLE browser_sessions (
id TEXT NOT NULL PRIMARY KEY,
serialized TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_seen_at TIMESTAMPTZ NOT NULL
expiry TIMESTAMPTZ
);
CREATE INDEX ON browser_sessions (last_seen_at);
CREATE INDEX ON browser_sessions (expiry);
CREATE INDEX ON browser_sessions (created_at);

View file

@ -1,7 +1,12 @@
use std::fmt::{self, Display};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::response::{IntoResponse, Redirect, Response};
#[derive(Debug)]
pub struct AuthRedirectInfo {
base_path: String,
}
// Use anyhow, define error and enable '?'
// For a simplified example of using anyhow in axum check /examples/anyhow-error-response
@ -11,12 +16,23 @@ pub enum AppError {
ForbiddenError(String),
NotFoundError(String),
BadRequestError(String),
AuthRedirect(AuthRedirectInfo),
}
impl AppError {
pub fn auth_redirect_from_base_path(base_path: String) -> Self {
Self::AuthRedirect(AuthRedirectInfo { base_path })
}
}
// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
fn into_response(self) -> Response {
match self {
Self::AuthRedirect(AuthRedirectInfo { base_path }) => {
tracing::debug!("Handling AuthRedirect");
Redirect::to(&format!("{}/auth/login", base_path)).into_response()
}
Self::InternalServerError(err) => {
tracing::error!("Application error: {:?}", err);
(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response()
@ -51,6 +67,7 @@ where
impl Display for AppError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AppError::AuthRedirect(info) => write!(f, "AuthRedirect: {:?}", info),
AppError::InternalServerError(inner) => inner.fmt(f),
AppError::ForbiddenError(client_message) => {
write!(f, "ForbiddenError: {}", client_message)

View file

@ -12,6 +12,7 @@ use crate::{app_error::AppError, sessions::PgStore, settings::Settings};
pub struct AppState {
pub db_pool: Pool,
pub mailer: Mailer,
pub reqwest_client: reqwest::Client,
pub oauth_client: BasicClient,
pub session_store: PgStore,
pub settings: Settings,
@ -26,9 +27,12 @@ impl FromRef<AppState> for Mailer {
}
}
impl FromRef<AppState> for PgStore {
#[derive(Clone)]
pub struct ReqwestClient(pub reqwest::Client);
impl FromRef<AppState> for ReqwestClient {
fn from_ref(state: &AppState) -> Self {
state.session_store.clone()
ReqwestClient(state.reqwest_client.clone())
}
}

View file

@ -3,23 +3,28 @@ use async_session::{Session, SessionStore as _};
use axum::{
extract::{FromRequestParts, Query, State},
http::request::Parts,
response::{IntoResponse, Redirect, Response},
response::{IntoResponse, Redirect},
routing::get,
RequestPartsExt, Router,
};
use axum_extra::{
extract::cookie::{Cookie, CookieJar, SameSite},
headers, TypedHeader,
};
use diesel::prelude::*;
use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
use oauth2::{
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
ClientSecret, CsrfToken, RedirectUrl, TokenResponse, TokenUrl,
ClientSecret, CsrfToken, RedirectUrl, RefreshToken, TokenResponse, TokenUrl,
};
use serde::{Deserialize, Serialize};
use tracing::{debug, trace_span};
use tracing::trace_span;
use crate::{app_error::AppError, app_state::AppState, schema, settings::Settings};
use crate::{
app_error::AppError,
app_state::{AppState, ReqwestClient},
sessions::{AppSession, PgStore},
settings::Settings,
};
const SESSION_KEY_AUTH_CSRF_TOKEN: &'static str = "oauth_csrf_token";
const SESSION_KEY_AUTH_REFRESH_TOKEN: &'static str = "oauth_refresh_token";
const SESSION_KEY_AUTH_INFO: &'static str = "auth";
pub fn new_oauth_client(settings: &Settings) -> Result<BasicClient, AppError> {
Ok(BasicClient::new(
@ -45,38 +50,81 @@ pub fn new_router() -> Router<AppState> {
.route("/logout", get(logout))
}
pub async fn propel_auth(State(state): State<AppState>) -> impl IntoResponse {
pub async fn propel_auth(
State(state): State<AppState>,
State(Settings {
auth: auth_settings,
base_path,
..
}): State<Settings>,
State(session_store): State<PgStore>,
AppSession(maybe_session): AppSession,
jar: CookieJar,
) -> Result<impl IntoResponse, AppError> {
if let Some(session) = maybe_session {
if session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).is_some() {
tracing::debug!("already logged in, redirecting...");
return Ok(Redirect::to(&base_path).into_response());
}
}
let csrf_token = CsrfToken::new_random();
let (auth_url, _csrf_token) = state
.oauth_client
.authorize_url(|| csrf_token)
// .add_scopes(vec![Scope::new("openid".to_string())])
.authorize_url(|| csrf_token.clone())
.url();
// FIXME: check CSRF token
Redirect::to(auth_url.as_ref())
let mut session = Session::new();
session.insert(SESSION_KEY_AUTH_CSRF_TOKEN, &csrf_token)?;
let cookie_value = session_store
.store_session(session)
.await?
.ok_or(anyhow::anyhow!("cookie value from store_session() is None"))?;
let jar = jar.add(
Cookie::build((auth_settings.cookie_name.clone(), cookie_value))
.same_site(SameSite::Lax)
.http_only(true)
.path("/"),
);
Ok((jar, Redirect::to(&auth_url.to_string())).into_response())
}
pub async fn logout(
State(state): State<AppState>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
State(Settings {
base_path,
auth: auth_settings,
..
}): State<Settings>,
State(ReqwestClient(reqwest_client)): State<ReqwestClient>,
State(session_store): State<PgStore>,
AppSession(session): AppSession,
jar: CookieJar,
) -> Result<impl IntoResponse, AppError> {
let cookie = cookies
.get(state.settings.auth.cookie_name.as_str())
.context("couldn't get session cookie")?;
let session_id = Session::id_from_cookie_value(cookie)?;
state
.db_pool
.get()
.await?
.interact(move |conn| {
diesel::delete(schema::browser_sessions::table)
.filter(schema::browser_sessions::id.eq(session_id))
.execute(conn)
})
.await
.unwrap()?;
// FIXME: call logout endpoint of OIDC provider
Ok(Redirect::to(&state.settings.base_path))
if let Some(session) = session {
tracing::debug!("Session {} loaded.", session.id());
if let Some(logout_url) = auth_settings.logout_url {
tracing::debug!("attempting to send logout request to oauth provider");
let refresh_token: Option<RefreshToken> = session.get(SESSION_KEY_AUTH_REFRESH_TOKEN);
if let Some(refresh_token) = refresh_token {
tracing::debug!("Sending logout request to OAuth provider.");
#[derive(Serialize)]
struct LogoutRequestBody {
refresh_token: String,
}
reqwest_client
.post(logout_url)
.json(&LogoutRequestBody {
refresh_token: refresh_token.secret().to_owned(),
})
.send()
.await?
.error_for_status()?;
tracing::debug!("Sent logout request to OAuth provider successfully.");
}
}
session_store.destroy_session(session).await?;
}
let jar = jar.remove(Cookie::from(auth_settings.cookie_name));
tracing::debug!("Removed session cookie from jar.");
Ok((jar, Redirect::to(&base_path)))
}
#[derive(Debug, Deserialize)]
@ -85,96 +133,81 @@ pub struct AuthRequestQuery {
state: String, // CSRF token
}
pub const AUTH_INFO_SESSION_KEY: &'static str = "user";
#[derive(Debug, Deserialize, Serialize)]
pub struct AuthInfo {
pub sub: String,
pub email: String,
}
async fn get_user_info(
settings: &Settings,
access_token: &oauth2::AccessToken,
) -> Result<AuthInfo, AppError> {
let client = reqwest::Client::new();
Ok(client
.get(settings.auth.userinfo_url.as_str())
.bearer_auth(access_token.secret())
.send()
.await?
.json()
.await?)
}
pub async fn login_authorized(
Query(query): Query<AuthRequestQuery>,
State(state): State<AppState>,
jar: CookieJar,
State(Settings {
auth: auth_settings,
base_path,
..
}): State<Settings>,
State(ReqwestClient(reqwest_client)): State<ReqwestClient>,
AppSession(session): AppSession,
) -> Result<impl IntoResponse, AppError> {
let mut session = if let Some(session) = session {
session
} else {
return Err(AppError::auth_redirect_from_base_path(
state.settings.base_path,
));
};
let session_csrf_token: String = session.get(SESSION_KEY_AUTH_CSRF_TOKEN).ok_or_else(|| {
tracing::debug!("oauth csrf token not found on session");
AppError::auth_redirect_from_base_path(base_path.clone())
})?;
if session_csrf_token != query.state {
tracing::debug!("oauth csrf tokens did not match");
return Err(AppError::ForbiddenError(
"OAuth CSRF tokens do not match.".to_string(),
));
}
let response = state
.oauth_client
.exchange_code(AuthorizationCode::new(query.code.clone()))
.request_async(async_http_client)
.await?;
let user_info = get_user_info(&state.settings, response.access_token()).await?;
let mut session = Session::new();
session.insert(AUTH_INFO_SESSION_KEY, &user_info)?;
let cookie_value = state
.session_store
.store_session(session)
let auth_info: AuthInfo = reqwest_client
.get(auth_settings.userinfo_url.as_str())
.bearer_auth(response.access_token().secret())
.send()
.await?
.context("cookie value from store_session() is None")?;
let jar = jar.add(
Cookie::build((state.settings.auth.cookie_name.clone(), cookie_value))
.same_site(SameSite::Lax)
.http_only(true)
.path("/"),
);
Ok((jar, Redirect::to(&format!("{}/", state.settings.base_path))))
}
pub struct AuthRedirect {
base_path: String,
}
impl AuthRedirect {
pub fn new(base_path: &str) -> Self {
Self {
base_path: base_path.to_string(),
}
}
}
impl IntoResponse for AuthRedirect {
fn into_response(self) -> Response {
Redirect::to(&format!("{}/auth/login", self.base_path)).into_response()
.json()
.await?;
session.insert(SESSION_KEY_AUTH_INFO, &auth_info)?;
session.insert(SESSION_KEY_AUTH_REFRESH_TOKEN, response.refresh_token())?;
if state.session_store.store_session(session).await?.is_some() {
return Err(anyhow::anyhow!(
"expected cookie value returned by store_session() to be None for existing session"
)
.into());
}
Ok(Redirect::to(&base_path))
}
impl FromRequestParts<AppState> for AuthInfo {
type Rejection = AuthRedirect;
type Rejection = AppError;
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, <Self as FromRequestParts<AppState>>::Rejection> {
let _ = trace_span!("AuthInfo from_request_parts()").enter();
let jar = parts.extract::<CookieJar>().await.unwrap();
let session_cookie = jar
.get(&state.settings.auth.cookie_name)
.ok_or(AuthRedirect::new(&state.settings.base_path))?;
debug!("session cookie loaded");
let session = state
.session_store
.load_session(session_cookie.value().to_string())
.await
.unwrap()
.ok_or(AuthRedirect::new(&state.settings.base_path))?;
debug!("session loaded");
let user = session
.get::<AuthInfo>(AUTH_INFO_SESSION_KEY)
.ok_or(AuthRedirect::new(&state.settings.base_path))?;
let session = parts
.extract_with_state::<AppSession, AppState>(state)
.await?
.0
.ok_or(AppError::auth_redirect_from_base_path(
state.settings.base_path.clone(),
))?;
let user = session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).ok_or(
AppError::auth_redirect_from_base_path(state.settings.base_path.clone()),
)?;
Ok(user)
}
}

View file

@ -41,9 +41,7 @@ async fn main() {
let db_pool = deadpool_diesel::postgres::Pool::builder(manager)
.build()
.unwrap();
let session_store = PgStore::new(db_pool.clone());
let mailer_creds = lettre::transport::smtp::authentication::Credentials::new(
settings.email.smtp_username.clone(),
settings.email.smtp_password.clone(),
@ -54,20 +52,31 @@ async fn main() {
.credentials(mailer_creds)
.build(),
);
let reqwest_client = reqwest::ClientBuilder::new()
.https_only(true)
.build()
.unwrap();
let oauth_client = auth::new_oauth_client(&settings).unwrap();
let app_state = AppState {
db_pool,
mailer,
oauth_client,
reqwest_client,
session_store,
settings: settings.clone(),
};
let router = new_router(app_state);
let listener = tokio::net::TcpListener::bind((settings.host, settings.port))
let listener = tokio::net::TcpListener::bind((settings.host.clone(), settings.port.clone()))
.await
.unwrap();
tracing::info!(
"App running at http://{}:{}{}",
settings.host,
settings.port,
settings.base_path
);
axum::serve(listener, router).await.unwrap();
}

View file

@ -14,7 +14,7 @@ diesel::table! {
id -> Text,
serialized -> Text,
created_at -> Timestamptz,
last_seen_at -> Timestamptz,
expiry -> Nullable<Timestamptz>,
}
}

View file

@ -1,17 +1,26 @@
use anyhow::Result;
use async_session::{async_trait, Session, SessionStore};
use axum::{
extract::{FromRef, FromRequestParts},
http::request::Parts,
RequestPartsExt as _,
};
use axum_extra::extract::CookieJar;
use chrono::{DateTime, TimeDelta, Utc};
use diesel::{pg::Pg, prelude::*, upsert::excluded};
use tracing::trace_span;
use crate::schema::browser_sessions::dsl::*;
use crate::{app_error::AppError, app_state::AppState, schema::browser_sessions};
const EXPIRY_DAYS: i64 = 7;
#[derive(Clone, Debug, Identifiable, Queryable, Selectable)]
#[diesel(table_name = crate::schema::browser_sessions)]
#[diesel(table_name = browser_sessions)]
#[diesel(check_for_backend(Pg))]
pub struct BrowserSession {
pub id: String,
pub serialized: String,
pub last_seen_at: DateTime<Utc>,
pub expiry: Option<DateTime<Utc>>,
}
#[derive(Clone)]
@ -32,21 +41,28 @@ impl std::fmt::Debug for PgStore {
}
}
impl FromRef<AppState> for PgStore {
fn from_ref(state: &AppState) -> Self {
state.session_store.clone()
}
}
#[async_trait]
impl SessionStore for PgStore {
async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
let session_id = Session::id_from_cookie_value(&cookie_value)?;
let timestamp_stale = Utc::now() - TimeDelta::days(7);
let conn = self.pool.get().await?;
let row = conn
.interact(move |conn| {
// Drop all sessions without recent activity
diesel::delete(browser_sessions.filter(last_seen_at.lt(timestamp_stale)))
.execute(conn)?;
diesel::update(browser_sessions.filter(id.eq(session_id)))
.set(last_seen_at.eq(diesel::dsl::now))
.returning(BrowserSession::as_returning())
.get_result(conn)
diesel::delete(
browser_sessions::table.filter(browser_sessions::expiry.lt(diesel::dsl::now)),
)
.execute(conn)?;
browser_sessions::table
.filter(browser_sessions::id.eq(session_id))
.select(BrowserSession::as_select())
.first(conn)
.optional()
})
.await
@ -62,43 +78,94 @@ impl SessionStore for PgStore {
async fn store_session(&self, session: Session) -> Result<Option<String>> {
let serialized_data = serde_json::to_string(&session)?;
let session_id = session.id().to_string();
let expiry = session.expiry().map(|exp| exp.clone());
let conn = self.pool.get().await?;
conn.interact(move |conn| {
diesel::insert_into(browser_sessions)
diesel::insert_into(browser_sessions::table)
.values((
id.eq(session_id),
serialized.eq(serialized_data),
last_seen_at.eq(diesel::dsl::now),
browser_sessions::id.eq(session_id),
browser_sessions::serialized.eq(serialized_data),
browser_sessions::expiry.eq(expiry),
))
.on_conflict(id)
.on_conflict(browser_sessions::id)
.do_update()
.set((
serialized.eq(excluded(serialized)),
last_seen_at.eq(excluded(last_seen_at)),
browser_sessions::serialized.eq(excluded(browser_sessions::serialized)),
browser_sessions::expiry.eq(excluded(browser_sessions::expiry)),
))
.execute(conn)
})
.await
.unwrap()?;
session.reset_data_changed();
Ok(session.into_cookie_value())
}
async fn destroy_session(&self, session: Session) -> Result<()> {
let session_id = session.id().to_owned();
let conn = self.pool.get().await?;
conn.interact(move |conn| {
diesel::delete(browser_sessions.filter(id.eq(session.id().to_string()))).execute(conn)
diesel::delete(
browser_sessions::table.filter(browser_sessions::id.eq(session.id().to_string())),
)
.execute(conn)
})
.await
.unwrap()?;
tracing::debug!("destroyed session {}", session_id);
Ok(())
}
async fn clear_store(&self) -> Result<()> {
let conn = self.pool.get().await?;
conn.interact(move |conn| diesel::delete(browser_sessions).execute(conn))
conn.interact(move |conn| diesel::delete(browser_sessions::table).execute(conn))
.await
.unwrap()?;
Ok(())
}
}
#[derive(Clone)]
pub struct AppSession(pub Option<Session>);
impl FromRequestParts<AppState> for AppSession {
type Rejection = AppError;
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, <Self as FromRequestParts<AppState>>::Rejection> {
let _ = trace_span!("AppSession::from_request_parts()").enter();
let jar = parts.extract::<CookieJar>().await.unwrap();
let session_cookie = match jar.get(&state.settings.auth.cookie_name) {
Some(cookie) => cookie,
None => {
tracing::debug!("no session cookie present");
return Ok(AppSession(None));
}
};
tracing::debug!("session cookie loaded");
let maybe_session = state
.session_store
.load_session(session_cookie.value().to_string())
.await?;
if let Some(mut session) = maybe_session {
tracing::debug!("session {} loaded", session.id());
session.expire_in(TimeDelta::days(EXPIRY_DAYS).to_std()?);
if state
.session_store
.store_session(session.clone())
.await?
.is_some()
{
return Err(anyhow::anyhow!(
"expected cookie value returned by store_session() to be None for existing session"
)
.into());
}
Ok(AppSession(Some(session)))
} else {
tracing::debug!("no matching session found in database");
Ok(AppSession(None))
}
}
}

View file

@ -39,6 +39,7 @@ pub struct AuthSettings {
pub auth_url: String,
pub token_url: String,
pub userinfo_url: String,
pub logout_url: Option<String>,
#[serde(default = "default_cookie_name")]
pub cookie_name: String,
@ -68,7 +69,7 @@ pub struct SlackSettings {
impl Settings {
pub fn load() -> Result<Self, ConfigError> {
if let Err(err) = dotenv() {
println!("Couldn't load .env file: {:?}", err);
tracing::warn!("Couldn't load .env file: {:?}", err);
}
let s = Config::builder()
.add_source(Environment::default())

View file

@ -52,7 +52,7 @@
<ul class="dropdown-menu">
<li><a class="dropdown-item" href="#">Settings</a></li>
<li><hr class="dropdown-divider"></li>
<li><a class="dropdown-item" href="#">Log out</a></li>
<li><a class="dropdown-item" href="{{ base_path }}/auth/logout">Log out</a></li>
</ul>
</li>
</ul>