From 47bb893d3ffdec48461466e79b0442a316709c5c Mon Sep 17 00:00:00 2001 From: Brent Schroeter Date: Tue, 28 Jan 2025 18:01:43 -0800 Subject: [PATCH] misc back-end cleanup --- Cargo.lock | 12 ++ Cargo.toml | 2 +- migrations/2024-11-25-232658_init/up.sql | 4 +- migrations/2025-01-08-211839_sessions/up.sql | 6 +- src/api_keys.rs | 34 ++-- src/app_state.rs | 20 +- src/csrf.rs | 96 ++++++---- src/guards.rs | 45 +++++ src/main.rs | 9 +- src/models.rs | 41 ---- src/projects.rs | 19 +- src/router.rs | 190 +++++-------------- src/schema.rs | 4 +- src/sessions.rs | 73 ++++--- src/settings.rs | 9 + src/team_memberships.rs | 43 +++++ src/teams.rs | 22 +++ src/users.rs | 38 ++-- 18 files changed, 366 insertions(+), 301 deletions(-) create mode 100644 src/guards.rs create mode 100644 src/team_memberships.rs create mode 100644 src/teams.rs diff --git a/Cargo.lock b/Cargo.lock index 4743cd8..ef6c906 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -227,6 +227,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" dependencies = [ "axum-core 0.5.0", + "axum-macros", "bytes", "form_urlencoded", "futures-util", @@ -317,6 +318,17 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "backtrace" version = "0.3.74" diff --git a/Cargo.toml b/Cargo.toml index 52d65bd..3c69f3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ tracing-subscriber = { version = "0.3.19", features = ["chrono", "env-filter"] } tower-http = { version = "0.6.2", features = ["compression-br", "compression-gzip", "fs", "trace", "tracing"] } tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "tracing"] } deadpool-diesel = { version = "0.6.1", features = ["postgres", "serde"] } -axum = "0.8.1" +axum = { version = "0.8.1", features = ["macros"] } axum-extra = { version = "0.10.0", features = ["cookie", "typed-header"] } chrono = { version = "0.4.39", features = ["serde"] } base64 = "0.22.1" diff --git a/migrations/2024-11-25-232658_init/up.sql b/migrations/2024-11-25-232658_init/up.sql index db22957..8a92256 100644 --- a/migrations/2024-11-25-232658_init/up.sql +++ b/migrations/2024-11-25-232658_init/up.sql @@ -8,9 +8,9 @@ CREATE INDEX ON users (uid); CREATE TABLE IF NOT EXISTS csrf_tokens ( id UUID NOT NULL PRIMARY KEY, user_id UUID REFERENCES users(id), - expires_at TIMESTAMPTZ NOT NULL + created_at TIMESTAMPTZ NOT NULL ); -CREATE INDEX ON csrf_tokens (expires_at); +CREATE INDEX ON csrf_tokens (created_at); CREATE TABLE teams ( id UUID NOT NULL PRIMARY KEY, diff --git a/migrations/2025-01-08-211839_sessions/up.sql b/migrations/2025-01-08-211839_sessions/up.sql index 7ff25bb..f3a2db6 100644 --- a/migrations/2025-01-08-211839_sessions/up.sql +++ b/migrations/2025-01-08-211839_sessions/up.sql @@ -1,4 +1,8 @@ CREATE TABLE browser_sessions ( id TEXT NOT NULL PRIMARY KEY, - serialized TEXT NOT NULL + serialized TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + last_seen_at TIMESTAMPTZ NOT NULL ); +CREATE INDEX ON browser_sessions (last_seen_at); +CREATE INDEX ON browser_sessions (created_at); diff --git a/src/api_keys.rs b/src/api_keys.rs index 3a11e68..98736b1 100644 --- a/src/api_keys.rs +++ b/src/api_keys.rs @@ -1,12 +1,11 @@ -use anyhow::Result; -use deadpool_diesel::postgres::Pool; +use deadpool_diesel::postgres::Connection; use diesel::prelude::*; use uuid::Uuid; -use crate::{app_error::AppError, models::Team, schema}; +use crate::{app_error::AppError, schema::api_keys::dsl::*, teams::Team}; -#[derive(Associations, Clone, Debug, Identifiable, Queryable, Selectable)] -#[diesel(table_name = schema::api_keys)] +#[derive(Associations, Clone, Debug, Identifiable, Insertable, Queryable, Selectable)] +#[diesel(table_name = crate::schema::api_keys)] #[diesel(belongs_to(Team))] pub struct ApiKey { pub id: Uuid, @@ -14,19 +13,20 @@ pub struct ApiKey { } impl ApiKey { - pub async fn generate_for_team(db_pool: &Pool, team_id: Uuid) -> Result { - let id = Uuid::new_v4(); - let api_key = db_pool - .get() - .await? + pub async fn generate_for_team( + db_conn: &Connection, + key_team_id: Uuid, + ) -> Result { + let api_key = Self { + id: Uuid::new_v4(), + team_id: key_team_id, + }; + let api_key_copy = api_key.clone(); + db_conn .interact(move |conn| { - diesel::insert_into(schema::api_keys::table) - .values(( - schema::api_keys::id.eq(id), - schema::api_keys::team_id.eq(team_id), - )) - .returning(ApiKey::as_select()) - .get_result(conn) + diesel::insert_into(api_keys) + .values(api_key_copy) + .execute(conn) }) .await .unwrap()?; diff --git a/src/app_state.rs b/src/app_state.rs index 76642b5..5dd1fdf 100644 --- a/src/app_state.rs +++ b/src/app_state.rs @@ -1,8 +1,11 @@ -use axum::extract::FromRef; -use deadpool_diesel::postgres::Pool; +use axum::{ + extract::{FromRef, FromRequestParts}, + http::request::Parts, +}; +use deadpool_diesel::postgres::{Connection, Pool}; use oauth2::basic::BasicClient; -use crate::{sessions::PgStore, settings::Settings}; +use crate::{app_error::AppError, sessions::PgStore, settings::Settings}; #[derive(Clone)] pub(crate) struct AppState { @@ -17,3 +20,14 @@ impl FromRef for PgStore { state.session_store.clone() } } + +pub struct DbConn(pub Connection); + +impl FromRequestParts for DbConn { + type Rejection = AppError; + + async fn from_request_parts(_: &mut Parts, state: &AppState) -> Result { + let conn = state.db_pool.get().await?; + Ok(Self(conn)) + } +} diff --git a/src/csrf.rs b/src/csrf.rs index 730ee9b..1a1277b 100644 --- a/src/csrf.rs +++ b/src/csrf.rs @@ -1,60 +1,84 @@ -use anyhow::{Context, Result}; -use chrono::{TimeDelta, Utc}; -use deadpool_diesel::postgres::Pool; -use diesel::prelude::*; +use chrono::{DateTime, TimeDelta, Utc}; +use deadpool_diesel::postgres::Connection; +use diesel::{ + dsl::{AsSelect, Eq, Gt, IsNotDistinctFrom, Select}, + pg::Pg, + prelude::*, +}; use uuid::Uuid; -use crate::{app_error::AppError, models::CsrfToken, schema}; +use crate::{app_error::AppError, schema::csrf_tokens::dsl::*}; -const TOKEN_PREFIX: &'static str = "csrf__"; -const TTL_SEC: i64 = 60 * 60 * 24 * 7; +const TOKEN_PREFIX: &'static str = "csrf-"; -pub async fn generate_csrf_token_for_user( - db_pool: &Pool, - uid: Option, +#[derive(Clone, Debug, Identifiable, Queryable, Selectable)] +#[diesel(table_name = crate::schema::csrf_tokens)] +#[diesel(check_for_backend(Pg))] +pub struct CsrfToken { + pub id: Uuid, + pub user_id: Option, + pub created_at: DateTime, +} + +impl CsrfToken { + fn all() -> Select> { + csrf_tokens.select(Self::as_select()) + } + + pub fn is_not_expired() -> Gt> { + let ttl = TimeDelta::hours(24); + let min_created_at: DateTime = Utc::now() - ttl; + created_at.gt(min_created_at) + } + + pub fn with_user_id(token_user_id: Option) -> IsNotDistinctFrom> { + user_id.is_not_distinct_from(token_user_id) + } + + pub fn with_token_id(token_id: Uuid) -> Eq { + id.eq(token_id) + } +} + +pub async fn generate_csrf_token( + db_conn: &Connection, + with_user_id: Option, ) -> Result { - let id = Uuid::new_v4(); - let expires_at = - Utc::now() + TimeDelta::new(TTL_SEC, 0).context("Failed to generate TimeDelta")?; - db_pool - .get() - .await? + let token_id = Uuid::new_v4(); + db_conn .interact(move |conn| { - diesel::insert_into(schema::csrf_tokens::table) + diesel::insert_into(csrf_tokens) .values(( - schema::csrf_tokens::id.eq(id), - schema::csrf_tokens::user_id.eq(uid), - schema::csrf_tokens::expires_at.eq(expires_at), + id.eq(token_id), + user_id.eq(with_user_id), + created_at.eq(diesel::dsl::now), )) .execute(conn) }) .await .unwrap()?; - Ok(format!("{}{}", TOKEN_PREFIX, id.hyphenated().to_string())) + Ok(format!("{}{}", TOKEN_PREFIX, token_id.simple().to_string())) } -pub async fn validate_csrf_token_for_user( - db_pool: &Pool, +pub async fn validate_csrf_token( + db_conn: &Connection, token: &str, - uid: Option, + with_user_id: Option, ) -> Result { - let id = match Uuid::try_parse(&token[TOKEN_PREFIX.len()..]) { - Ok(id) => id, + let token_id = match Uuid::try_parse(&token[TOKEN_PREFIX.len()..]) { + Ok(token_id) => token_id, Err(_) => return Ok(false), }; - let row = db_pool - .get() - .await? + Ok(db_conn .interact(move |conn| { - schema::csrf_tokens::table - .select(CsrfToken::as_select()) - .filter(schema::csrf_tokens::id.eq(id)) - .filter(schema::csrf_tokens::expires_at.gt(Utc::now())) - .filter(schema::csrf_tokens::user_id.is_not_distinct_from(uid)) + CsrfToken::all() + .filter(CsrfToken::with_token_id(token_id)) + .filter(CsrfToken::with_user_id(with_user_id)) + .filter(CsrfToken::is_not_expired()) .first(conn) .optional() }) .await - .unwrap()?; - Ok(row.is_some()) + .unwrap()? + .is_some()) } diff --git a/src/guards.rs b/src/guards.rs new file mode 100644 index 0000000..3252f88 --- /dev/null +++ b/src/guards.rs @@ -0,0 +1,45 @@ +macro_rules! require_team_membership { + ($current_user:expr, $team_id:expr, $db_conn:expr) => {{ + let current_user_id = $current_user.id.clone(); + match $db_conn + .interact(move |conn| { + crate::team_memberships::TeamMembership::all() + .filter(crate::team_memberships::TeamMembership::with_user_id( + current_user_id, + )) + .filter(crate::team_memberships::TeamMembership::with_team_id( + $team_id, + )) + .first(conn) + .optional() + }) + .await + .unwrap()? + { + Some((team, _)) => team, + None => { + return Ok(( + axum::http::StatusCode::FORBIDDEN, + "not a member of requested team".to_string(), + ) + .into_response()); + } + } + }}; +} +pub(crate) use require_team_membership; + +macro_rules! require_valid_csrf_token { + ($csrf_token:expr, $current_user:expr, $db_conn:expr) => {{ + if !crate::csrf::validate_csrf_token(&$db_conn, &$csrf_token, Some($current_user.id)) + .await? + { + return Ok(( + axum::http::StatusCode::FORBIDDEN, + "invalid CSRF token".to_string(), + ) + .into_response()); + } + }}; +} +pub(crate) use require_valid_csrf_token; diff --git a/src/main.rs b/src/main.rs index e357556..dbd5545 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,20 +3,19 @@ mod app_error; mod app_state; mod auth; mod csrf; -mod models; +mod guards; mod projects; mod router; mod schema; mod sessions; mod settings; +mod team_memberships; +mod teams; mod users; -use app_state::AppState; -use router::new_router; -use sessions::PgStore; use tracing_subscriber::EnvFilter; -use crate::settings::Settings; +use crate::{app_state::AppState, router::new_router, sessions::PgStore, settings::Settings}; #[tokio::main] async fn main() { diff --git a/src/models.rs b/src/models.rs index d5526a2..8b13789 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,42 +1 @@ -use chrono::{offset::Utc, DateTime}; -use diesel::{pg::Pg, prelude::*}; -use uuid::Uuid; -use crate::schema; - -#[derive(Clone, Debug, Identifiable, Insertable, Queryable, Selectable)] -#[diesel(table_name = schema::teams)] -#[diesel(check_for_backend(Pg))] -pub struct Team { - pub id: Uuid, - pub name: String, -} - -#[derive(Associations, Clone, Debug, Identifiable, Insertable, Queryable, Selectable)] -#[diesel(table_name = schema::team_memberships)] -#[diesel(belongs_to(Team))] -#[diesel(belongs_to(crate::users::User))] -#[diesel(primary_key(team_id, user_id))] -#[diesel(check_for_backend(Pg))] -pub struct TeamMembership { - pub team_id: Uuid, - pub user_id: Uuid, - pub roles: Vec>, -} - -#[derive(Clone, Debug, Identifiable, Queryable, Selectable)] -#[diesel(table_name = schema::browser_sessions)] -#[diesel(check_for_backend(Pg))] -pub struct BrowserSession { - pub id: String, - pub serialized: String, -} - -#[derive(Clone, Debug, Identifiable, Queryable, Selectable)] -#[diesel(table_name = schema::csrf_tokens)] -#[diesel(check_for_backend(Pg))] -pub struct CsrfToken { - pub id: Uuid, - pub user_id: Option, - pub expires_at: DateTime, -} diff --git a/src/projects.rs b/src/projects.rs index 4178539..0c29710 100644 --- a/src/projects.rs +++ b/src/projects.rs @@ -1,7 +1,14 @@ -use diesel::prelude::*; +use diesel::{ + dsl::{auto_type, AsSelect}, + pg::Pg, + prelude::*, +}; use uuid::Uuid; -use crate::{models::Team, schema}; +use crate::{ + schema::{self, projects::dsl::*}, + teams::Team, +}; #[derive(Associations, Clone, Debug, Identifiable, Insertable, Queryable, Selectable)] #[diesel(table_name = schema::projects)] @@ -11,3 +18,11 @@ pub struct Project { pub team_id: Uuid, pub name: String, } + +impl Project { + #[auto_type(no_type_alias)] + pub fn all() -> _ { + let select: AsSelect = Project::as_select(); + projects.select(select) + } +} diff --git a/src/router.rs b/src/router.rs index 2e37bab..22099cf 100644 --- a/src/router.rs +++ b/src/router.rs @@ -1,13 +1,11 @@ -use anyhow::anyhow; use askama_axum::Template; use axum::{ extract::{Path, State}, - http::status::StatusCode, response::{Html, IntoResponse, Redirect}, routing::{get, post}, Form, Router, }; -use diesel::{dsl::insert_into, prelude::*, result::Error::NotFound}; +use diesel::{dsl::insert_into, prelude::*}; use serde::Deserialize; use tower::ServiceBuilder; use tower_http::{ @@ -20,12 +18,15 @@ use uuid::Uuid; use crate::{ api_keys::ApiKey, app_error::AppError, - app_state::AppState, - auth::{self, AuthInfo}, - csrf::{generate_csrf_token_for_user, validate_csrf_token_for_user}, - models::{Team, TeamMembership}, + app_state::{AppState, DbConn}, + auth, + csrf::generate_csrf_token, + guards, projects::Project, schema, + settings::Settings, + team_memberships::TeamMembership, + teams::Team, users::{CurrentUser, User}, }; @@ -59,14 +60,12 @@ async fn landing_page(State(state): State) -> impl IntoResponse { } async fn teams_page( - State(state): State, + State(Settings { base_path, .. }): State, + DbConn(conn): DbConn, CurrentUser(current_user): CurrentUser, ) -> Result { let current_user_id = current_user.id.clone(); - let teams_of_current_user = state - .db_pool - .get() - .await? + let teams_of_current_user = conn .interact(move |conn| { schema::team_memberships::table .inner_join(schema::teams::table) @@ -75,8 +74,7 @@ async fn teams_page( .load(conn) }) .await - .unwrap() - .unwrap(); + .unwrap()?; #[derive(Template)] #[template(path = "teams.html")] struct ResponseTemplate { @@ -86,13 +84,12 @@ async fn teams_page( } Ok(Html( ResponseTemplate { + base_path, current_user, - base_path: state.settings.base_path, teams: teams_of_current_user, } .render()?, - ) - .into_response()) + )) } async fn team_page(State(state): State, Path(team_id): Path) -> impl IntoResponse { @@ -108,76 +105,30 @@ struct PostNewApiKeyForm { } async fn post_new_api_key( - State(state): State, + State(Settings { base_path, .. }): State, + DbConn(db_conn): DbConn, Path(team_id): Path, - user_info: AuthInfo, + CurrentUser(current_user): CurrentUser, Form(form): Form, ) -> Result { - let current_uid = user_info.sub.clone(); - let current_user = state - .db_pool - .get() - .await? - .interact(move |conn| { - schema::users::table - .filter(schema::users::uid.eq(current_uid)) - .select(User::as_select()) - .first(conn) - }) - .await - .unwrap() // Bubble up panic from callback - .map_err(|diesel_err| match diesel_err { - NotFound => AppError::InternalServerError(anyhow!( - "user not found in database for uid {}", - user_info.sub - )), - _ => AppError::InternalServerError(diesel_err.into()), - })?; - if !validate_csrf_token_for_user(&state.db_pool, &form.csrf_token, Some(current_user.id)) - .await? - { - return Ok((StatusCode::FORBIDDEN, "invalid CSRF token".to_string()).into_response()); - } - let team_membership = match state - .db_pool - .get() - .await? - .interact(move |conn| { - schema::team_memberships::table - .filter(schema::team_memberships::team_id.eq(team_id)) - .filter(schema::team_memberships::user_id.eq(current_user.id)) - .select(TeamMembership::as_select()) - .first(conn) - .optional() - }) - .await - .unwrap() - .unwrap() - { - Some(team_membership) => team_membership, - None => { - return Ok(( - StatusCode::FORBIDDEN, - "not a member of requested team".to_string(), - ) - .into_response()); - } - }; - ApiKey::generate_for_team(&state.db_pool, team_membership.team_id.clone()).await?; + guards::require_valid_csrf_token!(form.csrf_token, current_user, db_conn); + let team = guards::require_team_membership!(current_user, team_id, db_conn); + + ApiKey::generate_for_team(&db_conn, team.id.clone()).await?; Ok(Redirect::to(&format!( "{}/teams/{}/projects", - state.settings.base_path, - team_membership.team_id.hyphenated().to_string() + base_path, + team.id.hyphenated().to_string() )) .into_response()) } async fn new_team_page( - State(state): State, + State(Settings { base_path, .. }): State, + DbConn(db_conn): DbConn, CurrentUser(current_user): CurrentUser, ) -> Result { - let csrf_token = - generate_csrf_token_for_user(&state.db_pool, Some(current_user.id.clone())).await?; + let csrf_token = generate_csrf_token(&db_conn, Some(current_user.id)).await?; #[derive(Template)] #[template(path = "new-team.html")] struct ResponseTemplate { @@ -187,9 +138,9 @@ async fn new_team_page( } Ok(Html( ResponseTemplate { + base_path, csrf_token, current_user, - base_path: state.settings.base_path, } .render()?, )) @@ -202,33 +153,13 @@ struct PostNewTeamForm { } async fn post_new_team( - State(state): State, - user_info: AuthInfo, + DbConn(db_conn): DbConn, + State(Settings { base_path, .. }): State, + CurrentUser(current_user): CurrentUser, Form(form): Form, ) -> Result { - let current_uid = user_info.sub.clone(); - let current_user = state - .db_pool - .get() - .await? - .interact(move |conn| { - schema::users::table - .filter(schema::users::uid.eq(current_uid)) - .select(User::as_select()) - .first(conn) - }) - .await - .unwrap() - .unwrap(); - if !validate_csrf_token_for_user( - &state.db_pool, - &form.csrf_token, - Some(current_user.id.clone()), - ) - .await? - { - return Err(anyhow!("Invalid CSRF token").into()); - } + guards::require_valid_csrf_token!(form.csrf_token, current_user, db_conn); + let team_id = Uuid::now_v7(); let team = Team { id: team_id.clone(), @@ -239,10 +170,7 @@ async fn post_new_team( user_id: current_user.id, roles: vec![Some("OWNER".to_string())], }; - state - .db_pool - .get() - .await? + db_conn .interact(move |conn| { conn.transaction(move |conn| { insert_into(schema::teams::table) @@ -257,50 +185,20 @@ async fn post_new_team( .await .unwrap() .unwrap(); - ApiKey::generate_for_team(&state.db_pool, team_id.clone()).await?; - Ok(Redirect::to(&format!( - "{}/teams/{}/projects", - state.settings.base_path, team_id - ))) + ApiKey::generate_for_team(&db_conn, team_id.clone()).await?; + Ok(Redirect::to(&format!("{}/teams/{}/projects", base_path, team_id)).into_response()) } async fn projects_page( - State(state): State, + State(Settings { base_path, .. }): State, + DbConn(db_conn): DbConn, Path(team_id): Path, CurrentUser(current_user): CurrentUser, ) -> Result { - let current_user_id = current_user.id.clone(); - let team = match state - .db_pool - .get() - .await? - .interact(move |conn| { - schema::team_memberships::table - .inner_join(schema::teams::table) - .filter(schema::team_memberships::user_id.eq(current_user_id)) - .filter(schema::teams::id.eq(team_id)) - .select(Team::as_select()) - .first(conn) - .optional() - }) - .await - .unwrap() - .unwrap() - { - Some(team) => team, - None => { - return Ok(( - StatusCode::FORBIDDEN, - "not a member of requested team".to_string(), - ) - .into_response()); - } - }; + let team = guards::require_team_membership!(current_user, team_id, db_conn); + let team_id = team.id.clone(); - let api_keys = state - .db_pool - .get() - .await? + let api_keys = db_conn .interact(move |conn| { schema::api_keys::table .filter(schema::api_keys::team_id.eq(team_id)) @@ -308,8 +206,7 @@ async fn projects_page( .load(conn) }) .await - .unwrap() - .unwrap(); + .unwrap()?; #[derive(Template)] #[template(path = "projects.html")] struct ResponseTemplate { @@ -320,14 +217,13 @@ async fn projects_page( team: Team, current_user: User, } - let csrf_token = - generate_csrf_token_for_user(&state.db_pool, Some(current_user.id.clone())).await?; + let csrf_token = generate_csrf_token(&db_conn, Some(current_user.id.clone())).await?; Ok(Html( ResponseTemplate { + base_path, csrf_token, current_user, team, - base_path: state.settings.base_path, keys: api_keys, projects: vec![], } diff --git a/src/schema.rs b/src/schema.rs index ab042aa..57b0f71 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -11,6 +11,8 @@ diesel::table! { browser_sessions (id) { id -> Text, serialized -> Text, + created_at -> Timestamptz, + last_seen_at -> Timestamptz, } } @@ -18,7 +20,7 @@ diesel::table! { csrf_tokens (id) { id -> Uuid, user_id -> Nullable, - expires_at -> Timestamptz, + created_at -> Timestamptz, } } diff --git a/src/sessions.rs b/src/sessions.rs index c1ebee6..aa19379 100644 --- a/src/sessions.rs +++ b/src/sessions.rs @@ -1,12 +1,21 @@ use anyhow::Result; use async_session::{async_trait, Session, SessionStore}; -use diesel::prelude::*; +use chrono::{DateTime, TimeDelta, Utc}; +use diesel::{pg::Pg, prelude::*, upsert::excluded}; -use crate::{models::BrowserSession, schema}; +use crate::schema::browser_sessions::dsl::*; + +#[derive(Clone, Debug, Identifiable, Queryable, Selectable)] +#[diesel(table_name = crate::schema::browser_sessions)] +#[diesel(check_for_backend(Pg))] +pub struct BrowserSession { + pub id: String, + pub serialized: String, + pub last_seen_at: DateTime, +} #[derive(Clone)] pub struct PgStore { - // TODO: reference instead of clone pool: deadpool_diesel::postgres::Pool, } @@ -18,7 +27,7 @@ impl PgStore { impl std::fmt::Debug for PgStore { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "PgStore {{ pool }}")?; + write!(f, "PgStore")?; Ok(()).into() } } @@ -26,39 +35,47 @@ impl std::fmt::Debug for PgStore { #[async_trait] impl SessionStore for PgStore { async fn load_session(&self, cookie_value: String) -> Result> { - let conn = self.pool.get().await?; let session_id = Session::id_from_cookie_value(&cookie_value)?; - let rows = conn + let timestamp_stale = Utc::now() - TimeDelta::days(7); + let conn = self.pool.get().await?; + let row = conn .interact(move |conn| { - schema::browser_sessions::table - .filter(schema::browser_sessions::id.eq(session_id)) - .select(BrowserSession::as_select()) - .load(conn) + // Drop all sessions without recent activity + diesel::delete(browser_sessions.filter(last_seen_at.lt(timestamp_stale))) + .execute(conn)?; + diesel::update(browser_sessions.filter(id.eq(session_id))) + .set(last_seen_at.eq(diesel::dsl::now)) + .returning(BrowserSession::as_returning()) + .get_result(conn) + .optional() }) .await .unwrap()?; - if rows.len() == 0 { - Ok(None) - } else { - Ok(Some(serde_json::from_str::( - rows[0].serialized.as_str(), - )?)) - } + Ok(match row { + Some(session) => Some(serde_json::from_str::( + session.serialized.as_str(), + )?), + None => None, + }) } async fn store_session(&self, session: Session) -> Result> { - let serialized = serde_json::to_string(&session)?; - let conn = self.pool.get().await?; + let serialized_data = serde_json::to_string(&session)?; let session_id = session.id().to_string(); + let conn = self.pool.get().await?; conn.interact(move |conn| { - diesel::insert_into(schema::browser_sessions::table) + diesel::insert_into(browser_sessions) .values(( - schema::browser_sessions::id.eq(session_id), - schema::browser_sessions::serialized.eq(serialized.clone()), + id.eq(session_id), + serialized.eq(serialized_data), + last_seen_at.eq(diesel::dsl::now), )) - .on_conflict(schema::browser_sessions::id) + .on_conflict(id) .do_update() - .set(schema::browser_sessions::serialized.eq(serialized.clone())) + .set(( + serialized.eq(excluded(serialized)), + last_seen_at.eq(excluded(last_seen_at)), + )) .execute(conn) }) .await @@ -70,11 +87,7 @@ impl SessionStore for PgStore { async fn destroy_session(&self, session: Session) -> Result<()> { let conn = self.pool.get().await?; conn.interact(move |conn| { - diesel::delete( - schema::browser_sessions::table - .filter(schema::browser_sessions::id.eq(session.id().to_string())), - ) - .execute(conn) + diesel::delete(browser_sessions.filter(id.eq(session.id().to_string()))).execute(conn) }) .await .unwrap()?; @@ -83,7 +96,7 @@ impl SessionStore for PgStore { async fn clear_store(&self) -> Result<()> { let conn = self.pool.get().await?; - conn.interact(move |conn| diesel::delete(schema::browser_sessions::table).execute(conn)) + conn.interact(move |conn| diesel::delete(browser_sessions).execute(conn)) .await .unwrap()?; Ok(()) diff --git a/src/settings.rs b/src/settings.rs index 805caf7..2e6de90 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,7 +1,10 @@ +use axum::extract::FromRef; use config::{Config, ConfigError, Environment}; use dotenvy::dotenv; use serde::Deserialize; +use crate::app_state::AppState; + #[derive(Clone, Debug, Deserialize)] pub struct Settings { #[serde(default)] @@ -54,3 +57,9 @@ impl Settings { s.try_deserialize() } } + +impl FromRef for Settings { + fn from_ref(state: &AppState) -> Self { + state.settings.clone() + } +} diff --git a/src/team_memberships.rs b/src/team_memberships.rs new file mode 100644 index 0000000..d6392d5 --- /dev/null +++ b/src/team_memberships.rs @@ -0,0 +1,43 @@ +use diesel::{ + dsl::{AsSelect, Eq}, + pg::Pg, + prelude::*, +}; +use uuid::Uuid; + +use crate::{ + schema::{self, team_memberships::dsl::*}, + teams::Team, + users::User, +}; + +#[derive(Associations, Clone, Debug, Identifiable, Insertable, Queryable, Selectable)] +#[diesel(table_name = schema::team_memberships)] +#[diesel(belongs_to(crate::teams::Team))] +#[diesel(belongs_to(crate::users::User))] +#[diesel(primary_key(team_id, user_id))] +#[diesel(check_for_backend(Pg))] +pub struct TeamMembership { + pub team_id: Uuid, + pub user_id: Uuid, + pub roles: Vec>, +} + +impl TeamMembership { + #[diesel::dsl::auto_type(no_type_alias)] + pub fn all() -> _ { + let select: AsSelect<(Team, User), Pg> = <(Team, User)>::as_select(); + team_memberships + .inner_join(schema::teams::table) + .inner_join(schema::users::table) + .select(select) + } + + pub fn with_team_id(team_id_value: Uuid) -> Eq { + team_id.eq(team_id_value) + } + + pub fn with_user_id(user_id_value: Uuid) -> Eq { + user_id.eq(user_id_value) + } +} diff --git a/src/teams.rs b/src/teams.rs new file mode 100644 index 0000000..32423fb --- /dev/null +++ b/src/teams.rs @@ -0,0 +1,22 @@ +use diesel::{ + dsl::{AsSelect, Select}, + pg::Pg, + prelude::*, +}; +use uuid::Uuid; + +use crate::schema::teams::dsl::*; + +#[derive(Clone, Debug, Identifiable, Insertable, Queryable, Selectable)] +#[diesel(table_name = crate::schema::teams)] +#[diesel(check_for_backend(Pg))] +pub struct Team { + pub id: Uuid, + pub name: String, +} + +impl Team { + pub fn all() -> Select> { + teams.select(Team::as_select()) + } +} diff --git a/src/users.rs b/src/users.rs index 9b714ac..f34149c 100644 --- a/src/users.rs +++ b/src/users.rs @@ -5,20 +5,19 @@ use axum::{ RequestPartsExt, }; use diesel::{ - associations::Identifiable, deserialize::Queryable, dsl::insert_into, pg::Pg, prelude::*, + associations::Identifiable, + deserialize::Queryable, + dsl::{insert_into, AsSelect, Eq, Select}, + pg::Pg, + prelude::*, Selectable, }; use uuid::Uuid; -use crate::{ - app_error::AppError, - app_state::AppState, - auth::AuthInfo, - schema::{self, users}, -}; +use crate::{app_error::AppError, app_state::AppState, auth::AuthInfo, schema::users::dsl::*}; #[derive(Clone, Debug, Identifiable, Insertable, Queryable, Selectable)] -#[diesel(table_name = schema::users)] +#[diesel(table_name = crate::schema::users)] #[diesel(check_for_backend(Pg))] pub struct User { pub id: Uuid, @@ -26,6 +25,16 @@ pub struct User { pub email: String, } +impl User { + pub fn all() -> Select> { + users.select(User::as_select()) + } + + pub fn with_uid(uid_value: &str) -> Eq { + uid.eq(uid_value) + } +} + #[derive(Clone, Debug)] pub struct CurrentUser(pub User); @@ -46,9 +55,8 @@ impl FromRequestParts for CurrentUser { .await .map_err(|err| CurrentUserRejection::InternalServerError(err.into()))? .interact(move |conn| { - let maybe_current_user = users::table - .filter(users::uid.eq(auth_info.sub.clone())) - .select(User::as_select()) + let maybe_current_user = User::all() + .filter(User::with_uid(&auth_info.sub)) .first(conn) .optional()?; if let Some(current_user) = maybe_current_user { @@ -59,11 +67,11 @@ impl FromRequestParts for CurrentUser { uid: auth_info.sub, email: auth_info.email, }; - insert_into(users::table) - .values(&new_user) - .returning(User::as_returning()) - .on_conflict(users::uid) + insert_into(users) + .values(new_user) + .on_conflict(uid) .do_nothing() + .returning(User::as_returning()) .get_result(conn) }) .await