Files
Deep-Learning/burn_tutorial/ch1.md
2026-03-14 19:11:21 +01:00

37 KiB
Raw Blame History

Chapter 1: Burn Foundations — Tensors, Modules & the Training Loop

Burn version: 0.20.x · Backend: NdArray (CPU) only · Assumed knowledge: proficient Rust, ML fundamentals (scikit-learn level), studying deep learning theory.

Cargo.toml you should already have:

[dependencies]
burn = { version = "0.20", features = ["ndarray"] }

Step 1 — Tensor Basics: Creation, Shape & Kinds

What you need to know

Everything in Burn starts with Tensor. Unlike ndarray or nalgebra, Burn's tensor is generic over three things:

Tensor<B, D, K>
       │  │  └─ Kind: Float (default), Int, Bool
       │  └──── const dimensionality (rank) — a usize, NOT the shape
       └─────── Backend (NdArray, Wgpu, …)

The shape (e.g. 3×4) is a runtime value. The rank (e.g. 2) is a compile-time constant. This means the compiler catches dimension-mismatch bugs for you — you cannot accidentally add a 2-D tensor to a 3-D tensor.

The NdArray backend is Burn's pure-Rust CPU backend built on top of the ndarray crate. It needs no GPU drivers, compiles everywhere (including no_std embedded), and is the simplest backend to start with.

Creating tensors

use burn::backend::NdArray;
use burn::tensor::{Tensor, TensorData, Int, Bool};

// Type alias — used everywhere in this chapter.
type B = NdArray;

fn main() {
    let device = Default::default(); // NdArray only has a Cpu device

    // ---- Float tensors (default kind) ----

    // From a nested array literal — easiest way for concrete backends:
    let a = Tensor::<B, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);

    // from_floats — recommended shorthand for f32:
    let b = Tensor::<B, 1>::from_floats([10.0, 20.0, 30.0], &device);

    // Factories:
    let zeros = Tensor::<B, 2>::zeros([3, 4], &device);       // 3×4 of 0.0
    let ones  = Tensor::<B, 2>::ones([2, 2], &device);         // 2×2 of 1.0
    let full  = Tensor::<B, 1>::full([5], 3.14, &device);      // [3.14; 5]

    // Same shape as another tensor:
    let like_a = Tensor::<B, 2>::ones_like(&a);                // 2×2 of 1.0
    let like_z = Tensor::<B, 2>::zeros_like(&zeros);           // 3×4 of 0.0

    // Random (uniform or normal):
    use burn::tensor::Distribution;
    let uniform = Tensor::<B, 2>::random([4, 4], Distribution::Uniform(0.0, 1.0), &device);
    let normal  = Tensor::<B, 2>::random([4, 4], Distribution::Normal(0.0, 1.0), &device);

    // ---- Int tensors ----
    let labels = Tensor::<B, 1, Int>::from_data(TensorData::from([0i32, 1, 2, 1]), &device);
    let range  = Tensor::<B, 1, Int>::arange(0..5, &device);   // [0, 1, 2, 3, 4]

    // ---- Bool tensors ----
    let mask = a.clone().greater_elem(2.0); // Tensor<B, 2, Bool> — true where > 2

    println!("{}", a);
    println!("{}", labels);
    println!("{}", mask);
}

Extracting data back out

// .to_data() clones — tensor is still usable afterward
let data: TensorData = a.to_data();

// .into_data() moves — tensor is consumed
let data: TensorData = a.into_data();

// Get a single scalar
let scalar: f32 = Tensor::<B, 1>::from_floats([42.0], &device)
    .into_scalar();

Printing

println!("{}", tensor);         // full debug view with shape, backend, dtype
println!("{:.2}", tensor);      // limit to 2 decimal places

Exercise 1

Write a program that:

  1. Creates a 3×3 identity-like float tensor manually with from_data (1s on diagonal, 0s elsewhere).
  2. Creates a 1-D Int tensor [10, 20, 30].
  3. Creates a Bool tensor by checking which elements of the float tensor are greater than 0.
  4. Prints all three tensors.

Success criterion: compiles, prints the expected values, you understand every generic parameter.


Step 2 — Ownership, Cloning & Tensor Operations

Ownership: the one Burn rule you must internalize

Almost every Burn tensor operation takes ownership (moves) of the input tensor. This is different from PyTorch where tensors are reference-counted behind the scenes.

let x = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0], &device);
let y = x.exp();  // x is MOVED into exp(), you cannot use x anymore
// println!("{}", x);  // ← compile error: value used after move

The fix is .clone(). Cloning a Burn tensor is cheap — it only bumps a reference count, it does not copy the underlying buffer.

let x = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0], &device);
let y = x.clone().exp();   // x is still alive
let z = x.clone().log();   // x is still alive
let w = x + y;             // x is consumed here — that's fine, last use

