1.8 KiB
1.8 KiB
Phase 1: Core Data Structures
src/model.rs - Manual parameter management
struct Parameters<B: Backend>: holdsw1, b1, w2, b2asTensor<B, 2>impl Parameters: initialization withrandn(0.1)for weights, zeros for biases- No
nn.Linear—manual tensors to match the Python exercise
Phase 2: Forward Pass
src/forward.rs or in model.rs
fn forward<B: Backend>(params: &Parameters<B>, images: Tensor<B, 2>) -> Tensor<B, 2>- Cast
uint8images tof32, divide by 255, flatten to[batch, 784] hidden = tanh(images @ w1 + b1)logits = hidden @ w2 + b2- Return raw logits (no softmax here)
Phase 3: Loss Computation
src/loss.rs
fn cross_entropy_loss<B: Backend>(logits: Tensor<B, 2>, labels: Tensor<B, 1, Int>) -> Tensor<B, 0>- Manual implementation—no
CrossEntropyLossmodule softmax = exp(logits - max) / sum(exp(logits - max))- Index
softmaxby gold labels to getp_correct loss = -mean(log(p_correct))
Phase 4: Backward Pass & SGD
src/train.rs
fn train_epoch<B: Backend>(params: &mut Parameters<B>, dataset: &[MnistItem], args: &Args)- For each batch:
let loss = cross_entropy_loss(forward(¶ms, images), labels)let grads = loss.backward()— automatic differentiation- Manual SGD:
param = param - lr * gradfor each parameter - No
Optimizer—raw gradient descent like Python
Phase 5: Evaluation
src/eval.rs
fn evaluate<B: Backend>(params: &Parameters<B>, dataset: &[MnistItem]) -> f64argmaxon logits, compare to labels, return accuracy
Phase 6: Main Training Loop
Update src/main.rs
- Parse args ✓ (done)
- Load data ✓ (done)
- Initialize
Parameterswith seed - Loop
args.epochs:train_epoch→evaluate(dev)→ print - Final
evaluate(test)