steel with decimal math, saving before separating steel to a separate crate

This commit is contained in:
filipriec
2025-07-02 14:44:37 +02:00
parent 7b7f3ca05a
commit d1ebe4732f
6 changed files with 387 additions and 18 deletions

View File

@@ -33,7 +33,7 @@ validator = { version = "0.20.0", features = ["derive"] }
uuid = { version = "1.16.0", features = ["serde", "v4"] } uuid = { version = "1.16.0", features = ["serde", "v4"] }
jsonwebtoken = "9.3.1" jsonwebtoken = "9.3.1"
rust-stemmers = "1.2.0" rust-stemmers = "1.2.0"
rust_decimal = "1.37.2" rust_decimal = { version = "1.37.2", features = ["maths", "serde"] }
rust_decimal_macros = "1.37.1" rust_decimal_macros = "1.37.1"
[lib] [lib]

View File

@@ -0,0 +1,190 @@
// src/steel/server/decimal_math.rs
use rust_decimal::prelude::*;
use rust_decimal::MathematicalOps;
use steel::rvals::SteelVal;
use std::str::FromStr;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum DecimalMathError {
#[error("Invalid decimal format: {0}")]
InvalidDecimal(String),
#[error("Math operation failed: {0}")]
MathError(String),
#[error("Division by zero")]
DivisionByZero,
}
/// Converts a SteelVal to a Decimal
fn steel_val_to_decimal(val: &SteelVal) -> Result<Decimal, DecimalMathError> {
match val {
SteelVal::StringV(s) => {
Decimal::from_str(&s.to_string())
.map_err(|e| DecimalMathError::InvalidDecimal(format!("{}: {}", s, e)))
}
SteelVal::NumV(n) => {
Decimal::try_from(*n)
.map_err(|e| DecimalMathError::InvalidDecimal(format!("{}: {}", n, e)))
}
SteelVal::IntV(i) => {
Ok(Decimal::from(*i))
}
_ => Err(DecimalMathError::InvalidDecimal(format!("Unsupported type: {:?}", val)))
}
}
/// Converts a Decimal back to a SteelVal string
fn decimal_to_steel_val(decimal: Decimal) -> SteelVal {
SteelVal::StringV(decimal.to_string().into())
}
// Basic arithmetic operations
pub fn decimal_add(a: String, b: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?;
Ok((a_dec + b_dec).to_string())
}
pub fn decimal_sub(a: String, b: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?;
Ok((a_dec - b_dec).to_string())
}
pub fn decimal_mul(a: String, b: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?;
Ok((a_dec * b_dec).to_string())
}
pub fn decimal_div(a: String, b: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?;
if b_dec.is_zero() {
return Err("Division by zero".to_string());
}
Ok((a_dec / b_dec).to_string())
}
// Advanced mathematical functions (requires maths feature)
pub fn decimal_pow(base: String, exp: String) -> Result<String, String> {
let base_dec = Decimal::from_str(&base).map_err(|e| format!("Invalid decimal '{}': {}", base, e))?;
let exp_dec = Decimal::from_str(&exp).map_err(|e| format!("Invalid decimal '{}': {}", exp, e))?;
base_dec.checked_powd(exp_dec)
.map(|result| result.to_string())
.ok_or_else(|| "Power operation failed or overflowed".to_string())
}
pub fn decimal_sqrt(a: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
a_dec.sqrt()
.map(|result| result.to_string())
.ok_or_else(|| "Square root failed (negative number?)".to_string())
}
pub fn decimal_ln(a: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
a_dec.checked_ln()
.map(|result| result.to_string())
.ok_or_else(|| "Natural log failed (non-positive number?)".to_string())
}
pub fn decimal_log10(a: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
a_dec.checked_log10()
.map(|result| result.to_string())
.ok_or_else(|| "Log10 failed (non-positive number?)".to_string())
}
pub fn decimal_exp(a: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
a_dec.checked_exp()
.map(|result| result.to_string())
.ok_or_else(|| "Exponential failed or overflowed".to_string())
}
// Trigonometric functions (input in radians)
pub fn decimal_sin(a: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
a_dec.checked_sin()
.map(|result| result.to_string())
.ok_or_else(|| "Sine calculation failed or overflowed".to_string())
}
pub fn decimal_cos(a: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
a_dec.checked_cos()
.map(|result| result.to_string())
.ok_or_else(|| "Cosine calculation failed or overflowed".to_string())
}
pub fn decimal_tan(a: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
a_dec.checked_tan()
.map(|result| result.to_string())
.ok_or_else(|| "Tangent calculation failed or overflowed".to_string())
}
// Comparison functions
pub fn decimal_gt(a: String, b: String) -> Result<bool, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?;
Ok(a_dec > b_dec)
}
pub fn decimal_lt(a: String, b: String) -> Result<bool, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?;
Ok(a_dec < b_dec)
}
pub fn decimal_eq(a: String, b: String) -> Result<bool, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?;
Ok(a_dec == b_dec)
}
// Utility functions
pub fn decimal_abs(a: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
Ok(a_dec.abs().to_string())
}
pub fn decimal_round(a: String, places: i32) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
Ok(a_dec.round_dp(places as u32).to_string())
}
pub fn decimal_min(a: String, b: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?;
Ok(a_dec.min(b_dec).to_string())
}
pub fn decimal_max(a: String, b: String) -> Result<String, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?;
Ok(a_dec.max(b_dec).to_string())
}
pub fn decimal_gte(a: String, b: String) -> Result<bool, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?;
Ok(a_dec >= b_dec)
}
pub fn decimal_lte(a: String, b: String) -> Result<bool, String> {
let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?;
let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?;
Ok(a_dec <= b_dec)
}

