68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
|
|
import torch
|
|
import torchmetrics
|
|
|
|
import npfl138
|
|
from npfl138.datasets.mnist import MNIST
|
|
npfl138.require_version("2526.1")
|
|
|
|
# Parse arguments
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
|
|
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
|
|
parser.add_argument("--hidden_layer_size", default=100, type=int, help="Size of the hidden layer.")
|
|
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
|
|
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")
|
|
|
|
|
|
class Dataset(npfl138.TransformedDataset):
|
|
def transform(self, example):
|
|
image = example["image"] # a torch.Tensor with torch.uint8 values in [0, 255] range
|
|
image = image.to(torch.float32) / 255 # image converted to float32 and rescaled to [0, 1]
|
|
label = example["label"] # a torch.Tensor with a single integer representing the label
|
|
return image, label # return an (input, target) pair
|
|
|
|
|
|
def main(args: argparse.Namespace) -> None:
|
|
# Set the random seed and the number of threads.
|
|
npfl138.startup(args.seed, args.threads)
|
|
npfl138.global_keras_initializers()
|
|
|
|
# Load the data and create dataloaders.
|
|
mnist = MNIST()
|
|
|
|
train = torch.utils.data.DataLoader(Dataset(mnist.train), batch_size=args.batch_size, shuffle=True)
|
|
dev = torch.utils.data.DataLoader(Dataset(mnist.dev), batch_size=args.batch_size)
|
|
test = torch.utils.data.DataLoader(Dataset(mnist.test), batch_size=args.batch_size)
|
|
|
|
# Create the model.
|
|
model = torch.nn.Sequential(
|
|
torch.nn.Flatten(),
|
|
torch.nn.Linear(MNIST.C * MNIST.H * MNIST.W, args.hidden_layer_size),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(args.hidden_layer_size, MNIST.LABELS),
|
|
)
|
|
print("The following model has been created:", model)
|
|
|
|
# Create the TrainableModule and configure it for training.
|
|
model = npfl138.TrainableModule(model)
|
|
|
|
model.configure(
|
|
optimizer=torch.optim.Adam(model.parameters()),
|
|
loss=torch.nn.CrossEntropyLoss(),
|
|
metrics={"accuracy": torchmetrics.Accuracy("multiclass", num_classes=MNIST.LABELS)},
|
|
)
|
|
|
|
# Train the model.
|
|
model.fit(train, dev=dev, epochs=args.epochs)
|
|
|
|
# Evaluate the model on the test data.
|
|
model.evaluate(test)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main_args = parser.parse_args([] if "__file__" not in globals() else None)
|
|
main(main_args)
|