implement pg-backed governors for rate limiting

This commit is contained in:
Brent Schroeter 2025-03-10 14:52:02 -07:00
parent 157eb37257
commit c7fc56cff3
8 changed files with 300 additions and 2 deletions

View file

@ -0,0 +1,2 @@
DROP TABLE IF EXISTS governor_entries;
DROP TABLE IF EXISTS governors;

View file

@ -0,0 +1,16 @@
CREATE TABLE governors (
id UUID PRIMARY KEY NOT NULL,
team_id UUID NOT NULL REFERENCES teams(id) ON DELETE CASCADE,
project_id UUID REFERENCES projects(id) ON DELETE CASCADE,
window_size INTERVAL NOT NULL DEFAULT '1 hour',
max_count INT NOT NULL,
-- incremented when an entry is created; decremented when an entry expires
rolling_count INT NOT NULL DEFAULT 0
);
CREATE TABLE governor_entries (
id UUID PRIMARY KEY NOT NULL,
governor_id UUID NOT NULL REFERENCES governors(id) ON DELETE CASCADE,
timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX ON governor_entries(timestamp);

View file

@ -16,6 +16,7 @@ pub enum AppError {
ForbiddenError(String),
NotFoundError(String),
BadRequestError(String),
TooManyRequestsError(String),
AuthRedirect(AuthRedirectInfo),
}
@ -45,6 +46,12 @@ impl IntoResponse for AppError {
tracing::info!("Not found: {}", client_message);
(StatusCode::NOT_FOUND, client_message).into_response()
}
Self::TooManyRequestsError(client_message) => {
// Debug level so that if this is from a runaway loop, it won't
// overwhelm server logs
tracing::debug!("Too many requests: {}", client_message);
(StatusCode::TOO_MANY_REQUESTS, client_message).into_response()
}
Self::BadRequestError(client_message) => {
tracing::info!("Bad user input: {}", client_message);
(StatusCode::BAD_REQUEST, client_message).into_response()
@ -78,6 +85,9 @@ impl Display for AppError {
AppError::BadRequestError(client_message) => {
write!(f, "BadRequestError: {}", client_message)
}
AppError::TooManyRequestsError(client_message) => {
write!(f, "TooManyRequestsError: {}", client_message)
}
}
}
}

175
src/governors.rs Normal file
View file

