diff --git a/Cargo.lock b/Cargo.lock index 0f4c1d3..7a1062b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2571,6 +2571,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_variant" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a0068df419f9d9b6488fdded3f1c818522cdea328e02ce9d9f147380265a432" +dependencies = [ + "serde", +] + [[package]] name = "sha1" version = "0.10.6" @@ -2647,6 +2656,7 @@ dependencies = [ "reqwest 0.12.14", "serde", "serde_json", + "serde_variant", "tokio", "tower", "tower-http", diff --git a/Cargo.toml b/Cargo.toml index 056668b..0d28dc9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ regex = "1.11.1" reqwest = { version = "0.12.8", features = ["json"] } serde = { version = "1.0.213", features = ["derive"] } serde_json = "1.0.132" +serde_variant = "0.1.3" tokio = { version = "1.42.0", features = ["full"] } tower = "0.5.2" tower-http = { version = "0.6.2", features = ["compression-gzip", "fs", "normalize-path", "set-header", "trace"] } diff --git a/migrations/2025-04-21-071711_failable_messages/down.sql b/migrations/2025-04-21-071711_failable_messages/down.sql new file mode 100644 index 0000000..b8294df --- /dev/null +++ b/migrations/2025-04-21-071711_failable_messages/down.sql @@ -0,0 +1 @@ +ALTER TABLE messages DROP COLUMN IF EXISTS failed_at; diff --git a/migrations/2025-04-21-071711_failable_messages/up.sql b/migrations/2025-04-21-071711_failable_messages/up.sql new file mode 100644 index 0000000..83076cf --- /dev/null +++ b/migrations/2025-04-21-071711_failable_messages/up.sql @@ -0,0 +1,2 @@ +ALTER TABLE messages ADD COLUMN failed_at TIMESTAMPTZ; +CREATE INDEX ON messages (failed_at); diff --git a/src/app_state.rs b/src/app_state.rs index a92242b..d16d401 100644 --- a/src/app_state.rs +++ b/src/app_state.rs @@ -10,10 +10,12 @@ use oauth2::basic::BasicClient; use crate::{ app_error::AppError, + auth, email::{Mailer, SmtpOptions}, nav::NavbarBuilder, sessions::PgStore, settings::Settings, + slack_auth, }; /// Global app configuration @@ -25,6 +27,7 @@ pub struct App { pub reqwest_client: reqwest::Client, pub session_store: PgStore, pub settings: Settings, + pub slack_oauth_client: BasicClient, } impl App { @@ -37,7 +40,8 @@ impl App { let session_store = PgStore::new(db_pool.clone()); let reqwest_client = reqwest::ClientBuilder::new().https_only(true).build()?; - let oauth_client = crate::auth::new_oauth_client(&settings)?; + let oauth_client = auth::new_oauth_client(&settings)?; + let slack_oauth_client = slack_auth::new_oauth_client(&settings)?; let mailer = if let Some(smtp_settings) = settings.email.smtp.clone() { Mailer::new_smtp(SmtpOptions { @@ -60,6 +64,7 @@ impl App { reqwest_client, session_store, settings, + slack_oauth_client, }) } } diff --git a/src/channels.rs b/src/channels.rs index 33e682e..5bae8cc 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + use diesel::{ backend::Backend, deserialize::{self, FromSql, FromSqlRow}, @@ -17,6 +19,8 @@ use crate::{schema::channels, teams::Team}; pub const CHANNEL_BACKEND_EMAIL: &str = "email"; pub const CHANNEL_BACKEND_SLACK: &str = "slack"; +pub use crate::schema::channels::{dsl, table}; + /// Represents a target/destination for messages, with the sender configuration /// defined in the backend_config field. A single channel may be attached to /// (in other words, "enabled" or "selected" for) any number of projects within @@ -144,12 +148,11 @@ impl From for BackendConfig { } } -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct SlackBackendConfig { - pub oauth_state: String, - pub access_token: String, - pub refresh_token: String, - pub conversation_id: String, + pub oauth_state: Option, + pub oauth_tokens: Option, + pub conversation_id: Option, } impl TryFrom for SlackBackendConfig { @@ -171,3 +174,9 @@ impl From for BackendConfig { Self::Slack(value) } } + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct OAuthTokens { + pub access_token: oauth2::AccessToken, + pub refresh_token: Option, +} diff --git a/src/channels_router.rs b/src/channels_router.rs index 2fde4fb..2ea4138 100644 --- a/src/channels_router.rs +++ b/src/channels_router.rs @@ -14,14 +14,19 @@ use uuid::Uuid; use crate::{ app_error::AppError, - app_state::{AppState, DbConn}, - channels::{BackendConfig, Channel, EmailBackendConfig, CHANNEL_BACKEND_EMAIL}, + app_state::{AppState, DbConn, ReqwestClient}, + channels::{ + BackendConfig, Channel, EmailBackendConfig, SlackBackendConfig, CHANNEL_BACKEND_EMAIL, + CHANNEL_BACKEND_SLACK, + }, csrf::generate_csrf_token, email::{is_permissible_email, MailSender as _, Mailer}, guards, nav::{BreadcrumbTrail, Navbar, NavbarBuilder, NAVBAR_ITEM_CHANNELS}, schema::channels, - settings::Settings, + settings::{Settings, SlackSettings}, + slack_auth, + slack_utils::{self, ConversationType, SlackClient}, users::CurrentUser, }; @@ -63,7 +68,12 @@ pub fn new_router() -> Router { "/teams/{team_id}/channels/{channel_id}/verify-email", post(verify_email), ) + .route( + "/teams/{team_id}/channels/{channel_id}/update-slack-conversation", + post(update_channel_slack_conversation), + ) .route("/teams/{team_id}/new-channel", post(post_new_channel)) + .merge(slack_auth::new_router()) } async fn channels_page( @@ -151,6 +161,22 @@ async fn post_new_channel( }) .await .unwrap()?, + CHANNEL_BACKEND_SLACK => db_conn + .interact::<_, Result>(move |conn| { + Ok(diesel::insert_into(channels::table) + .values(( + channels::id.eq(channel_id), + channels::team_id.eq(team_id), + channels::name.eq("Untitled Slack Channel"), + channels::backend_config + .eq(Into::::into(SlackBackendConfig::default())), + )) + .returning(Channel::as_returning()) + .get_result(conn) + .context("Failed to insert new SlackChannel.")?) + }) + .await + .unwrap()?, _ => { return Err(AppError::BadRequest( "Channel type not recognized.".to_string(), @@ -167,8 +193,16 @@ async fn post_new_channel( } async fn channel_page( - State(Settings { base_path, .. }): State, + State(Settings { + base_path, + slack: SlackSettings { + api_root: slack_api_root, + .. + }, + .. + }): State, State(navbar_template): State, + State(ReqwestClient(reqwest_client)): State, DbConn(db_conn): DbConn, Path((team_id, channel_id)): Path<(Uuid, Uuid)>, CurrentUser(current_user): CurrentUser, @@ -198,7 +232,7 @@ async fn channel_page( let csrf_token = generate_csrf_token(&db_conn, Some(current_user.id)).await?; - match channel.backend_config { + match channel.backend_config.clone() { BackendConfig::Email(_) => { #[derive(Template)] #[template(path = "channel-email.html")] @@ -228,8 +262,51 @@ async fn channel_page( .render()?, )) } - BackendConfig::Slack(_) => { - Err(anyhow::anyhow!("Slack channel config page is not yet implemented.").into()) + BackendConfig::Slack(slack_data) => { + let slack_client = slack_data.oauth_tokens.map(|tokens| { + SlackClient::new(&tokens.access_token) + .with_reqwest_client(reqwest_client) + .with_api_root(&slack_api_root) + }); + let slack_channels = if let Some(client) = slack_client { + client + .list_conversations() + .with_types([ConversationType::PublicChannel]) + .with_exclude_archived(true) + .load_all() + .await? + } else { + Vec::new() + }; + #[derive(Template)] + #[template(path = "channel-slack.html")] + struct ResponseTemplate { + base_path: String, + breadcrumbs: BreadcrumbTrail, + channel: Channel, + csrf_token: String, + navbar: Navbar, + slack_channels: Vec, + } + Ok(Html( + ResponseTemplate { + breadcrumbs: BreadcrumbTrail::from_base_path(&base_path) + .with_i18n_slug("en") + .push_slug("Teams", "teams") + .push_slug(&team.name, &team.id.simple().to_string()) + .push_slug("Channels", "channels") + .push_slug(&channel.name, &channel.id.simple().to_string()), + base_path, + channel, + csrf_token, + navbar: navbar_template + .with_param("team_id", &team.id.simple().to_string()) + .with_active_item(NAVBAR_ITEM_CHANNELS) + .build(), + slack_channels, + } + .render()?, + )) } } } @@ -373,6 +450,68 @@ async fn update_channel_email_recipient( ))) } +#[derive(Deserialize)] +struct UpdateChannelSlackConversationFormBody { + csrf_token: String, + conversation_id: String, +} + +async fn update_channel_slack_conversation( + State(Settings { base_path, .. }): State, + DbConn(db_conn): DbConn, + Path((team_id, channel_id)): Path<(Uuid, Uuid)>, + CurrentUser(current_user): CurrentUser, + Form(form): Form, +) -> Result { + guards::require_valid_csrf_token(&form.csrf_token, ¤t_user, &db_conn).await?; + guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; + + tracing::debug!("updating conversation id"); + db_conn + .interact(move |conn| -> Result<(), AppError> { + let channel = Channel::all() + .filter(Channel::with_id(&channel_id)) + .filter(Channel::with_team(&team_id)) + .first(conn) + .optional() + .context("failed to load channel")? + .ok_or(AppError::NotFound( + "Channel with that team and ID not found.".to_owned(), + ))?; + tracing::debug!("loaded channel"); + let mut slack_data: SlackBackendConfig = channel + .backend_config + .try_into() + .map_err(|_| AppError::BadRequest("Not a Slack channel.".to_owned()))?; + tracing::debug!("parsed slack config"); + // There should be no need to validate that this is a real + // conversation ID, or one that the end user should have access to, + // since the end user should be allowed to wire up Shout.dev with + // any channel that is in scope for the access token. + // TODO: Ensure this holds true with private channels and groups. + slack_data.conversation_id = Some(form.conversation_id); + let num_rows = diesel::update(channels::table.filter(Channel::with_id(&channel.id))) + .set(channels::backend_config.eq(BackendConfig::from(slack_data))) + .execute(conn)?; + tracing::debug!("updated {} rows", num_rows); + // If the channel is deleted while this db interaction is running, 0 + // rows will be updated, which is technically correct in that case, + // but we should still throw an error because the intended mutation + // has not in fact been performed. + assert_eq!(num_rows, 1); + Ok(()) + }) + .await + .unwrap()?; + + Ok(Redirect::to(&format!( + "{}/en/teams/{}/channels/{}", + base_path, + team_id.simple(), + channel_id.simple() + ))) +} + #[derive(Deserialize)] struct VerifyEmailFormBody { csrf_token: String, diff --git a/src/main.rs b/src/main.rs index 585f6c9..c81c192 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,6 +32,8 @@ mod router; mod schema; mod sessions; mod settings; +mod slack_auth; +mod slack_utils; mod team_invitations; mod team_memberships; mod teams; diff --git a/src/messages.rs b/src/messages.rs index d828043..886d590 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -8,6 +8,8 @@ use uuid::Uuid; use crate::{channels::Channel, schema::messages}; +pub use crate::schema::messages::{dsl, table}; + /// A "/say" message queued for sending #[derive(Associations, Clone, Debug, Identifiable, Queryable, Selectable)] #[diesel(table_name = messages)] @@ -18,6 +20,7 @@ pub struct Message { pub channel_id: Uuid, pub created_at: DateTime, pub sent_at: Option>, + pub failed_at: Option>, pub message: String, } @@ -25,16 +28,21 @@ impl Message { #[auto_type(no_type_alias)] pub fn all() -> _ { let select: AsSelect = Message::as_select(); - messages::table.select(select) + table.select(select) + } + + #[auto_type(no_type_alias)] + pub fn with_id<'a>(id: &'a Uuid) -> _ { + dsl::id.eq(id) } #[auto_type(no_type_alias)] pub fn with_channel<'a>(channel_id: &'a Uuid) -> _ { - messages::channel_id.eq(channel_id) + dsl::channel_id.eq(channel_id) } #[auto_type(no_type_alias)] - pub fn is_not_sent() -> _ { - messages::sent_at.is_null() + pub fn is_pending() -> _ { + dsl::sent_at.is_null().and(dsl::failed_at.is_null()) } } diff --git a/src/schema.rs b/src/schema.rs index 314b5ff..aea468f 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -70,6 +70,7 @@ diesel::table! { created_at -> Timestamptz, sent_at -> Nullable, message -> Text, + failed_at -> Nullable, } } diff --git a/src/settings.rs b/src/settings.rs index 77cacdc..9c3172e 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -4,7 +4,7 @@ use config::{Config, Environment}; use dotenvy::dotenv; use serde::Deserialize; -use crate::app_state::AppState; +use crate::{app_state::AppState, slack_utils}; #[derive(Clone, Debug, Deserialize)] pub struct Settings { @@ -34,6 +34,8 @@ pub struct Settings { pub auth: AuthSettings, pub email: EmailSettings, + + pub slack: SlackSettings, } fn default_port() -> u16 { @@ -81,12 +83,22 @@ pub struct EmailSettings { pub postmark: Option, } +#[derive(Clone, Debug, Deserialize)] pub struct SlackSettings { pub client_id: String, pub client_secret: String, - pub redirect_url: String, + #[serde(default = "default_slack_auth_url")] pub auth_url: String, - pub token_url: String, + #[serde(default = "default_slack_api_root")] + pub api_root: String, +} + +fn default_slack_auth_url() -> String { + "https://slack.com/oauth/v2/authorize".to_owned() +} + +fn default_slack_api_root() -> String { + slack_utils::DEFAULT_API_ROOT.to_owned() } impl Settings { diff --git a/src/slack_auth.rs b/src/slack_auth.rs new file mode 100644 index 0000000..fc6c46e --- /dev/null +++ b/src/slack_auth.rs @@ -0,0 +1,321 @@ +use std::borrow::Cow; + +use anyhow::{Context as _, Result}; +use axum::{ + extract::{Path, Query, State}, + response::{IntoResponse, Redirect, Response}, + routing::{get, post}, + Router, +}; +use axum_extra::extract::Form; +use diesel::prelude::*; +use oauth2::{ + basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, + ClientSecret, CsrfToken, RedirectUrl, TokenResponse, TokenUrl, +}; +use serde::Deserialize; +use uuid::Uuid; + +use crate::{ + app_error::AppError, + app_state::{AppState, DbConn, ReqwestClient}, + channels::{self, BackendConfig, Channel, OAuthTokens, SlackBackendConfig}, + guards, + settings::{Settings, SlackSettings}, + users::CurrentUser, +}; + +/// Creates a new OAuth2 client to be stored in global application state. +pub fn new_oauth_client(settings: &Settings) -> Result { + Ok(BasicClient::new( + ClientId::new(settings.slack.client_id.clone()), + Some(ClientSecret::new(settings.slack.client_secret.clone())), + AuthUrl::new(settings.slack.auth_url.clone()) + .context("failed to create new authorization server URL")?, + Some( + TokenUrl::new(format!("{}/oauth.v2.access", settings.slack.api_root)) + .context("failed to create new token endpoint URL")?, + ), + )) +} + +/// Creates a router which can be nested within the higher level app router. +pub fn new_router() -> Router { + Router::new() + .route( + "/teams/{team_id}/channels/{channel_id}/slack-auth/login", + post(start_login), + ) + .route( + "/teams/{team_id}/channels/{channel_id}/slack-auth/callback", + get(callback), + ) + .route( + "/teams/{team_id}/channels/{channel_id}/slack-auth/revoke", + post(revoke), + ) +} + +#[derive(Deserialize)] +struct StartLoginFormBody { + csrf_token: String, +} + +/// HTTP get handler for /login +async fn start_login( + State(app_state): State, + State(Settings { + base_path, + frontend_host, + .. + }): State, + DbConn(db_conn): DbConn, + CurrentUser(current_user): CurrentUser, + Path((team_id, channel_id)): Path<(Uuid, Uuid)>, + Form(form): Form, +) -> Result { + guards::require_valid_csrf_token(&form.csrf_token, ¤t_user, &db_conn).await?; + guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; + + let channel = db_conn + .interact(move |conn| -> Result { + Channel::all() + .filter(Channel::with_id(&channel_id)) + .filter(Channel::with_team(&team_id)) + .first(conn) + .optional()? + .ok_or(AppError::NotFound( + "channel with that ID and team not found".to_owned(), + )) + }) + .await + .unwrap()?; + + let csrf_token = CsrfToken::new_random(); + let SlackBackendConfig { + conversation_id, .. + } = channel + .backend_config + .try_into() + .map_err(|_| anyhow::anyhow!("channel does not have a Slack backend"))?; + let slack_config = SlackBackendConfig { + conversation_id, + oauth_state: Some(csrf_token.clone()), + oauth_tokens: None, + }; + + const SCOPE_CHANNELS_READ: &str = "channels:read"; + const SCOPE_CHAT_WRITE_PUBLIC: &str = "chat:write.public"; + + let (auth_url, _csrf_token) = app_state + .slack_oauth_client + .authorize_url(|| csrf_token) + .add_scopes([ + oauth2::Scope::new(SCOPE_CHANNELS_READ.to_owned()), + oauth2::Scope::new(SCOPE_CHAT_WRITE_PUBLIC.to_owned()), + ]) + .set_redirect_uri(Cow::Owned( + RedirectUrl::new(format!( + "{}{}/en/teams/{}/channels/{}/slack-auth/callback", + frontend_host, base_path, team_id, channel_id + )) + .context("failed to create redirection URL")?, + )) + .url(); + + db_conn + .interact(move |conn| -> Result<()> { + diesel::update(channels::table.filter(Channel::with_id(&channel.id))) + .set(channels::dsl::backend_config.eq(Into::::into(slack_config))) + .execute(conn) + .map(|_| ()) + .map_err(Into::into) + }) + .await + .unwrap()?; + + Ok(Redirect::to(auth_url.as_ref()).into_response()) +} + +#[derive(Debug, Deserialize)] +struct AuthRequestQuery { + code: String, + /// CSRF token + state: String, +} + +/// HTTP get handler for /callback +async fn callback( + Query(query): Query, + State(app_state): State, + State(Settings { + base_path, + frontend_host, + .. + }): State, + DbConn(db_conn): DbConn, + CurrentUser(current_user): CurrentUser, + Path((team_id, channel_id)): Path<(Uuid, Uuid)>, +) -> Result { + guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; + + let channel = db_conn + .interact(move |conn| -> Result { + Channel::all() + .filter(Channel::with_id(&channel_id)) + .filter(Channel::with_team(&team_id)) + .first(conn) + .optional()? + .ok_or(AppError::NotFound( + "channel with that ID and team not found".to_owned(), + )) + }) + .await + .unwrap()?; + + let slack_data: SlackBackendConfig = channel + .backend_config + .try_into() + .map_err(|_| anyhow::anyhow!("channel does not have a Slack backend"))?; + + let true_csrf_token = slack_data.oauth_state.ok_or(AppError::BadRequest( + "No active Slack auth flow.".to_owned(), + ))?; + if true_csrf_token.secret() != &query.state { + tracing::debug!("oauth csrf tokens did not match"); + return Err(AppError::Forbidden( + "Slack OAuth CSRF tokens do not match.".to_owned(), + )); + } + + tracing::debug!("exchanging authorization code"); + let response = app_state + .slack_oauth_client + .exchange_code(AuthorizationCode::new(query.code)) + .set_redirect_uri(Cow::Owned( + RedirectUrl::new(format!( + "{}{}/en/teams/{}/channels/{}/slack-auth/callback", + frontend_host, base_path, team_id, channel_id + )) + .context("failed to create redirection URL")?, + )) + .request_async(async_http_client) + .await + .context("failed to exchange slack oauth code")?; + let slack_data = SlackBackendConfig { + conversation_id: slack_data.conversation_id, + oauth_state: None, + oauth_tokens: Some(OAuthTokens { + access_token: response.access_token().to_owned(), + refresh_token: response.refresh_token().map(|value| value.to_owned()), + }), + }; + db_conn + .interact(move |conn| -> Result<()> { + let n_rows = diesel::update(channels::table.filter(Channel::with_id(&channel_id))) + .set(channels::dsl::backend_config.eq(BackendConfig::from(slack_data))) + .execute(conn)?; + tracing::debug!("updated {} rows", n_rows); + assert!(n_rows == 1); + Ok(()) + }) + .await + .unwrap()?; + + tracing::debug!("successfully authenticated"); + Ok(Redirect::to(&format!( + "{}/en/teams/{}/channels/{}", + base_path, team_id, channel_id + ))) +} + +#[derive(Deserialize)] +struct RevokeFormBody { + csrf_token: String, +} + +async fn revoke( + State(Settings { + base_path, + slack: SlackSettings { api_root, .. }, + .. + }): State, + State(ReqwestClient(reqwest_client)): State, + DbConn(db_conn): DbConn, + CurrentUser(current_user): CurrentUser, + Path((team_id, channel_id)): Path<(Uuid, Uuid)>, + Form(form): Form, +) -> Result { + guards::require_valid_csrf_token(&form.csrf_token, ¤t_user, &db_conn).await?; + guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; + + let channel = db_conn + .interact(move |conn| -> Result { + Channel::all() + .filter(Channel::with_id(&channel_id)) + .filter(Channel::with_team(&team_id)) + .first(conn) + .optional()? + .ok_or(AppError::NotFound( + "channel with that ID and team not found".to_owned(), + )) + }) + .await + .unwrap()?; + + let slack_data: SlackBackendConfig = channel + .backend_config + .try_into() + .map_err(|_| anyhow::anyhow!("channel does not have a Slack backend"))?; + + if let Some(OAuthTokens { access_token, .. }) = slack_data.oauth_tokens { + #[derive(Deserialize)] + struct ApiResponse { + revoked: Option, + error: Option, + } + tracing::debug!("revoking slack access token via slack api"); + let response: ApiResponse = reqwest_client + .get(format!("{}/auth.revoke", api_root)) + .bearer_auth(access_token.secret()) + .send() + .await? + .error_for_status()? + .json() + .await?; + if response.revoked == Some(true) { + tracing::debug!("access token revoked successfully; updating backend config"); + let slack_data = SlackBackendConfig { + conversation_id: slack_data.conversation_id, + oauth_state: None, + oauth_tokens: None, + }; + db_conn + .interact(move |conn| -> Result<()> { + let n_rows = + diesel::update(channels::table.filter(Channel::with_id(&channel_id))) + .set(channels::dsl::backend_config.eq(BackendConfig::from(slack_data))) + .execute(conn)?; + tracing::debug!("updated {} rows", n_rows); + assert!(n_rows == 1); + Ok(()) + }) + .await + .unwrap()?; + tracing::debug!("backend config successfully updated"); + Ok(Redirect::to(&format!( + "{}/en/teams/{}/channels/{}", + base_path, team_id, channel_id + )) + .into_response()) + } else if let Some(message) = response.error { + Err(anyhow::anyhow!("error while revoking access token: {}", message).into()) + } else { + Err(anyhow::anyhow!("unknown error while revoking access token").into()) + } + } else { + Err(AppError::BadRequest( + "Channel is not currently authenticated with Slack credentials.".to_owned(), + )) + } +} diff --git a/src/slack_utils.rs b/src/slack_utils.rs new file mode 100644 index 0000000..218017a --- /dev/null +++ b/src/slack_utils.rs @@ -0,0 +1,464 @@ +use std::{collections::HashSet, fmt::Display}; + +use anyhow::{Context as _, Result}; +use reqwest::RequestBuilder; +use serde::{Deserialize, Serialize}; +use tracing::Instrument; +use validator::Validate; + +use crate::app_error::AppError; + +// ================ Common ================ // + +pub const DEFAULT_API_ROOT: &str = "https://slack.com/api"; + +#[derive(Clone, Debug)] +pub struct SlackClient { + access_token: oauth2::AccessToken, + api_root: String, + reqwest_client: reqwest::Client, +} + +impl SlackClient { + pub fn new(access_token: &oauth2::AccessToken) -> Self { + Self { + access_token: access_token.to_owned(), + api_root: DEFAULT_API_ROOT.to_owned(), + reqwest_client: reqwest::ClientBuilder::new() + .https_only(true) + .build() + .expect("reqwest client is always built with the same options"), + } + } + + /// Sets the API root (for example, "https://slack.com/api") + pub fn with_api_root(mut self, api_root: &str) -> Self { + self.api_root = api_root.to_owned(); + self + } + + /// Use a pre-existing reqwest client for making HTTP requests + pub fn with_reqwest_client(mut self, reqwest_client: reqwest::Client) -> Self { + self.reqwest_client = reqwest_client; + self + } + + /// Create an authenticated reqwest::RequestBuilder for an API endpoint. + fn get(&self, slack_method: &str) -> RequestBuilder { + self.reqwest_client + .get(format!("{}/{}", self.api_root, slack_method)) + .bearer_auth(self.access_token.secret()) + } + + /// Create an authenticated reqwest::RequestBuilder for an API endpoint. + fn post(&self, slack_method: &str) -> RequestBuilder { + self.reqwest_client + .post(format!("{}/{}", self.api_root, slack_method)) + .bearer_auth(self.access_token.secret()) + } + + pub fn list_conversations(&self) -> ListConversationsRequest { + ListConversationsRequest::new(self.clone()) + } + + pub fn post_chat_message(&self) -> PostChatMessageRequest { + PostChatMessageRequest::new(self.clone()) + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct ResponseMetadata { + pub next_cursor: Option, +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(untagged)] +pub enum ApiResult { + Ok(R), + Err(ApiError), +} + +impl From> for std::result::Result { + fn from(val: ApiResult) -> Self { + match val { + ApiResult::Ok(response) => Ok(response), + ApiResult::Err(error) => Err(SlackError::Api(error)), + } + } +} + +// ================ Conversations ================ // + +#[derive(Clone, Debug, Serialize, Validate)] +pub struct ListConversationsRequest { + #[serde(skip)] + client: SlackClient, + cursor: Option, + exclude_archived: Option, + types: Option, +} + +impl ListConversationsRequest { + pub fn new(client: SlackClient) -> Self { + Self { + client, + cursor: None, + exclude_archived: None, + types: None, + } + } + + // Takes String instead of &str, since pagination will almost always + // provide and consume an owned string value + pub fn with_cursor(mut self, cursor: String) -> Self { + self.cursor = Some(cursor); + self + } + + pub fn with_exclude_archived(mut self, exclude_archived: bool) -> Self { + self.exclude_archived = Some(exclude_archived); + self + } + + pub fn with_types>(mut self, types: I) -> Self { + self.types = Some( + types + .into_iter() + .collect::>() + .into_iter() + .map(|value| value.to_string()) + .collect::>() + .join(","), + ); + self + } + + pub async fn load(self) -> Result { + async { + tracing::debug!("loading page of slack conversations"); + self.validate()?; + tracing::debug!("request structure validated"); + let mut response: ListConversationsResponse = std::result::Result::from( + self.client + .get("conversations.list") + .query(&self) + .send() + .await + .context("error sending request")? + .error_for_status() + .context("bad http status")? + .json::>() + .await + .context("failed to deserialize response")?, + )?; + tracing::debug!("loaded page successfully"); + response.request = Some(self); + Ok(response) + } + .instrument(tracing::debug_span!("ListConversationsRequest::load()")) + .await + } + + pub async fn load_all(self) -> Result, SlackError> { + let mut conversations: Vec = Vec::new(); + let mut response = self.load().await?; + conversations.append(&mut response.channels); + while let Some(request) = response.next_page()? { + response = request.load().await?; + conversations.append(&mut response.channels); + } + Ok(conversations) + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct ListConversationsResponse { + #[serde(skip)] + request: Option, + + pub channels: Vec, + pub response_metadata: ResponseMetadata, +} + +impl ListConversationsResponse { + pub fn next_page(&self) -> Result> { + if self.response_metadata.next_cursor == Some("".to_owned()) { + Ok(None) + } else { + self.request + .clone() + .ok_or(anyhow::anyhow!( + "original request was not stored with the api response" + )) + .map(|request| { + self.response_metadata + .next_cursor + .clone() + .map(|cursor| request.with_cursor(cursor)) + }) + } + } +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(untagged)] +pub enum Conversation { + Channel(ChannelConversation), + Group(GroupConversation), + Im(ImConversation), +} + +#[derive(Clone, Debug, Deserialize)] +pub struct ChannelConversation { + pub id: String, + pub name: String, + pub is_archived: bool, + pub name_normalized: String, + pub is_member: bool, + pub is_private: bool, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct GroupConversation { + pub id: String, + pub name: String, + pub is_archived: bool, + pub name_normalized: String, + pub is_member: bool, + pub is_mpim: bool, + pub is_open: bool, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct ImConversation { + pub id: String, + pub is_im: bool, + pub user: String, +} + +#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ConversationType { + /// Public channel + PublicChannel, + /// Private channel + PrivateChannel, + /// Individual DM chat + Im, + /// Multi-person DM chat + Mpim, +} + +impl Display for ConversationType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", serde_variant::to_variant_name(self).unwrap()) + } +} + +// ================ Chat ================ // + +#[derive(Clone, Debug, Serialize, Validate)] +pub struct PostChatMessageRequest { + #[serde(skip)] + client: SlackClient, + #[validate(required)] + channel: Option, + #[validate(required, length(max = 4000))] + text: Option, + mrkdwn: Option, + unfurl_links: Option, + unfurl_media: Option, +} + +impl PostChatMessageRequest { + pub fn new(client: SlackClient) -> Self { + Self { + client, + channel: None, + text: None, + mrkdwn: None, + unfurl_links: None, + unfurl_media: None, + } + } + + /// Set the Slack channel ID (not to be confused with Shout.dev's internal + /// channels concept). + pub fn with_channel(mut self, channel: &str) -> Self { + self.channel = Some(channel.to_owned()); + self + } + + pub fn with_text(mut self, text: &str) -> Self { + self.text = Some(text.to_owned()); + self + } + + pub fn with_mrkdwn(mut self, mrkdwn: bool) -> Self { + self.mrkdwn = Some(mrkdwn); + self + } + + pub fn with_unfurl_links(mut self, unfurl_links: bool) -> Self { + self.unfurl_links = Some(unfurl_links); + self + } + + pub fn with_unfurl_media(mut self, unfurl_media: bool) -> Self { + self.unfurl_media = Some(unfurl_media); + self + } + + pub async fn execute(self) -> Result { + async { + tracing::debug!("posting slack message"); + self.validate()?; + tracing::debug!("request structure validated"); + let response: PostChatMessageResponse = std::result::Result::from( + self.client + .post("chat.postMessage") + .json(&self) + .send() + .await + .context("error sending request")? + .error_for_status() + .context("bad http status")? + .json::>() + .await + .context("failed to deserialize response")?, + )?; + tracing::debug!("posted message successfully"); + Ok(response) + } + .instrument(tracing::debug_span!("PostChatMessageRequest::execute()")) + .await + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct PostChatMessageResponse { + pub ok: bool, + pub channel: String, + pub ts: String, +} + +// ================ Errors ================ // + +#[derive(Debug)] +pub enum SlackError { + Api(ApiError), + Unknown(anyhow::Error), +} + +impl Display for SlackError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Api(error) => write!(f, "API error: {:?}", error), + Self::Unknown(error) => error.fmt(f), + } + } +} + +impl From for SlackError +where + E: Into, +{ + fn from(err: E) -> Self { + Self::Unknown(Into::::into(err)) + } +} + +impl From for AppError { + fn from(value: SlackError) -> Self { + Self::InternalServerError(anyhow::anyhow!("Slack error: {}", value)) + } +} + +impl SlackError { + pub fn into_anyhow(self) -> anyhow::Error { + match self { + Self::Api(error) => anyhow::anyhow!("API error: {:?}", error), + Self::Unknown(error) => error, + } + } +} + +#[derive(Clone, Debug, Deserialize, PartialEq)] +pub struct ApiError { + pub error: ErrorCode, +} + +#[derive(Clone, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum ErrorCode { + AccessDenied, + Accesslimited, + AccountInactive, + AsUserNotSupported, + AttachmentPayloadLimitExceeded, + CannotReplyToMessage, + ChannelNotFound, + DeprecatedEndpoint, + DuplicateChannelNotFound, + DuplicateMessageNotFound, + EkmAccessDenied, + EnterpriseIsRestricted, + FatalError, + InternalError, + InvalidArgName, + InvalidArguments, + InvalidArrayArg, + InvalidAuth, + InvalidBlocks, + InvalidBlocksFormat, + InvalidCharset, + InvalidCursor, + InvalidFormData, + InvalidLimit, + InvalidMetadataFormat, + InvalidMetadataSchema, + InvalidPostType, + InvalidTypes, + IsArchived, + MarkdownTextConflict, + MessageLimitExceeded, + MessagesTabDisabled, + MetadataMustBeSentFromApp, + MetadataTooLarge, + MethodDeprecated, + MethodNotSupportedForChannelType, + MissingArgument, + MissingFileData, + MissingPostType, + MissingScope, + MsgBlocksTooLong, + NoPermission, + NoText, + NotAllowedTokenType, + NotAuthed, + NotInChannel, + OrgLoginRequired, + // Yes, there are two distinct rate limit error codes. "ratelimited" seems + // to be the generic one, and "rate_limited" seems to be specific to + // posting messages. + RateLimited, + Ratelimited, + RequestTimeout, + RestrictedAction, + RestrictedActionNonThreadableChannel, + RestrictedActionReadOnlyChannel, + RestrictedActionThreadLocked, + RestrictedActionThreadOnlyChannel, + ServiceUnavailable, + SlackConnectCanvasSharingBlocked, + SlackConnectFileLinkSharingBlocked, + SlackConnectListsSharingBlocked, + TeamAccessNotGranted, + TeamAddedToOrg, + TeamNotFound, + TokenExpired, + TokenRevoked, + TooManyAttachments, + TooManyContactCards, + TwoFactorSetupRequired, +} diff --git a/src/worker.rs b/src/worker.rs index 0a15077..74576b7 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -5,11 +5,11 @@ use uuid::Uuid; use crate::{ app_state::AppState, - channels::{Channel, EmailBackendConfig}, + channels::{self, BackendConfig, Channel, EmailBackendConfig}, email::MailSender, governors::Governor, - messages::Message, - schema::{channels, messages}, + messages::{self, Message}, + slack_utils::SlackClient, }; pub async fn run_worker(state: AppState) -> Result<()> { @@ -32,55 +32,120 @@ async fn process_messages(state: AppState) -> Result<()> { const MESSAGE_QUEUE_LIMIT: i64 = 250; let db_conn = state.db_pool.get().await?; let queued_messages = db_conn - .interact::<_, Result>>(move |conn| { + .interact(move |conn| -> Result> { messages::table .inner_join(channels::table) .select((Message::as_select(), Channel::as_select())) - .filter(Message::is_not_sent()) - .order(messages::created_at.asc()) + .filter(Message::is_pending()) + .order(messages::dsl::created_at.asc()) .limit(MESSAGE_QUEUE_LIMIT) .load(conn) .context("failed to load queued messages") }) .await .unwrap()?; + // Dispatch email messages together to take advantage of Postmark's // batch send API - let emails: Vec<(Uuid, crate::email::Message)> = queued_messages - .iter() - .filter_map(|(message, channel)| { - if let Ok(backend_config) = - TryInto::::try_into(channel.backend_config.clone()) - { - if backend_config.verified { - let recipient: lettre::message::Mailbox = if let Ok(recipient) = - backend_config.recipient.parse() - { - recipient - } else { - tracing::error!("failed to parse recipient for channel {}", channel.id); - return None; - }; - let email = crate::email::Message { - from: state.settings.email.message_from.clone(), - to: recipient.into(), - subject: "Shout".to_string(), - text_body: message.message.clone(), - }; - tracing::debug!("Sending email to recipient for channel {}", channel.id); - Some((message.id, email)) - } else { - tracing::info!( - "Email recipient for channel {} is not verified", - channel.id - ); - None - } - } else { - None + let mut email_messages: Vec<(Message, Channel, EmailBackendConfig)> = + Vec::with_capacity(queued_messages.len()); + + for (message, channel) in queued_messages { + match channel.backend_config.clone() { + BackendConfig::Email(email_data) => { + email_messages.push((message, channel, email_data)); } - }) - .collect(); + BackendConfig::Slack(slack_data) => { + let result: Result<()> = if let (Some(oauth_tokens), Some(conversation_id)) = + (slack_data.oauth_tokens, slack_data.conversation_id) + { + let slack_client = SlackClient::new(&oauth_tokens.access_token) + .with_reqwest_client(state.reqwest_client.clone()) + .with_api_root(&state.settings.slack.api_root); + slack_client + .post_chat_message() + .with_channel(&conversation_id) + .with_text(&message.message) + .with_mrkdwn(false) + .with_unfurl_links(false) + .with_unfurl_media(false) + .execute() + .await + .map(|_| ()) + .map_err(|err| err.into_anyhow()) + } else { + Err(anyhow::anyhow!("slack channel is not fully configured")) + }; + if let Err(err) = result { + tracing::warn!("error sending message {}: {:?}", message.id, err); + db_conn + .interact(move |conn| -> Result<_> { + diesel::update( + messages::table.filter(Message::with_id(&message.id)), + ) + .set(messages::dsl::failed_at.eq(diesel::dsl::now)) + .execute(conn) + .context("failed to update message resolution") + }) + .await + .unwrap()?; + } else { + db_conn + .interact(move |conn| -> Result<_> { + diesel::update( + messages::table.filter(Message::with_id(&message.id)), + ) + .set(messages::dsl::sent_at.eq(diesel::dsl::now)) + .execute(conn) + .context("failed to update message resolution") + }) + .await + .unwrap()?; + } + } + } + } + + let mut emails: Vec<(Uuid, crate::email::Message)> = + Vec::with_capacity(email_messages.len()); + for (message, channel, email_data) in email_messages { + let result = if email_data.verified { + if let Ok(recipient) = email_data.recipient.parse::() { + let email = crate::email::Message { + from: state.settings.email.message_from.clone(), + to: recipient.into(), + subject: "Shout".to_string(), + text_body: message.message.clone(), + }; + tracing::debug!("prepared email to recipient for channel {}", channel.id); + emails.push((message.id, email)); + Ok(()) + } else { + Err(anyhow::anyhow!( + "failed to parse recipient for channel {}", + channel.id + )) + } + } else { + Err(anyhow::anyhow!( + "Email recipient for channel {} is not verified", + channel.id + )) + }; + if let Err(err) = result { + tracing::warn!("error sending message {}: {:?}", message.id, err); + db_conn + .interact(move |conn| -> Result<_> { + diesel::update(messages::table.filter(Message::with_id(&message.id))) + .set(messages::dsl::failed_at.eq(diesel::dsl::now)) + .execute(conn) + .context("failed to update message resolution") + }) + .await + .unwrap()?; + } + } + if !emails.is_empty() { let message_ids: Vec = emails.iter().map(|(id, _)| *id).collect(); let results = state @@ -90,13 +155,16 @@ async fn process_messages(state: AppState) -> Result<()> { assert!(results.len() == message_ids.len()); let results_by_id = message_ids.into_iter().zip(results.into_iter()); db_conn - .interact::<_, Result<_>>(move |conn| { + .interact(move |conn| -> Result<()> { for (id, result) in results_by_id { if let Err(err) = result { - tracing::error!("error sending message {}: {:?}", id, err); + tracing::warn!("error sending message {}: {:?}", id, err); + diesel::update(messages::table.filter(Message::with_id(&id))) + .set(messages::dsl::failed_at.eq(diesel::dsl::now)) + .execute(conn)?; } else { - diesel::update(messages::table.filter(messages::id.eq(id))) - .set(messages::sent_at.eq(diesel::dsl::now)) + diesel::update(messages::table.filter(Message::with_id(&id))) + .set(messages::dsl::sent_at.eq(diesel::dsl::now)) .execute(conn)?; } } @@ -105,7 +173,9 @@ async fn process_messages(state: AppState) -> Result<()> { .await .unwrap()?; } - tracing::info!("finished processing messages"); + tracing::info!("finished processing email messages"); + + tracing::info!("finished processing all messages in batch"); Ok(()) } .instrument(tracing::debug_span!("process_messages()")) diff --git a/templates/channel-base.html b/templates/channel-base.html new file mode 100644 index 0000000..b59f4c5 --- /dev/null +++ b/templates/channel-base.html @@ -0,0 +1,52 @@ +{% extends "base.html" %} + +{% block title %}Shout.dev: Channels{% endblock %} + +{% block main %} + {% include "breadcrumbs.html" %} +
+
+

Channel Configuration

+
+
+
+
+ + +
+
+
+ + +
+
+
+ + +
+
+
+ {% block extra_config %}{% endblock %} +
+{% endblock %} diff --git a/templates/channel-email.html b/templates/channel-email.html index 4fb39a2..4f7db21 100644 --- a/templates/channel-email.html +++ b/templates/channel-email.html @@ -1,53 +1,7 @@ -{% extends "base.html" %} +{% extends "channel-base.html" %} -{% block title %}Shout.dev: Channels{% endblock %} - -{% block main %} -{% if let BackendConfig::Email(email_data) = channel.backend_config %} - {% include "breadcrumbs.html" %} -
-
-

Channel Configuration

-
-
-
-
- - -
-
-
- - -
-
-
- - -
-
-
+{% block extra_config %} + {% if let BackendConfig::Email(email_data) = channel.backend_config %}
{% endif %} -
-{% endif %} + {% endif %} {% endblock %} diff --git a/templates/channel-slack.html b/templates/channel-slack.html new file mode 100644 index 0000000..5c82078 --- /dev/null +++ b/templates/channel-slack.html @@ -0,0 +1,62 @@ +{% extends "channel-base.html" %} + +{% block extra_config %} +
+ {% if let BackendConfig::Slack(slack_data) = channel.backend_config %} + {% if slack_data.oauth_tokens.is_none() %} +
+ + + + +
+ {% else %} +
+
+
+ + +
+
+ + +
+
+
+
+
+
+ + +
+
+ {% endif %} + {% endif %} +{% endblock %} diff --git a/templates/channels.html b/templates/channels.html index b155023..35e0ef3 100644 --- a/templates/channels.html +++ b/templates/channels.html @@ -20,7 +20,6 @@ New Channel