From a4ffb44f4d3a990526a0aa671e2fc3ed16284e75 Mon Sep 17 00:00:00 2001 From: Brent Schroeter Date: Fri, 13 Feb 2026 08:00:23 +0000 Subject: [PATCH] add phono-pestgros crate --- Cargo.lock | 142 +++++++- Cargo.toml | 1 + phono-pestgros/Cargo.toml | 15 + phono-pestgros/src/datum.rs | 71 ++++ phono-pestgros/src/fragment_tests.rs | 15 + phono-pestgros/src/func_invocation_tests.rs | 76 ++++ phono-pestgros/src/grammar.pest | 212 +++++++++++ phono-pestgros/src/identifier_tests.rs | 37 ++ phono-pestgros/src/lib.rs | 370 ++++++++++++++++++++ phono-pestgros/src/literal_tests.rs | 30 ++ phono-pestgros/src/op_tests.rs | 106 ++++++ phono-pestgros/src/query_builders.rs | 273 +++++++++++++++ 12 files changed, 1337 insertions(+), 11 deletions(-) create mode 100644 phono-pestgros/Cargo.toml create mode 100644 phono-pestgros/src/datum.rs create mode 100644 phono-pestgros/src/fragment_tests.rs create mode 100644 phono-pestgros/src/func_invocation_tests.rs create mode 100644 phono-pestgros/src/grammar.pest create mode 100644 phono-pestgros/src/identifier_tests.rs create mode 100644 phono-pestgros/src/lib.rs create mode 100644 phono-pestgros/src/literal_tests.rs create mode 100644 phono-pestgros/src/op_tests.rs create mode 100644 phono-pestgros/src/query_builders.rs diff --git a/Cargo.lock b/Cargo.lock index be42df4..7f3a314 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -364,6 +364,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "backtrace-ext" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "537beee3be4a18fb023b570f80e3ae28003db9167a751266b259926e25539d50" +dependencies = [ + "backtrace", +] + [[package]] name = "base64" version = "0.13.1" @@ -1222,7 +1231,7 @@ version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe4fbac503b8d1f88e6676011885f34b7174f46e59956bba534ba83abded4df" dependencies = [ - "unicode-width", + "unicode-width 0.2.2", ] [[package]] @@ -1756,6 +1765,12 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "is_ci" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7655c9839580ee829dfacba1d1278c2b7883e50a277ff7541299489d6bdfdc45" + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -1916,6 +1931,36 @@ version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +[[package]] +name = "miette" +version = "7.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f98efec8807c63c752b5bd61f862c165c115b0a35685bdcfd9238c7aeb592b7" +dependencies = [ + "backtrace", + "backtrace-ext", + "cfg-if 1.0.3", + "miette-derive", + "owo-colors", + "supports-color", + "supports-hyperlinks", + "supports-unicode", + "terminal_size", + "textwrap", + "unicode-width 0.1.14", +] + +[[package]] +name = "miette-derive" +version = "7.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db5b29714e950dbb20d5e6f74f9dcec4edbcc1067bb7f8ed198c097b8c1a818b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "mime" version = "0.3.17" @@ -2174,6 +2219,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "owo-colors" +version = "4.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c6901729fa79e91a0913333229e9ca5dc725089d1c363b2f4b4760709dc4a52" + [[package]] name = "parking" version = "2.2.1" @@ -2226,20 +2277,22 @@ checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" [[package]] name = "pest" -version = "2.8.0" +version = "2.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "198db74531d58c70a361c42201efde7e2591e976d518caf7662a47dc5720e7b6" +checksum = "e0848c601009d37dfa3430c4666e147e49cdcf1b92ecd3e63657d8a5f19da662" dependencies = [ "memchr", - "thiserror 2.0.12", + "miette", + "serde", + "serde_json", "ucd-trie", ] [[package]] name = "pest_derive" -version = "2.8.0" +version = "2.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d725d9cfd79e87dccc9341a2ef39d1b6f6353d68c4b33c177febbe1a402c97c5" +checksum = "11f486f1ea21e6c10ed15d5a7c77165d0ee443402f0780849d1768e7d9d6fe77" dependencies = [ "pest", "pest_generator", @@ -2247,9 +2300,9 @@ dependencies = [ [[package]] name = "pest_generator" -version = "2.8.0" +version = "2.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db7d01726be8ab66ab32f9df467ae8b1148906685bbe75c82d1e65d7f5b3f841" +checksum = "8040c4647b13b210a963c1ed407c1ff4fdfa01c31d6d2a098218702e6664f94f" dependencies = [ "pest", "pest_meta", @@ -2260,11 +2313,10 @@ dependencies = [ [[package]] name = "pest_meta" -version = "2.8.0" +version = "2.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f9f832470494906d1fca5329f8ab5791cc60beb230c74815dff541cbd2b5ca0" +checksum = "89815c69d36021a140146f26659a81d6c2afa33d216d736dd4be5381a7362220" dependencies = [ - "once_cell", "pest", "sha2 0.10.9", ] @@ -2366,6 +2418,21 @@ dependencies = [ "thiserror 2.0.12", ] +[[package]] +name = "phono-pestgros" +version = "0.0.1" +dependencies = [ + "bigdecimal", + "chrono", + "pest", + "pest_derive", + "serde", + "serde_json", + "sqlx", + "thiserror 2.0.12", + "uuid", +] + [[package]] name = "phono-server" version = "0.0.1" @@ -3489,6 +3556,27 @@ version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" +[[package]] +name = "supports-color" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c64fc7232dd8d2e4ac5ce4ef302b1d81e0b80d055b9d77c7c4f51f6aa4c867d6" +dependencies = [ + "is_ci", +] + +[[package]] +name = "supports-hyperlinks" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e396b6523b11ccb83120b115a0b7366de372751aa6edf19844dfb13a6af97e91" + +[[package]] +name = "supports-unicode" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7401a30af6cb5818bb64852270bb722533397edcfc7344954a38f420819ece2" + [[package]] name = "syn" version = "2.0.101" @@ -3592,6 +3680,26 @@ dependencies = [ "utf-8", ] +[[package]] +name = "terminal_size" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45c6481c4829e4cc63825e62c49186a34538b7b2750b73b266581ffb612fb5ed" +dependencies = [ + "rustix", + "windows-sys 0.59.0", +] + +[[package]] +name = "textwrap" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057" +dependencies = [ + "unicode-linebreak", + "unicode-width 0.2.2", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -4020,6 +4128,12 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +[[package]] +name = "unicode-linebreak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" + [[package]] name = "unicode-normalization" version = "0.1.24" @@ -4041,6 +4155,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + [[package]] name = "unicode-width" version = "0.2.2" diff --git a/Cargo.toml b/Cargo.toml index 81facb8..ed3f5aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ futures = "0.3.31" phono-backends = { path = "./phono-backends" } phono-models = { path = "./phono-models" } phono-namegen = { path = "./phono-namegen" } +phono-pestgros = { path = "./phono-pestgros" } rand = "0.8.5" redact = { version = "0.1.11", features = ["serde", "zeroize"] } regex = "1.11.1" diff --git a/phono-pestgros/Cargo.toml b/phono-pestgros/Cargo.toml new file mode 100644 index 0000000..3be9672 --- /dev/null +++ b/phono-pestgros/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "phono-pestgros" +edition.workspace = true +version.workspace = true + +[dependencies] +bigdecimal = { workspace = true } +chrono = { workspace = true } +pest = { version = "2.8.6", features = ["miette-error"] } +pest_derive = "2.8.6" +serde = { workspace = true } +serde_json = { workspace = true } +sqlx = { workspace = true } +thiserror = { workspace = true } +uuid = { workspace = true } diff --git a/phono-pestgros/src/datum.rs b/phono-pestgros/src/datum.rs new file mode 100644 index 0000000..a86d7c7 --- /dev/null +++ b/phono-pestgros/src/datum.rs @@ -0,0 +1,71 @@ +use bigdecimal::BigDecimal; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::{Postgres, QueryBuilder}; +use uuid::Uuid; + +/// Enum representing all supported literal types, providing convenience +/// methods for working with them in [`sqlx`] queries, and defining a [`serde`] +/// encoding for use across the application stack. +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[serde(tag = "t", content = "c")] +pub enum Datum { + // BigDecimal is used because a user may insert a value directly via SQL + // which overflows the representational space of `rust_decimal::Decimal`. + // Note that by default, [`BigDecimal`] serializes to JSON as a string. This + // behavior can be modified, but it's a pain when paired with the [`Option`] + // type. String representation should be acceptable for the UI, as [`Datum`] + // values should always be parsed through Zod, which can coerce the value to + // a number transparently. + Numeric(Option), + Text(Option), + Timestamp(Option>), + Uuid(Option), +} + +// TODO: Should sqlx helpers be moved to a separate crate? +impl Datum { + // TODO: Can something similar be achieved with a generic return type? + /// Bind this as a parameter to a sqlx query. + pub fn bind_onto<'a>( + self, + query: sqlx::query::Query<'a, Postgres, ::Arguments<'a>>, + ) -> sqlx::query::Query<'a, Postgres, ::Arguments<'a>> { + match self { + Self::Numeric(value) => query.bind(value), + Self::Text(value) => query.bind(value), + Self::Timestamp(value) => query.bind(value), + Self::Uuid(value) => query.bind(value), + } + } + + /// Push this as a parameter to a [`QueryBuilder`]. + pub fn push_bind_onto(self, builder: &mut QueryBuilder<'_, Postgres>) { + match self { + Self::Numeric(value) => builder.push_bind(value), + Self::Text(value) => builder.push_bind(value), + Self::Timestamp(value) => builder.push_bind(value), + Self::Uuid(value) => builder.push_bind(value), + }; + } + + /// Transform the contained value into a serde_json::Value. + pub fn inner_as_value(&self) -> serde_json::Value { + let serialized = serde_json::to_value(self).unwrap(); + #[derive(Deserialize)] + struct Tagged { + c: serde_json::Value, + } + let deserialized: Tagged = serde_json::from_value(serialized).unwrap(); + deserialized.c + } + + pub fn is_none(&self) -> bool { + match self { + Self::Numeric(None) | Self::Text(None) | Self::Timestamp(None) | Self::Uuid(None) => { + true + } + Self::Numeric(_) | Self::Text(_) | Self::Timestamp(_) | Self::Uuid(_) => false, + } + } +} diff --git a/phono-pestgros/src/fragment_tests.rs b/phono-pestgros/src/fragment_tests.rs new file mode 100644 index 0000000..5f50c03 --- /dev/null +++ b/phono-pestgros/src/fragment_tests.rs @@ -0,0 +1,15 @@ +use std::error::Error; + +use sqlx::{Postgres, QueryBuilder}; + +use crate::Expr; + +#[test] +fn sql_converts_to_query_builder() -> Result<(), Box> { + let expr = Expr::try_from("3 + 5 < 10")?; + assert_eq!( + QueryBuilder::<'_, Postgres>::from(expr).sql(), + "(($1) + ($2)) < ($3)", + ); + Ok(()) +} diff --git a/phono-pestgros/src/func_invocation_tests.rs b/phono-pestgros/src/func_invocation_tests.rs new file mode 100644 index 0000000..4372ecf --- /dev/null +++ b/phono-pestgros/src/func_invocation_tests.rs @@ -0,0 +1,76 @@ +use std::error::Error; + +use crate::{ArithOp, Datum, Expr, FnArgs, InfixOp}; + +#[test] +fn parses_without_args() -> Result<(), Box> { + assert_eq!( + Expr::try_from("now()")?, + Expr::FnCall { + name: vec!["now".to_owned()], + args: FnArgs::Exprs { + distinct_flag: false, + exprs: vec![], + }, + } + ); + Ok(()) +} + +#[test] +fn parses_with_args() -> Result<(), Box> { + assert_eq!( + Expr::try_from("repeat('hello!', 1 + 2)")?, + Expr::FnCall { + name: vec!["repeat".to_owned()], + args: FnArgs::Exprs { + distinct_flag: false, + exprs: vec![ + Expr::Literal(Datum::Text(Some("hello!".to_owned()))), + Expr::Infix { + lhs: Box::new(Expr::Literal(Datum::Numeric(Some(1.into())))), + op: InfixOp::ArithInfix(ArithOp::Add), + rhs: Box::new(Expr::Literal(Datum::Numeric(Some(2.into())))), + } + ], + }, + } + ); + Ok(()) +} + +#[test] +fn schema_qualified() -> Result<(), Box> { + assert_eq!( + Expr::try_from(r#"my_schema."MyFunc"('hello!', 1)"#)?, + Expr::FnCall { + name: vec!["my_schema".to_owned(), "MyFunc".to_owned()], + args: FnArgs::Exprs { + distinct_flag: false, + exprs: vec![ + Expr::Literal(Datum::Text(Some("hello!".to_owned()))), + Expr::Literal(Datum::Numeric(Some(1.into()))), + ], + }, + } + ); + Ok(()) +} + +#[test] +fn distinct_aggregate() -> Result<(), Box> { + assert_eq!( + Expr::try_from(r#"AGGREGATOR(DISTINCT a."Col 1", b."Col 2")"#)?, + Expr::FnCall { + name: vec!["aggregator".to_owned()], + args: FnArgs::Exprs { + distinct_flag: true, + exprs: vec![ + Expr::ObjName(vec!["a".to_owned(), "Col 1".to_owned()]), + Expr::ObjName(vec!["b".to_owned(), "Col 2".to_owned()]), + ], + }, + } + ); + Ok(()) +} diff --git a/phono-pestgros/src/grammar.pest b/phono-pestgros/src/grammar.pest new file mode 100644 index 0000000..5b2cbaf --- /dev/null +++ b/phono-pestgros/src/grammar.pest @@ -0,0 +1,212 @@ +//! Based on +//! https://github.com/pest-parser/pest/blob/master/grammars/src/grammars/sql.pest. +//! (Original is dual-licensed under MIT/Apache-2.0.) +//! +//! Postgres largely conforms to the SQLite flavored dialect captured by the +//! original grammar, but its rules for identifiers differ: +//! +//! > SQL identifiers and key words must begin with a letter (a-z, but also +//! > letters with diacritical marks and non-Latin letters) or an underscore +//! > (_). Subsequent characters in an identifier or key word can be letters, +//! > underscores, digits (0-9), or dollar signs ($). Note that dollar signs are +//! > not allowed in identifiers according to the letter of the SQL standard, +//! > so their use might render applications less portable. +//! -- https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS + +Command = _{ SOI ~ (Query | ExplainQuery | DDL | ACL) ~ EOF } + +ACL = _{ DropRole | DropUser | CreateRole | CreateUser | AlterUser | GrantPrivilege | RevokePrivilege } + CreateUser = { + ^"create" ~ ^"user" ~ Identifier ~ (^"with")? ~ ^"password" ~ SingleQuotedString ~ + AuthMethod? + } + AlterUser = { + ^"alter" ~ ^"user" ~ Identifier ~ (^"with")? ~ AlterOption + } + AlterOption = _{ AlterLogin | AlterNoLogin | AlterPassword } + AlterLogin = { ^"login" } + AlterNoLogin = { ^"nologin" } + AlterPassword = { ^"password" ~ SingleQuotedString ~ AuthMethod? } + AuthMethod = { ^"using" ~ (ChapSha1 | Md5 | Ldap) } + ChapSha1 = { ^"chap-sha1" } + Md5 = { ^"md5" } + Ldap = { ^"ldap" } + DropUser = { ^"drop" ~ ^"user" ~ Identifier } + CreateRole = { ^"create" ~ ^"role" ~ Identifier } + DropRole = { ^"drop" ~ ^"role" ~ Identifier } + GrantPrivilege = { ^"grant" ~ PrivBlock ~ ^"to" ~ Identifier } + RevokePrivilege = { ^"revoke" ~ PrivBlock ~ ^"from" ~ Identifier } + PrivBlock = _{ PrivBlockPrivilege | PrivBlockRolePass } + PrivBlockPrivilege = {Privilege ~ (PrivBlockUser | PrivBlockSpecificUser | PrivBlockRole + | PrivBlockSpecificRole | PrivBlockTable | PrivBlockSpecificTable)} + PrivBlockUser = { ^"user" } + PrivBlockSpecificUser = { ^"on" ~ ^"user" ~ Identifier } + PrivBlockRole = { ^"role" } + PrivBlockSpecificRole = { ^"on" ~ ^"role" ~ Identifier } + PrivBlockTable = { ^"table" } + PrivBlockSpecificTable = { ^"on" ~ ^"table" ~ Identifier } + PrivBlockRolePass = { Identifier } + Privilege = _{ PrivilegeRead | PrivilegeWrite | PrivilegeExecute | + PrivilegeCreate | PrivilegeAlter | PrivilegeDrop | + PrivilegeSession | PrivilegeUsage } + PrivilegeAlter = { ^"alter" } + PrivilegeCreate = { ^"create" } + PrivilegeDrop = { ^"drop" } + PrivilegeExecute = { ^"execute" } + PrivilegeRead = { ^"read" } + PrivilegeSession = { ^"session" } + PrivilegeUsage = { ^"usage" } + PrivilegeWrite = { ^"write" } + +DDL = _{ CreateTable | DropTable | CreateProc } + CreateTable = { + ^"create" ~ ^"table" ~ Identifier ~ + "(" ~ Columns ~ "," ~ PrimaryKey ~ ")" ~ + Distribution + } + Columns = { ColumnDef ~ ("," ~ ColumnDef)* } + ColumnDef = { Identifier ~ ColumnDefType ~ ColumnDefIsNull? } + ColumnDefIsNull = { NotFlag? ~ ^"null" } + PrimaryKey = { + ^"primary" ~ ^"key" ~ + "(" ~ Identifier ~ ("," ~ Identifier)* ~ ")" + } + Distribution = { ^"distributed" ~ (Global | Sharding) } + Global = { ^"globally" } + Sharding = { ^"by" ~ "(" ~ Identifier ~ ("," ~ Identifier)* ~ ")"} + DropTable = { ^"drop" ~ ^"table" ~ Identifier } + + CreateProc = { + ^"create" ~ ^"procedure" ~ Identifier ~ + "(" ~ ProcParams? ~ ")" ~ (^"language" ~ ProcLanguage)? ~ + ((^"as" ~ "$$" ~ ProcBody ~ "$$") | (^"begin" ~ "atomic" ~ ProcBody ~ "end")) + } + ProcParams = { ProcParamDef ~ ("," ~ ProcParamDef)* } + ProcParamDef = { ColumnDefType } + ProcLanguage = { SQL } + SQL = { ^"sql" } + ProcBody = { (Insert | Update | Delete) } + +ExplainQuery = _{ Explain } + Explain = { ^"explain" ~ Query } + +Query = { (SelectWithOptionalContinuation | Values | Insert | Update | Delete) } + SelectWithOptionalContinuation = { Select ~ (ExceptContinuation | UnionAllContinuation)? } + ExceptContinuation = { ((^"except" ~ ^"distinct") | ^"except") ~ Select } + UnionAllContinuation = { ^"union" ~ ^"all" ~ Select } + Select = { + ^"select" ~ Projection ~ ^"from" ~ Scan ~ + Join? ~ WhereClause? ~ + (^"group" ~ ^"by" ~ GroupBy)? ~ + (^"having" ~ Having)? + } + Projection = { Distinct? ~ ProjectionElement ~ ("," ~ ProjectionElement)* } + ProjectionElement = _{ Asterisk | Column } + Column = { Expr ~ ((^"as")? ~ Identifier)? } + Asterisk = { "*" } + WhereClause = _{ ^"where" ~ Selection } + Selection = { Expr } + Scan = { (Identifier | SubQuery) ~ ((^"as")? ~ Identifier)? } + Join = { JoinKind? ~ ^"join" ~ Scan ~ ^"on" ~ Expr } + JoinKind = _{ ( InnerJoinKind | LeftJoinKind ) } + InnerJoinKind = { ^"inner" } + LeftJoinKind = { ^"left" ~ (^"outer")? } + GroupBy = { Expr ~ ("," ~ Expr)* } + Having = { Expr } + SubQuery = { "(" ~ (SelectWithOptionalContinuation | Values) ~ ")" } + Insert = { ^"insert" ~ ^"into" ~ Identifier ~ ("(" ~ TargetColumns ~ ")")? ~ (Values | Select) ~ OnConflict? } + TargetColumns = { Identifier ~ ("," ~ Identifier)* } + OnConflict = _{ ^"on conflict" ~ ^"do" ~ (DoNothing | DoReplace | DoFail) } + DoReplace = { ^"replace" } + DoNothing = { ^"nothing" } + DoFail = { ^"fail" } + Update = { ^"update" ~ Identifier ~ ^"set" ~ UpdateList ~ (UpdateFrom | WhereClause)? } + UpdateList = { UpdateItem ~ ("," ~ UpdateItem)* } + UpdateItem = { Identifier ~ "=" ~ Expr } + UpdateFrom = _{ ^"from" ~ Scan ~ (^"where" ~ Expr)? } + Values = { ^"values" ~ Row ~ ("," ~ Row)* } + Delete = { ^"delete" ~ ^"from" ~ Identifier ~ (^"where" ~ DeleteFilter)? } + DeleteFilter = { Expr } + +Identifier = ${ DoubleQuotedIdentifier | UnquotedIdentifier } + DoubleQuotedIdentifier = @{ "\"" ~ ("\"\"" | '\u{01}'..'\u{21}' | '\u{23}'..'\u{10FFFF}')+ ~ "\"" } + UnquotedIdentifier = @{ !(Keyword ~ ("(" | WHITESPACE | "," | EOF)) ~ (UnquotedIdentifierStart ~ UnquotedIdentifierRemainder*) } + UnquotedIdentifierStart = _{ 'a'..'я' | 'A'..'Я' | "_" } + UnquotedIdentifierRemainder = _{ UnquotedIdentifierStart | "$" | ASCII_DIGIT } + Keyword = { ^"left" | ^"having" | ^"not" | ^"inner" | ^"group" + | ^"on" | ^"join" | ^"from" | ^"exists" | ^"except" + | ^"union" | ^"where" | ^"distinct" | ^"between" | ^"option" + | ^"values"} + +ExprRoot = _{ &SOI ~ Expr ~ &EOI } +Expr = { ExprAtomValue ~ (ExprInfixOp ~ ExprAtomValue)* } + ExprInfixOp = _{ Between | ArithInfixOp | CmpInfixOp | ConcatInfixOp | And | Or } + Between = { NotFlag? ~ ^"between" } + And = { ^"and" } + Or = { ^"or" } + ConcatInfixOp = { "||" } + ArithInfixOp = _{ Add | Subtract | Multiply | Divide } + Add = { "+" } + Subtract = { "-" } + Multiply = { "*" } + Divide = { "/" } + CmpInfixOp = _{ NotEq | GtEq | Gt | LtEq | Lt | Eq | Lt | In } + Eq = { "=" } + Gt = { ">" } + GtEq = { ">=" } + Lt = { "<" } + LtEq = { "<=" } + NotEq = { "<>" | "!=" } + In = { NotFlag? ~ ^"in" } + ExprAtomValue = _{ UnaryNot* ~ AtomicExpr ~ IsNullPostfix? } + UnaryNot = @{ NotFlag } + IsNullPostfix = { ^"is" ~ NotFlag? ~ ^"null" } + AtomicExpr = _{ Literal | Parameter | Cast | IdentifierWithOptionalContinuation | ExpressionInParentheses | UnaryOperator | SubQuery | Row } + Literal = _{ True | False | Null | Double | Decimal | Unsigned | Integer | SingleQuotedString } + True = { ^"true" } + False = { ^"false" } + Null = { ^"null" } + Decimal = @{ Integer ~ ("." ~ ASCII_DIGIT*) } + Double = @{ Integer ~ ("." ~ ASCII_DIGIT*)? ~ (^"e" ~ Integer) } + Integer = @{ ("+" | "-")? ~ ASCII_DIGIT+ } + Unsigned = @{ ASCII_DIGIT+ } + // TODO: Handle dollar-quoted string literals. + SingleQuotedString = @{ "'" ~ ("''" | (!("'") ~ ANY))* ~ "'" } + Parameter = @{ "$" ~ Unsigned } + // Postgres permits qualified object names with a single identifier + // part, 2 parts plus a function invocation, 3 parts, or 3 parts + // plus a function invocation. For simplicity, assume that an + // arbitrary number of qualifications (e.g. "a.b.c.d[...]") are + // supported. + // TODO: Disallow whitespace where it shouldn't be. + IdentifierWithOptionalContinuation = { Identifier ~ QualifiedIdentifierContinuation* ~ FunctionInvocationContinuation? } + QualifiedIdentifierContinuation = ${ "." ~ Identifier } + FunctionInvocationContinuation = { "(" ~ (CountAsterisk | FunctionArgs)? ~ ")" } + // TODO: Support named argument notation + // (`my_func(name => value)`). + FunctionArgs = { Distinct? ~ (Expr ~ ("," ~ Expr)*)? } + CountAsterisk = { "*" } + ExpressionInParentheses = { "(" ~ Expr ~ ")" } + Cast = { ^"cast" ~ "(" ~ Expr ~ ^"as" ~ TypeCast ~ ")" } + TypeCast = _{ TypeAny | ColumnDefType } + ColumnDefType = { TypeBool | TypeDecimal | TypeDouble | TypeInt | TypeNumber + | TypeScalar | TypeString | TypeText | TypeUnsigned | TypeVarchar } + TypeAny = { ^"any" } + TypeBool = { (^"boolean" | ^"bool") } + TypeDecimal = { ^"decimal" } + TypeDouble = { ^"double" } + TypeInt = { (^"integer" | ^"int") } + TypeNumber = { ^"number" } + TypeScalar = { ^"scalar" } + TypeString = { ^"string" } + TypeText = { ^"text" } + TypeUnsigned = { ^"unsigned" } + TypeVarchar = { ^"varchar" ~ "(" ~ Unsigned ~ ")" } + UnaryOperator = _{ Exists } + Exists = { NotFlag? ~ ^"exists" ~ SubQuery } + Row = { "(" ~ Expr ~ ("," ~ Expr)* ~ ")" } + +Distinct = { ^"distinct" } +NotFlag = { ^"not" } +EOF = { EOI | ";" } +WHITESPACE = _{ " " | "\t" | "\n" | "\r\n" } diff --git a/phono-pestgros/src/identifier_tests.rs b/phono-pestgros/src/identifier_tests.rs new file mode 100644 index 0000000..08f7d54 --- /dev/null +++ b/phono-pestgros/src/identifier_tests.rs @@ -0,0 +1,37 @@ +//! Unit tests for identifier and object name parsing within expressions. + +use crate::{Expr, escape_identifier}; + +#[test] +fn escaper_escapes() { + assert_eq!(escape_identifier("hello"), r#""hello""#); + assert_eq!(escape_identifier("hello world"), r#""hello world""#); + assert_eq!( + escape_identifier(r#""hello" "world""#), + r#""""hello"" ""world""""# + ); +} + +#[test] +fn qualified_obj_name_parses() { + assert_eq!( + Expr::try_from(r#""""Hello"", World! 四十二".deep_thought"#), + Ok(Expr::ObjName(vec![ + r#""Hello", World! 四十二"#.to_owned(), + "deep_thought".to_owned(), + ])), + ); +} + +#[test] +fn misquoted_ident_fails_to_parse() { + assert!(Expr::try_from(r#""Hello, "World!""#).is_err()); +} + +#[test] +fn unquoted_ident_lowercased() { + assert_eq!( + Expr::try_from("HeLlO_WoRlD"), + Ok(Expr::ObjName(vec!["hello_world".to_owned()])), + ); +} diff --git a/phono-pestgros/src/lib.rs b/phono-pestgros/src/lib.rs new file mode 100644 index 0000000..5462dab --- /dev/null +++ b/phono-pestgros/src/lib.rs @@ -0,0 +1,370 @@ +//! Incomplete but useful parser and generator for Postgres flavored SQL +//! expressions and more, based on a modified version of the +//! [official Pest SQL grammar](https://github.com/pest-parser/pest/blob/79dd30d11aab6f0fba3cd79bd48f456209b966b3/grammars/src/grammars/sql.pest). +//! +//! ## Example +//! +//! ``` +//! use phono_pestgros::{ArithOp, BoolOp, Datum, Expr, InfixOp}; +//! +//! # fn main() -> Result<(), Box> { +//! let expr = Expr::try_from("3 + 5 < 10")?; +//! +//! assert_eq!(expr, Expr::Infix { +//! lhs: Box::new(Expr::Infix { +//! lhs: Box::new(Expr::Literal(Datum::Numeric(Some(3.into())))), +//! op: InfixOp::ArithInfix(ArithOp::Add), +//! rhs: Box::new(Expr::Literal(Datum::Numeric(Some(5.into())))), +//! }), +//! op: InfixOp::BoolInfix(BoolOp::Lt), +//! rhs: Box::new(Expr::Literal(Datum::Numeric(Some(10.into())))), +//! }); +//! +//! assert_eq!(QueryBuilder::try_from(expr).sql(), "(($1) + ($2)) < ($3)"); +//! # Ok(()) +//! # } +//! ``` + +use std::{str::FromStr, sync::LazyLock}; + +use bigdecimal::BigDecimal; +use pest::{ + Parser as _, + iterators::{Pair, Pairs}, + pratt_parser::PrattParser, +}; +use pest_derive::Parser; + +pub use crate::datum::Datum; + +mod datum; +mod query_builders; + +#[cfg(test)] +mod fragment_tests; +#[cfg(test)] +mod func_invocation_tests; +#[cfg(test)] +mod identifier_tests; +#[cfg(test)] +mod literal_tests; +#[cfg(test)] +mod op_tests; + +/// Given a raw identifier (such as a table name, column name, etc.), format it +/// so that it may be safely interpolated into a SQL query. +/// +/// Note that in PostgreSQL, unquoted identifiers are case-insensitive (or, +/// rather, they are always implicitly converted to lowercase), while quoted +/// identifiers are case-sensitive. The caller of this function is responsible +/// for performing conversion to lowercase as appropriate. +pub fn escape_identifier(identifier: &str) -> String { + // Escaping identifiers for Postgres is fairly easy, provided that the input is + // already known to contain no invalid multi-byte sequences. Backslashes may + // remain as-is, and embedded double quotes are escaped simply by doubling + // them (`"` becomes `""`). Refer to the PQescapeInternal() function in + // libpq (fe-exec.c) and Diesel's PgQueryBuilder::push_identifier(). + format!("\"{0}\"", identifier.replace('"', "\"\"")) +} + +/// Decodes a SQL representation of an identifier. If the input is unquoted, it +/// is converted to lowercase. If it is double quoted, the surrounding quotes +/// are stripped and escaped inner double quotes (double-double quotes, if you +/// will) are converted to single-double quotes. The opposite of +/// [`escape_identifier`], sort of. +/// +/// Assumes that the provided identifier is well-formed. Basic gut checks are +/// performed, but they are non-exhaustive. +/// +/// `U&"..."`-style escaped Unicode identifiers are not yet supported. +fn parse_ident(value: &str) -> String { + assert!( + !value.to_lowercase().starts_with("u&"), + "escaped Unicode identifiers are not supported" + ); + if value.starts_with('"') { + assert!(value.ends_with('"'), "malformed double-quoted identifier"); + { + // Strip first and last characters. + let mut chars = value.chars(); + chars.next(); + chars.next_back(); + chars.as_str() + } + .replace(r#""""#, r#"""#) + } else { + // TODO: assert validity with regex + value.to_lowercase() + } +} + +/// Decodes a single-quoted string literal. Removes surrounding quotes and +/// replaces embedded single quotes (double-single quotes) with single-single +/// quotes. +/// +/// Assumes that the provided identifier is well-formed. Basic gut checks are +/// performed, but they are non-exhaustive. +/// +/// `E'...'`-style, dollar-quoted, and other (relatively) uncommon formats for +/// text literals are not yet supported. +fn parse_text_literal(value: &str) -> String { + assert!(value.starts_with('\'') && value.ends_with('\'')); + { + // Strip first and last characters. + let mut chars = value.chars(); + chars.next(); + chars.next_back(); + chars.as_str() + } + .replace("''", "'") +} + +/// Primary parser and code generation for [`Rule`] types. +#[derive(Parser)] +#[grammar = "src/grammar.pest"] +struct PsqlParser; + +/// Secondary parser configuration for handling operator precedence. +static PRATT_PARSER: LazyLock> = LazyLock::new(|| { + use pest::pratt_parser::{ + Assoc::{Left, Right}, + Op, + }; + + PrattParser::new() + .op(Op::infix(Rule::Or, Left)) + .op(Op::infix(Rule::Between, Left)) + .op(Op::infix(Rule::And, Left)) + .op(Op::prefix(Rule::UnaryNot)) + .op(Op::infix(Rule::Eq, Right) + | Op::infix(Rule::NotEq, Right) + | Op::infix(Rule::Gt, Right) + | Op::infix(Rule::GtEq, Right) + | Op::infix(Rule::Lt, Right) + | Op::infix(Rule::LtEq, Right) + | Op::infix(Rule::In, Right)) + // Official Pest example overstates the concat operator's precedence. It + // should be lower precedence than add/subtract. + .op(Op::infix(Rule::ConcatInfixOp, Left)) + .op(Op::infix(Rule::Add, Left) | Op::infix(Rule::Subtract, Left)) + .op(Op::infix(Rule::Multiply, Left) | Op::infix(Rule::Divide, Left)) + .op(Op::postfix(Rule::IsNullPostfix)) +}); + +/// Represents a SQL expression. An expression is a collection of values and +/// operators that theoretically evaluates to some value, such as a boolean +/// condition, an object name, or a string dynamically derived from other +/// values. An expression is *not* a complete SQL statement, command, or query. +#[derive(Clone, Debug, PartialEq)] +pub enum Expr { + Infix { + lhs: Box, + op: InfixOp, + rhs: Box, + }, + Literal(Datum), + ObjName(Vec), + FnCall { + name: Vec, + args: FnArgs, + }, + Not(Box), + Nullness { + is_null: bool, + expr: Box, + }, +} + +impl TryFrom<&str> for Expr { + type Error = ParseError; + + fn try_from(value: &str) -> Result { + // `ExprRoot` is a silent rule which simply dictates that the inner + // `Expr` rule must consume the entire input. + let pairs = PsqlParser::parse(Rule::ExprRoot, value)?; + parse_expr_pairs(pairs) + } +} + +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum InfixOp { + ArithInfix(ArithOp), + BoolInfix(BoolOp), +} + +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum ArithOp { + Add, + Concat, + Div, + Mult, + Sub, +} + +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum BoolOp { + And, + Or, + Eq, + Gt, + Gte, + Lt, + Lte, + Neq, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum FnArgs { + CountAsterisk, + Exprs { + /// `true` for aggregator invocations with the `DISTINCT` keyword + /// specified. + distinct_flag: bool, + exprs: Vec, + }, +} + +/// Recursive helper, which does most of the work to convert [`pest`]'s pattern +/// matching output to a usable syntax tree. +fn parse_expr_pairs(expr_pairs: Pairs<'_, Rule>) -> Result { + PRATT_PARSER + .map_primary(|pair| match pair.as_rule() { + Rule::Expr | Rule::ExpressionInParentheses => parse_expr_pairs(pair.into_inner()), + Rule::Decimal | Rule::Double | Rule::Integer | Rule::Unsigned => Ok(Expr::Literal( + Datum::Numeric(Some(BigDecimal::from_str(pair.as_str()).expect( + "parsed numeric values should always be convertible to BigDecimal", + ))), + )), + Rule::SingleQuotedString => Ok(Expr::Literal(Datum::Text(Some(parse_text_literal(pair.as_str()))))), + Rule::IdentifierWithOptionalContinuation => { + let mut name: Vec = vec![]; + let mut fn_args: Option = None; + let inner = pair.into_inner(); + for inner_pair in inner { + match inner_pair.as_rule() { + Rule::Identifier => { + name.push(parse_ident(inner_pair.as_str())); + } + Rule::QualifiedIdentifierContinuation => { + let ident_cont = inner_pair.as_str(); + assert!( + ident_cont.starts_with('.'), + "QualifiedIdentifierContinuation should always start with the infix dot", + ); + name.push(parse_ident({ + // Strip leading dot. + let mut chars = ident_cont.chars(); + chars.next(); + chars.as_str() + })); + } + Rule::FunctionInvocationContinuation => { + fn_args = Some(parse_function_invocation_continuation(inner_pair)?); + } + _ => unreachable!( + "IdentifierWithOptionalContinuation has only 3 valid child rules", + ), + } + } + Ok(if let Some(fn_args) = fn_args { + Expr::FnCall { name, args: fn_args } + } else { + Expr::ObjName(name) + }) + } + rule => Err(ParseError::UnknownRule(rule)), + }) + .map_infix(|lhs, op, rhs| Ok(Expr::Infix { + lhs: Box::new(lhs?), + op: match op.as_rule() { + Rule::Add => InfixOp::ArithInfix(ArithOp::Add), + Rule::ConcatInfixOp => InfixOp::ArithInfix(ArithOp::Concat), + Rule::Divide => InfixOp::ArithInfix(ArithOp::Div), + Rule::Multiply => InfixOp::ArithInfix(ArithOp::Mult), + Rule::Subtract => InfixOp::ArithInfix(ArithOp::Sub), + Rule::And => InfixOp::BoolInfix(BoolOp::And), + Rule::Eq => InfixOp::BoolInfix(BoolOp::Eq), + Rule::Gt => InfixOp::BoolInfix(BoolOp::Gt), + Rule::GtEq => InfixOp::BoolInfix(BoolOp::Gte), + Rule::Lt => InfixOp::BoolInfix(BoolOp::Lt), + Rule::LtEq => InfixOp::BoolInfix(BoolOp::Lte), + Rule::NotEq => InfixOp::BoolInfix(BoolOp::Neq), + Rule::Or => InfixOp::BoolInfix(BoolOp::Or), + rule => Err(ParseError::UnknownRule(rule))?, + }, + rhs: Box::new(rhs?), + })) + .map_prefix(|op, child| Ok(match op.as_rule() { + Rule::UnaryNot => Expr::Not(Box::new(child?)), + rule => Err(ParseError::UnknownRule(rule))?, + })) + .map_postfix(|child, op| Ok(match op.as_rule() { + Rule::IsNullPostfix => Expr::Nullness { + is_null: op + .into_inner() + .next() + .map(|inner| inner.as_rule()) != Some(Rule::NotFlag), + expr: Box::new(child?), + }, + rule => Err(ParseError::UnknownRule(rule))?, + })) + .parse(expr_pairs) +} + +fn parse_function_invocation_continuation(pair: Pair<'_, Rule>) -> Result { + let mut cont_inner_iter = pair.into_inner(); + let fn_args = if let Some(cont_inner) = cont_inner_iter.next() { + match cont_inner.as_rule() { + Rule::FunctionArgs => { + let mut distinct_flag = false; + let mut exprs: Vec = vec![]; + for arg_inner in cont_inner.into_inner() { + match arg_inner.as_rule() { + Rule::Distinct => { + distinct_flag = true; + } + Rule::Expr => { + exprs.push(parse_expr_pairs(arg_inner.into_inner())?); + } + _ => unreachable!( + "only valid children of FunctionArgs are Distinct and Expr" + ), + } + } + FnArgs::Exprs { + distinct_flag, + exprs, + } + } + Rule::CountAsterisk => FnArgs::CountAsterisk, + _ => unreachable!( + "only valid children of FunctionInvocationContinuation are FunctionArgs and CountAsterisk" + ), + } + } else { + FnArgs::Exprs { + distinct_flag: false, + exprs: vec![], + } + }; + assert!( + cont_inner_iter.next().is_none(), + "function should have consumed entire FunctionInvocationContinuation pair", + ); + Ok(fn_args) +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +#[error("parse error")] +pub enum ParseError { + #[error("unknown rule")] + UnknownRule(Rule), + #[error("pest failed to parse: {0}")] + Pest(pest::error::Error), +} + +impl From> for ParseError { + fn from(value: pest::error::Error) -> Self { + Self::Pest(value) + } +} diff --git a/phono-pestgros/src/literal_tests.rs b/phono-pestgros/src/literal_tests.rs new file mode 100644 index 0000000..8bca20c --- /dev/null +++ b/phono-pestgros/src/literal_tests.rs @@ -0,0 +1,30 @@ +use std::error::Error; + +use crate::{Datum, Expr}; + +#[test] +fn text_parses() -> Result<(), Box> { + assert_eq!( + Expr::try_from("'Hello, World!'")?, + Expr::Literal(Datum::Text(Some("Hello, World!".to_owned()))) + ); + Ok(()) +} + +#[test] +fn escaped_quotes_parse() -> Result<(), Box> { + assert_eq!( + Expr::try_from("'''Hello, World!'''")?, + Expr::Literal(Datum::Text(Some("'Hello, World!'".to_owned()))) + ); + Ok(()) +} + +#[test] +fn numeric_parses() -> Result<(), Box> { + assert_eq!( + Expr::try_from("1234.56")?, + Expr::Literal(Datum::Numeric(Some("1234.56".parse()?))) + ); + Ok(()) +} diff --git a/phono-pestgros/src/op_tests.rs b/phono-pestgros/src/op_tests.rs new file mode 100644 index 0000000..fbeab91 --- /dev/null +++ b/phono-pestgros/src/op_tests.rs @@ -0,0 +1,106 @@ +//! Unit tests for infix operator parsing within expressions. + +use crate::{ArithOp, Datum, Expr, InfixOp}; + +#[test] +fn add_op_parses() { + assert_eq!( + // https://xkcd.com/3184/ + Expr::try_from("six + 7"), + Ok(Expr::Infix { + lhs: Box::new(Expr::ObjName(vec!["six".to_owned()])), + op: InfixOp::ArithInfix(ArithOp::Add), + rhs: Box::new(Expr::Literal(Datum::Numeric(Some(7.into())))), + }) + ); +} + +#[test] +fn mult_op_parses() { + assert_eq!( + Expr::try_from("six * 7"), + Ok(Expr::Infix { + lhs: Box::new(Expr::ObjName(vec!["six".to_owned()])), + op: InfixOp::ArithInfix(ArithOp::Mult), + rhs: Box::new(Expr::Literal(Datum::Numeric(Some(7.into())))), + }) + ); +} + +#[test] +fn arith_precedence() { + assert_eq!( + Expr::try_from("(1 + 2) * 3 + 4"), + Ok(Expr::Infix { + lhs: Box::new(Expr::Infix { + lhs: Box::new(Expr::Infix { + lhs: Box::new(Expr::Literal(Datum::Numeric(Some(1.into())))), + op: InfixOp::ArithInfix(ArithOp::Add), + rhs: Box::new(Expr::Literal(Datum::Numeric(Some(2.into())))), + }), + op: InfixOp::ArithInfix(ArithOp::Mult), + rhs: Box::new(Expr::Literal(Datum::Numeric(Some(3.into())))), + }), + op: InfixOp::ArithInfix(ArithOp::Add), + rhs: Box::new(Expr::Literal(Datum::Numeric(Some(4.into())))), + }) + ); + assert_eq!( + Expr::try_from("1 - 2 / (3 - 4)"), + Ok(Expr::Infix { + lhs: Box::new(Expr::Literal(Datum::Numeric(Some(1.into())))), + op: InfixOp::ArithInfix(ArithOp::Sub), + rhs: Box::new(Expr::Infix { + lhs: Box::new(Expr::Literal(Datum::Numeric(Some(2.into())))), + op: InfixOp::ArithInfix(ArithOp::Div), + rhs: Box::new(Expr::Infix { + lhs: Box::new(Expr::Literal(Datum::Numeric(Some(3.into())))), + op: InfixOp::ArithInfix(ArithOp::Sub), + rhs: Box::new(Expr::Literal(Datum::Numeric(Some(4.into())))), + }), + }) + }) + ); +} + +#[test] +fn is_null_parses() { + assert_eq!( + Expr::try_from("my_var is null"), + Ok(Expr::Nullness { + is_null: true, + expr: Box::new(Expr::ObjName(vec!["my_var".to_owned()])) + }), + ); +} + +#[test] +fn is_not_null_parses() { + assert_eq!( + Expr::try_from("my_var is not null"), + Ok(Expr::Nullness { + is_null: false, + expr: Box::new(Expr::ObjName(vec!["my_var".to_owned()])) + }), + ); +} + +#[test] +fn not_parses() { + assert_eq!( + Expr::try_from("not my_var"), + Ok(Expr::Not(Box::new(Expr::ObjName(vec![ + "my_var".to_owned() + ])))), + ); +} + +#[test] +fn repeated_nots_parse() { + assert_eq!( + Expr::try_from("not not my_var"), + Ok(Expr::Not(Box::new(Expr::Not(Box::new(Expr::ObjName( + vec!["my_var".to_owned()] + )))))), + ); +} diff --git a/phono-pestgros/src/query_builders.rs b/phono-pestgros/src/query_builders.rs new file mode 100644 index 0000000..a8669e6 --- /dev/null +++ b/phono-pestgros/src/query_builders.rs @@ -0,0 +1,273 @@ +//! Assorted utilities for dynamically constructing and manipulating [`sqlx`] +//! queries. + +use sqlx::{Postgres, QueryBuilder}; + +use crate::{ArithOp, BoolOp, Datum, Expr, FnArgs, InfixOp, escape_identifier}; + +/// Representation of a partial, parameterized SQL query. Allows callers to +/// build queries iteratively and dynamically, handling parameter numbering +/// (`$1`, `$2`, `$3`, ...) automatically. +/// +/// This is similar to [`sqlx::QueryBuilder`], except that [`QueryFragment`] +/// objects are composable and may be concatenated to each other. +#[derive(Clone, Debug, PartialEq)] +pub struct QueryFragment { + /// SQL string, split wherever there is a query parameter. For example, + /// `select * from foo where id = $1 and status = $2` is represented along + /// the lines of `["select * from foo where id = ", " and status = ", ""]`. + /// `plain_sql` should always have exactly one more element than `params`. + plain_sql: Vec, + params: Vec, +} + +impl QueryFragment { + /// Validate invariants. Should be run immediately before returning any + /// useful output. + fn gut_checks(&self) { + assert!(self.plain_sql.len() == self.params.len() + 1); + } + + /// Parse from a SQL string with no parameters. + pub fn from_sql(sql: &str) -> Self { + Self { + plain_sql: vec![sql.to_owned()], + params: vec![], + } + } + + /// Convenience function to construct an empty value. + pub fn empty() -> Self { + Self::from_sql("") + } + + /// Parse from a parameter value with no additional SQL. (Renders as `$n`, + /// where`n` is the appropriate parameter index.) + pub fn from_param(param: Datum) -> Self { + Self { + plain_sql: vec!["".to_owned(), "".to_owned()], + params: vec![param], + } + } + + /// Append another query fragment to this one. + pub fn push(&mut self, mut other: QueryFragment) { + let tail = self + .plain_sql + .pop() + .expect("already asserted that vec contains at least 1 item"); + let head = other + .plain_sql + .first() + .expect("already asserted that vec contains at least 1 item"); + self.plain_sql.push(format!("{tail}{head}")); + for value in other.plain_sql.drain(1..) { + self.plain_sql.push(value); + } + self.params.append(&mut other.params); + } + + /// Combine multiple QueryFragments with a separator, similar to + /// [`Vec::join`]. + pub fn join>(fragments: I, sep: Self) -> Self { + let mut acc = QueryFragment::from_sql(""); + let mut iter = fragments.into_iter(); + let mut fragment = match iter.next() { + Some(value) => value, + None => return acc, + }; + for next_fragment in iter { + acc.push(fragment); + acc.push(sep.clone()); + fragment = next_fragment; + } + acc.push(fragment); + acc + } + + /// Convenience method equivalent to: + /// `QueryFragment::concat(fragments, QueryFragment::from_sql(""))` + pub fn concat>(fragments: I) -> Self { + Self::join(fragments, Self::from_sql("")) + } + + /// Checks whether value is empty. A value is considered empty if the + /// resulting SQL code is 0 characters long. + pub fn is_empty(&self) -> bool { + self.gut_checks(); + self.plain_sql.len() == 1 + && self + .plain_sql + .first() + .expect("already checked that len == 1") + .is_empty() + } +} + +impl From for QueryFragment { + fn from(value: Expr) -> Self { + match value { + Expr::Infix { lhs, op, rhs } => Self::concat([ + // RHS and LHS must be explicitly wrapped in parentheses to + // ensure correct precedence, because parentheses are taken + // into account **but not preserved** when parsing. + Self::from_sql("("), + (*lhs).into(), + Self::from_sql(") "), + op.into(), + Self::from_sql(" ("), + (*rhs).into(), + Self::from_sql(")"), + ]), + Expr::Literal(datum) => Self::from_param(datum), + Expr::ObjName(idents) => Self::join( + idents + .iter() + .map(|ident| Self::from_sql(&escape_identifier(ident))), + Self::from_sql("."), + ), + Expr::Not(expr) => { + Self::concat([Self::from_sql("not ("), (*expr).into(), Self::from_sql(")")]) + } + Expr::Nullness { is_null, expr } => Self::concat([ + Self::from_sql("("), + (*expr).into(), + Self::from_sql(if is_null { + ") is null" + } else { + ") is not null" + }), + ]), + Expr::FnCall { name, args } => { + let mut fragment = Self::empty(); + fragment.push(Self::join( + name.iter() + .map(|ident| Self::from_sql(&escape_identifier(ident))), + Self::from_sql("."), + )); + fragment.push(Self::from_sql("(")); + match args { + FnArgs::CountAsterisk => { + fragment.push(Self::from_sql("*")); + } + FnArgs::Exprs { + distinct_flag, + exprs, + } => { + if distinct_flag { + fragment.push(Self::from_sql("distinct ")); + } + fragment.push(Self::join( + exprs.into_iter().map(|expr| { + // Wrap arguments in parentheses to ensure they + // are appropriately distinguishable from each + // other regardless of the presence of extra + // commas. + Self::concat([ + Self::from_sql("("), + expr.into(), + Self::from_sql(")"), + ]) + }), + Self::from_sql(", "), + )); + } + } + fragment.push(Self::from_sql(")")); + fragment + } + } + } +} + +impl From for QueryFragment { + fn from(value: InfixOp) -> Self { + Self::from_sql(match value { + InfixOp::ArithInfix(ArithOp::Add) => "+", + InfixOp::ArithInfix(ArithOp::Concat) => "||", + InfixOp::ArithInfix(ArithOp::Div) => "/", + InfixOp::ArithInfix(ArithOp::Mult) => "*", + InfixOp::ArithInfix(ArithOp::Sub) => "-", + InfixOp::BoolInfix(BoolOp::And) => "and", + InfixOp::BoolInfix(BoolOp::Or) => "or", + InfixOp::BoolInfix(BoolOp::Eq) => "=", + InfixOp::BoolInfix(BoolOp::Gt) => ">", + InfixOp::BoolInfix(BoolOp::Gte) => ">=", + InfixOp::BoolInfix(BoolOp::Lt) => "<", + InfixOp::BoolInfix(BoolOp::Lte) => "<=", + InfixOp::BoolInfix(BoolOp::Neq) => "<>", + }) + } +} + +impl From for QueryBuilder<'_, Postgres> { + fn from(value: QueryFragment) -> Self { + value.gut_checks(); + let mut builder = QueryBuilder::new(""); + let mut param_iter = value.params.into_iter(); + for plain_sql in value.plain_sql { + builder.push(plain_sql); + if let Some(param) = param_iter.next() { + param.push_bind_onto(&mut builder); + } + } + builder + } +} + +impl From for QueryBuilder<'_, Postgres> { + fn from(value: Expr) -> Self { + Self::from(QueryFragment::from(value)) + } +} + +/// Helper type to make it easier to build and reason about multiple related SQL +/// queries. +#[derive(Clone, Debug)] +pub struct SelectQuery { + /// Query fragment following (not including) "select ". + pub selection: QueryFragment, + + /// Query fragment following (not including) "from ". + pub source: QueryFragment, + + /// Query fragment following (not including) "where ", or empty if not + /// applicable. + pub filters: QueryFragment, + + /// Query fragment following (not including) "order by ", or empty if not + /// applicable. + pub order: QueryFragment, + + /// Query fragment following (not including) "limit ", or empty if not + /// applicable. + pub limit: QueryFragment, +} + +impl From for QueryFragment { + fn from(value: SelectQuery) -> Self { + let mut result = QueryFragment::from_sql("select "); + result.push(value.selection); + result.push(QueryFragment::from_sql(" from ")); + result.push(value.source); + if !value.filters.is_empty() { + result.push(QueryFragment::from_sql(" where ")); + result.push(value.filters); + } + if !value.order.is_empty() { + result.push(QueryFragment::from_sql(" order by ")); + result.push(value.order); + } + if !value.limit.is_empty() { + result.push(QueryFragment::from_sql(" limit ")); + result.push(value.limit); + } + result + } +} + +impl From for QueryBuilder<'_, Postgres> { + fn from(value: SelectQuery) -> Self { + QueryFragment::from(value).into() + } +}