111 lines
3.4 KiB
Rust
111 lines
3.4 KiB
Rust
use std::{collections::HashMap, sync::Arc, time::Duration};
|
|
|
|
use anyhow::{Context as _, Result};
|
|
use axum::extract::FromRef;
|
|
use sqlx::{pool::PoolConnection, postgres::PgPoolOptions, raw_sql, Executor, PgPool, Postgres};
|
|
use tokio::sync::{OnceCell, RwLock};
|
|
use uuid::Uuid;
|
|
|
|
use crate::{app_state::AppState, bases::Base};
|
|
|
|
const MAX_CONNECTIONS: u32 = 4;
|
|
const IDLE_SECONDS: u64 = 3600;
|
|
|
|
// NOTE: The Arc<RwLock> this uses will probably need to be cleaned up for
|
|
// performance eventually.
|
|
|
|
/// A collection of multiple SQLx Pools.
|
|
#[derive(Clone)]
|
|
pub struct BasePooler {
|
|
pools: Arc<RwLock<HashMap<Uuid, OnceCell<PgPool>>>>,
|
|
app_db: PgPool,
|
|
}
|
|
|
|
impl BasePooler {
|
|
pub fn new_with_app_db(app_db: PgPool) -> Self {
|
|
Self {
|
|
app_db,
|
|
pools: Arc::new(RwLock::new(HashMap::new())),
|
|
}
|
|
}
|
|
|
|
async fn get_pool_for(&mut self, base_id: Uuid) -> Result<PgPool> {
|
|
let init_cell = || async {
|
|
let base = Base::fetch_by_id(base_id, &self.app_db)
|
|
.await?
|
|
.context("no such base")?;
|
|
Ok(PgPoolOptions::new()
|
|
.min_connections(0)
|
|
.max_connections(MAX_CONNECTIONS)
|
|
.idle_timeout(Some(Duration::from_secs(IDLE_SECONDS)))
|
|
.after_release(|conn, _| {
|
|
Box::pin(async move {
|
|
// Essentially "DISCARD ALL" without "DEALLOCATE ALL"
|
|
conn.execute(raw_sql(
|
|
"
|
|
close all;
|
|
set session authorization default;
|
|
reset all;
|
|
unlisten *;
|
|
select pg_advisory_unlock_all();
|
|
discard plans;
|
|
discard temp;
|
|
discard sequences;
|
|
",
|
|
))
|
|
.await?;
|
|
Ok(true)
|
|
})
|
|
})
|
|
.connect(&base.url)
|
|
.await?)
|
|
};
|
|
|
|
// Attempt to get an existing pool without write-locking the map
|
|
let pools = self.pools.read().await;
|
|
if let Some(cell) = pools.get(&base_id) {
|
|
return Ok(cell
|
|
.get_or_try_init::<anyhow::Error, _, _>(init_cell)
|
|
.await?
|
|
.clone());
|
|
}
|
|
drop(pools); // Release read lock
|
|
let mut pools = self.pools.write().await;
|
|
let entry = pools.entry(base_id).or_insert(OnceCell::new());
|
|
Ok(entry
|
|
.get_or_try_init::<anyhow::Error, _, _>(init_cell)
|
|
.await?
|
|
.clone())
|
|
}
|
|
|
|
pub async fn acquire_for(&mut self, base_id: Uuid) -> Result<PoolConnection<Postgres>> {
|
|
let pool = self.get_pool_for(base_id).await?;
|
|
Ok(pool.acquire().await?)
|
|
}
|
|
|
|
pub async fn close_for(&mut self, base_id: Uuid) -> Result<()> {
|
|
let pools = self.pools.read().await;
|
|
if let Some(cell) = pools.get(&base_id) {
|
|
if let Some(pool) = cell.get() {
|
|
let pool = pool.clone();
|
|
drop(pools); // Release read lock
|
|
let mut pools = self.pools.write().await;
|
|
pools.remove(&base_id);
|
|
drop(pools); // Release write lock
|
|
pool.close().await;
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
// TODO: Add a cleanup method to remove entries with no connections
|
|
}
|
|
|
|
impl<S> FromRef<S> for BasePooler
|
|
where
|
|
S: Into<AppState> + Clone,
|
|
{
|
|
fn from_ref(state: &S) -> Self {
|
|
Into::<AppState>::into(state.clone()).base_pooler.clone()
|
|
}
|
|
}
|