diff --git a/server/src/steel/server/functions.rs b/server/src/steel/server/functions.rs index 46fb807..5f92df6 100644 --- a/server/src/steel/server/functions.rs +++ b/server/src/steel/server/functions.rs @@ -96,6 +96,9 @@ impl SteelContext { _ => value.to_string(), // Return as-is if not a recognized boolean } } + "INTEGER" => { + value.to_string() + } _ => value.to_string(), // Return as-is for non-boolean types } } diff --git a/server/src/table_script/handlers/post_table_script.rs b/server/src/table_script/handlers/post_table_script.rs index 3594264..22558aa 100644 --- a/server/src/table_script/handlers/post_table_script.rs +++ b/server/src/table_script/handlers/post_table_script.rs @@ -8,6 +8,8 @@ use serde_json::Value; use steel_decimal::SteelDecimal; use regex::Regex; use std::collections::HashSet; +use std::collections::HashMap; + use crate::table_script::handlers::dependency_analyzer::DependencyAnalyzer; @@ -15,6 +17,169 @@ const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"]; // Define prohibited data types for Steel scripts (boolean is explicitly allowed) const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"]; +const MATH_PROHIBITED_TYPES: &[&str] = &["TEXT", "BOOLEAN"]; + + +/// Extract mathematical expressions from the original script (before steel_decimal transformation) +fn extract_math_operations_with_operands(script: &str) -> Vec { + let mut math_operands = Vec::new(); + + // Define math operation patterns that steel_decimal will transform + let math_patterns = [ + r"\(\s*\+\s+([^)]+)\)", // (+ operands) + r"\(\s*-\s+([^)]+)\)", // (- operands) + r"\(\s*\*\s+([^)]+)\)", // (* operands) + r"\(\s*/\s+([^)]+)\)", // (/ operands) + r"\(\s*\^\s+([^)]+)\)", // (^ operands) + r"\(\s*\*\*\s+([^)]+)\)", // (** operands) + r"\(\s*pow\s+([^)]+)\)", // (pow operands) + r"\(\s*sqrt\s+([^)]+)\)", // (sqrt operands) + r"\(\s*>\s+([^)]+)\)", // (> operands) + r"\(\s*<\s+([^)]+)\)", // (< operands) + r"\(\s*=\s+([^)]+)\)", // (= operands) + r"\(\s*>=\s+([^)]+)\)", // (>= operands) + r"\(\s*<=\s+([^)]+)\)", // (<= operands) + r"\(\s*min\s+([^)]+)\)", // (min operands) + r"\(\s*max\s+([^)]+)\)", // (max operands) + r"\(\s*abs\s+([^)]+)\)", // (abs operands) + r"\(\s*round\s+([^)]+)\)", // (round operands) + r"\(\s*ln\s+([^)]+)\)", // (ln operands) + r"\(\s*log\s+([^)]+)\)", // (log operands) + r"\(\s*log10\s+([^)]+)\)", // (log10 operands) + r"\(\s*exp\s+([^)]+)\)", // (exp operands) + r"\(\s*sin\s+([^)]+)\)", // (sin operands) + r"\(\s*cos\s+([^)]+)\)", // (cos operands) + r"\(\s*tan\s+([^)]+)\)", // (tan operands) + ]; + + for pattern in &math_patterns { + if let Ok(re) = Regex::new(pattern) { + for cap in re.captures_iter(script) { + if let Some(operands_str) = cap.get(1) { + // Add all operands from this math operation + math_operands.push(operands_str.as_str().to_string()); + } + } + } + } + + math_operands +} + +/// Extract column references from mathematical operands +fn extract_column_references_from_math_operands(operands: &[String]) -> Vec<(String, String)> { + let mut references = Vec::new(); + + for operand_str in operands { + // Check for steel_get_column calls: (steel_get_column "table" "column") + if let Ok(re) = Regex::new(r#"\(steel_get_column\s+"([^"]+)"\s+"([^"]+)"\)"#) { + for cap in re.captures_iter(operand_str) { + if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) { + references.push((table.as_str().to_string(), column.as_str().to_string())); + } + } + } + + // Check for steel_get_column_with_index calls + if let Ok(re) = Regex::new(r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#) { + for cap in re.captures_iter(operand_str) { + if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) { + references.push((table.as_str().to_string(), column.as_str().to_string())); + } + } + } + } + + references +} + +/// Validate that mathematical operations don't use TEXT or BOOLEAN columns +async fn validate_math_operations_column_types( + db_pool: &PgPool, + schema_id: i64, + script: &str, +) -> Result<(), Status> { + // Extract all mathematical operations and their operands + let math_operands = extract_math_operations_with_operands(script); + + if math_operands.is_empty() { + return Ok(()); // No math operations to validate + } + + // Extract column references from math operands + let column_refs = extract_column_references_from_math_operands(&math_operands); + + if column_refs.is_empty() { + return Ok(()); // No column references in math operations + } + + // Get all unique table names referenced in math operations + let table_names: HashSet = column_refs.iter() + .map(|(table, _)| table.clone()) + .collect(); + + // Fetch table definitions for all referenced tables + let table_definitions = sqlx::query!( + r#"SELECT table_name, columns FROM table_definitions + WHERE schema_id = $1 AND table_name = ANY($2)"#, + schema_id, + &table_names.into_iter().collect::>() + ) + .fetch_all(db_pool) + .await + .map_err(|e| Status::internal(format!("Failed to fetch table definitions: {}", e)))?; + + // Build a map of table_name -> column_name -> column_type + let mut table_column_types: HashMap> = HashMap::new(); + + for table_def in table_definitions { + let columns: Vec = serde_json::from_value(table_def.columns) + .map_err(|e| Status::internal(format!("Invalid column data for table '{}': {}", table_def.table_name, e)))?; + + let mut column_types = HashMap::new(); + for column_def in columns { + let mut parts = column_def.split_whitespace(); + if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) { + let column_name = name.trim_matches('"'); + column_types.insert(column_name.to_string(), data_type.to_string()); + } + } + table_column_types.insert(table_def.table_name, column_types); + } + + // Check each column reference in mathematical operations + for (table_name, column_name) in column_refs { + if let Some(table_columns) = table_column_types.get(&table_name) { + if let Some(column_type) = table_columns.get(&column_name) { + let normalized_type = normalize_data_type(column_type); + + // Check if this type is prohibited in math operations + if MATH_PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) { + return Err(Status::invalid_argument(format!( + "Cannot use column '{}' of type '{}' from table '{}' in mathematical operations. Mathematical operations cannot use columns of type: {}", + column_name, + column_type, + table_name, + MATH_PROHIBITED_TYPES.join(", ") + ))); + } + } else { + return Err(Status::invalid_argument(format!( + "Script references column '{}' in table '{}' but this column does not exist", + column_name, + table_name + ))); + } + } else { + return Err(Status::invalid_argument(format!( + "Script references table '{}' in mathematical operations but this table does not exist in this schema", + table_name + ))); + } + } + + Ok(()) +} /// Validates the target column and ensures it is not a system column or prohibited type. /// Returns the column type if valid. @@ -281,6 +446,9 @@ pub async fn post_table_script( // Validate that script doesn't reference prohibited column types by checking actual DB schema validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?; + // NEW: Validate that mathematical operations don't use TEXT or BOOLEAN columns + validate_math_operations_column_types(db_pool, table_def.schema_id, &request.script).await?; + // Validate SQL queries in script for prohibited type operations validate_sql_queries_in_script(&request.script) .map_err(|e| Status::invalid_argument(e))?; @@ -300,7 +468,7 @@ pub async fn post_table_script( .await .map_err(|e| Status::from(e))?; - // Transform the script using steel_decimal + // Transform the script using steel_decimal (this happens AFTER validation) let steel_decimal = SteelDecimal::new(); let parsed_script = steel_decimal.transform(&request.script);