01 hotova

This commit is contained in:
Priec
2026-03-07 21:30:55 +01:00
parent 009e5c4925
commit f0b2073caa
15 changed files with 11753 additions and 48 deletions

View File

@@ -0,0 +1,67 @@
#!/usr/bin/env python3
import argparse
import torch
import torchmetrics
import npfl138
from npfl138.datasets.mnist import MNIST
npfl138.require_version("2526.1")
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
parser.add_argument("--hidden_layer_size", default=100, type=int, help="Size of the hidden layer.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")
class Dataset(npfl138.TransformedDataset):
def transform(self, example):
image = example["image"] # a torch.Tensor with torch.uint8 values in [0, 255] range
image = image.to(torch.float32) / 255 # image converted to float32 and rescaled to [0, 1]
label = example["label"] # a torch.Tensor with a single integer representing the label
return image, label # return an (input, target) pair
def main(args: argparse.Namespace) -> None:
# Set the random seed and the number of threads.
npfl138.startup(args.seed, args.threads)
npfl138.global_keras_initializers()
# Load the data and create dataloaders.
mnist = MNIST()
train = torch.utils.data.DataLoader(Dataset(mnist.train), batch_size=args.batch_size, shuffle=True)
dev = torch.utils.data.DataLoader(Dataset(mnist.dev), batch_size=args.batch_size)
test = torch.utils.data.DataLoader(Dataset(mnist.test), batch_size=args.batch_size)
# Create the model.
model = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(MNIST.C * MNIST.H * MNIST.W, args.hidden_layer_size),
torch.nn.ReLU(),
torch.nn.Linear(args.hidden_layer_size, MNIST.LABELS),
)
print("The following model has been created:", model)
# Create the TrainableModule and configure it for training.
model = npfl138.TrainableModule(model)
model.configure(
optimizer=torch.optim.Adam(model.parameters()),
loss=torch.nn.CrossEntropyLoss(),
metrics={"accuracy": torchmetrics.Accuracy("multiclass", num_classes=MNIST.LABELS)},
)
# Train the model.
model.fit(train, dev=dev, epochs=args.epochs)
# Evaluate the model on the test data.
model.evaluate(test)
if __name__ == "__main__":
main_args = parser.parse_args([] if "__file__" not in globals() else None)
main(main_args)

View File

@@ -0,0 +1,68 @@
#!/usr/bin/env python3
import argparse
import torch
import torchmetrics
import npfl138
from npfl138.datasets.mnist import MNIST
npfl138.require_version("2526.1")
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
parser.add_argument("--hidden_layer_size", default=100, type=int, help="Size of the hidden layer.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")
class Dataset(npfl138.TransformedDataset):
def transform(self, example):
image = example["image"] # a torch.Tensor with torch.uint8 values in [0, 255] range
image = image.to(torch.float32) / 255 # image converted to float32 and rescaled to [0, 1]
label = example["label"] # a torch.Tensor with a single integer representing the label
return image, label # return an (input, target) pair
def main(args: argparse.Namespace) -> None:
# Set the random seed and the number of threads.
npfl138.startup(args.seed, args.threads)
npfl138.global_keras_initializers()
# Load the data and create dataloaders.
mnist = MNIST()
train = torch.utils.data.DataLoader(Dataset(mnist.train), batch_size=args.batch_size, shuffle=True)
dev = torch.utils.data.DataLoader(Dataset(mnist.dev), batch_size=args.batch_size)
test = torch.utils.data.DataLoader(Dataset(mnist.test), batch_size=args.batch_size)
# Create the model.
model = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(MNIST.C * MNIST.H * MNIST.W, args.hidden_layer_size),
torch.nn.ReLU(),
torch.nn.Linear(args.hidden_layer_size, MNIST.LABELS),
)
print("The following model has been created:", model)
# Create the TrainableModule and configure it for training.
model = npfl138.TrainableModule(model)
model.configure(
optimizer=torch.optim.Adam(model.parameters()),
loss=torch.nn.CrossEntropyLoss(),
metrics={"accuracy": torchmetrics.Accuracy("multiclass", num_classes=MNIST.LABELS)},
logdir=npfl138.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
)
# Train the model.
model.fit(train, dev=dev, epochs=args.epochs, log_config=vars(args), log_graph=True)
# Evaluate the model on the test data.
model.evaluate(test)
if __name__ == "__main__":
main_args = parser.parse_args([] if "__file__" not in globals() else None)
main(main_args)

