use chrono::{DateTime, TimeDelta, Utc}; use deadpool_diesel::postgres::Connection; use diesel::{ dsl::{auto_type, AsSelect, Gt, Select}, pg::Pg, prelude::*, }; use uuid::Uuid; use crate::{app_error::AppError, schema::csrf_tokens}; pub use crate::schema::csrf_tokens::{dsl, table}; const TOKEN_PREFIX: &str = "csrf-"; #[derive(Clone, Debug, Identifiable, Queryable, Selectable)] #[diesel(table_name = csrf_tokens)] #[diesel(check_for_backend(Pg))] pub struct CsrfToken { pub id: Uuid, pub user_id: Option, pub created_at: DateTime, } impl CsrfToken { fn all() -> Select> { table.select(Self::as_select()) } pub fn is_not_expired() -> Gt> { let ttl = TimeDelta::hours(24); let min_created_at: DateTime = Utc::now() - ttl; dsl::created_at.gt(min_created_at) } #[auto_type(no_type_alias)] pub fn with_user_id<'a>(token_user_id: &'a Option) -> _ { dsl::user_id.is_not_distinct_from(token_user_id) } #[auto_type(no_type_alias)] pub fn with_token_id<'a>(token_id: &'a Uuid) -> _ { dsl::id.eq(token_id) } } /// Convenience function for creating new CSRF token rows in the database. pub async fn generate_csrf_token( db_conn: &Connection, with_user_id: Option, ) -> Result { let token_id = Uuid::new_v4(); db_conn .interact(move |conn| { diesel::insert_into(table) .values(( dsl::id.eq(token_id), dsl::user_id.eq(with_user_id), dsl::created_at.eq(diesel::dsl::now), )) .execute(conn) }) .await .unwrap()?; Ok(format!("{}{}", TOKEN_PREFIX, token_id.simple())) } /// Convenience function for validating CSRF tokens against the database. pub async fn validate_csrf_token( db_conn: &Connection, token: &str, with_user_id: Option, ) -> Result { let token_id = match Uuid::try_parse(&token[TOKEN_PREFIX.len()..]) { Ok(token_id) => token_id, Err(_) => return Ok(false), }; Ok(db_conn .interact(move |conn| { CsrfToken::all() .filter(CsrfToken::with_token_id(&token_id)) .filter(CsrfToken::with_user_id(&with_user_id)) .filter(CsrfToken::is_not_expired()) .first(conn) .optional() }) .await .unwrap()? .is_some()) }