From 4d5d22d0c20b4119d8df45fffa359a2d0f0eb181 Mon Sep 17 00:00:00 2001 From: filipriec Date: Mon, 7 Jul 2025 00:31:13 +0200 Subject: [PATCH] precision to steel decimal crate implemented --- steel_decimal/docs/rust_decimal.txt | 366 ++++++++++++++++++++++++++ steel_decimal/src/functions.rs | 167 +++++++++--- steel_decimal/src/registry.rs | 18 ++ steel_decimal/tests/function_tests.rs | 52 ++++ 4 files changed, 566 insertions(+), 37 deletions(-) create mode 100644 steel_decimal/docs/rust_decimal.txt diff --git a/steel_decimal/docs/rust_decimal.txt b/steel_decimal/docs/rust_decimal.txt new file mode 100644 index 0000000..747c6bd --- /dev/null +++ b/steel_decimal/docs/rust_decimal.txt @@ -0,0 +1,366 @@ +# rust_decimal for Financial Applications: Complete Guide + +rust_decimal provides a 128-bit fixed-precision decimal implementation designed specifically for financial calculations, eliminating floating-point rounding errors that plague traditional financial software. With a 96-bit mantissa and support for up to 28 decimal places, it offers the exact precision required for accounting and monetary calculations while maintaining performance suitable for high-throughput financial systems. + +## Input handling best practices + +### String parsing and validation patterns + +rust_decimal provides multiple parsing methods optimized for different input scenarios. The **most robust approach for financial applications** uses `from_str_exact()` for strict validation combined with comprehensive error handling: + +```rust +use rust_decimal::{Decimal, Error}; + +fn parse_financial_amount(input: &str) -> Result { + let trimmed = input.trim(); + + // Pre-validation checks + if trimmed.is_empty() { + return Err(AmountError::EmptyInput); + } + + if trimmed.len() > 50 { + return Err(AmountError::InputTooLong); + } + + // Use from_str_exact for strict parsing + Decimal::from_str_exact(trimmed) + .map_err(|e| match e { + Error::InvalidOperation => AmountError::InvalidFormat, + Error::Underflow => AmountError::Underflow, + Error::Overflow => AmountError::Overflow, + _ => AmountError::ParseError, + }) +} +``` + +The library supports **multiple input formats automatically**: standard decimal notation (`"123.45"`), scientific notation via `from_scientific("2.512e1")`, and different radix bases through `from_str_radix()`. For **compile-time optimization**, use the `dec!()` macro which parses literals at compile time with zero runtime cost. + +### Precision and currency considerations + +Different financial contexts require specific precision strategies. **Standard recommendations** include 2-4 decimal places for retail/e-commerce, 4-6 for forex trading, 8-18 for cryptocurrency, and 4-6 for GAAP accounting compliance. The library's maximum scale of 28 decimal places accommodates even the most demanding financial calculations. + +**Currency-specific validation patterns** should enforce appropriate ranges and scales: + +```rust +pub struct ValidatedAmount(Decimal); + +impl ValidatedAmount { + pub fn new(value: Decimal) -> Result { + // Range validation + if value < Decimal::MIN || value > Decimal::MAX { + return Err(FinancialError::OutOfRange); + } + + // Precision validation for currency context + if value.scale() > 28 { + return Err(FinancialError::ScaleExceeded); + } + + Ok(ValidatedAmount(value)) + } +} +``` + +### Handling different input formats + +For **production systems processing various input formats**, implement a unified parsing strategy that handles integers, decimals, and scientific notation: + +```rust +// Automatic conversion from integers +let amount = Decimal::from(12345_i64); // 12345 + +// Float conversion with precision control +let price = Decimal::from_f64(123.45).unwrap(); + +// Scientific notation parsing +let large_amount = Decimal::from_scientific("1.23e6").unwrap(); // 1230000 + +// String parsing with validation +let user_input = "99.99"; +let parsed = Decimal::from_str(user_input)?; +``` + +## Output formatting best practices + +### Display and precision control + +rust_decimal provides multiple formatting approaches optimized for different financial contexts. The **standard approach** uses the `Display` trait for human-readable output, while **precision-controlled formatting** uses `round_dp()` for specific decimal places: + +```rust +let amount = dec!(123.456789); + +// Standard string representation +let display = amount.to_string(); // "123.456789" + +// Precision-controlled output +let currency_format = amount.round_dp(2).to_string(); // "123.46" + +// Scientific notation for large numbers +let scientific = format!("{:e}", amount); // "1.23456789e2" +``` + +### Rounding strategies for accounting + +The library implements **comprehensive rounding strategies** including banker's rounding (IEEE 754 compliant) which eliminates systematic bias in large datasets: + +```rust +use rust_decimal::RoundingStrategy; + +let tax = dec!(3.4395); + +// Banker's rounding (default) - preferred for financial compliance +let rounded = tax.round_dp(2); // Uses MidpointNearestEven + +// Explicit rounding strategies +let away_from_zero = tax.round_dp_with_strategy(2, RoundingStrategy::MidpointAwayFromZero); +let truncated = tax.round_dp_with_strategy(2, RoundingStrategy::ToZero); +``` + +### Currency formatting and localization + +For **multi-currency applications**, implement currency-aware formatting that maintains precision requirements: + +```rust +#[derive(Debug, Clone)] +pub struct Money { + amount: Decimal, + currency: Currency, +} + +impl Money { + pub fn format_for_display(&self, precision: u32) -> String { + match self.currency { + Currency::USD => format!("${}", self.amount.round_dp(precision)), + Currency::EUR => format!("€{}", self.amount.round_dp(precision)), + Currency::BTC => format!("₿{}", self.amount.round_dp(8)), + } + } +} +``` + +## Robust conversion patterns + +### String-to-Decimal-to-String pipeline + +The **most efficient conversion pipeline** for financial applications uses compile-time optimization where possible and validated parsing for runtime inputs: + +```rust +// Compile-time optimization for known values +const COMMISSION_RATE: Decimal = dec!(0.0025); +const TAX_RATE: Decimal = dec!(0.15); + +// Runtime parsing with validation +fn process_transaction(amount_str: &str) -> Result { + let amount = parse_financial_amount(amount_str)?; + let commission = amount * COMMISSION_RATE; + let total = amount + commission; + + Ok(total.round_dp(2).to_string()) +} +``` + +### Edge case handling + +**Production-ready edge case handling** requires comprehensive validation and error recovery: + +```rust +pub trait SafeDecimalOps { + fn safe_add(&self, other: Self) -> Result; + fn safe_multiply(&self, other: Self) -> Result; + fn safe_divide(&self, other: Self) -> Result; +} + +impl SafeDecimalOps for Decimal { + fn safe_add(&self, other: Self) -> Result { + self.checked_add(other).ok_or(FinancialError::Overflow) + } + + fn safe_divide(&self, other: Self) -> Result { + if other.is_zero() { + return Err(FinancialError::DivisionByZero); + } + self.checked_div(other).ok_or(FinancialError::Overflow) + } +} +``` + +### Performance considerations + +For **high-frequency financial calculations**, rust_decimal offers significant advantages over floating-point arithmetic despite being 2-6x slower. **Key performance characteristics** include 10-20ns for addition/subtraction, 50-100ns for multiplication, and 100-200ns for division. The library uses **stack allocation** (16 bytes per Decimal) and provides **zero-cost abstractions** through compile-time macros. + +**Memory optimization strategies** include using the `Copy` trait for efficient stack-based operations, implementing batch processing patterns, and pre-allocating constants: + +```rust +// Efficient batch processing +fn calculate_portfolio_value(positions: &[Position]) -> Decimal { + positions.iter() + .map(|pos| pos.quantity * pos.average_price) + .sum() // Decimal implements Sum trait +} +``` + +## Financial-specific features + +### Scale and precision handling + +rust_decimal's **128-bit architecture** provides optimal balance between precision and performance for financial applications. The **96-bit mantissa** supports approximately 28-29 significant digits, while the **32-bit metadata** handles scale (0-28) and sign information. + +**Database integration patterns** vary by storage backend: +- **PostgreSQL**: Use `NUMERIC(19,4)` for standard applications, `NUMERIC(28,8)` for high precision +- **MySQL**: Use `DECIMAL(13,4)` for general use, `DECIMAL(19,4)` for large amounts +- **GAAP compliance**: Often requires 4-6 decimal places with specific rounding rules + +### Monetary arithmetic best practices + +**Core principles for financial calculations** include always using rust_decimal instead of floating-point, maintaining consistent scale throughout calculations, using explicit rounding strategies, and implementing overflow protection: + +```rust +// Safe financial calculation with tax +let subtotal = dec!(199.99); +let tax_rate = dec!(0.0875); // 8.75% +let tax = (subtotal * tax_rate).round_dp(2); +let total = subtotal.checked_add(tax).expect("Calculation overflow"); + +// Interest calculation with proper rounding +let principal = dec!(10000.00); +let rate = dec!(0.045); // 4.5% annual +let compound_interest = principal * (dec!(1) + rate / dec!(12)).powi(12); +let final_amount = compound_interest.round_dp(2); +``` + +### Integration with accounting systems + +**Database integration** requires appropriate feature flags and schema design: + +```rust +// PostgreSQL integration +[dependencies] +rust_decimal = { version = "1.37", features = ["db-postgres"] } + +// Usage with tokio-postgres +let amount: Decimal = dec!(1234.56); +client.execute( + "INSERT INTO transactions (amount) VALUES ($1)", + &[&amount] +)?; +``` + +**JSON serialization** maintains precision through string-based representation: + +```rust +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +struct Invoice { + #[serde(with = "rust_decimal::serde::str")] + total: Decimal, + #[serde(with = "rust_decimal::serde::arbitrary_precision")] + tax: Decimal, +} +``` + +## Integration patterns + +### Codebase integration strategies + +**Domain-driven design patterns** provide robust abstractions for financial applications: + +```rust +// Value object pattern for type safety +#[derive(Debug, Clone, PartialEq)] +pub struct Balance(Decimal); + +impl Balance { + pub fn new(amount: Decimal) -> Result { + if amount < Decimal::ZERO { + return Err(Error::NegativeBalance); + } + Ok(Balance(amount)) + } + + pub fn add(&self, other: &Balance) -> Result { + let new_amount = self.0.checked_add(other.0) + .ok_or(Error::Overflow)?; + Ok(Balance(new_amount)) + } +} +``` + +**Aggregate root patterns** encapsulate business logic and maintain consistency: + +```rust +pub struct Account { + id: AccountId, + balance: Balance, + transactions: Vec, +} + +impl Account { + pub fn debit(&mut self, amount: Decimal) -> Result<(), DomainError> { + if self.balance.value() < amount { + return Err(DomainError::InsufficientFunds); + } + + self.balance = Balance::new(self.balance.value() - amount)?; + self.transactions.push(Transaction::debit(amount)); + Ok(()) + } +} +``` + +### Testing approaches + +**Property-based testing** with proptest ensures mathematical correctness: + +```rust +use proptest::prelude::*; + +proptest! { + #[test] + fn test_addition_commutative(a in any::(), b in any::()) { + prop_assert_eq!(a + b, b + a); + } + + #[test] + fn test_compound_interest_monotonic( + principal in 0.01f64..1000000.0, + rate in 0.001f64..0.5, + periods in 1u32..100 + ) { + let p = Decimal::from_f64(principal).unwrap(); + let r = Decimal::from_f64(rate).unwrap(); + let result = p * (Decimal::ONE + r).powi(periods as i64); + + // Property: compound interest should always be >= principal + prop_assert!(result >= p); + } +} +``` + +### Error handling strategies + +**Comprehensive error handling** uses the `thiserror` crate for production systems: + +```rust +#[derive(Debug, thiserror::Error)] +pub enum FinancialError { + #[error("Insufficient funds: available {available}, required {required}")] + InsufficientFunds { available: Decimal, required: Decimal }, + + #[error("Currency mismatch: expected {expected}, got {actual}")] + CurrencyMismatch { expected: String, actual: String }, + + #[error("Precision overflow in calculation")] + PrecisionOverflow, + + #[error("Division by zero")] + DivisionByZero, +} +``` + +## Conclusion + +rust_decimal provides a mature, production-ready foundation for financial applications requiring exact precision. Its 128-bit fixed-precision architecture, comprehensive rounding strategies, and extensive ecosystem integration make it ideal for everything from simple e-commerce transactions to complex multi-currency trading platforms. The library's emphasis on correctness over raw performance, combined with Rust's memory safety guarantees, creates a robust platform for mission-critical financial systems where precision errors can have significant monetary consequences. + +The key to successful implementation lies in proper domain modeling, comprehensive error handling, appropriate precision management, and thorough testing including property-based testing for financial invariants. With these practices, rust_decimal serves as a reliable foundation for financial software systems handling high-throughput transaction processing while maintaining the exact precision required for regulatory compliance and accounting accuracy. diff --git a/steel_decimal/src/functions.rs b/steel_decimal/src/functions.rs index 04f400b..0f7da65 100644 --- a/steel_decimal/src/functions.rs +++ b/steel_decimal/src/functions.rs @@ -3,6 +3,29 @@ use rust_decimal::prelude::*; use rust_decimal::MathematicalOps; use std::str::FromStr; +/// Global precision setting for the current Steel execution context +thread_local! { + static PRECISION_CONTEXT: std::cell::RefCell> = std::cell::RefCell::new(None); +} + +/// Set execution precision for all decimal operations in current thread +pub fn set_execution_precision(precision: Option) { + PRECISION_CONTEXT.with(|p| *p.borrow_mut() = precision); +} + +/// Get current execution precision +pub fn get_execution_precision() -> Option { + PRECISION_CONTEXT.with(|p| *p.borrow()) +} + +/// Format decimal according to current execution context +fn format_result(decimal: Decimal) -> String { + match get_execution_precision() { + Some(precision) => decimal.round_dp(precision).to_string(), + None => decimal.to_string(), // Full precision (default behavior) + } +} + /// Helper function to parse decimals with strict accounting precision /// Supports both standard decimal notation AND scientific notation fn parse_decimal(s: &str) -> Result { @@ -10,42 +33,37 @@ fn parse_decimal(s: &str) -> Result { if let Ok(decimal) = Decimal::from_str(s) { return Ok(decimal); } - + // Check for scientific notation if s.contains('e') || s.contains('E') { return parse_scientific_notation(s); } - + Err(format!("Invalid decimal '{}': unknown format", s)) } /// Parse scientific notation (e.g., "1e2", "1.5e-3") using decimal arithmetic fn parse_scientific_notation(s: &str) -> Result { - // Split on 'e' or 'E' (case insensitive) let lower_s = s.to_lowercase(); let parts: Vec<&str> = lower_s.split('e').collect(); if parts.len() != 2 { return Err(format!("Invalid scientific notation '{}': malformed", s)); } - - // Parse mantissa and exponent + let mantissa = Decimal::from_str(parts[0]) .map_err(|_| format!("Invalid mantissa in '{}': {}", s, parts[0]))?; let exponent: i32 = parts[1].parse() .map_err(|_| format!("Invalid exponent in '{}': {}", s, parts[1]))?; - - // Handle exponent using decimal arithmetic to maintain precision + let result = if exponent == 0 { mantissa } else if exponent > 0 { - // Multiply by 10^exponent let ten = Decimal::from(10); let power_of_ten = ten.checked_powi(exponent as i64) .ok_or_else(|| format!("Exponent too large in '{}': {}", s, exponent))?; mantissa.checked_mul(power_of_ten) .ok_or_else(|| format!("Scientific notation result overflow in '{}'", s))? } else { - // Divide by 10^|exponent| for negative exponents let ten = Decimal::from(10); let positive_exp = (-exponent) as i64; let divisor = ten.checked_powi(positive_exp) @@ -53,27 +71,27 @@ fn parse_scientific_notation(s: &str) -> Result { mantissa.checked_div(divisor) .ok_or_else(|| format!("Scientific notation result underflow in '{}'", s))? }; - + Ok(result) } -// Basic arithmetic operations +// Basic arithmetic operations (now precision-aware) pub fn decimal_add(a: String, b: String) -> Result { let a_dec = parse_decimal(&a)?; let b_dec = parse_decimal(&b)?; - Ok((a_dec + b_dec).to_string()) + Ok(format_result(a_dec + b_dec)) } pub fn decimal_sub(a: String, b: String) -> Result { let a_dec = parse_decimal(&a)?; let b_dec = parse_decimal(&b)?; - Ok((a_dec - b_dec).to_string()) + Ok(format_result(a_dec - b_dec)) } pub fn decimal_mul(a: String, b: String) -> Result { let a_dec = parse_decimal(&a)?; let b_dec = parse_decimal(&b)?; - Ok((a_dec * b_dec).to_string()) + Ok(format_result(a_dec * b_dec)) } pub fn decimal_div(a: String, b: String) -> Result { @@ -84,16 +102,87 @@ pub fn decimal_div(a: String, b: String) -> Result { return Err("Division by zero".to_string()); } - Ok((a_dec / b_dec).to_string()) + Ok(format_result(a_dec / b_dec)) } -// Advanced mathematical functions +// Precision-specific operations (explicit precision override) +pub fn decimal_add_p(a: String, b: String, precision: u32) -> Result { + let a_dec = parse_decimal(&a)?; + let b_dec = parse_decimal(&b)?; + let result = a_dec + b_dec; + + Ok(result.round_dp(precision).to_string()) +} + +pub fn decimal_sub_p(a: String, b: String, precision: u32) -> Result { + let a_dec = parse_decimal(&a)?; + let b_dec = parse_decimal(&b)?; + let result = a_dec - b_dec; + + Ok(result.round_dp(precision).to_string()) +} + +pub fn decimal_mul_p(a: String, b: String, precision: u32) -> Result { + let a_dec = parse_decimal(&a)?; + let b_dec = parse_decimal(&b)?; + let result = a_dec * b_dec; + + Ok(result.round_dp(precision).to_string()) +} + +pub fn decimal_div_p(a: String, b: String, precision: u32) -> Result { + let a_dec = parse_decimal(&a)?; + let b_dec = parse_decimal(&b)?; + + if b_dec.is_zero() { + return Err("Division by zero".to_string()); + } + + let result = a_dec / b_dec; + + Ok(result.round_dp(precision).to_string()) +} + +// Precision control functions +pub fn set_precision(precision: u32) -> String { + if precision > 28 { + "Error: Maximum precision is 28 decimal places".to_string() + } else { + set_execution_precision(Some(precision as u32)); + format!("Precision set to {} decimal places", precision) + } +} + +pub fn get_precision() -> String { + match get_execution_precision() { + Some(p) => p.to_string(), + None => "full".to_string(), + } +} + +pub fn clear_precision() -> String { + set_execution_precision(None); + "Precision cleared - using full precision".to_string() +} + +// Format functions with explicit precision +pub fn decimal_format(value: String, precision: u32) -> Result { + let decimal = parse_decimal(&value)?; + + if precision > 28 { + Err("Maximum precision is 28 decimal places".to_string()) + } else { + Ok(decimal.round_dp(precision as u32).to_string()) + } +} + +// Advanced mathematical functions (updated to use format_result) pub fn decimal_pow(base: String, exp: String) -> Result { let base_dec = parse_decimal(&base)?; let exp_dec = parse_decimal(&exp)?; base_dec.checked_powd(exp_dec) - .map(|result| result.to_string()) + .map(|result| format_result(result)) .ok_or_else(|| "Power operation failed or overflowed".to_string()) } @@ -101,7 +190,7 @@ pub fn decimal_sqrt(a: String) -> Result { let a_dec = parse_decimal(&a)?; a_dec.sqrt() - .map(|result| result.to_string()) + .map(|result| format_result(result)) .ok_or_else(|| "Square root failed (negative number?)".to_string()) } @@ -109,7 +198,7 @@ pub fn decimal_ln(a: String) -> Result { let a_dec = parse_decimal(&a)?; a_dec.checked_ln() - .map(|result| result.to_string()) + .map(|result| format_result(result)) .ok_or_else(|| "Natural log failed (non-positive number?)".to_string()) } @@ -117,7 +206,7 @@ pub fn decimal_log10(a: String) -> Result { let a_dec = parse_decimal(&a)?; a_dec.checked_log10() - .map(|result| result.to_string()) + .map(|result| format_result(result)) .ok_or_else(|| "Log10 failed (non-positive number?)".to_string()) } @@ -125,16 +214,16 @@ pub fn decimal_exp(a: String) -> Result { let a_dec = parse_decimal(&a)?; a_dec.checked_exp() - .map(|result| result.to_string()) + .map(|result| format_result(result)) .ok_or_else(|| "Exponential failed or overflowed".to_string()) } -// Trigonometric functions +// Trigonometric functions (updated to use format_result) pub fn decimal_sin(a: String) -> Result { let a_dec = parse_decimal(&a)?; a_dec.checked_sin() - .map(|result| result.to_string()) + .map(|result| format_result(result)) .ok_or_else(|| "Sine calculation failed or overflowed".to_string()) } @@ -142,7 +231,7 @@ pub fn decimal_cos(a: String) -> Result { let a_dec = parse_decimal(&a)?; a_dec.checked_cos() - .map(|result| result.to_string()) + .map(|result| format_result(result)) .ok_or_else(|| "Cosine calculation failed or overflowed".to_string()) } @@ -150,11 +239,11 @@ pub fn decimal_tan(a: String) -> Result { let a_dec = parse_decimal(&a)?; a_dec.checked_tan() - .map(|result| result.to_string()) + .map(|result| format_result(result)) .ok_or_else(|| "Tangent calculation failed or overflowed".to_string()) } -// Comparison functions +// Comparison functions (unchanged) pub fn decimal_gt(a: String, b: String) -> Result { let a_dec = parse_decimal(&a)?; let b_dec = parse_decimal(&b)?; @@ -185,30 +274,34 @@ pub fn decimal_eq(a: String, b: String) -> Result { Ok(a_dec == b_dec) } -// Utility functions +// Utility functions (updated to use format_result) pub fn decimal_abs(a: String) -> Result { let a_dec = parse_decimal(&a)?; - Ok(a_dec.abs().to_string()) + Ok(format_result(a_dec.abs())) } pub fn decimal_round(a: String, places: i32) -> Result { let a_dec = parse_decimal(&a)?; - Ok(a_dec.round_dp(places as u32).to_string()) + if places < 0 { + Ok(a_dec.to_string()) + } else { + Ok(a_dec.round_dp(places as u32).to_string()) + } } pub fn decimal_min(a: String, b: String) -> Result { let a_dec = parse_decimal(&a)?; let b_dec = parse_decimal(&b)?; - Ok(a_dec.min(b_dec).to_string()) + Ok(format_result(a_dec.min(b_dec))) } pub fn decimal_max(a: String, b: String) -> Result { let a_dec = parse_decimal(&a)?; let b_dec = parse_decimal(&b)?; - Ok(a_dec.max(b_dec).to_string()) + Ok(format_result(a_dec.max(b_dec))) } -// Constants +// Constants (unchanged) pub fn decimal_zero() -> String { "0".to_string() } @@ -225,13 +318,13 @@ pub fn decimal_e() -> String { "2.7182818284590452353602874714".to_string() } -// Financial functions +// Financial functions (updated to use format_result) pub fn decimal_percentage(amount: String, percentage: String) -> Result { let amount_dec = parse_decimal(&amount)?; let percentage_dec = parse_decimal(&percentage)?; let hundred = Decimal::from(100); - Ok((amount_dec * percentage_dec / hundred).to_string()) + Ok(format_result(amount_dec * percentage_dec / hundred)) } pub fn decimal_compound(principal: String, rate: String, time: String) -> Result { @@ -243,12 +336,12 @@ pub fn decimal_compound(principal: String, rate: String, time: String) -> Result let compound_factor = (one + rate_dec).checked_powd(time_dec) .ok_or("Compound calculation overflow")?; - Ok((principal_dec * compound_factor).to_string()) + Ok(format_result(principal_dec * compound_factor)) } -// Type conversion helper +// Type conversion helper (updated to use format_result) pub fn to_decimal(s: String) -> Result { parse_decimal(&s) - .map(|d| d.to_string()) + .map(|d| format_result(d)) .map_err(|e| format!("Invalid decimal: {}", e)) } diff --git a/steel_decimal/src/registry.rs b/steel_decimal/src/registry.rs index 7484e0b..6254a04 100644 --- a/steel_decimal/src/registry.rs +++ b/steel_decimal/src/registry.rs @@ -10,6 +10,8 @@ impl FunctionRegistry { /// Register all decimal math functions with the Steel VM pub fn register_all(vm: &mut Engine) { Self::register_basic_arithmetic(vm); + Self::register_precision_arithmetic(vm); + Self::register_precision_control(vm); Self::register_advanced_math(vm); Self::register_trigonometric(vm); Self::register_comparison(vm); @@ -27,6 +29,22 @@ impl FunctionRegistry { vm.register_fn("decimal-div", decimal_div); } + /// Register precision-specific arithmetic functions + pub fn register_precision_arithmetic(vm: &mut Engine) { + vm.register_fn("decimal-add-p", decimal_add_p); + vm.register_fn("decimal-sub-p", decimal_sub_p); + vm.register_fn("decimal-mul-p", decimal_mul_p); + vm.register_fn("decimal-div-p", decimal_div_p); + } + + /// Register precision control functions + pub fn register_precision_control(vm: &mut Engine) { + vm.register_fn("set-precision", set_precision); + vm.register_fn("get-precision", get_precision); + vm.register_fn("clear-precision", clear_precision); + vm.register_fn("decimal-format", decimal_format); + } + /// Register advanced mathematical functions pub fn register_advanced_math(vm: &mut Engine) { vm.register_fn("decimal-pow", decimal_pow); diff --git a/steel_decimal/tests/function_tests.rs b/steel_decimal/tests/function_tests.rs index 7945179..28903de 100644 --- a/steel_decimal/tests/function_tests.rs +++ b/steel_decimal/tests/function_tests.rs @@ -243,3 +243,55 @@ fn test_scientific_notation(#[case] a: &str, #[case] b: &str, #[case] expected: let result = decimal_add(a.to_string(), b.to_string()).unwrap(); assert_eq!(result, expected); } + +// Test precision behavior +#[rstest] +#[case("5", "0", "5")] // Integer + integer = integer +#[case("5.0", "0", "5.0")] // Decimal + integer = decimal +#[case("5.00", "0.00", "5.00")] // Preserves highest precision +#[case("5.1", "0.23", "5.33")] // Normal decimal arithmetic +fn test_precision_preservation(#[case] a: &str, #[case] b: &str, #[case] expected: &str) { + let result = decimal_add(a.to_string(), b.to_string()).unwrap(); + assert_eq!(result, expected); +} + +// Test explicit precision functions +#[rstest] +#[case("5.123", "2.456", 0, "8")] // 0 decimal places +#[case("5.123", "2.456", 2, "7.58")] // 2 decimal places +#[case("5.123", "2.456", 4, "7.5790")] // 4 decimal places +fn test_explicit_precision(#[case] a: &str, #[case] b: &str, #[case] precision: u32, #[case] expected: &str) { + let result = decimal_add_p(a.to_string(), b.to_string(), precision).unwrap(); + assert_eq!(result, expected); +} + +// Test scientific notation edge cases +#[rstest] +#[case("1e0", "1")] // Simple case +#[case("1.0e0", "1.0")] // Preserves decimal +#[case("1e-2", "0.01")] // Negative exponent +#[case("1.5e-3", "0.0015")] // Decimal + negative exponent +#[case("2.5e2", "250.0")] // Decimal + positive exponent +fn test_scientific_edge_cases(#[case] input: &str, #[case] expected: &str) { + let result = to_decimal(input.to_string()).unwrap(); + assert_eq!(result, expected); +} + +// Test precision functions +#[test] +fn test_precision_functions() { + // Test setting precision + assert_eq!(set_precision(2), "Precision set to 2 decimal places"); + assert_eq!(get_precision(), "2"); + + // Test with precision set + let result = decimal_add("1.567".to_string(), "2.891".to_string()).unwrap(); + assert_eq!(result, "4.46"); + + // Test clearing precision + assert_eq!(get_precision(), "full"); + + // Test with full precision + let result = decimal_add("1.567".to_string(), "2.891".to_string()).unwrap(); + assert_eq!(result, "4.458"); +}