Rule of thumb: clone whenever you need the tensor again later; don't clone on its last use. Burn will automatically do in-place operations when it detects a single owner.

Arithmetic

All standard ops are supported and work element-wise with broadcasting:

let a = Tensor::<B, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
let b = Tensor::<B, 2>::from_data([[10.0, 20.0], [30.0, 40.0]], &device);

let sum  = a.clone() + b.clone();         // element-wise add
let diff = a.clone() - b.clone();         // element-wise sub
let prod = a.clone() * b.clone();         // element-wise mul
let quot = a.clone() / b.clone();         // element-wise div

// Scalar operations
let scaled = a.clone() * 2.0;            // multiply every element by 2
let shifted = a.clone() + 1.0;           // add 1 to every element
let neg = -a.clone();                     // negate

Float-specific math

let x = Tensor::<B, 1>::from_floats([0.0, 1.0, 2.0], &device);

let _ = x.clone().exp();       // e^x
let _ = x.clone().log();       // ln(x)
let _ = x.clone().sqrt();      // √x
let _ = x.clone().sin();       // sin(x)
let _ = x.clone().tanh();      // tanh(x)
let _ = x.clone().powf_scalar(2.0);  // x²
let _ = x.clone().recip();     // 1/x
let _ = x.clone().abs();
let _ = x.clone().ceil();
let _ = x.clone().floor();

Matrix multiplication

let a = Tensor::<B, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
let b = Tensor::<B, 2>::from_data([[5.0, 6.0], [7.0, 8.0]], &device);

let c = a.matmul(b); // standard matrix multiply — [2,2] × [2,2] → [2,2]
println!("{}", c);

matmul works on any rank ≥ 2 with batch-broadcasting semantics (just like PyTorch's @).

Reductions

let t = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);

let total  = t.clone().sum();            // scalar: 21.0
let col    = t.clone().sum_dim(0);       // sum along rows → shape [1, 3]
let row    = t.clone().sum_dim(1);       // sum along cols → shape [2, 1]
let mean_v = t.clone().mean();           // scalar mean
let max_v  = t.clone().max();            // scalar max
let argmax = t.clone().argmax(1);        // indices of max along dim 1 → Int tensor

