# Chapter 1: Burn Foundations — Tensors, Modules & the Training Loop > **Burn version:** 0.20.x · **Backend:** NdArray (CPU) only · **Assumed knowledge:** proficient Rust, ML fundamentals (scikit-learn level), studying deep learning theory. > > **Cargo.toml you should already have:** > ```toml > [dependencies] > burn = { version = "0.20", features = ["ndarray"] } > ``` --- ## Step 1 — Tensor Basics: Creation, Shape & Kinds ### What you need to know Everything in Burn starts with `Tensor`. Unlike ndarray or nalgebra, Burn's tensor is **generic over three things**: ``` Tensor │ │ └─ Kind: Float (default), Int, Bool │ └──── const dimensionality (rank) — a usize, NOT the shape └─────── Backend (NdArray, Wgpu, …) ``` The *shape* (e.g. 3×4) is a runtime value. The *rank* (e.g. 2) is a compile-time constant. This means the compiler catches dimension-mismatch bugs for you — you cannot accidentally add a 2-D tensor to a 3-D tensor. The `NdArray` backend is Burn's pure-Rust CPU backend built on top of the `ndarray` crate. It needs no GPU drivers, compiles everywhere (including `no_std` embedded), and is the simplest backend to start with. ### Creating tensors ```rust use burn::backend::NdArray; use burn::tensor::{Tensor, TensorData, Int, Bool}; // Type alias — used everywhere in this chapter. type B = NdArray; fn main() { let device = Default::default(); // NdArray only has a Cpu device // ---- Float tensors (default kind) ---- // From a nested array literal — easiest way for concrete backends: let a = Tensor::::from_data([[1.0, 2.0], [3.0, 4.0]], &device); // from_floats — recommended shorthand for f32: let b = Tensor::::from_floats([10.0, 20.0, 30.0], &device); // Factories: let zeros = Tensor::::zeros([3, 4], &device); // 3×4 of 0.0 let ones = Tensor::::ones([2, 2], &device); // 2×2 of 1.0 let full = Tensor::::full([5], 3.14, &device); // [3.14; 5] // Same shape as another tensor: let like_a = Tensor::::ones_like(&a); // 2×2 of 1.0 let like_z = Tensor::::zeros_like(&zeros); // 3×4 of 0.0 // Random (uniform or normal): use burn::tensor::Distribution; let uniform = Tensor::::random([4, 4], Distribution::Uniform(0.0, 1.0), &device); let normal = Tensor::::random([4, 4], Distribution::Normal(0.0, 1.0), &device); // ---- Int tensors ---- let labels = Tensor::::from_data(TensorData::from([0i32, 1, 2, 1]), &device); let range = Tensor::::arange(0..5, &device); // [0, 1, 2, 3, 4] // ---- Bool tensors ---- let mask = a.clone().greater_elem(2.0); // Tensor — true where > 2 println!("{}", a); println!("{}", labels); println!("{}", mask); } ``` ### Extracting data back out ```rust // .to_data() clones — tensor is still usable afterward let data: TensorData = a.to_data(); // .into_data() moves — tensor is consumed let data: TensorData = a.into_data(); // Get a single scalar let scalar: f32 = Tensor::::from_floats([42.0], &device) .into_scalar(); ``` ### Printing ```rust println!("{}", tensor); // full debug view with shape, backend, dtype println!("{:.2}", tensor); // limit to 2 decimal places ``` ### Exercise 1 Write a program that: 1. Creates a 3×3 identity-like float tensor manually with `from_data` (1s on diagonal, 0s elsewhere). 2. Creates a 1-D Int tensor `[10, 20, 30]`. 3. Creates a Bool tensor by checking which elements of the float tensor are greater than 0. 4. Prints all three tensors. **Success criterion:** compiles, prints the expected values, you understand every generic parameter. --- ## Step 2 — Ownership, Cloning & Tensor Operations ### Ownership: the one Burn rule you must internalize Almost every Burn tensor operation **takes ownership** (moves) of the input tensor. This is different from PyTorch where tensors are reference-counted behind the scenes. ```rust let x = Tensor::::from_floats([1.0, 2.0, 3.0], &device); let y = x.exp(); // x is MOVED into exp(), you cannot use x anymore // println!("{}", x); // ← compile error: value used after move ``` The fix is `.clone()`. Cloning a Burn tensor is **cheap** — it only bumps a reference count, it does not copy the underlying buffer. ```rust let x = Tensor::::from_floats([1.0, 2.0, 3.0], &device); let y = x.clone().exp(); // x is still alive let z = x.clone().log(); // x is still alive let w = x + y; // x is consumed here — that's fine, last use ``` **Rule of thumb:** clone whenever you need the tensor again later; don't clone on its last use. Burn will automatically do in-place operations when it detects a single owner. ### Arithmetic All standard ops are supported and work element-wise with broadcasting: ```rust let a = Tensor::::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let b = Tensor::::from_data([[10.0, 20.0], [30.0, 40.0]], &device); let sum = a.clone() + b.clone(); // element-wise add let diff = a.clone() - b.clone(); // element-wise sub let prod = a.clone() * b.clone(); // element-wise mul let quot = a.clone() / b.clone(); // element-wise div // Scalar operations let scaled = a.clone() * 2.0; // multiply every element by 2 let shifted = a.clone() + 1.0; // add 1 to every element let neg = -a.clone(); // negate ``` ### Float-specific math ```rust let x = Tensor::::from_floats([0.0, 1.0, 2.0], &device); let _ = x.clone().exp(); // e^x let _ = x.clone().log(); // ln(x) let _ = x.clone().sqrt(); // √x let _ = x.clone().sin(); // sin(x) let _ = x.clone().tanh(); // tanh(x) let _ = x.clone().powf_scalar(2.0); // x² let _ = x.clone().recip(); // 1/x let _ = x.clone().abs(); let _ = x.clone().ceil(); let _ = x.clone().floor(); ``` ### Matrix multiplication ```rust let a = Tensor::::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let b = Tensor::::from_data([[5.0, 6.0], [7.0, 8.0]], &device); let c = a.matmul(b); // standard matrix multiply — [2,2] × [2,2] → [2,2] println!("{}", c); ``` `matmul` works on any rank ≥ 2 with batch-broadcasting semantics (just like PyTorch's `@`). ### Reductions ```rust let t = Tensor::::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); let total = t.clone().sum(); // scalar: 21.0 let col = t.clone().sum_dim(0); // sum along rows → shape [1, 3] let row = t.clone().sum_dim(1); // sum along cols → shape [2, 1] let mean_v = t.clone().mean(); // scalar mean let max_v = t.clone().max(); // scalar max let argmax = t.clone().argmax(1); // indices of max along dim 1 → Int tensor ``` **Note:** `sum_dim`, `mean_dim`, `max_dim` etc. keep the reduced dimension (like PyTorch's `keepdim=True`). This is consistent everywhere in Burn. ### Reshaping and dimension manipulation ```rust let t = Tensor::::from_floats([1., 2., 3., 4., 5., 6.], &device); let r = t.clone().reshape([2, 3]); // 1-D → 2-D (2×3) let f = r.clone().flatten(0, 1); // 2-D → 1-D (flatten dims 0..1) let u = r.clone().unsqueeze_dim(0); // [2,3] → [1,2,3] let s = u.squeeze_dim(0); // [1,2,3] → [2,3] let p = r.clone().swap_dims(0, 1); // transpose: [2,3] → [3,2] let t2 = r.transpose(); // same as swap_dims(0,1) for 2-D ``` ### Concatenation and stacking ```rust let a = Tensor::::from_data([[1.0, 2.0]], &device); // [1, 2] let b = Tensor::::from_data([[3.0, 4.0]], &device); // [1, 2] // cat: join along an EXISTING dimension let catted = Tensor::cat(vec![a.clone(), b.clone()], 0); // [2, 2] // stack: join along a NEW dimension let stacked = Tensor::stack(vec![a, b], 0); // [2, 1, 2] ``` ### Slicing ```rust use burn::tensor::s; let t = Tensor::::from_data( [[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], &device, ); let row0 = t.clone().slice([0..1, 0..3]); // [[1, 2, 3]] let col01 = t.clone().slice([0..3, 0..2]); // [[1,2],[4,5],[7,8]] let single = t.clone().slice([1..2, 1..2]); // [[5]] ``` ### Comparisons → Bool tensors ```rust let t = Tensor::::from_floats([1.0, 5.0, 3.0], &device); let gt3: Tensor = t.clone().greater_elem(3.0); // [false, true, false] let eq5: Tensor = t.clone().equal_elem(5.0); // [false, true, false] ``` ### Exercise 2 Write a function `fn min_max_normalize(t: Tensor) -> Tensor` that: 1. Computes min and max of the input (remember to clone!). 2. Returns `(t - min) / (max - min)`. 3. Test it with `[2.0, 4.0, 6.0, 8.0, 10.0]` — expected output: `[0.0, 0.25, 0.5, 0.75, 1.0]`. Then write a second function that takes a 2-D tensor and returns the softmax along dimension 1 using only basic ops: `exp`, `sum_dim`, and division. Compare a few values mentally. --- ## Step 3 — Autodiff: Computing Gradients ### How Burn does autodiff Unlike PyTorch where autograd is baked into every tensor, Burn uses a **backend decorator** pattern. You wrap your base backend with `Autodiff<...>` and that transparently adds gradient tracking: ``` NdArray → forward-only, no gradients Autodiff → same ops, but now you can call .backward() ``` This is a zero-cost abstraction: if you don't need gradients (inference), you don't pay for them. The `AutodiffBackend` trait extends `Backend` with gradient-related methods. When you write training code, you constrain your generic with `B: AutodiffBackend`. ### A first gradient computation ```rust use burn::backend::ndarray::NdArray; use burn::backend::Autodiff; use burn::tensor::{Tensor, backend::AutodiffBackend}; type TrainBackend = Autodiff; fn main() { let device = Default::default(); // Create a tensor that will be tracked for gradients let x = Tensor::::from_floats([1.0, 2.0, 3.0], &device) .require_grad(); // Forward: compute y = x^2, then sum to get a scalar loss let y = x.clone().powf_scalar(2.0); // [1, 4, 9] let loss = y.sum(); // 14.0 // Backward: compute gradients let grads = loss.backward(); // Extract gradient of `loss` w.r.t. `x` // dy/dx = 2x → expected: [2.0, 4.0, 6.0] let x_grad = x.grad(&grads).expect("x should have a gradient"); println!("x = {}", x.to_data()); println!("dx = {}", x_grad.to_data()); } ``` ### The flow in detail 1. **`require_grad()`** — marks a leaf tensor as "track me". Any tensor derived from it inherits tracking automatically. 2. **Forward pass** — you build a computation graph by chaining normal tensor ops. 3. **`.backward()`** — called on a **scalar** tensor (0-D or summed to 0-D). Returns a `Gradients` object containing all partial derivatives. 4. **`.grad(&grads)`** — retrieves the gradient for a specific tensor. Returns `Option`. ### `model.valid()` — disabling gradients When you have a model on an `Autodiff<...>` backend and want to run inference without tracking: ```rust // During training: model is on Autodiff let model_valid = model.valid(); // returns the same model on plain NdArray // Now forward passes do not build a graph → faster, less memory ``` This replaces PyTorch's `with torch.no_grad():` and `model.eval()` (for the grad-tracking part). ### Exercise 3 1. Create a tensor `w = [0.5, -0.3, 0.8]` with `require_grad()` on `Autodiff`. 2. Create an input `x = [1.0, 2.0, 3.0]` (no grad needed — think of it as data). 3. Compute a "prediction" `y_hat = (w * x).sum()` (dot product). 4. Define a "target" `y = 2.5` and compute MSE loss: `loss = (y_hat - y)^2`. 5. Call `.backward()` and print the gradient of `w`. 6. Manually verify the gradient: `d(loss)/dw_i = 2 * (y_hat - y) * x_i`. --- ## Step 4 — The Config Pattern ### Why configs exist In PyTorch, you create a layer like `nn.Linear(784, 128, bias=True)` — the constructor directly builds the parameters. Burn separates **configuration** from **initialization** into two steps: ``` LinearConfig::new(784, 128) // 1. describe WHAT you want .with_bias(true) // (builder-style tweaks) .init(&device) // 2. actually allocate weights on a device ``` This separation matters because: - Configs are serializable (save/load hyperparameters to JSON). - Configs are device-agnostic — you can build on CPU today, GPU tomorrow. - You can inspect or modify a config before committing to allocating memory. ### The `#[derive(Config)]` macro Burn's `Config` derive macro generates a builder pattern for you: ```rust use burn::config::Config; #[derive(Config)] pub struct MyTrainingConfig { #[config(default = 64)] pub batch_size: usize, #[config(default = 1e-3)] pub learning_rate: f64, #[config(default = 10)] pub num_epochs: usize, pub model: MyModelConfig, // nested config — no default, must be provided } ``` This generates: - `MyTrainingConfig::new(model: MyModelConfig)` — constructor requiring non-default fields. - `.with_batch_size(128)` — optional builder method for each field with a default. - Serialization support (JSON via serde). ### Saving and loading configs ```rust // Save let config = MyTrainingConfig::new(MyModelConfig::new(10, 512)); config.save("config.json").expect("failed to save config"); // Load let loaded = MyTrainingConfig::load("config.json").expect("failed to load config"); ``` ### Exercise 4 Define a `Config` struct called `ExperimentConfig` with: - `hidden_size: usize` (no default — required) - `dropout: f64` (default 0.5) - `activation: String` (default `"relu"`) - `lr: f64` (default 1e-4) Create an instance, modify `dropout` to 0.3 via the builder, save it to JSON, load it back, and assert the values match. --- ## Step 5 — Defining Modules (Your First Neural Network) ### The `#[derive(Module)]` macro A Burn "module" is any struct that holds learnable parameters (or sub-modules that do). You annotate it with `#[derive(Module)]` and make it generic over `Backend`: ```rust use burn::prelude::*; use burn::nn::{Linear, LinearConfig, Relu}; #[derive(Module, Debug)] pub struct Mlp { linear1: Linear, linear2: Linear, activation: Relu, } ``` The derive macro automatically implements the `Module` trait, which gives you: - Parameter iteration (for optimizers). - Device transfer (`.to_device()`). - Serialization (save/load weights). - `.valid()` to strip autodiff. - `.clone()` (cheap, reference-counted). ### The Config → init pattern for your own module Every module you write should have a companion config: ```rust #[derive(Config)] pub struct MlpConfig { input_dim: usize, hidden_dim: usize, output_dim: usize, } impl MlpConfig { pub fn init(&self, device: &B::Device) -> Mlp { Mlp { linear1: LinearConfig::new(self.input_dim, self.hidden_dim) .init(device), linear2: LinearConfig::new(self.hidden_dim, self.output_dim) .init(device), activation: Relu::new(), } } } ``` ### The forward pass By convention, you implement a `forward` method. This is NOT a trait method — it's just a convention. You can name it anything, or have multiple forward methods for different purposes. ```rust impl Mlp { pub fn forward(&self, x: Tensor) -> Tensor { let x = self.linear1.forward(x); // [batch, hidden] let x = self.activation.forward(x); self.linear2.forward(x) // [batch, output] } } ``` ### Putting it together ```rust use burn::backend::NdArray; fn main() { type B = NdArray; let device = Default::default(); let config = MlpConfig::new(4, 32, 2); let model = config.init::(&device); // Random input: batch of 8 samples, 4 features each let input = Tensor::::random( [8, 4], burn::tensor::Distribution::Normal(0.0, 1.0), &device, ); let output = model.forward(input); println!("Output shape: {:?}", output.dims()); // [8, 2] println!("{:.4}", output); } ``` ### Exercise 5 Build a 3-layer MLP from scratch: - Input: 10 features - Hidden 1: 64 neurons + ReLU - Hidden 2: 32 neurons + ReLU - Output: 3 neurons (raw logits for a 3-class problem) Create the config, init the model, feed a random `[16, 10]` batch through it, and verify the output shape is `[16, 3]`. --- ## Step 6 — Built-in Layers Tour ### The naming convention Every built-in layer in `burn::nn` follows the same pattern: ``` ThingConfig::new(required_args) // create config .with_optional_param(value) // builder tweaks .init(&device) // → Thing ``` The layer struct has a `.forward(input)` method. ### Linear (fully connected) ```rust use burn::nn::{Linear, LinearConfig}; let layer = LinearConfig::new(784, 128) .with_bias(true) // default is true .init::(&device); // input: [batch, 784] → output: [batch, 128] let out = layer.forward(input); ``` ### Conv2d ```rust use burn::nn::conv::{Conv2d, Conv2dConfig}; use burn::nn::PaddingConfig2d; let conv = Conv2dConfig::new([1, 32], [3, 3]) // [in_channels, out_channels], [kH, kW] .with_padding(PaddingConfig2d::Same) .with_stride([1, 1]) .init::(&device); // input: [batch, channels, H, W] → output: [batch, 32, H, W] (with Same padding) let out = conv.forward(input_4d); ``` ### BatchNorm ```rust use burn::nn::{BatchNorm, BatchNormConfig}; let bn = BatchNormConfig::new(32) // num_features (channels) .init::(&device); // input: [batch, 32, H, W] → same shape, normalized let out = bn.forward(input); ``` ### Dropout ```rust use burn::nn::{Dropout, DropoutConfig}; let dropout = DropoutConfig::new(0.5).init(); // During training, randomly zeros elements. // Note: Dropout is stateless, no Backend generic needed. let out = dropout.forward(input); ``` Burn's dropout is automatically disabled during inference when you use `model.valid()`. ### Pooling ```rust use burn::nn::pool::{ MaxPool2d, MaxPool2dConfig, AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig, }; let maxpool = MaxPool2dConfig::new([2, 2]) .with_stride([2, 2]) .init(); let adaptive = AdaptiveAvgPool2dConfig::new([1, 1]).init(); let out = maxpool.forward(input); // halves spatial dims let out = adaptive.forward(input); // spatial → 1×1 ``` ### Embedding ```rust use burn::nn::{Embedding, EmbeddingConfig}; let embed = EmbeddingConfig::new(10000, 256) // vocab_size, embed_dim .init::(&device); // input: Int tensor [batch, seq_len] → output: [batch, seq_len, 256] let out = embed.forward(token_ids); ``` ### Activations (stateless functions) Most activations are free functions in `burn::tensor::activation`, not module structs: ```rust use burn::tensor::activation; let x = Tensor::::random([4, 4], Distribution::Normal(0.0, 1.0), &device); let _ = activation::relu(x.clone()); let _ = activation::sigmoid(x.clone()); let _ = activation::gelu(x.clone()); let _ = activation::softmax(x.clone(), 1); // softmax along dim 1 let _ = activation::log_softmax(x.clone(), 1); ``` There are also module wrappers like `Relu`, `Gelu`, etc. that just call these functions. Use whichever style you prefer. ### Exercise 6 Build a small CNN module for 1-channel 28×28 images (think MNIST): - Conv2d: 1 → 16 channels, 3×3 kernel, same padding → ReLU → MaxPool 2×2 - Conv2d: 16 → 32 channels, 3×3 kernel, same padding → ReLU → AdaptiveAvgPool 1×1 - Flatten → Linear(32, 10) Create the config, init the model, feed a random `[4, 1, 28, 28]` tensor, verify output is `[4, 10]`. --- ## Step 7 — Loss Functions ### Built-in losses Burn provides losses in `burn::nn::loss`: ```rust use burn::nn::loss::{CrossEntropyLoss, CrossEntropyLossConfig, MseLoss}; ``` **Cross-entropy** (for classification): ```rust let loss_fn = CrossEntropyLossConfig::new() .init(&device); // logits: [batch, num_classes] — raw model output, NOT softmax'd // targets: [batch] — Int tensor of class indices let loss: Tensor = loss_fn.forward(logits, targets); ``` Note: `CrossEntropyLoss` expects **logits** (raw scores) by default, not probabilities. It applies log-softmax internally, just like PyTorch's `nn.CrossEntropyLoss`. **MSE loss** (for regression): ```rust let mse = MseLoss::new(); // Both tensors must have the same shape. let loss = mse.forward(predictions, targets, burn::nn::loss::Reduction::Mean); ``` ### Writing a custom loss A loss is just a function that takes tensors and returns a scalar tensor. There's nothing special about it: ```rust /// Binary cross-entropy from logits (custom implementation). fn bce_with_logits(logits: Tensor, targets: Tensor) -> Tensor { // sigmoid + log trick for numerical stability: // loss = max(logits, 0) - logits * targets + log(1 + exp(-|logits|)) let zeros = Tensor::zeros_like(&logits); let max_val = logits.clone().max_pair(zeros); let abs_logits = logits.clone().abs(); let loss = max_val - logits * targets + ((-abs_logits).exp() + 1.0).log(); loss.mean() } ``` ### Exercise 7 1. Create a random logits tensor `[8, 5]` (batch=8, 5 classes) and a random targets Int tensor `[8]` with values in 0..5. 2. Compute cross-entropy loss using `CrossEntropyLossConfig`. 3. Write your own `manual_cross_entropy` function that: - Computes log-softmax: `log_softmax(x, dim=1)` via `activation::log_softmax`. - Gathers the log-prob at the target index for each sample (use a loop or `Tensor::select` + some reshaping). - Returns the negative mean. 4. Compare the two values — they should be very close (use `check_closeness` or just eyeball the printed scalars). --- ## Step 8 — Datasets & Batching ### The `Dataset` trait Burn's dataset system lives in `burn::data::dataset`: ```rust pub trait Dataset: Send + Sync { fn get(&self, index: usize) -> Option; fn len(&self) -> usize; fn is_empty(&self) -> bool { self.len() == 0 } } ``` Simple — it's a random-access collection. ### InMemDataset The simplest dataset — a Vec in memory: ```rust use burn::data::dataset::InMemDataset; #[derive(Clone, Debug)] struct Sample { features: [f32; 4], label: u8, } let samples = vec![ Sample { features: [1.0, 2.0, 3.0, 4.0], label: 0 }, Sample { features: [5.0, 6.0, 7.0, 8.0], label: 1 }, // ... more ]; let dataset = InMemDataset::new(samples); println!("Dataset length: {}", dataset.len()); println!("Sample 0: {:?}", dataset.get(0)); ``` ### The `Batcher` trait A batcher converts a `Vec` into a batch struct containing tensors: ```rust use burn::data::dataloader::batcher::Batcher; #[derive(Clone, Debug)] struct MyBatch { inputs: Tensor, // [batch, features] targets: Tensor, // [batch] } #[derive(Clone)] struct MyBatcher { device: B::Device, } impl Batcher> for MyBatcher { fn batch(&self, items: Vec) -> MyBatch { let inputs: Vec> = items.iter() .map(|s| Tensor::::from_floats(s.features, &self.device)) .map(|t| t.unsqueeze_dim(0)) // [4] → [1, 4] .collect(); let targets: Vec> = items.iter() .map(|s| Tensor::::from_data( TensorData::from([s.label as i32]), &self.device, )) .collect(); MyBatch { inputs: Tensor::cat(inputs, 0), // [batch, 4] targets: Tensor::cat(targets, 0), // [batch] } } } ``` ### DataLoaderBuilder ```rust use burn::data::dataloader::DataLoaderBuilder; let batcher = MyBatcher:: { device: device.clone() }; let dataloader = DataLoaderBuilder::new(batcher) .batch_size(32) .shuffle(42) // seed for shuffling .num_workers(4) // parallel data loading threads .build(dataset); for batch in dataloader.iter() { println!("Batch inputs: {:?}", batch.inputs.dims()); println!("Batch targets: {:?}", batch.targets.dims()); } ``` ### Exercise 8 Create a synthetic XOR-like dataset: - 200 samples of `[x1, x2]` where x1, x2 are random in {0, 1}. - Label is `x1 XOR x2` (0 or 1). - Implement the `Sample` struct, populate an `InMemDataset`, write a `Batcher`, and build a `DataLoader` with batch size 16. - Loop over one epoch and print each batch's input shape and target shape. --- ## Step 9 — Optimizers ### How optimizers work in Burn In PyTorch, you register parameters with the optimizer, call `loss.backward()`, then `optimizer.step()`, then `optimizer.zero_grad()`. Burn is more functional: ```rust let grads = loss.backward(); // 1. get raw gradients let grads = GradientsParams::from_grads(grads, &model); // 2. map to parameter IDs model = optim.step(lr, model, grads); // 3. optimizer consumes grads, returns updated model ``` No `zero_grad` — gradients are consumed in step 3. No mutation — `step` takes ownership of the model and returns an updated one. ### Creating an optimizer ```rust use burn::optim::{AdamConfig, SgdConfig, Adam, Sgd}; // Adam (most common) let optim = AdamConfig::new() .with_beta_1(0.9) .with_beta_2(0.999) .with_epsilon(1e-8) .with_weight_decay_config(None) .init(); // SGD with momentum let optim = SgdConfig::new() .with_momentum(Some(burn::optim::MomentumConfig { momentum: 0.9, dampening: 0.0, nesterov: false, })) .init(); ``` The optimizer is generic — it works with any module on any `AutodiffBackend`. ### Gradient accumulation If you need to accumulate gradients over multiple mini-batches before stepping: ```rust use burn::optim::GradientsAccumulator; let mut accumulator = GradientsAccumulator::new(); for batch in mini_batches { let loss = model.forward(batch); let grads = loss.backward(); let grads = GradientsParams::from_grads(grads, &model); accumulator.accumulate(&model, grads); } let accumulated = accumulator.grads(); model = optim.step(lr, model, accumulated); ``` ### Exercise 9 Using the XOR dataset from Exercise 8 and the 3-layer MLP from Exercise 5 (adapt dimensions to input=2, hidden=16, output=2): 1. Create an Adam optimizer. 2. Run one batch through the model. 3. Compute cross-entropy loss. 4. Call `.backward()` → `GradientsParams::from_grads` → `optim.step`. 5. Print the loss before and after the step (run the forward pass again) to confirm it changed. Don't write the full loop yet — just verify the single-step mechanics compile and run. --- ## Step 10 — The Full Training Loop (Manual) ### Wiring everything together This is the capstone: a complete, working training loop with no `Learner` abstraction. We'll train an MLP on the XOR problem. ```rust use burn::backend::ndarray::NdArray; use burn::backend::Autodiff; use burn::data::dataloader::DataLoaderBuilder; use burn::data::dataset::InMemDataset; use burn::nn::loss::{CrossEntropyLoss, CrossEntropyLossConfig}; use burn::optim::{AdamConfig, GradientsParams}; use burn::prelude::*; use burn::tensor::backend::AutodiffBackend; // ---- Backend aliases ---- type TrainB = Autodiff; type InferB = NdArray; // ---- Model (reuse your MLP, adapted) ---- use burn::nn::{Linear, LinearConfig, Relu}; #[derive(Module, Debug)] pub struct XorModel { fc1: Linear, fc2: Linear, fc3: Linear, relu: Relu, } #[derive(Config)] pub struct XorModelConfig { #[config(default = 2)] input: usize, #[config(default = 16)] hidden: usize, #[config(default = 2)] output: usize, } impl XorModelConfig { pub fn init(&self, device: &B::Device) -> XorModel { XorModel { fc1: LinearConfig::new(self.input, self.hidden).init(device), fc2: LinearConfig::new(self.hidden, self.hidden).init(device), fc3: LinearConfig::new(self.hidden, self.output).init(device), relu: Relu::new(), } } } impl XorModel { pub fn forward(&self, x: Tensor) -> Tensor { let x = self.relu.forward(self.fc1.forward(x)); let x = self.relu.forward(self.fc2.forward(x)); self.fc3.forward(x) } } // ---- Data ---- #[derive(Clone, Debug)] struct XorSample { x: [f32; 2], y: i32, } #[derive(Clone)] struct XorBatcher { device: B::Device, } impl burn::data::dataloader::batcher::Batcher, Tensor)> for XorBatcher { fn batch(&self, items: Vec) -> (Tensor, Tensor) { let xs: Vec> = items.iter() .map(|s| Tensor::::from_floats(s.x, &self.device).unsqueeze_dim(0)) .collect(); let ys: Vec> = items.iter() .map(|s| Tensor::::from_data(TensorData::from([s.y]), &self.device)) .collect(); (Tensor::cat(xs, 0), Tensor::cat(ys, 0)) } } fn make_xor_dataset() -> InMemDataset { let mut samples = Vec::new(); // Repeat patterns many times so the network has enough data for _ in 0..200 { samples.push(XorSample { x: [0.0, 0.0], y: 0 }); samples.push(XorSample { x: [0.0, 1.0], y: 1 }); samples.push(XorSample { x: [1.0, 0.0], y: 1 }); samples.push(XorSample { x: [1.0, 1.0], y: 0 }); } InMemDataset::new(samples) } // ---- Training loop ---- fn train() { let device = Default::default(); // Config let lr = 1e-3; let num_epochs = 50; let batch_size = 32; // Model + optimizer let model_config = XorModelConfig::new(); let mut model: XorModel = model_config.init(&device); let mut optim = AdamConfig::new().init(); // Data let dataset = make_xor_dataset(); let batcher = XorBatcher:: { device: device.clone() }; let dataloader = DataLoaderBuilder::new(batcher) .batch_size(batch_size) .shuffle(42) .build(dataset); // Loss function let loss_fn = CrossEntropyLossConfig::new().init(&device); // Train! for epoch in 1..=num_epochs { let mut epoch_loss = 0.0; let mut n_batches = 0; for (inputs, targets) in dataloader.iter() { // Forward let logits = model.forward(inputs); let loss = loss_fn.forward(logits.clone(), targets.clone()); // Backward + optimize let grads = loss.backward(); let grads = GradientsParams::from_grads(grads, &model); model = optim.step(lr, model, grads); epoch_loss += loss.into_scalar().elem::(); n_batches += 1; } if epoch % 10 == 0 || epoch == 1 { println!( "Epoch {:>3} | Loss: {:.4}", epoch, epoch_loss / n_batches as f32, ); } } // ---- Quick inference check ---- let model_valid = model.valid(); // drop autodiff → runs on NdArray let test_input = Tensor::::from_data( [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], &device, ); let predictions = model_valid.forward(test_input); let classes = predictions.argmax(1); println!("\nPredictions (should be 0, 1, 1, 0):"); println!("{}", classes.to_data()); } fn main() { train(); } ``` ### Key points to internalize - **`model` is reassigned** on every `optim.step(...)` — Burn is functional: the old model is consumed, a new updated model is returned. - **`loss.backward()`** returns raw gradients. You must wrap them with `GradientsParams::from_grads` to associate them with parameter IDs. - **No `zero_grad`** — gradients are consumed by `step`. - **`model.valid()`** strips the `Autodiff` wrapper. The returned model lives on the inner backend (`NdArray`). Use it for validation/inference. - **`loss.into_scalar()`** gives you a backend-specific element. Use `.elem::()` to get a plain Rust `f32` for printing. ### Exercise 10 Take the training loop above and extend it: 1. Add a **validation step** after each epoch: call `model.valid()`, run the 4 XOR inputs through it, compute accuracy (percentage of correct predictions). 2. Add **early stopping**: if validation accuracy hits 100%, break out of the epoch loop. 3. Experiment: reduce the hidden size to 4 — does it still learn? Increase to 64 — does it learn faster? --- ## Step 11 — Saving & Loading Models ### Records and Recorders Burn saves models as "records" using "recorders". A record is the serializable state of a module (all its parameter tensors). A recorder is a strategy for how to store that record (binary, MessagePack, etc.). ### Saving ```rust use burn::record::{CompactRecorder, NamedMpkFileRecorder, FullPrecisionSettings}; let recorder = CompactRecorder::new(); // Save — writes to "my_model.bin" (extension added automatically) model .save_file("my_model", &recorder) .expect("Failed to save model"); ``` `CompactRecorder` uses a compact binary format. `NamedMpkFileRecorder` uses named MessagePack, which is more portable and human-debuggable. ### Loading ```rust // First, create a fresh model with random weights let mut model = model_config.init::(&device); // Then load saved weights into it let recorder = CompactRecorder::new(); model = model .load_file("my_model", &recorder, &device) .expect("Failed to load model"); ``` Note: the model architecture must match exactly — same layers, same sizes. ### A complete save/load cycle ```rust fn save_model(model: &XorModel) { let recorder = CompactRecorder::new(); model .save_file("xor_trained", &recorder) .expect("save failed"); println!("Model saved."); } fn load_model(device: &B::Device) -> XorModel { let model = XorModelConfig::new().init::(device); let recorder = CompactRecorder::new(); model .load_file("xor_trained", &recorder, device) .expect("load failed") } // In your training function, after training: save_model(&model_valid); // Later, for inference: let loaded: XorModel = load_model(&device); let output = loaded.forward(test_input); ``` ### Checkpoint during training A common pattern is to save every N epochs: ```rust for epoch in 1..=num_epochs { // ... training ... if epoch % 10 == 0 { let snapshot = model.valid(); // strip autodiff for saving save_model(&snapshot); println!("Checkpoint saved at epoch {}", epoch); } } ``` ### Exercise 11 1. Train the XOR model from Step 10 until it converges. 2. Save it with `CompactRecorder`. 3. In a **separate function** (simulating a new program run), load the model on `NdArray` and run inference on all 4 XOR inputs. 4. Verify the predictions match what you got at the end of training. **Bonus:** also try `NamedMpkFileRecorder::` and compare the file sizes. --- ## Quick Reference Card | Concept | Burn | PyTorch equivalent | |---|---|---| | Backend selection | `type B = NdArray;` | device selection | | Tensor creation | `Tensor::::from_data(...)` | `torch.tensor(...)` | | Autodiff | `Autodiff` backend wrapper | built-in autograd | | Disable grad | `model.valid()` | `torch.no_grad()` | | Module definition | `#[derive(Module)]` | `nn.Module` subclass | | Layer construction | `LinearConfig::new(a,b).init(&dev)` | `nn.Linear(a,b)` | | Hyperparameters | `#[derive(Config)]` | manual / dataclass | | Loss | `CrossEntropyLossConfig::new().init(&dev)` | `nn.CrossEntropyLoss()` | | Optimizer | `AdamConfig::new().init()` | `optim.Adam(params, lr)` | | Backward | `loss.backward()` | `loss.backward()` | | Param gradients | `GradientsParams::from_grads(grads, &model)` | automatic | | Optimizer step | `model = optim.step(lr, model, grads)` | `optimizer.step()` | | Zero grad | automatic (grads consumed) | `optimizer.zero_grad()` | | Save model | `model.save_file("path", &recorder)` | `torch.save(state_dict)` | | Load model | `model.load_file("path", &recorder, &dev)` | `model.load_state_dict(...)` | --- ## Where to go next You now have the vocabulary and mechanics to read and write Burn code for deep learning practicals. The topics we deliberately skipped (for a future chapter) include: - **The `Learner` abstraction** — Burn's built-in training loop with metrics, checkpointing, and a TUI dashboard. - **GPU backends** — `Wgpu`, `Cuda`, `LibTorch` for actual hardware acceleration. - **ONNX import** — loading models from PyTorch/TensorFlow. - **Custom GPU kernels** — writing CubeCL compute shaders. - **Quantization** — post-training quantization for deployment. - **`no_std` deployment** — running Burn on bare-metal embedded (where NdArray also works). For now, go do your deep learning exercises in Burn. You have the tools.