reject boolean and text in math functions at the post table script
This commit is contained in:
@@ -96,6 +96,9 @@ impl SteelContext {
|
||||
_ => value.to_string(), // Return as-is if not a recognized boolean
|
||||
}
|
||||
}
|
||||
"INTEGER" => {
|
||||
value.to_string()
|
||||
}
|
||||
_ => value.to_string(), // Return as-is for non-boolean types
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ use serde_json::Value;
|
||||
use steel_decimal::SteelDecimal;
|
||||
use regex::Regex;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::HashMap;
|
||||
|
||||
|
||||
use crate::table_script::handlers::dependency_analyzer::DependencyAnalyzer;
|
||||
|
||||
@@ -15,6 +17,169 @@ const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"];
|
||||
|
||||
// Define prohibited data types for Steel scripts (boolean is explicitly allowed)
|
||||
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
|
||||
const MATH_PROHIBITED_TYPES: &[&str] = &["TEXT", "BOOLEAN"];
|
||||
|
||||
|
||||
/// Extract mathematical expressions from the original script (before steel_decimal transformation)
|
||||
fn extract_math_operations_with_operands(script: &str) -> Vec<String> {
|
||||
let mut math_operands = Vec::new();
|
||||
|
||||
// Define math operation patterns that steel_decimal will transform
|
||||
let math_patterns = [
|
||||
r"\(\s*\+\s+([^)]+)\)", // (+ operands)
|
||||
r"\(\s*-\s+([^)]+)\)", // (- operands)
|
||||
r"\(\s*\*\s+([^)]+)\)", // (* operands)
|
||||
r"\(\s*/\s+([^)]+)\)", // (/ operands)
|
||||
r"\(\s*\^\s+([^)]+)\)", // (^ operands)
|
||||
r"\(\s*\*\*\s+([^)]+)\)", // (** operands)
|
||||
r"\(\s*pow\s+([^)]+)\)", // (pow operands)
|
||||
r"\(\s*sqrt\s+([^)]+)\)", // (sqrt operands)
|
||||
r"\(\s*>\s+([^)]+)\)", // (> operands)
|
||||
r"\(\s*<\s+([^)]+)\)", // (< operands)
|
||||
r"\(\s*=\s+([^)]+)\)", // (= operands)
|
||||
r"\(\s*>=\s+([^)]+)\)", // (>= operands)
|
||||
r"\(\s*<=\s+([^)]+)\)", // (<= operands)
|
||||
r"\(\s*min\s+([^)]+)\)", // (min operands)
|
||||
r"\(\s*max\s+([^)]+)\)", // (max operands)
|
||||
r"\(\s*abs\s+([^)]+)\)", // (abs operands)
|
||||
r"\(\s*round\s+([^)]+)\)", // (round operands)
|
||||
r"\(\s*ln\s+([^)]+)\)", // (ln operands)
|
||||
r"\(\s*log\s+([^)]+)\)", // (log operands)
|
||||
r"\(\s*log10\s+([^)]+)\)", // (log10 operands)
|
||||
r"\(\s*exp\s+([^)]+)\)", // (exp operands)
|
||||
r"\(\s*sin\s+([^)]+)\)", // (sin operands)
|
||||
r"\(\s*cos\s+([^)]+)\)", // (cos operands)
|
||||
r"\(\s*tan\s+([^)]+)\)", // (tan operands)
|
||||
];
|
||||
|
||||
for pattern in &math_patterns {
|
||||
if let Ok(re) = Regex::new(pattern) {
|
||||
for cap in re.captures_iter(script) {
|
||||
if let Some(operands_str) = cap.get(1) {
|
||||
// Add all operands from this math operation
|
||||
math_operands.push(operands_str.as_str().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
math_operands
|
||||
}
|
||||
|
||||
/// Extract column references from mathematical operands
|
||||
fn extract_column_references_from_math_operands(operands: &[String]) -> Vec<(String, String)> {
|
||||
let mut references = Vec::new();
|
||||
|
||||
for operand_str in operands {
|
||||
// Check for steel_get_column calls: (steel_get_column "table" "column")
|
||||
if let Ok(re) = Regex::new(r#"\(steel_get_column\s+"([^"]+)"\s+"([^"]+)"\)"#) {
|
||||
for cap in re.captures_iter(operand_str) {
|
||||
if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) {
|
||||
references.push((table.as_str().to_string(), column.as_str().to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for steel_get_column_with_index calls
|
||||
if let Ok(re) = Regex::new(r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#) {
|
||||
for cap in re.captures_iter(operand_str) {
|
||||
if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) {
|
||||
references.push((table.as_str().to_string(), column.as_str().to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
references
|
||||
}
|
||||
|
||||
/// Validate that mathematical operations don't use TEXT or BOOLEAN columns
|
||||
async fn validate_math_operations_column_types(
|
||||
db_pool: &PgPool,
|
||||
schema_id: i64,
|
||||
script: &str,
|
||||
) -> Result<(), Status> {
|
||||
// Extract all mathematical operations and their operands
|
||||
let math_operands = extract_math_operations_with_operands(script);
|
||||
|
||||
if math_operands.is_empty() {
|
||||
return Ok(()); // No math operations to validate
|
||||
}
|
||||
|
||||
// Extract column references from math operands
|
||||
let column_refs = extract_column_references_from_math_operands(&math_operands);
|
||||
|
||||
if column_refs.is_empty() {
|
||||
return Ok(()); // No column references in math operations
|
||||
}
|
||||
|
||||
// Get all unique table names referenced in math operations
|
||||
let table_names: HashSet<String> = column_refs.iter()
|
||||
.map(|(table, _)| table.clone())
|
||||
.collect();
|
||||
|
||||
// Fetch table definitions for all referenced tables
|
||||
let table_definitions = sqlx::query!(
|
||||
r#"SELECT table_name, columns FROM table_definitions
|
||||
WHERE schema_id = $1 AND table_name = ANY($2)"#,
|
||||
schema_id,
|
||||
&table_names.into_iter().collect::<Vec<_>>()
|
||||
)
|
||||
.fetch_all(db_pool)
|
||||
.await
|
||||
.map_err(|e| Status::internal(format!("Failed to fetch table definitions: {}", e)))?;
|
||||
|
||||
// Build a map of table_name -> column_name -> column_type
|
||||
let mut table_column_types: HashMap<String, HashMap<String, String>> = HashMap::new();
|
||||
|
||||
for table_def in table_definitions {
|
||||
let columns: Vec<String> = serde_json::from_value(table_def.columns)
|
||||
.map_err(|e| Status::internal(format!("Invalid column data for table '{}': {}", table_def.table_name, e)))?;
|
||||
|
||||
let mut column_types = HashMap::new();
|
||||
for column_def in columns {
|
||||
let mut parts = column_def.split_whitespace();
|
||||
if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) {
|
||||
let column_name = name.trim_matches('"');
|
||||
column_types.insert(column_name.to_string(), data_type.to_string());
|
||||
}
|
||||
}
|
||||
table_column_types.insert(table_def.table_name, column_types);
|
||||
}
|
||||
|
||||
// Check each column reference in mathematical operations
|
||||
for (table_name, column_name) in column_refs {
|
||||
if let Some(table_columns) = table_column_types.get(&table_name) {
|
||||
if let Some(column_type) = table_columns.get(&column_name) {
|
||||
let normalized_type = normalize_data_type(column_type);
|
||||
|
||||
// Check if this type is prohibited in math operations
|
||||
if MATH_PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) {
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"Cannot use column '{}' of type '{}' from table '{}' in mathematical operations. Mathematical operations cannot use columns of type: {}",
|
||||
column_name,
|
||||
column_type,
|
||||
table_name,
|
||||
MATH_PROHIBITED_TYPES.join(", ")
|
||||
)));
|
||||
}
|
||||
} else {
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"Script references column '{}' in table '{}' but this column does not exist",
|
||||
column_name,
|
||||
table_name
|
||||
)));
|
||||
}
|
||||
} else {
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"Script references table '{}' in mathematical operations but this table does not exist in this schema",
|
||||
table_name
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validates the target column and ensures it is not a system column or prohibited type.
|
||||
/// Returns the column type if valid.
|
||||
@@ -281,6 +446,9 @@ pub async fn post_table_script(
|
||||
// Validate that script doesn't reference prohibited column types by checking actual DB schema
|
||||
validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?;
|
||||
|
||||
// NEW: Validate that mathematical operations don't use TEXT or BOOLEAN columns
|
||||
validate_math_operations_column_types(db_pool, table_def.schema_id, &request.script).await?;
|
||||
|
||||
// Validate SQL queries in script for prohibited type operations
|
||||
validate_sql_queries_in_script(&request.script)
|
||||
.map_err(|e| Status::invalid_argument(e))?;
|
||||
@@ -300,7 +468,7 @@ pub async fn post_table_script(
|
||||
.await
|
||||
.map_err(|e| Status::from(e))?;
|
||||
|
||||
// Transform the script using steel_decimal
|
||||
// Transform the script using steel_decimal (this happens AFTER validation)
|
||||
let steel_decimal = SteelDecimal::new();
|
||||
let parsed_script = steel_decimal.transform(&request.script);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user