Note: sum_dim, mean_dim, max_dim etc. keep the reduced dimension (like PyTorch's keepdim=True). This is consistent everywhere in Burn.

Reshaping and dimension manipulation

let t = Tensor::<B, 1>::from_floats([1., 2., 3., 4., 5., 6.], &device);

let r = t.clone().reshape([2, 3]);       // 1-D → 2-D (2×3)
let f = r.clone().flatten(0, 1);         // 2-D → 1-D (flatten dims 0..1)

let u = r.clone().unsqueeze_dim(0);      // [2,3] → [1,2,3]
let s = u.squeeze_dim(0);                // [1,2,3] → [2,3]
let p = r.clone().swap_dims(0, 1);       // transpose: [2,3] → [3,2]
let t2 = r.transpose();                  // same as swap_dims(0,1) for 2-D

Concatenation and stacking

let a = Tensor::<B, 2>::from_data([[1.0, 2.0]], &device);   // [1, 2]
let b = Tensor::<B, 2>::from_data([[3.0, 4.0]], &device);   // [1, 2]

// cat: join along an EXISTING dimension
let catted = Tensor::cat(vec![a.clone(), b.clone()], 0);     // [2, 2]

// stack: join along a NEW dimension
let stacked = Tensor::stack(vec![a, b], 0);                  // [2, 1, 2]

Slicing

use burn::tensor::s;

let t = Tensor::<B, 2>::from_data(
    [[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
    &device,
);

let row0   = t.clone().slice([0..1, 0..3]);      // [[1, 2, 3]]
let col01  = t.clone().slice([0..3, 0..2]);       // [[1,2],[4,5],[7,8]]
let single = t.clone().slice([1..2, 1..2]);       // [[5]]

Comparisons → Bool tensors

let t = Tensor::<B, 1>::from_floats([1.0, 5.0, 3.0], &device);

let gt3: Tensor<B, 1, Bool> = t.clone().greater_elem(3.0);   // [false, true, false]
let eq5: Tensor<B, 1, Bool> = t.clone().equal_elem(5.0);     // [false, true, false]

Exercise 2

Write a function fn min_max_normalize(t: Tensor<B, 1>) -> Tensor<B, 1> that:

  1. Computes min and max of the input (remember to clone!).
  2. Returns (t - min) / (max - min).
  3. Test it with [2.0, 4.0, 6.0, 8.0, 10.0] — expected output: [0.0, 0.25, 0.5, 0.75, 1.0].

Then write a second function that takes a 2-D tensor and returns the softmax along dimension 1 using only basic ops: exp, sum_dim, and division. Compare a few values mentally.


Step 3 — Autodiff: Computing Gradients

How Burn does autodiff

Unlike PyTorch where autograd is baked into every tensor, Burn uses a backend decorator pattern. You wrap your base backend with Autodiff<...> and that transparently adds gradient tracking:

NdArray                → forward-only, no gradients
Autodiff<NdArray>      → same ops, but now you can call .backward()

This is a zero-cost abstraction: if you don't need gradients (inference), you don't pay for them.

The AutodiffBackend trait extends Backend with gradient-related methods. When you write training code, you constrain your generic with B: AutodiffBackend.

A first gradient computation

use burn::backend::ndarray::NdArray;
use burn::backend::Autodiff;
use burn::tensor::{Tensor, backend::AutodiffBackend};

type TrainBackend = Autodiff<NdArray>;

fn main() {
    let device = Default::default();

    // Create a tensor that will be tracked for gradients
    let x = Tensor::<TrainBackend, 1>::from_floats([1.0, 2.0, 3.0], &device)
        .require_grad();

    // Forward: compute y = x^2, then sum to get a scalar loss
    let y = x.clone().powf_scalar(2.0);   // [1, 4, 9]
    let loss = y.sum();                     // 14.0

    // Backward: compute gradients
    let grads = loss.backward();

    // Extract gradient of `loss` w.r.t. `x`
    // dy/dx = 2x → expected: [2.0, 4.0, 6.0]
    let x_grad = x.grad(&grads).expect("x should have a gradient");
    println!("x     = {}", x.to_data());
    println!("dx    = {}", x_grad.to_data());
}

The flow in detail

  1. require_grad() — marks a leaf tensor as "track me". Any tensor derived from it inherits tracking automatically.
  2. Forward pass — you build a computation graph by chaining normal tensor ops.
  3. .backward() — called on a scalar tensor (0-D or summed to 0-D). Returns a Gradients object containing all partial derivatives.
  4. .grad(&grads) — retrieves the gradient for a specific tensor. Returns Option<Tensor>.

model.valid() — disabling gradients

When you have a model on an Autodiff<...> backend and want to run inference without tracking:

// During training: model is on Autodiff<NdArray>
let model_valid = model.valid(); // returns the same model on plain NdArray
// Now forward passes do not build a graph → faster, less memory

This replaces PyTorch's with torch.no_grad(): and model.eval() (for the grad-tracking part).

Exercise 3

  1. Create a tensor w = [0.5, -0.3, 0.8] with require_grad() on Autodiff<NdArray>.
  2. Create an input x = [1.0, 2.0, 3.0] (no grad needed — think of it as data).
  3. Compute a "prediction" y_hat = (w * x).sum() (dot product).
  4. Define a "target" y = 2.5 and compute MSE loss: loss = (y_hat - y)^2.
  5. Call .backward() and print the gradient of w.
  6. Manually verify the gradient: d(loss)/dw_i = 2 * (y_hat - y) * x_i.

Step 4 — The Config Pattern

Why configs exist

In PyTorch, you create a layer like nn.Linear(784, 128, bias=True) — the constructor directly builds the parameters. Burn separates configuration from initialization into two steps:

LinearConfig::new(784, 128)      // 1. describe WHAT you want
    .with_bias(true)              //    (builder-style tweaks)
    .init(&device)                // 2. actually allocate weights on a device

This separation matters because:

  • Configs are serializable (save/load hyperparameters to JSON).
  • Configs are device-agnostic — you can build on CPU today, GPU tomorrow.
  • You can inspect or modify a config before committing to allocating memory.

The #[derive(Config)] macro

Burn's Config derive macro generates a builder pattern for you:

use burn::config::Config;

#[derive(Config)]
pub struct MyTrainingConfig {
    #[config(default = 64)]
    pub batch_size: usize,

    #[config(default = 1e-3)]
    pub learning_rate: f64,

    #[config(default = 10)]
    pub num_epochs: usize,

    pub model: MyModelConfig,    // nested config — no default, must be provided
}

This generates:

  • MyTrainingConfig::new(model: MyModelConfig) — constructor requiring non-default fields.
  • .with_batch_size(128) — optional builder method for each field with a default.
  • Serialization support (JSON via serde).

Saving and loading configs

// Save
let config = MyTrainingConfig::new(MyModelConfig::new(10, 512));
config.save("config.json").expect("failed to save config");

// Load
let loaded = MyTrainingConfig::load("config.json").expect("failed to load config");

Exercise 4

Define a Config struct called ExperimentConfig with:

  • hidden_size: usize (no default — required)
  • dropout: f64 (default 0.5)
  • activation: String (default "relu")
  • lr: f64 (default 1e-4)

Create an instance, modify dropout to 0.3 via the builder, save it to JSON, load it back, and assert the values match.


Step 5 — Defining Modules (Your First Neural Network)

The #[derive(Module)] macro

A Burn "module" is any struct that holds learnable parameters (or sub-modules that do). You annotate it with #[derive(Module)] and make it generic over Backend:

use burn::prelude::*;
use burn::nn::{Linear, LinearConfig, Relu};

#[derive(Module, Debug)]
pub struct Mlp<B: Backend> {
    linear1: Linear<B>,
    linear2: Linear<B>,
    activation: Relu,
}

The derive macro automatically implements the Module trait, which gives you:

  • Parameter iteration (for optimizers).
  • Device transfer (.to_device()).
  • Serialization (save/load weights).
  • .valid() to strip autodiff.
  • .clone() (cheap, reference-counted).

The Config → init pattern for your own module

Every module you write should have a companion config:

#[derive(Config)]
pub struct MlpConfig {
    input_dim: usize,
    hidden_dim: usize,
    output_dim: usize,
}

impl MlpConfig {
    pub fn init<B: Backend>(&self, device: &B::Device) -> Mlp<B> {
        Mlp {
            linear1: LinearConfig::new(self.input_dim, self.hidden_dim)
                .init(device),
            linear2: LinearConfig::new(self.hidden_dim, self.output_dim)
                .init(device),
            activation: Relu::new(),
        }
    }
}

The forward pass

By convention, you implement a forward method. This is NOT a trait method — it's just a convention. You can name it anything, or have multiple forward methods for different purposes.

impl<B: Backend> Mlp<B> {
    pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
        let x = self.linear1.forward(x);  // [batch, hidden]
        let x = self.activation.forward(x);
        self.linear2.forward(x)            // [batch, output]
    }
}

Putting it together

use burn::backend::NdArray;

fn main() {
    type B = NdArray;
    let device = Default::default();

    let config = MlpConfig::new(4, 32, 2);
    let model = config.init::<B>(&device);

    // Random input: batch of 8 samples, 4 features each
    let input = Tensor::<B, 2>::random(
        [8, 4],
        burn::tensor::Distribution::Normal(0.0, 1.0),
        &device,
    );

    let output = model.forward(input);
    println!("Output shape: {:?}", output.dims()); // [8, 2]
    println!("{:.4}", output);
}

Exercise 5

Build a 3-layer MLP from scratch:

  • Input: 10 features
  • Hidden 1: 64 neurons + ReLU
  • Hidden 2: 32 neurons + ReLU
  • Output: 3 neurons (raw logits for a 3-class problem)

Create the config, init the model, feed a random [16, 10] batch through it, and verify the output shape is [16, 3].


Step 6 — Built-in Layers Tour

The naming convention

Every built-in layer in burn::nn follows the same pattern:

ThingConfig::new(required_args)     // create config
    .with_optional_param(value)     // builder tweaks
    .init(&device)                  // → Thing<B>

The layer struct has a .forward(input) method.

Linear (fully connected)

use burn::nn::{Linear, LinearConfig};

let layer = LinearConfig::new(784, 128)
    .with_bias(true)       // default is true
    .init::<B>(&device);

// input: [batch, 784] → output: [batch, 128]
let out = layer.forward(input);

Conv2d

use burn::nn::conv::{Conv2d, Conv2dConfig};
use burn::nn::PaddingConfig2d;

let conv = Conv2dConfig::new([1, 32], [3, 3])   // [in_channels, out_channels], [kH, kW]
    .with_padding(PaddingConfig2d::Same)
    .with_stride([1, 1])
    .init::<B>(&device);

// input: [batch, channels, H, W] → output: [batch, 32, H, W]  (with Same padding)
let out = conv.forward(input_4d);

BatchNorm

use burn::nn::{BatchNorm, BatchNormConfig};

let bn = BatchNormConfig::new(32)   // num_features (channels)
    .init::<B>(&device);

// input: [batch, 32, H, W] → same shape, normalized
let out = bn.forward(input);

Dropout

use burn::nn::{Dropout, DropoutConfig};

let dropout = DropoutConfig::new(0.5).init();

// During training, randomly zeros elements.
// Note: Dropout is stateless, no Backend generic needed.
let out = dropout.forward(input);

Burn's dropout is automatically disabled during inference when you use model.valid().

Pooling

use burn::nn::pool::{
    MaxPool2d, MaxPool2dConfig,
    AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig,
};

let maxpool = MaxPool2dConfig::new([2, 2])
    .with_stride([2, 2])
    .init();

let adaptive = AdaptiveAvgPool2dConfig::new([1, 1]).init();

let out = maxpool.forward(input);     // halves spatial dims
let out = adaptive.forward(input);    // spatial → 1×1

Embedding

use burn::nn::{Embedding, EmbeddingConfig};

let embed = EmbeddingConfig::new(10000, 256)  // vocab_size, embed_dim
    .init::<B>(&device);

// input: Int tensor [batch, seq_len] → output: [batch, seq_len, 256]
let out = embed.forward(token_ids);

Activations (stateless functions)

Most activations are free functions in burn::tensor::activation, not module structs:

use burn::tensor::activation;

let x = Tensor::<B, 2>::random([4, 4], Distribution::Normal(0.0, 1.0), &device);

let _ = activation::relu(x.clone());
let _ = activation::sigmoid(x.clone());
let _ = activation::gelu(x.clone());
let _ = activation::softmax(x.clone(), 1);     // softmax along dim 1
let _ = activation::log_softmax(x.clone(), 1);

There are also module wrappers like Relu, Gelu, etc. that just call these functions. Use whichever style you prefer.

Exercise 6

Build a small CNN module for 1-channel 28×28 images (think MNIST):

  • Conv2d: 1 → 16 channels, 3×3 kernel, same padding → ReLU → MaxPool 2×2
  • Conv2d: 16 → 32 channels, 3×3 kernel, same padding → ReLU → AdaptiveAvgPool 1×1
  • Flatten → Linear(32, 10)

Create the config, init the model, feed a random [4, 1, 28, 28] tensor, verify output is [4, 10].


Step 7 — Loss Functions

Built-in losses

Burn provides losses in burn::nn::loss:

use burn::nn::loss::{CrossEntropyLoss, CrossEntropyLossConfig, MseLoss};

Cross-entropy (for classification):

let loss_fn = CrossEntropyLossConfig::new()
    .init(&device);

// logits: [batch, num_classes]   — raw model output, NOT softmax'd
// targets: [batch]               — Int tensor of class indices
let loss: Tensor<B, 1> = loss_fn.forward(logits, targets);

Note: CrossEntropyLoss expects logits (raw scores) by default, not probabilities. It applies log-softmax internally, just like PyTorch's nn.CrossEntropyLoss.

MSE loss (for regression):

let mse = MseLoss::new();
// Both tensors must have the same shape.
let loss = mse.forward(predictions, targets, burn::nn::loss::Reduction::Mean);

Writing a custom loss

A loss is just a function that takes tensors and returns a scalar tensor. There's nothing special about it:

/// Binary cross-entropy from logits (custom implementation).
fn bce_with_logits<B: Backend>(logits: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
    // sigmoid + log trick for numerical stability:
    // loss = max(logits, 0) - logits * targets + log(1 + exp(-|logits|))
    let zeros = Tensor::zeros_like(&logits);
    let max_val = logits.clone().max_pair(zeros);
    let abs_logits = logits.clone().abs();

    let loss = max_val - logits * targets
        + ((-abs_logits).exp() + 1.0).log();

    loss.mean()
}

Exercise 7

  1. Create a random logits tensor [8, 5] (batch=8, 5 classes) and a random targets Int tensor [8] with values in 0..5.
  2. Compute cross-entropy loss using CrossEntropyLossConfig.
  3. Write your own manual_cross_entropy function that:
    • Computes log-softmax: log_softmax(x, dim=1) via activation::log_softmax.
    • Gathers the log-prob at the target index for each sample (use a loop or Tensor::select + some reshaping).
    • Returns the negative mean.
  4. Compare the two values — they should be very close (use check_closeness or just eyeball the printed scalars).

Step 8 — Datasets & Batching

The Dataset trait

Burn's dataset system lives in burn::data::dataset:

pub trait Dataset<I>: Send + Sync {
    fn get(&self, index: usize) -> Option<I>;
    fn len(&self) -> usize;
    fn is_empty(&self) -> bool { self.len() == 0 }
}

Simple — it's a random-access collection.

InMemDataset

The simplest dataset — a Vec in memory:

use burn::data::dataset::InMemDataset;

#[derive(Clone, Debug)]
struct Sample {
    features: [f32; 4],
    label: u8,
}

let samples = vec![
    Sample { features: [1.0, 2.0, 3.0, 4.0], label: 0 },
    Sample { features: [5.0, 6.0, 7.0, 8.0], label: 1 },
    // ... more
];

let dataset = InMemDataset::new(samples);
println!("Dataset length: {}", dataset.len());
println!("Sample 0: {:?}", dataset.get(0));

The Batcher trait

A batcher converts a Vec<Item> into a batch struct containing tensors:

use burn::data::dataloader::batcher::Batcher;

#[derive(Clone, Debug)]
struct MyBatch<B: Backend> {
    inputs:  Tensor<B, 2>,       // [batch, features]
    targets: Tensor<B, 1, Int>,  // [batch]
}

#[derive(Clone)]
struct MyBatcher<B: Backend> {
    device: B::Device,
}

impl<B: Backend> Batcher<Sample, MyBatch<B>> for MyBatcher<B> {
    fn batch(&self, items: Vec<Sample>) -> MyBatch<B> {
        let inputs: Vec<Tensor<B, 2>> = items.iter()
            .map(|s| Tensor::<B, 1>::from_floats(s.features, &self.device))
            .map(|t| t.unsqueeze_dim(0))   // [4] → [1, 4]
            .collect();

        let targets: Vec<Tensor<B, 1, Int>> = items.iter()
            .map(|s| Tensor::<B, 1, Int>::from_data(
                TensorData::from([s.label as i32]),
                &self.device,
            ))
            .collect();

        MyBatch {
            inputs: Tensor::cat(inputs, 0),     // [batch, 4]
            targets: Tensor::cat(targets, 0),   // [batch]
        }
    }
}

DataLoaderBuilder

use burn::data::dataloader::DataLoaderBuilder;

let batcher = MyBatcher::<B> { device: device.clone() };
let dataloader = DataLoaderBuilder::new(batcher)
    .batch_size(32)
    .shuffle(42)       // seed for shuffling
    .num_workers(4)    // parallel data loading threads
    .build(dataset);

for batch in dataloader.iter() {
    println!("Batch inputs: {:?}", batch.inputs.dims());
    println!("Batch targets: {:?}", batch.targets.dims());
}

Exercise 8

Create a synthetic XOR-like dataset:

  • 200 samples of [x1, x2] where x1, x2 are random in {0, 1}.
  • Label is x1 XOR x2 (0 or 1).
  • Implement the Sample struct, populate an InMemDataset, write a Batcher, and build a DataLoader with batch size 16.
  • Loop over one epoch and print each batch's input shape and target shape.

Step 9 — Optimizers

How optimizers work in Burn

In PyTorch, you register parameters with the optimizer, call loss.backward(), then optimizer.step(), then optimizer.zero_grad(). Burn is more functional:

let grads = loss.backward();                              // 1. get raw gradients
let grads = GradientsParams::from_grads(grads, &model);   // 2. map to parameter IDs
model = optim.step(lr, model, grads);                      // 3. optimizer consumes grads, returns updated model

No zero_grad — gradients are consumed in step 3. No mutation — step takes ownership of the model and returns an updated one.

Creating an optimizer

use burn::optim::{AdamConfig, SgdConfig, Adam, Sgd};

// Adam (most common)
let optim = AdamConfig::new()
    .with_beta_1(0.9)
    .with_beta_2(0.999)
    .with_epsilon(1e-8)
    .with_weight_decay_config(None)
    .init();

// SGD with momentum
let optim = SgdConfig::new()
    .with_momentum(Some(burn::optim::MomentumConfig {
        momentum: 0.9,
        dampening: 0.0,
        nesterov: false,
    }))
    .init();

The optimizer is generic — it works with any module on any AutodiffBackend.

Gradient accumulation

If you need to accumulate gradients over multiple mini-batches before stepping:

use burn::optim::GradientsAccumulator;

let mut accumulator = GradientsAccumulator::new();

for batch in mini_batches {
    let loss = model.forward(batch);
    let grads = loss.backward();
    let grads = GradientsParams::from_grads(grads, &model);
    accumulator.accumulate(&model, grads);
}

let accumulated = accumulator.grads();
model = optim.step(lr, model, accumulated);

Exercise 9

Using the XOR dataset from Exercise 8 and the 3-layer MLP from Exercise 5 (adapt dimensions to input=2, hidden=16, output=2):

  1. Create an Adam optimizer.
  2. Run one batch through the model.
  3. Compute cross-entropy loss.
  4. Call .backward()GradientsParams::from_gradsoptim.step.
  5. Print the loss before and after the step (run the forward pass again) to confirm it changed.

Don't write the full loop yet — just verify the single-step mechanics compile and run.


Step 10 — The Full Training Loop (Manual)

Wiring everything together

This is the capstone: a complete, working training loop with no Learner abstraction. We'll train an MLP on the XOR problem.

use burn::backend::ndarray::NdArray;
use burn::backend::Autodiff;
use burn::data::dataloader::DataLoaderBuilder;
use burn::data::dataset::InMemDataset;
use burn::nn::loss::{CrossEntropyLoss, CrossEntropyLossConfig};
use burn::optim::{AdamConfig, GradientsParams};
use burn::prelude::*;
use burn::tensor::backend::AutodiffBackend;

// ---- Backend aliases ----
type TrainB = Autodiff<NdArray>;
type InferB = NdArray;

// ---- Model (reuse your MLP, adapted) ----
use burn::nn::{Linear, LinearConfig, Relu};

#[derive(Module, Debug)]
pub struct XorModel<B: Backend> {
    fc1: Linear<B>,
    fc2: Linear<B>,
    fc3: Linear<B>,
    relu: Relu,
}

#[derive(Config)]
pub struct XorModelConfig {
    #[config(default = 2)]
    input: usize,
    #[config(default = 16)]
    hidden: usize,
    #[config(default = 2)]
    output: usize,
}

impl XorModelConfig {
    pub fn init<B: Backend>(&self, device: &B::Device) -> XorModel<B> {
        XorModel {
            fc1: LinearConfig::new(self.input, self.hidden).init(device),
            fc2: LinearConfig::new(self.hidden, self.hidden).init(device),
            fc3: LinearConfig::new(self.hidden, self.output).init(device),
            relu: Relu::new(),
        }
    }
}

impl<B: Backend> XorModel<B> {
    pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
        let x = self.relu.forward(self.fc1.forward(x));
        let x = self.relu.forward(self.fc2.forward(x));
        self.fc3.forward(x)
    }
}

// ---- Data ----
#[derive(Clone, Debug)]
struct XorSample {
    x: [f32; 2],
    y: i32,
}

#[derive(Clone)]
struct XorBatcher<B: Backend> {
    device: B::Device,
}

impl<B: Backend> burn::data::dataloader::batcher::Batcher<XorSample, (Tensor<B, 2>, Tensor<B, 1, Int>)>
    for XorBatcher<B>
{
    fn batch(&self, items: Vec<XorSample>) -> (Tensor<B, 2>, Tensor<B, 1, Int>) {
        let xs: Vec<Tensor<B, 2>> = items.iter()
            .map(|s| Tensor::<B, 1>::from_floats(s.x, &self.device).unsqueeze_dim(0))
            .collect();
        let ys: Vec<Tensor<B, 1, Int>> = items.iter()
            .map(|s| Tensor::<B, 1, Int>::from_data(TensorData::from([s.y]), &self.device))
            .collect();
        (Tensor::cat(xs, 0), Tensor::cat(ys, 0))
    }
}

fn make_xor_dataset() -> InMemDataset<XorSample> {
    let mut samples = Vec::new();
    // Repeat patterns many times so the network has enough data
    for _ in 0..200 {
        samples.push(XorSample { x: [0.0, 0.0], y: 0 });
        samples.push(XorSample { x: [0.0, 1.0], y: 1 });
        samples.push(XorSample { x: [1.0, 0.0], y: 1 });
        samples.push(XorSample { x: [1.0, 1.0], y: 0 });
    }
    InMemDataset::new(samples)
}

// ---- Training loop ----
fn train() {
    let device = Default::default();

    // Config
    let lr = 1e-3;
    let num_epochs = 50;
    let batch_size = 32;

    // Model + optimizer
    let model_config = XorModelConfig::new();
    let mut model: XorModel<TrainB> = model_config.init(&device);
    let mut optim = AdamConfig::new().init();

    // Data
    let dataset = make_xor_dataset();
    let batcher = XorBatcher::<TrainB> { device: device.clone() };
    let dataloader = DataLoaderBuilder::new(batcher)
        .batch_size(batch_size)
        .shuffle(42)
        .build(dataset);

    // Loss function
    let loss_fn = CrossEntropyLossConfig::new().init(&device);

    // Train!
    for epoch in 1..=num_epochs {
        let mut epoch_loss = 0.0;
        let mut n_batches = 0;

        for (inputs, targets) in dataloader.iter() {
            // Forward
            let logits = model.forward(inputs);
            let loss = loss_fn.forward(logits.clone(), targets.clone());

            // Backward + optimize
            let grads = loss.backward();
            let grads = GradientsParams::from_grads(grads, &model);
            model = optim.step(lr, model, grads);

            epoch_loss += loss.into_scalar().elem::<f32>();
            n_batches += 1;
        }

        if epoch % 10 == 0 || epoch == 1 {
            println!(
                "Epoch {:>3} | Loss: {:.4}",
                epoch,
                epoch_loss / n_batches as f32,
            );
        }
    }

    // ---- Quick inference check ----
    let model_valid = model.valid();  // drop autodiff → runs on NdArray

    let test_input = Tensor::<InferB, 2>::from_data(
        [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]],
        &device,
    );
    let predictions = model_valid.forward(test_input);
    let classes = predictions.argmax(1);
    println!("\nPredictions (should be 0, 1, 1, 0):");
    println!("{}", classes.to_data());
}

