3 Commits

Author SHA1 Message Date
Priec
b86b3334d6 hod2 2026-03-14 08:19:00 +01:00
Priec
8fc8addcac learner used instead of manual version 2026-03-13 21:52:37 +01:00
Priec
f6b9d79062 cvicenie 3 hotove 2026-03-12 22:06:29 +01:00
11 changed files with 8930 additions and 143 deletions

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
*/target/
*/mnist.npz

2420
hod_1/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,9 +4,11 @@ version = "0.1.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
burn = { version = "0.20.1", default-features = false, features = ["ndarray"] } burn = { version = "0.20.1", default-features = false, features = ["ndarray", "std", "train"] }
burn-autodiff = "0.20.1"
burn-ndarray = "0.20.1" burn-ndarray = "0.20.1"
clap = { version = "4.5.60", features = ["derive"] } clap = { version = "4.5.60", features = ["derive"] }
ndarray = "0.17.2" ndarray = "0.17.2"
npyz = { version = "0.8.4", features = ["npz"] } npyz = { version = "0.8.4", features = ["npz"] }
serde = { version = "1.0.228", features = ["derive"] }
zip = { version = "8.2.0", features = ["deflate"] } zip = { version = "8.2.0", features = ["deflate"] }

View File

@@ -1,57 +1,264 @@
use burn::tensor::Tensor; // src/lib.rs
use burn::tensor::linalg::diag;
use burn::tensor::Shape;
pub type B = burn_ndarray::NdArray<f64>; use burn::config::Config;
use burn::data::dataloader::DataLoaderBuilder;
use burn::data::dataloader::batcher::Batcher;
use burn::data::dataset::Dataset;
use burn::module::Module;
use burn::nn::loss::CrossEntropyLossConfig;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamConfig;
use burn::record::CompactRecorder;
use burn::tensor::activation;
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::{Int, Tensor};
use burn::train::metric::{AccuracyMetric, LossMetric};
use burn::lr_scheduler::constant::ConstantLr;
use burn::train::{
ClassificationOutput, InferenceStep, Learner, SupervisedTraining,
TrainOutput, TrainStep, TrainingStrategy,
};
use std::str::FromStr;
fn l2_norm(v: Tensor<B, 1>) -> f64 { pub type B = burn_autodiff::Autodiff<burn_ndarray::NdArray<f64>>;
v.clone()
.mul(v) // element-wise: v_i * v_i // Model
.sum() // suma všetkých v_i^2 #[derive(Module, Debug)]
.sqrt() // odmocnina pub struct MnistClassifier<B: Backend> {
.into_scalar() // na f32 hidden: Vec<Linear<B>>,
output: Linear<B>,
activation: Activation,
} }
/// Input: [N, 784], Output: [N, 784] impl<B: Backend> MnistClassifier<B> {
pub fn center(x: Tensor<B, 2>) -> Tensor<B, 2> { pub fn new(
let mean = x.clone().mean_dim(0); device: &B::Device,
x.sub(mean) hidden_layers: usize,
} hidden_layer_size: usize,
activation: Activation,
) -> Self {
let mut hidden = Vec::new();
let mut in_size = 784;
/// Input: [N, 784], Output: [784, 784] for _ in 0..hidden_layers {
pub fn covariance(x: Tensor<B, 2>) -> Tensor<B, 2> { hidden.push(LinearConfig::new(in_size, hidden_layer_size).init(device));
let cen = center(x); in_size = hidden_layer_size;
let transpose = cen.clone().transpose(); }
let n = cen.dims()[0] as f64;
let mul = transpose.matmul(cen);
mul.div_scalar(n - 1.0)
} let output = LinearConfig::new(in_size, 10).init(device);
Self { hidden, output, activation }
pub fn total_variance(x: Tensor<B, 2>) -> f64 { }
let cov: Tensor<B, 2> = covariance(x);
let diag: Tensor<B, 1> = diag(cov); pub fn forward(&self, images: Tensor<B, 2>) -> Tensor<B, 2> {
let sum = diag.sum().into_scalar(); let mut x = images;
sum for layer in &self.hidden {
} x = layer.forward(x);
x = self.activation.forward(x);
/// Input: [784, 784], scalar, Output: [784] }
pub fn power_iteration(cov: Tensor<B, 2>, iterations: usize) -> (Tensor<B, 1>, f64) { self.output.forward(x)
let n = cov.dims()[0];
let device = cov.device();
let mut v: Tensor<B, 1> = Tensor::ones(Shape::new([n]), &device);
let mut s: f64 = 0.0;
for _ in 0..iterations {
let v_new_2d = cov.clone().matmul(v.reshape([n, 1]));
let v_new = v_new_2d.squeeze::<1>();
s = l2_norm(v_new.clone());
v = v_new.div_scalar(s);
} }
return (v, s);
} }
/// Input: [784, 784], [784], Output: f32 impl<B: AutodiffBackend> TrainStep for MnistClassifier<B> {
pub fn explained_variance(total_var: f64, s: f64) -> f64 { type Input = MnistBatch<B>;
s / total_var type Output = ClassificationOutput<B>;
fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
let output = self.forward(batch.images);
let loss = CrossEntropyLossConfig::new()
.init(&output.device())
.forward(output.clone(), batch.targets.clone());
TrainOutput::new(
self,
loss.backward(),
ClassificationOutput { loss, output, targets: batch.targets },
)
}
}
impl<B: Backend> InferenceStep for MnistClassifier<B> {
type Input = MnistBatch<B>;
type Output = ClassificationOutput<B>;
fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
let output = self.forward(batch.images);
let loss = CrossEntropyLossConfig::new()
.init(&output.device())
.forward(output.clone(), batch.targets.clone());
ClassificationOutput { loss, output, targets: batch.targets }
}
}
// Activation
#[derive(Debug, Clone, Copy, Module, Default, serde::Serialize, serde::Deserialize)]
pub enum Activation {
#[default]
None,
ReLU,
Tanh,
Sigmoid,
}
impl FromStr for Activation {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"none" => Ok(Activation::None),
"relu" => Ok(Activation::ReLU),
"tanh" => Ok(Activation::Tanh),
"sigmoid" => Ok(Activation::Sigmoid),
_ => Err(format!("Unknown activation: {}", s)),
}
}
}
impl Activation {
pub fn forward<B: Backend, const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
match self {
Activation::None => x,
Activation::ReLU => activation::relu(x),
Activation::Tanh => activation::tanh(x),
Activation::Sigmoid => activation::sigmoid(x),
}
}
}
// Dataset & Batch
#[derive(Clone, Debug)]
pub struct MnistItem {
pub image: [f64; 784],
pub label: u8,
}
pub struct MnistDataset {
items: Vec<MnistItem>,
}
impl MnistDataset {
pub fn new(items: Vec<MnistItem>) -> Self {
Self { items }
}
}
impl Dataset<MnistItem> for MnistDataset {
fn get(&self, index: usize) -> Option<MnistItem> {
self.items.get(index).cloned()
}
fn len(&self) -> usize {
self.items.len()
}
}
#[derive(Clone, Debug)]
pub struct MnistBatch<B: Backend> {
pub images: Tensor<B, 2>,
pub targets: Tensor<B, 1, Int>,
}
#[derive(Clone)]
pub struct MnistBatcher;
impl MnistBatcher {
pub fn new() -> Self {
Self
}
}
impl<B: Backend<FloatElem = f64, IntElem = i64>> Batcher<B, MnistItem, MnistBatch<B>>
for MnistBatcher
{
fn batch(&self, items: Vec<MnistItem>, device: &B::Device) -> MnistBatch<B> {
let n = items.len();
let image_data: Vec<f64> = items.iter().flat_map(|i| i.image).collect();
let label_data: Vec<i64> = items.iter().map(|i| i.label as i64).collect();
let images = Tensor::<B, 2>::from_data(
burn::tensor::TensorData::new(image_data, [n, 784]),
device, // ← use the passed-in device, not self.device
);
let targets = Tensor::<B, 1, Int>::from_data(
burn::tensor::TensorData::new(label_data, [n]),
device,
);
MnistBatch { images, targets }
}
}
// Config
#[derive(Config, Debug)]
pub struct MnistModelConfig {
pub hidden_layers: usize,
pub hidden_layer_size: usize,
pub activation: Activation,
}
impl MnistModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> MnistClassifier<B> {
MnistClassifier::new(device, self.hidden_layers, self.hidden_layer_size, self.activation)
}
}
#[derive(Config, Debug)]
pub struct MnistTrainingConfig {
pub model: MnistModelConfig,
pub optimizer: AdamConfig,
#[config(default = 10)]
pub num_epochs: usize,
#[config(default = 64)]
pub batch_size: usize,
#[config(default = 4)]
pub num_workers: usize,
#[config(default = 42)]
pub seed: u64,
#[config(default = 1.0e-4)]
pub learning_rate: f64,
}
// Training
impl MnistTrainingConfig {
pub fn train<B>(
&self,
device: B::Device,
train_dataset: MnistDataset,
valid_dataset: MnistDataset,
) where
B: AutodiffBackend<FloatElem = f64, IntElem = i64>,
B::InnerBackend: Backend<FloatElem = f64, IntElem = i64>,
{
B::seed(&device, self.seed);
let model = self.model.init::<B>(&device);
let optim = self.optimizer.init();
let batcher_train = MnistBatcher::new();
let batcher_valid = MnistBatcher::new();
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(self.batch_size)
.shuffle(self.seed)
.num_workers(self.num_workers)
.build(train_dataset);
let dataloader_valid = DataLoaderBuilder::new(batcher_valid)
.batch_size(self.batch_size)
.num_workers(self.num_workers)
.build(valid_dataset);
let training = SupervisedTraining::new("/tmp/artifacts", dataloader_train, dataloader_valid)
.metrics((AccuracyMetric::new(), LossMetric::new()))
.with_file_checkpointer(CompactRecorder::new())
.num_epochs(self.num_epochs)
.summary()
.with_training_strategy(TrainingStrategy::SingleDevice(device));
let _result = training.launch(Learner::new(
model,
optim,
ConstantLr::new(self.learning_rate), // plain float → constant LR scheduler
));
}
} }