View File

@@ -0,0 +1,73 @@
#!/usr/bin/env python3
import argparse
import torch
import torchmetrics
import npfl138
npfl138.require_version("2526.1")
from npfl138.datasets.mnist import MNIST
parser = argparse.ArgumentParser()
# These arguments will be set appropriately by ReCodEx, even if you change them.
parser.add_argument("--activation", default="none", choices=["none", "relu", "tanh", "sigmoid"], help="Activation.")
parser.add_argument("--batch_size", default=50, type=int, help="Batch size.")
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs.")
parser.add_argument("--hidden_layer_size", default=100, type=int, help="Size of the hidden layer.")
parser.add_argument("--hidden_layers", default=1, type=int, help="Number of layers.")
parser.add_argument("--recodex", default=False, action="store_true", help="Evaluation in ReCodEx.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")
# If you add more arguments, ReCodEx will keep them with your default values.
class Dataset(npfl138.TransformedDataset):
def transform(self, example):
image = example["image"] # a torch.Tensor with torch.uint8 values in [0, 255] range
image = image.to(torch.float32) / 255 # image converted to float32 and rescaled to [0, 1]
label = example["label"] # a torch.Tensor with a single integer representing the label
return image, label # return an (input, target) pair
def main(args: argparse.Namespace) -> dict[str, float]:
# Set the random seed and the number of threads.
npfl138.startup(args.seed, args.threads, args.recodex)
npfl138.global_keras_initializers()
# Load the data and create dataloaders.
mnist = MNIST()
train = torch.utils.data.DataLoader(Dataset(mnist.train), batch_size=args.batch_size, shuffle=True)
dev = torch.utils.data.DataLoader(Dataset(mnist.dev), batch_size=args.batch_size)
# Create the model.
model = torch.nn.Sequential()
# TODO: Finish the model. Namely:
# - start by adding the `torch.nn.Flatten()` layer;
# - then add `args.hidden_layers` number of fully connected hidden layers
# `torch.nn.Linear()`, each with `args.hidden_layer_size` neurons and followed by
# a specified `args.activation`, allowing "none", "relu", "tanh", "sigmoid";
# - finally, add an output fully connected layer with `MNIST.LABELS` units.
...
# Create the TrainableModule and configure it for training.
model = npfl138.TrainableModule(model)
model.configure(
optimizer=torch.optim.Adam(model.parameters()),
loss=torch.nn.CrossEntropyLoss(),
metrics={"accuracy": torchmetrics.Accuracy("multiclass", num_classes=MNIST.LABELS)},
logdir=npfl138.format_logdir("logs/{file-}{timestamp}{-config}", **vars(args)),
)
# Train the model.
logs = model.fit(train, dev=dev, epochs=args.epochs)
# Return development metrics for ReCodEx to validate.
return {metric: value for metric, value in logs.items() if metric.startswith("dev:")}
if __name__ == "__main__":
main_args = parser.parse_args([] if "__file__" not in globals() else None)
main(main_args)

View File

