diff --git a/Cargo.lock b/Cargo.lock index 5a7386f..26779ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3086,6 +3086,7 @@ dependencies = [ "sqlx", "steel-core", "steel-derive 0.5.0 (git+https://github.com/mattwparas/steel.git?branch=master)", + "steel_decimal", "tantivy", "thiserror 2.0.12", "time", diff --git a/server/Cargo.toml b/server/Cargo.toml index 0c9731d..a526868 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -7,6 +7,7 @@ license = "AGPL-3.0-or-later" [dependencies] common = { path = "../common" } search = { path = "../search" } +steel_decimal = { path = "../steel_decimal" } anyhow = { workspace = true } tantivy = { workspace = true } diff --git a/server/src/table_script/handlers/post_table_script.rs b/server/src/table_script/handlers/post_table_script.rs index f590b92..d09e06c 100644 --- a/server/src/table_script/handlers/post_table_script.rs +++ b/server/src/table_script/handlers/post_table_script.rs @@ -3,11 +3,11 @@ use tonic::Status; use sqlx::{PgPool, Error as SqlxError}; use common::proto::multieko2::table_script::{PostTableScriptRequest, TableScriptResponse}; use serde_json::Value; -use crate::steel::server::syntax_parser::SyntaxParser; +use steel_decimal::SteelDecimal; const SYSTEM_COLUMNS: &[&str] = &["id", "deleted", "created_at"]; -// TODO MAKE SCRIPT PUSH ONLY TO THE EMPTY TABLES +// TODO MAKE SCRIPT PUSH ONLY TO THE EMPTY TABLES /// Validates the target column and ensures it is not a system column. /// Returns the column type if valid. fn validate_target_column( @@ -68,9 +68,11 @@ pub async fn post_table_script( ) .map_err(|e| Status::invalid_argument(e))?; - // Parse and transform the script using the syntax parser - let parser = SyntaxParser::new(); - let parsed_script = parser.parse(&request.script, &table_def.table_name); + // Use the steel_decimal for script transformation + let steel_decimal = SteelDecimal::new(); + + // Transform the script using steel_decimal (no context needed for basic transformation) + let parsed_script = steel_decimal.transform(&request.script); // Insert the script into the database let script_record = sqlx::query!( diff --git a/steel_decimal/Cargo.toml b/steel_decimal/Cargo.toml index e1f640e..edefd9c 100644 --- a/steel_decimal/Cargo.toml +++ b/steel_decimal/Cargo.toml @@ -2,7 +2,7 @@ name = "steel_decimal" version.workspace = true edition.workspace = true -license.workspace = true +license = "MIT OR Apache-2.0" authors.workspace = true description.workspace = true readme.workspace = true diff --git a/steel_decimal/src/examples/basic_usage.rs b/steel_decimal/src/examples/basic_usage.rs new file mode 100644 index 0000000..723296b --- /dev/null +++ b/steel_decimal/src/examples/basic_usage.rs @@ -0,0 +1,61 @@ +// examples/basic_usage.rs +use steel_decimal::SteelDecimal; +use steel::steel_vm::engine::Engine; + +fn main() { + // Create a new Steel Decimal engine + let steel_decimal = SteelDecimal::new(); + + // Transform a simple math expression + let script = "(+ 1.5 2.3)"; + let transformed = steel_decimal.transform(script); + println!("Original: {}", script); + println!("Transformed: {}", transformed); + + // Create Steel VM and register functions + let mut vm = Engine::new(); + steel_decimal.register_functions(&mut vm); + + // Execute the transformed script + match vm.compile_and_run_raw_program(transformed) { + Ok(results) => { + println!("Results: {:?}", results); + if let Some(last_result) = results.last() { + println!("Final result: {:?}", last_result); + } + } + Err(e) => { + println!("Error: {}", e); + } + } + + // Try a more complex expression + let complex_script = "(+ (* 2.5 3.0) (/ 15.0 3.0))"; + let complex_transformed = steel_decimal.transform(complex_script); + println!("\nComplex original: {}", complex_script); + println!("Complex transformed: {}", complex_transformed); + + match vm.compile_and_run_raw_program(complex_transformed) { + Ok(results) => { + if let Some(last_result) = results.last() { + println!("Complex result: {:?}", last_result); + } + } + Err(e) => { + println!("Error: {}", e); + } + } + + // Using the convenience method + println!("\nUsing convenience method:"); + match steel_decimal.parse_and_execute("(+ 10.5 20.3)") { + Ok(results) => { + if let Some(last_result) = results.last() { + println!("Convenience result: {:?}", last_result); + } + } + Err(e) => { + println!("Error: {}", e); + } + } +} diff --git a/steel_decimal/src/examples/selective_registration.rs b/steel_decimal/src/examples/selective_registration.rs new file mode 100644 index 0000000..dbe05e0 --- /dev/null +++ b/steel_decimal/src/examples/selective_registration.rs @@ -0,0 +1,131 @@ +// examples/selective_registration.rs +use steel_decimal::{SteelDecimal, FunctionRegistryBuilder}; +use steel::steel_vm::engine::Engine; +use std::collections::HashMap; + +fn main() { + let steel_decimal = SteelDecimal::new(); + + println!("=== Basic Arithmetic Only ==="); + let mut vm1 = Engine::new(); + FunctionRegistryBuilder::new() + .basic_arithmetic(true) + .advanced_math(false) + .trigonometric(false) + .comparison(false) + .utility(false) + .constants(false) + .financial(false) + .register(&mut vm1); + + // Test basic arithmetic + let script = "(+ 10.5 20.3)"; + let transformed = steel_decimal.transform(script); + match vm1.compile_and_run_raw_program(transformed) { + Ok(results) => { + if let Some(result) = results.last() { + println!("Basic arithmetic result: {:?}", result); + } + } + Err(e) => println!("Error: {}", e), + } + + println!("\n=== With Advanced Math ==="); + let mut vm2 = Engine::new(); + FunctionRegistryBuilder::new() + .basic_arithmetic(true) + .advanced_math(true) + .trigonometric(false) + .register(&mut vm2); + + // Test power function + let power_script = "(^ 2.0 3.0)"; + let power_transformed = steel_decimal.transform(power_script); + match vm2.compile_and_run_raw_program(power_transformed) { + Ok(results) => { + if let Some(result) = results.last() { + println!("Power result: {:?}", result); + } + } + Err(e) => println!("Error: {}", e), + } + + // Test square root + let sqrt_script = "(sqrt 16.0)"; + let sqrt_transformed = steel_decimal.transform(sqrt_script); + match vm2.compile_and_run_raw_program(sqrt_transformed) { + Ok(results) => { + if let Some(result) = results.last() { + println!("Square root result: {:?}", result); + } + } + Err(e) => println!("Error: {}", e), + } + + println!("\n=== With Variables ==="); + let mut variables = HashMap::new(); + variables.insert("radius".to_string(), "5.0".to_string()); + variables.insert("pi".to_string(), "3.14159".to_string()); + + let mut vm3 = Engine::new(); + FunctionRegistryBuilder::new() + .basic_arithmetic(true) + .advanced_math(true) + .constants(true) + .with_variables(variables) + .register(&mut vm3); + + // Calculate area of circle using variables + let area_script = "(* $pi (* $radius $radius))"; + let area_transformed = steel_decimal.transform(area_script); + println!("Area script: {}", area_script); + println!("Transformed: {}", area_transformed); + + match vm3.compile_and_run_raw_program(area_transformed) { + Ok(results) => { + if let Some(result) = results.last() { + println!("Circle area result: {:?}", result); + } + } + Err(e) => println!("Error: {}", e), + } + + println!("\n=== Financial Functions ==="); + let mut vm4 = Engine::new(); + FunctionRegistryBuilder::new() + .basic_arithmetic(true) + .financial(true) + .register(&mut vm4); + + // Test percentage calculation + let percent_script = r#"(decimal-percentage "1000.00" "15.0")"#; + match vm4.compile_and_run_raw_program(percent_script.to_string()) { + Ok(results) => { + if let Some(result) = results.last() { + println!("15% of 1000: {:?}", result); + } + } + Err(e) => println!("Error: {}", e), + } + + // Test compound interest + let compound_script = r#"(decimal-compound "1000.00" "0.05" "10.0")"#; + match vm4.compile_and_run_raw_program(compound_script.to_string()) { + Ok(results) => { + if let Some(result) = results.last() { + println!("Compound interest (1000 @ 5% for 10 years): {:?}", result); + } + } + Err(e) => println!("Error: {}", e), + } + + println!("\n=== Available Functions ==="); + let function_names = steel_decimal::FunctionRegistry::get_function_names(); + for (i, name) in function_names.iter().enumerate() { + if i % 5 == 0 { + println!(); + } + print!("{:<18}", name); + } + println!(); +} diff --git a/steel_decimal/src/examples/with_variables.rs b/steel_decimal/src/examples/with_variables.rs new file mode 100644 index 0000000..f0fda0f --- /dev/null +++ b/steel_decimal/src/examples/with_variables.rs @@ -0,0 +1,73 @@ +// examples/with_variables.rs +use steel_decimal::SteelDecimal; +use steel::steel_vm::engine::Engine; +use std::collections::HashMap; + +fn main() { + // Create variables + let mut variables = HashMap::new(); + variables.insert("price".to_string(), "29.99".to_string()); + variables.insert("quantity".to_string(), "5".to_string()); + variables.insert("tax_rate".to_string(), "0.08".to_string()); + + // Create Steel Decimal with variables + let steel_decimal = SteelDecimal::with_variables(variables); + + // Script using variables + let script = "(+ (* $price $quantity) (* (* $price $quantity) $tax_rate))"; + let transformed = steel_decimal.transform(script); + + println!("Original script: {}", script); + println!("Transformed script: {}", transformed); + + // Create VM and register functions + let mut vm = Engine::new(); + steel_decimal.register_functions(&mut vm); + + // Execute the script + match vm.compile_and_run_raw_program(transformed) { + Ok(results) => { + if let Some(last_result) = results.last() { + println!("Total with tax: {:?}", last_result); + } + } + Err(e) => { + println!("Error: {}", e); + } + } + + // Demonstrate adding variables dynamically + let mut steel_decimal = SteelDecimal::new(); + steel_decimal.add_variable("x".to_string(), "10.5".to_string()); + steel_decimal.add_variable("y".to_string(), "20.3".to_string()); + + let simple_script = "(+ $x $y)"; + println!("\nSimple script: {}", simple_script); + + match steel_decimal.parse_and_execute(simple_script) { + Ok(results) => { + if let Some(last_result) = results.last() { + println!("Simple result: {:?}", last_result); + } + } + Err(e) => { + println!("Error: {}", e); + } + } + + // Show variable validation + println!("\nVariable validation:"); + match steel_decimal.validate_script("(+ $x $y)") { + Ok(()) => println!("Script is valid"), + Err(e) => println!("Script error: {}", e), + } + + match steel_decimal.validate_script("(+ $x $undefined_var)") { + Ok(()) => println!("Script is valid"), + Err(e) => println!("Script error: {}", e), + } + + // Extract dependencies + let dependencies = steel_decimal.extract_dependencies("(+ $x $y $z)"); + println!("Dependencies: {:?}", dependencies); +} diff --git a/steel_decimal/src/decimal_math.rs b/steel_decimal/src/functions.rs similarity index 78% rename from steel_decimal/src/decimal_math.rs rename to steel_decimal/src/functions.rs index 0c7c114..84c3b43 100644 --- a/steel_decimal/src/decimal_math.rs +++ b/steel_decimal/src/functions.rs @@ -1,9 +1,9 @@ -// src/decimal_math.rs - +// src/functions.rs use rust_decimal::prelude::*; use rust_decimal::MathematicalOps; use std::str::FromStr; +// Basic arithmetic operations pub fn decimal_add(a: String, b: String) -> Result { let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; @@ -33,6 +33,7 @@ pub fn decimal_div(a: String, b: String) -> Result { Ok((a_dec / b_dec).to_string()) } +// Advanced mathematical functions pub fn decimal_pow(base: String, exp: String) -> Result { let base_dec = Decimal::from_str(&base).map_err(|e| format!("Invalid decimal '{}': {}", base, e))?; let exp_dec = Decimal::from_str(&exp).map_err(|e| format!("Invalid decimal '{}': {}", exp, e))?; @@ -74,6 +75,7 @@ pub fn decimal_exp(a: String) -> Result { .ok_or_else(|| "Exponential failed or overflowed".to_string()) } +// Trigonometric functions pub fn decimal_sin(a: String) -> Result { let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; @@ -98,6 +100,7 @@ pub fn decimal_tan(a: String) -> Result { .ok_or_else(|| "Tangent calculation failed or overflowed".to_string()) } +// Comparison functions pub fn decimal_gt(a: String, b: String) -> Result { let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; @@ -128,6 +131,7 @@ pub fn decimal_eq(a: String, b: String) -> Result { Ok(a_dec == b_dec) } +// Utility functions pub fn decimal_abs(a: String) -> Result { let a_dec = Decimal::from_str(&a).map_err(|e| format!("Invalid decimal '{}': {}", a, e))?; Ok(a_dec.abs().to_string()) @@ -149,3 +153,53 @@ pub fn decimal_max(a: String, b: String) -> Result { let b_dec = Decimal::from_str(&b).map_err(|e| format!("Invalid decimal '{}': {}", b, e))?; Ok(a_dec.max(b_dec).to_string()) } + +// Constants +pub fn decimal_zero() -> String { + "0".to_string() +} + +pub fn decimal_one() -> String { + "1".to_string() +} + +pub fn decimal_pi() -> String { + "3.1415926535897932384626433833".to_string() +} + +pub fn decimal_e() -> String { + "2.7182818284590452353602874714".to_string() +} + +// Financial functions +pub fn decimal_percentage(amount: String, percentage: String) -> Result { + let amount_dec = Decimal::from_str(&amount) + .map_err(|e| format!("Invalid amount: {}", e))?; + let percentage_dec = Decimal::from_str(&percentage) + .map_err(|e| format!("Invalid percentage: {}", e))?; + let hundred = Decimal::from(100); + + Ok((amount_dec * percentage_dec / hundred).to_string()) +} + +pub fn decimal_compound(principal: String, rate: String, time: String) -> Result { + let principal_dec = Decimal::from_str(&principal) + .map_err(|e| format!("Invalid principal: {}", e))?; + let rate_dec = Decimal::from_str(&rate) + .map_err(|e| format!("Invalid rate: {}", e))?; + let time_dec = Decimal::from_str(&time) + .map_err(|e| format!("Invalid time: {}", e))?; + + let one = Decimal::ONE; + let compound_factor = (one + rate_dec).checked_powd(time_dec) + .ok_or("Compound calculation overflow")?; + + Ok((principal_dec * compound_factor).to_string()) +} + +// Type conversion helper +pub fn to_decimal(s: String) -> Result { + Decimal::from_str(&s) + .map(|d| d.to_string()) + .map_err(|e| format!("Invalid decimal: {}", e)) +} diff --git a/steel_decimal/src/lib.rs b/steel_decimal/src/lib.rs index 211a6d9..063a140 100644 --- a/steel_decimal/src/lib.rs +++ b/steel_decimal/src/lib.rs @@ -1,207 +1,279 @@ // src/lib.rs +//! # Steel Decimal +//! +//! A Rust library that provides decimal arithmetic support for the Steel programming language. +//! This crate transforms Steel scripts to use high-precision decimal operations and provides +//! function registration utilities for Steel VMs. +//! +//! ## Quick Start +//! +//! ### Basic Usage +//! ```rust +//! use steel_decimal::SteelDecimal; +//! use steel::steel_vm::engine::Engine; +//! +//! // Create a new Steel Decimal engine +//! let steel_decimal = SteelDecimal::new(); +//! +//! // Transform a script +//! let script = "(+ 1.5 2.3)"; +//! let transformed = steel_decimal.transform(script); +//! // Result: "(decimal-add \"1.5\" \"2.3\")" +//! +//! // Register functions with Steel VM +//! let mut vm = Engine::new(); +//! steel_decimal.register_functions(&mut vm); +//! +//! // Now you can execute the transformed script +//! let result = vm.compile_and_run_raw_program(transformed); +//! ``` +//! +//! ### With Variables +//! ```rust +//! use steel_decimal::SteelDecimal; +//! use steel::steel_vm::engine::Engine; +//! use std::collections::HashMap; +//! +//! let mut variables = HashMap::new(); +//! variables.insert("x".to_string(), "10.5".to_string()); +//! variables.insert("y".to_string(), "20.3".to_string()); +//! +//! let steel_decimal = SteelDecimal::with_variables(variables); +//! +//! let script = "(+ $x $y)"; +//! let transformed = steel_decimal.transform(script); +//! +//! let mut vm = Engine::new(); +//! steel_decimal.register_functions(&mut vm); +//! ``` +//! +//! ### Selective Function Registration +//! ```rust +//! use steel_decimal::FunctionRegistryBuilder; +//! use steel::steel_vm::engine::Engine; +//! +//! let mut vm = Engine::new(); +//! FunctionRegistryBuilder::new() +//! .basic_arithmetic(true) +//! .advanced_math(false) +//! .trigonometric(false) +//! .register(&mut vm); +//! ``` + +pub mod functions; +pub mod parser; +pub mod registry; +pub mod utils; + +pub use functions::*; +pub use parser::ScriptParser; +pub use registry::{FunctionRegistry, FunctionRegistryBuilder}; +pub use utils::{TypeConverter, ScriptAnalyzer, DecimalPrecision, ConversionError}; -use steel::steel_vm::engine::Engine; -use steel::steel_vm::register_fn::RegisterFn; -use steel::rvals::SteelVal; use std::collections::HashMap; -use thiserror::Error; +use steel::steel_vm::engine::Engine; -mod decimal_math; -mod syntax_parser; - -pub use decimal_math::*; -use syntax_parser::*; - -#[derive(Debug, Error)] -pub enum SteelDecimalError { - #[error("Script parsing failed: {0}")] - ParseError(String), - #[error("Script execution failed: {0}")] - RuntimeError(String), - #[error("Type conversion error: {0}")] - TypeConversionError(String), +/// Main entry point for the Steel Decimal library +pub struct SteelDecimal { + parser: ScriptParser, + variables: HashMap, } -#[derive(Clone, Debug)] -pub struct SteelContext { - pub variables: HashMap, - pub current_table: String, -} - -impl SteelContext { - pub fn new(current_table: String) -> Self { - Self { - variables: HashMap::new(), - current_table, - } - } - - pub fn with_variables(current_table: String, variables: HashMap) -> Self { - Self { - variables, - current_table, - } - } - - pub fn add_variable(&mut self, key: String, value: String) { - self.variables.insert(key, value); - } -} - -#[derive(Debug, Clone)] -pub struct SteelResult { - pub result: String, - pub warnings: Vec, -} - -pub struct SteelDecimalEngine { - parser: SyntaxParser, -} - -impl SteelDecimalEngine { +impl SteelDecimal { + /// Create a new SteelDecimal instance pub fn new() -> Self { Self { - parser: SyntaxParser::new(), + parser: ScriptParser::new(), + variables: HashMap::new(), } } - pub fn parse_script(&self, script: &str, context: &SteelContext) -> Result { - Ok(self.parser.parse(script, &context.current_table)) + /// Create a new SteelDecimal instance with variables + pub fn with_variables(variables: HashMap) -> Self { + Self { + parser: ScriptParser::new(), + variables, + } } - pub fn execute_script(&self, script: &str, context: &SteelContext) -> Result { - let transformed_script = self.parser.parse(script, &context.current_table); + /// Transform a script by converting math operations to decimal functions + pub fn transform(&self, script: &str) -> String { + self.parser.transform(script) + } + /// Register all decimal functions with a Steel VM + pub fn register_functions(&self, vm: &mut Engine) { + FunctionRegistry::register_all(vm); + + if !self.variables.is_empty() { + FunctionRegistry::register_variables(vm, self.variables.clone()); + } + } + + /// Register functions selectively using a builder + pub fn register_functions_with_builder(&self, vm: &mut Engine, builder: FunctionRegistryBuilder) { + let builder = if !self.variables.is_empty() { + builder.with_variables(self.variables.clone()) + } else { + builder + }; + + builder.register(vm); + } + + /// Add a variable to the context + pub fn add_variable(&mut self, name: String, value: String) { + self.variables.insert(name, value); + } + + /// Get all variables + pub fn get_variables(&self) -> &HashMap { + &self.variables + } + + /// Extract dependencies from a script + pub fn extract_dependencies(&self, script: &str) -> std::collections::HashSet { + self.parser.extract_dependencies(script) + } + + /// Parse and execute a script in one step (convenience method) + pub fn parse_and_execute(&self, script: &str) -> Result, String> { let mut vm = Engine::new(); - self.register_decimal_functions(&mut vm); - self.register_context_functions(&mut vm, context); - - let results = vm.compile_and_run_raw_program(transformed_script) - .map_err(|e| SteelDecimalError::RuntimeError(e.to_string()))?; - - let result_string = self.convert_result_to_string(results)?; - - Ok(SteelResult { - result: result_string, - warnings: Vec::new(), - }) + self.register_functions(&mut vm); + + let transformed = self.transform(script); + + vm.compile_and_run_raw_program(transformed) + .map_err(|e| e.to_string()) } - fn register_decimal_functions(&self, vm: &mut Engine) { - vm.register_fn("decimal-add", decimal_add); - vm.register_fn("decimal-sub", decimal_sub); - vm.register_fn("decimal-mul", decimal_mul); - vm.register_fn("decimal-div", decimal_div); - vm.register_fn("decimal-pow", decimal_pow); - vm.register_fn("decimal-sqrt", decimal_sqrt); - vm.register_fn("decimal-ln", decimal_ln); - vm.register_fn("decimal-log10", decimal_log10); - vm.register_fn("decimal-exp", decimal_exp); - vm.register_fn("decimal-sin", decimal_sin); - vm.register_fn("decimal-cos", decimal_cos); - vm.register_fn("decimal-tan", decimal_tan); - vm.register_fn("decimal-gt", decimal_gt); - vm.register_fn("decimal-gte", decimal_gte); - vm.register_fn("decimal-lt", decimal_lt); - vm.register_fn("decimal-lte", decimal_lte); - vm.register_fn("decimal-eq", decimal_eq); - vm.register_fn("decimal-abs", decimal_abs); - vm.register_fn("decimal-round", decimal_round); - vm.register_fn("decimal-min", decimal_min); - vm.register_fn("decimal-max", decimal_max); - vm.register_fn("decimal-zero", || "0".to_string()); - vm.register_fn("decimal-one", || "1".to_string()); - vm.register_fn("decimal-pi", || "3.1415926535897932384626433833".to_string()); - vm.register_fn("decimal-e", || "2.7182818284590452353602874714".to_string()); - vm.register_fn("decimal-percentage", decimal_percentage); - vm.register_fn("decimal-compound", decimal_compound); - } - - fn register_context_functions(&self, vm: &mut Engine, context: &SteelContext) { - let variables = context.variables.clone(); - vm.register_fn("get-var", move |var_name: String| -> Result { - variables.get(&var_name) - .cloned() - .ok_or_else(|| format!("Variable '{}' not found", var_name)) - }); - - let variables_check = context.variables.clone(); - vm.register_fn("has-var?", move |var_name: String| -> bool { - variables_check.contains_key(&var_name) - }); - } - - fn convert_result_to_string(&self, results: Vec) -> Result { - if results.is_empty() { - return Ok(String::new()); + /// Validate that a script can be transformed without errors + pub fn validate_script(&self, script: &str) -> Result<(), String> { + // Basic validation - check if transformation succeeds + let transformed = self.transform(script); + + // Check if the script contains balanced parentheses + let open_count = transformed.chars().filter(|c| *c == '(').count(); + let close_count = transformed.chars().filter(|c| *c == ')').count(); + + if open_count != close_count { + return Err("Unbalanced parentheses in script".to_string()); } - - let last_result = &results[results.len() - 1]; - - match last_result { - SteelVal::StringV(s) => Ok(s.to_string()), - SteelVal::NumV(n) => Ok(n.to_string()), - SteelVal::IntV(i) => Ok(i.to_string()), - SteelVal::BoolV(b) => Ok(b.to_string()), - SteelVal::VectorV(v) => { - let string_values: Result, _> = v.iter() - .map(|val| match val { - SteelVal::StringV(s) => Ok(s.to_string()), - SteelVal::NumV(n) => Ok(n.to_string()), - SteelVal::IntV(i) => Ok(i.to_string()), - SteelVal::BoolV(b) => Ok(b.to_string()), - _ => Err(SteelDecimalError::TypeConversionError( - format!("Cannot convert vector element to string: {:?}", val) - )) - }) - .collect(); - - match string_values { - Ok(strings) => Ok(strings.join(",")), - Err(e) => Err(e), - } + + // Check if all variables are defined + let dependencies = self.extract_dependencies(script); + for dep in dependencies { + if !self.variables.contains_key(&dep) { + return Err(format!("Undefined variable: {}", dep)); } - _ => Err(SteelDecimalError::TypeConversionError( - format!("Cannot convert result to string: {:?}", last_result) - )) } + + Ok(()) } } -impl Default for SteelDecimalEngine { +impl Default for SteelDecimal { fn default() -> Self { Self::new() } } -fn decimal_percentage(amount: String, percentage: String) -> Result { - use rust_decimal::prelude::*; - use std::str::FromStr; - - let amount_dec = Decimal::from_str(&amount) - .map_err(|e| format!("Invalid amount: {}", e))?; - let percentage_dec = Decimal::from_str(&percentage) - .map_err(|e| format!("Invalid percentage: {}", e))?; - let hundred = Decimal::from(100); - - Ok((amount_dec * percentage_dec / hundred).to_string()) +/// Convenience functions for quick operations +pub mod prelude { + pub use crate::{SteelDecimal, FunctionRegistry, FunctionRegistryBuilder}; + pub use crate::functions::*; + pub use crate::utils::{TypeConverter, ScriptAnalyzer, DecimalPrecision}; } -fn decimal_compound(principal: String, rate: String, time: String) -> Result { - use rust_decimal::prelude::*; - use rust_decimal::MathematicalOps; - use std::str::FromStr; +#[cfg(test)] +mod tests { + use super::*; - let principal_dec = Decimal::from_str(&principal) - .map_err(|e| format!("Invalid principal: {}", e))?; - let rate_dec = Decimal::from_str(&rate) - .map_err(|e| format!("Invalid rate: {}", e))?; - let time_dec = Decimal::from_str(&time) - .map_err(|e| format!("Invalid time: {}", e))?; + #[test] + fn test_basic_transformation() { + let steel_decimal = SteelDecimal::new(); + let script = "(+ 1.5 2.3)"; + let transformed = steel_decimal.transform(script); + assert_eq!(transformed, "(decimal-add \"1.5\" \"2.3\")"); + } - let one = Decimal::ONE; - let compound_factor = (one + rate_dec).checked_powd(time_dec) - .ok_or("Compound calculation overflow")?; + #[test] + fn test_with_variables() { + let mut variables = HashMap::new(); + variables.insert("x".to_string(), "10.5".to_string()); + variables.insert("y".to_string(), "20.3".to_string()); - Ok((principal_dec * compound_factor).to_string()) + let steel_decimal = SteelDecimal::with_variables(variables); + let script = "(+ $x $y)"; + let transformed = steel_decimal.transform(script); + assert_eq!(transformed, "(decimal-add (get-var \"x\") (get-var \"y\"))"); + } + + #[test] + fn test_function_registration() { + let steel_decimal = SteelDecimal::new(); + let mut vm = Engine::new(); + + steel_decimal.register_functions(&mut vm); + + let script = "(decimal-add \"1.5\" \"2.3\")"; + let result = vm.compile_and_run_raw_program(script.to_string()); + assert!(result.is_ok()); + } + + #[test] + fn test_parse_and_execute() { + let steel_decimal = SteelDecimal::new(); + let script = "(+ 1.5 2.3)"; + + let result = steel_decimal.parse_and_execute(script); + assert!(result.is_ok()); + } + + #[test] + fn test_script_validation() { + let steel_decimal = SteelDecimal::new(); + + // Valid script + assert!(steel_decimal.validate_script("(+ 1.5 2.3)").is_ok()); + + // Invalid script - unbalanced parentheses + assert!(steel_decimal.validate_script("(+ 1.5 2.3").is_err()); + } + + #[test] + fn test_variable_validation() { + let steel_decimal = SteelDecimal::new(); + + // Script with undefined variable + assert!(steel_decimal.validate_script("(+ $x 2.3)").is_err()); + + // Script with defined variable + let mut variables = HashMap::new(); + variables.insert("x".to_string(), "10.5".to_string()); + let steel_decimal = SteelDecimal::with_variables(variables); + assert!(steel_decimal.validate_script("(+ $x 2.3)").is_ok()); + } + + #[test] + fn test_complex_expressions() { + let steel_decimal = SteelDecimal::new(); + + let script = "(+ (* 2.5 3.0) (/ 15.0 3.0))"; + let transformed = steel_decimal.transform(script); + let expected = "(decimal-add (decimal-mul \"2.5\" \"3.0\") (decimal-div \"15.0\" \"3.0\"))"; + assert_eq!(transformed, expected); + } + + #[test] + fn test_dependency_extraction() { + let steel_decimal = SteelDecimal::new(); + let script = "(+ $x $y $z)"; + + let dependencies = steel_decimal.extract_dependencies(script); + assert_eq!(dependencies.len(), 3); + assert!(dependencies.contains("x")); + assert!(dependencies.contains("y")); + assert!(dependencies.contains("z")); + } } diff --git a/steel_decimal/src/parser.rs b/steel_decimal/src/parser.rs new file mode 100644 index 0000000..67ccd7a --- /dev/null +++ b/steel_decimal/src/parser.rs @@ -0,0 +1,176 @@ +// src/parser.rs +use regex::Regex; +use std::collections::HashSet; + +pub struct ScriptParser { + math_operators: Vec<(Regex, &'static str)>, + number_literal_re: Regex, + variable_re: Regex, +} + +impl ScriptParser { + pub fn new() -> Self { + let math_operators = vec![ + // Basic arithmetic + (Regex::new(r"\(\s*\+\s+").unwrap(), "(decimal-add "), + (Regex::new(r"\(\s*-\s+").unwrap(), "(decimal-sub "), + (Regex::new(r"\(\s*\*\s+").unwrap(), "(decimal-mul "), + (Regex::new(r"\(\s*/\s+").unwrap(), "(decimal-div "), + + // Power and advanced operations + (Regex::new(r"\(\s*\^\s+").unwrap(), "(decimal-pow "), + (Regex::new(r"\(\s*\*\*\s+").unwrap(), "(decimal-pow "), + (Regex::new(r"\(\s*pow\s+").unwrap(), "(decimal-pow "), + (Regex::new(r"\(\s*sqrt\s+").unwrap(), "(decimal-sqrt "), + + // Logarithmic functions + (Regex::new(r"\(\s*ln\s+").unwrap(), "(decimal-ln "), + (Regex::new(r"\(\s*log\s+").unwrap(), "(decimal-ln "), + (Regex::new(r"\(\s*log10\s+").unwrap(), "(decimal-log10 "), + (Regex::new(r"\(\s*exp\s+").unwrap(), "(decimal-exp "), + + // Trigonometric functions + (Regex::new(r"\(\s*sin\s+").unwrap(), "(decimal-sin "), + (Regex::new(r"\(\s*cos\s+").unwrap(), "(decimal-cos "), + (Regex::new(r"\(\s*tan\s+").unwrap(), "(decimal-tan "), + + // Comparison operators + (Regex::new(r"\(\s*>\s+").unwrap(), "(decimal-gt "), + (Regex::new(r"\(\s*<\s+").unwrap(), "(decimal-lt "), + (Regex::new(r"\(\s*=\s+").unwrap(), "(decimal-eq "), + (Regex::new(r"\(\s*>=\s+").unwrap(), "(decimal-gte "), + (Regex::new(r"\(\s*<=\s+").unwrap(), "(decimal-lte "), + + // Utility functions + (Regex::new(r"\(\s*abs\s+").unwrap(), "(decimal-abs "), + (Regex::new(r"\(\s*min\s+").unwrap(), "(decimal-min "), + (Regex::new(r"\(\s*max\s+").unwrap(), "(decimal-max "), + (Regex::new(r"\(\s*round\s+").unwrap(), "(decimal-round "), + ]; + + ScriptParser { + math_operators, + number_literal_re: Regex::new(r#"(? String { + let mut transformed = script.to_string(); + + // Step 1: Convert numeric literals to strings + transformed = self.convert_numbers_to_strings(&transformed); + + // Step 2: Replace math function calls with decimal equivalents + transformed = self.replace_math_functions(&transformed); + + // Step 3: Replace variable references + transformed = self.replace_variable_references(&transformed); + + transformed + } + + /// Convert all unquoted numeric literals to quoted strings + fn convert_numbers_to_strings(&self, script: &str) -> String { + self.number_literal_re.replace_all(script, |caps: ®ex::Captures| { + format!("\"{}\"", &caps[1]) + }).to_string() + } + + /// Replace math function calls with decimal equivalents + fn replace_math_functions(&self, script: &str) -> String { + let mut result = script.to_string(); + + for (pattern, replacement) in &self.math_operators { + result = pattern.replace_all(&result, *replacement).to_string(); + } + + result + } + + /// Replace variable references ($var) with function calls + fn replace_variable_references(&self, script: &str) -> String { + self.variable_re.replace_all(script, |caps: ®ex::Captures| { + format!("(get-var \"{}\")", &caps[1]) + }).to_string() + } + + /// Extract dependencies from script (useful for analysis) + pub fn extract_dependencies(&self, script: &str) -> HashSet { + let mut dependencies = HashSet::new(); + + // Extract variable dependencies + for cap in self.variable_re.captures_iter(script) { + dependencies.insert(cap[1].to_string()); + } + + dependencies + } +} + +impl Default for ScriptParser { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_math_transformation() { + let parser = ScriptParser::new(); + + let input = "(+ 1.5 2.3)"; + let expected = "(decimal-add \"1.5\" \"2.3\")"; + let result = parser.transform(input); + + assert_eq!(result, expected); + } + + #[test] + fn test_complex_expression() { + let parser = ScriptParser::new(); + + let input = "(+ (* 2 3) (/ 10 2))"; + let expected = "(decimal-add (decimal-mul \"2\" \"3\") (decimal-div \"10\" \"2\"))"; + let result = parser.transform(input); + + assert_eq!(result, expected); + } + + #[test] + fn test_variable_replacement() { + let parser = ScriptParser::new(); + + let input = "(+ $x $y)"; + let expected = "(decimal-add (get-var \"x\") (get-var \"y\"))"; + let result = parser.transform(input); + + assert_eq!(result, expected); + } + + #[test] + fn test_negative_numbers() { + let parser = ScriptParser::new(); + + let input = "(+ -1.5 2.3)"; + let expected = "(decimal-add \"-1.5\" \"2.3\")"; + let result = parser.transform(input); + + assert_eq!(result, expected); + } + + #[test] + fn test_scientific_notation() { + let parser = ScriptParser::new(); + + let input = "(+ 1.5e2 2.3E-1)"; + let expected = "(decimal-add \"1.5e2\" \"2.3E-1\")"; + let result = parser.transform(input); + + assert_eq!(result, expected); + } +} diff --git a/steel_decimal/src/registry.rs b/steel_decimal/src/registry.rs new file mode 100644 index 0000000..77bd0f4 --- /dev/null +++ b/steel_decimal/src/registry.rs @@ -0,0 +1,281 @@ +// src/registry.rs +use steel::steel_vm::engine::Engine; +use steel::steel_vm::register_fn::RegisterFn; +use crate::functions::*; +use std::collections::HashMap; + +pub struct FunctionRegistry; + +impl FunctionRegistry { + /// Register all decimal math functions with the Steel VM + pub fn register_all(vm: &mut Engine) { + Self::register_basic_arithmetic(vm); + Self::register_advanced_math(vm); + Self::register_trigonometric(vm); + Self::register_comparison(vm); + Self::register_utility(vm); + Self::register_constants(vm); + Self::register_financial(vm); + Self::register_conversion(vm); + } + + /// Register basic arithmetic functions + pub fn register_basic_arithmetic(vm: &mut Engine) { + vm.register_fn("decimal-add", decimal_add); + vm.register_fn("decimal-sub", decimal_sub); + vm.register_fn("decimal-mul", decimal_mul); + vm.register_fn("decimal-div", decimal_div); + } + + /// Register advanced mathematical functions + pub fn register_advanced_math(vm: &mut Engine) { + vm.register_fn("decimal-pow", decimal_pow); + vm.register_fn("decimal-sqrt", decimal_sqrt); + vm.register_fn("decimal-ln", decimal_ln); + vm.register_fn("decimal-log10", decimal_log10); + vm.register_fn("decimal-exp", decimal_exp); + } + + /// Register trigonometric functions + pub fn register_trigonometric(vm: &mut Engine) { + vm.register_fn("decimal-sin", decimal_sin); + vm.register_fn("decimal-cos", decimal_cos); + vm.register_fn("decimal-tan", decimal_tan); + } + + /// Register comparison functions + pub fn register_comparison(vm: &mut Engine) { + vm.register_fn("decimal-gt", decimal_gt); + vm.register_fn("decimal-gte", decimal_gte); + vm.register_fn("decimal-lt", decimal_lt); + vm.register_fn("decimal-lte", decimal_lte); + vm.register_fn("decimal-eq", decimal_eq); + } + + /// Register utility functions + pub fn register_utility(vm: &mut Engine) { + vm.register_fn("decimal-abs", decimal_abs); + vm.register_fn("decimal-round", decimal_round); + vm.register_fn("decimal-min", decimal_min); + vm.register_fn("decimal-max", decimal_max); + } + + /// Register mathematical constants + pub fn register_constants(vm: &mut Engine) { + vm.register_fn("decimal-zero", decimal_zero); + vm.register_fn("decimal-one", decimal_one); + vm.register_fn("decimal-pi", decimal_pi); + vm.register_fn("decimal-e", decimal_e); + } + + /// Register financial functions + pub fn register_financial(vm: &mut Engine) { + vm.register_fn("decimal-percentage", decimal_percentage); + vm.register_fn("decimal-compound", decimal_compound); + } + + /// Register type conversion functions + pub fn register_conversion(vm: &mut Engine) { + vm.register_fn("to-decimal", to_decimal); + } + + /// Register variable access functions + pub fn register_variables(vm: &mut Engine, variables: HashMap) { + let variables_for_get = variables.clone(); + let variables_for_has = variables; + + vm.register_fn("get-var", move |var_name: String| -> Result { + variables_for_get.get(&var_name) + .cloned() + .ok_or_else(|| format!("Variable '{}' not found", var_name)) + }); + + vm.register_fn("has-var?", move |var_name: String| -> bool { + variables_for_has.contains_key(&var_name) + }); + } + + /// Get a list of all registered function names + pub fn get_function_names() -> Vec<&'static str> { + vec![ + // Basic arithmetic + "decimal-add", "decimal-sub", "decimal-mul", "decimal-div", + // Advanced math + "decimal-pow", "decimal-sqrt", "decimal-ln", "decimal-log10", "decimal-exp", + // Trigonometric + "decimal-sin", "decimal-cos", "decimal-tan", + // Comparison + "decimal-gt", "decimal-gte", "decimal-lt", "decimal-lte", "decimal-eq", + // Utility + "decimal-abs", "decimal-round", "decimal-min", "decimal-max", + // Constants + "decimal-zero", "decimal-one", "decimal-pi", "decimal-e", + // Financial + "decimal-percentage", "decimal-compound", + // Conversion + "to-decimal", + // Variables + "get-var", "has-var?", + ] + } +} + +/// Builder pattern for selective function registration +pub struct FunctionRegistryBuilder { + include_basic: bool, + include_advanced: bool, + include_trigonometric: bool, + include_comparison: bool, + include_utility: bool, + include_constants: bool, + include_financial: bool, + include_conversion: bool, + variables: Option>, +} + +impl FunctionRegistryBuilder { + pub fn new() -> Self { + Self { + include_basic: true, + include_advanced: true, + include_trigonometric: true, + include_comparison: true, + include_utility: true, + include_constants: true, + include_financial: true, + include_conversion: true, + variables: None, + } + } + + pub fn basic_arithmetic(mut self, include: bool) -> Self { + self.include_basic = include; + self + } + + pub fn advanced_math(mut self, include: bool) -> Self { + self.include_advanced = include; + self + } + + pub fn trigonometric(mut self, include: bool) -> Self { + self.include_trigonometric = include; + self + } + + pub fn comparison(mut self, include: bool) -> Self { + self.include_comparison = include; + self + } + + pub fn utility(mut self, include: bool) -> Self { + self.include_utility = include; + self + } + + pub fn constants(mut self, include: bool) -> Self { + self.include_constants = include; + self + } + + pub fn financial(mut self, include: bool) -> Self { + self.include_financial = include; + self + } + + pub fn conversion(mut self, include: bool) -> Self { + self.include_conversion = include; + self + } + + pub fn with_variables(mut self, variables: HashMap) -> Self { + self.variables = Some(variables); + self + } + + pub fn register(self, vm: &mut Engine) { + if self.include_basic { + FunctionRegistry::register_basic_arithmetic(vm); + } + if self.include_advanced { + FunctionRegistry::register_advanced_math(vm); + } + if self.include_trigonometric { + FunctionRegistry::register_trigonometric(vm); + } + if self.include_comparison { + FunctionRegistry::register_comparison(vm); + } + if self.include_utility { + FunctionRegistry::register_utility(vm); + } + if self.include_constants { + FunctionRegistry::register_constants(vm); + } + if self.include_financial { + FunctionRegistry::register_financial(vm); + } + if self.include_conversion { + FunctionRegistry::register_conversion(vm); + } + if let Some(variables) = self.variables { + FunctionRegistry::register_variables(vm, variables); + } + } +} + +impl Default for FunctionRegistryBuilder { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_function_registration() { + let mut vm = Engine::new(); + FunctionRegistry::register_all(&mut vm); + + // Test that functions are registered by running a simple script + let script = r#"(decimal-add "1.5" "2.3")"#; + let result = vm.compile_and_run_raw_program(script.to_string()); + + assert!(result.is_ok()); + } + + #[test] + fn test_selective_registration() { + let mut vm = Engine::new(); + + FunctionRegistryBuilder::new() + .basic_arithmetic(true) + .advanced_math(false) + .trigonometric(false) + .register(&mut vm); + + // Test that basic functions work + let script = r#"(decimal-add "1.5" "2.3")"#; + let result = vm.compile_and_run_raw_program(script.to_string()); + assert!(result.is_ok()); + } + + #[test] + fn test_variable_registration() { + let mut vm = Engine::new(); + let mut variables = HashMap::new(); + variables.insert("x".to_string(), "10.5".to_string()); + variables.insert("y".to_string(), "20.3".to_string()); + + FunctionRegistryBuilder::new() + .with_variables(variables) + .register(&mut vm); + + // Test variable access + let script = r#"(get-var "x")"#; + let result = vm.compile_and_run_raw_program(script.to_string()); + assert!(result.is_ok()); + } +} diff --git a/steel_decimal/src/syntax_parser.rs b/steel_decimal/src/syntax_parser.rs deleted file mode 100644 index 4f05076..0000000 --- a/steel_decimal/src/syntax_parser.rs +++ /dev/null @@ -1,83 +0,0 @@ -// src/syntax_parser.rs - -use regex::Regex; - -pub struct SyntaxParser { - math_operators: Vec<(Regex, &'static str)>, - number_literal_re: Regex, -} - -impl SyntaxParser { - pub fn new() -> Self { - let math_operators = vec![ - (Regex::new(r"\(\s*\+\s+").unwrap(), "(decimal-add "), - (Regex::new(r"\(\s*-\s+").unwrap(), "(decimal-sub "), - (Regex::new(r"\(\s*\*\s+").unwrap(), "(decimal-mul "), - (Regex::new(r"\(\s*/\s+").unwrap(), "(decimal-div "), - (Regex::new(r"\(\s*\^\s+").unwrap(), "(decimal-pow "), - (Regex::new(r"\(\s*\*\*\s+").unwrap(), "(decimal-pow "), - (Regex::new(r"\(\s*pow\s+").unwrap(), "(decimal-pow "), - (Regex::new(r"\(\s*sqrt\s+").unwrap(), "(decimal-sqrt "), - (Regex::new(r"\(\s*ln\s+").unwrap(), "(decimal-ln "), - (Regex::new(r"\(\s*log\s+").unwrap(), "(decimal-ln "), - (Regex::new(r"\(\s*log10\s+").unwrap(), "(decimal-log10 "), - (Regex::new(r"\(\s*exp\s+").unwrap(), "(decimal-exp "), - (Regex::new(r"\(\s*sin\s+").unwrap(), "(decimal-sin "), - (Regex::new(r"\(\s*cos\s+").unwrap(), "(decimal-cos "), - (Regex::new(r"\(\s*tan\s+").unwrap(), "(decimal-tan "), - (Regex::new(r"\(\s*>\s+").unwrap(), "(decimal-gt "), - (Regex::new(r"\(\s*<\s+").unwrap(), "(decimal-lt "), - (Regex::new(r"\(\s*=\s+").unwrap(), "(decimal-eq "), - (Regex::new(r"\(\s*>=\s+").unwrap(), "(decimal-gte "), - (Regex::new(r"\(\s*<=\s+").unwrap(), "(decimal-lte "), - (Regex::new(r"\(\s*abs\s+").unwrap(), "(decimal-abs "), - (Regex::new(r"\(\s*min\s+").unwrap(), "(decimal-min "), - (Regex::new(r"\(\s*max\s+").unwrap(), "(decimal-max "), - (Regex::new(r"\(\s*round\s+").unwrap(), "(decimal-round "), - ]; - - SyntaxParser { - number_literal_re: Regex::new(r#"(? String { - let mut transformed = script.to_string(); - - transformed = self.convert_numbers_to_strings(&transformed); - transformed = self.replace_math_functions(&transformed); - transformed = self.replace_variable_references(&transformed); - - transformed - } - - fn convert_numbers_to_strings(&self, script: &str) -> String { - self.number_literal_re.replace_all(script, |caps: ®ex::Captures| { - format!("\"{}\"", &caps[1]) - }).to_string() - } - - fn replace_math_functions(&self, script: &str) -> String { - let mut result = script.to_string(); - - for (pattern, replacement) in &self.math_operators { - result = pattern.replace_all(&result, *replacement).to_string(); - } - - result - } - - fn replace_variable_references(&self, script: &str) -> String { - let var_re = Regex::new(r"\$(\w+)").unwrap(); - var_re.replace_all(script, |caps: ®ex::Captures| { - format!("(get-var \"{}\")", &caps[1]) - }).to_string() - } -} - -impl Default for SyntaxParser { - fn default() -> Self { - Self::new() - } -} diff --git a/steel_decimal/src/utils.rs b/steel_decimal/src/utils.rs new file mode 100644 index 0000000..63b7547 --- /dev/null +++ b/steel_decimal/src/utils.rs @@ -0,0 +1,227 @@ +// src/utils.rs +use steel::rvals::SteelVal; +use rust_decimal::prelude::*; +use std::str::FromStr; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ConversionError { + #[error("Invalid decimal format: {0}")] + InvalidDecimal(String), + #[error("Unsupported SteelVal type: {0:?}")] + UnsupportedType(SteelVal), + #[error("Type conversion failed: {0}")] + ConversionFailed(String), +} + +/// Utility functions for converting between Rust types and Steel values +pub struct TypeConverter; + +impl TypeConverter { + /// Convert a SteelVal to a Decimal + pub fn steel_val_to_decimal(val: &SteelVal) -> Result { + match val { + SteelVal::StringV(s) => { + Decimal::from_str(&s.to_string()) + .map_err(|e| ConversionError::InvalidDecimal(format!("{}: {}", s, e))) + } + SteelVal::NumV(n) => { + Decimal::try_from(*n) + .map_err(|e| ConversionError::InvalidDecimal(format!("{}: {}", n, e))) + } + SteelVal::IntV(i) => { + Ok(Decimal::from(*i)) + } + _ => Err(ConversionError::UnsupportedType(val.clone())) + } + } + + /// Convert a Decimal to a SteelVal string + pub fn decimal_to_steel_val(decimal: Decimal) -> SteelVal { + SteelVal::StringV(decimal.to_string().into()) + } + + /// Convert multiple SteelVals to strings + pub fn steel_vals_to_strings(vals: Vec) -> Result, ConversionError> { + vals.into_iter() + .map(|val| Self::steel_val_to_string(val)) + .collect() + } + + /// Convert a single SteelVal to a string + pub fn steel_val_to_string(val: SteelVal) -> Result { + match val { + SteelVal::StringV(s) => Ok(s.to_string()), + SteelVal::NumV(n) => Ok(n.to_string()), + SteelVal::IntV(i) => Ok(i.to_string()), + SteelVal::BoolV(b) => Ok(b.to_string()), + SteelVal::VectorV(v) => { + let string_values: Result, _> = v.iter() + .map(|item| Self::steel_val_to_string(item.clone())) + .collect(); + + match string_values { + Ok(strings) => Ok(strings.join(",")), + Err(e) => Err(e), + } + } + _ => Err(ConversionError::UnsupportedType(val)) + } + } + + /// Convert a string to a validated decimal string + pub fn validate_decimal_string(s: &str) -> Result { + Decimal::from_str(s) + .map(|d| d.to_string()) + .map_err(|e| ConversionError::InvalidDecimal(format!("{}: {}", s, e))) + } + + /// Convert common Rust types to decimal strings + pub fn f64_to_decimal_string(f: f64) -> Result { + Decimal::try_from(f) + .map(|d| d.to_string()) + .map_err(|e| ConversionError::ConversionFailed(format!("f64 to decimal: {}", e))) + } + + pub fn i64_to_decimal_string(i: i64) -> String { + Decimal::from(i).to_string() + } + + pub fn u64_to_decimal_string(u: u64) -> String { + Decimal::from(u).to_string() + } +} + +/// Utility functions for script validation and analysis +pub struct ScriptAnalyzer; + +impl ScriptAnalyzer { + /// Check if a string looks like a valid decimal number + pub fn is_decimal_like(s: &str) -> bool { + Decimal::from_str(s.trim()).is_ok() + } + + /// Extract all string literals from a script (simple regex-based approach) + pub fn extract_string_literals(script: &str) -> Vec { + let re = regex::Regex::new(r#""([^"]*)"#).unwrap(); + re.captures_iter(script) + .map(|cap| cap[1].to_string()) + .collect() + } + + /// Count function calls in a script + pub fn count_function_calls(script: &str, function_name: &str) -> usize { + let pattern = format!(r"\({}\s+", regex::escape(function_name)); + let re = regex::Regex::new(&pattern).unwrap(); + re.find_iter(script).count() + } + + /// Check if script contains any decimal functions + pub fn contains_decimal_functions(script: &str) -> bool { + let decimal_functions = [ + "decimal-add", "decimal-sub", "decimal-mul", "decimal-div", + "decimal-pow", "decimal-sqrt", "decimal-ln", "decimal-log10", "decimal-exp", + "decimal-sin", "decimal-cos", "decimal-tan", + "decimal-gt", "decimal-gte", "decimal-lt", "decimal-lte", "decimal-eq", + "decimal-abs", "decimal-round", "decimal-min", "decimal-max", + ]; + + decimal_functions.iter().any(|func| script.contains(func)) + } +} + +/// Utility functions for working with decimal precision +pub struct DecimalPrecision; + +impl DecimalPrecision { + /// Set precision for a decimal string + pub fn set_precision(decimal_str: &str, precision: u32) -> Result { + let decimal = Decimal::from_str(decimal_str) + .map_err(|e| ConversionError::InvalidDecimal(format!("{}: {}", decimal_str, e)))?; + + Ok(decimal.round_dp(precision).to_string()) + } + + /// Get the number of decimal places in a decimal string + pub fn get_decimal_places(decimal_str: &str) -> Result { + let decimal = Decimal::from_str(decimal_str) + .map_err(|e| ConversionError::InvalidDecimal(format!("{}: {}", decimal_str, e)))?; + + Ok(decimal.scale()) + } + + /// Normalize decimal string (remove trailing zeros) + pub fn normalize(decimal_str: &str) -> Result { + let decimal = Decimal::from_str(decimal_str) + .map_err(|e| ConversionError::InvalidDecimal(format!("{}: {}", decimal_str, e)))?; + + Ok(decimal.normalize().to_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_steel_val_to_decimal() { + let string_val = SteelVal::StringV("123.456".into()); + let result = TypeConverter::steel_val_to_decimal(&string_val); + assert!(result.is_ok()); + assert_eq!(result.unwrap().to_string(), "123.456"); + + let int_val = SteelVal::IntV(42); + let result = TypeConverter::steel_val_to_decimal(&int_val); + assert!(result.is_ok()); + assert_eq!(result.unwrap().to_string(), "42"); + } + + #[test] + fn test_decimal_to_steel_val() { + let decimal = Decimal::from_str("123.456").unwrap(); + let result = TypeConverter::decimal_to_steel_val(decimal); + + if let SteelVal::StringV(s) = result { + assert_eq!(s.to_string(), "123.456"); + } else { + panic!("Expected StringV"); + } + } + + #[test] + fn test_validate_decimal_string() { + assert!(TypeConverter::validate_decimal_string("123.456").is_ok()); + assert!(TypeConverter::validate_decimal_string("invalid").is_err()); + } + + #[test] + fn test_script_analyzer() { + let script = r#"(decimal-add "1.5" "2.3")"#; + + assert!(ScriptAnalyzer::contains_decimal_functions(script)); + assert_eq!(ScriptAnalyzer::count_function_calls(script, "decimal-add"), 1); + + let literals = ScriptAnalyzer::extract_string_literals(script); + assert_eq!(literals, vec!["1.5", "2.3"]); + } + + #[test] + fn test_decimal_precision() { + let result = DecimalPrecision::set_precision("123.456789", 2); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "123.46"); + + let places = DecimalPrecision::get_decimal_places("123.456"); + assert!(places.is_ok()); + assert_eq!(places.unwrap(), 3); + } + + #[test] + fn test_type_conversions() { + assert_eq!(TypeConverter::i64_to_decimal_string(42), "42"); + assert_eq!(TypeConverter::u64_to_decimal_string(42), "42"); + + let f64_result = TypeConverter::f64_to_decimal_string(123.456); + assert!(f64_result.is_ok()); + } +}