fix filter expression sql syntax bug

This commit is contained in:
Brent Schroeter 2026-01-13 18:10:44 +00:00
parent 4ba4e787a2
commit 234e6d6e7e
2 changed files with 42 additions and 31 deletions

View file

@ -5,6 +5,9 @@ use serde::{Deserialize, Serialize};
use crate::datum::Datum; use crate::datum::Datum;
/// Representation of a partial, parameterized SQL query. Allows callers to
/// build queries iteratively and dynamically, handling parameter numbering
/// (`$1`, `$2`, `$3`, ...) automatically.
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct QueryFragment { pub struct QueryFragment {
/// SQL string, split wherever there is a query parameter. For example, /// SQL string, split wherever there is a query parameter. For example,
@ -34,10 +37,12 @@ impl QueryFragment {
.join("") .join("")
} }
/// Returns only the parameterized values, in order.
pub fn to_params(&self) -> Vec<Datum> { pub fn to_params(&self) -> Vec<Datum> {
self.params.clone() self.params.clone()
} }
/// Parse from a SQL string with no parameters.
pub fn from_sql(sql: &str) -> Self { pub fn from_sql(sql: &str) -> Self {
Self { Self {
plain_sql: vec![sql.to_owned()], plain_sql: vec![sql.to_owned()],
@ -45,6 +50,8 @@ impl QueryFragment {
} }
} }
/// Parse from a parameter value with no additional SQL. (Renders as `$n`,
/// where`n` is the appropriate parameter index.)
pub fn from_param(param: Datum) -> Self { pub fn from_param(param: Datum) -> Self {
Self { Self {
plain_sql: vec!["".to_owned(), "".to_owned()], plain_sql: vec!["".to_owned(), "".to_owned()],
@ -52,6 +59,7 @@ impl QueryFragment {
} }
} }
/// Append another query fragment to this one.
pub fn push(&mut self, mut other: QueryFragment) { pub fn push(&mut self, mut other: QueryFragment) {
assert!(self.plain_sql.len() == self.params.len() + 1); assert!(self.plain_sql.len() == self.params.len() + 1);
assert!(other.plain_sql.len() == other.params.len() + 1); assert!(other.plain_sql.len() == other.params.len() + 1);
@ -70,7 +78,8 @@ impl QueryFragment {
self.params.append(&mut other.params); self.params.append(&mut other.params);
} }
/// Combine multiple QueryFragments with a separator, similar to Vec::join(). /// Combine multiple QueryFragments with a separator, similar to
/// [`Vec::join`].
pub fn join<I: IntoIterator<Item = Self>>(fragments: I, sep: Self) -> Self { pub fn join<I: IntoIterator<Item = Self>>(fragments: I, sep: Self) -> Self {
let mut acc = QueryFragment::from_sql(""); let mut acc = QueryFragment::from_sql("");
let mut iter = fragments.into_iter(); let mut iter = fragments.into_iter();
@ -94,6 +103,9 @@ impl QueryFragment {
} }
} }
/// Building block of a syntax tree for a constrained subset of SQL that can be
/// statically analyzed, to validate that user-provided expressions perform only
/// operations that are read-only and otherwise safe to execute.
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[serde(tag = "t", content = "c")] #[serde(tag = "t", content = "c")]
pub enum PgExpressionAny { pub enum PgExpressionAny {

View file

@ -11,6 +11,7 @@ use phono_backends::{
use phono_models::{ use phono_models::{
accessors::{Accessor, Actor, portal::PortalAccessor}, accessors::{Accessor, Actor, portal::PortalAccessor},
datum::Datum, datum::Datum,
expression::QueryFragment,
field::Field, field::Field,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -96,38 +97,36 @@ pub(super) async fn get(
field_info field_info
}; };
let mut sql_raw = format!( let sql_fragment = {
"select {0} from {1}.{2} order by _id", // Defensive programming: Make `sql_fragment` immutable once built.
let mut sql_fragment = QueryFragment::from_sql(&format!(
"select {0} from {1}",
pkey_attrs pkey_attrs
.iter() .iter()
.chain(attrs.iter()) .chain(attrs.iter())
.map(|attr| escape_identifier(&attr.attname)) .map(|attr| escape_identifier(&attr.attname))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", "), .join(", "),
escape_identifier(&rel.regnamespace), rel.get_identifier(),
escape_identifier(&rel.relname), ));
); if let Some(filter_expr) = portal.table_filter.0 {
let rows: Vec<PgRow> = if let Some(filter_expr) = portal.table_filter.0 { sql_fragment.push(QueryFragment::from_sql(" where "));
let filter_fragment = filter_expr.into_query_fragment(); sql_fragment.push(filter_expr.into_query_fragment());
let filter_params = filter_fragment.to_params(); }
sql_raw = format!( sql_fragment.push(QueryFragment::from_sql(" order by _id limit "));
"{sql_raw} where {0} limit ${1}", sql_fragment.push(QueryFragment::from_param(Datum::Numeric(Some(
filter_fragment.to_sql(1), FRONTEND_ROW_LIMIT.into(),
filter_params.len() + 1 ))));
); sql_fragment
};
let sql_raw = sql_fragment.to_sql(1);
let mut q = query(&sql_raw); let mut q = query(&sql_raw);
for param in filter_params { for param in sql_fragment.to_params() {
q = param.bind_onto(q); q = param.bind_onto(q);
} }
q = q.bind(FRONTEND_ROW_LIMIT); q = q.bind(FRONTEND_ROW_LIMIT);
q.fetch_all(workspace_client.get_conn()).await? let rows: Vec<PgRow> = q.fetch_all(workspace_client.get_conn()).await?;
} else {
sql_raw = format!("{sql_raw} limit $1");
query(&sql_raw)
.bind(FRONTEND_ROW_LIMIT)
.fetch_all(workspace_client.get_conn())
.await?
};
#[derive(Serialize)] #[derive(Serialize)]
struct DataRow { struct DataRow {