phonograph/interim-server/src/base_pooler.rs

112 lines
3.4 KiB
Rust
Raw Normal View History

2025-05-26 22:08:21 -07:00
use std::{collections::HashMap, sync::Arc, time::Duration};
use anyhow::{Context as _, Result};
use axum::extract::FromRef;
use sqlx::{pool::PoolConnection, postgres::PgPoolOptions, raw_sql, Executor, PgPool, Postgres};
use tokio::sync::{OnceCell, RwLock};
use uuid::Uuid;
use crate::{app_state::AppState, bases::Base};
const MAX_CONNECTIONS: u32 = 4;
const IDLE_SECONDS: u64 = 3600;
// NOTE: The Arc<RwLock> this uses will probably need to be cleaned up for
// performance eventually.
/// A collection of multiple SQLx Pools.
#[derive(Clone)]
pub struct BasePooler {
pools: Arc<RwLock<HashMap<Uuid, OnceCell<PgPool>>>>,
app_db: PgPool,
}
impl BasePooler {
pub fn new_with_app_db(app_db: PgPool) -> Self {
Self {
app_db,
pools: Arc::new(RwLock::new(HashMap::new())),
}
}
async fn get_pool_for(&mut self, base_id: Uuid) -> Result<PgPool> {
let init_cell = || async {
let base = Base::fetch_by_id(base_id, &self.app_db)
.await?
.context("no such base")?;
Ok(PgPoolOptions::new()
.min_connections(0)
.max_connections(MAX_CONNECTIONS)
.idle_timeout(Some(Duration::from_secs(IDLE_SECONDS)))
.after_release(|conn, _| {
Box::pin(async move {
// Essentially "DISCARD ALL" without "DEALLOCATE ALL"
conn.execute(raw_sql(
"
close all;
set session authorization default;
reset all;
unlisten *;
select pg_advisory_unlock_all();
discard plans;
discard temp;
discard sequences;
",
))
.await?;
Ok(true)
})
})
.connect(&base.url)
.await?)
};
// Attempt to get an existing pool without write-locking the map
let pools = self.pools.read().await;
if let Some(cell) = pools.get(&base_id) {
return Ok(cell
.get_or_try_init::<anyhow::Error, _, _>(init_cell)
.await?
.clone());
}
drop(pools); // Release read lock
let mut pools = self.pools.write().await;
let entry = pools.entry(base_id).or_insert(OnceCell::new());
Ok(entry
.get_or_try_init::<anyhow::Error, _, _>(init_cell)
.await?
.clone())
}
pub async fn acquire_for(&mut self, base_id: Uuid) -> Result<PoolConnection<Postgres>> {
let pool = self.get_pool_for(base_id).await?;
Ok(pool.acquire().await?)
}
pub async fn close_for(&mut self, base_id: Uuid) -> Result<()> {
let pools = self.pools.read().await;
if let Some(cell) = pools.get(&base_id) {
if let Some(pool) = cell.get() {
let pool = pool.clone();
drop(pools); // Release read lock
let mut pools = self.pools.write().await;
pools.remove(&base_id);
drop(pools); // Release write lock
pool.close().await;
}
}
Ok(())
}
// TODO: Add a cleanup method to remove entries with no connections
}
impl<S> FromRef<S> for BasePooler
where
S: Into<AppState> + Clone,
{
fn from_ref(state: &S) -> Self {
Into::<AppState>::into(state.clone()).base_pooler.clone()
}
}