// src/table_script/handlers/post_table_script.rs // TODO MAKE THE SCRIPTS PUSH ONLY TO THE EMPTY FILES use tonic::Status; use sqlx::{PgPool, Error as SqlxError}; use common::proto::multieko2::table_script::{PostTableScriptRequest, TableScriptResponse}; use serde_json::Value; use steel_decimal::SteelDecimal; use regex::Regex; use std::collections::HashSet; use std::collections::HashMap; use crate::table_script::handlers::dependency_analyzer::DependencyAnalyzer; const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"]; // Define prohibited data types for Steel scripts (boolean is explicitly allowed) const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"]; const MATH_PROHIBITED_TYPES: &[&str] = &["BIGINT", "TEXT", "BOOLEAN", "DATE", "TIMESTAMPTZ"]; // Math operations that Steel Decimal will transform const MATH_OPERATIONS: &[&str] = &[ "+", "-", "*", "/", "^", "**", "pow", "sqrt", ">", "<", "=", ">=", "<=", "min", "max", "abs", "round", "ln", "log", "log10", "exp", "sin", "cos", "tan" ]; #[derive(Debug, Clone)] enum SExpr { Atom(String), List(Vec), } #[derive(Debug)] struct Parser { tokens: Vec, position: usize, } impl Parser { fn new(script: &str) -> Self { let tokens = Self::tokenize(script); Self { tokens, position: 0 } } fn tokenize(script: &str) -> Vec { let mut tokens = Vec::new(); let mut current_token = String::new(); let mut in_string = false; let mut escape_next = false; for ch in script.chars() { if escape_next { current_token.push(ch); escape_next = false; continue; } match ch { '\\' if in_string => { escape_next = true; current_token.push(ch); } '"' => { current_token.push(ch); if in_string { // End of string - push the complete string token tokens.push(current_token.clone()); current_token.clear(); } in_string = !in_string; } '(' | ')' if !in_string => { if !current_token.is_empty() { tokens.push(current_token.clone()); current_token.clear(); } tokens.push(ch.to_string()); } ch if ch.is_whitespace() && !in_string => { if !current_token.is_empty() { tokens.push(current_token.clone()); current_token.clear(); } } _ => { current_token.push(ch); } } } if !current_token.is_empty() { tokens.push(current_token); } tokens } fn parse(&mut self) -> Result, String> { let mut expressions = Vec::new(); while self.position < self.tokens.len() { expressions.push(self.parse_expr()?); } Ok(expressions) } fn parse_expr(&mut self) -> Result { if self.position >= self.tokens.len() { return Err("Unexpected end of input".to_string()); } let token = &self.tokens[self.position]; if token == "(" { self.position += 1; // consume '(' let mut elements = Vec::new(); while self.position < self.tokens.len() && self.tokens[self.position] != ")" { elements.push(self.parse_expr()?); } if self.position >= self.tokens.len() { return Err("Missing closing parenthesis".to_string()); } self.position += 1; // consume ')' Ok(SExpr::List(elements)) } else { self.position += 1; Ok(SExpr::Atom(token.clone())) } } } #[derive(Debug)] struct MathValidator { column_references: Vec<(String, String)>, // (table, column) pairs found in math contexts } impl MathValidator { fn new() -> Self { Self { column_references: Vec::new(), } } fn validate_expressions(&mut self, expressions: &[SExpr]) -> Result<(), String> { for expr in expressions { self.check_expression(expr, false)?; } Ok(()) } fn check_expression(&mut self, expr: &SExpr, in_math_context: bool) -> Result<(), String> { match expr { SExpr::Atom(_) => Ok(()), SExpr::List(elements) => { if elements.is_empty() { return Ok(()); } // Check if this is a math operation let is_math = if let SExpr::Atom(op) = &elements[0] { MATH_OPERATIONS.contains(&op.as_str()) } else { false }; // Check if this is a column access function if let SExpr::Atom(func) = &elements[0] { if func == "steel_get_column" && in_math_context { self.extract_column_reference_from_steel_get_column(elements)?; } else if func == "steel_get_column_with_index" && in_math_context { self.extract_column_reference_from_steel_get_column_with_index(elements)?; } } // Recursively check all elements, marking math context appropriately for element in &elements[1..] { // Skip the operator/function name self.check_expression(element, in_math_context || is_math)?; } Ok(()) } } } fn extract_column_reference_from_steel_get_column(&mut self, elements: &[SExpr]) -> Result<(), String> { // (steel_get_column "table" "column") if elements.len() >= 3 { if let (SExpr::Atom(table), SExpr::Atom(column)) = (&elements[1], &elements[2]) { let table_name = self.unquote_string(table)?; let column_name = self.unquote_string(column)?; self.column_references.push((table_name, column_name)); } } Ok(()) } fn extract_column_reference_from_steel_get_column_with_index(&mut self, elements: &[SExpr]) -> Result<(), String> { // (steel_get_column_with_index "table" index "column") if elements.len() >= 4 { if let (SExpr::Atom(table), SExpr::Atom(column)) = (&elements[1], &elements[3]) { let table_name = self.unquote_string(table)?; let column_name = self.unquote_string(column)?; self.column_references.push((table_name, column_name)); } } Ok(()) } fn unquote_string(&self, s: &str) -> Result { if s.starts_with('"') && s.ends_with('"') && s.len() >= 2 { Ok(s[1..s.len()-1].to_string()) } else { Err(format!("Expected quoted string, got: {}", s)) } } } /// Valide script is not empty fn validate_script_basic_syntax(script: &str) -> Result<(), Status> { let trimmed = script.trim(); // Check for empty script if trimmed.is_empty() { return Err(Status::invalid_argument("Script cannot be empty")); } // Basic parentheses balance check let mut paren_count = 0; for ch in trimmed.chars() { match ch { '(' => paren_count += 1, ')' => { paren_count -= 1; if paren_count < 0 { return Err(Status::invalid_argument("Unbalanced parentheses: closing ')' without matching opening '('")); } }, _ => {} } } if paren_count != 0 { return Err(Status::invalid_argument("Unbalanced parentheses: missing closing parentheses")); } // Check for basic S-expression structure if !trimmed.starts_with('(') { return Err(Status::invalid_argument("Script must start with an opening parenthesis '('")); } Ok(()) } /// Parse Steel script and extract column references used in mathematical contexts fn extract_math_column_references(script: &str) -> Result, String> { let mut parser = Parser::new(script); let expressions = parser.parse() .map_err(|e| format!("Parse error: {}", e))?; let mut validator = MathValidator::new(); validator.validate_expressions(&expressions) .map_err(|e| format!("Validation error: {}", e))?; Ok(validator.column_references) } /// Validate that mathematical operations don't use TEXT or BOOLEAN columns async fn validate_math_operations_column_types( db_pool: &PgPool, schema_id: i64, script: &str, ) -> Result<(), Status> { // Extract column references from mathematical contexts using proper S-expression parsing let column_refs = extract_math_column_references(script) .map_err(|e| Status::invalid_argument(format!("Script parsing failed: {}", e)))?; if column_refs.is_empty() { return Ok(()); // No column references in math operations } // Get all unique table names referenced in math operations let table_names: HashSet = column_refs.iter() .map(|(table, _)| table.clone()) .collect(); // Fetch table definitions for all referenced tables let table_definitions = sqlx::query!( r#"SELECT table_name, columns FROM table_definitions WHERE schema_id = $1 AND table_name = ANY($2)"#, schema_id, &table_names.into_iter().collect::>() ) .fetch_all(db_pool) .await .map_err(|e| Status::internal(format!("Failed to fetch table definitions: {}", e)))?; // Build a map of table_name -> column_name -> column_type let mut table_column_types: HashMap> = HashMap::new(); for table_def in table_definitions { let columns: Vec = serde_json::from_value(table_def.columns) .map_err(|e| Status::internal(format!("Invalid column data for table '{}': {}", table_def.table_name, e)))?; let mut column_types = HashMap::new(); for column_def in columns { let mut parts = column_def.split_whitespace(); if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) { let column_name = name.trim_matches('"'); column_types.insert(column_name.to_string(), data_type.to_string()); } } table_column_types.insert(table_def.table_name, column_types); } // Check each column reference in mathematical operations for (table_name, column_name) in column_refs { if let Some(table_columns) = table_column_types.get(&table_name) { if let Some(column_type) = table_columns.get(&column_name) { let normalized_type = normalize_data_type(column_type); // Check if this type is prohibited in math operations if MATH_PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) { return Err(Status::invalid_argument(format!( "Cannot use column '{}' of type '{}' from table '{}' in mathematical operations. Mathematical operations cannot use columns of type: {}", column_name, column_type, table_name, MATH_PROHIBITED_TYPES.join(", ") ))); } } else { return Err(Status::invalid_argument(format!( "Script references column '{}' in table '{}' but this column does not exist", column_name, table_name ))); } } else { return Err(Status::invalid_argument(format!( "Script references table '{}' in mathematical operations but this table does not exist in this schema", table_name ))); } } Ok(()) } /// Validates the target column and ensures it is not a system column or prohibited type. /// Returns the column type if valid. fn validate_target_column( table_name: &str, target: &str, table_columns: &Value, ) -> Result { if SYSTEM_COLUMNS.contains(&target) { return Err(format!("Cannot override system column: {}", target)); } // Parse the columns JSON into a vector of strings let columns: Vec = serde_json::from_value(table_columns.clone()) .map_err(|e| format!("Invalid column data: {}", e))?; // Extract column names and types let column_info: Vec<(&str, &str)> = columns .iter() .filter_map(|c| { let mut parts = c.split_whitespace(); let name = parts.next()?.trim_matches('"'); let data_type = parts.next()?; Some((name, data_type)) }) .collect(); // Find the target column and return its type let column_type = column_info .iter() .find(|(name, _)| *name == target) .map(|(_, dt)| dt.to_string()) .ok_or_else(|| format!("Target column '{}' not defined in table '{}'", target, table_name))?; // Check if the target column type is prohibited if is_prohibited_type(&column_type) { return Err(format!( "Cannot create script for column '{}' with type '{}'. Steel scripts cannot target columns of type: {}", target, column_type, PROHIBITED_TYPES.join(", ") )); } // Add helpful info for boolean columns let normalized_type = normalize_data_type(&column_type); if normalized_type == "BOOLEAN" || normalized_type == "BOOL" { println!("Info: Target column '{}' is boolean type. Values will be converted to Steel format (#true/#false)", target); } Ok(column_type) } /// Check if a data type is prohibited for Steel scripts /// Note: BOOLEAN/BOOL is explicitly allowed and handled with special conversion fn is_prohibited_type(data_type: &str) -> bool { let normalized_type = normalize_data_type(data_type); PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) } /// Normalize data type for comparison (handle NUMERIC variations, etc.) fn normalize_data_type(data_type: &str) -> String { data_type.to_uppercase() .split('(') // Remove precision/scale from NUMERIC(x,y) .next() .unwrap_or(data_type) .trim() .to_string() } /// Parse Steel script to extract all table/column references fn extract_column_references_from_script(script: &str) -> Vec<(String, String)> { let mut references = Vec::new(); // Regex patterns to match Steel function calls let patterns = [ // (steel_get_column "table_name" "column_name") r#"\(steel_get_column\s+"([^"]+)"\s+"([^"]+)"\)"#, // (steel_get_column_with_index "table_name" index "column_name") r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#, ]; for pattern in &patterns { if let Ok(re) = Regex::new(pattern) { for cap in re.captures_iter(script) { if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) { references.push((table.as_str().to_string(), column.as_str().to_string())); } } } } // Also check for steel_get_column_with_index pattern (table, column are in different positions) if let Ok(re) = Regex::new(r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#) { for cap in re.captures_iter(script) { if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) { references.push((table.as_str().to_string(), column.as_str().to_string())); } } } references } /// Validate that script doesn't reference prohibited column types by checking actual DB schema async fn validate_script_column_references( db_pool: &PgPool, schema_id: i64, script: &str, ) -> Result<(), Status> { // Extract all table/column references from the script let references = extract_column_references_from_script(script); if references.is_empty() { return Ok(()); // No column references to validate } // Get all unique table names referenced in the script let table_names: HashSet = references.iter() .map(|(table, _)| table.clone()) .collect(); // Fetch table definitions for all referenced tables for table_name in table_names { // Query the actual table definition from the database let table_def = sqlx::query!( r#"SELECT table_name, columns FROM table_definitions WHERE schema_id = $1 AND table_name = $2"#, schema_id, table_name ) .fetch_optional(db_pool) .await .map_err(|e| Status::internal(format!("Failed to fetch table definition for '{}': {}", table_name, e)))?; if let Some(table_def) = table_def { // Check each column reference for this table for (ref_table, ref_column) in &references { if ref_table == &table_name { // Validate this specific column reference if let Err(error_msg) = validate_referenced_column_type(&table_name, ref_column, &table_def.columns) { return Err(Status::invalid_argument(error_msg)); } } } } else { return Err(Status::invalid_argument(format!( "Script references table '{}' which does not exist in this schema", table_name ))); } } Ok(()) } /// Validate that a referenced column doesn't have a prohibited type fn validate_referenced_column_type(table_name: &str, column_name: &str, table_columns: &Value) -> Result<(), String> { // Parse the columns JSON into a vector of strings let columns: Vec = serde_json::from_value(table_columns.clone()) .map_err(|e| format!("Invalid column data for table '{}': {}", table_name, e))?; // Extract column names and types let column_info: Vec<(&str, &str)> = columns .iter() .filter_map(|c| { let mut parts = c.split_whitespace(); let name = parts.next()?.trim_matches('"'); let data_type = parts.next()?; Some((name, data_type)) }) .collect(); // Find the referenced column and check its type if let Some((_, column_type)) = column_info.iter().find(|(name, _)| *name == column_name) { if is_prohibited_type(column_type) { return Err(format!( "Script references column '{}' in table '{}' which has prohibited type '{}'. Steel scripts cannot access columns of type: {}", column_name, table_name, column_type, PROHIBITED_TYPES.join(", ") )); } // Log info for boolean columns let normalized_type = normalize_data_type(column_type); if normalized_type == "BOOLEAN" || normalized_type == "BOOL" { println!("Info: Script references boolean column '{}' in table '{}'. Values will be converted to Steel format (#true/#false)", column_name, table_name); } } else { return Err(format!( "Script references column '{}' in table '{}' but this column does not exist", column_name, table_name )); } Ok(()) } /// Parse Steel SQL queries to check for prohibited type usage (basic heuristic) fn validate_sql_queries_in_script(script: &str) -> Result<(), String> { // Look for steel_query_sql calls if let Ok(re) = Regex::new(r#"\(steel_query_sql\s+"([^"]+)"\)"#) { for cap in re.captures_iter(script) { if let Some(query) = cap.get(1) { let sql = query.as_str().to_uppercase(); // Basic heuristic checks for prohibited type operations let prohibited_patterns = [ "EXTRACT(", "DATE_PART(", "::DATE", "::TIMESTAMPTZ", "::BIGINT", "CAST(", // Could be casting to prohibited types ]; for pattern in &prohibited_patterns { if sql.contains(pattern) { return Err(format!( "Script contains SQL query with potentially prohibited type operations: '{}'. Steel scripts cannot use operations on types: {}", query.as_str(), PROHIBITED_TYPES.join(", ") )); } } } } } Ok(()) } /// Handles the creation of a new table script with dependency validation. pub async fn post_table_script( db_pool: &PgPool, request: PostTableScriptRequest, ) -> Result { // Basic script validation first validate_script_basic_syntax(&request.script)?; // Start a transaction for ALL operations - critical for atomicity let mut tx = db_pool.begin().await .map_err(|e| Status::internal(format!("Failed to start transaction: {}", e)))?; // Fetch the table definition let table_def = sqlx::query!( r#"SELECT id, table_name, columns, schema_id FROM table_definitions WHERE id = $1"#, request.table_definition_id ) .fetch_optional(&mut *tx) .await .map_err(|e| Status::internal(format!("Failed to fetch table definition: {}", e)))? .ok_or_else(|| Status::not_found("Table definition not found"))?; // Validate the target column and get its type (includes prohibited type check) let column_type = validate_target_column( &table_def.table_name, &request.target_column, &table_def.columns, ) .map_err(|e| Status::invalid_argument(e))?; // REORDER: Math validation FIRST so we get specific error messages for math operations validate_math_operations_column_types(db_pool, table_def.schema_id, &request.script).await?; // THEN general column validation (catches non-math prohibited access) validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?; // Validate SQL queries in script for prohibited type operations validate_sql_queries_in_script(&request.script) .map_err(|e| Status::invalid_argument(e))?; // Create dependency analyzer for this schema let analyzer = DependencyAnalyzer::new(table_def.schema_id, db_pool.clone()); // Analyze script dependencies let dependencies = analyzer .analyze_script_dependencies(&request.script) .map_err(|e| Status::from(e))?; // Check for circular dependencies BEFORE making any changes // Pass the transaction to ensure we see any existing dependencies analyzer .check_for_cycles(&mut tx, table_def.id, &dependencies) .await .map_err(|e| Status::from(e))?; // Transform the script using steel_decimal (this happens AFTER validation) let steel_decimal = SteelDecimal::new(); let parsed_script = steel_decimal.transform(&request.script); // Insert or update the script let script_record = sqlx::query!( r#"INSERT INTO table_scripts (table_definitions_id, target_table, target_column, target_column_type, script, description, schema_id) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (table_definitions_id, target_column) DO UPDATE SET script = EXCLUDED.script, description = EXCLUDED.description, target_column_type = EXCLUDED.target_column_type RETURNING id"#, request.table_definition_id, table_def.table_name, request.target_column, column_type, parsed_script, request.description, table_def.schema_id ) .fetch_one(&mut *tx) .await .map_err(|e| { match e { SqlxError::Database(db_err) if db_err.constraint() == Some("table_scripts_table_definitions_id_target_column_key") => { Status::already_exists("Script already exists for this column") } _ => Status::internal(format!("Failed to insert script: {}", e)), } })?; // Save the dependencies within the same transaction analyzer .save_dependencies(&mut tx, script_record.id, table_def.id, &dependencies) .await .map_err(|e| Status::from(e))?; // Only now commit the entire transaction - script + dependencies together tx.commit().await .map_err(|e| Status::internal(format!("Failed to commit transaction: {}", e)))?; // Generate warnings for potential issues let warnings = generate_warnings(&dependencies, &table_def.table_name); Ok(TableScriptResponse { id: script_record.id, warnings, }) } /// Generate helpful warnings for script dependencies fn generate_warnings(dependencies: &[crate::table_script::handlers::dependency_analyzer::Dependency], table_name: &str) -> String { let mut warnings = Vec::new(); // Check for self-references if dependencies.iter().any(|d| d.target_table == table_name) { warnings.push("Warning: Script references its own table, which may cause issues during initial population.".to_string()); } // Check for complex SQL queries let sql_deps_count = dependencies.iter() .filter(|d| matches!(d.dependency_type, crate::table_script::handlers::dependency_analyzer::DependencyType::SqlQuery { .. })) .count(); if sql_deps_count > 0 { warnings.push(format!( "Warning: Script contains {} raw SQL quer{}, ensure they are read-only and reference valid tables.", sql_deps_count, if sql_deps_count == 1 { "y" } else { "ies" } )); } // Check for many dependencies if dependencies.len() > 5 { warnings.push(format!( "Warning: Script depends on {} tables, which may affect processing performance.", dependencies.len() )); } // Count structured access dependencies let structured_deps_count = dependencies.iter() .filter(|d| matches!( d.dependency_type, crate::table_script::handlers::dependency_analyzer::DependencyType::ColumnAccess { .. } | crate::table_script::handlers::dependency_analyzer::DependencyType::IndexedAccess { .. } )) .count(); if structured_deps_count > 0 { warnings.push(format!( "Info: Script uses {} linked table{} via steel_get_column functions.", structured_deps_count, if structured_deps_count == 1 { "" } else { "s" } )); } warnings.join(" ") }