diff --git a/server/src/table_script/handlers/post_table_script.rs b/server/src/table_script/handlers/post_table_script.rs index b10f072..de1caa7 100644 --- a/server/src/table_script/handlers/post_table_script.rs +++ b/server/src/table_script/handlers/post_table_script.rs @@ -15,13 +15,13 @@ use crate::table_script::handlers::dependency_analyzer::DependencyAnalyzer; const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"]; // Define prohibited data types for Steel scripts (boolean is explicitly allowed) -const PROHIBITED_TYPES: &[&str] = &["DATE", "TIMESTAMPTZ"]; -const MATH_PROHIBITED_TYPES: &[&str] = &["BIGINT", "TEXT", "BOOLEAN"]; +const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"]; +const MATH_PROHIBITED_TYPES: &[&str] = &["BIGINT", "TEXT", "BOOLEAN", "DATE", "TIMESTAMPTZ"]; // Math operations that Steel Decimal will transform const MATH_OPERATIONS: &[&str] = &[ - "+", "-", "*", "/", "^", "**", "pow", "sqrt", - ">", "<", "=", ">=", "<=", "min", "max", "abs", + "+", "-", "*", "/", "^", "**", "pow", "sqrt", + ">", "<", "=", ">=", "<=", "min", "max", "abs", "round", "ln", "log", "log10", "exp", "sin", "cos", "tan" ]; @@ -42,20 +42,20 @@ impl Parser { let tokens = Self::tokenize(script); Self { tokens, position: 0 } } - + fn tokenize(script: &str) -> Vec { let mut tokens = Vec::new(); let mut current_token = String::new(); let mut in_string = false; let mut escape_next = false; - + for ch in script.chars() { if escape_next { current_token.push(ch); escape_next = false; continue; } - + match ch { '\\' if in_string => { escape_next = true; @@ -88,43 +88,43 @@ impl Parser { } } } - + if !current_token.is_empty() { tokens.push(current_token); } - + tokens } - + fn parse(&mut self) -> Result, String> { let mut expressions = Vec::new(); - + while self.position < self.tokens.len() { expressions.push(self.parse_expr()?); } - + Ok(expressions) } - + fn parse_expr(&mut self) -> Result { if self.position >= self.tokens.len() { return Err("Unexpected end of input".to_string()); } - + let token = &self.tokens[self.position]; - + if token == "(" { self.position += 1; // consume '(' let mut elements = Vec::new(); - + while self.position < self.tokens.len() && self.tokens[self.position] != ")" { elements.push(self.parse_expr()?); } - + if self.position >= self.tokens.len() { return Err("Missing closing parenthesis".to_string()); } - + self.position += 1; // consume ')' Ok(SExpr::List(elements)) } else { @@ -145,14 +145,14 @@ impl MathValidator { column_references: Vec::new(), } } - + fn validate_expressions(&mut self, expressions: &[SExpr]) -> Result<(), String> { for expr in expressions { self.check_expression(expr, false)?; } Ok(()) } - + fn check_expression(&mut self, expr: &SExpr, in_math_context: bool) -> Result<(), String> { match expr { SExpr::Atom(_) => Ok(()), @@ -160,14 +160,14 @@ impl MathValidator { if elements.is_empty() { return Ok(()); } - + // Check if this is a math operation let is_math = if let SExpr::Atom(op) = &elements[0] { MATH_OPERATIONS.contains(&op.as_str()) } else { false }; - + // Check if this is a column access function if let SExpr::Atom(func) = &elements[0] { if func == "steel_get_column" && in_math_context { @@ -176,17 +176,17 @@ impl MathValidator { self.extract_column_reference_from_steel_get_column_with_index(elements)?; } } - + // Recursively check all elements, marking math context appropriately for element in &elements[1..] { // Skip the operator/function name self.check_expression(element, in_math_context || is_math)?; } - + Ok(()) } } } - + fn extract_column_reference_from_steel_get_column(&mut self, elements: &[SExpr]) -> Result<(), String> { // (steel_get_column "table" "column") if elements.len() >= 3 { @@ -198,7 +198,7 @@ impl MathValidator { } Ok(()) } - + fn extract_column_reference_from_steel_get_column_with_index(&mut self, elements: &[SExpr]) -> Result<(), String> { // (steel_get_column_with_index "table" index "column") if elements.len() >= 4 { @@ -210,7 +210,7 @@ impl MathValidator { } Ok(()) } - + fn unquote_string(&self, s: &str) -> Result { if s.starts_with('"') && s.ends_with('"') && s.len() >= 2 { Ok(s[1..s.len()-1].to_string()) @@ -223,12 +223,12 @@ impl MathValidator { /// Valide script is not empty fn validate_script_basic_syntax(script: &str) -> Result<(), Status> { let trimmed = script.trim(); - + // Check for empty script if trimmed.is_empty() { return Err(Status::invalid_argument("Script cannot be empty")); } - + // Basic parentheses balance check let mut paren_count = 0; for ch in trimmed.chars() { @@ -243,16 +243,16 @@ fn validate_script_basic_syntax(script: &str) -> Result<(), Status> { _ => {} } } - + if paren_count != 0 { return Err(Status::invalid_argument("Unbalanced parentheses: missing closing parentheses")); } - + // Check for basic S-expression structure if !trimmed.starts_with('(') { return Err(Status::invalid_argument("Script must start with an opening parenthesis '('")); } - + Ok(()) } @@ -261,11 +261,11 @@ fn extract_math_column_references(script: &str) -> Result, let mut parser = Parser::new(script); let expressions = parser.parse() .map_err(|e| format!("Parse error: {}", e))?; - + let mut validator = MathValidator::new(); validator.validate_expressions(&expressions) .map_err(|e| format!("Validation error: {}", e))?; - + Ok(validator.column_references) } @@ -278,19 +278,19 @@ async fn validate_math_operations_column_types( // Extract column references from mathematical contexts using proper S-expression parsing let column_refs = extract_math_column_references(script) .map_err(|e| Status::invalid_argument(format!("Script parsing failed: {}", e)))?; - + if column_refs.is_empty() { return Ok(()); // No column references in math operations } - + // Get all unique table names referenced in math operations let table_names: HashSet = column_refs.iter() .map(|(table, _)| table.clone()) .collect(); - + // Fetch table definitions for all referenced tables let table_definitions = sqlx::query!( - r#"SELECT table_name, columns FROM table_definitions + r#"SELECT table_name, columns FROM table_definitions WHERE schema_id = $1 AND table_name = ANY($2)"#, schema_id, &table_names.into_iter().collect::>() @@ -298,14 +298,14 @@ async fn validate_math_operations_column_types( .fetch_all(db_pool) .await .map_err(|e| Status::internal(format!("Failed to fetch table definitions: {}", e)))?; - + // Build a map of table_name -> column_name -> column_type let mut table_column_types: HashMap> = HashMap::new(); - + for table_def in table_definitions { let columns: Vec = serde_json::from_value(table_def.columns) .map_err(|e| Status::internal(format!("Invalid column data for table '{}': {}", table_def.table_name, e)))?; - + let mut column_types = HashMap::new(); for column_def in columns { let mut parts = column_def.split_whitespace(); @@ -316,13 +316,13 @@ async fn validate_math_operations_column_types( } table_column_types.insert(table_def.table_name, column_types); } - + // Check each column reference in mathematical operations for (table_name, column_name) in column_refs { if let Some(table_columns) = table_column_types.get(&table_name) { if let Some(column_type) = table_columns.get(&column_name) { let normalized_type = normalize_data_type(column_type); - + // Check if this type is prohibited in math operations if MATH_PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) { return Err(Status::invalid_argument(format!( @@ -347,7 +347,7 @@ async fn validate_math_operations_column_types( ))); } } - + Ok(()) } @@ -616,12 +616,12 @@ pub async fn post_table_script( ) .map_err(|e| Status::invalid_argument(e))?; - // Validate that script doesn't reference prohibited column types by checking actual DB schema - validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?; - - // NEW: Validate that mathematical operations don't use TEXT or BOOLEAN columns + // REORDER: Math validation FIRST so we get specific error messages for math operations validate_math_operations_column_types(db_pool, table_def.schema_id, &request.script).await?; + // THEN general column validation (catches non-math prohibited access) + validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?; + // Validate SQL queries in script for prohibited type operations validate_sql_queries_in_script(&request.script) .map_err(|e| Status::invalid_argument(e))?; diff --git a/server/tests/table_script/comprehensive_error_scenarios_tests.rs b/server/tests/table_script/comprehensive_error_scenarios_tests.rs index cda3298..5a64a69 100644 --- a/server/tests/table_script/comprehensive_error_scenarios_tests.rs +++ b/server/tests/table_script/comprehensive_error_scenarios_tests.rs @@ -218,59 +218,76 @@ async fn test_advanced_validation_scenarios( // Don't assert failure here - these are edge cases that might be handled differently } +async fn create_table_link(pool: &PgPool, source_table_id: i64, linked_table_id: i64) { + sqlx::query!( + "INSERT INTO table_definition_links (source_table_id, linked_table_id, is_required) + VALUES ($1, $2, false)", + source_table_id, + linked_table_id + ) + .execute(pool) + .await + .expect("Failed to create table link"); +} + #[tokio::test] async fn test_dependency_cycle_detection() { let pool = setup_isolated_db().await; let schema_id = get_default_schema_id(&pool).await; - // Create two tables for dependency testing - let table_a_columns = vec![ - ("value_a", "NUMERIC(10, 2)"), - ("result_a", "NUMERIC(10, 2)"), - ]; - let table_a_id = create_test_table(&pool, schema_id, "table_a", table_a_columns).await; - + // Create table_b first let table_b_columns = vec![ ("value_b", "NUMERIC(10, 2)"), ("result_b", "NUMERIC(10, 2)"), ]; let table_b_id = create_test_table(&pool, schema_id, "table_b", table_b_columns).await; - // Create first dependency: table_a.result_a depends on table_b.value_b + // Create table_a + let table_a_columns = vec![ + ("value_a", "NUMERIC(10, 2)"), + ("result_a", "NUMERIC(10, 2)"), + ]; + let table_a_id = create_test_table(&pool, schema_id, "table_a", table_a_columns).await; + + // CREATE BOTH LINKS for circular dependency testing + create_table_link(&pool, table_a_id, table_b_id).await; // table_a -> table_b + create_table_link(&pool, table_b_id, table_a_id).await; // table_b -> table_a + + // First dependency should work let script_a = r#"(+ (steel_get_column "table_b" "value_b") "10")"#; let request_a = PostTableScriptRequest { table_definition_id: table_a_id, target_column: "result_a".to_string(), script: script_a.to_string(), - description: "First dependency".to_string(), // Fixed: removed Some() + description: "First dependency".to_string(), }; let result_a = post_table_script(&pool, request_a).await; assert!(result_a.is_ok(), "First dependency should succeed"); - // Try to create circular dependency: table_b.result_b depends on table_a.result_a + // Try circular dependency - should now work since links exist both ways let script_b = r#"(* (steel_get_column "table_a" "result_a") "2")"#; let request_b = PostTableScriptRequest { table_definition_id: table_b_id, target_column: "result_b".to_string(), script: script_b.to_string(), - description: "Circular dependency attempt".to_string(), // Fixed: removed Some() + description: "Circular dependency attempt".to_string(), }; let result_b = post_table_script(&pool, request_b).await; - // Depending on implementation, this should either succeed or detect the cycle + // This should either succeed or detect the cycle match result_b { Ok(_) => { // Implementation allows this pattern } Err(error) => { // Implementation detects circular dependencies - let error_msg = error.to_string(); + let error_msg = error.to_string().to_lowercase(); assert!( error_msg.contains("cycle") || error_msg.contains("circular"), "Circular dependency should be detected properly: {}", - error_msg + error.to_string() ); } }