479 lines
16 KiB
Rust
479 lines
16 KiB
Rust
// tests/concurrency_tests.rs
|
|
use steel_decimal::*;
|
|
use std::sync::{Arc, Barrier, Mutex};
|
|
use std::thread;
|
|
use std::time::Duration;
|
|
use std::collections::HashMap;
|
|
|
|
// Test precision isolation between threads
|
|
#[test]
|
|
fn test_precision_thread_isolation() {
|
|
let num_threads = 10;
|
|
let barrier = Arc::new(Barrier::new(num_threads));
|
|
let results = Arc::new(Mutex::new(Vec::new()));
|
|
|
|
let handles: Vec<_> = (0..num_threads)
|
|
.map(|thread_id| {
|
|
let barrier = barrier.clone();
|
|
let results = results.clone();
|
|
|
|
thread::spawn(move || {
|
|
// Each thread sets different precision
|
|
let precision = thread_id as u32 % 5; // 0-4
|
|
set_precision(precision);
|
|
|
|
// Wait for all threads to set their precision
|
|
barrier.wait();
|
|
|
|
// Perform calculation
|
|
let result = decimal_add("1.123456789".to_string(), "2.987654321".to_string()).unwrap();
|
|
|
|
// Verify precision is maintained in this thread
|
|
let current_precision = get_precision();
|
|
|
|
results.lock().unwrap().push((thread_id, precision, result, current_precision));
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
for handle in handles {
|
|
handle.join().unwrap();
|
|
}
|
|
|
|
let results = results.lock().unwrap();
|
|
|
|
// Verify each thread maintained its own precision
|
|
for (thread_id, set_precision, result, current_precision) in results.iter() {
|
|
assert_eq!(current_precision, &set_precision.to_string(),
|
|
"Thread {} precision not isolated", thread_id);
|
|
|
|
// Verify result respects the precision
|
|
if *set_precision > 0 {
|
|
let decimal_places = result.split('.').nth(1).map(|s| s.len()).unwrap_or(0);
|
|
assert!(decimal_places <= *set_precision as usize,
|
|
"Thread {} result {} has more than {} decimal places",
|
|
thread_id, result, set_precision);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Test concurrent arithmetic operations
|
|
#[test]
|
|
fn test_concurrent_arithmetic_operations() {
|
|
let num_threads = 20;
|
|
let operations_per_thread = 100;
|
|
let barrier = Arc::new(Barrier::new(num_threads));
|
|
let errors = Arc::new(Mutex::new(Vec::new()));
|
|
|
|
let handles: Vec<_> = (0..num_threads)
|
|
.map(|thread_id| {
|
|
let barrier = barrier.clone();
|
|
let errors = errors.clone();
|
|
|
|
thread::spawn(move || {
|
|
barrier.wait();
|
|
|
|
for i in 0..operations_per_thread {
|
|
let a = format!("{}.{}", thread_id, i);
|
|
let b = format!("{}.{}", i, thread_id);
|
|
|
|
// Test various operations don't interfere
|
|
let add_result = decimal_add(a.clone(), b.clone());
|
|
let mul_result = decimal_mul(a.clone(), b.clone());
|
|
let sub_result = decimal_sub(a.clone(), b.clone());
|
|
|
|
if add_result.is_err() || mul_result.is_err() || sub_result.is_err() {
|
|
errors.lock().unwrap().push(format!(
|
|
"Thread {}, iteration {}: arithmetic error",
|
|
thread_id, i
|
|
));
|
|
}
|
|
}
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
for handle in handles {
|
|
handle.join().unwrap();
|
|
}
|
|
|
|
let errors = errors.lock().unwrap();
|
|
assert!(errors.is_empty(), "Concurrent arithmetic errors: {:?}", *errors);
|
|
}
|
|
|
|
// Test Steel VM registration under concurrent load
|
|
#[test]
|
|
fn test_concurrent_vm_registration() {
|
|
use steel::steel_vm::engine::Engine;
|
|
|
|
let num_threads = 5;
|
|
let barrier = Arc::new(Barrier::new(num_threads));
|
|
let errors = Arc::new(Mutex::new(Vec::new()));
|
|
|
|
let handles: Vec<_> = (0..num_threads)
|
|
.map(|thread_id| {
|
|
let barrier = barrier.clone();
|
|
let errors = errors.clone();
|
|
|
|
thread::spawn(move || {
|
|
barrier.wait();
|
|
|
|
// Each thread creates its own VM and registers functions
|
|
let mut vm = Engine::new();
|
|
FunctionRegistry::register_all(&mut vm);
|
|
|
|
// Test execution
|
|
let script = r#"(decimal-add "1.5" "2.3")"#;
|
|
let result = vm.compile_and_run_raw_program(script.to_string());
|
|
|
|
match result {
|
|
Ok(vals) => {
|
|
if vals.len() != 1 {
|
|
errors.lock().unwrap().push(format!(
|
|
"Thread {}: Wrong number of results", thread_id
|
|
));
|
|
}
|
|
}
|
|
Err(e) => {
|
|
errors.lock().unwrap().push(format!(
|
|
"Thread {}: VM execution error: {}", thread_id, e
|
|
));
|
|
}
|
|
}
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
for handle in handles {
|
|
handle.join().unwrap();
|
|
}
|
|
|
|
let errors = errors.lock().unwrap();
|
|
assert!(errors.is_empty(), "Concurrent VM errors: {:?}", *errors);
|
|
}
|
|
|
|
// Test variable access concurrency
|
|
#[test]
|
|
fn test_concurrent_variable_access() {
|
|
use steel::steel_vm::engine::Engine;
|
|
|
|
let num_threads = 8;
|
|
let barrier = Arc::new(Barrier::new(num_threads));
|
|
let errors = Arc::new(Mutex::new(Vec::new()));
|
|
|
|
let handles: Vec<_> = (0..num_threads)
|
|
.map(|thread_id| {
|
|
let barrier = barrier.clone();
|
|
let errors = errors.clone();
|
|
|
|
thread::spawn(move || {
|
|
// Each thread has its own variable set
|
|
let mut variables = HashMap::new();
|
|
variables.insert(format!("var_{}", thread_id), format!("{}.0", thread_id * 10));
|
|
variables.insert("shared".to_string(), "42.0".to_string());
|
|
|
|
let mut vm = Engine::new();
|
|
FunctionRegistry::register_variables(&mut vm, variables);
|
|
|
|
barrier.wait();
|
|
|
|
// Test variable access
|
|
let get_script = format!(r#"(get-var "var_{}")"#, thread_id);
|
|
let has_script = format!(r#"(has-var? "var_{}")"#, thread_id);
|
|
let shared_script = r#"(get-var "shared")"#.to_string();
|
|
|
|
for script in [get_script, shared_script] {
|
|
match vm.compile_and_run_raw_program(script) {
|
|
Ok(_) => {}
|
|
Err(e) => {
|
|
errors.lock().unwrap().push(format!(
|
|
"Thread {}: Variable access error: {}", thread_id, e
|
|
));
|
|
}
|
|
}
|
|
}
|
|
|
|
match vm.compile_and_run_raw_program(has_script) {
|
|
Ok(_) => {}
|
|
Err(e) => {
|
|
errors.lock().unwrap().push(format!(
|
|
"Thread {}: Variable check error: {}", thread_id, e
|
|
));
|
|
}
|
|
}
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
for handle in handles {
|
|
handle.join().unwrap();
|
|
}
|
|
|
|
let errors = errors.lock().unwrap();
|
|
assert!(errors.is_empty(), "Concurrent variable access errors: {:?}", *errors);
|
|
}
|
|
|
|
// Test precision state under rapid changes
|
|
#[test]
|
|
fn test_rapid_precision_changes() {
|
|
let num_threads = 4;
|
|
let changes_per_thread = 1000;
|
|
let barrier = Arc::new(Barrier::new(num_threads));
|
|
let inconsistencies = Arc::new(Mutex::new(0));
|
|
|
|
let handles: Vec<_> = (0..num_threads)
|
|
.map(|_thread_id| {
|
|
let barrier = barrier.clone();
|
|
let inconsistencies = inconsistencies.clone();
|
|
|
|
thread::spawn(move || {
|
|
barrier.wait();
|
|
|
|
for i in 0..changes_per_thread {
|
|
let precision = (i % 5) as u32; // Cycle through 0-4
|
|
|
|
set_precision(precision);
|
|
|
|
// Immediately check precision
|
|
let current = get_precision();
|
|
if current != precision.to_string() {
|
|
*inconsistencies.lock().unwrap() += 1;
|
|
}
|
|
|
|
// Perform calculation and verify
|
|
let result = decimal_add("1.123456".to_string(), "2.654321".to_string()).unwrap();
|
|
|
|
if precision > 0 {
|
|
let decimal_places = result.split('.').nth(1).map(|s| s.len()).unwrap_or(0);
|
|
if decimal_places > precision as usize {
|
|
*inconsistencies.lock().unwrap() += 1;
|
|
}
|
|
}
|
|
}
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
for handle in handles {
|
|
handle.join().unwrap();
|
|
}
|
|
|
|
let inconsistencies = *inconsistencies.lock().unwrap();
|
|
assert_eq!(inconsistencies, 0, "Found {} precision inconsistencies", inconsistencies);
|
|
}
|
|
|
|
// Test parser thread safety
|
|
#[test]
|
|
fn test_parser_thread_safety() {
|
|
let num_threads = 10;
|
|
let transformations_per_thread = 100;
|
|
let barrier = Arc::new(Barrier::new(num_threads));
|
|
let errors = Arc::new(Mutex::new(Vec::new()));
|
|
|
|
let test_scripts = vec![
|
|
"(+ 1.5 2.3)",
|
|
"(* $x $y)",
|
|
"(sqrt (+ (* $a $a) (* $b $b)))",
|
|
"(/ (- $max $min) 2)",
|
|
"(abs (- $value $target))",
|
|
];
|
|
|
|
let handles: Vec<_> = (0..num_threads)
|
|
.map(|thread_id| {
|
|
let barrier = barrier.clone();
|
|
let errors = errors.clone();
|
|
let scripts = test_scripts.clone();
|
|
|
|
thread::spawn(move || {
|
|
let parser = ScriptParser::new();
|
|
barrier.wait();
|
|
|
|
for i in 0..transformations_per_thread {
|
|
let script = &scripts[i % scripts.len()];
|
|
|
|
let transformed = parser.transform(script);
|
|
let _dependencies = parser.extract_dependencies(script);
|
|
|
|
// Basic validation
|
|
let open_count = transformed.chars().filter(|c| *c == '(').count();
|
|
let close_count = transformed.chars().filter(|c| *c == ')').count();
|
|
|
|
if open_count != close_count {
|
|
errors.lock().unwrap().push(format!(
|
|
"Thread {}, iteration {}: Unbalanced parentheses in {}",
|
|
thread_id, i, transformed
|
|
));
|
|
}
|
|
|
|
if !transformed.contains("decimal-") && script.contains('+') {
|
|
errors.lock().unwrap().push(format!(
|
|
"Thread {}, iteration {}: Transformation failed for {}",
|
|
thread_id, i, script
|
|
));
|
|
}
|
|
}
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
for handle in handles {
|
|
handle.join().unwrap();
|
|
}
|
|
|
|
let errors = errors.lock().unwrap();
|
|
assert!(errors.is_empty(), "Parser thread safety errors: {:?}", *errors);
|
|
}
|
|
|
|
// Test memory safety under concurrent load
|
|
#[test]
|
|
fn test_memory_safety_concurrent_load() {
|
|
let num_threads = 8;
|
|
let iterations = 500;
|
|
let barrier = Arc::new(Barrier::new(num_threads));
|
|
|
|
let handles: Vec<_> = (0..num_threads)
|
|
.map(|thread_id| {
|
|
let barrier = barrier.clone();
|
|
|
|
thread::spawn(move || {
|
|
barrier.wait();
|
|
|
|
// Create many SteelDecimal instances
|
|
for i in 0..iterations {
|
|
let mut steel_decimal = SteelDecimal::new();
|
|
|
|
// Add variables
|
|
steel_decimal.add_variable(format!("var_{}", i), format!("{}.{}", thread_id, i));
|
|
|
|
// Transform scripts
|
|
let script = format!("(+ {} {})", i, thread_id);
|
|
let _ = steel_decimal.transform(&script);
|
|
|
|
// Extract dependencies
|
|
let _ = steel_decimal.extract_dependencies(&script);
|
|
|
|
// Small delay to increase chance of race conditions
|
|
if i % 100 == 0 {
|
|
thread::sleep(Duration::from_micros(1));
|
|
}
|
|
}
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
for handle in handles {
|
|
handle.join().unwrap();
|
|
}
|
|
|
|
// If we get here without panicking, memory safety is maintained
|
|
}
|
|
|
|
// Test precision cleanup after thread termination
|
|
#[test]
|
|
fn test_precision_cleanup_after_thread_death() {
|
|
// Create thread that sets precision and dies
|
|
let handle = thread::spawn(|| {
|
|
set_precision(3);
|
|
decimal_add("1.123456".to_string(), "2.654321".to_string()).unwrap()
|
|
});
|
|
|
|
let result = handle.join().unwrap();
|
|
|
|
// Verify the result had the precision applied
|
|
let decimal_places = result.split('.').nth(1).map(|s| s.len()).unwrap_or(0);
|
|
assert!(decimal_places <= 3);
|
|
|
|
// In main thread, precision should be unaffected
|
|
let main_precision = get_precision();
|
|
// Should be "full" (default) since we haven't set it in main thread
|
|
assert_eq!(main_precision, "full");
|
|
|
|
// Create another thread - should start fresh
|
|
let handle2 = thread::spawn(|| {
|
|
let precision = get_precision();
|
|
(precision, decimal_add("1.123456".to_string(), "2.654321".to_string()).unwrap())
|
|
});
|
|
|
|
let (new_precision, new_result) = handle2.join().unwrap();
|
|
assert_eq!(new_precision, "full");
|
|
|
|
// This result should use full precision
|
|
let new_decimal_places = new_result.split('.').nth(1).map(|s| s.len()).unwrap_or(0);
|
|
assert!(new_decimal_places > 3); // Should be more than the previous thread's precision
|
|
}
|
|
|
|
// Stress test with mixed operations
|
|
#[test]
|
|
fn test_concurrent_stress_mixed_operations() {
|
|
let num_threads = 6;
|
|
let operations_per_thread = 200;
|
|
let barrier = Arc::new(Barrier::new(num_threads));
|
|
let total_errors = Arc::new(Mutex::new(0));
|
|
|
|
let handles: Vec<_> = (0..num_threads)
|
|
.map(|thread_id| {
|
|
let barrier = barrier.clone();
|
|
let total_errors = total_errors.clone();
|
|
|
|
thread::spawn(move || {
|
|
let mut errors = 0;
|
|
barrier.wait();
|
|
|
|
for i in 0..operations_per_thread {
|
|
// Mix of precision settings
|
|
if i % 50 == 0 {
|
|
set_precision((thread_id as u32) % 5);
|
|
}
|
|
|
|
// Mix of operations
|
|
match i % 6 {
|
|
0 => {
|
|
if decimal_add(format!("{}.{}", thread_id, i), "1.0".to_string()).is_err() {
|
|
errors += 1;
|
|
}
|
|
}
|
|
1 => {
|
|
if decimal_mul(format!("{}", i), format!("{}.5", thread_id)).is_err() {
|
|
errors += 1;
|
|
}
|
|
}
|
|
2 => {
|
|
if decimal_sqrt(format!("{}", i + 1)).is_err() && i > 0 {
|
|
errors += 1;
|
|
}
|
|
}
|
|
3 => {
|
|
if decimal_abs(format!("-{}.{}", thread_id, i)).is_err() {
|
|
errors += 1;
|
|
}
|
|
}
|
|
4 => {
|
|
if decimal_gt(format!("{}", i), format!("{}", thread_id)).is_err() {
|
|
errors += 1;
|
|
}
|
|
}
|
|
5 => {
|
|
if to_decimal(format!("{}.{}e1", thread_id, i)).is_err() {
|
|
errors += 1;
|
|
}
|
|
}
|
|
_ => unreachable!()
|
|
}
|
|
}
|
|
|
|
*total_errors.lock().unwrap() += errors;
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
for handle in handles {
|
|
handle.join().unwrap();
|
|
}
|
|
|
|
let total_errors = *total_errors.lock().unwrap();
|
|
|
|
// Allow some errors for edge cases (like sqrt of 0), but not too many
|
|
assert!(total_errors < num_threads * operations_per_thread / 10,
|
|
"Too many errors in stress test: {}", total_errors);
|
|
}
|