steel scripts now have far better logic than before
This commit is contained in:
@@ -2,10 +2,11 @@
|
||||
use steel::steel_vm::engine::Engine;
|
||||
use steel::steel_vm::register_fn::RegisterFn;
|
||||
use steel::rvals::SteelVal;
|
||||
use super::functions::SteelContext;
|
||||
use super::functions::{SteelContext, convert_row_data_for_steel};
|
||||
use steel_decimal::registry::FunctionRegistry;
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -25,6 +26,74 @@ pub enum ExecutionError {
|
||||
UnsupportedType(String),
|
||||
}
|
||||
|
||||
/// Create a SteelContext with boolean conversion applied to row data
|
||||
pub async fn create_steel_context_with_boolean_conversion(
|
||||
current_table: String,
|
||||
schema_id: i64,
|
||||
schema_name: String,
|
||||
mut row_data: HashMap<String, String>,
|
||||
db_pool: Arc<PgPool>,
|
||||
) -> Result<SteelContext, ExecutionError> {
|
||||
// Convert boolean values in row_data to Steel format
|
||||
convert_row_data_for_steel(&db_pool, schema_id, ¤t_table, &mut row_data)
|
||||
.await
|
||||
.map_err(|e| ExecutionError::RuntimeError(format!("Failed to convert row data: {}", e)))?;
|
||||
|
||||
Ok(SteelContext {
|
||||
current_table,
|
||||
schema_id,
|
||||
schema_name,
|
||||
row_data,
|
||||
db_pool,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute script with proper boolean handling
|
||||
pub async fn execute_script_with_boolean_support(
|
||||
script: String,
|
||||
target_type: &str,
|
||||
db_pool: Arc<PgPool>,
|
||||
schema_id: i64,
|
||||
schema_name: String,
|
||||
current_table: String,
|
||||
row_data: HashMap<String, String>,
|
||||
) -> Result<Value, ExecutionError> {
|
||||
let mut vm = Engine::new();
|
||||
|
||||
// Create context with boolean conversion
|
||||
let context = create_steel_context_with_boolean_conversion(
|
||||
current_table,
|
||||
schema_id,
|
||||
schema_name,
|
||||
row_data,
|
||||
db_pool.clone(),
|
||||
).await?;
|
||||
|
||||
let context = Arc::new(context);
|
||||
|
||||
// Register existing Steel functions
|
||||
register_steel_functions(&mut vm, context.clone());
|
||||
|
||||
// Register all decimal math functions using the steel_decimal crate
|
||||
register_decimal_math_functions(&mut vm);
|
||||
|
||||
// Register variables from the context with the Steel VM
|
||||
// The row_data now contains Steel-formatted boolean values
|
||||
FunctionRegistry::register_variables(&mut vm, context.row_data.clone());
|
||||
|
||||
// Execute script and process results
|
||||
let results = vm.compile_and_run_raw_program(script)
|
||||
.map_err(|e| ExecutionError::RuntimeError(e.to_string()))?;
|
||||
|
||||
// Convert results to target type
|
||||
match target_type {
|
||||
"STRINGS" => process_string_results(results),
|
||||
_ => Err(ExecutionError::UnsupportedType(target_type.into()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Original execute_script function (kept for backward compatibility)
|
||||
/// Note: This doesn't include boolean conversion - use execute_script_with_boolean_support for new code
|
||||
pub fn execute_script(
|
||||
script: String,
|
||||
target_type: &str,
|
||||
|
||||
@@ -16,8 +16,13 @@ pub enum FunctionError {
|
||||
TableNotFound(String),
|
||||
#[error("Database error: {0}")]
|
||||
DatabaseError(String),
|
||||
#[error("Prohibited data type access: {0}")]
|
||||
ProhibitedTypeAccess(String),
|
||||
}
|
||||
|
||||
// Define prohibited data types (boolean is explicitly allowed)
|
||||
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SteelContext {
|
||||
pub current_table: String,
|
||||
@@ -43,10 +48,102 @@ impl SteelContext {
|
||||
Ok(table_def.table_name)
|
||||
}
|
||||
|
||||
/// Get column type for a given table and column
|
||||
async fn get_column_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
|
||||
let table_def = sqlx::query!(
|
||||
r#"SELECT columns FROM table_definitions
|
||||
WHERE schema_id = $1 AND table_name = $2"#,
|
||||
self.schema_id,
|
||||
table_name
|
||||
)
|
||||
.fetch_optional(&*self.db_pool)
|
||||
.await
|
||||
.map_err(|e| FunctionError::DatabaseError(e.to_string()))?
|
||||
.ok_or_else(|| FunctionError::TableNotFound(table_name.to_string()))?;
|
||||
|
||||
// Parse columns JSON to find the column type
|
||||
let columns: Vec<String> = serde_json::from_value(table_def.columns)
|
||||
.map_err(|e| FunctionError::DatabaseError(format!("Invalid column data: {}", e)))?;
|
||||
|
||||
// Find the column and its type
|
||||
for column_def in columns {
|
||||
let mut parts = column_def.split_whitespace();
|
||||
if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) {
|
||||
let column_name_clean = name.trim_matches('"');
|
||||
if column_name_clean == column_name {
|
||||
return Ok(data_type.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(FunctionError::ColumnNotFound(format!(
|
||||
"Column '{}' not found in table '{}'",
|
||||
column_name,
|
||||
table_name
|
||||
)))
|
||||
}
|
||||
|
||||
/// Convert database value to Steel format based on column type
|
||||
fn convert_value_to_steel_format(&self, value: &str, column_type: &str) -> String {
|
||||
let normalized_type = normalize_data_type(column_type);
|
||||
|
||||
match normalized_type.as_str() {
|
||||
"BOOLEAN" | "BOOL" => {
|
||||
// Convert database boolean to Steel boolean syntax
|
||||
match value.to_lowercase().as_str() {
|
||||
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
|
||||
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
|
||||
_ => value.to_string(), // Return as-is if not a recognized boolean
|
||||
}
|
||||
}
|
||||
_ => value.to_string(), // Return as-is for non-boolean types
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate column type and return the column type if valid
|
||||
async fn validate_column_type_and_get_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
|
||||
let column_type = self.get_column_type(table_name, column_name).await?;
|
||||
|
||||
// Check if this type is prohibited
|
||||
if is_prohibited_type(&column_type) {
|
||||
return Err(FunctionError::ProhibitedTypeAccess(format!(
|
||||
"Cannot access column '{}' in table '{}' because it has prohibited type '{}'. Steel scripts cannot access columns of type: {}",
|
||||
column_name,
|
||||
table_name,
|
||||
column_type,
|
||||
PROHIBITED_TYPES.join(", ")
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(column_type)
|
||||
}
|
||||
|
||||
/// Validate column type before access (legacy method for backward compatibility)
|
||||
async fn validate_column_type(&self, table_name: &str, column_name: &str) -> Result<(), FunctionError> {
|
||||
self.validate_column_type_and_get_type(table_name, column_name).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn steel_get_column(&self, table: &str, column: &str) -> Result<SteelVal, SteelVal> {
|
||||
if table == self.current_table {
|
||||
// Validate column type for current table access and get the type
|
||||
let column_type = tokio::task::block_in_place(|| {
|
||||
let handle = tokio::runtime::Handle::current();
|
||||
handle.block_on(async {
|
||||
self.validate_column_type_and_get_type(table, column).await
|
||||
})
|
||||
});
|
||||
|
||||
let column_type = match column_type {
|
||||
Ok(ct) => ct,
|
||||
Err(e) => return Err(SteelVal::StringV(e.to_string().into())),
|
||||
};
|
||||
|
||||
return self.row_data.get(column)
|
||||
.map(|v| SteelVal::StringV(v.clone().into()))
|
||||
.map(|v| {
|
||||
let converted_value = self.convert_value_to_steel_format(v, &column_type);
|
||||
SteelVal::StringV(converted_value.into())
|
||||
})
|
||||
.ok_or_else(|| SteelVal::StringV(format!("Column {} not found", column).into()));
|
||||
}
|
||||
|
||||
@@ -65,15 +162,23 @@ impl SteelContext {
|
||||
let actual_table = self.get_related_table_name(base_name).await
|
||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||
|
||||
// Add quotes around the table name
|
||||
sqlx::query_scalar::<_, String>(
|
||||
// Get column type for validation and conversion
|
||||
let column_type = self.validate_column_type_and_get_type(&actual_table, column).await
|
||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||
|
||||
// Fetch the raw value from database
|
||||
let raw_value = sqlx::query_scalar::<_, String>(
|
||||
&format!("SELECT {} FROM \"{}\".\"{}\" WHERE id = $1", column, self.schema_name, actual_table)
|
||||
)
|
||||
.bind(fk_value.parse::<i64>().map_err(|_|
|
||||
SteelVal::StringV("Invalid foreign key format".into()))?)
|
||||
.fetch_one(&*self.db_pool)
|
||||
.await
|
||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))
|
||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||
|
||||
// Convert to Steel format
|
||||
let converted_value = self.convert_value_to_steel_format(&raw_value, &column_type);
|
||||
Ok(converted_value)
|
||||
})
|
||||
});
|
||||
|
||||
@@ -86,12 +191,37 @@ impl SteelContext {
|
||||
index: i64,
|
||||
column: &str
|
||||
) -> Result<SteelVal, SteelVal> {
|
||||
// Get the full value first (this already handles type conversion)
|
||||
let value = self.steel_get_column(table, column)?;
|
||||
|
||||
if let SteelVal::StringV(s) = value {
|
||||
let parts: Vec<_> = s.split(',').collect();
|
||||
parts.get(index as usize)
|
||||
.map(|v| SteelVal::StringV(v.trim().into()))
|
||||
.ok_or_else(|| SteelVal::StringV("Index out of bounds".into()))
|
||||
|
||||
if let Some(part) = parts.get(index as usize) {
|
||||
let trimmed_part = part.trim();
|
||||
|
||||
// If the original column was boolean type, each part should also be treated as boolean
|
||||
// We need to get the column type to determine if conversion is needed
|
||||
let column_type = tokio::task::block_in_place(|| {
|
||||
let handle = tokio::runtime::Handle::current();
|
||||
handle.block_on(async {
|
||||
self.get_column_type(table, column).await
|
||||
})
|
||||
});
|
||||
|
||||
match column_type {
|
||||
Ok(ct) => {
|
||||
let converted_part = self.convert_value_to_steel_format(trimmed_part, &ct);
|
||||
Ok(SteelVal::StringV(converted_part.into()))
|
||||
}
|
||||
Err(_) => {
|
||||
// If we can't get the type, return as-is
|
||||
Ok(SteelVal::StringV(trimmed_part.into()))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err(SteelVal::StringV("Index out of bounds".into()))
|
||||
}
|
||||
} else {
|
||||
Err(SteelVal::StringV("Expected comma-separated string".into()))
|
||||
}
|
||||
@@ -105,6 +235,14 @@ impl SteelContext {
|
||||
));
|
||||
}
|
||||
|
||||
// Check if query might access prohibited columns
|
||||
if contains_prohibited_column_access(query) {
|
||||
return Err(SteelVal::StringV(format!(
|
||||
"SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}",
|
||||
PROHIBITED_TYPES.join(", ")
|
||||
).into()));
|
||||
}
|
||||
|
||||
let pool = self.db_pool.clone();
|
||||
|
||||
// Use `tokio::task::block_in_place` to safely block the thread
|
||||
@@ -132,9 +270,87 @@ impl SteelContext {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a data type is prohibited for Steel scripts
|
||||
fn is_prohibited_type(data_type: &str) -> bool {
|
||||
let normalized_type = normalize_data_type(data_type);
|
||||
PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited))
|
||||
}
|
||||
|
||||
/// Normalize data type for comparison (handle NUMERIC variations, etc.)
|
||||
fn normalize_data_type(data_type: &str) -> String {
|
||||
data_type.to_uppercase()
|
||||
.split('(') // Remove precision/scale from NUMERIC(x,y)
|
||||
.next()
|
||||
.unwrap_or(data_type)
|
||||
.trim()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Basic check for prohibited column access in SQL queries
|
||||
/// This is a simple heuristic - more sophisticated parsing could be added
|
||||
fn contains_prohibited_column_access(query: &str) -> bool {
|
||||
let query_upper = query.to_uppercase();
|
||||
|
||||
// Look for common patterns that might indicate prohibited type access
|
||||
// This is a basic implementation - you might want to enhance this
|
||||
let patterns = [
|
||||
"EXTRACT(", // Common with DATE/TIMESTAMPTZ
|
||||
"DATE_PART(", // Common with DATE/TIMESTAMPTZ
|
||||
"::DATE",
|
||||
"::TIMESTAMPTZ",
|
||||
"::BIGINT",
|
||||
];
|
||||
|
||||
patterns.iter().any(|pattern| query_upper.contains(pattern))
|
||||
}
|
||||
|
||||
fn is_read_only_query(query: &str) -> bool {
|
||||
let query = query.trim_start().to_uppercase();
|
||||
query.starts_with("SELECT") ||
|
||||
query.starts_with("SHOW") ||
|
||||
query.starts_with("EXPLAIN")
|
||||
}
|
||||
|
||||
/// Helper function to convert initial row data for boolean columns
|
||||
pub async fn convert_row_data_for_steel(
|
||||
db_pool: &PgPool,
|
||||
schema_id: i64,
|
||||
table_name: &str,
|
||||
row_data: &mut HashMap<String, String>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
// Get table definition to check column types
|
||||
let table_def = sqlx::query!(
|
||||
r#"SELECT columns FROM table_definitions
|
||||
WHERE schema_id = $1 AND table_name = $2"#,
|
||||
schema_id,
|
||||
table_name
|
||||
)
|
||||
.fetch_optional(db_pool)
|
||||
.await?
|
||||
.ok_or_else(|| sqlx::Error::RowNotFound)?;
|
||||
|
||||
// Parse columns to find boolean types
|
||||
if let Ok(columns) = serde_json::from_value::<Vec<String>>(table_def.columns) {
|
||||
for column_def in columns {
|
||||
let mut parts = column_def.split_whitespace();
|
||||
if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) {
|
||||
let column_name = name.trim_matches('"');
|
||||
let normalized_type = normalize_data_type(data_type);
|
||||
|
||||
// Fixed: Use traditional if let instead of let chains
|
||||
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
|
||||
if let Some(value) = row_data.get_mut(column_name) {
|
||||
// Convert boolean value to Steel format
|
||||
*value = match value.to_lowercase().as_str() {
|
||||
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
|
||||
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
|
||||
_ => value.clone(), // Keep original if not recognized
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// src/table_script/handlers/dependency_analyzer.rs
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashMap;
|
||||
use tonic::Status;
|
||||
use sqlx::PgPool;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
@@ -6,12 +6,17 @@ use sqlx::{PgPool, Error as SqlxError};
|
||||
use common::proto::multieko2::table_script::{PostTableScriptRequest, TableScriptResponse};
|
||||
use serde_json::Value;
|
||||
use steel_decimal::SteelDecimal;
|
||||
use regex::Regex;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::table_script::handlers::dependency_analyzer::DependencyAnalyzer;
|
||||
|
||||
const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"];
|
||||
|
||||
/// Validates the target column and ensures it is not a system column.
|
||||
// Define prohibited data types for Steel scripts (boolean is explicitly allowed)
|
||||
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
|
||||
|
||||
/// Validates the target column and ensures it is not a system column or prohibited type.
|
||||
/// Returns the column type if valid.
|
||||
fn validate_target_column(
|
||||
table_name: &str,
|
||||
@@ -38,11 +43,211 @@ fn validate_target_column(
|
||||
.collect();
|
||||
|
||||
// Find the target column and return its type
|
||||
column_info
|
||||
let column_type = column_info
|
||||
.iter()
|
||||
.find(|(name, _)| *name == target)
|
||||
.map(|(_, dt)| dt.to_string())
|
||||
.ok_or_else(|| format!("Target column '{}' not defined in table '{}'", target, table_name))
|
||||
.ok_or_else(|| format!("Target column '{}' not defined in table '{}'", target, table_name))?;
|
||||
|
||||
// Check if the target column type is prohibited
|
||||
if is_prohibited_type(&column_type) {
|
||||
return Err(format!(
|
||||
"Cannot create script for column '{}' with type '{}'. Steel scripts cannot target columns of type: {}",
|
||||
target,
|
||||
column_type,
|
||||
PROHIBITED_TYPES.join(", ")
|
||||
));
|
||||
}
|
||||
|
||||
// Add helpful info for boolean columns
|
||||
let normalized_type = normalize_data_type(&column_type);
|
||||
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
|
||||
println!("Info: Target column '{}' is boolean type. Values will be converted to Steel format (#true/#false)", target);
|
||||
}
|
||||
|
||||
Ok(column_type)
|
||||
}
|
||||
|
||||
/// Check if a data type is prohibited for Steel scripts
|
||||
/// Note: BOOLEAN/BOOL is explicitly allowed and handled with special conversion
|
||||
fn is_prohibited_type(data_type: &str) -> bool {
|
||||
let normalized_type = normalize_data_type(data_type);
|
||||
PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited))
|
||||
}
|
||||
|
||||
/// Normalize data type for comparison (handle NUMERIC variations, etc.)
|
||||
fn normalize_data_type(data_type: &str) -> String {
|
||||
data_type.to_uppercase()
|
||||
.split('(') // Remove precision/scale from NUMERIC(x,y)
|
||||
.next()
|
||||
.unwrap_or(data_type)
|
||||
.trim()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Parse Steel script to extract all table/column references
|
||||
fn extract_column_references_from_script(script: &str) -> Vec<(String, String)> {
|
||||
let mut references = Vec::new();
|
||||
|
||||
// Regex patterns to match Steel function calls
|
||||
let patterns = [
|
||||
// (steel_get_column "table_name" "column_name")
|
||||
r#"\(steel_get_column\s+"([^"]+)"\s+"([^"]+)"\)"#,
|
||||
// (steel_get_column_with_index "table_name" index "column_name")
|
||||
r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#,
|
||||
];
|
||||
|
||||
for pattern in &patterns {
|
||||
if let Ok(re) = Regex::new(pattern) {
|
||||
for cap in re.captures_iter(script) {
|
||||
if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) {
|
||||
references.push((table.as_str().to_string(), column.as_str().to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check for steel_get_column_with_index pattern (table, column are in different positions)
|
||||
if let Ok(re) = Regex::new(r#"\(steel_get_column_with_index\s+"([^"]+)"\s+\d+\s+"([^"]+)"\)"#) {
|
||||
for cap in re.captures_iter(script) {
|
||||
if let (Some(table), Some(column)) = (cap.get(1), cap.get(2)) {
|
||||
references.push((table.as_str().to_string(), column.as_str().to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
references
|
||||
}
|
||||
|
||||
/// Validate that script doesn't reference prohibited column types by checking actual DB schema
|
||||
async fn validate_script_column_references(
|
||||
db_pool: &PgPool,
|
||||
schema_id: i64,
|
||||
script: &str,
|
||||
) -> Result<(), Status> {
|
||||
// Extract all table/column references from the script
|
||||
let references = extract_column_references_from_script(script);
|
||||
|
||||
if references.is_empty() {
|
||||
return Ok(()); // No column references to validate
|
||||
}
|
||||
|
||||
// Get all unique table names referenced in the script
|
||||
let table_names: HashSet<String> = references.iter()
|
||||
.map(|(table, _)| table.clone())
|
||||
.collect();
|
||||
|
||||
// Fetch table definitions for all referenced tables
|
||||
for table_name in table_names {
|
||||
// Query the actual table definition from the database
|
||||
let table_def = sqlx::query!(
|
||||
r#"SELECT table_name, columns FROM table_definitions
|
||||
WHERE schema_id = $1 AND table_name = $2"#,
|
||||
schema_id,
|
||||
table_name
|
||||
)
|
||||
.fetch_optional(db_pool)
|
||||
.await
|
||||
.map_err(|e| Status::internal(format!("Failed to fetch table definition for '{}': {}", table_name, e)))?;
|
||||
|
||||
if let Some(table_def) = table_def {
|
||||
// Check each column reference for this table
|
||||
for (ref_table, ref_column) in &references {
|
||||
if ref_table == &table_name {
|
||||
// Validate this specific column reference
|
||||
if let Err(error_msg) = validate_referenced_column_type(&table_name, ref_column, &table_def.columns) {
|
||||
return Err(Status::invalid_argument(error_msg));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(Status::invalid_argument(format!(
|
||||
"Script references table '{}' which does not exist in this schema",
|
||||
table_name
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate that a referenced column doesn't have a prohibited type
|
||||
fn validate_referenced_column_type(table_name: &str, column_name: &str, table_columns: &Value) -> Result<(), String> {
|
||||
// Parse the columns JSON into a vector of strings
|
||||
let columns: Vec<String> = serde_json::from_value(table_columns.clone())
|
||||
.map_err(|e| format!("Invalid column data for table '{}': {}", table_name, e))?;
|
||||
|
||||
// Extract column names and types
|
||||
let column_info: Vec<(&str, &str)> = columns
|
||||
.iter()
|
||||
.filter_map(|c| {
|
||||
let mut parts = c.split_whitespace();
|
||||
let name = parts.next()?.trim_matches('"');
|
||||
let data_type = parts.next()?;
|
||||
Some((name, data_type))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Find the referenced column and check its type
|
||||
if let Some((_, column_type)) = column_info.iter().find(|(name, _)| *name == column_name) {
|
||||
if is_prohibited_type(column_type) {
|
||||
return Err(format!(
|
||||
"Script references column '{}' in table '{}' which has prohibited type '{}'. Steel scripts cannot access columns of type: {}",
|
||||
column_name,
|
||||
table_name,
|
||||
column_type,
|
||||
PROHIBITED_TYPES.join(", ")
|
||||
));
|
||||
}
|
||||
|
||||
// Log info for boolean columns
|
||||
let normalized_type = normalize_data_type(column_type);
|
||||
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
|
||||
println!("Info: Script references boolean column '{}' in table '{}'. Values will be converted to Steel format (#true/#false)", column_name, table_name);
|
||||
}
|
||||
} else {
|
||||
return Err(format!(
|
||||
"Script references column '{}' in table '{}' but this column does not exist",
|
||||
column_name,
|
||||
table_name
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse Steel SQL queries to check for prohibited type usage (basic heuristic)
|
||||
fn validate_sql_queries_in_script(script: &str) -> Result<(), String> {
|
||||
// Look for steel_query_sql calls
|
||||
if let Ok(re) = Regex::new(r#"\(steel_query_sql\s+"([^"]+)"\)"#) {
|
||||
for cap in re.captures_iter(script) {
|
||||
if let Some(query) = cap.get(1) {
|
||||
let sql = query.as_str().to_uppercase();
|
||||
|
||||
// Basic heuristic checks for prohibited type operations
|
||||
let prohibited_patterns = [
|
||||
"EXTRACT(",
|
||||
"DATE_PART(",
|
||||
"::DATE",
|
||||
"::TIMESTAMPTZ",
|
||||
"::BIGINT",
|
||||
"CAST(", // Could be casting to prohibited types
|
||||
];
|
||||
|
||||
for pattern in &prohibited_patterns {
|
||||
if sql.contains(pattern) {
|
||||
return Err(format!(
|
||||
"Script contains SQL query with potentially prohibited type operations: '{}'. Steel scripts cannot use operations on types: {}",
|
||||
query.as_str(),
|
||||
PROHIBITED_TYPES.join(", ")
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handles the creation of a new table script with dependency validation.
|
||||
@@ -65,7 +270,7 @@ pub async fn post_table_script(
|
||||
.map_err(|e| Status::internal(format!("Failed to fetch table definition: {}", e)))?
|
||||
.ok_or_else(|| Status::not_found("Table definition not found"))?;
|
||||
|
||||
// Validate the target column and get its type
|
||||
// Validate the target column and get its type (includes prohibited type check)
|
||||
let column_type = validate_target_column(
|
||||
&table_def.table_name,
|
||||
&request.target_column,
|
||||
@@ -73,6 +278,13 @@ pub async fn post_table_script(
|
||||
)
|
||||
.map_err(|e| Status::invalid_argument(e))?;
|
||||
|
||||
// Validate that script doesn't reference prohibited column types by checking actual DB schema
|
||||
validate_script_column_references(db_pool, table_def.schema_id, &request.script).await?;
|
||||
|
||||
// Validate SQL queries in script for prohibited type operations
|
||||
validate_sql_queries_in_script(&request.script)
|
||||
.map_err(|e| Status::invalid_argument(e))?;
|
||||
|
||||
// Create dependency analyzer for this schema
|
||||
let analyzer = DependencyAnalyzer::new(table_def.schema_id, db_pool.clone());
|
||||
|
||||
|
||||
181
server/tests/table_script/prohibited_types_test.rs
Normal file
181
server/tests/table_script/prohibited_types_test.rs
Normal file
@@ -0,0 +1,181 @@
|
||||
// tests/table_script/prohibited_types_test.rs
|
||||
|
||||
#[cfg(test)]
|
||||
mod prohibited_types_tests {
|
||||
use super::*;
|
||||
use common::proto::multieko2::table_script::PostTableScriptRequest;
|
||||
use sqlx::PgPool;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reject_bigint_target_column() {
|
||||
let pool = setup_test_db().await;
|
||||
|
||||
// Create a table with a BIGINT column
|
||||
let table_id = create_test_table_with_bigint_column(&pool).await;
|
||||
|
||||
let request = PostTableScriptRequest {
|
||||
table_definition_id: table_id,
|
||||
target_column: "big_number".to_string(), // This is BIGINT
|
||||
script: r#"
|
||||
(define result "some calculation")
|
||||
result
|
||||
"#.to_string(),
|
||||
description: "Test script".to_string(),
|
||||
};
|
||||
|
||||
let result = post_table_script(&pool, request).await;
|
||||
|
||||
// Should fail with prohibited type error
|
||||
assert!(result.is_err());
|
||||
let error_msg = result.unwrap_err().to_string();
|
||||
assert!(error_msg.contains("Cannot create script for column 'big_number' with type 'BIGINT'"));
|
||||
assert!(error_msg.contains("Steel scripts cannot target columns of type: BIGINT, DATE, TIMESTAMPTZ"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reject_date_target_column() {
|
||||
let pool = setup_test_db().await;
|
||||
|
||||
// Create a table with a DATE column
|
||||
let table_id = create_test_table_with_date_column(&pool).await;
|
||||
|
||||
let request = PostTableScriptRequest {
|
||||
table_definition_id: table_id,
|
||||
target_column: "event_date".to_string(), // This is DATE
|
||||
script: r#"
|
||||
(define result "2024-01-01")
|
||||
result
|
||||
"#.to_string(),
|
||||
description: "Test script".to_string(),
|
||||
};
|
||||
|
||||
let result = post_table_script(&pool, request).await;
|
||||
|
||||
// Should fail with prohibited type error
|
||||
assert!(result.is_err());
|
||||
let error_msg = result.unwrap_err().to_string();
|
||||
assert!(error_msg.contains("Cannot create script for column 'event_date' with type 'DATE'"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reject_timestamptz_target_column() {
|
||||
let pool = setup_test_db().await;
|
||||
|
||||
// Create a table with a TIMESTAMPTZ column
|
||||
let table_id = create_test_table_with_timestamptz_column(&pool).await;
|
||||
|
||||
let request = PostTableScriptRequest {
|
||||
table_definition_id: table_id,
|
||||
target_column: "created_time".to_string(), // This is TIMESTAMPTZ
|
||||
script: r#"
|
||||
(define result "2024-01-01T10:00:00Z")
|
||||
result
|
||||
"#.to_string(),
|
||||
description: "Test script".to_string(),
|
||||
};
|
||||
|
||||
let result = post_table_script(&pool, request).await;
|
||||
|
||||
// Should fail with prohibited type error
|
||||
assert!(result.is_err());
|
||||
let error_msg = result.unwrap_err().to_string();
|
||||
assert!(error_msg.contains("Cannot create script for column 'created_time' with type 'TIMESTAMPTZ'"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reject_script_referencing_prohibited_column() {
|
||||
let pool = setup_test_db().await;
|
||||
|
||||
// Create linked tables - one with BIGINT column, another with TEXT target
|
||||
let source_table_id = create_test_table_with_text_column(&pool).await;
|
||||
let linked_table_id = create_test_table_with_bigint_column(&pool).await;
|
||||
|
||||
// Create link between tables
|
||||
create_table_link(&pool, source_table_id, linked_table_id).await;
|
||||
|
||||
let request = PostTableScriptRequest {
|
||||
table_definition_id: source_table_id,
|
||||
target_column: "description".to_string(), // This is TEXT (allowed)
|
||||
script: r#"
|
||||
(define big_val (steel_get_column "linked_table" "big_number"))
|
||||
(string-append "Value: " (number->string big_val))
|
||||
"#.to_string(),
|
||||
description: "Script that tries to access BIGINT column".to_string(),
|
||||
};
|
||||
|
||||
let result = post_table_script(&pool, request).await;
|
||||
|
||||
// Should fail because script references BIGINT column
|
||||
assert!(result.is_err());
|
||||
let error_msg = result.unwrap_err().to_string();
|
||||
assert!(error_msg.contains("Script cannot reference column 'big_number'"));
|
||||
assert!(error_msg.contains("prohibited type 'BIGINT'"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_allow_valid_script_with_allowed_types() {
|
||||
let pool = setup_test_db().await;
|
||||
|
||||
// Create a table with allowed column types
|
||||
let table_id = create_test_table_with_allowed_columns(&pool).await;
|
||||
|
||||
let request = PostTableScriptRequest {
|
||||
table_definition_id: table_id,
|
||||
target_column: "computed_value".to_string(), // This is TEXT (allowed)
|
||||
script: r#"
|
||||
(define name_val (steel_get_column "test_table" "name"))
|
||||
(define count_val (steel_get_column "test_table" "count"))
|
||||
(string-append name_val " has " (number->string count_val) " items")
|
||||
"#.to_string(),
|
||||
description: "Valid script using allowed types".to_string(),
|
||||
};
|
||||
|
||||
let result = post_table_script(&pool, request).await;
|
||||
|
||||
// Should succeed
|
||||
assert!(result.is_ok());
|
||||
let response = result.unwrap();
|
||||
assert!(response.id > 0);
|
||||
}
|
||||
|
||||
// Helper functions for test setup
|
||||
async fn setup_test_db() -> PgPool {
|
||||
// Your test database setup code here
|
||||
todo!("Implement test DB setup")
|
||||
}
|
||||
|
||||
async fn create_test_table_with_bigint_column(pool: &PgPool) -> i64 {
|
||||
// Create table definition with BIGINT column
|
||||
// JSON columns would be: ["name TEXT", "big_number BIGINT"]
|
||||
todo!("Implement table creation with BIGINT")
|
||||
}
|
||||
|
||||
async fn create_test_table_with_date_column(pool: &PgPool) -> i64 {
|
||||
// Create table definition with DATE column
|
||||
// JSON columns would be: ["name TEXT", "event_date DATE"]
|
||||
todo!("Implement table creation with DATE")
|
||||
}
|
||||
|
||||
async fn create_test_table_with_timestamptz_column(pool: &PgPool) -> i64 {
|
||||
// Create table definition with TIMESTAMPTZ column
|
||||
// JSON columns would be: ["name TEXT", "created_time TIMESTAMPTZ"]
|
||||
todo!("Implement table creation with TIMESTAMPTZ")
|
||||
}
|
||||
|
||||
async fn create_test_table_with_text_column(pool: &PgPool) -> i64 {
|
||||
// Create table definition with TEXT columns only
|
||||
// JSON columns would be: ["name TEXT", "description TEXT"]
|
||||
todo!("Implement table creation with TEXT")
|
||||
}
|
||||
|
||||
async fn create_test_table_with_allowed_columns(pool: &PgPool) -> i64 {
|
||||
// Create table definition with only allowed column types
|
||||
// JSON columns would be: ["name TEXT", "count INTEGER", "computed_value TEXT"]
|
||||
todo!("Implement table creation with allowed types")
|
||||
}
|
||||
|
||||
async fn create_table_link(pool: &PgPool, source_id: i64, target_id: i64) {
|
||||
// Create a link in table_definition_links
|
||||
todo!("Implement table linking")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user