From d1ebe4732fb1b62acddf10fb6f532426bc6d94f3 Mon Sep 17 00:00:00 2001 From: filipriec Date: Wed, 2 Jul 2025 14:44:37 +0200 Subject: [PATCH] steel with decimal math, saving before separating steel to a separate crate --- server/Cargo.toml | 2 +- server/src/steel/server/decimal_math.rs | 190 +++++++++++++++++++++++ server/src/steel/server/execution.rs | 123 +++++++++++++-- server/src/steel/server/mod.rs | 2 + server/src/steel/server/syntax_parser.rs | 86 +++++++++- server/tests/mod.rs | 2 +- 6 files changed, 387 insertions(+), 18 deletions(-) create mode 100644 server/src/steel/server/decimal_math.rs diff --git a/server/Cargo.toml b/server/Cargo.toml index 927e402..64b15aa 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -33,7 +33,7 @@ validator = { version = "0.20.0", features = ["derive"] } uuid = { version = "1.16.0", features = ["serde", "v4"] } jsonwebtoken = "9.3.1" rust-stemmers = "1.2.0" -rust_decimal = "1.37.2" +rust_decimal = { version = "1.37.2", features = ["maths", "serde"] } rust_decimal_macros = "1.37.1" [lib] diff --git a/server/src/steel/server/decimal_math.rs b/server/src/steel/server/decimal_math.rs new file mode 100644 index 0000000..83dba2b --- /dev/null +++ b/server/src/steel/server/decimal_math.rs @@ -0,0 +1,190 @@ +// src/steel/server/decimal_math.rs +use rust_decimal::prelude::*; +use rust_decimal::MathematicalOps; +use steel::rvals::SteelVal; +use std::str::FromStr; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum DecimalMathError { + #[error("Invalid decimal format: {0}")] + InvalidDecimal(String), + #[error("Math operation failed: {0}")] + MathError(String), + #[error("Division by zero")] + DivisionByZero, +} + +/// Converts a SteelVal to a Decimal +fn steel_val_to_decimal(val: &SteelVal) -> Result { + match val { + SteelVal::StringV(s) => { + Decimal::from_str(&s.to_string()) + .map_err(|e| DecimalMathError::InvalidDecimal(format!("{}: {}", s, e))) + } + SteelVal::NumV(n) => { + Decimal::try_from(*n) + .map_err(|e| DecimalMathError::InvalidDecimal(format!("{}: {}", n, e))) + } + SteelVal::IntV(i) => { + Ok(Decimal::from(*i)) + } + _ => Err(DecimalMathError::InvalidDecimal(format!("Unsupported type: {:?}", val))) + } +} + +/// Converts a Decimal back to a SteelVal string +fn decimal_to_steel_val(decimal: Decimal) -> SteelVal { + SteelVal::StringV(decimal.to_string().into()) +} + +// Basic arithmetic operations +pub fn decimal_add(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok((a_dec + b_dec).to_string()) +} + +pub fn decimal_sub(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok((a_dec - b_dec).to_string()) +} + +pub fn decimal_mul(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok((a_dec * b_dec).to_string()) +} + +pub fn decimal_div(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + + if b_dec.is_zero() { + return Err("Division by zero".to_string()); + } + + Ok((a_dec / b_dec).to_string()) +} + +// Advanced mathematical functions (requires maths feature) +pub fn decimal_pow(base: String, exp: String) -> Result { + let base_dec = Decimal::from_str(&base).map_err(|e| format!("Invalid decimal '{}': {}", base, e))?; + let exp_dec = Decimal::from_str(&exp).map_err(|e| format!("Invalid decimal '{}': {}", exp, e))?; + + base_dec.checked_powd(exp_dec) + .map(|result| result.to_string()) + .ok_or_else(|| "Power operation failed or overflowed".to_string()) +} + +pub fn decimal_sqrt(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.sqrt() + .map(|result| result.to_string()) + .ok_or_else(|| "Square root failed (negative number?)".to_string()) +} + +pub fn decimal_ln(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.checked_ln() + .map(|result| result.to_string()) + .ok_or_else(|| "Natural log failed (non-positive number?)".to_string()) +} + +pub fn decimal_log10(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.checked_log10() + .map(|result| result.to_string()) + .ok_or_else(|| "Log10 failed (non-positive number?)".to_string()) +} + +pub fn decimal_exp(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.checked_exp() + .map(|result| result.to_string()) + .ok_or_else(|| "Exponential failed or overflowed".to_string()) +} + +// Trigonometric functions (input in radians) +pub fn decimal_sin(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.checked_sin() + .map(|result| result.to_string()) + .ok_or_else(|| "Sine calculation failed or overflowed".to_string()) +} + +pub fn decimal_cos(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.checked_cos() + .map(|result| result.to_string()) + .ok_or_else(|| "Cosine calculation failed or overflowed".to_string()) +} + +pub fn decimal_tan(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.checked_tan() + .map(|result| result.to_string()) + .ok_or_else(|| "Tangent calculation failed or overflowed".to_string()) +} + +// Comparison functions +pub fn decimal_gt(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec > b_dec) +} + +pub fn decimal_lt(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec < b_dec) +} + +pub fn decimal_eq(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec == b_dec) +} + +// Utility functions +pub fn decimal_abs(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + Ok(a_dec.abs().to_string()) +} + +pub fn decimal_round(a: String, places: i32) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + Ok(a_dec.round_dp(places as u32).to_string()) +} + +pub fn decimal_min(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec.min(b_dec).to_string()) +} + +pub fn decimal_max(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec.max(b_dec).to_string()) +} + +pub fn decimal_gte(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec >= b_dec) +} + +pub fn decimal_lte(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec <= b_dec) +} diff --git a/server/src/steel/server/execution.rs b/server/src/steel/server/execution.rs index 6bf9df3..f3ad50b 100644 --- a/server/src/steel/server/execution.rs +++ b/server/src/steel/server/execution.rs @@ -1,8 +1,9 @@ -// src/steel/server/execution.rs +// Updated src/steel/server/execution.rs use steel::steel_vm::engine::Engine; use steel::steel_vm::register_fn::RegisterFn; use steel::rvals::SteelVal; use super::functions::SteelContext; +use super::decimal_math::*; use sqlx::PgPool; use std::sync::Arc; use thiserror::Error; @@ -33,6 +34,24 @@ pub fn execute_script( let mut vm = Engine::new(); let context = Arc::new(context); + // Register existing Steel functions + register_steel_functions(&mut vm, context.clone()); + + // Register all decimal math functions + register_decimal_math_functions(&mut vm); + + // Execute script and process results + let results = vm.compile_and_run_raw_program(script) + .map_err(|e| ExecutionError::RuntimeError(e.to_string()))?; + + // Convert results to target type + match target_type { + "STRINGS" => process_string_results(results), + _ => Err(ExecutionError::UnsupportedType(target_type.into())) + } +} + +fn register_steel_functions(vm: &mut Engine, context: Arc) { // Register steel_get_column with row context vm.register_fn("steel_get_column", { let ctx = context.clone(); @@ -59,27 +78,101 @@ pub fn execute_script( .map_err(|e| e.to_string()) } }); +} - // Execute script and process results - let results = vm.compile_and_run_raw_program(script) - .map_err(|e| ExecutionError::RuntimeError(e.to_string()))?; +fn register_decimal_math_functions(vm: &mut Engine) { + // Basic arithmetic operations + vm.register_fn("decimal-add", decimal_add); + vm.register_fn("decimal-sub", decimal_sub); + vm.register_fn("decimal-mul", decimal_mul); + vm.register_fn("decimal-div", decimal_div); - // Convert results to target type - match target_type { - "STRINGS" => process_string_results(results), - _ => Err(ExecutionError::UnsupportedType(target_type.into())) - } + // Advanced mathematical functions + vm.register_fn("decimal-pow", decimal_pow); + vm.register_fn("decimal-sqrt", decimal_sqrt); + vm.register_fn("decimal-ln", decimal_ln); + vm.register_fn("decimal-log10", decimal_log10); + vm.register_fn("decimal-exp", decimal_exp); + + // Trigonometric functions + vm.register_fn("decimal-sin", decimal_sin); + vm.register_fn("decimal-cos", decimal_cos); + vm.register_fn("decimal-tan", decimal_tan); + + // Comparison functions + vm.register_fn("decimal-gt", decimal_gt); + vm.register_fn("decimal-lt", decimal_lt); + vm.register_fn("decimal-eq", decimal_eq); + + // Utility functions + vm.register_fn("decimal-abs", decimal_abs); + vm.register_fn("decimal-round", decimal_round); + vm.register_fn("decimal-min", decimal_min); + vm.register_fn("decimal-max", decimal_max); + + // Additional convenience functions + vm.register_fn("decimal-zero", || "0".to_string()); + vm.register_fn("decimal-one", || "1".to_string()); + vm.register_fn("decimal-pi", || "3.1415926535897932384626433833".to_string()); + vm.register_fn("decimal-e", || "2.7182818284590452353602874714".to_string()); + + // Type conversion helpers + vm.register_fn("to-decimal", |s: String| -> Result { + use rust_decimal::prelude::*; + use std::str::FromStr; + + Decimal::from_str(&s) + .map(|d| d.to_string()) + .map_err(|e| format!("Invalid decimal: {}", e)) + }); + + // Financial functions + vm.register_fn("decimal-percentage", |amount: String, percentage: String| -> Result { + use rust_decimal::prelude::*; + use std::str::FromStr; + + let amount_dec = Decimal::from_str(&amount) + .map_err(|e| format!("Invalid amount: {}", e))?; + let percentage_dec = Decimal::from_str(&percentage) + .map_err(|e| format!("Invalid percentage: {}", e))?; + let hundred = Decimal::from(100); + + Ok((amount_dec * percentage_dec / hundred).to_string()) + }); + + vm.register_fn("decimal-compound", |principal: String, rate: String, time: String| -> Result { + use rust_decimal::prelude::*; + use rust_decimal::MathematicalOps; + use std::str::FromStr; + + let principal_dec = Decimal::from_str(&principal) + .map_err(|e| format!("Invalid principal: {}", e))?; + let rate_dec = Decimal::from_str(&rate) + .map_err(|e| format!("Invalid rate: {}", e))?; + let time_dec = Decimal::from_str(&time) + .map_err(|e| format!("Invalid time: {}", e))?; + + let one = Decimal::ONE; + let compound_factor = (one + rate_dec).checked_powd(time_dec) + .ok_or("Compound calculation overflow")?; + + Ok((principal_dec * compound_factor).to_string()) + }); } fn process_string_results(results: Vec) -> Result { let mut strings = Vec::new(); for result in results { - if let SteelVal::StringV(s) = result { - strings.push(s.to_string()); - } else { - return Err(ExecutionError::TypeConversionError( - format!("Expected string, got {:?}", result) - )); + match result { + SteelVal::StringV(s) => strings.push(s.to_string()), + SteelVal::NumV(n) => strings.push(n.to_string()), + SteelVal::IntV(i) => strings.push(i.to_string()), + SteelVal::BoolV(b) => strings.push(b.to_string()), + _ => { + return Err(ExecutionError::TypeConversionError( + format!("Expected string-convertible type, got {:?}", result) + )); + } } } Ok(Value::Strings(strings)) diff --git a/server/src/steel/server/mod.rs b/server/src/steel/server/mod.rs index 3356077..2c689bc 100644 --- a/server/src/steel/server/mod.rs +++ b/server/src/steel/server/mod.rs @@ -2,7 +2,9 @@ pub mod execution; pub mod syntax_parser; pub mod functions; +pub mod decimal_math; pub use execution::*; pub use syntax_parser::*; pub use functions::*; +pub use decimal_math::*; diff --git a/server/src/steel/server/syntax_parser.rs b/server/src/steel/server/syntax_parser.rs index 33b0c48..2dfba9a 100644 --- a/server/src/steel/server/syntax_parser.rs +++ b/server/src/steel/server/syntax_parser.rs @@ -1,27 +1,111 @@ -// src/steel/server/syntax_parser.rs use regex::Regex; use std::collections::HashSet; pub struct SyntaxParser { + // Existing patterns for column/SQL integration current_table_column_re: Regex, different_table_column_re: Regex, one_to_many_indexed_re: Regex, sql_integration_re: Regex, + + // Simple math operation replacement patterns + math_operators: Vec<(Regex, &'static str)>, + number_literal_re: Regex, } impl SyntaxParser { pub fn new() -> Self { + // Define math operator replacements + let math_operators = vec![ + // Basic arithmetic + (Regex::new(r"\(\s*\+\s+").unwrap(), "(decimal-add "), + (Regex::new(r"\(\s*-\s+").unwrap(), "(decimal-sub "), + (Regex::new(r"\(\s*\*\s+").unwrap(), "(decimal-mul "), + (Regex::new(r"\(\s*/\s+").unwrap(), "(decimal-div "), + + // Power and advanced operations + (Regex::new(r"\(\s*\^\s+").unwrap(), "(decimal-pow "), + (Regex::new(r"\(\s*\*\*\s+").unwrap(), "(decimal-pow "), + (Regex::new(r"\(\s*pow\s+").unwrap(), "(decimal-pow "), + (Regex::new(r"\(\s*sqrt\s+").unwrap(), "(decimal-sqrt "), + + // Logarithmic functions + (Regex::new(r"\(\s*ln\s+").unwrap(), "(decimal-ln "), + (Regex::new(r"\(\s*log\s+").unwrap(), "(decimal-ln "), + (Regex::new(r"\(\s*log10\s+").unwrap(), "(decimal-log10 "), + (Regex::new(r"\(\s*exp\s+").unwrap(), "(decimal-exp "), + + // Trigonometric functions + (Regex::new(r"\(\s*sin\s+").unwrap(), "(decimal-sin "), + (Regex::new(r"\(\s*cos\s+").unwrap(), "(decimal-cos "), + (Regex::new(r"\(\s*tan\s+").unwrap(), "(decimal-tan "), + + // Comparison operators + (Regex::new(r"\(\s*>\s+").unwrap(), "(decimal-gt "), + (Regex::new(r"\(\s*<\s+").unwrap(), "(decimal-lt "), + (Regex::new(r"\(\s*=\s+").unwrap(), "(decimal-eq "), + (Regex::new(r"\(\s*>=\s+").unwrap(), "(decimal-gte "), + (Regex::new(r"\(\s*<=\s+").unwrap(), "(decimal-lte "), + + // Utility functions + (Regex::new(r"\(\s*abs\s+").unwrap(), "(decimal-abs "), + (Regex::new(r"\(\s*min\s+").unwrap(), "(decimal-min "), + (Regex::new(r"\(\s*max\s+").unwrap(), "(decimal-max "), + (Regex::new(r"\(\s*round\s+").unwrap(), "(decimal-round "), + ]; + SyntaxParser { current_table_column_re: Regex::new(r"@(\w+)").unwrap(), different_table_column_re: Regex::new(r"@(\w+)\.(\w+)").unwrap(), one_to_many_indexed_re: Regex::new(r"@(\w+)\[(\d+)\]\.(\w+)").unwrap(), sql_integration_re: Regex::new(r#"@sql\((['"])(.*?)['"]\)"#).unwrap(), + + // FIXED: Match negative numbers and avoid already quoted strings + number_literal_re: Regex::new(r#"(? String { let mut transformed = script.to_string(); + // Step 1: Convert all numeric literals to strings (FIXED to handle negative numbers) + transformed = self.convert_numbers_to_strings(&transformed); + + // Step 2: Replace math function calls with decimal equivalents (SIMPLIFIED) + transformed = self.replace_math_functions(&transformed); + + // Step 3: Handle existing column and SQL integrations (unchanged) + transformed = self.process_column_integrations(&transformed, current_table); + + transformed + } + + /// Convert all unquoted numeric literals to quoted strings + fn convert_numbers_to_strings(&self, script: &str) -> String { + // This regex matches numbers that are NOT already inside quotes + self.number_literal_re.replace_all(script, |caps: ®ex::Captures| { + format!("\"{}\"", &caps[1]) + }).to_string() + } + + /// Replace math function calls with decimal equivalents (SIMPLIFIED) + fn replace_math_functions(&self, script: &str) -> String { + let mut result = script.to_string(); + + // Apply all math operator replacements + for (pattern, replacement) in &self.math_operators { + result = pattern.replace_all(&result, *replacement).to_string(); + } + + result + } + + /// Process existing column and SQL integrations (unchanged logic) + fn process_column_integrations(&self, script: &str, current_table: &str) -> String { + let mut transformed = script.to_string(); + // Process indexed access first to avoid overlap with relationship matches transformed = self.one_to_many_indexed_re.replace_all(&transformed, |caps: ®ex::Captures| { format!("(steel_get_column_with_index \"{}\" {} \"{}\")", diff --git a/server/tests/mod.rs b/server/tests/mod.rs index 09829a3..497c490 100644 --- a/server/tests/mod.rs +++ b/server/tests/mod.rs @@ -1,4 +1,4 @@ // tests/mod.rs pub mod tables_data; pub mod common; -// pub mod table_definition; +pub mod table_definition;