working pca_first

This commit is contained in:
Priec
2026-03-12 15:20:07 +01:00
parent f0b2073caa
commit 2174d4e506
4 changed files with 844 additions and 69 deletions

View File

@@ -1,63 +1,57 @@
use burn::tensor::Tensor;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader};
use burn::tensor::linalg::diag;
use burn::tensor::Shape;
pub type B = burn_ndarray::NdArray<f64>;
// Funkcia na načítanie a zarovnanie dát
pub fn load_and_align(data_path: &str, model_path: &str) -> (Vec<f64>, Vec<f64>) {
let mut counts = HashMap::new();
let mut total_count = 0.0;
let file_p = File::open(data_path).expect("Nepodarilo sa otvoriť dáta");
for line in BufReader::new(file_p).lines() {
let point = line.unwrap().trim().to_string();
if point.is_empty() { continue; }
*counts.entry(point).or_insert(0.0) += 1.0;
total_count += 1.0;
}
let mut model_map = HashMap::new();
let file_q = File::open(model_path).expect("Nepodarilo sa otvoriť model");
for line in BufReader::new(file_q).lines() {
let l = line.unwrap();
let parts: Vec<&str> = l.split('\t').collect();
if parts.len() >= 2 {
model_map.insert(parts[0].to_string(), parts[1].parse::<f64>().unwrap());
}
}
let mut p_vals = Vec::new();
let mut q_vals = Vec::new();
for (point, count) in counts.iter() {
p_vals.push(count / total_count);
q_vals.push(*model_map.get(point).unwrap_or(&0.0));
}
(p_vals, q_vals)
fn l2_norm(v: Tensor<B, 1>) -> f64 {
v.clone()
.mul(v) // element-wise: v_i * v_i
.sum() // suma všetkých v_i^2
.sqrt() // odmocnina
.into_scalar() // na f32
}
pub fn entropy(p: Tensor<B, 1>) -> f64 {
let zero_mask = p.clone().equal_elem(0.0);
let p_safe = p.clone().mask_fill(zero_mask, 1.0);
let terms = p * p_safe.log();
-terms.sum().into_scalar()
/// Input: [N, 784], Output: [N, 784]
pub fn center(x: Tensor<B, 2>) -> Tensor<B, 2> {
let mean = x.clone().mean_dim(0);
x.sub(mean)
}
pub fn cross_entropy(p: Tensor<B, 1>, q: Tensor<B, 1>) -> f64 {
let zero_mask_q = q.clone().equal_elem(0.0);
let p_exists = p.clone().greater_elem(0.0);
if p_exists.bool_and(zero_mask_q.clone()).any().into_scalar() {
return f64::INFINITY;
/// Input: [N, 784], Output: [784, 784]
pub fn covariance(x: Tensor<B, 2>) -> Tensor<B, 2> {
let cen = center(x);
let transpose = cen.clone().transpose();
let n = cen.dims()[0] as f64;
let mul = transpose.matmul(cen);
mul.div_scalar(n - 1.0)
}
pub fn total_variance(x: Tensor<B, 2>) -> f64 {
let cov: Tensor<B, 2> = covariance(x);
let diag: Tensor<B, 1> = diag(cov);
let sum = diag.sum().into_scalar();
sum
}
/// Input: [784, 784], scalar, Output: [784]
pub fn power_iteration(cov: Tensor<B, 2>, iterations: usize) -> (Tensor<B, 1>, f64) {
let n = cov.dims()[0];
let device = cov.device();
let mut v: Tensor<B, 1> = Tensor::ones(Shape::new([n]), &device);
let mut s: f64 = 0.0;
for _ in 0..iterations {
let v_new_2d = cov.clone().matmul(v.reshape([n, 1]));
let v_new = v_new_2d.squeeze::<1>();
s = l2_norm(v_new.clone());
v = v_new.div_scalar(s);
}
let q_safe = q.mask_fill(zero_mask_q, 1.0);
let terms = p * q_safe.log();
-terms.sum().into_scalar()
return (v, s);
}
pub fn kl_div2(p: Tensor<B, 1>, q: Tensor<B, 1>) -> f64 {
let ce = cross_entropy(p.clone(), q);
let e = entropy(p);
let result = ce - e;
if result < 0.0 { 0.0 } else { result }
/// Input: [784, 784], [784], Output: f32
pub fn explained_variance(total_var: f64, s: f64) -> f64 {
s / total_var
}

View File

@@ -1,30 +1,86 @@
use burn::tensor::{backend::Backend, Tensor};
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use clap::Parser;
// Nahraď 'hod_1' názvom tvojho projektu v Cargo.toml
use hod_1::{load_and_align, entropy, cross_entropy, kl_div2, B};
use hod_1::{covariance, explained_variance, power_iteration, total_variance, B};
use std::fs::File;
use std::io::Read;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(long = "data_path")]
data_path: String,
#[arg(long = "examples", default_value = "1024")]
examples: usize,
#[arg(long = "iterations", default_value = "64")]
iterations: usize,
}
#[arg(long = "model_path")]
model_path: String,
fn load_mnist(examples: usize, device: &<B as Backend>::Device) -> Tensor<B, 2> {
let file = File::open("mnist.npz").expect("Cannot open mnist.npz");
let mut archive = zip::ZipArchive::new(file).expect("Cannot read zip");
// Print all available array names so you can see what's inside
eprintln!("Arrays in mnist.npz:");
for i in 0..archive.len() {
eprintln!(" {}", archive.by_index(i).unwrap().name());
}
// Try the most common key names used for MNIST train images
let candidates = [
"train_images.npy",
"train.images.npy",
"x_train.npy",
"images.npy",
];
let mut bytes = Vec::new();
let mut found_name = "";
for name in &candidates {
if archive.by_name(name).is_ok() {
archive
.by_name(name)
.unwrap()
.read_to_end(&mut bytes)
.expect("Failed to read entry");
found_name = name;
break;
}
}
assert!(!bytes.is_empty(), "Could not find train images — check the printed names above and update candidates[]");
eprintln!("Loaded from: {found_name}");
// Parse the .npy header to get the shape
let npy = npyz::NpyFile::new(&bytes[..]).expect("Cannot parse npy");
let shape = npy.shape().to_vec();
eprintln!("Raw array shape: {shape:?}");
// MNIST is stored as uint8 (0255); we normalise to [0.0, 1.0]
let raw: Vec<u8> = npy.into_vec().expect("Failed to read as u8 — dtype mismatch?");
let n = examples.min(shape[0] as usize);
let pixels = raw.len() / shape[0] as usize; // 784 = 1*28*28, regardless of how axes are ordered
let data: Vec<f64> = raw[..n * pixels]
.iter()
.map(|&p| p as f64 / 255.0)
.collect();
eprintln!("Loaded {n} examples, {pixels} pixels each");
let tensor_data = burn::tensor::TensorData::new(data, [n, pixels]);
Tensor::<B, 2>::from_data(tensor_data, device)
}
fn main() {
let args = Args::parse();
let device = <B as Backend>::Device::default();
// Použijeme funkciu z lib.rs
let (p_vec, q_vec) = load_and_align(&args.data_path, &args.model_path);
let x = load_mnist(args.examples, &device);
let p = Tensor::<B, 1>::from_data(p_vec.as_slice(), &device);
let q = Tensor::<B, 1>::from_data(q_vec.as_slice(), &device);
let cov = covariance(x.clone());
let total_var = total_variance(x.clone());
let (_pc, s) = power_iteration(cov, args.iterations);
let ev = explained_variance(total_var, s);
// Výpočty
println!("{}", entropy(p.clone()));
println!("{}", cross_entropy(p.clone(), q.clone()));
println!("{}", kl_div2(p, q));
println!("Total variance: {:.2}", total_var);
println!("Explained variance: {:.2}%", 100.0 * ev);
}