View File

@@ -1,86 +1,115 @@
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use clap::Parser; use clap::Parser;
use hod_1::{covariance, explained_variance, power_iteration, total_variance, B}; use hod_1::{Activation, MnistDataset, MnistItem, MnistModelConfig, MnistTrainingConfig, B};
use burn::optim::AdamConfig;
use std::fs::File; use std::fs::File;
use std::io::Read; use std::io::Read;
use std::str::FromStr;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about)]
struct Args { struct Args {
#[arg(long = "examples", default_value = "1024")] #[arg(long, default_value = "none")]
examples: usize, activation: String,
#[arg(long = "iterations", default_value = "64")]
iterations: usize, #[arg(long, default_value = "64")]
batch_size: usize,
#[arg(long, default_value = "10")]
epochs: usize,
#[arg(long, default_value = "100")]
hidden_layer_size: usize,
#[arg(long, default_value = "1")]
hidden_layers: usize,
#[arg(long, default_value = "42")]
seed: u64,
/// Fraction of training data used for validation (e.g. 0.1 = 10 %)
#[arg(long, default_value = "0.1")]
valid_split: f64,
} }
fn load_mnist(examples: usize, device: &<B as Backend>::Device) -> Tensor<B, 2> { fn load_mnist_items(examples: usize) -> Vec<MnistItem> {
let file = File::open("mnist.npz").expect("Cannot open mnist.npz"); let file = File::open("mnist.npz").expect("Cannot open mnist.npz");
let mut archive = zip::ZipArchive::new(file).expect("Cannot read zip"); let mut archive = zip::ZipArchive::new(file).expect("Cannot read zip");
// Print all available array names so you can see what's inside // images
eprintln!("Arrays in mnist.npz:"); let image_candidates = ["train_images.npy", "train.images.npy", "x_train.npy", "images.npy"];
for i in 0..archive.len() { let mut image_bytes = Vec::new();
eprintln!(" {}", archive.by_index(i).unwrap().name()); for name in &image_candidates {
} if let Ok(mut entry) = archive.by_name(name) {
entry.read_to_end(&mut image_bytes).expect("read images");
// Try the most common key names used for MNIST train images
let candidates = [
"train_images.npy",
"train.images.npy",
"x_train.npy",
"images.npy",
];
let mut bytes = Vec::new();
let mut found_name = "";
for name in &candidates {
if archive.by_name(name).is_ok() {
archive
.by_name(name)
.unwrap()
.read_to_end(&mut bytes)
.expect("Failed to read entry");
found_name = name;
break; break;
} }
} }
assert!(!bytes.is_empty(), "Could not find train images — check the printed names above and update candidates[]"); assert!(!image_bytes.is_empty(), "Could not find train images in mnist.npz");
eprintln!("Loaded from: {found_name}");
// Parse the .npy header to get the shape // labels
let npy = npyz::NpyFile::new(&bytes[..]).expect("Cannot parse npy"); let label_candidates = ["train_labels.npy", "train.labels.npy", "y_train.npy", "labels.npy"];
let shape = npy.shape().to_vec(); let mut label_bytes = Vec::new();
eprintln!("Raw array shape: {shape:?}"); for name in &label_candidates {
if let Ok(mut entry) = archive.by_name(name) {
entry.read_to_end(&mut label_bytes).expect("read labels");
break;
}
}
assert!(!label_bytes.is_empty(), "Could not find train labels in mnist.npz");
// MNIST is stored as uint8 (0255); we normalise to [0.0, 1.0] // parse
let raw: Vec<u8> = npy.into_vec().expect("Failed to read as u8 — dtype mismatch?"); let image_npy = npyz::NpyFile::new(&image_bytes[..]).expect("parse images");
let image_shape = image_npy.shape().to_vec();
let image_raw: Vec<u8> = image_npy.into_vec().expect("images to vec");
let n = examples.min(image_shape[0] as usize);
let pixels = image_raw.len() / image_shape[0] as usize; // should be 784
assert_eq!(pixels, 784, "Expected 784 pixels per image, got {pixels}");
let n = examples.min(shape[0] as usize); let label_npy = npyz::NpyFile::new(&label_bytes[..]).expect("parse labels");
let pixels = raw.len() / shape[0] as usize; // 784 = 1*28*28, regardless of how axes are ordered let label_raw: Vec<u8> = label_npy.into_vec().expect("labels to vec");
let data: Vec<f64> = raw[..n * pixels] // build items
.iter() (0..n)
.map(|&p| p as f64 / 255.0) .map(|i| {
.collect(); let mut image = [0f64; 784];
for (j, &px) in image_raw[i * 784..(i + 1) * 784].iter().enumerate() {
eprintln!("Loaded {n} examples, {pixels} pixels each"); image[j] = px as f64 / 255.0;
}
let tensor_data = burn::tensor::TensorData::new(data, [n, pixels]); MnistItem { image, label: label_raw[i] }
Tensor::<B, 2>::from_data(tensor_data, device) })
.collect()
} }
fn main() { fn main() {
let args = Args::parse(); let args = Args::parse();
let device = <B as Backend>::Device::default(); let device = burn_ndarray::NdArrayDevice::Cpu;
let activation = Activation::from_str(&args.activation).unwrap_or_default();
let x = load_mnist(args.examples, &device); println!("Loading MNIST…");
let all_items = load_mnist_items(60_000);
let cov = covariance(x.clone()); // Split into train / validation
let total_var = total_variance(x.clone()); let valid_n = (all_items.len() as f64 * args.valid_split) as usize;
let (_pc, s) = power_iteration(cov, args.iterations); let train_n = all_items.len() - valid_n;
let ev = explained_variance(total_var, s); let mut items = all_items;
let valid_items = items.split_off(train_n); // last `valid_n` items
let train_items = items;
println!("Total variance: {:.2}", total_var); println!("Train: {} Valid: {}", train_items.len(), valid_items.len());
println!("Explained variance: {:.2}%", 100.0 * ev);
let train_dataset = MnistDataset::new(train_items);
let valid_dataset = MnistDataset::new(valid_items);
let config = MnistTrainingConfig::new(
MnistModelConfig::new(args.hidden_layers, args.hidden_layer_size, activation),
AdamConfig::new(),
)
.with_num_epochs(args.epochs)
.with_batch_size(args.batch_size)
.with_num_workers(1) // NdArray backend is single-threaded; keep at 1
.with_seed(args.seed)
.with_learning_rate(1e-3);
println!("Starting training…");
config.train::<B>(device, train_dataset, valid_dataset);
} }

