diff --git a/server/src/steel/server/execution.rs b/server/src/steel/server/execution.rs index 34ec379..a021bda 100644 --- a/server/src/steel/server/execution.rs +++ b/server/src/steel/server/execution.rs @@ -3,7 +3,7 @@ use steel::steel_vm::engine::Engine; use steel::steel_vm::register_fn::RegisterFn; use steel::rvals::SteelVal; -use super::functions::{SteelContext, convert_row_data_for_steel}; +use super::functions::SteelContext; use steel_decimal::registry::FunctionRegistry; use sqlx::PgPool; use std::sync::Arc; @@ -62,6 +62,48 @@ fn auto_promote_with_index( .into_owned() } +use common::proto::komp_ac::table_definition::ColumnDefinition; + +// Converts row data boolean values to Steel script format during context initialization. +pub async fn convert_row_data_for_steel( + db_pool: &PgPool, + schema_id: i64, + table_name: &str, + row_data: &mut HashMap, +) -> Result<(), sqlx::Error> { + let table_def = sqlx::query!( + r#" + SELECT columns FROM table_definitions + WHERE schema_id = $1 AND table_name = $2 + "#, + schema_id, + table_name + ) + .fetch_optional(db_pool) + .await? + .ok_or_else(|| sqlx::Error::RowNotFound)?; + + // Parse column definitions to identify boolean columns + if let Ok(columns) = serde_json::from_value::>(table_def.columns) { + for col_def in columns { + let normalized_type = + col_def.field_type.to_uppercase().split('(').next().unwrap().to_string(); + + if normalized_type == "BOOLEAN" || normalized_type == "BOOL" { + if let Some(value) = row_data.get_mut(&col_def.name) { + *value = match value.to_lowercase().as_str() { + "true" | "t" | "1" | "yes" | "on" => "#true".to_string(), + "false" | "f" | "0" | "no" | "off" => "#false".to_string(), + _ => value.clone(), + }; + } + } + } + } + + Ok(()) +} + pub async fn create_steel_context_with_boolean_conversion( current_table: String, schema_id: i64, diff --git a/server/src/steel/server/functions.rs b/server/src/steel/server/functions.rs index 01367b3..54d7776 100644 --- a/server/src/steel/server/functions.rs +++ b/server/src/steel/server/functions.rs @@ -22,10 +22,8 @@ pub enum FunctionError { ProhibitedTypeAccess(String), } -/// Data types that Steel scripts are prohibited from accessing for security reasons const PROHIBITED_TYPES: &[&str] = &["BIGINT", "DATE", "TIMESTAMPTZ"]; -/// Execution context for Steel scripts with database access capabilities. #[derive(Clone)] pub struct SteelContext { pub current_table: String, @@ -36,26 +34,11 @@ pub struct SteelContext { } impl SteelContext { - /// Resolves a base table name to its full qualified name in the current schema. - /// Used for foreign key relationship traversal in Steel scripts. - pub async fn get_related_table_name(&self, base_name: &str) -> Result { - let table_def = sqlx::query!( - r#"SELECT table_name FROM table_definitions - WHERE schema_id = $1 AND table_name LIKE $2"#, - self.schema_id, - format!("%_{}", base_name) - ) - .fetch_optional(&*self.db_pool) - .await - .map_err(|e| FunctionError::DatabaseError(e.to_string()))? - .ok_or_else(|| FunctionError::TableNotFound(base_name.to_string()))?; - - Ok(table_def.table_name) - } - - /// Retrieves the SQL data type for a specific column in a table. - /// Parses the JSON column definitions to find type information. - async fn get_column_type(&self, table_name: &str, column_name: &str) -> Result { + async fn get_column_type( + &self, + table_name: &str, + column_name: &str, + ) -> Result { let table_def = sqlx::query!( r#"SELECT columns FROM table_definitions WHERE schema_id = $1 AND table_name = $2"#, @@ -68,7 +51,10 @@ impl SteelContext { .ok_or_else(|| FunctionError::TableNotFound(table_name.to_string()))?; let columns: Vec = serde_json::from_value(table_def.columns) - .map_err(|e| FunctionError::DatabaseError(format!("Invalid column data: {}", e)))?; + .map_err(|e| FunctionError::DatabaseError(format!( + "Invalid column data: {}", + e + )))?; for col_def in columns { if col_def.name == column_name { @@ -78,33 +64,29 @@ impl SteelContext { Err(FunctionError::ColumnNotFound(format!( "Column '{}' not found in table '{}'", - column_name, - table_name + column_name, table_name ))) } - /// Converts database values to Steel script format based on column type. - /// Currently handles boolean conversion to Steel's #true/#false syntax. fn convert_value_to_steel_format(&self, value: &str, column_type: &str) -> String { let normalized_type = normalize_data_type(column_type); match normalized_type.as_str() { - "BOOLEAN" | "BOOL" => { - // Convert database boolean representations to Steel boolean syntax - match value.to_lowercase().as_str() { - "true" | "t" | "1" | "yes" | "on" => "#true".to_string(), - "false" | "f" | "0" | "no" | "off" => "#false".to_string(), - _ => value.to_string(), // Return as-is if not a recognized boolean - } - } + "BOOLEAN" | "BOOL" => match value.to_lowercase().as_str() { + "true" | "t" | "1" | "yes" | "on" => "#true".to_string(), + "false" | "f" | "0" | "no" | "off" => "#false".to_string(), + _ => value.to_string(), + }, "INTEGER" => value.to_string(), - _ => value.to_string(), // Return as-is for other types + _ => value.to_string(), } } - /// Validates that a column type is allowed for Steel script access. - /// Returns the column type if validation passes, error if prohibited. - async fn validate_column_type_and_get_type(&self, table_name: &str, column_name: &str) -> Result { + async fn validate_column_type_and_get_type( + &self, + table_name: &str, + column_name: &str, + ) -> Result { let column_type = self.get_column_type(table_name, column_name).await?; if is_prohibited_type(&column_type) { @@ -120,15 +102,13 @@ impl SteelContext { Ok(column_type) } - /// Retrieves column value from current table or related tables via foreign keys. - /// - /// # Behavior - /// - Current table: Returns value directly from row_data with type conversion - /// - Related table: Follows foreign key relationship and queries database - /// - All accesses are subject to prohibited type validation - pub fn steel_get_column(&self, table: &str, column: &str) -> Result { + pub fn steel_get_column( + &self, + table: &str, + column: &str, + ) -> Result { if table == self.current_table { - // Access current table data with type validation + // current table let column_type = tokio::task::block_in_place(|| { let handle = tokio::runtime::Handle::current(); handle.block_on(async { @@ -141,71 +121,112 @@ impl SteelContext { Err(e) => return Err(SteelVal::StringV(e.to_string().into())), }; - return self.row_data.get(column) + return self + .row_data + .get(column) .map(|v| { - let converted_value = self.convert_value_to_steel_format(v, &column_type); - SteelVal::StringV(converted_value.into()) + let converted = + self.convert_value_to_steel_format(v, &column_type); + SteelVal::StringV(converted.into()) }) - .ok_or_else(|| SteelVal::StringV(format!("Column {} not found", column).into())); + .ok_or_else(|| { + SteelVal::StringV( + format!("Column {} not found", column).into(), + ) + }); } - // Access related table via foreign key relationship - // TODO REDO REMOVE YEAR PREFIX DEPRECATION - let base_name = table.split_once('_') - .map(|(_, rest)| rest) - .unwrap_or(table); - - let fk_column = format!("{}_id", base_name); - let fk_value = self.row_data.get(&fk_column) - .ok_or_else(|| SteelVal::StringV(format!("Foreign key {} not found", fk_column).into()))?; - + // Cross-table via FK: use exact table name FK convention: "_id" let result = tokio::task::block_in_place(|| { let handle = tokio::runtime::Handle::current(); handle.block_on(async { - let actual_table = self.get_related_table_name(base_name).await - .map_err(|e| SteelVal::StringV(e.to_string().into()))?; + let fk_column = format!("{}_id", table); + let fk_value = self + .row_data + .get(&fk_column) + .ok_or_else(|| { + FunctionError::ForeignKeyNotFound(format!( + "Foreign key column '{}' not found on '{}'", + fk_column, self.current_table + )) + })?; - // Validate column type and get type information - let column_type = self.validate_column_type_and_get_type(&actual_table, column).await - .map_err(|e| SteelVal::StringV(e.to_string().into()))?; + let column_type = + self.validate_column_type_and_get_type(table, column) + .await?; - // Query the related table for the column value - let raw_value = sqlx::query_scalar::<_, String>( - &format!("SELECT {} FROM \"{}\".\"{}\" WHERE id = $1", column, self.schema_name, actual_table) + let raw_value = sqlx::query_scalar::<_, String>(&format!( + "SELECT \"{}\" FROM \"{}\".\"{}\" WHERE id = $1", + column, self.schema_name, table + )) + .bind( + fk_value + .parse::() + .map_err(|_| { + FunctionError::DatabaseError( + "Invalid foreign key format".into(), + ) + })?, ) - .bind(fk_value.parse::().map_err(|_| - SteelVal::StringV("Invalid foreign key format".into()))?) .fetch_one(&*self.db_pool) .await - .map_err(|e| SteelVal::StringV(e.to_string().into()))?; + .map_err(|e| FunctionError::DatabaseError(e.to_string()))?; - // Convert to appropriate Steel format - let converted_value = self.convert_value_to_steel_format(&raw_value, &column_type); - Ok(converted_value) + let converted = + self.convert_value_to_steel_format(&raw_value, &column_type); + Ok::(converted) }) }); - result.map(|v| SteelVal::StringV(v.into())) + match result { + Ok(v) => Ok(SteelVal::StringV(v.into())), + Err(e) => Err(SteelVal::StringV(e.to_string().into())), + } } - /// Retrieves a specific indexed element from a comma-separated column value. - /// Useful for accessing elements from array-like string representations. pub fn steel_get_column_with_index( &self, table: &str, index: i64, - column: &str + column: &str, ) -> Result { - // Get the full value with proper type conversion - let value = self.steel_get_column(table, column)?; + // Cross-table: interpret 'index' as the row id to fetch directly + if table != self.current_table { + let result = tokio::task::block_in_place(|| { + let handle = tokio::runtime::Handle::current(); + handle.block_on(async { + let column_type = + self.validate_column_type_and_get_type(table, column) + .await?; + let raw_value = sqlx::query_scalar::<_, String>(&format!( + "SELECT \"{}\" FROM \"{}\".\"{}\" WHERE id = $1", + column, self.schema_name, table + )) + .bind(index) + .fetch_one(&*self.db_pool) + .await + .map_err(|e| FunctionError::DatabaseError(e.to_string()))?; + + let converted = self + .convert_value_to_steel_format(&raw_value, &column_type); + Ok::(converted) + }) + }); + + return match result { + Ok(v) => Ok(SteelVal::StringV(v.into())), + Err(e) => Err(SteelVal::StringV(e.to_string().into())), + }; + } + + // Current table: existing behavior (index in comma-separated string) + let value = self.steel_get_column(table, column)?; if let SteelVal::StringV(s) = value { let parts: Vec<_> = s.split(',').collect(); - if let Some(part) = parts.get(index as usize) { - let trimmed_part = part.trim(); + let trimmed = part.trim(); - // Apply type conversion to the indexed part based on original column type let column_type = tokio::task::block_in_place(|| { let handle = tokio::runtime::Handle::current(); handle.block_on(async { @@ -215,40 +236,35 @@ impl SteelContext { match column_type { Ok(ct) => { - let converted_part = self.convert_value_to_steel_format(trimmed_part, &ct); - Ok(SteelVal::StringV(converted_part.into())) - } - Err(_) => { - // If type cannot be determined, return value as-is - Ok(SteelVal::StringV(trimmed_part.into())) + let converted = + self.convert_value_to_steel_format(trimmed, &ct); + Ok(SteelVal::StringV(converted.into())) } + Err(_) => Ok(SteelVal::StringV(trimmed.into())), } } else { Err(SteelVal::StringV("Index out of bounds".into())) } } else { - Err(SteelVal::StringV("Expected comma-separated string".into())) + Err(SteelVal::StringV( + "Expected comma-separated string".into(), + )) } } - /// Executes read-only SQL queries from Steel scripts with safety restrictions. - /// - /// # Security Features - /// - Only SELECT, SHOW, and EXPLAIN queries allowed - /// - Prohibited column type access validation - /// - Returns first column of all rows as comma-separated string pub fn steel_query_sql(&self, query: &str) -> Result { if !is_read_only_query(query) { - return Err(SteelVal::StringV( - "Only SELECT queries are allowed".into() - )); + return Err(SteelVal::StringV("Only SELECT queries are allowed".into())); } if contains_prohibited_column_access(query) { - return Err(SteelVal::StringV(format!( - "SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}", - PROHIBITED_TYPES.join(", ") - ).into())); + return Err(SteelVal::StringV( + format!( + "SQL query may access prohibited column types. Steel scripts cannot access columns of type: {}", + PROHIBITED_TYPES.join(", ") + ) + .into(), + )); } let pool = self.db_pool.clone(); @@ -263,7 +279,8 @@ impl SteelContext { let mut results = Vec::new(); for row in rows { - let val: String = row.try_get(0) + let val: String = row + .try_get(0) .map_err(|e| SteelVal::StringV(e.to_string().into()))?; results.push(val); } @@ -276,80 +293,30 @@ impl SteelContext { } } -/// Checks if a data type is prohibited for Steel script access. fn is_prohibited_type(data_type: &str) -> bool { let normalized_type = normalize_data_type(data_type); - PROHIBITED_TYPES.iter().any(|&prohibited| normalized_type.starts_with(prohibited)) + PROHIBITED_TYPES + .iter() + .any(|&prohibited| normalized_type.starts_with(prohibited)) } -/// Normalizes data type strings for consistent comparison. -/// Handles variations like NUMERIC(10,2) by extracting base type. fn normalize_data_type(data_type: &str) -> String { - data_type.to_uppercase() - .split('(') // Remove precision/scale from NUMERIC(x,y) + data_type + .to_uppercase() + .split('(') .next() .unwrap_or(data_type) .trim() .to_string() } -/// Performs basic heuristic check for prohibited column type access in SQL queries. -/// Looks for common patterns that might indicate access to restricted types. fn contains_prohibited_column_access(query: &str) -> bool { let query_upper = query.to_uppercase(); - - let patterns = [ - "EXTRACT(", // Common with DATE/TIMESTAMPTZ - "DATE_PART(", // Common with DATE/TIMESTAMPTZ - "::DATE", - "::TIMESTAMPTZ", - "::BIGINT", - ]; - - patterns.iter().any(|pattern| query_upper.contains(pattern)) + let patterns = ["EXTRACT(", "DATE_PART(", "::DATE", "::TIMESTAMPTZ", "::BIGINT"]; + patterns.iter().any(|p| query_upper.contains(p)) } -/// Validates that a query is read-only and safe for Steel script execution. fn is_read_only_query(query: &str) -> bool { let query = query.trim_start().to_uppercase(); - query.starts_with("SELECT") || - query.starts_with("SHOW") || - query.starts_with("EXPLAIN") -} - -/// Converts row data boolean values to Steel script format during context initialization. -pub async fn convert_row_data_for_steel( - db_pool: &PgPool, - schema_id: i64, - table_name: &str, - row_data: &mut HashMap, -) -> Result<(), sqlx::Error> { - let table_def = sqlx::query!( - r#"SELECT columns FROM table_definitions - WHERE schema_id = $1 AND table_name = $2"#, - schema_id, - table_name - ) - .fetch_optional(db_pool) - .await? - .ok_or_else(|| sqlx::Error::RowNotFound)?; - - // Parse column definitions to identify boolean columns for conversion - if let Ok(columns) = serde_json::from_value::>(table_def.columns) { - for col_def in columns { - let normalized_type = normalize_data_type(&col_def.field_type); - - if normalized_type == "BOOLEAN" || normalized_type == "BOOL" { - if let Some(value) = row_data.get_mut(&col_def.name) { - *value = match value.to_lowercase().as_str() { - "true" | "t" | "1" | "yes" | "on" => "#true".to_string(), - "false" | "f" | "0" | "no" | "off" => "#false".to_string(), - _ => value.clone(), - }; - } - } - } - } - - Ok(()) + query.starts_with("SELECT") || query.starts_with("SHOW") || query.starts_with("EXPLAIN") }