@@ -0,0 +1,58 @@
#!/usr/bin/env python3
import argparse
import numpy as np
parser = argparse.ArgumentParser()
# These arguments will be set appropriately by ReCodEx, even if you change them.
parser.add_argument("--data_path", default="numpy_entropy_data.txt", type=str, help="Data distribution path.")
parser.add_argument("--model_path", default="numpy_entropy_model.txt", type=str, help="Model distribution path.")
parser.add_argument("--recodex", default=False, action="store_true", help="Evaluation in ReCodEx.")
# If you add more arguments, ReCodEx will keep them with your default values.
def main(args: argparse.Namespace) -> tuple[float, float, float]:
# TODO: Load data distribution, each line containing a datapoint -- a string.
with open(args.data_path, "r") as data:
for line in data:
line = line.rstrip("\n")
# TODO: Process the line, aggregating data with built-in Python
# data structures (not NumPy, which is not suitable for incremental
# addition and string mapping).
# TODO: Create a NumPy array containing the data distribution. The
# NumPy array should contain only data, not any mapping. Alternatively,
# the NumPy array might be created after loading the model distribution.
# TODO: Load model distribution, each line `string \t probability`.
with open(args.model_path, "r") as model:
for line in model:
line = line.rstrip("\n")
# TODO: Process the line, aggregating using Python data structures.
# TODO: Create a NumPy array containing the model distribution.
# TODO: Compute the entropy H(data distribution). You should not use
# manual for/while cycles, but instead use the fact that most NumPy methods
# operate on all elements (for example `*` is vector element-wise multiplication).
entropy = ...
# TODO: Compute cross-entropy H(data distribution, model distribution).
# When some data distribution elements are missing in the model distribution,
# the resulting crossentropy should be `np.inf`.
crossentropy = ...
# TODO: Compute KL-divergence D_KL(data distribution, model_distribution),
# again using `np.inf` when needed.
kl_divergence = ...
# Return the computed values for ReCodEx to validate.
return entropy, crossentropy, kl_divergence
if __name__ == "__main__":
main_args = parser.parse_args([] if "__file__" not in globals() else None)
entropy, crossentropy, kl_divergence = main(main_args)
print(f"Entropy: {entropy:.2f} nats")
print(f"Crossentropy: {crossentropy:.2f} nats")
print(f"KL divergence: {kl_divergence:.2f} nats")

View File

@@ -0,0 +1,7 @@
A
BB
A
A
BB
A
CCC

View File

@@ -0,0 +1,7 @@
A
BB
A
A
BB
A
CCC

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
BB 0.4
A 0.5
CCC 0.1

View File

@@ -0,0 +1,3 @@
BB 0.4
A 0.5
D 0.1

View File

@@ -0,0 +1,100 @@
ttvhw 0.0094312126288478
jvgsz 0.0010542688790326
dppgn 0.00155178394881855
gsrfj 0.00115959330379828
gonfc 0.0125594135756661
taplx 0.0130168483718944
grjbd 0.0118157969957567
ezlzh 0.000884396746443509
etbqm 0.0028883276682822
hbezl 0.0477149325045959
izndo 0.0136057551818149
emvms 0.00348555623538066
vvbgz 0.011074503885834
qlvmh 0.0324508321082636
xmjnf 0.00731021437537367
mzzkf 0.00587917090038903
qfdhm 0.0245761056970893
qwmks 0.00849433418640691
osehz 0.00337991233869235
dsgct 0.00303651707035691
mpubo 0.0133586334433875
fpnuy 4.5815030263525e-05
icwiv 0.0104882479671
usmyd 0.0143965053434422
tahse 0.00826831006992072
kpewx 0.00251578239615924
xtkrt 0.0107339052903878
iougj 0.0342521998718883
guhsk 0.0185906728923218
tykoo 0.0197009412084226
ehhoq 0.0353696768415327
vlquz 0.00115966059630282
wocje 0.0141062590065673
cpbgj 0.00219079990051944
ctugy 0.00378729517828768
aessn 0.000587148370086414
sqxvv 0.0116458117148786
rcruz 0.00226519368027114
bevod 0.00718954957464239
fufwi 0.0252245678495015
mkcbp 0.0146700693446102
fuayx 0.0131555291338195
gcltr 0.00174610698744372
esgda 0.00815650534385205
tajwq 0.00445130730008939
qcrek 0.00272711711160869
bulfd 0.00193564926804655
oyqll 0.00407246240697133
nfloq 0.00982791783138974
aiphx 0.00102055661540108
xiohq 0.000665516553375884
jnzwv 0.0159815501657243
ydniv 0.00740460630409081
zntnp 0.001347520113701
hpzog 0.00828471646149118
zpfwu 0.00259323364561152
fgnfx 0.00281632291480477
zexjt 0.00460980407970152
cqjir 0.00449780441058927
raabk 0.0168867125862489
pojke 0.000629586226187955
qhsix 0.0131909362211278
zxzaj 0.00643340353139025
dtpqp 0.0306079363057332
cgkms 0.0142372189818962
dpuri 0.0183070308798131
ysuqz 0.0121665423788785
tcclk 0.0170144702078789
tqwbr 0.00377004404546414
izoar 0.0407864649522166
egbxl 0.000800271672355991
nsjfo 0.00449051924902884
rgmnm 0.0506655416330107
lvynz 0.0159326679515115
rwlqi 0.000972123481280115
ypygl 0.00490944226804078
suddn 0.00796616445877217
yzgcc 0.000398622353240487
wtxyq 0.0067836100325681
huchd 0.00622457548430983
zhrco 0.00380565517326084
raffb 0.00216123484164195
xumnc 1.76957759973509e-05
peyxt 0.00331881717875047
keunp 0.00108034331566177
euofa 0.00273925508171951
rzacp 0.0412127910997975
mcktt 0.0122506696757427
boyoj 0.0123878873436728
wuidn 0.0109507122517976
omyfi 0.00125088875317538
fxwih 0.00289331010631559
zajpa 0.00207816129029445
qilpf 0.00757622675637339
eghhy 0.00147889398768911
xzxgz 0.0228182364589342
lfgnw 0.00765022250913878
wxaxe 0.0118444674613856
qpeyg 0.0164611528436398
ydmkt 0.00163874437311358

