use anyhow::Result; use async_session::{async_trait, Session, SessionStore}; use axum::{ extract::{FromRef, FromRequestParts}, http::request::Parts, RequestPartsExt as _, }; use axum_extra::extract::CookieJar; use chrono::{DateTime, TimeDelta, Utc}; use diesel::{pg::Pg, prelude::*, upsert::excluded}; use tracing::{trace_span, Instrument}; use crate::{app_error::AppError, app_state::AppState, schema::browser_sessions}; const EXPIRY_DAYS: i64 = 7; #[derive(Clone, Debug, Identifiable, Queryable, Selectable)] #[diesel(table_name = browser_sessions)] #[diesel(check_for_backend(Pg))] pub struct BrowserSession { pub id: String, pub serialized: String, pub expiry: Option>, } #[derive(Clone)] pub struct PgStore { pool: deadpool_diesel::postgres::Pool, } impl PgStore { pub fn new(pool: deadpool_diesel::postgres::Pool) -> PgStore { Self { pool } } } impl std::fmt::Debug for PgStore { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "PgStore")?; Ok(()).into() } } impl FromRef for PgStore { fn from_ref(state: &AppState) -> Self { state.session_store.clone() } } #[async_trait] impl SessionStore for PgStore { async fn load_session(&self, cookie_value: String) -> Result> { let session_id = Session::id_from_cookie_value(&cookie_value)?; let conn = self.pool.get().await?; let row = conn .interact(move |conn| { // Drop all sessions without recent activity diesel::delete( browser_sessions::table.filter(browser_sessions::expiry.lt(diesel::dsl::now)), ) .execute(conn)?; browser_sessions::table .filter(browser_sessions::id.eq(session_id)) .select(BrowserSession::as_select()) .first(conn) .optional() }) .await .unwrap()?; Ok(match row { Some(session) => Some(serde_json::from_str::( session.serialized.as_str(), )?), None => None, }) } async fn store_session(&self, session: Session) -> Result> { let serialized_data = serde_json::to_string(&session)?; let session_id = session.id().to_string(); let expiry = session.expiry().map(|exp| exp.clone()); let conn = self.pool.get().await?; conn.interact(move |conn| { diesel::insert_into(browser_sessions::table) .values(( browser_sessions::id.eq(session_id), browser_sessions::serialized.eq(serialized_data), browser_sessions::expiry.eq(expiry), )) .on_conflict(browser_sessions::id) .do_update() .set(( browser_sessions::serialized.eq(excluded(browser_sessions::serialized)), browser_sessions::expiry.eq(excluded(browser_sessions::expiry)), )) .execute(conn) }) .await .unwrap()?; Ok(session.into_cookie_value()) } async fn destroy_session(&self, session: Session) -> Result<()> { let session_id = session.id().to_owned(); let conn = self.pool.get().await?; conn.interact(move |conn| { diesel::delete( browser_sessions::table.filter(browser_sessions::id.eq(session.id().to_string())), ) .execute(conn) }) .await .unwrap()?; tracing::debug!("destroyed session {}", session_id); Ok(()) } async fn clear_store(&self) -> Result<()> { let conn = self.pool.get().await?; conn.interact(move |conn| diesel::delete(browser_sessions::table).execute(conn)) .await .unwrap()?; Ok(()) } } #[derive(Clone)] pub struct AppSession(pub Option); impl FromRequestParts for AppSession { type Rejection = AppError; async fn from_request_parts( parts: &mut Parts, state: &AppState, ) -> Result>::Rejection> { async move { let jar = parts.extract::().await.unwrap(); let session_cookie = match jar.get(&state.settings.auth.cookie_name) { Some(cookie) => cookie, None => { tracing::debug!("no session cookie present"); return Ok(AppSession(None)); } }; tracing::debug!("session cookie loaded"); let maybe_session = state .session_store .load_session(session_cookie.value().to_string()) .await?; if let Some(mut session) = maybe_session { tracing::debug!("session {} loaded", session.id()); session.expire_in(TimeDelta::days(EXPIRY_DAYS).to_std()?); if state .session_store .store_session(session.clone()) .await? .is_some() { return Err(anyhow::anyhow!( "expected cookie value returned by store_session() to be None for existing session" ) .into()); } Ok(AppSession(Some(session))) } else { tracing::debug!("no matching session found in database"); Ok(AppSession(None)) } // The Span.enter() guard pattern doesn't play nicely async }.instrument(trace_span!("AppSession::from_request_parts()")).await } }