5987
hod_2/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

16
hod_2/Cargo.toml Normal file
View File

@@ -0,0 +1,16 @@
[package]
name = "hod_2"
version = "0.1.0"
edition = "2024"
[dependencies]
burn = { version = "0.20.1", default-features = false, features = ["ndarray", "std", "train"] }
burn-autodiff = "0.20.1"
burn-ndarray = "0.20.1"
clap = { version = "4.5.60", features = ["derive"] }
ndarray = "0.17.2"
npyz = { version = "0.8.4", features = ["npz"] }
rand = "0.10.0"
rand_distr = "0.6.0"
serde = { version = "1.0.228", features = ["derive"] }
zip = { version = "8.2.0", features = ["deflate"] }

50
hod_2/plan.md Normal file
View File

@@ -0,0 +1,50 @@
## Phase 1: Core Data Structures
**`src/model.rs`** - Manual parameter management
- `struct Parameters<B: Backend>`: holds `w1, b1, w2, b2` as `Tensor<B, 2>`
- `impl Parameters`: initialization with `randn(0.1)` for weights, zeros for biases
- No `nn.Linear`—manual tensors to match the Python exercise
## Phase 2: Forward Pass
**`src/forward.rs`** or in `model.rs`
- `fn forward<B: Backend>(params: &Parameters<B>, images: Tensor<B, 2>) -> Tensor<B, 2>`
- Cast `uint8` images to `f32`, divide by 255, flatten to `[batch, 784]`
- `hidden = tanh(images @ w1 + b1)`
- `logits = hidden @ w2 + b2`
- Return raw logits (no softmax here)
## Phase 3: Loss Computation
**`src/loss.rs`**
- `fn cross_entropy_loss<B: Backend>(logits: Tensor<B, 2>, labels: Tensor<B, 1, Int>) -> Tensor<B, 0>`
- Manual implementation—no `CrossEntropyLoss` module
- `softmax = exp(logits - max) / sum(exp(logits - max))`
- Index `softmax` by gold labels to get `p_correct`
- `loss = -mean(log(p_correct))`
## Phase 4: Backward Pass & SGD
**`src/train.rs`**
- `fn train_epoch<B: Backend>(params: &mut Parameters<B>, dataset: &[MnistItem], args: &Args)`
- For each batch:
1. `let loss = cross_entropy_loss(forward(&params, images), labels)`
2. `let grads = loss.backward()` — automatic differentiation
3. **Manual SGD**: `param = param - lr * grad` for each parameter
4. No `Optimizer`—raw gradient descent like Python
## Phase 5: Evaluation
**`src/eval.rs`**
- `fn evaluate<B: Backend>(params: &Parameters<B>, dataset: &[MnistItem]) -> f64`
- `argmax` on logits, compare to labels, return accuracy
## Phase 6: Main Training Loop
**Update `src/main.rs`**
- Parse args ✓ (done)
- Load data ✓ (done)
- Initialize `Parameters` with seed
- Loop `args.epochs`: `train_epoch``evaluate(dev)` → print
- Final `evaluate(test)`

