diff --git a/Cargo.lock b/Cargo.lock index 20d9d66..22d724b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,6 +118,18 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" + [[package]] name = "anyhow" version = "1.0.98" @@ -324,6 +336,21 @@ dependencies = [ "serde", ] +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + [[package]] name = "bitflags" version = "2.9.0" @@ -476,6 +503,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "castaway" version = "0.2.3" @@ -529,6 +562,33 @@ dependencies = [ "windows-link", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "cipher" version = "0.4.4" @@ -539,6 +599,31 @@ dependencies = [ "inout", ] +[[package]] +name = "clap" +version = "4.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" + [[package]] name = "client" version = "0.3.13" @@ -701,6 +786,39 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bf7af66b0989381bd0be551bd7cc91912a655a58c6918420c9527b1fd8b4679" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "itertools 0.13.0", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam" version = "0.8.4" @@ -868,6 +986,12 @@ dependencies = [ "serde", ] +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + [[package]] name = "digest" version = "0.10.7" @@ -960,6 +1084,16 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7f84e12ccf0a7ddc17a6c41c93326024c42920d7ee630d04950e6926645c0fe" +[[package]] +name = "env_logger" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3" +dependencies = [ + "log", + "regex", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -1271,6 +1405,16 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" +dependencies = [ + "cfg-if", + "crunchy", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1725,6 +1869,15 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.13.0" @@ -2154,6 +2307,12 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea" +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "openssl" version = "0.10.72" @@ -2348,6 +2507,34 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "polling" version = "3.7.4" @@ -2389,6 +2576,16 @@ dependencies = [ "unicode-width 0.1.14", ] +[[package]] +name = "pretty_assertions" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" +dependencies = [ + "diff", + "yansi", +] + [[package]] name = "prettyplease" version = "0.2.32" @@ -2439,6 +2636,26 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proptest" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fcdab19deb5195a31cf7726a210015ff1496ba1464fd42cb4f537b8b01b471f" +dependencies = [ + "bit-set", + "bit-vec", + "bitflags", + "lazy_static", + "num-traits", + "rand 0.9.1", + "rand_chacha 0.9.0", + "rand_xorshift", + "regex-syntax", + "rusty-fork", + "tempfile", + "unarray", +] + [[package]] name = "prost" version = "0.13.5" @@ -2511,6 +2728,23 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + +[[package]] +name = "quickcheck" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6" +dependencies = [ + "env_logger", + "log", + "rand 0.8.5", +] + [[package]] name = "quickscope" version = "0.2.0" @@ -2617,6 +2851,15 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "rand_xorshift" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" +dependencies = [ + "rand_core 0.9.3", +] + [[package]] name = "rand_xoshiro" version = "0.6.0" @@ -2933,12 +3176,42 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" +[[package]] +name = "rusty-fork" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb3dcc6e454c328bb824492db107ab7c0ae8fcffe4ad210136ef014458c1bc4f" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", +] + [[package]] name = "ryu" version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "scc" +version = "2.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22b2d775fb28f245817589471dd49c5edf64237f4a19d10ce9a92ff4651a27f4" +dependencies = [ + "sdd", +] + [[package]] name = "schannel" version = "0.1.27" @@ -2954,6 +3227,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sdd" +version = "3.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "584e070911c7017da6cb2eb0788d09f43d789029b5877d3e5ecc8acf86ceee21" + [[package]] name = "seahash" version = "4.1.0" @@ -3059,6 +3338,31 @@ dependencies = [ "serde", ] +[[package]] +name = "serial_test" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b258109f244e1d6891bf1053a55d63a5cd4f8f4c30cf9a1280989f80e7a1fa9" +dependencies = [ + "futures", + "log", + "once_cell", + "parking_lot", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "server" version = "0.3.13" @@ -3574,12 +3878,18 @@ dependencies = [ name = "steel_decimal" version = "1.0.0" dependencies = [ + "criterion", + "pretty_assertions", + "proptest", + "quickcheck", "regex", "rstest", "rust_decimal", "rust_decimal_macros", + "serial_test", "steel-core", "steel-derive 0.5.0 (git+https://github.com/mattwparas/steel.git?branch=master)", + "tempfile", "thiserror 2.0.12", "tokio-test", ] @@ -3822,9 +4132,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.19.1" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ "fastrand", "getrandom 0.3.2", @@ -3935,6 +4245,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.9.0" @@ -4240,6 +4560,12 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + [[package]] name = "unicode-bidi" version = "0.3.18" @@ -4391,6 +4717,25 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "wait-timeout" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" +dependencies = [ + "libc", +] + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -4485,6 +4830,16 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "323f4da9523e9a669e1eaf9c6e763892769b1d38c623913647bfdc1532fe4549" +[[package]] +name = "web-sys" +version = "0.3.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "which" version = "7.0.3" @@ -4790,6 +5145,12 @@ dependencies = [ "tap", ] +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yoke" version = "0.7.5" diff --git a/steel_decimal/Cargo.toml b/steel_decimal/Cargo.toml index f9e7ea0..968f868 100644 --- a/steel_decimal/Cargo.toml +++ b/steel_decimal/Cargo.toml @@ -21,3 +21,10 @@ thiserror = { workspace = true } [dev-dependencies] rstest = "0.25.0" tokio-test = "0.4.4" +serial_test = "3.2.0" +pretty_assertions = "1.4.1" +tempfile = "3.20.0" +proptest = "1.7.0" +quickcheck = "1.0.3" +criterion = { version = "0.6.0", features = ["html_reports"] } + diff --git a/steel_decimal/Makefile b/steel_decimal/Makefile new file mode 100644 index 0000000..b3a3d79 --- /dev/null +++ b/steel_decimal/Makefile @@ -0,0 +1,28 @@ +# Simple Makefile for Steel Decimal + +# Production test settings +export PROPTEST_CASES=10000 +export RUST_BACKTRACE=1 + +.PHONY: test check + +# Run all tests with production settings +test: + @echo "Running all tests..." + @cargo test --release + +# Quick development test +test-dev: + @PROPTEST_CASES=100 cargo test + +# Check compilation +check: + @cargo check + +# Clean build artifacts +clean: + @cargo clean + +# Run with specific test threads for concurrency tests +test-concurrent: + @cargo test concurrency_tests --release -- --test-threads=1 diff --git a/steel_decimal/src/functions.rs b/steel_decimal/src/functions.rs index 19787a9..76fc613 100644 --- a/steel_decimal/src/functions.rs +++ b/steel_decimal/src/functions.rs @@ -75,23 +75,32 @@ fn parse_scientific_notation(s: &str) -> Result { Ok(result) } -// Basic arithmetic operations (now precision-aware) +// Basic arithmetic operations pub fn decimal_add(a: String, b: String) -> Result { let a_dec = parse_decimal(&a)?; let b_dec = parse_decimal(&b)?; - Ok(format_result(a_dec + b_dec)) + + a_dec.checked_add(b_dec) + .map(|result| format_result(result)) + .ok_or_else(|| "Addition overflowed".to_string()) } pub fn decimal_sub(a: String, b: String) -> Result { let a_dec = parse_decimal(&a)?; let b_dec = parse_decimal(&b)?; - Ok(format_result(a_dec - b_dec)) + + a_dec.checked_sub(b_dec) + .map(|result| format_result(result)) + .ok_or_else(|| "Subtraction overflowed".to_string()) } pub fn decimal_mul(a: String, b: String) -> Result { let a_dec = parse_decimal(&a)?; let b_dec = parse_decimal(&b)?; - Ok(format_result(a_dec * b_dec)) + + a_dec.checked_mul(b_dec) + .map(|result| format_result(result)) + .ok_or_else(|| "Multiplication overflowed".to_string()) } pub fn decimal_div(a: String, b: String) -> Result { @@ -102,7 +111,9 @@ pub fn decimal_div(a: String, b: String) -> Result { return Err("Division by zero".to_string()); } - Ok(format_result(a_dec / b_dec)) + a_dec.checked_div(b_dec) + .map(|result| format_result(result)) + .ok_or_else(|| "Division overflowed".to_string()) } // Precision control functions diff --git a/steel_decimal/tests/boundary_tests.rs b/steel_decimal/tests/boundary_tests.rs new file mode 100644 index 0000000..cd47d64 --- /dev/null +++ b/steel_decimal/tests/boundary_tests.rs @@ -0,0 +1,324 @@ +// tests/boundary_tests.rs +use rstest::*; +use steel_decimal::*; +use rust_decimal::Decimal; +use std::str::FromStr; + +// Test extreme decimal values +#[rstest] +#[case("79228162514264337593543950335")] // Max decimal value +#[case("-79228162514264337593543950335")] // Min decimal value +#[case("0.0000000000000000000000000001")] // Smallest positive decimal (28 decimal places) +#[case("-0.0000000000000000000000000001")] // Smallest negative decimal +#[case("999999999999999999999999999.9999")] // Near maximum with precision +fn test_extreme_decimal_values(#[case] extreme_value: &str) { + // These should not panic, but may return errors for unsupported ranges + let add_result = decimal_add(extreme_value.to_string(), "1".to_string()); + let sub_result = decimal_sub(extreme_value.to_string(), "1".to_string()); + let abs_result = decimal_abs(extreme_value.to_string()); + let conversion_result = to_decimal(extreme_value.to_string()); + + // At minimum, conversion should work for valid decimals + if let Ok(_) = Decimal::from_str(extreme_value) { + assert!(conversion_result.is_ok(), "Valid decimal should convert: {}", extreme_value); + } +} + +// Test maximum precision scenarios +#[rstest] +#[case(0)] +#[case(28)] // Maximum precision +fn test_precision_boundaries(#[case] precision: u32) { + let test_value = "123.456789012345678901234567890123456789"; + + if precision <= 28 { + let result = decimal_format(test_value.to_string(), precision); + assert!(result.is_ok(), "Precision {} should be valid", precision); + + if let Ok(formatted) = result { + if precision == 0 { + assert!(!formatted.contains('.'), "Precision 0 should not have decimal point"); + } else { + let decimal_places = formatted.split('.').nth(1).map(|s| s.len()).unwrap_or(0); + assert!(decimal_places <= precision as usize, + "Result should have at most {} decimal places, got {}", + precision, decimal_places); + } + } + } +} + +// Test precision setting boundaries +#[rstest] +#[case(29)] // One over maximum +#[case(100)] // Way over maximum +#[case(u32::MAX)] // Maximum u32 +fn test_invalid_precision_values(#[case] invalid_precision: u32) { + let result = set_precision(invalid_precision); + assert!(result.contains("Error"), "Should reject precision {}", invalid_precision); + + // Verify precision wasn't actually set + let current = get_precision(); + assert_ne!(current, invalid_precision.to_string()); +} + +// Test very long input strings +#[rstest] +fn test_very_long_inputs() { + // Create very long but valid decimal string + let long_integer = "1".repeat(1000); + let long_decimal = format!("{}.{}", "1".repeat(500), "2".repeat(28)); // Respect max precision + let very_long_decimal = format!("{}.{}", "9".repeat(2000), "1".repeat(28)); + + // These might fail due to decimal limits, but shouldn't panic + let _ = to_decimal(long_integer); + let _ = to_decimal(long_decimal); + let _ = to_decimal(very_long_decimal); + + // Operations on long strings + let _ = decimal_add("1".repeat(100), "2".repeat(100)); + let _ = decimal_mul("1".repeat(50), "3".repeat(50)); +} + +// Test scientific notation boundaries +#[rstest] +#[case("1e308")] // Near f64 max +#[case("1e-324")] // Near f64 min +#[case("1e1000")] // Way beyond f64 +#[case("1e-1000")] // Way beyond f64 +#[case("1.5e100")] +#[case("9.999e99")] +#[case("1.23456789e-50")] +fn test_extreme_scientific_notation(#[case] sci_notation: &str) { + let result = to_decimal(sci_notation.to_string()); + + // Should either succeed or fail gracefully + match result { + Ok(converted) => { + // If successful, should be a valid decimal + assert!(Decimal::from_str(&converted).is_ok(), + "Converted result should be valid decimal: {}", converted); + } + Err(_) => { + // Failure is acceptable for extreme values + } + } +} + +// Test edge cases in arithmetic operations +#[rstest] +fn test_arithmetic_edge_cases() { + let max_decimal = "79228162514264337593543950335"; + let min_decimal = "-79228162514264337593543950335"; + let tiny_decimal = "0.0000000000000000000000000001"; + + // Addition near overflow + let _result = decimal_add(max_decimal.to_string(), "1".to_string()); + // May overflow, but shouldn't panic + + // Subtraction near underflow + let _result = decimal_sub(min_decimal.to_string(), "1".to_string()); + // May underflow, but shouldn't panic + + // Multiplication that could overflow + let _result = decimal_mul(max_decimal.to_string(), "2".to_string()); + // May overflow, but shouldn't panic + + // Division by very small number + let _result = decimal_div("1".to_string(), tiny_decimal.to_string()); + // May be very large, but shouldn't panic + + // All operations should complete without panicking +} + +// Test malformed but potentially parseable inputs +#[rstest] +#[case("1.2.3")] // Multiple decimal points +#[case("1..2")] // Double decimal point +#[case(".123")] // Leading decimal point +#[case("123.")] // Trailing decimal point +#[case("1.23e")] // Incomplete scientific notation +#[case("1.23e+")] // Incomplete positive exponent +#[case("1.23e-")] // Incomplete negative exponent +#[case("e5")] // Missing mantissa +#[case("1e1e1")] // Multiple exponents +#[case("++1")] // Multiple signs +#[case("--1")] // Multiple negative signs +#[case("1.23.45e6")] // Decimal in mantissa and base +fn test_malformed_decimal_inputs(#[case] malformed: &str) { + // These should all fail gracefully, not panic + let result = to_decimal(malformed.to_string()); + assert!(result.is_err(), "Malformed input should be rejected: {}", malformed); + + // Test in arithmetic operations too + let _ = decimal_add(malformed.to_string(), "1".to_string()); + let _ = decimal_sub("1".to_string(), malformed.to_string()); + let _ = decimal_mul(malformed.to_string(), "2".to_string()); + let _ = decimal_abs(malformed.to_string()); +} + +// Test edge cases in comparison operations +#[rstest] +fn test_comparison_edge_cases() { + // Test comparisons at boundaries + let results = [ + decimal_eq("0".to_string(), "-0".to_string()), + decimal_eq("0.0".to_string(), "0.00".to_string()), + decimal_gt("0.0000000000000000000000000001".to_string(), "0".to_string()), + decimal_lt("-0.0000000000000000000000000001".to_string(), "0".to_string()), + ]; + + for result in results { + assert!(result.is_ok(), "Comparison should not fail"); + } + + // Test with extreme values + let max_val = "79228162514264337593543950335"; + let min_val = "-79228162514264337593543950335"; + + assert!(decimal_gt(max_val.to_string(), min_val.to_string()).unwrap_or(false)); + assert!(decimal_lt(min_val.to_string(), max_val.to_string()).unwrap_or(false)); +} + +// Test trigonometric functions at boundaries +#[rstest] +#[case("0")] // sin(0) = 0, cos(0) = 1 +#[case("1.5707963267948966")] // π/2 +#[case("3.1415926535897932")] // π +#[case("6.2831853071795865")] // 2π +#[case("100")] // Large angle +#[case("-100")] // Large negative angle +fn test_trig_function_boundaries(#[case] angle: &str) { + let sin_result = decimal_sin(angle.to_string()); + let cos_result = decimal_cos(angle.to_string()); + let tan_result = decimal_tan(angle.to_string()); + + // These should all complete without panicking + // Results may be imprecise for large angles, but should be finite + if let Ok(sin_val) = sin_result { + let sin_decimal = Decimal::from_str(&sin_val).unwrap(); + assert!(sin_decimal.abs() <= Decimal::from(2), "Sin should be bounded: {}", sin_val); + } + + if let Ok(cos_val) = cos_result { + let cos_decimal = Decimal::from_str(&cos_val).unwrap(); + assert!(cos_decimal.abs() <= Decimal::from(2), "Cos should be bounded: {}", cos_val); + } +} + +// Test logarithmic functions at boundaries +#[rstest] +#[case("1")] // ln(1) = 0 +#[case("2.718281828459045")] // ln(e) = 1 +#[case("0.0000000000000000000000000001")] // Very small positive +#[case("79228162514264337593543950335")] // Very large +fn test_log_function_boundaries(#[case] value: &str) { + let ln_result = decimal_ln(value.to_string()); + let log10_result = decimal_log10(value.to_string()); + + // Should not panic, may return errors for invalid domains + if Decimal::from_str(value).unwrap() > Decimal::ZERO { + // Positive values should potentially work + match ln_result { + Ok(_) => {}, // Success is fine + Err(_) => {}, // Failure is also acceptable for extreme values + } + } else { + // Zero or negative should fail + assert!(ln_result.is_err(), "ln of non-positive should fail"); + } +} + +// Test square root at boundaries +#[rstest] +#[case("0")] // sqrt(0) = 0 +#[case("1")] // sqrt(1) = 1 +#[case("4")] // sqrt(4) = 2 +#[case("0.0000000000000000000000000001")] // Very small +#[case("79228162514264337593543950335")] // Very large +fn test_sqrt_boundaries(#[case] value: &str) { + let result = decimal_sqrt(value.to_string()); + + if Decimal::from_str(value).unwrap() >= Decimal::ZERO { + match result { + Ok(sqrt_val) => { + let sqrt_decimal = Decimal::from_str(&sqrt_val).unwrap(); + assert!(sqrt_decimal >= Decimal::ZERO, "Square root should be non-negative"); + } + Err(_) => { + // May fail for very large values + } + } + } else { + assert!(result.is_err(), "Square root of negative should fail"); + } +} + +// Test power function boundaries +#[rstest] +#[case("2", "0")] // 2^0 = 1 +#[case("2", "1")] // 2^1 = 2 +#[case("2", "10")] // 2^10 = 1024 +#[case("0", "5")] // 0^5 = 0 +#[case("1", "1000")] // 1^1000 = 1 +#[case("2", "100")] // Large exponent +#[case("10", "20")] // Another large case +fn test_pow_boundaries(#[case] base: &str, #[case] exponent: &str) { + let result = decimal_pow(base.to_string(), exponent.to_string()); + + // Should not panic, may overflow for large exponents + match &result { + Ok(_) => {}, // Success is fine + Err(_) => {}, // Overflow/underflow acceptable for extreme cases + } + + // Special cases that should always work + if base == "1" { + // 1^anything = 1 + if let Ok(ref val) = result { + assert_eq!(val, "1"); + } + } + + if exponent == "0" && base != "0" { + // anything^0 = 1 (except 0^0 which is undefined) + if let Ok(ref val) = result { + assert_eq!(val, "1"); + } + } +} + +// Test financial functions with boundary values +#[rstest] +fn test_financial_boundaries() { + // Test percentage calculations + let percentage_tests = [ + ("0", "50"), // 0% of 50 + ("100", "0"), // 100% of 0 + ("100", "100"), // 100% of 100 + ("1000000", "0.001"), // Large amount, tiny percentage + ("0.001", "1000000"), // Tiny amount, huge percentage + ]; + + for (amount, percentage) in percentage_tests { + let result = decimal_percentage(amount.to_string(), percentage.to_string()); + assert!(result.is_ok(), "Percentage calculation should work: {}% of {}", percentage, amount); + } + + // Test compound interest edge cases + let compound_tests = [ + ("1000", "0", "10"), // 0% interest + ("1000", "0.05", "0"), // 0 time periods + ("0", "0.05", "10"), // 0 principal + ("1", "2", "10"), // 200% interest (extreme but valid) + ]; + + for (principal, rate, time) in compound_tests { + let result = decimal_compound(principal.to_string(), rate.to_string(), time.to_string()); + // Some extreme cases may overflow, but shouldn't panic + match result { + Ok(_) => {}, + Err(_) => {}, // Acceptable for extreme cases + } + } +} diff --git a/steel_decimal/tests/concurrency_tests.rs b/steel_decimal/tests/concurrency_tests.rs new file mode 100644 index 0000000..aa9b692 --- /dev/null +++ b/steel_decimal/tests/concurrency_tests.rs @@ -0,0 +1,478 @@ +// tests/concurrency_tests.rs +use steel_decimal::*; +use std::sync::{Arc, Barrier, Mutex}; +use std::thread; +use std::time::Duration; +use std::collections::HashMap; + +// Test precision isolation between threads +#[test] +fn test_precision_thread_isolation() { + let num_threads = 10; + let barrier = Arc::new(Barrier::new(num_threads)); + let results = Arc::new(Mutex::new(Vec::new())); + + let handles: Vec<_> = (0..num_threads) + .map(|thread_id| { + let barrier = barrier.clone(); + let results = results.clone(); + + thread::spawn(move || { + // Each thread sets different precision + let precision = thread_id as u32 % 5; // 0-4 + set_precision(precision); + + // Wait for all threads to set their precision + barrier.wait(); + + // Perform calculation + let result = decimal_add("1.123456789".to_string(), "2.987654321".to_string()).unwrap(); + + // Verify precision is maintained in this thread + let current_precision = get_precision(); + + results.lock().unwrap().push((thread_id, precision, result, current_precision)); + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + + let results = results.lock().unwrap(); + + // Verify each thread maintained its own precision + for (thread_id, set_precision, result, current_precision) in results.iter() { + assert_eq!(current_precision, &set_precision.to_string(), + "Thread {} precision not isolated", thread_id); + + // Verify result respects the precision + if *set_precision > 0 { + let decimal_places = result.split('.').nth(1).map(|s| s.len()).unwrap_or(0); + assert!(decimal_places <= *set_precision as usize, + "Thread {} result {} has more than {} decimal places", + thread_id, result, set_precision); + } + } +} + +// Test concurrent arithmetic operations +#[test] +fn test_concurrent_arithmetic_operations() { + let num_threads = 20; + let operations_per_thread = 100; + let barrier = Arc::new(Barrier::new(num_threads)); + let errors = Arc::new(Mutex::new(Vec::new())); + + let handles: Vec<_> = (0..num_threads) + .map(|thread_id| { + let barrier = barrier.clone(); + let errors = errors.clone(); + + thread::spawn(move || { + barrier.wait(); + + for i in 0..operations_per_thread { + let a = format!("{}.{}", thread_id, i); + let b = format!("{}.{}", i, thread_id); + + // Test various operations don't interfere + let add_result = decimal_add(a.clone(), b.clone()); + let mul_result = decimal_mul(a.clone(), b.clone()); + let sub_result = decimal_sub(a.clone(), b.clone()); + + if add_result.is_err() || mul_result.is_err() || sub_result.is_err() { + errors.lock().unwrap().push(format!( + "Thread {}, iteration {}: arithmetic error", + thread_id, i + )); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + + let errors = errors.lock().unwrap(); + assert!(errors.is_empty(), "Concurrent arithmetic errors: {:?}", *errors); +} + +// Test Steel VM registration under concurrent load +#[test] +fn test_concurrent_vm_registration() { + use steel::steel_vm::engine::Engine; + + let num_threads = 5; + let barrier = Arc::new(Barrier::new(num_threads)); + let errors = Arc::new(Mutex::new(Vec::new())); + + let handles: Vec<_> = (0..num_threads) + .map(|thread_id| { + let barrier = barrier.clone(); + let errors = errors.clone(); + + thread::spawn(move || { + barrier.wait(); + + // Each thread creates its own VM and registers functions + let mut vm = Engine::new(); + FunctionRegistry::register_all(&mut vm); + + // Test execution + let script = r#"(decimal-add "1.5" "2.3")"#; + let result = vm.compile_and_run_raw_program(script.to_string()); + + match result { + Ok(vals) => { + if vals.len() != 1 { + errors.lock().unwrap().push(format!( + "Thread {}: Wrong number of results", thread_id + )); + } + } + Err(e) => { + errors.lock().unwrap().push(format!( + "Thread {}: VM execution error: {}", thread_id, e + )); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + + let errors = errors.lock().unwrap(); + assert!(errors.is_empty(), "Concurrent VM errors: {:?}", *errors); +} + +// Test variable access concurrency +#[test] +fn test_concurrent_variable_access() { + use steel::steel_vm::engine::Engine; + + let num_threads = 8; + let barrier = Arc::new(Barrier::new(num_threads)); + let errors = Arc::new(Mutex::new(Vec::new())); + + let handles: Vec<_> = (0..num_threads) + .map(|thread_id| { + let barrier = barrier.clone(); + let errors = errors.clone(); + + thread::spawn(move || { + // Each thread has its own variable set + let mut variables = HashMap::new(); + variables.insert(format!("var_{}", thread_id), format!("{}.0", thread_id * 10)); + variables.insert("shared".to_string(), "42.0".to_string()); + + let mut vm = Engine::new(); + FunctionRegistry::register_variables(&mut vm, variables); + + barrier.wait(); + + // Test variable access + let get_script = format!(r#"(get-var "var_{}")"#, thread_id); + let has_script = format!(r#"(has-var? "var_{}")"#, thread_id); + let shared_script = r#"(get-var "shared")"#.to_string(); + + for script in [get_script, shared_script] { + match vm.compile_and_run_raw_program(script) { + Ok(_) => {} + Err(e) => { + errors.lock().unwrap().push(format!( + "Thread {}: Variable access error: {}", thread_id, e + )); + } + } + } + + match vm.compile_and_run_raw_program(has_script) { + Ok(_) => {} + Err(e) => { + errors.lock().unwrap().push(format!( + "Thread {}: Variable check error: {}", thread_id, e + )); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + + let errors = errors.lock().unwrap(); + assert!(errors.is_empty(), "Concurrent variable access errors: {:?}", *errors); +} + +// Test precision state under rapid changes +#[test] +fn test_rapid_precision_changes() { + let num_threads = 4; + let changes_per_thread = 1000; + let barrier = Arc::new(Barrier::new(num_threads)); + let inconsistencies = Arc::new(Mutex::new(0)); + + let handles: Vec<_> = (0..num_threads) + .map(|_thread_id| { + let barrier = barrier.clone(); + let inconsistencies = inconsistencies.clone(); + + thread::spawn(move || { + barrier.wait(); + + for i in 0..changes_per_thread { + let precision = (i % 5) as u32; // Cycle through 0-4 + + set_precision(precision); + + // Immediately check precision + let current = get_precision(); + if current != precision.to_string() { + *inconsistencies.lock().unwrap() += 1; + } + + // Perform calculation and verify + let result = decimal_add("1.123456".to_string(), "2.654321".to_string()).unwrap(); + + if precision > 0 { + let decimal_places = result.split('.').nth(1).map(|s| s.len()).unwrap_or(0); + if decimal_places > precision as usize { + *inconsistencies.lock().unwrap() += 1; + } + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + + let inconsistencies = *inconsistencies.lock().unwrap(); + assert_eq!(inconsistencies, 0, "Found {} precision inconsistencies", inconsistencies); +} + +// Test parser thread safety +#[test] +fn test_parser_thread_safety() { + let num_threads = 10; + let transformations_per_thread = 100; + let barrier = Arc::new(Barrier::new(num_threads)); + let errors = Arc::new(Mutex::new(Vec::new())); + + let test_scripts = vec![ + "(+ 1.5 2.3)", + "(* $x $y)", + "(sqrt (+ (* $a $a) (* $b $b)))", + "(/ (- $max $min) 2)", + "(abs (- $value $target))", + ]; + + let handles: Vec<_> = (0..num_threads) + .map(|thread_id| { + let barrier = barrier.clone(); + let errors = errors.clone(); + let scripts = test_scripts.clone(); + + thread::spawn(move || { + let parser = ScriptParser::new(); + barrier.wait(); + + for i in 0..transformations_per_thread { + let script = &scripts[i % scripts.len()]; + + let transformed = parser.transform(script); + let _dependencies = parser.extract_dependencies(script); + + // Basic validation + let open_count = transformed.chars().filter(|c| *c == '(').count(); + let close_count = transformed.chars().filter(|c| *c == ')').count(); + + if open_count != close_count { + errors.lock().unwrap().push(format!( + "Thread {}, iteration {}: Unbalanced parentheses in {}", + thread_id, i, transformed + )); + } + + if !transformed.contains("decimal-") && script.contains('+') { + errors.lock().unwrap().push(format!( + "Thread {}, iteration {}: Transformation failed for {}", + thread_id, i, script + )); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + + let errors = errors.lock().unwrap(); + assert!(errors.is_empty(), "Parser thread safety errors: {:?}", *errors); +} + +// Test memory safety under concurrent load +#[test] +fn test_memory_safety_concurrent_load() { + let num_threads = 8; + let iterations = 500; + let barrier = Arc::new(Barrier::new(num_threads)); + + let handles: Vec<_> = (0..num_threads) + .map(|thread_id| { + let barrier = barrier.clone(); + + thread::spawn(move || { + barrier.wait(); + + // Create many SteelDecimal instances + for i in 0..iterations { + let mut steel_decimal = SteelDecimal::new(); + + // Add variables + steel_decimal.add_variable(format!("var_{}", i), format!("{}.{}", thread_id, i)); + + // Transform scripts + let script = format!("(+ {} {})", i, thread_id); + let _ = steel_decimal.transform(&script); + + // Extract dependencies + let _ = steel_decimal.extract_dependencies(&script); + + // Small delay to increase chance of race conditions + if i % 100 == 0 { + thread::sleep(Duration::from_micros(1)); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + + // If we get here without panicking, memory safety is maintained +} + +// Test precision cleanup after thread termination +#[test] +fn test_precision_cleanup_after_thread_death() { + // Create thread that sets precision and dies + let handle = thread::spawn(|| { + set_precision(3); + decimal_add("1.123456".to_string(), "2.654321".to_string()).unwrap() + }); + + let result = handle.join().unwrap(); + + // Verify the result had the precision applied + let decimal_places = result.split('.').nth(1).map(|s| s.len()).unwrap_or(0); + assert!(decimal_places <= 3); + + // In main thread, precision should be unaffected + let main_precision = get_precision(); + // Should be "full" (default) since we haven't set it in main thread + assert_eq!(main_precision, "full"); + + // Create another thread - should start fresh + let handle2 = thread::spawn(|| { + let precision = get_precision(); + (precision, decimal_add("1.123456".to_string(), "2.654321".to_string()).unwrap()) + }); + + let (new_precision, new_result) = handle2.join().unwrap(); + assert_eq!(new_precision, "full"); + + // This result should use full precision + let new_decimal_places = new_result.split('.').nth(1).map(|s| s.len()).unwrap_or(0); + assert!(new_decimal_places > 3); // Should be more than the previous thread's precision +} + +// Stress test with mixed operations +#[test] +fn test_concurrent_stress_mixed_operations() { + let num_threads = 6; + let operations_per_thread = 200; + let barrier = Arc::new(Barrier::new(num_threads)); + let total_errors = Arc::new(Mutex::new(0)); + + let handles: Vec<_> = (0..num_threads) + .map(|thread_id| { + let barrier = barrier.clone(); + let total_errors = total_errors.clone(); + + thread::spawn(move || { + let mut errors = 0; + barrier.wait(); + + for i in 0..operations_per_thread { + // Mix of precision settings + if i % 50 == 0 { + set_precision((thread_id as u32) % 5); + } + + // Mix of operations + match i % 6 { + 0 => { + if decimal_add(format!("{}.{}", thread_id, i), "1.0".to_string()).is_err() { + errors += 1; + } + } + 1 => { + if decimal_mul(format!("{}", i), format!("{}.5", thread_id)).is_err() { + errors += 1; + } + } + 2 => { + if decimal_sqrt(format!("{}", i + 1)).is_err() && i > 0 { + errors += 1; + } + } + 3 => { + if decimal_abs(format!("-{}.{}", thread_id, i)).is_err() { + errors += 1; + } + } + 4 => { + if decimal_gt(format!("{}", i), format!("{}", thread_id)).is_err() { + errors += 1; + } + } + 5 => { + if to_decimal(format!("{}.{}e1", thread_id, i)).is_err() { + errors += 1; + } + } + _ => unreachable!() + } + } + + *total_errors.lock().unwrap() += errors; + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + + let total_errors = *total_errors.lock().unwrap(); + + // Allow some errors for edge cases (like sqrt of 0), but not too many + assert!(total_errors < num_threads * operations_per_thread / 10, + "Too many errors in stress test: {}", total_errors); +} diff --git a/steel_decimal/tests/property_tests.rs b/steel_decimal/tests/property_tests.rs new file mode 100644 index 0000000..029298d --- /dev/null +++ b/steel_decimal/tests/property_tests.rs @@ -0,0 +1,338 @@ +// tests/property_tests.rs +use proptest::prelude::*; +use steel_decimal::*; +use rust_decimal::Decimal; +use std::str::FromStr; + +// Strategy for generating valid decimal strings +fn decimal_string() -> impl Strategy { + prop_oneof![ + // Small integers + (-1000i32..1000i32).prop_map(|i| i.to_string()), + // Small decimals with 1-6 decimal places + ( + -1000i32..1000i32, + 1..1000000u32 + ).prop_map(|(whole, frac)| { + let frac_str = format!("{:06}", frac); + format!("{}.{}", whole, frac_str.trim_end_matches('0')) + }), + // Scientific notation + ( + -100i32..100i32, + -10i32..10i32 + ).prop_map(|(mantissa, exp)| format!("{}e{}", mantissa, exp)), + // Very small numbers + Just("0.000000000000000001".to_string()), + Just("0.000000000000000000000000001".to_string()), + // Numbers at decimal precision limits + Just("99999999999999999999999999.9999".to_string()), + ] +} + +// Strategy for generating valid precision values +fn precision_value() -> impl Strategy { + 0..=28u32 +} + +// Property: Basic arithmetic operations preserve decimal precision semantics +proptest! { + #[test] + fn test_arithmetic_commutativity( + a in decimal_string(), + b in decimal_string() + ) { + // Addition should be commutative: a + b = b + a + let result1 = decimal_add(a.clone(), b.clone()); + let result2 = decimal_add(b, a); + + match (result1, result2) { + (Ok(r1), Ok(r2)) => { + // Parse both results and compare as decimals + let d1 = Decimal::from_str(&r1).unwrap(); + let d2 = Decimal::from_str(&r2).unwrap(); + prop_assert_eq!(d1, d2); + } + (Err(_), Err(_)) => { + // Both should fail in the same way for invalid inputs + } + _ => prop_assert!(false, "Inconsistent error handling") + } + } + + #[test] + fn test_multiplication_commutativity( + a in decimal_string(), + b in decimal_string() + ) { + let result1 = decimal_mul(a.clone(), b.clone()); + let result2 = decimal_mul(b, a); + + match (result1, result2) { + (Ok(r1), Ok(r2)) => { + let d1 = Decimal::from_str(&r1).unwrap(); + let d2 = Decimal::from_str(&r2).unwrap(); + prop_assert_eq!(d1, d2); + } + (Err(_), Err(_)) => {} + _ => prop_assert!(false, "Inconsistent error handling") + } + } + + #[test] + fn test_addition_associativity( + a in decimal_string(), + b in decimal_string(), + c in decimal_string() + ) { + // (a + b) + c = a + (b + c) + let ab = decimal_add(a.clone(), b.clone()); + let bc = decimal_add(b, c.clone()); + + if let (Ok(ab_result), Ok(bc_result)) = (ab, bc) { + let left = decimal_add(ab_result, c); + let right = decimal_add(a, bc_result); + + if let (Ok(left_final), Ok(right_final)) = (left, right) { + let d1 = Decimal::from_str(&left_final).unwrap(); + let d2 = Decimal::from_str(&right_final).unwrap(); + prop_assert_eq!(d1, d2); + } + } + } + + #[test] + fn test_multiplication_by_zero(a in decimal_string()) { + let result = decimal_mul(a, "0".to_string()); + if let Ok(r) = result { + let d = Decimal::from_str(&r).unwrap(); + prop_assert!(d.is_zero()); + } + } + + #[test] + fn test_addition_with_zero_identity(a in decimal_string()) { + let result = decimal_add(a.clone(), "0".to_string()); + match result { + Ok(r) => { + // Converting through decimal and back should give equivalent result + if let Ok(original) = Decimal::from_str(&a) { + let result_decimal = Decimal::from_str(&r).unwrap(); + prop_assert_eq!(original, result_decimal); + } + } + Err(_) => { + // If a is invalid, this is expected + prop_assert!(Decimal::from_str(&a).is_err()); + } + } + } + + #[test] + fn test_division_then_multiplication_inverse( + a in decimal_string(), + b in decimal_string().prop_filter("b != 0", |b| b != "0") + ) { + // (a / b) * b should approximately equal a + let div_result = decimal_div(a.clone(), b.clone()); + if let Ok(quotient) = div_result { + let mul_result = decimal_mul(quotient, b); + if let Ok(final_result) = mul_result { + if let (Ok(original), Ok(final_decimal)) = + (Decimal::from_str(&a), Decimal::from_str(&final_result)) { + // Allow for small rounding differences + let diff = (original - final_decimal).abs(); + let tolerance = Decimal::from_str("0.000000000001").unwrap(); + prop_assert!(diff <= tolerance, + "Division-multiplication not inverse: {} vs {}", + original, final_decimal); + } + } + } + } + + #[test] + fn test_absolute_value_properties(a in decimal_string()) { + let abs_result = decimal_abs(a.clone()); + if let Ok(abs_val) = abs_result { + let abs_decimal = Decimal::from_str(&abs_val).unwrap(); + + // abs(x) >= 0 + prop_assert!(abs_decimal >= Decimal::ZERO); + + // abs(abs(x)) = abs(x) + let double_abs = decimal_abs(abs_val); + if let Ok(double_abs_val) = double_abs { + let double_abs_decimal = Decimal::from_str(&double_abs_val).unwrap(); + prop_assert_eq!(abs_decimal, double_abs_decimal); + } + } + } + + #[test] + fn test_comparison_transitivity( + a in decimal_string(), + b in decimal_string(), + c in decimal_string() + ) { + // If a > b and b > c, then a > c + let ab = decimal_gt(a.clone(), b.clone()); + let bc = decimal_gt(b, c.clone()); + let ac = decimal_gt(a, c); + + if let (Ok(true), Ok(true), Ok(ac_result)) = (ab, bc, ac) { + prop_assert!(ac_result, "Transitivity violated for > comparison"); + } + } + + #[test] + fn test_min_max_properties( + a in decimal_string(), + b in decimal_string() + ) { + let min_result = decimal_min(a.clone(), b.clone()); + let max_result = decimal_max(a.clone(), b.clone()); + + if let (Ok(min_val), Ok(max_val)) = (min_result, max_result) { + let min_decimal = Decimal::from_str(&min_val).unwrap(); + let max_decimal = Decimal::from_str(&max_val).unwrap(); + + // min(a,b) <= max(a,b) + prop_assert!(min_decimal <= max_decimal); + + // min(a,b) should equal either a or b + if let (Ok(a_decimal), Ok(b_decimal)) = + (Decimal::from_str(&a), Decimal::from_str(&b)) { + prop_assert!(min_decimal == a_decimal || min_decimal == b_decimal); + prop_assert!(max_decimal == a_decimal || max_decimal == b_decimal); + } + } + } + + #[test] + fn test_round_trip_conversion(a in decimal_string()) { + // to_decimal should be idempotent for valid decimals + let first_conversion = to_decimal(a.clone()); + if let Ok(converted) = first_conversion { + let second_conversion = to_decimal(converted.clone()); + prop_assert_eq!(Ok(converted), second_conversion); + } + } + + #[test] + fn test_precision_formatting_consistency( + a in decimal_string(), + precision in precision_value() + ) { + let formatted = decimal_format(a.clone(), precision); + if let Ok(result) = formatted { + // Formatting again with same precision should be idempotent + let reformatted = decimal_format(result.clone(), precision); + prop_assert_eq!(Ok(result.clone()), reformatted); + + // Result should have at most 'precision' decimal places + if let Some(dot_pos) = result.find('.') { + let decimal_part = &result[dot_pos + 1..]; + prop_assert!(decimal_part.len() <= precision as usize); + } + } + } + + #[test] + fn test_sqrt_then_square_approximate_inverse( + a in decimal_string().prop_filter("positive", |s| { + Decimal::from_str(s).map(|d| d >= Decimal::ZERO).unwrap_or(false) + }) + ) { + let sqrt_result = decimal_sqrt(a.clone()); + if let Ok(sqrt_val) = sqrt_result { + let square_result = decimal_mul(sqrt_val.clone(), sqrt_val); + if let Ok(square_val) = square_result { + if let (Ok(original), Ok(squared)) = + (Decimal::from_str(&a), Decimal::from_str(&square_val)) { + // Allow for rounding differences in sqrt + let diff = (original - squared).abs(); + let tolerance = Decimal::from_str("0.0001").unwrap(); + prop_assert!(diff <= tolerance, + "sqrt-square not approximate inverse: {} vs {}", + original, squared); + } + } + } + } +} + +// Property tests for parser transformation +proptest! { + #[test] + fn test_parser_transformation_preserves_structure( + operations in prop::collection::vec( + prop_oneof!["+" , "-", "*", "/", "sqrt", "abs"], + 1..5usize + ) + ) { + let parser = ScriptParser::new(); + + // Generate a simple expression + let expr = format!("({} 1 2)", operations[0]); + let transformed = parser.transform(&expr); + + // Transformed should be balanced parentheses + let open_count = transformed.chars().filter(|c| *c == '(').count(); + let close_count = transformed.chars().filter(|c| *c == ')').count(); + prop_assert_eq!(open_count, close_count); + + // Should contain decimal function + prop_assert!(transformed.contains("decimal-")); + } + + #[test] + fn test_variable_extraction_correctness( + var_names in prop::collection::vec("[a-zA-Z][a-zA-Z0-9_]*", 1..10) + ) { + let parser = ScriptParser::new(); + + // Create expression with variables + let expr = format!("(+ ${})", var_names.join(" $")); + let dependencies = parser.extract_dependencies(&expr); + + // Should extract all variable names + for name in &var_names { + prop_assert!(dependencies.contains(name)); + } + + // Should not extract extra variables + prop_assert_eq!(dependencies.len(), var_names.len()); + } +} + +// Fuzzing-style tests for edge cases +proptest! { + #[test] + fn test_no_panics_on_random_input( + input in ".*" + ) { + // These operations should never panic, only return errors + let _ = to_decimal(input.clone()); + let _ = decimal_add(input.clone(), "1".to_string()); + let _ = decimal_abs(input.clone()); + + let parser = ScriptParser::new(); + let _ = parser.transform(&input); + let _ = parser.extract_dependencies(&input); + } + + #[test] + fn test_scientific_notation_consistency( + mantissa in -1000f64..1000f64, + exponent in -10i32..10i32 + ) { + let sci_notation = format!("{}e{}", mantissa, exponent); + let conversion_result = to_decimal(sci_notation); + + // If conversion succeeds, result should be a valid decimal + if let Ok(result) = conversion_result { + prop_assert!(Decimal::from_str(&result).is_ok()); + } + } +} diff --git a/steel_decimal/tests/security_tests.rs b/steel_decimal/tests/security_tests.rs new file mode 100644 index 0000000..4e1025c --- /dev/null +++ b/steel_decimal/tests/security_tests.rs @@ -0,0 +1,424 @@ +// tests/security_tests.rs +use rstest::*; +use steel_decimal::*; +use steel::steel_vm::engine::Engine; +use std::collections::HashMap; + +// Test stack overflow protection with deeply nested expressions +#[rstest] +fn test_stack_overflow_protection() { + let parser = ScriptParser::new(); + + // Create extremely deep nesting (potential stack overflow) + let mut expr = "1".to_string(); + for i in 0..10000 { + expr = format!("(+ {} {})", expr, i); + } + + // Should not crash the process + let result = std::panic::catch_unwind(|| { + parser.transform(&expr) + }); + + // Either succeeds or panics gracefully, but shouldn't segfault + match result { + Ok(_) => {}, // Transformation succeeded + Err(_) => {}, // Panic caught, which is acceptable + } +} + +// Test memory exhaustion protection +#[rstest] +fn test_memory_exhaustion_protection() { + let parser = ScriptParser::new(); + + // Create expression designed to consume lots of memory + let large_var_name = "x".repeat(1_000_000); // 1MB variable name + let expr = format!("(+ ${} 1)", large_var_name); + + // Should not consume unlimited memory + let result = std::panic::catch_unwind(|| { + parser.transform(&expr) + }); + + // Should handle gracefully + assert!(result.is_ok()); +} + +// Test injection attacks through variable names +#[rstest] +#[case("'; DROP TABLE users; --")] // SQL injection style +#[case("$(rm -rf /)")] // Shell injection style +#[case("")] // XSS style +#[case("../../etc/passwd")] // Path traversal style +#[case("${system('rm -rf /')}")] // Template injection style +#[case("{{7*7}}")] // Template injection +#[case("__proto__")] // Prototype pollution +#[case("constructor")] // Constructor pollution +#[case("\\x00\\x01\\x02")] // Null bytes and control chars +fn test_variable_name_injection(#[case] malicious_var: &str) { + let parser = ScriptParser::new(); + + // Attempt injection through variable name + let expr = format!("(+ ${} 1)", malicious_var); + let transformed = parser.transform(&expr); + + // Should transform without executing malicious code + assert!(transformed.contains("get-var")); + assert!(transformed.contains(malicious_var)); + + // Should extract as dependency without side effects + let deps = parser.extract_dependencies(&expr); + assert!(deps.contains(malicious_var)); +} + +// Test malicious Steel expressions +#[rstest] +#[case("(eval '(system \"rm -rf /\"))")] // Code execution attempt +#[case("(load \"../../etc/passwd\")")] // File access attempt +#[case("(require 'os) (os/execute \"malicious-command\")")] // Module injection +#[case("(define loop (lambda () (loop))) (loop)")] // Infinite recursion +#[case("(define mem-bomb (lambda () (cons 1 (mem-bomb)))) (mem-bomb)")] // Memory bomb +fn test_malicious_steel_expressions(#[case] malicious_expr: &str) { + let steel_decimal = SteelDecimal::new(); + + // Should not execute malicious Steel code during transformation + let transformed = steel_decimal.transform(malicious_expr); + + // Transformation should complete without side effects + assert!(!transformed.is_empty()); + + // Should not contain the original malicious functions if transformed + if malicious_expr.contains("eval") || malicious_expr.contains("load") { + // These shouldn't be transformed into decimal operations + assert!(!transformed.contains("decimal-")); + } +} + +// Test parser regex exploitation +#[rstest] +#[case("((((((((((a")] // Unbalanced parentheses +fn test_parser_regex_exploitation_simple(#[case] malicious_input: &str) { + let parser = ScriptParser::new(); + + // Should not hang or consume excessive CPU + let start = std::time::Instant::now(); + let result = std::panic::catch_unwind(|| { + parser.transform(malicious_input) + }); + let duration = start.elapsed(); + + // Should complete within reasonable time (not ReDoS) + assert!(duration.as_secs() < 5, "Parser took too long: {:?}", duration); + + // Should not crash + assert!(result.is_ok()); +} + +#[rstest] +fn test_parser_regex_exploitation_large_inputs() { + let parser = ScriptParser::new(); + + // Test extremely long variable reference + let large_var = format!("${}", "a".repeat(100000)); + let start = std::time::Instant::now(); + let result = std::panic::catch_unwind(|| { + parser.transform(&large_var) + }); + let duration = start.elapsed(); + assert!(duration.as_secs() < 5, "Large variable parsing took too long: {:?}", duration); + assert!(result.is_ok()); + + // Test repeated operators + let repeated_ops = format!("({}{})", "+".repeat(100000), " 1 2)"); + let start = std::time::Instant::now(); + let result = std::panic::catch_unwind(|| { + parser.transform(&repeated_ops) + }); + let duration = start.elapsed(); + assert!(duration.as_secs() < 5, "Repeated operators parsing took too long: {:?}", duration); + assert!(result.is_ok()); + + // Test huge string literals + let huge_string = format!("\"{}\"", "a".repeat(1000000)); + let start = std::time::Instant::now(); + let result = std::panic::catch_unwind(|| { + parser.transform(&huge_string) + }); + let duration = start.elapsed(); + assert!(duration.as_secs() < 5, "Huge string parsing took too long: {:?}", duration); + assert!(result.is_ok()); +} + +// Test Steel VM security integration +#[rstest] +fn test_steel_vm_security_integration() { + let mut vm = Engine::new(); + let steel_decimal = SteelDecimal::new(); + steel_decimal.register_functions(&mut vm); + + // Test that we can't escape decimal functions to execute arbitrary code + let malicious_scripts = vec![ + r#"(eval "(system \"echo pwned\")")"#, + r#"(load "../../etc/passwd")"#, + r#"(define dangerous (lambda () (system "rm -rf /")))"#, + r#"(require 'steel/core)"#, // Try to access core modules + ]; + + for script in malicious_scripts { + let result = vm.compile_and_run_raw_program(script.to_string()); + + // These should fail to compile or execute, not succeed + match result { + Ok(_) => { + // If it succeeds, verify it didn't do anything dangerous + // (We can't really test this without side effects, so we assume it's safe) + } + Err(_) => { + // Expected - should fail to execute dangerous code + } + } + } +} + +// Test variable access security +#[rstest] +fn test_variable_access_security() { + let mut variables = HashMap::new(); + variables.insert("safe_var".to_string(), "42".to_string()); + variables.insert("password".to_string(), "secret123".to_string()); + variables.insert("api_key".to_string(), "key_abc123".to_string()); + + let mut vm = Engine::new(); + FunctionRegistry::register_variables(&mut vm, variables); + + // Test that we can't enumerate all variables + let enumeration_attempts = vec![ + r#"(map get-var (list "password" "api_key" "secret"))"#, + r#"(get-var "")"#, // Empty variable name + r#"(get-var nil)"#, // Nil variable name + ]; + + for attempt in enumeration_attempts { + let result = vm.compile_and_run_raw_program(attempt.to_string()); + // Should either fail or not reveal sensitive information + match result { + Ok(_) => {}, // If succeeds, assume it's safe + Err(_) => {}, // Expected failure + } + } +} + +// Test format string attacks through decimal formatting +#[rstest] +#[case("%s%s%s%s")] // Format string attack +#[case("%n")] // Write to memory attempt +#[case("%x%x%x%x")] // Memory reading attempt +#[case("\\x41\\x41\\x41\\x41")] // Buffer overflow attempt +fn test_format_string_attacks(#[case] format_attack: &str) { + // Test in various contexts where user input might be formatted + let _ = to_decimal(format_attack.to_string()); + let _ = decimal_add(format_attack.to_string(), "1".to_string()); + let _ = decimal_format("123.456".to_string(), 2); // Shouldn't use user input as format + + // Should not crash or leak memory +} + +// Test buffer overflow attempts +#[rstest] +fn test_buffer_overflow_attempts() { + // Test with very long inputs that might cause buffer overflows in C libraries + let long_input = "A".repeat(100_000); + let long_number = "1".repeat(10_000) + "." + &"2".repeat(10_000); + + // Should handle gracefully without buffer overflows + let _ = to_decimal(long_input); + let _ = to_decimal(long_number.clone()); + let _ = decimal_add(long_number.clone(), "1".to_string()); + let _ = decimal_sqrt(long_number); + + // If we get here without crashing, buffer overflow protection works +} + +// Test denial of service through resource exhaustion +#[rstest] +fn test_resource_exhaustion_protection() { + let steel_decimal = SteelDecimal::new(); + + // Test CPU exhaustion + let cpu_bomb = "(+ ".repeat(10000) + "1" + &")".repeat(10000); + let start = std::time::Instant::now(); + let _ = steel_decimal.transform(&cpu_bomb); + let duration = start.elapsed(); + + // Should not take excessive time + assert!(duration.as_secs() < 10, "CPU exhaustion detected"); + + // Test memory exhaustion through many variables + let mut steel_decimal = SteelDecimal::new(); + for i in 0..100_000 { + steel_decimal.add_variable(format!("var_{}", i), "1".to_string()); + } + + // Should handle many variables without exhausting memory + let expr = "(+ $var_0 $var_99999)"; + let _ = steel_decimal.transform(expr); +} + +// Test integer overflow/underflow in precision settings +#[rstest] +#[case(u32::MAX)] +#[case(u32::MAX - 1)] +fn test_integer_overflow_in_precision(#[case] overflow_value: u32) { + // Should handle overflow gracefully + let result = set_precision(overflow_value); + assert!(result.contains("Error") || result.contains("Maximum")); + + // Should not set invalid precision + let current = get_precision(); + assert_ne!(current, overflow_value.to_string()); +} + +// Test race conditions in precision settings (security through thread safety) +#[rstest] +fn test_precision_race_conditions() { + use std::sync::{Arc, Barrier}; + use std::thread; + + let num_threads = 10; + let barrier = Arc::new(Barrier::new(num_threads)); + let success_count = Arc::new(std::sync::atomic::AtomicU32::new(0)); + + let handles: Vec<_> = (0..num_threads) + .map(|thread_id| { + let barrier = barrier.clone(); + let success_count = success_count.clone(); + + thread::spawn(move || { + barrier.wait(); + + // Try to cause race condition + for i in 0..1000 { + let precision = (thread_id + i) % 5; + set_precision(precision as u32); + + // Immediately use precision + let result = decimal_add("1.123456789".to_string(), "2.987654321".to_string()); + if result.is_ok() { + success_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + + // Should have high success rate (race conditions would cause failures) + let successes = success_count.load(std::sync::atomic::Ordering::Relaxed); + assert!(successes > (num_threads * 900) as u32, "Too many race condition failures: {}", successes); +} + +// Test SQL injection style attacks through numeric inputs +#[rstest] +#[case("1; DROP TABLE decimals; --")] +#[case("1' OR '1'='1")] +#[case("1 UNION SELECT * FROM passwords")] +#[case("1; exec('rm -rf /')")] +fn test_sql_injection_style_attacks(#[case] injection_attempt: &str) { + // These should be treated as invalid decimal formats + let result = to_decimal(injection_attempt.to_string()); + assert!(result.is_err(), "SQL injection attempt should fail: {}", injection_attempt); + + // Should also fail in arithmetic + let add_result = decimal_add(injection_attempt.to_string(), "1".to_string()); + assert!(add_result.is_err(), "Arithmetic with injection should fail"); +} + +// Test path traversal through variable names +#[rstest] +#[case("../../../etc/passwd")] +#[case("..\\..\\..\\windows\\system32\\config\\sam")] +#[case("/etc/passwd")] +#[case("C:\\Windows\\System32\\config\\SAM")] +#[case("file:///etc/passwd")] +#[case("data:text/plain;base64,cm9vdDp4OjA6MA==")] +fn test_path_traversal_attacks(#[case] path_attack: &str) { + let mut steel_decimal = SteelDecimal::new(); + + // Should treat as normal variable name, not file path + steel_decimal.add_variable(path_attack.to_string(), "42".to_string()); + + let expr = format!("(+ ${} 1)", path_attack); + let transformed = steel_decimal.transform(&expr); + + // Should treat as variable reference, not attempt file access + assert!(transformed.contains("get-var")); + assert!(transformed.contains(path_attack)); +} + +// Test XML/HTML injection through variable values +#[rstest] +#[case("content")] +#[case("")] +#[case("]>")] +fn test_xml_html_injection(#[case] xml_attack: &str) { + let mut steel_decimal = SteelDecimal::new(); + + // Should treat as string value, not parse as XML/HTML + steel_decimal.add_variable("test_var".to_string(), xml_attack.to_string()); + + let vars = steel_decimal.get_variables(); + assert_eq!(vars.get("test_var").unwrap(), xml_attack); + + // Should not interpret as markup + assert!(!xml_attack.is_empty()); // Basic sanity check +} + +// Test deserialization attacks +#[rstest] +fn test_deserialization_attacks() { + // Test with serialized data that might trigger deserialization vulnerabilities + let malicious_serialized = vec![ + "rO0ABXNyABFqYXZhLnV0aWwuSGFzaE1hcAUH2sHDFmDRAwACRgAKbG9hZEZhY3RvckkACXRocmVzaG9sZHhwP0AAAAAAAAx3CAAAABAAAAABdAABYXQAAWJ4", + "AC ED 00 05 73 72", + "pickle\\x80\\x03]q\\x00.", + ]; + + for payload in malicious_serialized { + // Should treat as regular string, not attempt deserialization + let result = to_decimal(payload.to_string()); + assert!(result.is_err(), "Serialized payload should not be valid decimal"); + + let mut steel_decimal = SteelDecimal::new(); + steel_decimal.add_variable("payload".to_string(), payload.to_string()); + + // Should store as string value + assert_eq!(steel_decimal.get_variables().get("payload").unwrap(), payload); + } +} + +// Test timing attacks +#[rstest] +fn test_timing_attack_resistance() { + // Test that comparison operations don't leak information through timing + let values = vec!["1", "1.0", "1.00", "1.000"]; + let mut times = Vec::new(); + + for value in values { + let start = std::time::Instant::now(); + let _ = decimal_eq(value.to_string(), "1".to_string()); + let duration = start.elapsed(); + times.push(duration); + } + + // Times should be relatively similar (not vulnerable to timing attacks) + let max_time = times.iter().max().unwrap(); + let min_time = times.iter().min().unwrap(); + let ratio = max_time.as_nanos() as f64 / min_time.as_nanos() as f64; + + // Allow for reasonable variance but not massive differences + assert!(ratio < 10.0, "Timing attack vulnerability detected: ratio = {}", ratio); +}