// src/table_script/handlers/post_table_script.rs // TODO MAKE THE SCRIPTS PUSH ONLY TO THE EMPTY FILES use tonic::Status; use sqlx::{PgPool, Error as SqlxError}; use common::proto::multieko2::table_script::{PostTableScriptRequest, TableScriptResponse}; use serde_json::Value; use steel_decimal::SteelDecimal; use regex::Regex; use std::collections::HashSet; use std::collections::HashMap; use crate::table_script::handlers::dependency_analyzer::DependencyAnalyzer; const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"]; // Define prohibited data types for Steel scripts (boolean is explicitly allowed) const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"]; const MATH_PROHIBITED_TYPES: &[&str] = &["TEXT", "BOOLEAN"]; /// Extract mathematical expressions from the original script (before steel_decimal transformation) fn extract_math_operations_with_operands(script: &str) -> Vec { let mut math_operands = Vec::new(); // Define math operation patterns that steel_decimal will transform let math_patterns = [ r"\(\s*\+\s+([^)]+)\)", // (+ operands) r"\(\s*-\s+([^)]+)\)", // (- operands) r"\(\s*\*\s+([^)]+)\)", // (* operands) r"\(\s*/\s+([^)]+)\)", // (/ operands) r"\(\s*\^\s+([^)]+)\)", // (^ operands) r"\(\s*\*\*\s+([^)]+)\)", // (** operands) r"\(\s*pow\s+([^)]+)\)", // (pow operands) r"\(\s*sqrt\s+([^)]+)\)", // (sqrt operands) r"\(\s*>\s+([^)]+)\)", // (> operands) r"\(\s*<\s+([^)]+)\)", // (< operands) r"\(\s*=\s+([^)]+)\)", // (= operands) r"\(\s*>=\s+([^)]+)\)", // (>= operands) r"\(\s*<=\s+([^)]+)\)", // (<= operands) r"\(\s*min\s+([^)]+)\)", // (min operands) r"\(\s*max\s+([^)]+)\)", // (max operands) r"\(\s*abs\s+([^)]+)\)", // (abs operands) r"\(\s*round\s+([^)]+)\)", // (round operands) r"\(\s*ln\s+([^)]+)\)", // (ln operands) r"\(\s*log\s+([^)]+)\)", // (log operands) r"\(\s*log10\s+([^)]+)\)", // (log10 operands) r"\(\s*exp\s+([^)]+)\)", // (exp operands) r"\(\s*sin\s+([^)]+)\)", // (sin operands) r"\(\s*cos\s+([^)]+)\)", // (cos operands) r"\(\s*tan\s+([^)]+)\)", // (tan operands) ]; for pattern in &math_patterns { if let Ok(re) = Regex::new(pattern) { for cap in re.captures_iter(script) { if let Some(operands_str) = cap.get(1) { // Add all operands from this math operation math_operands.push(operands_str.as_str().to_string()); } } } } math_operands } /// Extract column references from mathematical operands fn extract_column_references_from_math_operands(operands: &[String]) -> Vec<(String, String)> { let mut references = Vec::new(); for operand_str in operands { // Check for steel_get_column calls: (steel_get_column "table" "column") if let Ok(re) = Regex::new(r#"\(steel_get_column\s+"([^"]+)"\s+"([^"]+)"\)"#) { for cap in re.captures_iter(operand_str) { if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) { references.push((table.as_str().to_string(), column.as_str().to_string())); } } } // Check for steel_get_column_with_index calls if let Ok(re) = Regex::new(r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#) { for cap in re.captures_iter(operand_str) { if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) { references.push((table.as_str().to_string(), column.as_str().to_string())); } } } } references } /// Validate that mathematical operations don't use TEXT or BOOLEAN columns async fn validate_math_operations_column_types( db_pool: &PgPool, schema_id: i64, script: &str, ) -> Result<(), Status> { // Extract all mathematical operations and their operands let math_operands = extract_math_operations_with_operands(script); if math_operands.is_empty() { return Ok(()); // No math operations to validate } // Extract column references from math operands let column_refs = extract_column_references_from_math_operands(&math_operands); if column_refs.is_empty() { return Ok(()); // No column references in math operations } // Get all unique table names referenced in math operations let table_names: HashSet = column_refs.iter() .map(|(table, _)| table.clone()) .collect(); // Fetch table definitions for all referenced tables let table_definitions = sqlx::query!( r#"SELECT table_name, columns FROM table_definitions WHERE schema_id = $1 AND table_name = ANY($2)"#, schema_id, &table_names.into_iter().collect::>() ) .fetch_all(db_pool) .await .map_err(|e| Status::internal(format!("Failed to fetch table definitions: {}", e)))?; // Build a map of table_name -> column_name -> column_type let mut table_column_types: HashMap> = HashMap::new(); for table_def in table_definitions { let columns: Vec = serde_json::from_value(table_def.columns) .map_err(|e| Status::internal(format!("Invalid column data for table '{}': {}", table_def.table_name, e)))?; let mut column_types = HashMap::new(); for column_def in columns { let mut parts = column_def.split_whitespace(); if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) { let column_name = name.trim_matches('"'); column_types.insert(column_name.to_string(), data_type.to_string()); } } table_column_types.insert(table_def.table_name, column_types); } // Check each column reference in mathematical operations for (table_name, column_name) in column_refs { if let Some(table_columns) = table_column_types.get(&table_name) { if let Some(column_type) = table_columns.get(&column_name) { let normalized_type = normalize_data_type(column_type); // Check if this type is prohibited in math operations if MATH_PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) { return Err(Status::invalid_argument(format!( "Cannot use column '{}' of type '{}' from table '{}' in mathematical operations. Mathematical operations cannot use columns of type: {}", column_name, column_type, table_name, MATH_PROHIBITED_TYPES.join(", ") ))); } } else { return Err(Status::invalid_argument(format!( "Script references column '{}' in table '{}' but this column does not exist", column_name, table_name ))); } } else { return Err(Status::invalid_argument(format!( "Script references table '{}' in mathematical operations but this table does not exist in this schema", table_name ))); } } Ok(()) } /// Validates the target column and ensures it is not a system column or prohibited type. /// Returns the column type if valid. fn validate_target_column( table_name: &str, target: &str, table_columns: &Value, ) -> Result { if SYSTEM_COLUMNS.contains(&target) { return Err(format!("Cannot override system column: {}", target)); } // Parse the columns JSON into a vector of strings let columns: Vec = serde_json::from_value(table_columns.clone()) .map_err(|e| format!("Invalid column data: {}", e))?; // Extract column names and types let column_info: Vec<(&str, &str)> = columns .iter() .filter_map(|c| { let mut parts = c.split_whitespace(); let name = parts.next()?.trim_matches('"'); let data_type = parts.next()?; Some((name, data_type)) }) .collect(); // Find the target column and return its type let column_type = column_info .iter() .find(|(name, _)| *name == target) .map(|(_, dt)| dt.to_string()) .ok_or_else(|| format!("Target column '{}' not defined in table '{}'", target, table_name))?; // Check if the target column type is prohibited if is_prohibited_type(&column_type) { return Err(format!( "Cannot create script for column '{}' with type '{}'. Steel scripts cannot target columns of type: {}", target, column_type, PROHIBITED_TYPES.join(", ") )); } // Add helpful info for boolean columns let normalized_type = normalize_data_type(&column_type); if normalized_type == "BOOLEAN" || normalized_type == "BOOL" { println!("Info: Target column '{}' is boolean type. Values will be converted to Steel format (#true/#false)", target); } Ok(column_type) } /// Check if a data type is prohibited for Steel scripts /// Note: BOOLEAN/BOOL is explicitly allowed and handled with special conversion fn is_prohibited_type(data_type: &str) -> bool { let normalized_type = normalize_data_type(data_type); PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) } /// Normalize data type for comparison (handle NUMERIC variations, etc.) fn normalize_data_type(data_type: &str) -> String { data_type.to_uppercase() .split('(') // Remove precision/scale from NUMERIC(x,y) .next() .unwrap_or(data_type) .trim() .to_string() } /// Parse Steel script to extract all table/column references fn extract_column_references_from_script(script: &str) -> Vec<(String, String)> { let mut references = Vec::new(); // Regex patterns to match Steel function calls let patterns = [ // (steel_get_column "table_name" "column_name") r#"\(steel_get_column\s+"([^"]+)"\s+"([^"]+)"\)"#, // (steel_get_column_with_index "table_name" index "column_name") r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#, ]; for pattern in &patterns { if let Ok(re) = Regex::new(pattern) { for cap in re.captures_iter(script) { if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) { references.push((table.as_str().to_string(), column.as_str().to_string())); } } } } // Also check for steel_get_column_with_index pattern (table, column are in different positions) if let Ok(re) = Regex::new(r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#) { for cap in re.captures_iter(script) { if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) { references.push((table.as_str().to_string(), column.as_str().to_string())); } } } references } /// Validate that script doesn't reference prohibited column types by checking actual DB schema async fn validate_script_column_references( db_pool: &PgPool, schema_id: i64, script: &str, ) -> Result<(), Status> { // Extract all table/column references from the script let references = extract_column_references_from_script(script); if references.is_empty() { return Ok(()); // No column references to validate } // Get all unique table names referenced in the script let table_names: HashSet = references.iter() .map(|(table, _)| table.clone()) .collect(); // Fetch table definitions for all referenced tables for table_name in table_names { // Query the actual table definition from the database let table_def = sqlx::query!( r#"SELECT table_name, columns FROM table_definitions WHERE schema_id = $1 AND table_name = $2"#, schema_id, table_name ) .fetch_optional(db_pool) .await .map_err(|e| Status::internal(format!("Failed to fetch table definition for '{}': {}", table_name, e)))?; if let Some(table_def) = table_def { // Check each column reference for this table for (ref_table, ref_column) in &references { if ref_table == &table_name { // Validate this specific column reference if let Err(error_msg) = validate_referenced_column_type(&table_name, ref_column, &table_def.columns) { return Err(Status::invalid_argument(error_msg)); } } } } else { return Err(Status::invalid_argument(format!( "Script references table '{}' which does not exist in this schema", table_name ))); } } Ok(()) } /// Validate that a referenced column doesn't have a prohibited type fn validate_referenced_column_type(table_name: &str, column_name: &str, table_columns: &Value) -> Result<(), String> { // Parse the columns JSON into a vector of strings let columns: Vec = serde_json::from_value(table_columns.clone()) .map_err(|e| format!("Invalid column data for table '{}': {}", table_name, e))?; // Extract column names and types let column_info: Vec<(&str, &str)> = columns .iter() .filter_map(|c| { let mut parts = c.split_whitespace(); let name = parts.next()?.trim_matches('"'); let data_type = parts.next()?; Some((name, data_type)) }) .collect(); // Find the referenced column and check its type if let Some((_, column_type)) = column_info.iter().find(|(name, _)| *name == column_name) { if is_prohibited_type(column_type) { return Err(format!( "Script references column '{}' in table '{}' which has prohibited type '{}'. Steel scripts cannot access columns of type: {}", column_name, table_name, column_type, PROHIBITED_TYPES.join(", ") )); } // Log info for boolean columns let normalized_type = normalize_data_type(column_type); if normalized_type == "BOOLEAN" || normalized_type == "BOOL" { println!("Info: Script references boolean column '{}' in table '{}'. Values will be converted to Steel format (#true/#false)", column_name, table_name); } } else { return Err(format!( "Script references column '{}' in table '{}' but this column does not exist", column_name, table_name )); } Ok(()) } /// Parse Steel SQL queries to check for prohibited type usage (basic heuristic) fn validate_sql_queries_in_script(script: &str) -> Result<(), String> { // Look for steel_query_sql calls if let Ok(re) = Regex::new(r#"\(steel_query_sql\s+"([^"]+)"\)"#) { for cap in re.captures_iter(script) { if let Some(query) = cap.get(1) { let sql = query.as_str().to_uppercase(); // Basic heuristic checks for prohibited type operations let prohibited_patterns = [ "EXTRACT(", "DATE_PART(", "::DATE", "::TIMESTAMPTZ", "::BIGINT", "CAST(", // Could be casting to prohibited types ]; for pattern in &prohibited_patterns { if sql.contains(pattern) { return Err(format!( "Script contains SQL query with potentially prohibited type operations: '{}'. Steel scripts cannot use operations on types: {}", query.as_str(), PROHIBITED_TYPES.join(", ") )); } } } } } Ok(()) } /// Handles the creation of a new table script with dependency validation. pub async fn post_table_script( db_pool: &PgPool, request: PostTableScriptRequest, ) -> Result { // Start a transaction for ALL operations - critical for atomicity let mut tx = db_pool.begin().await .map_err(|e| Status::internal(format!("Failed to start transaction: {}", e)))?; // Fetch the table definition let table_def = sqlx::query!( r#"SELECT id, table_name, columns, schema_id FROM table_definitions WHERE id = $1"#, request.table_definition_id ) .fetch_optional(&mut *tx) .await .map_err(|e| Status::internal(format!("Failed to fetch table definition: {}", e)))? .ok_or_else(|| Status::not_found("Table definition not found"))?; // Validate the target column and get its type (includes prohibited type check) let column_type = validate_target_column( &table_def.table_name, &request.target_column, &table_def.columns, ) .map_err(|e| Status::invalid_argument(e))?; // Validate that script doesn't reference prohibited column types by checking actual DB schema validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?; // NEW: Validate that mathematical operations don't use TEXT or BOOLEAN columns validate_math_operations_column_types(db_pool, table_def.schema_id, &request.script).await?; // Validate SQL queries in script for prohibited type operations validate_sql_queries_in_script(&request.script) .map_err(|e| Status::invalid_argument(e))?; // Create dependency analyzer for this schema let analyzer = DependencyAnalyzer::new(table_def.schema_id, db_pool.clone()); // Analyze script dependencies let dependencies = analyzer .analyze_script_dependencies(&request.script) .map_err(|e| Status::from(e))?; // Check for circular dependencies BEFORE making any changes // Pass the transaction to ensure we see any existing dependencies analyzer .check_for_cycles(&mut tx, table_def.id, &dependencies) .await .map_err(|e| Status::from(e))?; // Transform the script using steel_decimal (this happens AFTER validation) let steel_decimal = SteelDecimal::new(); let parsed_script = steel_decimal.transform(&request.script); // Insert or update the script let script_record = sqlx::query!( r#"INSERT INTO table_scripts (table_definitions_id, target_table, target_column, target_column_type, script, description, schema_id) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (table_definitions_id, target_column) DO UPDATE SET script = EXCLUDED.script, description = EXCLUDED.description, target_column_type = EXCLUDED.target_column_type RETURNING id"#, request.table_definition_id, table_def.table_name, request.target_column, column_type, parsed_script, request.description, table_def.schema_id ) .fetch_one(&mut *tx) .await .map_err(|e| { match e { SqlxError::Database(db_err) if db_err.constraint() == Some("table_scripts_table_definitions_id_target_column_key") => { Status::already_exists("Script already exists for this column") } _ => Status::internal(format!("Failed to insert script: {}", e)), } })?; // Save the dependencies within the same transaction analyzer .save_dependencies(&mut tx, script_record.id, table_def.id, &dependencies) .await .map_err(|e| Status::from(e))?; // Only now commit the entire transaction - script + dependencies together tx.commit().await .map_err(|e| Status::internal(format!("Failed to commit transaction: {}", e)))?; // Generate warnings for potential issues let warnings = generate_warnings(&dependencies, &table_def.table_name); Ok(TableScriptResponse { id: script_record.id, warnings, }) } /// Generate helpful warnings for script dependencies fn generate_warnings(dependencies: &[crate::table_script::handlers::dependency_analyzer::Dependency], table_name: &str) -> String { let mut warnings = Vec::new(); // Check for self-references if dependencies.iter().any(|d| d.target_table == table_name) { warnings.push("Warning: Script references its own table, which may cause issues during initial population.".to_string()); } // Check for complex SQL queries let sql_deps_count = dependencies.iter() .filter(|d| matches!(d.dependency_type, crate::table_script::handlers::dependency_analyzer::DependencyType::SqlQuery { .. })) .count(); if sql_deps_count > 0 { warnings.push(format!( "Warning: Script contains {} raw SQL quer{}, ensure they are read-only and reference valid tables.", sql_deps_count, if sql_deps_count == 1 { "y" } else { "ies" } )); } // Check for many dependencies if dependencies.len() > 5 { warnings.push(format!( "Warning: Script depends on {} tables, which may affect processing performance.", dependencies.len() )); } // Count structured access dependencies let structured_deps_count = dependencies.iter() .filter(|d| matches!( d.dependency_type, crate::table_script::handlers::dependency_analyzer::DependencyType::ColumnAccess { .. } | crate::table_script::handlers::dependency_analyzer::DependencyType::IndexedAccess { .. } )) .count(); if structured_deps_count > 0 { warnings.push(format!( "Info: Script uses {} linked table{} via steel_get_column functions.", structured_deps_count, if structured_deps_count == 1 { "" } else { "s" } )); } warnings.join(" ") }