fixing warnings and making prod code
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
// src/steel/server/execution.rs
|
// src/steel/server/execution.rs
|
||||||
|
|
||||||
use steel::steel_vm::engine::Engine;
|
use steel::steel_vm::engine::Engine;
|
||||||
use steel::steel_vm::register_fn::RegisterFn;
|
use steel::steel_vm::register_fn::RegisterFn;
|
||||||
use steel::rvals::SteelVal;
|
use steel::rvals::SteelVal;
|
||||||
@@ -10,13 +11,13 @@ use std::collections::HashMap;
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing::{debug, error};
|
use tracing::{debug, error};
|
||||||
|
|
||||||
|
/// Represents different types of values that can be returned from Steel script execution.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Value {
|
pub enum Value {
|
||||||
Strings(Vec<String>),
|
Strings(Vec<String>),
|
||||||
Numbers(Vec<i64>),
|
|
||||||
Mixed(Vec<SteelVal>),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Errors that can occur during Steel script execution.
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum ExecutionError {
|
pub enum ExecutionError {
|
||||||
#[error("Script execution failed: {0}")]
|
#[error("Script execution failed: {0}")]
|
||||||
@@ -27,7 +28,7 @@ pub enum ExecutionError {
|
|||||||
UnsupportedType(String),
|
UnsupportedType(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a SteelContext with boolean conversion applied to row data
|
/// Creates a Steel execution context with proper boolean value conversion.
|
||||||
pub async fn create_steel_context_with_boolean_conversion(
|
pub async fn create_steel_context_with_boolean_conversion(
|
||||||
current_table: String,
|
current_table: String,
|
||||||
schema_id: i64,
|
schema_id: i64,
|
||||||
@@ -35,10 +36,6 @@ pub async fn create_steel_context_with_boolean_conversion(
|
|||||||
mut row_data: HashMap<String, String>,
|
mut row_data: HashMap<String, String>,
|
||||||
db_pool: Arc<PgPool>,
|
db_pool: Arc<PgPool>,
|
||||||
) -> Result<SteelContext, ExecutionError> {
|
) -> Result<SteelContext, ExecutionError> {
|
||||||
println!("=== CREATING STEEL CONTEXT ===");
|
|
||||||
println!("Table: {}, Schema: {}", current_table, schema_name);
|
|
||||||
println!("Row data BEFORE boolean conversion: {:?}", row_data);
|
|
||||||
|
|
||||||
// Convert boolean values in row_data to Steel format
|
// Convert boolean values in row_data to Steel format
|
||||||
convert_row_data_for_steel(&db_pool, schema_id, ¤t_table, &mut row_data)
|
convert_row_data_for_steel(&db_pool, schema_id, ¤t_table, &mut row_data)
|
||||||
.await
|
.await
|
||||||
@@ -47,9 +44,6 @@ pub async fn create_steel_context_with_boolean_conversion(
|
|||||||
ExecutionError::RuntimeError(format!("Failed to convert row data: {}", e))
|
ExecutionError::RuntimeError(format!("Failed to convert row data: {}", e))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
println!("Row data AFTER boolean conversion: {:?}", row_data);
|
|
||||||
println!("=== END CREATING STEEL CONTEXT ===");
|
|
||||||
|
|
||||||
Ok(SteelContext {
|
Ok(SteelContext {
|
||||||
current_table,
|
current_table,
|
||||||
schema_id,
|
schema_id,
|
||||||
@@ -59,7 +53,7 @@ pub async fn create_steel_context_with_boolean_conversion(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute script with proper boolean handling
|
/// Executes a Steel script with database context and type-safe result processing.
|
||||||
pub async fn execute_script(
|
pub async fn execute_script(
|
||||||
script: String,
|
script: String,
|
||||||
target_type: &str,
|
target_type: &str,
|
||||||
@@ -69,20 +63,9 @@ pub async fn execute_script(
|
|||||||
current_table: String,
|
current_table: String,
|
||||||
row_data: HashMap<String, String>,
|
row_data: HashMap<String, String>,
|
||||||
) -> Result<Value, ExecutionError> {
|
) -> Result<Value, ExecutionError> {
|
||||||
eprintln!("🚨🚨🚨 EXECUTE_SCRIPT CALLED 🚨🚨🚨");
|
|
||||||
eprintln!("Script: '{}'", script);
|
|
||||||
eprintln!("Table: {}, Target type: {}", current_table, target_type);
|
|
||||||
eprintln!("Input row_data: {:?}", row_data);
|
|
||||||
eprintln!("🚨🚨🚨 END EXECUTE_SCRIPT ENTRY 🚨🚨🚨");
|
|
||||||
|
|
||||||
println!("=== STEEL SCRIPT EXECUTION START ===");
|
|
||||||
println!("Script: '{}'", script);
|
|
||||||
println!("Table: {}, Target type: {}", current_table, target_type);
|
|
||||||
println!("Input row_data: {:?}", row_data);
|
|
||||||
|
|
||||||
let mut vm = Engine::new();
|
let mut vm = Engine::new();
|
||||||
|
|
||||||
// Create context with boolean conversion
|
// Create execution context with proper boolean value conversion
|
||||||
let context = create_steel_context_with_boolean_conversion(
|
let context = create_steel_context_with_boolean_conversion(
|
||||||
current_table,
|
current_table,
|
||||||
schema_id,
|
schema_id,
|
||||||
@@ -93,41 +76,32 @@ pub async fn execute_script(
|
|||||||
|
|
||||||
let context = Arc::new(context);
|
let context = Arc::new(context);
|
||||||
|
|
||||||
// Register existing Steel functions
|
// Register database access functions
|
||||||
register_steel_functions(&mut vm, context.clone());
|
register_steel_functions(&mut vm, context.clone());
|
||||||
|
|
||||||
// Register all decimal math functions using the steel_decimal crate
|
// Register decimal math operations
|
||||||
register_decimal_math_functions(&mut vm);
|
register_decimal_math_functions(&mut vm);
|
||||||
|
|
||||||
// Register variables from the context with the Steel VM
|
// Register row data as variables in the Steel VM
|
||||||
// The row_data now contains Steel-formatted boolean values
|
// Both bare names and @-prefixed names are supported for flexibility
|
||||||
println!("=== REGISTERING STEEL VARIABLES ===");
|
|
||||||
|
|
||||||
// Manual variable registration using Steel's define mechanism
|
|
||||||
let mut define_script = String::new();
|
let mut define_script = String::new();
|
||||||
|
|
||||||
println!("Variables being registered with Steel VM:");
|
|
||||||
for (key, value) in &context.row_data {
|
for (key, value) in &context.row_data {
|
||||||
println!(" STEEL[{}] = '{}'", key, value);
|
|
||||||
// Register both @ prefixed and bare variable names
|
// Register both @ prefixed and bare variable names
|
||||||
define_script.push_str(&format!("(define {} \"{}\")\n", key, value));
|
define_script.push_str(&format!("(define {} \"{}\")\n", key, value));
|
||||||
define_script.push_str(&format!("(define @{} \"{}\")\n", key, value));
|
define_script.push_str(&format!("(define @{} \"{}\")\n", key, value));
|
||||||
}
|
}
|
||||||
println!("Steel script to execute: {}", script);
|
|
||||||
println!("=== END REGISTERING STEEL VARIABLES ===");
|
|
||||||
|
|
||||||
|
// Execute variable definitions if any exist
|
||||||
if !define_script.is_empty() {
|
if !define_script.is_empty() {
|
||||||
println!("Define script: {}", define_script);
|
|
||||||
vm.compile_and_run_raw_program(define_script)
|
vm.compile_and_run_raw_program(define_script)
|
||||||
.map_err(|e| ExecutionError::RuntimeError(format!("Failed to register variables: {}", e)))?;
|
.map_err(|e| ExecutionError::RuntimeError(format!("Failed to register variables: {}", e)))?;
|
||||||
println!("Variables defined successfully");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also try the original method as backup
|
// Also register variables using the decimal registry as backup method
|
||||||
FunctionRegistry::register_variables(&mut vm, context.row_data.clone());
|
FunctionRegistry::register_variables(&mut vm, context.row_data.clone());
|
||||||
|
|
||||||
// Execute script and process results
|
// Execute the main script
|
||||||
println!("Compiling and running Steel script: {}", script);
|
|
||||||
let results = vm.compile_and_run_raw_program(script.clone())
|
let results = vm.compile_and_run_raw_program(script.clone())
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
error!("Steel script execution failed: {}", e);
|
error!("Steel script execution failed: {}", e);
|
||||||
@@ -136,27 +110,18 @@ pub async fn execute_script(
|
|||||||
ExecutionError::RuntimeError(e.to_string())
|
ExecutionError::RuntimeError(e.to_string())
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
println!("Script execution returned {} results", results.len());
|
// Convert results to the requested target type
|
||||||
for (i, result) in results.iter().enumerate() {
|
|
||||||
println!("Result[{}]: {:?}", i, result);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert results to target type
|
|
||||||
match target_type {
|
match target_type {
|
||||||
"STRINGS" => {
|
"STRINGS" => process_string_results(results),
|
||||||
let result = process_string_results(results);
|
|
||||||
println!("Final processed result: {:?}", result);
|
|
||||||
println!("=== STEEL SCRIPT EXECUTION END ===");
|
|
||||||
result
|
|
||||||
},
|
|
||||||
_ => Err(ExecutionError::UnsupportedType(target_type.into()))
|
_ => Err(ExecutionError::UnsupportedType(target_type.into()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Registers Steel functions for database access within the VM context.
|
||||||
fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
|
fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
|
||||||
debug!("Registering Steel functions with context");
|
debug!("Registering Steel functions with context");
|
||||||
|
|
||||||
// Register steel_get_column with row context
|
// Register column access function for current and related tables
|
||||||
vm.register_fn("steel_get_column", {
|
vm.register_fn("steel_get_column", {
|
||||||
let ctx = context.clone();
|
let ctx = context.clone();
|
||||||
move |table: String, column: String| {
|
move |table: String, column: String| {
|
||||||
@@ -169,7 +134,7 @@ fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Register steel_get_column_with_index
|
// Register indexed column access for comma-separated values
|
||||||
vm.register_fn("steel_get_column_with_index", {
|
vm.register_fn("steel_get_column_with_index", {
|
||||||
let ctx = context.clone();
|
let ctx = context.clone();
|
||||||
move |table: String, index: i64, column: String| {
|
move |table: String, index: i64, column: String| {
|
||||||
@@ -182,7 +147,7 @@ fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// SQL query registration
|
// Register safe SQL query execution
|
||||||
vm.register_fn("steel_query_sql", {
|
vm.register_fn("steel_query_sql", {
|
||||||
let ctx = context.clone();
|
let ctx = context.clone();
|
||||||
move |query: String| {
|
move |query: String| {
|
||||||
@@ -196,45 +161,31 @@ fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Registers decimal mathematics functions in the Steel VM.
|
||||||
fn register_decimal_math_functions(vm: &mut Engine) {
|
fn register_decimal_math_functions(vm: &mut Engine) {
|
||||||
debug!("Registering decimal math functions");
|
debug!("Registering decimal math functions");
|
||||||
// Use the steel_decimal crate's FunctionRegistry to register all functions
|
|
||||||
FunctionRegistry::register_all(vm);
|
FunctionRegistry::register_all(vm);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Processes Steel script results into string format for consistent output.
|
||||||
fn process_string_results(results: Vec<SteelVal>) -> Result<Value, ExecutionError> {
|
fn process_string_results(results: Vec<SteelVal>) -> Result<Value, ExecutionError> {
|
||||||
let mut strings = Vec::new();
|
let mut strings = Vec::new();
|
||||||
|
|
||||||
for result in results {
|
for result in results {
|
||||||
match result {
|
let result_str = match result {
|
||||||
SteelVal::StringV(s) => {
|
SteelVal::StringV(s) => s.to_string(),
|
||||||
let result_str = s.to_string();
|
SteelVal::NumV(n) => n.to_string(),
|
||||||
println!("Processing string result: '{}'", result_str);
|
SteelVal::IntV(i) => i.to_string(),
|
||||||
strings.push(result_str);
|
SteelVal::BoolV(b) => b.to_string(),
|
||||||
},
|
|
||||||
SteelVal::NumV(n) => {
|
|
||||||
let result_str = n.to_string();
|
|
||||||
println!("Processing number result: '{}'", result_str);
|
|
||||||
strings.push(result_str);
|
|
||||||
},
|
|
||||||
SteelVal::IntV(i) => {
|
|
||||||
let result_str = i.to_string();
|
|
||||||
println!("Processing integer result: '{}'", result_str);
|
|
||||||
strings.push(result_str);
|
|
||||||
},
|
|
||||||
SteelVal::BoolV(b) => {
|
|
||||||
let result_str = b.to_string();
|
|
||||||
println!("Processing boolean result: '{}'", result_str);
|
|
||||||
strings.push(result_str);
|
|
||||||
},
|
|
||||||
_ => {
|
_ => {
|
||||||
error!("Unexpected result type: {:?}", result);
|
error!("Unexpected result type: {:?}", result);
|
||||||
return Err(ExecutionError::TypeConversionError(
|
return Err(ExecutionError::TypeConversionError(
|
||||||
format!("Expected string-convertible type, got {:?}", result)
|
format!("Expected string-convertible type, got {:?}", result)
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
strings.push(result_str);
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Final processed strings: {:?}", strings);
|
|
||||||
Ok(Value::Strings(strings))
|
Ok(Value::Strings(strings))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
// src/steel/server/functions.rs
|
// src/steel/server/functions.rs
|
||||||
|
|
||||||
use steel::rvals::SteelVal;
|
use steel::rvals::SteelVal;
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@@ -20,9 +21,10 @@ pub enum FunctionError {
|
|||||||
ProhibitedTypeAccess(String),
|
ProhibitedTypeAccess(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define prohibited data types (boolean is explicitly allowed)
|
/// Data types that Steel scripts are prohibited from accessing for security reasons
|
||||||
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
|
const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"];
|
||||||
|
|
||||||
|
/// Execution context for Steel scripts with database access capabilities.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct SteelContext {
|
pub struct SteelContext {
|
||||||
pub current_table: String,
|
pub current_table: String,
|
||||||
@@ -33,6 +35,8 @@ pub struct SteelContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl SteelContext {
|
impl SteelContext {
|
||||||
|
/// Resolves a base table name to its full qualified name in the current schema.
|
||||||
|
/// Used for foreign key relationship traversal in Steel scripts.
|
||||||
pub async fn get_related_table_name(&self, base_name: &str) -> Result<String, FunctionError> {
|
pub async fn get_related_table_name(&self, base_name: &str) -> Result<String, FunctionError> {
|
||||||
let table_def = sqlx::query!(
|
let table_def = sqlx::query!(
|
||||||
r#"SELECT table_name FROM table_definitions
|
r#"SELECT table_name FROM table_definitions
|
||||||
@@ -48,7 +52,8 @@ impl SteelContext {
|
|||||||
Ok(table_def.table_name)
|
Ok(table_def.table_name)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get column type for a given table and column
|
/// Retrieves the SQL data type for a specific column in a table.
|
||||||
|
/// Parses the JSON column definitions to find type information.
|
||||||
async fn get_column_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
|
async fn get_column_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
|
||||||
let table_def = sqlx::query!(
|
let table_def = sqlx::query!(
|
||||||
r#"SELECT columns FROM table_definitions
|
r#"SELECT columns FROM table_definitions
|
||||||
@@ -61,11 +66,10 @@ impl SteelContext {
|
|||||||
.map_err(|e| FunctionError::DatabaseError(e.to_string()))?
|
.map_err(|e| FunctionError::DatabaseError(e.to_string()))?
|
||||||
.ok_or_else(|| FunctionError::TableNotFound(table_name.to_string()))?;
|
.ok_or_else(|| FunctionError::TableNotFound(table_name.to_string()))?;
|
||||||
|
|
||||||
// Parse columns JSON to find the column type
|
|
||||||
let columns: Vec<String> = serde_json::from_value(table_def.columns)
|
let columns: Vec<String> = serde_json::from_value(table_def.columns)
|
||||||
.map_err(|e| FunctionError::DatabaseError(format!("Invalid column data: {}", e)))?;
|
.map_err(|e| FunctionError::DatabaseError(format!("Invalid column data: {}", e)))?;
|
||||||
|
|
||||||
// Find the column and its type
|
// Parse column definitions to find the requested column type
|
||||||
for column_def in columns {
|
for column_def in columns {
|
||||||
let mut parts = column_def.split_whitespace();
|
let mut parts = column_def.split_whitespace();
|
||||||
if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) {
|
if let (Some(name), Some(data_type)) = (parts.next(), parts.next()) {
|
||||||
@@ -83,31 +87,30 @@ impl SteelContext {
|
|||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert database value to Steel format based on column type
|
/// Converts database values to Steel script format based on column type.
|
||||||
|
/// Currently handles boolean conversion to Steel's #true/#false syntax.
|
||||||
fn convert_value_to_steel_format(&self, value: &str, column_type: &str) -> String {
|
fn convert_value_to_steel_format(&self, value: &str, column_type: &str) -> String {
|
||||||
let normalized_type = normalize_data_type(column_type);
|
let normalized_type = normalize_data_type(column_type);
|
||||||
|
|
||||||
match normalized_type.as_str() {
|
match normalized_type.as_str() {
|
||||||
"BOOLEAN" | "BOOL" => {
|
"BOOLEAN" | "BOOL" => {
|
||||||
// Convert database boolean to Steel boolean syntax
|
// Convert database boolean representations to Steel boolean syntax
|
||||||
match value.to_lowercase().as_str() {
|
match value.to_lowercase().as_str() {
|
||||||
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
|
"true" | "t" | "1" | "yes" | "on" => "#true".to_string(),
|
||||||
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
|
"false" | "f" | "0" | "no" | "off" => "#false".to_string(),
|
||||||
_ => value.to_string(), // Return as-is if not a recognized boolean
|
_ => value.to_string(), // Return as-is if not a recognized boolean
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"INTEGER" => {
|
"INTEGER" => value.to_string(),
|
||||||
value.to_string()
|
_ => value.to_string(), // Return as-is for other types
|
||||||
}
|
|
||||||
_ => value.to_string(), // Return as-is for non-boolean types
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate column type and return the column type if valid
|
/// Validates that a column type is allowed for Steel script access.
|
||||||
|
/// Returns the column type if validation passes, error if prohibited.
|
||||||
async fn validate_column_type_and_get_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
|
async fn validate_column_type_and_get_type(&self, table_name: &str, column_name: &str) -> Result<String, FunctionError> {
|
||||||
let column_type = self.get_column_type(table_name, column_name).await?;
|
let column_type = self.get_column_type(table_name, column_name).await?;
|
||||||
|
|
||||||
// Check if this type is prohibited
|
|
||||||
if is_prohibited_type(&column_type) {
|
if is_prohibited_type(&column_type) {
|
||||||
return Err(FunctionError::ProhibitedTypeAccess(format!(
|
return Err(FunctionError::ProhibitedTypeAccess(format!(
|
||||||
"Cannot access column '{}' in table '{}' because it has prohibited type '{}'. Steel scripts cannot access columns of type: {}",
|
"Cannot access column '{}' in table '{}' because it has prohibited type '{}'. Steel scripts cannot access columns of type: {}",
|
||||||
@@ -121,15 +124,15 @@ impl SteelContext {
|
|||||||
Ok(column_type)
|
Ok(column_type)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate column type before access (legacy method for backward compatibility)
|
/// Retrieves column value from current table or related tables via foreign keys.
|
||||||
async fn validate_column_type(&self, table_name: &str, column_name: &str) -> Result<(), FunctionError> {
|
///
|
||||||
self.validate_column_type_and_get_type(table_name, column_name).await?;
|
/// # Behavior
|
||||||
Ok(())
|
/// - Current table: Returns value directly from row_data with type conversion
|
||||||
}
|
/// - Related table: Follows foreign key relationship and queries database
|
||||||
|
/// - All accesses are subject to prohibited type validation
|
||||||
pub fn steel_get_column(&self, table: &str, column: &str) -> Result<SteelVal, SteelVal> {
|
pub fn steel_get_column(&self, table: &str, column: &str) -> Result<SteelVal, SteelVal> {
|
||||||
if table == self.current_table {
|
if table == self.current_table {
|
||||||
// Validate column type for current table access and get the type
|
// Access current table data with type validation
|
||||||
let column_type = tokio::task::block_in_place(|| {
|
let column_type = tokio::task::block_in_place(|| {
|
||||||
let handle = tokio::runtime::Handle::current();
|
let handle = tokio::runtime::Handle::current();
|
||||||
handle.block_on(async {
|
handle.block_on(async {
|
||||||
@@ -150,6 +153,7 @@ impl SteelContext {
|
|||||||
.ok_or_else(|| SteelVal::StringV(format!("Column {} not found", column).into()));
|
.ok_or_else(|| SteelVal::StringV(format!("Column {} not found", column).into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Access related table via foreign key relationship
|
||||||
let base_name = table.split_once('_')
|
let base_name = table.split_once('_')
|
||||||
.map(|(_, rest)| rest)
|
.map(|(_, rest)| rest)
|
||||||
.unwrap_or(table);
|
.unwrap_or(table);
|
||||||
@@ -158,18 +162,17 @@ impl SteelContext {
|
|||||||
let fk_value = self.row_data.get(&fk_column)
|
let fk_value = self.row_data.get(&fk_column)
|
||||||
.ok_or_else(|| SteelVal::StringV(format!("Foreign key {} not found", fk_column).into()))?;
|
.ok_or_else(|| SteelVal::StringV(format!("Foreign key {} not found", fk_column).into()))?;
|
||||||
|
|
||||||
// Use `tokio::task::block_in_place` to safely block the thread
|
|
||||||
let result = tokio::task::block_in_place(|| {
|
let result = tokio::task::block_in_place(|| {
|
||||||
let handle = tokio::runtime::Handle::current();
|
let handle = tokio::runtime::Handle::current();
|
||||||
handle.block_on(async {
|
handle.block_on(async {
|
||||||
let actual_table = self.get_related_table_name(base_name).await
|
let actual_table = self.get_related_table_name(base_name).await
|
||||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||||
|
|
||||||
// Get column type for validation and conversion
|
// Validate column type and get type information
|
||||||
let column_type = self.validate_column_type_and_get_type(&actual_table, column).await
|
let column_type = self.validate_column_type_and_get_type(&actual_table, column).await
|
||||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||||
|
|
||||||
// Fetch the raw value from database
|
// Query the related table for the column value
|
||||||
let raw_value = sqlx::query_scalar::<_, String>(
|
let raw_value = sqlx::query_scalar::<_, String>(
|
||||||
&format!("SELECT {} FROM \"{}\".\"{}\" WHERE id = $1", column, self.schema_name, actual_table)
|
&format!("SELECT {} FROM \"{}\".\"{}\" WHERE id = $1", column, self.schema_name, actual_table)
|
||||||
)
|
)
|
||||||
@@ -179,7 +182,7 @@ impl SteelContext {
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
.map_err(|e| SteelVal::StringV(e.to_string().into()))?;
|
||||||
|
|
||||||
// Convert to Steel format
|
// Convert to appropriate Steel format
|
||||||
let converted_value = self.convert_value_to_steel_format(&raw_value, &column_type);
|
let converted_value = self.convert_value_to_steel_format(&raw_value, &column_type);
|
||||||
Ok(converted_value)
|
Ok(converted_value)
|
||||||
})
|
})
|
||||||
@@ -188,13 +191,15 @@ impl SteelContext {
|
|||||||
result.map(|v| SteelVal::StringV(v.into()))
|
result.map(|v| SteelVal::StringV(v.into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieves a specific indexed element from a comma-separated column value.
|
||||||
|
/// Useful for accessing elements from array-like string representations.
|
||||||
pub fn steel_get_column_with_index(
|
pub fn steel_get_column_with_index(
|
||||||
&self,
|
&self,
|
||||||
table: &str,
|
table: &str,
|
||||||
index: i64,
|
index: i64,
|
||||||
column: &str
|
column: &str
|
||||||
) -> Result<SteelVal, SteelVal> {
|
) -> Result<SteelVal, SteelVal> {
|
||||||
// Get the full value first (this already handles type conversion)
|
// Get the full value with proper type conversion
|
||||||
let value = self.steel_get_column(table, column)?;
|
let value = self.steel_get_column(table, column)?;
|
||||||
|
|
||||||
if let SteelVal::StringV(s) = value {
|
if let SteelVal::StringV(s) = value {
|
||||||
@@ -203,8 +208,7 @@ impl SteelContext {
|
|||||||
if let Some(part) = parts.get(index as usize) {
|
if let Some(part) = parts.get(index as usize) {
|
||||||
let trimmed_part = part.trim();
|
let trimmed_part = part.trim();
|
||||||
|
|
||||||
// If the original column was boolean type, each part should also be treated as boolean
|
// Apply type conversion to the indexed part based on original column type
|
||||||
// We need to get the column type to determine if conversion is needed
|
|
||||||
let column_type = tokio::task::block_in_place(|| {
|
let column_type = tokio::task::block_in_place(|| {
|
||||||
let handle = tokio::runtime::Handle::current();
|
let handle = tokio::runtime::Handle::current();
|
||||||
handle.block_on(async {
|
handle.block_on(async {
|
||||||
@@ -218,7 +222,7 @@ impl SteelContext {
|
|||||||
Ok(SteelVal::StringV(converted_part.into()))
|
Ok(SteelVal::StringV(converted_part.into()))
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
// If we can't get the type, return as-is
|
// If type cannot be determined, return value as-is
|
||||||
Ok(SteelVal::StringV(trimmed_part.into()))
|
Ok(SteelVal::StringV(trimmed_part.into()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -230,15 +234,19 @@ impl SteelContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Executes read-only SQL queries from Steel scripts with safety restrictions.
|
||||||
|
///
|
||||||
|
/// # Security Features
|
||||||
|
/// - Only SELECT, SHOW, and EXPLAIN queries allowed
|
||||||
|
/// - Prohibited column type access validation
|
||||||
|
/// - Returns first column of all rows as comma-separated string
|
||||||
pub fn steel_query_sql(&self, query: &str) -> Result<SteelVal, SteelVal> {
|
pub fn steel_query_sql(&self, query: &str) -> Result<SteelVal, SteelVal> {
|
||||||
// Validate query is read-only
|
|
||||||
if !is_read_only_query(query) {
|
if !is_read_only_query(query) {
|
||||||
return Err(SteelVal::StringV(
|
return Err(SteelVal::StringV(
|
||||||
"Only SELECT queries are allowed".into()
|
"Only SELECT queries are allowed".into()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if query might access prohibited columns
|
|
||||||
if contains_prohibited_column_access(query) {
|
if contains_prohibited_column_access(query) {
|
||||||
return Err(SteelVal::StringV(format!(
|
return Err(SteelVal::StringV(format!(
|
||||||
"SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}",
|
"SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}",
|
||||||
@@ -248,11 +256,9 @@ impl SteelContext {
|
|||||||
|
|
||||||
let pool = self.db_pool.clone();
|
let pool = self.db_pool.clone();
|
||||||
|
|
||||||
// Use `tokio::task::block_in_place` to safely block the thread
|
|
||||||
let result = tokio::task::block_in_place(|| {
|
let result = tokio::task::block_in_place(|| {
|
||||||
let handle = tokio::runtime::Handle::current();
|
let handle = tokio::runtime::Handle::current();
|
||||||
handle.block_on(async {
|
handle.block_on(async {
|
||||||
// Execute and get first column of all rows as strings
|
|
||||||
let rows = sqlx::query(query)
|
let rows = sqlx::query(query)
|
||||||
.fetch_all(&*pool)
|
.fetch_all(&*pool)
|
||||||
.await
|
.await
|
||||||
@@ -273,13 +279,14 @@ impl SteelContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if a data type is prohibited for Steel scripts
|
/// Checks if a data type is prohibited for Steel script access.
|
||||||
fn is_prohibited_type(data_type: &str) -> bool {
|
fn is_prohibited_type(data_type: &str) -> bool {
|
||||||
let normalized_type = normalize_data_type(data_type);
|
let normalized_type = normalize_data_type(data_type);
|
||||||
PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited))
|
PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Normalize data type for comparison (handle NUMERIC variations, etc.)
|
/// Normalizes data type strings for consistent comparison.
|
||||||
|
/// Handles variations like NUMERIC(10,2) by extracting base type.
|
||||||
fn normalize_data_type(data_type: &str) -> String {
|
fn normalize_data_type(data_type: &str) -> String {
|
||||||
data_type.to_uppercase()
|
data_type.to_uppercase()
|
||||||
.split('(') // Remove precision/scale from NUMERIC(x,y)
|
.split('(') // Remove precision/scale from NUMERIC(x,y)
|
||||||
@@ -289,13 +296,11 @@ fn normalize_data_type(data_type: &str) -> String {
|
|||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Basic check for prohibited column access in SQL queries
|
/// Performs basic heuristic check for prohibited column type access in SQL queries.
|
||||||
/// This is a simple heuristic - more sophisticated parsing could be added
|
/// Looks for common patterns that might indicate access to restricted types.
|
||||||
fn contains_prohibited_column_access(query: &str) -> bool {
|
fn contains_prohibited_column_access(query: &str) -> bool {
|
||||||
let query_upper = query.to_uppercase();
|
let query_upper = query.to_uppercase();
|
||||||
|
|
||||||
// Look for common patterns that might indicate prohibited type access
|
|
||||||
// This is a basic implementation - you might want to enhance this
|
|
||||||
let patterns = [
|
let patterns = [
|
||||||
"EXTRACT(", // Common with DATE/TIMESTAMPTZ
|
"EXTRACT(", // Common with DATE/TIMESTAMPTZ
|
||||||
"DATE_PART(", // Common with DATE/TIMESTAMPTZ
|
"DATE_PART(", // Common with DATE/TIMESTAMPTZ
|
||||||
@@ -307,6 +312,7 @@ fn contains_prohibited_column_access(query: &str) -> bool {
|
|||||||
patterns.iter().any(|pattern| query_upper.contains(pattern))
|
patterns.iter().any(|pattern| query_upper.contains(pattern))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validates that a query is read-only and safe for Steel script execution.
|
||||||
fn is_read_only_query(query: &str) -> bool {
|
fn is_read_only_query(query: &str) -> bool {
|
||||||
let query = query.trim_start().to_uppercase();
|
let query = query.trim_start().to_uppercase();
|
||||||
query.starts_with("SELECT") ||
|
query.starts_with("SELECT") ||
|
||||||
@@ -314,14 +320,13 @@ fn is_read_only_query(query: &str) -> bool {
|
|||||||
query.starts_with("EXPLAIN")
|
query.starts_with("EXPLAIN")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper function to convert initial row data for boolean columns
|
/// Converts row data boolean values to Steel script format during context initialization.
|
||||||
pub async fn convert_row_data_for_steel(
|
pub async fn convert_row_data_for_steel(
|
||||||
db_pool: &PgPool,
|
db_pool: &PgPool,
|
||||||
schema_id: i64,
|
schema_id: i64,
|
||||||
table_name: &str,
|
table_name: &str,
|
||||||
row_data: &mut HashMap<String, String>,
|
row_data: &mut HashMap<String, String>,
|
||||||
) -> Result<(), sqlx::Error> {
|
) -> Result<(), sqlx::Error> {
|
||||||
// Get table definition to check column types
|
|
||||||
let table_def = sqlx::query!(
|
let table_def = sqlx::query!(
|
||||||
r#"SELECT columns FROM table_definitions
|
r#"SELECT columns FROM table_definitions
|
||||||
WHERE schema_id = $1 AND table_name = $2"#,
|
WHERE schema_id = $1 AND table_name = $2"#,
|
||||||
@@ -332,7 +337,7 @@ pub async fn convert_row_data_for_steel(
|
|||||||
.await?
|
.await?
|
||||||
.ok_or_else(|| sqlx::Error::RowNotFound)?;
|
.ok_or_else(|| sqlx::Error::RowNotFound)?;
|
||||||
|
|
||||||
// Parse columns to find boolean types
|
// Parse column definitions to identify boolean columns for conversion
|
||||||
if let Ok(columns) = serde_json::from_value::<Vec<String>>(table_def.columns) {
|
if let Ok(columns) = serde_json::from_value::<Vec<String>>(table_def.columns) {
|
||||||
for column_def in columns {
|
for column_def in columns {
|
||||||
let mut parts = column_def.split_whitespace();
|
let mut parts = column_def.split_whitespace();
|
||||||
@@ -340,7 +345,6 @@ pub async fn convert_row_data_for_steel(
|
|||||||
let column_name = name.trim_matches('"');
|
let column_name = name.trim_matches('"');
|
||||||
let normalized_type = normalize_data_type(data_type);
|
let normalized_type = normalize_data_type(data_type);
|
||||||
|
|
||||||
// Fixed: Use traditional if let instead of let chains
|
|
||||||
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
|
if normalized_type == "BOOLEAN" || normalized_type == "BOOL" {
|
||||||
if let Some(value) = row_data.get_mut(column_name) {
|
if let Some(value) = row_data.get_mut(column_name) {
|
||||||
// Convert boolean value to Steel format
|
// Convert boolean value to Steel format
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ use tonic::Status;
|
|||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
|
/// Represents the state of a node during dependency graph traversal.
|
||||||
#[derive(Clone, Copy, PartialEq)]
|
#[derive(Clone, Copy, PartialEq)]
|
||||||
enum NodeState {
|
enum NodeState {
|
||||||
Unvisited,
|
Unvisited,
|
||||||
@@ -12,6 +13,7 @@ enum NodeState {
|
|||||||
Visited, // Completely processed
|
Visited, // Completely processed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Represents a dependency relationship between tables in Steel scripts.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Dependency {
|
pub struct Dependency {
|
||||||
pub target_table: String,
|
pub target_table: String,
|
||||||
@@ -19,14 +21,19 @@ pub struct Dependency {
|
|||||||
pub context: Option<Value>,
|
pub context: Option<Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Types of dependencies that can exist between tables in Steel scripts.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum DependencyType {
|
pub enum DependencyType {
|
||||||
|
/// Direct column access via steel_get_column
|
||||||
ColumnAccess { column: String },
|
ColumnAccess { column: String },
|
||||||
|
/// Indexed column access via steel_get_column_with_index
|
||||||
IndexedAccess { column: String, index: i64 },
|
IndexedAccess { column: String, index: i64 },
|
||||||
|
/// Raw SQL query access via steel_query_sql
|
||||||
SqlQuery { query_fragment: String },
|
SqlQuery { query_fragment: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DependencyType {
|
impl DependencyType {
|
||||||
|
/// Returns the string representation used in the database.
|
||||||
pub fn as_str(&self) -> &'static str {
|
pub fn as_str(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
DependencyType::ColumnAccess { .. } => "column_access",
|
DependencyType::ColumnAccess { .. } => "column_access",
|
||||||
@@ -35,6 +42,7 @@ impl DependencyType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates context JSON for database storage.
|
||||||
pub fn context_json(&self) -> Value {
|
pub fn context_json(&self) -> Value {
|
||||||
match self {
|
match self {
|
||||||
DependencyType::ColumnAccess { column } => {
|
DependencyType::ColumnAccess { column } => {
|
||||||
@@ -50,6 +58,7 @@ impl DependencyType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Errors that can occur during dependency analysis.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum DependencyError {
|
pub enum DependencyError {
|
||||||
CircularDependency {
|
CircularDependency {
|
||||||
@@ -94,43 +103,53 @@ impl From<DependencyError> for Status {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Analyzes Steel scripts to extract table dependencies and validate referential integrity.
|
||||||
|
///
|
||||||
|
/// This analyzer identifies how tables reference each other through Steel function calls
|
||||||
|
/// and ensures that dependency graphs remain acyclic while respecting table link constraints.
|
||||||
pub struct DependencyAnalyzer {
|
pub struct DependencyAnalyzer {
|
||||||
schema_id: i64,
|
schema_id: i64,
|
||||||
pool: PgPool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DependencyAnalyzer {
|
impl DependencyAnalyzer {
|
||||||
|
/// Creates a new dependency analyzer for the specified schema.
|
||||||
pub fn new(schema_id: i64, pool: PgPool) -> Self {
|
pub fn new(schema_id: i64, pool: PgPool) -> Self {
|
||||||
Self { schema_id, pool }
|
Self { schema_id }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Analyzes a Steel script to extract all table dependencies
|
/// Analyzes a Steel script to extract all table dependencies.
|
||||||
/// Uses regex patterns to find function calls that create dependencies
|
///
|
||||||
|
/// Uses regex patterns to identify function calls that create dependencies:
|
||||||
|
/// - `steel_get_column` calls for direct column access
|
||||||
|
/// - `steel_get_column_with_index` calls for indexed access
|
||||||
|
/// - `steel_query_sql` calls for raw SQL access
|
||||||
|
/// - Variable references like `@column_name`
|
||||||
|
/// - `get-var` calls in transformed scripts
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `script` - The Steel script to analyze
|
/// * `script` - The Steel script code to analyze
|
||||||
/// * `current_table_name` - Name of the table this script belongs to (for self-references)
|
/// * `current_table_name` - Name of the table this script belongs to (for self-references)
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * `Ok(Vec<Dependency>)` - List of identified dependencies
|
||||||
|
/// * `Err(DependencyError)` - If script parsing fails
|
||||||
pub fn analyze_script_dependencies(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> {
|
pub fn analyze_script_dependencies(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> {
|
||||||
let mut dependencies = Vec::new();
|
let mut dependencies = Vec::new();
|
||||||
|
|
||||||
// Extract function calls and SQL dependencies using regex
|
// Extract different types of dependencies using regex patterns
|
||||||
dependencies.extend(self.extract_function_calls(script)?);
|
dependencies.extend(self.extract_function_calls(script)?);
|
||||||
dependencies.extend(self.extract_sql_dependencies(script)?);
|
dependencies.extend(self.extract_sql_dependencies(script)?);
|
||||||
|
|
||||||
// Extract get-var calls (for transformed scripts with variables)
|
|
||||||
dependencies.extend(self.extract_get_var_calls(script, current_table_name)?);
|
dependencies.extend(self.extract_get_var_calls(script, current_table_name)?);
|
||||||
|
|
||||||
// Extract direct variable references like @price, @quantity
|
|
||||||
dependencies.extend(self.extract_variable_references(script, current_table_name)?);
|
dependencies.extend(self.extract_variable_references(script, current_table_name)?);
|
||||||
|
|
||||||
Ok(dependencies)
|
Ok(dependencies)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract function calls using regex patterns
|
/// Extracts Steel function calls that create table dependencies.
|
||||||
fn extract_function_calls(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> {
|
fn extract_function_calls(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> {
|
||||||
let mut dependencies = Vec::new();
|
let mut dependencies = Vec::new();
|
||||||
|
|
||||||
// Look for steel_get_column patterns
|
// Pattern: (steel_get_column "table" "column")
|
||||||
let column_pattern = regex::Regex::new(r#"\(\s*steel_get_column\s+"([^"]+)"\s+"([^"]+)""#)
|
let column_pattern = regex::Regex::new(r#"\(\s*steel_get_column\s+"([^"]+)"\s+"([^"]+)""#)
|
||||||
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
||||||
|
|
||||||
@@ -144,7 +163,7 @@ impl DependencyAnalyzer {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look for steel_get_column_with_index patterns
|
// Pattern: (steel_get_column_with_index "table" index "column")
|
||||||
let indexed_pattern = regex::Regex::new(r#"\(\s*steel_get_column_with_index\s+"([^"]+)"\s+(\d+)\s+"([^"]+)""#)
|
let indexed_pattern = regex::Regex::new(r#"\(\s*steel_get_column_with_index\s+"([^"]+)"\s+(\d+)\s+"([^"]+)""#)
|
||||||
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
||||||
|
|
||||||
@@ -163,19 +182,19 @@ impl DependencyAnalyzer {
|
|||||||
Ok(dependencies)
|
Ok(dependencies)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract get-var calls as dependencies (for transformed scripts with variables)
|
/// Extracts get-var calls as dependencies for transformed scripts.
|
||||||
/// These are self-references to the current table
|
/// These represent self-references to the current table.
|
||||||
fn extract_get_var_calls(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> {
|
fn extract_get_var_calls(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> {
|
||||||
let mut dependencies = Vec::new();
|
let mut dependencies = Vec::new();
|
||||||
|
|
||||||
// Look for get-var patterns in transformed scripts: (get-var "variable")
|
// Pattern: (get-var "variable")
|
||||||
let get_var_pattern = regex::Regex::new(r#"\(get-var\s+"([^"]+)"\)"#)
|
let get_var_pattern = regex::Regex::new(r#"\(get-var\s+"([^"]+)"\)"#)
|
||||||
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
||||||
|
|
||||||
for caps in get_var_pattern.captures_iter(script) {
|
for caps in get_var_pattern.captures_iter(script) {
|
||||||
let variable_name = caps[1].to_string();
|
let variable_name = caps[1].to_string();
|
||||||
dependencies.push(Dependency {
|
dependencies.push(Dependency {
|
||||||
target_table: current_table_name.to_string(), // Use actual table name
|
target_table: current_table_name.to_string(),
|
||||||
dependency_type: DependencyType::ColumnAccess { column: variable_name },
|
dependency_type: DependencyType::ColumnAccess { column: variable_name },
|
||||||
context: None,
|
context: None,
|
||||||
});
|
});
|
||||||
@@ -184,19 +203,19 @@ impl DependencyAnalyzer {
|
|||||||
Ok(dependencies)
|
Ok(dependencies)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract direct variable references like @price, @quantity
|
/// Extracts direct variable references like @price, @quantity.
|
||||||
/// These are self-references to the current table
|
/// These represent self-references to the current table.
|
||||||
fn extract_variable_references(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> {
|
fn extract_variable_references(&self, script: &str, current_table_name: &str) -> Result<Vec<Dependency>, DependencyError> {
|
||||||
let mut dependencies = Vec::new();
|
let mut dependencies = Vec::new();
|
||||||
|
|
||||||
// Look for @variable patterns: @price, @quantity, etc.
|
// Pattern: @variable_name
|
||||||
let variable_pattern = regex::Regex::new(r#"@([a-zA-Z_][a-zA-Z0-9_]*)"#)
|
let variable_pattern = regex::Regex::new(r#"@([a-zA-Z_][a-zA-Z0-9_]*)"#)
|
||||||
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
||||||
|
|
||||||
for caps in variable_pattern.captures_iter(script) {
|
for caps in variable_pattern.captures_iter(script) {
|
||||||
let variable_name = caps[1].to_string();
|
let variable_name = caps[1].to_string();
|
||||||
dependencies.push(Dependency {
|
dependencies.push(Dependency {
|
||||||
target_table: current_table_name.to_string(), // Use actual table name
|
target_table: current_table_name.to_string(),
|
||||||
dependency_type: DependencyType::ColumnAccess { column: variable_name },
|
dependency_type: DependencyType::ColumnAccess { column: variable_name },
|
||||||
context: None,
|
context: None,
|
||||||
});
|
});
|
||||||
@@ -205,11 +224,11 @@ impl DependencyAnalyzer {
|
|||||||
Ok(dependencies)
|
Ok(dependencies)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract table references from SQL queries in steel_query_sql calls
|
/// Extracts table references from SQL queries in steel_query_sql calls.
|
||||||
fn extract_sql_dependencies(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> {
|
fn extract_sql_dependencies(&self, script: &str) -> Result<Vec<Dependency>, DependencyError> {
|
||||||
let mut dependencies = Vec::new();
|
let mut dependencies = Vec::new();
|
||||||
|
|
||||||
// Look for steel_query_sql calls and extract table names from the SQL
|
// Pattern: (steel_query_sql "SELECT ... FROM table ...")
|
||||||
let sql_pattern = regex::Regex::new(r#"\(\s*steel_query_sql\s+"([^"]+)""#)
|
let sql_pattern = regex::Regex::new(r#"\(\s*steel_query_sql\s+"([^"]+)""#)
|
||||||
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
||||||
|
|
||||||
@@ -231,12 +250,12 @@ impl DependencyAnalyzer {
|
|||||||
Ok(dependencies)
|
Ok(dependencies)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract table names from SQL query text
|
/// Extracts table names from SQL query text using regex patterns.
|
||||||
|
/// Looks for FROM and JOIN clauses to identify table references.
|
||||||
fn extract_table_references_from_sql(&self, sql: &str) -> Result<Vec<String>, DependencyError> {
|
fn extract_table_references_from_sql(&self, sql: &str) -> Result<Vec<String>, DependencyError> {
|
||||||
let mut tables = Vec::new();
|
let mut tables = Vec::new();
|
||||||
|
|
||||||
// Simple extraction - look for FROM and JOIN clauses
|
// Pattern: FROM table_name or JOIN table_name
|
||||||
// This could be made more sophisticated with a proper SQL parser
|
|
||||||
let table_pattern = regex::Regex::new(r#"(?i)\b(?:FROM|JOIN)\s+(?:"([^"]+)"|(\w+))"#)
|
let table_pattern = regex::Regex::new(r#"(?i)\b(?:FROM|JOIN)\s+(?:"([^"]+)"|(\w+))"#)
|
||||||
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
.map_err(|e: regex::Error| DependencyError::ScriptParseError { error: e.to_string() })?;
|
||||||
|
|
||||||
@@ -253,18 +272,30 @@ impl DependencyAnalyzer {
|
|||||||
Ok(tables)
|
Ok(tables)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check for cycles in the dependency graph using proper DFS
|
/// Checks for circular dependencies in the dependency graph.
|
||||||
/// Self-references are allowed and filtered out from cycle detection
|
///
|
||||||
|
/// This function validates that adding new dependencies won't create cycles
|
||||||
|
/// that could lead to infinite loops during script execution. Self-references
|
||||||
|
/// are explicitly allowed and filtered out from cycle detection.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `tx` - Database transaction for querying existing dependencies
|
||||||
|
/// * `table_id` - ID of the table adding new dependencies
|
||||||
|
/// * `new_dependencies` - Dependencies to be added
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * `Ok(())` - No cycles detected, safe to add dependencies
|
||||||
|
/// * `Err(DependencyError)` - Cycle detected or validation failed
|
||||||
pub async fn check_for_cycles(
|
pub async fn check_for_cycles(
|
||||||
&self,
|
&self,
|
||||||
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
|
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
|
||||||
table_id: i64,
|
table_id: i64,
|
||||||
new_dependencies: &[Dependency],
|
new_dependencies: &[Dependency],
|
||||||
) -> Result<(), DependencyError> {
|
) -> Result<(), DependencyError> {
|
||||||
// FIRST: Validate that structured table access respects link constraints
|
// Validate that structured table access respects link constraints
|
||||||
self.validate_link_constraints(tx, table_id, new_dependencies).await?;
|
self.validate_link_constraints(tx, table_id, new_dependencies).await?;
|
||||||
|
|
||||||
// Get current dependency graph for this schema
|
// Build current dependency graph excluding self-references
|
||||||
let current_deps = sqlx::query!(
|
let current_deps = sqlx::query!(
|
||||||
r#"SELECT sd.source_table_id, sd.target_table_id, st.table_name as source_name, tt.table_name as target_name
|
r#"SELECT sd.source_table_id, sd.target_table_id, st.table_name as source_name, tt.table_name as target_name
|
||||||
FROM script_dependencies sd
|
FROM script_dependencies sd
|
||||||
@@ -277,12 +308,11 @@ impl DependencyAnalyzer {
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?;
|
.map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?;
|
||||||
|
|
||||||
// Build adjacency list - EXCLUDE self-references since they're always allowed
|
|
||||||
let mut graph: HashMap<i64, Vec<i64>> = HashMap::new();
|
let mut graph: HashMap<i64, Vec<i64>> = HashMap::new();
|
||||||
let mut table_names: HashMap<i64, String> = HashMap::new();
|
let mut table_names: HashMap<i64, String> = HashMap::new();
|
||||||
|
|
||||||
|
// Build adjacency list excluding self-references
|
||||||
for dep in current_deps {
|
for dep in current_deps {
|
||||||
// Skip self-references in cycle detection
|
|
||||||
if dep.source_table_id != dep.target_table_id {
|
if dep.source_table_id != dep.target_table_id {
|
||||||
graph.entry(dep.source_table_id).or_default().push(dep.target_table_id);
|
graph.entry(dep.source_table_id).or_default().push(dep.target_table_id);
|
||||||
}
|
}
|
||||||
@@ -290,9 +320,8 @@ impl DependencyAnalyzer {
|
|||||||
table_names.insert(dep.target_table_id, dep.target_name);
|
table_names.insert(dep.target_table_id, dep.target_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new dependencies to test - EXCLUDE self-references
|
// Add new dependencies to test (excluding self-references)
|
||||||
for dep in new_dependencies {
|
for dep in new_dependencies {
|
||||||
// Look up target table ID using the actual table name
|
|
||||||
let target_id = sqlx::query_scalar!(
|
let target_id = sqlx::query_scalar!(
|
||||||
"SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2",
|
"SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2",
|
||||||
self.schema_id,
|
self.schema_id,
|
||||||
@@ -306,12 +335,12 @@ impl DependencyAnalyzer {
|
|||||||
script_context: format!("table_id_{}", table_id),
|
script_context: format!("table_id_{}", table_id),
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Only add to cycle detection graph if it's NOT a self-reference
|
// Only add to cycle detection if not a self-reference
|
||||||
if table_id != target_id {
|
if table_id != target_id {
|
||||||
graph.entry(table_id).or_default().push(target_id);
|
graph.entry(table_id).or_default().push(target_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get table name for error reporting
|
// Ensure table names are available for error reporting
|
||||||
if !table_names.contains_key(&table_id) {
|
if !table_names.contains_key(&table_id) {
|
||||||
let source_name = sqlx::query_scalar!(
|
let source_name = sqlx::query_scalar!(
|
||||||
"SELECT table_name FROM table_definitions WHERE id = $1",
|
"SELECT table_name FROM table_definitions WHERE id = $1",
|
||||||
@@ -325,12 +354,13 @@ impl DependencyAnalyzer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detect cycles using proper DFS algorithm (now without self-references)
|
// Detect cycles using DFS algorithm
|
||||||
self.detect_cycles_dfs(&graph, &table_names, table_id)?;
|
self.detect_cycles_dfs(&graph, &table_names, table_id)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Performs depth-first search to detect cycles in the dependency graph.
|
||||||
fn dfs_visit(
|
fn dfs_visit(
|
||||||
&self,
|
&self,
|
||||||
node: i64,
|
node: i64,
|
||||||
@@ -345,21 +375,18 @@ impl DependencyAnalyzer {
|
|||||||
|
|
||||||
if let Some(neighbors) = graph.get(&node) {
|
if let Some(neighbors) = graph.get(&node) {
|
||||||
for &neighbor in neighbors {
|
for &neighbor in neighbors {
|
||||||
// Ensure neighbor is in states map
|
|
||||||
if !states.contains_key(&neighbor) {
|
if !states.contains_key(&neighbor) {
|
||||||
states.insert(neighbor, NodeState::Unvisited);
|
states.insert(neighbor, NodeState::Unvisited);
|
||||||
}
|
}
|
||||||
|
|
||||||
match states.get(&neighbor).copied().unwrap_or(NodeState::Unvisited) {
|
match states.get(&neighbor).copied().unwrap_or(NodeState::Unvisited) {
|
||||||
NodeState::Visiting => {
|
NodeState::Visiting => {
|
||||||
// Check if this is a self-reference (allowed) or a real cycle (not allowed)
|
// Skip self-references as they're allowed
|
||||||
if neighbor == node {
|
if neighbor == node {
|
||||||
// Self-reference: A table referencing itself is allowed
|
|
||||||
// Skip this - it's not a harmful cycle
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Found a real cycle! Build the cycle path
|
// Found a cycle - build the cycle path
|
||||||
let cycle_start_idx = path.iter().position(|&x| x == neighbor).unwrap_or(0);
|
let cycle_start_idx = path.iter().position(|&x| x == neighbor).unwrap_or(0);
|
||||||
let cycle_path: Vec<String> = path[cycle_start_idx..]
|
let cycle_path: Vec<String> = path[cycle_start_idx..]
|
||||||
.iter()
|
.iter()
|
||||||
@@ -367,7 +394,7 @@ impl DependencyAnalyzer {
|
|||||||
.map(|&id| table_names.get(&id).cloned().unwrap_or_else(|| id.to_string()))
|
.map(|&id| table_names.get(&id).cloned().unwrap_or_else(|| id.to_string()))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Only report as error if the cycle involves more than one table
|
// Only report as error if cycle involves multiple tables
|
||||||
if cycle_path.len() > 2 || (cycle_path.len() == 2 && cycle_path[0] != cycle_path[1]) {
|
if cycle_path.len() > 2 || (cycle_path.len() == 2 && cycle_path[0] != cycle_path[1]) {
|
||||||
let involving_script = table_names.get(&starting_table)
|
let involving_script = table_names.get(&starting_table)
|
||||||
.cloned()
|
.cloned()
|
||||||
@@ -380,7 +407,6 @@ impl DependencyAnalyzer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
NodeState::Unvisited => {
|
NodeState::Unvisited => {
|
||||||
// Recursively visit unvisited neighbor
|
|
||||||
self.dfs_visit(neighbor, states, graph, path, table_names, starting_table)?;
|
self.dfs_visit(neighbor, states, graph, path, table_names, starting_table)?;
|
||||||
}
|
}
|
||||||
NodeState::Visited => {
|
NodeState::Visited => {
|
||||||
@@ -395,23 +421,23 @@ impl DependencyAnalyzer {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validates that structured table access (steel_get_column functions) respects link constraints
|
/// Validates that structured table access respects table link constraints.
|
||||||
/// Raw SQL access (steel_query_sql) is allowed to reference any table
|
|
||||||
/// SELF-REFERENCES are always allowed (table can access its own columns)
|
|
||||||
///
|
///
|
||||||
/// Example:
|
/// # Access Rules
|
||||||
/// - Table A can ALWAYS use: (steel_get_column "table_a" "column_name") ✅ (self-reference)
|
/// - Self-references are always allowed (table can access its own columns)
|
||||||
/// - Table A is linked to Table B via table_definition_links
|
/// - Structured access (steel_get_column functions) requires explicit table links
|
||||||
/// - Script for Table A can use: (steel_get_column "table_b" "column_name") ✅
|
/// - Raw SQL access (steel_query_sql) is unrestricted
|
||||||
/// - Script for Table A CANNOT use: (steel_get_column "table_c" "column_name") ❌
|
///
|
||||||
/// - Script for Table A CAN use: (steel_query_sql "SELECT * FROM table_c") ✅
|
/// # Arguments
|
||||||
|
/// * `tx` - Database transaction for querying table links
|
||||||
|
/// * `source_table_id` - ID of the table with the script
|
||||||
|
/// * `dependencies` - Dependencies to validate
|
||||||
async fn validate_link_constraints(
|
async fn validate_link_constraints(
|
||||||
&self,
|
&self,
|
||||||
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
|
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
|
||||||
source_table_id: i64,
|
source_table_id: i64,
|
||||||
dependencies: &[Dependency],
|
dependencies: &[Dependency],
|
||||||
) -> Result<(), DependencyError> {
|
) -> Result<(), DependencyError> {
|
||||||
// Get the current table name for self-reference checking
|
|
||||||
let current_table_name = sqlx::query_scalar!(
|
let current_table_name = sqlx::query_scalar!(
|
||||||
"SELECT table_name FROM table_definitions WHERE id = $1",
|
"SELECT table_name FROM table_definitions WHERE id = $1",
|
||||||
source_table_id
|
source_table_id
|
||||||
@@ -432,25 +458,24 @@ impl DependencyAnalyzer {
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?;
|
.map_err(|e| DependencyError::DatabaseError { error: e.to_string() })?;
|
||||||
|
|
||||||
// Create a set of allowed table names for quick lookup
|
|
||||||
let mut allowed_tables: std::collections::HashSet<String> = linked_tables
|
let mut allowed_tables: std::collections::HashSet<String> = linked_tables
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|row| row.table_name)
|
.map(|row| row.table_name)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// ALWAYS allow self-references
|
// Self-references are always allowed
|
||||||
allowed_tables.insert(current_table_name.clone());
|
allowed_tables.insert(current_table_name.clone());
|
||||||
|
|
||||||
// Validate each dependency
|
// Validate each dependency
|
||||||
for dep in dependencies {
|
for dep in dependencies {
|
||||||
match &dep.dependency_type {
|
match &dep.dependency_type {
|
||||||
// Structured access must respect link constraints (but self-references are always allowed)
|
|
||||||
DependencyType::ColumnAccess { column } | DependencyType::IndexedAccess { column, .. } => {
|
DependencyType::ColumnAccess { column } | DependencyType::IndexedAccess { column, .. } => {
|
||||||
// Self-references are always allowed (compare table names directly)
|
// Allow self-references
|
||||||
if dep.target_table == current_table_name {
|
if dep.target_table == current_table_name {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if table is linked
|
||||||
if !allowed_tables.contains(&dep.target_table) {
|
if !allowed_tables.contains(&dep.target_table) {
|
||||||
return Err(DependencyError::InvalidTableReference {
|
return Err(DependencyError::InvalidTableReference {
|
||||||
table_name: dep.target_table.clone(),
|
table_name: dep.target_table.clone(),
|
||||||
@@ -464,9 +489,8 @@ impl DependencyAnalyzer {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Raw SQL access is unrestricted
|
|
||||||
DependencyType::SqlQuery { .. } => {
|
DependencyType::SqlQuery { .. } => {
|
||||||
// No validation - raw SQL can access any table
|
// Raw SQL access is unrestricted
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -474,7 +498,7 @@ impl DependencyAnalyzer {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Proper DFS-based cycle detection with state tracking
|
/// Runs DFS-based cycle detection on the dependency graph.
|
||||||
fn detect_cycles_dfs(
|
fn detect_cycles_dfs(
|
||||||
&self,
|
&self,
|
||||||
graph: &HashMap<i64, Vec<i64>>,
|
graph: &HashMap<i64, Vec<i64>>,
|
||||||
@@ -508,7 +532,16 @@ impl DependencyAnalyzer {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Save dependencies to database within an existing transaction
|
/// Saves dependencies to the database within an existing transaction.
|
||||||
|
///
|
||||||
|
/// This function replaces all existing dependencies for a script with the new set,
|
||||||
|
/// ensuring the database reflects the current script analysis results.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `tx` - Database transaction for atomic updates
|
||||||
|
/// * `script_id` - ID of the script these dependencies belong to
|
||||||
|
/// * `table_id` - ID of the table containing the script
|
||||||
|
/// * `dependencies` - Dependencies to save
|
||||||
pub async fn save_dependencies(
|
pub async fn save_dependencies(
|
||||||
&self,
|
&self,
|
||||||
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
|
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
|
||||||
@@ -524,7 +557,6 @@ impl DependencyAnalyzer {
|
|||||||
|
|
||||||
// Insert new dependencies
|
// Insert new dependencies
|
||||||
for dep in dependencies {
|
for dep in dependencies {
|
||||||
// Look up target table ID using actual table name (no magic strings!)
|
|
||||||
let target_id = sqlx::query_scalar!(
|
let target_id = sqlx::query_scalar!(
|
||||||
"SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2",
|
"SELECT id FROM table_definitions WHERE schema_id = $1 AND table_name = $2",
|
||||||
self.schema_id,
|
self.schema_id,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ use rust_decimal::Decimal;
|
|||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
use crate::steel::server::execution::{self, Value};
|
use crate::steel::server::execution::{self, Value};
|
||||||
|
|
||||||
use crate::indexer::{IndexCommand, IndexCommandData};
|
use crate::indexer::{IndexCommand, IndexCommandData};
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
@@ -153,9 +154,7 @@ pub async fn post_table_data(
|
|||||||
format!("Script execution failed for '{}': {}", target_column, e)
|
format!("Script execution failed for '{}': {}", target_column, e)
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
let Value::Strings(mut script_output) = script_result else {
|
let Value::Strings(mut script_output) = script_result;
|
||||||
return Err(Status::internal("Script must return string values"));
|
|
||||||
};
|
|
||||||
|
|
||||||
let expected_value = script_output.pop()
|
let expected_value = script_output.pop()
|
||||||
.ok_or_else(|| Status::internal("Script returned no values"))?;
|
.ok_or_else(|| Status::internal("Script returned no values"))?;
|
||||||
|
|||||||
@@ -216,9 +216,8 @@ pub async fn put_table_data(
|
|||||||
Status::invalid_argument(format!("Script execution failed for '{}': {}", target_column, e))
|
Status::invalid_argument(format!("Script execution failed for '{}': {}", target_column, e))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let Value::Strings(mut script_output_vec) = script_result else {
|
let Value::Strings(mut script_output_vec) = script_result;
|
||||||
return Err(Status::internal("Script must return string values"));
|
|
||||||
};
|
|
||||||
let script_output = script_output_vec.pop().ok_or_else(|| Status::internal("Script returned no values"))?;
|
let script_output = script_output_vec.pop().ok_or_else(|| Status::internal("Script returned no values"))?;
|
||||||
|
|
||||||
if update_data.contains_key(&target_column) {
|
if update_data.contains_key(&target_column) {
|
||||||
|
|||||||
Reference in New Issue
Block a user