fixing warnings and making prod code

This commit is contained in:
filipriec
2025-07-24 23:57:21 +02:00
parent c82813185f
commit c58ce52b33
5 changed files with 173 additions and 188 deletions

View File

@@ -1,4 +1,5 @@
// src/steel/server/execution.rs // src/steel/server/execution.rs
use steel::steel_vm::engine::Engine; use steel::steel_vm::engine::Engine;
use steel::steel_vm::register_fn::RegisterFn; use steel::steel_vm::register_fn::RegisterFn;
use steel::rvals::SteelVal; use steel::rvals::SteelVal;
@@ -10,13 +11,13 @@ use std::collections::HashMap;
use thiserror::Error; use thiserror::Error;
use tracing::{debug, error}; use tracing::{debug, error};
/// Represents different types of values that can be returned from Steel script execution.
#[derive(Debug)] #[derive(Debug)]
pub enum Value { pub enum Value {
Strings(Vec<String>), Strings(Vec<String>),
Numbers(Vec<i64>),
Mixed(Vec<SteelVal>),
} }
/// Errors that can occur during Steel script execution.
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ExecutionError { pub enum ExecutionError {
#[error("Script execution failed: {0}")] #[error("Script execution failed: {0}")]
@@ -27,7 +28,7 @@ pub enum ExecutionError {
UnsupportedType(String), UnsupportedType(String),
} }
/// Create a SteelContext with boolean conversion applied to row data /// Creates a Steel execution context with proper boolean value conversion.
pub async fn create_steel_context_with_boolean_conversion( pub async fn create_steel_context_with_boolean_conversion(
current_table: String, current_table: String,
schema_id: i64, schema_id: i64,
@@ -35,10 +36,6 @@ pub async fn create_steel_context_with_boolean_conversion(
mut row_data: HashMap<String, String>, mut row_data: HashMap<String, String>,
db_pool: Arc<PgPool>, db_pool: Arc<PgPool>,
) -> Result<SteelContext, ExecutionError> { ) -> Result<SteelContext, ExecutionError> {
println!("=== CREATING STEEL CONTEXT ===");
println!("Table: {}, Schema: {}", current_table, schema_name);
println!("Row data BEFORE boolean conversion: {:?}", row_data);
// Convert boolean values in row_data to Steel format // Convert boolean values in row_data to Steel format
convert_row_data_for_steel(&db_pool, schema_id, &current_table, &mut row_data) convert_row_data_for_steel(&db_pool, schema_id, &current_table, &mut row_data)
.await .await
@@ -47,9 +44,6 @@ pub async fn create_steel_context_with_boolean_conversion(
ExecutionError::RuntimeError(format!("Failed to convert row data: {}", e)) ExecutionError::RuntimeError(format!("Failed to convert row data: {}", e))
})?; })?;
println!("Row data AFTER boolean conversion: {:?}", row_data);
println!("=== END CREATING STEEL CONTEXT ===");
Ok(SteelContext { Ok(SteelContext {
current_table, current_table,
schema_id, schema_id,
@@ -59,7 +53,7 @@ pub async fn create_steel_context_with_boolean_conversion(
}) })
} }
/// Execute script with proper boolean handling /// Executes a Steel script with database context and type-safe result processing.
pub async fn execute_script( pub async fn execute_script(
script: String, script: String,
target_type: &str, target_type: &str,
@@ -69,20 +63,9 @@ pub async fn execute_script(
current_table: String, current_table: String,
row_data: HashMap<String, String>, row_data: HashMap<String, String>,
) -> Result<Value, ExecutionError> { ) -> Result<Value, ExecutionError> {
eprintln!("🚨🚨🚨 EXECUTE_SCRIPT CALLED 🚨🚨🚨");
eprintln!("Script: '{}'", script);
eprintln!("Table: {}, Target type: {}", current_table, target_type);
eprintln!("Input row_data: {:?}", row_data);
eprintln!("🚨🚨🚨 END EXECUTE_SCRIPT ENTRY 🚨🚨🚨");
println!("=== STEEL SCRIPT EXECUTION START ===");
println!("Script: '{}'", script);
println!("Table: {}, Target type: {}", current_table, target_type);
println!("Input row_data: {:?}", row_data);
let mut vm = Engine::new(); let mut vm = Engine::new();
// Create context with boolean conversion // Create execution context with proper boolean value conversion
let context = create_steel_context_with_boolean_conversion( let context = create_steel_context_with_boolean_conversion(
current_table, current_table,
schema_id, schema_id,
@@ -93,41 +76,32 @@ pub async fn execute_script(
let context = Arc::new(context); let context = Arc::new(context);
// Register existing Steel functions // Register database access functions
register_steel_functions(&mut vm, context.clone()); register_steel_functions(&mut vm, context.clone());
// Register all decimal math functions using the steel_decimal crate // Register decimal math operations
register_decimal_math_functions(&mut vm); register_decimal_math_functions(&mut vm);
// Register variables from the context with the Steel VM // Register row data as variables in the Steel VM
// The row_data now contains Steel-formatted boolean values // Both bare names and @-prefixed names are supported for flexibility
println!("=== REGISTERING STEEL VARIABLES ===");
// Manual variable registration using Steel's define mechanism
let mut define_script = String::new(); let mut define_script = String::new();
println!("Variables being registered with Steel VM:");
for (key, value) in &context.row_data { for (key, value) in &context.row_data {
println!(" STEEL[{}] = '{}'", key, value);
// Register both @ prefixed and bare variable names // Register both @ prefixed and bare variable names
define_script.push_str(&format!("(define {} \"{}\")\n", key, value)); define_script.push_str(&format!("(define {} \"{}\")\n", key, value));
define_script.push_str(&format!("(define @{} \"{}\")\n", key, value)); define_script.push_str(&format!("(define @{} \"{}\")\n", key, value));
} }
println!("Steel script to execute: {}", script);
println!("=== END REGISTERING STEEL VARIABLES ===");
// Execute variable definitions if any exist
if !define_script.is_empty() { if !define_script.is_empty() {
println!("Define script: {}", define_script);
vm.compile_and_run_raw_program(define_script) vm.compile_and_run_raw_program(define_script)
.map_err(|e| ExecutionError::RuntimeError(format!("Failed to register variables: {}", e)))?; .map_err(|e| ExecutionError::RuntimeError(format!("Failed to register variables: {}", e)))?;
println!("Variables defined successfully");
} }
// Also try the original method as backup // Also register variables using the decimal registry as backup method
FunctionRegistry::register_variables(&mut vm, context.row_data.clone()); FunctionRegistry::register_variables(&mut vm, context.row_data.clone());
// Execute script and process results // Execute the main script
println!("Compiling and running Steel script: {}", script);
let results = vm.compile_and_run_raw_program(script.clone()) let results = vm.compile_and_run_raw_program(script.clone())
.map_err(|e| { .map_err(|e| {
error!("Steel script execution failed: {}", e); error!("Steel script execution failed: {}", e);
@@ -136,27 +110,18 @@ pub async fn execute_script(
ExecutionError::RuntimeError(e.to_string()) ExecutionError::RuntimeError(e.to_string())
})?; })?;
println!("Script execution returned {} results", results.len()); // Convert results to the requested target type
for (i, result) in results.iter().enumerate() {
println!("Result[{}]: {:?}", i, result);
}
// Convert results to target type
match target_type { match target_type {
"STRINGS" => { "STRINGS" => process_string_results(results),
let result = process_string_results(results);
println!("Final processed result: {:?}", result);
println!("=== STEEL SCRIPT EXECUTION END ===");
result
},
_ => Err(ExecutionError::UnsupportedType(target_type.into())) _ => Err(ExecutionError::UnsupportedType(target_type.into()))
} }
} }
/// Registers Steel functions for database access within the VM context.
fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) { fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
debug!("Registering Steel functions with context"); debug!("Registering Steel functions with context");
// Register steel_get_column with row context // Register column access function for current and related tables
vm.register_fn("steel_get_column", { vm.register_fn("steel_get_column", {
let ctx = context.clone(); let ctx = context.clone();
move |table: String, column: String| { move |table: String, column: String| {
@@ -169,7 +134,7 @@ fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
} }
}); });
// Register steel_get_column_with_index // Register indexed column access for comma-separated values
vm.register_fn("steel_get_column_with_index", { vm.register_fn("steel_get_column_with_index", {
let ctx = context.clone(); let ctx = context.clone();
move |table: String, index: i64, column: String| { move |table: String, index: i64, column: String| {
@@ -182,7 +147,7 @@ fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
} }
}); });
// SQL query registration // Register safe SQL query execution
vm.register_fn("steel_query_sql", { vm.register_fn("steel_query_sql", {
let ctx = context.clone(); let ctx = context.clone();
move |query: String| { move |query: String| {
@@ -196,45 +161,31 @@ fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
}); });
} }
/// Registers decimal mathematics functions in the Steel VM.
fn register_decimal_math_functions(vm: &mut Engine) { fn register_decimal_math_functions(vm: &mut Engine) {
debug!("Registering decimal math functions"); debug!("Registering decimal math functions");
// Use the steel_decimal crate's FunctionRegistry to register all functions
FunctionRegistry::register_all(vm); FunctionRegistry::register_all(vm);
} }
/// Processes Steel script results into string format for consistent output.
fn process_string_results(results: Vec<SteelVal>) -> Result<Value, ExecutionError> { fn process_string_results(results: Vec<SteelVal>) -> Result<Value, ExecutionError> {
let mut strings = Vec::new(); let mut strings = Vec::new();
for result in results { for result in results {
match result { let result_str = match result {
SteelVal::StringV(s) => { SteelVal::StringV(s) => s.to_string(),
let result_str = s.to_string(); SteelVal::NumV(n) => n.to_string(),
println!("Processing string result: '{}'", result_str); SteelVal::IntV(i) => i.to_string(),
strings.push(result_str); SteelVal::BoolV(b) => b.to_string(),
},
SteelVal::NumV(n) => {
let result_str = n.to_string();
println!("Processing number result: '{}'", result_str);
strings.push(result_str);
},
SteelVal::IntV(i) => {
let result_str = i.to_string();
println!("Processing integer result: '{}'", result_str);
strings.push(result_str);
},
SteelVal::BoolV(b) => {
let result_str = b.to_string();
println!("Processing boolean result: '{}'", result_str);
strings.push(result_str);
},
_ => { _ => {
error!("Unexpected result type: {:?}", result); error!("Unexpected result type: {:?}", result);
return Err(ExecutionError::TypeConversionError( return Err(ExecutionError::TypeConversionError(
format!("Expected string-convertible type, got {:?}", result) format!("Expected string-convertible type, got {:?}", result)
)); ));
} }
} };
strings.push(result_str);
} }
println!("Final processed strings: {:?}", strings);
Ok(Value::Strings(strings)) Ok(Value::Strings(strings))
} }

View File

@@ -1,4 +1,5 @@
// src/steel/server/functions.rs // src/steel/server/functions.rs
use steel::rvals::SteelVal; use steel::rvals::SteelVal;
use sqlx::PgPool; use sqlx::PgPool;
use std::collections::HashMap; use std::collections::HashMap;
@@ -20,9 +21,10 @@ pub enum FunctionError {
ProhibitedTypeAccess(String), ProhibitedTypeAccess(String),
} }
// Define prohibited data types (boolean is explicitly allowed) /// Data types that Steel scripts are prohibited from accessing for security reasons
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"]; const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
/// Execution context for Steel scripts with database access capabilities.
#[derive(Clone)] #[derive(Clone)]
pub struct SteelContext { pub struct SteelContext {
pub current_table: String, pub current_table: String,
@@ -33,6 +35,8 @@ pub struct SteelContext {
} }
impl SteelContext { impl SteelContext {
/// Resolves a base table name to its full qualified name in the current schema.
/// Used for foreign key relationship traversal in Steel scripts.
pub async fn get_related_table_name(&self, base_name: &str) -> Result<String, FunctionError> { pub async fn get_related_table_name(&self, base_name: &str) -> Result<String, FunctionError> {
let table_def = sqlx::query!( let table_def = sqlx::query!(
r#"SELECT table_name FROM table_definitions r#"SELECT table_name FROM table_definitions
@@ -48,7 +52,8 @@ impl SteelContext {
Ok(table_def.table_name) Ok(table_def.table_name)
} }
/// Get column type for a given table and column /// Retrieves the SQL data type for a specific column in a table.
/// Parses the JSON column definitions to find type information.
async fn get_column_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> { async fn get_column_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
let table_def = sqlx::query!( let table_def = sqlx::query!(
r#"SELECT columns FROM table_definitions r#"SELECT columns FROM table_definitions
@@ -61,11 +66,10 @@ impl SteelContext {
.map_err(|e| FunctionError::DatabaseError(e.to_string()))? .map_err(|e| FunctionError::DatabaseError(e.to_string()))?
.ok_or_else(|| FunctionError::TableNotFound(table_name.to_string()))?; .ok_or_else(|| FunctionError::TableNotFound(table_name.to_string()))?;
// Parse columns JSON to find the column type
let columns: Vec<String> = serde_json::from_value(table_def.columns) let columns: Vec<String> = serde_json::from_value(table_def.columns)
.map_err(|e| FunctionError::DatabaseError(format!("Invalid column data: {}", e)))?; .map_err(|e| FunctionError::DatabaseError(format!("Invalid column data: {}", e)))?;
// Find the column and its type // Parse column definitions to find the requested column type
for column_def in columns { for column_def in columns {
let mut parts = column_def.split_whitespace(); let mut parts = column_def.split_whitespace();
if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) { if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) {
@@ -83,31 +87,30 @@ impl SteelContext {
))) )))
} }
/// Convert database value to Steel format based on column type /// Converts database values to Steel script format based on column type.
/// Currently handles boolean conversion to Steel's #true/#false syntax.
fn convert_value_to_steel_format(&self, value: &str, column_type: &str) -> String { fn convert_value_to_steel_format(&self, value: &str, column_type: &str) -> String {
let normalized_type = normalize_data_type(column_type); let normalized_type = normalize_data_type(column_type);
match normalized_type.as_str() { match normalized_type.as_str() {
"BOOLEAN" | "BOOL" => { "BOOLEAN" | "BOOL" => {
// Convert database boolean to Steel boolean syntax // Convert database boolean representations to Steel boolean syntax
match value.to_lowercase().as_str() { match value.to_lowercase().as_str() {
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(), "true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
"false" | "f" | "0" | "no" | "off" => "#false".to_string(), "false" | "f" | "0" | "no" | "off" => "#false".to_string(),
_ => value.to_string(), // Return as-is if not a recognized boolean _ => value.to_string(), // Return as-is if not a recognized boolean
} }
} }
"INTEGER" => { "INTEGER" => value.to_string(),
value.to_string() _ => value.to_string(), // Return as-is for other types
}
_ => value.to_string(), // Return as-is for non-boolean types
} }
} }
/// Validate column type and return the column type if valid /// Validates that a column type is allowed for Steel script access.
/// Returns the column type if validation passes, error if prohibited.
async fn validate_column_type_and_get_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> { async fn validate_column_type_and_get_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
let column_type = self.get_column_type(table_name, column_name).await?; let column_type = self.get_column_type(table_name, column_name).await?;
// Check if this type is prohibited
if is_prohibited_type(&column_type) { if is_prohibited_type(&column_type) {
return Err(FunctionError::ProhibitedTypeAccess(format!( return Err(FunctionError::ProhibitedTypeAccess(format!(
"Cannot access column '{}' in table '{}' because it has prohibited type '{}'. Steel scripts cannot access columns of type: {}", "Cannot access column '{}' in table '{}' because it has prohibited type '{}'. Steel scripts cannot access columns of type: {}",
@@ -121,15 +124,15 @@ impl SteelContext {
Ok(column_type) Ok(column_type)
} }
/// Validate column type before access (legacy method for backward compatibility) /// Retrieves column value from current table or related tables via foreign keys.
async fn validate_column_type(&self, table_name: &str, column_name: &str) -> Result<(), FunctionError> { ///
self.validate_column_type_and_get_type(table_name, column_name).await?; /// # Behavior
Ok(()) /// - Current table: Returns value directly from row_data with type conversion
} /// - Related table: Follows foreign key relationship and queries database
/// - All accesses are subject to prohibited type validation
pub fn steel_get_column(&self, table: &str, column: &str) -> Result<SteelVal, SteelVal> { pub fn steel_get_column(&self, table: &str, column: &str) -> Result<SteelVal, SteelVal> {
if table == self.current_table { if table == self.current_table {
// Validate column type for current table access and get the type // Access current table data with type validation
let column_type = tokio::task::block_in_place(|| { let column_type = tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current(); let handle = tokio::runtime::Handle::current();
handle.block_on(async { handle.block_on(async {
@@ -150,6 +153,7 @@ impl SteelContext {
.ok_or_else(|| SteelVal::StringV(format!("Column {} not found", column).into())); .ok_or_else(|| SteelVal::StringV(format!("Column {} not found", column).into()));
} }
// Access related table via foreign key relationship
let base_name = table.split_once('_') let base_name = table.split_once('_')
.map(|(_, rest)| rest) .map(|(_, rest)| rest)
.unwrap_or(table); .unwrap_or(table);
@@ -158,18 +162,17 @@ impl SteelContext {
let fk_value = self.row_data.get(&fk_column) let fk_value = self.row_data.get(&fk_column)
.ok_or_else(|| SteelVal::StringV(format!("Foreign key {} not found", fk_column).into()))?; .ok_or_else(|| SteelVal::StringV(format!("Foreign key {} not found", fk_column).into()))?;
// Use `tokio::task::block_in_place` to safely block the thread
let result = tokio::task::block_in_place(|| { let result = tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current(); let handle = tokio::runtime::Handle::current();
handle.block_on(async { handle.block_on(async {
let actual_table = self.get_related_table_name(base_name).await let actual_table = self.get_related_table_name(base_name).await
.map_err(|e| SteelVal::StringV(e.to_string().into()))?; .map_err(|e| SteelVal::StringV(e.to_string().into()))?;
// Get column type for validation and conversion // Validate column type and get type information
let column_type = self.validate_column_type_and_get_type(&actual_table, column).await let column_type = self.validate_column_type_and_get_type(&actual_table, column).await
.map_err(|e| SteelVal::StringV(e.to_string().into()))?; .map_err(|e| SteelVal::StringV(e.to_string().into()))?;
// Fetch the raw value from database // Query the related table for the column value
let raw_value = sqlx::query_scalar::<_, String>( let raw_value = sqlx::query_scalar::<_, String>(
&format!("SELECT {} FROM \"{}\".\"{}\" WHERE id = $1", column, self.schema_name, actual_table) &format!("SELECT {} FROM \"{}\".\"{}\" WHERE id = $1", column, self.schema_name, actual_table)
) )
@@ -179,7 +182,7 @@ impl SteelContext {
.await .await
.map_err(|e| SteelVal::StringV(e.to_string().into()))?; .map_err(|e| SteelVal::StringV(e.to_string().into()))?;
// Convert to Steel format // Convert to appropriate Steel format
let converted_value = self.convert_value_to_steel_format(&raw_value, &column_type); let converted_value = self.convert_value_to_steel_format(&raw_value, &column_type);
Ok(converted_value) Ok(converted_value)
}) })
@@ -188,13 +191,15 @@ impl SteelContext {
result.map(|v| SteelVal::StringV(v.into())) result.map(|v| SteelVal::StringV(v.into()))
} }
/// Retrieves a specific indexed element from a comma-separated column value.
/// Useful for accessing elements from array-like string representations.
pub fn steel_get_column_with_index( pub fn steel_get_column_with_index(
&self, &self,
table: &str, table: &str,
index: i64, index: i64,
column: &str column: &str
) -> Result<SteelVal, SteelVal> { ) -> Result<SteelVal, SteelVal> {
// Get the full value first (this already handles type conversion) // Get the full value with proper type conversion
let value = self.steel_get_column(table, column)?; let value = self.steel_get_column(table, column)?;
if let SteelVal::StringV(s) = value { if let SteelVal::StringV(s) = value {
@@ -203,8 +208,7 @@ impl SteelContext {
if let Some(part) = parts.get(index as usize) { if let Some(part) = parts.get(index as usize) {
let trimmed_part = part.trim(); let trimmed_part = part.trim();
// If the original column was boolean type, each part should also be treated as boolean // Apply type conversion to the indexed part based on original column type
// We need to get the column type to determine if conversion is needed
let column_type = tokio::task::block_in_place(|| { let column_type = tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current(); let handle = tokio::runtime::Handle::current();
handle.block_on(async { handle.block_on(async {
@@ -218,7 +222,7 @@ impl SteelContext {
Ok(SteelVal::StringV(converted_part.into())) Ok(SteelVal::StringV(converted_part.into()))
} }
Err(_) => { Err(_) => {
// If we can't get the type, return as-is // If type cannot be determined, return value as-is
Ok(SteelVal::StringV(trimmed_part.into())) Ok(SteelVal::StringV(trimmed_part.into()))
} }
} }
@@ -230,15 +234,19 @@ impl SteelContext {
} }
} }
/// Executes read-only SQL queries from Steel scripts with safety restrictions.
///
/// # Security Features
/// - Only SELECT, SHOW, and EXPLAIN queries allowed
/// - Prohibited column type access validation
/// - Returns first column of all rows as comma-separated string
pub fn steel_query_sql(&self, query: &str) -> Result<SteelVal, SteelVal> { pub fn steel_query_sql(&self, query: &str) -> Result<SteelVal, SteelVal> {
// Validate query is read-only
if !is_read_only_query(query) { if !is_read_only_query(query) {
return Err(SteelVal::StringV( return Err(SteelVal::StringV(
"Only SELECT queries are allowed".into() "Only SELECT queries are allowed".into()
)); ));
} }
// Check if query might access prohibited columns
if contains_prohibited_column_access(query) { if contains_prohibited_column_access(query) {
return Err(SteelVal::StringV(format!( return Err(SteelVal::StringV(format!(
"SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}", "SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}",
@@ -248,11 +256,9 @@ impl SteelContext {
let pool = self.db_pool.clone(); let pool = self.db_pool.clone();
// Use `tokio::task::block_in_place` to safely block the thread
let result = tokio::task::block_in_place(|| { let result = tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current(); let handle = tokio::runtime::Handle::current();
handle.block_on(async { handle.block_on(async {
// Execute and get first column of all rows as strings
let rows = sqlx::query(query) let rows = sqlx::query(query)
.fetch_all(&*pool) .fetch_all(&*pool)
.await .await
@@ -273,13 +279,14 @@ impl SteelContext {
} }
} }
/// Check if a data type is prohibited for Steel scripts /// Checks if a data type is prohibited for Steel script access.
fn is_prohibited_type(data_type: &str) -> bool { fn is_prohibited_type(data_type: &str) -> bool {
let normalized_type = normalize_data_type(data_type); let normalized_type = normalize_data_type(data_type);
PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited))
} }
/// Normalize data type for comparison (handle NUMERIC variations, etc.) /// Normalizes data type strings for consistent comparison.
/// Handles variations like NUMERIC(10,2) by extracting base type.
fn normalize_data_type(data_type: &str) -> String { fn normalize_data_type(data_type: &str) -> String {
data_type.to_uppercase() data_type.to_uppercase()
.split('(') // Remove precision/scale from NUMERIC(x,y) .split('(') // Remove precision/scale from NUMERIC(x,y)
@@ -289,13 +296,11 @@ fn normalize_data_type(data_type: &str) -> String {
.to_string() .to_string()
} }
/// Basic check for prohibited column access in SQL queries /// Performs basic heuristic check for prohibited column type access in SQL queries.
/// This is a simple heuristic - more sophisticated parsing could be added /// Looks for common patterns that might indicate access to restricted types.
fn contains_prohibited_column_access(query: &str) -> bool { fn contains_prohibited_column_access(query: &str) -> bool {
let query_upper = query.to_uppercase(); let query_upper = query.to_uppercase();
// Look for common patterns that might indicate prohibited type access
// This is a basic implementation - you might want to enhance this
let patterns = [ let patterns = [
"EXTRACT(", // Common with DATE/TIMESTAMPTZ "EXTRACT(", // Common with DATE/TIMESTAMPTZ
"DATE_PART(", // Common with DATE/TIMESTAMPTZ "DATE_PART(", // Common with DATE/TIMESTAMPTZ
@@ -307,6 +312,7 @@ fn contains_prohibited_column_access(query: &str) -> bool {
patterns.iter().any(|pattern| query_upper.contains(pattern)) patterns.iter().any(|pattern| query_upper.contains(pattern))
} }
/// Validates that a query is read-only and safe for Steel script execution.
fn is_read_only_query(query: &str) -> bool { fn is_read_only_query(query: &str) -> bool {
let query = query.trim_start().to_uppercase(); let query = query.trim_start().to_uppercase();
query.starts_with("SELECT") || query.starts_with("SELECT") ||
@@ -314,14 +320,13 @@ fn is_read_only_query(query: &str) -> bool {
query.starts_with("EXPLAIN") query.starts_with("EXPLAIN")
} }
/// Helper function to convert initial row data for boolean columns /// Converts row data boolean values to Steel script format during context initialization.
pub async fn convert_row_data_for_steel( pub async fn convert_row_data_for_steel(
db_pool: &PgPool, db_pool: &PgPool,
schema_id: i64, schema_id: i64,
table_name: &str, table_name: &str,
row_data: &mut HashMap<String, String>, row_data: &mut HashMap<String, String>,
) -> Result<(), sqlx::Error> { ) -> Result<(), sqlx::Error> {
// Get table definition to check column types
let table_def = sqlx::query!( let table_def = sqlx::query!(
r#"SELECT columns FROM table_definitions r#"SELECT columns FROM table_definitions
WHERE schema_id = $1 AND table_name = $2"#, WHERE schema_id = $1 AND table_name = $2"#,
@@ -332,7 +337,7 @@ pub async fn convert_row_data_for_steel(
.await? .await?
.ok_or_else(|| sqlx::Error::RowNotFound)?; .ok_or_else(|| sqlx::Error::RowNotFound)?;
// Parse columns to find boolean types // Parse column definitions to identify boolean columns for conversion
if let Ok(columns) = serde_json::from_value::<Vec<String>>(table_def.columns) { if let Ok(columns) = serde_json::from_value::<Vec<String>>(table_def.columns) {
for column_def in columns { for column_def in columns {
let mut parts = column_def.split_whitespace(); let mut parts = column_def.split_whitespace();
@@ -340,7 +345,6 @@ pub async fn convert_row_data_for_steel(
let column_name = name.trim_matches('"'); let column_name = name.trim_matches('"');
let normalized_type = normalize_data_type(data_type); let normalized_type = normalize_data_type(data_type);
// Fixed: Use traditional if let instead of let chains
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" { if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
if let Some(value) = row_data.get_mut(column_name) { if let Some(value) = row_data.get_mut(column_name) {
// Convert boolean value to Steel format // Convert boolean value to Steel format

View File

@@ -5,6 +5,7 @@ use tonic::Status;
use sqlx::PgPool; use sqlx::PgPool;
use serde_json::{json, Value}; use serde_json::{json, Value};
/// Represents the state of a node during dependency graph traversal.
#[derive(Clone, Copy, PartialEq)] #[derive(Clone, Copy, PartialEq)]
enum NodeState { enum NodeState {
Unvisited, Unvisited,
@@ -12,6 +13,7 @@ enum NodeState {
Visited, // Completely processed Visited, // Completely processed
} }
/// Represents a dependency relationship between tables in Steel scripts.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Dependency { pub struct Dependency {
pub target_table: String, pub target_table: String,
@@ -19,14 +21,19 @@ pub struct Dependency {
pub context: Option<Value>, pub context: Option<Value>,
} }
/// Types of dependencies that can exist between tables in Steel scripts.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum DependencyType { pub enum DependencyType {
/// Direct column access via steel_get_column
ColumnAccess { column: String }, ColumnAccess { column: String },
/// Indexed column access via steel_get_column_with_index
IndexedAccess { column: String, index: i64 }, IndexedAccess { column: String, index: i64 },
/// Raw SQL query access via steel_query_sql
SqlQuery { query_fragment: String }, SqlQuery { query_fragment: String },
} }
impl DependencyType { impl DependencyType {
/// Returns the string representation used in the database.
pub fn as_str(&self) -> &'static str { pub fn as_str(&self) -> &'static str {
match self { match self {
DependencyType::ColumnAccess { .. } => "column_access", DependencyType::ColumnAccess { .. } => "column_access",
@@ -35,6 +42,7 @@ impl DependencyType {
} }
} }
/// Generates context JSON for database storage.
pub fn context_json(&self) -> Value { pub fn context_json(&self) -> Value {
match self { match self {
DependencyType::ColumnAccess { column } => { DependencyType::ColumnAccess { column } => {
@@ -50,6 +58,7 @@ impl DependencyType {
} }
} }
/// Errors that can occur during dependency analysis.
#[derive(Debug)] #[derive(Debug)]
pub enum DependencyError { pub enum DependencyError {
CircularDependency { CircularDependency {
@@ -94,43 +103,53 @@ impl From<DependencyError> for Status {
} }
} }
/// Analyzes Steel scripts to extract table dependencies and validate referential integrity.
///
/// This analyzer identifies how tables reference each other through Steel function calls
/// and ensures that dependency graphs remain acyclic while respecting table link constraints.
pub struct DependencyAnalyzer { pub struct DependencyAnalyzer {
schema_id: i64, schema_id: i64,
pool: PgPool,
} }
impl DependencyAnalyzer { impl DependencyAnalyzer {
/// Creates a new dependency analyzer for the specified schema.
pub fn new(schema_id: i64, pool: PgPool) -> Self { pub fn new(schema_id: i64, pool: PgPool) -> Self {
Self { schema_id, pool } Self { schema_id }
} }
/// Analyzes a Steel script to extract all table dependencies /// Analyzes a Steel script to extract all table dependencies.
/// Uses regex patterns to find function calls that create dependencies ///
/// Uses regex patterns to identify function calls that create dependencies:
/// - `steel_get_column` calls for direct column access
/// - `steel_get_column_with_index` calls for indexed access
/// - `steel_query_sql` calls for raw SQL access
/// - Variable references like `@column_name`
/// - `get-var` calls in transformed scripts
/// ///
/// # Arguments /// # Arguments
/// * `script` - The Steel script to analyze /// * `script` - The Steel script code to analyze
/// * `current_table_name` - Name of the table this script belongs to (for self-references) /// * `current_table_name` - Name of the table this script belongs to (for self-references)
///
/// # Returns
/// * `Ok(Vec<Dependency>)` - List of identified dependencies
/// * `Err(DependencyError)` - If script parsing fails
pub fn analyze_script_dependencies(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> { pub fn analyze_script_dependencies(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> {
let mut dependencies = Vec::new(); let mut dependencies = Vec::new();
// Extract function calls and SQL dependencies using regex // Extract different types of dependencies using regex patterns
dependencies.extend(self.extract_function_calls(script)?); dependencies.extend(self.extract_function_calls(script)?);
dependencies.extend(self.extract_sql_dependencies(script)?); dependencies.extend(self.extract_sql_dependencies(script)?);
// Extract get-var calls (for transformed scripts with variables)
dependencies.extend(self.extract_get_var_calls(script, current_table_name)?); dependencies.extend(self.extract_get_var_calls(script, current_table_name)?);
// Extract direct variable references like @price, @quantity
dependencies.extend(self.extract_variable_references(script, current_table_name)?); dependencies.extend(self.extract_variable_references(script, current_table_name)?);
Ok(dependencies) Ok(dependencies)
} }
/// Extract function calls using regex patterns /// Extracts Steel function calls that create table dependencies.
fn extract_function_calls(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> { fn extract_function_calls(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> {
let mut dependencies = Vec::new(); let mut dependencies = Vec::new();
// Look for steel_get_column patterns // Pattern: (steel_get_column "table" "column")
let column_pattern = regex::Regex::new(r#"\(\s*steel_get_column\s+"([^"]+)"\s+"([^"]+)""#) let column_pattern = regex::Regex::new(r#"\(\s*steel_get_column\s+"([^"]+)"\s+"([^"]+)""#)
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?; .map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
@@ -144,7 +163,7 @@ impl DependencyAnalyzer {
}); });
} }
// Look for steel_get_column_with_index patterns // Pattern: (steel_get_column_with_index "table" index "column")
let indexed_pattern = regex::Regex::new(r#"\(\s*steel_get_column_with_index\s+"([^"]+)"\s+(\d+)\s+"([^"]+)""#) let indexed_pattern = regex::Regex::new(r#"\(\s*steel_get_column_with_index\s+"([^"]+)"\s+(\d+)\s+"([^"]+)""#)
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?; .map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
@@ -163,19 +182,19 @@ impl DependencyAnalyzer {
Ok(dependencies) Ok(dependencies)
} }
/// Extract get-var calls as dependencies (for transformed scripts with variables) /// Extracts get-var calls as dependencies for transformed scripts.
/// These are self-references to the current table /// These represent self-references to the current table.
fn extract_get_var_calls(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> { fn extract_get_var_calls(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> {
let mut dependencies = Vec::new(); let mut dependencies = Vec::new();
// Look for get-var patterns in transformed scripts: (get-var "variable") // Pattern: (get-var "variable")
let get_var_pattern = regex::Regex::new(r#"\(get-var\s+"([^"]+)"\)"#) let get_var_pattern = regex::Regex::new(r#"\(get-var\s+"([^"]+)"\)"#)
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?; .map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
for caps in get_var_pattern.captures_iter(script) { for caps in get_var_pattern.captures_iter(script) {
let variable_name = caps[1].to_string(); let variable_name = caps[1].to_string();
dependencies.push(Dependency { dependencies.push(Dependency {
target_table: current_table_name.to_string(), // Use actual table name target_table: current_table_name.to_string(),
dependency_type: DependencyType::ColumnAccess { column: variable_name }, dependency_type: DependencyType::ColumnAccess { column: variable_name },
context: None, context: None,
}); });
@@ -184,19 +203,19 @@ impl DependencyAnalyzer {
Ok(dependencies) Ok(dependencies)
} }
/// Extract direct variable references like @price, @quantity /// Extracts direct variable references like @price, @quantity.
/// These are self-references to the current table /// These represent self-references to the current table.
fn extract_variable_references(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> { fn extract_variable_references(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> {
let mut dependencies = Vec::new(); let mut dependencies = Vec::new();
// Look for @variable patterns: @price, @quantity, etc. // Pattern: @variable_name
let variable_pattern = regex::Regex::new(r#"@([a-zA-Z_][a-zA-Z0-9_]*)"#) let variable_pattern = regex::Regex::new(r#"@([a-zA-Z_][a-zA-Z0-9_]*)"#)
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?; .map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
for caps in variable_pattern.captures_iter(script) { for caps in variable_pattern.captures_iter(script) {
let variable_name = caps[1].to_string(); let variable_name = caps[1].to_string();
dependencies.push(Dependency { dependencies.push(Dependency {
target_table: current_table_name.to_string(), // Use actual table name target_table: current_table_name.to_string(),
dependency_type: DependencyType::ColumnAccess { column: variable_name }, dependency_type: DependencyType::ColumnAccess { column: variable_name },
context: None, context: None,
}); });
@@ -205,11 +224,11 @@ impl DependencyAnalyzer {
Ok(dependencies) Ok(dependencies)
} }
/// Extract table references from SQL queries in steel_query_sql calls /// Extracts table references from SQL queries in steel_query_sql calls.
fn extract_sql_dependencies(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> { fn extract_sql_dependencies(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> {
let mut dependencies = Vec::new(); let mut dependencies = Vec::new();
// Look for steel_query_sql calls and extract table names from the SQL // Pattern: (steel_query_sql "SELECT ... FROM table ...")
let sql_pattern = regex::Regex::new(r#"\(\s*steel_query_sql\s+"([^"]+)""#) let sql_pattern = regex::Regex::new(r#"\(\s*steel_query_sql\s+"([^"]+)""#)
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?; .map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
@@ -231,12 +250,12 @@ impl DependencyAnalyzer {
Ok(dependencies) Ok(dependencies)
} }
/// Extract table names from SQL query text /// Extracts table names from SQL query text using regex patterns.
/// Looks for FROM and JOIN clauses to identify table references.
fn extract_table_references_from_sql(&self, sql: &str) -> Result<Vec<String>, DependencyError> { fn extract_table_references_from_sql(&self, sql: &str) -> Result<Vec<String>, DependencyError> {
let mut tables = Vec::new(); let mut tables = Vec::new();
// Simple extraction - look for FROM and JOIN clauses // Pattern: FROM table_name or JOIN table_name
// This could be made more sophisticated with a proper SQL parser
let table_pattern = regex::Regex::new(r#"(?i)\b(?:FROM|JOIN)\s+(?:"([^"]+)"|(\w+))"#) let table_pattern = regex::Regex::new(r#"(?i)\b(?:FROM|JOIN)\s+(?:"([^"]+)"|(\w+))"#)
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?; .map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
@@ -253,18 +272,30 @@ impl DependencyAnalyzer {
Ok(tables) Ok(tables)
} }
/// Check for cycles in the dependency graph using proper DFS /// Checks for circular dependencies in the dependency graph.
/// Self-references are allowed and filtered out from cycle detection ///
/// This function validates that adding new dependencies won't create cycles
/// that could lead to infinite loops during script execution. Self-references
/// are explicitly allowed and filtered out from cycle detection.
///
/// # Arguments
/// * `tx` - Database transaction for querying existing dependencies
/// * `table_id` - ID of the table adding new dependencies
/// * `new_dependencies` - Dependencies to be added
///
/// # Returns
/// * `Ok(())` - No cycles detected, safe to add dependencies
/// * `Err(DependencyError)` - Cycle detected or validation failed
pub async fn check_for_cycles( pub async fn check_for_cycles(
&self, &self,
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>, tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
table_id: i64, table_id: i64,
new_dependencies: &[Dependency], new_dependencies: &[Dependency],
) -> Result<(), DependencyError> { ) -> Result<(), DependencyError> {
// FIRST: Validate that structured table access respects link constraints // Validate that structured table access respects link constraints
self.validate_link_constraints(tx, table_id, new_dependencies).await?; self.validate_link_constraints(tx, table_id, new_dependencies).await?;
// Get current dependency graph for this schema // Build current dependency graph excluding self-references
let current_deps = sqlx::query!( let current_deps = sqlx::query!(
r#"SELECT sd.source_table_id, sd.target_table_id, st.table_name as source_name, tt.table_name as target_name r#"SELECT sd.source_table_id, sd.target_table_id, st.table_name as source_name, tt.table_name as target_name
FROM script_dependencies sd FROM script_dependencies sd
@@ -277,12 +308,11 @@ impl DependencyAnalyzer {
.await .await
.map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?; .map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?;
// Build adjacency list - EXCLUDE self-references since they're always allowed
let mut graph: HashMap<i64, Vec<i64>> = HashMap::new(); let mut graph: HashMap<i64, Vec<i64>> = HashMap::new();
let mut table_names: HashMap<i64, String> = HashMap::new(); let mut table_names: HashMap<i64, String> = HashMap::new();
// Build adjacency list excluding self-references
for dep in current_deps { for dep in current_deps {
// Skip self-references in cycle detection
if dep.source_table_id != dep.target_table_id { if dep.source_table_id != dep.target_table_id {
graph.entry(dep.source_table_id).or_default().push(dep.target_table_id); graph.entry(dep.source_table_id).or_default().push(dep.target_table_id);
} }
@@ -290,9 +320,8 @@ impl DependencyAnalyzer {
table_names.insert(dep.target_table_id, dep.target_name); table_names.insert(dep.target_table_id, dep.target_name);
} }
// Add new dependencies to test - EXCLUDE self-references // Add new dependencies to test (excluding self-references)
for dep in new_dependencies { for dep in new_dependencies {
// Look up target table ID using the actual table name
let target_id = sqlx::query_scalar!( let target_id = sqlx::query_scalar!(
"SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2", "SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2",
self.schema_id, self.schema_id,
@@ -306,12 +335,12 @@ impl DependencyAnalyzer {
script_context: format!("table_id_{}", table_id), script_context: format!("table_id_{}", table_id),
})?; })?;
// Only add to cycle detection graph if it's NOT a self-reference // Only add to cycle detection if not a self-reference
if table_id != target_id { if table_id != target_id {
graph.entry(table_id).or_default().push(target_id); graph.entry(table_id).or_default().push(target_id);
} }
// Get table name for error reporting // Ensure table names are available for error reporting
if !table_names.contains_key(&table_id) { if !table_names.contains_key(&table_id) {
let source_name = sqlx::query_scalar!( let source_name = sqlx::query_scalar!(
"SELECT table_name FROM table_definitions WHERE id = $1", "SELECT table_name FROM table_definitions WHERE id = $1",
@@ -325,12 +354,13 @@ impl DependencyAnalyzer {
} }
} }
// Detect cycles using proper DFS algorithm (now without self-references) // Detect cycles using DFS algorithm
self.detect_cycles_dfs(&graph, &table_names, table_id)?; self.detect_cycles_dfs(&graph, &table_names, table_id)?;
Ok(()) Ok(())
} }
/// Performs depth-first search to detect cycles in the dependency graph.
fn dfs_visit( fn dfs_visit(
&self, &self,
node: i64, node: i64,
@@ -345,21 +375,18 @@ impl DependencyAnalyzer {
if let Some(neighbors) = graph.get(&node) { if let Some(neighbors) = graph.get(&node) {
for &neighbor in neighbors { for &neighbor in neighbors {
// Ensure neighbor is in states map
if !states.contains_key(&neighbor) { if !states.contains_key(&neighbor) {
states.insert(neighbor, NodeState::Unvisited); states.insert(neighbor, NodeState::Unvisited);
} }
match states.get(&neighbor).copied().unwrap_or(NodeState::Unvisited) { match states.get(&neighbor).copied().unwrap_or(NodeState::Unvisited) {
NodeState::Visiting => { NodeState::Visiting => {
// Check if this is a self-reference (allowed) or a real cycle (not allowed) // Skip self-references as they're allowed
if neighbor == node { if neighbor == node {
// Self-reference: A table referencing itself is allowed
// Skip this - it's not a harmful cycle
continue; continue;
} }
// Found a real cycle! Build the cycle path // Found a cycle - build the cycle path
let cycle_start_idx = path.iter().position(|&x| x == neighbor).unwrap_or(0); let cycle_start_idx = path.iter().position(|&x| x == neighbor).unwrap_or(0);
let cycle_path: Vec<String> = path[cycle_start_idx..] let cycle_path: Vec<String> = path[cycle_start_idx..]
.iter() .iter()
@@ -367,7 +394,7 @@ impl DependencyAnalyzer {
.map(|&id| table_names.get(&id).cloned().unwrap_or_else(|| id.to_string())) .map(|&id| table_names.get(&id).cloned().unwrap_or_else(|| id.to_string()))
.collect(); .collect();
// Only report as error if the cycle involves more than one table // Only report as error if cycle involves multiple tables
if cycle_path.len() > 2 || (cycle_path.len() == 2 && cycle_path[0] != cycle_path[1]) { if cycle_path.len() > 2 || (cycle_path.len() == 2 && cycle_path[0] != cycle_path[1]) {
let involving_script = table_names.get(&starting_table) let involving_script = table_names.get(&starting_table)
.cloned() .cloned()
@@ -380,7 +407,6 @@ impl DependencyAnalyzer {
} }
} }
NodeState::Unvisited => { NodeState::Unvisited => {
// Recursively visit unvisited neighbor
self.dfs_visit(neighbor, states, graph, path, table_names, starting_table)?; self.dfs_visit(neighbor, states, graph, path, table_names, starting_table)?;
} }
NodeState::Visited => { NodeState::Visited => {
@@ -395,23 +421,23 @@ impl DependencyAnalyzer {
Ok(()) Ok(())
} }
/// Validates that structured table access (steel_get_column functions) respects link constraints /// Validates that structured table access respects table link constraints.
/// Raw SQL access (steel_query_sql) is allowed to reference any table
/// SELF-REFERENCES are always allowed (table can access its own columns)
/// ///
/// Example: /// # Access Rules
/// - Table A can ALWAYS use: (steel_get_column "table_a" "column_name") ✅ (self-reference) /// - Self-references are always allowed (table can access its own columns)
/// - Table A is linked to Table B via table_definition_links /// - Structured access (steel_get_column functions) requires explicit table links
/// - Script for Table A can use: (steel_get_column "table_b" "column_name") ✅ /// - Raw SQL access (steel_query_sql) is unrestricted
/// - Script for Table A CANNOT use: (steel_get_column "table_c" "column_name") ❌ ///
/// - Script for Table A CAN use: (steel_query_sql "SELECT * FROM table_c") ✅ /// # Arguments
/// * `tx` - Database transaction for querying table links
/// * `source_table_id` - ID of the table with the script
/// * `dependencies` - Dependencies to validate
async fn validate_link_constraints( async fn validate_link_constraints(
&self, &self,
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>, tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
source_table_id: i64, source_table_id: i64,
dependencies: &[Dependency], dependencies: &[Dependency],
) -> Result<(), DependencyError> { ) -> Result<(), DependencyError> {
// Get the current table name for self-reference checking
let current_table_name = sqlx::query_scalar!( let current_table_name = sqlx::query_scalar!(
"SELECT table_name FROM table_definitions WHERE id = $1", "SELECT table_name FROM table_definitions WHERE id = $1",
source_table_id source_table_id
@@ -432,25 +458,24 @@ impl DependencyAnalyzer {
.await .await
.map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?; .map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?;
// Create a set of allowed table names for quick lookup
let mut allowed_tables: std::collections::HashSet<String> = linked_tables let mut allowed_tables: std::collections::HashSet<String> = linked_tables
.into_iter() .into_iter()
.map(|row| row.table_name) .map(|row| row.table_name)
.collect(); .collect();
// ALWAYS allow self-references // Self-references are always allowed
allowed_tables.insert(current_table_name.clone()); allowed_tables.insert(current_table_name.clone());
// Validate each dependency // Validate each dependency
for dep in dependencies { for dep in dependencies {
match &dep.dependency_type { match &dep.dependency_type {
// Structured access must respect link constraints (but self-references are always allowed)
DependencyType::ColumnAccess { column } | DependencyType::IndexedAccess { column, .. } => { DependencyType::ColumnAccess { column } | DependencyType::IndexedAccess { column, .. } => {
// Self-references are always allowed (compare table names directly) // Allow self-references
if dep.target_table == current_table_name { if dep.target_table == current_table_name {
continue; continue;
} }
// Check if table is linked
if !allowed_tables.contains(&dep.target_table) { if !allowed_tables.contains(&dep.target_table) {
return Err(DependencyError::InvalidTableReference { return Err(DependencyError::InvalidTableReference {
table_name: dep.target_table.clone(), table_name: dep.target_table.clone(),
@@ -464,9 +489,8 @@ impl DependencyAnalyzer {
}); });
} }
} }
// Raw SQL access is unrestricted
DependencyType::SqlQuery { .. } => { DependencyType::SqlQuery { .. } => {
// No validation - raw SQL can access any table // Raw SQL access is unrestricted
} }
} }
} }
@@ -474,7 +498,7 @@ impl DependencyAnalyzer {
Ok(()) Ok(())
} }
/// Proper DFS-based cycle detection with state tracking /// Runs DFS-based cycle detection on the dependency graph.
fn detect_cycles_dfs( fn detect_cycles_dfs(
&self, &self,
graph: &HashMap<i64, Vec<i64>>, graph: &HashMap<i64, Vec<i64>>,
@@ -508,7 +532,16 @@ impl DependencyAnalyzer {
Ok(()) Ok(())
} }
/// Save dependencies to database within an existing transaction /// Saves dependencies to the database within an existing transaction.
///
/// This function replaces all existing dependencies for a script with the new set,
/// ensuring the database reflects the current script analysis results.
///
/// # Arguments
/// * `tx` - Database transaction for atomic updates
/// * `script_id` - ID of the script these dependencies belong to
/// * `table_id` - ID of the table containing the script
/// * `dependencies` - Dependencies to save
pub async fn save_dependencies( pub async fn save_dependencies(
&self, &self,
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>, tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
@@ -524,7 +557,6 @@ impl DependencyAnalyzer {
// Insert new dependencies // Insert new dependencies
for dep in dependencies { for dep in dependencies {
// Look up target table ID using actual table name (no magic strings!)
let target_id = sqlx::query_scalar!( let target_id = sqlx::query_scalar!(
"SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2", "SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2",
self.schema_id, self.schema_id,

View File

@@ -12,6 +12,7 @@ use rust_decimal::Decimal;
use std::str::FromStr; use std::str::FromStr;
use crate::steel::server::execution::{self, Value}; use crate::steel::server::execution::{self, Value};
use crate::indexer::{IndexCommand, IndexCommandData}; use crate::indexer::{IndexCommand, IndexCommandData};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing::error; use tracing::error;
@@ -153,9 +154,7 @@ pub async fn post_table_data(
format!("Script execution failed for '{}': {}", target_column, e) format!("Script execution failed for '{}': {}", target_column, e)
))?; ))?;
let Value::Strings(mut script_output) = script_result else { let Value::Strings(mut script_output) = script_result;
return Err(Status::internal("Script must return string values"));
};
let expected_value = script_output.pop() let expected_value = script_output.pop()
.ok_or_else(|| Status::internal("Script returned no values"))?; .ok_or_else(|| Status::internal("Script returned no values"))?;

View File

@@ -216,9 +216,8 @@ pub async fn put_table_data(
Status::invalid_argument(format!("Script execution failed for '{}': {}", target_column, e)) Status::invalid_argument(format!("Script execution failed for '{}': {}", target_column, e))
})?; })?;
let Value::Strings(mut script_output_vec) = script_result else { let Value::Strings(mut script_output_vec) = script_result;
return Err(Status::internal("Script must return string values"));
};
let script_output = script_output_vec.pop().ok_or_else(|| Status::internal("Script returned no values"))?; let script_output = script_output_vec.pop().ok_or_else(|| Status::internal("Script returned no values"))?;
if update_data.contains_key(&target_column) { if update_data.contains_key(&target_column) {