use std::{collections::HashMap, sync::Arc, time::Duration}; use anyhow::{Context as _, Result}; use axum::extract::FromRef; use sqlx::{pool::PoolConnection, postgres::PgPoolOptions, raw_sql, Executor, PgPool, Postgres}; use tokio::sync::{OnceCell, RwLock}; use uuid::Uuid; use crate::{app_state::AppState, bases::Base}; const MAX_CONNECTIONS: u32 = 4; const IDLE_SECONDS: u64 = 3600; // NOTE: The Arc this uses will probably need to be cleaned up for // performance eventually. /// A collection of multiple SQLx Pools. #[derive(Clone)] pub struct BasePooler { pools: Arc>>>, app_db: PgPool, } impl BasePooler { pub fn new_with_app_db(app_db: PgPool) -> Self { Self { app_db, pools: Arc::new(RwLock::new(HashMap::new())), } } async fn get_pool_for(&mut self, base_id: Uuid) -> Result { let init_cell = || async { let base = Base::fetch_by_id(base_id, &self.app_db) .await? .context("no such base")?; Ok(PgPoolOptions::new() .min_connections(0) .max_connections(MAX_CONNECTIONS) .idle_timeout(Some(Duration::from_secs(IDLE_SECONDS))) .after_release(|conn, _| { Box::pin(async move { // Essentially "DISCARD ALL" without "DEALLOCATE ALL" conn.execute(raw_sql( " close all; set session authorization default; reset all; unlisten *; select pg_advisory_unlock_all(); discard plans; discard temp; discard sequences; ", )) .await?; Ok(true) }) }) .connect(&base.url) .await?) }; // Attempt to get an existing pool without write-locking the map let pools = self.pools.read().await; if let Some(cell) = pools.get(&base_id) { return Ok(cell .get_or_try_init::(init_cell) .await? .clone()); } drop(pools); // Release read lock let mut pools = self.pools.write().await; let entry = pools.entry(base_id).or_insert(OnceCell::new()); Ok(entry .get_or_try_init::(init_cell) .await? .clone()) } pub async fn acquire_for(&mut self, base_id: Uuid) -> Result> { let pool = self.get_pool_for(base_id).await?; Ok(pool.acquire().await?) } pub async fn close_for(&mut self, base_id: Uuid) -> Result<()> { let pools = self.pools.read().await; if let Some(cell) = pools.get(&base_id) { if let Some(pool) = cell.get() { let pool = pool.clone(); drop(pools); // Release read lock let mut pools = self.pools.write().await; pools.remove(&base_id); drop(pools); // Release write lock pool.close().await; } } Ok(()) } // TODO: Add a cleanup method to remove entries with no connections } impl FromRef for BasePooler where S: Into + Clone, { fn from_ref(state: &S) -> Self { Into::::into(state.clone()).base_pooler.clone() } }