Files
komp_ac/server/src/table_script/handlers/post_table_script.rs
2025-07-20 14:58:39 +02:00

748 lines
27 KiB
Rust

// src/table_script/handlers/post_table_script.rs
// TODO MAKE THE SCRIPTS PUSH ONLY TO THE EMPTY FILES
use tonic::Status;
use sqlx::{PgPool, Error as SqlxError};
use common::proto::multieko2::table_script::{PostTableScriptRequest, TableScriptResponse};
use serde_json::Value;
use steel_decimal::SteelDecimal;
use regex::Regex;
use std::collections::HashSet;
use std::collections::HashMap;
use crate::table_script::handlers::dependency_analyzer::DependencyAnalyzer;
const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"];
// Define prohibited data types for Steel scripts (boolean is explicitly allowed)
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
const MATH_PROHIBITED_TYPES: &[&str] = &["BIGINT", "TEXT", "BOOLEAN", "DATE", "TIMESTAMPTZ"];
// Math operations that Steel Decimal will transform
const MATH_OPERATIONS: &[&str] = &[
"+", "-", "*", "/", "^", "**", "pow", "sqrt",
">", "<", "=", ">=", "<=", "min", "max", "abs",
"round", "ln", "log", "log10", "exp", "sin", "cos", "tan"
];
#[derive(Debug, Clone)]
enum SExpr {
Atom(String),
List(Vec<SExpr>),
}
#[derive(Debug)]
struct Parser {
tokens: Vec<String>,
position: usize,
}
impl Parser {
fn new(script: &str) -> Self {
let tokens = Self::tokenize(script);
Self { tokens, position: 0 }
}
fn tokenize(script: &str) -> Vec<String> {
let mut tokens = Vec::new();
let mut current_token = String::new();
let mut in_string = false;
let mut escape_next = false;
for ch in script.chars() {
if escape_next {
current_token.push(ch);
escape_next = false;
continue;
}
match ch {
'\\' if in_string => {
escape_next = true;
current_token.push(ch);
}
'"' => {
current_token.push(ch);
if in_string {
// End of string - push the complete string token
tokens.push(current_token.clone());
current_token.clear();
}
in_string = !in_string;
}
'(' | ')' if !in_string => {
if !current_token.is_empty() {
tokens.push(current_token.clone());
current_token.clear();
}
tokens.push(ch.to_string());
}
ch if ch.is_whitespace() && !in_string => {
if !current_token.is_empty() {
tokens.push(current_token.clone());
current_token.clear();
}
}
_ => {
current_token.push(ch);
}
}
}
if !current_token.is_empty() {
tokens.push(current_token);
}
tokens
}
fn parse(&mut self) -> Result<Vec<SExpr>, String> {
let mut expressions = Vec::new();
while self.position < self.tokens.len() {
expressions.push(self.parse_expr()?);
}
Ok(expressions)
}
fn parse_expr(&mut self) -> Result<SExpr, String> {
if self.position >= self.tokens.len() {
return Err("Unexpected end of input".to_string());
}
let token = &self.tokens[self.position];
if token == "(" {
self.position += 1; // consume '('
let mut elements = Vec::new();
while self.position < self.tokens.len() && self.tokens[self.position] != ")" {
elements.push(self.parse_expr()?);
}
if self.position >= self.tokens.len() {
return Err("Missing closing parenthesis".to_string());
}
self.position += 1; // consume ')'
Ok(SExpr::List(elements))
} else {
self.position += 1;
Ok(SExpr::Atom(token.clone()))
}
}
}
#[derive(Debug)]
struct MathValidator {
column_references: Vec<(String, String)>, // (table, column) pairs found in math contexts
}
impl MathValidator {
fn new() -> Self {
Self {
column_references: Vec::new(),
}
}
fn validate_expressions(&mut self, expressions: &[SExpr]) -> Result<(), String> {
for expr in expressions {
self.check_expression(expr, false)?;
}
Ok(())
}
fn check_expression(&mut self, expr: &SExpr, in_math_context: bool) -> Result<(), String> {
match expr {
SExpr::Atom(_) => Ok(()),
SExpr::List(elements) => {
if elements.is_empty() {
return Ok(());
}
// Check if this is a math operation
let is_math = if let SExpr::Atom(op) = &elements[0] {
MATH_OPERATIONS.contains(&op.as_str())
} else {
false
};
// Check if this is a column access function
if let SExpr::Atom(func) = &elements[0] {
if func == "steel_get_column" && in_math_context {
self.extract_column_reference_from_steel_get_column(elements)?;
} else if func == "steel_get_column_with_index" && in_math_context {
self.extract_column_reference_from_steel_get_column_with_index(elements)?;
}
}
// Recursively check all elements, marking math context appropriately
for element in &elements[1..] { // Skip the operator/function name
self.check_expression(element, in_math_context || is_math)?;
}
Ok(())
}
}
}
fn extract_column_reference_from_steel_get_column(&mut self, elements: &[SExpr]) -> Result<(), String> {
// (steel_get_column "table" "column")
if elements.len() >= 3 {
if let (SExpr::Atom(table), SExpr::Atom(column)) = (&elements[1], &elements[2]) {
let table_name = self.unquote_string(table)?;
let column_name = self.unquote_string(column)?;
self.column_references.push((table_name, column_name));
}
}
Ok(())
}
fn extract_column_reference_from_steel_get_column_with_index(&mut self, elements: &[SExpr]) -> Result<(), String> {
// (steel_get_column_with_index "table" index "column")
if elements.len() >= 4 {
if let (SExpr::Atom(table), SExpr::Atom(column)) = (&elements[1], &elements[3]) {
let table_name = self.unquote_string(table)?;
let column_name = self.unquote_string(column)?;
self.column_references.push((table_name, column_name));
}
}
Ok(())
}
fn unquote_string(&self, s: &str) -> Result<String, String> {
if s.starts_with('"') && s.ends_with('"') && s.len() >= 2 {
Ok(s[1..s.len()-1].to_string())
} else {
Err(format!("Expected quoted string, got: {}", s))
}
}
}
/// Valide script is not empty
fn validate_script_basic_syntax(script: &str) -> Result<(), Status> {
let trimmed = script.trim();
// Check for empty script
if trimmed.is_empty() {
return Err(Status::invalid_argument("Script cannot be empty"));
}
// Basic parentheses balance check
let mut paren_count = 0;
for ch in trimmed.chars() {
match ch {
'(' => paren_count += 1,
')' => {
paren_count -= 1;
if paren_count < 0 {
return Err(Status::invalid_argument("Unbalanced parentheses: closing ')' without matching opening '('"));
}
},
_ => {}
}
}
if paren_count != 0 {
return Err(Status::invalid_argument("Unbalanced parentheses: missing closing parentheses"));
}
// Check for basic S-expression structure
if !trimmed.starts_with('(') {
return Err(Status::invalid_argument("Script must start with an opening parenthesis '('"));
}
Ok(())
}
/// Parse Steel script and extract column references used in mathematical contexts
fn extract_math_column_references(script: &str) -> Result<Vec<(String, String)>, String> {
let mut parser = Parser::new(script);
let expressions = parser.parse()
.map_err(|e| format!("Parse error: {}", e))?;
let mut validator = MathValidator::new();
validator.validate_expressions(&expressions)
.map_err(|e| format!("Validation error: {}", e))?;
Ok(validator.column_references)
}
/// Validate that mathematical operations don't use TEXT or BOOLEAN columns
async fn validate_math_operations_column_types(
db_pool: &PgPool,
schema_id: i64,
script: &str,
) -> Result<(), Status> {
// Extract column references from mathematical contexts using proper S-expression parsing
let column_refs = extract_math_column_references(script)
.map_err(|e| Status::invalid_argument(format!("Script parsing failed: {}", e)))?;
if column_refs.is_empty() {
return Ok(()); // No column references in math operations
}
// Get all unique table names referenced in math operations
let table_names: HashSet<String> = column_refs.iter()
.map(|(table, _)| table.clone())
.collect();
// Fetch table definitions for all referenced tables
let table_definitions = sqlx::query!(
r#"SELECT table_name, columns FROM table_definitions
WHERE schema_id = $1 AND table_name = ANY($2)"#,
schema_id,
&table_names.into_iter().collect::<Vec<_>>()
)
.fetch_all(db_pool)
.await
.map_err(|e| Status::internal(format!("Failed to fetch table definitions: {}", e)))?;
// Build a map of table_name -> column_name -> column_type
let mut table_column_types: HashMap<String, HashMap<String, String>> = HashMap::new();
for table_def in table_definitions {
let columns: Vec<String> = serde_json::from_value(table_def.columns)
.map_err(|e| Status::internal(format!("Invalid column data for table '{}': {}", table_def.table_name, e)))?;
let mut column_types = HashMap::new();
for column_def in columns {
let mut parts = column_def.split_whitespace();
if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) {
let column_name = name.trim_matches('"');
column_types.insert(column_name.to_string(), data_type.to_string());
}
}
table_column_types.insert(table_def.table_name, column_types);
}
// Check each column reference in mathematical operations
for (table_name, column_name) in column_refs {
if let Some(table_columns) = table_column_types.get(&table_name) {
if let Some(column_type) = table_columns.get(&column_name) {
let normalized_type = normalize_data_type(column_type);
// Check if this type is prohibited in math operations
if MATH_PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) {
return Err(Status::invalid_argument(format!(
"Cannot use column '{}' of type '{}' from table '{}' in mathematical operations. Mathematical operations cannot use columns of type: {}",
column_name,
column_type,
table_name,
MATH_PROHIBITED_TYPES.join(", ")
)));
}
} else {
return Err(Status::invalid_argument(format!(
"Script references column '{}' in table '{}' but this column does not exist",
column_name,
table_name
)));
}
} else {
return Err(Status::invalid_argument(format!(
"Script references table '{}' in mathematical operations but this table does not exist in this schema",
table_name
)));
}
}
Ok(())
}
/// Validates the target column and ensures it is not a system column or prohibited type.
/// Returns the column type if valid.
fn validate_target_column(
table_name: &str,
target: &str,
table_columns: &Value,
) -> Result<String, String> {
if SYSTEM_COLUMNS.contains(&target) {
return Err(format!("Cannot override system column: {}", target));
}
// Parse the columns JSON into a vector of strings
let columns: Vec<String> = serde_json::from_value(table_columns.clone())
.map_err(|e| format!("Invalid column data: {}", e))?;
// Extract column names and types
let column_info: Vec<(&str, &str)> = columns
.iter()
.filter_map(|c| {
let mut parts = c.split_whitespace();
let name = parts.next()?.trim_matches('"');
let data_type = parts.next()?;
Some((name, data_type))
})
.collect();
// Find the target column and return its type
let column_type = column_info
.iter()
.find(|(name, _)| *name == target)
.map(|(_, dt)| dt.to_string())
.ok_or_else(|| format!("Target column '{}' not defined in table '{}'", target, table_name))?;
// Check if the target column type is prohibited
if is_prohibited_type(&column_type) {
return Err(format!(
"Cannot create script for column '{}' with type '{}'. Steel scripts cannot target columns of type: {}",
target,
column_type,
PROHIBITED_TYPES.join(", ")
));
}
// Add helpful info for boolean columns
let normalized_type = normalize_data_type(&column_type);
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
println!("Info: Target column '{}' is boolean type. Values will be converted to Steel format (#true/#false)", target);
}
Ok(column_type)
}
/// Check if a data type is prohibited for Steel scripts
/// Note: BOOLEAN/BOOL is explicitly allowed and handled with special conversion
fn is_prohibited_type(data_type: &str) -> bool {
let normalized_type = normalize_data_type(data_type);
PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited))
}
/// Normalize data type for comparison (handle NUMERIC variations, etc.)
fn normalize_data_type(data_type: &str) -> String {
data_type.to_uppercase()
.split('(') // Remove precision/scale from NUMERIC(x,y)
.next()
.unwrap_or(data_type)
.trim()
.to_string()
}
/// Parse Steel script to extract all table/column references
fn extract_column_references_from_script(script: &str) -> Vec<(String, String)> {
let mut references = Vec::new();
// Regex patterns to match Steel function calls
let patterns = [
// (steel_get_column "table_name" "column_name")
r#"\(steel_get_column\s+"([^"]+)"\s+"([^"]+)"\)"#,
// (steel_get_column_with_index "table_name" index "column_name")
r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#,
];
for pattern in &patterns {
if let Ok(re) = Regex::new(pattern) {
for cap in re.captures_iter(script) {
if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) {
references.push((table.as_str().to_string(), column.as_str().to_string()));
}
}
}
}
// Also check for steel_get_column_with_index pattern (table, column are in different positions)
if let Ok(re) = Regex::new(r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#) {
for cap in re.captures_iter(script) {
if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) {
references.push((table.as_str().to_string(), column.as_str().to_string()));
}
}
}
references
}
/// Validate that script doesn't reference prohibited column types by checking actual DB schema
async fn validate_script_column_references(
db_pool: &PgPool,
schema_id: i64,
script: &str,
) -> Result<(), Status> {
// Extract all table/column references from the script
let references = extract_column_references_from_script(script);
if references.is_empty() {
return Ok(()); // No column references to validate
}
// Get all unique table names referenced in the script
let table_names: HashSet<String> = references.iter()
.map(|(table, _)| table.clone())
.collect();
// Fetch table definitions for all referenced tables
for table_name in table_names {
// Query the actual table definition from the database
let table_def = sqlx::query!(
r#"SELECT table_name, columns FROM table_definitions
WHERE schema_id = $1 AND table_name = $2"#,
schema_id,
table_name
)
.fetch_optional(db_pool)
.await
.map_err(|e| Status::internal(format!("Failed to fetch table definition for '{}': {}", table_name, e)))?;
if let Some(table_def) = table_def {
// Check each column reference for this table
for (ref_table, ref_column) in &references {
if ref_table == &table_name {
// Validate this specific column reference
if let Err(error_msg) = validate_referenced_column_type(&table_name, ref_column, &table_def.columns) {
return Err(Status::invalid_argument(error_msg));
}
}
}
} else {
return Err(Status::invalid_argument(format!(
"Script references table '{}' which does not exist in this schema",
table_name
)));
}
}
Ok(())
}
/// Validate that a referenced column doesn't have a prohibited type
fn validate_referenced_column_type(table_name: &str, column_name: &str, table_columns: &Value) -> Result<(), String> {
// Parse the columns JSON into a vector of strings
let columns: Vec<String> = serde_json::from_value(table_columns.clone())
.map_err(|e| format!("Invalid column data for table '{}': {}", table_name, e))?;
// Extract column names and types
let column_info: Vec<(&str, &str)> = columns
.iter()
.filter_map(|c| {
let mut parts = c.split_whitespace();
let name = parts.next()?.trim_matches('"');
let data_type = parts.next()?;
Some((name, data_type))
})
.collect();
// Find the referenced column and check its type
if let Some((_, column_type)) = column_info.iter().find(|(name, _)| *name == column_name) {
if is_prohibited_type(column_type) {
return Err(format!(
"Script references column '{}' in table '{}' which has prohibited type '{}'. Steel scripts cannot access columns of type: {}",
column_name,
table_name,
column_type,
PROHIBITED_TYPES.join(", ")
));
}
// Log info for boolean columns
let normalized_type = normalize_data_type(column_type);
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
println!("Info: Script references boolean column '{}' in table '{}'. Values will be converted to Steel format (#true/#false)", column_name, table_name);
}
} else {
return Err(format!(
"Script references column '{}' in table '{}' but this column does not exist",
column_name,
table_name
));
}
Ok(())
}
/// Parse Steel SQL queries to check for prohibited type usage (basic heuristic)
fn validate_sql_queries_in_script(script: &str) -> Result<(), String> {
// Look for steel_query_sql calls
if let Ok(re) = Regex::new(r#"\(steel_query_sql\s+"([^"]+)"\)"#) {
for cap in re.captures_iter(script) {
if let Some(query) = cap.get(1) {
let sql = query.as_str().to_uppercase();
// Basic heuristic checks for prohibited type operations
let prohibited_patterns = [
"EXTRACT(",
"DATE_PART(",
"::DATE",
"::TIMESTAMPTZ",
"::BIGINT",
"CAST(", // Could be casting to prohibited types
];
for pattern in &prohibited_patterns {
if sql.contains(pattern) {
return Err(format!(
"Script contains SQL query with potentially prohibited type operations: '{}'. Steel scripts cannot use operations on types: {}",
query.as_str(),
PROHIBITED_TYPES.join(", ")
));
}
}
}
}
}
Ok(())
}
/// Handles the creation of a new table script with dependency validation.
pub async fn post_table_script(
db_pool: &PgPool,
request: PostTableScriptRequest,
) -> Result<TableScriptResponse, Status> {
// Basic script validation first
validate_script_basic_syntax(&request.script)?;
// Start a transaction for ALL operations - critical for atomicity
let mut tx = db_pool.begin().await
.map_err(|e| Status::internal(format!("Failed to start transaction: {}", e)))?;
// Fetch the table definition
let table_def = sqlx::query!(
r#"SELECT id, table_name, columns, schema_id
FROM table_definitions WHERE id = $1"#,
request.table_definition_id
)
.fetch_optional(&mut *tx)
.await
.map_err(|e| Status::internal(format!("Failed to fetch table definition: {}", e)))?
.ok_or_else(|| Status::not_found("Table definition not found"))?;
// Validate the target column and get its type (includes prohibited type check)
let column_type = validate_target_column(
&table_def.table_name,
&request.target_column,
&table_def.columns,
)
.map_err(|e| Status::invalid_argument(e))?;
// REORDER: Math validation FIRST so we get specific error messages for math operations
validate_math_operations_column_types(db_pool, table_def.schema_id, &request.script).await?;
// THEN general column validation (catches non-math prohibited access)
validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?;
// Validate SQL queries in script for prohibited type operations
validate_sql_queries_in_script(&request.script)
.map_err(|e| Status::invalid_argument(e))?;
// Create dependency analyzer for this schema
let analyzer = DependencyAnalyzer::new(table_def.schema_id, db_pool.clone());
// Analyze script dependencies
let dependencies = analyzer
.analyze_script_dependencies(&request.script)
.map_err(|e| Status::from(e))?;
// Check for circular dependencies BEFORE making any changes
// Pass the transaction to ensure we see any existing dependencies
analyzer
.check_for_cycles(&mut tx, table_def.id, &dependencies)
.await
.map_err(|e| Status::from(e))?;
// Transform the script using steel_decimal (this happens AFTER validation)
let steel_decimal = SteelDecimal::new();
let parsed_script = steel_decimal.transform(&request.script);
// Insert or update the script
let script_record = sqlx::query!(
r#"INSERT INTO table_scripts
(table_definitions_id, target_table, target_column,
target_column_type, script, description, schema_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (table_definitions_id, target_column)
DO UPDATE SET
script = EXCLUDED.script,
description = EXCLUDED.description,
target_column_type = EXCLUDED.target_column_type
RETURNING id"#,
request.table_definition_id,
table_def.table_name,
request.target_column,
column_type,
parsed_script,
request.description,
table_def.schema_id
)
.fetch_one(&mut *tx)
.await
.map_err(|e| {
match e {
SqlxError::Database(db_err)
if db_err.constraint() == Some("table_scripts_table_definitions_id_target_column_key") => {
Status::already_exists("Script already exists for this column")
}
_ => Status::internal(format!("Failed to insert script: {}", e)),
}
})?;
// Save the dependencies within the same transaction
analyzer
.save_dependencies(&mut tx, script_record.id, table_def.id, &dependencies)
.await
.map_err(|e| Status::from(e))?;
// Only now commit the entire transaction - script + dependencies together
tx.commit().await
.map_err(|e| Status::internal(format!("Failed to commit transaction: {}", e)))?;
// Generate warnings for potential issues
let warnings = generate_warnings(&dependencies, &table_def.table_name);
Ok(TableScriptResponse {
id: script_record.id,
warnings,
})
}
/// Generate helpful warnings for script dependencies
fn generate_warnings(dependencies: &[crate::table_script::handlers::dependency_analyzer::Dependency], table_name: &str) -> String {
let mut warnings = Vec::new();
// Check for self-references
if dependencies.iter().any(|d| d.target_table == table_name) {
warnings.push("Warning: Script references its own table, which may cause issues during initial population.".to_string());
}
// Check for complex SQL queries
let sql_deps_count = dependencies.iter()
.filter(|d| matches!(d.dependency_type, crate::table_script::handlers::dependency_analyzer::DependencyType::SqlQuery { .. }))
.count();
if sql_deps_count > 0 {
warnings.push(format!(
"Warning: Script contains {} raw SQL quer{}, ensure they are read-only and reference valid tables.",
sql_deps_count,
if sql_deps_count == 1 { "y" } else { "ies" }
));
}
// Check for many dependencies
if dependencies.len() > 5 {
warnings.push(format!(
"Warning: Script depends on {} tables, which may affect processing performance.",
dependencies.len()
));
}
// Count structured access dependencies
let structured_deps_count = dependencies.iter()
.filter(|d| matches!(
d.dependency_type,
crate::table_script::handlers::dependency_analyzer::DependencyType::ColumnAccess { .. } |
crate::table_script::handlers::dependency_analyzer::DependencyType::IndexedAccess { .. }
))
.count();
if structured_deps_count > 0 {
warnings.push(format!(
"Info: Script uses {} linked table{} via steel_get_column functions.",
structured_deps_count,
if structured_deps_count == 1 { "" } else { "s" }
));
}
warnings.join(" ")
}