forked from 2sys/shoutdotdev
implement pg-backed governors for rate limiting
This commit is contained in:
parent
157eb37257
commit
c7fc56cff3
8 changed files with 300 additions and 2 deletions
2
migrations/2025-03-09-042820_init_governors/down.sql
Normal file
2
migrations/2025-03-09-042820_init_governors/down.sql
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
DROP TABLE IF EXISTS governor_entries;
|
||||||
|
DROP TABLE IF EXISTS governors;
|
16
migrations/2025-03-09-042820_init_governors/up.sql
Normal file
16
migrations/2025-03-09-042820_init_governors/up.sql
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
CREATE TABLE governors (
|
||||||
|
id UUID PRIMARY KEY NOT NULL,
|
||||||
|
team_id UUID NOT NULL REFERENCES teams(id) ON DELETE CASCADE,
|
||||||
|
project_id UUID REFERENCES projects(id) ON DELETE CASCADE,
|
||||||
|
window_size INTERVAL NOT NULL DEFAULT '1 hour',
|
||||||
|
max_count INT NOT NULL,
|
||||||
|
-- incremented when an entry is created; decremented when an entry expires
|
||||||
|
rolling_count INT NOT NULL DEFAULT 0
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE governor_entries (
|
||||||
|
id UUID PRIMARY KEY NOT NULL,
|
||||||
|
governor_id UUID NOT NULL REFERENCES governors(id) ON DELETE CASCADE,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX ON governor_entries(timestamp);
|
|
@ -16,6 +16,7 @@ pub enum AppError {
|
||||||
ForbiddenError(String),
|
ForbiddenError(String),
|
||||||
NotFoundError(String),
|
NotFoundError(String),
|
||||||
BadRequestError(String),
|
BadRequestError(String),
|
||||||
|
TooManyRequestsError(String),
|
||||||
AuthRedirect(AuthRedirectInfo),
|
AuthRedirect(AuthRedirectInfo),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,6 +46,12 @@ impl IntoResponse for AppError {
|
||||||
tracing::info!("Not found: {}", client_message);
|
tracing::info!("Not found: {}", client_message);
|
||||||
(StatusCode::NOT_FOUND, client_message).into_response()
|
(StatusCode::NOT_FOUND, client_message).into_response()
|
||||||
}
|
}
|
||||||
|
Self::TooManyRequestsError(client_message) => {
|
||||||
|
// Debug level so that if this is from a runaway loop, it won't
|
||||||
|
// overwhelm server logs
|
||||||
|
tracing::debug!("Too many requests: {}", client_message);
|
||||||
|
(StatusCode::TOO_MANY_REQUESTS, client_message).into_response()
|
||||||
|
}
|
||||||
Self::BadRequestError(client_message) => {
|
Self::BadRequestError(client_message) => {
|
||||||
tracing::info!("Bad user input: {}", client_message);
|
tracing::info!("Bad user input: {}", client_message);
|
||||||
(StatusCode::BAD_REQUEST, client_message).into_response()
|
(StatusCode::BAD_REQUEST, client_message).into_response()
|
||||||
|
@ -78,6 +85,9 @@ impl Display for AppError {
|
||||||
AppError::BadRequestError(client_message) => {
|
AppError::BadRequestError(client_message) => {
|
||||||
write!(f, "BadRequestError: {}", client_message)
|
write!(f, "BadRequestError: {}", client_message)
|
||||||
}
|
}
|
||||||
|
AppError::TooManyRequestsError(client_message) => {
|
||||||
|
write!(f, "TooManyRequestsError: {}", client_message)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
175
src/governors.rs
Normal file
175
src/governors.rs
Normal file
|
@ -0,0 +1,175 @@
|
||||||
|
// Fault tolerant rate limiting backed by Postgres.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use chrono::{DateTime, TimeDelta, Utc};
|
||||||
|
use diesel::{
|
||||||
|
dsl::{auto_type, AsSelect},
|
||||||
|
pg::Pg,
|
||||||
|
prelude::*,
|
||||||
|
sql_types::Timestamptz,
|
||||||
|
};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::schema::{governor_entries, governors};
|
||||||
|
|
||||||
|
define_sql_function! {
|
||||||
|
fn greatest(a: diesel::sql_types::Integer, b: diesel::sql_types::Integer) -> Integer
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Identifiable, Insertable, Queryable, Selectable)]
|
||||||
|
#[diesel(table_name = governors)]
|
||||||
|
pub struct Governor {
|
||||||
|
pub id: Uuid,
|
||||||
|
pub team_id: Uuid,
|
||||||
|
pub project_id: Option<Uuid>,
|
||||||
|
pub window_size: TimeDelta,
|
||||||
|
pub max_count: i32,
|
||||||
|
pub rolling_count: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Governor {
|
||||||
|
#[auto_type(no_type_alias)]
|
||||||
|
pub fn all() -> _ {
|
||||||
|
let select: AsSelect<Governor, Pg> = Governor::as_select();
|
||||||
|
governors::table.select(select)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[auto_type(no_type_alias)]
|
||||||
|
pub fn with_id(governor_id: Uuid) -> _ {
|
||||||
|
governors::id.eq(governor_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[auto_type(no_type_alias)]
|
||||||
|
pub fn with_team(team_id: Uuid) -> _ {
|
||||||
|
governors::team_id.eq(team_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[auto_type(no_type_alias)]
|
||||||
|
pub fn with_project(project_id: Option<Uuid>) -> _ {
|
||||||
|
governors::project_id.is_not_distinct_from(project_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: return a custom result enum instead of a Result<Option>, for
|
||||||
|
// better readability
|
||||||
|
/**
|
||||||
|
* Attempt to increment the rolling count. If the governor is not full,
|
||||||
|
* returns a GovernorEntry which can be used to cancel the operation and
|
||||||
|
* restore the rolling count. If governor is full, returns None.
|
||||||
|
*/
|
||||||
|
pub fn create_entry(&self, conn: &mut diesel::PgConnection) -> Result<Option<GovernorEntry>> {
|
||||||
|
let entry = diesel::insert_into(governor_entries::table)
|
||||||
|
.values((
|
||||||
|
governor_entries::id.eq(Uuid::now_v7()),
|
||||||
|
governor_entries::governor_id.eq(self.id),
|
||||||
|
))
|
||||||
|
.get_result(conn)?;
|
||||||
|
let n_rows = diesel::update(
|
||||||
|
governors::table
|
||||||
|
.filter(governors::id.eq(self.id))
|
||||||
|
.filter(governors::rolling_count.lt(self.max_count)),
|
||||||
|
)
|
||||||
|
.set(governors::rolling_count.eq(governors::rolling_count + 1))
|
||||||
|
.execute(conn)?;
|
||||||
|
assert!(n_rows < 2);
|
||||||
|
if n_rows == 1 {
|
||||||
|
Ok(Some(entry))
|
||||||
|
} else {
|
||||||
|
// Clean up unused entry, or else it will artificially decrement
|
||||||
|
// rolling count when it expires
|
||||||
|
diesel::delete(governor_entries::table.filter(GovernorEntry::with_id(entry.id)))
|
||||||
|
.execute(conn)?;
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Governors work by continually incrementing a counter and then
|
||||||
|
* periodically decrementing it as entries fall out of the current window of
|
||||||
|
* time. This function performs the latter part of the cycle, sweeping out
|
||||||
|
* expired entries and adjusting the counter accordingly.
|
||||||
|
*/
|
||||||
|
pub fn reclaim(&self, conn: &mut diesel::PgConnection) -> Result<()> {
|
||||||
|
let n_expired_entries: i32 = diesel::delete(
|
||||||
|
GovernorEntry::belonging_to(self).filter(
|
||||||
|
governor_entries::timestamp
|
||||||
|
.lt(diesel::dsl::now.into_sql::<Timestamptz>() - self.window_size),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.execute(conn)?
|
||||||
|
.try_into()
|
||||||
|
.expect("a governor should never have been allowed enough entries to overflow an i32");
|
||||||
|
// Clamp rolling_count >= 0
|
||||||
|
diesel::update(governors::table.filter(Self::with_id(self.id.clone())))
|
||||||
|
.set(
|
||||||
|
governors::rolling_count
|
||||||
|
.eq(greatest(governors::rolling_count - n_expired_entries, 0)),
|
||||||
|
)
|
||||||
|
.execute(conn)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reclaim_all(conn: &mut diesel::PgConnection) -> Result<()> {
|
||||||
|
let applicable_governors = governors::table
|
||||||
|
.inner_join(governor_entries::table)
|
||||||
|
.filter(
|
||||||
|
governor_entries::timestamp
|
||||||
|
.lt(diesel::dsl::now.into_sql::<Timestamptz>() - governors::window_size),
|
||||||
|
)
|
||||||
|
.select(Self::as_select())
|
||||||
|
.group_by(governors::id)
|
||||||
|
.load(conn)?;
|
||||||
|
tracing::info!(
|
||||||
|
"reclaiming counts for {} governors",
|
||||||
|
applicable_governors.len()
|
||||||
|
);
|
||||||
|
for governor in applicable_governors {
|
||||||
|
governor.reclaim(conn)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset all governors to a count of 0, to fix any accumulated error between
|
||||||
|
* rolling counts and number of entries.
|
||||||
|
*/
|
||||||
|
pub fn reset_all(conn: &mut diesel::PgConnection) -> Result<()> {
|
||||||
|
// Delete entries and then reset counts, not vice-versa; otherwise
|
||||||
|
// concurrent inserts could result in rolling counts getting stuck
|
||||||
|
// higher than they should be
|
||||||
|
diesel::delete(governor_entries::table).execute(conn)?;
|
||||||
|
diesel::update(governors::table)
|
||||||
|
.set(governors::rolling_count.eq(0))
|
||||||
|
.execute(conn)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Associations, Clone, Debug, Identifiable, Insertable, Queryable, Selectable)]
|
||||||
|
#[diesel(table_name = governor_entries)]
|
||||||
|
#[diesel(belongs_to(Governor))]
|
||||||
|
pub struct GovernorEntry {
|
||||||
|
pub id: Uuid,
|
||||||
|
pub governor_id: Uuid,
|
||||||
|
pub timestamp: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GovernorEntry {
|
||||||
|
#[auto_type(no_type_alias)]
|
||||||
|
pub fn with_id(entry_id: Uuid) -> _ {
|
||||||
|
governor_entries::id.eq(entry_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes this entry from the governor and decrements the overall rolling
|
||||||
|
* count by 1.
|
||||||
|
*/
|
||||||
|
pub fn cancel(&self, conn: &mut diesel::PgConnection) -> Result<()> {
|
||||||
|
let entry_filter = Self::with_id(self.id.clone());
|
||||||
|
let governor_filter = Governor::with_id(self.governor_id.clone());
|
||||||
|
diesel::update(governors::table.filter(governor_filter))
|
||||||
|
.set(governors::rolling_count.eq(greatest(governors::rolling_count - 1, 0)))
|
||||||
|
.execute(conn)?;
|
||||||
|
diesel::delete(governor_entries::table.filter(entry_filter)).execute(conn)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
|
@ -6,6 +6,7 @@ mod channel_selections;
|
||||||
mod channels;
|
mod channels;
|
||||||
mod csrf;
|
mod csrf;
|
||||||
mod email;
|
mod email;
|
||||||
|
mod governors;
|
||||||
mod guards;
|
mod guards;
|
||||||
mod messages;
|
mod messages;
|
||||||
mod nav_state;
|
mod nav_state;
|
||||||
|
@ -50,6 +51,8 @@ enum Commands {
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
auto_loop_seconds: Option<u32>,
|
auto_loop_seconds: Option<u32>,
|
||||||
},
|
},
|
||||||
|
// TODO: add a low-frequency worker task exclusively for self-healing
|
||||||
|
// mechanisms like Governor::reset_all()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
|
|
@ -43,6 +43,25 @@ diesel::table! {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
diesel::table! {
|
||||||
|
governor_entries (id) {
|
||||||
|
id -> Uuid,
|
||||||
|
governor_id -> Uuid,
|
||||||
|
timestamp -> Timestamptz,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
diesel::table! {
|
||||||
|
governors (id) {
|
||||||
|
id -> Uuid,
|
||||||
|
team_id -> Uuid,
|
||||||
|
project_id -> Nullable<Uuid>,
|
||||||
|
window_size -> Interval,
|
||||||
|
max_count -> Int4,
|
||||||
|
rolling_count -> Int4,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
diesel::table! {
|
diesel::table! {
|
||||||
messages (id) {
|
messages (id) {
|
||||||
id -> Uuid,
|
id -> Uuid,
|
||||||
|
@ -89,6 +108,9 @@ diesel::joinable!(channel_selections -> channels (channel_id));
|
||||||
diesel::joinable!(channel_selections -> projects (project_id));
|
diesel::joinable!(channel_selections -> projects (project_id));
|
||||||
diesel::joinable!(channels -> teams (team_id));
|
diesel::joinable!(channels -> teams (team_id));
|
||||||
diesel::joinable!(csrf_tokens -> users (user_id));
|
diesel::joinable!(csrf_tokens -> users (user_id));
|
||||||
|
diesel::joinable!(governor_entries -> governors (governor_id));
|
||||||
|
diesel::joinable!(governors -> projects (project_id));
|
||||||
|
diesel::joinable!(governors -> teams (team_id));
|
||||||
diesel::joinable!(messages -> channels (channel_id));
|
diesel::joinable!(messages -> channels (channel_id));
|
||||||
diesel::joinable!(messages -> projects (project_id));
|
diesel::joinable!(messages -> projects (project_id));
|
||||||
diesel::joinable!(projects -> teams (team_id));
|
diesel::joinable!(projects -> teams (team_id));
|
||||||
|
@ -101,6 +123,8 @@ diesel::allow_tables_to_appear_in_same_query!(
|
||||||
channel_selections,
|
channel_selections,
|
||||||
channels,
|
channels,
|
||||||
csrf_tokens,
|
csrf_tokens,
|
||||||
|
governor_entries,
|
||||||
|
governors,
|
||||||
messages,
|
messages,
|
||||||
projects,
|
projects,
|
||||||
team_memberships,
|
team_memberships,
|
||||||
|
|
|
@ -5,6 +5,7 @@ use axum::{
|
||||||
routing::get,
|
routing::get,
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
|
use chrono::TimeDelta;
|
||||||
use diesel::{dsl::insert_into, prelude::*, update};
|
use diesel::{dsl::insert_into, prelude::*, update};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
@ -15,10 +16,14 @@ use crate::{
|
||||||
app_error::AppError,
|
app_error::AppError,
|
||||||
app_state::{AppState, DbConn},
|
app_state::{AppState, DbConn},
|
||||||
channels::Channel,
|
channels::Channel,
|
||||||
|
governors::Governor,
|
||||||
projects::Project,
|
projects::Project,
|
||||||
schema::{api_keys, messages, projects},
|
schema::{api_keys, governors, messages, projects},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const TEAM_GOVERNOR_DEFAULT_WINDOW_SIZE_SEC: i64 = 300;
|
||||||
|
const TEAM_GOVERNOR_DEFAULT_MAX_COUNT: i32 = 50;
|
||||||
|
|
||||||
pub fn new_router(state: AppState) -> Router<AppState> {
|
pub fn new_router(state: AppState) -> Router<AppState> {
|
||||||
Router::new().route("/say", get(say_get)).with_state(state)
|
Router::new().route("/say", get(say_get)).with_state(state)
|
||||||
}
|
}
|
||||||
|
@ -80,6 +85,52 @@ async fn say_get(
|
||||||
.await
|
.await
|
||||||
.unwrap()?
|
.unwrap()?
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let team_governor = {
|
||||||
|
let team_id = project.team_id.clone();
|
||||||
|
db_conn
|
||||||
|
.interact::<_, Result<Governor, AppError>>(move |conn| {
|
||||||
|
// TODO: extract this logic to a method in crate::governors,
|
||||||
|
// and create governor proactively on team creation
|
||||||
|
match Governor::all()
|
||||||
|
.filter(Governor::with_team(team_id.clone()))
|
||||||
|
.filter(Governor::with_project(None))
|
||||||
|
.first(conn)
|
||||||
|
{
|
||||||
|
diesel::QueryResult::Ok(governor) => Ok(governor),
|
||||||
|
diesel::QueryResult::Err(diesel::result::Error::NotFound) => {
|
||||||
|
// Lazily initialize governor
|
||||||
|
Ok(diesel::insert_into(governors::table)
|
||||||
|
.values((
|
||||||
|
governors::team_id.eq(team_id),
|
||||||
|
governors::id.eq(Uuid::now_v7()),
|
||||||
|
governors::project_id.eq(None as Option<Uuid>),
|
||||||
|
governors::window_size
|
||||||
|
.eq(TimeDelta::seconds(TEAM_GOVERNOR_DEFAULT_WINDOW_SIZE_SEC)),
|
||||||
|
governors::max_count.eq(TEAM_GOVERNOR_DEFAULT_MAX_COUNT),
|
||||||
|
))
|
||||||
|
.get_result(conn)?)
|
||||||
|
}
|
||||||
|
diesel::QueryResult::Err(err) => Err(err.into()),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap()?
|
||||||
|
};
|
||||||
|
|
||||||
|
if db_conn
|
||||||
|
.interact::<_, Result<Option<_>, anyhow::Error>>(move |conn| {
|
||||||
|
team_governor.create_entry(conn)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap()?
|
||||||
|
.is_none()
|
||||||
|
{
|
||||||
|
return Err(AppError::TooManyRequestsError(
|
||||||
|
"team rate limit exceeded".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let selected_channels = {
|
let selected_channels = {
|
||||||
let project = project.clone();
|
let project = project.clone();
|
||||||
db_conn
|
db_conn
|
||||||
|
|
|
@ -7,13 +7,15 @@ use crate::{
|
||||||
app_state::AppState,
|
app_state::AppState,
|
||||||
channels::{Channel, EmailBackendConfig},
|
channels::{Channel, EmailBackendConfig},
|
||||||
email::MailSender,
|
email::MailSender,
|
||||||
|
governors::Governor,
|
||||||
messages::Message,
|
messages::Message,
|
||||||
schema::{channels, messages},
|
schema::{channels, messages},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub async fn run_worker(state: AppState) -> Result<()> {
|
pub async fn run_worker(state: AppState) -> Result<()> {
|
||||||
async move {
|
async move {
|
||||||
process_messages(state).await?;
|
process_messages(state.clone()).await?;
|
||||||
|
reclaim_governor_entries(state.clone()).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
.instrument(tracing::debug_span!("run_worker()"))
|
.instrument(tracing::debug_span!("run_worker()"))
|
||||||
|
@ -109,3 +111,18 @@ async fn process_messages(state: AppState) -> Result<()> {
|
||||||
.instrument(tracing::debug_span!("process_messages()"))
|
.instrument(tracing::debug_span!("process_messages()"))
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn reclaim_governor_entries(state: AppState) -> Result<()> {
|
||||||
|
async move {
|
||||||
|
let db_conn = state.db_pool.get().await?;
|
||||||
|
db_conn
|
||||||
|
.interact(move |conn| Governor::reclaim_all(conn))
|
||||||
|
.await
|
||||||
|
.unwrap()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
// This doesn't do much, since it seems that tracing spans don't carry
|
||||||
|
// across into the thread spawned by db_conn.interact()
|
||||||
|
.instrument(tracing::debug_span!("reclaim_governor_entries()"))
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue