add support for any() and all() array comparisons

This commit is contained in:
Brent Schroeter 2026-02-16 04:31:55 +00:00
parent 928f6cb759
commit b2257ab1c0
6 changed files with 258 additions and 89 deletions

View file

@ -0,0 +1,52 @@
use crate::{Datum, Expr, InfixOp};
#[test]
fn empty_array_parses() {
assert_eq!(Expr::try_from("array[]"), Ok(Expr::Array(vec![])));
}
#[test]
fn array_of_literals_parses() {
assert_eq!(
Expr::try_from("array[1, 2, 3]"),
Ok(Expr::Array(vec![
Expr::Literal(Datum::Numeric(Some(1.into()))),
Expr::Literal(Datum::Numeric(Some(2.into()))),
Expr::Literal(Datum::Numeric(Some(3.into()))),
])),
);
}
#[test]
fn array_of_exprs_parses() {
assert_eq!(
Expr::try_from("array[(1), 2 + 3]"),
Ok(Expr::Array(vec![
Expr::Literal(Datum::Numeric(Some(1.into()))),
Expr::Infix {
lhs: Box::new(Expr::Literal(Datum::Numeric(Some(2.into())))),
op: InfixOp::Add,
rhs: Box::new(Expr::Literal(Datum::Numeric(Some(3.into())))),
},
])),
);
}
#[test]
fn array_cmp_modifier_parses() {
assert_eq!(
Expr::try_from("3 = any(array[3])"),
Ok(Expr::Infix {
lhs: Box::new(Expr::Literal(Datum::Numeric(Some(3.into())))),
op: InfixOp::WithCmpModifierAny(Box::new(InfixOp::Eq)),
rhs: Box::new(Expr::Array(vec![Expr::Literal(Datum::Numeric(Some(
3.into()
)))]))
}),
);
}
#[test]
fn non_parenthesized_array_cmp_modifier_fails() {
assert!(Expr::try_from("3 = any array[3]").is_err());
}

View file

@ -13,3 +13,13 @@ fn sql_converts_to_query_builder() -> Result<(), Box<dyn Error>> {
);
Ok(())
}
#[test]
fn cmp_array_modifier_round_trips() -> Result<(), Box<dyn Error>> {
let expr = Expr::try_from("1 = 2 and 3 < any(array[4])")?;
assert_eq!(
QueryBuilder::<'_, Postgres>::from(expr).sql(),
"(($1) = ($2)) and (($3) < any (array[($4)]))",
);
Ok(())
}

View file

