1
0
Fork 0
forked from 2sys/shoutdotdev
shoutdotdev/src/csrf.rs

90 lines
2.5 KiB
Rust

use chrono::{DateTime, TimeDelta, Utc};
use deadpool_diesel::postgres::Connection;
use diesel::{
dsl::{auto_type, AsSelect, Gt, Select},
pg::Pg,
prelude::*,
};
use uuid::Uuid;
use crate::{app_error::AppError, schema::csrf_tokens};
pub use crate::schema::csrf_tokens::{dsl, table};
const TOKEN_PREFIX: &str = "csrf-";
#[derive(Clone, Debug, Identifiable, Queryable, Selectable)]
#[diesel(table_name = csrf_tokens)]
#[diesel(check_for_backend(Pg))]
pub struct CsrfToken {
pub id: Uuid,
pub user_id: Option<Uuid>,
pub created_at: DateTime<Utc>,
}
impl CsrfToken {
fn all() -> Select<table, AsSelect<CsrfToken, Pg>> {
table.select(Self::as_select())
}
pub fn is_not_expired() -> Gt<dsl::created_at, DateTime<Utc>> {
let ttl = TimeDelta::hours(24);
let min_created_at: DateTime<Utc> = Utc::now() - ttl;
dsl::created_at.gt(min_created_at)
}
#[auto_type(no_type_alias)]
pub fn with_user_id<'a>(token_user_id: &'a Option<Uuid>) -> _ {
dsl::user_id.is_not_distinct_from(token_user_id)
}
#[auto_type(no_type_alias)]
pub fn with_token_id<'a>(token_id: &'a Uuid) -> _ {
dsl::id.eq(token_id)
}
}
/// Convenience function for creating new CSRF token rows in the database.
pub async fn generate_csrf_token(
db_conn: &Connection,
with_user_id: Option<Uuid>,
) -> Result<String, AppError> {
let token_id = Uuid::new_v4();
db_conn
.interact(move |conn| {
diesel::insert_into(table)
.values((
dsl::id.eq(token_id),
dsl::user_id.eq(with_user_id),
dsl::created_at.eq(diesel::dsl::now),
))
.execute(conn)
})
.await
.unwrap()?;
Ok(format!("{}{}", TOKEN_PREFIX, token_id.simple()))
}
/// Convenience function for validating CSRF tokens against the database.
pub async fn validate_csrf_token(
db_conn: &Connection,
token: &str,
with_user_id: Option<Uuid>,
) -> Result<bool, AppError> {
let token_id = match Uuid::try_parse(&token[TOKEN_PREFIX.len()..]) {
Ok(token_id) => token_id,
Err(_) => return Ok(false),
};
Ok(db_conn
.interact(move |conn| {
CsrfToken::all()
.filter(CsrfToken::with_token_id(&token_id))
.filter(CsrfToken::with_user_id(&with_user_id))
.filter(CsrfToken::is_not_expired())
.first(conn)
.optional()
})
.await
.unwrap()?
.is_some())
}