// tests/security_tests.rs
use rstest::*;
use steel_decimal::*;
use steel::steel_vm::engine::Engine;
use std::collections::HashMap;
// Test stack overflow protection with deeply nested expressions
#[rstest]
fn test_stack_overflow_protection() {
let parser = ScriptParser::new();
// Create extremely deep nesting (potential stack overflow)
let mut expr = "1".to_string();
for i in 0..10000 {
expr = format!("(+ {} {})", expr, i);
}
// Should not crash the process
let result = std::panic::catch_unwind(|| {
parser.transform(&expr)
});
// Either succeeds or panics gracefully, but shouldn't segfault
match result {
Ok(_) => {}, // Transformation succeeded
Err(_) => {}, // Panic caught, which is acceptable
}
}
// Test memory exhaustion protection
#[rstest]
fn test_memory_exhaustion_protection() {
let parser = ScriptParser::new();
// Create expression designed to consume lots of memory
let large_var_name = "x".repeat(1_000_000); // 1MB variable name
let expr = format!("(+ ${} 1)", large_var_name);
// Should not consume unlimited memory
let result = std::panic::catch_unwind(|| {
parser.transform(&expr)
});
// Should handle gracefully
assert!(result.is_ok());
}
// Test injection attacks through variable names
#[rstest]
#[case("'; DROP TABLE users; --")] // SQL injection style
#[case("$(rm -rf /)")] // Shell injection style
#[case("")] // XSS style
#[case("../../etc/passwd")] // Path traversal style
#[case("${system('rm -rf /')}")] // Template injection style
#[case("{{7*7}}")] // Template injection
#[case("__proto__")] // Prototype pollution
#[case("constructor")] // Constructor pollution
#[case("\\x00\\x01\\x02")] // Null bytes and control chars
fn test_variable_name_injection(#[case] malicious_var: &str) {
let parser = ScriptParser::new();
// Attempt injection through variable name
let expr = format!("(+ ${} 1)", malicious_var);
let transformed = parser.transform(&expr);
// Should transform without executing malicious code
assert!(transformed.contains("get-var"));
// Extract what the parser actually captured as the variable name
let deps = parser.extract_dependencies(&expr);
assert!(!deps.is_empty(), "Should extract at least one dependency");
// The captured variable name should be in the transformed output
let captured_var = deps.iter().next().unwrap();
assert!(transformed.contains(captured_var));
// Security check: For inputs with dangerous characters (spaces, parens),
// verify that the parser truncated the variable name safely
if malicious_var.contains(' ') || malicious_var.contains('(') || malicious_var.contains(')') {
// Variable should be truncated, not the full malicious string
assert_ne!(captured_var, malicious_var,
"Parser should truncate variable names with dangerous characters");
assert!(!transformed.contains(malicious_var),
"Full malicious string should not appear in transformed output");
} else {
// If no dangerous characters, full variable name should be preserved
assert_eq!(captured_var, malicious_var);
}
}
// Test malicious Steel expressions
#[rstest]
#[case("(eval '(system \"rm -rf /\"))")] // Code execution attempt
#[case("(load \"../../etc/passwd\")")] // File access attempt
#[case("(require 'os) (os/execute \"malicious-command\")")] // Module injection
#[case("(define loop (lambda () (loop))) (loop)")] // Infinite recursion
#[case("(define mem-bomb (lambda () (cons 1 (mem-bomb)))) (mem-bomb)")] // Memory bomb
fn test_malicious_steel_expressions(#[case] malicious_expr: &str) {
let steel_decimal = SteelDecimal::new();
// Should not execute malicious Steel code during transformation
let transformed = steel_decimal.transform(malicious_expr);
// Transformation should complete without side effects
assert!(!transformed.is_empty());
// Should not contain the original malicious functions if transformed
if malicious_expr.contains("eval") || malicious_expr.contains("load") {
// These shouldn't be transformed into decimal operations
assert!(!transformed.contains("decimal-"));
}
}
// Test parser regex exploitation
#[rstest]
#[case("((((((((((a")] // Unbalanced parentheses
fn test_parser_regex_exploitation_simple(#[case] malicious_input: &str) {
let parser = ScriptParser::new();
// Should not hang or consume excessive CPU
let start = std::time::Instant::now();
let result = std::panic::catch_unwind(|| {
parser.transform(malicious_input)
});
let duration = start.elapsed();
// Should complete within reasonable time (not ReDoS)
assert!(duration.as_secs() < 5, "Parser took too long: {:?}", duration);
// Should not crash
assert!(result.is_ok());
}
#[rstest]
fn test_parser_regex_exploitation_large_inputs() {
let parser = ScriptParser::new();
// Test extremely long variable reference
let large_var = format!("${}", "a".repeat(100000));
let start = std::time::Instant::now();
let result = std::panic::catch_unwind(|| {
parser.transform(&large_var)
});
let duration = start.elapsed();
assert!(duration.as_secs() < 5, "Large variable parsing took too long: {:?}", duration);
assert!(result.is_ok());
// Test repeated operators
let repeated_ops = format!("({}{})", "+".repeat(100000), " 1 2)");
let start = std::time::Instant::now();
let result = std::panic::catch_unwind(|| {
parser.transform(&repeated_ops)
});
let duration = start.elapsed();
assert!(duration.as_secs() < 5, "Repeated operators parsing took too long: {:?}", duration);
assert!(result.is_ok());
// Test huge string literals
let huge_string = format!("\"{}\"", "a".repeat(1000000));
let start = std::time::Instant::now();
let result = std::panic::catch_unwind(|| {
parser.transform(&huge_string)
});
let duration = start.elapsed();
assert!(duration.as_secs() < 5, "Huge string parsing took too long: {:?}", duration);
assert!(result.is_ok());
}
// Test Steel VM security integration
#[rstest]
fn test_steel_vm_security_integration() {
let mut vm = Engine::new();
let steel_decimal = SteelDecimal::new();
steel_decimal.register_functions(&mut vm);
// Test that we can't escape decimal functions to execute arbitrary code
let malicious_scripts = vec![
r#"(eval "(system \"echo pwned\")")"#,
r#"(load "../../etc/passwd")"#,
r#"(define dangerous (lambda () (system "rm -rf /")))"#,
r#"(require 'steel/core)"#, // Try to access core modules
];
for script in malicious_scripts {
let result = vm.compile_and_run_raw_program(script.to_string());
// These should fail to compile or execute, not succeed
match result {
Ok(_) => {
// If it succeeds, verify it didn't do anything dangerous
// (We can't really test this without side effects, so we assume it's safe)
}
Err(_) => {
// Expected - should fail to execute dangerous code
}
}
}
}
// Test variable access security
#[rstest]
fn test_variable_access_security() {
let mut variables = HashMap::new();
variables.insert("safe_var".to_string(), "42".to_string());
variables.insert("password".to_string(), "secret123".to_string());
variables.insert("api_key".to_string(), "key_abc123".to_string());
let mut vm = Engine::new();
FunctionRegistry::register_variables(&mut vm, variables);
// Test that we can't enumerate all variables
let enumeration_attempts = vec![
r#"(map get-var (list "password" "api_key" "secret"))"#,
r#"(get-var "")"#, // Empty variable name
r#"(get-var nil)"#, // Nil variable name
];
for attempt in enumeration_attempts {
let result = vm.compile_and_run_raw_program(attempt.to_string());
// Should either fail or not reveal sensitive information
match result {
Ok(_) => {}, // If succeeds, assume it's safe
Err(_) => {}, // Expected failure
}
}
}
// Test format string attacks through decimal formatting
#[rstest]
#[case("%s%s%s%s")] // Format string attack
#[case("%n")] // Write to memory attempt
#[case("%x%x%x%x")] // Memory reading attempt
#[case("\\x41\\x41\\x41\\x41")] // Buffer overflow attempt
fn test_format_string_attacks(#[case] format_attack: &str) {
// Test in various contexts where user input might be formatted
let _ = to_decimal(format_attack.to_string());
let _ = decimal_add(format_attack.to_string(), "1".to_string());
let _ = decimal_format("123.456".to_string(), 2); // Shouldn't use user input as format
// Should not crash or leak memory
}
// Test buffer overflow attempts
#[rstest]
fn test_buffer_overflow_attempts() {
// Test with very long inputs that might cause buffer overflows in C libraries
let long_input = "A".repeat(100_000);
let long_number = "1".repeat(10_000) + "." + &"2".repeat(10_000);
// Should handle gracefully without buffer overflows
let _ = to_decimal(long_input);
let _ = to_decimal(long_number.clone());
let _ = decimal_add(long_number.clone(), "1".to_string());
let _ = decimal_sqrt(long_number);
// If we get here without crashing, buffer overflow protection works
}
// Test denial of service through resource exhaustion
#[rstest]
fn test_resource_exhaustion_protection() {
let steel_decimal = SteelDecimal::new();
// Test CPU exhaustion
let cpu_bomb = "(+ ".repeat(10000) + "1" + &")".repeat(10000);
let start = std::time::Instant::now();
let _ = steel_decimal.transform(&cpu_bomb);
let duration = start.elapsed();
// Should not take excessive time
assert!(duration.as_secs() < 10, "CPU exhaustion detected");
// Test memory exhaustion through many variables
let mut steel_decimal = SteelDecimal::new();
for i in 0..100_000 {
steel_decimal.add_variable(format!("var_{}", i), "1".to_string());
}
// Should handle many variables without exhausting memory
let expr = "(+ $var_0 $var_99999)";
let _ = steel_decimal.transform(expr);
}
// Test integer overflow/underflow in precision settings
#[rstest]
#[case(u32::MAX)]
#[case(u32::MAX - 1)]
fn test_integer_overflow_in_precision(#[case] overflow_value: u32) {
// Should handle overflow gracefully
let result = set_precision(overflow_value);
assert!(result.contains("Error") || result.contains("Maximum"));
// Should not set invalid precision
let current = get_precision();
assert_ne!(current, overflow_value.to_string());
}
// Test race conditions in precision settings (security through thread safety)
#[rstest]
fn test_precision_race_conditions() {
use std::sync::{Arc, Barrier};
use std::thread;
let num_threads = 10;
let barrier = Arc::new(Barrier::new(num_threads));
let success_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let handles: Vec<_> = (0..num_threads)
.map(|thread_id| {
let barrier = barrier.clone();
let success_count = success_count.clone();
thread::spawn(move || {
barrier.wait();
// Try to cause race condition
for i in 0..1000 {
let precision = (thread_id + i) % 5;
set_precision(precision as u32);
// Immediately use precision
let result = decimal_add("1.123456789".to_string(), "2.987654321".to_string());
if result.is_ok() {
success_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
// Should have high success rate (race conditions would cause failures)
let successes = success_count.load(std::sync::atomic::Ordering::Relaxed);
assert!(successes > (num_threads * 900) as u32, "Too many race condition failures: {}", successes);
}
// Test SQL injection style attacks through numeric inputs
#[rstest]
#[case("1; DROP TABLE decimals; --")]
#[case("1' OR '1'='1")]
#[case("1 UNION SELECT * FROM passwords")]
#[case("1; exec('rm -rf /')")]
fn test_sql_injection_style_attacks(#[case] injection_attempt: &str) {
// These should be treated as invalid decimal formats
let result = to_decimal(injection_attempt.to_string());
assert!(result.is_err(), "SQL injection attempt should fail: {}", injection_attempt);
// Should also fail in arithmetic
let add_result = decimal_add(injection_attempt.to_string(), "1".to_string());
assert!(add_result.is_err(), "Arithmetic with injection should fail");
}
// Test path traversal through variable names
#[rstest]
#[case("../../../etc/passwd")]
#[case("..\\..\\..\\windows\\system32\\config\\sam")]
#[case("/etc/passwd")]
#[case("C:\\Windows\\System32\\config\\SAM")]
#[case("file:///etc/passwd")]
#[case("data:text/plain;base64,cm9vdDp4OjA6MA==")]
fn test_path_traversal_attacks(#[case] path_attack: &str) {
let mut steel_decimal = SteelDecimal::new();
// Should treat as normal variable name, not file path
steel_decimal.add_variable(path_attack.to_string(), "42".to_string());
let expr = format!("(+ ${} 1)", path_attack);
let transformed = steel_decimal.transform(&expr);
// Should treat as variable reference, not attempt file access
assert!(transformed.contains("get-var"));
assert!(transformed.contains(path_attack));
}
// Test XML/HTML injection through variable values
#[rstest]
#[case("content")]
#[case("")]
#[case("]>")]
fn test_xml_html_injection(#[case] xml_attack: &str) {
let mut steel_decimal = SteelDecimal::new();
// Should treat as string value, not parse as XML/HTML
steel_decimal.add_variable("test_var".to_string(), xml_attack.to_string());
let vars = steel_decimal.get_variables();
assert_eq!(vars.get("test_var").unwrap(), xml_attack);
// Should not interpret as markup
assert!(!xml_attack.is_empty()); // Basic sanity check
}
// Test deserialization attacks
#[rstest]
fn test_deserialization_attacks() {
// Test with serialized data that might trigger deserialization vulnerabilities
let malicious_serialized = vec![
"rO0ABXNyABFqYXZhLnV0aWwuSGFzaE1hcAUH2sHDFmDRAwACRgAKbG9hZEZhY3RvckkACXRocmVzaG9sZHhwP0AAAAAAAAx3CAAAABAAAAABdAABYXQAAWJ4",
"AC ED 00 05 73 72",
"pickle\\x80\\x03]q\\x00.",
];
for payload in malicious_serialized {
// Should treat as regular string, not attempt deserialization
let result = to_decimal(payload.to_string());
assert!(result.is_err(), "Serialized payload should not be valid decimal");
let mut steel_decimal = SteelDecimal::new();
steel_decimal.add_variable("payload".to_string(), payload.to_string());
// Should store as string value
assert_eq!(steel_decimal.get_variables().get("payload").unwrap(), payload);
}
}
// Test timing attacks
#[rstest]
fn test_timing_attack_resistance() {
// Test that comparison operations don't leak information through timing
let values = vec!["1", "1.0", "1.00", "1.000"];
let mut times = Vec::new();
for value in values {
let start = std::time::Instant::now();
let _ = decimal_eq(value.to_string(), "1".to_string());
let duration = start.elapsed();
times.push(duration);
}
// Times should be relatively similar (not vulnerable to timing attacks)
let max_time = times.iter().max().unwrap();
let min_time = times.iter().min().unwrap();
let ratio = max_time.as_nanos() as f64 / min_time.as_nanos() as f64;
// Allow for reasonable variance but not massive differences
assert!(ratio < 10.0, "Timing attack vulnerability detected: ratio = {}", ratio);
}