@ -2,16 +2,14 @@
//! https://github.com/pest-parser/pest/blob/master/grammars/src/grammars/sql.pest.
//! (Original is dual-licensed under MIT/Apache-2.0.)
//!
//! Postgres largely conforms to the SQLite flavored dialect captured by the
//! original grammar, but its rules for identifiers differ:
//! PostgreSQL departs extensively from the SQLite flavored dialect captured in
//! the original grammar. For example, rules for identifiers/object names
//! differ, as do keywords, built-in types, and syntax for specifying function
//! arguments, type modifiers, CTEs, and so on.
//!
//! > SQL identifiers and key words must begin with a letter (a-z, but also
//! > letters with diacritical marks and non-Latin letters) or an underscore
//! > (_). Subsequent characters in an identifier or key word can be letters,
//! > underscores, digits (0-9), or dollar signs ($). Note that dollar signs are
//! > not allowed in identifiers according to the letter of the SQL standard,
//! > so their use might render applications less portable.
//! -- https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
//! This grammar covers a larger subset of the Postgres SQL dialect, but it is a
//! work in progress and is far from complete. It should only be used to parse
//! input that is "PostgreSQL-esque", not input that expects spec compliance.
Command = _{ SOI ~ (Query | ExplainQuery | DDL | ACL) ~ EOF }
@ -65,7 +63,7 @@ DDL = _{ CreateTable | DropTable | CreateProc }
Distribution
}
Columns = { ColumnDef ~ ("," ~ ColumnDef)* }
ColumnDef = { Identifier ~ ColumnDefType ~ ColumnDefIsNull? }
ColumnDef = { Identifier ~ TypeCast ~ ColumnDefIsNull? }
ColumnDefIsNull = { NotFlag? ~ ^"null" }
PrimaryKey = {
^"primary" ~ ^"key" ~
@ -82,7 +80,7 @@ DDL = _{ CreateTable | DropTable | CreateProc }
((^"as" ~ "$$" ~ ProcBody ~ "$$") | (^"begin" ~ "atomic" ~ ProcBody ~ "end"))
}
ProcParams = { ProcParamDef ~ ("," ~ ProcParamDef)* }
ProcParamDef = { ColumnDefType }
ProcParamDef = { TypeCast }
ProcLanguage = { SQL }
SQL = { ^"sql" }
ProcBody = { (Insert | Update | Delete) }
@ -130,38 +128,48 @@ Query = { (SelectWithOptionalContinuation | Values | Insert | Update | Delete) }
Identifier = ${ DoubleQuotedIdentifier | UnquotedIdentifier }
DoubleQuotedIdentifier = @{ "\"" ~ ("\"\"" | '\u{01}'..'\u{21}' | '\u{23}'..'\u{10FFFF}')+ ~ "\"" }
UnquotedIdentifier = @{ !(Keyword ~ ("(" | WHITESPACE | "," | EOF)) ~ (UnquotedIdentifierStart ~ UnquotedIdentifierRemainder*) }
UnquotedIdentifier = @{ !(Keyword ~ ("(" | "[" | WHITESPACE | "," | EOF)) ~ (UnquotedIdentifierStart ~ UnquotedIdentifierRemainder*) }
UnquotedIdentifierStart = _{ 'a'..'я' | 'A'..'Я' | "_" }
UnquotedIdentifierRemainder = _{ UnquotedIdentifierStart | "$" | ASCII_DIGIT }
Keyword = { ^"left" | ^"having" | ^"not" | ^"inner" | ^"group"
| ^"on" | ^"join" | ^"from" | ^"exists" | ^"except"
| ^"union" | ^"where" | ^"distinct" | ^"between" | ^"option"
| ^"values"}
| ^"values" | ^"with" | ^"as" | ^"array" | ^"any" | ^"some"
| ^"all" | ^"in" }
ExprRoot = _{ &SOI ~ Expr ~ &EOI }
Expr = { ExprAtomValue ~ (ExprInfixOp ~ ExprAtomValue)* }
ExprInfixOp = _{ Between | ArithInfixOp | CmpInfixOp | ConcatInfixOp | And | Or }
ExprInfixOp = _{ Between | NonCmpInfixOp | CmpInfixOp | ConcatInfixOp | And | Or }
Between = { NotFlag? ~ ^"between" }
And = { ^"and" }
Or = { ^"or" }
ConcatInfixOp = { "||" }
ArithInfixOp = _{ Add | Subtract | Multiply | Divide }
Add = { "+" }
Subtract = { "-" }
Multiply = { "*" }
Divide = { "/" }
CmpInfixOp = _{ NotEq | GtEq | Gt | LtEq | Lt | Eq | Lt | In }
CmpInfixOp = { (NotEq | GtEq | Gt | LtEq | Lt | Eq | Lt) ~ (CmpArrayModifier ~ &ExpressionInParentheses)? }
Eq = { "=" }
Gt = { ">" }
GtEq = { ">=" }
Lt = { "<" }
LtEq = { "<=" }
NotEq = { "<>" | "!=" }
NonCmpInfixOp = _{ Add | Subtract | Multiply | Divide | In }
Add = { "+" }
Subtract = { "-" }
Multiply = { "*" }
Divide = { "/" }
In = { NotFlag? ~ ^"in" }
CmpArrayModifier = { CmpModifierAny | CmpModifierAll }
CmpModifierAny = { ^"any" | ^"some "}
CmpModifierAll = { ^"all" }
ExprAtomValue = _{ UnaryNot* ~ AtomicExpr ~ IsNullPostfix? }
UnaryNot = @{ NotFlag }
IsNullPostfix = { ^"is" ~ NotFlag? ~ ^"null" }
AtomicExpr = _{ Literal | Parameter | Cast | IdentifierWithOptionalContinuation | ExpressionInParentheses | UnaryOperator | SubQuery | Row }
AtomicExpr = _{ Literal | Parameter | IdentifierWithOptionalContinuation | ExpressionInParentheses | UnaryOperator | SubQuery | Row | SquareBracketArray }
// TODO: Empty arrays don't parse without the `!"]"` prefix in the
// optional sequence of sub-expressions, but the reason is not
// immediately clear: the ']' character doesn't seem like it should
// be compatible with the beginning of any `AtomicExpr`. This may
// be worth investigating.
SquareBracketArray = { ^"array" ~ "[" ~ (!"]" ~ (Expr ~ ("," ~ Expr)*))? ~ "]" }
Literal = _{ True | False | Null | Double | Decimal | Unsigned | Integer | SingleQuotedString }
True = { ^"true" }
False = { ^"false" }
@ -184,24 +192,34 @@ Expr = { ExprAtomValue ~ (ExprInfixOp ~ ExprAtomValue)* }
FunctionInvocationContinuation = { "(" ~ (CountAsterisk | FunctionArgs)? ~ ")" }
// TODO: Support named argument notation
// (`my_func(name => value)`).
// TODO: Support keywords within args list as applicable.
FunctionArgs = { Distinct? ~ (Expr ~ ("," ~ Expr)*)? }
CountAsterisk = { "*" }
ExpressionInParentheses = { "(" ~ Expr ~ ")" }
Cast = { ^"cast" ~ "(" ~ Expr ~ ^"as" ~ TypeCast ~ ")" }
TypeCast = _{ TypeAny | ColumnDefType }
ColumnDefType = { TypeBool | TypeDecimal | TypeDouble | TypeInt | TypeNumber
| TypeScalar | TypeString | TypeText | TypeUnsigned | TypeVarchar }
TypeAny = { ^"any" }
CastInfix = { Expr ~ "::" ~ TypeCast }
TypeCast = {
TypeBool
| TypeDecimal
| TypeDouble
| TypeInt
| TypeNumeric
| TypeText
| TypeVarchar
}
TypeBool = { (^"boolean" | ^"bool") }
TypeDecimal = { ^"decimal" }
TypeDouble = { ^"double" }
TypeInt = { (^"integer" | ^"int") }
TypeNumber = { ^"number" }
TypeScalar = { ^"scalar" }
TypeString = { ^"string" }
TypeNumeric = { ^"numeric" }
TypeText = { ^"text" }
TypeUnsigned = { ^"unsigned" }
TypeVarchar = { ^"varchar" ~ "(" ~ Unsigned ~ ")" }
TypeDate = { ^"date" }
TypeTime = { ^"time" ~ Unsigned? ~ (WithTimeZone | WithoutTimeZone)? }
TypeTimestamp = { ^"timestamp" ~ Unsigned? ~ (WithTimeZone | WithoutTimeZone)? }
WithTimeZone = { ^"with" ~ ^"time" ~ ^"zone" }
WithoutTimeZone = { ^"without" ~ ^"time" ~ ^"zone" }
UnaryOperator = _{ Exists }
Exists = { NotFlag? ~ ^"exists" ~ SubQuery }
Row = { "(" ~ Expr ~ ("," ~ Expr)* ~ ")" }

