learner used instead of manual version

This commit is contained in:
Priec
2026-03-13 21:52:37 +01:00
parent f6b9d79062
commit 8fc8addcac
8 changed files with 835 additions and 203 deletions

View File

@@ -1,17 +1,28 @@
use burn::tensor::Tensor;
use burn::optim::Optimizer;
use burn::nn::loss::CrossEntropyLossConfig;
use burn::tensor::Int;
use burn::tensor::backend::AutodiffBackend;
use burn::tensor::backend::Backend;
use burn::optim::GradientsParams;
use burn::tensor::activation;
use std::str::FromStr;
// src/lib.rs
use burn::config::Config;
use burn::data::dataloader::DataLoaderBuilder;
use burn::data::dataloader::batcher::Batcher;
use burn::data::dataset::Dataset;
use burn::module::Module;
use burn::nn::loss::CrossEntropyLossConfig;
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamConfig;
use burn::record::CompactRecorder;
use burn::tensor::activation;
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::{Int, Tensor};
use burn::train::metric::{AccuracyMetric, LossMetric};
use burn::lr_scheduler::constant::ConstantLr;
use burn::train::{
ClassificationOutput, InferenceStep, Learner, SupervisedTraining,
TrainOutput, TrainStep, TrainingStrategy,
};
use std::str::FromStr;
pub type B = burn_autodiff::Autodiff<burn_ndarray::NdArray<f64>>;
// Model
#[derive(Module, Debug)]
pub struct MnistClassifier<B: Backend> {
hidden: Vec<Linear<B>>,
@@ -19,7 +30,7 @@ pub struct MnistClassifier<B: Backend> {
activation: Activation,
}
impl<B: Backend<FloatElem = f64, IntElem = i64>> MnistClassifier<B> {
impl<B: Backend> MnistClassifier<B> {
pub fn new(
device: &B::Device,
hidden_layers: usize,
@@ -27,123 +38,61 @@ impl<B: Backend<FloatElem = f64, IntElem = i64>> MnistClassifier<B> {
activation: Activation,
) -> Self {
let mut hidden = Vec::new();
let mut current_input_size = 784;
if hidden_layers > 0 {
hidden.push(LinearConfig::new(current_input_size, hidden_layer_size).init(device));
current_input_size = hidden_layer_size;
let mut in_size = 784;
for _ in 1..hidden_layers {
hidden.push(LinearConfig::new(hidden_layer_size, hidden_layer_size).init(device));
}
for _ in 0..hidden_layers {
hidden.push(LinearConfig::new(in_size, hidden_layer_size).init(device));
in_size = hidden_layer_size;
}
let output = LinearConfig::new(current_input_size, 10).init(device);
let output = LinearConfig::new(in_size, 10).init(device);
Self { hidden, output, activation }
}
pub fn forward(&self, images: Tensor<B, 2>) -> Tensor<B, 2> {
let mut result = images;
let mut x = images;
for layer in &self.hidden {
result = layer.forward(result);
result = self.activation.forward(result);
}
self.output.forward(result)
}
pub fn train_step(
&self,
images: Tensor<B, 2>,
labels: Tensor<B, 1, Int>,
optimizer: &mut impl Optimizer<Self, B>,
lr: f64
) -> (Self, f64, usize) where B: AutodiffBackend {
// Forward pass
let logits = self.forward(images);
// Loss calculation
let loss_fn = CrossEntropyLossConfig::new().init(&logits.device());
let loss = loss_fn.forward(logits.clone(), labels.clone());
// Accuracy
let correct = logits.argmax(1)
.flatten::<1>(0, 1)
.equal(labels)
.int()
.sum()
.into_scalar() as usize;
let loss_val = loss.clone().into_scalar();
// Backprop
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, self);
let updated_model = optimizer.step(lr, self.clone(), grads);
(updated_model, loss_val, correct)
}
pub fn train_and_evaluate(
&mut self,
images: Tensor<B, 2>,
labels: Tensor<B, 1, Int>,
optimizer: &mut impl Optimizer<Self, B>,
args_epochs: usize,
args_batch_size: usize,
) where B: AutodiffBackend {
eprintln!("images shape: {:?}", images.shape());
eprintln!("labels shape: {:?}", labels.shape());
let train_size = 50000;
let x_train = images.clone().slice([0..train_size]);
let y_train = labels.clone().slice([0..train_size]);
let x_dev = images.slice([train_size..55000]);
let y_dev = labels.slice([train_size..55000]);
let target_epochs = [1, 5, 10];
for epoch in target_epochs {
let start = std::time::Instant::now();
let mut train_loss = 0.0;
let mut train_correct = 0;
for i in (0..train_size).step_by(args_batch_size) {
let end = (i + args_batch_size).min(train_size);
if i >= end { continue; }
let b_x = x_train.clone().slice([i..end]);
let b_y = y_train.clone().slice([i..end]);
if i == 0 {
eprintln!("first batch shape: {:?}", b_x.shape());
eprintln!("output layer: input={:?} output=10", self.output.weight.shape());
}
let (updated_model, loss_val, correct) = self.train_step(b_x, b_y, optimizer, 1e-3);
*self = updated_model;
train_loss += loss_val;
train_correct += correct;
}
// Dev metrics
let dev_logits = self.forward(x_dev.clone());
let loss_fn = CrossEntropyLossConfig::new().init(&dev_logits.device());
let dev_loss = loss_fn.forward(dev_logits.clone(), y_dev.clone()).into_scalar();
let dev_acc = dev_logits.argmax(1).flatten::<1>(0, 1).equal(y_dev.clone()).int().sum().into_scalar() as f64 / 5000.0;
println!(
"Epoch {:2}/{} {:.1}s loss={:.4} accuracy={:.4} dev:loss={:.4} dev:accuracy={:.4}",
epoch, args_epochs, start.elapsed().as_secs_f32(),
train_loss / (train_size as f64 / args_batch_size as f64),
train_correct as f64 / train_size as f64,
dev_loss, dev_acc
);
x = layer.forward(x);
x = self.activation.forward(x);
}
self.output.forward(x)
}
}
impl<B: AutodiffBackend> TrainStep for MnistClassifier<B> {
type Input = MnistBatch<B>;
type Output = ClassificationOutput<B>;
#[derive(Debug, Clone, Copy, Module, Default)]
fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
let output = self.forward(batch.images);
let loss = CrossEntropyLossConfig::new()
.init(&output.device())
.forward(output.clone(), batch.targets.clone());
TrainOutput::new(
self,
loss.backward(),
ClassificationOutput { loss, output, targets: batch.targets },
)
}
}
impl<B: Backend> InferenceStep for MnistClassifier<B> {
type Input = MnistBatch<B>;
type Output = ClassificationOutput<B>;
fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
let output = self.forward(batch.images);
let loss = CrossEntropyLossConfig::new()
.init(&output.device())
.forward(output.clone(), batch.targets.clone());
ClassificationOutput { loss, output, targets: batch.targets }
}
}
// Activation
#[derive(Debug, Clone, Copy, Module, Default, serde::Serialize, serde::Deserialize)]
pub enum Activation {
#[default]
None,
@@ -156,11 +105,11 @@ impl FromStr for Activation {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"none" => Ok(Activation::None),
"relu" => Ok(Activation::ReLU),
"tanh" => Ok(Activation::Tanh),
"none" => Ok(Activation::None),
"relu" => Ok(Activation::ReLU),
"tanh" => Ok(Activation::Tanh),
"sigmoid" => Ok(Activation::Sigmoid),
_ => Err(format!("Unknown activation: {}", s)),
_ => Err(format!("Unknown activation: {}", s)),
}
}
}
@@ -168,10 +117,148 @@ impl FromStr for Activation {
impl Activation {
pub fn forward<B: Backend, const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
match self {
Activation::None => x,
Activation::ReLU => activation::relu(x),
Activation::Tanh => activation::tanh(x),
Activation::None => x,
Activation::ReLU => activation::relu(x),
Activation::Tanh => activation::tanh(x),
Activation::Sigmoid => activation::sigmoid(x),
}
}
}
// Dataset & Batch
#[derive(Clone, Debug)]
pub struct MnistItem {
pub image: [f64; 784],
pub label: u8,
}
pub struct MnistDataset {
items: Vec<MnistItem>,
}
impl MnistDataset {
pub fn new(items: Vec<MnistItem>) -> Self {
Self { items }
}
}
impl Dataset<MnistItem> for MnistDataset {
fn get(&self, index: usize) -> Option<MnistItem> {
self.items.get(index).cloned()
}
fn len(&self) -> usize {
self.items.len()
}
}
#[derive(Clone, Debug)]
pub struct MnistBatch<B: Backend> {
pub images: Tensor<B, 2>,
pub targets: Tensor<B, 1, Int>,
}
#[derive(Clone)]
pub struct MnistBatcher;
impl MnistBatcher {
pub fn new() -> Self {
Self
}
}
impl<B: Backend<FloatElem = f64, IntElem = i64>> Batcher<B, MnistItem, MnistBatch<B>>
for MnistBatcher
{
fn batch(&self, items: Vec<MnistItem>, device: &B::Device) -> MnistBatch<B> {
let n = items.len();
let image_data: Vec<f64> = items.iter().flat_map(|i| i.image).collect();
let label_data: Vec<i64> = items.iter().map(|i| i.label as i64).collect();
let images = Tensor::<B, 2>::from_data(
burn::tensor::TensorData::new(image_data, [n, 784]),
device, // ← use the passed-in device, not self.device
);
let targets = Tensor::<B, 1, Int>::from_data(
burn::tensor::TensorData::new(label_data, [n]),
device,
);
MnistBatch { images, targets }
}
}
// Config
#[derive(Config, Debug)]
pub struct MnistModelConfig {
pub hidden_layers: usize,
pub hidden_layer_size: usize,
pub activation: Activation,
}
impl MnistModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> MnistClassifier<B> {
MnistClassifier::new(device, self.hidden_layers, self.hidden_layer_size, self.activation)
}
}
#[derive(Config, Debug)]
pub struct MnistTrainingConfig {
pub model: MnistModelConfig,
pub optimizer: AdamConfig,
#[config(default = 10)]
pub num_epochs: usize,
#[config(default = 64)]
pub batch_size: usize,
#[config(default = 4)]
pub num_workers: usize,
#[config(default = 42)]
pub seed: u64,
#[config(default = 1.0e-4)]
pub learning_rate: f64,
}
// Training
impl MnistTrainingConfig {
pub fn train<B>(
&self,
device: B::Device,
train_dataset: MnistDataset,
valid_dataset: MnistDataset,
) where
B: AutodiffBackend<FloatElem = f64, IntElem = i64>,
B::InnerBackend: Backend<FloatElem = f64, IntElem = i64>,
{
B::seed(&device, self.seed);
let model = self.model.init::<B>(&device);
let optim = self.optimizer.init();
let batcher_train = MnistBatcher::new();
let batcher_valid = MnistBatcher::new();
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(self.batch_size)
.shuffle(self.seed)
.num_workers(self.num_workers)
.build(train_dataset);
let dataloader_valid = DataLoaderBuilder::new(batcher_valid)
.batch_size(self.batch_size)
.num_workers(self.num_workers)
.build(valid_dataset);
let training = SupervisedTraining::new("/tmp/artifacts", dataloader_train, dataloader_valid)
.metrics((AccuracyMetric::new(), LossMetric::new()))
.with_file_checkpointer(CompactRecorder::new())
.num_epochs(self.num_epochs)
.summary()
.with_training_strategy(TrainingStrategy::SingleDevice(device));
let _result = training.launch(Learner::new(
model,
optim,
ConstantLr::new(self.learning_rate), // plain float → constant LR scheduler
));
}
}

