use std::fmt::Display; use interim_pgtypes::escape_identifier; use serde::{Deserialize, Serialize}; use crate::datum::Datum; #[derive(Clone, Debug, PartialEq)] pub struct QueryFragment { /// SQL string, split wherever there is a query parameter. For example, /// `select * from foo where id = $1 and status = $2` is represented along /// the lines of `["select * from foo where id = ", " and status = ", ""]`. /// `plain_sql` should always have exactly one more element than `params`. plain_sql: Vec, params: Vec, } impl QueryFragment { pub fn to_sql(&self, first_param_idx: usize) -> String { assert!(self.plain_sql.len() == self.params.len() + 1); self.plain_sql .iter() .cloned() .zip((first_param_idx..).map(|n| format!("${n}"))) .fold( Vec::with_capacity(2 * self.plain_sql.len()), |mut acc, pair| { acc.extend([pair.0, pair.1]); acc }, ) .get(0..(2 * self.plain_sql.len() - 1)) .expect("already asserted sufficient length") .join("") } pub fn to_params(&self) -> Vec { self.params.clone() } pub fn from_sql(sql: &str) -> Self { Self { plain_sql: vec![sql.to_owned()], params: vec![], } } pub fn from_param(param: Datum) -> Self { Self { plain_sql: vec!["".to_owned(), "".to_owned()], params: vec![param], } } pub fn push(&mut self, mut other: QueryFragment) { assert!(self.plain_sql.len() == self.params.len() + 1); assert!(other.plain_sql.len() == other.params.len() + 1); let tail = self .plain_sql .pop() .expect("already asserted that vec contains at least 1 item"); let head = other .plain_sql .first() .expect("already asserted that vec contains at least 1 item"); self.plain_sql.push(format!("{tail}{head}")); for value in other.plain_sql.drain(1..) { self.plain_sql.push(value); } self.params.append(&mut other.params); } /// Combine multiple QueryFragments with a separator, similar to Vec::join(). pub fn join>(fragments: I, sep: Self) -> Self { let mut acc = QueryFragment::from_sql(""); let mut iter = fragments.into_iter(); let mut fragment = match iter.next() { Some(value) => value, None => return acc, }; for next_fragment in iter { acc.push(fragment); acc.push(sep.clone()); fragment = next_fragment; } acc.push(fragment); acc } /// Convenience method equivalent to: /// `QueryFragment::concat(fragments, QueryFragment::from_sql(""))` pub fn concat>(fragments: I) -> Self { Self::join(fragments, Self::from_sql("")) } } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] #[serde(tag = "t", content = "c")] pub enum PgExpressionAny { Comparison(PgComparisonExpression), Identifier(PgIdentifierExpression), Literal(Datum), ToJson(PgToJsonExpression), } impl PgExpressionAny { pub fn into_query_fragment(self) -> QueryFragment { match self { Self::Comparison(expr) => expr.into_query_fragment(), Self::Identifier(expr) => expr.into_query_fragment(), Self::Literal(expr) => { if expr.is_none() { QueryFragment::from_sql("null") } else { QueryFragment::from_param(expr) } } Self::ToJson(expr) => expr.into_query_fragment(), } } } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] #[serde(tag = "t", content = "c")] pub enum PgComparisonExpression { Infix(PgInfixExpression), IsNull(PgIsNullExpression), IsNotNull(PgIsNotNullExpression), } impl PgComparisonExpression { fn into_query_fragment(self) -> QueryFragment { match self { Self::Infix(expr) => expr.into_query_fragment(), Self::IsNull(expr) => expr.into_query_fragment(), Self::IsNotNull(expr) => expr.into_query_fragment(), } } } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct PgInfixExpression { pub operator: T, pub lhs: Box, pub rhs: Box, } impl PgInfixExpression { fn into_query_fragment(self) -> QueryFragment { QueryFragment::concat([ QueryFragment::from_sql("(("), self.lhs.into_query_fragment(), QueryFragment::from_sql(&format!(") {} (", self.operator)), self.rhs.into_query_fragment(), QueryFragment::from_sql("))"), ]) } } #[derive(Clone, Debug, strum::Display, Deserialize, PartialEq, Serialize)] pub enum PgComparisonOperator { #[strum(to_string = "and")] And, #[strum(to_string = "=")] Eq, #[strum(to_string = ">")] Gt, #[strum(to_string = "<")] Lt, #[strum(to_string = "<>")] Neq, #[strum(to_string = "or")] Or, } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct PgIsNullExpression { lhs: Box, } impl PgIsNullExpression { fn into_query_fragment(self) -> QueryFragment { QueryFragment::concat([ QueryFragment::from_sql("(("), self.lhs.into_query_fragment(), QueryFragment::from_sql(") is null)"), ]) } } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct PgIsNotNullExpression { lhs: Box, } impl PgIsNotNullExpression { fn into_query_fragment(self) -> QueryFragment { QueryFragment::concat([ QueryFragment::from_sql("(("), self.lhs.into_query_fragment(), QueryFragment::from_sql(") is not null)"), ]) } } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct PgIdentifierExpression { pub parts_raw: Vec, } impl PgIdentifierExpression { fn into_query_fragment(self) -> QueryFragment { QueryFragment::join( self.parts_raw .iter() .map(|part| QueryFragment::from_sql(&escape_identifier(part))), QueryFragment::from_sql("."), ) } } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct PgToJsonExpression { entries: Vec<(String, PgExpressionAny)>, } impl PgToJsonExpression { /// Generates a query fragment to the effect of: /// `to_json((select ($expr) as "ident", ($expr2) as "ident2"))` fn into_query_fragment(self) -> QueryFragment { if self.entries.is_empty() { QueryFragment::from_sql("'{}'") } else { QueryFragment::concat([ QueryFragment::from_sql("to_json((select "), QueryFragment::join( self.entries.into_iter().map(|(key, value)| { QueryFragment::concat([ QueryFragment::from_sql("("), value.into_query_fragment(), QueryFragment::from_sql(&format!(") as {}", escape_identifier(&key))), ]) }), QueryFragment::from_sql(", "), ), QueryFragment::from_sql("))"), ]) } } }