@ -0,0 +1,175 @@
// Fault tolerant rate limiting backed by Postgres.
use anyhow::Result;
use chrono::{DateTime, TimeDelta, Utc};
use diesel::{
dsl::{auto_type, AsSelect},
pg::Pg,
prelude::*,
sql_types::Timestamptz,
};
use uuid::Uuid;
use crate::schema::{governor_entries, governors};
define_sql_function! {
fn greatest(a: diesel::sql_types::Integer, b: diesel::sql_types::Integer) -> Integer
}
#[derive(Clone, Debug, Identifiable, Insertable, Queryable, Selectable)]
#[diesel(table_name = governors)]
pub struct Governor {
pub id: Uuid,
pub team_id: Uuid,
pub project_id: Option<Uuid>,
pub window_size: TimeDelta,
pub max_count: i32,
pub rolling_count: i32,
}
impl Governor {
#[auto_type(no_type_alias)]
pub fn all() -> _ {
let select: AsSelect<Governor, Pg> = Governor::as_select();
governors::table.select(select)
}
#[auto_type(no_type_alias)]
pub fn with_id(governor_id: Uuid) -> _ {
governors::id.eq(governor_id)
}
#[auto_type(no_type_alias)]
pub fn with_team(team_id: Uuid) -> _ {
governors::team_id.eq(team_id)
}
#[auto_type(no_type_alias)]
pub fn with_project(project_id: Option<Uuid>) -> _ {
governors::project_id.is_not_distinct_from(project_id)
}
// TODO: return a custom result enum instead of a Result<Option>, for
// better readability
/**
* Attempt to increment the rolling count. If the governor is not full,
* returns a GovernorEntry which can be used to cancel the operation and
* restore the rolling count. If governor is full, returns None.
*/
pub fn create_entry(&self, conn: &mut diesel::PgConnection) -> Result<Option<GovernorEntry>> {
let entry = diesel::insert_into(governor_entries::table)
.values((
governor_entries::id.eq(Uuid::now_v7()),
governor_entries::governor_id.eq(self.id),
))
.get_result(conn)?;
let n_rows = diesel::update(
governors::table
.filter(governors::id.eq(self.id))
.filter(governors::rolling_count.lt(self.max_count)),
)
.set(governors::rolling_count.eq(governors::rolling_count + 1))
.execute(conn)?;
assert!(n_rows < 2);
if n_rows == 1 {
Ok(Some(entry))
} else {
// Clean up unused entry, or else it will artificially decrement
// rolling count when it expires
diesel::delete(governor_entries::table.filter(GovernorEntry::with_id(entry.id)))
.execute(conn)?;
Ok(None)
}
}
/**
* Governors work by continually incrementing a counter and then
* periodically decrementing it as entries fall out of the current window of
* time. This function performs the latter part of the cycle, sweeping out
* expired entries and adjusting the counter accordingly.
*/
pub fn reclaim(&self, conn: &mut diesel::PgConnection) -> Result<()> {
let n_expired_entries: i32 = diesel::delete(
GovernorEntry::belonging_to(self).filter(
governor_entries::timestamp
.lt(diesel::dsl::now.into_sql::<Timestamptz>() - self.window_size),
),
)
.execute(conn)?
.try_into()
.expect("a governor should never have been allowed enough entries to overflow an i32");
// Clamp rolling_count >= 0
diesel::update(governors::table.filter(Self::with_id(self.id.clone())))
.set(
governors::rolling_count
.eq(greatest(governors::rolling_count - n_expired_entries, 0)),
)
.execute(conn)?;
Ok(())
}
pub fn reclaim_all(conn: &mut diesel::PgConnection) -> Result<()> {
let applicable_governors = governors::table
.inner_join(governor_entries::table)
.filter(
governor_entries::timestamp
.lt(diesel::dsl::now.into_sql::<Timestamptz>() - governors::window_size),
)
.select(Self::as_select())
.group_by(governors::id)
.load(conn)?;
tracing::info!(
"reclaiming counts for {} governors",
applicable_governors.len()
);
for governor in applicable_governors {
governor.reclaim(conn)?;
}
Ok(())
}
/**
* Reset all governors to a count of 0, to fix any accumulated error between
* rolling counts and number of entries.
*/
pub fn reset_all(conn: &mut diesel::PgConnection) -> Result<()> {
// Delete entries and then reset counts, not vice-versa; otherwise
// concurrent inserts could result in rolling counts getting stuck
// higher than they should be
diesel::delete(governor_entries::table).execute(conn)?;
diesel::update(governors::table)
.set(governors::rolling_count.eq(0))
.execute(conn)?;
Ok(())
}
}
#[derive(Associations, Clone, Debug, Identifiable, Insertable, Queryable, Selectable)]
#[diesel(table_name = governor_entries)]
#[diesel(belongs_to(Governor))]
pub struct GovernorEntry {
pub id: Uuid,
pub governor_id: Uuid,
pub timestamp: DateTime<Utc>,
}
impl GovernorEntry {
#[auto_type(no_type_alias)]
pub fn with_id(entry_id: Uuid) -> _ {
governor_entries::id.eq(entry_id)
}
/**
* Removes this entry from the governor and decrements the overall rolling
* count by 1.
*/
pub fn cancel(&self, conn: &mut diesel::PgConnection) -> Result<()> {
let entry_filter = Self::with_id(self.id.clone());
let governor_filter = Governor::with_id(self.governor_id.clone());
diesel::update(governors::table.filter(governor_filter))
.set(governors::rolling_count.eq(greatest(governors::rolling_count - 1, 0)))
.execute(conn)?;
diesel::delete(governor_entries::table.filter(entry_filter)).execute(conn)?;
Ok(())
}
}

View file

@ -6,6 +6,7 @@ mod channel_selections;
mod channels;
mod csrf;
mod email;
mod governors;
mod guards;
mod messages;
mod nav_state;
@ -50,6 +51,8 @@ enum Commands {
#[arg(long)]
auto_loop_seconds: Option<u32>,
},
// TODO: add a low-frequency worker task exclusively for self-healing
// mechanisms like Governor::reset_all()
}
#[tokio::main]

View file

