2 Commits

Author SHA1 Message Date
Priec
b86b3334d6 hod2 2026-03-14 08:19:00 +01:00
Priec
8fc8addcac learner used instead of manual version 2026-03-13 21:52:37 +01:00
11 changed files with 7016 additions and 203 deletions

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
*/target/
*/mnist.npz

546
hod_1/Cargo.lock generated
View File

@@ -274,11 +274,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b78ff10ed98b73e1d477ea6e6e1ec1b9cf9f71a17afc3fea9f4dca482d43dcd4" checksum = "b78ff10ed98b73e1d477ea6e6e1ec1b9cf9f71a17afc3fea9f4dca482d43dcd4"
dependencies = [ dependencies = [
"burn-autodiff", "burn-autodiff",
"burn-candle",
"burn-core", "burn-core",
"burn-cpu",
"burn-cuda",
"burn-ndarray", "burn-ndarray",
"burn-nn", "burn-nn",
"burn-optim", "burn-optim",
"burn-rocm",
"burn-router",
"burn-store",
"burn-train", "burn-train",
"burn-wgpu",
] ]
[[package]] [[package]]
@@ -317,6 +324,18 @@ dependencies = [
"thiserror 2.0.18", "thiserror 2.0.18",
] ]
[[package]]
name = "burn-candle"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21c752d5008923eb9299783da5edae3242b94afdb956e88d2b37b025244b5071"
dependencies = [
"burn-backend",
"burn-std",
"candle-core",
"derive-new",
]
[[package]] [[package]]
name = "burn-core" name = "burn-core"
version = "0.20.1" version = "0.20.1"
@@ -347,6 +366,64 @@ dependencies = [
"uuid", "uuid",
] ]
[[package]]
name = "burn-cpu"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60aa53c4536719f1c91c250d4b4348daca473c44cf0c45b81096785f5510c192"
dependencies = [
"burn-backend",
"burn-cubecl",
"cubecl",
]
[[package]]
name = "burn-cubecl"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d6d13aff03fec966da4300459688883f8a1d741dddbf19d1bfc2562656a9a9b"
dependencies = [
"burn-backend",
"burn-cubecl-fusion",
"burn-fusion",
"burn-ir",
"burn-std",
"cubecl",
"cubek",
"derive-new",
"futures-lite",
"log",
"serde",
"text_placeholder",
]
[[package]]
name = "burn-cubecl-fusion"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17d25b2e9fb805931401f79782aabd92462d65e60bc207035a3e554de8d7cd9f"
dependencies = [
"burn-backend",
"burn-fusion",
"burn-ir",
"burn-std",
"cubecl",
"cubek",
"derive-new",
"serde",
]
[[package]]
name = "burn-cuda"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d0c68cf653eb9c27dcbe046bb7b04cc18c6b33afda4c09317c102e6f4ae7cb6"
dependencies = [
"burn-backend",
"burn-cubecl",
"cubecl",
]
[[package]] [[package]]
name = "burn-dataset" name = "burn-dataset"
version = "0.20.1" version = "0.20.1"
@@ -378,6 +455,22 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "burn-fusion"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ea83d7f8574bcc07967291c5bb679ddc0a655c8db0642eca62755e2fffc8047"
dependencies = [
"burn-backend",
"burn-ir",
"derive-new",
"hashbrown 0.16.1",
"log",
"serde",
"spin",
"tracing",
]
[[package]] [[package]]
name = "burn-ir" name = "burn-ir"
version = "0.20.1" version = "0.20.1"
@@ -440,6 +533,31 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "burn-rocm"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73e2abda6ee63bdcb730f1a335349a9ff83f03048130d405b6ecdccd2df3ff23"
dependencies = [
"burn-backend",
"burn-cubecl",
"cubecl",
]
[[package]]
name = "burn-router"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "823ccb88484736a2861d53dc7f67db375ef050b0446bb02dd7cb8783ac6b69a2"
dependencies = [
"burn-backend",
"burn-ir",
"burn-std",
"hashbrown 0.16.1",
"log",
"spin",
]
[[package]] [[package]]
name = "burn-std" name = "burn-std"
version = "0.20.1" version = "0.20.1"
@@ -455,6 +573,25 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "burn-store"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4be80a7b084a19901dc1d0a2e9b77e226c5c575879fe66de891c67062db41a6d"
dependencies = [
"burn-core",
"burn-nn",
"burn-tensor",
"byteorder",
"bytes",
"half",
"hashbrown 0.16.1",
"memmap2",
"regex",
"safetensors",
"textdistance",
]
[[package]] [[package]]
name = "burn-tensor" name = "burn-tensor"
version = "0.20.1" version = "0.20.1"
@@ -489,6 +626,17 @@ dependencies = [
"tracing-subscriber", "tracing-subscriber",
] ]
[[package]]
name = "burn-wgpu"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df78d62afc9b9fbb8ee4e49b72006485bb64f778a790e185a2d919479bcfc008"
dependencies = [
"burn-backend",
"burn-cubecl",
"cubecl",
]
[[package]] [[package]]
name = "bytemuck" name = "bytemuck"
version = "1.25.0" version = "1.25.0"
@@ -553,6 +701,29 @@ dependencies = [
"pkg-config", "pkg-config",
] ]
[[package]]
name = "candle-core"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c15b675b80d994b2eadb20a4bbe434eabeb454eac3ee5e2b4cf6f147ee9be091"
dependencies = [
"byteorder",
"float8 0.6.1",
"gemm",
"half",
"libm",
"memmap2",
"num-traits",
"num_cpus",
"rand",
"rand_distr",
"rayon",
"safetensors",
"thiserror 2.0.18",
"yoke",
"zip 7.2.0",
]
[[package]] [[package]]
name = "caseless" name = "caseless"
version = "0.2.2" version = "0.2.2"
@@ -895,6 +1066,7 @@ dependencies = [
"cubecl-hip", "cubecl-hip",
"cubecl-ir", "cubecl-ir",
"cubecl-runtime", "cubecl-runtime",
"cubecl-std",
"cubecl-wgpu", "cubecl-wgpu",
"half", "half",
] ]
@@ -916,7 +1088,7 @@ dependencies = [
"embassy-futures", "embassy-futures",
"embassy-time", "embassy-time",
"float4", "float4",
"float8", "float8 0.4.2",
"futures-lite", "futures-lite",
"half", "half",
"hashbrown 0.15.5", "hashbrown 0.15.5",
@@ -1199,6 +1371,104 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a0f819071413b19a00b7105497e0f6d2cf3e7e9d65cbb8d4ecf1ddb29c61dc2" checksum = "7a0f819071413b19a00b7105497e0f6d2cf3e7e9d65cbb8d4ecf1ddb29c61dc2"
[[package]]
name = "cubek"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bb1cce47db02017925301bedec92ae84628493df3f9761ea7ac42a60c6146f8"
dependencies = [
"cubecl",
"cubek-attention",
"cubek-convolution",
"cubek-matmul",
"cubek-quant",
"cubek-random",
"cubek-reduce",
]
[[package]]
name = "cubek-attention"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7278bd122b2428af479f9af05285160613733c33c93b63ab3c6d25cd0460c18b"
dependencies = [
"bytemuck",
"cubecl",
"cubecl-common",
"cubek-matmul",
"cubek-random",
"half",
"serde",
]
[[package]]
name = "cubek-convolution"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18eb04bca4ae104d62a56def04b04f3d079c42fe49aac62202c96876f90fa28b"
dependencies = [
"bytemuck",
"cubecl",
"cubecl-common",
"cubek-matmul",
"derive-new",
"enumset",
"half",
"serde",
]
[[package]]
name = "cubek-matmul"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a28f3b04b113760e97c65a8a4dca9afc220744031eeecd5ad6cd0e3be91ba3a9"
dependencies = [
"bytemuck",
"cubecl",
"cubecl-common",
"half",
"serde",
]
[[package]]
name = "cubek-quant"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96ec3ae04af324df2d615c2b394e270d58d6f08cb833d67633e2ba794de75916"
dependencies = [
"cubecl",
"cubecl-common",
"half",
"serde",
]
[[package]]
name = "cubek-random"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65a34844d8b7f739185c1d24896137dcb73f458830444103b45f678585ad983e"
dependencies = [
"cubecl",
"cubecl-common",
"half",
"num-traits",
"rand",
"serde",
]
[[package]]
name = "cubek-reduce"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42397d9ed85bb3084dfb56ed26de75690b5b07caf42a32f4006b57eb23d5b6d6"
dependencies = [
"cubecl",
"half",
"num-traits",
"serde",
"thiserror 2.0.18",
]
[[package]] [[package]]
name = "cudarc" name = "cudarc"
version = "0.18.2" version = "0.18.2"
@@ -1391,6 +1661,22 @@ dependencies = [
"litrs", "litrs",
] ]
[[package]]
name = "dyn-stack"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c4713e43e2886ba72b8271aa66c93d722116acf7a75555cce11dcde84388fe8"
dependencies = [
"bytemuck",
"dyn-stack-macros",
]
[[package]]
name = "dyn-stack-macros"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1d926b4d407d372f141f93bb444696142c29d32962ccbd3531117cf3aa0bfa9"
[[package]] [[package]]
name = "either" name = "either"
version = "1.15.0" version = "1.15.0"
@@ -1459,6 +1745,18 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca" checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca"
[[package]]
name = "enum-as-inner"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "enumset" name = "enumset"
version = "1.1.10" version = "1.1.10"
@@ -1579,6 +1877,18 @@ dependencies = [
"half", "half",
] ]
[[package]]
name = "float8"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719a903cc23e4a89e87962c2a80fdb45cdaad0983a89bd150bb57b4c8571a7d5"
dependencies = [
"half",
"num-traits",
"rand",
"rand_distr",
]
[[package]] [[package]]
name = "fnv" name = "fnv"
version = "1.0.7" version = "1.0.7"
@@ -1713,6 +2023,125 @@ dependencies = [
"slab", "slab",
] ]
[[package]]
name = "gemm"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa0673db364b12263d103b68337a68fbecc541d6f6b61ba72fe438654709eacb"
dependencies = [
"dyn-stack",
"gemm-c32",
"gemm-c64",
"gemm-common",
"gemm-f16",
"gemm-f32",
"gemm-f64",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "gemm-c32"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "086936dbdcb99e37aad81d320f98f670e53c1e55a98bee70573e83f95beb128c"
dependencies = [
"dyn-stack",
"gemm-common",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "gemm-c64"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20c8aeeeec425959bda4d9827664029ba1501a90a0d1e6228e48bef741db3a3f"
dependencies = [
"dyn-stack",
"gemm-common",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "gemm-common"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88027625910cc9b1085aaaa1c4bc46bb3a36aad323452b33c25b5e4e7c8e2a3e"
dependencies = [
"bytemuck",
"dyn-stack",
"half",
"libm",
"num-complex",
"num-traits",
"once_cell",
"paste",
"pulp",
"raw-cpuid",
"rayon",
"seq-macro",
"sysctl",
]
[[package]]
name = "gemm-f16"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3df7a55202e6cd6739d82ae3399c8e0c7e1402859b30e4cb780e61525d9486e"
dependencies = [
"dyn-stack",
"gemm-common",
"gemm-f32",
"half",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"rayon",
"seq-macro",
]
[[package]]
name = "gemm-f32"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02e0b8c9da1fbec6e3e3ab2ce6bc259ef18eb5f6f0d3e4edf54b75f9fd41a81c"
dependencies = [
"dyn-stack",
"gemm-common",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "gemm-f64"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "056131e8f2a521bfab322f804ccd652520c79700d81209e9d9275bbdecaadc6a"
dependencies = [
"dyn-stack",
"gemm-common",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]] [[package]]
name = "generic-array" name = "generic-array"
version = "0.14.7" version = "0.14.7"
@@ -1889,10 +2318,21 @@ dependencies = [
"cfg-if", "cfg-if",
"crunchy", "crunchy",
"num-traits", "num-traits",
"rand",
"rand_distr",
"serde", "serde",
"zerocopy", "zerocopy",
] ]
[[package]]
name = "hashbrown"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
dependencies = [
"ahash",
]
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.15.5" version = "0.15.5"
@@ -1955,6 +2395,7 @@ dependencies = [
"clap", "clap",
"ndarray", "ndarray",
"npyz", "npyz",
"serde",
"zip 8.2.0", "zip 8.2.0",
] ]
@@ -2466,6 +2907,16 @@ version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
[[package]]
name = "memmap2"
version = "0.9.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "744133e4a0e0a658e1374cf3bf8e415c4052a15a111acd372764c55b4177d490"
dependencies = [
"libc",
"stable_deref_trait",
]
[[package]] [[package]]
name = "metal" name = "metal"
version = "0.32.0" version = "0.32.0"
@@ -2650,6 +3101,7 @@ version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [ dependencies = [
"bytemuck",
"num-traits", "num-traits",
] ]
@@ -3013,6 +3465,29 @@ version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3eb8486b569e12e2c32ad3e204dbaba5e4b5b216e9367044f25f1dba42341773" checksum = "3eb8486b569e12e2c32ad3e204dbaba5e4b5b216e9367044f25f1dba42341773"
[[package]]
name = "pulp"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e205bb30d5b916c55e584c22201771bcf2bad9aabd5d4127f38387140c38632"
dependencies = [
"bytemuck",
"cfg-if",
"libm",
"num-complex",
"paste",
"pulp-wasm-simd-flag",
"raw-cpuid",
"reborrow",
"version_check",
]
[[package]]
name = "pulp-wasm-simd-flag"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40e24eee682d89fb193496edf918a7f407d30175b2e785fe057e4392dfd182e0"
[[package]] [[package]]
name = "py_literal" name = "py_literal"
version = "0.4.0" version = "0.4.0"
@@ -3153,6 +3628,15 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca45419789ae5a7899559e9512e58ca889e41f04f1f2445e9f4b290ceccd1d08" checksum = "ca45419789ae5a7899559e9512e58ca889e41f04f1f2445e9f4b290ceccd1d08"
[[package]]
name = "raw-cpuid"
version = "11.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
dependencies = [
"bitflags",
]
[[package]] [[package]]
name = "raw-window-handle" name = "raw-window-handle"
version = "0.6.2" version = "0.6.2"
@@ -3185,6 +3669,12 @@ dependencies = [
"crossbeam-utils", "crossbeam-utils",
] ]
[[package]]
name = "reborrow"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.18" version = "0.5.18"
@@ -3445,6 +3935,17 @@ version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f"
[[package]]
name = "safetensors"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "675656c1eabb620b921efea4f9199f97fc86e36dd6ffd1fbbe48d0f59a4987f5"
dependencies = [
"hashbrown 0.16.1",
"serde",
"serde_json",
]
[[package]] [[package]]
name = "same-file" name = "same-file"
version = "1.0.6" version = "1.0.6"
@@ -3740,6 +4241,20 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "sysctl"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01198a2debb237c62b6826ec7081082d951f46dbb64b0e8c7649a452230d1dfc"
dependencies = [
"bitflags",
"byteorder",
"enum-as-inner",
"libc",
"thiserror 1.0.69",
"walkdir",
]
[[package]] [[package]]
name = "sysinfo" name = "sysinfo"
version = "0.36.1" version = "0.36.1"
@@ -3787,6 +4302,23 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "text_placeholder"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd5008f74a09742486ef0047596cf35df2b914e2a8dca5727fcb6ba6842a766b"
dependencies = [
"hashbrown 0.13.2",
"serde",
"serde_json",
]
[[package]]
name = "textdistance"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa672c55ab69f787dbc9126cc387dbe57fdd595f585e4524cf89018fa44ab819"
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.69" version = "1.0.69"
@@ -5294,6 +5826,18 @@ dependencies = [
"zstd 0.11.2+zstd.1.5.2", "zstd 0.11.2+zstd.1.5.2",
] ]
[[package]]
name = "zip"
version = "7.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c42e33efc22a0650c311c2ef19115ce232583abbe80850bc8b66509ebef02de0"
dependencies = [
"crc32fast",
"indexmap",
"memchr",
"typed-path",
]
[[package]] [[package]]
name = "zip" name = "zip"
version = "8.2.0" version = "8.2.0"

