01 hotova
This commit is contained in:
63
hod_1/src/lib.rs
Normal file
63
hod_1/src/lib.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
use burn::tensor::Tensor;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader};
|
||||
|
||||
pub type B = burn_ndarray::NdArray<f64>;
|
||||
|
||||
// Funkcia na načítanie a zarovnanie dát
|
||||
pub fn load_and_align(data_path: &str, model_path: &str) -> (Vec<f64>, Vec<f64>) {
|
||||
let mut counts = HashMap::new();
|
||||
let mut total_count = 0.0;
|
||||
|
||||
let file_p = File::open(data_path).expect("Nepodarilo sa otvoriť dáta");
|
||||
for line in BufReader::new(file_p).lines() {
|
||||
let point = line.unwrap().trim().to_string();
|
||||
if point.is_empty() { continue; }
|
||||
*counts.entry(point).or_insert(0.0) += 1.0;
|
||||
total_count += 1.0;
|
||||
}
|
||||
|
||||
let mut model_map = HashMap::new();
|
||||
let file_q = File::open(model_path).expect("Nepodarilo sa otvoriť model");
|
||||
for line in BufReader::new(file_q).lines() {
|
||||
let l = line.unwrap();
|
||||
let parts: Vec<&str> = l.split('\t').collect();
|
||||
if parts.len() >= 2 {
|
||||
model_map.insert(parts[0].to_string(), parts[1].parse::<f64>().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
let mut p_vals = Vec::new();
|
||||
let mut q_vals = Vec::new();
|
||||
for (point, count) in counts.iter() {
|
||||
p_vals.push(count / total_count);
|
||||
q_vals.push(*model_map.get(point).unwrap_or(&0.0));
|
||||
}
|
||||
(p_vals, q_vals)
|
||||
}
|
||||
|
||||
pub fn entropy(p: Tensor<B, 1>) -> f64 {
|
||||
let zero_mask = p.clone().equal_elem(0.0);
|
||||
let p_safe = p.clone().mask_fill(zero_mask, 1.0);
|
||||
let terms = p * p_safe.log();
|
||||
-terms.sum().into_scalar()
|
||||
}
|
||||
|
||||
pub fn cross_entropy(p: Tensor<B, 1>, q: Tensor<B, 1>) -> f64 {
|
||||
let zero_mask_q = q.clone().equal_elem(0.0);
|
||||
let p_exists = p.clone().greater_elem(0.0);
|
||||
if p_exists.bool_and(zero_mask_q.clone()).any().into_scalar() {
|
||||
return f64::INFINITY;
|
||||
}
|
||||
let q_safe = q.mask_fill(zero_mask_q, 1.0);
|
||||
let terms = p * q_safe.log();
|
||||
-terms.sum().into_scalar()
|
||||
}
|
||||
|
||||
pub fn kl_div2(p: Tensor<B, 1>, q: Tensor<B, 1>) -> f64 {
|
||||
let ce = cross_entropy(p.clone(), q);
|
||||
let e = entropy(p);
|
||||
let result = ce - e;
|
||||
if result < 0.0 { 0.0 } else { result }
|
||||
}
|
||||
@@ -1,56 +1,30 @@
|
||||
use burn::tensor::{backend::Backend, Tensor, TensorData};
|
||||
use burn_ndarray::NdArray;
|
||||
use ndarray::{array, Array1};
|
||||
use burn::tensor::{backend::Backend, Tensor};
|
||||
use clap::Parser;
|
||||
// Nahraď 'hod_1' názvom tvojho projektu v Cargo.toml
|
||||
use hod_1::{load_and_align, entropy, cross_entropy, kl_div2, B};
|
||||
|
||||
type B = NdArray<f64>;
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[arg(long = "data_path")]
|
||||
data_path: String,
|
||||
|
||||
#[arg(long = "model_path")]
|
||||
model_path: String,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args = Args::parse();
|
||||
let device = <B as Backend>::Device::default();
|
||||
|
||||
// ndarray::Array1<f64>
|
||||
let a: Array1<f64> = array![0.4, 0.6];
|
||||
// Použijeme funkciu z lib.rs
|
||||
let (p_vec, q_vec) = load_and_align(&args.data_path, &args.model_path);
|
||||
|
||||
// Convert ndarray -> Burn tensor
|
||||
let p = Tensor::<B, 1>::from_data(
|
||||
TensorData::new(a.to_vec(), [a.len()]),
|
||||
&device,
|
||||
);
|
||||
let p = Tensor::<B, 1>::from_data(p_vec.as_slice(), &device);
|
||||
let q = Tensor::<B, 1>::from_data(q_vec.as_slice(), &device);
|
||||
|
||||
let h = entropy(p);
|
||||
println!("{h}");
|
||||
// Výpočty
|
||||
println!("{}", entropy(p.clone()));
|
||||
println!("{}", cross_entropy(p.clone(), q.clone()));
|
||||
println!("{}", kl_div2(p, q));
|
||||
}
|
||||
|
||||
fn entropy(p: Tensor<B, 1>) -> f64 {
|
||||
// Handle p = 0 safely, because 0 * log(0) should contribute 0
|
||||
let zero_mask = p.clone().equal_elem(0.0);
|
||||
let p_safe = p.clone().mask_fill(zero_mask.clone(), 1.0);
|
||||
|
||||
let terms = (p.clone() * p_safe.log()).mask_fill(zero_mask, 0.0);
|
||||
|
||||
(-terms).sum().into_scalar()
|
||||
}
|
||||
|
||||
// pub fn entropy2(p: Tensor<B, 1>) -> f64 {
|
||||
// if p == 0.0 {
|
||||
// return 0.0;
|
||||
// }
|
||||
// - p * p.ln()
|
||||
// }
|
||||
|
||||
// pub fn cross_entropy(p: Tensor<B, 1>, q: f64) -> f64 {
|
||||
// if p == 0.0 {
|
||||
// return 0.0;
|
||||
// }
|
||||
// - p * q.ln()
|
||||
// }
|
||||
|
||||
// pub fn kl_div(p: f64, q: f64) -> f64 {
|
||||
// if p == 0.0 {
|
||||
// return 0.0;
|
||||
// }
|
||||
// p*(p.ln()-q.ln())
|
||||
// }
|
||||
|
||||
// pub fn kl_div2(p: f64, q: f64) -> f64 {
|
||||
// cross_entropy(p,q) - entropy(p)
|
||||
// }
|
||||
|
||||
Reference in New Issue
Block a user