Compare commits

..

4 commits

8 changed files with 105 additions and 104 deletions

41
Cargo.lock generated
View file

@ -38,21 +38,6 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "alloc-no-stdlib"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3"
[[package]]
name = "alloc-stdlib"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece"
dependencies = [
"alloc-no-stdlib",
]
[[package]] [[package]]
name = "allocator-api2" name = "allocator-api2"
version = "0.2.18" version = "0.2.18"
@ -205,11 +190,10 @@ dependencies = [
[[package]] [[package]]
name = "async-compression" name = "async-compression"
version = "0.4.18" version = "0.4.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df895a515f70646414f4b45c0b79082783b80552b373a68283012928df56f522" checksum = "310c9bcae737a48ef5cdee3174184e6d548b292739ede61a1f955ef76a738861"
dependencies = [ dependencies = [
"brotli",
"flate2", "flate2",
"futures-core", "futures-core",
"memchr", "memchr",
@ -481,27 +465,6 @@ dependencies = [
"generic-array", "generic-array",
] ]
[[package]]
name = "brotli"
version = "7.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
"brotli-decompressor",
]
[[package]]
name = "brotli-decompressor"
version = "4.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
]
[[package]] [[package]]
name = "bumpalo" name = "bumpalo"
version = "3.16.0" version = "3.16.0"

View file

@ -21,7 +21,7 @@ futures = "0.3.31"
uuid = { version = "1.11.0", features = ["js", "serde", "v4", "v7"] } uuid = { version = "1.11.0", features = ["js", "serde", "v4", "v7"] }
rand = "0.8.5" rand = "0.8.5"
tracing-subscriber = { version = "0.3.19", features = ["chrono", "env-filter"] } tracing-subscriber = { version = "0.3.19", features = ["chrono", "env-filter"] }
tower-http = { version = "0.6.2", features = ["compression-br", "compression-gzip", "fs", "trace", "tracing"] } tower-http = { version = "0.6.2", features = ["compression-gzip", "fs", "normalize-path", "trace"] }
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "tracing"] } tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "tracing"] }
deadpool-diesel = { version = "0.6.1", features = ["postgres", "serde"] } deadpool-diesel = { version = "0.6.1", features = ["postgres", "serde"] }
axum = { version = "0.8.1", features = ["macros"] } axum = { version = "0.8.1", features = ["macros"] }

View file

@ -1,15 +1,15 @@
RUST_LOG=debug RUST_LOG=debug
DATABASE_URL=postgresql://shoutdotdev:callous@127.0.0.1:5447/shoutdotdev DATABASE_URL=postgresql://shoutdotdev:callous@127.0.0.1:5447/shoutdotdev
AUTH.CLIENT_ID= AUTH__CLIENT_ID=
AUTH.CLIENT_SECRET= AUTH__CLIENT_SECRET=
AUTH.REDIRECT_URL=http://localhost:3000/auth/callback AUTH__REDIRECT_URL=http://localhost:3000/auth/callback
AUTH.AUTH_URL=https://example.com/authorize AUTH__AUTH_URL=https://example.com/authorize
AUTH.TOKEN_URL=https://example.com/token AUTH__TOKEN_URL=https://example.com/token
AUTH.USERINFO_URL=https://example.com/userinfo AUTH__USERINFO_URL=https://example.com/userinfo
# The .env parser (dotenvy) requires quotes around any value with spaces. Note # The .env parser (dotenvy) requires quotes around any value with spaces. Note
# that in this regard it is incompatible with Docker's --env-file parser. # that in this regard it is incompatible with Docker's --env-file parser.
EMAIL.VERIFICATION_FROM=no-reply@shout.dev EMAIL__VERIFICATION_FROM=no-reply@shout.dev
EMAIL.MESSAGE_FROM=no-reply@shout.dev EMAIL__MESSAGE_FROM=no-reply@shout.dev
EMAIL.SMTP.SERVER=smtp.example.com EMAIL__SMTP__SERVER=smtp.example.com
EMAIL.SMTP.USERNAME= EMAIL__SMTP__USERNAME=
EMAIL.SMTP.PASSWORD= EMAIL__SMTP__PASSWORD=

View file

