phonograph/phono-pestgros/src/query_builders.rs

223 lines
7.9 KiB
Rust
Raw Normal View History

2026-02-13 08:00:23 +00:00
//! Assorted utilities for dynamically constructing and manipulating [`sqlx`]
//! queries.
use sqlx::{Postgres, QueryBuilder};
use crate::{ArithOp, BoolOp, Datum, Expr, FnArgs, InfixOp, escape_identifier};
/// Representation of a partial, parameterized SQL query. Allows callers to
/// build queries iteratively and dynamically, handling parameter numbering
/// (`$1`, `$2`, `$3`, ...) automatically.
///
/// This is similar to [`sqlx::QueryBuilder`], except that [`QueryFragment`]
/// objects are composable and may be concatenated to each other.
#[derive(Clone, Debug, PartialEq)]
pub struct QueryFragment {
/// SQL string, split wherever there is a query parameter. For example,
/// `select * from foo where id = $1 and status = $2` is represented along
/// the lines of `["select * from foo where id = ", " and status = ", ""]`.
/// `plain_sql` should always have exactly one more element than `params`.
plain_sql: Vec<String>,
params: Vec<Datum>,
}
impl QueryFragment {
/// Validate invariants. Should be run immediately before returning any
/// useful output.
fn gut_checks(&self) {
assert!(self.plain_sql.len() == self.params.len() + 1);
}
/// Parse from a SQL string with no parameters.
pub fn from_sql(sql: &str) -> Self {
Self {
plain_sql: vec![sql.to_owned()],
params: vec![],
}
}
/// Convenience function to construct an empty value.
pub fn empty() -> Self {
Self::from_sql("")
}
/// Parse from a parameter value with no additional SQL. (Renders as `$n`,
/// where`n` is the appropriate parameter index.)
pub fn from_param(param: Datum) -> Self {
Self {
plain_sql: vec!["".to_owned(), "".to_owned()],
params: vec![param],
}
}
/// Append another query fragment to this one.
pub fn push(&mut self, mut other: QueryFragment) {
let tail = self
.plain_sql
.pop()
.expect("already asserted that vec contains at least 1 item");
let head = other
.plain_sql
.first()
.expect("already asserted that vec contains at least 1 item");
self.plain_sql.push(format!("{tail}{head}"));
for value in other.plain_sql.drain(1..) {
self.plain_sql.push(value);
}
self.params.append(&mut other.params);
}
/// Combine multiple QueryFragments with a separator, similar to
/// [`Vec::join`].
pub fn join<I: IntoIterator<Item = Self>>(fragments: I, sep: Self) -> Self {
let mut acc = QueryFragment::from_sql("");
let mut iter = fragments.into_iter();
let mut fragment = match iter.next() {
Some(value) => value,
None => return acc,
};
for next_fragment in iter {
acc.push(fragment);
acc.push(sep.clone());
fragment = next_fragment;
}
acc.push(fragment);
acc
}
/// Convenience method equivalent to:
/// `QueryFragment::concat(fragments, QueryFragment::from_sql(""))`
pub fn concat<I: IntoIterator<Item = Self>>(fragments: I) -> Self {
Self::join(fragments, Self::from_sql(""))
}
/// Checks whether value is empty. A value is considered empty if the
/// resulting SQL code is 0 characters long.
pub fn is_empty(&self) -> bool {
self.gut_checks();
self.plain_sql.len() == 1
&& self
.plain_sql
.first()
.expect("already checked that len == 1")
.is_empty()
}
}
impl From<Expr> for QueryFragment {
fn from(value: Expr) -> Self {
match value {
Expr::Infix { lhs, op, rhs } => Self::concat([
// RHS and LHS must be explicitly wrapped in parentheses to
// ensure correct precedence, because parentheses are taken
// into account **but not preserved** when parsing.
Self::from_sql("("),
(*lhs).into(),
Self::from_sql(") "),
op.into(),
Self::from_sql(" ("),
(*rhs).into(),
Self::from_sql(")"),
]),
Expr::Literal(datum) => Self::from_param(datum),
Expr::ObjName(idents) => Self::join(
idents
.iter()
.map(|ident| Self::from_sql(&escape_identifier(ident))),
Self::from_sql("."),
),
Expr::Not(expr) => {
Self::concat([Self::from_sql("not ("), (*expr).into(), Self::from_sql(")")])
}
Expr::Nullness { is_null, expr } => Self::concat([
Self::from_sql("("),
(*expr).into(),
Self::from_sql(if is_null {
") is null"
} else {
") is not null"
}),
]),
Expr::FnCall { name, args } => {
let mut fragment = Self::empty();
fragment.push(Self::join(
name.iter()
.map(|ident| Self::from_sql(&escape_identifier(ident))),
Self::from_sql("."),
));
fragment.push(Self::from_sql("("));
match args {
FnArgs::CountAsterisk => {
fragment.push(Self::from_sql("*"));
}
FnArgs::Exprs {
distinct_flag,
exprs,
} => {
if distinct_flag {
fragment.push(Self::from_sql("distinct "));
}
fragment.push(Self::join(
exprs.into_iter().map(|expr| {
// Wrap arguments in parentheses to ensure they
// are appropriately distinguishable from each
// other regardless of the presence of extra
// commas.
Self::concat([
Self::from_sql("("),
expr.into(),
Self::from_sql(")"),
])
}),
Self::from_sql(", "),
));
}
}
fragment.push(Self::from_sql(")"));
fragment
}
}
}
}
impl From<InfixOp> for QueryFragment {
fn from(value: InfixOp) -> Self {
Self::from_sql(match value {
InfixOp::ArithInfix(ArithOp::Add) => "+",
InfixOp::ArithInfix(ArithOp::Concat) => "||",
InfixOp::ArithInfix(ArithOp::Div) => "/",
InfixOp::ArithInfix(ArithOp::Mult) => "*",
InfixOp::ArithInfix(ArithOp::Sub) => "-",
InfixOp::BoolInfix(BoolOp::And) => "and",
InfixOp::BoolInfix(BoolOp::Or) => "or",
InfixOp::BoolInfix(BoolOp::Eq) => "=",
InfixOp::BoolInfix(BoolOp::Gt) => ">",
InfixOp::BoolInfix(BoolOp::Gte) => ">=",
InfixOp::BoolInfix(BoolOp::Lt) => "<",
InfixOp::BoolInfix(BoolOp::Lte) => "<=",
InfixOp::BoolInfix(BoolOp::Neq) => "<>",
})
}
}
impl From<QueryFragment> for QueryBuilder<'_, Postgres> {
fn from(value: QueryFragment) -> Self {
value.gut_checks();
let mut builder = QueryBuilder::new("");
let mut param_iter = value.params.into_iter();
for plain_sql in value.plain_sql {
builder.push(plain_sql);
if let Some(param) = param_iter.next() {
param.push_bind_onto(&mut builder);
}
}
builder
}
}
impl From<Expr> for QueryBuilder<'_, Postgres> {
fn from(value: Expr) -> Self {
Self::from(QueryFragment::from(value))
}
}