View file

@ -2,6 +2,10 @@
//! expressions and more, based on a modified version of the
//! [official Pest SQL grammar](https://github.com/pest-parser/pest/blob/79dd30d11aab6f0fba3cd79bd48f456209b966b3/grammars/src/grammars/sql.pest).
//!
//! This grammar covers a larger subset of the Postgres SQL dialect, but it is a
//! work in progress and is far from complete. It should only be used to parse
//! input that is "PostgreSQL-esque", not input that expects spec compliance.
//!
//! ## Example
//!
//! ```
@ -40,6 +44,8 @@ pub use crate::{datum::Datum, query_builders::QueryFragment};
mod datum;
mod query_builders;
#[cfg(test)]
mod array_tests;
#[cfg(test)]
mod fragment_tests;
#[cfg(test)]
@ -136,18 +142,13 @@ static PRATT_PARSER: LazyLock<PrattParser<Rule>> = LazyLock::new(|| {
.op(Op::infix(Rule::Between, Left))
.op(Op::infix(Rule::And, Left))
.op(Op::prefix(Rule::UnaryNot))
.op(Op::infix(Rule::Eq, Right)
| Op::infix(Rule::NotEq, Right)
| Op::infix(Rule::Gt, Right)
| Op::infix(Rule::GtEq, Right)
| Op::infix(Rule::Lt, Right)
| Op::infix(Rule::LtEq, Right)
| Op::infix(Rule::In, Right))
.op(Op::infix(Rule::CmpInfixOp, Right))
// Official Pest example overstates the concat operator's precedence. It
// should be lower precedence than add/subtract.
.op(Op::infix(Rule::ConcatInfixOp, Left))
.op(Op::infix(Rule::Add, Left) | Op::infix(Rule::Subtract, Left))
.op(Op::infix(Rule::Multiply, Left) | Op::infix(Rule::Divide, Left))
.op(Op::infix(Rule::CastInfix, Left))
.op(Op::postfix(Rule::IsNullPostfix))
});
@ -174,6 +175,7 @@ pub enum Expr {
is_null: bool,
expr: Box<Expr>,
},
Array(Vec<Expr>),
}
impl TryFrom<&str> for Expr {
@ -188,7 +190,7 @@ impl TryFrom<&str> for Expr {
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq)]
pub enum InfixOp {
// Arithmetic ops:
Add,
@ -206,6 +208,18 @@ pub enum InfixOp {
Lt,
Lte,
Neq,
// Miscellaneous ops:
Cast,
// Array comparison modifiers (such as `= any(array[])`):
// TODO: This is an awkward pattern, which is capable of representing
// invalid expressions (such as `3 + any(array[])`). I expect it'll need to
// be rewritten at some point anyways to handle other keyword-driven infix
// syntax, but for expediency I'm leaving a more robust solution as a
// challenge for another day.
WithCmpModifierAny(Box<Self>),
WithCmpModifierAll(Box<Self>),
}
#[derive(Clone, Debug, PartialEq)]
@ -267,6 +281,18 @@ fn parse_expr_pairs(expr_pairs: Pairs<'_, Rule>) -> Result<Expr, ParseError> {
Expr::ObjName(name)
})
}
Rule::SquareBracketArray => {
let mut arr_items: Vec<Expr> = vec![];
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::Expr => {arr_items.push(parse_expr_pairs(inner_pair.into_inner())?);}
_ => unreachable!(
"SquareBracketArray has only Exprs as direct child rules",
),
}
}
Ok(Expr::Array(arr_items))
}
rule => Err(ParseError::UnknownRule(rule)),
})
.map_infix(|lhs, op, rhs| Ok(Expr::Infix {
@ -278,13 +304,9 @@ fn parse_expr_pairs(expr_pairs: Pairs<'_, Rule>) -> Result<Expr, ParseError> {
Rule::Multiply => InfixOp::Mult,
Rule::Subtract => InfixOp::Sub,
Rule::And => InfixOp::And,
Rule::Eq => InfixOp::Eq,
Rule::Gt => InfixOp::Gt,
Rule::GtEq => InfixOp::Gte,
Rule::Lt => InfixOp::Lt,
Rule::LtEq => InfixOp::Lte,
Rule::NotEq => InfixOp::Neq,
Rule::CmpInfixOp => parse_cmp_op(op)?,
Rule::Or => InfixOp::Or,
Rule::CastInfix => InfixOp::Cast,
rule => Err(ParseError::UnknownRule(rule))?,
},
rhs: Box::new(rhs?),
@ -306,6 +328,52 @@ fn parse_expr_pairs(expr_pairs: Pairs<'_, Rule>) -> Result<Expr, ParseError> {
.parse(expr_pairs)
}
fn parse_cmp_op(op: Pair<'_, Rule>) -> Result<InfixOp, ParseError> {
let mut base_op: Option<InfixOp> = None;
for inner in op.into_inner() {
match inner.as_rule() {
Rule::Eq => {
base_op = Some(InfixOp::Eq);
}
Rule::Gt => {
base_op = Some(InfixOp::Gt);
}
Rule::GtEq => {
base_op = Some(InfixOp::Gte);
}
Rule::Lt => {
base_op = Some(InfixOp::Lt);
}
Rule::LtEq => {
base_op = Some(InfixOp::Lte);
}
Rule::NotEq => {
base_op = Some(InfixOp::Neq);
}
Rule::CmpArrayModifier => {
if let Some(base_op) = base_op {
return Ok(
match inner
.into_inner()
.next()
.expect("CmpArrayModifier should be a simple enumeration")
.as_rule()
{
Rule::CmpModifierAny => InfixOp::WithCmpModifierAny(Box::new(base_op)),
Rule::CmpModifierAll => InfixOp::WithCmpModifierAll(Box::new(base_op)),
rule => Err(ParseError::UnknownRule(rule))?,
},
);
} else {
return Err(ParseError::UnknownRule(Rule::CmpArrayModifier));
}
}
rule => Err(ParseError::UnknownRule(rule))?,
}
}
Ok(base_op.expect("CmpInfixOp always has at least one child"))
}
fn parse_function_invocation_continuation(pair: Pair<'_, Rule>) -> Result<FnArgs, ParseError> {
let mut cont_inner_iter = pair.into_inner();
let fn_args = if let Some(cont_inner) = cont_inner_iter.next() {

View file

@ -115,6 +115,8 @@ impl From<Expr> for QueryFragment {
(*lhs).into(),
Self::from_sql(") "),
op.into(),
// The RHS expression **must** be parenthesized to correctly
// reconstruct syntax like `= any (array[...])`.
Self::from_sql(" ("),
(*rhs).into(),
Self::from_sql(")"),
@ -176,27 +178,44 @@ impl From<Expr> for QueryFragment {
fragment.push(Self::from_sql(")"));
fragment
}
Expr::Array(arr_items) => Self::concat([
Self::from_sql("array["),
Self::join(
arr_items.into_iter().map(|item| {
Self::concat([Self::from_sql("("), item.into(), Self::from_sql(")")])
}),
Self::from_sql(", "),
),
Self::from_sql("]"),
]),
}
}
}
impl From<InfixOp> for QueryFragment {
fn from(value: InfixOp) -> Self {
Self::from_sql(match value {
InfixOp::Add => "+",
InfixOp::Concat => "||",
InfixOp::Div => "/",
InfixOp::Mult => "*",
InfixOp::Sub => "-",
InfixOp::And => "and",
InfixOp::Or => "or",
InfixOp::Eq => "=",
InfixOp::Gt => ">",
InfixOp::Gte => ">=",
InfixOp::Lt => "<",
InfixOp::Lte => "<=",
InfixOp::Neq => "<>",
})
match value {
InfixOp::Add => Self::from_sql("+"),
InfixOp::Concat => Self::from_sql("||"),
InfixOp::Div => Self::from_sql("/"),
InfixOp::Mult => Self::from_sql("*"),
InfixOp::Sub => Self::from_sql("-"),
InfixOp::And => Self::from_sql("and"),
InfixOp::Or => Self::from_sql("or"),
InfixOp::Eq => Self::from_sql("="),
InfixOp::Gt => Self::from_sql(">"),
InfixOp::Gte => Self::from_sql(">="),
InfixOp::Lt => Self::from_sql("<"),
InfixOp::Lte => Self::from_sql("<="),
InfixOp::Neq => Self::from_sql("<>"),
InfixOp::Cast => Self::from_sql("::"),
InfixOp::WithCmpModifierAny(inner) => {
Self::concat([(*inner).into(), Self::from_sql(" any")])
}
InfixOp::WithCmpModifierAll(inner) => {
Self::concat([(*inner).into(), Self::from_sql(" all")])
}
}
}
}

View file

@ -270,33 +270,11 @@ fn into_safe_filter_sql(expr_text: &str) -> Option<QueryFragment> {
fn is_safe_filter_expr(expr: &Expr) -> bool {
match expr {
&Expr::Literal(_) | &Expr::ObjName(_) => true,
&Expr::Infix {
ref lhs,
op,
ref rhs,
} => match op {
// Most if not all infix operators should be safe, but enumerate
// them just in case.
// Arithmetic:
InfixOp::Add
| InfixOp::Concat
| InfixOp::Div
| InfixOp::Mult
| InfixOp::Sub
// Boolean:
| InfixOp::And
| InfixOp::Or
| InfixOp::Eq
| InfixOp::Gt
| InfixOp::Gte
| InfixOp::Lt
| InfixOp::Lte
| InfixOp::Neq => is_safe_filter_expr(lhs) && is_safe_filter_expr(rhs),
_ => false,
},
&Expr::FnCall { ref name, ref args } => match name
Expr::Literal(_) | &Expr::ObjName(_) => true,
Expr::Infix { lhs, op, rhs } => {
is_safe_filter_infix_op(op) && is_safe_filter_expr(lhs) && is_safe_filter_expr(rhs)
}
Expr::FnCall { name, args } => match name
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
@ -339,11 +317,35 @@ fn is_safe_filter_expr(expr: &Expr) -> bool {
},
_ => false,
},
&Expr::Not(ref inner) => is_safe_filter_expr(inner),
&Expr::Nullness {
Expr::Not(inner) => is_safe_filter_expr(inner),
Expr::Nullness {
is_null: _,
expr: ref inner,
expr: inner,
} => is_safe_filter_expr(inner),
Expr::Array(arr_items) => arr_items.iter().all(is_safe_filter_expr),
_ => false,
}
}
fn is_safe_filter_infix_op(op: &InfixOp) -> bool {
match op {
InfixOp::Add
| InfixOp::Concat
| InfixOp::Div
| InfixOp::Mult
| InfixOp::Sub
// Boolean:
| InfixOp::And
| InfixOp::Or
| InfixOp::Eq
| InfixOp::Gt
| InfixOp::Gte
| InfixOp::Lt
| InfixOp::Lte
| InfixOp::Neq
// Miscellaneous:
| InfixOp::Cast => true,
InfixOp::WithCmpModifierAny(inner) | InfixOp::WithCmpModifierAll(inner) => is_safe_filter_infix_op(inner),
_ => false
}
}