hod2
This commit is contained in:
50
hod_2/plan.md
Normal file
50
hod_2/plan.md
Normal file
@@ -0,0 +1,50 @@
|
||||
## Phase 1: Core Data Structures
|
||||
|
||||
**`src/model.rs`** - Manual parameter management
|
||||
- `struct Parameters<B: Backend>`: holds `w1, b1, w2, b2` as `Tensor<B, 2>`
|
||||
- `impl Parameters`: initialization with `randn(0.1)` for weights, zeros for biases
|
||||
- No `nn.Linear`—manual tensors to match the Python exercise
|
||||
|
||||
## Phase 2: Forward Pass
|
||||
|
||||
**`src/forward.rs`** or in `model.rs`
|
||||
- `fn forward<B: Backend>(params: &Parameters<B>, images: Tensor<B, 2>) -> Tensor<B, 2>`
|
||||
- Cast `uint8` images to `f32`, divide by 255, flatten to `[batch, 784]`
|
||||
- `hidden = tanh(images @ w1 + b1)`
|
||||
- `logits = hidden @ w2 + b2`
|
||||
- Return raw logits (no softmax here)
|
||||
|
||||
## Phase 3: Loss Computation
|
||||
|
||||
**`src/loss.rs`**
|
||||
- `fn cross_entropy_loss<B: Backend>(logits: Tensor<B, 2>, labels: Tensor<B, 1, Int>) -> Tensor<B, 0>`
|
||||
- Manual implementation—no `CrossEntropyLoss` module
|
||||
- `softmax = exp(logits - max) / sum(exp(logits - max))`
|
||||
- Index `softmax` by gold labels to get `p_correct`
|
||||
- `loss = -mean(log(p_correct))`
|
||||
|
||||
## Phase 4: Backward Pass & SGD
|
||||
|
||||
**`src/train.rs`**
|
||||
- `fn train_epoch<B: Backend>(params: &mut Parameters<B>, dataset: &[MnistItem], args: &Args)`
|
||||
- For each batch:
|
||||
1. `let loss = cross_entropy_loss(forward(¶ms, images), labels)`
|
||||
2. `let grads = loss.backward()` — automatic differentiation
|
||||
3. **Manual SGD**: `param = param - lr * grad` for each parameter
|
||||
4. No `Optimizer`—raw gradient descent like Python
|
||||
|
||||
## Phase 5: Evaluation
|
||||
|
||||
**`src/eval.rs`**
|
||||
- `fn evaluate<B: Backend>(params: &Parameters<B>, dataset: &[MnistItem]) -> f64`
|
||||
- `argmax` on logits, compare to labels, return accuracy
|
||||
|
||||
## Phase 6: Main Training Loop
|
||||
|
||||
**Update `src/main.rs`**
|
||||
- Parse args ✓ (done)
|
||||
- Load data ✓ (done)
|
||||
- Initialize `Parameters` with seed
|
||||
- Loop `args.epochs`: `train_epoch` → `evaluate(dev)` → print
|
||||
- Final `evaluate(test)`
|
||||
|
||||
Reference in New Issue
Block a user