@ -43,6 +43,25 @@ diesel::table! {
}
}
diesel::table! {
governor_entries (id) {
id -> Uuid,
governor_id -> Uuid,
timestamp -> Timestamptz,
}
}
diesel::table! {
governors (id) {
id -> Uuid,
team_id -> Uuid,
project_id -> Nullable<Uuid>,
window_size -> Interval,
max_count -> Int4,
rolling_count -> Int4,
}
}
diesel::table! {
messages (id) {
id -> Uuid,
@ -89,6 +108,9 @@ diesel::joinable!(channel_selections -> channels (channel_id));
diesel::joinable!(channel_selections -> projects (project_id));
diesel::joinable!(channels -> teams (team_id));
diesel::joinable!(csrf_tokens -> users (user_id));
diesel::joinable!(governor_entries -> governors (governor_id));
diesel::joinable!(governors -> projects (project_id));
diesel::joinable!(governors -> teams (team_id));
diesel::joinable!(messages -> channels (channel_id));
diesel::joinable!(messages -> projects (project_id));
diesel::joinable!(projects -> teams (team_id));
@ -101,6 +123,8 @@ diesel::allow_tables_to_appear_in_same_query!(
channel_selections,
channels,
csrf_tokens,
governor_entries,
governors,
messages,
projects,
team_memberships,

View file

@ -5,6 +5,7 @@ use axum::{
routing::get,
Router,
};
use chrono::TimeDelta;
use diesel::{dsl::insert_into, prelude::*, update};
use serde::Deserialize;
use serde_json::json;
@ -15,10 +16,14 @@ use crate::{
app_error::AppError,
app_state::{AppState, DbConn},
channels::Channel,
governors::Governor,
projects::Project,
schema::{api_keys, messages, projects},
schema::{api_keys, governors, messages, projects},
};
const TEAM_GOVERNOR_DEFAULT_WINDOW_SIZE_SEC: i64 = 300;
const TEAM_GOVERNOR_DEFAULT_MAX_COUNT: i32 = 50;
pub fn new_router(state: AppState) -> Router<AppState> {
Router::new().route("/say", get(say_get)).with_state(state)
}
@ -80,6 +85,52 @@ async fn say_get(
.await
.unwrap()?
};
let team_governor = {
let team_id = project.team_id.clone();
db_conn
.interact::<_, Result<Governor, AppError>>(move |conn| {
// TODO: extract this logic to a method in crate::governors,
// and create governor proactively on team creation
match Governor::all()
.filter(Governor::with_team(team_id.clone()))
.filter(Governor::with_project(None))
.first(conn)
{
diesel::QueryResult::Ok(governor) => Ok(governor),
diesel::QueryResult::Err(diesel::result::Error::NotFound) => {
// Lazily initialize governor
Ok(diesel::insert_into(governors::table)
.values((
governors::team_id.eq(team_id),
governors::id.eq(Uuid::now_v7()),
governors::project_id.eq(None as Option<Uuid>),
governors::window_size
.eq(TimeDelta::seconds(TEAM_GOVERNOR_DEFAULT_WINDOW_SIZE_SEC)),
governors::max_count.eq(TEAM_GOVERNOR_DEFAULT_MAX_COUNT),
))
.get_result(conn)?)
}
diesel::QueryResult::Err(err) => Err(err.into()),
}
})
.await
.unwrap()?
};
if db_conn
.interact::<_, Result<Option<_>, anyhow::Error>>(move |conn| {
team_governor.create_entry(conn)
})
.await
.unwrap()?
.is_none()
{
return Err(AppError::TooManyRequestsError(
"team rate limit exceeded".to_string(),
));
}
let selected_channels = {
let project = project.clone();
db_conn

View file

@ -7,13 +7,15 @@ use crate::{
app_state::AppState,
channels::{Channel, EmailBackendConfig},
email::MailSender,
governors::Governor,
messages::Message,
schema::{channels, messages},
};
pub async fn run_worker(state: AppState) -> Result<()> {
async move {
process_messages(state).await?;
process_messages(state.clone()).await?;
reclaim_governor_entries(state.clone()).await?;
Ok(())
}
.instrument(tracing::debug_span!("run_worker()"))
@ -109,3 +111,18 @@ async fn process_messages(state: AppState) -> Result<()> {
.instrument(tracing::debug_span!("process_messages()"))
.await
}
async fn reclaim_governor_entries(state: AppState) -> Result<()> {
async move {
let db_conn = state.db_pool.get().await?;
db_conn
.interact(move |conn| Governor::reclaim_all(conn))
.await
.unwrap()?;
Ok(())
}
// This doesn't do much, since it seems that tracing spans don't carry
// across into the thread spawned by db_conn.interact()
.instrument(tracing::debug_span!("reclaim_governor_entries()"))
.await
}