diff --git a/server/src/steel/server/decimal_math.rs b/server/src/steel/server/decimal_math.rs deleted file mode 100644 index 83dba2b..0000000 --- a/server/src/steel/server/decimal_math.rs +++ /dev/null @@ -1,190 +0,0 @@ -// src/steel/server/decimal_math.rs -use rust_decimal::prelude::*; -use rust_decimal::MathematicalOps; -use steel::rvals::SteelVal; -use std::str::FromStr; -use thiserror::Error; - -#[derive(Debug, Error)] -pub enum DecimalMathError { - #[error("Invalid decimal format: {0}")] - InvalidDecimal(String), - #[error("Math operation failed: {0}")] - MathError(String), - #[error("Division by zero")] - DivisionByZero, -} - -/// Converts a SteelVal to a Decimal -fn steel_val_to_decimal(val: &SteelVal) -> Result { - match val { - SteelVal::StringV(s) => { - Decimal::from_str(&s.to_string()) - .map_err(|e| DecimalMathError::InvalidDecimal(format!("{}: {}", s, e))) - } - SteelVal::NumV(n) => { - Decimal::try_from(*n) - .map_err(|e| DecimalMathError::InvalidDecimal(format!("{}: {}", n, e))) - } - SteelVal::IntV(i) => { - Ok(Decimal::from(*i)) - } - _ => Err(DecimalMathError::InvalidDecimal(format!("Unsupported type: {:?}", val))) - } -} - -/// Converts a Decimal back to a SteelVal string -fn decimal_to_steel_val(decimal: Decimal) -> SteelVal { - SteelVal::StringV(decimal.to_string().into()) -} - -// Basic arithmetic operations -pub fn decimal_add(a: String, b: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; - Ok((a_dec + b_dec).to_string()) -} - -pub fn decimal_sub(a: String, b: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; - Ok((a_dec - b_dec).to_string()) -} - -pub fn decimal_mul(a: String, b: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; - Ok((a_dec * b_dec).to_string()) -} - -pub fn decimal_div(a: String, b: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; - - if b_dec.is_zero() { - return Err("Division by zero".to_string()); - } - - Ok((a_dec / b_dec).to_string()) -} - -// Advanced mathematical functions (requires maths feature) -pub fn decimal_pow(base: String, exp: String) -> Result { - let base_dec = Decimal::from_str(&base).map_err(|e| format!("Invalid decimal '{}': {}", base, e))?; - let exp_dec = Decimal::from_str(&exp).map_err(|e| format!("Invalid decimal '{}': {}", exp, e))?; - - base_dec.checked_powd(exp_dec) - .map(|result| result.to_string()) - .ok_or_else(|| "Power operation failed or overflowed".to_string()) -} - -pub fn decimal_sqrt(a: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - - a_dec.sqrt() - .map(|result| result.to_string()) - .ok_or_else(|| "Square root failed (negative number?)".to_string()) -} - -pub fn decimal_ln(a: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - - a_dec.checked_ln() - .map(|result| result.to_string()) - .ok_or_else(|| "Natural log failed (non-positive number?)".to_string()) -} - -pub fn decimal_log10(a: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - - a_dec.checked_log10() - .map(|result| result.to_string()) - .ok_or_else(|| "Log10 failed (non-positive number?)".to_string()) -} - -pub fn decimal_exp(a: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - - a_dec.checked_exp() - .map(|result| result.to_string()) - .ok_or_else(|| "Exponential failed or overflowed".to_string()) -} - -// Trigonometric functions (input in radians) -pub fn decimal_sin(a: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - - a_dec.checked_sin() - .map(|result| result.to_string()) - .ok_or_else(|| "Sine calculation failed or overflowed".to_string()) -} - -pub fn decimal_cos(a: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - - a_dec.checked_cos() - .map(|result| result.to_string()) - .ok_or_else(|| "Cosine calculation failed or overflowed".to_string()) -} - -pub fn decimal_tan(a: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - - a_dec.checked_tan() - .map(|result| result.to_string()) - .ok_or_else(|| "Tangent calculation failed or overflowed".to_string()) -} - -// Comparison functions -pub fn decimal_gt(a: String, b: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; - Ok(a_dec > b_dec) -} - -pub fn decimal_lt(a: String, b: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; - Ok(a_dec < b_dec) -} - -pub fn decimal_eq(a: String, b: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; - Ok(a_dec == b_dec) -} - -// Utility functions -pub fn decimal_abs(a: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - Ok(a_dec.abs().to_string()) -} - -pub fn decimal_round(a: String, places: i32) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - Ok(a_dec.round_dp(places as u32).to_string()) -} - -pub fn decimal_min(a: String, b: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; - Ok(a_dec.min(b_dec).to_string()) -} - -pub fn decimal_max(a: String, b: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; - Ok(a_dec.max(b_dec).to_string()) -} - -pub fn decimal_gte(a: String, b: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; - Ok(a_dec >= b_dec) -} - -pub fn decimal_lte(a: String, b: String) -> Result { - let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; - let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; - Ok(a_dec <= b_dec) -} diff --git a/server/src/steel/server/execution.rs b/server/src/steel/server/execution.rs index f3ad50b..c2c1268 100644 --- a/server/src/steel/server/execution.rs +++ b/server/src/steel/server/execution.rs @@ -3,7 +3,7 @@ use steel::steel_vm::engine::Engine; use steel::steel_vm::register_fn::RegisterFn; use steel::rvals::SteelVal; use super::functions::SteelContext; -use super::decimal_math::*; +use steel_decimal::registry::FunctionRegistry; use sqlx::PgPool; use std::sync::Arc; use thiserror::Error; @@ -28,7 +28,7 @@ pub enum ExecutionError { pub fn execute_script( script: String, target_type: &str, - _db_pool: Arc, // Passed to the SteelContext + db_pool: Arc, context: SteelContext, ) -> Result { let mut vm = Engine::new(); @@ -36,8 +36,8 @@ pub fn execute_script( // Register existing Steel functions register_steel_functions(&mut vm, context.clone()); - - // Register all decimal math functions + + // Register all decimal math functions using the steel_decimal crate register_decimal_math_functions(&mut vm); // Execute script and process results @@ -81,83 +81,8 @@ fn register_steel_functions(vm: &mut Engine, context: Arc) { } fn register_decimal_math_functions(vm: &mut Engine) { - // Basic arithmetic operations - vm.register_fn("decimal-add", decimal_add); - vm.register_fn("decimal-sub", decimal_sub); - vm.register_fn("decimal-mul", decimal_mul); - vm.register_fn("decimal-div", decimal_div); - - // Advanced mathematical functions - vm.register_fn("decimal-pow", decimal_pow); - vm.register_fn("decimal-sqrt", decimal_sqrt); - vm.register_fn("decimal-ln", decimal_ln); - vm.register_fn("decimal-log10", decimal_log10); - vm.register_fn("decimal-exp", decimal_exp); - - // Trigonometric functions - vm.register_fn("decimal-sin", decimal_sin); - vm.register_fn("decimal-cos", decimal_cos); - vm.register_fn("decimal-tan", decimal_tan); - - // Comparison functions - vm.register_fn("decimal-gt", decimal_gt); - vm.register_fn("decimal-lt", decimal_lt); - vm.register_fn("decimal-eq", decimal_eq); - - // Utility functions - vm.register_fn("decimal-abs", decimal_abs); - vm.register_fn("decimal-round", decimal_round); - vm.register_fn("decimal-min", decimal_min); - vm.register_fn("decimal-max", decimal_max); - - // Additional convenience functions - vm.register_fn("decimal-zero", || "0".to_string()); - vm.register_fn("decimal-one", || "1".to_string()); - vm.register_fn("decimal-pi", || "3.1415926535897932384626433833".to_string()); - vm.register_fn("decimal-e", || "2.7182818284590452353602874714".to_string()); - - // Type conversion helpers - vm.register_fn("to-decimal", |s: String| -> Result { - use rust_decimal::prelude::*; - use std::str::FromStr; - - Decimal::from_str(&s) - .map(|d| d.to_string()) - .map_err(|e| format!("Invalid decimal: {}", e)) - }); - - // Financial functions - vm.register_fn("decimal-percentage", |amount: String, percentage: String| -> Result { - use rust_decimal::prelude::*; - use std::str::FromStr; - - let amount_dec = Decimal::from_str(&amount) - .map_err(|e| format!("Invalid amount: {}", e))?; - let percentage_dec = Decimal::from_str(&percentage) - .map_err(|e| format!("Invalid percentage: {}", e))?; - let hundred = Decimal::from(100); - - Ok((amount_dec * percentage_dec / hundred).to_string()) - }); - - vm.register_fn("decimal-compound", |principal: String, rate: String, time: String| -> Result { - use rust_decimal::prelude::*; - use rust_decimal::MathematicalOps; - use std::str::FromStr; - - let principal_dec = Decimal::from_str(&principal) - .map_err(|e| format!("Invalid principal: {}", e))?; - let rate_dec = Decimal::from_str(&rate) - .map_err(|e| format!("Invalid rate: {}", e))?; - let time_dec = Decimal::from_str(&time) - .map_err(|e| format!("Invalid time: {}", e))?; - - let one = Decimal::ONE; - let compound_factor = (one + rate_dec).checked_powd(time_dec) - .ok_or("Compound calculation overflow")?; - - Ok((principal_dec * compound_factor).to_string()) - }); + // Use the steel_decimal crate's FunctionRegistry to register all functions + FunctionRegistry::register_all(vm); } fn process_string_results(results: Vec) -> Result { diff --git a/server/src/steel/server/mod.rs b/server/src/steel/server/mod.rs index 2c689bc..7b80092 100644 --- a/server/src/steel/server/mod.rs +++ b/server/src/steel/server/mod.rs @@ -1,10 +1,6 @@ // src/steel/server/mod.rs pub mod execution; -pub mod syntax_parser; pub mod functions; -pub mod decimal_math; pub use execution::*; -pub use syntax_parser::*; pub use functions::*; -pub use decimal_math::*; diff --git a/server/src/steel/server/syntax_parser.rs b/server/src/steel/server/syntax_parser.rs deleted file mode 100644 index 2dfba9a..0000000 --- a/server/src/steel/server/syntax_parser.rs +++ /dev/null @@ -1,154 +0,0 @@ -use regex::Regex; -use std::collections::HashSet; - -pub struct SyntaxParser { - // Existing patterns for column/SQL integration - current_table_column_re: Regex, - different_table_column_re: Regex, - one_to_many_indexed_re: Regex, - sql_integration_re: Regex, - - // Simple math operation replacement patterns - math_operators: Vec<(Regex, &'static str)>, - number_literal_re: Regex, -} - -impl SyntaxParser { - pub fn new() -> Self { - // Define math operator replacements - let math_operators = vec![ - // Basic arithmetic - (Regex::new(r"\(\s*\+\s+").unwrap(), "(decimal-add "), - (Regex::new(r"\(\s*-\s+").unwrap(), "(decimal-sub "), - (Regex::new(r"\(\s*\*\s+").unwrap(), "(decimal-mul "), - (Regex::new(r"\(\s*/\s+").unwrap(), "(decimal-div "), - - // Power and advanced operations - (Regex::new(r"\(\s*\^\s+").unwrap(), "(decimal-pow "), - (Regex::new(r"\(\s*\*\*\s+").unwrap(), "(decimal-pow "), - (Regex::new(r"\(\s*pow\s+").unwrap(), "(decimal-pow "), - (Regex::new(r"\(\s*sqrt\s+").unwrap(), "(decimal-sqrt "), - - // Logarithmic functions - (Regex::new(r"\(\s*ln\s+").unwrap(), "(decimal-ln "), - (Regex::new(r"\(\s*log\s+").unwrap(), "(decimal-ln "), - (Regex::new(r"\(\s*log10\s+").unwrap(), "(decimal-log10 "), - (Regex::new(r"\(\s*exp\s+").unwrap(), "(decimal-exp "), - - // Trigonometric functions - (Regex::new(r"\(\s*sin\s+").unwrap(), "(decimal-sin "), - (Regex::new(r"\(\s*cos\s+").unwrap(), "(decimal-cos "), - (Regex::new(r"\(\s*tan\s+").unwrap(), "(decimal-tan "), - - // Comparison operators - (Regex::new(r"\(\s*>\s+").unwrap(), "(decimal-gt "), - (Regex::new(r"\(\s*<\s+").unwrap(), "(decimal-lt "), - (Regex::new(r"\(\s*=\s+").unwrap(), "(decimal-eq "), - (Regex::new(r"\(\s*>=\s+").unwrap(), "(decimal-gte "), - (Regex::new(r"\(\s*<=\s+").unwrap(), "(decimal-lte "), - - // Utility functions - (Regex::new(r"\(\s*abs\s+").unwrap(), "(decimal-abs "), - (Regex::new(r"\(\s*min\s+").unwrap(), "(decimal-min "), - (Regex::new(r"\(\s*max\s+").unwrap(), "(decimal-max "), - (Regex::new(r"\(\s*round\s+").unwrap(), "(decimal-round "), - ]; - - SyntaxParser { - current_table_column_re: Regex::new(r"@(\w+)").unwrap(), - different_table_column_re: Regex::new(r"@(\w+)\.(\w+)").unwrap(), - one_to_many_indexed_re: Regex::new(r"@(\w+)\[(\d+)\]\.(\w+)").unwrap(), - sql_integration_re: Regex::new(r#"@sql\((['"])(.*?)['"]\)"#).unwrap(), - - // FIXED: Match negative numbers and avoid already quoted strings - number_literal_re: Regex::new(r#"(? String { - let mut transformed = script.to_string(); - - // Step 1: Convert all numeric literals to strings (FIXED to handle negative numbers) - transformed = self.convert_numbers_to_strings(&transformed); - - // Step 2: Replace math function calls with decimal equivalents (SIMPLIFIED) - transformed = self.replace_math_functions(&transformed); - - // Step 3: Handle existing column and SQL integrations (unchanged) - transformed = self.process_column_integrations(&transformed, current_table); - - transformed - } - - /// Convert all unquoted numeric literals to quoted strings - fn convert_numbers_to_strings(&self, script: &str) -> String { - // This regex matches numbers that are NOT already inside quotes - self.number_literal_re.replace_all(script, |caps: ®ex::Captures| { - format!("\"{}\"", &caps[1]) - }).to_string() - } - - /// Replace math function calls with decimal equivalents (SIMPLIFIED) - fn replace_math_functions(&self, script: &str) -> String { - let mut result = script.to_string(); - - // Apply all math operator replacements - for (pattern, replacement) in &self.math_operators { - result = pattern.replace_all(&result, *replacement).to_string(); - } - - result - } - - /// Process existing column and SQL integrations (unchanged logic) - fn process_column_integrations(&self, script: &str, current_table: &str) -> String { - let mut transformed = script.to_string(); - - // Process indexed access first to avoid overlap with relationship matches - transformed = self.one_to_many_indexed_re.replace_all(&transformed, |caps: ®ex::Captures| { - format!("(steel_get_column_with_index \"{}\" {} \"{}\")", - &caps[1], &caps[2], &caps[3]) - }).to_string(); - - // Process relationships - transformed = self.different_table_column_re.replace_all(&transformed, |caps: ®ex::Captures| { - format!("(steel_get_column \"{}\" \"{}\")", &caps[1], &caps[2]) - }).to_string(); - - // Process basic column access - transformed = self.current_table_column_re.replace_all(&transformed, |caps: ®ex::Captures| { - format!("(steel_get_column \"{}\" \"{}\")", current_table, &caps[1]) - }).to_string(); - - // Process SQL integration - transformed = self.sql_integration_re.replace_all(&transformed, |caps: ®ex::Captures| { - format!("(steel_query_sql \"{}\")", &caps[2]) - }).to_string(); - - transformed - } - - pub fn extract_dependencies(&self, script: &str, current_table: &str) -> (HashSet, HashSet) { - let mut tables = HashSet::new(); - let mut columns = HashSet::new(); - - for cap in self.current_table_column_re.captures_iter(script) { - tables.insert(current_table.to_string()); - columns.insert(cap[1].to_string()); - } - - for cap in self.different_table_column_re.captures_iter(script) { - tables.insert(cap[1].to_string()); - columns.insert(cap[2].to_string()); - } - - for cap in self.one_to_many_indexed_re.captures_iter(script) { - tables.insert(cap[1].to_string()); - columns.insert(cap[3].to_string()); - } - - (tables, columns) - } -}