CHECK THIS COMMIT I HAVE NO CLUE IF ITS CORRECT
This commit is contained in:
@@ -8,6 +8,7 @@ use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, error};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Value {
|
||||
@@ -34,10 +35,20 @@ pub async fn create_steel_context_with_boolean_conversion(
|
||||
mut row_data: HashMap<String, String>,
|
||||
db_pool: Arc<PgPool>,
|
||||
) -> Result<SteelContext, ExecutionError> {
|
||||
println!("=== CREATING STEEL CONTEXT ===");
|
||||
println!("Table: {}, Schema: {}", current_table, schema_name);
|
||||
println!("Row data BEFORE boolean conversion: {:?}", row_data);
|
||||
|
||||
// Convert boolean values in row_data to Steel format
|
||||
convert_row_data_for_steel(&db_pool, schema_id, ¤t_table, &mut row_data)
|
||||
.await
|
||||
.map_err(|e| ExecutionError::RuntimeError(format!("Failed to convert row data: {}", e)))?;
|
||||
.map_err(|e| {
|
||||
error!("Failed to convert row data for Steel: {}", e);
|
||||
ExecutionError::RuntimeError(format!("Failed to convert row data: {}", e))
|
||||
})?;
|
||||
|
||||
println!("Row data AFTER boolean conversion: {:?}", row_data);
|
||||
println!("=== END CREATING STEEL CONTEXT ===");
|
||||
|
||||
Ok(SteelContext {
|
||||
current_table,
|
||||
@@ -58,8 +69,13 @@ pub async fn execute_script(
|
||||
current_table: String,
|
||||
row_data: HashMap<String, String>,
|
||||
) -> Result<Value, ExecutionError> {
|
||||
println!("=== STEEL SCRIPT EXECUTION START ===");
|
||||
println!("Script: '{}'", script);
|
||||
println!("Table: {}, Target type: {}", current_table, target_type);
|
||||
println!("Input row_data: {:?}", row_data);
|
||||
|
||||
let mut vm = Engine::new();
|
||||
|
||||
|
||||
// Create context with boolean conversion
|
||||
let context = create_steel_context_with_boolean_conversion(
|
||||
current_table,
|
||||
@@ -68,7 +84,7 @@ pub async fn execute_script(
|
||||
row_data,
|
||||
db_pool.clone(),
|
||||
).await?;
|
||||
|
||||
|
||||
let context = Arc::new(context);
|
||||
|
||||
// Register existing Steel functions
|
||||
@@ -79,26 +95,71 @@ pub async fn execute_script(
|
||||
|
||||
// Register variables from the context with the Steel VM
|
||||
// The row_data now contains Steel-formatted boolean values
|
||||
println!("=== REGISTERING STEEL VARIABLES ===");
|
||||
|
||||
// Manual variable registration using Steel's define mechanism
|
||||
let mut define_script = String::new();
|
||||
|
||||
println!("Variables being registered with Steel VM:");
|
||||
for (key, value) in &context.row_data {
|
||||
println!(" STEEL[{}] = '{}'", key, value);
|
||||
// Register both @ prefixed and bare variable names
|
||||
define_script.push_str(&format!("(define {} \"{}\")\n", key, value));
|
||||
define_script.push_str(&format!("(define @{} \"{}\")\n", key, value));
|
||||
}
|
||||
println!("Steel script to execute: {}", script);
|
||||
println!("=== END REGISTERING STEEL VARIABLES ===");
|
||||
|
||||
if !define_script.is_empty() {
|
||||
println!("Define script: {}", define_script);
|
||||
vm.compile_and_run_raw_program(define_script)
|
||||
.map_err(|e| ExecutionError::RuntimeError(format!("Failed to register variables: {}", e)))?;
|
||||
println!("Variables defined successfully");
|
||||
}
|
||||
|
||||
// Also try the original method as backup
|
||||
FunctionRegistry::register_variables(&mut vm, context.row_data.clone());
|
||||
|
||||
// Execute script and process results
|
||||
let results = vm.compile_and_run_raw_program(script)
|
||||
.map_err(|e| ExecutionError::RuntimeError(e.to_string()))?;
|
||||
println!("Compiling and running Steel script: {}", script);
|
||||
let results = vm.compile_and_run_raw_program(script.clone())
|
||||
.map_err(|e| {
|
||||
error!("Steel script execution failed: {}", e);
|
||||
error!("Script was: {}", script);
|
||||
error!("Available variables were: {:?}", context.row_data);
|
||||
ExecutionError::RuntimeError(e.to_string())
|
||||
})?;
|
||||
|
||||
println!("Script execution returned {} results", results.len());
|
||||
for (i, result) in results.iter().enumerate() {
|
||||
println!("Result[{}]: {:?}", i, result);
|
||||
}
|
||||
|
||||
// Convert results to target type
|
||||
match target_type {
|
||||
"STRINGS" => process_string_results(results),
|
||||
"STRINGS" => {
|
||||
let result = process_string_results(results);
|
||||
println!("Final processed result: {:?}", result);
|
||||
println!("=== STEEL SCRIPT EXECUTION END ===");
|
||||
result
|
||||
},
|
||||
_ => Err(ExecutionError::UnsupportedType(target_type.into()))
|
||||
}
|
||||
}
|
||||
|
||||
fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
|
||||
debug!("Registering Steel functions with context");
|
||||
|
||||
// Register steel_get_column with row context
|
||||
vm.register_fn("steel_get_column", {
|
||||
let ctx = context.clone();
|
||||
move |table: String, column: String| {
|
||||
debug!("steel_get_column called with table: '{}', column: '{}'", table, column);
|
||||
ctx.steel_get_column(&table, &column)
|
||||
.map_err(|e| e.to_string())
|
||||
.map_err(|e| {
|
||||
error!("steel_get_column failed: {:?}", e);
|
||||
e.to_string()
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
@@ -106,8 +167,12 @@ fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
|
||||
vm.register_fn("steel_get_column_with_index", {
|
||||
let ctx = context.clone();
|
||||
move |table: String, index: i64, column: String| {
|
||||
debug!("steel_get_column_with_index called with table: '{}', index: {}, column: '{}'", table, index, column);
|
||||
ctx.steel_get_column_with_index(&table, index, &column)
|
||||
.map_err(|e| e.to_string())
|
||||
.map_err(|e| {
|
||||
error!("steel_get_column_with_index failed: {:?}", e);
|
||||
e.to_string()
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
@@ -115,13 +180,18 @@ fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
|
||||
vm.register_fn("steel_query_sql", {
|
||||
let ctx = context.clone();
|
||||
move |query: String| {
|
||||
debug!("steel_query_sql called with query: '{}'", query);
|
||||
ctx.steel_query_sql(&query)
|
||||
.map_err(|e| e.to_string())
|
||||
.map_err(|e| {
|
||||
error!("steel_query_sql failed: {:?}", e);
|
||||
e.to_string()
|
||||
})
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn register_decimal_math_functions(vm: &mut Engine) {
|
||||
debug!("Registering decimal math functions");
|
||||
// Use the steel_decimal crate's FunctionRegistry to register all functions
|
||||
FunctionRegistry::register_all(vm);
|
||||
}
|
||||
@@ -130,16 +200,35 @@ fn process_string_results(results: Vec<SteelVal>) -> Result<Value, ExecutionErro
|
||||
let mut strings = Vec::new();
|
||||
for result in results {
|
||||
match result {
|
||||
SteelVal::StringV(s) => strings.push(s.to_string()),
|
||||
SteelVal::NumV(n) => strings.push(n.to_string()),
|
||||
SteelVal::IntV(i) => strings.push(i.to_string()),
|
||||
SteelVal::BoolV(b) => strings.push(b.to_string()),
|
||||
SteelVal::StringV(s) => {
|
||||
let result_str = s.to_string();
|
||||
println!("Processing string result: '{}'", result_str);
|
||||
strings.push(result_str);
|
||||
},
|
||||
SteelVal::NumV(n) => {
|
||||
let result_str = n.to_string();
|
||||
println!("Processing number result: '{}'", result_str);
|
||||
strings.push(result_str);
|
||||
},
|
||||
SteelVal::IntV(i) => {
|
||||
let result_str = i.to_string();
|
||||
println!("Processing integer result: '{}'", result_str);
|
||||
strings.push(result_str);
|
||||
},
|
||||
SteelVal::BoolV(b) => {
|
||||
let result_str = b.to_string();
|
||||
println!("Processing boolean result: '{}'", result_str);
|
||||
strings.push(result_str);
|
||||
},
|
||||
_ => {
|
||||
error!("Unexpected result type: {:?}", result);
|
||||
return Err(ExecutionError::TypeConversionError(
|
||||
format!("Expected string-convertible type, got {:?}", result)
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("Final processed strings: {:?}", strings);
|
||||
Ok(Value::Strings(strings))
|
||||
}
|
||||
|
||||
@@ -108,22 +108,28 @@ impl DependencyAnalyzer {
|
||||
/// Uses regex patterns to find function calls that create dependencies
|
||||
pub fn analyze_script_dependencies(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> {
|
||||
let mut dependencies = Vec::new();
|
||||
|
||||
|
||||
// Extract function calls and SQL dependencies using regex
|
||||
dependencies.extend(self.extract_function_calls(script)?);
|
||||
dependencies.extend(self.extract_sql_dependencies(script)?);
|
||||
|
||||
|
||||
// Extract get-var calls (for transformed scripts with variables)
|
||||
dependencies.extend(self.extract_get_var_calls(script)?);
|
||||
|
||||
// Extract direct variable references like @price, @quantity
|
||||
dependencies.extend(self.extract_variable_references(script)?);
|
||||
|
||||
Ok(dependencies)
|
||||
}
|
||||
|
||||
/// Extract function calls using regex patterns
|
||||
fn extract_function_calls(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> {
|
||||
let mut dependencies = Vec::new();
|
||||
|
||||
|
||||
// Look for steel_get_column patterns
|
||||
let column_pattern = regex::Regex::new(r#"\(\s*steel_get_column\s+"([^"]+)"\s+"([^"]+)""#)
|
||||
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
||||
|
||||
|
||||
for caps in column_pattern.captures_iter(script) {
|
||||
let table = caps[1].to_string();
|
||||
let column = caps[2].to_string();
|
||||
@@ -137,7 +143,7 @@ impl DependencyAnalyzer {
|
||||
// Look for steel_get_column_with_index patterns
|
||||
let indexed_pattern = regex::Regex::new(r#"\(\s*steel_get_column_with_index\s+"([^"]+)"\s+(\d+)\s+"([^"]+)""#)
|
||||
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
||||
|
||||
|
||||
for caps in indexed_pattern.captures_iter(script) {
|
||||
let table = caps[1].to_string();
|
||||
let index: i64 = caps[2].parse()
|
||||
@@ -153,23 +159,63 @@ impl DependencyAnalyzer {
|
||||
Ok(dependencies)
|
||||
}
|
||||
|
||||
/// Extract get-var calls as dependencies (for transformed scripts with variables)
|
||||
fn extract_get_var_calls(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> {
|
||||
let mut dependencies = Vec::new();
|
||||
|
||||
// Look for get-var patterns in transformed scripts: (get-var "variable")
|
||||
let get_var_pattern = regex::Regex::new(r#"\(get-var\s+"([^"]+)"\)"#)
|
||||
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
||||
|
||||
for caps in get_var_pattern.captures_iter(script) {
|
||||
let variable_name = caps[1].to_string();
|
||||
dependencies.push(Dependency {
|
||||
target_table: "SAME_TABLE".to_string(), // Special marker for same-table references
|
||||
dependency_type: DependencyType::ColumnAccess { column: variable_name },
|
||||
context: None,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(dependencies)
|
||||
}
|
||||
|
||||
/// Extract direct variable references like @price, @quantity
|
||||
fn extract_variable_references(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> {
|
||||
let mut dependencies = Vec::new();
|
||||
|
||||
// Look for @variable patterns: @price, @quantity, etc.
|
||||
let variable_pattern = regex::Regex::new(r#"@([a-zA-Z_][a-zA-Z0-9_]*)"#)
|
||||
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
||||
|
||||
for caps in variable_pattern.captures_iter(script) {
|
||||
let variable_name = caps[1].to_string();
|
||||
dependencies.push(Dependency {
|
||||
target_table: "SAME_TABLE".to_string(), // Same table reference
|
||||
dependency_type: DependencyType::ColumnAccess { column: variable_name },
|
||||
context: None,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(dependencies)
|
||||
}
|
||||
|
||||
/// Extract table references from SQL queries in steel_query_sql calls
|
||||
fn extract_sql_dependencies(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> {
|
||||
let mut dependencies = Vec::new();
|
||||
|
||||
|
||||
// Look for steel_query_sql calls and extract table names from the SQL
|
||||
let sql_pattern = regex::Regex::new(r#"\(\s*steel_query_sql\s+"([^"]+)""#)
|
||||
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
||||
|
||||
|
||||
for caps in sql_pattern.captures_iter(script) {
|
||||
let query = caps[1].to_string();
|
||||
let table_refs = self.extract_table_references_from_sql(&query)?;
|
||||
|
||||
|
||||
for table in table_refs {
|
||||
dependencies.push(Dependency {
|
||||
target_table: table.clone(),
|
||||
dependency_type: DependencyType::SqlQuery {
|
||||
query_fragment: query.clone()
|
||||
dependency_type: DependencyType::SqlQuery {
|
||||
query_fragment: query.clone()
|
||||
},
|
||||
context: None,
|
||||
});
|
||||
@@ -182,17 +228,17 @@ impl DependencyAnalyzer {
|
||||
/// Extract table names from SQL query text
|
||||
fn extract_table_references_from_sql(&self, sql: &str) -> Result<Vec<String>, DependencyError> {
|
||||
let mut tables = Vec::new();
|
||||
|
||||
|
||||
// Simple extraction - look for FROM and JOIN clauses
|
||||
// This could be made more sophisticated with a proper SQL parser
|
||||
let table_pattern = regex::Regex::new(r#"(?i)\b(?:FROM|JOIN)\s+(?:"([^"]+)"|(\w+))"#)
|
||||
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
||||
|
||||
|
||||
for caps in table_pattern.captures_iter(sql) {
|
||||
let table = caps.get(1)
|
||||
.or_else(|| caps.get(2))
|
||||
.map(|m| m.as_str().to_string());
|
||||
|
||||
|
||||
if let Some(table_name) = table {
|
||||
tables.push(table_name);
|
||||
}
|
||||
@@ -211,7 +257,7 @@ impl DependencyAnalyzer {
|
||||
) -> Result<(), DependencyError> {
|
||||
// FIRST: Validate that structured table access respects link constraints
|
||||
self.validate_link_constraints(tx, table_id, new_dependencies).await?;
|
||||
|
||||
|
||||
// Get current dependency graph for this schema
|
||||
let current_deps = sqlx::query!(
|
||||
r#"SELECT sd.source_table_id, sd.target_table_id, st.table_name as source_name, tt.table_name as target_name
|
||||
@@ -241,18 +287,22 @@ impl DependencyAnalyzer {
|
||||
// Add new dependencies to test - EXCLUDE self-references
|
||||
for dep in new_dependencies {
|
||||
// Look up target table ID
|
||||
let target_id = sqlx::query_scalar!(
|
||||
"SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2",
|
||||
self.schema_id,
|
||||
dep.target_table
|
||||
)
|
||||
.fetch_optional(&mut **tx)
|
||||
.await
|
||||
.map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?
|
||||
.ok_or_else(|| DependencyError::InvalidTableReference {
|
||||
table_name: dep.target_table.clone(),
|
||||
script_context: format!("table_id_{}", table_id),
|
||||
})?;
|
||||
let target_id = if dep.target_table == "SAME_TABLE" {
|
||||
table_id // Same table reference
|
||||
} else {
|
||||
sqlx::query_scalar!(
|
||||
"SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2",
|
||||
self.schema_id,
|
||||
dep.target_table
|
||||
)
|
||||
.fetch_optional(&mut **tx)
|
||||
.await
|
||||
.map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?
|
||||
.ok_or_else(|| DependencyError::InvalidTableReference {
|
||||
table_name: dep.target_table.clone(),
|
||||
script_context: format!("table_id_{}", table_id),
|
||||
})?
|
||||
};
|
||||
|
||||
// Only add to cycle detection graph if it's NOT a self-reference
|
||||
if table_id != target_id {
|
||||
@@ -306,7 +356,7 @@ impl DependencyAnalyzer {
|
||||
// Skip this - it's not a harmful cycle
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
// Found a real cycle! Build the cycle path
|
||||
let cycle_start_idx = path.iter().position(|&x| x == neighbor).unwrap_or(0);
|
||||
let cycle_path: Vec<String> = path[cycle_start_idx..]
|
||||
@@ -394,6 +444,11 @@ impl DependencyAnalyzer {
|
||||
match &dep.dependency_type {
|
||||
// Structured access must respect link constraints (but self-references are always allowed)
|
||||
DependencyType::ColumnAccess { column } | DependencyType::IndexedAccess { column, .. } => {
|
||||
// Skip validation for SAME_TABLE marker (these are always allowed)
|
||||
if dep.target_table == "SAME_TABLE" {
|
||||
continue;
|
||||
}
|
||||
|
||||
if !allowed_tables.contains(&dep.target_table) {
|
||||
return Err(DependencyError::InvalidTableReference {
|
||||
table_name: dep.target_table.clone(),
|
||||
@@ -425,7 +480,7 @@ impl DependencyAnalyzer {
|
||||
starting_table: i64,
|
||||
) -> Result<(), DependencyError> {
|
||||
let mut states: HashMap<i64, NodeState> = HashMap::new();
|
||||
|
||||
|
||||
// Initialize all nodes as unvisited
|
||||
for &node in graph.keys() {
|
||||
states.insert(node, NodeState::Unvisited);
|
||||
@@ -436,11 +491,11 @@ impl DependencyAnalyzer {
|
||||
if states.get(&node) == Some(&NodeState::Unvisited) {
|
||||
let mut path = Vec::new();
|
||||
if let Err(cycle_error) = self.dfs_visit(
|
||||
node,
|
||||
&mut states,
|
||||
graph,
|
||||
&mut path,
|
||||
table_names,
|
||||
node,
|
||||
&mut states,
|
||||
graph,
|
||||
&mut path,
|
||||
table_names,
|
||||
starting_table
|
||||
) {
|
||||
return Err(cycle_error);
|
||||
@@ -467,18 +522,22 @@ impl DependencyAnalyzer {
|
||||
|
||||
// Insert new dependencies
|
||||
for dep in dependencies {
|
||||
let target_id = sqlx::query_scalar!(
|
||||
"SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2",
|
||||
self.schema_id,
|
||||
dep.target_table
|
||||
)
|
||||
.fetch_optional(&mut **tx)
|
||||
.await
|
||||
.map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?
|
||||
.ok_or_else(|| DependencyError::InvalidTableReference {
|
||||
table_name: dep.target_table.clone(),
|
||||
script_context: format!("script_id_{}", script_id),
|
||||
})?;
|
||||
let target_id = if dep.target_table == "SAME_TABLE" {
|
||||
table_id // Use the same table as the script
|
||||
} else {
|
||||
sqlx::query_scalar!(
|
||||
"SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2",
|
||||
self.schema_id,
|
||||
dep.target_table
|
||||
)
|
||||
.fetch_optional(&mut **tx)
|
||||
.await
|
||||
.map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?
|
||||
.ok_or_else(|| DependencyError::InvalidTableReference {
|
||||
table_name: dep.target_table.clone(),
|
||||
script_context: format!("script_id_{}", script_id),
|
||||
})?
|
||||
};
|
||||
|
||||
sqlx::query!(
|
||||
r#"INSERT INTO script_dependencies
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
// src/tables_data/handlers/put_table_data.rs
|
||||
|
||||
// TODO WORK ON SCRIPTS INPUT OUTPUT HAS TO BE CHECKED
|
||||
|
||||
use tonic::Status;
|
||||
@@ -14,9 +13,7 @@ use rust_decimal::Decimal;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::steel::server::execution::{self, Value};
|
||||
use crate::steel::server::functions::SteelContext;
|
||||
use crate::indexer::{IndexCommand, IndexCommandData};
|
||||
use crate::table_script::handlers::dependency_analyzer::DependencyAnalyzer;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::error;
|
||||
|
||||
@@ -29,6 +26,10 @@ pub async fn put_table_data(
|
||||
let table_name = request.table_name;
|
||||
let record_id = request.id;
|
||||
|
||||
println!("=== PUT TABLE DATA START ===");
|
||||
println!("Profile: {}, Table: {}, Record ID: {}", profile_name, table_name, record_id);
|
||||
println!("Request data: {:?}", request.data);
|
||||
|
||||
if request.data.is_empty() {
|
||||
return Ok(PutTableDataResponse {
|
||||
success: true,
|
||||
@@ -56,7 +57,6 @@ pub async fn put_table_data(
|
||||
.map_err(|e| Status::internal(format!("Column parsing error: {}", e)))?;
|
||||
|
||||
let mut columns = Vec::new();
|
||||
let mut user_column_names = Vec::new();
|
||||
for col_def in columns_json {
|
||||
let parts: Vec<&str> = col_def.splitn(2, ' ').collect();
|
||||
if parts.len() != 2 {
|
||||
@@ -64,10 +64,11 @@ pub async fn put_table_data(
|
||||
}
|
||||
let name = parts[0].trim_matches('"').to_string();
|
||||
let sql_type = parts[1].to_string();
|
||||
user_column_names.push(name.clone());
|
||||
columns.push((name, sql_type));
|
||||
}
|
||||
|
||||
println!("Table columns: {:?}", columns);
|
||||
|
||||
// --- Validate Column Permissions ---
|
||||
let fk_columns = sqlx::query!(
|
||||
r#"SELECT ltd.table_name
|
||||
@@ -81,7 +82,6 @@ pub async fn put_table_data(
|
||||
.map_err(|e| Status::internal(format!("Foreign key lookup error: {}", e)))?;
|
||||
|
||||
let mut system_columns = vec!["deleted".to_string()];
|
||||
// FIX 1: Change from `fk_columns` to `&fk_columns` to avoid move
|
||||
for fk in &fk_columns {
|
||||
system_columns.push(format!("{}_id", fk.table_name));
|
||||
}
|
||||
@@ -96,11 +96,16 @@ pub async fn put_table_data(
|
||||
}
|
||||
}
|
||||
|
||||
// --- [OPTIMIZATION] Smart Data Fetching: Only fetch what scripts need ---
|
||||
let scripts = sqlx::query!("SELECT target_column, script FROM table_scripts WHERE table_definitions_id = $1", table_def.id)
|
||||
// --- Smart Data Fetching using script_dependencies table ---
|
||||
let scripts = sqlx::query!("SELECT id, target_column, script FROM table_scripts WHERE table_definitions_id = $1", table_def.id)
|
||||
.fetch_all(db_pool).await
|
||||
.map_err(|e| Status::internal(format!("Failed to fetch scripts: {}", e)))?;
|
||||
|
||||
println!("Found {} scripts for table", scripts.len());
|
||||
for script in &scripts {
|
||||
println!("Script ID {}: target_column='{}', script='{}'", script.id, script.target_column, script.script);
|
||||
}
|
||||
|
||||
let mut required_columns = std::collections::HashSet::new();
|
||||
|
||||
// Always need: id, target columns of scripts, and columns being updated
|
||||
@@ -112,35 +117,45 @@ pub async fn put_table_data(
|
||||
required_columns.insert(key.clone());
|
||||
}
|
||||
|
||||
// Analyze script dependencies to find what columns scripts actually access
|
||||
// Use pre-computed dependencies from script_dependencies table
|
||||
if !scripts.is_empty() {
|
||||
let analyzer = DependencyAnalyzer::new(schema_id, db_pool.clone());
|
||||
let script_ids: Vec<i64> = scripts.iter().map(|s| s.id).collect();
|
||||
|
||||
for script_record in &scripts {
|
||||
let dependencies = analyzer
|
||||
.analyze_script_dependencies(&script_record.script)
|
||||
.map_err(|e| Status::internal(format!("Failed to analyze script dependencies: {:?}", e)))?;
|
||||
let dependencies = sqlx::query!(
|
||||
r#"SELECT sd.target_table_id, sd.dependency_type, sd.context_info, td.table_name as target_table
|
||||
FROM script_dependencies sd
|
||||
JOIN table_definitions td ON sd.target_table_id = td.id
|
||||
WHERE sd.script_id = ANY($1)"#,
|
||||
&script_ids
|
||||
)
|
||||
.fetch_all(db_pool)
|
||||
.await
|
||||
.map_err(|e| Status::internal(format!("Failed to fetch script dependencies: {}", e)))?;
|
||||
|
||||
for dep in dependencies {
|
||||
// If it references this table, add the columns it uses
|
||||
if dep.target_table == table_name {
|
||||
match dep.dependency_type {
|
||||
crate::table_script::handlers::dependency_analyzer::DependencyType::ColumnAccess { column } |
|
||||
crate::table_script::handlers::dependency_analyzer::DependencyType::IndexedAccess { column, .. } => {
|
||||
required_columns.insert(column);
|
||||
for dep in dependencies {
|
||||
// If it references this table, add the columns it uses
|
||||
if dep.target_table == table_name {
|
||||
match dep.dependency_type.as_str() {
|
||||
"column_access" | "indexed_access" => {
|
||||
if let Some(context) = dep.context_info {
|
||||
if let Some(column) = context.get("column").and_then(|v| v.as_str()) {
|
||||
required_columns.insert(column.to_string());
|
||||
}
|
||||
}
|
||||
_ => {} // SQL queries handled differently
|
||||
}
|
||||
_ => {} // SQL queries handled differently
|
||||
}
|
||||
// If it references linked tables, add their foreign key columns
|
||||
else {
|
||||
let fk_column = format!("{}_id", dep.target_table.split('_').last().unwrap_or(&dep.target_table));
|
||||
required_columns.insert(fk_column);
|
||||
}
|
||||
}
|
||||
// If it references linked tables, add their foreign key columns
|
||||
else {
|
||||
let fk_column = format!("{}_id", dep.target_table.split('_').last().unwrap_or(&dep.target_table));
|
||||
required_columns.insert(fk_column);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("Required columns for context: {:?}", required_columns);
|
||||
|
||||
// Build optimized SELECT query with only required columns
|
||||
let qualified_table = crate::shared::schema_qualifier::qualify_table_name_for_data(db_pool, &profile_name, &table_name).await?;
|
||||
|
||||
@@ -151,6 +166,7 @@ pub async fn put_table_data(
|
||||
.join(", ");
|
||||
|
||||
let select_sql = format!("SELECT {} FROM {} WHERE id = $1", columns_clause, qualified_table);
|
||||
println!("SELECT SQL: {}", select_sql);
|
||||
|
||||
let current_row = sqlx::query(&select_sql).bind(record_id).fetch_optional(db_pool).await
|
||||
.map_err(|e| Status::internal(format!("Failed to fetch current row state: {}", e)))?
|
||||
@@ -162,8 +178,12 @@ pub async fn put_table_data(
|
||||
current_row_data.insert(col_name.clone(), value);
|
||||
}
|
||||
|
||||
println!("=== CURRENT ROW DATA FROM DB ===");
|
||||
println!("{:?}", current_row_data);
|
||||
|
||||
// --- Data Merging Logic ---
|
||||
let mut update_data = HashMap::new();
|
||||
println!("=== EXTRACTING UPDATE DATA FROM REQUEST ===");
|
||||
for (key, proto_value) in &request.data {
|
||||
let str_val = match &proto_value.kind {
|
||||
Some(Kind::StringValue(s)) => s.trim().to_string(),
|
||||
@@ -172,18 +192,26 @@ pub async fn put_table_data(
|
||||
Some(Kind::NullValue(_)) | None => String::new(),
|
||||
_ => return Err(Status::invalid_argument(format!("Unsupported type for column '{}'", key))),
|
||||
};
|
||||
if !str_val.is_empty() {
|
||||
update_data.insert(key.clone(), str_val);
|
||||
}
|
||||
println!("UPDATE_DATA[{}] = '{}'", key, str_val);
|
||||
// Always add the value, even if empty (to properly override current values)
|
||||
update_data.insert(key.clone(), str_val);
|
||||
}
|
||||
|
||||
println!("=== UPDATE DATA EXTRACTED ===");
|
||||
println!("{:?}", update_data);
|
||||
|
||||
let mut final_context_data = current_row_data.clone();
|
||||
final_context_data.extend(update_data.clone());
|
||||
|
||||
// FIX 2: Type-aware script validation
|
||||
println!("=== FINAL CONTEXT DATA FOR STEEL ===");
|
||||
println!("{:?}", final_context_data);
|
||||
|
||||
// Script validation with type-aware comparison
|
||||
for script_record in scripts {
|
||||
let target_column = script_record.target_column;
|
||||
|
||||
println!("=== PROCESSING SCRIPT FOR COLUMN: {} ===", target_column);
|
||||
|
||||
// Find the SQL type for this target column
|
||||
let target_sql_type = if let Some((_, stype)) = columns.iter().find(|(name, _)| name == &target_column) {
|
||||
stype.as_str()
|
||||
@@ -191,16 +219,11 @@ pub async fn put_table_data(
|
||||
"TEXT" // Default fallback for system columns
|
||||
};
|
||||
|
||||
let context = SteelContext {
|
||||
current_table: table_name.clone(),
|
||||
schema_id,
|
||||
schema_name: profile_name.clone(),
|
||||
row_data: final_context_data.clone(),
|
||||
db_pool: Arc::new(db_pool.clone()),
|
||||
};
|
||||
println!("Target column SQL type: {}", target_sql_type);
|
||||
println!("Executing script: {}", script_record.script);
|
||||
|
||||
let script_result = execution::execute_script(
|
||||
script_record.script,
|
||||
script_record.script.clone(),
|
||||
"STRINGS",
|
||||
Arc::new(db_pool.clone()),
|
||||
schema_id,
|
||||
@@ -209,82 +232,100 @@ pub async fn put_table_data(
|
||||
final_context_data.clone(),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| Status::invalid_argument(format!("Script execution failed for '{}': {}", target_column, e)))?;
|
||||
.map_err(|e| {
|
||||
error!("Script execution failed for '{}': {:?}", target_column, e);
|
||||
Status::invalid_argument(format!("Script execution failed for '{}': {}", target_column, e))
|
||||
})?;
|
||||
|
||||
let Value::Strings(mut script_output_vec) = script_result else {
|
||||
return Err(Status::internal("Script must return string values"));
|
||||
};
|
||||
let script_output = script_output_vec.pop().ok_or_else(|| Status::internal("Script returned no values"))?;
|
||||
|
||||
println!("Script output: '{}'", script_output);
|
||||
|
||||
if update_data.contains_key(&target_column) {
|
||||
// Case A: Column is being updated. Validate user input against script.
|
||||
let user_value = update_data.get(&target_column).unwrap();
|
||||
|
||||
println!("Case A: Validating user value '{}' against script output '{}'", user_value, script_output);
|
||||
|
||||
// TYPE-AWARE COMPARISON based on SQL type
|
||||
let values_match = match target_sql_type {
|
||||
s if s.starts_with("NUMERIC") => {
|
||||
// For NUMERIC columns, compare as decimals
|
||||
let user_decimal = Decimal::from_str(user_value).map_err(|_| Status::invalid_argument(format!("Invalid decimal format for column '{}'", target_column)))?;
|
||||
let script_decimal = Decimal::from_str(&script_output).map_err(|_| Status::internal(format!("Script for '{}' produced invalid decimal", target_column)))?;
|
||||
println!("Decimal comparison: user={}, script={}", user_decimal, script_decimal);
|
||||
user_decimal == script_decimal
|
||||
},
|
||||
"INTEGER" | "BIGINT" => {
|
||||
// For integer columns, compare as integers
|
||||
let user_int: i64 = user_value.parse().map_err(|_| Status::invalid_argument(format!("Invalid integer format for column '{}'", target_column)))?;
|
||||
let script_int: i64 = script_output.parse().map_err(|_| Status::internal(format!("Script for '{}' produced invalid integer", target_column)))?;
|
||||
println!("Integer comparison: user={}, script={}", user_int, script_int);
|
||||
user_int == script_int
|
||||
},
|
||||
"BOOLEAN" => {
|
||||
// For boolean columns, compare as booleans
|
||||
let user_bool: bool = user_value.parse().map_err(|_| Status::invalid_argument(format!("Invalid boolean format for column '{}'", target_column)))?;
|
||||
let script_bool: bool = script_output.parse().map_err(|_| Status::internal(format!("Script for '{}' produced invalid boolean", target_column)))?;
|
||||
println!("Boolean comparison: user={}, script={}", user_bool, script_bool);
|
||||
user_bool == script_bool
|
||||
},
|
||||
_ => {
|
||||
// For TEXT, TIMESTAMPTZ, DATE, etc. - compare as strings
|
||||
println!("String comparison: user='{}', script='{}'", user_value, script_output);
|
||||
user_value == &script_output
|
||||
}
|
||||
};
|
||||
|
||||
println!("Values match: {}", values_match);
|
||||
|
||||
if !values_match {
|
||||
return Err(Status::invalid_argument(format!("Validation failed for column '{}': Script calculated '{}', but user provided '{}'", target_column, script_output, user_value)));
|
||||
}
|
||||
} else {
|
||||
// Case B: Column is NOT being updated. Prevent unauthorized changes.
|
||||
let current_value = current_row_data.get(&target_column).cloned().unwrap_or_default();
|
||||
|
||||
println!("Case B: Checking if script would change current value '{}' to '{}'", current_value, script_output);
|
||||
|
||||
let values_match = match target_sql_type {
|
||||
s if s.starts_with("NUMERIC") => {
|
||||
let current_decimal = Decimal::from_str(¤t_value).unwrap_or_default();
|
||||
let script_decimal = Decimal::from_str(&script_output).unwrap_or_default();
|
||||
println!("Decimal comparison: current={}, script={}", current_decimal, script_decimal);
|
||||
current_decimal == script_decimal
|
||||
},
|
||||
"INTEGER" | "BIGINT" => {
|
||||
let current_int: i64 = current_value.parse().unwrap_or_default();
|
||||
let script_int: i64 = script_output.parse().unwrap_or_default();
|
||||
println!("Integer comparison: current={}, script={}", current_int, script_int);
|
||||
current_int == script_int
|
||||
},
|
||||
"BOOLEAN" => {
|
||||
let current_bool: bool = current_value.parse().unwrap_or(false);
|
||||
let script_bool: bool = script_output.parse().unwrap_or(false);
|
||||
println!("Boolean comparison: current={}, script={}", current_bool, script_bool);
|
||||
current_bool == script_bool
|
||||
},
|
||||
_ => {
|
||||
println!("String comparison: current='{}', script='{}'", current_value, script_output);
|
||||
current_value == script_output
|
||||
}
|
||||
};
|
||||
|
||||
println!("Values match: {}", values_match);
|
||||
|
||||
if !values_match {
|
||||
return Err(Status::failed_precondition(format!("Script for column '{}' was triggered and would change its value from '{}' to '{}'. To apply this change, please include '{}' in your update request.", target_column, current_value, script_output, target_column)));
|
||||
}
|
||||
}
|
||||
println!("=== END PROCESSING SCRIPT FOR COLUMN: {} ===", target_column);
|
||||
}
|
||||
|
||||
// --- Database Update with Full Validation ---
|
||||
// --- Database Update ---
|
||||
let mut params = PgArguments::default();
|
||||
let mut set_clauses = Vec::new();
|
||||
let mut param_idx = 1;
|
||||
|
||||
println!("=== BUILDING UPDATE QUERY ===");
|
||||
|
||||
for (col, proto_value) in request.data {
|
||||
let sql_type = if system_columns_set.contains(col.as_str()) {
|
||||
match col.as_str() {
|
||||
@@ -321,7 +362,6 @@ pub async fn put_table_data(
|
||||
if sql_type == "TEXT" {
|
||||
if let Kind::StringValue(value) = kind {
|
||||
let trimmed_value = value.trim();
|
||||
|
||||
if trimmed_value.is_empty() {
|
||||
params.add(None::<String>).map_err(|e| Status::internal(format!("Failed to add null parameter for {}: {}", col, e)))?;
|
||||
} else {
|
||||
@@ -424,6 +464,8 @@ pub async fn put_table_data(
|
||||
param_idx
|
||||
);
|
||||
|
||||
println!("UPDATE SQL: {}", sql);
|
||||
|
||||
params.add(record_id).map_err(|e| Status::internal(format!("Failed to add record_id parameter: {}", e)))?;
|
||||
|
||||
let result = sqlx::query_scalar_with::<_, i64, _>(&sql, params)
|
||||
@@ -459,6 +501,9 @@ pub async fn put_table_data(
|
||||
);
|
||||
}
|
||||
|
||||
println!("=== PUT TABLE DATA SUCCESS ===");
|
||||
println!("Updated record ID: {}", updated_id);
|
||||
|
||||
Ok(PutTableDataResponse {
|
||||
success: true,
|
||||
message: "Data updated successfully".into(),
|
||||
|
||||
@@ -117,7 +117,7 @@ async fn test_put_basic_arithmetic_validation_success(pool: PgPool) {
|
||||
let script_request = PostTableScriptRequest {
|
||||
table_definition_id: table_def_id,
|
||||
target_column: "total".to_string(),
|
||||
script: "(* $price $quantity)".to_string(),
|
||||
script: "(* @price @quantity)".to_string(),
|
||||
description: "Total = Price × Quantity".to_string(),
|
||||
};
|
||||
post_table_script(&pool, script_request).await.unwrap();
|
||||
@@ -155,7 +155,7 @@ async fn test_put_basic_arithmetic_validation_failure(pool: PgPool) {
|
||||
let script_request = PostTableScriptRequest {
|
||||
table_definition_id: table_def_id,
|
||||
target_column: "total".to_string(),
|
||||
script: "(* $price $quantity)".to_string(),
|
||||
script: "(* @price @quantity)".to_string(),
|
||||
description: "Total = Price × Quantity".to_string(),
|
||||
};
|
||||
post_table_script(&pool, script_request).await.unwrap();
|
||||
@@ -187,7 +187,7 @@ async fn test_put_basic_arithmetic_validation_failure(pool: PgPool) {
|
||||
assert_eq!(error.code(), tonic::Code::InvalidArgument);
|
||||
let msg = error.message();
|
||||
assert!(msg.contains("Validation failed for column 'total'"));
|
||||
assert!(msg.contains("Script calculated '76.5'"));
|
||||
assert!(msg.contains("Script calculated '76.50'"));
|
||||
assert!(msg.contains("but user provided '70.00'"));
|
||||
}
|
||||
|
||||
@@ -199,7 +199,7 @@ async fn test_put_complex_formula_validation(pool: PgPool) {
|
||||
let script_request = PostTableScriptRequest {
|
||||
table_definition_id: table_def_id,
|
||||
target_column: "total".to_string(),
|
||||
script: "(+ (* $price $quantity) (* (* $price $quantity) 0.08))".to_string(),
|
||||
script: "(+ (* @price @quantity) (* (* @price @quantity) 0.08))".to_string(),
|
||||
description: "Total with 8% tax".to_string(),
|
||||
};
|
||||
post_table_script(&pool, script_request).await.unwrap();
|
||||
@@ -236,7 +236,7 @@ async fn test_put_division_with_precision(pool: PgPool) {
|
||||
let script_request = PostTableScriptRequest {
|
||||
table_definition_id: table_def_id,
|
||||
target_column: "percentage".to_string(),
|
||||
script: "(/ $total $price)".to_string(),
|
||||
script: "(/ @total @price)".to_string(),
|
||||
description: "Percentage = Total / Price".to_string(),
|
||||
};
|
||||
post_table_script(&pool, script_request).await.unwrap();
|
||||
@@ -301,7 +301,7 @@ async fn test_put_advanced_math_functions(pool: PgPool) {
|
||||
let sqrt_script = PostTableScriptRequest {
|
||||
table_definition_id: table_row.id,
|
||||
target_column: "square_root".to_string(),
|
||||
script: "(sqrt $input)".to_string(),
|
||||
script: "(sqrt @input)".to_string(),
|
||||
description: "Square root validation".to_string(),
|
||||
};
|
||||
post_table_script(&pool, sqrt_script).await.unwrap();
|
||||
@@ -309,7 +309,7 @@ async fn test_put_advanced_math_functions(pool: PgPool) {
|
||||
let power_script = PostTableScriptRequest {
|
||||
table_definition_id: table_row.id,
|
||||
target_column: "power_result".to_string(),
|
||||
script: "(^ $input 2.0)".to_string(),
|
||||
script: "(^ @input 2.0)".to_string(),
|
||||
description: "Power function validation".to_string(),
|
||||
};
|
||||
post_table_script(&pool, power_script).await.unwrap();
|
||||
@@ -364,7 +364,7 @@ async fn test_put_financial_calculations(pool: PgPool) {
|
||||
let compound_script = PostTableScriptRequest {
|
||||
table_definition_id: table_row.id,
|
||||
target_column: "compound_result".to_string(),
|
||||
script: "(* $principal (^ (+ 1.0 $rate) $time))".to_string(),
|
||||
script: "(* @principal (^ (+ 1.0 @rate) @time))".to_string(),
|
||||
description: "Compound interest calculation".to_string(),
|
||||
};
|
||||
post_table_script(&pool, compound_script).await.unwrap();
|
||||
@@ -372,7 +372,7 @@ async fn test_put_financial_calculations(pool: PgPool) {
|
||||
let percentage_script = PostTableScriptRequest {
|
||||
table_definition_id: table_row.id,
|
||||
target_column: "percentage_result".to_string(),
|
||||
script: "(* $principal $rate)".to_string(),
|
||||
script: "(* @principal @rate)".to_string(),
|
||||
description: "Percentage calculation".to_string(),
|
||||
};
|
||||
post_table_script(&pool, percentage_script).await.unwrap();
|
||||
@@ -416,17 +416,15 @@ async fn test_put_partial_update_with_validation(pool: PgPool) {
|
||||
let script_request = PostTableScriptRequest {
|
||||
table_definition_id: table_def_id,
|
||||
target_column: "total".to_string(),
|
||||
script: "(* $price $quantity)".to_string(),
|
||||
script: "(* @price @quantity)".to_string(),
|
||||
description: "Total = Price × Quantity".to_string(),
|
||||
};
|
||||
post_table_script(&pool, script_request).await.unwrap();
|
||||
|
||||
let record_id = create_initial_record(&pool, "test_put_partial", "invoice", &indexer_tx).await;
|
||||
|
||||
// Partial update: only update quantity. Validation for total should still run and pass.
|
||||
// The merged context will be { price: 10.00, quantity: 5, total: 10.00, ... }
|
||||
// The script will calculate total as 10.00 * 5 = 50.00.
|
||||
// Since we are not providing 'total' in the update, validation for it is skipped.
|
||||
// Partial update: only update quantity. The script detects this would change total
|
||||
// from 10.00 to 50.00 and requires the user to include 'total' in the update.
|
||||
let mut update_data = HashMap::new();
|
||||
update_data.insert("quantity".to_string(), ProtoValue {
|
||||
kind: Some(Kind::NumberValue(5.0)),
|
||||
@@ -439,8 +437,13 @@ async fn test_put_partial_update_with_validation(pool: PgPool) {
|
||||
data: update_data,
|
||||
};
|
||||
|
||||
let response = put_table_data(&pool, put_request, &indexer_tx).await.unwrap();
|
||||
assert!(response.success);
|
||||
// This should fail because script would change total value
|
||||
let result = put_table_data(&pool, put_request, &indexer_tx).await;
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err();
|
||||
assert_eq!(error.code(), tonic::Code::FailedPrecondition);
|
||||
assert!(error.message().contains("Script for column 'total' was triggered"));
|
||||
assert!(error.message().contains("from '10.00' to '50.00'"));
|
||||
|
||||
// Now, test a partial update that SHOULD fail validation.
|
||||
// We update quantity and provide an incorrect total.
|
||||
@@ -463,7 +466,8 @@ async fn test_put_partial_update_with_validation(pool: PgPool) {
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err();
|
||||
assert_eq!(error.code(), tonic::Code::InvalidArgument);
|
||||
assert!(error.message().contains("Script calculated '30'"));
|
||||
assert!(error.message().contains("Script calculated '30.00'"));
|
||||
assert!(error.message().contains("but user provided '99.99'"));
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
@@ -520,7 +524,7 @@ async fn test_put_steel_script_error_handling(pool: PgPool) {
|
||||
let script_request = PostTableScriptRequest {
|
||||
table_definition_id: table_def_id,
|
||||
target_column: "total".to_string(),
|
||||
script: "(/ $price 0.0)".to_string(),
|
||||
script: "(/ @price 0.0)".to_string(),
|
||||
description: "Error test".to_string(),
|
||||
};
|
||||
post_table_script(&pool, script_request).await.unwrap();
|
||||
|
||||
Reference in New Issue
Block a user