use std::fmt::Display; use phono_backends::escape_identifier; use serde::{Deserialize, Serialize}; use sqlx::{Postgres, QueryBuilder}; use crate::datum::Datum; /// Representation of a partial, parameterized SQL query. Allows callers to /// build queries iteratively and dynamically, handling parameter numbering /// (`$1`, `$2`, `$3`, ...) automatically. /// /// This is similar to [`sqlx::QueryBuilder`], except that [`QueryFragment`] /// objects are composable and may be concatenated to each other. #[derive(Clone, Debug, PartialEq)] pub struct QueryFragment { /// SQL string, split wherever there is a query parameter. For example, /// `select * from foo where id = $1 and status = $2` is represented along /// the lines of `["select * from foo where id = ", " and status = ", ""]`. /// `plain_sql` should always have exactly one more element than `params`. plain_sql: Vec, params: Vec, } impl QueryFragment { /// Validate invariants. Should be run immediately before returning any /// useful output. fn gut_checks(&self) { assert!(self.plain_sql.len() == self.params.len() + 1); } pub fn to_sql(&self, first_param_idx: usize) -> String { self.gut_checks(); self.plain_sql .iter() .cloned() .zip((first_param_idx..).map(|n| format!("${n}"))) .fold( Vec::with_capacity(2 * self.plain_sql.len()), |mut acc, pair| { acc.extend([pair.0, pair.1]); acc }, ) .get(0..(2 * self.plain_sql.len() - 1)) .expect("already asserted sufficient length") .join("") } /// Returns only the parameterized values, in order. pub fn to_params(&self) -> Vec { self.gut_checks(); self.params.clone() } /// Parse from a SQL string with no parameters. pub fn from_sql(sql: &str) -> Self { Self { plain_sql: vec![sql.to_owned()], params: vec![], } } /// Convenience function to construct an empty value. pub fn empty() -> Self { Self::from_sql("") } /// Parse from a parameter value with no additional SQL. (Renders as `$n`, /// where`n` is the appropriate parameter index.) pub fn from_param(param: Datum) -> Self { Self { plain_sql: vec!["".to_owned(), "".to_owned()], params: vec![param], } } /// Append another query fragment to this one. pub fn push(&mut self, mut other: QueryFragment) { let tail = self .plain_sql .pop() .expect("already asserted that vec contains at least 1 item"); let head = other .plain_sql .first() .expect("already asserted that vec contains at least 1 item"); self.plain_sql.push(format!("{tail}{head}")); for value in other.plain_sql.drain(1..) { self.plain_sql.push(value); } self.params.append(&mut other.params); } /// Combine multiple QueryFragments with a separator, similar to /// [`Vec::join`]. pub fn join>(fragments: I, sep: Self) -> Self { let mut acc = QueryFragment::from_sql(""); let mut iter = fragments.into_iter(); let mut fragment = match iter.next() { Some(value) => value, None => return acc, }; for next_fragment in iter { acc.push(fragment); acc.push(sep.clone()); fragment = next_fragment; } acc.push(fragment); acc } /// Convenience method equivalent to: /// `QueryFragment::concat(fragments, QueryFragment::from_sql(""))` pub fn concat>(fragments: I) -> Self { Self::join(fragments, Self::from_sql("")) } /// Checks whether value is empty. A value is considered empty if the /// resulting SQL code is 0 characters long. pub fn is_empty(&self) -> bool { self.gut_checks(); self.plain_sql.len() == 1 && self .plain_sql .first() .expect("already checked that len == 1") .is_empty() } } impl From for QueryBuilder<'_, Postgres> { fn from(value: QueryFragment) -> Self { value.gut_checks(); let mut builder = QueryBuilder::new(""); let mut param_iter = value.params.into_iter(); for plain_sql in value.plain_sql { builder.push(plain_sql); if let Some(param) = param_iter.next() { param.push_bind_onto(&mut builder); } } builder } } /// Building block of a syntax tree for a constrained subset of SQL that can be /// statically analyzed, to validate that user-provided expressions perform only /// operations that are read-only and otherwise safe to execute. #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] #[serde(tag = "t", content = "c")] pub enum PgExpressionAny { Comparison(PgComparisonExpression), Identifier(PgIdentifierExpression), Literal(Datum), ToJson(PgToJsonExpression), } impl PgExpressionAny { pub fn into_query_fragment(self) -> QueryFragment { match self { Self::Comparison(expr) => expr.into_query_fragment(), Self::Identifier(expr) => expr.into_query_fragment(), Self::Literal(expr) => { if expr.is_none() { QueryFragment::from_sql("null") } else { QueryFragment::from_param(expr) } } Self::ToJson(expr) => expr.into_query_fragment(), } } } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] #[serde(tag = "t", content = "c")] pub enum PgComparisonExpression { Infix(PgInfixExpression), IsNull(PgIsNullExpression), IsNotNull(PgIsNotNullExpression), } impl PgComparisonExpression { fn into_query_fragment(self) -> QueryFragment { match self { Self::Infix(expr) => expr.into_query_fragment(), Self::IsNull(expr) => expr.into_query_fragment(), Self::IsNotNull(expr) => expr.into_query_fragment(), } } } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct PgInfixExpression { pub operator: T, pub lhs: Box, pub rhs: Box, } impl PgInfixExpression { fn into_query_fragment(self) -> QueryFragment { QueryFragment::concat([ QueryFragment::from_sql("(("), self.lhs.into_query_fragment(), QueryFragment::from_sql(&format!(") {} (", self.operator)), self.rhs.into_query_fragment(), QueryFragment::from_sql("))"), ]) } } #[derive(Clone, Debug, strum::Display, Deserialize, PartialEq, Serialize)] pub enum PgComparisonOperator { #[strum(to_string = "and")] And, #[strum(to_string = "=")] Eq, #[strum(to_string = ">")] Gt, #[strum(to_string = "<")] Lt, #[strum(to_string = "<>")] Neq, #[strum(to_string = "or")] Or, } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct PgIsNullExpression { lhs: Box, } impl PgIsNullExpression { fn into_query_fragment(self) -> QueryFragment { QueryFragment::concat([ QueryFragment::from_sql("(("), self.lhs.into_query_fragment(), QueryFragment::from_sql(") is null)"), ]) } } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct PgIsNotNullExpression { lhs: Box, } impl PgIsNotNullExpression { fn into_query_fragment(self) -> QueryFragment { QueryFragment::concat([ QueryFragment::from_sql("(("), self.lhs.into_query_fragment(), QueryFragment::from_sql(") is not null)"), ]) } } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct PgIdentifierExpression { pub parts_raw: Vec, } impl PgIdentifierExpression { fn into_query_fragment(self) -> QueryFragment { QueryFragment::join( self.parts_raw .iter() .map(|part| QueryFragment::from_sql(&escape_identifier(part))), QueryFragment::from_sql("."), ) } } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct PgToJsonExpression { entries: Vec<(String, PgExpressionAny)>, } impl PgToJsonExpression { /// Generates a query fragment to the effect of: /// `to_json((select ($expr) as "ident", ($expr2) as "ident2"))` fn into_query_fragment(self) -> QueryFragment { if self.entries.is_empty() { QueryFragment::from_sql("'{}'") } else { QueryFragment::concat([ QueryFragment::from_sql("to_json((select "), QueryFragment::join( self.entries.into_iter().map(|(key, value)| { QueryFragment::concat([ QueryFragment::from_sql("("), value.into_query_fragment(), QueryFragment::from_sql(&format!(") as {}", escape_identifier(&key))), ]) }), QueryFragment::from_sql(", "), ), QueryFragment::from_sql("))"), ]) } } }