use axum::{ extract::FromRequestParts, http::request::Parts, response::{IntoResponse, Redirect, Response}, RequestPartsExt, }; use diesel::{ associations::Identifiable, deserialize::Queryable, dsl::{insert_into, AsSelect, Eq, Select}, pg::Pg, prelude::*, Selectable, }; use uuid::Uuid; use crate::{app_error::AppError, app_state::AppState, auth::AuthInfo, schema::users::dsl::*}; #[derive(Clone, Debug, Identifiable, Insertable, Queryable, Selectable)] #[diesel(table_name = crate::schema::users)] #[diesel(check_for_backend(Pg))] pub struct User { pub id: Uuid, pub uid: String, pub email: String, } impl User { pub fn all() -> Select> { users.select(User::as_select()) } pub fn with_uid(uid_value: &str) -> Eq { uid.eq(uid_value) } } #[derive(Clone, Debug)] pub struct CurrentUser(pub User); impl FromRequestParts for CurrentUser { type Rejection = CurrentUserRejection; async fn from_request_parts( parts: &mut Parts, state: &AppState, ) -> Result>::Rejection> { let auth_info = parts .extract_with_state::(state) .await .map_err(|_| CurrentUserRejection::AuthRequired(state.settings.base_path.clone()))?; let current_user = state .db_pool .get() .await .map_err(|err| CurrentUserRejection::InternalServerError(err.into()))? .interact(move |conn| { let maybe_current_user = User::all() .filter(User::with_uid(&auth_info.sub)) .first(conn) .optional()?; if let Some(current_user) = maybe_current_user { return Ok(current_user); } let new_user = User { id: Uuid::now_v7(), uid: auth_info.sub, email: auth_info.email, }; insert_into(users) .values(new_user) .on_conflict(uid) .do_nothing() .returning(User::as_returning()) .get_result(conn) }) .await .unwrap() .map_err(|err| CurrentUserRejection::InternalServerError(err.into()))?; Ok(CurrentUser(current_user)) } } pub enum CurrentUserRejection { AuthRequired(String), InternalServerError(AppError), } impl IntoResponse for CurrentUserRejection { fn into_response(self) -> Response { match self { Self::AuthRequired(base_path) => { Redirect::to(&format!("{}/auth/login", base_path)).into_response() } Self::InternalServerError(err) => err.into_response(), } } }