Files
2026-03-14 19:11:21 +01:00

1237 lines
37 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Chapter 1: Burn Foundations — Tensors, Modules & the Training Loop
> **Burn version:** 0.20.x · **Backend:** NdArray (CPU) only · **Assumed knowledge:** proficient Rust, ML fundamentals (scikit-learn level), studying deep learning theory.
>
> **Cargo.toml you should already have:**
> ```toml
> [dependencies]
> burn = { version = "0.20", features = ["ndarray"] }
> ```
---
## Step 1 — Tensor Basics: Creation, Shape & Kinds
### What you need to know
Everything in Burn starts with `Tensor`. Unlike ndarray or nalgebra, Burn's tensor is **generic
over three things**:
```
Tensor<B, D, K>
│ │ └─ Kind: Float (default), Int, Bool
│ └──── const dimensionality (rank) — a usize, NOT the shape
└─────── Backend (NdArray, Wgpu, …)
```
The *shape* (e.g. 3×4) is a runtime value. The *rank* (e.g. 2) is a compile-time constant. This
means the compiler catches dimension-mismatch bugs for you — you cannot accidentally add a 2-D
tensor to a 3-D tensor.
The `NdArray` backend is Burn's pure-Rust CPU backend built on top of the `ndarray` crate. It needs
no GPU drivers, compiles everywhere (including `no_std` embedded), and is the simplest backend to
start with.
### Creating tensors
```rust
use burn::backend::NdArray;
use burn::tensor::{Tensor, TensorData, Int, Bool};
// Type alias — used everywhere in this chapter.
type B = NdArray;
fn main() {
let device = Default::default(); // NdArray only has a Cpu device
// ---- Float tensors (default kind) ----
// From a nested array literal — easiest way for concrete backends:
let a = Tensor::<B, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
// from_floats — recommended shorthand for f32:
let b = Tensor::<B, 1>::from_floats([10.0, 20.0, 30.0], &device);
// Factories:
let zeros = Tensor::<B, 2>::zeros([3, 4], &device); // 3×4 of 0.0
let ones = Tensor::<B, 2>::ones([2, 2], &device); // 2×2 of 1.0
let full = Tensor::<B, 1>::full([5], 3.14, &device); // [3.14; 5]
// Same shape as another tensor:
let like_a = Tensor::<B, 2>::ones_like(&a); // 2×2 of 1.0
let like_z = Tensor::<B, 2>::zeros_like(&zeros); // 3×4 of 0.0
// Random (uniform or normal):
use burn::tensor::Distribution;
let uniform = Tensor::<B, 2>::random([4, 4], Distribution::Uniform(0.0, 1.0), &device);
let normal = Tensor::<B, 2>::random([4, 4], Distribution::Normal(0.0, 1.0), &device);
// ---- Int tensors ----
let labels = Tensor::<B, 1, Int>::from_data(TensorData::from([0i32, 1, 2, 1]), &device);
let range = Tensor::<B, 1, Int>::arange(0..5, &device); // [0, 1, 2, 3, 4]
// ---- Bool tensors ----
let mask = a.clone().greater_elem(2.0); // Tensor<B, 2, Bool> — true where > 2
println!("{}", a);
println!("{}", labels);
println!("{}", mask);
}
```
### Extracting data back out
```rust
// .to_data() clones — tensor is still usable afterward
let data: TensorData = a.to_data();
// .into_data() moves — tensor is consumed
let data: TensorData = a.into_data();
// Get a single scalar
let scalar: f32 = Tensor::<B, 1>::from_floats([42.0], &device)
.into_scalar();
```
### Printing
```rust
println!("{}", tensor); // full debug view with shape, backend, dtype
println!("{:.2}", tensor); // limit to 2 decimal places
```
### Exercise 1
Write a program that:
1. Creates a 3×3 identity-like float tensor manually with `from_data` (1s on diagonal, 0s
elsewhere).
2. Creates a 1-D Int tensor `[10, 20, 30]`.
3. Creates a Bool tensor by checking which elements of the float tensor are greater than 0.
4. Prints all three tensors.
**Success criterion:** compiles, prints the expected values, you understand every generic parameter.
---
## Step 2 — Ownership, Cloning & Tensor Operations
### Ownership: the one Burn rule you must internalize
Almost every Burn tensor operation **takes ownership** (moves) of the input tensor. This is
different from PyTorch where tensors are reference-counted behind the scenes.
```rust
let x = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0], &device);
let y = x.exp(); // x is MOVED into exp(), you cannot use x anymore
// println!("{}", x); // ← compile error: value used after move
```
The fix is `.clone()`. Cloning a Burn tensor is **cheap** — it only bumps a reference count, it
does not copy the underlying buffer.
```rust
let x = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0], &device);
let y = x.clone().exp(); // x is still alive
let z = x.clone().log(); // x is still alive
let w = x + y; // x is consumed here — that's fine, last use
```
**Rule of thumb:** clone whenever you need the tensor again later; don't clone on its last use.
Burn will automatically do in-place operations when it detects a single owner.
### Arithmetic
All standard ops are supported and work element-wise with broadcasting:
```rust
let a = Tensor::<B, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
let b = Tensor::<B, 2>::from_data([[10.0, 20.0], [30.0, 40.0]], &device);
let sum = a.clone() + b.clone(); // element-wise add
let diff = a.clone() - b.clone(); // element-wise sub
let prod = a.clone() * b.clone(); // element-wise mul
let quot = a.clone() / b.clone(); // element-wise div
// Scalar operations
let scaled = a.clone() * 2.0; // multiply every element by 2
let shifted = a.clone() + 1.0; // add 1 to every element
let neg = -a.clone(); // negate
```
### Float-specific math
```rust
let x = Tensor::<B, 1>::from_floats([0.0, 1.0, 2.0], &device);
let _ = x.clone().exp(); // e^x
let _ = x.clone().log(); // ln(x)
let _ = x.clone().sqrt(); // √x
let _ = x.clone().sin(); // sin(x)
let _ = x.clone().tanh(); // tanh(x)
let _ = x.clone().powf_scalar(2.0); // x²
let _ = x.clone().recip(); // 1/x
let _ = x.clone().abs();
let _ = x.clone().ceil();
let _ = x.clone().floor();
```
### Matrix multiplication
```rust
let a = Tensor::<B, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
let b = Tensor::<B, 2>::from_data([[5.0, 6.0], [7.0, 8.0]], &device);
let c = a.matmul(b); // standard matrix multiply — [2,2] × [2,2] → [2,2]
println!("{}", c);
```
`matmul` works on any rank ≥ 2 with batch-broadcasting semantics (just like PyTorch's `@`).
### Reductions
```rust
let t = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
let total = t.clone().sum(); // scalar: 21.0
let col = t.clone().sum_dim(0); // sum along rows → shape [1, 3]
let row = t.clone().sum_dim(1); // sum along cols → shape [2, 1]
let mean_v = t.clone().mean(); // scalar mean
let max_v = t.clone().max(); // scalar max
let argmax = t.clone().argmax(1); // indices of max along dim 1 → Int tensor
```
**Note:** `sum_dim`, `mean_dim`, `max_dim` etc. keep the reduced dimension (like PyTorch's
`keepdim=True`). This is consistent everywhere in Burn.
### Reshaping and dimension manipulation
```rust
let t = Tensor::<B, 1>::from_floats([1., 2., 3., 4., 5., 6.], &device);
let r = t.clone().reshape([2, 3]); // 1-D → 2-D (2×3)
let f = r.clone().flatten(0, 1); // 2-D → 1-D (flatten dims 0..1)
let u = r.clone().unsqueeze_dim(0); // [2,3] → [1,2,3]
let s = u.squeeze_dim(0); // [1,2,3] → [2,3]
let p = r.clone().swap_dims(0, 1); // transpose: [2,3] → [3,2]
let t2 = r.transpose(); // same as swap_dims(0,1) for 2-D
```
### Concatenation and stacking
```rust
let a = Tensor::<B, 2>::from_data([[1.0, 2.0]], &device); // [1, 2]
let b = Tensor::<B, 2>::from_data([[3.0, 4.0]], &device); // [1, 2]
// cat: join along an EXISTING dimension
let catted = Tensor::cat(vec![a.clone(), b.clone()], 0); // [2, 2]
// stack: join along a NEW dimension
let stacked = Tensor::stack(vec![a, b], 0); // [2, 1, 2]
```
### Slicing
```rust
use burn::tensor::s;
let t = Tensor::<B, 2>::from_data(
[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
&device,
);
let row0 = t.clone().slice([0..1, 0..3]); // [[1, 2, 3]]
let col01 = t.clone().slice([0..3, 0..2]); // [[1,2],[4,5],[7,8]]
let single = t.clone().slice([1..2, 1..2]); // [[5]]
```
### Comparisons → Bool tensors
```rust
let t = Tensor::<B, 1>::from_floats([1.0, 5.0, 3.0], &device);
let gt3: Tensor<B, 1, Bool> = t.clone().greater_elem(3.0); // [false, true, false]
let eq5: Tensor<B, 1, Bool> = t.clone().equal_elem(5.0); // [false, true, false]
```
### Exercise 2
Write a function `fn min_max_normalize(t: Tensor<B, 1>) -> Tensor<B, 1>` that:
1. Computes min and max of the input (remember to clone!).
2. Returns `(t - min) / (max - min)`.
3. Test it with `[2.0, 4.0, 6.0, 8.0, 10.0]` — expected output: `[0.0, 0.25, 0.5, 0.75, 1.0]`.
Then write a second function that takes a 2-D tensor and returns the softmax along dimension 1
using only basic ops: `exp`, `sum_dim`, and division. Compare a few values mentally.
---
## Step 3 — Autodiff: Computing Gradients
### How Burn does autodiff
Unlike PyTorch where autograd is baked into every tensor, Burn uses a **backend decorator** pattern.
You wrap your base backend with `Autodiff<...>` and that transparently adds gradient tracking:
```
NdArray → forward-only, no gradients
Autodiff<NdArray> → same ops, but now you can call .backward()
```
This is a zero-cost abstraction: if you don't need gradients (inference), you don't pay for them.
The `AutodiffBackend` trait extends `Backend` with gradient-related methods. When you write
training code, you constrain your generic with `B: AutodiffBackend`.
### A first gradient computation
```rust
use burn::backend::ndarray::NdArray;
use burn::backend::Autodiff;
use burn::tensor::{Tensor, backend::AutodiffBackend};
type TrainBackend = Autodiff<NdArray>;
fn main() {
let device = Default::default();
// Create a tensor that will be tracked for gradients
let x = Tensor::<TrainBackend, 1>::from_floats([1.0, 2.0, 3.0], &device)
.require_grad();
// Forward: compute y = x^2, then sum to get a scalar loss
let y = x.clone().powf_scalar(2.0); // [1, 4, 9]
let loss = y.sum(); // 14.0
// Backward: compute gradients
let grads = loss.backward();
// Extract gradient of `loss` w.r.t. `x`
// dy/dx = 2x → expected: [2.0, 4.0, 6.0]
let x_grad = x.grad(&grads).expect("x should have a gradient");
println!("x = {}", x.to_data());
println!("dx = {}", x_grad.to_data());
}
```
### The flow in detail
1. **`require_grad()`** — marks a leaf tensor as "track me". Any tensor derived from it inherits
tracking automatically.
2. **Forward pass** — you build a computation graph by chaining normal tensor ops.
3. **`.backward()`** — called on a **scalar** tensor (0-D or summed to 0-D). Returns a `Gradients`
object containing all partial derivatives.
4. **`.grad(&grads)`** — retrieves the gradient for a specific tensor. Returns `Option<Tensor>`.
### `model.valid()` — disabling gradients
When you have a model on an `Autodiff<...>` backend and want to run inference without tracking:
```rust
// During training: model is on Autodiff<NdArray>
let model_valid = model.valid(); // returns the same model on plain NdArray
// Now forward passes do not build a graph → faster, less memory
```
This replaces PyTorch's `with torch.no_grad():` and `model.eval()` (for the grad-tracking part).
### Exercise 3
1. Create a tensor `w = [0.5, -0.3, 0.8]` with `require_grad()` on `Autodiff<NdArray>`.
2. Create an input `x = [1.0, 2.0, 3.0]` (no grad needed — think of it as data).
3. Compute a "prediction" `y_hat = (w * x).sum()` (dot product).
4. Define a "target" `y = 2.5` and compute MSE loss: `loss = (y_hat - y)^2`.
5. Call `.backward()` and print the gradient of `w`.
6. Manually verify the gradient: `d(loss)/dw_i = 2 * (y_hat - y) * x_i`.
---
## Step 4 — The Config Pattern
### Why configs exist
In PyTorch, you create a layer like `nn.Linear(784, 128, bias=True)` — the constructor directly
builds the parameters. Burn separates **configuration** from **initialization** into two steps:
```
LinearConfig::new(784, 128) // 1. describe WHAT you want
.with_bias(true) // (builder-style tweaks)
.init(&device) // 2. actually allocate weights on a device
```
This separation matters because:
- Configs are serializable (save/load hyperparameters to JSON).
- Configs are device-agnostic — you can build on CPU today, GPU tomorrow.
- You can inspect or modify a config before committing to allocating memory.
### The `#[derive(Config)]` macro
Burn's `Config` derive macro generates a builder pattern for you:
```rust
use burn::config::Config;
#[derive(Config)]
pub struct MyTrainingConfig {
#[config(default = 64)]
pub batch_size: usize,
#[config(default = 1e-3)]
pub learning_rate: f64,
#[config(default = 10)]
pub num_epochs: usize,
pub model: MyModelConfig, // nested config — no default, must be provided
}
```
This generates:
- `MyTrainingConfig::new(model: MyModelConfig)` — constructor requiring non-default fields.
- `.with_batch_size(128)` — optional builder method for each field with a default.
- Serialization support (JSON via serde).
### Saving and loading configs
```rust
// Save
let config = MyTrainingConfig::new(MyModelConfig::new(10, 512));
config.save("config.json").expect("failed to save config");
// Load
let loaded = MyTrainingConfig::load("config.json").expect("failed to load config");
```
### Exercise 4
Define a `Config` struct called `ExperimentConfig` with:
- `hidden_size: usize` (no default — required)
- `dropout: f64` (default 0.5)
- `activation: String` (default `"relu"`)
- `lr: f64` (default 1e-4)
Create an instance, modify `dropout` to 0.3 via the builder, save it to JSON, load it back, and
assert the values match.
---
## Step 5 — Defining Modules (Your First Neural Network)
### The `#[derive(Module)]` macro
A Burn "module" is any struct that holds learnable parameters (or sub-modules that do). You annotate
it with `#[derive(Module)]` and make it generic over `Backend`:
```rust
use burn::prelude::*;
use burn::nn::{Linear, LinearConfig, Relu};
#[derive(Module, Debug)]
pub struct Mlp<B: Backend> {
linear1: Linear<B>,
linear2: Linear<B>,
activation: Relu,
}
```
The derive macro automatically implements the `Module` trait, which gives you:
- Parameter iteration (for optimizers).
- Device transfer (`.to_device()`).
- Serialization (save/load weights).
- `.valid()` to strip autodiff.
- `.clone()` (cheap, reference-counted).
### The Config → init pattern for your own module
Every module you write should have a companion config:
```rust
#[derive(Config)]
pub struct MlpConfig {
input_dim: usize,
hidden_dim: usize,
output_dim: usize,
}
impl MlpConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Mlp<B> {
Mlp {
linear1: LinearConfig::new(self.input_dim, self.hidden_dim)
.init(device),
linear2: LinearConfig::new(self.hidden_dim, self.output_dim)
.init(device),
activation: Relu::new(),
}
}
}
```
### The forward pass
By convention, you implement a `forward` method. This is NOT a trait method — it's just a
convention. You can name it anything, or have multiple forward methods for different purposes.
```rust
impl<B: Backend> Mlp<B> {
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.linear1.forward(x); // [batch, hidden]
let x = self.activation.forward(x);
self.linear2.forward(x) // [batch, output]
}
}
```
### Putting it together
```rust
use burn::backend::NdArray;
fn main() {
type B = NdArray;
let device = Default::default();
let config = MlpConfig::new(4, 32, 2);
let model = config.init::<B>(&device);
// Random input: batch of 8 samples, 4 features each
let input = Tensor::<B, 2>::random(
[8, 4],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = model.forward(input);
println!("Output shape: {:?}", output.dims()); // [8, 2]
println!("{:.4}", output);
}
```
### Exercise 5
Build a 3-layer MLP from scratch:
- Input: 10 features
- Hidden 1: 64 neurons + ReLU
- Hidden 2: 32 neurons + ReLU
- Output: 3 neurons (raw logits for a 3-class problem)
Create the config, init the model, feed a random `[16, 10]` batch through it, and verify the output
shape is `[16, 3]`.
---
## Step 6 — Built-in Layers Tour
### The naming convention
Every built-in layer in `burn::nn` follows the same pattern:
```
ThingConfig::new(required_args) // create config
.with_optional_param(value) // builder tweaks
.init(&device) // → Thing<B>
```
The layer struct has a `.forward(input)` method.
### Linear (fully connected)
```rust
use burn::nn::{Linear, LinearConfig};
let layer = LinearConfig::new(784, 128)
.with_bias(true) // default is true
.init::<B>(&device);
// input: [batch, 784] → output: [batch, 128]
let out = layer.forward(input);
```
### Conv2d
```rust
use burn::nn::conv::{Conv2d, Conv2dConfig};
use burn::nn::PaddingConfig2d;
let conv = Conv2dConfig::new([1, 32], [3, 3]) // [in_channels, out_channels], [kH, kW]
.with_padding(PaddingConfig2d::Same)
.with_stride([1, 1])
.init::<B>(&device);
// input: [batch, channels, H, W] → output: [batch, 32, H, W] (with Same padding)
let out = conv.forward(input_4d);
```
### BatchNorm
```rust
use burn::nn::{BatchNorm, BatchNormConfig};
let bn = BatchNormConfig::new(32) // num_features (channels)
.init::<B>(&device);
// input: [batch, 32, H, W] → same shape, normalized
let out = bn.forward(input);
```
### Dropout
```rust
use burn::nn::{Dropout, DropoutConfig};
let dropout = DropoutConfig::new(0.5).init();
// During training, randomly zeros elements.
// Note: Dropout is stateless, no Backend generic needed.
let out = dropout.forward(input);
```
Burn's dropout is automatically disabled during inference when you use `model.valid()`.
### Pooling
```rust
use burn::nn::pool::{
MaxPool2d, MaxPool2dConfig,
AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig,
};
let maxpool = MaxPool2dConfig::new([2, 2])
.with_stride([2, 2])
.init();
let adaptive = AdaptiveAvgPool2dConfig::new([1, 1]).init();
let out = maxpool.forward(input); // halves spatial dims
let out = adaptive.forward(input); // spatial → 1×1
```
### Embedding
```rust
use burn::nn::{Embedding, EmbeddingConfig};
let embed = EmbeddingConfig::new(10000, 256) // vocab_size, embed_dim
.init::<B>(&device);
// input: Int tensor [batch, seq_len] → output: [batch, seq_len, 256]
let out = embed.forward(token_ids);
```
### Activations (stateless functions)
Most activations are free functions in `burn::tensor::activation`, not module structs:
```rust
use burn::tensor::activation;
let x = Tensor::<B, 2>::random([4, 4], Distribution::Normal(0.0, 1.0), &device);
let _ = activation::relu(x.clone());
let _ = activation::sigmoid(x.clone());
let _ = activation::gelu(x.clone());
let _ = activation::softmax(x.clone(), 1); // softmax along dim 1
let _ = activation::log_softmax(x.clone(), 1);
```
There are also module wrappers like `Relu`, `Gelu`, etc. that just call these functions.
Use whichever style you prefer.
### Exercise 6
Build a small CNN module for 1-channel 28×28 images (think MNIST):
- Conv2d: 1 → 16 channels, 3×3 kernel, same padding → ReLU → MaxPool 2×2
- Conv2d: 16 → 32 channels, 3×3 kernel, same padding → ReLU → AdaptiveAvgPool 1×1
- Flatten → Linear(32, 10)
Create the config, init the model, feed a random `[4, 1, 28, 28]` tensor, verify output is
`[4, 10]`.
---
## Step 7 — Loss Functions
### Built-in losses
Burn provides losses in `burn::nn::loss`:
```rust
use burn::nn::loss::{CrossEntropyLoss, CrossEntropyLossConfig, MseLoss};
```
**Cross-entropy** (for classification):
```rust
let loss_fn = CrossEntropyLossConfig::new()
.init(&device);
// logits: [batch, num_classes] — raw model output, NOT softmax'd
// targets: [batch] — Int tensor of class indices
let loss: Tensor<B, 1> = loss_fn.forward(logits, targets);
```
Note: `CrossEntropyLoss` expects **logits** (raw scores) by default, not probabilities. It applies
log-softmax internally, just like PyTorch's `nn.CrossEntropyLoss`.
**MSE loss** (for regression):
```rust
let mse = MseLoss::new();
// Both tensors must have the same shape.
let loss = mse.forward(predictions, targets, burn::nn::loss::Reduction::Mean);
```
### Writing a custom loss
A loss is just a function that takes tensors and returns a scalar tensor. There's nothing special
about it:
```rust
/// Binary cross-entropy from logits (custom implementation).
fn bce_with_logits<B: Backend>(logits: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
// sigmoid + log trick for numerical stability:
// loss = max(logits, 0) - logits * targets + log(1 + exp(-|logits|))
let zeros = Tensor::zeros_like(&logits);
let max_val = logits.clone().max_pair(zeros);
let abs_logits = logits.clone().abs();
let loss = max_val - logits * targets
+ ((-abs_logits).exp() + 1.0).log();
loss.mean()
}
```
### Exercise 7
1. Create a random logits tensor `[8, 5]` (batch=8, 5 classes) and a random targets Int tensor
`[8]` with values in 0..5.
2. Compute cross-entropy loss using `CrossEntropyLossConfig`.
3. Write your own `manual_cross_entropy` function that:
- Computes log-softmax: `log_softmax(x, dim=1)` via `activation::log_softmax`.
- Gathers the log-prob at the target index for each sample (use a loop or
`Tensor::select` + some reshaping).
- Returns the negative mean.
4. Compare the two values — they should be very close (use `check_closeness` or just eyeball the
printed scalars).
---
## Step 8 — Datasets & Batching
### The `Dataset` trait
Burn's dataset system lives in `burn::data::dataset`:
```rust
pub trait Dataset<I>: Send + Sync {
fn get(&self, index: usize) -> Option<I>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool { self.len() == 0 }
}
```
Simple — it's a random-access collection.
### InMemDataset
The simplest dataset — a Vec in memory:
```rust
use burn::data::dataset::InMemDataset;
#[derive(Clone, Debug)]
struct Sample {
features: [f32; 4],
label: u8,
}
let samples = vec![
Sample { features: [1.0, 2.0, 3.0, 4.0], label: 0 },
Sample { features: [5.0, 6.0, 7.0, 8.0], label: 1 },
// ... more
];
let dataset = InMemDataset::new(samples);
println!("Dataset length: {}", dataset.len());
println!("Sample 0: {:?}", dataset.get(0));
```
### The `Batcher` trait
A batcher converts a `Vec<Item>` into a batch struct containing tensors:
```rust
use burn::data::dataloader::batcher::Batcher;
#[derive(Clone, Debug)]
struct MyBatch<B: Backend> {
inputs: Tensor<B, 2>, // [batch, features]
targets: Tensor<B, 1, Int>, // [batch]
}
#[derive(Clone)]
struct MyBatcher<B: Backend> {
device: B::Device,
}
impl<B: Backend> Batcher<Sample, MyBatch<B>> for MyBatcher<B> {
fn batch(&self, items: Vec<Sample>) -> MyBatch<B> {
let inputs: Vec<Tensor<B, 2>> = items.iter()
.map(|s| Tensor::<B, 1>::from_floats(s.features, &self.device))
.map(|t| t.unsqueeze_dim(0)) // [4] → [1, 4]
.collect();
let targets: Vec<Tensor<B, 1, Int>> = items.iter()
.map(|s| Tensor::<B, 1, Int>::from_data(
TensorData::from([s.label as i32]),
&self.device,
))
.collect();
MyBatch {
inputs: Tensor::cat(inputs, 0), // [batch, 4]
targets: Tensor::cat(targets, 0), // [batch]
}
}
}
```
### DataLoaderBuilder
```rust
use burn::data::dataloader::DataLoaderBuilder;
let batcher = MyBatcher::<B> { device: device.clone() };
let dataloader = DataLoaderBuilder::new(batcher)
.batch_size(32)
.shuffle(42) // seed for shuffling
.num_workers(4) // parallel data loading threads
.build(dataset);
for batch in dataloader.iter() {
println!("Batch inputs: {:?}", batch.inputs.dims());
println!("Batch targets: {:?}", batch.targets.dims());
}
```
### Exercise 8
Create a synthetic XOR-like dataset:
- 200 samples of `[x1, x2]` where x1, x2 are random in {0, 1}.
- Label is `x1 XOR x2` (0 or 1).
- Implement the `Sample` struct, populate an `InMemDataset`, write a `Batcher`, and build a
`DataLoader` with batch size 16.
- Loop over one epoch and print each batch's input shape and target shape.
---
## Step 9 — Optimizers
### How optimizers work in Burn
In PyTorch, you register parameters with the optimizer, call `loss.backward()`, then
`optimizer.step()`, then `optimizer.zero_grad()`. Burn is more functional:
```rust
let grads = loss.backward(); // 1. get raw gradients
let grads = GradientsParams::from_grads(grads, &model); // 2. map to parameter IDs
model = optim.step(lr, model, grads); // 3. optimizer consumes grads, returns updated model
```
No `zero_grad` — gradients are consumed in step 3. No mutation — `step` takes ownership of the
model and returns an updated one.
### Creating an optimizer
```rust
use burn::optim::{AdamConfig, SgdConfig, Adam, Sgd};
// Adam (most common)
let optim = AdamConfig::new()
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_epsilon(1e-8)
.with_weight_decay_config(None)
.init();
// SGD with momentum
let optim = SgdConfig::new()
.with_momentum(Some(burn::optim::MomentumConfig {
momentum: 0.9,
dampening: 0.0,
nesterov: false,
}))
.init();
```
The optimizer is generic — it works with any module on any `AutodiffBackend`.
### Gradient accumulation
If you need to accumulate gradients over multiple mini-batches before stepping:
```rust
use burn::optim::GradientsAccumulator;
let mut accumulator = GradientsAccumulator::new();
for batch in mini_batches {
let loss = model.forward(batch);
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
accumulator.accumulate(&model, grads);
}
let accumulated = accumulator.grads();
model = optim.step(lr, model, accumulated);
```
### Exercise 9
Using the XOR dataset from Exercise 8 and the 3-layer MLP from Exercise 5 (adapt dimensions to
input=2, hidden=16, output=2):
1. Create an Adam optimizer.
2. Run one batch through the model.
3. Compute cross-entropy loss.
4. Call `.backward()``GradientsParams::from_grads``optim.step`.
5. Print the loss before and after the step (run the forward pass again) to confirm it changed.
Don't write the full loop yet — just verify the single-step mechanics compile and run.
---
## Step 10 — The Full Training Loop (Manual)
### Wiring everything together
This is the capstone: a complete, working training loop with no `Learner` abstraction. We'll train
an MLP on the XOR problem.
```rust
use burn::backend::ndarray::NdArray;
use burn::backend::Autodiff;
use burn::data::dataloader::DataLoaderBuilder;
use burn::data::dataset::InMemDataset;
use burn::nn::loss::{CrossEntropyLoss, CrossEntropyLossConfig};
use burn::optim::{AdamConfig, GradientsParams};
use burn::prelude::*;
use burn::tensor::backend::AutodiffBackend;
// ---- Backend aliases ----
type TrainB = Autodiff<NdArray>;
type InferB = NdArray;
// ---- Model (reuse your MLP, adapted) ----
use burn::nn::{Linear, LinearConfig, Relu};
#[derive(Module, Debug)]
pub struct XorModel<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
fc3: Linear<B>,
relu: Relu,
}
#[derive(Config)]
pub struct XorModelConfig {
#[config(default = 2)]
input: usize,
#[config(default = 16)]
hidden: usize,
#[config(default = 2)]
output: usize,
}
impl XorModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> XorModel<B> {
XorModel {
fc1: LinearConfig::new(self.input, self.hidden).init(device),
fc2: LinearConfig::new(self.hidden, self.hidden).init(device),
fc3: LinearConfig::new(self.hidden, self.output).init(device),
relu: Relu::new(),
}
}
}
impl<B: Backend> XorModel<B> {
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.relu.forward(self.fc1.forward(x));
let x = self.relu.forward(self.fc2.forward(x));
self.fc3.forward(x)
}
}
// ---- Data ----
#[derive(Clone, Debug)]
struct XorSample {
x: [f32; 2],
y: i32,
}
#[derive(Clone)]
struct XorBatcher<B: Backend> {
device: B::Device,
}
impl<B: Backend> burn::data::dataloader::batcher::Batcher<XorSample, (Tensor<B, 2>, Tensor<B, 1, Int>)>
for XorBatcher<B>
{
fn batch(&self, items: Vec<XorSample>) -> (Tensor<B, 2>, Tensor<B, 1, Int>) {
let xs: Vec<Tensor<B, 2>> = items.iter()
.map(|s| Tensor::<B, 1>::from_floats(s.x, &self.device).unsqueeze_dim(0))
.collect();
let ys: Vec<Tensor<B, 1, Int>> = items.iter()
.map(|s| Tensor::<B, 1, Int>::from_data(TensorData::from([s.y]), &self.device))
.collect();
(Tensor::cat(xs, 0), Tensor::cat(ys, 0))
}
}
fn make_xor_dataset() -> InMemDataset<XorSample> {
let mut samples = Vec::new();
// Repeat patterns many times so the network has enough data
for _ in 0..200 {
samples.push(XorSample { x: [0.0, 0.0], y: 0 });
samples.push(XorSample { x: [0.0, 1.0], y: 1 });
samples.push(XorSample { x: [1.0, 0.0], y: 1 });
samples.push(XorSample { x: [1.0, 1.0], y: 0 });
}
InMemDataset::new(samples)
}
// ---- Training loop ----
fn train() {
let device = Default::default();
// Config
let lr = 1e-3;
let num_epochs = 50;
let batch_size = 32;
// Model + optimizer
let model_config = XorModelConfig::new();
let mut model: XorModel<TrainB> = model_config.init(&device);
let mut optim = AdamConfig::new().init();
// Data
let dataset = make_xor_dataset();
let batcher = XorBatcher::<TrainB> { device: device.clone() };
let dataloader = DataLoaderBuilder::new(batcher)
.batch_size(batch_size)
.shuffle(42)
.build(dataset);
// Loss function
let loss_fn = CrossEntropyLossConfig::new().init(&device);
// Train!
for epoch in 1..=num_epochs {
let mut epoch_loss = 0.0;
let mut n_batches = 0;
for (inputs, targets) in dataloader.iter() {
// Forward
let logits = model.forward(inputs);
let loss = loss_fn.forward(logits.clone(), targets.clone());
// Backward + optimize
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(lr, model, grads);
epoch_loss += loss.into_scalar().elem::<f32>();
n_batches += 1;
}
if epoch % 10 == 0 || epoch == 1 {
println!(
"Epoch {:>3} | Loss: {:.4}",
epoch,
epoch_loss / n_batches as f32,
);
}
}
// ---- Quick inference check ----
let model_valid = model.valid(); // drop autodiff → runs on NdArray
let test_input = Tensor::<InferB, 2>::from_data(
[[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]],
&device,
);
let predictions = model_valid.forward(test_input);
let classes = predictions.argmax(1);
println!("\nPredictions (should be 0, 1, 1, 0):");
println!("{}", classes.to_data());
}
fn main() {
train();
}
```
### Key points to internalize
- **`model` is reassigned** on every `optim.step(...)` — Burn is functional: the old model is
consumed, a new updated model is returned.
- **`loss.backward()`** returns raw gradients. You must wrap them with `GradientsParams::from_grads`
to associate them with parameter IDs.
- **No `zero_grad`** — gradients are consumed by `step`.
- **`model.valid()`** strips the `Autodiff` wrapper. The returned model lives on the inner backend
(`NdArray`). Use it for validation/inference.
- **`loss.into_scalar()`** gives you a backend-specific element. Use `.elem::<f32>()` to get a
plain Rust `f32` for printing.
### Exercise 10
Take the training loop above and extend it:
1. Add a **validation step** after each epoch: call `model.valid()`, run the 4 XOR inputs through
it, compute accuracy (percentage of correct predictions).
2. Add **early stopping**: if validation accuracy hits 100%, break out of the epoch loop.
3. Experiment: reduce the hidden size to 4 — does it still learn? Increase to 64 — does it learn
faster?
---
## Step 11 — Saving & Loading Models
### Records and Recorders
Burn saves models as "records" using "recorders". A record is the serializable state of a module
(all its parameter tensors). A recorder is a strategy for how to store that record (binary,
MessagePack, etc.).
### Saving
```rust
use burn::record::{CompactRecorder, NamedMpkFileRecorder, FullPrecisionSettings};
let recorder = CompactRecorder::new();
// Save — writes to "my_model.bin" (extension added automatically)
model
.save_file("my_model", &recorder)
.expect("Failed to save model");
```
`CompactRecorder` uses a compact binary format. `NamedMpkFileRecorder` uses named MessagePack,
which is more portable and human-debuggable.
### Loading
```rust
// First, create a fresh model with random weights
let mut model = model_config.init::<B>(&device);
// Then load saved weights into it
let recorder = CompactRecorder::new();
model = model
.load_file("my_model", &recorder, &device)
.expect("Failed to load model");
```
Note: the model architecture must match exactly — same layers, same sizes.
### A complete save/load cycle
```rust
fn save_model<B: Backend>(model: &XorModel<B>) {
let recorder = CompactRecorder::new();
model
.save_file("xor_trained", &recorder)
.expect("save failed");
println!("Model saved.");
}
fn load_model<B: Backend>(device: &B::Device) -> XorModel<B> {
let model = XorModelConfig::new().init::<B>(device);
let recorder = CompactRecorder::new();
model
.load_file("xor_trained", &recorder, device)
.expect("load failed")
}
// In your training function, after training:
save_model(&model_valid);
// Later, for inference:
let loaded: XorModel<InferB> = load_model(&device);
let output = loaded.forward(test_input);
```
### Checkpoint during training
A common pattern is to save every N epochs:
```rust
for epoch in 1..=num_epochs {
// ... training ...
if epoch % 10 == 0 {
let snapshot = model.valid(); // strip autodiff for saving
save_model(&snapshot);
println!("Checkpoint saved at epoch {}", epoch);
}
}
```
### Exercise 11
1. Train the XOR model from Step 10 until it converges.
2. Save it with `CompactRecorder`.
3. In a **separate function** (simulating a new program run), load the model on `NdArray` and run
inference on all 4 XOR inputs.
4. Verify the predictions match what you got at the end of training.
**Bonus:** also try `NamedMpkFileRecorder::<FullPrecisionSettings>` and compare the file sizes.
---
## Quick Reference Card
| Concept | Burn | PyTorch equivalent |
|---|---|---|
| Backend selection | `type B = NdArray;` | device selection |
| Tensor creation | `Tensor::<B,2>::from_data(...)` | `torch.tensor(...)` |
| Autodiff | `Autodiff<NdArray>` backend wrapper | built-in autograd |
| Disable grad | `model.valid()` | `torch.no_grad()` |
| Module definition | `#[derive(Module)]` | `nn.Module` subclass |
| Layer construction | `LinearConfig::new(a,b).init(&dev)` | `nn.Linear(a,b)` |
| Hyperparameters | `#[derive(Config)]` | manual / dataclass |
| Loss | `CrossEntropyLossConfig::new().init(&dev)` | `nn.CrossEntropyLoss()` |
| Optimizer | `AdamConfig::new().init()` | `optim.Adam(params, lr)` |
| Backward | `loss.backward()` | `loss.backward()` |
| Param gradients | `GradientsParams::from_grads(grads, &model)` | automatic |
| Optimizer step | `model = optim.step(lr, model, grads)` | `optimizer.step()` |
| Zero grad | automatic (grads consumed) | `optimizer.zero_grad()` |
| Save model | `model.save_file("path", &recorder)` | `torch.save(state_dict)` |
| Load model | `model.load_file("path", &recorder, &dev)` | `model.load_state_dict(...)` |
---
## Where to go next
You now have the vocabulary and mechanics to read and write Burn code for deep learning practicals.
The topics we deliberately skipped (for a future chapter) include:
- **The `Learner` abstraction** — Burn's built-in training loop with metrics, checkpointing, and a
TUI dashboard.
- **GPU backends** — `Wgpu`, `Cuda`, `LibTorch` for actual hardware acceleration.
- **ONNX import** — loading models from PyTorch/TensorFlow.
- **Custom GPU kernels** — writing CubeCL compute shaders.
- **Quantization** — post-training quantization for deployment.
- **`no_std` deployment** — running Burn on bare-metal embedded (where NdArray also works).
For now, go do your deep learning exercises in Burn. You have the tools.