1
hod_2/src/lib.rs Normal file
View File

@@ -0,0 +1 @@
pub mod model;

79
hod_2/src/main.rs Normal file
View File

@@ -0,0 +1,79 @@
use clap::Parser;
use std::fs::File;
use std::io::{Cursor, Read};
#[derive(Parser, Debug)]
#[command(author, version, about)]
struct Args {
#[arg(long, default_value_t = 50)]
batch_size: usize,
#[arg(long, default_value_t = 10)]
epochs: usize,
#[arg(long, default_value_t = 100)]
hidden_layer_size: usize,
#[arg(long, default_value_t = 0.1)]
learning_rate: f64,
#[arg(long, default_value_t = 42)]
seed: u64,
#[arg(long, default_value_t = 1)]
threads: usize,
}
fn load_mnist_items(path: &str, examples: usize) -> Vec<(Vec<f32>, u8)> {
let file = File::open(path).expect("Cannot open mnist.npz");
let mut archive = zip::ZipArchive::new(file).expect("Cannot read zip");
let image_names = ["train_images.npy", "train.images.npy", "x_train.npy", "images.npy"];
let mut image_bytes = Vec::new();
for name in &image_names {
if let Ok(mut entry) = archive.by_name(name) {
entry.read_to_end(&mut image_bytes).unwrap();
break;
}
}
let label_names = ["train_labels.npy", "train.labels.npy", "y_train.npy", "labels.npy"];
let mut label_bytes = Vec::new();
for name in &label_names {
if let Ok(mut entry) = archive.by_name(name) {
entry.read_to_end(&mut label_bytes).unwrap();
break;
}
}
let images_npy = npyz::NpyFile::new(Cursor::new(&image_bytes)).unwrap();
let shape = images_npy.shape().to_vec();
let n = shape[0] as usize;
let pixels = shape[1..].iter().product::<u64>() as usize;
let image_raw: Vec<u8> = images_npy.into_vec().unwrap();
let labels_npy = npyz::NpyFile::new(Cursor::new(&label_bytes)).unwrap();
let label_raw: Vec<u8> = labels_npy.into_vec().unwrap();
(0..n.min(examples))
.map(|i| {
let image: Vec<f32> = image_raw[i * pixels..(i + 1) * pixels]
.iter()
.map(|&p| p as f32 / 255.0)
.collect();
(image, label_raw[i])
})
.collect()
}
fn main() {
let args = Args::parse();
println!("Loading MNIST...");
let train_items = load_mnist_items("mnist.npz", 55_000);
let dev_items = load_mnist_items("mnist.npz", 5_000);
let test_items = load_mnist_items("mnist.npz", 10_000);
println!("Train: {}, Dev: {}, Test: {}", train_items.len(), dev_items.len(), test_items.len());
println!("Args: {:?}", args);
}

