// tests/concurrency_tests.rs use steel_decimal::*; use std::sync::{Arc, Barrier, Mutex}; use std::thread; use std::time::Duration; use std::collections::HashMap; // Test precision isolation between threads #[test] fn test_precision_thread_isolation() { let num_threads = 10; let barrier = Arc::new(Barrier::new(num_threads)); let results = Arc::new(Mutex::new(Vec::new())); let handles: Vec<_> = (0..num_threads) .map(|thread_id| { let barrier = barrier.clone(); let results = results.clone(); thread::spawn(move || { // Each thread sets different precision let precision = thread_id as u32 % 5; // 0-4 set_precision(precision); // Wait for all threads to set their precision barrier.wait(); // Perform calculation let result = decimal_add("1.123456789".to_string(), "2.987654321".to_string()).unwrap(); // Verify precision is maintained in this thread let current_precision = get_precision(); results.lock().unwrap().push((thread_id, precision, result, current_precision)); }) }) .collect(); for handle in handles { handle.join().unwrap(); } let results = results.lock().unwrap(); // Verify each thread maintained its own precision for (thread_id, set_precision, result, current_precision) in results.iter() { assert_eq!(current_precision, &set_precision.to_string(), "Thread {} precision not isolated", thread_id); // Verify result respects the precision if *set_precision > 0 { let decimal_places = result.split('.').nth(1).map(|s| s.len()).unwrap_or(0); assert!(decimal_places <= *set_precision as usize, "Thread {} result {} has more than {} decimal places", thread_id, result, set_precision); } } } // Test concurrent arithmetic operations #[test] fn test_concurrent_arithmetic_operations() { let num_threads = 20; let operations_per_thread = 100; let barrier = Arc::new(Barrier::new(num_threads)); let errors = Arc::new(Mutex::new(Vec::new())); let handles: Vec<_> = (0..num_threads) .map(|thread_id| { let barrier = barrier.clone(); let errors = errors.clone(); thread::spawn(move || { barrier.wait(); for i in 0..operations_per_thread { let a = format!("{}.{}", thread_id, i); let b = format!("{}.{}", i, thread_id); // Test various operations don't interfere let add_result = decimal_add(a.clone(), b.clone()); let mul_result = decimal_mul(a.clone(), b.clone()); let sub_result = decimal_sub(a.clone(), b.clone()); if add_result.is_err() || mul_result.is_err() || sub_result.is_err() { errors.lock().unwrap().push(format!( "Thread {}, iteration {}: arithmetic error", thread_id, i )); } } }) }) .collect(); for handle in handles { handle.join().unwrap(); } let errors = errors.lock().unwrap(); assert!(errors.is_empty(), "Concurrent arithmetic errors: {:?}", *errors); } // Test Steel VM registration under concurrent load #[test] fn test_concurrent_vm_registration() { use steel::steel_vm::engine::Engine; let num_threads = 5; let barrier = Arc::new(Barrier::new(num_threads)); let errors = Arc::new(Mutex::new(Vec::new())); let handles: Vec<_> = (0..num_threads) .map(|thread_id| { let barrier = barrier.clone(); let errors = errors.clone(); thread::spawn(move || { barrier.wait(); // Each thread creates its own VM and registers functions let mut vm = Engine::new(); FunctionRegistry::register_all(&mut vm); // Test execution let script = r#"(decimal-add "1.5" "2.3")"#; let result = vm.compile_and_run_raw_program(script.to_string()); match result { Ok(vals) => { if vals.len() != 1 { errors.lock().unwrap().push(format!( "Thread {}: Wrong number of results", thread_id )); } } Err(e) => { errors.lock().unwrap().push(format!( "Thread {}: VM execution error: {}", thread_id, e )); } } }) }) .collect(); for handle in handles { handle.join().unwrap(); } let errors = errors.lock().unwrap(); assert!(errors.is_empty(), "Concurrent VM errors: {:?}", *errors); } // Test variable access concurrency #[test] fn test_concurrent_variable_access() { use steel::steel_vm::engine::Engine; let num_threads = 8; let barrier = Arc::new(Barrier::new(num_threads)); let errors = Arc::new(Mutex::new(Vec::new())); let handles: Vec<_> = (0..num_threads) .map(|thread_id| { let barrier = barrier.clone(); let errors = errors.clone(); thread::spawn(move || { // Each thread has its own variable set let mut variables = HashMap::new(); variables.insert(format!("var_{}", thread_id), format!("{}.0", thread_id * 10)); variables.insert("shared".to_string(), "42.0".to_string()); let mut vm = Engine::new(); FunctionRegistry::register_variables(&mut vm, variables); barrier.wait(); // Test variable access let get_script = format!(r#"(get-var "var_{}")"#, thread_id); let has_script = format!(r#"(has-var? "var_{}")"#, thread_id); let shared_script = r#"(get-var "shared")"#.to_string(); for script in [get_script, shared_script] { match vm.compile_and_run_raw_program(script) { Ok(_) => {} Err(e) => { errors.lock().unwrap().push(format!( "Thread {}: Variable access error: {}", thread_id, e )); } } } match vm.compile_and_run_raw_program(has_script) { Ok(_) => {} Err(e) => { errors.lock().unwrap().push(format!( "Thread {}: Variable check error: {}", thread_id, e )); } } }) }) .collect(); for handle in handles { handle.join().unwrap(); } let errors = errors.lock().unwrap(); assert!(errors.is_empty(), "Concurrent variable access errors: {:?}", *errors); } // Test precision state under rapid changes #[test] fn test_rapid_precision_changes() { let num_threads = 4; let changes_per_thread = 1000; let barrier = Arc::new(Barrier::new(num_threads)); let inconsistencies = Arc::new(Mutex::new(0)); let handles: Vec<_> = (0..num_threads) .map(|_thread_id| { let barrier = barrier.clone(); let inconsistencies = inconsistencies.clone(); thread::spawn(move || { barrier.wait(); for i in 0..changes_per_thread { let precision = (i % 5) as u32; // Cycle through 0-4 set_precision(precision); // Immediately check precision let current = get_precision(); if current != precision.to_string() { *inconsistencies.lock().unwrap() += 1; } // Perform calculation and verify let result = decimal_add("1.123456".to_string(), "2.654321".to_string()).unwrap(); if precision > 0 { let decimal_places = result.split('.').nth(1).map(|s| s.len()).unwrap_or(0); if decimal_places > precision as usize { *inconsistencies.lock().unwrap() += 1; } } } }) }) .collect(); for handle in handles { handle.join().unwrap(); } let inconsistencies = *inconsistencies.lock().unwrap(); assert_eq!(inconsistencies, 0, "Found {} precision inconsistencies", inconsistencies); } // Test parser thread safety #[test] fn test_parser_thread_safety() { let num_threads = 10; let transformations_per_thread = 100; let barrier = Arc::new(Barrier::new(num_threads)); let errors = Arc::new(Mutex::new(Vec::new())); let test_scripts = vec![ "(+ 1.5 2.3)", "(* $x $y)", "(sqrt (+ (* $a $a) (* $b $b)))", "(/ (- $max $min) 2)", "(abs (- $value $target))", ]; let handles: Vec<_> = (0..num_threads) .map(|thread_id| { let barrier = barrier.clone(); let errors = errors.clone(); let scripts = test_scripts.clone(); thread::spawn(move || { let parser = ScriptParser::new(); barrier.wait(); for i in 0..transformations_per_thread { let script = &scripts[i % scripts.len()]; let transformed = parser.transform(script); let _dependencies = parser.extract_dependencies(script); // Basic validation let open_count = transformed.chars().filter(|c| *c == '(').count(); let close_count = transformed.chars().filter(|c| *c == ')').count(); if open_count != close_count { errors.lock().unwrap().push(format!( "Thread {}, iteration {}: Unbalanced parentheses in {}", thread_id, i, transformed )); } if !transformed.contains("decimal-") && script.contains('+') { errors.lock().unwrap().push(format!( "Thread {}, iteration {}: Transformation failed for {}", thread_id, i, script )); } } }) }) .collect(); for handle in handles { handle.join().unwrap(); } let errors = errors.lock().unwrap(); assert!(errors.is_empty(), "Parser thread safety errors: {:?}", *errors); } // Test memory safety under concurrent load #[test] fn test_memory_safety_concurrent_load() { let num_threads = 8; let iterations = 500; let barrier = Arc::new(Barrier::new(num_threads)); let handles: Vec<_> = (0..num_threads) .map(|thread_id| { let barrier = barrier.clone(); thread::spawn(move || { barrier.wait(); // Create many SteelDecimal instances for i in 0..iterations { let mut steel_decimal = SteelDecimal::new(); // Add variables steel_decimal.add_variable(format!("var_{}", i), format!("{}.{}", thread_id, i)); // Transform scripts let script = format!("(+ {} {})", i, thread_id); let _ = steel_decimal.transform(&script); // Extract dependencies let _ = steel_decimal.extract_dependencies(&script); // Small delay to increase chance of race conditions if i % 100 == 0 { thread::sleep(Duration::from_micros(1)); } } }) }) .collect(); for handle in handles { handle.join().unwrap(); } // If we get here without panicking, memory safety is maintained } // Test precision cleanup after thread termination #[test] fn test_precision_cleanup_after_thread_death() { // Create thread that sets precision and dies let handle = thread::spawn(|| { set_precision(3); decimal_add("1.123456".to_string(), "2.654321".to_string()).unwrap() }); let result = handle.join().unwrap(); // Verify the result had the precision applied let decimal_places = result.split('.').nth(1).map(|s| s.len()).unwrap_or(0); assert!(decimal_places <= 3); // In main thread, precision should be unaffected let main_precision = get_precision(); // Should be "full" (default) since we haven't set it in main thread assert_eq!(main_precision, "full"); // Create another thread - should start fresh let handle2 = thread::spawn(|| { let precision = get_precision(); (precision, decimal_add("1.123456".to_string(), "2.654321".to_string()).unwrap()) }); let (new_precision, new_result) = handle2.join().unwrap(); assert_eq!(new_precision, "full"); // This result should use full precision let new_decimal_places = new_result.split('.').nth(1).map(|s| s.len()).unwrap_or(0); assert!(new_decimal_places > 3); // Should be more than the previous thread's precision } // Stress test with mixed operations #[test] fn test_concurrent_stress_mixed_operations() { let num_threads = 6; let operations_per_thread = 200; let barrier = Arc::new(Barrier::new(num_threads)); let total_errors = Arc::new(Mutex::new(0)); let handles: Vec<_> = (0..num_threads) .map(|thread_id| { let barrier = barrier.clone(); let total_errors = total_errors.clone(); thread::spawn(move || { let mut errors = 0; barrier.wait(); for i in 0..operations_per_thread { // Mix of precision settings if i % 50 == 0 { set_precision((thread_id as u32) % 5); } // Mix of operations match i % 6 { 0 => { if decimal_add(format!("{}.{}", thread_id, i), "1.0".to_string()).is_err() { errors += 1; } } 1 => { if decimal_mul(format!("{}", i), format!("{}.5", thread_id)).is_err() { errors += 1; } } 2 => { if decimal_sqrt(format!("{}", i + 1)).is_err() && i > 0 { errors += 1; } } 3 => { if decimal_abs(format!("-{}.{}", thread_id, i)).is_err() { errors += 1; } } 4 => { if decimal_gt(format!("{}", i), format!("{}", thread_id)).is_err() { errors += 1; } } 5 => { if to_decimal(format!("{}.{}e1", thread_id, i)).is_err() { errors += 1; } } _ => unreachable!() } } *total_errors.lock().unwrap() += errors; }) }) .collect(); for handle in handles { handle.join().unwrap(); } let total_errors = *total_errors.lock().unwrap(); // Allow some errors for edge cases (like sqrt of 0), but not too many assert!(total_errors < num_threads * operations_per_thread / 10, "Too many errors in stress test: {}", total_errors); }