fn main() {
    train();
}

Key points to internalize

  • model is reassigned on every optim.step(...) — Burn is functional: the old model is consumed, a new updated model is returned.
  • loss.backward() returns raw gradients. You must wrap them with GradientsParams::from_grads to associate them with parameter IDs.
  • No zero_grad — gradients are consumed by step.
  • model.valid() strips the Autodiff wrapper. The returned model lives on the inner backend (NdArray). Use it for validation/inference.
  • loss.into_scalar() gives you a backend-specific element. Use .elem::<f32>() to get a plain Rust f32 for printing.

Exercise 10

Take the training loop above and extend it:

  1. Add a validation step after each epoch: call model.valid(), run the 4 XOR inputs through it, compute accuracy (percentage of correct predictions).
  2. Add early stopping: if validation accuracy hits 100%, break out of the epoch loop.
  3. Experiment: reduce the hidden size to 4 — does it still learn? Increase to 64 — does it learn faster?

Step 11 — Saving & Loading Models

Records and Recorders

Burn saves models as "records" using "recorders". A record is the serializable state of a module (all its parameter tensors). A recorder is a strategy for how to store that record (binary, MessagePack, etc.).

Saving

use burn::record::{CompactRecorder, NamedMpkFileRecorder, FullPrecisionSettings};