View File

@@ -4,10 +4,11 @@ version = "0.1.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
burn = { version = "0.20.1", default-features = false, features = ["ndarray", "train"] } burn = { version = "0.20.1", default-features = false, features = ["ndarray", "std", "train"] }
burn-autodiff = "0.20.1" burn-autodiff = "0.20.1"
burn-ndarray = "0.20.1" burn-ndarray = "0.20.1"
clap = { version = "4.5.60", features = ["derive"] } clap = { version = "4.5.60", features = ["derive"] }
ndarray = "0.17.2" ndarray = "0.17.2"
npyz = { version = "0.8.4", features = ["npz"] } npyz = { version = "0.8.4", features = ["npz"] }
serde = { version = "1.0.228", features = ["derive"] }
zip = { version = "8.2.0", features = ["deflate"] } zip = { version = "8.2.0", features = ["deflate"] }

View File

@@ -1,17 +1,28 @@
use burn::tensor::Tensor; // src/lib.rs
use burn::optim::Optimizer;
use burn::nn::loss::CrossEntropyLossConfig; use burn::config::Config;
use burn::tensor::Int; use burn::data::dataloader::DataLoaderBuilder;
use burn::tensor::backend::AutodiffBackend; use burn::data::dataloader::batcher::Batcher;
use burn::tensor::backend::Backend; use burn::data::dataset::Dataset;
use burn::optim::GradientsParams;
use burn::tensor::activation;
use std::str::FromStr;
use burn::module::Module; use burn::module::Module;
use burn::nn::loss::CrossEntropyLossConfig;
use burn::nn::{Linear, LinearConfig}; use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamConfig;
use burn::record::CompactRecorder;
use burn::tensor::activation;
use burn::tensor::backend::{AutodiffBackend, Backend};
use burn::tensor::{Int, Tensor};
use burn::train::metric::{AccuracyMetric, LossMetric};
use burn::lr_scheduler::constant::ConstantLr;
use burn::train::{
ClassificationOutput, InferenceStep, Learner, SupervisedTraining,
TrainOutput, TrainStep, TrainingStrategy,
};
use std::str::FromStr;
pub type B = burn_autodiff::Autodiff<burn_ndarray::NdArray<f64>>; pub type B = burn_autodiff::Autodiff<burn_ndarray::NdArray<f64>>;
// Model
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct MnistClassifier<B: Backend> { pub struct MnistClassifier<B: Backend> {
hidden: Vec<Linear<B>>, hidden: Vec<Linear<B>>,
@@ -19,7 +30,7 @@ pub struct MnistClassifier<B: Backend> {
activation: Activation, activation: Activation,
} }
impl<B: Backend<FloatElem = f64, IntElem = i64>> MnistClassifier<B> { impl<B: Backend> MnistClassifier<B> {
pub fn new( pub fn new(
device: &B::Device, device: &B::Device,
hidden_layers: usize, hidden_layers: usize,
@@ -27,123 +38,61 @@ impl<B: Backend<FloatElem = f64, IntElem = i64>> MnistClassifier<B> {
activation: Activation, activation: Activation,
) -> Self { ) -> Self {
let mut hidden = Vec::new(); let mut hidden = Vec::new();
let mut current_input_size = 784; let mut in_size = 784;
if hidden_layers > 0 {
hidden.push(LinearConfig::new(current_input_size, hidden_layer_size).init(device));
current_input_size = hidden_layer_size;
for _ in 1..hidden_layers { for _ in 0..hidden_layers {
hidden.push(LinearConfig::new(hidden_layer_size, hidden_layer_size).init(device)); hidden.push(LinearConfig::new(in_size, hidden_layer_size).init(device));
} in_size = hidden_layer_size;
} }
let output = LinearConfig::new(current_input_size, 10).init(device); let output = LinearConfig::new(in_size, 10).init(device);
Self { hidden, output, activation } Self { hidden, output, activation }
} }
pub fn forward(&self, images: Tensor<B, 2>) -> Tensor<B, 2> { pub fn forward(&self, images: Tensor<B, 2>) -> Tensor<B, 2> {
let mut result = images; let mut x = images;
for layer in &self.hidden { for layer in &self.hidden {
result = layer.forward(result); x = layer.forward(x);
result = self.activation.forward(result); x = self.activation.forward(x);
}
self.output.forward(result)
}
pub fn train_step(
&self,
images: Tensor<B, 2>,
labels: Tensor<B, 1, Int>,
optimizer: &mut impl Optimizer<Self, B>,
lr: f64
) -> (Self, f64, usize) where B: AutodiffBackend {
// Forward pass
let logits = self.forward(images);
// Loss calculation
let loss_fn = CrossEntropyLossConfig::new().init(&logits.device());
let loss = loss_fn.forward(logits.clone(), labels.clone());
// Accuracy
let correct = logits.argmax(1)
.flatten::<1>(0, 1)
.equal(labels)
.int()
.sum()
.into_scalar() as usize;
let loss_val = loss.clone().into_scalar();
// Backprop
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, self);
let updated_model = optimizer.step(lr, self.clone(), grads);
(updated_model, loss_val, correct)
}
pub fn train_and_evaluate(
&mut self,
images: Tensor<B, 2>,
labels: Tensor<B, 1, Int>,
optimizer: &mut impl Optimizer<Self, B>,
args_epochs: usize,
args_batch_size: usize,
) where B: AutodiffBackend {
eprintln!("images shape: {:?}", images.shape());
eprintln!("labels shape: {:?}", labels.shape());
let train_size = 50000;
let x_train = images.clone().slice([0..train_size]);
let y_train = labels.clone().slice([0..train_size]);
let x_dev = images.slice([train_size..55000]);
let y_dev = labels.slice([train_size..55000]);
let target_epochs = [1, 5, 10];
for epoch in target_epochs {
let start = std::time::Instant::now();
let mut train_loss = 0.0;
let mut train_correct = 0;
for i in (0..train_size).step_by(args_batch_size) {
let end = (i + args_batch_size).min(train_size);
if i >= end { continue; }
let b_x = x_train.clone().slice([i..end]);
let b_y = y_train.clone().slice([i..end]);
if i == 0 {
eprintln!("first batch shape: {:?}", b_x.shape());
eprintln!("output layer: input={:?} output=10", self.output.weight.shape());
}
let (updated_model, loss_val, correct) = self.train_step(b_x, b_y, optimizer, 1e-3);
*self = updated_model;
train_loss += loss_val;
train_correct += correct;
}
// Dev metrics
let dev_logits = self.forward(x_dev.clone());
let loss_fn = CrossEntropyLossConfig::new().init(&dev_logits.device());
let dev_loss = loss_fn.forward(dev_logits.clone(), y_dev.clone()).into_scalar();
let dev_acc = dev_logits.argmax(1).flatten::<1>(0, 1).equal(y_dev.clone()).int().sum().into_scalar() as f64 / 5000.0;
println!(
"Epoch {:2}/{} {:.1}s loss={:.4} accuracy={:.4} dev:loss={:.4} dev:accuracy={:.4}",
epoch, args_epochs, start.elapsed().as_secs_f32(),
train_loss / (train_size as f64 / args_batch_size as f64),
train_correct as f64 / train_size as f64,
dev_loss, dev_acc
);
} }
self.output.forward(x)
} }
} }
impl<B: AutodiffBackend> TrainStep for MnistClassifier<B> {
type Input = MnistBatch<B>;
type Output = ClassificationOutput<B>;
#[derive(Debug, Clone, Copy, Module, Default)] fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
let output = self.forward(batch.images);
let loss = CrossEntropyLossConfig::new()
.init(&output.device())
.forward(output.clone(), batch.targets.clone());
TrainOutput::new(
self,
loss.backward(),
ClassificationOutput { loss, output, targets: batch.targets },
)
}
}
impl<B: Backend> InferenceStep for MnistClassifier<B> {
type Input = MnistBatch<B>;
type Output = ClassificationOutput<B>;
fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
let output = self.forward(batch.images);
let loss = CrossEntropyLossConfig::new()
.init(&output.device())
.forward(output.clone(), batch.targets.clone());
ClassificationOutput { loss, output, targets: batch.targets }
}
}
// Activation
#[derive(Debug, Clone, Copy, Module, Default, serde::Serialize, serde::Deserialize)]
pub enum Activation { pub enum Activation {
#[default] #[default]
None, None,
@@ -175,3 +124,141 @@ impl Activation {
} }
} }
} }
// Dataset & Batch
#[derive(Clone, Debug)]
pub struct MnistItem {
pub image: [f64; 784],
pub label: u8,
}
pub struct MnistDataset {
items: Vec<MnistItem>,
}
impl MnistDataset {
pub fn new(items: Vec<MnistItem>) -> Self {
Self { items }
}
}
impl Dataset<MnistItem> for MnistDataset {
fn get(&self, index: usize) -> Option<MnistItem> {
self.items.get(index).cloned()
}
fn len(&self) -> usize {
self.items.len()
}
}
#[derive(Clone, Debug)]
pub struct MnistBatch<B: Backend> {
pub images: Tensor<B, 2>,
pub targets: Tensor<B, 1, Int>,
}
#[derive(Clone)]
pub struct MnistBatcher;
impl MnistBatcher {
pub fn new() -> Self {
Self
}
}
impl<B: Backend<FloatElem = f64, IntElem = i64>> Batcher<B, MnistItem, MnistBatch<B>>
for MnistBatcher
{
fn batch(&self, items: Vec<MnistItem>, device: &B::Device) -> MnistBatch<B> {
let n = items.len();
let image_data: Vec<f64> = items.iter().flat_map(|i| i.image).collect();
let label_data: Vec<i64> = items.iter().map(|i| i.label as i64).collect();
let images = Tensor::<B, 2>::from_data(
burn::tensor::TensorData::new(image_data, [n, 784]),
device, // ← use the passed-in device, not self.device
);
let targets = Tensor::<B, 1, Int>::from_data(
burn::tensor::TensorData::new(label_data, [n]),
device,
);
MnistBatch { images, targets }
}
}
// Config
#[derive(Config, Debug)]
pub struct MnistModelConfig {
pub hidden_layers: usize,
pub hidden_layer_size: usize,
pub activation: Activation,
}
impl MnistModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> MnistClassifier<B> {
MnistClassifier::new(device, self.hidden_layers, self.hidden_layer_size, self.activation)
}
}
#[derive(Config, Debug)]
pub struct MnistTrainingConfig {
pub model: MnistModelConfig,
pub optimizer: AdamConfig,
#[config(default = 10)]
pub num_epochs: usize,
#[config(default = 64)]
pub batch_size: usize,
#[config(default = 4)]
pub num_workers: usize,
#[config(default = 42)]
pub seed: u64,
#[config(default = 1.0e-4)]
pub learning_rate: f64,
}
// Training
impl MnistTrainingConfig {
pub fn train<B>(
&self,
device: B::Device,
train_dataset: MnistDataset,
valid_dataset: MnistDataset,
) where
B: AutodiffBackend<FloatElem = f64, IntElem = i64>,
B::InnerBackend: Backend<FloatElem = f64, IntElem = i64>,
{
B::seed(&device, self.seed);
let model = self.model.init::<B>(&device);
let optim = self.optimizer.init();
let batcher_train = MnistBatcher::new();
let batcher_valid = MnistBatcher::new();
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(self.batch_size)
.shuffle(self.seed)
.num_workers(self.num_workers)
.build(train_dataset);
let dataloader_valid = DataLoaderBuilder::new(batcher_valid)
.batch_size(self.batch_size)
.num_workers(self.num_workers)
.build(valid_dataset);
let training = SupervisedTraining::new("/tmp/artifacts", dataloader_train, dataloader_valid)
.metrics((AccuracyMetric::new(), LossMetric::new()))
.with_file_checkpointer(CompactRecorder::new())
.num_epochs(self.num_epochs)
.summary()
.with_training_strategy(TrainingStrategy::SingleDevice(device));
let _result = training.launch(Learner::new(
model,
optim,
ConstantLr::new(self.learning_rate), // plain float → constant LR scheduler
));
}
}

