156 lines
4.5 KiB
Rust
156 lines
4.5 KiB
Rust
use anyhow::Result;
|
|
use async_session::{Session, SessionStore, async_trait};
|
|
use axum::{
|
|
RequestPartsExt as _,
|
|
extract::{FromRef, FromRequestParts},
|
|
http::request::Parts,
|
|
};
|
|
use axum_extra::extract::CookieJar;
|
|
use chrono::{DateTime, TimeDelta, Utc};
|
|
use sqlx::{PgPool, query, query_as};
|
|
use tracing::{Instrument, trace_span};
|
|
|
|
use crate::{app_error::AppError, app_state::AppState};
|
|
|
|
const EXPIRY_DAYS: i64 = 7;
|
|
|
|
pub struct BrowserSession {
|
|
pub id: String,
|
|
pub serialized: String,
|
|
pub expiry: Option<DateTime<Utc>>,
|
|
pub created_at: DateTime<Utc>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct PgStore {
|
|
pool: PgPool,
|
|
}
|
|
|
|
impl PgStore {
|
|
pub fn new(pool: PgPool) -> PgStore {
|
|
Self { pool }
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Debug for PgStore {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "PgStore")?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl FromRef<AppState> for PgStore {
|
|
fn from_ref(state: &AppState) -> Self {
|
|
state.session_store.clone()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl SessionStore for PgStore {
|
|
async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
|
|
let session_id = Session::id_from_cookie_value(&cookie_value)?;
|
|
query!("delete from browser_sessions where expiry < now()")
|
|
.execute(&self.pool)
|
|
.await?;
|
|
let row = query_as!(
|
|
BrowserSession,
|
|
"select * from browser_sessions where id = $1",
|
|
session_id
|
|
)
|
|
.fetch_optional(&self.pool)
|
|
.await?;
|
|
Ok(match row {
|
|
Some(session) => Some(serde_json::from_str::<Session>(
|
|
session.serialized.as_str(),
|
|
)?),
|
|
None => None,
|
|
})
|
|
}
|
|
|
|
async fn store_session(&self, session: Session) -> Result<Option<String>> {
|
|
let serialized_data = serde_json::to_string(&session)?;
|
|
let session_id = session.id().to_string();
|
|
let expiry = session.expiry().copied();
|
|
query!(
|
|
"
|
|
insert into browser_sessions
|
|
(id, serialized, expiry)
|
|
values ($1, $2, $3)
|
|
on conflict (id) do update set
|
|
serialized = excluded.serialized,
|
|
expiry = excluded.expiry
|
|
",
|
|
session_id,
|
|
serialized_data,
|
|
expiry
|
|
)
|
|
.execute(&self.pool)
|
|
.await?;
|
|
Ok(session.into_cookie_value())
|
|
}
|
|
|
|
async fn destroy_session(&self, session: Session) -> Result<()> {
|
|
let session_id = session.id().to_owned();
|
|
query!("delete from browser_sessions where id = $1", session_id)
|
|
.execute(&self.pool)
|
|
.await?;
|
|
tracing::debug!("destroyed session {}", session_id);
|
|
Ok(())
|
|
}
|
|
|
|
async fn clear_store(&self) -> Result<()> {
|
|
query!("truncate browser_sessions")
|
|
.execute(&self.pool)
|
|
.await?;
|
|
tracing::info!("cleared session store");
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct AppSession(pub Option<Session>);
|
|
|
|
impl FromRequestParts<AppState> for AppSession {
|
|
type Rejection = AppError;
|
|
|
|
async fn from_request_parts(
|
|
parts: &mut Parts,
|
|
state: &AppState,
|
|
) -> Result<Self, <Self as FromRequestParts<AppState>>::Rejection> {
|
|
async move {
|
|
let jar = parts.extract::<CookieJar>().await.unwrap();
|
|
let session_cookie = match jar.get(&state.settings.auth.cookie_name) {
|
|
Some(cookie) => cookie,
|
|
None => {
|
|
tracing::debug!("no session cookie present");
|
|
return Ok(AppSession(None));
|
|
}
|
|
};
|
|
tracing::debug!("session cookie loaded");
|
|
let maybe_session = state
|
|
.session_store
|
|
.load_session(session_cookie.value().to_string())
|
|
.await?;
|
|
if let Some(mut session) = maybe_session {
|
|
tracing::debug!("session {} loaded", session.id());
|
|
session.expire_in(TimeDelta::days(EXPIRY_DAYS).to_std()?);
|
|
if state
|
|
.session_store
|
|
.store_session(session.clone())
|
|
.await?
|
|
.is_some()
|
|
{
|
|
return Err(anyhow::anyhow!(
|
|
"expected cookie value returned by store_session() to be None for existing session"
|
|
)
|
|
.into());
|
|
}
|
|
Ok(AppSession(Some(session)))
|
|
} else {
|
|
tracing::debug!("no matching session found in database");
|
|
Ok(AppSession(None))
|
|
}
|
|
// The Span.enter() guard pattern doesn't play nicely async
|
|
}.instrument(trace_span!("AppSession::from_request_parts()")).await
|
|
}
|
|
}
|