let recorder = CompactRecorder::new();

// Save — writes to "my_model.bin" (extension added automatically)
model
    .save_file("my_model", &recorder)
    .expect("Failed to save model");

CompactRecorder uses a compact binary format. NamedMpkFileRecorder uses named MessagePack, which is more portable and human-debuggable.

Loading

// First, create a fresh model with random weights
let mut model = model_config.init::<B>(&device);

// Then load saved weights into it
let recorder = CompactRecorder::new();
model = model
    .load_file("my_model", &recorder, &device)
    .expect("Failed to load model");

Note: the model architecture must match exactly — same layers, same sizes.

A complete save/load cycle

fn save_model<B: Backend>(model: &XorModel<B>) {
    let recorder = CompactRecorder::new();
    model
        .save_file("xor_trained", &recorder)
        .expect("save failed");
    println!("Model saved.");
}

fn load_model<B: Backend>(device: &B::Device) -> XorModel<B> {
    let model = XorModelConfig::new().init::<B>(device);
    let recorder = CompactRecorder::new();
    model
        .load_file("xor_trained", &recorder, device)
        .expect("load failed")
}

// In your training function, after training:
save_model(&model_valid);

// Later, for inference:
let loaded: XorModel<InferB> = load_model(&device);
let output = loaded.forward(test_input);

