forked from 2sys/shoutdotdev
90 lines
2.5 KiB
Rust
90 lines
2.5 KiB
Rust
use chrono::{DateTime, TimeDelta, Utc};
|
|
use deadpool_diesel::postgres::Connection;
|
|
use diesel::{
|
|
dsl::{auto_type, AsSelect, Gt, Select},
|
|
pg::Pg,
|
|
prelude::*,
|
|
};
|
|
use uuid::Uuid;
|
|
|
|
use crate::{app_error::AppError, schema::csrf_tokens};
|
|
|
|
pub use crate::schema::csrf_tokens::{dsl, table};
|
|
|
|
const TOKEN_PREFIX: &str = "csrf-";
|
|
|
|
#[derive(Clone, Debug, Identifiable, Queryable, Selectable)]
|
|
#[diesel(table_name = csrf_tokens)]
|
|
#[diesel(check_for_backend(Pg))]
|
|
pub struct CsrfToken {
|
|
pub id: Uuid,
|
|
pub user_id: Option<Uuid>,
|
|
pub created_at: DateTime<Utc>,
|
|
}
|
|
|
|
impl CsrfToken {
|
|
fn all() -> Select<table, AsSelect<CsrfToken, Pg>> {
|
|
table.select(Self::as_select())
|
|
}
|
|
|
|
pub fn is_not_expired() -> Gt<dsl::created_at, DateTime<Utc>> {
|
|
let ttl = TimeDelta::hours(24);
|
|
let min_created_at: DateTime<Utc> = Utc::now() - ttl;
|
|
dsl::created_at.gt(min_created_at)
|
|
}
|
|
|
|
#[auto_type(no_type_alias)]
|
|
pub fn with_user_id<'a>(token_user_id: &'a Option<Uuid>) -> _ {
|
|
dsl::user_id.is_not_distinct_from(token_user_id)
|
|
}
|
|
|
|
#[auto_type(no_type_alias)]
|
|
pub fn with_token_id<'a>(token_id: &'a Uuid) -> _ {
|
|
dsl::id.eq(token_id)
|
|
}
|
|
}
|
|
|
|
/// Convenience function for creating new CSRF token rows in the database.
|
|
pub async fn generate_csrf_token(
|
|
db_conn: &Connection,
|
|
with_user_id: Option<Uuid>,
|
|
) -> Result<String, AppError> {
|
|
let token_id = Uuid::new_v4();
|
|
db_conn
|
|
.interact(move |conn| {
|
|
diesel::insert_into(table)
|
|
.values((
|
|
dsl::id.eq(token_id),
|
|
dsl::user_id.eq(with_user_id),
|
|
dsl::created_at.eq(diesel::dsl::now),
|
|
))
|
|
.execute(conn)
|
|
})
|
|
.await
|
|
.unwrap()?;
|
|
Ok(format!("{}{}", TOKEN_PREFIX, token_id.simple()))
|
|
}
|
|
|
|
/// Convenience function for validating CSRF tokens against the database.
|
|
pub async fn validate_csrf_token(
|
|
db_conn: &Connection,
|
|
token: &str,
|
|
with_user_id: Option<Uuid>,
|
|
) -> Result<bool, AppError> {
|
|
let token_id = match Uuid::try_parse(&token[TOKEN_PREFIX.len()..]) {
|
|
Ok(token_id) => token_id,
|
|
Err(_) => return Ok(false),
|
|
};
|
|
Ok(db_conn
|
|
.interact(move |conn| {
|
|
CsrfToken::all()
|
|
.filter(CsrfToken::with_token_id(&token_id))
|
|
.filter(CsrfToken::with_user_id(&with_user_id))
|
|
.filter(CsrfToken::is_not_expired())
|
|
.first(conn)
|
|
.optional()
|
|
})
|
|
.await
|
|
.unwrap()?
|
|
.is_some())
|
|
}
|