View File

@@ -1,114 +1,83 @@
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use clap::Parser; use clap::Parser;
use hod_1::B; use hod_1::{Activation, MnistDataset, MnistItem, MnistModelConfig, MnistTrainingConfig, B};
use burn::optim::AdamConfig;
use std::fs::File; use std::fs::File;
use std::io::Read; use std::io::Read;
use std::str::FromStr; use std::str::FromStr;
use hod_1::*;
use burn::optim::AdamConfig;
use burn::optim::Optimizer;
use burn::nn::loss::CrossEntropyLossConfig;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about)]
struct Args { struct Args {
#[arg(long = "activation", default_value = "none")] #[arg(long, default_value = "none")]
activation: String, activation: String,
#[arg(long = "batch_size", default_value = "50")] #[arg(long, default_value = "64")]
batch_size: usize, batch_size: usize,
#[arg(long = "epochs", default_value = "10")] #[arg(long, default_value = "10")]
epochs: usize, epochs: usize,
#[arg(long = "hidden_layer_size", default_value = "100")] #[arg(long, default_value = "100")]
hidden_layer_size: usize, hidden_layer_size: usize,
#[arg(long = "hidden_layers", default_value = "1")] #[arg(long, default_value = "1")]
hidden_layers: usize, hidden_layers: usize,
#[arg(long = "seed", default_value = "42")] #[arg(long, default_value = "42")]
seed: u64, seed: u64,
#[arg(long = "threads", default_value = "1")] /// Fraction of training data used for validation (e.g. 0.1 = 10 %)
threads: usize, #[arg(long, default_value = "0.1")]
valid_split: f64,
} }
/// Load MNIST images and labels for training. fn load_mnist_items(examples: usize) -> Vec<MnistItem> {
/// Returns (images [N, 784], labels [N]) where labels are class indices 0-9.
fn load_mnist_labeled(
examples: usize,
device: &<B as Backend>::Device,
) -> (Tensor<B, 2>, Tensor<B, 1, burn::tensor::Int>) {
let file = File::open("mnist.npz").expect("Cannot open mnist.npz"); let file = File::open("mnist.npz").expect("Cannot open mnist.npz");
let mut archive = zip::ZipArchive::new(file).expect("Cannot read zip"); let mut archive = zip::ZipArchive::new(file).expect("Cannot read zip");
// Load images // images
let image_candidates = [ let image_candidates = ["train_images.npy", "train.images.npy", "x_train.npy", "images.npy"];
"train_images.npy",
"train.images.npy",
"x_train.npy",
"images.npy",
];
let mut image_bytes = Vec::new(); let mut image_bytes = Vec::new();
let mut found_images = false;
for name in &image_candidates { for name in &image_candidates {
if let Ok(mut entry) = archive.by_name(name) { if let Ok(mut entry) = archive.by_name(name) {
entry.read_to_end(&mut image_bytes).expect("Failed to read images"); entry.read_to_end(&mut image_bytes).expect("read images");
found_images = true;
break; break;
} }
} }
assert!(found_images, "Could not find train images in mnist.npz"); assert!(!image_bytes.is_empty(), "Could not find train images in mnist.npz");
// Load labels // labels
let label_candidates = [ let label_candidates = ["train_labels.npy", "train.labels.npy", "y_train.npy", "labels.npy"];
"train_labels.npy",
"train.labels.npy",
"y_train.npy",
"labels.npy",
];
let mut label_bytes = Vec::new(); let mut label_bytes = Vec::new();
let mut found_labels = false;
for name in &label_candidates { for name in &label_candidates {
if let Ok(mut entry) = archive.by_name(name) { if let Ok(mut entry) = archive.by_name(name) {
entry.read_to_end(&mut label_bytes).expect("Failed to read labels"); entry.read_to_end(&mut label_bytes).expect("read labels");
found_labels = true;
break; break;
} }
} }
assert!(found_labels, "Could not find train labels in mnist.npz"); assert!(!label_bytes.is_empty(), "Could not find train labels in mnist.npz");
// Parse images // parse
let image_npy = npyz::NpyFile::new(&image_bytes[..]).expect("Cannot parse images npy"); let image_npy = npyz::NpyFile::new(&image_bytes[..]).expect("parse images");
let image_shape = image_npy.shape().to_vec(); let image_shape = image_npy.shape().to_vec();
let image_raw: Vec<u8> = image_npy.into_vec().expect("Failed to read images as u8"); let image_raw: Vec<u8> = image_npy.into_vec().expect("images to vec");
let n = examples.min(image_shape[0] as usize); let n = examples.min(image_shape[0] as usize);
let pixels = image_raw.len() / image_shape[0] as usize; let pixels = image_raw.len() / image_shape[0] as usize; // should be 784
assert_eq!(pixels, 784, "Expected 784 pixels per image, got {pixels}");
let image_data: Vec<f64> = image_raw[..n * pixels] let label_npy = npyz::NpyFile::new(&label_bytes[..]).expect("parse labels");
.iter() let label_raw: Vec<u8> = label_npy.into_vec().expect("labels to vec");
.map(|&p| p as f64 / 255.0)
.collect();
let image_tensor_data = burn::tensor::TensorData::new(image_data, [n, pixels]); // build items
let images = Tensor::<B, 2>::from_data(image_tensor_data, device); (0..n)
.map(|i| {
// Parse labels let mut image = [0f64; 784];
let label_npy = npyz::NpyFile::new(&label_bytes[..]).expect("Cannot parse labels npy"); for (j, &px) in image_raw[i * 784..(i + 1) * 784].iter().enumerate() {
let label_raw: Vec<u8> = label_npy.into_vec().expect("Failed to read labels as u8"); image[j] = px as f64 / 255.0;
}
let label_data: Vec<i64> = label_raw[..n] MnistItem { image, label: label_raw[i] }
.iter() })
.map(|&p| p as i64) .collect()
.collect();
let label_tensor_data = burn::tensor::TensorData::new(label_data, [n]);
let labels = Tensor::<B, 1, burn::tensor::Int>::from_data(label_tensor_data, device);
(images, labels)
} }
fn main() { fn main() {
@@ -116,18 +85,31 @@ fn main() {
let device = burn_ndarray::NdArrayDevice::Cpu; let device = burn_ndarray::NdArrayDevice::Cpu;
let activation = Activation::from_str(&args.activation).unwrap_or_default(); let activation = Activation::from_str(&args.activation).unwrap_or_default();
let mut model = MnistClassifier::<B>::new( println!("Loading MNIST…");
&device, let all_items = load_mnist_items(60_000);
args.hidden_layers,
args.hidden_layer_size,
activation,
);
let mut optim = AdamConfig::new().init::<B, MnistClassifier<B>>(); // Split into train / validation
let (images, labels) = load_mnist_labeled(60000, &device); let valid_n = (all_items.len() as f64 * args.valid_split) as usize;
let train_n = all_items.len() - valid_n;
let mut items = all_items;
let valid_items = items.split_off(train_n); // last `valid_n` items
let train_items = items;
println!("Starting training..."); println!("Train: {} Valid: {}", train_items.len(), valid_items.len());
// Main just tells the model to run the process let train_dataset = MnistDataset::new(train_items);
model.train_and_evaluate(images, labels, &mut optim, args.epochs, args.batch_size); let valid_dataset = MnistDataset::new(valid_items);
let config = MnistTrainingConfig::new(
MnistModelConfig::new(args.hidden_layers, args.hidden_layer_size, activation),
AdamConfig::new(),
)
.with_num_epochs(args.epochs)
.with_batch_size(args.batch_size)
.with_num_workers(1) // NdArray backend is single-threaded; keep at 1
.with_seed(args.seed)
.with_learning_rate(1e-3);
println!("Starting training…");
config.train::<B>(device, train_dataset, valid_dataset);
} }

5987
hod_2/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

16
hod_2/Cargo.toml Normal file
View File

@@ -0,0 +1,16 @@
[package]
name = "hod_2"
version = "0.1.0"
edition = "2024"
[dependencies]
burn = { version = "0.20.1", default-features = false, features = ["ndarray", "std", "train"] }
burn-autodiff = "0.20.1"
burn-ndarray = "0.20.1"
clap = { version = "4.5.60", features = ["derive"] }
ndarray = "0.17.2"
npyz = { version = "0.8.4", features = ["npz"] }
rand = "0.10.0"
rand_distr = "0.6.0"
serde = { version = "1.0.228", features = ["derive"] }
zip = { version = "8.2.0", features = ["deflate"] }

50
hod_2/plan.md Normal file
View File

@@ -0,0 +1,50 @@
## Phase 1: Core Data Structures
**`src/model.rs`** - Manual parameter management
- `struct Parameters<B: Backend>`: holds `w1, b1, w2, b2` as `Tensor<B, 2>`
- `impl Parameters`: initialization with `randn(0.1)` for weights, zeros for biases
- No `nn.Linear`—manual tensors to match the Python exercise
## Phase 2: Forward Pass
**`src/forward.rs`** or in `model.rs`
- `fn forward<B: Backend>(params: &Parameters<B>, images: Tensor<B, 2>) -> Tensor<B, 2>`
- Cast `uint8` images to `f32`, divide by 255, flatten to `[batch, 784]`
- `hidden = tanh(images @ w1 + b1)`
- `logits = hidden @ w2 + b2`
- Return raw logits (no softmax here)
## Phase 3: Loss Computation
**`src/loss.rs`**
- `fn cross_entropy_loss<B: Backend>(logits: Tensor<B, 2>, labels: Tensor<B, 1, Int>) -> Tensor<B, 0>`
- Manual implementation—no `CrossEntropyLoss` module
- `softmax = exp(logits - max) / sum(exp(logits - max))`
- Index `softmax` by gold labels to get `p_correct`
- `loss = -mean(log(p_correct))`
## Phase 4: Backward Pass & SGD
**`src/train.rs`**
- `fn train_epoch<B: Backend>(params: &mut Parameters<B>, dataset: &[MnistItem], args: &Args)`
- For each batch:
1. `let loss = cross_entropy_loss(forward(&params, images), labels)`
2. `let grads = loss.backward()` — automatic differentiation
3. **Manual SGD**: `param = param - lr * grad` for each parameter
4. No `Optimizer`—raw gradient descent like Python
## Phase 5: Evaluation
**`src/eval.rs`**
- `fn evaluate<B: Backend>(params: &Parameters<B>, dataset: &[MnistItem]) -> f64`
- `argmax` on logits, compare to labels, return accuracy
## Phase 6: Main Training Loop
**Update `src/main.rs`**
- Parse args ✓ (done)
- Load data ✓ (done)
- Initialize `Parameters` with seed
- Loop `args.epochs`: `train_epoch``evaluate(dev)` → print
- Final `evaluate(test)`

1
hod_2/src/lib.rs Normal file
View File

@@ -0,0 +1 @@
pub mod model;

79
hod_2/src/main.rs Normal file
View File

@@ -0,0 +1,79 @@
use clap::Parser;
use std::fs::File;
use std::io::{Cursor, Read};
#[derive(Parser, Debug)]
#[command(author, version, about)]
struct Args {
#[arg(long, default_value_t = 50)]
batch_size: usize,
#[arg(long, default_value_t = 10)]
epochs: usize,
#[arg(long, default_value_t = 100)]
hidden_layer_size: usize,
#[arg(long, default_value_t = 0.1)]
learning_rate: f64,
#[arg(long, default_value_t = 42)]
seed: u64,
#[arg(long, default_value_t = 1)]
threads: usize,
}
fn load_mnist_items(path: &str, examples: usize) -> Vec<(Vec<f32>, u8)> {
let file = File::open(path).expect("Cannot open mnist.npz");
let mut archive = zip::ZipArchive::new(file).expect("Cannot read zip");
let image_names = ["train_images.npy", "train.images.npy", "x_train.npy", "images.npy"];
let mut image_bytes = Vec::new();
for name in &image_names {
if let Ok(mut entry) = archive.by_name(name) {
entry.read_to_end(&mut image_bytes).unwrap();
break;
}
}
let label_names = ["train_labels.npy", "train.labels.npy", "y_train.npy", "labels.npy"];
let mut label_bytes = Vec::new();
for name in &label_names {
if let Ok(mut entry) = archive.by_name(name) {
entry.read_to_end(&mut label_bytes).unwrap();
break;
}
}
let images_npy = npyz::NpyFile::new(Cursor::new(&image_bytes)).unwrap();
let shape = images_npy.shape().to_vec();
let n = shape[0] as usize;
let pixels = shape[1..].iter().product::<u64>() as usize;
let image_raw: Vec<u8> = images_npy.into_vec().unwrap();
let labels_npy = npyz::NpyFile::new(Cursor::new(&label_bytes)).unwrap();
let label_raw: Vec<u8> = labels_npy.into_vec().unwrap();
(0..n.min(examples))
.map(|i| {
let image: Vec<f32> = image_raw[i * pixels..(i + 1) * pixels]
.iter()
.map(|&p| p as f32 / 255.0)
.collect();
(image, label_raw[i])
})
.collect()
}
fn main() {
let args = Args::parse();
println!("Loading MNIST...");
let train_items = load_mnist_items("mnist.npz", 55_000);
let dev_items = load_mnist_items("mnist.npz", 5_000);
let test_items = load_mnist_items("mnist.npz", 10_000);
println!("Train: {}, Dev: {}, Test: {}", train_items.len(), dev_items.len(), test_items.len());
println!("Args: {:?}", args);
}

64
hod_2/src/model.rs Normal file
View File

@@ -0,0 +1,64 @@
use burn::tensor::{backend::Backend, Tensor};
use rand::{rngs::StdRng, SeedableRng};
use rand_distr::{Distribution, Normal};
/// Manual neural network parameters for SGD backpropagation.
/// No nn.Linear — just raw tensors to match the Python exercise.
pub struct Parameters<B: Backend> {
/// First layer weights: [784, hidden_layer_size]
pub w1: Tensor<B, 2>,
/// First layer biases: [hidden_layer_size]
pub b1: Tensor<B, 1>,
/// Second layer weights: [hidden_layer_size, 10]
pub w2: Tensor<B, 2>,
/// Second layer biases: [10]
pub b2: Tensor<B, 1>,
}
impl<B: Backend> Parameters<B> {
/// Initialize parameters with given hidden size and random seed.
/// Weights: randn * 0.1, Biases: zeros
pub fn new(device: &B::Device, hidden_size: usize, seed: u64) -> Self {
let w1 = random_tensor([784, hidden_size], 0.1, seed, device);
let b1 = Tensor::zeros([hidden_size], device);
let w2 = random_tensor([hidden_size, 10], 0.1, seed.wrapping_add(1), device);
let b2 = Tensor::zeros([10], device);
Self { w1, b1, w2, b2 }
}
/// Get all parameters as a vector for gradient updates.
/// Order: w1, b1, w2, b2
pub fn to_vec(&self) -> Vec<ParamRef<B>> {
vec![
ParamRef::TwoD(self.w1.clone()),
ParamRef::OneD(self.b1.clone()),
ParamRef::TwoD(self.w2.clone()),
ParamRef::OneD(self.b2.clone()),
]
}
}
/// Helper enum to handle 1D and 2D parameters uniformly.
pub enum ParamRef<B: Backend> {
OneD(Tensor<B, 1>),
TwoD(Tensor<B, 2>),
}
/// Create a random tensor with normal distribution, scaled by std_dev.
fn random_tensor<B: Backend, const D: usize>(
shape: [usize; D],
std_dev: f64,
seed: u64,
device: &B::Device,
) -> Tensor<B, D> {
let dist = Normal::new(0.0, std_dev).unwrap();
let mut rng = StdRng::seed_from_u64(seed);
let total: usize = shape.iter().product();
let data: Vec<f64> = (0..total).map(|_| dist.sample(&mut rng)).collect();
Tensor::from_floats(burn::tensor::TensorData::new(data, shape), device)
}