@ -45,12 +45,12 @@ pub fn new_oauth_client(settings: &Settings) -> Result<BasicClient, AppError> {
pub fn new_router() -> Router<AppState> { pub fn new_router() -> Router<AppState> {
Router::new() Router::new()
.route("/login", get(propel_auth)) .route("/login", get(start_login))
.route("/callback", get(login_authorized)) .route("/callback", get(login_authorized))
.route("/logout", get(logout)) .route("/logout", get(logout))
} }
pub async fn propel_auth( pub async fn start_login(
State(state): State<AppState>, State(state): State<AppState>,
State(Settings { State(Settings {
auth: auth_settings, auth: auth_settings,
@ -64,7 +64,7 @@ pub async fn propel_auth(
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
if session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).is_some() { if session.get::<AuthInfo>(SESSION_KEY_AUTH_INFO).is_some() {
tracing::debug!("already logged in, redirecting..."); tracing::debug!("already logged in, redirecting...");
return Ok(Redirect::to(&base_path).into_response()); return Ok(Redirect::to(&format!("{}/", base_path)).into_response());
} }
} }
let csrf_token = CsrfToken::new_random(); let csrf_token = CsrfToken::new_random();
@ -124,7 +124,7 @@ pub async fn logout(
} }
let jar = jar.remove(Cookie::from(auth_settings.cookie_name)); let jar = jar.remove(Cookie::from(auth_settings.cookie_name));
tracing::debug!("Removed session cookie from jar."); tracing::debug!("Removed session cookie from jar.");
Ok((jar, Redirect::to(&base_path))) Ok((jar, Redirect::to(&format!("{}/", base_path))))
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -167,11 +167,13 @@ pub async fn login_authorized(
"OAuth CSRF tokens do not match.".to_string(), "OAuth CSRF tokens do not match.".to_string(),
)); ));
} }
tracing::debug!("exchanging authorization code");
let response = state let response = state
.oauth_client .oauth_client
.exchange_code(AuthorizationCode::new(query.code.clone())) .exchange_code(AuthorizationCode::new(query.code.clone()))
.request_async(async_http_client) .request_async(async_http_client)
.await?; .await?;
tracing::debug!("fetching user info");
let auth_info: AuthInfo = reqwest_client let auth_info: AuthInfo = reqwest_client
.get(auth_settings.userinfo_url.as_str()) .get(auth_settings.userinfo_url.as_str())
.bearer_auth(response.access_token().secret()) .bearer_auth(response.access_token().secret())
@ -179,6 +181,7 @@ pub async fn login_authorized(
.await? .await?
.json() .json()
.await?; .await?;
tracing::debug!("updating session");
session.insert(SESSION_KEY_AUTH_INFO, &auth_info)?; session.insert(SESSION_KEY_AUTH_INFO, &auth_info)?;
session.insert(SESSION_KEY_AUTH_REFRESH_TOKEN, response.refresh_token())?; session.insert(SESSION_KEY_AUTH_REFRESH_TOKEN, response.refresh_token())?;
if state.session_store.store_session(session).await?.is_some() { if state.session_store.store_session(session).await?.is_some() {
@ -187,7 +190,8 @@ pub async fn login_authorized(
) )
.into()); .into());
} }
Ok(Redirect::to(&base_path)) tracing::debug!("successfully authenticated");
Ok(Redirect::to(&format!("{}/", base_path)))
} }
impl FromRequestParts<AppState> for AuthInfo { impl FromRequestParts<AppState> for AuthInfo {

View file

@ -23,11 +23,16 @@ mod worker;
use std::process::exit; use std::process::exit;
use axum::{extract::Request, middleware::map_request, ServiceExt};
use chrono::{TimeDelta, Utc}; use chrono::{TimeDelta, Utc};
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
use email::SmtpOptions; use email::SmtpOptions;
use tokio::time::sleep; use tokio::time::sleep;
use tower::ServiceBuilder;
use tower_http::{
compression::CompressionLayer, normalize_path::NormalizePathLayer, trace::TraceLayer,
};
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use crate::{ use crate::{
@ -129,7 +134,16 @@ async fn main() {
settings.port, settings.port,
settings.base_path settings.base_path
); );
axum::serve(listener, router).await.unwrap();
let app = ServiceExt::<Request>::into_make_service(
ServiceBuilder::new()
.layer(map_request(lowercase_uri_path))
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.layer(NormalizePathLayer::trim_trailing_slash())
.service(router),
);
axum::serve(listener, app).await.unwrap();
} }
Commands::Worker { auto_loop_seconds } => { Commands::Worker { auto_loop_seconds } => {
if let Some(loop_seconds) = auto_loop_seconds { if let Some(loop_seconds) = auto_loop_seconds {
@ -155,3 +169,17 @@ async fn main() {
} }
} }
} }
async fn lowercase_uri_path<B>(mut request: Request<B>) -> Request<B> {
let path = request.uri().path().to_lowercase();
let path_and_query = match request.uri().query() {
Some(query) => format!("{}?{}", path, query),
None => path,
};
let builder =
axum::http::uri::Builder::from(request.uri().clone()).path_and_query(path_and_query);
*request.uri_mut() = builder
.build()
.expect("lowercasing URI path should not break it");
request
}

