phonograph/phono-pestgros/src/lib.rs

371 lines
13 KiB
Rust
Raw Normal View History

2026-02-13 08:00:23 +00:00
//! Incomplete but useful parser and generator for Postgres flavored SQL
//! expressions and more, based on a modified version of the
//! [official Pest SQL grammar](https://github.com/pest-parser/pest/blob/79dd30d11aab6f0fba3cd79bd48f456209b966b3/grammars/src/grammars/sql.pest).
//!
//! ## Example
//!
//! ```
//! use phono_pestgros::{ArithOp, BoolOp, Datum, Expr, InfixOp};
//!
//! # fn main() -> Result<(), Box<dyn Error>> {
//! let expr = Expr::try_from("3 + 5 < 10")?;
//!
//! assert_eq!(expr, Expr::Infix {
//! lhs: Box::new(Expr::Infix {
//! lhs: Box::new(Expr::Literal(Datum::Numeric(Some(3.into())))),
//! op: InfixOp::ArithInfix(ArithOp::Add),
//! rhs: Box::new(Expr::Literal(Datum::Numeric(Some(5.into())))),
//! }),
//! op: InfixOp::BoolInfix(BoolOp::Lt),
//! rhs: Box::new(Expr::Literal(Datum::Numeric(Some(10.into())))),
//! });
//!
//! assert_eq!(QueryBuilder::try_from(expr).sql(), "(($1) + ($2)) < ($3)");
//! # Ok(())
//! # }
//! ```
use std::{str::FromStr, sync::LazyLock};
use bigdecimal::BigDecimal;
use pest::{
Parser as _,
iterators::{Pair, Pairs},
pratt_parser::PrattParser,
};
use pest_derive::Parser;
pub use crate::datum::Datum;
mod datum;
mod query_builders;
#[cfg(test)]
mod fragment_tests;
#[cfg(test)]
mod func_invocation_tests;
#[cfg(test)]
mod identifier_tests;
#[cfg(test)]
mod literal_tests;
#[cfg(test)]
mod op_tests;
/// Given a raw identifier (such as a table name, column name, etc.), format it
/// so that it may be safely interpolated into a SQL query.
///
/// Note that in PostgreSQL, unquoted identifiers are case-insensitive (or,
/// rather, they are always implicitly converted to lowercase), while quoted
/// identifiers are case-sensitive. The caller of this function is responsible
/// for performing conversion to lowercase as appropriate.
pub fn escape_identifier(identifier: &str) -> String {
// Escaping identifiers for Postgres is fairly easy, provided that the input is
// already known to contain no invalid multi-byte sequences. Backslashes may
// remain as-is, and embedded double quotes are escaped simply by doubling
// them (`"` becomes `""`). Refer to the PQescapeInternal() function in
// libpq (fe-exec.c) and Diesel's PgQueryBuilder::push_identifier().
format!("\"{0}\"", identifier.replace('"', "\"\""))
}
/// Decodes a SQL representation of an identifier. If the input is unquoted, it
/// is converted to lowercase. If it is double quoted, the surrounding quotes
/// are stripped and escaped inner double quotes (double-double quotes, if you
/// will) are converted to single-double quotes. The opposite of
/// [`escape_identifier`], sort of.
///
/// Assumes that the provided identifier is well-formed. Basic gut checks are
/// performed, but they are non-exhaustive.
///
/// `U&"..."`-style escaped Unicode identifiers are not yet supported.
fn parse_ident(value: &str) -> String {
assert!(
!value.to_lowercase().starts_with("u&"),
"escaped Unicode identifiers are not supported"
);
if value.starts_with('"') {
assert!(value.ends_with('"'), "malformed double-quoted identifier");
{
// Strip first and last characters.
let mut chars = value.chars();
chars.next();
chars.next_back();
chars.as_str()
}
.replace(r#""""#, r#"""#)
} else {
// TODO: assert validity with regex
value.to_lowercase()
}
}
/// Decodes a single-quoted string literal. Removes surrounding quotes and
/// replaces embedded single quotes (double-single quotes) with single-single
/// quotes.
///
/// Assumes that the provided identifier is well-formed. Basic gut checks are
/// performed, but they are non-exhaustive.
///
/// `E'...'`-style, dollar-quoted, and other (relatively) uncommon formats for
/// text literals are not yet supported.
fn parse_text_literal(value: &str) -> String {
assert!(value.starts_with('\'') && value.ends_with('\''));
{
// Strip first and last characters.
let mut chars = value.chars();
chars.next();
chars.next_back();
chars.as_str()
}
.replace("''", "'")
}
/// Primary parser and code generation for [`Rule`] types.
#[derive(Parser)]
#[grammar = "src/grammar.pest"]
struct PsqlParser;
/// Secondary parser configuration for handling operator precedence.
static PRATT_PARSER: LazyLock<PrattParser<Rule>> = LazyLock::new(|| {
use pest::pratt_parser::{
Assoc::{Left, Right},
Op,
};
PrattParser::new()
.op(Op::infix(Rule::Or, Left))
.op(Op::infix(Rule::Between, Left))
.op(Op::infix(Rule::And, Left))
.op(Op::prefix(Rule::UnaryNot))
.op(Op::infix(Rule::Eq, Right)
| Op::infix(Rule::NotEq, Right)
| Op::infix(Rule::Gt, Right)
| Op::infix(Rule::GtEq, Right)
| Op::infix(Rule::Lt, Right)
| Op::infix(Rule::LtEq, Right)
| Op::infix(Rule::In, Right))
// Official Pest example overstates the concat operator's precedence. It
// should be lower precedence than add/subtract.
.op(Op::infix(Rule::ConcatInfixOp, Left))
.op(Op::infix(Rule::Add, Left) | Op::infix(Rule::Subtract, Left))
.op(Op::infix(Rule::Multiply, Left) | Op::infix(Rule::Divide, Left))
.op(Op::postfix(Rule::IsNullPostfix))
});
/// Represents a SQL expression. An expression is a collection of values and
/// operators that theoretically evaluates to some value, such as a boolean
/// condition, an object name, or a string dynamically derived from other
/// values. An expression is *not* a complete SQL statement, command, or query.
#[derive(Clone, Debug, PartialEq)]
pub enum Expr {
Infix {
lhs: Box<Expr>,
op: InfixOp,
rhs: Box<Expr>,
},
Literal(Datum),
ObjName(Vec<String>),
FnCall {
name: Vec<String>,
args: FnArgs,
},
Not(Box<Expr>),
Nullness {
is_null: bool,
expr: Box<Expr>,
},
}
impl TryFrom<&str> for Expr {
type Error = ParseError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
// `ExprRoot` is a silent rule which simply dictates that the inner
// `Expr` rule must consume the entire input.
let pairs = PsqlParser::parse(Rule::ExprRoot, value)?;
parse_expr_pairs(pairs)
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum InfixOp {
ArithInfix(ArithOp),
BoolInfix(BoolOp),
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum ArithOp {
Add,
Concat,
Div,
Mult,
Sub,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum BoolOp {
And,
Or,
Eq,
Gt,
Gte,
Lt,
Lte,
Neq,
}
#[derive(Clone, Debug, PartialEq)]
pub enum FnArgs {
CountAsterisk,
Exprs {
/// `true` for aggregator invocations with the `DISTINCT` keyword
/// specified.
distinct_flag: bool,
exprs: Vec<Expr>,
},
}
/// Recursive helper, which does most of the work to convert [`pest`]'s pattern
/// matching output to a usable syntax tree.
fn parse_expr_pairs(expr_pairs: Pairs<'_, Rule>) -> Result<Expr, ParseError> {
PRATT_PARSER
.map_primary(|pair| match pair.as_rule() {
Rule::Expr | Rule::ExpressionInParentheses => parse_expr_pairs(pair.into_inner()),
Rule::Decimal | Rule::Double | Rule::Integer | Rule::Unsigned => Ok(Expr::Literal(
Datum::Numeric(Some(BigDecimal::from_str(pair.as_str()).expect(
"parsed numeric values should always be convertible to BigDecimal",
))),
)),
Rule::SingleQuotedString => Ok(Expr::Literal(Datum::Text(Some(parse_text_literal(pair.as_str()))))),
Rule::IdentifierWithOptionalContinuation => {
let mut name: Vec<String> = vec![];
let mut fn_args: Option<FnArgs> = None;
let inner = pair.into_inner();
for inner_pair in inner {
match inner_pair.as_rule() {
Rule::Identifier => {
name.push(parse_ident(inner_pair.as_str()));
}
Rule::QualifiedIdentifierContinuation => {
let ident_cont = inner_pair.as_str();
assert!(
ident_cont.starts_with('.'),
"QualifiedIdentifierContinuation should always start with the infix dot",
);
name.push(parse_ident({
// Strip leading dot.
let mut chars = ident_cont.chars();
chars.next();
chars.as_str()
}));
}
Rule::FunctionInvocationContinuation => {
fn_args = Some(parse_function_invocation_continuation(inner_pair)?);
}
_ => unreachable!(
"IdentifierWithOptionalContinuation has only 3 valid child rules",
),
}
}
Ok(if let Some(fn_args) = fn_args {
Expr::FnCall { name, args: fn_args }
} else {
Expr::ObjName(name)
})
}
rule => Err(ParseError::UnknownRule(rule)),
})
.map_infix(|lhs, op, rhs| Ok(Expr::Infix {
lhs: Box::new(lhs?),
op: match op.as_rule() {
Rule::Add => InfixOp::ArithInfix(ArithOp::Add),
Rule::ConcatInfixOp => InfixOp::ArithInfix(ArithOp::Concat),
Rule::Divide => InfixOp::ArithInfix(ArithOp::Div),
Rule::Multiply => InfixOp::ArithInfix(ArithOp::Mult),
Rule::Subtract => InfixOp::ArithInfix(ArithOp::Sub),
Rule::And => InfixOp::BoolInfix(BoolOp::And),
Rule::Eq => InfixOp::BoolInfix(BoolOp::Eq),
Rule::Gt => InfixOp::BoolInfix(BoolOp::Gt),
Rule::GtEq => InfixOp::BoolInfix(BoolOp::Gte),
Rule::Lt => InfixOp::BoolInfix(BoolOp::Lt),
Rule::LtEq => InfixOp::BoolInfix(BoolOp::Lte),
Rule::NotEq => InfixOp::BoolInfix(BoolOp::Neq),
Rule::Or => InfixOp::BoolInfix(BoolOp::Or),
rule => Err(ParseError::UnknownRule(rule))?,
},
rhs: Box::new(rhs?),
}))
.map_prefix(|op, child| Ok(match op.as_rule() {
Rule::UnaryNot => Expr::Not(Box::new(child?)),
rule => Err(ParseError::UnknownRule(rule))?,
}))
.map_postfix(|child, op| Ok(match op.as_rule() {
Rule::IsNullPostfix => Expr::Nullness {
is_null: op
.into_inner()
.next()
.map(|inner| inner.as_rule()) != Some(Rule::NotFlag),
expr: Box::new(child?),
},
rule => Err(ParseError::UnknownRule(rule))?,
}))
.parse(expr_pairs)
}
fn parse_function_invocation_continuation(pair: Pair<'_, Rule>) -> Result<FnArgs, ParseError> {
let mut cont_inner_iter = pair.into_inner();
let fn_args = if let Some(cont_inner) = cont_inner_iter.next() {
match cont_inner.as_rule() {
Rule::FunctionArgs => {
let mut distinct_flag = false;
let mut exprs: Vec<Expr> = vec![];
for arg_inner in cont_inner.into_inner() {
match arg_inner.as_rule() {
Rule::Distinct => {
distinct_flag = true;
}
Rule::Expr => {
exprs.push(parse_expr_pairs(arg_inner.into_inner())?);
}
_ => unreachable!(
"only valid children of FunctionArgs are Distinct and Expr"
),
}
}
FnArgs::Exprs {
distinct_flag,
exprs,
}
}
Rule::CountAsterisk => FnArgs::CountAsterisk,
_ => unreachable!(
"only valid children of FunctionInvocationContinuation are FunctionArgs and CountAsterisk"
),
}
} else {
FnArgs::Exprs {
distinct_flag: false,
exprs: vec![],
}
};
assert!(
cont_inner_iter.next().is_none(),
"function should have consumed entire FunctionInvocationContinuation pair",
);
Ok(fn_args)
}
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
#[error("parse error")]
pub enum ParseError {
#[error("unknown rule")]
UnknownRule(Rule),
#[error("pest failed to parse: {0}")]
Pest(pest::error::Error<Rule>),
}
impl From<pest::error::Error<Rule>> for ParseError {
fn from(value: pest::error::Error<Rule>) -> Self {
Self::Pest(value)
}
}