steel scripts now have far better logic than before

This commit is contained in:
filipriec
2025-07-12 23:06:21 +02:00
parent 0e3a7a06a3
commit 17495c49ac
5 changed files with 691 additions and 13 deletions

View File

@@ -6,12 +6,17 @@ use sqlx::{PgPool, Error as SqlxError};
use common::proto::multieko2::table_script::{PostTableScriptRequest, TableScriptResponse};
use serde_json::Value;
use steel_decimal::SteelDecimal;
use regex::Regex;
use std::collections::HashSet;
use crate::table_script::handlers::dependency_analyzer::DependencyAnalyzer;
const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"];
/// Validates the target column and ensures it is not a system column.
// Define prohibited data types for Steel scripts (boolean is explicitly allowed)
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
/// Validates the target column and ensures it is not a system column or prohibited type.
/// Returns the column type if valid.
fn validate_target_column(
table_name: &str,
@@ -38,11 +43,211 @@ fn validate_target_column(
.collect();
// Find the target column and return its type
column_info
let column_type = column_info
.iter()
.find(|(name, _)| *name == target)
.map(|(_, dt)| dt.to_string())
.ok_or_else(|| format!("Target column '{}' not defined in table '{}'", target, table_name))
.ok_or_else(|| format!("Target column '{}' not defined in table '{}'", target, table_name))?;
// Check if the target column type is prohibited
if is_prohibited_type(&column_type) {
return Err(format!(
"Cannot create script for column '{}' with type '{}'. Steel scripts cannot target columns of type: {}",
target,
column_type,
PROHIBITED_TYPES.join(", ")
));
}
// Add helpful info for boolean columns
let normalized_type = normalize_data_type(&column_type);
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
println!("Info: Target column '{}' is boolean type. Values will be converted to Steel format (#true/#false)", target);
}
Ok(column_type)
}
/// Check if a data type is prohibited for Steel scripts
/// Note: BOOLEAN/BOOL is explicitly allowed and handled with special conversion
fn is_prohibited_type(data_type: &str) -> bool {
let normalized_type = normalize_data_type(data_type);
PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited))
}
/// Normalize data type for comparison (handle NUMERIC variations, etc.)
fn normalize_data_type(data_type: &str) -> String {
data_type.to_uppercase()
.split('(') // Remove precision/scale from NUMERIC(x,y)
.next()
.unwrap_or(data_type)
.trim()
.to_string()
}
/// Parse Steel script to extract all table/column references
fn extract_column_references_from_script(script: &str) -> Vec<(String, String)> {
let mut references = Vec::new();
// Regex patterns to match Steel function calls
let patterns = [
// (steel_get_column "table_name" "column_name")
r#"\(steel_get_column\s+"([^"]+)"\s+"([^"]+)"\)"#,
// (steel_get_column_with_index "table_name" index "column_name")
r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#,
];
for pattern in &patterns {
if let Ok(re) = Regex::new(pattern) {
for cap in re.captures_iter(script) {
if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) {
references.push((table.as_str().to_string(), column.as_str().to_string()));
}
}
}
}
// Also check for steel_get_column_with_index pattern (table, column are in different positions)
if let Ok(re) = Regex::new(r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#) {
for cap in re.captures_iter(script) {
if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) {
references.push((table.as_str().to_string(), column.as_str().to_string()));
}
}
}
references
}
/// Validate that script doesn't reference prohibited column types by checking actual DB schema
async fn validate_script_column_references(
db_pool: &PgPool,
schema_id: i64,
script: &str,
) -> Result<(), Status> {
// Extract all table/column references from the script
let references = extract_column_references_from_script(script);
if references.is_empty() {
return Ok(()); // No column references to validate
}
// Get all unique table names referenced in the script
let table_names: HashSet<String> = references.iter()
.map(|(table, _)| table.clone())
.collect();
// Fetch table definitions for all referenced tables
for table_name in table_names {
// Query the actual table definition from the database
let table_def = sqlx::query!(
r#"SELECT table_name, columns FROM table_definitions
WHERE schema_id = $1 AND table_name = $2"#,
schema_id,
table_name
)
.fetch_optional(db_pool)
.await
.map_err(|e| Status::internal(format!("Failed to fetch table definition for '{}': {}", table_name, e)))?;
if let Some(table_def) = table_def {
// Check each column reference for this table
for (ref_table, ref_column) in &references {
if ref_table == &table_name {
// Validate this specific column reference
if let Err(error_msg) = validate_referenced_column_type(&table_name, ref_column, &table_def.columns) {
return Err(Status::invalid_argument(error_msg));
}
}
}
} else {
return Err(Status::invalid_argument(format!(
"Script references table '{}' which does not exist in this schema",
table_name
)));
}
}
Ok(())
}
/// Validate that a referenced column doesn't have a prohibited type
fn validate_referenced_column_type(table_name: &str, column_name: &str, table_columns: &Value) -> Result<(), String> {
// Parse the columns JSON into a vector of strings
let columns: Vec<String> = serde_json::from_value(table_columns.clone())
.map_err(|e| format!("Invalid column data for table '{}': {}", table_name, e))?;
// Extract column names and types
let column_info: Vec<(&str, &str)> = columns
.iter()
.filter_map(|c| {
let mut parts = c.split_whitespace();
let name = parts.next()?.trim_matches('"');
let data_type = parts.next()?;
Some((name, data_type))
})
.collect();
// Find the referenced column and check its type
if let Some((_, column_type)) = column_info.iter().find(|(name, _)| *name == column_name) {
if is_prohibited_type(column_type) {
return Err(format!(
"Script references column '{}' in table '{}' which has prohibited type '{}'. Steel scripts cannot access columns of type: {}",
column_name,
table_name,
column_type,
PROHIBITED_TYPES.join(", ")
));
}
// Log info for boolean columns
let normalized_type = normalize_data_type(column_type);
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
println!("Info: Script references boolean column '{}' in table '{}'. Values will be converted to Steel format (#true/#false)", column_name, table_name);
}
} else {
return Err(format!(
"Script references column '{}' in table '{}' but this column does not exist",
column_name,
table_name
));
}
Ok(())
}
/// Parse Steel SQL queries to check for prohibited type usage (basic heuristic)
fn validate_sql_queries_in_script(script: &str) -> Result<(), String> {
// Look for steel_query_sql calls
if let Ok(re) = Regex::new(r#"\(steel_query_sql\s+"([^"]+)"\)"#) {
for cap in re.captures_iter(script) {
if let Some(query) = cap.get(1) {
let sql = query.as_str().to_uppercase();
// Basic heuristic checks for prohibited type operations
let prohibited_patterns = [
"EXTRACT(",
"DATE_PART(",
"::DATE",
"::TIMESTAMPTZ",
"::BIGINT",
"CAST(", // Could be casting to prohibited types
];
for pattern in &prohibited_patterns {
if sql.contains(pattern) {
return Err(format!(
"Script contains SQL query with potentially prohibited type operations: '{}'. Steel scripts cannot use operations on types: {}",
query.as_str(),
PROHIBITED_TYPES.join(", ")
));
}
}
}
}
}
Ok(())
}
/// Handles the creation of a new table script with dependency validation.
@@ -65,7 +270,7 @@ pub async fn post_table_script(
.map_err(|e| Status::internal(format!("Failed to fetch table definition: {}", e)))?
.ok_or_else(|| Status::not_found("Table definition not found"))?;
// Validate the target column and get its type
// Validate the target column and get its type (includes prohibited type check)
let column_type = validate_target_column(
&table_def.table_name,
&request.target_column,
@@ -73,6 +278,13 @@ pub async fn post_table_script(
)
.map_err(|e| Status::invalid_argument(e))?;
// Validate that script doesn't reference prohibited column types by checking actual DB schema
validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?;
// Validate SQL queries in script for prohibited type operations
validate_sql_queries_in_script(&request.script)
.map_err(|e| Status::invalid_argument(e))?;
// Create dependency analyzer for this schema
let analyzer = DependencyAnalyzer::new(table_def.schema_id, db_pool.clone());