1237 lines
37 KiB
Markdown
1237 lines
37 KiB
Markdown
# Chapter 1: Burn Foundations — Tensors, Modules & the Training Loop
|
||
|
||
> **Burn version:** 0.20.x · **Backend:** NdArray (CPU) only · **Assumed knowledge:** proficient Rust, ML fundamentals (scikit-learn level), studying deep learning theory.
|
||
>
|
||
> **Cargo.toml you should already have:**
|
||
> ```toml
|
||
> [dependencies]
|
||
> burn = { version = "0.20", features = ["ndarray"] }
|
||
> ```
|
||
|
||
---
|
||
|
||
## Step 1 — Tensor Basics: Creation, Shape & Kinds
|
||
|
||
### What you need to know
|
||
|
||
Everything in Burn starts with `Tensor`. Unlike ndarray or nalgebra, Burn's tensor is **generic
|
||
over three things**:
|
||
|
||
```
|
||
Tensor<B, D, K>
|
||
│ │ └─ Kind: Float (default), Int, Bool
|
||
│ └──── const dimensionality (rank) — a usize, NOT the shape
|
||
└─────── Backend (NdArray, Wgpu, …)
|
||
```
|
||
|
||
The *shape* (e.g. 3×4) is a runtime value. The *rank* (e.g. 2) is a compile-time constant. This
|
||
means the compiler catches dimension-mismatch bugs for you — you cannot accidentally add a 2-D
|
||
tensor to a 3-D tensor.
|
||
|
||
The `NdArray` backend is Burn's pure-Rust CPU backend built on top of the `ndarray` crate. It needs
|
||
no GPU drivers, compiles everywhere (including `no_std` embedded), and is the simplest backend to
|
||
start with.
|
||
|
||
### Creating tensors
|
||
|
||
```rust
|
||
use burn::backend::NdArray;
|
||
use burn::tensor::{Tensor, TensorData, Int, Bool};
|
||
|
||
// Type alias — used everywhere in this chapter.
|
||
type B = NdArray;
|
||
|
||
fn main() {
|
||
let device = Default::default(); // NdArray only has a Cpu device
|
||
|
||
// ---- Float tensors (default kind) ----
|
||
|
||
// From a nested array literal — easiest way for concrete backends:
|
||
let a = Tensor::<B, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
|
||
|
||
// from_floats — recommended shorthand for f32:
|
||
let b = Tensor::<B, 1>::from_floats([10.0, 20.0, 30.0], &device);
|
||
|
||
// Factories:
|
||
let zeros = Tensor::<B, 2>::zeros([3, 4], &device); // 3×4 of 0.0
|
||
let ones = Tensor::<B, 2>::ones([2, 2], &device); // 2×2 of 1.0
|
||
let full = Tensor::<B, 1>::full([5], 3.14, &device); // [3.14; 5]
|
||
|
||
// Same shape as another tensor:
|
||
let like_a = Tensor::<B, 2>::ones_like(&a); // 2×2 of 1.0
|
||
let like_z = Tensor::<B, 2>::zeros_like(&zeros); // 3×4 of 0.0
|
||
|
||
// Random (uniform or normal):
|
||
use burn::tensor::Distribution;
|
||
let uniform = Tensor::<B, 2>::random([4, 4], Distribution::Uniform(0.0, 1.0), &device);
|
||
let normal = Tensor::<B, 2>::random([4, 4], Distribution::Normal(0.0, 1.0), &device);
|
||
|
||
// ---- Int tensors ----
|
||
let labels = Tensor::<B, 1, Int>::from_data(TensorData::from([0i32, 1, 2, 1]), &device);
|
||
let range = Tensor::<B, 1, Int>::arange(0..5, &device); // [0, 1, 2, 3, 4]
|
||
|
||
// ---- Bool tensors ----
|
||
let mask = a.clone().greater_elem(2.0); // Tensor<B, 2, Bool> — true where > 2
|
||
|
||
println!("{}", a);
|
||
println!("{}", labels);
|
||
println!("{}", mask);
|
||
}
|
||
```
|
||
|
||
### Extracting data back out
|
||
|
||
```rust
|
||
// .to_data() clones — tensor is still usable afterward
|
||
let data: TensorData = a.to_data();
|
||
|
||
// .into_data() moves — tensor is consumed
|
||
let data: TensorData = a.into_data();
|
||
|
||
// Get a single scalar
|
||
let scalar: f32 = Tensor::<B, 1>::from_floats([42.0], &device)
|
||
.into_scalar();
|
||
```
|
||
|
||
### Printing
|
||
|
||
```rust
|
||
println!("{}", tensor); // full debug view with shape, backend, dtype
|
||
println!("{:.2}", tensor); // limit to 2 decimal places
|
||
```
|
||
|
||
### Exercise 1
|
||
|
||
Write a program that:
|
||
|
||
1. Creates a 3×3 identity-like float tensor manually with `from_data` (1s on diagonal, 0s
|
||
elsewhere).
|
||
2. Creates a 1-D Int tensor `[10, 20, 30]`.
|
||
3. Creates a Bool tensor by checking which elements of the float tensor are greater than 0.
|
||
4. Prints all three tensors.
|
||
|
||
**Success criterion:** compiles, prints the expected values, you understand every generic parameter.
|
||
|
||
---
|
||
|
||
## Step 2 — Ownership, Cloning & Tensor Operations
|
||
|
||
### Ownership: the one Burn rule you must internalize
|
||
|
||
Almost every Burn tensor operation **takes ownership** (moves) of the input tensor. This is
|
||
different from PyTorch where tensors are reference-counted behind the scenes.
|
||
|
||
```rust
|
||
let x = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0], &device);
|
||
let y = x.exp(); // x is MOVED into exp(), you cannot use x anymore
|
||
// println!("{}", x); // ← compile error: value used after move
|
||
```
|
||
|
||
The fix is `.clone()`. Cloning a Burn tensor is **cheap** — it only bumps a reference count, it
|
||
does not copy the underlying buffer.
|
||
|
||
```rust
|
||
let x = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0], &device);
|
||
let y = x.clone().exp(); // x is still alive
|
||
let z = x.clone().log(); // x is still alive
|
||
let w = x + y; // x is consumed here — that's fine, last use
|
||
```
|
||
|
||
**Rule of thumb:** clone whenever you need the tensor again later; don't clone on its last use.
|
||
Burn will automatically do in-place operations when it detects a single owner.
|
||
|
||
### Arithmetic
|
||
|
||
All standard ops are supported and work element-wise with broadcasting:
|
||
|
||
```rust
|
||
let a = Tensor::<B, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
|
||
let b = Tensor::<B, 2>::from_data([[10.0, 20.0], [30.0, 40.0]], &device);
|
||
|
||
let sum = a.clone() + b.clone(); // element-wise add
|
||
let diff = a.clone() - b.clone(); // element-wise sub
|
||
let prod = a.clone() * b.clone(); // element-wise mul
|
||
let quot = a.clone() / b.clone(); // element-wise div
|
||
|
||
// Scalar operations
|
||
let scaled = a.clone() * 2.0; // multiply every element by 2
|
||
let shifted = a.clone() + 1.0; // add 1 to every element
|
||
let neg = -a.clone(); // negate
|
||
```
|
||
|
||
### Float-specific math
|
||
|
||
```rust
|
||
let x = Tensor::<B, 1>::from_floats([0.0, 1.0, 2.0], &device);
|
||
|
||
let _ = x.clone().exp(); // e^x
|
||
let _ = x.clone().log(); // ln(x)
|
||
let _ = x.clone().sqrt(); // √x
|
||
let _ = x.clone().sin(); // sin(x)
|
||
let _ = x.clone().tanh(); // tanh(x)
|
||
let _ = x.clone().powf_scalar(2.0); // x²
|
||
let _ = x.clone().recip(); // 1/x
|
||
let _ = x.clone().abs();
|
||
let _ = x.clone().ceil();
|
||
let _ = x.clone().floor();
|
||
```
|
||
|
||
### Matrix multiplication
|
||
|
||
```rust
|
||
let a = Tensor::<B, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
|
||
let b = Tensor::<B, 2>::from_data([[5.0, 6.0], [7.0, 8.0]], &device);
|
||
|
||
let c = a.matmul(b); // standard matrix multiply — [2,2] × [2,2] → [2,2]
|
||
println!("{}", c);
|
||
```
|
||
|
||
`matmul` works on any rank ≥ 2 with batch-broadcasting semantics (just like PyTorch's `@`).
|
||
|
||
### Reductions
|
||
|
||
```rust
|
||
let t = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
|
||
|
||
let total = t.clone().sum(); // scalar: 21.0
|
||
let col = t.clone().sum_dim(0); // sum along rows → shape [1, 3]
|
||
let row = t.clone().sum_dim(1); // sum along cols → shape [2, 1]
|
||
let mean_v = t.clone().mean(); // scalar mean
|
||
let max_v = t.clone().max(); // scalar max
|
||
let argmax = t.clone().argmax(1); // indices of max along dim 1 → Int tensor
|
||
```
|
||
|
||
**Note:** `sum_dim`, `mean_dim`, `max_dim` etc. keep the reduced dimension (like PyTorch's
|
||
`keepdim=True`). This is consistent everywhere in Burn.
|
||
|
||
### Reshaping and dimension manipulation
|
||
|
||
```rust
|
||
let t = Tensor::<B, 1>::from_floats([1., 2., 3., 4., 5., 6.], &device);
|
||
|
||
let r = t.clone().reshape([2, 3]); // 1-D → 2-D (2×3)
|
||
let f = r.clone().flatten(0, 1); // 2-D → 1-D (flatten dims 0..1)
|
||
|
||
let u = r.clone().unsqueeze_dim(0); // [2,3] → [1,2,3]
|
||
let s = u.squeeze_dim(0); // [1,2,3] → [2,3]
|
||
let p = r.clone().swap_dims(0, 1); // transpose: [2,3] → [3,2]
|
||
let t2 = r.transpose(); // same as swap_dims(0,1) for 2-D
|
||
```
|
||
|
||
### Concatenation and stacking
|
||
|
||
```rust
|
||
let a = Tensor::<B, 2>::from_data([[1.0, 2.0]], &device); // [1, 2]
|
||
let b = Tensor::<B, 2>::from_data([[3.0, 4.0]], &device); // [1, 2]
|
||
|
||
// cat: join along an EXISTING dimension
|
||
let catted = Tensor::cat(vec![a.clone(), b.clone()], 0); // [2, 2]
|
||
|
||
// stack: join along a NEW dimension
|
||
let stacked = Tensor::stack(vec![a, b], 0); // [2, 1, 2]
|
||
```
|
||
|
||
### Slicing
|
||
|
||
```rust
|
||
use burn::tensor::s;
|
||
|
||
let t = Tensor::<B, 2>::from_data(
|
||
[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
|
||
&device,
|
||
);
|
||
|
||
let row0 = t.clone().slice([0..1, 0..3]); // [[1, 2, 3]]
|
||
let col01 = t.clone().slice([0..3, 0..2]); // [[1,2],[4,5],[7,8]]
|
||
let single = t.clone().slice([1..2, 1..2]); // [[5]]
|
||
```
|
||
|
||
### Comparisons → Bool tensors
|
||
|
||
```rust
|
||
let t = Tensor::<B, 1>::from_floats([1.0, 5.0, 3.0], &device);
|
||
|
||
let gt3: Tensor<B, 1, Bool> = t.clone().greater_elem(3.0); // [false, true, false]
|
||
let eq5: Tensor<B, 1, Bool> = t.clone().equal_elem(5.0); // [false, true, false]
|
||
```
|
||
|
||
### Exercise 2
|
||
|
||
Write a function `fn min_max_normalize(t: Tensor<B, 1>) -> Tensor<B, 1>` that:
|
||
|
||
1. Computes min and max of the input (remember to clone!).
|
||
2. Returns `(t - min) / (max - min)`.
|
||
3. Test it with `[2.0, 4.0, 6.0, 8.0, 10.0]` — expected output: `[0.0, 0.25, 0.5, 0.75, 1.0]`.
|
||
|
||
Then write a second function that takes a 2-D tensor and returns the softmax along dimension 1
|
||
using only basic ops: `exp`, `sum_dim`, and division. Compare a few values mentally.
|
||
|
||
---
|
||
|
||
## Step 3 — Autodiff: Computing Gradients
|
||
|
||
### How Burn does autodiff
|
||
|
||
Unlike PyTorch where autograd is baked into every tensor, Burn uses a **backend decorator** pattern.
|
||
You wrap your base backend with `Autodiff<...>` and that transparently adds gradient tracking:
|
||
|
||
```
|
||
NdArray → forward-only, no gradients
|
||
Autodiff<NdArray> → same ops, but now you can call .backward()
|
||
```
|
||
|
||
This is a zero-cost abstraction: if you don't need gradients (inference), you don't pay for them.
|
||
|
||
The `AutodiffBackend` trait extends `Backend` with gradient-related methods. When you write
|
||
training code, you constrain your generic with `B: AutodiffBackend`.
|
||
|
||
### A first gradient computation
|
||
|
||
```rust
|
||
use burn::backend::ndarray::NdArray;
|
||
use burn::backend::Autodiff;
|
||
use burn::tensor::{Tensor, backend::AutodiffBackend};
|
||
|
||
type TrainBackend = Autodiff<NdArray>;
|
||
|
||
fn main() {
|
||
let device = Default::default();
|
||
|
||
// Create a tensor that will be tracked for gradients
|
||
let x = Tensor::<TrainBackend, 1>::from_floats([1.0, 2.0, 3.0], &device)
|
||
.require_grad();
|
||
|
||
// Forward: compute y = x^2, then sum to get a scalar loss
|
||
let y = x.clone().powf_scalar(2.0); // [1, 4, 9]
|
||
let loss = y.sum(); // 14.0
|
||
|
||
// Backward: compute gradients
|
||
let grads = loss.backward();
|
||
|
||
// Extract gradient of `loss` w.r.t. `x`
|
||
// dy/dx = 2x → expected: [2.0, 4.0, 6.0]
|
||
let x_grad = x.grad(&grads).expect("x should have a gradient");
|
||
println!("x = {}", x.to_data());
|
||
println!("dx = {}", x_grad.to_data());
|
||
}
|
||
```
|
||
|
||
### The flow in detail
|
||
|
||
1. **`require_grad()`** — marks a leaf tensor as "track me". Any tensor derived from it inherits
|
||
tracking automatically.
|
||
2. **Forward pass** — you build a computation graph by chaining normal tensor ops.
|
||
3. **`.backward()`** — called on a **scalar** tensor (0-D or summed to 0-D). Returns a `Gradients`
|
||
object containing all partial derivatives.
|
||
4. **`.grad(&grads)`** — retrieves the gradient for a specific tensor. Returns `Option<Tensor>`.
|
||
|
||
### `model.valid()` — disabling gradients
|
||
|
||
When you have a model on an `Autodiff<...>` backend and want to run inference without tracking:
|
||
|
||
```rust
|
||
// During training: model is on Autodiff<NdArray>
|
||
let model_valid = model.valid(); // returns the same model on plain NdArray
|
||
// Now forward passes do not build a graph → faster, less memory
|
||
```
|
||
|
||
This replaces PyTorch's `with torch.no_grad():` and `model.eval()` (for the grad-tracking part).
|
||
|
||
### Exercise 3
|
||
|
||
1. Create a tensor `w = [0.5, -0.3, 0.8]` with `require_grad()` on `Autodiff<NdArray>`.
|
||
2. Create an input `x = [1.0, 2.0, 3.0]` (no grad needed — think of it as data).
|
||
3. Compute a "prediction" `y_hat = (w * x).sum()` (dot product).
|
||
4. Define a "target" `y = 2.5` and compute MSE loss: `loss = (y_hat - y)^2`.
|
||
5. Call `.backward()` and print the gradient of `w`.
|
||
6. Manually verify the gradient: `d(loss)/dw_i = 2 * (y_hat - y) * x_i`.
|
||
|
||
---
|
||
|
||
## Step 4 — The Config Pattern
|
||
|
||
### Why configs exist
|
||
|
||
In PyTorch, you create a layer like `nn.Linear(784, 128, bias=True)` — the constructor directly
|
||
builds the parameters. Burn separates **configuration** from **initialization** into two steps:
|
||
|
||
```
|
||
LinearConfig::new(784, 128) // 1. describe WHAT you want
|
||
.with_bias(true) // (builder-style tweaks)
|
||
.init(&device) // 2. actually allocate weights on a device
|
||
```
|
||
|
||
This separation matters because:
|
||
|
||
- Configs are serializable (save/load hyperparameters to JSON).
|
||
- Configs are device-agnostic — you can build on CPU today, GPU tomorrow.
|
||
- You can inspect or modify a config before committing to allocating memory.
|
||
|
||
### The `#[derive(Config)]` macro
|
||
|
||
Burn's `Config` derive macro generates a builder pattern for you:
|
||
|
||
```rust
|
||
use burn::config::Config;
|
||
|
||
#[derive(Config)]
|
||
pub struct MyTrainingConfig {
|
||
#[config(default = 64)]
|
||
pub batch_size: usize,
|
||
|
||
#[config(default = 1e-3)]
|
||
pub learning_rate: f64,
|
||
|
||
#[config(default = 10)]
|
||
pub num_epochs: usize,
|
||
|
||
pub model: MyModelConfig, // nested config — no default, must be provided
|
||
}
|
||
```
|
||
|
||
This generates:
|
||
|
||
- `MyTrainingConfig::new(model: MyModelConfig)` — constructor requiring non-default fields.
|
||
- `.with_batch_size(128)` — optional builder method for each field with a default.
|
||
- Serialization support (JSON via serde).
|
||
|
||
### Saving and loading configs
|
||
|
||
```rust
|
||
// Save
|
||
let config = MyTrainingConfig::new(MyModelConfig::new(10, 512));
|
||
config.save("config.json").expect("failed to save config");
|
||
|
||
// Load
|
||
let loaded = MyTrainingConfig::load("config.json").expect("failed to load config");
|
||
```
|
||
|
||
### Exercise 4
|
||
|
||
Define a `Config` struct called `ExperimentConfig` with:
|
||
|
||
- `hidden_size: usize` (no default — required)
|
||
- `dropout: f64` (default 0.5)
|
||
- `activation: String` (default `"relu"`)
|
||
- `lr: f64` (default 1e-4)
|
||
|
||
Create an instance, modify `dropout` to 0.3 via the builder, save it to JSON, load it back, and
|
||
assert the values match.
|
||
|
||
---
|
||
|
||
## Step 5 — Defining Modules (Your First Neural Network)
|
||
|
||
### The `#[derive(Module)]` macro
|
||
|
||
A Burn "module" is any struct that holds learnable parameters (or sub-modules that do). You annotate
|
||
it with `#[derive(Module)]` and make it generic over `Backend`:
|
||
|
||
```rust
|
||
use burn::prelude::*;
|
||
use burn::nn::{Linear, LinearConfig, Relu};
|
||
|
||
#[derive(Module, Debug)]
|
||
pub struct Mlp<B: Backend> {
|
||
linear1: Linear<B>,
|
||
linear2: Linear<B>,
|
||
activation: Relu,
|
||
}
|
||
```
|
||
|
||
The derive macro automatically implements the `Module` trait, which gives you:
|
||
|
||
- Parameter iteration (for optimizers).
|
||
- Device transfer (`.to_device()`).
|
||
- Serialization (save/load weights).
|
||
- `.valid()` to strip autodiff.
|
||
- `.clone()` (cheap, reference-counted).
|
||
|
||
### The Config → init pattern for your own module
|
||
|
||
Every module you write should have a companion config:
|
||
|
||
```rust
|
||
#[derive(Config)]
|
||
pub struct MlpConfig {
|
||
input_dim: usize,
|
||
hidden_dim: usize,
|
||
output_dim: usize,
|
||
}
|
||
|
||
impl MlpConfig {
|
||
pub fn init<B: Backend>(&self, device: &B::Device) -> Mlp<B> {
|
||
Mlp {
|
||
linear1: LinearConfig::new(self.input_dim, self.hidden_dim)
|
||
.init(device),
|
||
linear2: LinearConfig::new(self.hidden_dim, self.output_dim)
|
||
.init(device),
|
||
activation: Relu::new(),
|
||
}
|
||
}
|
||
}
|
||
```
|
||
|
||
### The forward pass
|
||
|
||
By convention, you implement a `forward` method. This is NOT a trait method — it's just a
|
||
convention. You can name it anything, or have multiple forward methods for different purposes.
|
||
|
||
```rust
|
||
impl<B: Backend> Mlp<B> {
|
||
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
||
let x = self.linear1.forward(x); // [batch, hidden]
|
||
let x = self.activation.forward(x);
|
||
self.linear2.forward(x) // [batch, output]
|
||
}
|
||
}
|
||
```
|
||
|
||
### Putting it together
|
||
|
||
```rust
|
||
use burn::backend::NdArray;
|
||
|
||
fn main() {
|
||
type B = NdArray;
|
||
let device = Default::default();
|
||
|
||
let config = MlpConfig::new(4, 32, 2);
|
||
let model = config.init::<B>(&device);
|
||
|
||
// Random input: batch of 8 samples, 4 features each
|
||
let input = Tensor::<B, 2>::random(
|
||
[8, 4],
|
||
burn::tensor::Distribution::Normal(0.0, 1.0),
|
||
&device,
|
||
);
|
||
|
||
let output = model.forward(input);
|
||
println!("Output shape: {:?}", output.dims()); // [8, 2]
|
||
println!("{:.4}", output);
|
||
}
|
||
```
|
||
|
||
### Exercise 5
|
||
|
||
Build a 3-layer MLP from scratch:
|
||
|
||
- Input: 10 features
|
||
- Hidden 1: 64 neurons + ReLU
|
||
- Hidden 2: 32 neurons + ReLU
|
||
- Output: 3 neurons (raw logits for a 3-class problem)
|
||
|
||
Create the config, init the model, feed a random `[16, 10]` batch through it, and verify the output
|
||
shape is `[16, 3]`.
|
||
|
||
---
|
||
|
||
## Step 6 — Built-in Layers Tour
|
||
|
||
### The naming convention
|
||
|
||
Every built-in layer in `burn::nn` follows the same pattern:
|
||
|
||
```
|
||
ThingConfig::new(required_args) // create config
|
||
.with_optional_param(value) // builder tweaks
|
||
.init(&device) // → Thing<B>
|
||
```
|
||
|
||
The layer struct has a `.forward(input)` method.
|
||
|
||
### Linear (fully connected)
|
||
|
||
```rust
|
||
use burn::nn::{Linear, LinearConfig};
|
||
|
||
let layer = LinearConfig::new(784, 128)
|
||
.with_bias(true) // default is true
|
||
.init::<B>(&device);
|
||
|
||
// input: [batch, 784] → output: [batch, 128]
|
||
let out = layer.forward(input);
|
||
```
|
||
|
||
### Conv2d
|
||
|
||
```rust
|
||
use burn::nn::conv::{Conv2d, Conv2dConfig};
|
||
use burn::nn::PaddingConfig2d;
|
||
|
||
let conv = Conv2dConfig::new([1, 32], [3, 3]) // [in_channels, out_channels], [kH, kW]
|
||
.with_padding(PaddingConfig2d::Same)
|
||
.with_stride([1, 1])
|
||
.init::<B>(&device);
|
||
|
||
// input: [batch, channels, H, W] → output: [batch, 32, H, W] (with Same padding)
|
||
let out = conv.forward(input_4d);
|
||
```
|
||
|
||
### BatchNorm
|
||
|
||
```rust
|
||
use burn::nn::{BatchNorm, BatchNormConfig};
|
||
|
||
let bn = BatchNormConfig::new(32) // num_features (channels)
|
||
.init::<B>(&device);
|
||
|
||
// input: [batch, 32, H, W] → same shape, normalized
|
||
let out = bn.forward(input);
|
||
```
|
||
|
||
### Dropout
|
||
|
||
```rust
|
||
use burn::nn::{Dropout, DropoutConfig};
|
||
|
||
let dropout = DropoutConfig::new(0.5).init();
|
||
|
||
// During training, randomly zeros elements.
|
||
// Note: Dropout is stateless, no Backend generic needed.
|
||
let out = dropout.forward(input);
|
||
```
|
||
|
||
Burn's dropout is automatically disabled during inference when you use `model.valid()`.
|
||
|
||
### Pooling
|
||
|
||
```rust
|
||
use burn::nn::pool::{
|
||
MaxPool2d, MaxPool2dConfig,
|
||
AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig,
|
||
};
|
||
|
||
let maxpool = MaxPool2dConfig::new([2, 2])
|
||
.with_stride([2, 2])
|
||
.init();
|
||
|
||
let adaptive = AdaptiveAvgPool2dConfig::new([1, 1]).init();
|
||
|
||
let out = maxpool.forward(input); // halves spatial dims
|
||
let out = adaptive.forward(input); // spatial → 1×1
|
||
```
|
||
|
||
### Embedding
|
||
|
||
```rust
|
||
use burn::nn::{Embedding, EmbeddingConfig};
|
||
|
||
let embed = EmbeddingConfig::new(10000, 256) // vocab_size, embed_dim
|
||
.init::<B>(&device);
|
||
|
||
// input: Int tensor [batch, seq_len] → output: [batch, seq_len, 256]
|
||
let out = embed.forward(token_ids);
|
||
```
|
||
|
||
### Activations (stateless functions)
|
||
|
||
Most activations are free functions in `burn::tensor::activation`, not module structs:
|
||
|
||
```rust
|
||
use burn::tensor::activation;
|
||
|
||
let x = Tensor::<B, 2>::random([4, 4], Distribution::Normal(0.0, 1.0), &device);
|
||
|
||
let _ = activation::relu(x.clone());
|
||
let _ = activation::sigmoid(x.clone());
|
||
let _ = activation::gelu(x.clone());
|
||
let _ = activation::softmax(x.clone(), 1); // softmax along dim 1
|
||
let _ = activation::log_softmax(x.clone(), 1);
|
||
```
|
||
|
||
There are also module wrappers like `Relu`, `Gelu`, etc. that just call these functions.
|
||
Use whichever style you prefer.
|
||
|
||
### Exercise 6
|
||
|
||
Build a small CNN module for 1-channel 28×28 images (think MNIST):
|
||
|
||
- Conv2d: 1 → 16 channels, 3×3 kernel, same padding → ReLU → MaxPool 2×2
|
||
- Conv2d: 16 → 32 channels, 3×3 kernel, same padding → ReLU → AdaptiveAvgPool 1×1
|
||
- Flatten → Linear(32, 10)
|
||
|
||
Create the config, init the model, feed a random `[4, 1, 28, 28]` tensor, verify output is
|
||
`[4, 10]`.
|
||
|
||
---
|
||
|
||
## Step 7 — Loss Functions
|
||
|
||
### Built-in losses
|
||
|
||
Burn provides losses in `burn::nn::loss`:
|
||
|
||
```rust
|
||
use burn::nn::loss::{CrossEntropyLoss, CrossEntropyLossConfig, MseLoss};
|
||
```
|
||
|
||
**Cross-entropy** (for classification):
|
||
|
||
```rust
|
||
let loss_fn = CrossEntropyLossConfig::new()
|
||
.init(&device);
|
||
|
||
// logits: [batch, num_classes] — raw model output, NOT softmax'd
|
||
// targets: [batch] — Int tensor of class indices
|
||
let loss: Tensor<B, 1> = loss_fn.forward(logits, targets);
|
||
```
|
||
|
||
Note: `CrossEntropyLoss` expects **logits** (raw scores) by default, not probabilities. It applies
|
||
log-softmax internally, just like PyTorch's `nn.CrossEntropyLoss`.
|
||
|
||
**MSE loss** (for regression):
|
||
|
||
```rust
|
||
let mse = MseLoss::new();
|
||
// Both tensors must have the same shape.
|
||
let loss = mse.forward(predictions, targets, burn::nn::loss::Reduction::Mean);
|
||
```
|
||
|
||
### Writing a custom loss
|
||
|
||
A loss is just a function that takes tensors and returns a scalar tensor. There's nothing special
|
||
about it:
|
||
|
||
```rust
|
||
/// Binary cross-entropy from logits (custom implementation).
|
||
fn bce_with_logits<B: Backend>(logits: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
|
||
// sigmoid + log trick for numerical stability:
|
||
// loss = max(logits, 0) - logits * targets + log(1 + exp(-|logits|))
|
||
let zeros = Tensor::zeros_like(&logits);
|
||
let max_val = logits.clone().max_pair(zeros);
|
||
let abs_logits = logits.clone().abs();
|
||
|
||
let loss = max_val - logits * targets
|
||
+ ((-abs_logits).exp() + 1.0).log();
|
||
|
||
loss.mean()
|
||
}
|
||
```
|
||
|
||
### Exercise 7
|
||
|
||
1. Create a random logits tensor `[8, 5]` (batch=8, 5 classes) and a random targets Int tensor
|
||
`[8]` with values in 0..5.
|
||
2. Compute cross-entropy loss using `CrossEntropyLossConfig`.
|
||
3. Write your own `manual_cross_entropy` function that:
|
||
- Computes log-softmax: `log_softmax(x, dim=1)` via `activation::log_softmax`.
|
||
- Gathers the log-prob at the target index for each sample (use a loop or
|
||
`Tensor::select` + some reshaping).
|
||
- Returns the negative mean.
|
||
4. Compare the two values — they should be very close (use `check_closeness` or just eyeball the
|
||
printed scalars).
|
||
|
||
---
|
||
|
||
## Step 8 — Datasets & Batching
|
||
|
||
### The `Dataset` trait
|
||
|
||
Burn's dataset system lives in `burn::data::dataset`:
|
||
|
||
```rust
|
||
pub trait Dataset<I>: Send + Sync {
|
||
fn get(&self, index: usize) -> Option<I>;
|
||
fn len(&self) -> usize;
|
||
fn is_empty(&self) -> bool { self.len() == 0 }
|
||
}
|
||
```
|
||
|
||
Simple — it's a random-access collection.
|
||
|
||
### InMemDataset
|
||
|
||
The simplest dataset — a Vec in memory:
|
||
|
||
```rust
|
||
use burn::data::dataset::InMemDataset;
|
||
|
||
#[derive(Clone, Debug)]
|
||
struct Sample {
|
||
features: [f32; 4],
|
||
label: u8,
|
||
}
|
||
|
||
let samples = vec![
|
||
Sample { features: [1.0, 2.0, 3.0, 4.0], label: 0 },
|
||
Sample { features: [5.0, 6.0, 7.0, 8.0], label: 1 },
|
||
// ... more
|
||
];
|
||
|
||
let dataset = InMemDataset::new(samples);
|
||
println!("Dataset length: {}", dataset.len());
|
||
println!("Sample 0: {:?}", dataset.get(0));
|
||
```
|
||
|
||
### The `Batcher` trait
|
||
|
||
A batcher converts a `Vec<Item>` into a batch struct containing tensors:
|
||
|
||
```rust
|
||
use burn::data::dataloader::batcher::Batcher;
|
||
|
||
#[derive(Clone, Debug)]
|
||
struct MyBatch<B: Backend> {
|
||
inputs: Tensor<B, 2>, // [batch, features]
|
||
targets: Tensor<B, 1, Int>, // [batch]
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
struct MyBatcher<B: Backend> {
|
||
device: B::Device,
|
||
}
|
||
|
||
impl<B: Backend> Batcher<Sample, MyBatch<B>> for MyBatcher<B> {
|
||
fn batch(&self, items: Vec<Sample>) -> MyBatch<B> {
|
||
let inputs: Vec<Tensor<B, 2>> = items.iter()
|
||
.map(|s| Tensor::<B, 1>::from_floats(s.features, &self.device))
|
||
.map(|t| t.unsqueeze_dim(0)) // [4] → [1, 4]
|
||
.collect();
|
||
|
||
let targets: Vec<Tensor<B, 1, Int>> = items.iter()
|
||
.map(|s| Tensor::<B, 1, Int>::from_data(
|
||
TensorData::from([s.label as i32]),
|
||
&self.device,
|
||
))
|
||
.collect();
|
||
|
||
MyBatch {
|
||
inputs: Tensor::cat(inputs, 0), // [batch, 4]
|
||
targets: Tensor::cat(targets, 0), // [batch]
|
||
}
|
||
}
|
||
}
|
||
```
|
||
|
||
### DataLoaderBuilder
|
||
|
||
```rust
|
||
use burn::data::dataloader::DataLoaderBuilder;
|
||
|
||
let batcher = MyBatcher::<B> { device: device.clone() };
|
||
let dataloader = DataLoaderBuilder::new(batcher)
|
||
.batch_size(32)
|
||
.shuffle(42) // seed for shuffling
|
||
.num_workers(4) // parallel data loading threads
|
||
.build(dataset);
|
||
|
||
for batch in dataloader.iter() {
|
||
println!("Batch inputs: {:?}", batch.inputs.dims());
|
||
println!("Batch targets: {:?}", batch.targets.dims());
|
||
}
|
||
```
|
||
|
||
### Exercise 8
|
||
|
||
Create a synthetic XOR-like dataset:
|
||
|
||
- 200 samples of `[x1, x2]` where x1, x2 are random in {0, 1}.
|
||
- Label is `x1 XOR x2` (0 or 1).
|
||
- Implement the `Sample` struct, populate an `InMemDataset`, write a `Batcher`, and build a
|
||
`DataLoader` with batch size 16.
|
||
- Loop over one epoch and print each batch's input shape and target shape.
|
||
|
||
---
|
||
|
||
## Step 9 — Optimizers
|
||
|
||
### How optimizers work in Burn
|
||
|
||
In PyTorch, you register parameters with the optimizer, call `loss.backward()`, then
|
||
`optimizer.step()`, then `optimizer.zero_grad()`. Burn is more functional:
|
||
|
||
```rust
|
||
let grads = loss.backward(); // 1. get raw gradients
|
||
let grads = GradientsParams::from_grads(grads, &model); // 2. map to parameter IDs
|
||
model = optim.step(lr, model, grads); // 3. optimizer consumes grads, returns updated model
|
||
```
|
||
|
||
No `zero_grad` — gradients are consumed in step 3. No mutation — `step` takes ownership of the
|
||
model and returns an updated one.
|
||
|
||
### Creating an optimizer
|
||
|
||
```rust
|
||
use burn::optim::{AdamConfig, SgdConfig, Adam, Sgd};
|
||
|
||
// Adam (most common)
|
||
let optim = AdamConfig::new()
|
||
.with_beta_1(0.9)
|
||
.with_beta_2(0.999)
|
||
.with_epsilon(1e-8)
|
||
.with_weight_decay_config(None)
|
||
.init();
|
||
|
||
// SGD with momentum
|
||
let optim = SgdConfig::new()
|
||
.with_momentum(Some(burn::optim::MomentumConfig {
|
||
momentum: 0.9,
|
||
dampening: 0.0,
|
||
nesterov: false,
|
||
}))
|
||
.init();
|
||
```
|
||
|
||
The optimizer is generic — it works with any module on any `AutodiffBackend`.
|
||
|
||
### Gradient accumulation
|
||
|
||
If you need to accumulate gradients over multiple mini-batches before stepping:
|
||
|
||
```rust
|
||
use burn::optim::GradientsAccumulator;
|
||
|
||
let mut accumulator = GradientsAccumulator::new();
|
||
|
||
for batch in mini_batches {
|
||
let loss = model.forward(batch);
|
||
let grads = loss.backward();
|
||
let grads = GradientsParams::from_grads(grads, &model);
|
||
accumulator.accumulate(&model, grads);
|
||
}
|
||
|
||
let accumulated = accumulator.grads();
|
||
model = optim.step(lr, model, accumulated);
|
||
```
|
||
|
||
### Exercise 9
|
||
|
||
Using the XOR dataset from Exercise 8 and the 3-layer MLP from Exercise 5 (adapt dimensions to
|
||
input=2, hidden=16, output=2):
|
||
|
||
1. Create an Adam optimizer.
|
||
2. Run one batch through the model.
|
||
3. Compute cross-entropy loss.
|
||
4. Call `.backward()` → `GradientsParams::from_grads` → `optim.step`.
|
||
5. Print the loss before and after the step (run the forward pass again) to confirm it changed.
|
||
|
||
Don't write the full loop yet — just verify the single-step mechanics compile and run.
|
||
|
||
---
|
||
|
||
## Step 10 — The Full Training Loop (Manual)
|
||
|
||
### Wiring everything together
|
||
|
||
This is the capstone: a complete, working training loop with no `Learner` abstraction. We'll train
|
||
an MLP on the XOR problem.
|
||
|
||
```rust
|
||
use burn::backend::ndarray::NdArray;
|
||
use burn::backend::Autodiff;
|
||
use burn::data::dataloader::DataLoaderBuilder;
|
||
use burn::data::dataset::InMemDataset;
|
||
use burn::nn::loss::{CrossEntropyLoss, CrossEntropyLossConfig};
|
||
use burn::optim::{AdamConfig, GradientsParams};
|
||
use burn::prelude::*;
|
||
use burn::tensor::backend::AutodiffBackend;
|
||
|
||
// ---- Backend aliases ----
|
||
type TrainB = Autodiff<NdArray>;
|
||
type InferB = NdArray;
|
||
|
||
// ---- Model (reuse your MLP, adapted) ----
|
||
use burn::nn::{Linear, LinearConfig, Relu};
|
||
|
||
#[derive(Module, Debug)]
|
||
pub struct XorModel<B: Backend> {
|
||
fc1: Linear<B>,
|
||
fc2: Linear<B>,
|
||
fc3: Linear<B>,
|
||
relu: Relu,
|
||
}
|
||
|
||
#[derive(Config)]
|
||
pub struct XorModelConfig {
|
||
#[config(default = 2)]
|
||
input: usize,
|
||
#[config(default = 16)]
|
||
hidden: usize,
|
||
#[config(default = 2)]
|
||
output: usize,
|
||
}
|
||
|
||
impl XorModelConfig {
|
||
pub fn init<B: Backend>(&self, device: &B::Device) -> XorModel<B> {
|
||
XorModel {
|
||
fc1: LinearConfig::new(self.input, self.hidden).init(device),
|
||
fc2: LinearConfig::new(self.hidden, self.hidden).init(device),
|
||
fc3: LinearConfig::new(self.hidden, self.output).init(device),
|
||
relu: Relu::new(),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl<B: Backend> XorModel<B> {
|
||
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
||
let x = self.relu.forward(self.fc1.forward(x));
|
||
let x = self.relu.forward(self.fc2.forward(x));
|
||
self.fc3.forward(x)
|
||
}
|
||
}
|
||
|
||
// ---- Data ----
|
||
#[derive(Clone, Debug)]
|
||
struct XorSample {
|
||
x: [f32; 2],
|
||
y: i32,
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
struct XorBatcher<B: Backend> {
|
||
device: B::Device,
|
||
}
|
||
|
||
impl<B: Backend> burn::data::dataloader::batcher::Batcher<XorSample, (Tensor<B, 2>, Tensor<B, 1, Int>)>
|
||
for XorBatcher<B>
|
||
{
|
||
fn batch(&self, items: Vec<XorSample>) -> (Tensor<B, 2>, Tensor<B, 1, Int>) {
|
||
let xs: Vec<Tensor<B, 2>> = items.iter()
|
||
.map(|s| Tensor::<B, 1>::from_floats(s.x, &self.device).unsqueeze_dim(0))
|
||
.collect();
|
||
let ys: Vec<Tensor<B, 1, Int>> = items.iter()
|
||
.map(|s| Tensor::<B, 1, Int>::from_data(TensorData::from([s.y]), &self.device))
|
||
.collect();
|
||
(Tensor::cat(xs, 0), Tensor::cat(ys, 0))
|
||
}
|
||
}
|
||
|
||
fn make_xor_dataset() -> InMemDataset<XorSample> {
|
||
let mut samples = Vec::new();
|
||
// Repeat patterns many times so the network has enough data
|
||
for _ in 0..200 {
|
||
samples.push(XorSample { x: [0.0, 0.0], y: 0 });
|
||
samples.push(XorSample { x: [0.0, 1.0], y: 1 });
|
||
samples.push(XorSample { x: [1.0, 0.0], y: 1 });
|
||
samples.push(XorSample { x: [1.0, 1.0], y: 0 });
|
||
}
|
||
InMemDataset::new(samples)
|
||
}
|
||
|
||
// ---- Training loop ----
|
||
fn train() {
|
||
let device = Default::default();
|
||
|
||
// Config
|
||
let lr = 1e-3;
|
||
let num_epochs = 50;
|
||
let batch_size = 32;
|
||
|
||
// Model + optimizer
|
||
let model_config = XorModelConfig::new();
|
||
let mut model: XorModel<TrainB> = model_config.init(&device);
|
||
let mut optim = AdamConfig::new().init();
|
||
|
||
// Data
|
||
let dataset = make_xor_dataset();
|
||
let batcher = XorBatcher::<TrainB> { device: device.clone() };
|
||
let dataloader = DataLoaderBuilder::new(batcher)
|
||
.batch_size(batch_size)
|
||
.shuffle(42)
|
||
.build(dataset);
|
||
|
||
// Loss function
|
||
let loss_fn = CrossEntropyLossConfig::new().init(&device);
|
||
|
||
// Train!
|
||
for epoch in 1..=num_epochs {
|
||
let mut epoch_loss = 0.0;
|
||
let mut n_batches = 0;
|
||
|
||
for (inputs, targets) in dataloader.iter() {
|
||
// Forward
|
||
let logits = model.forward(inputs);
|
||
let loss = loss_fn.forward(logits.clone(), targets.clone());
|
||
|
||
// Backward + optimize
|
||
let grads = loss.backward();
|
||
let grads = GradientsParams::from_grads(grads, &model);
|
||
model = optim.step(lr, model, grads);
|
||
|
||
epoch_loss += loss.into_scalar().elem::<f32>();
|
||
n_batches += 1;
|
||
}
|
||
|
||
if epoch % 10 == 0 || epoch == 1 {
|
||
println!(
|
||
"Epoch {:>3} | Loss: {:.4}",
|
||
epoch,
|
||
epoch_loss / n_batches as f32,
|
||
);
|
||
}
|
||
}
|
||
|
||
// ---- Quick inference check ----
|
||
let model_valid = model.valid(); // drop autodiff → runs on NdArray
|
||
|
||
let test_input = Tensor::<InferB, 2>::from_data(
|
||
[[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]],
|
||
&device,
|
||
);
|
||
let predictions = model_valid.forward(test_input);
|
||
let classes = predictions.argmax(1);
|
||
println!("\nPredictions (should be 0, 1, 1, 0):");
|
||
println!("{}", classes.to_data());
|
||
}
|
||
|
||
fn main() {
|
||
train();
|
||
}
|
||
```
|
||
|
||
### Key points to internalize
|
||
|
||
- **`model` is reassigned** on every `optim.step(...)` — Burn is functional: the old model is
|
||
consumed, a new updated model is returned.
|
||
- **`loss.backward()`** returns raw gradients. You must wrap them with `GradientsParams::from_grads`
|
||
to associate them with parameter IDs.
|
||
- **No `zero_grad`** — gradients are consumed by `step`.
|
||
- **`model.valid()`** strips the `Autodiff` wrapper. The returned model lives on the inner backend
|
||
(`NdArray`). Use it for validation/inference.
|
||
- **`loss.into_scalar()`** gives you a backend-specific element. Use `.elem::<f32>()` to get a
|
||
plain Rust `f32` for printing.
|
||
|
||
### Exercise 10
|
||
|
||
Take the training loop above and extend it:
|
||
|
||
1. Add a **validation step** after each epoch: call `model.valid()`, run the 4 XOR inputs through
|
||
it, compute accuracy (percentage of correct predictions).
|
||
2. Add **early stopping**: if validation accuracy hits 100%, break out of the epoch loop.
|
||
3. Experiment: reduce the hidden size to 4 — does it still learn? Increase to 64 — does it learn
|
||
faster?
|
||
|
||
---
|
||
|
||
## Step 11 — Saving & Loading Models
|
||
|
||
### Records and Recorders
|
||
|
||
Burn saves models as "records" using "recorders". A record is the serializable state of a module
|
||
(all its parameter tensors). A recorder is a strategy for how to store that record (binary,
|
||
MessagePack, etc.).
|
||
|
||
### Saving
|
||
|
||
```rust
|
||
use burn::record::{CompactRecorder, NamedMpkFileRecorder, FullPrecisionSettings};
|
||
|
||
let recorder = CompactRecorder::new();
|
||
|
||
// Save — writes to "my_model.bin" (extension added automatically)
|
||
model
|
||
.save_file("my_model", &recorder)
|
||
.expect("Failed to save model");
|
||
```
|
||
|
||
`CompactRecorder` uses a compact binary format. `NamedMpkFileRecorder` uses named MessagePack,
|
||
which is more portable and human-debuggable.
|
||
|
||
### Loading
|
||
|
||
```rust
|
||
// First, create a fresh model with random weights
|
||
let mut model = model_config.init::<B>(&device);
|
||
|
||
// Then load saved weights into it
|
||
let recorder = CompactRecorder::new();
|
||
model = model
|
||
.load_file("my_model", &recorder, &device)
|
||
.expect("Failed to load model");
|
||
```
|
||
|
||
Note: the model architecture must match exactly — same layers, same sizes.
|
||
|
||
### A complete save/load cycle
|
||
|
||
```rust
|
||
fn save_model<B: Backend>(model: &XorModel<B>) {
|
||
let recorder = CompactRecorder::new();
|
||
model
|
||
.save_file("xor_trained", &recorder)
|
||
.expect("save failed");
|
||
println!("Model saved.");
|
||
}
|
||
|
||
fn load_model<B: Backend>(device: &B::Device) -> XorModel<B> {
|
||
let model = XorModelConfig::new().init::<B>(device);
|
||
let recorder = CompactRecorder::new();
|
||
model
|
||
.load_file("xor_trained", &recorder, device)
|
||
.expect("load failed")
|
||
}
|
||
|
||
// In your training function, after training:
|
||
save_model(&model_valid);
|
||
|
||
// Later, for inference:
|
||
let loaded: XorModel<InferB> = load_model(&device);
|
||
let output = loaded.forward(test_input);
|
||
```
|
||
|
||
### Checkpoint during training
|
||
|
||
A common pattern is to save every N epochs:
|
||
|
||
```rust
|
||
for epoch in 1..=num_epochs {
|
||
// ... training ...
|
||
|
||
if epoch % 10 == 0 {
|
||
let snapshot = model.valid(); // strip autodiff for saving
|
||
save_model(&snapshot);
|
||
println!("Checkpoint saved at epoch {}", epoch);
|
||
}
|
||
}
|
||
```
|
||
|
||
### Exercise 11
|
||
|
||
1. Train the XOR model from Step 10 until it converges.
|
||
2. Save it with `CompactRecorder`.
|
||
3. In a **separate function** (simulating a new program run), load the model on `NdArray` and run
|
||
inference on all 4 XOR inputs.
|
||
4. Verify the predictions match what you got at the end of training.
|
||
|
||
**Bonus:** also try `NamedMpkFileRecorder::<FullPrecisionSettings>` and compare the file sizes.
|
||
|
||
---
|
||
|
||
## Quick Reference Card
|
||
|
||
| Concept | Burn | PyTorch equivalent |
|
||
|---|---|---|
|
||
| Backend selection | `type B = NdArray;` | device selection |
|
||
| Tensor creation | `Tensor::<B,2>::from_data(...)` | `torch.tensor(...)` |
|
||
| Autodiff | `Autodiff<NdArray>` backend wrapper | built-in autograd |
|
||
| Disable grad | `model.valid()` | `torch.no_grad()` |
|
||
| Module definition | `#[derive(Module)]` | `nn.Module` subclass |
|
||
| Layer construction | `LinearConfig::new(a,b).init(&dev)` | `nn.Linear(a,b)` |
|
||
| Hyperparameters | `#[derive(Config)]` | manual / dataclass |
|
||
| Loss | `CrossEntropyLossConfig::new().init(&dev)` | `nn.CrossEntropyLoss()` |
|
||
| Optimizer | `AdamConfig::new().init()` | `optim.Adam(params, lr)` |
|
||
| Backward | `loss.backward()` | `loss.backward()` |
|
||
| Param gradients | `GradientsParams::from_grads(grads, &model)` | automatic |
|
||
| Optimizer step | `model = optim.step(lr, model, grads)` | `optimizer.step()` |
|
||
| Zero grad | automatic (grads consumed) | `optimizer.zero_grad()` |
|
||
| Save model | `model.save_file("path", &recorder)` | `torch.save(state_dict)` |
|
||
| Load model | `model.load_file("path", &recorder, &dev)` | `model.load_state_dict(...)` |
|
||
|
||
---
|
||
|
||
## Where to go next
|
||
|
||
You now have the vocabulary and mechanics to read and write Burn code for deep learning practicals.
|
||
The topics we deliberately skipped (for a future chapter) include:
|
||
|
||
- **The `Learner` abstraction** — Burn's built-in training loop with metrics, checkpointing, and a
|
||
TUI dashboard.
|
||
- **GPU backends** — `Wgpu`, `Cuda`, `LibTorch` for actual hardware acceleration.
|
||
- **ONNX import** — loading models from PyTorch/TensorFlow.
|
||
- **Custom GPU kernels** — writing CubeCL compute shaders.
|
||
- **Quantization** — post-training quantization for deployment.
|
||
- **`no_std` deployment** — running Burn on bare-metal embedded (where NdArray also works).
|
||
|
||
For now, go do your deep learning exercises in Burn. You have the tools.
|