//! Browser session management via [`async_session`]. use anyhow::Result; use async_session::{Session, SessionStore, async_trait}; use axum::{ RequestPartsExt as _, extract::{FromRef, FromRequestParts}, http::request::Parts, }; use axum_extra::extract::CookieJar; use chrono::{DateTime, TimeDelta, Utc}; use sqlx::{PgPool, query, query_as}; use tracing::{Instrument, trace_span}; use crate::{app_error::AppError, app_state::App}; const EXPIRY_DAYS: i64 = 7; pub struct BrowserSession { pub id: String, pub serialized: String, pub expiry: Option>, pub created_at: DateTime, } #[derive(Clone)] pub struct PgStore { pool: PgPool, } impl PgStore { pub fn new(pool: PgPool) -> PgStore { Self { pool } } } impl std::fmt::Debug for PgStore { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "PgStore")?; Ok(()) } } impl FromRef for PgStore { fn from_ref(state: &App) -> Self { state.session_store.clone() } } #[async_trait] impl SessionStore for PgStore { async fn load_session(&self, cookie_value: String) -> Result> { let session_id = Session::id_from_cookie_value(&cookie_value)?; query!("delete from browser_sessions where expiry < now()") .execute(&self.pool) .await?; let row = query_as!( BrowserSession, "select * from browser_sessions where id = $1", session_id ) .fetch_optional(&self.pool) .await?; Ok(match row { Some(session) => Some(serde_json::from_str::( session.serialized.as_str(), )?), None => None, }) } async fn store_session(&self, session: Session) -> Result> { let serialized_data = serde_json::to_string(&session)?; let session_id = session.id().to_string(); let expiry = session.expiry().copied(); query!( " insert into browser_sessions (id, serialized, expiry) values ($1, $2, $3) on conflict (id) do update set serialized = excluded.serialized, expiry = excluded.expiry ", session_id, serialized_data, expiry ) .execute(&self.pool) .await?; Ok(session.into_cookie_value()) } async fn destroy_session(&self, session: Session) -> Result<()> { let session_id = session.id().to_owned(); query!("delete from browser_sessions where id = $1", session_id) .execute(&self.pool) .await?; tracing::debug!("destroyed session {}", session_id); Ok(()) } async fn clear_store(&self) -> Result<()> { query!("truncate browser_sessions") .execute(&self.pool) .await?; tracing::info!("cleared session store"); Ok(()) } } #[derive(Clone)] pub struct AppSession(pub Option); impl FromRequestParts for AppSession { type Rejection = AppError; async fn from_request_parts( parts: &mut Parts, state: &App, ) -> Result>::Rejection> { async move { let jar = parts.extract::().await.unwrap(); let session_cookie = match jar.get(&state.settings.auth.cookie_name) { Some(cookie) => cookie, None => { tracing::debug!("no session cookie present"); return Ok(AppSession(None)); } }; tracing::debug!("session cookie loaded"); let maybe_session = state .session_store .load_session(session_cookie.value().to_string()) .await?; if let Some(mut session) = maybe_session { tracing::debug!("session {} loaded", session.id()); session.expire_in(TimeDelta::days(EXPIRY_DAYS).to_std()?); if state .session_store .store_session(session.clone()) .await? .is_some() { return Err(anyhow::anyhow!( "expected cookie value returned by store_session() to be None for existing session" ) .into()); } Ok(AppSession(Some(session))) } else { tracing::debug!("no matching session found in database"); Ok(AppSession(None)) } // The Span.enter() guard pattern doesn't play nicely async }.instrument(trace_span!("AppSession::from_request_parts()")).await } }