entropy working
This commit is contained in:
61
flake.lock
generated
Normal file
61
flake.lock
generated
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
{
|
||||||
|
"nodes": {
|
||||||
|
"flake-utils": {
|
||||||
|
"inputs": {
|
||||||
|
"systems": "systems"
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1731533236,
|
||||||
|
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nixpkgs": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1772773019,
|
||||||
|
"narHash": "sha256-E1bxHxNKfDoQUuvriG71+f+s/NT0qWkImXsYZNFFfCs=",
|
||||||
|
"owner": "NixOS",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"rev": "aca4d95fce4914b3892661bcb80b8087293536c6",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "NixOS",
|
||||||
|
"ref": "nixos-unstable",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"inputs": {
|
||||||
|
"flake-utils": "flake-utils",
|
||||||
|
"nixpkgs": "nixpkgs"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"systems": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1681028828,
|
||||||
|
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": "root",
|
||||||
|
"version": 7
|
||||||
|
}
|
||||||
33
flake.nix
Normal file
33
flake.nix
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
{
|
||||||
|
description = "Pure Rust development environment";
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||||
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
|
};
|
||||||
|
|
||||||
|
outputs = { self, nixpkgs, flake-utils }:
|
||||||
|
flake-utils.lib.eachDefaultSystem (system:
|
||||||
|
let
|
||||||
|
pkgs = import nixpkgs { inherit system; };
|
||||||
|
in {
|
||||||
|
devShells.default = pkgs.mkShell {
|
||||||
|
packages = with pkgs; [
|
||||||
|
rustc
|
||||||
|
cargo
|
||||||
|
rust-analyzer
|
||||||
|
clippy
|
||||||
|
rustfmt
|
||||||
|
pkg-config
|
||||||
|
openssl
|
||||||
|
];
|
||||||
|
|
||||||
|
shellHook = ''
|
||||||
|
echo ">>> Pure Rust DevShell (Nix managed toolchain)"
|
||||||
|
cargo --version
|
||||||
|
rustc --version
|
||||||
|
echo "Run: cargo new <project_name>"
|
||||||
|
'';
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
1
hod_1/.gitignore
vendored
Normal file
1
hod_1/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
target/
|
||||||
2864
hod_1/Cargo.lock
generated
Normal file
2864
hod_1/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
10
hod_1/Cargo.toml
Normal file
10
hod_1/Cargo.toml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
[package]
|
||||||
|
name = "hod_1"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
burn = { version = "0.20.1", default-features = false, features = ["ndarray"] }
|
||||||
|
burn-ndarray = "0.20.1"
|
||||||
|
clap = { version = "4.5.60", features = ["derive"] }
|
||||||
|
ndarray = "0.17.2"
|
||||||
56
hod_1/src/main.rs
Normal file
56
hod_1/src/main.rs
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
use burn::tensor::{backend::Backend, Tensor, TensorData};
|
||||||
|
use burn_ndarray::NdArray;
|
||||||
|
use ndarray::{array, Array1};
|
||||||
|
|
||||||
|
type B = NdArray<f64>;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let device = <B as Backend>::Device::default();
|
||||||
|
|
||||||
|
// ndarray::Array1<f64>
|
||||||
|
let a: Array1<f64> = array![0.4, 0.6];
|
||||||
|
|
||||||
|
// Convert ndarray -> Burn tensor
|
||||||
|
let p = Tensor::<B, 1>::from_data(
|
||||||
|
TensorData::new(a.to_vec(), [a.len()]),
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
|
||||||
|
let h = entropy(p);
|
||||||
|
println!("{h}");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn entropy(p: Tensor<B, 1>) -> f64 {
|
||||||
|
// Handle p = 0 safely, because 0 * log(0) should contribute 0
|
||||||
|
let zero_mask = p.clone().equal_elem(0.0);
|
||||||
|
let p_safe = p.clone().mask_fill(zero_mask.clone(), 1.0);
|
||||||
|
|
||||||
|
let terms = (p.clone() * p_safe.log()).mask_fill(zero_mask, 0.0);
|
||||||
|
|
||||||
|
(-terms).sum().into_scalar()
|
||||||
|
}
|
||||||
|
|
||||||
|
// pub fn entropy2(p: Tensor<B, 1>) -> f64 {
|
||||||
|
// if p == 0.0 {
|
||||||
|
// return 0.0;
|
||||||
|
// }
|
||||||
|
// - p * p.ln()
|
||||||
|
// }
|
||||||
|
|
||||||
|
// pub fn cross_entropy(p: Tensor<B, 1>, q: f64) -> f64 {
|
||||||
|
// if p == 0.0 {
|
||||||
|
// return 0.0;
|
||||||
|
// }
|
||||||
|
// - p * q.ln()
|
||||||
|
// }
|
||||||
|
|
||||||
|
// pub fn kl_div(p: f64, q: f64) -> f64 {
|
||||||
|
// if p == 0.0 {
|
||||||
|
// return 0.0;
|
||||||
|
// }
|
||||||
|
// p*(p.ln()-q.ln())
|
||||||
|
// }
|
||||||
|
|
||||||
|
// pub fn kl_div2(p: f64, q: f64) -> f64 {
|
||||||
|
// cross_entropy(p,q) - entropy(p)
|
||||||
|
// }
|
||||||
Reference in New Issue
Block a user