View File

@@ -1,8 +1,9 @@
// src/steel/server/execution.rs // Updated src/steel/server/execution.rs
use steel::steel_vm::engine::Engine; use steel::steel_vm::engine::Engine;
use steel::steel_vm::register_fn::RegisterFn; use steel::steel_vm::register_fn::RegisterFn;
use steel::rvals::SteelVal; use steel::rvals::SteelVal;
use super::functions::SteelContext; use super::functions::SteelContext;
use super::decimal_math::*;
use sqlx::PgPool; use sqlx::PgPool;
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
@@ -33,6 +34,24 @@ pub fn execute_script(
let mut vm = Engine::new(); let mut vm = Engine::new();
let context = Arc::new(context); let context = Arc::new(context);
// Register existing Steel functions
register_steel_functions(&mut vm, context.clone());
// Register all decimal math functions
register_decimal_math_functions(&mut vm);
// Execute script and process results
let results = vm.compile_and_run_raw_program(script)
.map_err(|e| ExecutionError::RuntimeError(e.to_string()))?;
// Convert results to target type
match target_type {
"STRINGS" => process_string_results(results),
_ => Err(ExecutionError::UnsupportedType(target_type.into()))
}
}
fn register_steel_functions(vm: &mut Engine, context: Arc<SteelContext>) {
// Register steel_get_column with row context // Register steel_get_column with row context
vm.register_fn("steel_get_column", { vm.register_fn("steel_get_column", {
let ctx = context.clone(); let ctx = context.clone();
@@ -59,28 +78,102 @@ pub fn execute_script(
.map_err(|e| e.to_string()) .map_err(|e| e.to_string())
} }
}); });
// Execute script and process results
let results = vm.compile_and_run_raw_program(script)
.map_err(|e| ExecutionError::RuntimeError(e.to_string()))?;
// Convert results to target type
match target_type {
"STRINGS" => process_string_results(results),
_ => Err(ExecutionError::UnsupportedType(target_type.into()))
} }
fn register_decimal_math_functions(vm: &mut Engine) {
// Basic arithmetic operations
vm.register_fn("decimal-add", decimal_add);
vm.register_fn("decimal-sub", decimal_sub);
vm.register_fn("decimal-mul", decimal_mul);
vm.register_fn("decimal-div", decimal_div);
// Advanced mathematical functions
vm.register_fn("decimal-pow", decimal_pow);
vm.register_fn("decimal-sqrt", decimal_sqrt);
vm.register_fn("decimal-ln", decimal_ln);
vm.register_fn("decimal-log10", decimal_log10);
vm.register_fn("decimal-exp", decimal_exp);
// Trigonometric functions
vm.register_fn("decimal-sin", decimal_sin);
vm.register_fn("decimal-cos", decimal_cos);
vm.register_fn("decimal-tan", decimal_tan);
// Comparison functions
vm.register_fn("decimal-gt", decimal_gt);
vm.register_fn("decimal-lt", decimal_lt);
vm.register_fn("decimal-eq", decimal_eq);
// Utility functions
vm.register_fn("decimal-abs", decimal_abs);
vm.register_fn("decimal-round", decimal_round);
vm.register_fn("decimal-min", decimal_min);
vm.register_fn("decimal-max", decimal_max);
// Additional convenience functions
vm.register_fn("decimal-zero", || "0".to_string());
vm.register_fn("decimal-one", || "1".to_string());
vm.register_fn("decimal-pi", || "3.1415926535897932384626433833".to_string());
vm.register_fn("decimal-e", || "2.7182818284590452353602874714".to_string());
// Type conversion helpers
vm.register_fn("to-decimal", |s: String| -> Result<String, String> {
use rust_decimal::prelude::*;
use std::str::FromStr;
Decimal::from_str(&s)
.map(|d| d.to_string())
.map_err(|e| format!("Invalid decimal: {}", e))
});
// Financial functions
vm.register_fn("decimal-percentage", |amount: String, percentage: String| -> Result<String, String> {
use rust_decimal::prelude::*;
use std::str::FromStr;
let amount_dec = Decimal::from_str(&amount)
.map_err(|e| format!("Invalid amount: {}", e))?;
let percentage_dec = Decimal::from_str(&percentage)
.map_err(|e| format!("Invalid percentage: {}", e))?;
let hundred = Decimal::from(100);
Ok((amount_dec * percentage_dec / hundred).to_string())
});
vm.register_fn("decimal-compound", |principal: String, rate: String, time: String| -> Result<String, String> {
use rust_decimal::prelude::*;
use rust_decimal::MathematicalOps;
use std::str::FromStr;
let principal_dec = Decimal::from_str(&principal)
.map_err(|e| format!("Invalid principal: {}", e))?;
let rate_dec = Decimal::from_str(&rate)
.map_err(|e| format!("Invalid rate: {}", e))?;
let time_dec = Decimal::from_str(&time)
.map_err(|e| format!("Invalid time: {}", e))?;
let one = Decimal::ONE;
let compound_factor = (one + rate_dec).checked_powd(time_dec)
.ok_or("Compound calculation overflow")?;
Ok((principal_dec * compound_factor).to_string())
});
} }
fn process_string_results(results: Vec<SteelVal>) -> Result<Value, ExecutionError> { fn process_string_results(results: Vec<SteelVal>) -> Result<Value, ExecutionError> {
let mut strings = Vec::new(); let mut strings = Vec::new();
for result in results { for result in results {
if let SteelVal::StringV(s) = result { match result {
strings.push(s.to_string()); SteelVal::StringV(s) => strings.push(s.to_string()),
} else { SteelVal::NumV(n) => strings.push(n.to_string()),
SteelVal::IntV(i) => strings.push(i.to_string()),
SteelVal::BoolV(b) => strings.push(b.to_string()),
_ => {
return Err(ExecutionError::TypeConversionError( return Err(ExecutionError::TypeConversionError(
format!("Expected string, got {:?}", result) format!("Expected string-convertible type, got {:?}", result)
)); ));
} }
} }
}
Ok(Value::Strings(strings)) Ok(Value::Strings(strings))
} }