Checkpoint during training

A common pattern is to save every N epochs:

for epoch in 1..=num_epochs {
    // ... training ...

    if epoch % 10 == 0 {
        let snapshot = model.valid(); // strip autodiff for saving
        save_model(&snapshot);
        println!("Checkpoint saved at epoch {}", epoch);
    }
}

Exercise 11

  1. Train the XOR model from Step 10 until it converges.
  2. Save it with CompactRecorder.
  3. In a separate function (simulating a new program run), load the model on NdArray and run inference on all 4 XOR inputs.
  4. Verify the predictions match what you got at the end of training.

Bonus: also try NamedMpkFileRecorder::<FullPrecisionSettings> and compare the file sizes.


Quick Reference Card

Concept Burn PyTorch equivalent
Backend selection type B = NdArray; device selection
Tensor creation Tensor::<B,2>::from_data(...) torch.tensor(...)
Autodiff Autodiff<NdArray> backend wrapper built-in autograd
Disable grad model.valid() torch.no_grad()
Module definition #[derive(Module)] nn.Module subclass
Layer construction LinearConfig::new(a,b).init(&dev) nn.Linear(a,b)
Hyperparameters #[derive(Config)] manual / dataclass
Loss CrossEntropyLossConfig::new().init(&dev) nn.CrossEntropyLoss()
Optimizer AdamConfig::new().init() optim.Adam(params, lr)
Backward loss.backward() loss.backward()
Param gradients GradientsParams::from_grads(grads, &model) automatic
Optimizer step model = optim.step(lr, model, grads) optimizer.step()
Zero grad automatic (grads consumed) optimizer.zero_grad()
Save model model.save_file("path", &recorder) torch.save(state_dict)
Load model model.load_file("path", &recorder, &dev) model.load_state_dict(...)

Where to go next

You now have the vocabulary and mechanics to read and write Burn code for deep learning practicals. The topics we deliberately skipped (for a future chapter) include:

  • The Learner abstraction — Burn's built-in training loop with metrics, checkpointing, and a TUI dashboard.
  • GPU backendsWgpu, Cuda, LibTorch for actual hardware acceleration.
  • ONNX import — loading models from PyTorch/TensorFlow.
  • Custom GPU kernels — writing CubeCL compute shaders.
  • Quantization — post-training quantization for deployment.
  • no_std deployment — running Burn on bare-metal embedded (where NdArray also works).

For now, go do your deep learning exercises in Burn. You have the tools.