use std::borrow::Cow; use anyhow::{Context as _, Result}; use axum::{ extract::{Path, Query, State}, response::{IntoResponse, Redirect, Response}, routing::{get, post}, Router, }; use axum_extra::extract::Form; use diesel::prelude::*; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, TokenResponse, TokenUrl, }; use serde::Deserialize; use uuid::Uuid; use crate::{ app_error::AppError, app_state::{AppState, DbConn, ReqwestClient}, channels::{self, BackendConfig, Channel, OAuthTokens, SlackBackendConfig}, guards, settings::{Settings, SlackSettings}, users::CurrentUser, }; /// Creates a new OAuth2 client to be stored in global application state. pub fn new_oauth_client(settings: &Settings) -> Result { Ok(BasicClient::new( ClientId::new(settings.slack.client_id.clone()), Some(ClientSecret::new(settings.slack.client_secret.clone())), AuthUrl::new(settings.slack.auth_url.clone()) .context("failed to create new authorization server URL")?, Some( TokenUrl::new(format!("{}/oauth.v2.access", settings.slack.api_root)) .context("failed to create new token endpoint URL")?, ), )) } /// Creates a router which can be nested within the higher level app router. pub fn new_router() -> Router { Router::new() .route( "/teams/{team_id}/channels/{channel_id}/slack-auth/login", post(start_login), ) .route( "/teams/{team_id}/channels/{channel_id}/slack-auth/callback", get(callback), ) .route( "/teams/{team_id}/channels/{channel_id}/slack-auth/revoke", post(revoke), ) } #[derive(Deserialize)] struct StartLoginFormBody { csrf_token: String, } /// HTTP get handler for /login async fn start_login( State(app_state): State, State(Settings { base_path, frontend_host, .. }): State, DbConn(db_conn): DbConn, CurrentUser(current_user): CurrentUser, Path((team_id, channel_id)): Path<(Uuid, Uuid)>, Form(form): Form, ) -> Result { guards::require_valid_csrf_token(&form.csrf_token, ¤t_user, &db_conn).await?; guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; let channel = db_conn .interact(move |conn| -> Result { Channel::all() .filter(Channel::with_id(&channel_id)) .filter(Channel::with_team(&team_id)) .first(conn) .optional()? .ok_or(AppError::NotFound( "channel with that ID and team not found".to_owned(), )) }) .await .unwrap()?; let csrf_token = CsrfToken::new_random(); let SlackBackendConfig { conversation_id, .. } = channel .backend_config .try_into() .map_err(|_| anyhow::anyhow!("channel does not have a Slack backend"))?; let slack_config = SlackBackendConfig { conversation_id, oauth_state: Some(csrf_token.clone()), oauth_tokens: None, }; const SCOPE_CHANNELS_READ: &str = "channels:read"; const SCOPE_CHAT_WRITE_PUBLIC: &str = "chat:write.public"; let (auth_url, _csrf_token) = app_state .slack_oauth_client .authorize_url(|| csrf_token) .add_scopes([ oauth2::Scope::new(SCOPE_CHANNELS_READ.to_owned()), oauth2::Scope::new(SCOPE_CHAT_WRITE_PUBLIC.to_owned()), ]) .set_redirect_uri(Cow::Owned( RedirectUrl::new(format!( "{}{}/en/teams/{}/channels/{}/slack-auth/callback", frontend_host, base_path, team_id, channel_id )) .context("failed to create redirection URL")?, )) .url(); db_conn .interact(move |conn| -> Result<()> { diesel::update(channels::table.filter(Channel::with_id(&channel.id))) .set(channels::dsl::backend_config.eq(Into::::into(slack_config))) .execute(conn) .map(|_| ()) .map_err(Into::into) }) .await .unwrap()?; Ok(Redirect::to(auth_url.as_ref()).into_response()) } #[derive(Debug, Deserialize)] struct AuthRequestQuery { code: String, /// CSRF token state: String, } /// HTTP get handler for /callback async fn callback( Query(query): Query, State(app_state): State, State(Settings { base_path, frontend_host, .. }): State, DbConn(db_conn): DbConn, CurrentUser(current_user): CurrentUser, Path((team_id, channel_id)): Path<(Uuid, Uuid)>, ) -> Result { guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; let channel = db_conn .interact(move |conn| -> Result { Channel::all() .filter(Channel::with_id(&channel_id)) .filter(Channel::with_team(&team_id)) .first(conn) .optional()? .ok_or(AppError::NotFound( "channel with that ID and team not found".to_owned(), )) }) .await .unwrap()?; let slack_data: SlackBackendConfig = channel .backend_config .try_into() .map_err(|_| anyhow::anyhow!("channel does not have a Slack backend"))?; let true_csrf_token = slack_data.oauth_state.ok_or(AppError::BadRequest( "No active Slack auth flow.".to_owned(), ))?; if true_csrf_token.secret() != &query.state { tracing::debug!("oauth csrf tokens did not match"); return Err(AppError::Forbidden( "Slack OAuth CSRF tokens do not match.".to_owned(), )); } tracing::debug!("exchanging authorization code"); let response = app_state .slack_oauth_client .exchange_code(AuthorizationCode::new(query.code)) .set_redirect_uri(Cow::Owned( RedirectUrl::new(format!( "{}{}/en/teams/{}/channels/{}/slack-auth/callback", frontend_host, base_path, team_id, channel_id )) .context("failed to create redirection URL")?, )) .request_async(async_http_client) .await .context("failed to exchange slack oauth code")?; let slack_data = SlackBackendConfig { conversation_id: slack_data.conversation_id, oauth_state: None, oauth_tokens: Some(OAuthTokens { access_token: response.access_token().to_owned(), refresh_token: response.refresh_token().map(|value| value.to_owned()), }), }; db_conn .interact(move |conn| -> Result<()> { let n_rows = diesel::update(channels::table.filter(Channel::with_id(&channel_id))) .set(channels::dsl::backend_config.eq(BackendConfig::from(slack_data))) .execute(conn)?; tracing::debug!("updated {} rows", n_rows); assert!(n_rows == 1); Ok(()) }) .await .unwrap()?; tracing::debug!("successfully authenticated"); Ok(Redirect::to(&format!( "{}/en/teams/{}/channels/{}", base_path, team_id, channel_id ))) } #[derive(Deserialize)] struct RevokeFormBody { csrf_token: String, } async fn revoke( State(Settings { base_path, slack: SlackSettings { api_root, .. }, .. }): State, State(ReqwestClient(reqwest_client)): State, DbConn(db_conn): DbConn, CurrentUser(current_user): CurrentUser, Path((team_id, channel_id)): Path<(Uuid, Uuid)>, Form(form): Form, ) -> Result { guards::require_valid_csrf_token(&form.csrf_token, ¤t_user, &db_conn).await?; guards::require_team_membership(¤t_user, &team_id, &db_conn).await?; let channel = db_conn .interact(move |conn| -> Result { Channel::all() .filter(Channel::with_id(&channel_id)) .filter(Channel::with_team(&team_id)) .first(conn) .optional()? .ok_or(AppError::NotFound( "channel with that ID and team not found".to_owned(), )) }) .await .unwrap()?; let slack_data: SlackBackendConfig = channel .backend_config .try_into() .map_err(|_| anyhow::anyhow!("channel does not have a Slack backend"))?; if let Some(OAuthTokens { access_token, .. }) = slack_data.oauth_tokens { #[derive(Deserialize)] struct ApiResponse { revoked: Option, error: Option, } tracing::debug!("revoking slack access token via slack api"); let response: ApiResponse = reqwest_client .get(format!("{}/auth.revoke", api_root)) .bearer_auth(access_token.secret()) .send() .await? .error_for_status()? .json() .await?; if response.revoked == Some(true) { tracing::debug!("access token revoked successfully; updating backend config"); let slack_data = SlackBackendConfig { conversation_id: slack_data.conversation_id, oauth_state: None, oauth_tokens: None, }; db_conn .interact(move |conn| -> Result<()> { let n_rows = diesel::update(channels::table.filter(Channel::with_id(&channel_id))) .set(channels::dsl::backend_config.eq(BackendConfig::from(slack_data))) .execute(conn)?; tracing::debug!("updated {} rows", n_rows); assert!(n_rows == 1); Ok(()) }) .await .unwrap()?; tracing::debug!("backend config successfully updated"); Ok(Redirect::to(&format!( "{}/en/teams/{}/channels/{}", base_path, team_id, channel_id )) .into_response()) } else if let Some(message) = response.error { Err(anyhow::anyhow!("error while revoking access token: {}", message).into()) } else { Err(anyhow::anyhow!("unknown error while revoking access token").into()) } } else { Err(AppError::BadRequest( "Channel is not currently authenticated with Slack credentials.".to_owned(), )) } }