64
hod_2/src/model.rs Normal file
View File

@@ -0,0 +1,64 @@
use burn::tensor::{backend::Backend, Tensor};
use rand::{rngs::StdRng, SeedableRng};
use rand_distr::{Distribution, Normal};
/// Manual neural network parameters for SGD backpropagation.
/// No nn.Linear — just raw tensors to match the Python exercise.
pub struct Parameters<B: Backend> {
/// First layer weights: [784, hidden_layer_size]
pub w1: Tensor<B, 2>,
/// First layer biases: [hidden_layer_size]
pub b1: Tensor<B, 1>,
/// Second layer weights: [hidden_layer_size, 10]
pub w2: Tensor<B, 2>,
/// Second layer biases: [10]
pub b2: Tensor<B, 1>,
}
impl<B: Backend> Parameters<B> {
/// Initialize parameters with given hidden size and random seed.
/// Weights: randn * 0.1, Biases: zeros
pub fn new(device: &B::Device, hidden_size: usize, seed: u64) -> Self {
let w1 = random_tensor([784, hidden_size], 0.1, seed, device);
let b1 = Tensor::zeros([hidden_size], device);
let w2 = random_tensor([hidden_size, 10], 0.1, seed.wrapping_add(1), device);
let b2 = Tensor::zeros([10], device);
Self { w1, b1, w2, b2 }
}
/// Get all parameters as a vector for gradient updates.
/// Order: w1, b1, w2, b2
pub fn to_vec(&self) -> Vec<ParamRef<B>> {
vec![
ParamRef::TwoD(self.w1.clone()),
ParamRef::OneD(self.b1.clone()),
ParamRef::TwoD(self.w2.clone()),
ParamRef::OneD(self.b2.clone()),
]
}
}
/// Helper enum to handle 1D and 2D parameters uniformly.
pub enum ParamRef<B: Backend> {
OneD(Tensor<B, 1>),
TwoD(Tensor<B, 2>),
}
/// Create a random tensor with normal distribution, scaled by std_dev.
fn random_tensor<B: Backend, const D: usize>(
shape: [usize; D],
std_dev: f64,
seed: u64,
device: &B::Device,
) -> Tensor<B, D> {
let dist = Normal::new(0.0, std_dev).unwrap();
let mut rng = StdRng::seed_from_u64(seed);
let total: usize = shape.iter().product();
let data: Vec<f64> = (0..total).map(|_| dist.sample(&mut rng)).collect();
Tensor::from_floats(burn::tensor::TensorData::new(data, shape), device)
}