View File

@@ -0,0 +1,200 @@
aysirskl 0.000465969985327225
icjgrrej 0.00125399262934441
rpwecdml 0.00302176884392648
knqqumtg 0.00414001389703705
rijwsqio 0.00363108825124935
pjkouayb 0.000590694273926047
yhppippt 0.00281383589518343
gviqjnka 0.00596853741622785
rblbftch 0.0118925148629093
rpwoeola 0.00137284781891226
zkrxhkny 0.00514371341483877
evaytxev 0.000424463748174283
uoasbjrz 0.0368666520055359
faxbpnva 0.00413936957451408
ziyvyejl 0.00149656681939826
litrmtxc 0.00619374172395002
hlfnhahj 0.00338459417211952
xntjdryl 0.00966638405208417
gkqprfoc 0.00560911891937092
grxldlyk 0.00409549456009234
lklsancl 0.00323231239406044
syvktodv 0.0052132145962036
rqludekg 0.00644484194616499
uqsvrodw 0.00123606973097706
zryqedcy 0.00196828327955391
pslumxli 0.00183063547110444
tcuxhxaz 0.00583574881980228
nptbmfae 0.00358244966827256
wmrryjbh 0.00358044531741753
zusfguab 0.00275169767409786
nxknvpfz 0.00218936428227388
adydcnwt 0.00191009117751983
bgsrnutc 0.00243875740436847
aktfqryi 0.00274622822697218
czsayrkz 0.00245129923038723
snfnnaow 0.0170456098040683
fjffynzj 0.00460651599373612
uwkawwfh 0.000604574134472098
gdpurstk 0.0118908587120531
dqxddeox 0.00644999423047188
gjrnpmcc 0.0116106131733171
sghuynub 0.00426256348792279
vldohfjo 0.00356866558256043
mzwexwte 1.07961561013392e-05
dyslcyus 0.00128908650106376
pkonquix 0.000220433673243581
komsctkv 0.0023824578761889
opbtajsx 0.00398981801318325
nvgkafec 0.000529321417752203
bqekebug 0.00822060594854104
vzbhuxdb 0.012716052012467
onjszcec 0.0070519552054062
uyzkiyny 0.00663642136106938
lvpfctfp 0.00739145880453734
jqpyomdo 0.00602519295566007
dibzffyd 0.00166506991120574
bsvflmhn 0.00280000060286557
qeabeirk 0.00323521527495261
stnxvfgm 0.00209650697187905
pgtznfnf 0.00198824215693064
qxpfvmqq 0.00632503971872179
bqjirtjb 0.00710032224860597
ebnqbokg 0.0105976764736864
ifdqnine 0.00626844145375311
oarfqwxd 0.0061165404206265
jllhtiip 0.0132176299369683
sumgousm 0.0054742969637985
gifksbue 0.0105812688731108
zokasczg 0.00606818562710146
fwrqxmxw 0.000169467389301491
wqpfwzbj 0.000537291085070445
klhjcomg 0.0114437454569636
nfkgpovg 0.00246715557607813
vrjiljsq 0.00508852153669522
modgjwsr 0.019505200896841
lkwpaiik 0.00772415799941654
hpuguqca 0.000346480864675375
rodlhbtu 0.0119742929904463
sbvqorso 0.00291444366504115
utqcywou 0.000497243075444838
rcxcoclf 0.0183219799199796
clzquaoh 0.00875008206537465
hnpikhxo 0.00793749915555697
slfuffmf 0.00664340904976409
hmpjprva 0.0109556882297664
wrwlwwgs 0.0103183805128625
ocfnlnyk 0.00532839764792461
fgshzlkz 0.0137498707566837
llwmrzjl 0.00078175543757186
bnqqgrlp 0.0128729356486284
hcnggcmw 0.00143409059189367
wwatmyfh 0.0123189657702296
ybgaiwqm 0.00140099211351084
fpzxralq 0.00279045651387019
xlrcyrio 0.00227845165607012
cxzsyfuo 0.00267141299994114
gjhikwpn 0.000883501012657947
wlpbnexn 0.0027562191128744
jemonchd 0.00126825844735878
canuyblm 0.0216756574568504
pdbdosic 0.00180608741766137
cxorzqix 0.00212926341416582
rtledujy 0.00129871885173409
hejxlqrt 0.00573128888947967
kavwcfmu 3.54587299894723e-05
fatcfvpk 0.00723990686470467
bpplofjm 0.0041596511618454
posryrwj 0.0117680238680073
bvpnukag 0.00404497449067731
lbelzpbs 0.0062178236771778
gribnjjy 0.0109197239989149
stkfzrxm 0.00514831697438371
inbikasp 0.00284453578841593
cdxltwqx 0.00878268364294129
nvgiglpg 0.014784030751736
fmaafgvi 0.00204418798526165
qtyafvix 0.0137826948872999
erhbdnlg 0.00187229855731789
ulesdghl 0.00316583831999121
ebheqjpk 0.00266386141319316
gsmutwgq 0.00193798388627551
hfjbohcn 0.00723777956971534
kvanptal 0.00527843140542493
hsbfrqwc 0.00546974392745142
ruiquets 0.00168512464647176
veourctx 0.00180094566860724
nnshafed 0.00172147543036827
irohhwfg 0.00243495609454943
ivxoaqvg 0.000906434346262496
wnddbfsg 0.000226736591764822
osegqybm 0.00057255386567979
lrjysmbk 0.000139930203234663
sakueadt 0.00131143263623671
dhmytfyl 0.00102515492571535
fcctcaro 0.00722697313689681
bvggwzqp 0.0111959052719189
ajtczmtc 0.00214380561744269
dklpprvj 0.000723659838801874
scvfioku 0.00874660366923869
egrqonlm 0.000188777762207543
hpzlpxzg 0.000962056196345939
ntdwlyyu 0.00496070367855473
ygedcrzi 0.00182670850712668
fswersgr 0.00622626521429683
kndcdnkp 0.000674263651837786
qaoatkzw 0.003087492612335
krwmiegb 0.00855596722752998
rxpjkrob 0.00261830406705972
twtxssyg 0.013091523141205
qyalrqha 0.00503801619670401
ixmclekj 0.00107631312343053
tdfpaezp 0.0010257675376158
hvjrkgvn 0.00955381062591057
kqznhktp 0.00400096482424042
nuqwzjqi 0.000819740767639069
xjoydjkd 0.00395214578889989
zqcetrfj 0.00319829249514583
amhehfow 1.88587714546712e-05
sxmtlzlw 0.00663259259665281
qrrwjspo 0.00712841972374552
szhtpjyv 0.00156278271170568
verttiuu 0.00188167145628583
cxoftctg 0.00106191765565093
htuvrriy 0.00364447520438559
xqfvyctk 0.00449839948971503
kuipdodd 0.000411362629861999
gdqunmio 0.00264740859722539
sjyqvgwo 0.000428065184436695
rqgdglno 0.00175138550652958
aylkedwc 0.0119582316725472
ikutiphu 0.00576611726594471
qegytrys 0.018424201625129
zjfehspd 0.00139828990383295
hdbfrnha 0.00293768717793717
ssrktoar 0.000223708711653422
wcnhypxx 0.000694563090736388
egdyjbri 0.00107871204886407
stfwtrsa 0.00293697893838925
yaqvpcri 0.00303605758519147
irdnmbpc 0.0048176606001274
wtnecrbu 0.00459253359839141
lsijbdpo 0.00285817743547033
emvfrril 0.00422099180331
lmzuxenv 0.00898921604201908
ahpqjvoy 0.000549047186779937
mbpcgqyi 0.0139713446598273
cbutcdlm 0.00210694004038603
aybiktvj 0.00482170683266502
yhjoghiv 0.0024614271271427
umlwokvi 0.0012906962733002
pqxekaan 0.00585376668463907
xzdhqifz 0.00518329508250998
xpncwfyz 0.00616017953348638
hpuoygat 0.00626962514542393
ghvjpnkn 0.00674336788737429
elmvmqfx 0.00298649404059306
ewtkqwqc 0.00278733029070604
tmmjvjvw 0.0143408324546755
mxzsltag 0.00367522668308852
barlvanl 0.00468990081850378

