we got a full passer now

This commit is contained in:
filipriec
2025-07-20 14:58:39 +02:00
parent 84871faad4
commit 7e54b2fe43
2 changed files with 79 additions and 62 deletions

View File

@@ -15,13 +15,13 @@ use crate::table_script::handlers::dependency_analyzer::DependencyAnalyzer;
const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"]; const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"];
// Define prohibited data types for Steel scripts (boolean is explicitly allowed) // Define prohibited data types for Steel scripts (boolean is explicitly allowed)
const PROHIBITED_TYPES: &[&str] = &["DATE", "TIMESTAMPTZ"]; const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
const MATH_PROHIBITED_TYPES: &[&str] = &["BIGINT", "TEXT", "BOOLEAN"]; const MATH_PROHIBITED_TYPES: &[&str] = &["BIGINT", "TEXT", "BOOLEAN", "DATE", "TIMESTAMPTZ"];
// Math operations that Steel Decimal will transform // Math operations that Steel Decimal will transform
const MATH_OPERATIONS: &[&str] = &[ const MATH_OPERATIONS: &[&str] = &[
"+", "-", "*", "/", "^", "**", "pow", "sqrt", "+", "-", "*", "/", "^", "**", "pow", "sqrt",
">", "<", "=", ">=", "<=", "min", "max", "abs", ">", "<", "=", ">=", "<=", "min", "max", "abs",
"round", "ln", "log", "log10", "exp", "sin", "cos", "tan" "round", "ln", "log", "log10", "exp", "sin", "cos", "tan"
]; ];
@@ -42,20 +42,20 @@ impl Parser {
let tokens = Self::tokenize(script); let tokens = Self::tokenize(script);
Self { tokens, position: 0 } Self { tokens, position: 0 }
} }
fn tokenize(script: &str) -> Vec<String> { fn tokenize(script: &str) -> Vec<String> {
let mut tokens = Vec::new(); let mut tokens = Vec::new();
let mut current_token = String::new(); let mut current_token = String::new();
let mut in_string = false; let mut in_string = false;
let mut escape_next = false; let mut escape_next = false;
for ch in script.chars() { for ch in script.chars() {
if escape_next { if escape_next {
current_token.push(ch); current_token.push(ch);
escape_next = false; escape_next = false;
continue; continue;
} }
match ch { match ch {
'\\' if in_string => { '\\' if in_string => {
escape_next = true; escape_next = true;
@@ -88,43 +88,43 @@ impl Parser {
} }
} }
} }
if !current_token.is_empty() { if !current_token.is_empty() {
tokens.push(current_token); tokens.push(current_token);
} }
tokens tokens
} }
fn parse(&mut self) -> Result<Vec<SExpr>, String> { fn parse(&mut self) -> Result<Vec<SExpr>, String> {
let mut expressions = Vec::new(); let mut expressions = Vec::new();
while self.position < self.tokens.len() { while self.position < self.tokens.len() {
expressions.push(self.parse_expr()?); expressions.push(self.parse_expr()?);
} }
Ok(expressions) Ok(expressions)
} }
fn parse_expr(&mut self) -> Result<SExpr, String> { fn parse_expr(&mut self) -> Result<SExpr, String> {
if self.position >= self.tokens.len() { if self.position >= self.tokens.len() {
return Err("Unexpected end of input".to_string()); return Err("Unexpected end of input".to_string());
} }
let token = &self.tokens[self.position]; let token = &self.tokens[self.position];
if token == "(" { if token == "(" {
self.position += 1; // consume '(' self.position += 1; // consume '('
let mut elements = Vec::new(); let mut elements = Vec::new();
while self.position < self.tokens.len() && self.tokens[self.position] != ")" { while self.position < self.tokens.len() && self.tokens[self.position] != ")" {
elements.push(self.parse_expr()?); elements.push(self.parse_expr()?);
} }
if self.position >= self.tokens.len() { if self.position >= self.tokens.len() {
return Err("Missing closing parenthesis".to_string()); return Err("Missing closing parenthesis".to_string());
} }
self.position += 1; // consume ')' self.position += 1; // consume ')'
Ok(SExpr::List(elements)) Ok(SExpr::List(elements))
} else { } else {
@@ -145,14 +145,14 @@ impl MathValidator {
column_references: Vec::new(), column_references: Vec::new(),
} }
} }
fn validate_expressions(&mut self, expressions: &[SExpr]) -> Result<(), String> { fn validate_expressions(&mut self, expressions: &[SExpr]) -> Result<(), String> {
for expr in expressions { for expr in expressions {
self.check_expression(expr, false)?; self.check_expression(expr, false)?;
} }
Ok(()) Ok(())
} }
fn check_expression(&mut self, expr: &SExpr, in_math_context: bool) -> Result<(), String> { fn check_expression(&mut self, expr: &SExpr, in_math_context: bool) -> Result<(), String> {
match expr { match expr {
SExpr::Atom(_) => Ok(()), SExpr::Atom(_) => Ok(()),
@@ -160,14 +160,14 @@ impl MathValidator {
if elements.is_empty() { if elements.is_empty() {
return Ok(()); return Ok(());
} }
// Check if this is a math operation // Check if this is a math operation
let is_math = if let SExpr::Atom(op) = &elements[0] { let is_math = if let SExpr::Atom(op) = &elements[0] {
MATH_OPERATIONS.contains(&op.as_str()) MATH_OPERATIONS.contains(&op.as_str())
} else { } else {
false false
}; };
// Check if this is a column access function // Check if this is a column access function
if let SExpr::Atom(func) = &elements[0] { if let SExpr::Atom(func) = &elements[0] {
if func == "steel_get_column" && in_math_context { if func == "steel_get_column" && in_math_context {
@@ -176,17 +176,17 @@ impl MathValidator {
self.extract_column_reference_from_steel_get_column_with_index(elements)?; self.extract_column_reference_from_steel_get_column_with_index(elements)?;
} }
} }
// Recursively check all elements, marking math context appropriately // Recursively check all elements, marking math context appropriately
for element in &elements[1..] { // Skip the operator/function name for element in &elements[1..] { // Skip the operator/function name
self.check_expression(element, in_math_context || is_math)?; self.check_expression(element, in_math_context || is_math)?;
} }
Ok(()) Ok(())
} }
} }
} }
fn extract_column_reference_from_steel_get_column(&mut self, elements: &[SExpr]) -> Result<(), String> { fn extract_column_reference_from_steel_get_column(&mut self, elements: &[SExpr]) -> Result<(), String> {
// (steel_get_column "table" "column") // (steel_get_column "table" "column")
if elements.len() >= 3 { if elements.len() >= 3 {
@@ -198,7 +198,7 @@ impl MathValidator {
} }
Ok(()) Ok(())
} }
fn extract_column_reference_from_steel_get_column_with_index(&mut self, elements: &[SExpr]) -> Result<(), String> { fn extract_column_reference_from_steel_get_column_with_index(&mut self, elements: &[SExpr]) -> Result<(), String> {
// (steel_get_column_with_index "table" index "column") // (steel_get_column_with_index "table" index "column")
if elements.len() >= 4 { if elements.len() >= 4 {
@@ -210,7 +210,7 @@ impl MathValidator {
} }
Ok(()) Ok(())
} }
fn unquote_string(&self, s: &str) -> Result<String, String> { fn unquote_string(&self, s: &str) -> Result<String, String> {
if s.starts_with('"') && s.ends_with('"') && s.len() >= 2 { if s.starts_with('"') && s.ends_with('"') && s.len() >= 2 {
Ok(s[1..s.len()-1].to_string()) Ok(s[1..s.len()-1].to_string())
@@ -223,12 +223,12 @@ impl MathValidator {
/// Valide script is not empty /// Valide script is not empty
fn validate_script_basic_syntax(script: &str) -> Result<(), Status> { fn validate_script_basic_syntax(script: &str) -> Result<(), Status> {
let trimmed = script.trim(); let trimmed = script.trim();
// Check for empty script // Check for empty script
if trimmed.is_empty() { if trimmed.is_empty() {
return Err(Status::invalid_argument("Script cannot be empty")); return Err(Status::invalid_argument("Script cannot be empty"));
} }
// Basic parentheses balance check // Basic parentheses balance check
let mut paren_count = 0; let mut paren_count = 0;
for ch in trimmed.chars() { for ch in trimmed.chars() {
@@ -243,16 +243,16 @@ fn validate_script_basic_syntax(script: &str) -> Result<(), Status> {
_ => {} _ => {}
} }
} }
if paren_count != 0 { if paren_count != 0 {
return Err(Status::invalid_argument("Unbalanced parentheses: missing closing parentheses")); return Err(Status::invalid_argument("Unbalanced parentheses: missing closing parentheses"));
} }
// Check for basic S-expression structure // Check for basic S-expression structure
if !trimmed.starts_with('(') { if !trimmed.starts_with('(') {
return Err(Status::invalid_argument("Script must start with an opening parenthesis '('")); return Err(Status::invalid_argument("Script must start with an opening parenthesis '('"));
} }
Ok(()) Ok(())
} }
@@ -261,11 +261,11 @@ fn extract_math_column_references(script: &str) -> Result<Vec<(String, String)>,
let mut parser = Parser::new(script); let mut parser = Parser::new(script);
let expressions = parser.parse() let expressions = parser.parse()
.map_err(|e| format!("Parse error: {}", e))?; .map_err(|e| format!("Parse error: {}", e))?;
let mut validator = MathValidator::new(); let mut validator = MathValidator::new();
validator.validate_expressions(&expressions) validator.validate_expressions(&expressions)
.map_err(|e| format!("Validation error: {}", e))?; .map_err(|e| format!("Validation error: {}", e))?;
Ok(validator.column_references) Ok(validator.column_references)
} }
@@ -278,19 +278,19 @@ async fn validate_math_operations_column_types(
// Extract column references from mathematical contexts using proper S-expression parsing // Extract column references from mathematical contexts using proper S-expression parsing
let column_refs = extract_math_column_references(script) let column_refs = extract_math_column_references(script)
.map_err(|e| Status::invalid_argument(format!("Script parsing failed: {}", e)))?; .map_err(|e| Status::invalid_argument(format!("Script parsing failed: {}", e)))?;
if column_refs.is_empty() { if column_refs.is_empty() {
return Ok(()); // No column references in math operations return Ok(()); // No column references in math operations
} }
// Get all unique table names referenced in math operations // Get all unique table names referenced in math operations
let table_names: HashSet<String> = column_refs.iter() let table_names: HashSet<String> = column_refs.iter()
.map(|(table, _)| table.clone()) .map(|(table, _)| table.clone())
.collect(); .collect();
// Fetch table definitions for all referenced tables // Fetch table definitions for all referenced tables
let table_definitions = sqlx::query!( let table_definitions = sqlx::query!(
r#"SELECT table_name, columns FROM table_definitions r#"SELECT table_name, columns FROM table_definitions
WHERE schema_id = $1 AND table_name = ANY($2)"#, WHERE schema_id = $1 AND table_name = ANY($2)"#,
schema_id, schema_id,
&table_names.into_iter().collect::<Vec<_>>() &table_names.into_iter().collect::<Vec<_>>()
@@ -298,14 +298,14 @@ async fn validate_math_operations_column_types(
.fetch_all(db_pool) .fetch_all(db_pool)
.await .await
.map_err(|e| Status::internal(format!("Failed to fetch table definitions: {}", e)))?; .map_err(|e| Status::internal(format!("Failed to fetch table definitions: {}", e)))?;
// Build a map of table_name -> column_name -> column_type // Build a map of table_name -> column_name -> column_type
let mut table_column_types: HashMap<String, HashMap<String, String>> = HashMap::new(); let mut table_column_types: HashMap<String, HashMap<String, String>> = HashMap::new();
for table_def in table_definitions { for table_def in table_definitions {
let columns: Vec<String> = serde_json::from_value(table_def.columns) let columns: Vec<String> = serde_json::from_value(table_def.columns)
.map_err(|e| Status::internal(format!("Invalid column data for table '{}': {}", table_def.table_name, e)))?; .map_err(|e| Status::internal(format!("Invalid column data for table '{}': {}", table_def.table_name, e)))?;
let mut column_types = HashMap::new(); let mut column_types = HashMap::new();
for column_def in columns { for column_def in columns {
let mut parts = column_def.split_whitespace(); let mut parts = column_def.split_whitespace();
@@ -316,13 +316,13 @@ async fn validate_math_operations_column_types(
} }
table_column_types.insert(table_def.table_name, column_types); table_column_types.insert(table_def.table_name, column_types);
} }
// Check each column reference in mathematical operations // Check each column reference in mathematical operations
for (table_name, column_name) in column_refs { for (table_name, column_name) in column_refs {
if let Some(table_columns) = table_column_types.get(&table_name) { if let Some(table_columns) = table_column_types.get(&table_name) {
if let Some(column_type) = table_columns.get(&column_name) { if let Some(column_type) = table_columns.get(&column_name) {
let normalized_type = normalize_data_type(column_type); let normalized_type = normalize_data_type(column_type);
// Check if this type is prohibited in math operations // Check if this type is prohibited in math operations
if MATH_PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) { if MATH_PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) {
return Err(Status::invalid_argument(format!( return Err(Status::invalid_argument(format!(
@@ -347,7 +347,7 @@ async fn validate_math_operations_column_types(
))); )));
} }
} }
Ok(()) Ok(())
} }
@@ -616,12 +616,12 @@ pub async fn post_table_script(
) )
.map_err(|e| Status::invalid_argument(e))?; .map_err(|e| Status::invalid_argument(e))?;
// Validate that script doesn't reference prohibited column types by checking actual DB schema // REORDER: Math validation FIRST so we get specific error messages for math operations
validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?;
// NEW: Validate that mathematical operations don't use TEXT or BOOLEAN columns
validate_math_operations_column_types(db_pool, table_def.schema_id, &request.script).await?; validate_math_operations_column_types(db_pool, table_def.schema_id, &request.script).await?;
// THEN general column validation (catches non-math prohibited access)
validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?;
// Validate SQL queries in script for prohibited type operations // Validate SQL queries in script for prohibited type operations
validate_sql_queries_in_script(&request.script) validate_sql_queries_in_script(&request.script)
.map_err(|e| Status::invalid_argument(e))?; .map_err(|e| Status::invalid_argument(e))?;

View File

@@ -218,59 +218,76 @@ async fn test_advanced_validation_scenarios(
// Don't assert failure here - these are edge cases that might be handled differently // Don't assert failure here - these are edge cases that might be handled differently
} }
async fn create_table_link(pool: &PgPool, source_table_id: i64, linked_table_id: i64) {
sqlx::query!(
"INSERT INTO table_definition_links (source_table_id, linked_table_id, is_required)
VALUES ($1, $2, false)",
source_table_id,
linked_table_id
)
.execute(pool)
.await
.expect("Failed to create table link");
}
#[tokio::test] #[tokio::test]
async fn test_dependency_cycle_detection() { async fn test_dependency_cycle_detection() {
let pool = setup_isolated_db().await; let pool = setup_isolated_db().await;
let schema_id = get_default_schema_id(&pool).await; let schema_id = get_default_schema_id(&pool).await;
// Create two tables for dependency testing // Create table_b first
let table_a_columns = vec![
("value_a", "NUMERIC(10, 2)"),
("result_a", "NUMERIC(10, 2)"),
];
let table_a_id = create_test_table(&pool, schema_id, "table_a", table_a_columns).await;
let table_b_columns = vec![ let table_b_columns = vec![
("value_b", "NUMERIC(10, 2)"), ("value_b", "NUMERIC(10, 2)"),
("result_b", "NUMERIC(10, 2)"), ("result_b", "NUMERIC(10, 2)"),
]; ];
let table_b_id = create_test_table(&pool, schema_id, "table_b", table_b_columns).await; let table_b_id = create_test_table(&pool, schema_id, "table_b", table_b_columns).await;
// Create first dependency: table_a.result_a depends on table_b.value_b // Create table_a
let table_a_columns = vec![
("value_a", "NUMERIC(10, 2)"),
("result_a", "NUMERIC(10, 2)"),
];
let table_a_id = create_test_table(&pool, schema_id, "table_a", table_a_columns).await;
// CREATE BOTH LINKS for circular dependency testing
create_table_link(&pool, table_a_id, table_b_id).await; // table_a -> table_b
create_table_link(&pool, table_b_id, table_a_id).await; // table_b -> table_a
// First dependency should work
let script_a = r#"(+ (steel_get_column "table_b" "value_b") "10")"#; let script_a = r#"(+ (steel_get_column "table_b" "value_b") "10")"#;
let request_a = PostTableScriptRequest { let request_a = PostTableScriptRequest {
table_definition_id: table_a_id, table_definition_id: table_a_id,
target_column: "result_a".to_string(), target_column: "result_a".to_string(),
script: script_a.to_string(), script: script_a.to_string(),
description: "First dependency".to_string(), // Fixed: removed Some() description: "First dependency".to_string(),
}; };
let result_a = post_table_script(&pool, request_a).await; let result_a = post_table_script(&pool, request_a).await;
assert!(result_a.is_ok(), "First dependency should succeed"); assert!(result_a.is_ok(), "First dependency should succeed");
// Try to create circular dependency: table_b.result_b depends on table_a.result_a // Try circular dependency - should now work since links exist both ways
let script_b = r#"(* (steel_get_column "table_a" "result_a") "2")"#; let script_b = r#"(* (steel_get_column "table_a" "result_a") "2")"#;
let request_b = PostTableScriptRequest { let request_b = PostTableScriptRequest {
table_definition_id: table_b_id, table_definition_id: table_b_id,
target_column: "result_b".to_string(), target_column: "result_b".to_string(),
script: script_b.to_string(), script: script_b.to_string(),
description: "Circular dependency attempt".to_string(), // Fixed: removed Some() description: "Circular dependency attempt".to_string(),
}; };
let result_b = post_table_script(&pool, request_b).await; let result_b = post_table_script(&pool, request_b).await;
// Depending on implementation, this should either succeed or detect the cycle // This should either succeed or detect the cycle
match result_b { match result_b {
Ok(_) => { Ok(_) => {
// Implementation allows this pattern // Implementation allows this pattern
} }
Err(error) => { Err(error) => {
// Implementation detects circular dependencies // Implementation detects circular dependencies
let error_msg = error.to_string(); let error_msg = error.to_string().to_lowercase();
assert!( assert!(
error_msg.contains("cycle") || error_msg.contains("circular"), error_msg.contains("cycle") || error_msg.contains("circular"),
"Circular dependency should be detected properly: {}", "Circular dependency should be detected properly: {}",
error_msg error.to_string()
); );
} }
} }