From 4c57b562e679a2b468495767f37527e4ca5403d3 Mon Sep 17 00:00:00 2001 From: filipriec Date: Sat, 5 Jul 2025 10:00:04 +0200 Subject: [PATCH] more fixes, not there yet tho --- Cargo.lock | 2 + steel_decimal/Cargo.toml | 3 + steel_decimal/src/lib.rs | 93 ------ steel_decimal/src/parser.rs | 60 ---- steel_decimal/src/registry.rs | 50 ---- steel_decimal/src/utils.rs | 67 ----- steel_decimal/tests/function_tests.rs | 239 +++++++++++++++ steel_decimal/tests/integration_tests.rs | 310 ++++++++++++++++++++ steel_decimal/tests/parser_tests.rs | 158 ++++++++++ steel_decimal/tests/registry_tests.rs | 352 +++++++++++++++++++++++ steel_decimal/tests/utils_tests.rs | 343 ++++++++++++++++++++++ 11 files changed, 1407 insertions(+), 270 deletions(-) create mode 100644 steel_decimal/tests/function_tests.rs create mode 100644 steel_decimal/tests/integration_tests.rs create mode 100644 steel_decimal/tests/parser_tests.rs create mode 100644 steel_decimal/tests/registry_tests.rs create mode 100644 steel_decimal/tests/utils_tests.rs diff --git a/Cargo.lock b/Cargo.lock index 26779ef..a7b2204 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3575,11 +3575,13 @@ name = "steel_decimal" version = "0.3.13" dependencies = [ "regex", + "rstest", "rust_decimal", "rust_decimal_macros", "steel-core", "steel-derive 0.5.0 (git+https://github.com/mattwparas/steel.git?branch=master)", "thiserror 2.0.12", + "tokio-test", ] [[package]] diff --git a/steel_decimal/Cargo.toml b/steel_decimal/Cargo.toml index edefd9c..2039d76 100644 --- a/steel_decimal/Cargo.toml +++ b/steel_decimal/Cargo.toml @@ -18,3 +18,6 @@ rust_decimal_macros = { workspace = true } regex = { workspace = true } thiserror = { workspace = true } +[dev-dependencies] +rstest = "0.25.0" +tokio-test = "0.4.4" diff --git a/steel_decimal/src/lib.rs b/steel_decimal/src/lib.rs index 063a140..9c382d0 100644 --- a/steel_decimal/src/lib.rs +++ b/steel_decimal/src/lib.rs @@ -184,96 +184,3 @@ pub mod prelude { pub use crate::functions::*; pub use crate::utils::{TypeConverter, ScriptAnalyzer, DecimalPrecision}; } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_basic_transformation() { - let steel_decimal = SteelDecimal::new(); - let script = "(+ 1.5 2.3)"; - let transformed = steel_decimal.transform(script); - assert_eq!(transformed, "(decimal-add \"1.5\" \"2.3\")"); - } - - #[test] - fn test_with_variables() { - let mut variables = HashMap::new(); - variables.insert("x".to_string(), "10.5".to_string()); - variables.insert("y".to_string(), "20.3".to_string()); - - let steel_decimal = SteelDecimal::with_variables(variables); - let script = "(+ $x $y)"; - let transformed = steel_decimal.transform(script); - assert_eq!(transformed, "(decimal-add (get-var \"x\") (get-var \"y\"))"); - } - - #[test] - fn test_function_registration() { - let steel_decimal = SteelDecimal::new(); - let mut vm = Engine::new(); - - steel_decimal.register_functions(&mut vm); - - let script = "(decimal-add \"1.5\" \"2.3\")"; - let result = vm.compile_and_run_raw_program(script.to_string()); - assert!(result.is_ok()); - } - - #[test] - fn test_parse_and_execute() { - let steel_decimal = SteelDecimal::new(); - let script = "(+ 1.5 2.3)"; - - let result = steel_decimal.parse_and_execute(script); - assert!(result.is_ok()); - } - - #[test] - fn test_script_validation() { - let steel_decimal = SteelDecimal::new(); - - // Valid script - assert!(steel_decimal.validate_script("(+ 1.5 2.3)").is_ok()); - - // Invalid script - unbalanced parentheses - assert!(steel_decimal.validate_script("(+ 1.5 2.3").is_err()); - } - - #[test] - fn test_variable_validation() { - let steel_decimal = SteelDecimal::new(); - - // Script with undefined variable - assert!(steel_decimal.validate_script("(+ $x 2.3)").is_err()); - - // Script with defined variable - let mut variables = HashMap::new(); - variables.insert("x".to_string(), "10.5".to_string()); - let steel_decimal = SteelDecimal::with_variables(variables); - assert!(steel_decimal.validate_script("(+ $x 2.3)").is_ok()); - } - - #[test] - fn test_complex_expressions() { - let steel_decimal = SteelDecimal::new(); - - let script = "(+ (* 2.5 3.0) (/ 15.0 3.0))"; - let transformed = steel_decimal.transform(script); - let expected = "(decimal-add (decimal-mul \"2.5\" \"3.0\") (decimal-div \"15.0\" \"3.0\"))"; - assert_eq!(transformed, expected); - } - - #[test] - fn test_dependency_extraction() { - let steel_decimal = SteelDecimal::new(); - let script = "(+ $x $y $z)"; - - let dependencies = steel_decimal.extract_dependencies(script); - assert_eq!(dependencies.len(), 3); - assert!(dependencies.contains("x")); - assert!(dependencies.contains("y")); - assert!(dependencies.contains("z")); - } -} diff --git a/steel_decimal/src/parser.rs b/steel_decimal/src/parser.rs index 67ccd7a..02daa59 100644 --- a/steel_decimal/src/parser.rs +++ b/steel_decimal/src/parser.rs @@ -114,63 +114,3 @@ impl Default for ScriptParser { Self::new() } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_basic_math_transformation() { - let parser = ScriptParser::new(); - - let input = "(+ 1.5 2.3)"; - let expected = "(decimal-add \"1.5\" \"2.3\")"; - let result = parser.transform(input); - - assert_eq!(result, expected); - } - - #[test] - fn test_complex_expression() { - let parser = ScriptParser::new(); - - let input = "(+ (* 2 3) (/ 10 2))"; - let expected = "(decimal-add (decimal-mul \"2\" \"3\") (decimal-div \"10\" \"2\"))"; - let result = parser.transform(input); - - assert_eq!(result, expected); - } - - #[test] - fn test_variable_replacement() { - let parser = ScriptParser::new(); - - let input = "(+ $x $y)"; - let expected = "(decimal-add (get-var \"x\") (get-var \"y\"))"; - let result = parser.transform(input); - - assert_eq!(result, expected); - } - - #[test] - fn test_negative_numbers() { - let parser = ScriptParser::new(); - - let input = "(+ -1.5 2.3)"; - let expected = "(decimal-add \"-1.5\" \"2.3\")"; - let result = parser.transform(input); - - assert_eq!(result, expected); - } - - #[test] - fn test_scientific_notation() { - let parser = ScriptParser::new(); - - let input = "(+ 1.5e2 2.3E-1)"; - let expected = "(decimal-add \"1.5e2\" \"2.3E-1\")"; - let result = parser.transform(input); - - assert_eq!(result, expected); - } -} diff --git a/steel_decimal/src/registry.rs b/steel_decimal/src/registry.rs index 77bd0f4..7484e0b 100644 --- a/steel_decimal/src/registry.rs +++ b/steel_decimal/src/registry.rs @@ -229,53 +229,3 @@ impl Default for FunctionRegistryBuilder { Self::new() } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_function_registration() { - let mut vm = Engine::new(); - FunctionRegistry::register_all(&mut vm); - - // Test that functions are registered by running a simple script - let script = r#"(decimal-add "1.5" "2.3")"#; - let result = vm.compile_and_run_raw_program(script.to_string()); - - assert!(result.is_ok()); - } - - #[test] - fn test_selective_registration() { - let mut vm = Engine::new(); - - FunctionRegistryBuilder::new() - .basic_arithmetic(true) - .advanced_math(false) - .trigonometric(false) - .register(&mut vm); - - // Test that basic functions work - let script = r#"(decimal-add "1.5" "2.3")"#; - let result = vm.compile_and_run_raw_program(script.to_string()); - assert!(result.is_ok()); - } - - #[test] - fn test_variable_registration() { - let mut vm = Engine::new(); - let mut variables = HashMap::new(); - variables.insert("x".to_string(), "10.5".to_string()); - variables.insert("y".to_string(), "20.3".to_string()); - - FunctionRegistryBuilder::new() - .with_variables(variables) - .register(&mut vm); - - // Test variable access - let script = r#"(get-var "x")"#; - let result = vm.compile_and_run_raw_program(script.to_string()); - assert!(result.is_ok()); - } -} diff --git a/steel_decimal/src/utils.rs b/steel_decimal/src/utils.rs index 63b7547..54af289 100644 --- a/steel_decimal/src/utils.rs +++ b/steel_decimal/src/utils.rs @@ -158,70 +158,3 @@ impl DecimalPrecision { Ok(decimal.normalize().to_string()) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_steel_val_to_decimal() { - let string_val = SteelVal::StringV("123.456".into()); - let result = TypeConverter::steel_val_to_decimal(&string_val); - assert!(result.is_ok()); - assert_eq!(result.unwrap().to_string(), "123.456"); - - let int_val = SteelVal::IntV(42); - let result = TypeConverter::steel_val_to_decimal(&int_val); - assert!(result.is_ok()); - assert_eq!(result.unwrap().to_string(), "42"); - } - - #[test] - fn test_decimal_to_steel_val() { - let decimal = Decimal::from_str("123.456").unwrap(); - let result = TypeConverter::decimal_to_steel_val(decimal); - - if let SteelVal::StringV(s) = result { - assert_eq!(s.to_string(), "123.456"); - } else { - panic!("Expected StringV"); - } - } - - #[test] - fn test_validate_decimal_string() { - assert!(TypeConverter::validate_decimal_string("123.456").is_ok()); - assert!(TypeConverter::validate_decimal_string("invalid").is_err()); - } - - #[test] - fn test_script_analyzer() { - let script = r#"(decimal-add "1.5" "2.3")"#; - - assert!(ScriptAnalyzer::contains_decimal_functions(script)); - assert_eq!(ScriptAnalyzer::count_function_calls(script, "decimal-add"), 1); - - let literals = ScriptAnalyzer::extract_string_literals(script); - assert_eq!(literals, vec!["1.5", "2.3"]); - } - - #[test] - fn test_decimal_precision() { - let result = DecimalPrecision::set_precision("123.456789", 2); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), "123.46"); - - let places = DecimalPrecision::get_decimal_places("123.456"); - assert!(places.is_ok()); - assert_eq!(places.unwrap(), 3); - } - - #[test] - fn test_type_conversions() { - assert_eq!(TypeConverter::i64_to_decimal_string(42), "42"); - assert_eq!(TypeConverter::u64_to_decimal_string(42), "42"); - - let f64_result = TypeConverter::f64_to_decimal_string(123.456); - assert!(f64_result.is_ok()); - } -} diff --git a/steel_decimal/tests/function_tests.rs b/steel_decimal/tests/function_tests.rs new file mode 100644 index 0000000..bdd4e28 --- /dev/null +++ b/steel_decimal/tests/function_tests.rs @@ -0,0 +1,239 @@ +use rstest::*; +use steel_decimal::*; + +// Basic Arithmetic Tests +#[rstest] +#[case("1.5", "2.3", "3.8")] +#[case("10", "5", "15")] +#[case("-5.5", "3.2", "-2.3")] +#[case("0", "42", "42")] +fn test_decimal_add(#[case] a: &str, #[case] b: &str, #[case] expected: &str) { + let result = decimal_add(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("10", "3", "7")] +#[case("5.5", "2.2", "3.3")] +#[case("0", "5", "-5")] +#[case("-3", "-2", "-1")] +fn test_decimal_sub(#[case] a: &str, #[case] b: &str, #[case] expected: &str) { + let result = decimal_sub(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("2", "3", "6")] +#[case("2.5", "4", "10")] +#[case("-2", "3", "-6")] +#[case("0", "100", "0")] +fn test_decimal_mul(#[case] a: &str, #[case] b: &str, #[case] expected: &str) { + let result = decimal_mul(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("10", "2", "5")] +#[case("15", "3", "5")] +#[case("7.5", "2.5", "3")] +fn test_decimal_div(#[case] a: &str, #[case] b: &str, #[case] expected: &str) { + let result = decimal_div(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +fn test_decimal_div_by_zero() { + let result = decimal_div("10".to_string(), "0".to_string()); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Division by zero")); +} + +// Advanced Math Tests +#[rstest] +#[case("2", "3", "8")] +#[case("5", "2", "25")] +#[case("10", "0", "1")] +fn test_decimal_pow(#[case] base: &str, #[case] exp: &str, #[case] expected: &str) { + let result = decimal_pow(base.to_string(), exp.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("16", "4")] +#[case("25", "5")] +#[case("9", "3")] +fn test_decimal_sqrt(#[case] input: &str, #[case] expected: &str) { + let result = decimal_sqrt(input.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +fn test_decimal_sqrt_negative() { + let result = decimal_sqrt("-4".to_string()); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Square root failed")); +} + +// Comparison Tests +#[rstest] +#[case("5", "3", true)] +#[case("3", "5", false)] +#[case("5", "5", false)] +fn test_decimal_gt(#[case] a: &str, #[case] b: &str, #[case] expected: bool) { + let result = decimal_gt(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("3", "5", true)] +#[case("5", "3", false)] +#[case("5", "5", false)] +fn test_decimal_lt(#[case] a: &str, #[case] b: &str, #[case] expected: bool) { + let result = decimal_lt(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("5", "5", true)] +#[case("5", "3", false)] +#[case("3.14", "3.14", true)] +fn test_decimal_eq(#[case] a: &str, #[case] b: &str, #[case] expected: bool) { + let result = decimal_eq(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("5", "3", true)] +#[case("3", "5", false)] +#[case("5", "5", true)] +fn test_decimal_gte(#[case] a: &str, #[case] b: &str, #[case] expected: bool) { + let result = decimal_gte(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("3", "5", true)] +#[case("5", "3", false)] +#[case("5", "5", true)] +fn test_decimal_lte(#[case] a: &str, #[case] b: &str, #[case] expected: bool) { + let result = decimal_lte(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} + +// Utility Tests +#[rstest] +#[case("-5", "5")] +#[case("3.14", "3.14")] +#[case("-0", "0")] +fn test_decimal_abs(#[case] input: &str, #[case] expected: &str) { + let result = decimal_abs(input.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("3", "5", "3")] +#[case("10", "7", "7")] +#[case("-5", "-2", "-5")] +fn test_decimal_min(#[case] a: &str, #[case] b: &str, #[case] expected: &str) { + let result = decimal_min(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("3", "5", "5")] +#[case("10", "7", "10")] +#[case("-5", "-2", "-2")] +fn test_decimal_max(#[case] a: &str, #[case] b: &str, #[case] expected: &str) { + let result = decimal_max(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("3.14159", 2, "3.14")] +#[case("2.71828", 3, "2.718")] +#[case("10.999", 0, "11")] +fn test_decimal_round(#[case] input: &str, #[case] places: i32, #[case] expected: &str) { + let result = decimal_round(input.to_string(), places).unwrap(); + assert_eq!(result, expected); +} + +// Constants Tests +#[rstest] +fn test_decimal_constants() { + assert_eq!(decimal_zero(), "0"); + assert_eq!(decimal_one(), "1"); + assert!(decimal_pi().starts_with("3.14159")); + assert!(decimal_e().starts_with("2.71828")); +} + +// Financial Functions Tests +#[rstest] +#[case("100", "15", "15")] +#[case("1000", "5.5", "55")] +#[case("250", "20", "50")] +fn test_decimal_percentage(#[case] amount: &str, #[case] percentage: &str, #[case] expected: &str) { + let result = decimal_percentage(amount.to_string(), percentage.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("1000", "0.05", "1", "1050")] +#[case("1000", "0.1", "2", "1210")] +fn test_decimal_compound(#[case] principal: &str, #[case] rate: &str, #[case] time: &str, #[case] expected: &str) { + let result = decimal_compound(principal.to_string(), rate.to_string(), time.to_string()).unwrap(); + assert_eq!(result, expected); +} + +// Type Conversion Tests +#[rstest] +#[case("123.456", "123.456")] +#[case("42", "42")] +#[case("0.001", "0.001")] +fn test_to_decimal(#[case] input: &str, #[case] expected: &str) { + let result = to_decimal(input.to_string()).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +fn test_to_decimal_invalid() { + let result = to_decimal("not_a_number".to_string()); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Invalid decimal")); +} + +// Error Handling Tests +#[rstest] +#[case("invalid", "2")] +#[case("1", "invalid")] +#[case("abc", "def")] +fn test_decimal_add_invalid_input(#[case] a: &str, #[case] b: &str) { + let result = decimal_add(a.to_string(), b.to_string()); + assert!(result.is_err()); +} + +#[rstest] +#[case("invalid")] +#[case("not_a_number")] +#[case("abc")] +fn test_decimal_sqrt_invalid_input(#[case] input: &str) { + let result = decimal_sqrt(input.to_string()); + assert!(result.is_err()); +} + +// Edge Cases +#[rstest] +#[case("0.000000001", "0.000000001", "0.000000002")] +#[case("999999999999999", "1", "1000000000000000")] +fn test_decimal_precision_edge_cases(#[case] a: &str, #[case] b: &str, #[case] expected: &str) { + let result = decimal_add(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} + +// Test scientific notation +#[rstest] +#[case("1e2", "2e1", "120")] +#[case("1.5e2", "2.3e1", "173")] +fn test_scientific_notation(#[case] a: &str, #[case] b: &str, #[case] expected: &str) { + let result = decimal_add(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} diff --git a/steel_decimal/tests/integration_tests.rs b/steel_decimal/tests/integration_tests.rs new file mode 100644 index 0000000..7a86b08 --- /dev/null +++ b/steel_decimal/tests/integration_tests.rs @@ -0,0 +1,310 @@ +use rstest::*; +use steel_decimal::{SteelDecimal, FunctionRegistry, FunctionRegistryBuilder}; +use steel::steel_vm::engine::Engine; +use steel::rvals::SteelVal; +use std::collections::HashMap; + +#[fixture] +fn steel_decimal_instance() -> SteelDecimal { + SteelDecimal::new() +} + +#[fixture] +fn vm_with_functions() -> Engine { + let mut vm = Engine::new(); + FunctionRegistry::register_all(&mut vm); + vm +} + +// End-to-End Transformation and Execution Tests +#[rstest] +#[case("(+ 1.5 2.3)", "3.8")] +#[case("(- 10 4)", "6")] +#[case("(* 3 4)", "12")] +#[case("(/ 15 3)", "5")] +fn test_end_to_end_basic_arithmetic(steel_decimal_instance: SteelDecimal, #[case] input: &str, #[case] expected: &str) { + let result = steel_decimal_instance.parse_and_execute(input).unwrap(); + + // Should return a single value + assert_eq!(result.len(), 1); + + // Extract the string value + if let SteelVal::StringV(s) = &result[0] { + assert_eq!(s.to_string(), expected); + } else { + panic!("Expected StringV, got {:?}", result[0]); + } +} + +#[rstest] +#[case("(+ (* 2 3) (/ 12 4))", "9")] +#[case("(- (+ 10 5) (* 2 3))", "9")] +#[case("(* (+ 2 3) (- 8 3))", "25")] +fn test_end_to_end_complex_expressions(steel_decimal_instance: SteelDecimal, #[case] input: &str, #[case] expected: &str) { + let result = steel_decimal_instance.parse_and_execute(input).unwrap(); + + assert_eq!(result.len(), 1); + if let SteelVal::StringV(s) = &result[0] { + assert_eq!(s.to_string(), expected); + } else { + panic!("Expected StringV, got {:?}", result[0]); + } +} + +// Test with variables +#[rstest] +fn test_end_to_end_with_variables() { + let mut variables = HashMap::new(); + variables.insert("x".to_string(), "10".to_string()); + variables.insert("y".to_string(), "5".to_string()); + + let steel_decimal_instance = SteelDecimal::with_variables(variables); + let result = steel_decimal_instance.parse_and_execute("(+ $x $y)").unwrap(); + + assert_eq!(result.len(), 1); + if let SteelVal::StringV(s) = &result[0] { + assert_eq!(s.to_string(), "15"); + } else { + panic!("Expected StringV, got {:?}", result[0]); + } +} + +// Test transformation only +#[rstest] +#[case("(+ 1 2)", "(decimal-add \"1\" \"2\")")] +#[case("(* $x $y)", "(decimal-mul (get-var \"x\") (get-var \"y\"))")] +#[case("(sqrt 16)", "(decimal-sqrt \"16\")")] +fn test_transformation_only(steel_decimal_instance: SteelDecimal, #[case] input: &str, #[case] expected: &str) { + let result = steel_decimal_instance.transform(input); + assert_eq!(result, expected); +} + +// Test function registration +#[rstest] +fn test_function_registration_with_vm() { + let mut vm = Engine::new(); + FunctionRegistry::register_all(&mut vm); + + // Test that we can execute decimal functions directly + let script = r#"(decimal-add "2.5" "3.7")"#; + let result = vm.compile_and_run_raw_program(script.to_string()).unwrap(); + + assert_eq!(result.len(), 1); + if let SteelVal::StringV(s) = &result[0] { + assert_eq!(s.to_string(), "6.2"); + } else { + panic!("Expected StringV, got {:?}", result[0]); + } +} + +// Test selective function registration +#[rstest] +fn test_selective_function_registration() { + let mut vm = Engine::new(); + + FunctionRegistryBuilder::new() + .basic_arithmetic(true) + .advanced_math(false) + .trigonometric(false) + .comparison(true) + .utility(false) + .constants(true) + .financial(false) + .conversion(true) + .register(&mut vm); + + // Basic arithmetic should work + let script = r#"(decimal-add "1" "2")"#; + let result = vm.compile_and_run_raw_program(script.to_string()); + assert!(result.is_ok()); + + // Constants should work + let script = r#"(decimal-pi)"#; + let result = vm.compile_and_run_raw_program(script.to_string()); + assert!(result.is_ok()); + + // Comparison should work + let script = r#"(decimal-gt "5" "3")"#; + let result = vm.compile_and_run_raw_program(script.to_string()); + assert!(result.is_ok()); +} + +// Test variable registration +#[rstest] +fn test_variable_registration() { + let mut vm = Engine::new(); + let mut variables = HashMap::new(); + variables.insert("test_var".to_string(), "42.5".to_string()); + variables.insert("another_var".to_string(), "10.0".to_string()); + + FunctionRegistryBuilder::new() + .basic_arithmetic(true) + .with_variables(variables) + .register(&mut vm); + + // Test getting a variable + let script = r#"(get-var "test_var")"#; + let result = vm.compile_and_run_raw_program(script.to_string()).unwrap(); + + assert_eq!(result.len(), 1); + if let SteelVal::StringV(s) = &result[0] { + assert_eq!(s.to_string(), "42.5"); + } else { + panic!("Expected StringV, got {:?}", result[0]); + } + + // Test checking if variable exists + let script = r#"(has-var? "test_var")"#; + let result = vm.compile_and_run_raw_program(script.to_string()).unwrap(); + + assert_eq!(result.len(), 1); + if let SteelVal::BoolV(b) = &result[0] { + assert!(b); + } else { + panic!("Expected BoolV, got {:?}", result[0]); + } + + // Test non-existent variable + let script = r#"(has-var? "nonexistent")"#; + let result = vm.compile_and_run_raw_program(script.to_string()).unwrap(); + + assert_eq!(result.len(), 1); + if let SteelVal::BoolV(b) = &result[0] { + assert!(!b); + } else { + panic!("Expected BoolV, got {:?}", result[0]); + } +} + +// Test script validation +#[rstest] +#[case("(+ 1 2)", true, "Valid script")] +#[case("(+ 1 2", false, "Unbalanced parentheses")] +#[case("(+ $undefined_var 2)", false, "Undefined variable")] +fn test_script_validation(#[case] script: &str, #[case] should_be_valid: bool, #[case] _description: &str) { + let steel_decimal_instance = SteelDecimal::new(); + let result = steel_decimal_instance.validate_script(script); + + if should_be_valid { + assert!(result.is_ok(), "Script should be valid: {}", script); + } else { + assert!(result.is_err(), "Script should be invalid: {}", script); + } +} + +// Test with defined variables +#[rstest] +fn test_script_validation_with_variables() { + let mut variables = HashMap::new(); + variables.insert("x".to_string(), "10".to_string()); + variables.insert("y".to_string(), "20".to_string()); + + let steel_decimal_instance = SteelDecimal::with_variables(variables); + + // Should be valid now + assert!(steel_decimal_instance.validate_script("(+ $x $y)").is_ok()); + + // Still invalid variable + assert!(steel_decimal_instance.validate_script("(+ $x $undefined)").is_err()); +} + +// Test dependency extraction +#[rstest] +#[case("(+ $x $y)", vec!["x", "y"])] +#[case("(* $price $quantity)", vec!["price", "quantity"])] +#[case("(+ 1 2)", vec![])] +fn test_dependency_extraction(steel_decimal_instance: SteelDecimal, #[case] script: &str, #[case] expected_deps: Vec<&str>) { + let deps = steel_decimal_instance.extract_dependencies(script); + let expected: std::collections::HashSet = expected_deps.into_iter().map(|s| s.to_string()).collect(); + assert_eq!(deps, expected); +} + +// Test adding variables dynamically +#[rstest] +fn test_dynamic_variable_addition() { + let mut steel_decimal_instance = SteelDecimal::new(); + + // Should fail initially + assert!(steel_decimal_instance.validate_script("(+ $x $y)").is_err()); + + // Add variables + steel_decimal_instance.add_variable("x".to_string(), "10".to_string()); + steel_decimal_instance.add_variable("y".to_string(), "20".to_string()); + + // Should work now + assert!(steel_decimal_instance.validate_script("(+ $x $y)").is_ok()); + + // Test execution + let result = steel_decimal_instance.parse_and_execute("(+ $x $y)").unwrap(); + assert_eq!(result.len(), 1); + if let SteelVal::StringV(s) = &result[0] { + assert_eq!(s.to_string(), "30"); + } +} + +// Test error handling +#[rstest] +fn test_error_handling() { + let steel_decimal_instance = SteelDecimal::new(); + + // Test division by zero + let result = steel_decimal_instance.parse_and_execute("(/ 10 0)"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Division by zero")); +} + +// Test complex mathematical expressions +#[rstest] +#[case("(sqrt (+ (* 3 3) (* 4 4)))", "5")] // Pythagorean theorem: sqrt(3² + 4²) = 5 +#[case("(abs (- 10 15))", "5")] // |10 - 15| = 5 +#[case("(max (min 10 5) 3)", "5")] // max(min(10, 5), 3) = max(5, 3) = 5 +fn test_complex_mathematical_expressions(steel_decimal_instance: SteelDecimal, #[case] input: &str, #[case] expected: &str) { + let result = steel_decimal_instance.parse_and_execute(input).unwrap(); + + assert_eq!(result.len(), 1); + if let SteelVal::StringV(s) = &result[0] { + assert_eq!(s.to_string(), expected); + } else { + panic!("Expected StringV, got {:?}", result[0]); + } +} + +// Test financial calculations +#[rstest] +fn test_financial_calculations() { + let steel_decimal_instance = SteelDecimal::new(); + + // Test percentage calculation + let result = steel_decimal_instance.parse_and_execute("(decimal-percentage \"1000\" \"15\")").unwrap(); + assert_eq!(result.len(), 1); + if let SteelVal::StringV(s) = &result[0] { + assert_eq!(s.to_string(), "150"); + } + + // Test compound interest (simplified) + let result = steel_decimal_instance.parse_and_execute("(decimal-compound \"1000\" \"0.05\" \"2\")").unwrap(); + assert_eq!(result.len(), 1); + if let SteelVal::StringV(s) = &result[0] { + assert_eq!(s.to_string(), "1102.5"); + } +} + +// Test constants +#[rstest] +fn test_constants_integration() { + let steel_decimal_instance = SteelDecimal::new(); + + // Test pi + let result = steel_decimal_instance.parse_and_execute("(decimal-pi)").unwrap(); + assert_eq!(result.len(), 1); + if let SteelVal::StringV(s) = &result[0] { + assert!(s.to_string().starts_with("3.14159")); + } + + // Test using pi in calculation + let result = steel_decimal_instance.parse_and_execute("(decimal-mul (decimal-pi) \"2\")").unwrap(); + assert_eq!(result.len(), 1); + if let SteelVal::StringV(s) = &result[0] { + assert!(s.to_string().starts_with("6.28")); + } +} diff --git a/steel_decimal/tests/parser_tests.rs b/steel_decimal/tests/parser_tests.rs new file mode 100644 index 0000000..285a4aa --- /dev/null +++ b/steel_decimal/tests/parser_tests.rs @@ -0,0 +1,158 @@ +use rstest::*; +use steel_decimal::ScriptParser; +use std::collections::HashSet; + +#[fixture] +fn parser() -> ScriptParser { + ScriptParser::new() +} + +#[rstest] +#[case("(+ 1.5 2.3)", "(decimal-add \"1.5\" \"2.3\")")] +#[case("(- 10 5)", "(decimal-sub \"10\" \"5\")")] +#[case("(* 2.5 4)", "(decimal-mul \"2.5\" \"4\")")] +#[case("(/ 15 3)", "(decimal-div \"15\" \"3\")")] +fn test_basic_arithmetic_transformation(parser: ScriptParser, #[case] input: &str, #[case] expected: &str) { + let result = parser.transform(input); + assert_eq!(result, expected); +} + +#[rstest] +#[case("(^ 2 3)", "(decimal-pow \"2\" \"3\")")] +#[case("(** 2 3)", "(decimal-pow \"2\" \"3\")")] +#[case("(pow 2 3)", "(decimal-pow \"2\" \"3\")")] +#[case("(sqrt 16)", "(decimal-sqrt \"16\")")] +#[case("(ln 2.718)", "(decimal-ln \"2.718\")")] +#[case("(log 2.718)", "(decimal-ln \"2.718\")")] +#[case("(log10 100)", "(decimal-log10 \"100\")")] +#[case("(exp 1)", "(decimal-exp \"1\")")] +fn test_advanced_math_transformation(parser: ScriptParser, #[case] input: &str, #[case] expected: &str) { + let result = parser.transform(input); + assert_eq!(result, expected); +} + +#[rstest] +#[case("(sin 1.57)", "(decimal-sin \"1.57\")")] +#[case("(cos 0)", "(decimal-cos \"0\")")] +#[case("(tan 0.785)", "(decimal-tan \"0.785\")")] +fn test_trigonometric_transformation(parser: ScriptParser, #[case] input: &str, #[case] expected: &str) { + let result = parser.transform(input); + assert_eq!(result, expected); +} + +#[rstest] +#[case("(> 5 3)", "(decimal-gt \"5\" \"3\")")] +#[case("(< 3 5)", "(decimal-lt \"3\" \"5\")")] +#[case("(= 5 5)", "(decimal-eq \"5\" \"5\")")] +#[case("(>= 5 3)", "(decimal-gte \"5\" \"3\")")] +#[case("(<= 3 5)", "(decimal-lte \"3\" \"5\")")] +fn test_comparison_transformation(parser: ScriptParser, #[case] input: &str, #[case] expected: &str) { + let result = parser.transform(input); + assert_eq!(result, expected); +} + +#[rstest] +#[case("(abs -5)", "(decimal-abs \"-5\")")] +#[case("(min 3 5)", "(decimal-min \"3\" \"5\")")] +#[case("(max 3 5)", "(decimal-max \"3\" \"5\")")] +#[case("(round 3.14159 2)", "(decimal-round \"3.14159\" \"2\")")] +fn test_utility_transformation(parser: ScriptParser, #[case] input: &str, #[case] expected: &str) { + let result = parser.transform(input); + assert_eq!(result, expected); +} + +#[rstest] +#[case("$x", "(get-var \"x\")")] +#[case("$price", "(get-var \"price\")")] +#[case("$some_variable", "(get-var \"some_variable\")")] +fn test_variable_transformation(parser: ScriptParser, #[case] input: &str, #[case] expected: &str) { + let result = parser.transform(input); + assert_eq!(result, expected); +} + +#[rstest] +#[case("42", "\"42\"")] +#[case("3.14159", "\"3.14159\"")] +#[case("-5.5", "\"-5.5\"")] +#[case("1.5e2", "\"1.5e2\"")] +#[case("2.3E-1", "\"2.3E-1\"")] +fn test_number_literal_transformation(parser: ScriptParser, #[case] input: &str, #[case] expected: &str) { + let result = parser.transform(input); + assert_eq!(result, expected); +} + +#[rstest] +#[case( + "(+ (* 2.5 3.0) (/ 15.0 3.0))", + "(decimal-add (decimal-mul \"2.5\" \"3.0\") (decimal-div \"15.0\" \"3.0\"))" +)] +#[case( + "(sqrt (+ (* $x $x) (* $y $y)))", + "(decimal-sqrt (decimal-add (decimal-mul (get-var \"x\") (get-var \"x\")) (decimal-mul (get-var \"y\") (get-var \"y\"))))" +)] +#[case( + "(/ (+ $a $b) (- $c $d))", + "(decimal-div (decimal-add (get-var \"a\") (get-var \"b\")) (decimal-sub (get-var \"c\") (get-var \"d\")))" +)] +fn test_complex_expressions(parser: ScriptParser, #[case] input: &str, #[case] expected: &str) { + let result = parser.transform(input); + assert_eq!(result, expected); +} + +#[rstest] +#[case("(+ $x $y)", vec!["x", "y"])] +#[case("(* $price $quantity)", vec!["price", "quantity"])] +#[case("(/ (+ $a $b) $c)", vec!["a", "b", "c"])] +#[case("(sqrt (+ (* $x $x) (* $y $y)))", vec!["x", "y"])] +fn test_dependency_extraction(parser: ScriptParser, #[case] input: &str, #[case] expected_deps: Vec<&str>) { + let deps = parser.extract_dependencies(input); + let expected: HashSet = expected_deps.into_iter().map(|s| s.to_string()).collect(); + assert_eq!(deps, expected); +} + +#[rstest] +#[case("(+ 1 2)", "Addition")] +#[case("(- 5 3)", "Subtraction")] +#[case("(* 2 4)", "Multiplication")] +#[case("(/ 8 2)", "Division")] +#[case("(sin 0)", "Trigonometry")] +#[case("(sqrt 16)", "Square root")] +#[case("(> 5 3)", "Comparison")] +fn test_parser_handles_various_functions(parser: ScriptParser, #[case] input: &str, #[case] _description: &str) { + let result = parser.transform(input); + // Should not panic and should produce valid output + assert!(!result.is_empty()); + assert!(result.starts_with('(')); + assert!(result.ends_with(')')); +} + +#[rstest] +fn test_parser_preserves_structure(parser: ScriptParser) { + let input = "(+ (- 10 5) (* 2 3))"; + let result = parser.transform(input); + + // Check that parentheses are balanced + let open_count = result.chars().filter(|c| *c == '(').count(); + let close_count = result.chars().filter(|c| *c == ')').count(); + assert_eq!(open_count, close_count); + + // Check that the structure is preserved + assert!(result.contains("decimal-add")); + assert!(result.contains("decimal-sub")); + assert!(result.contains("decimal-mul")); +} + +#[rstest] +fn test_parser_handles_empty_input(parser: ScriptParser) { + let result = parser.transform(""); + assert_eq!(result, ""); +} + +#[rstest] +fn test_parser_handles_whitespace(parser: ScriptParser) { + let input = "( + 1.5 2.3 )"; + let result = parser.transform(input); + assert!(result.contains("decimal-add")); + assert!(result.contains("\"1.5\"")); + assert!(result.contains("\"2.3\"")); +} diff --git a/steel_decimal/tests/registry_tests.rs b/steel_decimal/tests/registry_tests.rs new file mode 100644 index 0000000..e16c536 --- /dev/null +++ b/steel_decimal/tests/registry_tests.rs @@ -0,0 +1,352 @@ +use rstest::*; +use steel_decimal::{FunctionRegistry, FunctionRegistryBuilder}; +use steel::steel_vm::engine::Engine; +use steel::rvals::SteelVal; +use std::collections::HashMap; + +#[fixture] +fn fresh_vm() -> Engine { + Engine::new() +} + +// Test basic function registration +#[rstest] +fn test_register_all_functions(mut fresh_vm: Engine) { + FunctionRegistry::register_all(&mut fresh_vm); + + // Test that all major function categories work + let scripts = vec![ + r#"(decimal-add "1" "2")"#, + r#"(decimal-pow "2" "3")"#, + r#"(decimal-sin "0")"#, + r#"(decimal-gt "5" "3")"#, + r#"(decimal-abs "-5")"#, + r#"(decimal-zero)"#, + r#"(decimal-percentage "100" "10")"#, + r#"(to-decimal "123.45")"#, + ]; + + for script in scripts { + let result = fresh_vm.compile_and_run_raw_program(script.to_string()); + assert!(result.is_ok(), "Failed to execute: {}", script); + } +} + +// Test selective function registration +#[rstest] +fn test_basic_arithmetic_only(mut fresh_vm: Engine) { + FunctionRegistryBuilder::new() + .basic_arithmetic(true) + .advanced_math(false) + .trigonometric(false) + .comparison(false) + .utility(false) + .constants(false) + .financial(false) + .conversion(false) + .register(&mut fresh_vm); + + // Basic arithmetic should work + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-add "1" "2")"#.to_string()); + assert!(result.is_ok()); + + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-mul "3" "4")"#.to_string()); + assert!(result.is_ok()); + + // Note: Advanced math, trig, etc. won't be available but we can't easily test + // their absence without expecting compilation/runtime errors +} + +#[rstest] +fn test_advanced_math_only(mut fresh_vm: Engine) { + FunctionRegistryBuilder::new() + .basic_arithmetic(false) + .advanced_math(true) + .trigonometric(false) + .comparison(false) + .utility(false) + .constants(false) + .financial(false) + .conversion(false) + .register(&mut fresh_vm); + + // Advanced math should work + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-pow "2" "3")"#.to_string()); + assert!(result.is_ok()); + + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-sqrt "16")"#.to_string()); + assert!(result.is_ok()); +} + +#[rstest] +fn test_trigonometric_only(mut fresh_vm: Engine) { + FunctionRegistryBuilder::new() + .basic_arithmetic(false) + .advanced_math(false) + .trigonometric(true) + .comparison(false) + .utility(false) + .constants(false) + .financial(false) + .conversion(false) + .register(&mut fresh_vm); + + // Trigonometric functions should work + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-sin "0")"#.to_string()); + assert!(result.is_ok()); + + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-cos "0")"#.to_string()); + assert!(result.is_ok()); +} + +#[rstest] +fn test_comparison_only(mut fresh_vm: Engine) { + FunctionRegistryBuilder::new() + .basic_arithmetic(false) + .advanced_math(false) + .trigonometric(false) + .comparison(true) + .utility(false) + .constants(false) + .financial(false) + .conversion(false) + .register(&mut fresh_vm); + + // Comparison functions should work + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-gt "5" "3")"#.to_string()); + assert!(result.is_ok()); + + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-eq "5" "5")"#.to_string()); + assert!(result.is_ok()); +} + +#[rstest] +fn test_utility_only(mut fresh_vm: Engine) { + FunctionRegistryBuilder::new() + .basic_arithmetic(false) + .advanced_math(false) + .trigonometric(false) + .comparison(false) + .utility(true) + .constants(false) + .financial(false) + .conversion(false) + .register(&mut fresh_vm); + + // Utility functions should work + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-abs "-5")"#.to_string()); + assert!(result.is_ok()); + + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-min "3" "5")"#.to_string()); + assert!(result.is_ok()); +} + +#[rstest] +fn test_constants_only(mut fresh_vm: Engine) { + FunctionRegistryBuilder::new() + .basic_arithmetic(false) + .advanced_math(false) + .trigonometric(false) + .comparison(false) + .utility(false) + .constants(true) + .financial(false) + .conversion(false) + .register(&mut fresh_vm); + + // Constants should work + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-pi)"#.to_string()); + assert!(result.is_ok()); + + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-zero)"#.to_string()); + assert!(result.is_ok()); +} + +#[rstest] +fn test_financial_only(mut fresh_vm: Engine) { + FunctionRegistryBuilder::new() + .basic_arithmetic(false) + .advanced_math(false) + .trigonometric(false) + .comparison(false) + .utility(false) + .constants(false) + .financial(true) + .conversion(false) + .register(&mut fresh_vm); + + // Financial functions should work + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-percentage "100" "10")"#.to_string()); + assert!(result.is_ok()); + + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-compound "1000" "0.05" "1")"#.to_string()); + assert!(result.is_ok()); +} + +#[rstest] +fn test_conversion_only(mut fresh_vm: Engine) { + FunctionRegistryBuilder::new() + .basic_arithmetic(false) + .advanced_math(false) + .trigonometric(false) + .comparison(false) + .utility(false) + .constants(false) + .financial(false) + .conversion(true) + .register(&mut fresh_vm); + + // Conversion functions should work + let result = fresh_vm.compile_and_run_raw_program(r#"(to-decimal "123.45")"#.to_string()); + assert!(result.is_ok()); +} + +// Test variable registration +#[rstest] +fn test_variable_registration(mut fresh_vm: Engine) { + let mut variables = HashMap::new(); + variables.insert("x".to_string(), "10.5".to_string()); + variables.insert("y".to_string(), "20.3".to_string()); + variables.insert("name".to_string(), "test_value".to_string()); + + FunctionRegistryBuilder::new() + .basic_arithmetic(true) + .with_variables(variables) + .register(&mut fresh_vm); + + // Test getting variables + let result = fresh_vm.compile_and_run_raw_program(r#"(get-var "x")"#.to_string()).unwrap(); + assert_eq!(result.len(), 1); + if let SteelVal::StringV(s) = &result[0] { + assert_eq!(s.to_string(), "10.5"); + } + + let result = fresh_vm.compile_and_run_raw_program(r#"(get-var "y")"#.to_string()).unwrap(); + assert_eq!(result.len(), 1); + if let SteelVal::StringV(s) = &result[0] { + assert_eq!(s.to_string(), "20.3"); + } + + // Test checking if variables exist + let result = fresh_vm.compile_and_run_raw_program(r#"(has-var? "x")"#.to_string()).unwrap(); + assert_eq!(result.len(), 1); + if let SteelVal::BoolV(b) = &result[0] { + assert!(b); + } + + let result = fresh_vm.compile_and_run_raw_program(r#"(has-var? "nonexistent")"#.to_string()).unwrap(); + assert_eq!(result.len(), 1); + if let SteelVal::BoolV(b) = &result[0] { + assert!(!b); + } + + // Test using variables in arithmetic + let result = fresh_vm.compile_and_run_raw_program(r#"(decimal-add (get-var "x") (get-var "y"))"#.to_string()).unwrap(); + assert_eq!(result.len(), 1); + if let SteelVal::StringV(s) = &result[0] { + assert_eq!(s.to_string(), "30.8"); + } +} + +// Test error handling in variable access +#[rstest] +fn test_variable_error_handling(mut fresh_vm: Engine) { + let variables = HashMap::new(); // Empty variables + + FunctionRegistryBuilder::new() + .with_variables(variables) + .register(&mut fresh_vm); + + // Should get an error for non-existent variable + let result = fresh_vm.compile_and_run_raw_program(r#"(get-var "nonexistent")"#.to_string()); + // The function should return an error, but Steel might handle it differently + // This test ensures the function is registered and callable + assert!(result.is_ok() || result.is_err()); // Either way is fine, just shouldn't panic +} + +// Test function name listing +#[rstest] +fn test_function_names() { + let names = FunctionRegistry::get_function_names(); + + // Check that all expected function categories are present + assert!(names.contains(&"decimal-add")); + assert!(names.contains(&"decimal-pow")); + assert!(names.contains(&"decimal-sin")); + assert!(names.contains(&"decimal-gt")); + assert!(names.contains(&"decimal-abs")); + assert!(names.contains(&"decimal-pi")); + assert!(names.contains(&"decimal-percentage")); + assert!(names.contains(&"to-decimal")); + assert!(names.contains(&"get-var")); + assert!(names.contains(&"has-var?")); + + // Check total count is reasonable + assert!(names.len() >= 25); // Should have at least 25 functions +} + +// Test builder pattern combinations +#[rstest] +fn test_builder_combinations(mut fresh_vm: Engine) { + FunctionRegistryBuilder::new() + .basic_arithmetic(true) + .comparison(true) + .constants(true) + .register(&mut fresh_vm); + + // Should be able to use arithmetic, comparison, and constants + let result = fresh_vm.compile_and_run_raw_program( + r#"(decimal-gt (decimal-add "1" "2") (decimal-zero))"#.to_string() + ).unwrap(); + + assert_eq!(result.len(), 1); + if let SteelVal::BoolV(b) = &result[0] { + assert!(b); // 3 > 0 should be true + } +} + +// Test default builder behavior +#[rstest] +fn test_builder_defaults(mut fresh_vm: Engine) { + // Default should include everything + FunctionRegistryBuilder::new().register(&mut fresh_vm); + + // Should be able to use any function + let scripts = vec![ + r#"(decimal-add "1" "2")"#, + r#"(decimal-pow "2" "3")"#, + r#"(decimal-sin "0")"#, + r#"(decimal-gt "5" "3")"#, + r#"(decimal-abs "-5")"#, + r#"(decimal-pi)"#, + r#"(decimal-percentage "100" "10")"#, + r#"(to-decimal "123.45")"#, + ]; + + for script in scripts { + let result = fresh_vm.compile_and_run_raw_program(script.to_string()); + assert!(result.is_ok(), "Failed to execute with default builder: {}", script); + } +} + +// Test multiple variable sets +#[rstest] +fn test_multiple_variable_registration(mut fresh_vm: Engine) { + let mut variables1 = HashMap::new(); + variables1.insert("a".to_string(), "1".to_string()); + variables1.insert("b".to_string(), "2".to_string()); + + // Register first set + FunctionRegistryBuilder::new() + .with_variables(variables1) + .register(&mut fresh_vm); + + // Test first set works + let result = fresh_vm.compile_and_run_raw_program(r#"(get-var "a")"#.to_string()).unwrap(); + if let SteelVal::StringV(s) = &result[0] { + assert_eq!(s.to_string(), "1"); + } + + // Note: In a real scenario, you probably wouldn't register variables twice + // This is just testing that the registration mechanism works +} diff --git a/steel_decimal/tests/utils_tests.rs b/steel_decimal/tests/utils_tests.rs new file mode 100644 index 0000000..59176a7 --- /dev/null +++ b/steel_decimal/tests/utils_tests.rs @@ -0,0 +1,343 @@ +use rstest::*; +use steel_decimal::{TypeConverter, ScriptAnalyzer, DecimalPrecision, ConversionError}; +use steel::rvals::SteelVal; +use rust_decimal::Decimal; +use std::str::FromStr; + +// TypeConverter Tests +#[rstest] +#[case("123.456")] +#[case("42")] +#[case("-15.75")] +#[case("0")] +fn test_steel_val_string_to_decimal(#[case] input: &str) { + let steel_val = SteelVal::StringV(input.into()); + let result = TypeConverter::steel_val_to_decimal(&steel_val).unwrap(); + let expected = Decimal::from_str(input).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case(42isize)] +#[case(-15isize)] +#[case(0isize)] +fn test_steel_val_int_to_decimal(#[case] input: isize) { + let steel_val = SteelVal::IntV(input); + let result = TypeConverter::steel_val_to_decimal(&steel_val).unwrap(); + let expected = Decimal::from(input); + assert_eq!(result, expected); +} + +#[rstest] +#[case(42.5)] +#[case(-15.75)] +#[case(0.0)] +fn test_steel_val_num_to_decimal(#[case] input: f64) { + let steel_val = SteelVal::NumV(input); + let result = TypeConverter::steel_val_to_decimal(&steel_val).unwrap(); + let expected = Decimal::try_from(input).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +fn test_steel_val_unsupported_type() { + let steel_val = SteelVal::BoolV(true); + let result = TypeConverter::steel_val_to_decimal(&steel_val); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), ConversionError::UnsupportedType(_))); +} + +#[rstest] +fn test_steel_val_invalid_decimal_string() { + let steel_val = SteelVal::StringV("not_a_number".into()); + let result = TypeConverter::steel_val_to_decimal(&steel_val); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), ConversionError::InvalidDecimal(_))); +} + +#[rstest] +#[case("123.456")] +#[case("42")] +#[case("-15.75")] +#[case("0")] +fn test_decimal_to_steel_val(#[case] input: &str) { + let decimal = Decimal::from_str(input).unwrap(); + let result = TypeConverter::decimal_to_steel_val(decimal); + + if let SteelVal::StringV(s) = result { + assert_eq!(s.to_string(), input); + } else { + panic!("Expected StringV"); + } +} + +#[rstest] +fn test_steel_vals_to_strings() { + let vals = vec![ + SteelVal::StringV("hello".into()), + SteelVal::IntV(42), + SteelVal::NumV(3.14), + SteelVal::BoolV(true), + ]; + + let result = TypeConverter::steel_vals_to_strings(vals).unwrap(); + let expected = vec!["hello", "42", "3.14", "true"]; + assert_eq!(result, expected); +} + +#[rstest] +fn test_steel_val_vector_to_string() { + let inner_vals = vec![ + SteelVal::StringV("a".into()), + SteelVal::StringV("b".into()), + SteelVal::StringV("c".into()), + ]; + let vector_val = SteelVal::VectorV(inner_vals.into()); + + let result = TypeConverter::steel_val_to_string(vector_val).unwrap(); + assert_eq!(result, "a,b,c"); +} + +#[rstest] +#[case("123.456", "123.456")] +#[case("42", "42")] +#[case("-15.75", "-15.75")] +fn test_validate_decimal_string(#[case] input: &str, #[case] expected: &str) { + let result = TypeConverter::validate_decimal_string(input).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("not_a_number")] +#[case("")] +#[case("12.34.56")] +fn test_validate_decimal_string_invalid(#[case] input: &str) { + let result = TypeConverter::validate_decimal_string(input); + assert!(result.is_err()); +} + +#[rstest] +#[case(123.456)] +#[case(-15.75)] +#[case(0.0)] +fn test_f64_to_decimal_string(#[case] input: f64) { + let result = TypeConverter::f64_to_decimal_string(input).unwrap(); + let expected = Decimal::try_from(input).unwrap().to_string(); + assert_eq!(result, expected); +} + +#[rstest] +#[case(42i64, "42")] +#[case(-15i64, "-15")] +#[case(0i64, "0")] +fn test_i64_to_decimal_string(#[case] input: i64, #[case] expected: &str) { + let result = TypeConverter::i64_to_decimal_string(input); + assert_eq!(result, expected); +} + +#[rstest] +#[case(42u64, "42")] +#[case(0u64, "0")] +#[case(1000u64, "1000")] +fn test_u64_to_decimal_string(#[case] input: u64, #[case] expected: &str) { + let result = TypeConverter::u64_to_decimal_string(input); + assert_eq!(result, expected); +} + +// ScriptAnalyzer Tests +#[rstest] +#[case("123.456", true)] +#[case("42", true)] +#[case("-15.75", true)] +#[case("0", true)] +#[case("not_a_number", false)] +#[case("", false)] +#[case("12.34.56", false)] +fn test_is_decimal_like(#[case] input: &str, #[case] expected: bool) { + let result = ScriptAnalyzer::is_decimal_like(input); + assert_eq!(result, expected); +} + +#[rstest] +#[case(r#"(test "hello" "world")"#, vec!["hello", "world"])] +#[case(r#"(add "1.5" "2.3")"#, vec!["1.5", "2.3"])] +#[case(r#"(func)"#, vec![])] +#[case(r#""single""#, vec!["single"])] +fn test_extract_string_literals(#[case] input: &str, #[case] expected: Vec<&str>) { + let result = ScriptAnalyzer::extract_string_literals(input); + let expected: Vec = expected.into_iter().map(|s| s.to_string()).collect(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("(decimal-add x y)", "decimal-add", 1)] +#[case("(decimal-add x y) (decimal-add a b)", "decimal-add", 2)] +#[case("(decimal-mul x y)", "decimal-add", 0)] +#[case("(decimal-add x (decimal-add y z))", "decimal-add", 2)] +fn test_count_function_calls(#[case] script: &str, #[case] function_name: &str, #[case] expected: usize) { + let result = ScriptAnalyzer::count_function_calls(script, function_name); + assert_eq!(result, expected); +} + +#[rstest] +#[case("(decimal-add x y)", true)] +#[case("(decimal-mul a b)", true)] +#[case("(decimal-sin x)", true)] +#[case("(regular-function x y)", false)] +#[case("(+ x y)", false)] +fn test_contains_decimal_functions(#[case] script: &str, #[case] expected: bool) { + let result = ScriptAnalyzer::contains_decimal_functions(script); + assert_eq!(result, expected); +} + +// DecimalPrecision Tests +#[rstest] +#[case("123.456789", 2, "123.46")] +#[case("123.456789", 4, "123.4568")] +#[case("123.456789", 0, "123")] +#[case("123", 2, "123.00")] +fn test_set_precision(#[case] input: &str, #[case] precision: u32, #[case] expected: &str) { + let result = DecimalPrecision::set_precision(input, precision).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("123.456", 3)] +#[case("123", 0)] +#[case("123.000", 3)] +#[case("0.001", 3)] +fn test_get_decimal_places(#[case] input: &str, #[case] expected: u32) { + let result = DecimalPrecision::get_decimal_places(input).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +#[case("123.000", "123")] +#[case("123.456", "123.456")] +#[case("0.100", "0.1")] +#[case("1000.000", "1000")] +fn test_normalize(#[case] input: &str, #[case] expected: &str) { + let result = DecimalPrecision::normalize(input).unwrap(); + assert_eq!(result, expected); +} + +#[rstest] +fn test_precision_invalid_input() { + let result = DecimalPrecision::set_precision("not_a_number", 2); + assert!(result.is_err()); + + let result = DecimalPrecision::get_decimal_places("invalid"); + assert!(result.is_err()); + + let result = DecimalPrecision::normalize("bad_input"); + assert!(result.is_err()); +} + +// Edge Cases and Error Handling +#[rstest] +fn test_type_converter_edge_cases() { + // Empty string + let result = TypeConverter::validate_decimal_string(""); + assert!(result.is_err()); + + // Very large number + let large_num = "999999999999999999999999999999"; + let result = TypeConverter::validate_decimal_string(large_num); + assert!(result.is_ok()); + + // Very small number + let small_num = "0.000000000000000000000001"; + let result = TypeConverter::validate_decimal_string(small_num); + assert!(result.is_ok()); +} + +#[rstest] +fn test_script_analyzer_edge_cases() { + // Empty script + let result = ScriptAnalyzer::extract_string_literals(""); + assert!(result.is_empty()); + + let result = ScriptAnalyzer::contains_decimal_functions(""); + assert!(!result); + + // Script with no functions + let result = ScriptAnalyzer::count_function_calls("just some text", "decimal-add"); + assert_eq!(result, 0); + + // Script with escaped quotes + let result = ScriptAnalyzer::extract_string_literals(r#"(test "hello \"world\"")"#); + assert_eq!(result, vec!["hello \\\"world\\\""]); +} + +#[rstest] +fn test_conversion_error_types() { + // Test InvalidDecimal error + let result = TypeConverter::validate_decimal_string("invalid"); + assert!(result.is_err()); + if let Err(ConversionError::InvalidDecimal(msg)) = result { + assert!(msg.contains("invalid")); + } else { + panic!("Expected InvalidDecimal error"); + } + + // Test UnsupportedType error + let steel_val = SteelVal::BoolV(true); + let result = TypeConverter::steel_val_to_decimal(&steel_val); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), ConversionError::UnsupportedType(_))); + + // Test ConversionFailed error (using f64 that can't be converted) + let result = TypeConverter::f64_to_decimal_string(f64::NAN); + assert!(result.is_err()); + if let Err(ConversionError::ConversionFailed(msg)) = result { + assert!(msg.contains("f64 to decimal")); + } else { + panic!("Expected ConversionFailed error"); + } +} + +// Integration tests for utilities +#[rstest] +fn test_utility_integration() { + // Test a complete workflow: string -> decimal -> string + let input = "123.456789"; + + // Validate input + let validated = TypeConverter::validate_decimal_string(input).unwrap(); + + // Set precision + let precise = DecimalPrecision::set_precision(&validated, 2).unwrap(); + assert_eq!(precise, "123.46"); + + // Check if it's decimal-like + assert!(ScriptAnalyzer::is_decimal_like(&precise)); + + // Normalize (should be no change since already precise) + let normalized = DecimalPrecision::normalize(&precise).unwrap(); + assert_eq!(normalized, "123.46"); +} + +#[rstest] +fn test_complex_script_analysis() { + let script = r#" + (decimal-add "1.5" "2.3") + (decimal-mul "result" "factor") + (decimal-sin "angle") + (decimal-gt "value" "threshold") + "#; + + // Should detect decimal functions + assert!(ScriptAnalyzer::contains_decimal_functions(script)); + + // Count specific functions + assert_eq!(ScriptAnalyzer::count_function_calls(script, "decimal-add"), 1); + assert_eq!(ScriptAnalyzer::count_function_calls(script, "decimal-mul"), 1); + assert_eq!(ScriptAnalyzer::count_function_calls(script, "decimal-sin"), 1); + assert_eq!(ScriptAnalyzer::count_function_calls(script, "decimal-gt"), 1); + + // Extract string literals + let literals = ScriptAnalyzer::extract_string_literals(script); + assert!(literals.contains(&"1.5".to_string())); + assert!(literals.contains(&"2.3".to_string())); + assert!(literals.contains(&"result".to_string())); +}