fixing warnings and making prod code

This commit is contained in:
filipriec
2025-07-24 23:57:21 +02:00
parent c82813185f
commit c58ce52b33
5 changed files with 173 additions and 188 deletions

View File

@@ -1,4 +1,5 @@
// src/steel/server/execution.rs
use steel::steel_vm::engine::Engine;
use steel::steel_vm::register_fn::RegisterFn;
use steel::rvals::SteelVal;
@@ -10,13 +11,13 @@ use std::collections::HashMap;
use thiserror::Error;
use tracing::{debug, error};
/// Represents different types of values that can be returned from Steel script execution.
#[derive(Debug)]
pub enum Value {
Strings(Vec<String>),
Numbers(Vec<i64>),
Mixed(Vec<SteelVal>),
}
/// Errors that can occur during Steel script execution.
#[derive(Debug, Error)]
pub enum ExecutionError {
#[error("Script execution failed: {0}")]
@@ -27,7 +28,7 @@ pub enum ExecutionError {
UnsupportedType(String),
}
/// Create a SteelContext with boolean conversion applied to row data
/// Creates a Steel execution context with proper boolean value conversion.
pub async fn create_steel_context_with_boolean_conversion(
current_table: String,
schema_id: i64,
@@ -35,10 +36,6 @@ pub async fn create_steel_context_with_boolean_conversion(
mut row_data: HashMap<String, String>,
db_pool: Arc<PgPool>,
) -> Result<SteelContext, ExecutionError> {
println!("=== CREATING STEEL CONTEXT ===");
println!("Table: {}, Schema: {}", current_table, schema_name);
println!("Row data BEFORE boolean conversion: {:?}", row_data);
// Convert boolean values in row_data to Steel format
convert_row_data_for_steel(&db_pool, schema_id, &current_table, &mut row_data)
.await
@@ -47,9 +44,6 @@ pub async fn create_steel_context_with_boolean_conversion(
ExecutionError::RuntimeError(format!("Failed to convert row data: {}", e))
})?;
println!("Row data AFTER boolean conversion: {:?}", row_data);
println!("=== END CREATING STEEL CONTEXT ===");
Ok(SteelContext {
current_table,
schema_id,
@@ -59,7 +53,7 @@ pub async fn create_steel_context_with_boolean_conversion(
})
}
/// Execute script with proper boolean handling
/// Executes a Steel script with database context and type-safe result processing.
pub async fn execute_script(
script: String,
target_type: &str,
@@ -69,20 +63,9 @@ pub async fn execute_script(
current_table: String,
row_data: HashMap<String, String>,
) -> Result<Value, ExecutionError> {
eprintln!("🚨🚨🚨 EXECUTE_SCRIPT CALLED 🚨🚨🚨");
eprintln!("Script: '{}'", script);
eprintln!("Table: {}, Target type: {}", current_table, target_type);
eprintln!("Input row_data: {:?}", row_data);
eprintln!("🚨🚨🚨 END EXECUTE_SCRIPT ENTRY 🚨🚨🚨");
println!("=== STEEL SCRIPT EXECUTION START ===");
println!("Script: '{}'", script);
println!("Table: {}, Target type: {}", current_table, target_type);
println!("Input row_data: {:?}", row_data);
let mut vm = Engine::new();
// Create context with boolean conversion
// Create execution context with proper boolean value conversion
let context = create_steel_context_with_boolean_conversion(
current_table,
schema_id,
@@ -93,41 +76,32 @@ pub async fn execute_script(
let context = Arc::new(context);
// Register existing Steel functions
// Register database access functions
register_steel_functions(&mut vm, context.clone());
// Register all decimal math functions using the steel_decimal crate
// Register decimal math operations
register_decimal_math_functions(&mut vm);
// Register variables from the context with the Steel VM
// The row_data now contains Steel-formatted boolean values
println!("=== REGISTERING STEEL VARIABLES ===");
// Manual variable registration using Steel's define mechanism
// Register row data as variables in the Steel VM
// Both bare names and @-prefixed names are supported for flexibility
let mut define_script = String::new();
println!("Variables being registered with Steel VM:");
for (key, value) in &context.row_data {
println!(" STEEL[{}] = '{}'", key, value);
// Register both @ prefixed and bare variable names
define_script.push_str(&format!("(define {} \"{}\")\n", key, value));
define_script.push_str(&format!("(define @{} \"{}\")\n", key, value));
}
println!("Steel script to execute: {}", script);
println!("=== END REGISTERING STEEL VARIABLES ===");
// Execute variable definitions if any exist
if !define_script.is_empty() {
println!("Define script: {}", define_script);
vm.compile_and_run_raw_program(define_script)
.map_err(|e| ExecutionError::RuntimeError(format!("Failed to register variables: {}", e)))?;
println!("Variables defined successfully");
}
// Also try the original method as backup
// Also register variables using the decimal registry as backup method
FunctionRegistry::register_variables(&mut vm, context.row_data.clone());
// Execute script and process results
println!("Compiling and running Steel script: {}", script);
// Execute the main script
let results = vm.compile_and_run_raw_program(script.clone())
.map_err(|e| {
error!("Steel script execution failed: {}", e);
@@ -136,27 +110,18 @@ pub async fn execute_script(
ExecutionError::RuntimeError(e.to_string())
})?;
println!("Script execution returned {} results", results.len());
for (i, result) in results.iter().enumerate() {
println!("Result[{}]: {:?}", i, result);
}
// Convert results to target type
// Convert results to the requested target type
match target_type {
"STRINGS" => {
let result = process_string_results(results);
println!("Final processed result: {:?}", result);
println!("=== STEEL SCRIPT EXECUTION END ===");
result
},
"STRINGS" => process_string_results(results),
_ => Err(ExecutionError::UnsupportedType(target_type.into()))
}
}
/// Registers Steel functions for database access within the VM context.
fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
debug!("Registering Steel functions with context");
// Register steel_get_column with row context
// Register column access function for current and related tables
vm.register_fn("steel_get_column", {
let ctx = context.clone();
move |table: String, column: String| {
@@ -169,7 +134,7 @@ fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
}
});
// Register steel_get_column_with_index
// Register indexed column access for comma-separated values
vm.register_fn("steel_get_column_with_index", {
let ctx = context.clone();
move |table: String, index: i64, column: String| {
@@ -182,7 +147,7 @@ fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
}
});
// SQL query registration
// Register safe SQL query execution
vm.register_fn("steel_query_sql", {
let ctx = context.clone();
move |query: String| {
@@ -196,45 +161,31 @@ fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
});
}
/// Registers decimal mathematics functions in the Steel VM.
fn register_decimal_math_functions(vm: &mut Engine) {
debug!("Registering decimal math functions");
// Use the steel_decimal crate's FunctionRegistry to register all functions
FunctionRegistry::register_all(vm);
}
/// Processes Steel script results into string format for consistent output.
fn process_string_results(results: Vec<SteelVal>) -> Result<Value, ExecutionError> {
let mut strings = Vec::new();
for result in results {
match result {
SteelVal::StringV(s) => {
let result_str = s.to_string();
println!("Processing string result: '{}'", result_str);
strings.push(result_str);
},
SteelVal::NumV(n) => {
let result_str = n.to_string();
println!("Processing number result: '{}'", result_str);
strings.push(result_str);
},
SteelVal::IntV(i) => {
let result_str = i.to_string();
println!("Processing integer result: '{}'", result_str);
strings.push(result_str);
},
SteelVal::BoolV(b) => {
let result_str = b.to_string();
println!("Processing boolean result: '{}'", result_str);
strings.push(result_str);
},
let result_str = match result {
SteelVal::StringV(s) => s.to_string(),
SteelVal::NumV(n) => n.to_string(),
SteelVal::IntV(i) => i.to_string(),
SteelVal::BoolV(b) => b.to_string(),
_ => {
error!("Unexpected result type: {:?}", result);
return Err(ExecutionError::TypeConversionError(
format!("Expected string-convertible type, got {:?}", result)
));
}
}
};
strings.push(result_str);
}
println!("Final processed strings: {:?}", strings);
Ok(Value::Strings(strings))
}

View File

@@ -1,4 +1,5 @@
// src/steel/server/functions.rs
use steel::rvals::SteelVal;
use sqlx::PgPool;
use std::collections::HashMap;
@@ -20,9 +21,10 @@ pub enum FunctionError {
ProhibitedTypeAccess(String),
}
// Define prohibited data types (boolean is explicitly allowed)
/// Data types that Steel scripts are prohibited from accessing for security reasons
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
/// Execution context for Steel scripts with database access capabilities.
#[derive(Clone)]
pub struct SteelContext {
pub current_table: String,
@@ -33,6 +35,8 @@ pub struct SteelContext {
}
impl SteelContext {
/// Resolves a base table name to its full qualified name in the current schema.
/// Used for foreign key relationship traversal in Steel scripts.
pub async fn get_related_table_name(&self, base_name: &str) -> Result<String, FunctionError> {
let table_def = sqlx::query!(
r#"SELECT table_name FROM table_definitions
@@ -48,7 +52,8 @@ impl SteelContext {
Ok(table_def.table_name)
}
/// Get column type for a given table and column
/// Retrieves the SQL data type for a specific column in a table.
/// Parses the JSON column definitions to find type information.
async fn get_column_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
let table_def = sqlx::query!(
r#"SELECT columns FROM table_definitions
@@ -61,11 +66,10 @@ impl SteelContext {
.map_err(|e| FunctionError::DatabaseError(e.to_string()))?
.ok_or_else(|| FunctionError::TableNotFound(table_name.to_string()))?;
// Parse columns JSON to find the column type
let columns: Vec<String> = serde_json::from_value(table_def.columns)
.map_err(|e| FunctionError::DatabaseError(format!("Invalid column data: {}", e)))?;
// Find the column and its type
// Parse column definitions to find the requested column type
for column_def in columns {
let mut parts = column_def.split_whitespace();
if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) {
@@ -83,31 +87,30 @@ impl SteelContext {
)))
}
/// Convert database value to Steel format based on column type
/// Converts database values to Steel script format based on column type.
/// Currently handles boolean conversion to Steel's #true/#false syntax.
fn convert_value_to_steel_format(&self, value: &str, column_type: &str) -> String {
let normalized_type = normalize_data_type(column_type);
match normalized_type.as_str() {
"BOOLEAN" | "BOOL" => {
// Convert database boolean to Steel boolean syntax
// Convert database boolean representations to Steel boolean syntax
match value.to_lowercase().as_str() {
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
_ => value.to_string(), // Return as-is if not a recognized boolean
}
}
"INTEGER" => {
value.to_string()
}
_ => value.to_string(), // Return as-is for non-boolean types
"INTEGER" => value.to_string(),
_ => value.to_string(), // Return as-is for other types
}
}
/// Validate column type and return the column type if valid
/// Validates that a column type is allowed for Steel script access.
/// Returns the column type if validation passes, error if prohibited.
async fn validate_column_type_and_get_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
let column_type = self.get_column_type(table_name, column_name).await?;
// Check if this type is prohibited
if is_prohibited_type(&column_type) {
return Err(FunctionError::ProhibitedTypeAccess(format!(
"Cannot access column '{}' in table '{}' because it has prohibited type '{}'. Steel scripts cannot access columns of type: {}",
@@ -121,15 +124,15 @@ impl SteelContext {
Ok(column_type)
}
/// Validate column type before access (legacy method for backward compatibility)
async fn validate_column_type(&self, table_name: &str, column_name: &str) -> Result<(), FunctionError> {
self.validate_column_type_and_get_type(table_name, column_name).await?;
Ok(())
}
/// Retrieves column value from current table or related tables via foreign keys.
///
/// # Behavior
/// - Current table: Returns value directly from row_data with type conversion
/// - Related table: Follows foreign key relationship and queries database
/// - All accesses are subject to prohibited type validation
pub fn steel_get_column(&self, table: &str, column: &str) -> Result<SteelVal, SteelVal> {
if table == self.current_table {
// Validate column type for current table access and get the type
// Access current table data with type validation
let column_type = tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current();
handle.block_on(async {
@@ -150,6 +153,7 @@ impl SteelContext {
.ok_or_else(|| SteelVal::StringV(format!("Column {} not found", column).into()));
}
// Access related table via foreign key relationship
let base_name = table.split_once('_')
.map(|(_, rest)| rest)
.unwrap_or(table);
@@ -158,18 +162,17 @@ impl SteelContext {
let fk_value = self.row_data.get(&fk_column)
.ok_or_else(|| SteelVal::StringV(format!("Foreign key {} not found", fk_column).into()))?;
// Use `tokio::task::block_in_place` to safely block the thread
let result = tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current();
handle.block_on(async {
let actual_table = self.get_related_table_name(base_name).await
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
// Get column type for validation and conversion
// Validate column type and get type information
let column_type = self.validate_column_type_and_get_type(&actual_table, column).await
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
// Fetch the raw value from database
// Query the related table for the column value
let raw_value = sqlx::query_scalar::<_, String>(
&format!("SELECT {} FROM \"{}\".\"{}\" WHERE id = $1", column, self.schema_name, actual_table)
)
@@ -179,7 +182,7 @@ impl SteelContext {
.await
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
// Convert to Steel format
// Convert to appropriate Steel format
let converted_value = self.convert_value_to_steel_format(&raw_value, &column_type);
Ok(converted_value)
})
@@ -188,13 +191,15 @@ impl SteelContext {
result.map(|v| SteelVal::StringV(v.into()))
}
/// Retrieves a specific indexed element from a comma-separated column value.
/// Useful for accessing elements from array-like string representations.
pub fn steel_get_column_with_index(
&self,
table: &str,
index: i64,
column: &str
) -> Result<SteelVal, SteelVal> {
// Get the full value first (this already handles type conversion)
// Get the full value with proper type conversion
let value = self.steel_get_column(table, column)?;
if let SteelVal::StringV(s) = value {
@@ -203,8 +208,7 @@ impl SteelContext {
if let Some(part) = parts.get(index as usize) {
let trimmed_part = part.trim();
// If the original column was boolean type, each part should also be treated as boolean
// We need to get the column type to determine if conversion is needed
// Apply type conversion to the indexed part based on original column type
let column_type = tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current();
handle.block_on(async {
@@ -218,7 +222,7 @@ impl SteelContext {
Ok(SteelVal::StringV(converted_part.into()))
}
Err(_) => {
// If we can't get the type, return as-is
// If type cannot be determined, return value as-is
Ok(SteelVal::StringV(trimmed_part.into()))
}
}
@@ -230,15 +234,19 @@ impl SteelContext {
}
}
/// Executes read-only SQL queries from Steel scripts with safety restrictions.
///
/// # Security Features
/// - Only SELECT, SHOW, and EXPLAIN queries allowed
/// - Prohibited column type access validation
/// - Returns first column of all rows as comma-separated string
pub fn steel_query_sql(&self, query: &str) -> Result<SteelVal, SteelVal> {
// Validate query is read-only
if !is_read_only_query(query) {
return Err(SteelVal::StringV(
"Only SELECT queries are allowed".into()
));
}
// Check if query might access prohibited columns
if contains_prohibited_column_access(query) {
return Err(SteelVal::StringV(format!(
"SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}",
@@ -248,11 +256,9 @@ impl SteelContext {
let pool = self.db_pool.clone();
// Use `tokio::task::block_in_place` to safely block the thread
let result = tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current();
handle.block_on(async {
// Execute and get first column of all rows as strings
let rows = sqlx::query(query)
.fetch_all(&*pool)
.await
@@ -273,13 +279,14 @@ impl SteelContext {
}
}
/// Check if a data type is prohibited for Steel scripts
/// Checks if a data type is prohibited for Steel script access.
fn is_prohibited_type(data_type: &str) -> bool {
let normalized_type = normalize_data_type(data_type);
PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited))
}
/// Normalize data type for comparison (handle NUMERIC variations, etc.)
/// Normalizes data type strings for consistent comparison.
/// Handles variations like NUMERIC(10,2) by extracting base type.
fn normalize_data_type(data_type: &str) -> String {
data_type.to_uppercase()
.split('(') // Remove precision/scale from NUMERIC(x,y)
@@ -289,13 +296,11 @@ fn normalize_data_type(data_type: &str) -> String {
.to_string()
}
/// Basic check for prohibited column access in SQL queries
/// This is a simple heuristic - more sophisticated parsing could be added
/// Performs basic heuristic check for prohibited column type access in SQL queries.
/// Looks for common patterns that might indicate access to restricted types.
fn contains_prohibited_column_access(query: &str) -> bool {
let query_upper = query.to_uppercase();
// Look for common patterns that might indicate prohibited type access
// This is a basic implementation - you might want to enhance this
let patterns = [
"EXTRACT(", // Common with DATE/TIMESTAMPTZ
"DATE_PART(", // Common with DATE/TIMESTAMPTZ
@@ -307,6 +312,7 @@ fn contains_prohibited_column_access(query: &str) -> bool {
patterns.iter().any(|pattern| query_upper.contains(pattern))
}
/// Validates that a query is read-only and safe for Steel script execution.
fn is_read_only_query(query: &str) -> bool {
let query = query.trim_start().to_uppercase();
query.starts_with("SELECT") ||
@@ -314,14 +320,13 @@ fn is_read_only_query(query: &str) -> bool {
query.starts_with("EXPLAIN")
}
/// Helper function to convert initial row data for boolean columns
/// Converts row data boolean values to Steel script format during context initialization.
pub async fn convert_row_data_for_steel(
db_pool: &PgPool,
schema_id: i64,
table_name: &str,
row_data: &mut HashMap<String, String>,
) -> Result<(), sqlx::Error> {
// Get table definition to check column types
let table_def = sqlx::query!(
r#"SELECT columns FROM table_definitions
WHERE schema_id = $1 AND table_name = $2"#,
@@ -332,7 +337,7 @@ pub async fn convert_row_data_for_steel(
.await?
.ok_or_else(|| sqlx::Error::RowNotFound)?;
// Parse columns to find boolean types
// Parse column definitions to identify boolean columns for conversion
if let Ok(columns) = serde_json::from_value::<Vec<String>>(table_def.columns) {
for column_def in columns {
let mut parts = column_def.split_whitespace();
@@ -340,7 +345,6 @@ pub async fn convert_row_data_for_steel(
let column_name = name.trim_matches('"');
let normalized_type = normalize_data_type(data_type);
// Fixed: Use traditional if let instead of let chains
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
if let Some(value) = row_data.get_mut(column_name) {
// Convert boolean value to Steel format

View File

@@ -5,6 +5,7 @@ use tonic::Status;
use sqlx::PgPool;
use serde_json::{json, Value};
/// Represents the state of a node during dependency graph traversal.
#[derive(Clone, Copy, PartialEq)]
enum NodeState {
Unvisited,
@@ -12,6 +13,7 @@ enum NodeState {
Visited, // Completely processed
}
/// Represents a dependency relationship between tables in Steel scripts.
#[derive(Debug, Clone)]
pub struct Dependency {
pub target_table: String,
@@ -19,14 +21,19 @@ pub struct Dependency {
pub context: Option<Value>,
}
/// Types of dependencies that can exist between tables in Steel scripts.
#[derive(Debug, Clone)]
pub enum DependencyType {
/// Direct column access via steel_get_column
ColumnAccess { column: String },
/// Indexed column access via steel_get_column_with_index
IndexedAccess { column: String, index: i64 },
/// Raw SQL query access via steel_query_sql
SqlQuery { query_fragment: String },
}
impl DependencyType {
/// Returns the string representation used in the database.
pub fn as_str(&self) -> &'static str {
match self {
DependencyType::ColumnAccess { .. } => "column_access",
@@ -35,6 +42,7 @@ impl DependencyType {
}
}
/// Generates context JSON for database storage.
pub fn context_json(&self) -> Value {
match self {
DependencyType::ColumnAccess { column } => {
@@ -50,6 +58,7 @@ impl DependencyType {
}
}
/// Errors that can occur during dependency analysis.
#[derive(Debug)]
pub enum DependencyError {
CircularDependency {
@@ -94,43 +103,53 @@ impl From<DependencyError> for Status {
}
}
/// Analyzes Steel scripts to extract table dependencies and validate referential integrity.
///
/// This analyzer identifies how tables reference each other through Steel function calls
/// and ensures that dependency graphs remain acyclic while respecting table link constraints.
pub struct DependencyAnalyzer {
schema_id: i64,
pool: PgPool,
}
impl DependencyAnalyzer {
/// Creates a new dependency analyzer for the specified schema.
pub fn new(schema_id: i64, pool: PgPool) -> Self {
Self { schema_id, pool }
Self { schema_id }
}
/// Analyzes a Steel script to extract all table dependencies
/// Uses regex patterns to find function calls that create dependencies
/// Analyzes a Steel script to extract all table dependencies.
///
/// Uses regex patterns to identify function calls that create dependencies:
/// - `steel_get_column` calls for direct column access
/// - `steel_get_column_with_index` calls for indexed access
/// - `steel_query_sql` calls for raw SQL access
/// - Variable references like `@column_name`
/// - `get-var` calls in transformed scripts
///
/// # Arguments
/// * `script` - The Steel script to analyze
/// * `script` - The Steel script code to analyze
/// * `current_table_name` - Name of the table this script belongs to (for self-references)
///
/// # Returns
/// * `Ok(Vec<Dependency>)` - List of identified dependencies
/// * `Err(DependencyError)` - If script parsing fails
pub fn analyze_script_dependencies(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> {
let mut dependencies = Vec::new();
// Extract function calls and SQL dependencies using regex
// Extract different types of dependencies using regex patterns
dependencies.extend(self.extract_function_calls(script)?);
dependencies.extend(self.extract_sql_dependencies(script)?);
// Extract get-var calls (for transformed scripts with variables)
dependencies.extend(self.extract_get_var_calls(script, current_table_name)?);
// Extract direct variable references like @price, @quantity
dependencies.extend(self.extract_variable_references(script, current_table_name)?);
Ok(dependencies)
}
/// Extract function calls using regex patterns
/// Extracts Steel function calls that create table dependencies.
fn extract_function_calls(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> {
let mut dependencies = Vec::new();
// Look for steel_get_column patterns
// Pattern: (steel_get_column "table" "column")
let column_pattern = regex::Regex::new(r#"\(\s*steel_get_column\s+"([^"]+)"\s+"([^"]+)""#)
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
@@ -144,7 +163,7 @@ impl DependencyAnalyzer {
});
}
// Look for steel_get_column_with_index patterns
// Pattern: (steel_get_column_with_index "table" index "column")
let indexed_pattern = regex::Regex::new(r#"\(\s*steel_get_column_with_index\s+"([^"]+)"\s+(\d+)\s+"([^"]+)""#)
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
@@ -163,19 +182,19 @@ impl DependencyAnalyzer {
Ok(dependencies)
}
/// Extract get-var calls as dependencies (for transformed scripts with variables)
/// These are self-references to the current table
/// Extracts get-var calls as dependencies for transformed scripts.
/// These represent self-references to the current table.
fn extract_get_var_calls(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> {
let mut dependencies = Vec::new();
// Look for get-var patterns in transformed scripts: (get-var "variable")
// Pattern: (get-var "variable")
let get_var_pattern = regex::Regex::new(r#"\(get-var\s+"([^"]+)"\)"#)
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
for caps in get_var_pattern.captures_iter(script) {
let variable_name = caps[1].to_string();
dependencies.push(Dependency {
target_table: current_table_name.to_string(), // Use actual table name
target_table: current_table_name.to_string(),
dependency_type: DependencyType::ColumnAccess { column: variable_name },
context: None,
});
@@ -184,19 +203,19 @@ impl DependencyAnalyzer {
Ok(dependencies)
}
/// Extract direct variable references like @price, @quantity
/// These are self-references to the current table
/// Extracts direct variable references like @price, @quantity.
/// These represent self-references to the current table.
fn extract_variable_references(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> {
let mut dependencies = Vec::new();
// Look for @variable patterns: @price, @quantity, etc.
// Pattern: @variable_name
let variable_pattern = regex::Regex::new(r#"@([a-zA-Z_][a-zA-Z0-9_]*)"#)
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
for caps in variable_pattern.captures_iter(script) {
let variable_name = caps[1].to_string();
dependencies.push(Dependency {
target_table: current_table_name.to_string(), // Use actual table name
target_table: current_table_name.to_string(),
dependency_type: DependencyType::ColumnAccess { column: variable_name },
context: None,
});
@@ -205,11 +224,11 @@ impl DependencyAnalyzer {
Ok(dependencies)
}
/// Extract table references from SQL queries in steel_query_sql calls
/// Extracts table references from SQL queries in steel_query_sql calls.
fn extract_sql_dependencies(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> {
let mut dependencies = Vec::new();
// Look for steel_query_sql calls and extract table names from the SQL
// Pattern: (steel_query_sql "SELECT ... FROM table ...")
let sql_pattern = regex::Regex::new(r#"\(\s*steel_query_sql\s+"([^"]+)""#)
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
@@ -231,12 +250,12 @@ impl DependencyAnalyzer {
Ok(dependencies)
}
/// Extract table names from SQL query text
/// Extracts table names from SQL query text using regex patterns.
/// Looks for FROM and JOIN clauses to identify table references.
fn extract_table_references_from_sql(&self, sql: &str) -> Result<Vec<String>, DependencyError> {
let mut tables = Vec::new();
// Simple extraction - look for FROM and JOIN clauses
// This could be made more sophisticated with a proper SQL parser
// Pattern: FROM table_name or JOIN table_name
let table_pattern = regex::Regex::new(r#"(?i)\b(?:FROM|JOIN)\s+(?:"([^"]+)"|(\w+))"#)
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
@@ -253,18 +272,30 @@ impl DependencyAnalyzer {
Ok(tables)
}
/// Check for cycles in the dependency graph using proper DFS
/// Self-references are allowed and filtered out from cycle detection
/// Checks for circular dependencies in the dependency graph.
///
/// This function validates that adding new dependencies won't create cycles
/// that could lead to infinite loops during script execution. Self-references
/// are explicitly allowed and filtered out from cycle detection.
///
/// # Arguments
/// * `tx` - Database transaction for querying existing dependencies
/// * `table_id` - ID of the table adding new dependencies
/// * `new_dependencies` - Dependencies to be added
///
/// # Returns
/// * `Ok(())` - No cycles detected, safe to add dependencies
/// * `Err(DependencyError)` - Cycle detected or validation failed
pub async fn check_for_cycles(
&self,
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
table_id: i64,
new_dependencies: &[Dependency],
) -> Result<(), DependencyError> {
// FIRST: Validate that structured table access respects link constraints
// Validate that structured table access respects link constraints
self.validate_link_constraints(tx, table_id, new_dependencies).await?;
// Get current dependency graph for this schema
// Build current dependency graph excluding self-references
let current_deps = sqlx::query!(
r#"SELECT sd.source_table_id, sd.target_table_id, st.table_name as source_name, tt.table_name as target_name
FROM script_dependencies sd
@@ -277,12 +308,11 @@ impl DependencyAnalyzer {
.await
.map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?;
// Build adjacency list - EXCLUDE self-references since they're always allowed
let mut graph: HashMap<i64, Vec<i64>> = HashMap::new();
let mut table_names: HashMap<i64, String> = HashMap::new();
// Build adjacency list excluding self-references
for dep in current_deps {
// Skip self-references in cycle detection
if dep.source_table_id != dep.target_table_id {
graph.entry(dep.source_table_id).or_default().push(dep.target_table_id);
}
@@ -290,9 +320,8 @@ impl DependencyAnalyzer {
table_names.insert(dep.target_table_id, dep.target_name);
}
// Add new dependencies to test - EXCLUDE self-references
// Add new dependencies to test (excluding self-references)
for dep in new_dependencies {
// Look up target table ID using the actual table name
let target_id = sqlx::query_scalar!(
"SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2",
self.schema_id,
@@ -306,12 +335,12 @@ impl DependencyAnalyzer {
script_context: format!("table_id_{}", table_id),
})?;
// Only add to cycle detection graph if it's NOT a self-reference
// Only add to cycle detection if not a self-reference
if table_id != target_id {
graph.entry(table_id).or_default().push(target_id);
}
// Get table name for error reporting
// Ensure table names are available for error reporting
if !table_names.contains_key(&table_id) {
let source_name = sqlx::query_scalar!(
"SELECT table_name FROM table_definitions WHERE id = $1",
@@ -325,12 +354,13 @@ impl DependencyAnalyzer {
}
}
// Detect cycles using proper DFS algorithm (now without self-references)
// Detect cycles using DFS algorithm
self.detect_cycles_dfs(&graph, &table_names, table_id)?;
Ok(())
}
/// Performs depth-first search to detect cycles in the dependency graph.
fn dfs_visit(
&self,
node: i64,
@@ -345,21 +375,18 @@ impl DependencyAnalyzer {
if let Some(neighbors) = graph.get(&node) {
for &neighbor in neighbors {
// Ensure neighbor is in states map
if !states.contains_key(&neighbor) {
states.insert(neighbor, NodeState::Unvisited);
}
match states.get(&neighbor).copied().unwrap_or(NodeState::Unvisited) {
NodeState::Visiting => {
// Check if this is a self-reference (allowed) or a real cycle (not allowed)
// Skip self-references as they're allowed
if neighbor == node {
// Self-reference: A table referencing itself is allowed
// Skip this - it's not a harmful cycle
continue;
}
// Found a real cycle! Build the cycle path
// Found a cycle - build the cycle path
let cycle_start_idx = path.iter().position(|&x| x == neighbor).unwrap_or(0);
let cycle_path: Vec<String> = path[cycle_start_idx..]
.iter()
@@ -367,7 +394,7 @@ impl DependencyAnalyzer {
.map(|&id| table_names.get(&id).cloned().unwrap_or_else(|| id.to_string()))
.collect();
// Only report as error if the cycle involves more than one table
// Only report as error if cycle involves multiple tables
if cycle_path.len() > 2 || (cycle_path.len() == 2 && cycle_path[0] != cycle_path[1]) {
let involving_script = table_names.get(&starting_table)
.cloned()
@@ -380,7 +407,6 @@ impl DependencyAnalyzer {
}
}
NodeState::Unvisited => {
// Recursively visit unvisited neighbor
self.dfs_visit(neighbor, states, graph, path, table_names, starting_table)?;
}
NodeState::Visited => {
@@ -395,23 +421,23 @@ impl DependencyAnalyzer {
Ok(())
}
/// Validates that structured table access (steel_get_column functions) respects link constraints
/// Raw SQL access (steel_query_sql) is allowed to reference any table
/// SELF-REFERENCES are always allowed (table can access its own columns)
/// Validates that structured table access respects table link constraints.
///
/// Example:
/// - Table A can ALWAYS use: (steel_get_column "table_a" "column_name") ✅ (self-reference)
/// - Table A is linked to Table B via table_definition_links
/// - Script for Table A can use: (steel_get_column "table_b" "column_name") ✅
/// - Script for Table A CANNOT use: (steel_get_column "table_c" "column_name") ❌
/// - Script for Table A CAN use: (steel_query_sql "SELECT * FROM table_c") ✅
/// # Access Rules
/// - Self-references are always allowed (table can access its own columns)
/// - Structured access (steel_get_column functions) requires explicit table links
/// - Raw SQL access (steel_query_sql) is unrestricted
///
/// # Arguments
/// * `tx` - Database transaction for querying table links
/// * `source_table_id` - ID of the table with the script
/// * `dependencies` - Dependencies to validate
async fn validate_link_constraints(
&self,
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
source_table_id: i64,
dependencies: &[Dependency],
) -> Result<(), DependencyError> {
// Get the current table name for self-reference checking
let current_table_name = sqlx::query_scalar!(
"SELECT table_name FROM table_definitions WHERE id = $1",
source_table_id
@@ -432,25 +458,24 @@ impl DependencyAnalyzer {
.await
.map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?;
// Create a set of allowed table names for quick lookup
let mut allowed_tables: std::collections::HashSet<String> = linked_tables
.into_iter()
.map(|row| row.table_name)
.collect();
// ALWAYS allow self-references
// Self-references are always allowed
allowed_tables.insert(current_table_name.clone());
// Validate each dependency
for dep in dependencies {
match &dep.dependency_type {
// Structured access must respect link constraints (but self-references are always allowed)
DependencyType::ColumnAccess { column } | DependencyType::IndexedAccess { column, .. } => {
// Self-references are always allowed (compare table names directly)
// Allow self-references
if dep.target_table == current_table_name {
continue;
}
// Check if table is linked
if !allowed_tables.contains(&dep.target_table) {
return Err(DependencyError::InvalidTableReference {
table_name: dep.target_table.clone(),
@@ -464,9 +489,8 @@ impl DependencyAnalyzer {
});
}
}
// Raw SQL access is unrestricted
DependencyType::SqlQuery { .. } => {
// No validation - raw SQL can access any table
// Raw SQL access is unrestricted
}
}
}
@@ -474,7 +498,7 @@ impl DependencyAnalyzer {
Ok(())
}
/// Proper DFS-based cycle detection with state tracking
/// Runs DFS-based cycle detection on the dependency graph.
fn detect_cycles_dfs(
&self,
graph: &HashMap<i64, Vec<i64>>,
@@ -508,7 +532,16 @@ impl DependencyAnalyzer {
Ok(())
}
/// Save dependencies to database within an existing transaction
/// Saves dependencies to the database within an existing transaction.
///
/// This function replaces all existing dependencies for a script with the new set,
/// ensuring the database reflects the current script analysis results.
///
/// # Arguments
/// * `tx` - Database transaction for atomic updates
/// * `script_id` - ID of the script these dependencies belong to
/// * `table_id` - ID of the table containing the script
/// * `dependencies` - Dependencies to save
pub async fn save_dependencies(
&self,
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
@@ -524,7 +557,6 @@ impl DependencyAnalyzer {
// Insert new dependencies
for dep in dependencies {
// Look up target table ID using actual table name (no magic strings!)
let target_id = sqlx::query_scalar!(
"SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2",
self.schema_id,

View File

@@ -12,6 +12,7 @@ use rust_decimal::Decimal;
use std::str::FromStr;
use crate::steel::server::execution::{self, Value};
use crate::indexer::{IndexCommand, IndexCommandData};
use tokio::sync::mpsc;
use tracing::error;
@@ -153,9 +154,7 @@ pub async fn post_table_data(
format!("Script execution failed for '{}': {}", target_column, e)
))?;
let Value::Strings(mut script_output) = script_result else {
return Err(Status::internal("Script must return string values"));
};
let Value::Strings(mut script_output) = script_result;
let expected_value = script_output.pop()
.ok_or_else(|| Status::internal("Script returned no values"))?;

View File

@@ -216,9 +216,8 @@ pub async fn put_table_data(
Status::invalid_argument(format!("Script execution failed for '{}': {}", target_column, e))
})?;
let Value::Strings(mut script_output_vec) = script_result else {
return Err(Status::internal("Script must return string values"));
};
let Value::Strings(mut script_output_vec) = script_result;
let script_output = script_output_vec.pop().ok_or_else(|| Status::internal("Script returned no values"))?;
if update_data.contains_key(&target_column) {