learner used instead of manual version

This commit is contained in:
Priec
2026-03-13 21:52:37 +01:00
parent f6b9d79062
commit 8fc8addcac
8 changed files with 835 additions and 203 deletions

546
hod_1/Cargo.lock generated
View File

@@ -274,11 +274,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b78ff10ed98b73e1d477ea6e6e1ec1b9cf9f71a17afc3fea9f4dca482d43dcd4"
dependencies = [
"burn-autodiff",
"burn-candle",
"burn-core",
"burn-cpu",
"burn-cuda",
"burn-ndarray",
"burn-nn",
"burn-optim",
"burn-rocm",
"burn-router",
"burn-store",
"burn-train",
"burn-wgpu",
]
[[package]]
@@ -317,6 +324,18 @@ dependencies = [
"thiserror 2.0.18",
]
[[package]]
name = "burn-candle"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21c752d5008923eb9299783da5edae3242b94afdb956e88d2b37b025244b5071"
dependencies = [
"burn-backend",
"burn-std",
"candle-core",
"derive-new",
]
[[package]]
name = "burn-core"
version = "0.20.1"
@@ -347,6 +366,64 @@ dependencies = [
"uuid",
]
[[package]]
name = "burn-cpu"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60aa53c4536719f1c91c250d4b4348daca473c44cf0c45b81096785f5510c192"
dependencies = [
"burn-backend",
"burn-cubecl",
"cubecl",
]
[[package]]
name = "burn-cubecl"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d6d13aff03fec966da4300459688883f8a1d741dddbf19d1bfc2562656a9a9b"
dependencies = [
"burn-backend",
"burn-cubecl-fusion",
"burn-fusion",
"burn-ir",
"burn-std",
"cubecl",
"cubek",
"derive-new",
"futures-lite",
"log",
"serde",
"text_placeholder",
]
[[package]]
name = "burn-cubecl-fusion"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17d25b2e9fb805931401f79782aabd92462d65e60bc207035a3e554de8d7cd9f"
dependencies = [
"burn-backend",
"burn-fusion",
"burn-ir",
"burn-std",
"cubecl",
"cubek",
"derive-new",
"serde",
]
[[package]]
name = "burn-cuda"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d0c68cf653eb9c27dcbe046bb7b04cc18c6b33afda4c09317c102e6f4ae7cb6"
dependencies = [
"burn-backend",
"burn-cubecl",
"cubecl",
]
[[package]]
name = "burn-dataset"
version = "0.20.1"
@@ -378,6 +455,22 @@ dependencies = [
"syn",
]
[[package]]
name = "burn-fusion"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ea83d7f8574bcc07967291c5bb679ddc0a655c8db0642eca62755e2fffc8047"
dependencies = [
"burn-backend",
"burn-ir",
"derive-new",
"hashbrown 0.16.1",
"log",
"serde",
"spin",
"tracing",
]
[[package]]
name = "burn-ir"
version = "0.20.1"
@@ -440,6 +533,31 @@ dependencies = [
"serde",
]
[[package]]
name = "burn-rocm"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73e2abda6ee63bdcb730f1a335349a9ff83f03048130d405b6ecdccd2df3ff23"
dependencies = [
"burn-backend",
"burn-cubecl",
"cubecl",
]
[[package]]
name = "burn-router"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "823ccb88484736a2861d53dc7f67db375ef050b0446bb02dd7cb8783ac6b69a2"
dependencies = [
"burn-backend",
"burn-ir",
"burn-std",
"hashbrown 0.16.1",
"log",
"spin",
]
[[package]]
name = "burn-std"
version = "0.20.1"
@@ -455,6 +573,25 @@ dependencies = [
"serde",
]
[[package]]
name = "burn-store"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4be80a7b084a19901dc1d0a2e9b77e226c5c575879fe66de891c67062db41a6d"
dependencies = [
"burn-core",
"burn-nn",
"burn-tensor",
"byteorder",
"bytes",
"half",
"hashbrown 0.16.1",
"memmap2",
"regex",
"safetensors",
"textdistance",
]
[[package]]
name = "burn-tensor"
version = "0.20.1"
@@ -489,6 +626,17 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "burn-wgpu"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df78d62afc9b9fbb8ee4e49b72006485bb64f778a790e185a2d919479bcfc008"
dependencies = [
"burn-backend",
"burn-cubecl",
"cubecl",
]
[[package]]
name = "bytemuck"
version = "1.25.0"
@@ -553,6 +701,29 @@ dependencies = [
"pkg-config",
]
[[package]]
name = "candle-core"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c15b675b80d994b2eadb20a4bbe434eabeb454eac3ee5e2b4cf6f147ee9be091"
dependencies = [
"byteorder",
"float8 0.6.1",
"gemm",
"half",
"libm",
"memmap2",
"num-traits",
"num_cpus",
"rand",
"rand_distr",
"rayon",
"safetensors",
"thiserror 2.0.18",
"yoke",
"zip 7.2.0",
]
[[package]]
name = "caseless"
version = "0.2.2"
@@ -895,6 +1066,7 @@ dependencies = [
"cubecl-hip",
"cubecl-ir",
"cubecl-runtime",
"cubecl-std",
"cubecl-wgpu",
"half",
]
@@ -916,7 +1088,7 @@ dependencies = [
"embassy-futures",
"embassy-time",
"float4",
"float8",
"float8 0.4.2",
"futures-lite",
"half",
"hashbrown 0.15.5",
@@ -1199,6 +1371,104 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a0f819071413b19a00b7105497e0f6d2cf3e7e9d65cbb8d4ecf1ddb29c61dc2"
[[package]]
name = "cubek"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bb1cce47db02017925301bedec92ae84628493df3f9761ea7ac42a60c6146f8"
dependencies = [
"cubecl",
"cubek-attention",
"cubek-convolution",
"cubek-matmul",
"cubek-quant",
"cubek-random",
"cubek-reduce",
]
[[package]]
name = "cubek-attention"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7278bd122b2428af479f9af05285160613733c33c93b63ab3c6d25cd0460c18b"
dependencies = [
"bytemuck",
"cubecl",
"cubecl-common",
"cubek-matmul",
"cubek-random",
"half",
"serde",
]
[[package]]
name = "cubek-convolution"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18eb04bca4ae104d62a56def04b04f3d079c42fe49aac62202c96876f90fa28b"
dependencies = [
"bytemuck",
"cubecl",
"cubecl-common",
"cubek-matmul",
"derive-new",
"enumset",
"half",
"serde",
]
[[package]]
name = "cubek-matmul"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a28f3b04b113760e97c65a8a4dca9afc220744031eeecd5ad6cd0e3be91ba3a9"
dependencies = [
"bytemuck",
"cubecl",
"cubecl-common",
"half",
"serde",
]
[[package]]
name = "cubek-quant"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96ec3ae04af324df2d615c2b394e270d58d6f08cb833d67633e2ba794de75916"
dependencies = [
"cubecl",
"cubecl-common",
"half",
"serde",
]
[[package]]
name = "cubek-random"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65a34844d8b7f739185c1d24896137dcb73f458830444103b45f678585ad983e"
dependencies = [
"cubecl",
"cubecl-common",
"half",
"num-traits",
"rand",
"serde",
]
[[package]]
name = "cubek-reduce"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42397d9ed85bb3084dfb56ed26de75690b5b07caf42a32f4006b57eb23d5b6d6"
dependencies = [
"cubecl",
"half",
"num-traits",
"serde",
"thiserror 2.0.18",
]
[[package]]
name = "cudarc"
version = "0.18.2"
@@ -1391,6 +1661,22 @@ dependencies = [
"litrs",
]
[[package]]
name = "dyn-stack"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c4713e43e2886ba72b8271aa66c93d722116acf7a75555cce11dcde84388fe8"
dependencies = [
"bytemuck",
"dyn-stack-macros",
]
[[package]]
name = "dyn-stack-macros"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1d926b4d407d372f141f93bb444696142c29d32962ccbd3531117cf3aa0bfa9"
[[package]]
name = "either"
version = "1.15.0"
@@ -1459,6 +1745,18 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca"
[[package]]
name = "enum-as-inner"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "enumset"
version = "1.1.10"
@@ -1579,6 +1877,18 @@ dependencies = [
"half",
]
[[package]]
name = "float8"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719a903cc23e4a89e87962c2a80fdb45cdaad0983a89bd150bb57b4c8571a7d5"
dependencies = [
"half",
"num-traits",
"rand",
"rand_distr",
]
[[package]]
name = "fnv"
version = "1.0.7"
@@ -1713,6 +2023,125 @@ dependencies = [
"slab",
]
[[package]]
name = "gemm"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa0673db364b12263d103b68337a68fbecc541d6f6b61ba72fe438654709eacb"
dependencies = [
"dyn-stack",
"gemm-c32",
"gemm-c64",
"gemm-common",
"gemm-f16",
"gemm-f32",
"gemm-f64",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "gemm-c32"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "086936dbdcb99e37aad81d320f98f670e53c1e55a98bee70573e83f95beb128c"
dependencies = [
"dyn-stack",
"gemm-common",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "gemm-c64"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20c8aeeeec425959bda4d9827664029ba1501a90a0d1e6228e48bef741db3a3f"
dependencies = [
"dyn-stack",
"gemm-common",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "gemm-common"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88027625910cc9b1085aaaa1c4bc46bb3a36aad323452b33c25b5e4e7c8e2a3e"
dependencies = [
"bytemuck",
"dyn-stack",
"half",
"libm",
"num-complex",
"num-traits",
"once_cell",
"paste",
"pulp",
"raw-cpuid",
"rayon",
"seq-macro",
"sysctl",
]
[[package]]
name = "gemm-f16"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3df7a55202e6cd6739d82ae3399c8e0c7e1402859b30e4cb780e61525d9486e"
dependencies = [
"dyn-stack",
"gemm-common",
"gemm-f32",
"half",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"rayon",
"seq-macro",
]
[[package]]
name = "gemm-f32"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02e0b8c9da1fbec6e3e3ab2ce6bc259ef18eb5f6f0d3e4edf54b75f9fd41a81c"
dependencies = [
"dyn-stack",
"gemm-common",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "gemm-f64"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "056131e8f2a521bfab322f804ccd652520c79700d81209e9d9275bbdecaadc6a"
dependencies = [
"dyn-stack",
"gemm-common",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "generic-array"
version = "0.14.7"
@@ -1889,10 +2318,21 @@ dependencies = [
"cfg-if",
"crunchy",
"num-traits",
"rand",
"rand_distr",
"serde",
"zerocopy",
]
[[package]]
name = "hashbrown"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
dependencies = [
"ahash",
]
[[package]]
name = "hashbrown"
version = "0.15.5"
@@ -1955,6 +2395,7 @@ dependencies = [
"clap",
"ndarray",
"npyz",
"serde",
"zip 8.2.0",
]
@@ -2466,6 +2907,16 @@ version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
[[package]]
name = "memmap2"
version = "0.9.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "744133e4a0e0a658e1374cf3bf8e415c4052a15a111acd372764c55b4177d490"
dependencies = [
"libc",
"stable_deref_trait",
]
[[package]]
name = "metal"
version = "0.32.0"
@@ -2650,6 +3101,7 @@ version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"bytemuck",
"num-traits",
]
@@ -3013,6 +3465,29 @@ version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3eb8486b569e12e2c32ad3e204dbaba5e4b5b216e9367044f25f1dba42341773"
[[package]]
name = "pulp"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e205bb30d5b916c55e584c22201771bcf2bad9aabd5d4127f38387140c38632"
dependencies = [
"bytemuck",
"cfg-if",
"libm",
"num-complex",
"paste",
"pulp-wasm-simd-flag",
"raw-cpuid",
"reborrow",
"version_check",
]
[[package]]
name = "pulp-wasm-simd-flag"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40e24eee682d89fb193496edf918a7f407d30175b2e785fe057e4392dfd182e0"
[[package]]
name = "py_literal"
version = "0.4.0"
@@ -3153,6 +3628,15 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca45419789ae5a7899559e9512e58ca889e41f04f1f2445e9f4b290ceccd1d08"
[[package]]
name = "raw-cpuid"
version = "11.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
dependencies = [
"bitflags",
]
[[package]]
name = "raw-window-handle"
version = "0.6.2"
@@ -3185,6 +3669,12 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "reborrow"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "redox_syscall"
version = "0.5.18"
@@ -3445,6 +3935,17 @@ version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f"
[[package]]
name = "safetensors"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "675656c1eabb620b921efea4f9199f97fc86e36dd6ffd1fbbe48d0f59a4987f5"
dependencies = [
"hashbrown 0.16.1",
"serde",
"serde_json",
]
[[package]]
name = "same-file"
version = "1.0.6"
@@ -3740,6 +4241,20 @@ dependencies = [
"syn",
]
[[package]]
name = "sysctl"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01198a2debb237c62b6826ec7081082d951f46dbb64b0e8c7649a452230d1dfc"
dependencies = [
"bitflags",
"byteorder",
"enum-as-inner",
"libc",
"thiserror 1.0.69",
"walkdir",
]
[[package]]
name = "sysinfo"
version = "0.36.1"
@@ -3787,6 +4302,23 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "text_placeholder"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd5008f74a09742486ef0047596cf35df2b914e2a8dca5727fcb6ba6842a766b"
dependencies = [
"hashbrown 0.13.2",
"serde",
"serde_json",
]
[[package]]
name = "textdistance"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa672c55ab69f787dbc9126cc387dbe57fdd595f585e4524cf89018fa44ab819"
[[package]]
name = "thiserror"
version = "1.0.69"
@@ -5294,6 +5826,18 @@ dependencies = [
"zstd 0.11.2+zstd.1.5.2",
]
[[package]]
name = "zip"
version = "7.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c42e33efc22a0650c311c2ef19115ce232583abbe80850bc8b66509ebef02de0"
dependencies = [
"crc32fast",
"indexmap",
"memchr",
"typed-path",
]
[[package]]
name = "zip"
version = "8.2.0"