entropy working
This commit is contained in:
61
flake.lock
generated
Normal file
61
flake.lock
generated
Normal file
@@ -0,0 +1,61 @@
|
||||
{
|
||||
"nodes": {
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1772773019,
|
||||
"narHash": "sha256-E1bxHxNKfDoQUuvriG71+f+s/NT0qWkImXsYZNFFfCs=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "aca4d95fce4914b3892661bcb80b8087293536c6",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils",
|
||||
"nixpkgs": "nixpkgs"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
||||
33
flake.nix
Normal file
33
flake.nix
Normal file
@@ -0,0 +1,33 @@
|
||||
{
|
||||
description = "Pure Rust development environment";
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
|
||||
outputs = { self, nixpkgs, flake-utils }:
|
||||
flake-utils.lib.eachDefaultSystem (system:
|
||||
let
|
||||
pkgs = import nixpkgs { inherit system; };
|
||||
in {
|
||||
devShells.default = pkgs.mkShell {
|
||||
packages = with pkgs; [
|
||||
rustc
|
||||
cargo
|
||||
rust-analyzer
|
||||
clippy
|
||||
rustfmt
|
||||
pkg-config
|
||||
openssl
|
||||
];
|
||||
|
||||
shellHook = ''
|
||||
echo ">>> Pure Rust DevShell (Nix managed toolchain)"
|
||||
cargo --version
|
||||
rustc --version
|
||||
echo "Run: cargo new <project_name>"
|
||||
'';
|
||||
};
|
||||
});
|
||||
}
|
||||
1
hod_1/.gitignore
vendored
Normal file
1
hod_1/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
target/
|
||||
2864
hod_1/Cargo.lock
generated
Normal file
2864
hod_1/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
10
hod_1/Cargo.toml
Normal file
10
hod_1/Cargo.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[package]
|
||||
name = "hod_1"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
burn = { version = "0.20.1", default-features = false, features = ["ndarray"] }
|
||||
burn-ndarray = "0.20.1"
|
||||
clap = { version = "4.5.60", features = ["derive"] }
|
||||
ndarray = "0.17.2"
|
||||
56
hod_1/src/main.rs
Normal file
56
hod_1/src/main.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use burn::tensor::{backend::Backend, Tensor, TensorData};
|
||||
use burn_ndarray::NdArray;
|
||||
use ndarray::{array, Array1};
|
||||
|
||||
type B = NdArray<f64>;
|
||||
|
||||
fn main() {
|
||||
let device = <B as Backend>::Device::default();
|
||||
|
||||
// ndarray::Array1<f64>
|
||||
let a: Array1<f64> = array![0.4, 0.6];
|
||||
|
||||
// Convert ndarray -> Burn tensor
|
||||
let p = Tensor::<B, 1>::from_data(
|
||||
TensorData::new(a.to_vec(), [a.len()]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let h = entropy(p);
|
||||
println!("{h}");
|
||||
}
|
||||
|
||||
fn entropy(p: Tensor<B, 1>) -> f64 {
|
||||
// Handle p = 0 safely, because 0 * log(0) should contribute 0
|
||||
let zero_mask = p.clone().equal_elem(0.0);
|
||||
let p_safe = p.clone().mask_fill(zero_mask.clone(), 1.0);
|
||||
|
||||
let terms = (p.clone() * p_safe.log()).mask_fill(zero_mask, 0.0);
|
||||
|
||||
(-terms).sum().into_scalar()
|
||||
}
|
||||
|
||||
// pub fn entropy2(p: Tensor<B, 1>) -> f64 {
|
||||
// if p == 0.0 {
|
||||
// return 0.0;
|
||||
// }
|
||||
// - p * p.ln()
|
||||
// }
|
||||
|
||||
// pub fn cross_entropy(p: Tensor<B, 1>, q: f64) -> f64 {
|
||||
// if p == 0.0 {
|
||||
// return 0.0;
|
||||
// }
|
||||
// - p * q.ln()
|
||||
// }
|
||||
|
||||
// pub fn kl_div(p: f64, q: f64) -> f64 {
|
||||
// if p == 0.0 {
|
||||
// return 0.0;
|
||||
// }
|
||||
// p*(p.ln()-q.ln())
|
||||
// }
|
||||
|
||||
// pub fn kl_div2(p: f64, q: f64) -> f64 {
|
||||
// cross_entropy(p,q) - entropy(p)
|
||||
// }
|
||||
Reference in New Issue
Block a user