82
hod_1/data/pca_first.py Normal file
View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python3
import argparse
import numpy as np
import torch
import npfl138
from npfl138.datasets.mnist import MNIST
npfl138.require_version("2526.1")
parser = argparse.ArgumentParser()
# These arguments will be set appropriately by ReCodEx, even if you change them.
parser.add_argument("--examples", default=256, type=int, help="MNIST examples to use.")
parser.add_argument("--iterations", default=100, type=int, help="Iterations of the power algorithm.")
parser.add_argument("--recodex", default=False, action="store_true", help="Evaluation in ReCodEx.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")
# If you add more arguments, ReCodEx will keep them with your default values.
def main(args: argparse.Namespace) -> tuple[float, float]:
# Set the random seed and the number of threads.
npfl138.startup(args.seed, args.threads, args.recodex)
npfl138.global_keras_initializers()
# Prepare the data.
mnist = MNIST()
data_indices = np.random.choice(len(mnist.train), size=args.examples, replace=False)
data = mnist.train.data["images"][data_indices].to(torch.float32) / 255
# TODO: Data has shape [args.examples, MNIST.C, MNIST.H, MNIST.W].
# We want to reshape it to [args.examples, MNIST.C * MNIST.H * MNIST.W].
# We can do so using `torch.reshape(data, new_shape)` with new shape
# `[data.shape[0], data.shape[1] * data.shape[2] * data.shape[3]]`.
data = ...
# TODO: Now compute mean of every feature. Use `torch.mean`, and set
# `dim` (or `axis`) argument to zero -- therefore, the mean will be
# computed across the first dimension, so across examples.
#
# Note that for compatibility with Numpy/TF/Keras, all `dim` arguments
# in PyTorch can be also called `axis`.
mean = ...
# TODO: Compute the covariance matrix. The covariance matrix is
# (data - mean)^T @ (data - mean) / data.shape[0]
# where transpose can be computed using `torch.transpose` or `torch.t` and
# matrix multiplication using either Python operator @ or `torch.matmul`.
cov = ...
# TODO: Compute the total variance, which is the sum of the diagonal
# of the covariance matrix. To extract the diagonal use `torch.diagonal`,
# and to sum a tensor use `torch.sum`.
total_variance = ...
# TODO: Now run `args.iterations` of the power iteration algorithm.
# Start with a vector of `cov.shape[0]` ones of type `torch.float32` using `torch.ones`.
v = ...
for i in range(args.iterations):
# TODO: In the power iteration algorithm, we compute
# 1. v = cov v
# The matrix-vector multiplication can be computed as regular matrix multiplication
# or using `torch.mv`.
# 2. s = l2_norm(v)
# The l2_norm can be computed using for example `torch.linalg.vector_norm`.
# 3. v = v / s
...
# The `v` is now approximately the eigenvector of the largest eigenvalue, `s`.
# We now compute the explained variance, which is the ratio of `s` and `total_variance`.
explained_variance = s / total_variance
# Return the total and explained variance for ReCodEx to validate
return total_variance, 100 * explained_variance
if __name__ == "__main__":
main_args = parser.parse_args([] if "__file__" not in globals() else None)
total_variance, explained_variance = main(main_args)
print(f"Total variance: {total_variance:.2f}")
print(f"Explained variance: {explained_variance:.2f}%")