phonograph/interim-models/src/expression.rs

251 lines
7.5 KiB
Rust
Raw Normal View History

2025-08-24 23:24:01 -07:00
use std::fmt::Display;
use interim_pgtypes::escape_identifier;
use serde::{Deserialize, Serialize};
use crate::datum::Datum;
2025-08-24 23:24:01 -07:00
#[derive(Clone, Debug, PartialEq)]
pub struct QueryFragment {
/// SQL string, split wherever there is a query parameter. For example,
/// `select * from foo where id = $1 and status = $2` is represented along
/// the lines of `["select * from foo where id = ", " and status = ", ""]`.
/// `plain_sql` should always have exactly one more element than `params`.
plain_sql: Vec<String>,
params: Vec<Datum>,
2025-08-24 23:24:01 -07:00
}
impl QueryFragment {
pub fn to_sql(&self, first_param_idx: usize) -> String {
assert!(self.plain_sql.len() == self.params.len() + 1);
self.plain_sql
.iter()
.cloned()
.zip((first_param_idx..).map(|n| format!("${n}")))
.fold(
Vec::with_capacity(2 * self.plain_sql.len()),
|mut acc, pair| {
acc.extend([pair.0, pair.1]);
acc
},
)
.get(0..(2 * self.plain_sql.len() - 1))
.expect("already asserted sufficient length")
.join("")
}
pub fn to_params(&self) -> Vec<Datum> {
2025-08-24 23:24:01 -07:00
self.params.clone()
}
pub fn from_sql(sql: &str) -> Self {
Self {
plain_sql: vec![sql.to_owned()],
params: vec![],
}
}
pub fn from_param(param: Datum) -> Self {
2025-08-24 23:24:01 -07:00
Self {
plain_sql: vec!["".to_owned(), "".to_owned()],
params: vec![param],
}
}
pub fn push(&mut self, mut other: QueryFragment) {
assert!(self.plain_sql.len() == self.params.len() + 1);
assert!(other.plain_sql.len() == other.params.len() + 1);
let tail = self
.plain_sql
.pop()
.expect("already asserted that vec contains at least 1 item");
let head = other
.plain_sql
.first()
.expect("already asserted that vec contains at least 1 item");
self.plain_sql.push(format!("{tail}{head}"));
for value in other.plain_sql.drain(1..) {
self.plain_sql.push(value);
}
self.params.append(&mut other.params);
}
/// Combine multiple QueryFragments with a separator, similar to Vec::join().
pub fn join<I: IntoIterator<Item = Self>>(fragments: I, sep: Self) -> Self {
let mut acc = QueryFragment::from_sql("");
let mut iter = fragments.into_iter();
let mut fragment = match iter.next() {
Some(value) => value,
None => return acc,
};
for next_fragment in iter {
acc.push(fragment);
acc.push(sep.clone());
fragment = next_fragment;
}
acc.push(fragment);
acc
}
/// Convenience method equivalent to:
/// `QueryFragment::concat(fragments, QueryFragment::from_sql(""))`
pub fn concat<I: IntoIterator<Item = Self>>(fragments: I) -> Self {
Self::join(fragments, Self::from_sql(""))
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[serde(tag = "t", content = "c")]
pub enum PgExpressionAny {
Comparison(PgComparisonExpression),
Identifier(PgIdentifierExpression),
Literal(Datum),
2025-08-24 23:24:01 -07:00
ToJson(PgToJsonExpression),
}
impl PgExpressionAny {
pub fn into_query_fragment(self) -> QueryFragment {
match self {
Self::Comparison(expr) => expr.into_query_fragment(),
Self::Identifier(expr) => expr.into_query_fragment(),
Self::Literal(expr) => {
if expr.is_none() {
QueryFragment::from_sql("null")
} else {
QueryFragment::from_param(expr)
}
}
Self::ToJson(expr) => expr.into_query_fragment(),
}
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[serde(tag = "t", content = "c")]
pub enum PgComparisonExpression {
Infix(PgInfixExpression<PgComparisonOperator>),
IsNull(PgIsNullExpression),
IsNotNull(PgIsNotNullExpression),
}
impl PgComparisonExpression {
fn into_query_fragment(self) -> QueryFragment {
match self {
Self::Infix(expr) => expr.into_query_fragment(),
Self::IsNull(expr) => expr.into_query_fragment(),
Self::IsNotNull(expr) => expr.into_query_fragment(),
}
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct PgInfixExpression<T: Display> {
pub operator: T,
pub lhs: Box<PgExpressionAny>,
pub rhs: Box<PgExpressionAny>,
}
impl<T: Display> PgInfixExpression<T> {
fn into_query_fragment(self) -> QueryFragment {
QueryFragment::concat([
QueryFragment::from_sql("(("),
self.lhs.into_query_fragment(),
QueryFragment::from_sql(&format!(") {} (", self.operator)),
self.rhs.into_query_fragment(),
QueryFragment::from_sql("))"),
])
}
}
#[derive(Clone, Debug, strum::Display, Deserialize, PartialEq, Serialize)]
pub enum PgComparisonOperator {
#[strum(to_string = "and")]
And,
#[strum(to_string = "=")]
Eq,
#[strum(to_string = ">")]
Gt,
#[strum(to_string = "<")]
Lt,
#[strum(to_string = "<>")]
Neq,
#[strum(to_string = "or")]
Or,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct PgIsNullExpression {
lhs: Box<PgExpressionAny>,
}
impl PgIsNullExpression {
fn into_query_fragment(self) -> QueryFragment {
QueryFragment::concat([
QueryFragment::from_sql("(("),
self.lhs.into_query_fragment(),
QueryFragment::from_sql(") is null)"),
])
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct PgIsNotNullExpression {
lhs: Box<PgExpressionAny>,
}
impl PgIsNotNullExpression {
fn into_query_fragment(self) -> QueryFragment {
QueryFragment::concat([
QueryFragment::from_sql("(("),
self.lhs.into_query_fragment(),
QueryFragment::from_sql(") is not null)"),
])
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct PgIdentifierExpression {
pub parts_raw: Vec<String>,
}
impl PgIdentifierExpression {
fn into_query_fragment(self) -> QueryFragment {
QueryFragment::join(
self.parts_raw
.iter()
.map(|part| QueryFragment::from_sql(&escape_identifier(part))),
QueryFragment::from_sql("."),
)
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct PgToJsonExpression {
entries: Vec<(String, PgExpressionAny)>,
}
impl PgToJsonExpression {
/// Generates a query fragment to the effect of:
/// `to_json((select ($expr) as "ident", ($expr2) as "ident2"))`
fn into_query_fragment(self) -> QueryFragment {
if self.entries.is_empty() {
QueryFragment::from_sql("'{}'")
} else {
QueryFragment::concat([
QueryFragment::from_sql("to_json((select "),
QueryFragment::join(
self.entries.into_iter().map(|(key, value)| {
QueryFragment::concat([
QueryFragment::from_sql("("),
value.into_query_fragment(),
QueryFragment::from_sql(&format!(") as {}", escape_identifier(&key))),
])
}),
QueryFragment::from_sql(", "),
),
QueryFragment::from_sql("))"),
])
}
}
}