View File

@@ -2,7 +2,9 @@
pub mod execution; pub mod execution;
pub mod syntax_parser; pub mod syntax_parser;
pub mod functions; pub mod functions;
pub mod decimal_math;
pub use execution::*; pub use execution::*;
pub use syntax_parser::*; pub use syntax_parser::*;
pub use functions::*; pub use functions::*;
pub use decimal_math::*;

View File

@@ -1,27 +1,111 @@
// src/steel/server/syntax_parser.rs
use regex::Regex; use regex::Regex;
use std::collections::HashSet; use std::collections::HashSet;
pub struct SyntaxParser { pub struct SyntaxParser {
// Existing patterns for column/SQL integration
current_table_column_re: Regex, current_table_column_re: Regex,
different_table_column_re: Regex, different_table_column_re: Regex,
one_to_many_indexed_re: Regex, one_to_many_indexed_re: Regex,
sql_integration_re: Regex, sql_integration_re: Regex,
// Simple math operation replacement patterns
math_operators: Vec<(Regex, &'static str)>,
number_literal_re: Regex,
} }
impl SyntaxParser { impl SyntaxParser {
pub fn new() -> Self { pub fn new() -> Self {
// Define math operator replacements
let math_operators = vec![
// Basic arithmetic
(Regex::new(r"\(\s*\+\s+").unwrap(), "(decimal-add "),
(Regex::new(r"\(\s*-\s+").unwrap(), "(decimal-sub "),
(Regex::new(r"\(\s*\*\s+").unwrap(), "(decimal-mul "),
(Regex::new(r"\(\s*/\s+").unwrap(), "(decimal-div "),
// Power and advanced operations
(Regex::new(r"\(\s*\^\s+").unwrap(), "(decimal-pow "),
(Regex::new(r"\(\s*\*\*\s+").unwrap(), "(decimal-pow "),
(Regex::new(r"\(\s*pow\s+").unwrap(), "(decimal-pow "),
(Regex::new(r"\(\s*sqrt\s+").unwrap(), "(decimal-sqrt "),
// Logarithmic functions
(Regex::new(r"\(\s*ln\s+").unwrap(), "(decimal-ln "),
(Regex::new(r"\(\s*log\s+").unwrap(), "(decimal-ln "),
(Regex::new(r"\(\s*log10\s+").unwrap(), "(decimal-log10 "),
(Regex::new(r"\(\s*exp\s+").unwrap(), "(decimal-exp "),
// Trigonometric functions
(Regex::new(r"\(\s*sin\s+").unwrap(), "(decimal-sin "),
(Regex::new(r"\(\s*cos\s+").unwrap(), "(decimal-cos "),
(Regex::new(r"\(\s*tan\s+").unwrap(), "(decimal-tan "),
// Comparison operators
(Regex::new(r"\(\s*>\s+").unwrap(), "(decimal-gt "),
(Regex::new(r"\(\s*<\s+").unwrap(), "(decimal-lt "),
(Regex::new(r"\(\s*=\s+").unwrap(), "(decimal-eq "),
(Regex::new(r"\(\s*>=\s+").unwrap(), "(decimal-gte "),
(Regex::new(r"\(\s*<=\s+").unwrap(), "(decimal-lte "),
// Utility functions
(Regex::new(r"\(\s*abs\s+").unwrap(), "(decimal-abs "),
(Regex::new(r"\(\s*min\s+").unwrap(), "(decimal-min "),
(Regex::new(r"\(\s*max\s+").unwrap(), "(decimal-max "),
(Regex::new(r"\(\s*round\s+").unwrap(), "(decimal-round "),
];
SyntaxParser { SyntaxParser {
current_table_column_re: Regex::new(r"@(\w+)").unwrap(), current_table_column_re: Regex::new(r"@(\w+)").unwrap(),
different_table_column_re: Regex::new(r"@(\w+)\.(\w+)").unwrap(), different_table_column_re: Regex::new(r"@(\w+)\.(\w+)").unwrap(),
one_to_many_indexed_re: Regex::new(r"@(\w+)\[(\d+)\]\.(\w+)").unwrap(), one_to_many_indexed_re: Regex::new(r"@(\w+)\[(\d+)\]\.(\w+)").unwrap(),
sql_integration_re: Regex::new(r#"@sql\((['"])(.*?)['"]\)"#).unwrap(), sql_integration_re: Regex::new(r#"@sql\((['"])(.*?)['"]\)"#).unwrap(),
// FIXED: Match negative numbers and avoid already quoted strings
number_literal_re: Regex::new(r#"(?<!")(-?\d+\.?\d*(?:[eE][+-]?\d+)?)(?!")"#).unwrap(),
math_operators,
} }
} }
pub fn parse(&self, script: &str, current_table: &str) -> String { pub fn parse(&self, script: &str, current_table: &str) -> String {
let mut transformed = script.to_string(); let mut transformed = script.to_string();
// Step 1: Convert all numeric literals to strings (FIXED to handle negative numbers)
transformed = self.convert_numbers_to_strings(&transformed);
// Step 2: Replace math function calls with decimal equivalents (SIMPLIFIED)
transformed = self.replace_math_functions(&transformed);
// Step 3: Handle existing column and SQL integrations (unchanged)
transformed = self.process_column_integrations(&transformed, current_table);
transformed
}
/// Convert all unquoted numeric literals to quoted strings
fn convert_numbers_to_strings(&self, script: &str) -> String {
// This regex matches numbers that are NOT already inside quotes
self.number_literal_re.replace_all(script, |caps: &regex::Captures| {
format!("\"{}\"", &caps[1])
}).to_string()
}
/// Replace math function calls with decimal equivalents (SIMPLIFIED)
fn replace_math_functions(&self, script: &str) -> String {
let mut result = script.to_string();
// Apply all math operator replacements
for (pattern, replacement) in &self.math_operators {
result = pattern.replace_all(&result, *replacement).to_string();
}
result
}
/// Process existing column and SQL integrations (unchanged logic)
fn process_column_integrations(&self, script: &str, current_table: &str) -> String {
let mut transformed = script.to_string();
// Process indexed access first to avoid overlap with relationship matches // Process indexed access first to avoid overlap with relationship matches
transformed = self.one_to_many_indexed_re.replace_all(&transformed, |caps: &regex::Captures| { transformed = self.one_to_many_indexed_re.replace_all(&transformed, |caps: &regex::Captures| {
format!("(steel_get_column_with_index \"{}\" {} \"{}\")", format!("(steel_get_column_with_index \"{}\" {} \"{}\")",

View File

@@ -1,4 +1,4 @@
// tests/mod.rs // tests/mod.rs
pub mod tables_data; pub mod tables_data;
pub mod common; pub mod common;
// pub mod table_definition; pub mod table_definition;