reject boolean and text in math functions at the post table script

This commit is contained in:
filipriec
2025-07-16 22:31:55 +02:00
parent de42bb48aa
commit fe246b1fe6
2 changed files with 172 additions and 1 deletions

View File

@@ -96,6 +96,9 @@ impl SteelContext {
_ => value.to_string(), // Return as-is if not a recognized boolean
}
}
"INTEGER" => {
value.to_string()
}
_ => value.to_string(), // Return as-is for non-boolean types
}
}

View File

@@ -8,6 +8,8 @@ use serde_json::Value;
use steel_decimal::SteelDecimal;
use regex::Regex;
use std::collections::HashSet;
use std::collections::HashMap;
use crate::table_script::handlers::dependency_analyzer::DependencyAnalyzer;
@@ -15,6 +17,169 @@ const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"];
// Define prohibited data types for Steel scripts (boolean is explicitly allowed)
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
const MATH_PROHIBITED_TYPES: &[&str] = &["TEXT", "BOOLEAN"];
/// Extract mathematical expressions from the original script (before steel_decimal transformation)
fn extract_math_operations_with_operands(script: &str) -> Vec<String> {
let mut math_operands = Vec::new();
// Define math operation patterns that steel_decimal will transform
let math_patterns = [
r"\(\s*\+\s+([^)]+)\)", // (+ operands)
r"\(\s*-\s+([^)]+)\)", // (- operands)
r"\(\s*\*\s+([^)]+)\)", // (* operands)
r"\(\s*/\s+([^)]+)\)", // (/ operands)
r"\(\s*\^\s+([^)]+)\)", // (^ operands)
r"\(\s*\*\*\s+([^)]+)\)", // (** operands)
r"\(\s*pow\s+([^)]+)\)", // (pow operands)
r"\(\s*sqrt\s+([^)]+)\)", // (sqrt operands)
r"\(\s*>\s+([^)]+)\)", // (> operands)
r"\(\s*<\s+([^)]+)\)", // (< operands)
r"\(\s*=\s+([^)]+)\)", // (= operands)
r"\(\s*>=\s+([^)]+)\)", // (>= operands)
r"\(\s*<=\s+([^)]+)\)", // (<= operands)
r"\(\s*min\s+([^)]+)\)", // (min operands)
r"\(\s*max\s+([^)]+)\)", // (max operands)
r"\(\s*abs\s+([^)]+)\)", // (abs operands)
r"\(\s*round\s+([^)]+)\)", // (round operands)
r"\(\s*ln\s+([^)]+)\)", // (ln operands)
r"\(\s*log\s+([^)]+)\)", // (log operands)
r"\(\s*log10\s+([^)]+)\)", // (log10 operands)
r"\(\s*exp\s+([^)]+)\)", // (exp operands)
r"\(\s*sin\s+([^)]+)\)", // (sin operands)
r"\(\s*cos\s+([^)]+)\)", // (cos operands)
r"\(\s*tan\s+([^)]+)\)", // (tan operands)
];
for pattern in &math_patterns {
if let Ok(re) = Regex::new(pattern) {
for cap in re.captures_iter(script) {
if let Some(operands_str) = cap.get(1) {
// Add all operands from this math operation
math_operands.push(operands_str.as_str().to_string());
}
}
}
}
math_operands
}
/// Extract column references from mathematical operands
fn extract_column_references_from_math_operands(operands: &[String]) -> Vec<(String, String)> {
let mut references = Vec::new();
for operand_str in operands {
// Check for steel_get_column calls: (steel_get_column "table" "column")
if let Ok(re) = Regex::new(r#"\(steel_get_column\s+"([^"]+)"\s+"([^"]+)"\)"#) {
for cap in re.captures_iter(operand_str) {
if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) {
references.push((table.as_str().to_string(), column.as_str().to_string()));
}
}
}
// Check for steel_get_column_with_index calls
if let Ok(re) = Regex::new(r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#) {
for cap in re.captures_iter(operand_str) {
if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) {
references.push((table.as_str().to_string(), column.as_str().to_string()));
}
}
}
}
references
}
/// Validate that mathematical operations don't use TEXT or BOOLEAN columns
async fn validate_math_operations_column_types(
db_pool: &PgPool,
schema_id: i64,
script: &str,
) -> Result<(), Status> {
// Extract all mathematical operations and their operands
let math_operands = extract_math_operations_with_operands(script);
if math_operands.is_empty() {
return Ok(()); // No math operations to validate
}
// Extract column references from math operands
let column_refs = extract_column_references_from_math_operands(&math_operands);
if column_refs.is_empty() {
return Ok(()); // No column references in math operations
}
// Get all unique table names referenced in math operations
let table_names: HashSet<String> = column_refs.iter()
.map(|(table, _)| table.clone())
.collect();
// Fetch table definitions for all referenced tables
let table_definitions = sqlx::query!(
r#"SELECT table_name, columns FROM table_definitions
WHERE schema_id = $1 AND table_name = ANY($2)"#,
schema_id,
&table_names.into_iter().collect::<Vec<_>>()
)
.fetch_all(db_pool)
.await
.map_err(|e| Status::internal(format!("Failed to fetch table definitions: {}", e)))?;
// Build a map of table_name -> column_name -> column_type
let mut table_column_types: HashMap<String, HashMap<String, String>> = HashMap::new();
for table_def in table_definitions {
let columns: Vec<String> = serde_json::from_value(table_def.columns)
.map_err(|e| Status::internal(format!("Invalid column data for table '{}': {}", table_def.table_name, e)))?;
let mut column_types = HashMap::new();
for column_def in columns {
let mut parts = column_def.split_whitespace();
if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) {
let column_name = name.trim_matches('"');
column_types.insert(column_name.to_string(), data_type.to_string());
}
}
table_column_types.insert(table_def.table_name, column_types);
}
// Check each column reference in mathematical operations
for (table_name, column_name) in column_refs {
if let Some(table_columns) = table_column_types.get(&table_name) {
if let Some(column_type) = table_columns.get(&column_name) {
let normalized_type = normalize_data_type(column_type);
// Check if this type is prohibited in math operations
if MATH_PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) {
return Err(Status::invalid_argument(format!(
"Cannot use column '{}' of type '{}' from table '{}' in mathematical operations. Mathematical operations cannot use columns of type: {}",
column_name,
column_type,
table_name,
MATH_PROHIBITED_TYPES.join(", ")
)));
}
} else {
return Err(Status::invalid_argument(format!(
"Script references column '{}' in table '{}' but this column does not exist",
column_name,
table_name
)));
}
} else {
return Err(Status::invalid_argument(format!(
"Script references table '{}' in mathematical operations but this table does not exist in this schema",
table_name
)));
}
}
Ok(())
}
/// Validates the target column and ensures it is not a system column or prohibited type.
/// Returns the column type if valid.
@@ -281,6 +446,9 @@ pub async fn post_table_script(
// Validate that script doesn't reference prohibited column types by checking actual DB schema
validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?;
// NEW: Validate that mathematical operations don't use TEXT or BOOLEAN columns
validate_math_operations_column_types(db_pool, table_def.schema_id, &request.script).await?;
// Validate SQL queries in script for prohibited type operations
validate_sql_queries_in_script(&request.script)
.map_err(|e| Status::invalid_argument(e))?;
@@ -300,7 +468,7 @@ pub async fn post_table_script(
.await
.map_err(|e| Status::from(e))?;
// Transform the script using steel_decimal
// Transform the script using steel_decimal (this happens AFTER validation)
let steel_decimal = SteelDecimal::new();
let parsed_script = steel_decimal.transform(&request.script);