View file

@ -13,12 +13,7 @@ use diesel::{delete, dsl::insert_into, prelude::*, update};
use rand::{distributions::Uniform, Rng}; use rand::{distributions::Uniform, Rng};
use regex::Regex; use regex::Regex;
use serde::Deserialize; use serde::Deserialize;
use tower::ServiceBuilder; use tower_http::services::{ServeDir, ServeFile};
use tower_http::{
compression::CompressionLayer,
services::{ServeDir, ServeFile},
trace::TraceLayer,
};
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
@ -46,48 +41,50 @@ const MAX_VERIFICATION_GUESSES: u32 = 100;
pub fn new_router(state: AppState) -> Router<()> { pub fn new_router(state: AppState) -> Router<()> {
let base_path = state.settings.base_path.clone(); let base_path = state.settings.base_path.clone();
Router::new().nest( let app = Router::new()
base_path.as_str(), .route("/", get(landing_page))
Router::new() .merge(v0_router::new_router(state.clone()))
.route("/", get(landing_page)) .route("/teams", get(teams_page))
.merge(v0_router::new_router(state.clone())) .route("/teams/{team_id}", get(team_page))
.route("/teams", get(teams_page)) .route("/teams/{team_id}/projects", get(projects_page))
.route("/teams/{team_id}", get(team_page)) .route("/teams/{team_id}/projects/{project_id}", get(project_page))
.route("/teams/{team_id}/projects", get(projects_page)) .route(
.route("/teams/{team_id}/projects/{project_id}", get(project_page)) "/teams/{team_id}/projects/{project_id}/update-enabled-channels",
.route( post(update_enabled_channels),
"/teams/{team_id}/projects/{project_id}/update-enabled-channels", )
post(update_enabled_channels), .route("/teams/{team_id}/new-api-key", post(post_new_api_key))
) .route("/teams/{team_id}/channels", get(channels_page))
.route("/teams/{team_id}/new-api-key", post(post_new_api_key)) .route("/teams/{team_id}/channels/{channel_id}", get(channel_page))
.route("/teams/{team_id}/channels", get(channels_page)) .route(
.route("/teams/{team_id}/channels/{channel_id}", get(channel_page)) "/teams/{team_id}/channels/{channel_id}/update-channel",
.route( post(update_channel),
"/teams/{team_id}/channels/{channel_id}/update-channel", )
post(update_channel), .route(
) "/teams/{team_id}/channels/{channel_id}/update-email-recipient",
.route( post(update_channel_email_recipient),
"/teams/{team_id}/channels/{channel_id}/update-email-recipient", )
post(update_channel_email_recipient), .route(
) "/teams/{team_id}/channels/{channel_id}/verify-email",
.route( post(verify_email),
"/teams/{team_id}/channels/{channel_id}/verify-email", )
post(verify_email), .route("/teams/{team_id}/new-channel", post(post_new_channel))
) .route("/new-team", get(new_team_page))
.route("/teams/{team_id}/new-channel", post(post_new_channel)) .route("/new-team", post(post_new_team))
.route("/new-team", get(new_team_page)) .nest("/auth", auth::new_router())
.route("/new-team", post(post_new_team)) .fallback_service(
.nest("/auth", auth::new_router()) ServeDir::new("static").not_found_service(ServeFile::new("static/_404.html")),
.fallback_service( )
.with_state(state);
let app = {
if base_path == "" {
app
} else {
Router::new().nest(&base_path, app).fallback_service(
ServeDir::new("static").not_found_service(ServeFile::new("static/_404.html")), ServeDir::new("static").not_found_service(ServeFile::new("static/_404.html")),
) )
.layer( }
ServiceBuilder::new() };
.layer(TraceLayer::new_for_http()) app
.layer(CompressionLayer::new()),
)
.with_state(state),
)
} }
async fn landing_page(State(state): State<AppState>) -> impl IntoResponse { async fn landing_page(State(state): State<AppState>) -> impl IntoResponse {

View file

@ -103,7 +103,7 @@ impl Settings {
} }
} }
let s = Config::builder() let s = Config::builder()
.add_source(Environment::default()) .add_source(Environment::default().separator("__"))
.build() .build()
.context("config error")?; .context("config error")?;
Ok(s.try_deserialize().context("deserialize error")?) Ok(s.try_deserialize().context("deserialize error")?)

9
static/_404.html Normal file
View file

@ -0,0 +1,9 @@
<!doctype html>
<html>
<head>
<title>Not found</title>
</head>
<body>
Page not found.
</body>
</html>