251 lines
7.6 KiB
Rust
251 lines
7.6 KiB
Rust
|
|
use std::fmt::Display;
|
||
|
|
|
||
|
|
use interim_pgtypes::escape_identifier;
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
|
||
|
|
use crate::field::Encodable;
|
||
|
|
|
||
|
|
#[derive(Clone, Debug, PartialEq)]
|
||
|
|
pub struct QueryFragment {
|
||
|
|
/// SQL string, split wherever there is a query parameter. For example,
|
||
|
|
/// `select * from foo where id = $1 and status = $2` is represented along
|
||
|
|
/// the lines of `["select * from foo where id = ", " and status = ", ""]`.
|
||
|
|
/// `plain_sql` should always have exactly one more element than `params`.
|
||
|
|
plain_sql: Vec<String>,
|
||
|
|
params: Vec<Encodable>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl QueryFragment {
|
||
|
|
pub fn to_sql(&self, first_param_idx: usize) -> String {
|
||
|
|
assert!(self.plain_sql.len() == self.params.len() + 1);
|
||
|
|
self.plain_sql
|
||
|
|
.iter()
|
||
|
|
.cloned()
|
||
|
|
.zip((first_param_idx..).map(|n| format!("${n}")))
|
||
|
|
.fold(
|
||
|
|
Vec::with_capacity(2 * self.plain_sql.len()),
|
||
|
|
|mut acc, pair| {
|
||
|
|
acc.extend([pair.0, pair.1]);
|
||
|
|
acc
|
||
|
|
},
|
||
|
|
)
|
||
|
|
.get(0..(2 * self.plain_sql.len() - 1))
|
||
|
|
.expect("already asserted sufficient length")
|
||
|
|
.join("")
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn to_params(&self) -> Vec<Encodable> {
|
||
|
|
self.params.clone()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn from_sql(sql: &str) -> Self {
|
||
|
|
Self {
|
||
|
|
plain_sql: vec![sql.to_owned()],
|
||
|
|
params: vec![],
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn from_param(param: Encodable) -> Self {
|
||
|
|
Self {
|
||
|
|
plain_sql: vec!["".to_owned(), "".to_owned()],
|
||
|
|
params: vec![param],
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn push(&mut self, mut other: QueryFragment) {
|
||
|
|
assert!(self.plain_sql.len() == self.params.len() + 1);
|
||
|
|
assert!(other.plain_sql.len() == other.params.len() + 1);
|
||
|
|
let tail = self
|
||
|
|
.plain_sql
|
||
|
|
.pop()
|
||
|
|
.expect("already asserted that vec contains at least 1 item");
|
||
|
|
let head = other
|
||
|
|
.plain_sql
|
||
|
|
.first()
|
||
|
|
.expect("already asserted that vec contains at least 1 item");
|
||
|
|
self.plain_sql.push(format!("{tail}{head}"));
|
||
|
|
for value in other.plain_sql.drain(1..) {
|
||
|
|
self.plain_sql.push(value);
|
||
|
|
}
|
||
|
|
self.params.append(&mut other.params);
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Combine multiple QueryFragments with a separator, similar to Vec::join().
|
||
|
|
pub fn join<I: IntoIterator<Item = Self>>(fragments: I, sep: Self) -> Self {
|
||
|
|
let mut acc = QueryFragment::from_sql("");
|
||
|
|
let mut iter = fragments.into_iter();
|
||
|
|
let mut fragment = match iter.next() {
|
||
|
|
Some(value) => value,
|
||
|
|
None => return acc,
|
||
|
|
};
|
||
|
|
for next_fragment in iter {
|
||
|
|
acc.push(fragment);
|
||
|
|
acc.push(sep.clone());
|
||
|
|
fragment = next_fragment;
|
||
|
|
}
|
||
|
|
acc.push(fragment);
|
||
|
|
acc
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Convenience method equivalent to:
|
||
|
|
/// `QueryFragment::concat(fragments, QueryFragment::from_sql(""))`
|
||
|
|
pub fn concat<I: IntoIterator<Item = Self>>(fragments: I) -> Self {
|
||
|
|
Self::join(fragments, Self::from_sql(""))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||
|
|
#[serde(tag = "t", content = "c")]
|
||
|
|
pub enum PgExpressionAny {
|
||
|
|
Comparison(PgComparisonExpression),
|
||
|
|
Identifier(PgIdentifierExpression),
|
||
|
|
Literal(Encodable),
|
||
|
|
ToJson(PgToJsonExpression),
|
||
|
|
}
|
||
|
|
|
||
|
|
impl PgExpressionAny {
|
||
|
|
pub fn into_query_fragment(self) -> QueryFragment {
|
||
|
|
match self {
|
||
|
|
Self::Comparison(expr) => expr.into_query_fragment(),
|
||
|
|
Self::Identifier(expr) => expr.into_query_fragment(),
|
||
|
|
Self::Literal(expr) => {
|
||
|
|
if expr.is_none() {
|
||
|
|
QueryFragment::from_sql("null")
|
||
|
|
} else {
|
||
|
|
QueryFragment::from_param(expr)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Self::ToJson(expr) => expr.into_query_fragment(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||
|
|
#[serde(tag = "t", content = "c")]
|
||
|
|
pub enum PgComparisonExpression {
|
||
|
|
Infix(PgInfixExpression<PgComparisonOperator>),
|
||
|
|
IsNull(PgIsNullExpression),
|
||
|
|
IsNotNull(PgIsNotNullExpression),
|
||
|
|
}
|
||
|
|
|
||
|
|
impl PgComparisonExpression {
|
||
|
|
fn into_query_fragment(self) -> QueryFragment {
|
||
|
|
match self {
|
||
|
|
Self::Infix(expr) => expr.into_query_fragment(),
|
||
|
|
Self::IsNull(expr) => expr.into_query_fragment(),
|
||
|
|
Self::IsNotNull(expr) => expr.into_query_fragment(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||
|
|
pub struct PgInfixExpression<T: Display> {
|
||
|
|
pub operator: T,
|
||
|
|
pub lhs: Box<PgExpressionAny>,
|
||
|
|
pub rhs: Box<PgExpressionAny>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl<T: Display> PgInfixExpression<T> {
|
||
|
|
fn into_query_fragment(self) -> QueryFragment {
|
||
|
|
QueryFragment::concat([
|
||
|
|
QueryFragment::from_sql("(("),
|
||
|
|
self.lhs.into_query_fragment(),
|
||
|
|
QueryFragment::from_sql(&format!(") {} (", self.operator)),
|
||
|
|
self.rhs.into_query_fragment(),
|
||
|
|
QueryFragment::from_sql("))"),
|
||
|
|
])
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Clone, Debug, strum::Display, Deserialize, PartialEq, Serialize)]
|
||
|
|
pub enum PgComparisonOperator {
|
||
|
|
#[strum(to_string = "and")]
|
||
|
|
And,
|
||
|
|
#[strum(to_string = "=")]
|
||
|
|
Eq,
|
||
|
|
#[strum(to_string = ">")]
|
||
|
|
Gt,
|
||
|
|
#[strum(to_string = "<")]
|
||
|
|
Lt,
|
||
|
|
#[strum(to_string = "<>")]
|
||
|
|
Neq,
|
||
|
|
#[strum(to_string = "or")]
|
||
|
|
Or,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||
|
|
pub struct PgIsNullExpression {
|
||
|
|
lhs: Box<PgExpressionAny>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl PgIsNullExpression {
|
||
|
|
fn into_query_fragment(self) -> QueryFragment {
|
||
|
|
QueryFragment::concat([
|
||
|
|
QueryFragment::from_sql("(("),
|
||
|
|
self.lhs.into_query_fragment(),
|
||
|
|
QueryFragment::from_sql(") is null)"),
|
||
|
|
])
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||
|
|
pub struct PgIsNotNullExpression {
|
||
|
|
lhs: Box<PgExpressionAny>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl PgIsNotNullExpression {
|
||
|
|
fn into_query_fragment(self) -> QueryFragment {
|
||
|
|
QueryFragment::concat([
|
||
|
|
QueryFragment::from_sql("(("),
|
||
|
|
self.lhs.into_query_fragment(),
|
||
|
|
QueryFragment::from_sql(") is not null)"),
|
||
|
|
])
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||
|
|
pub struct PgIdentifierExpression {
|
||
|
|
pub parts_raw: Vec<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl PgIdentifierExpression {
|
||
|
|
fn into_query_fragment(self) -> QueryFragment {
|
||
|
|
QueryFragment::join(
|
||
|
|
self.parts_raw
|
||
|
|
.iter()
|
||
|
|
.map(|part| QueryFragment::from_sql(&escape_identifier(part))),
|
||
|
|
QueryFragment::from_sql("."),
|
||
|
|
)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||
|
|
pub struct PgToJsonExpression {
|
||
|
|
entries: Vec<(String, PgExpressionAny)>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl PgToJsonExpression {
|
||
|
|
/// Generates a query fragment to the effect of:
|
||
|
|
/// `to_json((select ($expr) as "ident", ($expr2) as "ident2"))`
|
||
|
|
fn into_query_fragment(self) -> QueryFragment {
|
||
|
|
if self.entries.is_empty() {
|
||
|
|
QueryFragment::from_sql("'{}'")
|
||
|
|
} else {
|
||
|
|
QueryFragment::concat([
|
||
|
|
QueryFragment::from_sql("to_json((select "),
|
||
|
|
QueryFragment::join(
|
||
|
|
self.entries.into_iter().map(|(key, value)| {
|
||
|
|
QueryFragment::concat([
|
||
|
|
QueryFragment::from_sql("("),
|
||
|
|
value.into_query_fragment(),
|
||
|
|
QueryFragment::from_sql(&format!(") as {}", escape_identifier(&key))),
|
||
|
|
])
|
||
|
|
}),
|
||
|
|
QueryFragment::from_sql(", "),
|
||
|
|
),
|
||
|
|
QueryFragment::from_sql("))"),
|
||
|
|
])
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|