View File

@@ -1,114 +1,83 @@
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use clap::Parser;
use hod_1::B;
use hod_1::{Activation, MnistDataset, MnistItem, MnistModelConfig, MnistTrainingConfig, B};
use burn::optim::AdamConfig;
use std::fs::File;
use std::io::Read;
use std::str::FromStr;
use hod_1::*;
use burn::optim::AdamConfig;
use burn::optim::Optimizer;
use burn::nn::loss::CrossEntropyLossConfig;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
#[command(author, version, about)]
struct Args {
#[arg(long = "activation", default_value = "none")]
#[arg(long, default_value = "none")]
activation: String,
#[arg(long = "batch_size", default_value = "50")]
#[arg(long, default_value = "64")]
batch_size: usize,
#[arg(long = "epochs", default_value = "10")]
#[arg(long, default_value = "10")]
epochs: usize,
#[arg(long = "hidden_layer_size", default_value = "100")]
#[arg(long, default_value = "100")]
hidden_layer_size: usize,
#[arg(long = "hidden_layers", default_value = "1")]
#[arg(long, default_value = "1")]
hidden_layers: usize,
#[arg(long = "seed", default_value = "42")]
#[arg(long, default_value = "42")]
seed: u64,
#[arg(long = "threads", default_value = "1")]
threads: usize,
/// Fraction of training data used for validation (e.g. 0.1 = 10 %)
#[arg(long, default_value = "0.1")]
valid_split: f64,
}
/// Load MNIST images and labels for training.
/// Returns (images [N, 784], labels [N]) where labels are class indices 0-9.
fn load_mnist_labeled(
examples: usize,
device: &<B as Backend>::Device,
) -> (Tensor<B, 2>, Tensor<B, 1, burn::tensor::Int>) {
fn load_mnist_items(examples: usize) -> Vec<MnistItem> {
let file = File::open("mnist.npz").expect("Cannot open mnist.npz");
let mut archive = zip::ZipArchive::new(file).expect("Cannot read zip");
// Load images
let image_candidates = [
"train_images.npy",
"train.images.npy",
"x_train.npy",
"images.npy",
];
// images
let image_candidates = ["train_images.npy", "train.images.npy", "x_train.npy", "images.npy"];
let mut image_bytes = Vec::new();
let mut found_images = false;
for name in &image_candidates {
if let Ok(mut entry) = archive.by_name(name) {
entry.read_to_end(&mut image_bytes).expect("Failed to read images");
found_images = true;
entry.read_to_end(&mut image_bytes).expect("read images");
break;
}
}
assert!(found_images, "Could not find train images in mnist.npz");
assert!(!image_bytes.is_empty(), "Could not find train images in mnist.npz");
// Load labels
let label_candidates = [
"train_labels.npy",
"train.labels.npy",
"y_train.npy",
"labels.npy",
];
// labels
let label_candidates = ["train_labels.npy", "train.labels.npy", "y_train.npy", "labels.npy"];
let mut label_bytes = Vec::new();
let mut found_labels = false;
for name in &label_candidates {
if let Ok(mut entry) = archive.by_name(name) {
entry.read_to_end(&mut label_bytes).expect("Failed to read labels");
found_labels = true;
entry.read_to_end(&mut label_bytes).expect("read labels");
break;
}
}
assert!(found_labels, "Could not find train labels in mnist.npz");
assert!(!label_bytes.is_empty(), "Could not find train labels in mnist.npz");
// Parse images
let image_npy = npyz::NpyFile::new(&image_bytes[..]).expect("Cannot parse images npy");
// parse
let image_npy = npyz::NpyFile::new(&image_bytes[..]).expect("parse images");
let image_shape = image_npy.shape().to_vec();
let image_raw: Vec<u8> = image_npy.into_vec().expect("Failed to read images as u8");
let image_raw: Vec<u8> = image_npy.into_vec().expect("images to vec");
let n = examples.min(image_shape[0] as usize);
let pixels = image_raw.len() / image_shape[0] as usize;
let pixels = image_raw.len() / image_shape[0] as usize; // should be 784
assert_eq!(pixels, 784, "Expected 784 pixels per image, got {pixels}");
let image_data: Vec<f64> = image_raw[..n * pixels]
.iter()
.map(|&p| p as f64 / 255.0)
.collect();
let label_npy = npyz::NpyFile::new(&label_bytes[..]).expect("parse labels");
let label_raw: Vec<u8> = label_npy.into_vec().expect("labels to vec");
let image_tensor_data = burn::tensor::TensorData::new(image_data, [n, pixels]);
let images = Tensor::<B, 2>::from_data(image_tensor_data, device);
// Parse labels
let label_npy = npyz::NpyFile::new(&label_bytes[..]).expect("Cannot parse labels npy");
let label_raw: Vec<u8> = label_npy.into_vec().expect("Failed to read labels as u8");
let label_data: Vec<i64> = label_raw[..n]
.iter()
.map(|&p| p as i64)
.collect();
let label_tensor_data = burn::tensor::TensorData::new(label_data, [n]);
let labels = Tensor::<B, 1, burn::tensor::Int>::from_data(label_tensor_data, device);
(images, labels)
// build items
(0..n)
.map(|i| {
let mut image = [0f64; 784];
for (j, &px) in image_raw[i * 784..(i + 1) * 784].iter().enumerate() {
image[j] = px as f64 / 255.0;
}
MnistItem { image, label: label_raw[i] }
})
.collect()
}
fn main() {
@@ -116,18 +85,31 @@ fn main() {
let device = burn_ndarray::NdArrayDevice::Cpu;
let activation = Activation::from_str(&args.activation).unwrap_or_default();
let mut model = MnistClassifier::<B>::new(
&device,
args.hidden_layers,
args.hidden_layer_size,
activation,
);
println!("Loading MNIST…");
let all_items = load_mnist_items(60_000);
let mut optim = AdamConfig::new().init::<B, MnistClassifier<B>>();
let (images, labels) = load_mnist_labeled(60000, &device);
// Split into train / validation
let valid_n = (all_items.len() as f64 * args.valid_split) as usize;
let train_n = all_items.len() - valid_n;
let mut items = all_items;
let valid_items = items.split_off(train_n); // last `valid_n` items
let train_items = items;
println!("Starting training...");
println!("Train: {} Valid: {}", train_items.len(), valid_items.len());
// Main just tells the model to run the process
model.train_and_evaluate(images, labels, &mut optim, args.epochs, args.batch_size);
let train_dataset = MnistDataset::new(train_items);
let valid_dataset = MnistDataset::new(valid_items);
let config = MnistTrainingConfig::new(
MnistModelConfig::new(args.hidden_layers, args.hidden_layer_size, activation),
AdamConfig::new(),
)
.with_num_epochs(args.epochs)
.with_batch_size(args.batch_size)
.with_num_workers(1) // NdArray backend is single-threaded; keep at 1
.with_seed(args.seed)
.with_learning_rate(1e-3);
println!("Starting training…");
config.train::<B>(device, train_dataset, valid_dataset);
}