37 KiB
Chapter 1: Burn Foundations — Tensors, Modules & the Training Loop
Burn version: 0.20.x · Backend: NdArray (CPU) only · Assumed knowledge: proficient Rust, ML fundamentals (scikit-learn level), studying deep learning theory.
Cargo.toml you should already have:
[dependencies] burn = { version = "0.20", features = ["ndarray"] }
Step 1 — Tensor Basics: Creation, Shape & Kinds
What you need to know
Everything in Burn starts with Tensor. Unlike ndarray or nalgebra, Burn's tensor is generic
over three things:
Tensor<B, D, K>
│ │ └─ Kind: Float (default), Int, Bool
│ └──── const dimensionality (rank) — a usize, NOT the shape
└─────── Backend (NdArray, Wgpu, …)
The shape (e.g. 3×4) is a runtime value. The rank (e.g. 2) is a compile-time constant. This means the compiler catches dimension-mismatch bugs for you — you cannot accidentally add a 2-D tensor to a 3-D tensor.
The NdArray backend is Burn's pure-Rust CPU backend built on top of the ndarray crate. It needs
no GPU drivers, compiles everywhere (including no_std embedded), and is the simplest backend to
start with.
Creating tensors
use burn::backend::NdArray;
use burn::tensor::{Tensor, TensorData, Int, Bool};
// Type alias — used everywhere in this chapter.
type B = NdArray;
fn main() {
let device = Default::default(); // NdArray only has a Cpu device
// ---- Float tensors (default kind) ----
// From a nested array literal — easiest way for concrete backends:
let a = Tensor::<B, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
// from_floats — recommended shorthand for f32:
let b = Tensor::<B, 1>::from_floats([10.0, 20.0, 30.0], &device);
// Factories:
let zeros = Tensor::<B, 2>::zeros([3, 4], &device); // 3×4 of 0.0
let ones = Tensor::<B, 2>::ones([2, 2], &device); // 2×2 of 1.0
let full = Tensor::<B, 1>::full([5], 3.14, &device); // [3.14; 5]
// Same shape as another tensor:
let like_a = Tensor::<B, 2>::ones_like(&a); // 2×2 of 1.0
let like_z = Tensor::<B, 2>::zeros_like(&zeros); // 3×4 of 0.0
// Random (uniform or normal):
use burn::tensor::Distribution;
let uniform = Tensor::<B, 2>::random([4, 4], Distribution::Uniform(0.0, 1.0), &device);
let normal = Tensor::<B, 2>::random([4, 4], Distribution::Normal(0.0, 1.0), &device);
// ---- Int tensors ----
let labels = Tensor::<B, 1, Int>::from_data(TensorData::from([0i32, 1, 2, 1]), &device);
let range = Tensor::<B, 1, Int>::arange(0..5, &device); // [0, 1, 2, 3, 4]
// ---- Bool tensors ----
let mask = a.clone().greater_elem(2.0); // Tensor<B, 2, Bool> — true where > 2
println!("{}", a);
println!("{}", labels);
println!("{}", mask);
}
Extracting data back out
// .to_data() clones — tensor is still usable afterward
let data: TensorData = a.to_data();
// .into_data() moves — tensor is consumed
let data: TensorData = a.into_data();
// Get a single scalar
let scalar: f32 = Tensor::<B, 1>::from_floats([42.0], &device)
.into_scalar();
Printing
println!("{}", tensor); // full debug view with shape, backend, dtype
println!("{:.2}", tensor); // limit to 2 decimal places
Exercise 1
Write a program that:
- Creates a 3×3 identity-like float tensor manually with
from_data(1s on diagonal, 0s elsewhere). - Creates a 1-D Int tensor
[10, 20, 30]. - Creates a Bool tensor by checking which elements of the float tensor are greater than 0.
- Prints all three tensors.
Success criterion: compiles, prints the expected values, you understand every generic parameter.
Step 2 — Ownership, Cloning & Tensor Operations
Ownership: the one Burn rule you must internalize
Almost every Burn tensor operation takes ownership (moves) of the input tensor. This is different from PyTorch where tensors are reference-counted behind the scenes.
let x = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0], &device);
let y = x.exp(); // x is MOVED into exp(), you cannot use x anymore
// println!("{}", x); // ← compile error: value used after move
The fix is .clone(). Cloning a Burn tensor is cheap — it only bumps a reference count, it
does not copy the underlying buffer.
let x = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0], &device);
let y = x.clone().exp(); // x is still alive
let z = x.clone().log(); // x is still alive
let w = x + y; // x is consumed here — that's fine, last use
Rule of thumb: clone whenever you need the tensor again later; don't clone on its last use. Burn will automatically do in-place operations when it detects a single owner.
Arithmetic
All standard ops are supported and work element-wise with broadcasting:
let a = Tensor::<B, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
let b = Tensor::<B, 2>::from_data([[10.0, 20.0], [30.0, 40.0]], &device);
let sum = a.clone() + b.clone(); // element-wise add
let diff = a.clone() - b.clone(); // element-wise sub
let prod = a.clone() * b.clone(); // element-wise mul
let quot = a.clone() / b.clone(); // element-wise div
// Scalar operations
let scaled = a.clone() * 2.0; // multiply every element by 2
let shifted = a.clone() + 1.0; // add 1 to every element
let neg = -a.clone(); // negate
Float-specific math
let x = Tensor::<B, 1>::from_floats([0.0, 1.0, 2.0], &device);
let _ = x.clone().exp(); // e^x
let _ = x.clone().log(); // ln(x)
let _ = x.clone().sqrt(); // √x
let _ = x.clone().sin(); // sin(x)
let _ = x.clone().tanh(); // tanh(x)
let _ = x.clone().powf_scalar(2.0); // x²
let _ = x.clone().recip(); // 1/x
let _ = x.clone().abs();
let _ = x.clone().ceil();
let _ = x.clone().floor();
Matrix multiplication
let a = Tensor::<B, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
let b = Tensor::<B, 2>::from_data([[5.0, 6.0], [7.0, 8.0]], &device);
let c = a.matmul(b); // standard matrix multiply — [2,2] × [2,2] → [2,2]
println!("{}", c);
matmul works on any rank ≥ 2 with batch-broadcasting semantics (just like PyTorch's @).
Reductions
let t = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
let total = t.clone().sum(); // scalar: 21.0
let col = t.clone().sum_dim(0); // sum along rows → shape [1, 3]
let row = t.clone().sum_dim(1); // sum along cols → shape [2, 1]
let mean_v = t.clone().mean(); // scalar mean
let max_v = t.clone().max(); // scalar max
let argmax = t.clone().argmax(1); // indices of max along dim 1 → Int tensor
Note: sum_dim, mean_dim, max_dim etc. keep the reduced dimension (like PyTorch's
keepdim=True). This is consistent everywhere in Burn.
Reshaping and dimension manipulation
let t = Tensor::<B, 1>::from_floats([1., 2., 3., 4., 5., 6.], &device);
let r = t.clone().reshape([2, 3]); // 1-D → 2-D (2×3)
let f = r.clone().flatten(0, 1); // 2-D → 1-D (flatten dims 0..1)
let u = r.clone().unsqueeze_dim(0); // [2,3] → [1,2,3]
let s = u.squeeze_dim(0); // [1,2,3] → [2,3]
let p = r.clone().swap_dims(0, 1); // transpose: [2,3] → [3,2]
let t2 = r.transpose(); // same as swap_dims(0,1) for 2-D
Concatenation and stacking
let a = Tensor::<B, 2>::from_data([[1.0, 2.0]], &device); // [1, 2]
let b = Tensor::<B, 2>::from_data([[3.0, 4.0]], &device); // [1, 2]
// cat: join along an EXISTING dimension
let catted = Tensor::cat(vec![a.clone(), b.clone()], 0); // [2, 2]
// stack: join along a NEW dimension
let stacked = Tensor::stack(vec![a, b], 0); // [2, 1, 2]
Slicing
use burn::tensor::s;
let t = Tensor::<B, 2>::from_data(
[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
&device,
);
let row0 = t.clone().slice([0..1, 0..3]); // [[1, 2, 3]]
let col01 = t.clone().slice([0..3, 0..2]); // [[1,2],[4,5],[7,8]]
let single = t.clone().slice([1..2, 1..2]); // [[5]]
Comparisons → Bool tensors
let t = Tensor::<B, 1>::from_floats([1.0, 5.0, 3.0], &device);
let gt3: Tensor<B, 1, Bool> = t.clone().greater_elem(3.0); // [false, true, false]
let eq5: Tensor<B, 1, Bool> = t.clone().equal_elem(5.0); // [false, true, false]
Exercise 2
Write a function fn min_max_normalize(t: Tensor<B, 1>) -> Tensor<B, 1> that:
- Computes min and max of the input (remember to clone!).
- Returns
(t - min) / (max - min). - Test it with
[2.0, 4.0, 6.0, 8.0, 10.0]— expected output:[0.0, 0.25, 0.5, 0.75, 1.0].
Then write a second function that takes a 2-D tensor and returns the softmax along dimension 1
using only basic ops: exp, sum_dim, and division. Compare a few values mentally.
Step 3 — Autodiff: Computing Gradients
How Burn does autodiff
Unlike PyTorch where autograd is baked into every tensor, Burn uses a backend decorator pattern.
You wrap your base backend with Autodiff<...> and that transparently adds gradient tracking:
NdArray → forward-only, no gradients
Autodiff<NdArray> → same ops, but now you can call .backward()
This is a zero-cost abstraction: if you don't need gradients (inference), you don't pay for them.
The AutodiffBackend trait extends Backend with gradient-related methods. When you write
training code, you constrain your generic with B: AutodiffBackend.
A first gradient computation
use burn::backend::ndarray::NdArray;
use burn::backend::Autodiff;
use burn::tensor::{Tensor, backend::AutodiffBackend};
type TrainBackend = Autodiff<NdArray>;
fn main() {
let device = Default::default();
// Create a tensor that will be tracked for gradients
let x = Tensor::<TrainBackend, 1>::from_floats([1.0, 2.0, 3.0], &device)
.require_grad();
// Forward: compute y = x^2, then sum to get a scalar loss
let y = x.clone().powf_scalar(2.0); // [1, 4, 9]
let loss = y.sum(); // 14.0
// Backward: compute gradients
let grads = loss.backward();
// Extract gradient of `loss` w.r.t. `x`
// dy/dx = 2x → expected: [2.0, 4.0, 6.0]
let x_grad = x.grad(&grads).expect("x should have a gradient");
println!("x = {}", x.to_data());
println!("dx = {}", x_grad.to_data());
}
The flow in detail
require_grad()— marks a leaf tensor as "track me". Any tensor derived from it inherits tracking automatically.- Forward pass — you build a computation graph by chaining normal tensor ops.
.backward()— called on a scalar tensor (0-D or summed to 0-D). Returns aGradientsobject containing all partial derivatives..grad(&grads)— retrieves the gradient for a specific tensor. ReturnsOption<Tensor>.
model.valid() — disabling gradients
When you have a model on an Autodiff<...> backend and want to run inference without tracking:
// During training: model is on Autodiff<NdArray>
let model_valid = model.valid(); // returns the same model on plain NdArray
// Now forward passes do not build a graph → faster, less memory
This replaces PyTorch's with torch.no_grad(): and model.eval() (for the grad-tracking part).
Exercise 3
- Create a tensor
w = [0.5, -0.3, 0.8]withrequire_grad()onAutodiff<NdArray>. - Create an input
x = [1.0, 2.0, 3.0](no grad needed — think of it as data). - Compute a "prediction"
y_hat = (w * x).sum()(dot product). - Define a "target"
y = 2.5and compute MSE loss:loss = (y_hat - y)^2. - Call
.backward()and print the gradient ofw. - Manually verify the gradient:
d(loss)/dw_i = 2 * (y_hat - y) * x_i.
Step 4 — The Config Pattern
Why configs exist
In PyTorch, you create a layer like nn.Linear(784, 128, bias=True) — the constructor directly
builds the parameters. Burn separates configuration from initialization into two steps:
LinearConfig::new(784, 128) // 1. describe WHAT you want
.with_bias(true) // (builder-style tweaks)
.init(&device) // 2. actually allocate weights on a device
This separation matters because:
- Configs are serializable (save/load hyperparameters to JSON).
- Configs are device-agnostic — you can build on CPU today, GPU tomorrow.
- You can inspect or modify a config before committing to allocating memory.
The #[derive(Config)] macro
Burn's Config derive macro generates a builder pattern for you:
use burn::config::Config;
#[derive(Config)]
pub struct MyTrainingConfig {
#[config(default = 64)]
pub batch_size: usize,
#[config(default = 1e-3)]
pub learning_rate: f64,
#[config(default = 10)]
pub num_epochs: usize,
pub model: MyModelConfig, // nested config — no default, must be provided
}
This generates:
MyTrainingConfig::new(model: MyModelConfig)— constructor requiring non-default fields..with_batch_size(128)— optional builder method for each field with a default.- Serialization support (JSON via serde).
Saving and loading configs
// Save
let config = MyTrainingConfig::new(MyModelConfig::new(10, 512));
config.save("config.json").expect("failed to save config");
// Load
let loaded = MyTrainingConfig::load("config.json").expect("failed to load config");
Exercise 4
Define a Config struct called ExperimentConfig with:
hidden_size: usize(no default — required)dropout: f64(default 0.5)activation: String(default"relu")lr: f64(default 1e-4)
Create an instance, modify dropout to 0.3 via the builder, save it to JSON, load it back, and
assert the values match.
Step 5 — Defining Modules (Your First Neural Network)
The #[derive(Module)] macro
A Burn "module" is any struct that holds learnable parameters (or sub-modules that do). You annotate
it with #[derive(Module)] and make it generic over Backend:
use burn::prelude::*;
use burn::nn::{Linear, LinearConfig, Relu};
#[derive(Module, Debug)]
pub struct Mlp<B: Backend> {
linear1: Linear<B>,
linear2: Linear<B>,
activation: Relu,
}
The derive macro automatically implements the Module trait, which gives you:
- Parameter iteration (for optimizers).
- Device transfer (
.to_device()). - Serialization (save/load weights).
.valid()to strip autodiff..clone()(cheap, reference-counted).
The Config → init pattern for your own module
Every module you write should have a companion config:
#[derive(Config)]
pub struct MlpConfig {
input_dim: usize,
hidden_dim: usize,
output_dim: usize,
}
impl MlpConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Mlp<B> {
Mlp {
linear1: LinearConfig::new(self.input_dim, self.hidden_dim)
.init(device),
linear2: LinearConfig::new(self.hidden_dim, self.output_dim)
.init(device),
activation: Relu::new(),
}
}
}
The forward pass
By convention, you implement a forward method. This is NOT a trait method — it's just a
convention. You can name it anything, or have multiple forward methods for different purposes.
impl<B: Backend> Mlp<B> {
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.linear1.forward(x); // [batch, hidden]
let x = self.activation.forward(x);
self.linear2.forward(x) // [batch, output]
}
}
Putting it together
use burn::backend::NdArray;
fn main() {
type B = NdArray;
let device = Default::default();
let config = MlpConfig::new(4, 32, 2);
let model = config.init::<B>(&device);
// Random input: batch of 8 samples, 4 features each
let input = Tensor::<B, 2>::random(
[8, 4],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = model.forward(input);
println!("Output shape: {:?}", output.dims()); // [8, 2]
println!("{:.4}", output);
}
Exercise 5
Build a 3-layer MLP from scratch:
- Input: 10 features
- Hidden 1: 64 neurons + ReLU
- Hidden 2: 32 neurons + ReLU
- Output: 3 neurons (raw logits for a 3-class problem)
Create the config, init the model, feed a random [16, 10] batch through it, and verify the output
shape is [16, 3].
Step 6 — Built-in Layers Tour
The naming convention
Every built-in layer in burn::nn follows the same pattern:
ThingConfig::new(required_args) // create config
.with_optional_param(value) // builder tweaks
.init(&device) // → Thing<B>
The layer struct has a .forward(input) method.
Linear (fully connected)
use burn::nn::{Linear, LinearConfig};
let layer = LinearConfig::new(784, 128)
.with_bias(true) // default is true
.init::<B>(&device);
// input: [batch, 784] → output: [batch, 128]
let out = layer.forward(input);
Conv2d
use burn::nn::conv::{Conv2d, Conv2dConfig};
use burn::nn::PaddingConfig2d;
let conv = Conv2dConfig::new([1, 32], [3, 3]) // [in_channels, out_channels], [kH, kW]
.with_padding(PaddingConfig2d::Same)
.with_stride([1, 1])
.init::<B>(&device);
// input: [batch, channels, H, W] → output: [batch, 32, H, W] (with Same padding)
let out = conv.forward(input_4d);
BatchNorm
use burn::nn::{BatchNorm, BatchNormConfig};
let bn = BatchNormConfig::new(32) // num_features (channels)
.init::<B>(&device);
// input: [batch, 32, H, W] → same shape, normalized
let out = bn.forward(input);
Dropout
use burn::nn::{Dropout, DropoutConfig};
let dropout = DropoutConfig::new(0.5).init();
// During training, randomly zeros elements.
// Note: Dropout is stateless, no Backend generic needed.
let out = dropout.forward(input);
Burn's dropout is automatically disabled during inference when you use model.valid().
Pooling
use burn::nn::pool::{
MaxPool2d, MaxPool2dConfig,
AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig,
};
let maxpool = MaxPool2dConfig::new([2, 2])
.with_stride([2, 2])
.init();
let adaptive = AdaptiveAvgPool2dConfig::new([1, 1]).init();
let out = maxpool.forward(input); // halves spatial dims
let out = adaptive.forward(input); // spatial → 1×1
Embedding
use burn::nn::{Embedding, EmbeddingConfig};
let embed = EmbeddingConfig::new(10000, 256) // vocab_size, embed_dim
.init::<B>(&device);
// input: Int tensor [batch, seq_len] → output: [batch, seq_len, 256]
let out = embed.forward(token_ids);
Activations (stateless functions)
Most activations are free functions in burn::tensor::activation, not module structs:
use burn::tensor::activation;
let x = Tensor::<B, 2>::random([4, 4], Distribution::Normal(0.0, 1.0), &device);
let _ = activation::relu(x.clone());
let _ = activation::sigmoid(x.clone());
let _ = activation::gelu(x.clone());
let _ = activation::softmax(x.clone(), 1); // softmax along dim 1
let _ = activation::log_softmax(x.clone(), 1);
There are also module wrappers like Relu, Gelu, etc. that just call these functions.
Use whichever style you prefer.
Exercise 6
Build a small CNN module for 1-channel 28×28 images (think MNIST):
- Conv2d: 1 → 16 channels, 3×3 kernel, same padding → ReLU → MaxPool 2×2
- Conv2d: 16 → 32 channels, 3×3 kernel, same padding → ReLU → AdaptiveAvgPool 1×1
- Flatten → Linear(32, 10)
Create the config, init the model, feed a random [4, 1, 28, 28] tensor, verify output is
[4, 10].
Step 7 — Loss Functions
Built-in losses
Burn provides losses in burn::nn::loss:
use burn::nn::loss::{CrossEntropyLoss, CrossEntropyLossConfig, MseLoss};
Cross-entropy (for classification):
let loss_fn = CrossEntropyLossConfig::new()
.init(&device);
// logits: [batch, num_classes] — raw model output, NOT softmax'd
// targets: [batch] — Int tensor of class indices
let loss: Tensor<B, 1> = loss_fn.forward(logits, targets);
Note: CrossEntropyLoss expects logits (raw scores) by default, not probabilities. It applies
log-softmax internally, just like PyTorch's nn.CrossEntropyLoss.
MSE loss (for regression):
let mse = MseLoss::new();
// Both tensors must have the same shape.
let loss = mse.forward(predictions, targets, burn::nn::loss::Reduction::Mean);
Writing a custom loss
A loss is just a function that takes tensors and returns a scalar tensor. There's nothing special about it:
/// Binary cross-entropy from logits (custom implementation).
fn bce_with_logits<B: Backend>(logits: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
// sigmoid + log trick for numerical stability:
// loss = max(logits, 0) - logits * targets + log(1 + exp(-|logits|))
let zeros = Tensor::zeros_like(&logits);
let max_val = logits.clone().max_pair(zeros);
let abs_logits = logits.clone().abs();
let loss = max_val - logits * targets
+ ((-abs_logits).exp() + 1.0).log();
loss.mean()
}
Exercise 7
- Create a random logits tensor
[8, 5](batch=8, 5 classes) and a random targets Int tensor[8]with values in 0..5. - Compute cross-entropy loss using
CrossEntropyLossConfig. - Write your own
manual_cross_entropyfunction that:- Computes log-softmax:
log_softmax(x, dim=1)viaactivation::log_softmax. - Gathers the log-prob at the target index for each sample (use a loop or
Tensor::select+ some reshaping). - Returns the negative mean.
- Computes log-softmax:
- Compare the two values — they should be very close (use
check_closenessor just eyeball the printed scalars).
Step 8 — Datasets & Batching
The Dataset trait
Burn's dataset system lives in burn::data::dataset:
pub trait Dataset<I>: Send + Sync {
fn get(&self, index: usize) -> Option<I>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool { self.len() == 0 }
}
Simple — it's a random-access collection.
InMemDataset
The simplest dataset — a Vec in memory:
use burn::data::dataset::InMemDataset;
#[derive(Clone, Debug)]
struct Sample {
features: [f32; 4],
label: u8,
}
let samples = vec![
Sample { features: [1.0, 2.0, 3.0, 4.0], label: 0 },
Sample { features: [5.0, 6.0, 7.0, 8.0], label: 1 },
// ... more
];
let dataset = InMemDataset::new(samples);
println!("Dataset length: {}", dataset.len());
println!("Sample 0: {:?}", dataset.get(0));
The Batcher trait
A batcher converts a Vec<Item> into a batch struct containing tensors:
use burn::data::dataloader::batcher::Batcher;
#[derive(Clone, Debug)]
struct MyBatch<B: Backend> {
inputs: Tensor<B, 2>, // [batch, features]
targets: Tensor<B, 1, Int>, // [batch]
}
#[derive(Clone)]
struct MyBatcher<B: Backend> {
device: B::Device,
}
impl<B: Backend> Batcher<Sample, MyBatch<B>> for MyBatcher<B> {
fn batch(&self, items: Vec<Sample>) -> MyBatch<B> {
let inputs: Vec<Tensor<B, 2>> = items.iter()
.map(|s| Tensor::<B, 1>::from_floats(s.features, &self.device))
.map(|t| t.unsqueeze_dim(0)) // [4] → [1, 4]
.collect();
let targets: Vec<Tensor<B, 1, Int>> = items.iter()
.map(|s| Tensor::<B, 1, Int>::from_data(
TensorData::from([s.label as i32]),
&self.device,
))
.collect();
MyBatch {
inputs: Tensor::cat(inputs, 0), // [batch, 4]
targets: Tensor::cat(targets, 0), // [batch]
}
}
}
DataLoaderBuilder
use burn::data::dataloader::DataLoaderBuilder;
let batcher = MyBatcher::<B> { device: device.clone() };
let dataloader = DataLoaderBuilder::new(batcher)
.batch_size(32)
.shuffle(42) // seed for shuffling
.num_workers(4) // parallel data loading threads
.build(dataset);
for batch in dataloader.iter() {
println!("Batch inputs: {:?}", batch.inputs.dims());
println!("Batch targets: {:?}", batch.targets.dims());
}
Exercise 8
Create a synthetic XOR-like dataset:
- 200 samples of
[x1, x2]where x1, x2 are random in {0, 1}. - Label is
x1 XOR x2(0 or 1). - Implement the
Samplestruct, populate anInMemDataset, write aBatcher, and build aDataLoaderwith batch size 16. - Loop over one epoch and print each batch's input shape and target shape.
Step 9 — Optimizers
How optimizers work in Burn
In PyTorch, you register parameters with the optimizer, call loss.backward(), then
optimizer.step(), then optimizer.zero_grad(). Burn is more functional:
let grads = loss.backward(); // 1. get raw gradients
let grads = GradientsParams::from_grads(grads, &model); // 2. map to parameter IDs
model = optim.step(lr, model, grads); // 3. optimizer consumes grads, returns updated model
No zero_grad — gradients are consumed in step 3. No mutation — step takes ownership of the
model and returns an updated one.
Creating an optimizer
use burn::optim::{AdamConfig, SgdConfig, Adam, Sgd};
// Adam (most common)
let optim = AdamConfig::new()
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_epsilon(1e-8)
.with_weight_decay_config(None)
.init();
// SGD with momentum
let optim = SgdConfig::new()
.with_momentum(Some(burn::optim::MomentumConfig {
momentum: 0.9,
dampening: 0.0,
nesterov: false,
}))
.init();
The optimizer is generic — it works with any module on any AutodiffBackend.
Gradient accumulation
If you need to accumulate gradients over multiple mini-batches before stepping:
use burn::optim::GradientsAccumulator;
let mut accumulator = GradientsAccumulator::new();
for batch in mini_batches {
let loss = model.forward(batch);
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
accumulator.accumulate(&model, grads);
}
let accumulated = accumulator.grads();
model = optim.step(lr, model, accumulated);
Exercise 9
Using the XOR dataset from Exercise 8 and the 3-layer MLP from Exercise 5 (adapt dimensions to input=2, hidden=16, output=2):
- Create an Adam optimizer.
- Run one batch through the model.
- Compute cross-entropy loss.
- Call
.backward()→GradientsParams::from_grads→optim.step. - Print the loss before and after the step (run the forward pass again) to confirm it changed.
Don't write the full loop yet — just verify the single-step mechanics compile and run.
Step 10 — The Full Training Loop (Manual)
Wiring everything together
This is the capstone: a complete, working training loop with no Learner abstraction. We'll train
an MLP on the XOR problem.
use burn::backend::ndarray::NdArray;
use burn::backend::Autodiff;
use burn::data::dataloader::DataLoaderBuilder;
use burn::data::dataset::InMemDataset;
use burn::nn::loss::{CrossEntropyLoss, CrossEntropyLossConfig};
use burn::optim::{AdamConfig, GradientsParams};
use burn::prelude::*;
use burn::tensor::backend::AutodiffBackend;
// ---- Backend aliases ----
type TrainB = Autodiff<NdArray>;
type InferB = NdArray;
// ---- Model (reuse your MLP, adapted) ----
use burn::nn::{Linear, LinearConfig, Relu};
#[derive(Module, Debug)]
pub struct XorModel<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
fc3: Linear<B>,
relu: Relu,
}
#[derive(Config)]
pub struct XorModelConfig {
#[config(default = 2)]
input: usize,
#[config(default = 16)]
hidden: usize,
#[config(default = 2)]
output: usize,
}
impl XorModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> XorModel<B> {
XorModel {
fc1: LinearConfig::new(self.input, self.hidden).init(device),
fc2: LinearConfig::new(self.hidden, self.hidden).init(device),
fc3: LinearConfig::new(self.hidden, self.output).init(device),
relu: Relu::new(),
}
}
}
impl<B: Backend> XorModel<B> {
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.relu.forward(self.fc1.forward(x));
let x = self.relu.forward(self.fc2.forward(x));
self.fc3.forward(x)
}
}
// ---- Data ----
#[derive(Clone, Debug)]
struct XorSample {
x: [f32; 2],
y: i32,
}
#[derive(Clone)]
struct XorBatcher<B: Backend> {
device: B::Device,
}
impl<B: Backend> burn::data::dataloader::batcher::Batcher<XorSample, (Tensor<B, 2>, Tensor<B, 1, Int>)>
for XorBatcher<B>
{
fn batch(&self, items: Vec<XorSample>) -> (Tensor<B, 2>, Tensor<B, 1, Int>) {
let xs: Vec<Tensor<B, 2>> = items.iter()
.map(|s| Tensor::<B, 1>::from_floats(s.x, &self.device).unsqueeze_dim(0))
.collect();
let ys: Vec<Tensor<B, 1, Int>> = items.iter()
.map(|s| Tensor::<B, 1, Int>::from_data(TensorData::from([s.y]), &self.device))
.collect();
(Tensor::cat(xs, 0), Tensor::cat(ys, 0))
}
}
fn make_xor_dataset() -> InMemDataset<XorSample> {
let mut samples = Vec::new();
// Repeat patterns many times so the network has enough data
for _ in 0..200 {
samples.push(XorSample { x: [0.0, 0.0], y: 0 });
samples.push(XorSample { x: [0.0, 1.0], y: 1 });
samples.push(XorSample { x: [1.0, 0.0], y: 1 });
samples.push(XorSample { x: [1.0, 1.0], y: 0 });
}
InMemDataset::new(samples)
}
// ---- Training loop ----
fn train() {
let device = Default::default();
// Config
let lr = 1e-3;
let num_epochs = 50;
let batch_size = 32;
// Model + optimizer
let model_config = XorModelConfig::new();
let mut model: XorModel<TrainB> = model_config.init(&device);
let mut optim = AdamConfig::new().init();
// Data
let dataset = make_xor_dataset();
let batcher = XorBatcher::<TrainB> { device: device.clone() };
let dataloader = DataLoaderBuilder::new(batcher)
.batch_size(batch_size)
.shuffle(42)
.build(dataset);
// Loss function
let loss_fn = CrossEntropyLossConfig::new().init(&device);
// Train!
for epoch in 1..=num_epochs {
let mut epoch_loss = 0.0;
let mut n_batches = 0;
for (inputs, targets) in dataloader.iter() {
// Forward
let logits = model.forward(inputs);
let loss = loss_fn.forward(logits.clone(), targets.clone());
// Backward + optimize
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(lr, model, grads);
epoch_loss += loss.into_scalar().elem::<f32>();
n_batches += 1;
}
if epoch % 10 == 0 || epoch == 1 {
println!(
"Epoch {:>3} | Loss: {:.4}",
epoch,
epoch_loss / n_batches as f32,
);
}
}
// ---- Quick inference check ----
let model_valid = model.valid(); // drop autodiff → runs on NdArray
let test_input = Tensor::<InferB, 2>::from_data(
[[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]],
&device,
);
let predictions = model_valid.forward(test_input);
let classes = predictions.argmax(1);
println!("\nPredictions (should be 0, 1, 1, 0):");
println!("{}", classes.to_data());
}
fn main() {
train();
}
Key points to internalize
modelis reassigned on everyoptim.step(...)— Burn is functional: the old model is consumed, a new updated model is returned.loss.backward()returns raw gradients. You must wrap them withGradientsParams::from_gradsto associate them with parameter IDs.- No
zero_grad— gradients are consumed bystep. model.valid()strips theAutodiffwrapper. The returned model lives on the inner backend (NdArray). Use it for validation/inference.loss.into_scalar()gives you a backend-specific element. Use.elem::<f32>()to get a plain Rustf32for printing.
Exercise 10
Take the training loop above and extend it:
- Add a validation step after each epoch: call
model.valid(), run the 4 XOR inputs through it, compute accuracy (percentage of correct predictions). - Add early stopping: if validation accuracy hits 100%, break out of the epoch loop.
- Experiment: reduce the hidden size to 4 — does it still learn? Increase to 64 — does it learn faster?
Step 11 — Saving & Loading Models
Records and Recorders
Burn saves models as "records" using "recorders". A record is the serializable state of a module (all its parameter tensors). A recorder is a strategy for how to store that record (binary, MessagePack, etc.).
Saving
use burn::record::{CompactRecorder, NamedMpkFileRecorder, FullPrecisionSettings};
let recorder = CompactRecorder::new();
// Save — writes to "my_model.bin" (extension added automatically)
model
.save_file("my_model", &recorder)
.expect("Failed to save model");
CompactRecorder uses a compact binary format. NamedMpkFileRecorder uses named MessagePack,
which is more portable and human-debuggable.
Loading
// First, create a fresh model with random weights
let mut model = model_config.init::<B>(&device);
// Then load saved weights into it
let recorder = CompactRecorder::new();
model = model
.load_file("my_model", &recorder, &device)
.expect("Failed to load model");
Note: the model architecture must match exactly — same layers, same sizes.
A complete save/load cycle
fn save_model<B: Backend>(model: &XorModel<B>) {
let recorder = CompactRecorder::new();
model
.save_file("xor_trained", &recorder)
.expect("save failed");
println!("Model saved.");
}
fn load_model<B: Backend>(device: &B::Device) -> XorModel<B> {
let model = XorModelConfig::new().init::<B>(device);
let recorder = CompactRecorder::new();
model
.load_file("xor_trained", &recorder, device)
.expect("load failed")
}
// In your training function, after training:
save_model(&model_valid);
// Later, for inference:
let loaded: XorModel<InferB> = load_model(&device);
let output = loaded.forward(test_input);
Checkpoint during training
A common pattern is to save every N epochs:
for epoch in 1..=num_epochs {
// ... training ...
if epoch % 10 == 0 {
let snapshot = model.valid(); // strip autodiff for saving
save_model(&snapshot);
println!("Checkpoint saved at epoch {}", epoch);
}
}
Exercise 11
- Train the XOR model from Step 10 until it converges.
- Save it with
CompactRecorder. - In a separate function (simulating a new program run), load the model on
NdArrayand run inference on all 4 XOR inputs. - Verify the predictions match what you got at the end of training.
Bonus: also try NamedMpkFileRecorder::<FullPrecisionSettings> and compare the file sizes.
Quick Reference Card
| Concept | Burn | PyTorch equivalent |
|---|---|---|
| Backend selection | type B = NdArray; |
device selection |
| Tensor creation | Tensor::<B,2>::from_data(...) |
torch.tensor(...) |
| Autodiff | Autodiff<NdArray> backend wrapper |
built-in autograd |
| Disable grad | model.valid() |
torch.no_grad() |
| Module definition | #[derive(Module)] |
nn.Module subclass |
| Layer construction | LinearConfig::new(a,b).init(&dev) |
nn.Linear(a,b) |
| Hyperparameters | #[derive(Config)] |
manual / dataclass |
| Loss | CrossEntropyLossConfig::new().init(&dev) |
nn.CrossEntropyLoss() |
| Optimizer | AdamConfig::new().init() |
optim.Adam(params, lr) |
| Backward | loss.backward() |
loss.backward() |
| Param gradients | GradientsParams::from_grads(grads, &model) |
automatic |
| Optimizer step | model = optim.step(lr, model, grads) |
optimizer.step() |
| Zero grad | automatic (grads consumed) | optimizer.zero_grad() |
| Save model | model.save_file("path", &recorder) |
torch.save(state_dict) |
| Load model | model.load_file("path", &recorder, &dev) |
model.load_state_dict(...) |
Where to go next
You now have the vocabulary and mechanics to read and write Burn code for deep learning practicals. The topics we deliberately skipped (for a future chapter) include:
- The
Learnerabstraction — Burn's built-in training loop with metrics, checkpointing, and a TUI dashboard. - GPU backends —
Wgpu,Cuda,LibTorchfor actual hardware acceleration. - ONNX import — loading models from PyTorch/TensorFlow.
- Custom GPU kernels — writing CubeCL compute shaders.
- Quantization — post-training quantization for deployment.
no_stddeployment — running Burn on bare-metal embedded (where NdArray also works).
For now, go do your deep learning exercises in Burn. You have the tools.