we got a full passer now

This commit is contained in:
filipriec
2025-07-20 14:58:39 +02:00
parent 84871faad4
commit 7e54b2fe43
2 changed files with 79 additions and 62 deletions

View File

@@ -15,13 +15,13 @@ use crate::table_script::handlers::dependency_analyzer::DependencyAnalyzer;
const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"];
// Define prohibited data types for Steel scripts (boolean is explicitly allowed)
const PROHIBITED_TYPES: &[&str] = &["DATE", "TIMESTAMPTZ"];
const MATH_PROHIBITED_TYPES: &[&str] = &["BIGINT", "TEXT", "BOOLEAN"];
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
const MATH_PROHIBITED_TYPES: &[&str] = &["BIGINT", "TEXT", "BOOLEAN", "DATE", "TIMESTAMPTZ"];
// Math operations that Steel Decimal will transform
const MATH_OPERATIONS: &[&str] = &[
"+", "-", "*", "/", "^", "**", "pow", "sqrt",
">", "<", "=", ">=", "<=", "min", "max", "abs",
"+", "-", "*", "/", "^", "**", "pow", "sqrt",
">", "<", "=", ">=", "<=", "min", "max", "abs",
"round", "ln", "log", "log10", "exp", "sin", "cos", "tan"
];
@@ -42,20 +42,20 @@ impl Parser {
let tokens = Self::tokenize(script);
Self { tokens, position: 0 }
}
fn tokenize(script: &str) -> Vec<String> {
let mut tokens = Vec::new();
let mut current_token = String::new();
let mut in_string = false;
let mut escape_next = false;
for ch in script.chars() {
if escape_next {
current_token.push(ch);
escape_next = false;
continue;
}
match ch {
'\\' if in_string => {
escape_next = true;
@@ -88,43 +88,43 @@ impl Parser {
}
}
}
if !current_token.is_empty() {
tokens.push(current_token);
}
tokens
}
fn parse(&mut self) -> Result<Vec<SExpr>, String> {
let mut expressions = Vec::new();
while self.position < self.tokens.len() {
expressions.push(self.parse_expr()?);
}
Ok(expressions)
}
fn parse_expr(&mut self) -> Result<SExpr, String> {
if self.position >= self.tokens.len() {
return Err("Unexpected end of input".to_string());
}
let token = &self.tokens[self.position];
if token == "(" {
self.position += 1; // consume '('
let mut elements = Vec::new();
while self.position < self.tokens.len() && self.tokens[self.position] != ")" {
elements.push(self.parse_expr()?);
}
if self.position >= self.tokens.len() {
return Err("Missing closing parenthesis".to_string());
}
self.position += 1; // consume ')'
Ok(SExpr::List(elements))
} else {
@@ -145,14 +145,14 @@ impl MathValidator {
column_references: Vec::new(),
}
}
fn validate_expressions(&mut self, expressions: &[SExpr]) -> Result<(), String> {
for expr in expressions {
self.check_expression(expr, false)?;
}
Ok(())
}
fn check_expression(&mut self, expr: &SExpr, in_math_context: bool) -> Result<(), String> {
match expr {
SExpr::Atom(_) => Ok(()),
@@ -160,14 +160,14 @@ impl MathValidator {
if elements.is_empty() {
return Ok(());
}
// Check if this is a math operation
let is_math = if let SExpr::Atom(op) = &elements[0] {
MATH_OPERATIONS.contains(&op.as_str())
} else {
false
};
// Check if this is a column access function
if let SExpr::Atom(func) = &elements[0] {
if func == "steel_get_column" && in_math_context {
@@ -176,17 +176,17 @@ impl MathValidator {
self.extract_column_reference_from_steel_get_column_with_index(elements)?;
}
}
// Recursively check all elements, marking math context appropriately
for element in &elements[1..] { // Skip the operator/function name
self.check_expression(element, in_math_context || is_math)?;
}
Ok(())
}
}
}
fn extract_column_reference_from_steel_get_column(&mut self, elements: &[SExpr]) -> Result<(), String> {
// (steel_get_column "table" "column")
if elements.len() >= 3 {
@@ -198,7 +198,7 @@ impl MathValidator {
}
Ok(())
}
fn extract_column_reference_from_steel_get_column_with_index(&mut self, elements: &[SExpr]) -> Result<(), String> {
// (steel_get_column_with_index "table" index "column")
if elements.len() >= 4 {
@@ -210,7 +210,7 @@ impl MathValidator {
}
Ok(())
}
fn unquote_string(&self, s: &str) -> Result<String, String> {
if s.starts_with('"') && s.ends_with('"') && s.len() >= 2 {
Ok(s[1..s.len()-1].to_string())
@@ -223,12 +223,12 @@ impl MathValidator {
/// Valide script is not empty
fn validate_script_basic_syntax(script: &str) -> Result<(), Status> {
let trimmed = script.trim();
// Check for empty script
if trimmed.is_empty() {
return Err(Status::invalid_argument("Script cannot be empty"));
}
// Basic parentheses balance check
let mut paren_count = 0;
for ch in trimmed.chars() {
@@ -243,16 +243,16 @@ fn validate_script_basic_syntax(script: &str) -> Result<(), Status> {
_ => {}
}
}
if paren_count != 0 {
return Err(Status::invalid_argument("Unbalanced parentheses: missing closing parentheses"));
}
// Check for basic S-expression structure
if !trimmed.starts_with('(') {
return Err(Status::invalid_argument("Script must start with an opening parenthesis '('"));
}
Ok(())
}
@@ -261,11 +261,11 @@ fn extract_math_column_references(script: &str) -> Result<Vec<(String, String)>,
let mut parser = Parser::new(script);
let expressions = parser.parse()
.map_err(|e| format!("Parse error: {}", e))?;
let mut validator = MathValidator::new();
validator.validate_expressions(&expressions)
.map_err(|e| format!("Validation error: {}", e))?;
Ok(validator.column_references)
}
@@ -278,19 +278,19 @@ async fn validate_math_operations_column_types(
// Extract column references from mathematical contexts using proper S-expression parsing
let column_refs = extract_math_column_references(script)
.map_err(|e| Status::invalid_argument(format!("Script parsing failed: {}", e)))?;
if column_refs.is_empty() {
return Ok(()); // No column references in math operations
}
// Get all unique table names referenced in math operations
let table_names: HashSet<String> = column_refs.iter()
.map(|(table, _)| table.clone())
.collect();
// Fetch table definitions for all referenced tables
let table_definitions = sqlx::query!(
r#"SELECT table_name, columns FROM table_definitions
r#"SELECT table_name, columns FROM table_definitions
WHERE schema_id = $1 AND table_name = ANY($2)"#,
schema_id,
&table_names.into_iter().collect::<Vec<_>>()
@@ -298,14 +298,14 @@ async fn validate_math_operations_column_types(
.fetch_all(db_pool)
.await
.map_err(|e| Status::internal(format!("Failed to fetch table definitions: {}", e)))?;
// Build a map of table_name -> column_name -> column_type
let mut table_column_types: HashMap<String, HashMap<String, String>> = HashMap::new();
for table_def in table_definitions {
let columns: Vec<String> = serde_json::from_value(table_def.columns)
.map_err(|e| Status::internal(format!("Invalid column data for table '{}': {}", table_def.table_name, e)))?;
let mut column_types = HashMap::new();
for column_def in columns {
let mut parts = column_def.split_whitespace();
@@ -316,13 +316,13 @@ async fn validate_math_operations_column_types(
}
table_column_types.insert(table_def.table_name, column_types);
}
// Check each column reference in mathematical operations
for (table_name, column_name) in column_refs {
if let Some(table_columns) = table_column_types.get(&table_name) {
if let Some(column_type) = table_columns.get(&column_name) {
let normalized_type = normalize_data_type(column_type);
// Check if this type is prohibited in math operations
if MATH_PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) {
return Err(Status::invalid_argument(format!(
@@ -347,7 +347,7 @@ async fn validate_math_operations_column_types(
)));
}
}
Ok(())
}
@@ -616,12 +616,12 @@ pub async fn post_table_script(
)
.map_err(|e| Status::invalid_argument(e))?;
// Validate that script doesn't reference prohibited column types by checking actual DB schema
validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?;
// NEW: Validate that mathematical operations don't use TEXT or BOOLEAN columns
// REORDER: Math validation FIRST so we get specific error messages for math operations
validate_math_operations_column_types(db_pool, table_def.schema_id, &request.script).await?;
// THEN general column validation (catches non-math prohibited access)
validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?;
// Validate SQL queries in script for prohibited type operations
validate_sql_queries_in_script(&request.script)
.map_err(|e| Status::invalid_argument(e))?;