From 93c67ffa14d32accbe61bd0cbae78efbb0fa6f2b Mon Sep 17 00:00:00 2001 From: filipriec Date: Wed, 2 Jul 2025 16:31:15 +0200 Subject: [PATCH] steel decimal crate implemented --- Cargo.lock | 12 ++ Cargo.toml | 8 +- server/Cargo.toml | 9 +- steel_decimal/Cargo.toml | 20 +++ steel_decimal/src/decimal_math.rs | 151 +++++++++++++++++++++ steel_decimal/src/lib.rs | 207 +++++++++++++++++++++++++++++ steel_decimal/src/syntax_parser.rs | 83 ++++++++++++ 7 files changed, 485 insertions(+), 5 deletions(-) create mode 100644 steel_decimal/Cargo.toml create mode 100644 steel_decimal/src/decimal_math.rs create mode 100644 steel_decimal/src/lib.rs create mode 100644 steel_decimal/src/syntax_parser.rs diff --git a/Cargo.lock b/Cargo.lock index 8f04441..5a7386f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3569,6 +3569,18 @@ dependencies = [ "smallvec", ] +[[package]] +name = "steel_decimal" +version = "0.3.13" +dependencies = [ + "regex", + "rust_decimal", + "rust_decimal_macros", + "steel-core", + "steel-derive 0.5.0 (git+https://github.com/mattwparas/steel.git?branch=master)", + "thiserror 2.0.12", +] + [[package]] name = "stringprep" version = "0.1.5" diff --git a/Cargo.toml b/Cargo.toml index fcd15d1..abdff03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["client", "server", "common", "search"] +members = ["client", "server", "common", "search", "steel_decimal"] resolver = "2" [workspace.package] @@ -40,4 +40,10 @@ tracing = "0.1.41" # Search crate tantivy = "0.24.1" +# Steel_decimal crate +rust_decimal = { version = "1.37.2", features = ["maths", "serde"] } +rust_decimal_macros = "1.37.1" +thiserror = "2.0.12" +regex = "1.11.1" + common = { path = "./common" } diff --git a/server/Cargo.toml b/server/Cargo.toml index 64b15aa..0c9731d 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -24,17 +24,18 @@ tracing = "0.1.41" time = { version = "0.3.41", features = ["local-offset"] } steel-derive = { git = "https://github.com/mattwparas/steel.git", branch = "master", package = "steel-derive" } steel-core = { git = "https://github.com/mattwparas/steel.git", version = "0.6.0", features = ["anyhow", "dylibs", "sync"] } -thiserror = "2.0.12" dashmap = "6.1.0" lazy_static = "1.5.0" -regex = "1.11.1" bcrypt = "0.17.0" validator = { version = "0.20.0", features = ["derive"] } uuid = { version = "1.16.0", features = ["serde", "v4"] } jsonwebtoken = "9.3.1" rust-stemmers = "1.2.0" -rust_decimal = { version = "1.37.2", features = ["maths", "serde"] } -rust_decimal_macros = "1.37.1" + +rust_decimal = { workspace = true } +rust_decimal_macros = { workspace = true } +regex = { workspace = true } +thiserror = { workspace = true } [lib] name = "server" diff --git a/steel_decimal/Cargo.toml b/steel_decimal/Cargo.toml new file mode 100644 index 0000000..e1f640e --- /dev/null +++ b/steel_decimal/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "steel_decimal" +version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +description.workspace = true +readme.workspace = true +repository.workspace = true +categories.workspace = true + +[dependencies] +steel-derive = { git = "https://github.com/mattwparas/steel.git", branch = "master", package = "steel-derive" } +steel-core = { git = "https://github.com/mattwparas/steel.git", version = "0.6.0", features = ["anyhow", "dylibs", "sync"] } + +rust_decimal = { workspace = true } +rust_decimal_macros = { workspace = true } +regex = { workspace = true } +thiserror = { workspace = true } + diff --git a/steel_decimal/src/decimal_math.rs b/steel_decimal/src/decimal_math.rs new file mode 100644 index 0000000..0c7c114 --- /dev/null +++ b/steel_decimal/src/decimal_math.rs @@ -0,0 +1,151 @@ +// src/decimal_math.rs + +use rust_decimal::prelude::*; +use rust_decimal::MathematicalOps; +use std::str::FromStr; + +pub fn decimal_add(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok((a_dec + b_dec).to_string()) +} + +pub fn decimal_sub(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok((a_dec - b_dec).to_string()) +} + +pub fn decimal_mul(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok((a_dec * b_dec).to_string()) +} + +pub fn decimal_div(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + + if b_dec.is_zero() { + return Err("Division by zero".to_string()); + } + + Ok((a_dec / b_dec).to_string()) +} + +pub fn decimal_pow(base: String, exp: String) -> Result { + let base_dec = Decimal::from_str(&base).map_err(|e| format!("Invalid decimal '{}': {}", base, e))?; + let exp_dec = Decimal::from_str(&exp).map_err(|e| format!("Invalid decimal '{}': {}", exp, e))?; + + base_dec.checked_powd(exp_dec) + .map(|result| result.to_string()) + .ok_or_else(|| "Power operation failed or overflowed".to_string()) +} + +pub fn decimal_sqrt(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.sqrt() + .map(|result| result.to_string()) + .ok_or_else(|| "Square root failed (negative number?)".to_string()) +} + +pub fn decimal_ln(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.checked_ln() + .map(|result| result.to_string()) + .ok_or_else(|| "Natural log failed (non-positive number?)".to_string()) +} + +pub fn decimal_log10(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.checked_log10() + .map(|result| result.to_string()) + .ok_or_else(|| "Log10 failed (non-positive number?)".to_string()) +} + +pub fn decimal_exp(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.checked_exp() + .map(|result| result.to_string()) + .ok_or_else(|| "Exponential failed or overflowed".to_string()) +} + +pub fn decimal_sin(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.checked_sin() + .map(|result| result.to_string()) + .ok_or_else(|| "Sine calculation failed or overflowed".to_string()) +} + +pub fn decimal_cos(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.checked_cos() + .map(|result| result.to_string()) + .ok_or_else(|| "Cosine calculation failed or overflowed".to_string()) +} + +pub fn decimal_tan(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + + a_dec.checked_tan() + .map(|result| result.to_string()) + .ok_or_else(|| "Tangent calculation failed or overflowed".to_string()) +} + +pub fn decimal_gt(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec > b_dec) +} + +pub fn decimal_gte(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec >= b_dec) +} + +pub fn decimal_lt(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec < b_dec) +} + +pub fn decimal_lte(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec <= b_dec) +} + +pub fn decimal_eq(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec == b_dec) +} + +pub fn decimal_abs(a: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + Ok(a_dec.abs().to_string()) +} + +pub fn decimal_round(a: String, places: i32) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + Ok(a_dec.round_dp(places as u32).to_string()) +} + +pub fn decimal_min(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec.min(b_dec).to_string()) +} + +pub fn decimal_max(a: String, b: String) -> Result { + let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; + let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; + Ok(a_dec.max(b_dec).to_string()) +} diff --git a/steel_decimal/src/lib.rs b/steel_decimal/src/lib.rs new file mode 100644 index 0000000..211a6d9 --- /dev/null +++ b/steel_decimal/src/lib.rs @@ -0,0 +1,207 @@ +// src/lib.rs + +use steel::steel_vm::engine::Engine; +use steel::steel_vm::register_fn::RegisterFn; +use steel::rvals::SteelVal; +use std::collections::HashMap; +use thiserror::Error; + +mod decimal_math; +mod syntax_parser; + +pub use decimal_math::*; +use syntax_parser::*; + +#[derive(Debug, Error)] +pub enum SteelDecimalError { + #[error("Script parsing failed: {0}")] + ParseError(String), + #[error("Script execution failed: {0}")] + RuntimeError(String), + #[error("Type conversion error: {0}")] + TypeConversionError(String), +} + +#[derive(Clone, Debug)] +pub struct SteelContext { + pub variables: HashMap, + pub current_table: String, +} + +impl SteelContext { + pub fn new(current_table: String) -> Self { + Self { + variables: HashMap::new(), + current_table, + } + } + + pub fn with_variables(current_table: String, variables: HashMap) -> Self { + Self { + variables, + current_table, + } + } + + pub fn add_variable(&mut self, key: String, value: String) { + self.variables.insert(key, value); + } +} + +#[derive(Debug, Clone)] +pub struct SteelResult { + pub result: String, + pub warnings: Vec, +} + +pub struct SteelDecimalEngine { + parser: SyntaxParser, +} + +impl SteelDecimalEngine { + pub fn new() -> Self { + Self { + parser: SyntaxParser::new(), + } + } + + pub fn parse_script(&self, script: &str, context: &SteelContext) -> Result { + Ok(self.parser.parse(script, &context.current_table)) + } + + pub fn execute_script(&self, script: &str, context: &SteelContext) -> Result { + let transformed_script = self.parser.parse(script, &context.current_table); + + let mut vm = Engine::new(); + self.register_decimal_functions(&mut vm); + self.register_context_functions(&mut vm, context); + + let results = vm.compile_and_run_raw_program(transformed_script) + .map_err(|e| SteelDecimalError::RuntimeError(e.to_string()))?; + + let result_string = self.convert_result_to_string(results)?; + + Ok(SteelResult { + result: result_string, + warnings: Vec::new(), + }) + } + + fn register_decimal_functions(&self, vm: &mut Engine) { + vm.register_fn("decimal-add", decimal_add); + vm.register_fn("decimal-sub", decimal_sub); + vm.register_fn("decimal-mul", decimal_mul); + vm.register_fn("decimal-div", decimal_div); + vm.register_fn("decimal-pow", decimal_pow); + vm.register_fn("decimal-sqrt", decimal_sqrt); + vm.register_fn("decimal-ln", decimal_ln); + vm.register_fn("decimal-log10", decimal_log10); + vm.register_fn("decimal-exp", decimal_exp); + vm.register_fn("decimal-sin", decimal_sin); + vm.register_fn("decimal-cos", decimal_cos); + vm.register_fn("decimal-tan", decimal_tan); + vm.register_fn("decimal-gt", decimal_gt); + vm.register_fn("decimal-gte", decimal_gte); + vm.register_fn("decimal-lt", decimal_lt); + vm.register_fn("decimal-lte", decimal_lte); + vm.register_fn("decimal-eq", decimal_eq); + vm.register_fn("decimal-abs", decimal_abs); + vm.register_fn("decimal-round", decimal_round); + vm.register_fn("decimal-min", decimal_min); + vm.register_fn("decimal-max", decimal_max); + vm.register_fn("decimal-zero", || "0".to_string()); + vm.register_fn("decimal-one", || "1".to_string()); + vm.register_fn("decimal-pi", || "3.1415926535897932384626433833".to_string()); + vm.register_fn("decimal-e", || "2.7182818284590452353602874714".to_string()); + vm.register_fn("decimal-percentage", decimal_percentage); + vm.register_fn("decimal-compound", decimal_compound); + } + + fn register_context_functions(&self, vm: &mut Engine, context: &SteelContext) { + let variables = context.variables.clone(); + vm.register_fn("get-var", move |var_name: String| -> Result { + variables.get(&var_name) + .cloned() + .ok_or_else(|| format!("Variable '{}' not found", var_name)) + }); + + let variables_check = context.variables.clone(); + vm.register_fn("has-var?", move |var_name: String| -> bool { + variables_check.contains_key(&var_name) + }); + } + + fn convert_result_to_string(&self, results: Vec) -> Result { + if results.is_empty() { + return Ok(String::new()); + } + + let last_result = &results[results.len() - 1]; + + match last_result { + SteelVal::StringV(s) => Ok(s.to_string()), + SteelVal::NumV(n) => Ok(n.to_string()), + SteelVal::IntV(i) => Ok(i.to_string()), + SteelVal::BoolV(b) => Ok(b.to_string()), + SteelVal::VectorV(v) => { + let string_values: Result, _> = v.iter() + .map(|val| match val { + SteelVal::StringV(s) => Ok(s.to_string()), + SteelVal::NumV(n) => Ok(n.to_string()), + SteelVal::IntV(i) => Ok(i.to_string()), + SteelVal::BoolV(b) => Ok(b.to_string()), + _ => Err(SteelDecimalError::TypeConversionError( + format!("Cannot convert vector element to string: {:?}", val) + )) + }) + .collect(); + + match string_values { + Ok(strings) => Ok(strings.join(",")), + Err(e) => Err(e), + } + } + _ => Err(SteelDecimalError::TypeConversionError( + format!("Cannot convert result to string: {:?}", last_result) + )) + } + } +} + +impl Default for SteelDecimalEngine { + fn default() -> Self { + Self::new() + } +} + +fn decimal_percentage(amount: String, percentage: String) -> Result { + use rust_decimal::prelude::*; + use std::str::FromStr; + + let amount_dec = Decimal::from_str(&amount) + .map_err(|e| format!("Invalid amount: {}", e))?; + let percentage_dec = Decimal::from_str(&percentage) + .map_err(|e| format!("Invalid percentage: {}", e))?; + let hundred = Decimal::from(100); + + Ok((amount_dec * percentage_dec / hundred).to_string()) +} + +fn decimal_compound(principal: String, rate: String, time: String) -> Result { + use rust_decimal::prelude::*; + use rust_decimal::MathematicalOps; + use std::str::FromStr; + + let principal_dec = Decimal::from_str(&principal) + .map_err(|e| format!("Invalid principal: {}", e))?; + let rate_dec = Decimal::from_str(&rate) + .map_err(|e| format!("Invalid rate: {}", e))?; + let time_dec = Decimal::from_str(&time) + .map_err(|e| format!("Invalid time: {}", e))?; + + let one = Decimal::ONE; + let compound_factor = (one + rate_dec).checked_powd(time_dec) + .ok_or("Compound calculation overflow")?; + + Ok((principal_dec * compound_factor).to_string()) +} diff --git a/steel_decimal/src/syntax_parser.rs b/steel_decimal/src/syntax_parser.rs new file mode 100644 index 0000000..4f05076 --- /dev/null +++ b/steel_decimal/src/syntax_parser.rs @@ -0,0 +1,83 @@ +// src/syntax_parser.rs + +use regex::Regex; + +pub struct SyntaxParser { + math_operators: Vec<(Regex, &'static str)>, + number_literal_re: Regex, +} + +impl SyntaxParser { + pub fn new() -> Self { + let math_operators = vec![ + (Regex::new(r"\(\s*\+\s+").unwrap(), "(decimal-add "), + (Regex::new(r"\(\s*-\s+").unwrap(), "(decimal-sub "), + (Regex::new(r"\(\s*\*\s+").unwrap(), "(decimal-mul "), + (Regex::new(r"\(\s*/\s+").unwrap(), "(decimal-div "), + (Regex::new(r"\(\s*\^\s+").unwrap(), "(decimal-pow "), + (Regex::new(r"\(\s*\*\*\s+").unwrap(), "(decimal-pow "), + (Regex::new(r"\(\s*pow\s+").unwrap(), "(decimal-pow "), + (Regex::new(r"\(\s*sqrt\s+").unwrap(), "(decimal-sqrt "), + (Regex::new(r"\(\s*ln\s+").unwrap(), "(decimal-ln "), + (Regex::new(r"\(\s*log\s+").unwrap(), "(decimal-ln "), + (Regex::new(r"\(\s*log10\s+").unwrap(), "(decimal-log10 "), + (Regex::new(r"\(\s*exp\s+").unwrap(), "(decimal-exp "), + (Regex::new(r"\(\s*sin\s+").unwrap(), "(decimal-sin "), + (Regex::new(r"\(\s*cos\s+").unwrap(), "(decimal-cos "), + (Regex::new(r"\(\s*tan\s+").unwrap(), "(decimal-tan "), + (Regex::new(r"\(\s*>\s+").unwrap(), "(decimal-gt "), + (Regex::new(r"\(\s*<\s+").unwrap(), "(decimal-lt "), + (Regex::new(r"\(\s*=\s+").unwrap(), "(decimal-eq "), + (Regex::new(r"\(\s*>=\s+").unwrap(), "(decimal-gte "), + (Regex::new(r"\(\s*<=\s+").unwrap(), "(decimal-lte "), + (Regex::new(r"\(\s*abs\s+").unwrap(), "(decimal-abs "), + (Regex::new(r"\(\s*min\s+").unwrap(), "(decimal-min "), + (Regex::new(r"\(\s*max\s+").unwrap(), "(decimal-max "), + (Regex::new(r"\(\s*round\s+").unwrap(), "(decimal-round "), + ]; + + SyntaxParser { + number_literal_re: Regex::new(r#"(? String { + let mut transformed = script.to_string(); + + transformed = self.convert_numbers_to_strings(&transformed); + transformed = self.replace_math_functions(&transformed); + transformed = self.replace_variable_references(&transformed); + + transformed + } + + fn convert_numbers_to_strings(&self, script: &str) -> String { + self.number_literal_re.replace_all(script, |caps: ®ex::Captures| { + format!("\"{}\"", &caps[1]) + }).to_string() + } + + fn replace_math_functions(&self, script: &str) -> String { + let mut result = script.to_string(); + + for (pattern, replacement) in &self.math_operators { + result = pattern.replace_all(&result, *replacement).to_string(); + } + + result + } + + fn replace_variable_references(&self, script: &str) -> String { + let var_re = Regex::new(r"\$(\w+)").unwrap(); + var_re.replace_all(script, |caps: ®ex::Captures| { + format!("(get-var \"{}\")", &caps[1]) + }).to_string() + } +} + +impl Default for SyntaxParser { + fn default() -> Self { + Self::new() + } +}