01 hotova

This commit is contained in:
Priec
2026-03-07 21:30:55 +01:00
parent 009e5c4925
commit f0b2073caa
15 changed files with 11753 additions and 48 deletions

63
hod_1/src/lib.rs Normal file
View File

@@ -0,0 +1,63 @@
use burn::tensor::Tensor;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader};
pub type B = burn_ndarray::NdArray<f64>;
// Funkcia na načítanie a zarovnanie dát
pub fn load_and_align(data_path: &str, model_path: &str) -> (Vec<f64>, Vec<f64>) {
let mut counts = HashMap::new();
let mut total_count = 0.0;
let file_p = File::open(data_path).expect("Nepodarilo sa otvoriť dáta");
for line in BufReader::new(file_p).lines() {
let point = line.unwrap().trim().to_string();
if point.is_empty() { continue; }
*counts.entry(point).or_insert(0.0) += 1.0;
total_count += 1.0;
}
let mut model_map = HashMap::new();
let file_q = File::open(model_path).expect("Nepodarilo sa otvoriť model");
for line in BufReader::new(file_q).lines() {
let l = line.unwrap();
let parts: Vec<&str> = l.split('\t').collect();
if parts.len() >= 2 {
model_map.insert(parts[0].to_string(), parts[1].parse::<f64>().unwrap());
}
}
let mut p_vals = Vec::new();
let mut q_vals = Vec::new();
for (point, count) in counts.iter() {
p_vals.push(count / total_count);
q_vals.push(*model_map.get(point).unwrap_or(&0.0));
}
(p_vals, q_vals)
}
pub fn entropy(p: Tensor<B, 1>) -> f64 {
let zero_mask = p.clone().equal_elem(0.0);
let p_safe = p.clone().mask_fill(zero_mask, 1.0);
let terms = p * p_safe.log();
-terms.sum().into_scalar()
}
pub fn cross_entropy(p: Tensor<B, 1>, q: Tensor<B, 1>) -> f64 {
let zero_mask_q = q.clone().equal_elem(0.0);
let p_exists = p.clone().greater_elem(0.0);
if p_exists.bool_and(zero_mask_q.clone()).any().into_scalar() {
return f64::INFINITY;
}
let q_safe = q.mask_fill(zero_mask_q, 1.0);
let terms = p * q_safe.log();
-terms.sum().into_scalar()
}
pub fn kl_div2(p: Tensor<B, 1>, q: Tensor<B, 1>) -> f64 {
let ce = cross_entropy(p.clone(), q);
let e = entropy(p);
let result = ce - e;
if result < 0.0 { 0.0 } else { result }
}

View File

@@ -1,56 +1,30 @@
use burn::tensor::{backend::Backend, Tensor, TensorData};
use burn_ndarray::NdArray;
use ndarray::{array, Array1};
use burn::tensor::{backend::Backend, Tensor};
use clap::Parser;
// Nahraď 'hod_1' názvom tvojho projektu v Cargo.toml
use hod_1::{load_and_align, entropy, cross_entropy, kl_div2, B};
type B = NdArray<f64>;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(long = "data_path")]
data_path: String,
#[arg(long = "model_path")]
model_path: String,
}
fn main() {
let args = Args::parse();
let device = <B as Backend>::Device::default();
// ndarray::Array1<f64>
let a: Array1<f64> = array![0.4, 0.6];
// Použijeme funkciu z lib.rs
let (p_vec, q_vec) = load_and_align(&args.data_path, &args.model_path);
// Convert ndarray -> Burn tensor
let p = Tensor::<B, 1>::from_data(
TensorData::new(a.to_vec(), [a.len()]),
&device,
);
let p = Tensor::<B, 1>::from_data(p_vec.as_slice(), &device);
let q = Tensor::<B, 1>::from_data(q_vec.as_slice(), &device);
let h = entropy(p);
println!("{h}");
// Výpočty
println!("{}", entropy(p.clone()));
println!("{}", cross_entropy(p.clone(), q.clone()));
println!("{}", kl_div2(p, q));
}
fn entropy(p: Tensor<B, 1>) -> f64 {
// Handle p = 0 safely, because 0 * log(0) should contribute 0
let zero_mask = p.clone().equal_elem(0.0);
let p_safe = p.clone().mask_fill(zero_mask.clone(), 1.0);
let terms = (p.clone() * p_safe.log()).mask_fill(zero_mask, 0.0);
(-terms).sum().into_scalar()
}
// pub fn entropy2(p: Tensor<B, 1>) -> f64 {
// if p == 0.0 {
// return 0.0;
// }
// - p * p.ln()
// }
// pub fn cross_entropy(p: Tensor<B, 1>, q: f64) -> f64 {
// if p == 0.0 {
// return 0.0;
// }
// - p * q.ln()
// }
// pub fn kl_div(p: f64, q: f64) -> f64 {
// if p == 0.0 {
// return 0.0;
// }
// p*(p.ln()-q.ln())
// }
// pub fn kl_div2(p: f64, q: f64) -> f64 {
// cross_entropy(p,q) - entropy(p)
// }