From 17495c49ac579f2acd655625437c762ae77d6135 Mon Sep 17 00:00:00 2001 From: filipriec Date: Sat, 12 Jul 2025 23:06:21 +0200 Subject: [PATCH] steel scripts now have far better logic than before --- server/src/steel/server/execution.rs | 71 +++++- server/src/steel/server/functions.rs | 230 +++++++++++++++++- .../handlers/dependency_analyzer.rs | 2 +- .../handlers/post_table_script.rs | 220 ++++++++++++++++- .../table_script/prohibited_types_test.rs | 181 ++++++++++++++ 5 files changed, 691 insertions(+), 13 deletions(-) create mode 100644 server/tests/table_script/prohibited_types_test.rs diff --git a/server/src/steel/server/execution.rs b/server/src/steel/server/execution.rs index f03f027..f0a398b 100644 --- a/server/src/steel/server/execution.rs +++ b/server/src/steel/server/execution.rs @@ -2,10 +2,11 @@ use steel::steel_vm::engine::Engine; use steel::steel_vm::register_fn::RegisterFn; use steel::rvals::SteelVal; -use super::functions::SteelContext; +use super::functions::{SteelContext, convert_row_data_for_steel}; use steel_decimal::registry::FunctionRegistry; use sqlx::PgPool; use std::sync::Arc; +use std::collections::HashMap; use thiserror::Error; #[derive(Debug)] @@ -25,6 +26,74 @@ pub enum ExecutionError { UnsupportedType(String), } +/// Create a SteelContext with boolean conversion applied to row data +pub async fn create_steel_context_with_boolean_conversion( + current_table: String, + schema_id: i64, + schema_name: String, + mut row_data: HashMap, + db_pool: Arc, +) -> Result { + // Convert boolean values in row_data to Steel format + convert_row_data_for_steel(&db_pool, schema_id, ¤t_table, &mut row_data) + .await + .map_err(|e| ExecutionError::RuntimeError(format!("Failed to convert row data: {}", e)))?; + + Ok(SteelContext { + current_table, + schema_id, + schema_name, + row_data, + db_pool, + }) +} + +/// Execute script with proper boolean handling +pub async fn execute_script_with_boolean_support( + script: String, + target_type: &str, + db_pool: Arc, + schema_id: i64, + schema_name: String, + current_table: String, + row_data: HashMap, +) -> Result { + let mut vm = Engine::new(); + + // Create context with boolean conversion + let context = create_steel_context_with_boolean_conversion( + current_table, + schema_id, + schema_name, + row_data, + db_pool.clone(), + ).await?; + + let context = Arc::new(context); + + // Register existing Steel functions + register_steel_functions(&mut vm, context.clone()); + + // Register all decimal math functions using the steel_decimal crate + register_decimal_math_functions(&mut vm); + + // Register variables from the context with the Steel VM + // The row_data now contains Steel-formatted boolean values + FunctionRegistry::register_variables(&mut vm, context.row_data.clone()); + + // Execute script and process results + let results = vm.compile_and_run_raw_program(script) + .map_err(|e| ExecutionError::RuntimeError(e.to_string()))?; + + // Convert results to target type + match target_type { + "STRINGS" => process_string_results(results), + _ => Err(ExecutionError::UnsupportedType(target_type.into())) + } +} + +/// Original execute_script function (kept for backward compatibility) +/// Note: This doesn't include boolean conversion - use execute_script_with_boolean_support for new code pub fn execute_script( script: String, target_type: &str, diff --git a/server/src/steel/server/functions.rs b/server/src/steel/server/functions.rs index 7c4ff91..46fb807 100644 --- a/server/src/steel/server/functions.rs +++ b/server/src/steel/server/functions.rs @@ -16,8 +16,13 @@ pub enum FunctionError { TableNotFound(String), #[error("Database error: {0}")] DatabaseError(String), + #[error("Prohibited data type access: {0}")] + ProhibitedTypeAccess(String), } +// Define prohibited data types (boolean is explicitly allowed) +const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"]; + #[derive(Clone)] pub struct SteelContext { pub current_table: String, @@ -43,10 +48,102 @@ impl SteelContext { Ok(table_def.table_name) } + /// Get column type for a given table and column + async fn get_column_type(&self, table_name: &str, column_name: &str) -> Result { + let table_def = sqlx::query!( + r#"SELECT columns FROM table_definitions + WHERE schema_id = $1 AND table_name = $2"#, + self.schema_id, + table_name + ) + .fetch_optional(&*self.db_pool) + .await + .map_err(|e| FunctionError::DatabaseError(e.to_string()))? + .ok_or_else(|| FunctionError::TableNotFound(table_name.to_string()))?; + + // Parse columns JSON to find the column type + let columns: Vec = serde_json::from_value(table_def.columns) + .map_err(|e| FunctionError::DatabaseError(format!("Invalid column data: {}", e)))?; + + // Find the column and its type + for column_def in columns { + let mut parts = column_def.split_whitespace(); + if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) { + let column_name_clean = name.trim_matches('"'); + if column_name_clean == column_name { + return Ok(data_type.to_string()); + } + } + } + + Err(FunctionError::ColumnNotFound(format!( + "Column '{}' not found in table '{}'", + column_name, + table_name + ))) + } + + /// Convert database value to Steel format based on column type + fn convert_value_to_steel_format(&self, value: &str, column_type: &str) -> String { + let normalized_type = normalize_data_type(column_type); + + match normalized_type.as_str() { + "BOOLEAN" | "BOOL" => { + // Convert database boolean to Steel boolean syntax + match value.to_lowercase().as_str() { + "true" | "t" | "1" | "yes" | "on" => "#true".to_string(), + "false" | "f" | "0" | "no" | "off" => "#false".to_string(), + _ => value.to_string(), // Return as-is if not a recognized boolean + } + } + _ => value.to_string(), // Return as-is for non-boolean types + } + } + + /// Validate column type and return the column type if valid + async fn validate_column_type_and_get_type(&self, table_name: &str, column_name: &str) -> Result { + let column_type = self.get_column_type(table_name, column_name).await?; + + // Check if this type is prohibited + if is_prohibited_type(&column_type) { + return Err(FunctionError::ProhibitedTypeAccess(format!( + "Cannot access column '{}' in table '{}' because it has prohibited type '{}'. Steel scripts cannot access columns of type: {}", + column_name, + table_name, + column_type, + PROHIBITED_TYPES.join(", ") + ))); + } + + Ok(column_type) + } + + /// Validate column type before access (legacy method for backward compatibility) + async fn validate_column_type(&self, table_name: &str, column_name: &str) -> Result<(), FunctionError> { + self.validate_column_type_and_get_type(table_name, column_name).await?; + Ok(()) + } + pub fn steel_get_column(&self, table: &str, column: &str) -> Result { if table == self.current_table { + // Validate column type for current table access and get the type + let column_type = tokio::task::block_in_place(|| { + let handle = tokio::runtime::Handle::current(); + handle.block_on(async { + self.validate_column_type_and_get_type(table, column).await + }) + }); + + let column_type = match column_type { + Ok(ct) => ct, + Err(e) => return Err(SteelVal::StringV(e.to_string().into())), + }; + return self.row_data.get(column) - .map(|v| SteelVal::StringV(v.clone().into())) + .map(|v| { + let converted_value = self.convert_value_to_steel_format(v, &column_type); + SteelVal::StringV(converted_value.into()) + }) .ok_or_else(|| SteelVal::StringV(format!("Column {} not found", column).into())); } @@ -65,15 +162,23 @@ impl SteelContext { let actual_table = self.get_related_table_name(base_name).await .map_err(|e| SteelVal::StringV(e.to_string().into()))?; - // Add quotes around the table name - sqlx::query_scalar::<_, String>( + // Get column type for validation and conversion + let column_type = self.validate_column_type_and_get_type(&actual_table, column).await + .map_err(|e| SteelVal::StringV(e.to_string().into()))?; + + // Fetch the raw value from database + let raw_value = sqlx::query_scalar::<_, String>( &format!("SELECT {} FROM \"{}\".\"{}\" WHERE id = $1", column, self.schema_name, actual_table) ) .bind(fk_value.parse::().map_err(|_| SteelVal::StringV("Invalid foreign key format".into()))?) .fetch_one(&*self.db_pool) .await - .map_err(|e| SteelVal::StringV(e.to_string().into())) + .map_err(|e| SteelVal::StringV(e.to_string().into()))?; + + // Convert to Steel format + let converted_value = self.convert_value_to_steel_format(&raw_value, &column_type); + Ok(converted_value) }) }); @@ -86,12 +191,37 @@ impl SteelContext { index: i64, column: &str ) -> Result { + // Get the full value first (this already handles type conversion) let value = self.steel_get_column(table, column)?; + if let SteelVal::StringV(s) = value { let parts: Vec<_> = s.split(',').collect(); - parts.get(index as usize) - .map(|v| SteelVal::StringV(v.trim().into())) - .ok_or_else(|| SteelVal::StringV("Index out of bounds".into())) + + if let Some(part) = parts.get(index as usize) { + let trimmed_part = part.trim(); + + // If the original column was boolean type, each part should also be treated as boolean + // We need to get the column type to determine if conversion is needed + let column_type = tokio::task::block_in_place(|| { + let handle = tokio::runtime::Handle::current(); + handle.block_on(async { + self.get_column_type(table, column).await + }) + }); + + match column_type { + Ok(ct) => { + let converted_part = self.convert_value_to_steel_format(trimmed_part, &ct); + Ok(SteelVal::StringV(converted_part.into())) + } + Err(_) => { + // If we can't get the type, return as-is + Ok(SteelVal::StringV(trimmed_part.into())) + } + } + } else { + Err(SteelVal::StringV("Index out of bounds".into())) + } } else { Err(SteelVal::StringV("Expected comma-separated string".into())) } @@ -105,6 +235,14 @@ impl SteelContext { )); } + // Check if query might access prohibited columns + if contains_prohibited_column_access(query) { + return Err(SteelVal::StringV(format!( + "SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}", + PROHIBITED_TYPES.join(", ") + ).into())); + } + let pool = self.db_pool.clone(); // Use `tokio::task::block_in_place` to safely block the thread @@ -132,9 +270,87 @@ impl SteelContext { } } +/// Check if a data type is prohibited for Steel scripts +fn is_prohibited_type(data_type: &str) -> bool { + let normalized_type = normalize_data_type(data_type); + PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) +} + +/// Normalize data type for comparison (handle NUMERIC variations, etc.) +fn normalize_data_type(data_type: &str) -> String { + data_type.to_uppercase() + .split('(') // Remove precision/scale from NUMERIC(x,y) + .next() + .unwrap_or(data_type) + .trim() + .to_string() +} + +/// Basic check for prohibited column access in SQL queries +/// This is a simple heuristic - more sophisticated parsing could be added +fn contains_prohibited_column_access(query: &str) -> bool { + let query_upper = query.to_uppercase(); + + // Look for common patterns that might indicate prohibited type access + // This is a basic implementation - you might want to enhance this + let patterns = [ + "EXTRACT(", // Common with DATE/TIMESTAMPTZ + "DATE_PART(", // Common with DATE/TIMESTAMPTZ + "::DATE", + "::TIMESTAMPTZ", + "::BIGINT", + ]; + + patterns.iter().any(|pattern| query_upper.contains(pattern)) +} + fn is_read_only_query(query: &str) -> bool { let query = query.trim_start().to_uppercase(); query.starts_with("SELECT") || query.starts_with("SHOW") || query.starts_with("EXPLAIN") } + +/// Helper function to convert initial row data for boolean columns +pub async fn convert_row_data_for_steel( + db_pool: &PgPool, + schema_id: i64, + table_name: &str, + row_data: &mut HashMap, +) -> Result<(), sqlx::Error> { + // Get table definition to check column types + let table_def = sqlx::query!( + r#"SELECT columns FROM table_definitions + WHERE schema_id = $1 AND table_name = $2"#, + schema_id, + table_name + ) + .fetch_optional(db_pool) + .await? + .ok_or_else(|| sqlx::Error::RowNotFound)?; + + // Parse columns to find boolean types + if let Ok(columns) = serde_json::from_value::>(table_def.columns) { + for column_def in columns { + let mut parts = column_def.split_whitespace(); + if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) { + let column_name = name.trim_matches('"'); + let normalized_type = normalize_data_type(data_type); + + // Fixed: Use traditional if let instead of let chains + if normalized_type == "BOOLEAN" || normalized_type == "BOOL" { + if let Some(value) = row_data.get_mut(column_name) { + // Convert boolean value to Steel format + *value = match value.to_lowercase().as_str() { + "true" | "t" | "1" | "yes" | "on" => "#true".to_string(), + "false" | "f" | "0" | "no" | "off" => "#false".to_string(), + _ => value.clone(), // Keep original if not recognized + }; + } + } + } + } + } + + Ok(()) +} diff --git a/server/src/table_script/handlers/dependency_analyzer.rs b/server/src/table_script/handlers/dependency_analyzer.rs index 98a2a32..1db3def 100644 --- a/server/src/table_script/handlers/dependency_analyzer.rs +++ b/server/src/table_script/handlers/dependency_analyzer.rs @@ -1,6 +1,6 @@ // src/table_script/handlers/dependency_analyzer.rs -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use tonic::Status; use sqlx::PgPool; use serde_json::{json, Value}; diff --git a/server/src/table_script/handlers/post_table_script.rs b/server/src/table_script/handlers/post_table_script.rs index eb14c86..3594264 100644 --- a/server/src/table_script/handlers/post_table_script.rs +++ b/server/src/table_script/handlers/post_table_script.rs @@ -6,12 +6,17 @@ use sqlx::{PgPool, Error as SqlxError}; use common::proto::multieko2::table_script::{PostTableScriptRequest, TableScriptResponse}; use serde_json::Value; use steel_decimal::SteelDecimal; +use regex::Regex; +use std::collections::HashSet; use crate::table_script::handlers::dependency_analyzer::DependencyAnalyzer; const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"]; -/// Validates the target column and ensures it is not a system column. +// Define prohibited data types for Steel scripts (boolean is explicitly allowed) +const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"]; + +/// Validates the target column and ensures it is not a system column or prohibited type. /// Returns the column type if valid. fn validate_target_column( table_name: &str, @@ -38,11 +43,211 @@ fn validate_target_column( .collect(); // Find the target column and return its type - column_info + let column_type = column_info .iter() .find(|(name, _)| *name == target) .map(|(_, dt)| dt.to_string()) - .ok_or_else(|| format!("Target column '{}' not defined in table '{}'", target, table_name)) + .ok_or_else(|| format!("Target column '{}' not defined in table '{}'", target, table_name))?; + + // Check if the target column type is prohibited + if is_prohibited_type(&column_type) { + return Err(format!( + "Cannot create script for column '{}' with type '{}'. Steel scripts cannot target columns of type: {}", + target, + column_type, + PROHIBITED_TYPES.join(", ") + )); + } + + // Add helpful info for boolean columns + let normalized_type = normalize_data_type(&column_type); + if normalized_type == "BOOLEAN" || normalized_type == "BOOL" { + println!("Info: Target column '{}' is boolean type. Values will be converted to Steel format (#true/#false)", target); + } + + Ok(column_type) +} + +/// Check if a data type is prohibited for Steel scripts +/// Note: BOOLEAN/BOOL is explicitly allowed and handled with special conversion +fn is_prohibited_type(data_type: &str) -> bool { + let normalized_type = normalize_data_type(data_type); + PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) +} + +/// Normalize data type for comparison (handle NUMERIC variations, etc.) +fn normalize_data_type(data_type: &str) -> String { + data_type.to_uppercase() + .split('(') // Remove precision/scale from NUMERIC(x,y) + .next() + .unwrap_or(data_type) + .trim() + .to_string() +} + +/// Parse Steel script to extract all table/column references +fn extract_column_references_from_script(script: &str) -> Vec<(String, String)> { + let mut references = Vec::new(); + + // Regex patterns to match Steel function calls + let patterns = [ + // (steel_get_column "table_name" "column_name") + r#"\(steel_get_column\s+"([^"]+)"\s+"([^"]+)"\)"#, + // (steel_get_column_with_index "table_name" index "column_name") + r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#, + ]; + + for pattern in &patterns { + if let Ok(re) = Regex::new(pattern) { + for cap in re.captures_iter(script) { + if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) { + references.push((table.as_str().to_string(), column.as_str().to_string())); + } + } + } + } + + // Also check for steel_get_column_with_index pattern (table, column are in different positions) + if let Ok(re) = Regex::new(r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#) { + for cap in re.captures_iter(script) { + if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) { + references.push((table.as_str().to_string(), column.as_str().to_string())); + } + } + } + + references +} + +/// Validate that script doesn't reference prohibited column types by checking actual DB schema +async fn validate_script_column_references( + db_pool: &PgPool, + schema_id: i64, + script: &str, +) -> Result<(), Status> { + // Extract all table/column references from the script + let references = extract_column_references_from_script(script); + + if references.is_empty() { + return Ok(()); // No column references to validate + } + + // Get all unique table names referenced in the script + let table_names: HashSet = references.iter() + .map(|(table, _)| table.clone()) + .collect(); + + // Fetch table definitions for all referenced tables + for table_name in table_names { + // Query the actual table definition from the database + let table_def = sqlx::query!( + r#"SELECT table_name, columns FROM table_definitions + WHERE schema_id = $1 AND table_name = $2"#, + schema_id, + table_name + ) + .fetch_optional(db_pool) + .await + .map_err(|e| Status::internal(format!("Failed to fetch table definition for '{}': {}", table_name, e)))?; + + if let Some(table_def) = table_def { + // Check each column reference for this table + for (ref_table, ref_column) in &references { + if ref_table == &table_name { + // Validate this specific column reference + if let Err(error_msg) = validate_referenced_column_type(&table_name, ref_column, &table_def.columns) { + return Err(Status::invalid_argument(error_msg)); + } + } + } + } else { + return Err(Status::invalid_argument(format!( + "Script references table '{}' which does not exist in this schema", + table_name + ))); + } + } + + Ok(()) +} + +/// Validate that a referenced column doesn't have a prohibited type +fn validate_referenced_column_type(table_name: &str, column_name: &str, table_columns: &Value) -> Result<(), String> { + // Parse the columns JSON into a vector of strings + let columns: Vec = serde_json::from_value(table_columns.clone()) + .map_err(|e| format!("Invalid column data for table '{}': {}", table_name, e))?; + + // Extract column names and types + let column_info: Vec<(&str, &str)> = columns + .iter() + .filter_map(|c| { + let mut parts = c.split_whitespace(); + let name = parts.next()?.trim_matches('"'); + let data_type = parts.next()?; + Some((name, data_type)) + }) + .collect(); + + // Find the referenced column and check its type + if let Some((_, column_type)) = column_info.iter().find(|(name, _)| *name == column_name) { + if is_prohibited_type(column_type) { + return Err(format!( + "Script references column '{}' in table '{}' which has prohibited type '{}'. Steel scripts cannot access columns of type: {}", + column_name, + table_name, + column_type, + PROHIBITED_TYPES.join(", ") + )); + } + + // Log info for boolean columns + let normalized_type = normalize_data_type(column_type); + if normalized_type == "BOOLEAN" || normalized_type == "BOOL" { + println!("Info: Script references boolean column '{}' in table '{}'. Values will be converted to Steel format (#true/#false)", column_name, table_name); + } + } else { + return Err(format!( + "Script references column '{}' in table '{}' but this column does not exist", + column_name, + table_name + )); + } + + Ok(()) +} + +/// Parse Steel SQL queries to check for prohibited type usage (basic heuristic) +fn validate_sql_queries_in_script(script: &str) -> Result<(), String> { + // Look for steel_query_sql calls + if let Ok(re) = Regex::new(r#"\(steel_query_sql\s+"([^"]+)"\)"#) { + for cap in re.captures_iter(script) { + if let Some(query) = cap.get(1) { + let sql = query.as_str().to_uppercase(); + + // Basic heuristic checks for prohibited type operations + let prohibited_patterns = [ + "EXTRACT(", + "DATE_PART(", + "::DATE", + "::TIMESTAMPTZ", + "::BIGINT", + "CAST(", // Could be casting to prohibited types + ]; + + for pattern in &prohibited_patterns { + if sql.contains(pattern) { + return Err(format!( + "Script contains SQL query with potentially prohibited type operations: '{}'. Steel scripts cannot use operations on types: {}", + query.as_str(), + PROHIBITED_TYPES.join(", ") + )); + } + } + } + } + } + + Ok(()) } /// Handles the creation of a new table script with dependency validation. @@ -65,7 +270,7 @@ pub async fn post_table_script( .map_err(|e| Status::internal(format!("Failed to fetch table definition: {}", e)))? .ok_or_else(|| Status::not_found("Table definition not found"))?; - // Validate the target column and get its type + // Validate the target column and get its type (includes prohibited type check) let column_type = validate_target_column( &table_def.table_name, &request.target_column, @@ -73,6 +278,13 @@ pub async fn post_table_script( ) .map_err(|e| Status::invalid_argument(e))?; + // Validate that script doesn't reference prohibited column types by checking actual DB schema + validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?; + + // Validate SQL queries in script for prohibited type operations + validate_sql_queries_in_script(&request.script) + .map_err(|e| Status::invalid_argument(e))?; + // Create dependency analyzer for this schema let analyzer = DependencyAnalyzer::new(table_def.schema_id, db_pool.clone()); diff --git a/server/tests/table_script/prohibited_types_test.rs b/server/tests/table_script/prohibited_types_test.rs new file mode 100644 index 0000000..3121638 --- /dev/null +++ b/server/tests/table_script/prohibited_types_test.rs @@ -0,0 +1,181 @@ +// tests/table_script/prohibited_types_test.rs + +#[cfg(test)] +mod prohibited_types_tests { + use super::*; + use common::proto::multieko2::table_script::PostTableScriptRequest; + use sqlx::PgPool; + + #[tokio::test] + async fn test_reject_bigint_target_column() { + let pool = setup_test_db().await; + + // Create a table with a BIGINT column + let table_id = create_test_table_with_bigint_column(&pool).await; + + let request = PostTableScriptRequest { + table_definition_id: table_id, + target_column: "big_number".to_string(), // This is BIGINT + script: r#" + (define result "some calculation") + result + "#.to_string(), + description: "Test script".to_string(), + }; + + let result = post_table_script(&pool, request).await; + + // Should fail with prohibited type error + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot create script for column 'big_number' with type 'BIGINT'")); + assert!(error_msg.contains("Steel scripts cannot target columns of type: BIGINT, DATE, TIMESTAMPTZ")); + } + + #[tokio::test] + async fn test_reject_date_target_column() { + let pool = setup_test_db().await; + + // Create a table with a DATE column + let table_id = create_test_table_with_date_column(&pool).await; + + let request = PostTableScriptRequest { + table_definition_id: table_id, + target_column: "event_date".to_string(), // This is DATE + script: r#" + (define result "2024-01-01") + result + "#.to_string(), + description: "Test script".to_string(), + }; + + let result = post_table_script(&pool, request).await; + + // Should fail with prohibited type error + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot create script for column 'event_date' with type 'DATE'")); + } + + #[tokio::test] + async fn test_reject_timestamptz_target_column() { + let pool = setup_test_db().await; + + // Create a table with a TIMESTAMPTZ column + let table_id = create_test_table_with_timestamptz_column(&pool).await; + + let request = PostTableScriptRequest { + table_definition_id: table_id, + target_column: "created_time".to_string(), // This is TIMESTAMPTZ + script: r#" + (define result "2024-01-01T10:00:00Z") + result + "#.to_string(), + description: "Test script".to_string(), + }; + + let result = post_table_script(&pool, request).await; + + // Should fail with prohibited type error + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot create script for column 'created_time' with type 'TIMESTAMPTZ'")); + } + + #[tokio::test] + async fn test_reject_script_referencing_prohibited_column() { + let pool = setup_test_db().await; + + // Create linked tables - one with BIGINT column, another with TEXT target + let source_table_id = create_test_table_with_text_column(&pool).await; + let linked_table_id = create_test_table_with_bigint_column(&pool).await; + + // Create link between tables + create_table_link(&pool, source_table_id, linked_table_id).await; + + let request = PostTableScriptRequest { + table_definition_id: source_table_id, + target_column: "description".to_string(), // This is TEXT (allowed) + script: r#" + (define big_val (steel_get_column "linked_table" "big_number")) + (string-append "Value: " (number->string big_val)) + "#.to_string(), + description: "Script that tries to access BIGINT column".to_string(), + }; + + let result = post_table_script(&pool, request).await; + + // Should fail because script references BIGINT column + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Script cannot reference column 'big_number'")); + assert!(error_msg.contains("prohibited type 'BIGINT'")); + } + + #[tokio::test] + async fn test_allow_valid_script_with_allowed_types() { + let pool = setup_test_db().await; + + // Create a table with allowed column types + let table_id = create_test_table_with_allowed_columns(&pool).await; + + let request = PostTableScriptRequest { + table_definition_id: table_id, + target_column: "computed_value".to_string(), // This is TEXT (allowed) + script: r#" + (define name_val (steel_get_column "test_table" "name")) + (define count_val (steel_get_column "test_table" "count")) + (string-append name_val " has " (number->string count_val) " items") + "#.to_string(), + description: "Valid script using allowed types".to_string(), + }; + + let result = post_table_script(&pool, request).await; + + // Should succeed + assert!(result.is_ok()); + let response = result.unwrap(); + assert!(response.id > 0); + } + + // Helper functions for test setup + async fn setup_test_db() -> PgPool { + // Your test database setup code here + todo!("Implement test DB setup") + } + + async fn create_test_table_with_bigint_column(pool: &PgPool) -> i64 { + // Create table definition with BIGINT column + // JSON columns would be: ["name TEXT", "big_number BIGINT"] + todo!("Implement table creation with BIGINT") + } + + async fn create_test_table_with_date_column(pool: &PgPool) -> i64 { + // Create table definition with DATE column + // JSON columns would be: ["name TEXT", "event_date DATE"] + todo!("Implement table creation with DATE") + } + + async fn create_test_table_with_timestamptz_column(pool: &PgPool) -> i64 { + // Create table definition with TIMESTAMPTZ column + // JSON columns would be: ["name TEXT", "created_time TIMESTAMPTZ"] + todo!("Implement table creation with TIMESTAMPTZ") + } + + async fn create_test_table_with_text_column(pool: &PgPool) -> i64 { + // Create table definition with TEXT columns only + // JSON columns would be: ["name TEXT", "description TEXT"] + todo!("Implement table creation with TEXT") + } + + async fn create_test_table_with_allowed_columns(pool: &PgPool) -> i64 { + // Create table definition with only allowed column types + // JSON columns would be: ["name TEXT", "count INTEGER", "computed_value TEXT"] + todo!("Implement table creation with allowed types") + } + + async fn create_table_link(pool: &PgPool, source_id: i64, target_id: i64) { + // Create a link in table_definition_links + todo!("Implement table linking") + } +}