mod api_keys; mod app_error; mod app_state; mod auth; mod channel_selections; mod channels; mod csrf; mod email; mod governors; mod guards; mod messages; mod nav_state; mod projects; mod router; mod schema; mod sessions; mod settings; mod team_memberships; mod teams; mod users; mod v0_router; mod worker; use std::process::exit; use axum::{extract::Request, middleware::map_request, ServiceExt}; use chrono::{TimeDelta, Utc}; use clap::{Parser, Subcommand}; use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; use email::SmtpOptions; use tokio::time::sleep; use tower::ServiceBuilder; use tower_http::{ compression::CompressionLayer, normalize_path::NormalizePathLayer, trace::TraceLayer, }; use tracing_subscriber::EnvFilter; use crate::{ app_state::AppState, email::Mailer, router::new_router, sessions::PgStore, settings::Settings, worker::run_worker, }; #[derive(Parser)] #[command(version, about, long_about = None)] struct Cli { #[command(subcommand)] command: Commands, } #[derive(Subcommand)] enum Commands { /// Run web server Serve, /// Run background worker Worker { /// Loop the every n seconds instead of exiting after execution #[arg(long)] auto_loop_seconds: Option, }, // TODO: add a low-frequency worker task exclusively for self-healing // mechanisms like Governor::reset_all() } #[tokio::main] async fn main() { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .init(); let settings = Settings::load().unwrap(); let cli = Cli::parse(); let database_url = settings.database_url.clone(); let manager = deadpool_diesel::postgres::Manager::new(database_url, deadpool_diesel::Runtime::Tokio1); let db_pool = deadpool_diesel::postgres::Pool::builder(manager) .build() .unwrap(); let session_store = PgStore::new(db_pool.clone()); let reqwest_client = reqwest::ClientBuilder::new() .https_only(true) .build() .unwrap(); let oauth_client = auth::new_oauth_client(&settings).unwrap(); let mailer = if let Some(smtp_settings) = settings.email.smtp.clone() { Mailer::new_smtp(SmtpOptions { server: smtp_settings.server, username: smtp_settings.username, password: smtp_settings.password, }) .unwrap() } else if let Some(postmark_settings) = settings.email.postmark.clone() { Mailer::new_postmark(postmark_settings.server_token) .unwrap() .with_reqwest_client(reqwest_client.clone()) } else { tracing::error!("no email backend settings configured"); exit(1); }; if settings.run_database_migrations == Some(1) { const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/"); // Run migrations on server startup let conn = db_pool.get().await.unwrap(); conn.interact(|conn| conn.run_pending_migrations(MIGRATIONS).map(|_| ())) .await .unwrap() .unwrap(); } let app_state = AppState { db_pool: db_pool.clone(), mailer, oauth_client, reqwest_client, session_store, settings: settings.clone(), }; match &cli.command { Commands::Serve => { let router = new_router(app_state); let listener = tokio::net::TcpListener::bind((settings.host.clone(), settings.port.clone())) .await .unwrap(); tracing::info!( "App running at http://{}:{}{}", settings.host, settings.port, settings.base_path ); let app = ServiceExt::::into_make_service( ServiceBuilder::new() .layer(map_request(lowercase_uri_path)) .layer(TraceLayer::new_for_http()) .layer(CompressionLayer::new()) .layer(NormalizePathLayer::trim_trailing_slash()) .service(router), ); axum::serve(listener, app).await.unwrap(); } Commands::Worker { auto_loop_seconds } => { if let Some(loop_seconds) = auto_loop_seconds { let loop_delta = TimeDelta::seconds(i64::from(*loop_seconds)); loop { let t_next_loop = Utc::now() + loop_delta; if let Err(err) = run_worker(app_state.clone()).await { tracing::error!("{}", err) } let sleep_delta = t_next_loop - Utc::now(); match sleep_delta.to_std() { Ok(duration) => { sleep(duration).await; } Err(_) => { /* sleep_delta was < 0, so don't sleep */ } } } } else { run_worker(app_state).await.unwrap(); } } } } async fn lowercase_uri_path(mut request: Request) -> Request { let path = request.uri().path().to_lowercase(); let path_and_query = match request.uri().query() { Some(query) => format!("{}?{}", path, query), None => path, }; let builder = axum::http::uri::Builder::from(request.uri().clone()).path_and_query(path_and_query); *request.uri_mut() = builder .build() .expect("lowercasing URI path should not break it"); request }