diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e1c6da3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*/target/ +*/mnist.npz diff --git a/hod_1/Cargo.lock b/hod_1/Cargo.lock index 1b5d829..2e70aae 100644 --- a/hod_1/Cargo.lock +++ b/hod_1/Cargo.lock @@ -274,11 +274,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b78ff10ed98b73e1d477ea6e6e1ec1b9cf9f71a17afc3fea9f4dca482d43dcd4" dependencies = [ "burn-autodiff", + "burn-candle", "burn-core", + "burn-cpu", + "burn-cuda", "burn-ndarray", "burn-nn", "burn-optim", + "burn-rocm", + "burn-router", + "burn-store", "burn-train", + "burn-wgpu", ] [[package]] @@ -317,6 +324,18 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "burn-candle" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21c752d5008923eb9299783da5edae3242b94afdb956e88d2b37b025244b5071" +dependencies = [ + "burn-backend", + "burn-std", + "candle-core", + "derive-new", +] + [[package]] name = "burn-core" version = "0.20.1" @@ -347,6 +366,64 @@ dependencies = [ "uuid", ] +[[package]] +name = "burn-cpu" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60aa53c4536719f1c91c250d4b4348daca473c44cf0c45b81096785f5510c192" +dependencies = [ + "burn-backend", + "burn-cubecl", + "cubecl", +] + +[[package]] +name = "burn-cubecl" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d6d13aff03fec966da4300459688883f8a1d741dddbf19d1bfc2562656a9a9b" +dependencies = [ + "burn-backend", + "burn-cubecl-fusion", + "burn-fusion", + "burn-ir", + "burn-std", + "cubecl", + "cubek", + "derive-new", + "futures-lite", + "log", + "serde", + "text_placeholder", +] + +[[package]] +name = "burn-cubecl-fusion" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17d25b2e9fb805931401f79782aabd92462d65e60bc207035a3e554de8d7cd9f" +dependencies = [ + "burn-backend", + "burn-fusion", + "burn-ir", + "burn-std", + "cubecl", + "cubek", + "derive-new", + "serde", +] + +[[package]] +name = "burn-cuda" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d0c68cf653eb9c27dcbe046bb7b04cc18c6b33afda4c09317c102e6f4ae7cb6" +dependencies = [ + "burn-backend", + "burn-cubecl", + "cubecl", +] + [[package]] name = "burn-dataset" version = "0.20.1" @@ -378,6 +455,22 @@ dependencies = [ "syn", ] +[[package]] +name = "burn-fusion" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ea83d7f8574bcc07967291c5bb679ddc0a655c8db0642eca62755e2fffc8047" +dependencies = [ + "burn-backend", + "burn-ir", + "derive-new", + "hashbrown 0.16.1", + "log", + "serde", + "spin", + "tracing", +] + [[package]] name = "burn-ir" version = "0.20.1" @@ -440,6 +533,31 @@ dependencies = [ "serde", ] +[[package]] +name = "burn-rocm" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73e2abda6ee63bdcb730f1a335349a9ff83f03048130d405b6ecdccd2df3ff23" +dependencies = [ + "burn-backend", + "burn-cubecl", + "cubecl", +] + +[[package]] +name = "burn-router" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "823ccb88484736a2861d53dc7f67db375ef050b0446bb02dd7cb8783ac6b69a2" +dependencies = [ + "burn-backend", + "burn-ir", + "burn-std", + "hashbrown 0.16.1", + "log", + "spin", +] + [[package]] name = "burn-std" version = "0.20.1" @@ -455,6 +573,25 @@ dependencies = [ "serde", ] +[[package]] +name = "burn-store" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be80a7b084a19901dc1d0a2e9b77e226c5c575879fe66de891c67062db41a6d" +dependencies = [ + "burn-core", + "burn-nn", + "burn-tensor", + "byteorder", + "bytes", + "half", + "hashbrown 0.16.1", + "memmap2", + "regex", + "safetensors", + "textdistance", +] + [[package]] name = "burn-tensor" version = "0.20.1" @@ -489,6 +626,17 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "burn-wgpu" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df78d62afc9b9fbb8ee4e49b72006485bb64f778a790e185a2d919479bcfc008" +dependencies = [ + "burn-backend", + "burn-cubecl", + "cubecl", +] + [[package]] name = "bytemuck" version = "1.25.0" @@ -553,6 +701,29 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "candle-core" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c15b675b80d994b2eadb20a4bbe434eabeb454eac3ee5e2b4cf6f147ee9be091" +dependencies = [ + "byteorder", + "float8 0.6.1", + "gemm", + "half", + "libm", + "memmap2", + "num-traits", + "num_cpus", + "rand", + "rand_distr", + "rayon", + "safetensors", + "thiserror 2.0.18", + "yoke", + "zip 7.2.0", +] + [[package]] name = "caseless" version = "0.2.2" @@ -895,6 +1066,7 @@ dependencies = [ "cubecl-hip", "cubecl-ir", "cubecl-runtime", + "cubecl-std", "cubecl-wgpu", "half", ] @@ -916,7 +1088,7 @@ dependencies = [ "embassy-futures", "embassy-time", "float4", - "float8", + "float8 0.4.2", "futures-lite", "half", "hashbrown 0.15.5", @@ -1199,6 +1371,104 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a0f819071413b19a00b7105497e0f6d2cf3e7e9d65cbb8d4ecf1ddb29c61dc2" +[[package]] +name = "cubek" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bb1cce47db02017925301bedec92ae84628493df3f9761ea7ac42a60c6146f8" +dependencies = [ + "cubecl", + "cubek-attention", + "cubek-convolution", + "cubek-matmul", + "cubek-quant", + "cubek-random", + "cubek-reduce", +] + +[[package]] +name = "cubek-attention" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7278bd122b2428af479f9af05285160613733c33c93b63ab3c6d25cd0460c18b" +dependencies = [ + "bytemuck", + "cubecl", + "cubecl-common", + "cubek-matmul", + "cubek-random", + "half", + "serde", +] + +[[package]] +name = "cubek-convolution" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18eb04bca4ae104d62a56def04b04f3d079c42fe49aac62202c96876f90fa28b" +dependencies = [ + "bytemuck", + "cubecl", + "cubecl-common", + "cubek-matmul", + "derive-new", + "enumset", + "half", + "serde", +] + +[[package]] +name = "cubek-matmul" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28f3b04b113760e97c65a8a4dca9afc220744031eeecd5ad6cd0e3be91ba3a9" +dependencies = [ + "bytemuck", + "cubecl", + "cubecl-common", + "half", + "serde", +] + +[[package]] +name = "cubek-quant" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96ec3ae04af324df2d615c2b394e270d58d6f08cb833d67633e2ba794de75916" +dependencies = [ + "cubecl", + "cubecl-common", + "half", + "serde", +] + +[[package]] +name = "cubek-random" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65a34844d8b7f739185c1d24896137dcb73f458830444103b45f678585ad983e" +dependencies = [ + "cubecl", + "cubecl-common", + "half", + "num-traits", + "rand", + "serde", +] + +[[package]] +name = "cubek-reduce" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42397d9ed85bb3084dfb56ed26de75690b5b07caf42a32f4006b57eb23d5b6d6" +dependencies = [ + "cubecl", + "half", + "num-traits", + "serde", + "thiserror 2.0.18", +] + [[package]] name = "cudarc" version = "0.18.2" @@ -1391,6 +1661,22 @@ dependencies = [ "litrs", ] +[[package]] +name = "dyn-stack" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c4713e43e2886ba72b8271aa66c93d722116acf7a75555cce11dcde84388fe8" +dependencies = [ + "bytemuck", + "dyn-stack-macros", +] + +[[package]] +name = "dyn-stack-macros" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1d926b4d407d372f141f93bb444696142c29d32962ccbd3531117cf3aa0bfa9" + [[package]] name = "either" version = "1.15.0" @@ -1459,6 +1745,18 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca" +[[package]] +name = "enum-as-inner" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "enumset" version = "1.1.10" @@ -1579,6 +1877,18 @@ dependencies = [ "half", ] +[[package]] +name = "float8" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719a903cc23e4a89e87962c2a80fdb45cdaad0983a89bd150bb57b4c8571a7d5" +dependencies = [ + "half", + "num-traits", + "rand", + "rand_distr", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1713,6 +2023,125 @@ dependencies = [ "slab", ] +[[package]] +name = "gemm" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa0673db364b12263d103b68337a68fbecc541d6f6b61ba72fe438654709eacb" +dependencies = [ + "dyn-stack", + "gemm-c32", + "gemm-c64", + "gemm-common", + "gemm-f16", + "gemm-f32", + "gemm-f64", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "086936dbdcb99e37aad81d320f98f670e53c1e55a98bee70573e83f95beb128c" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c64" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20c8aeeeec425959bda4d9827664029ba1501a90a0d1e6228e48bef741db3a3f" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-common" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88027625910cc9b1085aaaa1c4bc46bb3a36aad323452b33c25b5e4e7c8e2a3e" +dependencies = [ + "bytemuck", + "dyn-stack", + "half", + "libm", + "num-complex", + "num-traits", + "once_cell", + "paste", + "pulp", + "raw-cpuid", + "rayon", + "seq-macro", + "sysctl", +] + +[[package]] +name = "gemm-f16" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3df7a55202e6cd6739d82ae3399c8e0c7e1402859b30e4cb780e61525d9486e" +dependencies = [ + "dyn-stack", + "gemm-common", + "gemm-f32", + "half", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f32" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02e0b8c9da1fbec6e3e3ab2ce6bc259ef18eb5f6f0d3e4edf54b75f9fd41a81c" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-f64" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "056131e8f2a521bfab322f804ccd652520c79700d81209e9d9275bbdecaadc6a" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -1889,10 +2318,21 @@ dependencies = [ "cfg-if", "crunchy", "num-traits", + "rand", + "rand_distr", "serde", "zerocopy", ] +[[package]] +name = "hashbrown" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +dependencies = [ + "ahash", +] + [[package]] name = "hashbrown" version = "0.15.5" @@ -1955,6 +2395,7 @@ dependencies = [ "clap", "ndarray", "npyz", + "serde", "zip 8.2.0", ] @@ -2466,6 +2907,16 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "memmap2" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "744133e4a0e0a658e1374cf3bf8e415c4052a15a111acd372764c55b4177d490" +dependencies = [ + "libc", + "stable_deref_trait", +] + [[package]] name = "metal" version = "0.32.0" @@ -2650,6 +3101,7 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" dependencies = [ + "bytemuck", "num-traits", ] @@ -3013,6 +3465,29 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3eb8486b569e12e2c32ad3e204dbaba5e4b5b216e9367044f25f1dba42341773" +[[package]] +name = "pulp" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e205bb30d5b916c55e584c22201771bcf2bad9aabd5d4127f38387140c38632" +dependencies = [ + "bytemuck", + "cfg-if", + "libm", + "num-complex", + "paste", + "pulp-wasm-simd-flag", + "raw-cpuid", + "reborrow", + "version_check", +] + +[[package]] +name = "pulp-wasm-simd-flag" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40e24eee682d89fb193496edf918a7f407d30175b2e785fe057e4392dfd182e0" + [[package]] name = "py_literal" version = "0.4.0" @@ -3153,6 +3628,15 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca45419789ae5a7899559e9512e58ca889e41f04f1f2445e9f4b290ceccd1d08" +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags", +] + [[package]] name = "raw-window-handle" version = "0.6.2" @@ -3185,6 +3669,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "reborrow" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" + [[package]] name = "redox_syscall" version = "0.5.18" @@ -3445,6 +3935,17 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" +[[package]] +name = "safetensors" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "675656c1eabb620b921efea4f9199f97fc86e36dd6ffd1fbbe48d0f59a4987f5" +dependencies = [ + "hashbrown 0.16.1", + "serde", + "serde_json", +] + [[package]] name = "same-file" version = "1.0.6" @@ -3740,6 +4241,20 @@ dependencies = [ "syn", ] +[[package]] +name = "sysctl" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01198a2debb237c62b6826ec7081082d951f46dbb64b0e8c7649a452230d1dfc" +dependencies = [ + "bitflags", + "byteorder", + "enum-as-inner", + "libc", + "thiserror 1.0.69", + "walkdir", +] + [[package]] name = "sysinfo" version = "0.36.1" @@ -3787,6 +4302,23 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "text_placeholder" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd5008f74a09742486ef0047596cf35df2b914e2a8dca5727fcb6ba6842a766b" +dependencies = [ + "hashbrown 0.13.2", + "serde", + "serde_json", +] + +[[package]] +name = "textdistance" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa672c55ab69f787dbc9126cc387dbe57fdd595f585e4524cf89018fa44ab819" + [[package]] name = "thiserror" version = "1.0.69" @@ -5294,6 +5826,18 @@ dependencies = [ "zstd 0.11.2+zstd.1.5.2", ] +[[package]] +name = "zip" +version = "7.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c42e33efc22a0650c311c2ef19115ce232583abbe80850bc8b66509ebef02de0" +dependencies = [ + "crc32fast", + "indexmap", + "memchr", + "typed-path", +] + [[package]] name = "zip" version = "8.2.0" diff --git a/hod_1/Cargo.toml b/hod_1/Cargo.toml index 52934e1..4ca19e4 100644 --- a/hod_1/Cargo.toml +++ b/hod_1/Cargo.toml @@ -4,10 +4,11 @@ version = "0.1.0" edition = "2024" [dependencies] -burn = { version = "0.20.1", default-features = false, features = ["ndarray", "train"] } +burn = { version = "0.20.1", default-features = false, features = ["ndarray", "std", "train"] } burn-autodiff = "0.20.1" burn-ndarray = "0.20.1" clap = { version = "4.5.60", features = ["derive"] } ndarray = "0.17.2" npyz = { version = "0.8.4", features = ["npz"] } +serde = { version = "1.0.228", features = ["derive"] } zip = { version = "8.2.0", features = ["deflate"] } diff --git a/hod_1/src/lib.rs b/hod_1/src/lib.rs index a4262c5..041de14 100644 --- a/hod_1/src/lib.rs +++ b/hod_1/src/lib.rs @@ -1,17 +1,28 @@ -use burn::tensor::Tensor; -use burn::optim::Optimizer; -use burn::nn::loss::CrossEntropyLossConfig; -use burn::tensor::Int; -use burn::tensor::backend::AutodiffBackend; -use burn::tensor::backend::Backend; -use burn::optim::GradientsParams; -use burn::tensor::activation; -use std::str::FromStr; +// src/lib.rs + +use burn::config::Config; +use burn::data::dataloader::DataLoaderBuilder; +use burn::data::dataloader::batcher::Batcher; +use burn::data::dataset::Dataset; use burn::module::Module; +use burn::nn::loss::CrossEntropyLossConfig; use burn::nn::{Linear, LinearConfig}; +use burn::optim::AdamConfig; +use burn::record::CompactRecorder; +use burn::tensor::activation; +use burn::tensor::backend::{AutodiffBackend, Backend}; +use burn::tensor::{Int, Tensor}; +use burn::train::metric::{AccuracyMetric, LossMetric}; +use burn::lr_scheduler::constant::ConstantLr; +use burn::train::{ + ClassificationOutput, InferenceStep, Learner, SupervisedTraining, + TrainOutput, TrainStep, TrainingStrategy, +}; +use std::str::FromStr; pub type B = burn_autodiff::Autodiff>; +// Model #[derive(Module, Debug)] pub struct MnistClassifier { hidden: Vec>, @@ -19,7 +30,7 @@ pub struct MnistClassifier { activation: Activation, } -impl> MnistClassifier { +impl MnistClassifier { pub fn new( device: &B::Device, hidden_layers: usize, @@ -27,123 +38,61 @@ impl> MnistClassifier { activation: Activation, ) -> Self { let mut hidden = Vec::new(); - let mut current_input_size = 784; - if hidden_layers > 0 { - hidden.push(LinearConfig::new(current_input_size, hidden_layer_size).init(device)); - current_input_size = hidden_layer_size; + let mut in_size = 784; - for _ in 1..hidden_layers { - hidden.push(LinearConfig::new(hidden_layer_size, hidden_layer_size).init(device)); - } + for _ in 0..hidden_layers { + hidden.push(LinearConfig::new(in_size, hidden_layer_size).init(device)); + in_size = hidden_layer_size; } - let output = LinearConfig::new(current_input_size, 10).init(device); - + let output = LinearConfig::new(in_size, 10).init(device); Self { hidden, output, activation } } pub fn forward(&self, images: Tensor) -> Tensor { - let mut result = images; + let mut x = images; for layer in &self.hidden { - result = layer.forward(result); - result = self.activation.forward(result); - } - self.output.forward(result) - } - - pub fn train_step( - &self, - images: Tensor, - labels: Tensor, - optimizer: &mut impl Optimizer, - lr: f64 - ) -> (Self, f64, usize) where B: AutodiffBackend { - // Forward pass - let logits = self.forward(images); - - // Loss calculation - let loss_fn = CrossEntropyLossConfig::new().init(&logits.device()); - let loss = loss_fn.forward(logits.clone(), labels.clone()); - - // Accuracy - let correct = logits.argmax(1) - .flatten::<1>(0, 1) - .equal(labels) - .int() - .sum() - .into_scalar() as usize; - - let loss_val = loss.clone().into_scalar(); - - // Backprop - let grads = loss.backward(); - let grads = GradientsParams::from_grads(grads, self); - let updated_model = optimizer.step(lr, self.clone(), grads); - - (updated_model, loss_val, correct) - } - - pub fn train_and_evaluate( - &mut self, - images: Tensor, - labels: Tensor, - optimizer: &mut impl Optimizer, - args_epochs: usize, - args_batch_size: usize, - ) where B: AutodiffBackend { - eprintln!("images shape: {:?}", images.shape()); - eprintln!("labels shape: {:?}", labels.shape()); - - let train_size = 50000; - let x_train = images.clone().slice([0..train_size]); - let y_train = labels.clone().slice([0..train_size]); - let x_dev = images.slice([train_size..55000]); - let y_dev = labels.slice([train_size..55000]); - - let target_epochs = [1, 5, 10]; - for epoch in target_epochs { - let start = std::time::Instant::now(); - let mut train_loss = 0.0; - let mut train_correct = 0; - - for i in (0..train_size).step_by(args_batch_size) { - let end = (i + args_batch_size).min(train_size); - if i >= end { continue; } - - let b_x = x_train.clone().slice([i..end]); - let b_y = y_train.clone().slice([i..end]); - - if i == 0 { - eprintln!("first batch shape: {:?}", b_x.shape()); - eprintln!("output layer: input={:?} output=10", self.output.weight.shape()); - } - - let (updated_model, loss_val, correct) = self.train_step(b_x, b_y, optimizer, 1e-3); - *self = updated_model; - - train_loss += loss_val; - train_correct += correct; - } - - // Dev metrics - let dev_logits = self.forward(x_dev.clone()); - let loss_fn = CrossEntropyLossConfig::new().init(&dev_logits.device()); - let dev_loss = loss_fn.forward(dev_logits.clone(), y_dev.clone()).into_scalar(); - let dev_acc = dev_logits.argmax(1).flatten::<1>(0, 1).equal(y_dev.clone()).int().sum().into_scalar() as f64 / 5000.0; - - println!( - "Epoch {:2}/{} {:.1}s loss={:.4} accuracy={:.4} dev:loss={:.4} dev:accuracy={:.4}", - epoch, args_epochs, start.elapsed().as_secs_f32(), - train_loss / (train_size as f64 / args_batch_size as f64), - train_correct as f64 / train_size as f64, - dev_loss, dev_acc - ); + x = layer.forward(x); + x = self.activation.forward(x); } + self.output.forward(x) } } +impl TrainStep for MnistClassifier { + type Input = MnistBatch; + type Output = ClassificationOutput; -#[derive(Debug, Clone, Copy, Module, Default)] + fn step(&self, batch: MnistBatch) -> TrainOutput> { + let output = self.forward(batch.images); + let loss = CrossEntropyLossConfig::new() + .init(&output.device()) + .forward(output.clone(), batch.targets.clone()); + + TrainOutput::new( + self, + loss.backward(), + ClassificationOutput { loss, output, targets: batch.targets }, + ) + } +} + +impl InferenceStep for MnistClassifier { + type Input = MnistBatch; + type Output = ClassificationOutput; + + fn step(&self, batch: MnistBatch) -> ClassificationOutput { + let output = self.forward(batch.images); + let loss = CrossEntropyLossConfig::new() + .init(&output.device()) + .forward(output.clone(), batch.targets.clone()); + + ClassificationOutput { loss, output, targets: batch.targets } + } +} + +// Activation +#[derive(Debug, Clone, Copy, Module, Default, serde::Serialize, serde::Deserialize)] pub enum Activation { #[default] None, @@ -156,11 +105,11 @@ impl FromStr for Activation { type Err = String; fn from_str(s: &str) -> Result { match s { - "none" => Ok(Activation::None), - "relu" => Ok(Activation::ReLU), - "tanh" => Ok(Activation::Tanh), + "none" => Ok(Activation::None), + "relu" => Ok(Activation::ReLU), + "tanh" => Ok(Activation::Tanh), "sigmoid" => Ok(Activation::Sigmoid), - _ => Err(format!("Unknown activation: {}", s)), + _ => Err(format!("Unknown activation: {}", s)), } } } @@ -168,10 +117,148 @@ impl FromStr for Activation { impl Activation { pub fn forward(&self, x: Tensor) -> Tensor { match self { - Activation::None => x, - Activation::ReLU => activation::relu(x), - Activation::Tanh => activation::tanh(x), + Activation::None => x, + Activation::ReLU => activation::relu(x), + Activation::Tanh => activation::tanh(x), Activation::Sigmoid => activation::sigmoid(x), } } } + +// Dataset & Batch +#[derive(Clone, Debug)] +pub struct MnistItem { + pub image: [f64; 784], + pub label: u8, +} + +pub struct MnistDataset { + items: Vec, +} + +impl MnistDataset { + pub fn new(items: Vec) -> Self { + Self { items } + } +} + +impl Dataset for MnistDataset { + fn get(&self, index: usize) -> Option { + self.items.get(index).cloned() + } + fn len(&self) -> usize { + self.items.len() + } +} + +#[derive(Clone, Debug)] +pub struct MnistBatch { + pub images: Tensor, + pub targets: Tensor, +} + +#[derive(Clone)] +pub struct MnistBatcher; + +impl MnistBatcher { + pub fn new() -> Self { + Self + } +} + +impl> Batcher> + for MnistBatcher +{ + fn batch(&self, items: Vec, device: &B::Device) -> MnistBatch { + let n = items.len(); + let image_data: Vec = items.iter().flat_map(|i| i.image).collect(); + let label_data: Vec = items.iter().map(|i| i.label as i64).collect(); + + let images = Tensor::::from_data( + burn::tensor::TensorData::new(image_data, [n, 784]), + device, // ← use the passed-in device, not self.device + ); + let targets = Tensor::::from_data( + burn::tensor::TensorData::new(label_data, [n]), + device, + ); + + MnistBatch { images, targets } + } +} + +// Config +#[derive(Config, Debug)] +pub struct MnistModelConfig { + pub hidden_layers: usize, + pub hidden_layer_size: usize, + pub activation: Activation, +} + +impl MnistModelConfig { + pub fn init(&self, device: &B::Device) -> MnistClassifier { + MnistClassifier::new(device, self.hidden_layers, self.hidden_layer_size, self.activation) + } +} + +#[derive(Config, Debug)] +pub struct MnistTrainingConfig { + pub model: MnistModelConfig, + pub optimizer: AdamConfig, + + #[config(default = 10)] + pub num_epochs: usize, + #[config(default = 64)] + pub batch_size: usize, + #[config(default = 4)] + pub num_workers: usize, + #[config(default = 42)] + pub seed: u64, + #[config(default = 1.0e-4)] + pub learning_rate: f64, +} + +// Training +impl MnistTrainingConfig { + pub fn train( + &self, + device: B::Device, + train_dataset: MnistDataset, + valid_dataset: MnistDataset, + ) where + B: AutodiffBackend, + B::InnerBackend: Backend, + { + B::seed(&device, self.seed); + + let model = self.model.init::(&device); + let optim = self.optimizer.init(); + + let batcher_train = MnistBatcher::new(); + let batcher_valid = MnistBatcher::new(); + + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(self.batch_size) + .shuffle(self.seed) + .num_workers(self.num_workers) + .build(train_dataset); + + let dataloader_valid = DataLoaderBuilder::new(batcher_valid) + .batch_size(self.batch_size) + .num_workers(self.num_workers) + .build(valid_dataset); + + let training = SupervisedTraining::new("/tmp/artifacts", dataloader_train, dataloader_valid) + .metrics((AccuracyMetric::new(), LossMetric::new())) + .with_file_checkpointer(CompactRecorder::new()) + .num_epochs(self.num_epochs) + .summary() + .with_training_strategy(TrainingStrategy::SingleDevice(device)); + + let _result = training.launch(Learner::new( + model, + optim, + ConstantLr::new(self.learning_rate), // plain float → constant LR scheduler + )); + } +} diff --git a/hod_1/src/main.rs b/hod_1/src/main.rs index 3f486e2..9195c10 100644 --- a/hod_1/src/main.rs +++ b/hod_1/src/main.rs @@ -1,114 +1,83 @@ -use burn::tensor::backend::Backend; -use burn::tensor::Tensor; use clap::Parser; -use hod_1::B; +use hod_1::{Activation, MnistDataset, MnistItem, MnistModelConfig, MnistTrainingConfig, B}; +use burn::optim::AdamConfig; use std::fs::File; use std::io::Read; use std::str::FromStr; -use hod_1::*; - -use burn::optim::AdamConfig; -use burn::optim::Optimizer; -use burn::nn::loss::CrossEntropyLossConfig; #[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] +#[command(author, version, about)] struct Args { - #[arg(long = "activation", default_value = "none")] + #[arg(long, default_value = "none")] activation: String, - #[arg(long = "batch_size", default_value = "50")] + #[arg(long, default_value = "64")] batch_size: usize, - #[arg(long = "epochs", default_value = "10")] + #[arg(long, default_value = "10")] epochs: usize, - #[arg(long = "hidden_layer_size", default_value = "100")] + #[arg(long, default_value = "100")] hidden_layer_size: usize, - #[arg(long = "hidden_layers", default_value = "1")] + #[arg(long, default_value = "1")] hidden_layers: usize, - #[arg(long = "seed", default_value = "42")] + #[arg(long, default_value = "42")] seed: u64, - #[arg(long = "threads", default_value = "1")] - threads: usize, + /// Fraction of training data used for validation (e.g. 0.1 = 10 %) + #[arg(long, default_value = "0.1")] + valid_split: f64, } -/// Load MNIST images and labels for training. -/// Returns (images [N, 784], labels [N]) where labels are class indices 0-9. -fn load_mnist_labeled( - examples: usize, - device: &::Device, -) -> (Tensor, Tensor) { +fn load_mnist_items(examples: usize) -> Vec { let file = File::open("mnist.npz").expect("Cannot open mnist.npz"); let mut archive = zip::ZipArchive::new(file).expect("Cannot read zip"); - // Load images - let image_candidates = [ - "train_images.npy", - "train.images.npy", - "x_train.npy", - "images.npy", - ]; + // images + let image_candidates = ["train_images.npy", "train.images.npy", "x_train.npy", "images.npy"]; let mut image_bytes = Vec::new(); - let mut found_images = false; for name in &image_candidates { if let Ok(mut entry) = archive.by_name(name) { - entry.read_to_end(&mut image_bytes).expect("Failed to read images"); - found_images = true; + entry.read_to_end(&mut image_bytes).expect("read images"); break; } } - assert!(found_images, "Could not find train images in mnist.npz"); + assert!(!image_bytes.is_empty(), "Could not find train images in mnist.npz"); - // Load labels - let label_candidates = [ - "train_labels.npy", - "train.labels.npy", - "y_train.npy", - "labels.npy", - ]; + // labels + let label_candidates = ["train_labels.npy", "train.labels.npy", "y_train.npy", "labels.npy"]; let mut label_bytes = Vec::new(); - let mut found_labels = false; for name in &label_candidates { if let Ok(mut entry) = archive.by_name(name) { - entry.read_to_end(&mut label_bytes).expect("Failed to read labels"); - found_labels = true; + entry.read_to_end(&mut label_bytes).expect("read labels"); break; } } - assert!(found_labels, "Could not find train labels in mnist.npz"); + assert!(!label_bytes.is_empty(), "Could not find train labels in mnist.npz"); - // Parse images - let image_npy = npyz::NpyFile::new(&image_bytes[..]).expect("Cannot parse images npy"); + // parse + let image_npy = npyz::NpyFile::new(&image_bytes[..]).expect("parse images"); let image_shape = image_npy.shape().to_vec(); - let image_raw: Vec = image_npy.into_vec().expect("Failed to read images as u8"); + let image_raw: Vec = image_npy.into_vec().expect("images to vec"); let n = examples.min(image_shape[0] as usize); - let pixels = image_raw.len() / image_shape[0] as usize; + let pixels = image_raw.len() / image_shape[0] as usize; // should be 784 + assert_eq!(pixels, 784, "Expected 784 pixels per image, got {pixels}"); - let image_data: Vec = image_raw[..n * pixels] - .iter() - .map(|&p| p as f64 / 255.0) - .collect(); + let label_npy = npyz::NpyFile::new(&label_bytes[..]).expect("parse labels"); + let label_raw: Vec = label_npy.into_vec().expect("labels to vec"); - let image_tensor_data = burn::tensor::TensorData::new(image_data, [n, pixels]); - let images = Tensor::::from_data(image_tensor_data, device); - - // Parse labels - let label_npy = npyz::NpyFile::new(&label_bytes[..]).expect("Cannot parse labels npy"); - let label_raw: Vec = label_npy.into_vec().expect("Failed to read labels as u8"); - - let label_data: Vec = label_raw[..n] - .iter() - .map(|&p| p as i64) - .collect(); - - let label_tensor_data = burn::tensor::TensorData::new(label_data, [n]); - let labels = Tensor::::from_data(label_tensor_data, device); - - (images, labels) + // build items + (0..n) + .map(|i| { + let mut image = [0f64; 784]; + for (j, &px) in image_raw[i * 784..(i + 1) * 784].iter().enumerate() { + image[j] = px as f64 / 255.0; + } + MnistItem { image, label: label_raw[i] } + }) + .collect() } fn main() { @@ -116,18 +85,31 @@ fn main() { let device = burn_ndarray::NdArrayDevice::Cpu; let activation = Activation::from_str(&args.activation).unwrap_or_default(); - let mut model = MnistClassifier::::new( - &device, - args.hidden_layers, - args.hidden_layer_size, - activation, - ); + println!("Loading MNIST…"); + let all_items = load_mnist_items(60_000); - let mut optim = AdamConfig::new().init::>(); - let (images, labels) = load_mnist_labeled(60000, &device); + // Split into train / validation + let valid_n = (all_items.len() as f64 * args.valid_split) as usize; + let train_n = all_items.len() - valid_n; + let mut items = all_items; + let valid_items = items.split_off(train_n); // last `valid_n` items + let train_items = items; - println!("Starting training..."); + println!("Train: {} Valid: {}", train_items.len(), valid_items.len()); - // Main just tells the model to run the process - model.train_and_evaluate(images, labels, &mut optim, args.epochs, args.batch_size); + let train_dataset = MnistDataset::new(train_items); + let valid_dataset = MnistDataset::new(valid_items); + + let config = MnistTrainingConfig::new( + MnistModelConfig::new(args.hidden_layers, args.hidden_layer_size, activation), + AdamConfig::new(), + ) + .with_num_epochs(args.epochs) + .with_batch_size(args.batch_size) + .with_num_workers(1) // NdArray backend is single-threaded; keep at 1 + .with_seed(args.seed) + .with_learning_rate(1e-3); + + println!("Starting training…"); + config.train::(device, train_dataset, valid_dataset); } diff --git a/hod_2/Cargo.lock b/hod_2/Cargo.lock new file mode 100644 index 0000000..004da3a --- /dev/null +++ b/hod_2/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "hod_2" +version = "0.1.0" diff --git a/hod_2/Cargo.toml b/hod_2/Cargo.toml new file mode 100644 index 0000000..877c0e5 --- /dev/null +++ b/hod_2/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "hod_2" +version = "0.1.0" +edition = "2024" + +[dependencies] diff --git a/hod_2/src/main.rs b/hod_2/src/main.rs new file mode 100644 index 0000000..e7a11a9 --- /dev/null +++ b/hod_2/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello, world!"); +}