├── .gitignore ├── mnist.png ├── plot.png ├── losses.png ├── losses_mnist.png ├── src ├── kernels │ ├── knn │ │ ├── mod.rs │ │ ├── forward.rs │ │ ├── backward.rs │ │ └── kernel.rs │ ├── euclidean │ │ ├── mod.rs │ │ ├── forward.rs │ │ ├── backward.rs │ │ └── kernel.rs │ └── mod.rs ├── macros.rs ├── normalizer.rs ├── backend.rs ├── prelude.rs ├── train │ ├── get_distance_by_metric.rs │ ├── config.rs │ └── mod.rs ├── lib.rs ├── model.rs ├── distances.rs ├── chart.rs └── utils.rs ├── .vscode └── settings.json ├── LICENSE ├── Cargo.toml ├── examples ├── simple.rs ├── advanced.rs ├── mnist.rs └── mnist_benchmark.rs └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | .DS_Store 3 | data/* -------------------------------------------------------------------------------- /mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenehp/fast-umap/HEAD/mnist.png -------------------------------------------------------------------------------- /plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenehp/fast-umap/HEAD/plot.png -------------------------------------------------------------------------------- /losses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenehp/fast-umap/HEAD/losses.png -------------------------------------------------------------------------------- /losses_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenehp/fast-umap/HEAD/losses_mnist.png -------------------------------------------------------------------------------- /src/kernels/knn/mod.rs: -------------------------------------------------------------------------------- 1 | pub(crate) mod backward; 2 | pub(crate) mod forward; 3 | mod kernel; 4 | -------------------------------------------------------------------------------- /src/kernels/euclidean/mod.rs: -------------------------------------------------------------------------------- 1 | pub(crate) mod backward; 2 | pub(crate) mod forward; 3 | mod kernel; 4 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.formatOnSave": true, 3 | "editor.defaultFormatter": "rust-lang.rust-analyzer" 4 | } -------------------------------------------------------------------------------- /src/macros.rs: -------------------------------------------------------------------------------- 1 | #[macro_export] 2 | macro_rules! print_if { 3 | ($cond:expr, $($arg:tt)*) => { 4 | if $cond { 5 | println!($($arg)*); 6 | } 7 | }; 8 | } 9 | -------------------------------------------------------------------------------- /src/normalizer.rs: -------------------------------------------------------------------------------- 1 | use burn::prelude::*; 2 | 3 | pub fn normalize(input: Tensor) -> Tensor { 4 | let mean = input.clone().mean_dim(1); // Mean along the feature dimension 5 | let var = input.clone().var(1); // Variance along the feature dimension 6 | let std = var.sqrt() + 1e-5; // Standard deviation with epsilon for numerical stability 7 | 8 | let normalized = (input - mean) / std; // Normalize the input 9 | 10 | normalized 11 | } 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024 Eugene Hauptmann 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "fast-umap" 3 | description = "Configurable UMAP (Uniform Manifold Approximation and Projection) in Rust" 4 | keywords = ["UMAP", "dimensionality", "manifold", "machine-learning", "GPU"] 5 | categories = ["science", "mathematics", "visualization"] 6 | authors = ["Eugene Hauptmann"] 7 | repository = "https://github.com/eugenehp/fast-map" 8 | homepage = "https://github.com/eugenehp/fast-map" 9 | license = "MIT" 10 | version = "0.0.2" 11 | edition = "2021" 12 | 13 | [dependencies] 14 | burn = { version = "0.18", features = ["train", "wgpu", "autodiff", "autotune"] } 15 | burn-cubecl = { version = "0.18.0", features = ["burn-autodiff"] } 16 | # burn-jit = { version = "0.16.1", features = ["burn-autodiff"] } 17 | crossbeam-channel = "0.5.15" 18 | ctrlc = "3.4.7" 19 | cubecl = { version = "0.6.0", features = ["wgpu"] } 20 | hsl = "0.1.1" 21 | indicatif = "0.18" 22 | ndarray = "0.16.1" 23 | num = "0.4.3" 24 | num-traits = "0.2.19" 25 | plotters = "0.3.7" 26 | prettytable = "0.10.0" 27 | rand = "0.9.2" 28 | rayon = "1.10.0" 29 | serde = "1.0.216" 30 | 31 | [dev-dependencies] 32 | mnist = { version = "0.6.1", git = "https://github.com/eugenehp/mnist.git", features = ["download"] } 33 | 34 | [features] 35 | default = ["verbose"] 36 | verbose = [] 37 | -------------------------------------------------------------------------------- /src/backend.rs: -------------------------------------------------------------------------------- 1 | use burn::{ 2 | backend::Autodiff, 3 | tensor::ops::{FloatTensor, IntTensor}, 4 | }; 5 | // use burn_jit::JitBackend; 6 | use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement}; 7 | // use cubecl::wgpu::WgpuRuntime; 8 | 9 | /// We create our own Backend trait that extends the Burn backend trait. 10 | pub trait Backend: burn::tensor::backend::Backend { 11 | fn euclidean_pairwise_distance(x: FloatTensor) -> FloatTensor; 12 | fn euclidean_pairwise_distance_backward( 13 | grad_x: FloatTensor, 14 | output: FloatTensor, 15 | ) -> FloatTensor; 16 | 17 | // TODO: return IntTensor for indices 18 | /// Returns indices, distances 19 | fn knn(pairwise_distances: FloatTensor, k: u32) -> (IntTensor, FloatTensor); 20 | 21 | fn knn_backward( 22 | pairwise_distances: FloatTensor, // Pairwise distance matrix (n, n) 23 | k: u32, // Number of nearest neighbors 24 | grad_output: FloatTensor, // Gradient of the loss w.r.t the output 25 | ) -> FloatTensor; 26 | } 27 | 28 | /// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait. 29 | pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {} 30 | 31 | // this line along with the `backward` module is what's needed to enable support for a particular device below 32 | // impl AutodiffBackend for Autodiff> {} 33 | 34 | impl AutodiffBackend 35 | for Autodiff> 36 | { 37 | } 38 | -------------------------------------------------------------------------------- /examples/simple.rs: -------------------------------------------------------------------------------- 1 | use cubecl::wgpu::WgpuRuntime; 2 | use fast_umap::prelude::*; 3 | use rand::Rng; 4 | 5 | fn main() { 6 | // Number of samples in the dataset 7 | let num_samples = 100; 8 | 9 | // Number of features (dimensions) for each sample 10 | let num_features = 3; 11 | 12 | // Create a random number generator for generating random values 13 | let mut rng = rand::rng(); 14 | 15 | // Generate a dataset of random values with `num_samples` rows and `num_features` columns 16 | let data: Vec> = (0..num_samples * num_features) 17 | .map(|_| rng.random::()) // Random number generation for each feature 18 | .collect::>() // Collect all random values into a vector 19 | .chunks_exact(num_features) // Chunk the vector into rows of length `num_features` 20 | .map(|chunk| chunk.to_vec()) // Convert each chunk into a Vec 21 | .collect(); // Collect the rows into a vector of vectors 22 | 23 | type MyBackend = burn::backend::wgpu::CubeBackend; 24 | type MyAutodiffBackend = burn::backend::Autodiff; 25 | 26 | // Fit the UMAP model to the data and reduce the data to a lower-dimensional space (default: 2D) 27 | let umap: fast_umap::UMAP = umap(data.clone()); 28 | // let umap = umap_size(data.clone(), 3); // where 3 is the output size of projected dimensions 29 | 30 | // Transform the data using the trained UMAP model to reduce its dimensions 31 | let reduced_dimensions_vector = umap.transform(data.clone()); 32 | 33 | // Visualize the reduced dimensions as a vector, plots only 2D for now 34 | chart_vector(reduced_dimensions_vector, None, None); 35 | 36 | // Optionally, you can also visualize the reduced dimensions as a tensor 37 | // let reduced_dimensions_tensor = umap.transform_to_tensor(data.clone()); 38 | // print_tensor_with_title("reduced_dimensions", &reduced_dimensions_tensor); 39 | // chart_tensor(reduced_dimensions_tensor, None); 40 | } 41 | -------------------------------------------------------------------------------- /src/prelude.rs: -------------------------------------------------------------------------------- 1 | use crate::backend::AutodiffBackend; 2 | use crate::{chart, train, utils, UMAP}; 3 | 4 | // Re-export common utilities for easier use 5 | pub use chart::{chart_tensor, chart_vector}; 6 | 7 | use crossbeam_channel::unbounded; 8 | use num::Float; 9 | pub use train::Metric; 10 | pub use train::{TrainingConfig, TrainingConfigBuilder}; 11 | pub use utils::generate_test_data; 12 | 13 | /// Convenience function for running UMAP with the WGPU backend. 14 | /// 15 | /// # Arguments 16 | /// * `data` - A vector of vectors, where each inner vector represents a data sample with multiple features. 17 | /// 18 | /// # Returns 19 | /// A trained `UMAP` model that has been fitted to the input data, using the WGPU backend for computation. 20 | /// 21 | /// This function wraps the `UMAP::fit` method and provides a simplified way to fit UMAP models with the WGPU backend. 22 | /// The resulting model will have 2-dimensional output by default. 23 | /// 24 | /// # Example 25 | /// ```rust 26 | /// let data = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]; 27 | /// let model = umap(data); 28 | /// ``` 29 | pub fn umap(data: Vec>) -> UMAP 30 | where 31 | F: num::FromPrimitive + burn::tensor::Element, 32 | { 33 | let (exit_tx, exit_rx) = unbounded(); 34 | 35 | ctrlc::set_handler(move || exit_tx.send(()).expect("Could not send signal on channel.")) 36 | .expect("Error setting Ctrl-C handler"); 37 | 38 | let output_size = 2; 39 | let device = Default::default(); 40 | let model = UMAP::::fit(data, device, output_size, exit_rx); 41 | model 42 | } 43 | 44 | /// Convenience function for running UMAP with the WGPU backend and a custom output size. 45 | /// 46 | /// # Arguments 47 | /// * `data` - A vector of vectors, where each inner vector represents a data sample with multiple features. 48 | /// * `output_size` - The number of dimensions for the reduced output. This controls the dimensionality of the embedding space. 49 | /// 50 | /// # Returns 51 | /// A trained `UMAP` model that has been fitted to the input data, using the WGPU backend for computation and the specified output size. 52 | /// 53 | /// This function wraps the `UMAP::fit` method, providing a way to fit UMAP models with the WGPU backend and a customizable number of output dimensions. 54 | /// 55 | /// # Example 56 | /// ```rust 57 | /// let data = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]; 58 | /// let output_size = 3; 59 | /// let model = umap_size(data, output_size); 60 | /// ``` 61 | pub fn umap_size(data: Vec>, output_size: usize) -> UMAP 62 | where 63 | F: num::FromPrimitive + burn::tensor::Element, 64 | { 65 | let (exit_tx, exit_rx) = unbounded(); 66 | 67 | ctrlc::set_handler(move || exit_tx.send(()).expect("Could not send signal on channel.")) 68 | .expect("Error setting Ctrl-C handler"); 69 | 70 | let device = Default::default(); 71 | let model = UMAP::::fit(data, device, output_size, exit_rx); 72 | model 73 | } 74 | -------------------------------------------------------------------------------- /src/train/get_distance_by_metric.rs: -------------------------------------------------------------------------------- 1 | use burn::tensor::{Float, Tensor, TensorPrimitive}; 2 | 3 | use crate::backend::Backend; 4 | 5 | use super::*; 6 | 7 | /// Computes the distance metric for the given data. 8 | /// 9 | /// This function calculates the distance between points in the provided `data` tensor 10 | /// according to the metric specified in the `config`. It currently supports the following 11 | /// metrics: 12 | /// 13 | /// - `Euclidean`: Computes the Euclidean distance between points. 14 | /// - `EuclideanKNN`: Computes the Euclidean distance using k-nearest neighbors (KNN), with `k_neighbors` 15 | /// determining the number of nearest neighbors to consider. 16 | /// 17 | /// # Parameters 18 | /// - `data`: A 2D tensor representing the data points, where each row is a point and each column is a feature. 19 | /// - `config`: The training configuration, which specifies the metric and other parameters like `k_neighbors`. 20 | /// 21 | /// # Returns 22 | /// A 1D tensor containing the computed distances for each point based on the selected metric. 23 | /// 24 | /// # Type Parameters 25 | /// - `B`: The backend type used for automatic differentiation (AutodiffBackend), which enables GPU or CPU computations. 26 | /// 27 | /// # Example 28 | /// ```rust 29 | /// let data = Tensor::from(...); // Some 2D tensor of data points 30 | /// let config = TrainingConfig { metric: Metric::Euclidean, k_neighbors: 5 }; 31 | /// let distances = get_distance_by_metric(data, &config); 32 | /// ``` 33 | #[allow(unused)] 34 | pub fn get_distance_by_metric( 35 | data: Tensor, 36 | config: &TrainingConfig, 37 | verbose: Option, 38 | ) -> Tensor { 39 | type F = f32; 40 | // let verbose = verbose.unwrap_or("".into()); 41 | // let before = data.to_data().to_vec::().unwrap(); 42 | // println!( 43 | // "get_distance_by_metric - before - {verbose} - {:?}", 44 | // data.shape(), 45 | // ); 46 | 47 | let distance = match config.metric { 48 | // Metric::Euclidean => euclidean(data), 49 | // Metric::Euclidean => { 50 | _ => { 51 | let x = data.clone().into_primitive().tensor(); 52 | let pairwise_distances = B::euclidean_pairwise_distance(x); 53 | 54 | let (indices, distances) = 55 | B::knn(pairwise_distances.clone(), config.k_neighbors as u32); 56 | 57 | // TODO: don't clone later, to optimize the speed 58 | let pairwise_distances: Tensor = 59 | Tensor::from_primitive(TensorPrimitive::Float(pairwise_distances)); 60 | let distances = Tensor::from_primitive(TensorPrimitive::Float(distances)); 61 | 62 | distances 63 | } // Metric::EuclideanKNN => euclidean_knn(data, config.k_neighbors), 64 | // Metric::Manhattan => manhattan(data), 65 | // Metric::Cosine => cosine(data), 66 | // Metric::Minkowski => minkowski(data, config.minkowski_p), 67 | // _ => euclidean(data), 68 | }; 69 | 70 | // let after = distance.to_data().to_vec::().unwrap(); 71 | // println!( 72 | // "get_distance_by_metric - after - {verbose} - {:?}", 73 | // distance.shape(), 74 | // ); 75 | 76 | // match config.normalized { 77 | // true => normalize_tensor(distance), 78 | // false => distance, 79 | // } 80 | distance 81 | } 82 | -------------------------------------------------------------------------------- /src/kernels/mod.rs: -------------------------------------------------------------------------------- 1 | use burn::{ 2 | backend::{autodiff::checkpoint::strategy::CheckpointStrategy, Autodiff}, 3 | tensor::ops::{FloatTensor, IntTensor}, 4 | }; 5 | // use burn_jit::{FloatElement, IntElement, JitBackend, JitRuntime}; 6 | use cubecl::CubeDim; 7 | 8 | use crate::backend::Backend; 9 | use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement}; 10 | 11 | mod euclidean; 12 | mod knn; 13 | 14 | /// Example cube size 15 | pub const DEFAULT_CUBE_DIM: CubeDim = CubeDim { x: 32, y: 32, z: 1 }; 16 | 17 | // impl Backend for JitBackend { 18 | impl Backend 19 | for CubeBackend 20 | { 21 | fn euclidean_pairwise_distance(x: FloatTensor) -> FloatTensor { 22 | euclidean::forward::forward::(x) 23 | } 24 | 25 | fn euclidean_pairwise_distance_backward( 26 | grad_x: FloatTensor, 27 | output: FloatTensor, 28 | ) -> FloatTensor { 29 | // TODO: this is confusing naming, FIXME 30 | euclidean::forward::backward::(grad_x, output) 31 | } 32 | 33 | fn knn(pairwise_distances: FloatTensor, k: u32) -> (IntTensor, FloatTensor) { 34 | knn::forward::forward::(pairwise_distances, k) 35 | } 36 | 37 | fn knn_backward( 38 | pairwise_distances: FloatTensor, // Pairwise distance matrix (n, n) 39 | k: u32, // Number of nearest neighbors 40 | grad_output: FloatTensor, // Gradient of the loss w.r.t the output 41 | ) -> FloatTensor { 42 | knn::forward::backward::(pairwise_distances, k, grad_output) 43 | } 44 | } 45 | 46 | // Forward 47 | // JitBackend -> euclidean_pairwise_distance 48 | 49 | // Backward 50 | // Autodiff -> euclidean_pairwise_distance -> JitBackend -> euclidean_pairwise_distance_backward 51 | 52 | // TODO: FIXME 53 | 54 | impl Backend for Autodiff { 55 | fn euclidean_pairwise_distance(x: FloatTensor) -> FloatTensor { 56 | euclidean::backward::backward::(x) 57 | } 58 | 59 | fn euclidean_pairwise_distance_backward( 60 | _grad_x: FloatTensor, 61 | _output: FloatTensor, 62 | ) -> FloatTensor { 63 | unimplemented!("We trigger this method in `JitBackend` above. Since I didn't find a nicer way to call kernel from the `euclidean_pairwise_distance` in `Autodiff` above."); 64 | } 65 | 66 | fn knn(pairwise_distances: FloatTensor, k: u32) -> (IntTensor, FloatTensor) { 67 | // todo!("We need to implement backward kernel for the KNN") 68 | // (pairwise_distances.clone(), pairwise_distances) 69 | knn::backward::backward::(pairwise_distances, k) 70 | } 71 | 72 | fn knn_backward( 73 | _pairwise_distances: FloatTensor, // Pairwise distance matrix (n, n) 74 | _k: u32, // Number of nearest neighbors 75 | _grad_output: FloatTensor, // Gradient of the loss w.r.t the output 76 | ) -> FloatTensor { 77 | unimplemented!("We trigger this method in `JitBackend` above. Since I didn't find a nicer way to call kernel from the `knn` in `Autodiff` above."); 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/kernels/euclidean/forward.rs: -------------------------------------------------------------------------------- 1 | use super::kernel::*; 2 | use crate::kernels::DEFAULT_CUBE_DIM; 3 | use burn::tensor::{ops::FloatTensor, Shape}; 4 | use burn_cubecl::{ 5 | kernel::into_contiguous, tensor::CubeTensor, BoolElement, CubeBackend, CubeRuntime, 6 | FloatElement, IntElement, 7 | }; 8 | use cubecl::prelude::*; 9 | 10 | pub fn forward( 11 | x: FloatTensor>, 12 | ) -> FloatTensor> { 13 | let xx = into_contiguous(x.clone()); 14 | let client = xx.client; 15 | let device = xx.device; 16 | let dims = xx.shape.dims; 17 | let n = dims[0]; 18 | 19 | // Allocate output tensor of shape (N, N) to hold pairwise distances 20 | let output_shape = Shape::from(vec![n, n]); 21 | let buffer = client.empty(output_shape.num_elements() * std::mem::size_of::()); 22 | let output = CubeTensor::new_contiguous( 23 | client.clone(), 24 | device.clone(), 25 | output_shape, 26 | buffer, 27 | F::dtype(), 28 | ); 29 | 30 | // Launch the Euclidean pairwise distance kernel 31 | let cube_dim = DEFAULT_CUBE_DIM; 32 | let cubes_needed_in_x = (n as f32 / cube_dim.x as f32).ceil() as u32; 33 | let cubes_needed_in_y = (n as f32 / cube_dim.y as f32).ceil() as u32; 34 | let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, 1); 35 | 36 | // Launch the kernel 37 | euclidean_pairwise_distance_kernel::launch::( 38 | &client, 39 | cube_count, 40 | cube_dim, 41 | x.as_tensor_arg::(1), 42 | output.as_tensor_arg::(1), 43 | ); 44 | 45 | #[cfg(feature = "verbose")] 46 | { 47 | // println!("euclidean_pairwise_distance[{n}] x - {x:?}"); 48 | // println!("euclidean_pairwise_distance[{n}] output - {output:?}"); 49 | } 50 | 51 | output 52 | } 53 | 54 | pub fn backward( 55 | grad_x: FloatTensor>, 56 | output: FloatTensor>, 57 | ) -> FloatTensor> { 58 | // println!("backend - euclidean_pairwise_distance_backward"); 59 | let output = into_contiguous(output); 60 | let n = output.shape.dims[0]; 61 | let d = output.shape.dims[1]; 62 | 63 | let grad_output_shape = Shape::from(vec![n, d]); 64 | let buffer = output 65 | .client 66 | .empty(grad_output_shape.num_elements() * std::mem::size_of::()); 67 | let grad_output: CubeTensor = CubeTensor::new_contiguous( 68 | output.client.clone(), 69 | output.device.clone(), 70 | grad_output_shape, 71 | buffer, 72 | F::dtype(), 73 | ); 74 | 75 | // Launch the Euclidean pairwise distance kernel 76 | let cube_dim = DEFAULT_CUBE_DIM; 77 | let cubes_needed_in_x = (n as f32 / cube_dim.x as f32).ceil() as u32; 78 | let cubes_needed_in_y = (d as f32 / cube_dim.y as f32).ceil() as u32; 79 | let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, 1); 80 | 81 | let vectorisation = 1; 82 | 83 | // Launch the kernel 84 | euclidean_pairwise_distance_backward_kernel::launch::( 85 | &output.client, 86 | cube_count, 87 | cube_dim, 88 | output.as_tensor_arg::(vectorisation), 89 | grad_output.as_tensor_arg::(vectorisation), 90 | grad_x.as_tensor_arg::(vectorisation), 91 | ); 92 | 93 | grad_output 94 | } 95 | -------------------------------------------------------------------------------- /src/kernels/euclidean/backward.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | 3 | use burn::{ 4 | backend::{ 5 | autodiff::{ 6 | checkpoint::{base::Checkpointer, strategy::CheckpointStrategy}, 7 | grads::Gradients, 8 | ops::{Backward, Ops, OpsKind}, 9 | NodeID, 10 | }, 11 | Autodiff, 12 | }, 13 | tensor::ops::FloatTensor, 14 | }; 15 | 16 | use crate::{backend::Backend, print_if, print_primitive_tensor}; 17 | 18 | const VERBOSE: bool = false; 19 | 20 | pub fn backward( 21 | x: FloatTensor>, 22 | ) -> FloatTensor> { 23 | // println!("euclidean_pairwise_distance"); 24 | // Create zero-sized struct for backward computation 25 | #[derive(Debug)] 26 | struct EuclideanPairwiseDistanceBackward; 27 | 28 | // Implement the backward trait for the given backend B 29 | impl Backward for EuclideanPairwiseDistanceBackward { 30 | type State = NodeID; // , FloatTensor 31 | 32 | fn backward( 33 | self, 34 | ops: Ops, 35 | grads: &mut Gradients, 36 | checkpointer: &mut Checkpointer, 37 | ) { 38 | let node_x = ops.state; // Retrieve x and output from the state 39 | 40 | // Fetch the gradient for the current node. 41 | let grad_x = grads.consume::(&ops.node); 42 | let output: FloatTensor = checkpointer.retrieve_node_output(node_x); 43 | 44 | if VERBOSE { 45 | println!("grad_x {grad_x:?}"); 46 | print_primitive_tensor::(&grad_x, 10, 10); 47 | println!("output {output:?}"); 48 | print_primitive_tensor::(&output, 10, 10); 49 | } 50 | 51 | let grad_output = B::euclidean_pairwise_distance_backward(grad_x, output); 52 | 53 | if VERBOSE { 54 | println!("===grad_output=== {grad_output:?}"); 55 | print_primitive_tensor::(&grad_output, 0, 0); 56 | } 57 | 58 | // let grad_output = B::float_matmul(grad_x, output); 59 | // println!("===grad_output=== {:?}", grad_output); 60 | // print_primitive_tensor::(&grad_output, 10, 10); 61 | grads.register::(node_x, grad_output); 62 | } 63 | } 64 | 65 | // Prepare the stateful operation 66 | match EuclideanPairwiseDistanceBackward 67 | .prepare::([x.node.clone()]) 68 | .compute_bound() 69 | .stateful() 70 | { 71 | OpsKind::Tracked(mut prep) => { 72 | // When at least one node is tracked, register the backward function 73 | let x_state = prep.checkpoint(&x); // Checkpoint x for future retrieval during the backward pass 74 | 75 | let output = B::euclidean_pairwise_distance(x.clone().primitive); // Forward pass calculation 76 | print_if!(VERBOSE, "Forward pass output (Tracked): {:?}", output); // Debug: Print output shape 77 | 78 | let state = x_state; 79 | 80 | // The state now includes the checkpointed x and the output 81 | prep.finish(state, output) // Finish with the computed output 82 | } 83 | OpsKind::UnTracked(prep) => { 84 | // If no node is tracked, just do the forward calculation 85 | let output = B::euclidean_pairwise_distance(x.primitive); 86 | print_if!(VERBOSE, "Forward pass output (UnTracked): {:?}", output); // Debug: Print output shape 87 | prep.finish(output) // No need for state here 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /examples/advanced.rs: -------------------------------------------------------------------------------- 1 | use burn::{module::*, prelude::*}; 2 | use crossbeam_channel::unbounded; 3 | use cubecl::wgpu::WgpuRuntime; 4 | use fast_umap::{chart, model::*, prelude::*, train::train, utils::*}; 5 | 6 | fn main() { 7 | let (exit_tx, exit_rx) = unbounded(); 8 | 9 | ctrlc::set_handler(move || exit_tx.send(()).expect("Could not send signal on channel.")) 10 | .expect("Error setting Ctrl-C handler"); 11 | 12 | type F = f32; 13 | // Define a custom backend type using Wgpu with 32-bit floating point precision and 32-bit integer type 14 | type MyBackend = burn::backend::wgpu::CubeBackend; 15 | 16 | // Define the AutodiffBackend based on the custom MyBackend type 17 | type MyAutodiffBackend = burn::backend::Autodiff; 18 | 19 | // Initialize the GPU device for computation 20 | // let device = burn::backend::wgpu::WgpuDevice::default(); 21 | let device = Default::default(); 22 | 23 | // Set training hyperparameters 24 | let batch_size = 1000; // Number of samples per batch during training 25 | let num_samples = 1000; // Total number of samples in the dataset 26 | let num_features = 100; // Number of features (dimensions) for each sample 27 | let k_neighbors = 10; // Number of nearest neighbors for the UMAP algorithm 28 | let output_size = 2; // Number of output dimensions (e.g., 2D for embeddings) 29 | let hidden_sizes = vec![100, 100, 100]; // Size of the hidden layer in the neural network 30 | let learning_rate = 0.001; // Learning rate for optimization 31 | let beta1 = 0.9; // Beta1 parameter for the Adam optimizer 32 | let beta2 = 0.999; // Beta2 parameter for the Adam optimizer 33 | 34 | // let epochs = 400; // Number of training epochs 35 | let epochs = 100; // Number of training epochs 36 | let seed = 9999; // Random seed to ensure reproducibility 37 | let verbose = true; // Whether to enable the progress bar during training 38 | 39 | // let patience = 10; // Number of epochs without improvement before early stopping 40 | let min_desired_loss = 0.001; // Minimum loss threshold for early stopping 41 | let timeout = 60; 42 | 43 | // let metric = "euclidean_knn"; // Distance metric used for the nearest neighbor search 44 | let metric = Metric::Euclidean; // Alternative metric for neighbors search 45 | 46 | // Seed the random number generator to ensure reproducibility 47 | MyAutodiffBackend::seed(seed); 48 | 49 | // Generate random test data for training 50 | let train_data: Vec = generate_test_data(num_samples, num_features); 51 | 52 | // Configure the UMAP model with the specified input size, hidden layer size, and output size 53 | let model_config = UMAPModelConfigBuilder::default() 54 | .input_size(num_features) 55 | .hidden_sizes(hidden_sizes) 56 | .output_size(output_size) 57 | .build() 58 | .unwrap(); 59 | 60 | // Initialize the UMAP model with the defined configuration and the selected device 61 | // let model: UMAPModel = UMAPModel::new(&model_config, &device); 62 | let model: UMAPModel = UMAPModel::new(&model_config, &device); 63 | 64 | // Set up the training configuration with the specified hyperparameters 65 | let config = TrainingConfig::builder() 66 | .with_epochs(epochs) // Set the number of epochs for training 67 | .with_batch_size(batch_size) // Set the batch size for training 68 | .with_learning_rate(learning_rate) // Set the learning rate for the optimizer 69 | .with_beta1(beta1) // Set the beta1 parameter for the Adam optimizer 70 | .with_beta2(beta2) // Set the beta2 parameter for the Adam optimizer 71 | .with_verbose(verbose) // Enable or disable the progress bar 72 | // .with_patience(patience) // Set the patience for early stopping 73 | .with_metric(metric.into()) // Set the metric for nearest neighbors (e.g., Euclidean) 74 | .with_k_neighbors(k_neighbors) // Set the number of neighbors to consider for UMAP 75 | .with_min_desired_loss(min_desired_loss) // Set the minimum desired loss for early stopping 76 | .with_timeout(timeout) // set timeout in seconds 77 | .build() 78 | .expect("Failed to build TrainingConfig"); 79 | 80 | // Start training the UMAP model with the specified training data and configuration 81 | let (model, _, _) = train( 82 | "advanced", 83 | model, // The model to train 84 | num_samples, // Total number of training samples 85 | num_features, // Number of features per sample 86 | train_data.clone(), // The training data 87 | &config, // The training configuration 88 | device.clone(), 89 | exit_rx, 90 | ); 91 | 92 | // Validate the trained model after training 93 | let model = model.valid(); 94 | 95 | // Convert the training data into a tensor for model input 96 | let global = convert_vector_to_tensor(train_data, num_samples, num_features, &device); 97 | 98 | // Perform a forward pass through the model to obtain the low-dimensional (local) representation 99 | let local = model.forward(global.clone()); 100 | 101 | // Optionally, print the global and local tensors for inspection (currently commented out) 102 | // if verbose { 103 | // print_tensor_with_title("global", &global); 104 | // print_tensor_with_title("local", &local); 105 | // } 106 | 107 | // Visualize the 2D embedding (local representation) using a chart 108 | chart::chart_tensor(local, None, None); 109 | } 110 | -------------------------------------------------------------------------------- /src/kernels/euclidean/kernel.rs: -------------------------------------------------------------------------------- 1 | use cubecl::{cube, prelude::*}; 2 | 3 | #[cube(launch)] 4 | pub fn euclidean_pairwise_distance_kernel( 5 | x: &Tensor, // Input tensor of shape (n, d) representing n vectors of dimension d 6 | output: &mut Tensor, // Output tensor of shape (n, n) to store pairwise distances 7 | ) { 8 | let row = ABSOLUTE_POS_X; // Row index for the pairwise computation 9 | let col = ABSOLUTE_POS_Y; // Column index for the pairwise computation 10 | 11 | let n = x.shape(0); // Number of vectors (rows) in the output tensor 12 | let d = x.shape(1); // Dimension of each vector (features) in the input tensor 13 | 14 | let mut exit_early = false; 15 | 16 | if row >= n || col >= n || row > col { 17 | // Skip threads that are out of bounds or handle only the upper triangular matrix 18 | exit_early = true; 19 | } 20 | 21 | // Edge case 1: Handle empty input tensor (n == 0 or d == 0) 22 | if n == 0 || d == 0 { 23 | // No computation needed for empty tensor 24 | exit_early = true; 25 | } 26 | 27 | // Edge case 2: Handle single vector case (n == 1) 28 | if n == 1 { 29 | output[0] = F::new(0.0); // Distance between the only vector and itself is 0 30 | exit_early = true; 31 | } 32 | 33 | // Edge case 3: Handle zero-dimensional vectors (d == 0) 34 | if d == 0 && !exit_early { 35 | // If vectors have 0 dimensions, the distance between any two vectors is trivially 0 36 | for i in 0..n { 37 | for j in i..n { 38 | output[i * n + j] = F::new(0.0); 39 | output[j * n + i] = F::new(0.0); // Symmetry: dist(i, j) = dist(j, i) 40 | } 41 | } 42 | exit_early = true; 43 | } 44 | 45 | if !exit_early { 46 | let mut sum = F::new(0.0); // Sum of squared differences 47 | 48 | // Compute the squared differences between vectors row and col 49 | for i in 0..d { 50 | let index_row = row * d + i; // Linear index for row, dimension i 51 | let index_col = col * d + i; // Linear index for col, dimension i 52 | 53 | let diff = x[index_row] - x[index_col]; 54 | sum += diff * diff; 55 | } 56 | 57 | // Calculate Euclidean distance (square root of sum of squared differences) 58 | let dist = F::sqrt(sum); 59 | 60 | // Linear index for the output tensor 61 | let output_index = row * n + col; 62 | 63 | // Store the pairwise Euclidean distance in the output tensor 64 | output[output_index] = dist; 65 | 66 | // Symmetry: dist(i, j) = dist(j, i) 67 | if row != col { 68 | // Avoid redundant assignments when row == col 69 | let output_index_sym = col * n + row; 70 | output[output_index_sym] = dist; 71 | } 72 | } 73 | } 74 | 75 | #[cube(launch)] 76 | pub fn euclidean_pairwise_distance_backward_kernel( 77 | output: &Tensor, // Output tensor (n, d), pairwise distances 78 | grad_output: &mut Tensor, // Gradient of the loss with respect to output tensor (n, d) 79 | grad_x: &Tensor, // Gradient of the loss with respect to input tensor (n, n) 80 | ) { 81 | let row = ABSOLUTE_POS_X; // Row index for the pairwise computation 82 | let col = ABSOLUTE_POS_Y; // Column index for the pairwise computation 83 | 84 | // Get the number of vectors (n) and the dimension (d) of each vector 85 | let n = output.shape(0); // Number of vectors (rows) in the input tensor 86 | let d = output.shape(1); // Dimension of each vector (features) in the input tensor 87 | 88 | let mut exit_early = false; 89 | 90 | // Edge case 1: Handle empty input tensor (n == 0 or d == 0) 91 | if n == 0 || d == 0 { 92 | // No computation needed for empty tensor 93 | exit_early = true; 94 | } 95 | 96 | // Edge case 2: Handle zero-dimensional vectors (d == 0) 97 | if d == 0 { 98 | // grad_output should already be zeroed out 99 | exit_early = true; 100 | } 101 | 102 | // Edge case: Ensure row and col are within bounds 103 | if row >= n || col >= n || row > col { 104 | // Skip threads that are out of bounds 105 | exit_early = true; 106 | } 107 | 108 | // Get the pairwise distance between vectors row and col 109 | let dist = output[row * n + col]; 110 | 111 | // Handle small distances (to avoid division by zero) 112 | let epsilon = F::new(1e-8); // Define a small epsilon value 113 | let dist = F::max(dist, epsilon); // Ensure dist is never less than epsilon 114 | 115 | // Skip if the distance is 0 (identical vectors) 116 | if dist < epsilon && !exit_early { 117 | for i in 0..d { 118 | let index_row = row * d + i; // Linear index for row, dimension i 119 | let index_col = col * d + i; // Linear index for col, dimension i 120 | grad_output[index_row] = F::new(0.0); 121 | grad_output[index_col] = F::new(0.0); 122 | } 123 | 124 | // No gradient to propagate for identical vectors 125 | exit_early = true; 126 | } 127 | 128 | if !exit_early { 129 | if row != col { 130 | // Compute the gradient of the pairwise distance w.r.t the input vectors 131 | for i in 0..d { 132 | let index_row = row * d + i; // Linear index for row, dimension i 133 | let index_col = col * d + i; // Linear index for col, dimension i 134 | 135 | let diff = output[index_row] - output[index_col]; // Difference between the vectors 136 | 137 | // Gradient of the distance w.r.t x_{i,k} 138 | let grad_dist_i = grad_x[row * n + col] * (diff / dist); // Scale the gradient 139 | 140 | // Propagate the gradient to the input tensor 141 | // grad_output is the gradient of the loss with respect to input tensor 142 | grad_output[index_row] += grad_dist_i; // Gradient w.r.t row vector (x_i) 143 | grad_output[index_col] -= grad_dist_i; // Gradient w.r.t col vector (x_j) 144 | } 145 | } 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | use backend::AutodiffBackend; 2 | use burn::module::AutodiffModule; 3 | 4 | use crossbeam_channel::Receiver; 5 | use model::{UMAPModel, UMAPModelConfigBuilder}; 6 | use num::Float; 7 | use train::*; 8 | use utils::*; 9 | 10 | use burn::tensor::{Device, Tensor}; 11 | 12 | pub mod backend; 13 | pub mod chart; 14 | pub mod distances; 15 | pub mod kernels; 16 | pub mod macros; 17 | pub mod model; 18 | pub mod normalizer; 19 | pub mod prelude; 20 | pub mod train; 21 | pub mod utils; 22 | 23 | /// Struct representing the UMAP (Uniform Manifold Approximation and Projection) model. 24 | /// 25 | /// This struct contains the model and the device (e.g., CPU or GPU) used for computation. 26 | /// The `fit` method trains the model, and the `transform` method projects the data into a lower-dimensional space. 27 | pub struct UMAP { 28 | model: UMAPModel, // UMAP model that performs dimensionality reduction 29 | device: Device, // Device to run the computation on (CPU, GPU) 30 | } 31 | 32 | impl UMAP { 33 | /// Trains the UMAP model on the given data and returns a fitted UMAP model. 34 | /// 35 | /// # Arguments 36 | /// * `data` - A vector of vectors, where each inner vector represents a data sample with multiple features. 37 | /// * `device` - The device (CPU or GPU) on which to perform training. 38 | /// 39 | /// # Returns 40 | /// A trained `UMAP` model. 41 | /// 42 | /// This method initializes the model configuration, sets up the training parameters (like batch size, learning rate, etc.), 43 | /// and runs the training process using the provided data. It returns an instance of the `UMAP` struct containing 44 | /// the trained model and the device. 45 | pub fn fit( 46 | data: Vec>, 47 | device: Device, 48 | output_size: usize, 49 | exit_rx: Receiver<()>, 50 | ) -> Self 51 | where 52 | F: num::FromPrimitive + burn::tensor::Element, 53 | { 54 | let default_name = "model"; 55 | // Set training parameters 56 | let batch_size = 1; 57 | let num_samples = data.len(); 58 | let num_features = data[0].len(); 59 | // let output_size = 2; // UMAP typically reduces the data to 2 dimensions 60 | let hidden_sizes = vec![100]; // Size of the hidden layers in the model 61 | let learning_rate = 0.001; // Learning rate for optimization 62 | let beta1 = 0.9; // Beta1 parameter for Adam optimizer 63 | let beta2 = 0.999; // Beta2 parameter for Adam optimizer 64 | let epochs = 100; // Number of epochs for training 65 | let seed = 9999; // Random seed for reproducibility 66 | 67 | B::seed(seed); // Set the seed for the backend 68 | 69 | // Flatten the input data into a single vector of f64 values 70 | let train_data: Vec = data.into_iter().flatten().map(|f| f).collect(); 71 | 72 | // Build the model configuration 73 | let model_config = UMAPModelConfigBuilder::default() 74 | .input_size(num_features) 75 | .hidden_sizes(hidden_sizes) 76 | .output_size(output_size) 77 | .build() 78 | .unwrap(); 79 | 80 | // Initialize the UMAP model 81 | let model: UMAPModel = UMAPModel::new(&model_config, &device); 82 | 83 | // Build the training configuration 84 | let config = TrainingConfig::builder() 85 | .with_epochs(epochs) 86 | .with_batch_size(batch_size) 87 | .with_learning_rate(learning_rate) 88 | .with_beta1(beta1) 89 | .with_beta2(beta2) 90 | .build() 91 | .expect("Failed to build TrainingConfig"); 92 | 93 | // Start training 94 | let (model, _losses, _best_loss): (UMAPModel, Vec, F) = train( 95 | default_name, 96 | model, 97 | num_samples, 98 | num_features, 99 | train_data.clone(), 100 | &config, 101 | device.clone(), 102 | exit_rx, 103 | ); 104 | 105 | // Validate the trained model 106 | let model: UMAPModel = model.valid(); 107 | 108 | // Return the fitted UMAP model wrapped in the UMAP struct 109 | let umap = UMAP { model, device }; 110 | 111 | umap 112 | } 113 | 114 | /// Transforms the input data to a tensor using the trained UMAP model. 115 | /// 116 | /// # Arguments 117 | /// * `data` - A vector of vectors, where each inner vector represents a data sample with multiple features. 118 | /// 119 | /// # Returns 120 | /// A tensor of shape `[num_samples, num_output_features]` representing the transformed data. 121 | /// 122 | /// This method converts the input data into a tensor, passes it through the model to obtain 123 | /// the low-dimensional representation (local space), and returns the result as a tensor. 124 | pub fn transform_to_tensor(&self, data: Vec>) -> Tensor { 125 | let num_samples = data.len(); 126 | let num_features = data[0].len(); 127 | 128 | // Flatten the input data into a vector of f64 values 129 | let train_data: Vec = data.into_iter().flatten().map(|f| f64::from(f)).collect(); 130 | 131 | // Convert the data into a tensor 132 | let global = convert_vector_to_tensor(train_data, num_samples, num_features, &self.device); 133 | 134 | // Perform the forward pass to get the local (low-dimensional) representation 135 | let local = self.model.forward(global); 136 | 137 | local 138 | } 139 | 140 | /// Transforms the input data into a lower-dimensional space and returns it as a vector of vectors. 141 | /// 142 | /// # Arguments 143 | /// * `data` - A vector of vectors, where each inner vector represents a data sample with multiple features. 144 | /// 145 | /// # Returns 146 | /// A vector of vectors, where each inner vector represents a low-dimensional representation of a data sample. 147 | /// 148 | /// This method is a higher-level abstraction that calls `transform_to_tensor`, converts the result 149 | /// back to a vector format for easier inspection, and returns the transformed data. 150 | pub fn transform(&self, data: Vec>) -> Vec> { 151 | let local = self.transform_to_tensor(data); 152 | let result = convert_tensor_to_vector(local); 153 | 154 | result 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /src/kernels/knn/forward.rs: -------------------------------------------------------------------------------- 1 | use super::kernel::*; 2 | use crate::kernels::DEFAULT_CUBE_DIM; 3 | use burn::tensor::{ops::FloatTensor, Shape}; 4 | use burn_cubecl::{ 5 | kernel::into_contiguous, tensor::CubeTensor, BoolElement, CubeBackend, CubeRuntime, 6 | FloatElement, IntElement, 7 | }; 8 | use cubecl::prelude::*; 9 | 10 | pub fn forward( 11 | pairwise_distances: FloatTensor>, 12 | k: u32, 13 | ) -> (CubeTensor, CubeTensor) { 14 | let pairwise_distances = into_contiguous(pairwise_distances.clone()); 15 | let client = pairwise_distances.client.clone(); 16 | let device = pairwise_distances.device.clone(); 17 | let dims = pairwise_distances.shape.dims.clone(); 18 | let n = dims[0]; 19 | 20 | // Allocate output tensors for indices and distances 21 | let indices_shape = Shape::from(vec![n, k as usize]); 22 | let distances_shape = Shape::from(vec![n, k as usize]); 23 | 24 | let indices_buffer = client.empty(indices_shape.num_elements() * std::mem::size_of::()); 25 | let distances_buffer = client.empty(distances_shape.num_elements() * std::mem::size_of::()); 26 | 27 | let indices: CubeTensor = CubeTensor::new_contiguous( 28 | client.clone(), 29 | device.clone(), 30 | indices_shape, 31 | indices_buffer, 32 | burn::tensor::DType::I64, 33 | ); 34 | let distances: CubeTensor = CubeTensor::new_contiguous( 35 | client.clone(), 36 | device.clone(), 37 | distances_shape, 38 | distances_buffer, 39 | F::dtype(), 40 | ); 41 | 42 | let local_shape = Shape::from(vec![k as usize]); // Local shape for k neighbors 43 | 44 | // Create the buffer and grad_pairwise_distances tensor 45 | let local_buffer = pairwise_distances 46 | .client 47 | .empty(local_shape.num_elements() * std::mem::size_of::()); 48 | 49 | let local_distances: CubeTensor = CubeTensor::new_contiguous( 50 | pairwise_distances.client.clone(), 51 | pairwise_distances.device.clone(), 52 | pairwise_distances.shape.clone(), 53 | local_buffer.clone(), 54 | F::dtype(), 55 | ); 56 | 57 | let local_indices: CubeTensor = CubeTensor::new_contiguous( 58 | pairwise_distances.client.clone(), 59 | pairwise_distances.device.clone(), 60 | pairwise_distances.shape.clone(), 61 | local_buffer, 62 | burn::tensor::DType::I64, 63 | ); 64 | 65 | // Launch the k-NN kernel 66 | let cube_dim = DEFAULT_CUBE_DIM; 67 | let cubes_needed_in_x = (n as f32 / cube_dim.x as f32).ceil() as u32; 68 | let cubes_needed_in_y = (k as f32 / cube_dim.y as f32).ceil() as u32; 69 | let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, 1); 70 | 71 | let vectorisation = 1; 72 | 73 | // Launch the k-NN kernel 74 | knn_kernel::launch::( 75 | &client, 76 | cube_count, 77 | cube_dim, 78 | pairwise_distances.as_tensor_arg::(vectorisation), // Pairwise distance matrix 79 | ScalarArg::new(k), // Number of neighbors 80 | local_distances.as_tensor_arg::(vectorisation), 81 | local_indices.as_tensor_arg::(vectorisation), 82 | indices.as_tensor_arg::(vectorisation), // Indices tensor 83 | distances.as_tensor_arg::(vectorisation), // Distances tensor 84 | ); 85 | 86 | (indices, distances) 87 | } 88 | 89 | pub fn backward( 90 | pairwise_distances: FloatTensor>, // Pairwise distance matrix (n, n) 91 | k: u32, // Number of nearest neighbors 92 | grad_output: FloatTensor>, // Gradient of the loss w.r.t the output 93 | ) -> FloatTensor> { 94 | // Convert the output tensor to a contiguous format for efficient access 95 | let pairwise_distances = into_contiguous(pairwise_distances); 96 | let n = pairwise_distances.shape.dims[0]; // Number of vectors 97 | let grad_output_shape = Shape::from(vec![n, k as usize]); // Gradient output shape for k neighbors 98 | 99 | // Create the buffer and grad_pairwise_distances tensor 100 | let buffer = pairwise_distances 101 | .client 102 | .empty(grad_output_shape.num_elements() * std::mem::size_of::()); 103 | 104 | let grad_pairwise_distances: CubeTensor = CubeTensor::new_contiguous( 105 | pairwise_distances.client.clone(), 106 | pairwise_distances.device.clone(), 107 | pairwise_distances.shape.clone(), 108 | buffer, 109 | F::dtype(), 110 | ); 111 | 112 | let local_shape = Shape::from(vec![k as usize]); // Local shape for k neighbors 113 | 114 | // Create the buffer and grad_pairwise_distances tensor 115 | let local_buffer = pairwise_distances 116 | .client 117 | .empty(local_shape.num_elements() * std::mem::size_of::()); 118 | 119 | let local_distances: CubeTensor = CubeTensor::new_contiguous( 120 | pairwise_distances.client.clone(), 121 | pairwise_distances.device.clone(), 122 | pairwise_distances.shape.clone(), 123 | local_buffer.clone(), 124 | F::dtype(), 125 | ); 126 | 127 | let local_indices: CubeTensor = CubeTensor::new_contiguous( 128 | pairwise_distances.client.clone(), 129 | pairwise_distances.device.clone(), 130 | pairwise_distances.shape.clone(), 131 | local_buffer, 132 | F::dtype(), 133 | ); 134 | 135 | // Calculate the number of blocks needed for the kernel launch 136 | let cube_dim = DEFAULT_CUBE_DIM; 137 | let cubes_needed_in_x = (n as f32 / cube_dim.x as f32).ceil() as u32; 138 | let cubes_needed_in_y = (n as f32 / cube_dim.y as f32).ceil() as u32; 139 | let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, 1); 140 | 141 | let vectorization = 1; // Use 1 for no vectorization 142 | 143 | // Launch the KNN backward kernel 144 | knn_backward_kernel::launch::( 145 | &pairwise_distances.client, 146 | cube_count, 147 | cube_dim, 148 | pairwise_distances.as_tensor_arg::(vectorization), 149 | ScalarArg::new(k), // Pass the value of k as an argument 150 | local_distances.as_tensor_arg::(vectorization), 151 | local_indices.as_tensor_arg::(vectorization), 152 | grad_output.as_tensor_arg::(vectorization), 153 | grad_pairwise_distances.as_tensor_arg::(vectorization), 154 | ); 155 | 156 | // Return the gradient of the pairwise distances 157 | grad_pairwise_distances 158 | } 159 | -------------------------------------------------------------------------------- /src/model.rs: -------------------------------------------------------------------------------- 1 | use burn::prelude::*; 2 | use nn::{Linear, LinearConfig, Relu}; 3 | use serde::{Deserialize, Serialize}; 4 | 5 | /// A neural network model with a configurable number of layers and dynamic sizes. 6 | /// The model can have multiple hidden layers, with each layer having its own configurable size. 7 | /// 8 | /// # Arguments 9 | /// * `B` - The backend type for tensor operations (e.g., `AutodiffBackend`) 10 | #[derive(Module, Debug)] 11 | pub struct UMAPModel { 12 | layers: Vec>, // Vector to store dynamic layers 13 | activation: Relu, // ReLU activation function 14 | } 15 | 16 | impl UMAPModel { 17 | /// Creates a new instance of `UMAPModel` with the specified configuration and device. 18 | /// 19 | /// # Arguments 20 | /// * `config` - Configuration struct containing the input size, hidden layer sizes, and output size. 21 | /// * `device` - The device on which the model should be initialized (e.g., CPU or GPU). 22 | /// 23 | /// # Returns 24 | /// A new `UMAPModel` instance initialized with the provided configuration. 25 | pub fn new(config: &UMAPModelConfig, device: &Device) -> Self { 26 | // Build the layers dynamically based on the hidden layer sizes. 27 | let mut layers = Vec::new(); 28 | let mut input_size = config.input_size; 29 | 30 | // Create hidden layers 31 | for &hidden_size in &config.hidden_sizes { 32 | layers.push( 33 | LinearConfig::new(input_size, hidden_size) 34 | .with_bias(true) 35 | .init(device), 36 | ); 37 | input_size = hidden_size; // Update input size for the next layer 38 | } 39 | 40 | // Add the output layer 41 | layers.push( 42 | LinearConfig::new(input_size, config.output_size) 43 | .with_bias(true) 44 | .init(device), 45 | ); 46 | 47 | // Initialize ReLU activation function 48 | let activation = Relu::new(); 49 | 50 | // Return the UMAPModel with the initialized layers 51 | UMAPModel { layers, activation } 52 | } 53 | 54 | /// Perform a forward pass through the model. 55 | /// 56 | /// # Arguments 57 | /// * `input` - A 2D tensor representing the input data, with shape (n_samples, n_features). 58 | /// 59 | /// # Returns 60 | /// A 2D tensor representing the output after passing through all the layers and activations. 61 | pub fn forward(&self, input: Tensor) -> Tensor { 62 | let mut x = input; 63 | 64 | // Forward pass through each layer with activation 65 | for (i, layer) in self.layers.iter().enumerate() { 66 | x = layer.forward(x); // Apply linear transformation 67 | 68 | // Apply activation only if it's not the last layer 69 | if i < self.layers.len() - 1 { 70 | x = self.activation.forward(x); // Apply ReLU activation 71 | } 72 | } 73 | 74 | x 75 | } 76 | } 77 | 78 | /// Configuration structure for creating a `UMAPModel`. 79 | /// 80 | /// # Fields 81 | /// * `input_size` - Number of input features. 82 | /// * `hidden_sizes` - Vector of sizes for the hidden layers. 83 | /// * `output_size` - Number of output features. 84 | #[derive(Debug, Clone, Serialize, Deserialize)] 85 | pub struct UMAPModelConfig { 86 | pub input_size: usize, // Number of input features 87 | pub hidden_sizes: Vec, // Sizes of hidden layers 88 | pub output_size: usize, // Number of output features 89 | } 90 | 91 | impl UMAPModelConfig { 92 | /// Creates a new builder for the `UMAPModelConfig`. 93 | /// 94 | /// # Returns 95 | /// A new `UMAPModelConfigBuilder` to configure the model. 96 | pub fn builder() -> UMAPModelConfigBuilder { 97 | UMAPModelConfigBuilder::default() 98 | } 99 | } 100 | 101 | /// Builder pattern for the `UMAPModelConfig` struct. 102 | /// 103 | /// # Fields 104 | /// * `input_size` - Option for the number of input features. 105 | /// * `hidden_sizes` - Option for the sizes of the hidden layers. 106 | /// * `output_size` - Option for the number of output features. 107 | #[derive(Debug, Clone)] 108 | pub struct UMAPModelConfigBuilder { 109 | input_size: Option, 110 | hidden_sizes: Option>, 111 | output_size: Option, 112 | } 113 | 114 | impl Default for UMAPModelConfigBuilder { 115 | fn default() -> Self { 116 | UMAPModelConfigBuilder { 117 | input_size: Some(100), 118 | hidden_sizes: Some(vec![100, 100, 100]), // Default to 3 hidden layers of size 100 119 | output_size: Some(2), 120 | } 121 | } 122 | } 123 | 124 | impl UMAPModelConfigBuilder { 125 | /// Set the input size for the model. 126 | /// 127 | /// # Arguments 128 | /// * `input_size` - The number of input features. 129 | /// 130 | /// # Returns 131 | /// The updated `UMAPModelConfigBuilder` with the specified input size. 132 | pub fn input_size(mut self, input_size: usize) -> Self { 133 | self.input_size = Some(input_size); 134 | self 135 | } 136 | 137 | /// Set the hidden layer sizes for the model. 138 | /// 139 | /// # Arguments 140 | /// * `hidden_sizes` - The sizes of the hidden layers. 141 | /// 142 | /// # Returns 143 | /// The updated `UMAPModelConfigBuilder` with the specified hidden sizes. 144 | pub fn hidden_sizes(mut self, hidden_sizes: Vec) -> Self { 145 | self.hidden_sizes = Some(hidden_sizes); 146 | self 147 | } 148 | 149 | /// Set the output size for the model. 150 | /// 151 | /// # Arguments 152 | /// * `output_size` - The number of output features. 153 | /// 154 | /// # Returns 155 | /// The updated `UMAPModelConfigBuilder` with the specified output size. 156 | pub fn output_size(mut self, output_size: usize) -> Self { 157 | self.output_size = Some(output_size); 158 | self 159 | } 160 | 161 | /// Build and return the final `UMAPModelConfig`. 162 | /// 163 | /// # Returns 164 | /// A `Result` containing the built `UMAPModelConfig` or an error message if required fields are missing. 165 | pub fn build(self) -> Result { 166 | // Ensure that all required fields are set 167 | Ok(UMAPModelConfig { 168 | input_size: self 169 | .input_size 170 | .ok_or_else(|| "Input size must be set".to_string())?, 171 | hidden_sizes: self 172 | .hidden_sizes 173 | .ok_or_else(|| "Hidden sizes must be set".to_string())?, 174 | output_size: self 175 | .output_size 176 | .ok_or_else(|| "Output size must be set".to_string())?, 177 | }) 178 | } 179 | } 180 | -------------------------------------------------------------------------------- /examples/mnist.rs: -------------------------------------------------------------------------------- 1 | use burn::{backend::*, module::*, prelude::*}; 2 | use crossbeam_channel::unbounded; 3 | use fast_umap::chart; 4 | #[allow(unused)] 5 | use fast_umap::{ 6 | chart::*, 7 | model::*, 8 | prelude::*, 9 | train::{train, LossReduction}, 10 | utils::*, 11 | }; 12 | use mnist::*; 13 | use wgpu::WgpuRuntime; 14 | 15 | fn main() { 16 | let (exit_tx, exit_rx) = unbounded(); 17 | 18 | ctrlc::set_handler(move || exit_tx.send(()).expect("Could not send signal on channel.")) 19 | .expect("Error setting Ctrl-C handler"); 20 | 21 | // Define a custom backend type using Wgpu with 32-bit floating point precision and 32-bit integer type 22 | type MyBackend = burn::backend::wgpu::CubeBackend; 23 | 24 | // Define the AutodiffBackend based on the custom MyBackend type 25 | type MyAutodiffBackend = burn::backend::Autodiff; 26 | 27 | // Initialize the GPU device for computation 28 | let device = burn::backend::wgpu::WgpuDevice::default(); 29 | 30 | // Set training hyperparameters 31 | let batch_size = 1_000; // Number of samples per batch during training 32 | let num_samples = 10_000 as usize; // Total number of samples in the dataset 33 | 34 | // let num_samples = 50_000 as usize; // Total number of samples in the dataset 35 | 36 | let num_features = 28 * 28; // Number of features (dimensions) for each sample, size of each mnist image 37 | let k_neighbors = 15; // Number of nearest neighbors for the UMAP algorithm 38 | let output_size = 2; // Number of output dimensions (e.g., 2D for embeddings) 39 | let hidden_sizes = vec![1000]; // Size of the hidden layer in the neural network 40 | let learning_rate = 1e-4; // Learning rate for optimization 41 | let penalty = 1e-6; // penalty for the Adam optimizer 42 | let beta1 = 0.9; // Beta1 parameter for the Adam optimizer 43 | let beta2 = 0.999; // Beta2 parameter for the Adam optimizer 44 | let epochs = 1_000; // Number of training epochs 45 | let seed = 9999; // Random seed to ensure reproducibility 46 | let verbose = true; // Whether to enable the progress bar during training 47 | 48 | // let patience = 100; // Number of epochs without improvement before early stopping 49 | let min_desired_loss = 1e-4; // Minimum loss threshold for early stopping 50 | let metric = Metric::Euclidean; // Alternative metric for neighbors search 51 | let loss_reduction = LossReduction::Mean; 52 | // below 1.0 gives NaN loss. Mind that it's rounded to integer inside the function 53 | // let minkowski_p = 3.0; // 1 is manhattan, 2 is Euclidean 54 | let normalized = true; // to reduce math, and keep it at float 55 | 56 | // let timeout = 30; // timeout in seconds 57 | 58 | // Seed the random number generator to ensure reproducibility 59 | MyBackend::seed(seed); 60 | 61 | let Mnist { 62 | trn_img, 63 | trn_lbl, 64 | // tst_img, 65 | // tst_lbl, 66 | .. 67 | } = MnistBuilder::new() 68 | .download_and_extract() 69 | .label_format_digit() 70 | .training_set_length(num_samples as u32) 71 | // .validation_set_length(10_000) 72 | // .test_set_length(10_000) 73 | .finalize(); 74 | 75 | // Generate random test data for training 76 | // let train_data = generate_test_data(num_samples, num_features); 77 | let train_data: Vec = trn_img.into_iter().map(|byte| byte as f64).collect(); 78 | 79 | // Configure the UMAP model with the specified input size, hidden layer size, and output size 80 | let model_config = UMAPModelConfigBuilder::default() 81 | .input_size(num_features) 82 | .hidden_sizes(hidden_sizes) 83 | .output_size(output_size) 84 | .build() 85 | .unwrap(); 86 | 87 | // Initialize the UMAP model with the defined configuration and the selected device 88 | let model: UMAPModel = UMAPModel::new(&model_config, &device); 89 | 90 | // Set up the training configuration with the specified hyperparameters 91 | let config = TrainingConfig::builder() 92 | .with_epochs(epochs) // Set the number of epochs for training 93 | .with_batch_size(batch_size) // Set the batch size for training 94 | .with_learning_rate(learning_rate) // Set the learning rate for the optimizer 95 | .with_beta1(beta1) // Set the beta1 parameter for the Adam optimizer 96 | .with_beta2(beta2) // Set the beta2 parameter for the Adam optimizer 97 | .with_verbose(verbose) // Enable or disable the progress bar 98 | // .with_patience(patience) // Set the patience for early stopping 99 | .with_metric(metric.into()) // Set the metric for nearest neighbors (e.g., Euclidean) 100 | .with_k_neighbors(k_neighbors) // Set the number of neighbors to consider for UMAP 101 | .with_min_desired_loss(min_desired_loss) // Set the minimum desired loss for early stopping 102 | .with_loss_reduction(loss_reduction) 103 | // .with_timeout(timeout) // set timeout in seconds 104 | // .with_minkowski_p(minkowski_p) 105 | .with_normalized(normalized) 106 | .with_penalty(penalty) 107 | .build() 108 | .expect("Failed to build TrainingConfig"); 109 | 110 | // Start training the UMAP model with the specified training data and configuration 111 | let (model, _, _) = train( 112 | "mnist", 113 | model, // The model to train 114 | num_samples, // Total number of training samples 115 | num_features, // Number of features per sample 116 | train_data.clone(), // The training data 117 | &config, // The training configuration 118 | device.clone(), 119 | exit_rx, 120 | ); 121 | 122 | // Validate the trained model after training 123 | let model = model.valid(); 124 | 125 | // Convert the training data into a tensor for model input 126 | let global = convert_vector_to_tensor(train_data, num_samples, num_features, &device); 127 | 128 | // Perform a forward pass through the model to obtain the low-dimensional (local) representation 129 | let local = model.forward(global.clone()); 130 | 131 | // Optionally, print the global and local tensors for inspection (currently commented out) 132 | // if verbose { 133 | // print_tensor_with_title("global", &global); 134 | // print_tensor_with_title("local", &local); 135 | // } 136 | 137 | let chart_config = ChartConfigBuilder::default() 138 | .caption("MNIST") 139 | .path("mnist.png") 140 | .build(); 141 | 142 | let labels: Vec = trn_lbl.iter().map(|digit| format!("{digit}")).collect(); 143 | 144 | // Visualize the 2D embedding (local representation) using a chart 145 | chart::chart_tensor(local, Some(labels), Some(chart_config)); 146 | } 147 | -------------------------------------------------------------------------------- /src/kernels/knn/backward.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | 3 | use burn::{ 4 | backend::{ 5 | autodiff::{ 6 | checkpoint::{base::Checkpointer, strategy::CheckpointStrategy}, 7 | grads::Gradients, 8 | ops::{Backward, Ops, OpsKind}, 9 | NodeID, 10 | }, 11 | Autodiff, 12 | }, 13 | tensor::ops::{FloatTensor, IntTensor}, 14 | }; 15 | 16 | use crate::{backend::*, print_if, print_primitive_tensor}; 17 | 18 | const VERBOSE: bool = false; 19 | 20 | pub fn backward( 21 | pairwise_distances: FloatTensor>, 22 | k: u32, // Number of nearest neighbors 23 | ) -> (IntTensor>, FloatTensor>) { 24 | // println!("knn_backward"); 25 | // Create zero-sized struct for backward computation 26 | #[derive(Debug)] 27 | struct KnnBackward; 28 | 29 | // Implement the backward trait for the given backend B 30 | impl Backward for KnnBackward { 31 | type State = (NodeID, u32); // FloatTensor, 32 | 33 | fn backward( 34 | self, 35 | ops: Ops, 36 | grads: &mut Gradients, 37 | checkpointer: &mut Checkpointer, 38 | ) { 39 | let (node_pairwise_distances, k) = ops.state; // Retrieve pairwise_distances and output from the state 40 | 41 | // Fetch the gradient for the current node. 42 | let grad_output = grads.consume::(&ops.node); 43 | let pairwise_distances: FloatTensor = 44 | checkpointer.retrieve_node_output(node_pairwise_distances); 45 | 46 | if VERBOSE { 47 | println!("grad_output {grad_output:?}"); 48 | print_primitive_tensor::(&grad_output, 10, 10); 49 | println!("pairwise_distances {pairwise_distances:?}"); 50 | print_primitive_tensor::(&pairwise_distances, 10, 10); 51 | } 52 | 53 | // Perform the backward pass for the KNN operation 54 | let grad_pairwise_distances = B::knn_backward(pairwise_distances, k, grad_output); 55 | 56 | if VERBOSE { 57 | println!("===grad_pairwise_distances=== {grad_pairwise_distances:?}"); 58 | print_primitive_tensor::(&grad_pairwise_distances, 0, 0); 59 | } 60 | 61 | // Register the gradient for the pairwise_distances tensor 62 | grads.register::(node_pairwise_distances, grad_pairwise_distances); 63 | } 64 | } 65 | 66 | // Prepare the stateful operation 67 | let indicies = match KnnBackward 68 | .prepare::([pairwise_distances.node.clone()]) 69 | .compute_bound() 70 | .stateful() 71 | { 72 | OpsKind::Tracked(mut prep) => { 73 | // When at least one node is tracked, register the backward function 74 | let pairwise_distances_state = prep.checkpoint(&pairwise_distances); // Checkpoint pairwise_distances for future retrieval during the backward pass 75 | 76 | let (indicies, distances) = B::knn(pairwise_distances.clone().primitive, k); // Forward pass calculation 77 | print_if!(VERBOSE, "Forward pass indicies (Tracked): {:?}", indicies); // Debug: Print indicies shape 78 | print_if!(VERBOSE, "Forward pass distances (Tracked): {:?}", distances); // Debug: Print distances shape 79 | 80 | let state = (pairwise_distances_state, k); 81 | 82 | // TODO: this is a strange way to convert it 83 | let indicies = B::int_into_float(indicies); 84 | 85 | // The state now includes the checkpointed pairwise_distances and the output 86 | let indicies = prep.finish(state, indicies); // Finish with the computed output 87 | 88 | indicies 89 | } 90 | OpsKind::UnTracked(prep) => { 91 | // If no node is tracked, just do the forward calculation 92 | let output = B::knn(pairwise_distances.clone().primitive, k); // Forward pass calculation 93 | let (indicies, distances) = output; 94 | 95 | print_if!(VERBOSE, "Forward pass indicies (UnTracked): {:?}", indicies); // Debug: Print indicies shape 96 | print_if!( 97 | VERBOSE, 98 | "Forward pass distances (UnTracked): {:?}", 99 | distances 100 | ); // Debug: Print distances shape 101 | 102 | // TODO: this is a strange way to convert it 103 | let indicies = B::int_into_float(indicies); 104 | 105 | let indicies = prep.finish(indicies); 106 | 107 | indicies 108 | } 109 | }; 110 | 111 | let distances = match KnnBackward 112 | .prepare::([pairwise_distances.node.clone()]) 113 | .compute_bound() 114 | .stateful() 115 | { 116 | OpsKind::Tracked(mut prep) => { 117 | // When at least one node is tracked, register the backward function 118 | let pairwise_distances_state = prep.checkpoint(&pairwise_distances); // Checkpoint pairwise_distances for future retrieval during the backward pass 119 | 120 | let (indicies, distances) = B::knn(pairwise_distances.clone().primitive, k); // Forward pass calculation 121 | print_if!(VERBOSE, "Forward pass indicies (Tracked): {:?}", indicies); // Debug: Print indicies shape 122 | print_if!(VERBOSE, "Forward pass distances (Tracked): {:?}", distances); // Debug: Print distances shape 123 | 124 | let state = (pairwise_distances_state, k); 125 | 126 | // The state now includes the checkpointed pairwise_distances and the output 127 | let distances = prep.finish(state, distances); // Finish with the computed output 128 | 129 | distances 130 | } 131 | OpsKind::UnTracked(prep) => { 132 | // If no node is tracked, just do the forward calculation 133 | let output = B::knn(pairwise_distances.clone().primitive, k); // Forward pass calculation 134 | let (indicies, distances) = output; 135 | 136 | print_if!(VERBOSE, "Forward pass indicies (UnTracked): {:?}", indicies); // Debug: Print indicies shape 137 | print_if!( 138 | VERBOSE, 139 | "Forward pass distances (UnTracked): {:?}", 140 | distances 141 | ); // Debug: Print distances shape 142 | 143 | let distances = prep.finish(distances); 144 | 145 | distances 146 | } 147 | }; 148 | 149 | // Extract the inner tensor 150 | let inner_tensor = indicies.into_primitive(); 151 | let int_tensor = B::float_into_int(inner_tensor); 152 | 153 | // Convert the inner tensor to the autodiff backend 154 | let indicies: IntTensor> = IntTensor::>::from(int_tensor); 155 | 156 | (indicies, distances) 157 | } 158 | -------------------------------------------------------------------------------- /src/kernels/knn/kernel.rs: -------------------------------------------------------------------------------- 1 | use core::f32; 2 | use cubecl::{cube, prelude::*}; 3 | 4 | #[cube] 5 | pub fn u32_to_float(x: u32) -> f32 { 6 | f32::cast_from(x) 7 | } 8 | 9 | #[cube] 10 | fn float_to_int_kernel(input: &Tensor, output: &mut Tensor) { 11 | let size = input.len(); 12 | for i in 0..size { 13 | output[i] = I::cast_from(input[i]); 14 | } 15 | } 16 | 17 | const INFINITY: f32 = 3.40282347e+38; // Maximum value for f32, used as infinity 18 | 19 | #[cube(launch)] 20 | pub fn knn_kernel( 21 | pairwise_distances: &Tensor, // Pairwise distance matrix (n, n) 22 | k: u32, // Number of nearest neighbors to find 23 | local_distances: &mut Tensor, // for local distances storage, size of k 24 | local_indices: &mut Tensor, // for local indices storage, size of k 25 | indices: &mut Tensor, // Output tensor of shape (n, k) storing the indices of k nearest neighbors 26 | distances: &mut Tensor, // Output tensor of shape (n, k) storing the distances of k nearest neighbors 27 | ) { 28 | let row = ABSOLUTE_POS_X; // Row index for the pairwise computation 29 | let n = pairwise_distances.shape(0); // Number of vectors 30 | 31 | // Edge case: skip if the row is out of bounds 32 | if row >= n { 33 | // do nothing 34 | } else { 35 | // Pre-allocate arrays to store the k smallest distances and corresponding indices (as F) 36 | // let mut local_distances = Array::::new(k); // Array for storing k smallest distances 37 | // let mut local_indices = Array::::new(k); // Array for storing k smallest indices 38 | 39 | // Initialize arrays with values that will be replaced by actual data 40 | for i in 0..k { 41 | // Initialize distances to infinity and indices to an invalid value 42 | local_distances[i] = F::new(INFINITY); // f32::INFINITY Use F::infinity() to represent infinity 43 | local_indices[i] = I::cast_from(k); // Set to an invalid index (out of range) 44 | 45 | // local_indices[i] = F::cast_from(u32_to_float(k)); // Set to an invalid index (out of range) 46 | } 47 | 48 | // Iterate through all the pairwise distances for the current row 49 | // for col in 0..n { 50 | // if row != col { 51 | // // Skip self-comparison 52 | // let dist = pairwise_distances[row * n + col]; 53 | 54 | // // Find where to insert this distance in the sorted array of top-k distances 55 | // if dist < local_distances[k - 1] { 56 | // let mut i = k - 1; // Start from the last index 57 | 58 | // // Shift larger distances one step to the right to make space for the new distance 59 | // while i > 0 { 60 | // if dist < local_distances[i] { 61 | // local_distances[i] = local_distances[i - 1]; 62 | // local_indices[i] = local_indices[i - 1]; 63 | // } else { 64 | // break; 65 | // } 66 | // i -= 1; // Move to the previous index 67 | // } 68 | 69 | // // Insert the new distance at the correct position 70 | // local_distances[i] = dist; 71 | // // local_indices[i] = F::cast_from(u32_to_float(col)); // Store the corresponding index 72 | // } 73 | // } 74 | // } 75 | 76 | // Copy the results from local arrays into the output tensors 77 | for i in 0..k { 78 | distances[row * k + i] = local_distances[i]; 79 | indices[row * k + i] = local_indices[i]; 80 | } 81 | } 82 | } 83 | 84 | #[cube(launch)] 85 | pub fn knn_backward_kernel( 86 | pairwise_distances: &Tensor, // Pairwise distance matrix (n, n) 87 | k: u32, // Number of nearest neighbors to find 88 | local_distances: &mut Tensor, // for local distances storage, size of k 89 | local_indices: &mut Tensor, // for local indices storage, size of k 90 | grad_output: &Tensor, // Gradient of the loss w.r.t the output (distances and indices) 91 | grad_pairwise_distances: &mut Tensor, // Gradient of the loss w.r.t the input (pairwise distances) 92 | ) { 93 | let row = ABSOLUTE_POS_X; // Row index for the pairwise computation 94 | let n = pairwise_distances.shape(0); // Number of vectors 95 | 96 | // Edge case: skip if the row is out of bounds 97 | if row >= n { 98 | // do nothing 99 | } else { 100 | // Pre-allocate arrays to store the k smallest distances and corresponding indices 101 | // let mut local_distances = Array::::new(k); // Array for storing k smallest distances 102 | // let mut local_indices = Array::::new(k); // Array for storing k smallest indices 103 | 104 | // Initialize arrays with values that will be replaced by actual data 105 | for i in 0..k { 106 | local_distances[i] = F::new(INFINITY); // Use F::infinity() to represent infinity 107 | 108 | // local_indices[i] = F::from_int(k as i64); // Set to an invalid index (out of range) 109 | } 110 | 111 | // Retrieve k nearest neighbors' indices and distances for the current row 112 | for col in 0..n { 113 | if row != col { 114 | // Skip self-comparison 115 | let dist = pairwise_distances[row * n + col]; 116 | 117 | // Find where to insert this distance in the sorted array of top-k distances 118 | if dist < local_distances[k - 1] { 119 | let mut i = k - 1; // Start from the last index 120 | 121 | // Shift larger distances one step to the right to make space for the new distance 122 | while i > 0 { 123 | if dist < local_distances[i] { 124 | local_distances[i] = local_distances[i - 1]; 125 | local_indices[i] = local_indices[i - 1]; 126 | } else { 127 | break; 128 | } 129 | i -= 1; // Move to the previous index 130 | } 131 | 132 | // Insert the new distance at the correct position 133 | local_distances[i] = dist; 134 | // local_indices[i] = F::from_int(col as i64); // Store the corresponding index 135 | } 136 | } 137 | } 138 | 139 | // Compute gradients with respect to the pairwise distances 140 | for i in 0..k { 141 | let _neighbor_index = local_indices[i]; // Get the index of the neighbor (column) 142 | let grad_value = grad_output[row * k + i]; // Get the gradient from the output tensor 143 | 144 | // TODO: once we move indices to IntTensor, refactor the types 145 | // let neighbor_index: u32 = u32::cast_from(neighbor_index); 146 | 147 | // If grad_value is non-zero, propagate the gradient to the pairwise distance 148 | if grad_value != F::new(0.0) { 149 | let dist = local_distances[i]; // The distance between row and neighbor_index 150 | let epsilon = F::new(1e-8); // Small epsilon to avoid division by zero 151 | 152 | // To avoid division by zero, we ensure that the distance is never too small 153 | let dist = F::max(dist, epsilon); 154 | 155 | // Gradient of the pairwise distance with respect to the input tensor 156 | // The gradient is proportional to the inverse of the distance 157 | let grad_pairwise = grad_value / dist; 158 | 159 | // Propagate the gradient back to the pairwise distance matrix 160 | grad_pairwise_distances[row * n] += grad_pairwise; 161 | grad_pairwise_distances[n + row] += grad_pairwise; // Symmetry 162 | 163 | // grad_pairwise_distances[row * n + neighbor_index] += grad_pairwise; 164 | // grad_pairwise_distances[neighbor_index * n + row] += grad_pairwise; // Symmetry 165 | } 166 | } 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /examples/mnist_benchmark.rs: -------------------------------------------------------------------------------- 1 | use burn::{backend::*, module::AutodiffModule as _, prelude::*}; 2 | use crossbeam_channel::{unbounded, Receiver}; 3 | use fast_umap::{backend::AutodiffBackend, chart}; 4 | #[allow(unused)] 5 | use fast_umap::{ 6 | chart::*, 7 | model::*, 8 | prelude::*, 9 | train::{train, LossReduction}, 10 | utils::*, 11 | }; 12 | use mnist::*; 13 | use wgpu::WgpuRuntime; 14 | 15 | fn generate_model_name( 16 | prefix: &str, 17 | learning_rate: f64, 18 | batch_size: usize, 19 | penalty: f64, 20 | hidden_sizes: &Vec, 21 | epochs: usize, 22 | ) -> String { 23 | let hidden_sizes = hidden_sizes 24 | .iter() 25 | .map(|size| format!("{size}")) 26 | .collect::>() 27 | .join("_"); 28 | format!( 29 | "{prefix}_lr_{:.0e}_bs_{:04}_pen_{:.0e}_hs_{hidden_sizes}_ep_{epochs}", 30 | learning_rate, batch_size, penalty 31 | ) 32 | } 33 | 34 | fn execute( 35 | name: String, 36 | num_features: usize, 37 | num_samples: usize, 38 | hidden_sizes: Vec, 39 | output_size: usize, 40 | device: Device, 41 | train_data: Vec, 42 | labels: Vec, 43 | config: TrainingConfig, 44 | exit_rx: Receiver<()>, 45 | ) -> f64 { 46 | // Configure the UMAP model with the specified input size, hidden layer size, and output size 47 | let model_config = UMAPModelConfigBuilder::default() 48 | .input_size(num_features) 49 | .hidden_sizes(hidden_sizes) 50 | .output_size(output_size) 51 | .build() 52 | .unwrap(); 53 | 54 | // Initialize the UMAP model with the defined configuration and the selected device 55 | let model: UMAPModel = UMAPModel::new(&model_config, &device); 56 | 57 | // Start training the UMAP model with the specified training data and configuration 58 | let (model, _, best_loss) = train( 59 | name.as_str(), 60 | model, // The model to train 61 | num_samples, // Total number of training samples 62 | num_features, // Number of features per sample 63 | train_data.clone(), // The training data 64 | &config, // The training configuration 65 | device.clone(), 66 | exit_rx, 67 | ); 68 | 69 | // Validate the trained model after training 70 | let model = model.valid(); 71 | 72 | // Convert the training data into a tensor for model input 73 | let global = convert_vector_to_tensor(train_data, num_samples, num_features, &device); 74 | 75 | // Perform a forward pass through the model to obtain the low-dimensional (local) representation 76 | let local = model.forward(global.clone()); 77 | 78 | let chart_config = ChartConfigBuilder::default() 79 | .caption(name.as_str()) 80 | .path(format!("{name}.png").as_str()) 81 | .build(); 82 | 83 | // Visualize the 2D embedding (local representation) using a chart 84 | chart::chart_tensor(local, Some(labels), Some(chart_config)); 85 | 86 | best_loss 87 | } 88 | 89 | fn find_best_hyperparameters( 90 | num_features: usize, 91 | num_samples: usize, 92 | train_data: Vec, 93 | labels: Vec, 94 | device: Device, 95 | config: TrainingConfig, 96 | learning_rates: Vec, 97 | batch_sizes: Vec, 98 | penalties: Vec, 99 | hidden_size_options: Vec>, 100 | epochs_options: Vec, // Added epochs as an array of values 101 | exit_rx: Receiver<()>, 102 | ) -> (String, f64) { 103 | let mut best_loss = f64::MAX; 104 | let mut best_config = String::new(); 105 | let output_size = 2; // 2D UMAP 2 dimensions 106 | 107 | // let (exit_tx, exit_rx) = channel(); 108 | 109 | // ctrlc::set_handler(move || exit_tx.send(()).expect("Could not send signal on channel.")) 110 | // .expect("Error setting Ctrl-C handler"); 111 | 112 | // Iterate over all combinations of the hyperparameters 113 | for &learning_rate in &learning_rates { 114 | for &batch_size in &batch_sizes { 115 | for &penalty in &penalties { 116 | for hidden_sizes in &hidden_size_options { 117 | for &epochs in &epochs_options { 118 | // Generate a model name for the current combination 119 | let model_name = generate_model_name( 120 | "mnist", 121 | learning_rate, 122 | batch_size, 123 | penalty, 124 | hidden_sizes, 125 | epochs, 126 | ); 127 | 128 | println!("{model_name}"); 129 | 130 | // Modify the configuration with the current batch_size and epochs 131 | let mut current_config = config.clone(); // Copy the config for each iteration 132 | current_config.batch_size = batch_size; 133 | current_config.epochs = epochs; 134 | 135 | // Execute training and get the loss for this configuration 136 | let loss = execute::( 137 | model_name.clone(), 138 | num_features, 139 | num_samples, 140 | hidden_sizes.clone(), 141 | output_size, 142 | device.clone(), 143 | train_data.clone(), 144 | labels.clone(), 145 | current_config, 146 | exit_rx.clone(), 147 | ); 148 | 149 | // Update the best configuration if the current loss is smaller 150 | if loss < best_loss { 151 | best_loss = loss; 152 | best_config = model_name; 153 | } 154 | } 155 | } 156 | } 157 | } 158 | } 159 | 160 | (best_config, best_loss) 161 | } 162 | 163 | // Example usage in main: 164 | 165 | fn main() { 166 | let (exit_tx, exit_rx) = unbounded(); 167 | 168 | ctrlc::set_handler(move || exit_tx.send(()).expect("Could not send signal on channel.")) 169 | .expect("Error setting Ctrl-C handler"); 170 | 171 | // Define a custom backend type using Wgpu with 32-bit floating point precision and 32-bit integer type 172 | type MyBackend = burn::backend::wgpu::CubeBackend; 173 | 174 | // Define the AutodiffBackend based on the custom MyBackend type 175 | type MyAutodiffBackend = burn::backend::Autodiff; 176 | 177 | // Initialize the GPU device for computation 178 | let device = burn::backend::wgpu::WgpuDevice::default(); 179 | 180 | // Set training parameters and configuration 181 | let num_samples = 10_000 as usize; 182 | let num_features = 28 * 28; 183 | let k_neighbors = 15; 184 | let learning_rate = 1e-4; 185 | let penalty = 1e-6; 186 | let beta1 = 0.9; 187 | let beta2 = 0.999; 188 | let seed = 9999; 189 | let verbose = true; 190 | let min_desired_loss = 1e-4; 191 | let metric = Metric::Euclidean; 192 | let loss_reduction = LossReduction::Mean; 193 | let normalized = true; 194 | 195 | // Seed the random number generator to ensure reproducibility 196 | MyBackend::seed(seed); 197 | 198 | let Mnist { 199 | trn_img, trn_lbl, .. 200 | } = MnistBuilder::new() 201 | .download_and_extract() 202 | .label_format_digit() 203 | .training_set_length(num_samples as u32) 204 | .finalize(); 205 | 206 | let train_data: Vec = trn_img.into_iter().map(|byte| byte as f64).collect(); 207 | let labels: Vec = trn_lbl.iter().map(|digit| format!("{digit}")).collect(); 208 | 209 | // Set up the training configuration with the specified hyperparameters 210 | let config = TrainingConfig::builder() // This is a default value; will be overridden in the search 211 | .with_learning_rate(learning_rate) 212 | .with_beta1(beta1) 213 | .with_beta2(beta2) 214 | .with_verbose(verbose) 215 | .with_metric(metric.into()) 216 | .with_k_neighbors(k_neighbors) 217 | .with_min_desired_loss(min_desired_loss) 218 | .with_loss_reduction(loss_reduction) 219 | .with_normalized(normalized) 220 | .with_penalty(penalty) 221 | .build() 222 | .expect("Failed to build TrainingConfig"); 223 | 224 | // Define the arrays of hyperparameters to search 225 | let learning_rates = vec![1e-4, 1e-3, 1e-5]; 226 | let batch_sizes = vec![500, 1000, 2000]; 227 | let penalties = vec![1e-6, 1e-7, 1e-8]; 228 | let hidden_size_options = vec![ 229 | vec![100], 230 | vec![200], 231 | vec![300], 232 | vec![500], // One hidden layer with 500 neurons 233 | vec![1000], // One hidden layer with 1000 neurons 234 | vec![1500], // One hidden layer with 1500 neurons 235 | vec![100, 100], 236 | vec![200, 200], 237 | vec![300, 300], 238 | vec![500, 500], // Two hidden layers, each with 500 neurons 239 | vec![1000, 1000], // One hidden layer with 1000 neurons, another with 500 240 | vec![1500, 1500], // One hidden layer with 1000 neurons, another with 500 241 | vec![100, 100, 100], 242 | vec![200, 200, 200], 243 | vec![300, 300, 300], 244 | vec![500, 500, 500], 245 | ]; 246 | // let epochs_options = vec![100, 200, 500, 1000, 2000, 5000]; // Different epochs options to test 247 | let epochs_options = vec![500]; 248 | 249 | // Find the best hyperparameters 250 | let (best_config, best_loss) = find_best_hyperparameters::( 251 | num_features, 252 | num_samples, 253 | train_data, 254 | labels, 255 | device, 256 | config, 257 | learning_rates, 258 | batch_sizes, 259 | penalties, 260 | hidden_size_options, 261 | epochs_options, 262 | exit_rx, 263 | ); 264 | 265 | // Print the best configuration and its corresponding loss 266 | println!("Best model configuration: {}", best_config); 267 | println!("Best loss: {}", best_loss); 268 | } 269 | -------------------------------------------------------------------------------- /src/train/config.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | 3 | #[derive(Debug, Clone)] 4 | pub enum LossReduction { 5 | Mean, 6 | Sum, 7 | } 8 | 9 | #[derive(Debug, Clone, PartialEq)] 10 | pub enum Metric { 11 | Euclidean, 12 | EuclideanKNN, 13 | // EuclideanWeighted, 14 | Manhattan, 15 | Cosine, 16 | // Correlation, 17 | // Hamming, 18 | // Jaccard, 19 | Minkowski, 20 | // Chebyshev, 21 | // Mahalnobis, 22 | // Spearman, // Spearman’s Rank Correlation Distance 23 | } 24 | 25 | // Implement From<&str> for Metric 26 | impl From<&str> for Metric { 27 | fn from(s: &str) -> Self { 28 | match s.to_lowercase().as_str() { 29 | "euclidean" => Metric::Euclidean, 30 | "euclideanknn" | "euclidean_knn" => Metric::EuclideanKNN, 31 | "manhattan" => Metric::Manhattan, 32 | "cosine" => Metric::Cosine, 33 | "minkowski" => Metric::Minkowski, 34 | _ => panic!("Invalid metric type: {}", s), 35 | } 36 | } 37 | } 38 | 39 | // Implement Display for Metric 40 | impl fmt::Display for Metric { 41 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 42 | match self { 43 | Metric::Euclidean => write!(f, "Euclidean"), 44 | Metric::EuclideanKNN => write!(f, "Euclidean KNN"), 45 | Metric::Manhattan => write!(f, "Manhattan"), 46 | Metric::Cosine => write!(f, "cosine"), 47 | Metric::Minkowski => write!(f, "minkowski"), 48 | } 49 | } 50 | } 51 | 52 | /// Configuration for training the UMAP model. 53 | /// 54 | /// This struct contains the hyperparameters and settings required to train the UMAP model. 55 | /// It includes options for the optimizer (e.g., learning rate, batch size, and beta parameters), 56 | /// device configuration (e.g., CPU or GPU), and additional features like verbosity, early stopping, 57 | /// and time limits for training. 58 | #[derive(Debug, Clone)] 59 | pub struct TrainingConfig { 60 | /// The distance metric to use for training the model (e.g., "euclidean", "manhattan"). 61 | pub metric: Metric, 62 | 63 | /// The total number of epochs to run during training. 64 | pub epochs: usize, 65 | 66 | /// The number of samples to process in each training batch. 67 | pub batch_size: usize, 68 | 69 | /// The learning rate for the optimizer (controls the step size for parameter updates). 70 | pub learning_rate: f64, 71 | 72 | /// The Beta1 parameter for the Adam optimizer (controls the first moment estimate). 73 | pub beta1: f64, 74 | 75 | /// The Beta2 parameter for the Adam optimizer (controls the second moment estimate). 76 | pub beta2: f64, 77 | 78 | /// The L2 regularization (weight decay) penalty to apply during training. 79 | pub penalty: f32, 80 | 81 | /// Whether to show detailed progress information during training (e.g., loss values, progress bars). 82 | pub verbose: bool, 83 | 84 | /// The number of epochs to wait for improvement before triggering early stopping. 85 | /// `None` disables early stopping. 86 | pub patience: Option, 87 | 88 | /// The method used to reduce the loss during training (e.g., mean or sum). 89 | pub loss_reduction: LossReduction, 90 | 91 | /// The number of nearest neighbors to consider in the UMAP algorithm. 92 | pub k_neighbors: usize, 93 | 94 | /// Optionally, the minimum desired loss to achieve before stopping early. 95 | pub min_desired_loss: Option, 96 | 97 | /// The maximum time (in seconds) to allow for training. If `None`, there is no time limit. 98 | pub timeout: Option, 99 | 100 | // normalize distance output 101 | pub normalized: bool, 102 | 103 | pub minkowski_p: f64, 104 | } 105 | 106 | impl TrainingConfig { 107 | /// Creates a new builder for constructing a `TrainingConfig`. 108 | /// 109 | /// This method allows you to incrementally build a `TrainingConfig` by setting its fields. 110 | pub fn builder() -> TrainingConfigBuilder { 111 | TrainingConfigBuilder::default() 112 | } 113 | } 114 | 115 | /// Builder pattern for constructing a `TrainingConfig` with optional parameters. 116 | #[derive(Default)] 117 | pub struct TrainingConfigBuilder { 118 | metric: Option, 119 | epochs: Option, 120 | batch_size: Option, 121 | learning_rate: Option, 122 | beta1: Option, 123 | beta2: Option, 124 | penalty: Option, 125 | verbose: Option, 126 | patience: Option, 127 | loss_reduction: Option, 128 | k_neighbors: Option, 129 | min_desired_loss: Option, 130 | timeout: Option, 131 | normalized: Option, 132 | minkowski_p: Option, 133 | } 134 | 135 | impl TrainingConfigBuilder { 136 | /// Set the distance metric for training (e.g., "Euclidean", "Manhattan"). 137 | pub fn with_metric(mut self, metric: Metric) -> Self { 138 | self.metric = Some(metric); 139 | self 140 | } 141 | 142 | /// Set the number of epochs to train the model. 143 | /// This defines how many times the entire dataset will be processed. 144 | pub fn with_epochs(mut self, epochs: usize) -> Self { 145 | self.epochs = Some(epochs); 146 | self 147 | } 148 | 149 | /// Set the batch size used during training. 150 | /// The batch size determines how many samples are processed before updating the model weights. 151 | pub fn with_batch_size(mut self, batch_size: usize) -> Self { 152 | self.batch_size = Some(batch_size); 153 | self 154 | } 155 | 156 | /// Set the learning rate for the optimizer. 157 | /// The learning rate controls the step size for each parameter update during training. 158 | pub fn with_learning_rate(mut self, learning_rate: f64) -> Self { 159 | self.learning_rate = Some(learning_rate); 160 | self 161 | } 162 | 163 | /// Set the beta1 parameter for the Adam optimizer. 164 | /// Beta1 controls the moving average of the first moment (mean) of the gradients. 165 | pub fn with_beta1(mut self, beta1: f64) -> Self { 166 | self.beta1 = Some(beta1); 167 | self 168 | } 169 | 170 | /// Set the beta2 parameter for the Adam optimizer. 171 | /// Beta2 controls the moving average of the second moment (uncentered variance) of the gradients. 172 | pub fn with_beta2(mut self, beta2: f64) -> Self { 173 | self.beta2 = Some(beta2); 174 | self 175 | } 176 | 177 | /// Set the L2 regularization (weight decay) penalty for the optimizer. 178 | /// This helps prevent overfitting by penalizing large weights. 179 | pub fn with_penalty(mut self, penalty: f32) -> Self { 180 | self.penalty = Some(penalty); 181 | self 182 | } 183 | 184 | /// Set whether verbose output should be shown during training. 185 | /// If `true`, detailed progress (e.g., loss, metrics) will be displayed during training. 186 | pub fn with_verbose(mut self, verbose: bool) -> Self { 187 | self.verbose = Some(verbose); 188 | self 189 | } 190 | 191 | /// Set the patience value for early stopping. 192 | /// 193 | /// If training does not improve the loss for `patience` consecutive epochs, training will stop early. 194 | /// **Warning!** Setting a `patience` value will override the `epochs` parameter. 195 | pub fn with_patience(mut self, patience: i32) -> Self { 196 | self.patience = Some(patience); 197 | self 198 | } 199 | 200 | /// Set the loss reduction method. 201 | /// This defines how the loss is reduced across batches (e.g., sum or mean). 202 | pub fn with_loss_reduction(mut self, loss_reduction: LossReduction) -> Self { 203 | self.loss_reduction = Some(loss_reduction); 204 | self 205 | } 206 | 207 | /// Set the number of nearest neighbors to use in the UMAP algorithm. 208 | /// This parameter controls the neighborhood size used in the model's calculations. 209 | pub fn with_k_neighbors(mut self, k_neighbors: usize) -> Self { 210 | self.k_neighbors = Some(k_neighbors); 211 | self 212 | } 213 | 214 | /// Set the minimum desired loss for early stopping. 215 | /// If the model reaches this loss value, training will stop early. 216 | pub fn with_min_desired_loss(mut self, min_desired_loss: f64) -> Self { 217 | self.min_desired_loss = Some(min_desired_loss); 218 | self 219 | } 220 | 221 | /// Set the maximum training time in seconds. 222 | /// The training will be stopped once this time is exceeded. 223 | pub fn with_timeout(mut self, timeout: u64) -> Self { 224 | self.timeout = Some(timeout); 225 | self 226 | } 227 | 228 | pub fn with_normalized(mut self, normalized: bool) -> Self { 229 | self.normalized = Some(normalized); 230 | self 231 | } 232 | 233 | /// The minkowski function supports any positive value of `p`. For `p = 1`, 234 | /// it computes Manhattan distance, and for `p = 2`, it computes Euclidean distance. 235 | pub fn with_minkowski_p(mut self, minkowski_p: f64) -> Self { 236 | self.minkowski_p = Some(minkowski_p); 237 | self 238 | } 239 | 240 | /// Finalize and create a `TrainingConfig` with the specified options. 241 | /// 242 | /// This method returns an `Option`. If any required parameters are missing, 243 | /// it returns `None`, and default values will be used for those parameters. 244 | pub fn build(self) -> Option { 245 | Some(TrainingConfig { 246 | metric: self.metric.unwrap_or(Metric::Euclidean), // Default to Euclidean if not set 247 | epochs: self.epochs.unwrap_or(1000), // Will panic if not set 248 | batch_size: self.batch_size.unwrap_or(1000), // Will panic if not set 249 | learning_rate: self.learning_rate.unwrap_or(0.001), // Default to 0.001 if not set 250 | beta1: self.beta1.unwrap_or(0.9), // Default beta1 if not set 251 | beta2: self.beta2.unwrap_or(0.999), // Default beta2 if not set 252 | penalty: self.penalty.unwrap_or(1e-5), // Default penalty if not set 253 | verbose: self.verbose.unwrap_or(false), // Default to false if not set 254 | patience: self.patience, // Optional, no default 255 | loss_reduction: self.loss_reduction.unwrap_or(LossReduction::Sum), // Default to Sum if not set 256 | k_neighbors: self.k_neighbors.unwrap_or(15), // Default to 15 if not set 257 | min_desired_loss: self.min_desired_loss, // Optional, no default 258 | timeout: self.timeout, // Optional, no default 259 | normalized: self.normalized.unwrap_or(true), // normalize output of distance by default 260 | minkowski_p: self.minkowski_p.unwrap_or(1.0), // default to 1.0 so it computes Manhattan distance 261 | }) 262 | } 263 | } 264 | -------------------------------------------------------------------------------- /src/distances.rs: -------------------------------------------------------------------------------- 1 | use burn::tensor::Tensor; 2 | 3 | use crate::backend::Backend; 4 | 5 | /// Calculate the pairwise Euclidean distance matrix for a given 2D tensor 6 | /// 7 | /// # Arguments 8 | /// * `x` - A 2D tensor of shape (n_samples, n_features) where each row is a sample and each column is a feature 9 | /// 10 | /// # Returns 11 | /// A 1D tensor containing the pairwise distances (upper triangular part of the distance matrix) of shape (n_samples) 12 | /// 13 | /// This function computes the pairwise Euclidean distance between samples by using broadcasting 14 | /// to efficiently subtract the samples from each other, squaring the differences, and summing across the features. 15 | pub fn euclidean(x: Tensor) -> Tensor { 16 | let n_samples = x.dims()[0]; // Number of samples (rows) 17 | let _n_features = x.dims()[1]; // Number of features (columns) 18 | 19 | // Expand x to shapes that allow broadcasting for pairwise subtraction 20 | let x_expanded = x.clone().unsqueeze::<3>(); // Shape: (1, n_samples, n_features) 21 | let x_transposed = x.clone().unsqueeze_dim(1); // Shape: (n_samples, 1, n_features) 22 | 23 | // Compute pairwise differences using broadcasting 24 | let diff = x_expanded - x_transposed; // Shape: (n_samples, n_samples, n_features) 25 | 26 | // Square the differences element-wise using powi_scalar 27 | let squared_diff = diff.powi_scalar(2); // Element-wise squared differences 28 | 29 | // Sum across the feature dimension (axis 2), producing a shape of (n_samples, n_samples) 30 | let pairwise_squared_distances = squared_diff.sum_dim(2); // Sum across the feature dimension 31 | 32 | // Use `flatten()` to convert the upper triangular part (excluding the diagonal) into a 1D tensor 33 | let pairwise_distances = pairwise_squared_distances.triu(0); // Extract the upper triangular part (without diagonal) 34 | 35 | // Extract the first column (distances from the first sample to all others) 36 | let distances = pairwise_distances 37 | .slice([0..n_samples, 0..1]) 38 | .reshape([n_samples]); 39 | 40 | distances 41 | } 42 | 43 | /// Computes the sum of the top K smallest pairwise squared Euclidean distances for each sample in the input tensor. 44 | /// 45 | /// This function calculates the Euclidean distances between all pairs of samples in the input tensor `x` using an efficient method 46 | /// that avoids creating a full 3D tensor of pairwise distances. It then returns the sum of the K smallest distances for each sample. 47 | /// 48 | /// # Parameters 49 | /// - `x`: A 2D tensor of shape `(n_samples, n_features)` representing the dataset, where each row is a sample and each column is a feature. 50 | /// - `k`: The number of nearest neighbors to consider when computing the sum of distances. 51 | /// 52 | /// # Returns 53 | /// - A 1D tensor of shape `(n_samples,)` containing the sum of the squared Euclidean distances to the top K nearest neighbors 54 | /// for each sample. The distance computation is done efficiently using broadcasting to avoid creating large intermediate tensors. 55 | /// 56 | /// # Example 57 | /// ```rust 58 | /// let x = Tensor::from([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]); 59 | /// let k = 2; 60 | /// let result = euclidean_knn(x, k); 61 | /// println!("{:?}", result); // Output: sum of squared distances for each sample to its 2 nearest neighbors 62 | /// ``` 63 | pub fn euclidean_knn(x: Tensor, k: usize) -> Tensor { 64 | let n_samples = x.dims()[0]; // Number of samples (rows) 65 | 66 | // Expand x to shapes that allow broadcasting for pairwise subtraction: 67 | let x_expanded = x.clone().unsqueeze::<3>(); // Shape: (1, n_samples, n_features) 68 | let x_transposed = x.clone().unsqueeze_dim(1); // Shape: (n_samples, 1, n_features) 69 | 70 | // Compute pairwise differences using broadcasting: 71 | let diff = x_expanded - x_transposed; 72 | 73 | // Element-wise square the differences: 74 | let squared_diff = diff.powi_scalar(2); // Shape: (n_samples, n_samples, n_features) 75 | 76 | // Sum along the feature dimension (axis 2) to get squared Euclidean distance: 77 | let pairwise_squared_distances = squared_diff.sum_dim(2); // Shape: (n_samples, n_samples) 78 | 79 | // Extract the upper triangular part (without diagonal) for efficient KNN calculation: 80 | let pairwise_distances = pairwise_squared_distances.triu(0); // Shape: (n_samples, n_samples) 81 | 82 | // Get the top K smallest distances for each sample (along axis 1): 83 | let (top_k_distances, _top_k_indices) = pairwise_distances.topk_with_indices(k, 1); 84 | 85 | // Sum the top K distances for each sample: 86 | let sum_of_top_k_distances: Tensor = top_k_distances.sum_dim(1).reshape([n_samples]); // Shape: (n_samples) 87 | 88 | sum_of_top_k_distances 89 | } 90 | 91 | pub fn manhattan(tensor: Tensor) -> Tensor { 92 | let n_samples = tensor.dims()[0]; 93 | // Sum the absolute difference along the rows (axis 1) 94 | let x = tensor 95 | .abs() // Take absolute value 96 | .sum_dim(1) 97 | .reshape([n_samples]); // Sum along axis 1 (columns) 98 | 99 | x 100 | } 101 | 102 | /// Computes the cosine similarity between each row of a 2D tensor and the first row. 103 | /// 104 | /// This function calculates the cosine similarity between each sample (row) in the input tensor 105 | /// and the first sample (first row). The cosine similarity is defined as the dot product of two 106 | /// vectors divided by the product of their magnitudes (L2 norms). The result is a 1D tensor where 107 | /// each element represents the cosine similarity between the corresponding row and the first row. 108 | /// 109 | /// # Arguments 110 | /// * `tensor` - A 2D tensor of shape `(n_samples, n_features)` representing the data. The function 111 | /// computes cosine similarity between each row (sample) and the first row. 112 | /// 113 | /// # Returns 114 | /// A 1D tensor of shape `(n_samples,)` containing the cosine similarities between the first row and 115 | /// each of the other rows in the input tensor. The values are in the range [-1, 1], where 1 indicates 116 | /// identical orientation, 0 indicates orthogonality, and -1 indicates opposite orientation. 117 | /// 118 | /// # Example 119 | /// ``` 120 | /// let tensor = Tensor::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); 121 | /// let similarities = cosine(tensor); 122 | /// // `similarities` is a 1D tensor of cosine similarities between the first row and all other rows 123 | /// ``` 124 | /// 125 | /// # Notes 126 | /// The function uses the following steps: 127 | /// 1. Computes the L2 norm (magnitude) of the first row. 128 | /// 2. Computes the dot product of each row with the first row. 129 | /// 3. Computes the L2 norm of each row. 130 | /// 4. Divides the dot product by the product of the norms to compute cosine similarity. 131 | /// 132 | /// # Performance 133 | /// This function clones the tensor multiple times, which may impact performance for large tensors. 134 | /// Optimizations could be made to minimize memory allocations and cloning. 135 | pub fn cosine(tensor: Tensor) -> Tensor { 136 | let n_samples = tensor.dims()[0]; 137 | let n_features = tensor.dims()[1]; 138 | // First, get the first row to compare to 139 | let first_row = tensor.clone().slice([0..1, 0..n_features]); // Select the first row 140 | 141 | // Compute L2 norm of the first row manually (sqrt(sum(x^2))) 142 | let first_row_norm = first_row.clone().powi_scalar(2).sum_dim(1).sqrt(); 143 | 144 | // Compute dot product of each row with the first row 145 | let dot_product: Tensor = tensor.clone().mul(first_row.clone()); // Calculate dot product for each row 146 | let dot_product: Tensor = dot_product.unsqueeze_dim(2); 147 | let dot_product: Tensor = dot_product.sum_dim(1).reshape([n_samples, 1]); // Reshape to a column vector (1D) 148 | 149 | // Compute L2 norm (magnitude) of each row manually (sqrt(sum(x^2))) 150 | let row_norms = tensor.clone().powi_scalar(2).sum_dim(1).sqrt(); 151 | 152 | // Compute cosine similarity 153 | let x = dot_product.div(row_norms).div(first_row_norm); 154 | 155 | x.reshape([n_samples]) 156 | } 157 | 158 | /// Computes the Minkowski distance between each row of a tensor and the first row. 159 | /// 160 | /// The Minkowski distance is a generalized distance metric defined as: 161 | /// 162 | /// D(x, y) = (sum_i |x_i - y_i|^p)^(1/p) 163 | /// 164 | /// Where `x` and `y` are vectors (rows), and `p` is the order of the distance. When `p = 1`, 165 | /// this becomes the **Manhattan distance**, and when `p = 2`, it becomes the **Euclidean distance**. 166 | /// 167 | /// This function calculates the Minkowski distance between each row of the input tensor and the 168 | /// first row of the tensor. It returns a 1D tensor containing the computed distances for each row. 169 | /// 170 | /// # Arguments 171 | /// * `tensor` - A 2D tensor of shape `(n_samples, n_features)`, where each row represents a sample. 172 | /// * `p` - A scalar value representing the order of the Minkowski distance. `p` must be a positive number. 173 | /// 174 | /// # Returns 175 | /// A 1D tensor of shape `(n_samples,)` where each element is the Minkowski distance between the 176 | /// corresponding row of the input tensor and the first row. 177 | /// 178 | /// # Example 179 | /// ``` 180 | /// let tensor = Tensor::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); 181 | /// let distances = minkowski(tensor, 2.0); 182 | /// // `distances` will contain the Euclidean distances between each row and the first row. 183 | /// ``` 184 | /// 185 | /// # Notes 186 | /// - The first row of the tensor is used as the reference row to compute distances. 187 | /// - The function supports any positive value of `p`. For `p = 1`, it computes Manhattan distance, 188 | /// and for `p = 2`, it computes Euclidean distance. 189 | /// - The function works element-wise along rows and sums over features (columns) to compute the distance. 190 | /// 191 | /// # Performance 192 | /// The function clones the tensor to avoid modifying the original data. For large tensors, this may 193 | /// incur some overhead due to memory allocation. You may want to explore optimization techniques like 194 | /// in-place operations if memory usage is a concern. 195 | pub fn minkowski(tensor: Tensor, p: f64) -> Tensor { 196 | let n_samples = tensor.dims()[0]; 197 | let n_features = tensor.dims()[1]; 198 | 199 | // Compute the L2 norm (distance to itself) for each row as a reference (row 0) 200 | let reference_row = tensor.clone().slice([0..1, 0..n_features]); // First row as the reference 201 | 202 | // Compute the element-wise absolute difference between the reference row and all rows 203 | let diff = tensor.clone().sub(reference_row.clone()).abs(); 204 | 205 | // Raise the absolute differences to the power of p 206 | let diff_p = diff.powi_scalar(p); 207 | 208 | // Sum over the features (dim 1) for each row 209 | let sum_p = diff_p.sum_dim(1); 210 | 211 | // Take the p-th root of the sum 212 | let distances = sum_p.powi_scalar((1.0 / p).ceil() as i32); // have to convert to integer, otherwise it returns NaN 213 | 214 | // Return the distances as a 1D tensor 215 | distances.reshape([n_samples]) 216 | } 217 | -------------------------------------------------------------------------------- /src/chart.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::*; 2 | use burn::prelude::*; 3 | use hsl::HSL; 4 | use plotters::{ 5 | prelude::*, 6 | style::text_anchor::{HPos, Pos, VPos}, 7 | }; 8 | use std::collections::HashSet; 9 | 10 | /// The default caption for the chart 11 | const CAPTION: &str = "fast-umap"; 12 | 13 | /// The default path where the plot will be saved 14 | const PATH: &str = "plot.png"; 15 | 16 | /// Configuration structure for the chart, including caption, path, width, and height 17 | #[derive(Debug, Clone)] 18 | pub struct ChartConfig { 19 | pub caption: String, 20 | pub path: String, 21 | pub width: u32, 22 | pub height: u32, 23 | } 24 | 25 | impl ChartConfig { 26 | /// Builder pattern for configuring the chart 27 | pub fn builder() -> ChartConfigBuilder { 28 | ChartConfigBuilder { 29 | caption: Some(CAPTION.to_string()), 30 | path: Some(PATH.to_string()), 31 | width: Some(1000), 32 | height: Some(1000), 33 | } 34 | } 35 | } 36 | 37 | impl Default for ChartConfig { 38 | /// Default implementation for ChartConfig with preset values 39 | fn default() -> Self { 40 | ChartConfig { 41 | caption: CAPTION.to_string(), 42 | path: PATH.to_string(), 43 | width: 1000, 44 | height: 1000, 45 | } 46 | } 47 | } 48 | 49 | /// Builder pattern for `ChartConfig` struct to allow flexible configuration 50 | pub struct ChartConfigBuilder { 51 | caption: Option, 52 | path: Option, 53 | width: Option, 54 | height: Option, 55 | } 56 | 57 | impl Default for ChartConfigBuilder { 58 | fn default() -> Self { 59 | ChartConfigBuilder { 60 | caption: Some(CAPTION.into()), 61 | path: Some(PATH.into()), 62 | width: None, 63 | height: None, 64 | } 65 | } 66 | } 67 | 68 | impl ChartConfigBuilder { 69 | /// Set the caption for the chart 70 | pub fn caption(mut self, caption: &str) -> Self { 71 | self.caption = Some(caption.to_string()); 72 | self 73 | } 74 | 75 | /// Set the path where the chart will be saved 76 | pub fn path(mut self, path: &str) -> Self { 77 | self.path = Some(path.to_string()); 78 | self 79 | } 80 | 81 | /// Set the width of the chart 82 | pub fn width(mut self, width: u32) -> Self { 83 | self.width = Some(width); 84 | self 85 | } 86 | 87 | /// Set the height of the chart 88 | pub fn height(mut self, height: u32) -> Self { 89 | self.height = Some(height); 90 | self 91 | } 92 | 93 | /// Build and return the final `ChartConfig` 94 | pub fn build(self) -> ChartConfig { 95 | ChartConfig { 96 | caption: self.caption.unwrap_or_else(|| CAPTION.to_string()), 97 | path: self.path.unwrap_or_else(|| PATH.to_string()), 98 | width: self.width.unwrap_or(1000), 99 | height: self.height.unwrap_or(1000), 100 | } 101 | } 102 | } 103 | 104 | type Float = f64; 105 | 106 | /// Plot the 2D chart using the given tensor data and optional chart configuration 107 | /// 108 | /// # Arguments 109 | /// * `data` - A 2D tensor of data points to plot 110 | /// * `config` - Optional custom chart configuration 111 | pub fn chart_tensor( 112 | data: Tensor, 113 | labels: Option>, 114 | config: Option, 115 | ) { 116 | // pub fn chart_tensor(data: Tensor, config: Option) { 117 | let data: Vec> = convert_tensor_to_vector(data); 118 | chart_vector(data, labels, config); 119 | } 120 | 121 | /// Plot the loss curve over epochs and save it to a file 122 | /// 123 | /// # Arguments 124 | /// * `losses` - A vector of loss values over multiple epochs 125 | /// * `output_path` - Path where the plot will be saved 126 | pub fn plot_loss( 127 | losses: Vec, 128 | output_path: &str, 129 | ) -> Result<(), Box> 130 | where 131 | F:, 132 | { 133 | // Calculate the min and max loss values 134 | let min_loss = losses.iter().cloned().fold(F::infinity(), F::min); 135 | let max_loss = losses.iter().cloned().fold(F::neg_infinity(), F::max); 136 | 137 | // Add padding to the min and max values for better visualization 138 | let padding = F::from(0.1).unwrap(); // 10% padding, adjust as needed 139 | let min_loss_with_padding = min_loss - padding * min_loss.abs(); 140 | let max_loss_with_padding = max_loss + padding * max_loss.abs(); 141 | let min_loss_with_padding = min_loss_with_padding.to_f64().unwrap(); 142 | let max_loss_with_padding = max_loss_with_padding.to_f64().unwrap(); 143 | 144 | // Create a drawing area with a width of 800px and a height of 600px 145 | let root = BitMapBackend::new(output_path, (800, 600)).into_drawing_area(); 146 | root.fill(&WHITE)?; 147 | 148 | // Create a chart builder with padded Y-axis range 149 | let mut chart = ChartBuilder::on(&root) 150 | .caption("Loss Over Epochs", ("sans-serif", 30)) 151 | .set_label_area_size(LabelAreaPosition::Left, 80) 152 | .set_label_area_size(LabelAreaPosition::Bottom, 50) 153 | .build_cartesian_2d( 154 | 0..losses.len() as u32, 155 | min_loss_with_padding..max_loss_with_padding, 156 | )?; 157 | 158 | // Draw the chart axes and grid 159 | chart 160 | .configure_mesh() 161 | .y_desc("Loss") 162 | .x_desc("Epochs") 163 | .draw()?; 164 | 165 | // Plot the losses as a line 166 | chart 167 | .draw_series(LineSeries::new( 168 | (0..losses.len()).map(|x| (x as u32, losses[x].to_f64().unwrap())), 169 | &BLUE, 170 | ))? 171 | .label("Loss") 172 | .legend(move |(x, y)| PathElement::new(vec![(x, y)], &RED)); 173 | 174 | // Draw the legend 175 | chart.configure_series_labels().draw()?; 176 | 177 | // Format Y-axis labels to handle small floats 178 | chart.configure_mesh().y_labels(10).draw()?; 179 | 180 | Ok(()) 181 | } 182 | 183 | pub fn chart_vector( 184 | data: Vec>, 185 | labels: Option>, 186 | config: Option, 187 | ) { 188 | let config = config.unwrap_or(ChartConfig::default()); 189 | 190 | // Create the drawing area 191 | let root = BitMapBackend::new(&config.path, (config.width, config.height)).into_drawing_area(); 192 | root.fill(&WHITE).unwrap(); 193 | 194 | // Calculate min and max for x and y axes 195 | let min_x = data 196 | .iter() 197 | .flat_map(|v| v.iter().step_by(2)) 198 | .cloned() 199 | .min_by(|a, b| a.partial_cmp(b).unwrap()) 200 | .unwrap() as Float; 201 | let max_x = data 202 | .iter() 203 | .flat_map(|v| v.iter().step_by(2)) 204 | .cloned() 205 | .max_by(|a, b| a.partial_cmp(b).unwrap()) 206 | .unwrap() as Float; 207 | let min_y = data 208 | .iter() 209 | .flat_map(|v| v.iter().skip(1).step_by(2)) 210 | .cloned() 211 | .min_by(|a, b| a.partial_cmp(b).unwrap()) 212 | .unwrap() as Float; 213 | let max_y = data 214 | .iter() 215 | .flat_map(|v| v.iter().skip(1).step_by(2)) 216 | .cloned() 217 | .max_by(|a, b| a.partial_cmp(b).unwrap()) 218 | .unwrap() as Float; 219 | 220 | // Assign colors to unique labels if provided 221 | let mut label_colors: Vec<(String, RGBColor)> = Vec::new(); 222 | if let Some(labels) = labels.clone() { 223 | let unique_labels: Vec = labels 224 | .into_iter() 225 | .collect::>() 226 | .into_iter() 227 | .collect(); 228 | for (i, label) in unique_labels.iter().enumerate() { 229 | let hue = i as f64 * 360.0 / unique_labels.len() as f64; // Even distribution of hues 230 | let color = HSL { 231 | h: hue, 232 | s: 0.7, 233 | l: 0.6, 234 | } 235 | .to_rgb(); 236 | label_colors.push((label.clone(), RGBColor(color.0, color.1, color.2))); 237 | } 238 | } 239 | 240 | // Build chart 241 | let mut chart = ChartBuilder::on(&root) 242 | .caption(config.caption, ("sans-serif", 30)) 243 | .margin(40) 244 | .x_label_area_size(30) 245 | .y_label_area_size(30) 246 | .build_cartesian_2d(min_x..max_x, min_y..max_y) 247 | .unwrap(); 248 | 249 | // Configure and draw the mesh (axes) 250 | chart 251 | .configure_mesh() 252 | .x_desc("X Axis") 253 | .y_desc("Y Axis") 254 | .x_labels(10) 255 | .y_labels(10) 256 | .draw() 257 | .unwrap(); 258 | 259 | // Store series for later adding to the legend 260 | let mut series_list: Vec<(String, Vec<(f64, f64)>, RGBColor)> = Vec::new(); 261 | 262 | // Draw data points and labels 263 | chart 264 | .draw_series(data.iter().enumerate().map(|(i, values)| { 265 | let label = labels 266 | .clone() 267 | .map(|l| l.get(i).cloned()) 268 | .flatten() 269 | .unwrap_or_else(|| "".into()); 270 | let color = label_colors 271 | .iter() 272 | .find(|(l, _)| *l == label) 273 | .map(|(_, c)| *c) 274 | .unwrap_or(RED); 275 | 276 | // Store series data for the legend 277 | if !label.is_empty() { 278 | let series_data = series_list.iter_mut().find(|(l, _, _)| *l == label); 279 | match series_data { 280 | Some((_, series_points, _)) => series_points.push((values[0], values[1])), 281 | None => series_list.push((label.clone(), vec![(values[0], values[1])], color)), 282 | } 283 | } 284 | 285 | // Draw circle for each point 286 | Circle::new( 287 | (values[0], values[1]), 288 | 3, 289 | ShapeStyle { 290 | color: color.into(), 291 | filled: false, 292 | stroke_width: 1, 293 | }, 294 | ) 295 | })) 296 | .unwrap(); 297 | 298 | // Add the legend manually 299 | if labels.is_some() { 300 | // Sort the series list alphabetically by label 301 | series_list.sort_by(|a, b| { 302 | let a = a.0.parse::().unwrap(); 303 | let b = b.0.parse::().unwrap(); 304 | a.cmp(&b) 305 | // a.0.cmp(&b.0) 306 | }); 307 | 308 | let spacing_y = (max_y - min_y) / (series_list.len() * 2) as f64; 309 | 310 | // Define the starting position for the legend 311 | let mut legend_position = (min_x + (max_x - min_x) * 0.8, max_y - (max_y - min_y) * 0.1); 312 | // let spacing = 10.0; // Increase the spacing 313 | let size = 5.0; // Make the circles slightly larger 314 | let font_size = 15.0; 315 | 316 | for (label, _, color) in series_list { 317 | // Draw a colored circle for each label in the legend 318 | chart 319 | .draw_series(std::iter::once(Circle::new( 320 | legend_position, 321 | size, 322 | ShapeStyle { 323 | color: color.into(), 324 | filled: true, 325 | stroke_width: 1, 326 | }, 327 | ))) 328 | .unwrap(); 329 | 330 | let style = TextStyle { 331 | font: ("sans-serif", font_size).into_font(), 332 | color: BLACK.to_backend_color(), 333 | pos: Pos::new(HPos::Left, VPos::Center), 334 | }; 335 | 336 | // Draw the label text next to the circle 337 | chart 338 | .draw_series(std::iter::once(Text::new( 339 | label, 340 | (legend_position.0 + spacing_y / 4.0, legend_position.1), 341 | style, 342 | ))) 343 | .unwrap(); 344 | 345 | // Move the position for the next legend item downwards 346 | legend_position.1 -= spacing_y; 347 | } 348 | } 349 | 350 | // Save the chart to file 351 | root.present().unwrap(); 352 | } 353 | -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | use core::f64; 2 | 3 | use burn::{ 4 | prelude::Backend, 5 | tensor::{ops::FloatTensor, Device, Tensor, TensorData, TensorPrimitive}, 6 | }; 7 | use num::{Float, FromPrimitive}; 8 | // use prettytable::{row, Table}; 9 | use rand::{distr::uniform::SampleUniform, Rng}; 10 | use rayon::prelude::*; 11 | 12 | /// Generates random test data with the given number of samples and features. 13 | /// 14 | /// # Arguments 15 | /// * `num_samples` - The number of samples to generate. 16 | /// * `num_features` - The number of features (columns) per sample. 17 | /// 18 | /// # Returns 19 | /// A `Vec` containing the randomly generated data. 20 | /// 21 | /// This function uses the `rand` crate to generate a flat vector of random floating-point values. 22 | pub fn generate_test_data( 23 | num_samples: usize, // Number of samples 24 | num_features: usize, // Number of features (columns) per sample 25 | ) -> Vec { 26 | let mut rng = rand::rng(); 27 | 28 | // Define the range for random numbers (e.g., [0.0, 1.0)) 29 | let zero = F::from_f64(0.0).unwrap(); // 0.0 as a `F` type 30 | let one = F::from_f64(1.0).unwrap(); // 1.0 as a `F` type 31 | 32 | // Generate random data for the tensor (size = num_samples * num_features) 33 | let data: Vec = (0..num_samples * num_features) 34 | .map(|_| rng.random_range(zero..one)) // Generate random number from the range [0.0, 1.0) 35 | .collect(); 36 | 37 | data 38 | } 39 | 40 | /// Converts a vector of `f64` values into a `Tensor` for the specified backend. 41 | /// 42 | /// # Arguments 43 | /// * `data` - A vector of `f64` values representing the data to convert. 44 | /// * `num_samples` - The number of samples (rows). 45 | /// * `num_features` - The number of features (columns). 46 | /// * `device` - The device to place the tensor on (e.g., CPU, GPU). 47 | /// 48 | /// # Returns 49 | /// A `Tensor` containing the data arranged as samples x features. 50 | /// 51 | /// This function uses the `TensorData` struct to create a tensor from the given data, then places it 52 | /// on the specified device (`CPU` or `GPU`). 53 | pub fn convert_vector_to_tensor( 54 | data: Vec, 55 | num_samples: usize, // Number of samples 56 | num_features: usize, // Number of features (columns) per sample 57 | device: &Device, // Device to place the tensor (CPU, GPU) 58 | ) -> Tensor 59 | where 60 | F: burn::tensor::Element, 61 | { 62 | let tensor_data = TensorData::new(data, [num_samples, num_features]); 63 | Tensor::::from_data(tensor_data, device) 64 | } 65 | 66 | /// Prints the content of a tensor in a table format with index and tensor values. 67 | /// 68 | /// # Arguments 69 | /// * `data` - The tensor to print, with a generic backend and dimensionality `D`. 70 | /// 71 | /// This function prints the tensor's data in a table with each row corresponding to one sample. 72 | /// The tensor data is printed in a format that makes it easy to inspect. 73 | // pub fn print_tensor(data: &Tensor) { 74 | // let dims = data.dims(); 75 | // let n_samples = match dims.len() > 0 { 76 | // true => dims[0], 77 | // false => 0, 78 | // }; 79 | 80 | // let mut table = Table::new(); 81 | // table.add_row(row!["Index", "Tensor"]); 82 | 83 | // for index in 0..n_samples { 84 | // let row = data.clone().slice([index..index + 1]); 85 | // let row = row.to_data().to_vec::().unwrap(); 86 | // let row = format!("{row:?}"); 87 | // table.add_row(row![index, format!("{:?}", row)]); 88 | // } 89 | 90 | // if dims.len() == 0 { 91 | // let row = data.to_data().to_vec::().unwrap(); 92 | // let row = row.get(0).unwrap(); 93 | // table.add_row(row![0, format!("{:?}", row)]); 94 | // } 95 | 96 | // table.printstd(); 97 | // } 98 | 99 | /// Prints the content of a tensor with a title. 100 | /// 101 | /// # Arguments 102 | /// * `title` - A string title to print before displaying the tensor data. 103 | /// * `data` - The tensor to print. 104 | /// 105 | /// This function is similar to `print_tensor`, but with an added title to help distinguish different tensor prints. 106 | // pub fn print_tensor_with_title(title: &str, data: &Tensor) { 107 | // println!("{title}"); 108 | // print_tensor(data); 109 | // } 110 | 111 | /// Converts a 2D tensor into a `Vec>` for easier inspection or manipulation. 112 | /// 113 | /// # Arguments 114 | /// * `data` - A 2D tensor (samples x features) to convert into a vector of vectors. 115 | /// 116 | /// # Returns 117 | /// A `Vec>` where each inner `Vec` represents a row (sample) of the tensor. 118 | /// 119 | /// This function extracts the data from a tensor and converts it into a `Vec>` format. The conversion 120 | /// assumes that the tensor is in a 2D shape and the precision is `f32` within the tensor. 121 | pub fn convert_tensor_to_vector(data: Tensor) -> Vec> 122 | where 123 | F: burn::tensor::Element, 124 | { 125 | let n_components = data.dims()[1]; // usually 2 dimensional 126 | 127 | // Burn Tensor only has f32 precision inside the tensors, when you export to to_data 128 | let data = data.to_data().to_vec::().unwrap(); 129 | 130 | let data: Vec> = data 131 | .chunks(n_components) 132 | .map(|chunk| chunk.to_vec()) 133 | .collect(); 134 | 135 | let data: Vec> = data 136 | .into_iter() 137 | .map(|v| { 138 | v.into_iter() 139 | .map(|v| { 140 | if v.is_nan() { 141 | // TODO: fix this on a different level in the loss function! 142 | // if NaN variables, replaces them with 0 143 | F::from(0.0).unwrap() 144 | } else { 145 | F::from(v).unwrap() 146 | } 147 | }) 148 | .collect() 149 | }) 150 | .collect(); 151 | 152 | data 153 | } 154 | 155 | /// Formats a `Duration` into a human-readable string in hours, minutes, and seconds format. 156 | /// 157 | /// # Arguments 158 | /// * `duration` - The duration to format. 159 | /// 160 | /// # Returns 161 | /// A formatted string representing the duration in the format `HH:MM:SS`. 162 | /// 163 | /// This function is useful for displaying elapsed times or durations in a more readable format. 164 | pub fn format_duration(duration: std::time::Duration) -> String { 165 | let secs = duration.as_secs(); 166 | let hours = secs / 3600; 167 | let minutes = (secs % 3600) / 60; 168 | let seconds = secs % 60; 169 | format!("{:02}:{:02}:{:02}", hours, minutes, seconds) 170 | } 171 | 172 | // a constant used to offset division by zero in the normalization function below 173 | const SMALL_STD_DEV: f64 = 1e-6; 174 | 175 | /// Normalizes the given dataset by centering each feature (column) to have mean 0 176 | /// and standard deviation 1. 177 | /// 178 | /// # Arguments 179 | /// * `data` - A mutable slice representing the dataset, where each row is a sample, 180 | /// and each column represents a feature. The data is assumed to be stored in 181 | /// row-major order (i.e., `data[sample_idx * num_features + feature_idx]`). 182 | /// * `num_samples` - The number of samples (rows) in the dataset. 183 | /// * `num_features` - The number of features (columns) in the dataset. 184 | /// 185 | /// # Example 186 | /// ``` 187 | /// let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; 188 | /// let num_samples = 2; 189 | /// let num_features = 3; 190 | /// normalize_data(&mut data, num_samples, num_features); 191 | /// ``` 192 | /// The function will normalize each feature (column) across all samples (rows). 193 | /// 194 | /// # Note 195 | /// This function assumes that the dataset has at least one sample and one feature. 196 | /// The data is normalized in-place, meaning the original data is modified directly. 197 | pub fn normalize_data(data: &mut [F], num_samples: usize, num_features: usize) 198 | where 199 | F: num::FromPrimitive + Send + Sync, 200 | { 201 | // Parallelize the outer loop over features 202 | (0..num_features).into_iter().for_each(|feature_idx| { 203 | // Calculate mean and standard deviation for the current feature 204 | let (sum, sum_sq) = (0..num_samples) 205 | .into_par_iter() 206 | .fold( 207 | || (F::zero(), F::zero()), // Initial value for fold: (sum, sum_sq) 208 | |(acc_sum, acc_sum_sq), sample_idx| { 209 | let value = data[sample_idx * num_features + feature_idx]; 210 | (acc_sum + value, acc_sum_sq + value * value) 211 | }, 212 | ) 213 | .reduce( 214 | || (F::zero(), F::zero()), 215 | |(sum1, sum_sq1), (sum2, sum_sq2)| (sum1 + sum2, sum_sq1 + sum_sq2), 216 | ); 217 | 218 | let mean = sum / F::from_usize(num_samples).unwrap(); 219 | let variance = (sum_sq / F::from_usize(num_samples).unwrap()) - (mean * mean); 220 | let std_dev = variance.sqrt(); 221 | 222 | // Avoid division by zero by adding SMALL_STD_DEV 223 | let safe_std_dev = std_dev + F::from_f64(SMALL_STD_DEV).unwrap(); 224 | 225 | // Normalize the feature in parallel 226 | (0..num_samples).into_iter().for_each(|sample_idx| { 227 | let idx = sample_idx * num_features + feature_idx; 228 | let value = data[idx]; 229 | let normalized_value = (value - mean) / safe_std_dev; 230 | // Directly update the value in the `data` array 231 | data[idx] = normalized_value; 232 | }); 233 | }); 234 | } 235 | 236 | /// Normalizes a 1D tensor using min-max normalization. 237 | /// 238 | /// This function performs min-max normalization on a 1D tensor, scaling its values 239 | /// to a range between 0 and 1. If the minimum and maximum values of the tensor are 240 | /// equal (i.e., the tensor has no variance), the original tensor is returned unmodified. 241 | /// 242 | /// # Type Parameters 243 | /// 244 | /// - `B`: The autodiff backend type. This should implement the `AutodiffBackend` trait, 245 | /// which provides support for automatic differentiation. 246 | /// 247 | /// # Arguments 248 | /// 249 | /// - `tensor`: A 1D tensor of type `Tensor`, which represents the input data to be normalized. 250 | /// 251 | /// # Returns 252 | /// 253 | /// - A new tensor of the same type and shape as the input tensor, with normalized values. 254 | /// If the minimum and maximum values of the tensor are equal, the original tensor is returned unchanged. 255 | /// 256 | /// # Explanation 257 | /// 258 | /// The normalization is performed using the following formula: 259 | /// 260 | /// ``` 261 | /// normalized = (tensor - min) / (max - min + epsilon) 262 | /// ``` 263 | /// 264 | /// Where: 265 | /// - `min`: The minimum value in the tensor. 266 | /// - `max`: The maximum value in the tensor. 267 | /// - `epsilon`: A small constant (`1e-6`) added to prevent division by zero, ensuring numerical stability. 268 | /// 269 | /// The function first checks if the minimum and maximum values are equal. If they are, it avoids division by zero 270 | /// and simply returns the original tensor. If they are not equal, it applies the min-max normalization formula. 271 | /// 272 | /// # Example 273 | /// 274 | /// ```rust 275 | /// let tensor = Tensor::::from_data(vec![1.0, 2.0, 3.0], &device); 276 | /// let normalized_tensor = normalize_tensor(tensor); 277 | /// ``` 278 | pub fn normalize_tensor(tensor: Tensor) -> Tensor { 279 | let device = tensor.device(); 280 | 281 | // Normalize the result using min-max normalization: 282 | let min_val = tensor.clone().min(); // Find the minimum value 283 | let max_val = tensor.clone().max(); // Find the maximum value 284 | 285 | // this is to prevent deleting by zero 286 | let offset_val = Tensor::::from_data(TensorData::new(vec![1e-6], [1]), &device); 287 | 288 | let are_equal = max_val 289 | .clone() 290 | .equal(min_val.clone()) 291 | .to_data() 292 | .to_vec::() 293 | .unwrap(); 294 | 295 | let are_equal = are_equal.first().unwrap(); 296 | 297 | // Avoid division by zero by ensuring max_val != min_val 298 | let normalized = if !are_equal { 299 | (tensor - min_val.clone()) / (max_val - min_val + offset_val) 300 | } else { 301 | tensor.clone() // If all values are the same, return the original 302 | }; 303 | 304 | // Return the normalized sum of the top K distances 305 | normalized 306 | } 307 | 308 | pub fn print_primitive_tensor(tensor: &FloatTensor, rows: usize, cols: usize) { 309 | let tensor: Tensor = Tensor::from_primitive(TensorPrimitive::Float(tensor.clone())); 310 | print_tensor(&tensor, rows, cols) 311 | } 312 | 313 | const MIN_SIZE: usize = 10; 314 | pub fn print_tensor(tensor: &Tensor, rows: usize, cols: usize) { 315 | let shape = tensor.shape().dims; 316 | let n = shape[0]; // Number of rows 317 | let d = shape[1]; // Number of columns 318 | 319 | let data = tensor.to_data().to_vec::().unwrap(); 320 | 321 | let nn = n.min(rows.max(MIN_SIZE)); // Ensure nn is at least MIN_SIZE if rows > 0 322 | let dd = d.min(cols.max(MIN_SIZE)); // Ensure dd is at least MIN_SIZE if cols > 0 323 | 324 | // Print first few rows and columns with scientific notation 325 | for i in 0..nn { 326 | for j in 0..dd { 327 | // Print each element with scientific notation, limited to 3 decimal places 328 | print!("{:10.3e}", data[i * d + j]); 329 | } 330 | println!(); 331 | } 332 | } 333 | -------------------------------------------------------------------------------- /src/train/mod.rs: -------------------------------------------------------------------------------- 1 | mod config; 2 | mod get_distance_by_metric; 3 | 4 | use crate::{ 5 | backend::AutodiffBackend, 6 | chart::{self, plot_loss, ChartConfigBuilder}, 7 | format_duration, 8 | model::UMAPModel, 9 | normalize_data, 10 | utils::convert_vector_to_tensor, 11 | }; 12 | use burn::{ 13 | module::{AutodiffModule, Module}, 14 | nn::loss::MseLoss, 15 | optim::{ 16 | decay::WeightDecayConfig, AdamConfig, GradientsAccumulator, GradientsParams, Optimizer, 17 | }, 18 | record::{BinFileRecorder, FullPrecisionSettings}, 19 | tensor::{cast::ToElement, Device, Shape, Tensor}, 20 | }; 21 | pub use config::*; 22 | 23 | use crossbeam_channel::Receiver; 24 | use get_distance_by_metric::*; 25 | use indicatif::{ProgressBar, ProgressStyle}; 26 | use num::{Float, FromPrimitive}; 27 | use std::time::Duration; 28 | use std::{thread, time::Instant}; 29 | 30 | /// Train the UMAP model over multiple epochs. 31 | /// 32 | /// This function trains the UMAP model by iterating over the dataset for the specified 33 | /// number of epochs. The model's parameters are updated using the Adam optimizer with 34 | /// the specified learning rate, weight decay, and beta parameters. The loss is computed 35 | /// at each epoch, and progress is displayed via a progress bar if verbose mode is enabled. 36 | /// 37 | /// # Arguments 38 | /// * `model`: The UMAP model to be trained. 39 | /// * `num_samples`: The number of samples in the training data. 40 | /// * `num_features`: The number of features per sample (columns in the data). 41 | /// * `data`: The training data as a flat `Vec`, where each sample is represented as a 42 | /// sequence of `num_features` values. 43 | /// * `config`: The `TrainingConfig` containing training hyperparameters and options. 44 | pub fn train( 45 | name: &str, 46 | mut model: UMAPModel, 47 | num_samples: usize, // Number of samples in the dataset. 48 | num_features: usize, // Number of features (columns) in each sample. 49 | mut data: Vec, // Training data. 50 | config: &TrainingConfig, // Configuration parameters for training. 51 | device: Device, 52 | exit_rx: Receiver<()>, 53 | ) -> (UMAPModel, Vec, F) 54 | where 55 | F: FromPrimitive + Send + Sync + burn::tensor::Element, 56 | { 57 | if config.metric == Metric::EuclideanKNN && config.k_neighbors > num_samples { 58 | panic!("When using Euclidean KNN distance, k_neighbors should be smaller than number of samples!") 59 | } 60 | 61 | // you can also store in memory using BytesRecorder 62 | let recorder = BinFileRecorder::::new(); 63 | let model_path = format!("./{name}.bin"); 64 | 65 | let batch_size = config.batch_size; 66 | 67 | #[cfg(feature = "verbose")] 68 | { 69 | println!("config - {config:#?}"); 70 | } 71 | 72 | if batch_size == 1 { 73 | panic!("You can not have a batch size of 1, because UMAP doesn't have other samples to compare to!") 74 | } 75 | 76 | // Normalize the input data (Z-score normalization). 77 | normalize_data(&mut data, num_samples, num_features); 78 | 79 | // Step 1: Split the data into batches (Vec>). 80 | let mut batches: Vec> = Vec::new(); 81 | for batch_start in (0..num_samples).step_by(config.batch_size) { 82 | let batch_end = std::cmp::min(batch_start + config.batch_size, num_samples); 83 | // Create a batch by extracting `batch_size * num_features` elements 84 | let mut batch = Vec::new(); 85 | for i in batch_start..batch_end { 86 | let start_idx = i * num_features; 87 | let end_idx = start_idx + num_features; 88 | batch.extend_from_slice(&data[start_idx..end_idx]); 89 | } 90 | batches.push(batch); 91 | } 92 | 93 | // Step 2: Precompute the tensor representations and global distances for each batch. 94 | let mut tensor_batches: Vec> = Vec::new(); 95 | let mut global_distances_batches: Vec> = Vec::new(); 96 | 97 | // store the size of the Tensor after the distance has been calculated 98 | let mut global_distance_size: Shape = Shape::from([0, 0]); 99 | 100 | for batch_data in &batches { 101 | // Convert each batch to tensor format. 102 | let tensor_batch = 103 | convert_vector_to_tensor(batch_data.clone(), batch_size, num_features, &device); 104 | 105 | tensor_batches.push(tensor_batch); 106 | 107 | // Compute the global distances for each batch (using the entire dataset). 108 | let global_tensor_data = 109 | convert_vector_to_tensor(data.clone(), batch_size, num_features, &device); 110 | let global_distances = 111 | get_distance_by_metric(global_tensor_data.clone(), config, Some("global".into())); 112 | 113 | global_distance_size = global_distances.shape(); 114 | global_distances_batches.push(global_distances); 115 | } 116 | 117 | let global_distances_all = Tensor::::cat(global_distances_batches, 0); // Concatenate along the 0-axis 118 | let tensor_batches_all = Tensor::::cat(tensor_batches, 0); // Concatenate along the 0-axis 119 | 120 | // Initialize the Adam optimizer with weight decay (L2 regularization). 121 | let config_optimizer = AdamConfig::new() 122 | .with_weight_decay(Some(WeightDecayConfig::new(config.penalty))) 123 | .with_beta_1(config.beta1 as f32) 124 | .with_beta_2(config.beta2 as f32); 125 | let mut optim = config_optimizer.init(); 126 | 127 | let mut accumulator = GradientsAccumulator::new(); 128 | 129 | // Start the timer to track training duration. 130 | let start_time = Instant::now(); 131 | 132 | // Initialize a progress bar for verbose output, if enabled. 133 | let pb = match config.verbose { 134 | true => { 135 | let pb = ProgressBar::new(config.epochs as u64); 136 | pb.set_style( 137 | ProgressStyle::default_bar() 138 | .template("{bar:40} | {msg}") 139 | .unwrap() 140 | .progress_chars("=>-"), 141 | ); 142 | Some(pb) 143 | } 144 | false => None, 145 | }; 146 | 147 | let mut epoch = 0; 148 | let mut losses: Vec = vec![]; 149 | let mut best_loss = F::infinity(); 150 | let mut epochs_without_improvement = 0; 151 | 152 | let mse_loss = MseLoss::new(); 153 | 154 | 'main: loop { 155 | // println!("batch {}", format_duration(start_time.elapsed())); 156 | for (batch_idx, _) in batches.iter().enumerate() { 157 | if let Ok(_) = exit_rx.try_recv() { 158 | break 'main; 159 | } 160 | 161 | // Slice the corresponding part of the global_distances_all tensor for this batch 162 | let start_idx = batch_idx * batch_size * num_features; // Calculate the starting index 163 | let end_idx = (batch_idx + 1) * batch_size * num_features; // Calculate the ending index 164 | let end_idx = end_idx.min(tensor_batches_all.shape().dims[0]); // Clip to the size of the tensor 165 | 166 | // skip last batch 167 | if start_idx > end_idx { 168 | continue; 169 | } 170 | 171 | let batch_start_idx = batch_idx * batch_size; 172 | let batch_end_idx = (batch_idx + 1) * batch_size; 173 | 174 | let global_start_idx = batch_idx * global_distance_size.dims[0]; 175 | let global_end_idx = (batch_idx + 1) * global_distance_size.dims[0]; 176 | 177 | let global_distances = global_distances_all 178 | .clone() 179 | .slice([global_start_idx..global_end_idx]); // Slice the tensor 180 | 181 | let tensor_batch = tensor_batches_all 182 | .clone() 183 | .slice([batch_start_idx..batch_end_idx, 0..num_features]); // Slice the tensor 184 | 185 | // Forward pass to get the local (low-dimensional) representation. 186 | let local = model.forward(tensor_batch.clone()); 187 | 188 | // Compute the loss for the batch. 189 | let local_distances = 190 | get_distance_by_metric(local.clone(), config, Some("local".into())); 191 | 192 | let loss = mse_loss.forward( 193 | global_distances.clone(), 194 | local_distances.clone(), 195 | burn::nn::loss::Reduction::Mean, 196 | ); 197 | 198 | let current_loss = F::from(loss.clone().into_scalar().to_f64()).unwrap(); 199 | 200 | losses.push(current_loss); 201 | 202 | #[cfg(feature = "verbose")] 203 | { 204 | // println!("global_distances {:?}", global_distances.to_data()); 205 | // println!("local_distances {:?}", local_distances.to_data()); 206 | // println!("loss {:?}", loss.to_data()); 207 | if current_loss.is_nan() { 208 | panic!("current loss is NaN") 209 | } 210 | } 211 | 212 | // TODO: if loss is NaN, do something else. FIXME 213 | let grads = loss.backward(); 214 | 215 | // Compute gradients and update the model parameters using the optimizer. 216 | let batch_grads = GradientsParams::from_grads(grads, &model); 217 | 218 | // Accumulate gradients. 219 | accumulator.accumulate(&model, batch_grads); 220 | } 221 | 222 | let current_loss = losses.last().unwrap().clone(); 223 | 224 | let grads = accumulator.grads(); // Pop the accumulated gradients. 225 | 226 | // Perform an optimization step to update model parameters. 227 | model = optim.step(config.learning_rate, model, grads); 228 | 229 | // Track elapsed time and update the progress bar. 230 | let elapsed = start_time.elapsed(); 231 | if let Some(pb) = &pb { 232 | pb.set_message(format!( 233 | "Elapsed: {} | Epoch: {} | Loss: {:.4} | Best loss: {:.4}", 234 | format_duration(elapsed), 235 | epoch, 236 | current_loss, 237 | best_loss, 238 | )); 239 | } 240 | 241 | if let Some(timeout) = config.timeout { 242 | if elapsed >= Duration::from_secs(timeout) { 243 | break; 244 | } 245 | } 246 | 247 | // Track improvements in loss for early stopping. 248 | if current_loss <= best_loss { 249 | best_loss = current_loss; 250 | epochs_without_improvement = 0; 251 | 252 | model 253 | .clone() 254 | .save_file(model_path.clone(), &recorder) 255 | .expect("Should be able to save the model"); 256 | } else { 257 | epochs_without_improvement += 1; 258 | } 259 | 260 | // Check for early stopping based on patience or number of epochs. 261 | if let Some(patience) = config.patience { 262 | if epochs_without_improvement >= patience && epoch >= config.epochs { 263 | break; // Stop training if patience is exceeded. 264 | } 265 | } else if epoch >= config.epochs { 266 | break; // Stop training if the specified number of epochs is reached. 267 | } 268 | 269 | // Stop early if we reach the desired loss. 270 | if let Some(min_desired_loss) = config.min_desired_loss { 271 | if current_loss < F::from(min_desired_loss).unwrap() { 272 | break; 273 | } 274 | } 275 | 276 | let output_path = format!("losses_{name}.png"); 277 | 278 | #[cfg(feature = "verbose")] 279 | { 280 | const STEP: usize = 100; 281 | let name = name.to_string(); 282 | if epoch > 0 && epoch % STEP == 0 { 283 | let losses = losses.clone(); 284 | let model = &model.valid(); 285 | let tensor_data = 286 | convert_vector_to_tensor(data.clone(), num_samples, num_features, &device); 287 | // this is still slow 288 | 289 | let embeddings_for_entire_dataset = model.forward(tensor_data); 290 | let output_path = output_path.clone(); 291 | 292 | thread::spawn(move || { 293 | let chart_config = ChartConfigBuilder::default() 294 | .caption(format!("{name}_{epoch}").as_str()) 295 | .path(format!("{name}_{epoch}.png").as_str()) 296 | .build(); 297 | 298 | // Visualize the 2D embedding (local representation) using a chart 299 | chart::chart_tensor(embeddings_for_entire_dataset, None, Some(chart_config)); 300 | // Print only last losses 301 | plot_loss(losses.clone()[STEP..].to_vec(), &output_path).unwrap(); 302 | }); 303 | } 304 | } 305 | 306 | epoch += 1; 307 | 308 | // Check for early stopping based on patience or number of epochs. 309 | if let Some(patience) = config.patience { 310 | if epochs_without_improvement >= patience && epoch >= config.epochs { 311 | break; // Stop training if patience is exceeded. 312 | } 313 | } else if epoch >= config.epochs { 314 | break; // Stop training if the specified number of epochs is reached. 315 | } 316 | 317 | // If verbose mode is enabled, plot the loss curve after training. 318 | if config.verbose { 319 | plot_loss(losses.clone(), &output_path).unwrap(); 320 | } 321 | 322 | // Finish the progress bar if it was used. 323 | if let Some(pb) = &pb { 324 | pb.finish(); 325 | } 326 | } 327 | 328 | // If verbose mode is enabled, plot the loss curve after training. 329 | #[cfg(feature = "verbose")] 330 | { 331 | let name = format!("losses_{name}.png"); 332 | plot_loss(losses.clone(), name.as_str()).unwrap(); 333 | } 334 | 335 | // Finish the progress bar if it was used. 336 | if let Some(pb) = pb { 337 | pb.finish(); 338 | } 339 | 340 | // let record = BinFileRecorder::::default(); 341 | model = model 342 | .load_file(model_path, &recorder, &device) 343 | .expect("Load model from the best weights file"); 344 | 345 | // Return best trained model 346 | (model, losses, best_loss) 347 | } 348 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fast-umap 2 | 3 | ⚠️Not ready for production⚠️ 4 | 5 | UMAP (Uniform Manifold Approximation and Projection) in Rust 6 | 7 | This repository contains a Rust implementation of **UMAP** (Uniform Manifold Approximation and Projection), a dimensionality reduction algorithm that preserves both the local and global structure of data. UMAP is widely used for visualizing high-dimensional data in 2D or 3D space. 8 | 9 | This implementation leverages the [burn](https://github.com/tracel-ai/burn) machine learning framework, which provides automatic differentiation and GPU support, allowing you to train and apply UMAP models on high-dimensional datasets efficiently. 10 | 11 | See [documentation](https://docs.rs/crate/fast-umap/latest) for more. 12 | 13 | ## Features 14 | 15 | - **Dimensionality Reduction**: Reduces high-dimensional data to a lower-dimensional space (e.g., 2D or 3D) for visualization or further analysis. 16 | - **Customizable UMAP Model**: The model architecture can be configured with different numbers of input features, hidden layer sizes, and output dimensions. 17 | - **GPU Support**: Powered by the `burn` framework with support for training on CPU and GPU using the `wgpu` backend. 18 | - **Flexible Data Handling**: Functions for converting between vectors and tensors, generating synthetic data, and more. 19 | 20 | 21 | ## Roadmap 22 | 23 | - [x] Add MNIST dataset example. add intermediary plots. 24 | - [x] move charting into a feature. 25 | - [x] add labels in the plots. 26 | - [x] Implement batches, **accumulated gradient** 27 | - [ ] precompute inital local fitting with the PCA 28 | - [x] implement distances in CubeCL kernels 29 | - [x] Create testbench to compare different hyper parameters (`patience` vs `n_features` vs `epochs` etc.) 30 | 31 | ## Installation 32 | 33 | ```shell 34 | cargo add fast-umap 35 | ``` 36 | 37 | ## Example Usage 38 | 39 | ### 1. Fitting a UMAP Model 40 | 41 | You can fit a UMAP model to your data using the `UMAP::fit` function. Here's how to do it: 42 | 43 | ```rust 44 | use burn::backend::Autodiff; 45 | use burn::backend::wgpu::{Wgpu, WgpuDevice}; 46 | use fast_umap::prelude::*; 47 | 48 | fn main() { 49 | // Example data (list of samples, each with a list of features) 50 | let data: Vec> = vec![ 51 | vec![1.0, 2.0, 3.0], 52 | vec![4.0, 5.0, 6.0], 53 | vec![7.0, 8.0, 9.0], 54 | // Add more samples... 55 | ]; 56 | 57 | // Fit the UMAP model 58 | let model = umap(data); 59 | 60 | // You can now use the model to transform new data 61 | let transformed = model.transform(data); 62 | 63 | // Print the transformed data (low-dimensional representation) 64 | for sample in transformed { 65 | println!("{:?}", sample); 66 | } 67 | } 68 | ``` 69 | 70 | ### 2. Transforming Data 71 | 72 | Once the UMAP model is trained, you can transform new high-dimensional data into its low-dimensional representation: 73 | 74 | ```rust 75 | let transformed_data = model.transform(new_data); 76 | ``` 77 | 78 | This function will take `new_data` in the form of `Vec>` and return its 2D or 3D representation, suitable for visualization. 79 | 80 | ### 3. Generating Test Data 81 | 82 | You can generate synthetic test data to experiment with the UMAP model using the `generate_test_data` function: 83 | 84 | ```rust 85 | let data = generate_test_data(100, 50); // 100 samples, each with 50 features 86 | ``` 87 | 88 | ### 4. Visualizing Data 89 | 90 | After transforming the data to a 2D or 3D space, you can use external charting libraries (e.g., `plotters` in Rust or `matplotlib` in Python) to visualize the results. 91 | 92 | ## Model Configuration 93 | 94 | The UMAP model configuration is customizable through the `UMAPModelConfigBuilder`. You can set the following parameters: 95 | 96 | - **input_size**: Number of input features (i.e., the dimensionality of the data). 97 | - **hidden_size**: The number of neurons in the hidden layers. 98 | - **output_size**: The target number of dimensions (typically 2 or 3 for visualization). 99 | 100 | Here's how to configure and build the model: 101 | 102 | ```rust 103 | let model_config = UMAPModelConfigBuilder::default() 104 | .input_size(50) // Input features: 50 dimensions 105 | .hidden_size(100) // Hidden layer size: 100 neurons 106 | .output_size(2) // Output size: 2 (for 2D visualization) 107 | .build() 108 | .unwrap(); 109 | ``` 110 | 111 | 112 | ## Training the UMAP Model 113 | 114 | To train the UMAP model on your dataset, you can use the `fit` method of the `UMAP` struct. The training process involves optimizing the model's weights to effectively reduce the dimensionality of the data while preserving its underlying structure. 115 | 116 | The training process is governed by several key configuration parameters, which control aspects of optimization, regularization, and performance. These parameters can be customized to suit your specific training needs: 117 | 118 | ### Configuration Parameters: 119 | 120 | - **`epochs`**: 121 | Specifies the total number of epochs (iterations over the entire dataset) to run during training. Increasing the number of epochs allows the model more time to converge but may also increase the risk of overfitting. 122 | 123 | - **`batch_size`**: 124 | Defines the number of samples to process in each training batch. Larger batch sizes typically speed up training but require more memory. A smaller batch size may help the model generalize better but might take longer to train. 125 | 126 | - **`learning_rate`**: 127 | The learning rate determines the size of the steps the optimizer will take during parameter updates. A higher learning rate can speed up training but may cause instability, while a smaller learning rate ensures more gradual updates but might result in slower convergence. 128 | 129 | - **`beta1`** and **`beta2`**: 130 | These are the hyperparameters for the Adam optimizer: 131 | - **`beta1`** controls the decay rate of the first moment estimate (the moving average of the gradients). 132 | - **`beta2`** controls the decay rate of the second moment estimate (the moving average of the squared gradients). 133 | Both parameters influence how the optimizer adapts the learning rate over time, with typical values being `beta1 = 0.9` and `beta2 = 0.999`. 134 | 135 | - **`penalty`**: 136 | L2 regularization (weight decay) applied during training to prevent overfitting by penalizing large model weights. This can help improve the generalization ability of the model. 137 | 138 | - **`metric`**: 139 | The distance metric used during training, such as `"euclidean"`, `"manhattan"`, or `"cosine"`. This metric is used to measure the similarity between points and plays a crucial role in the dimensionality reduction process. 140 | 141 | - **`verbose`**: 142 | A flag indicating whether detailed progress information (e.g., loss values, training status) should be printed during training. This can be useful for monitoring the model's progress. 143 | 144 | - **`patience`**: 145 | The number of epochs to wait for improvement in the loss before triggering early stopping. If the model's performance doesn't improve after `patience` epochs, training will be stopped early to avoid unnecessary computations and prevent overfitting. If set to `None`, early stopping is disabled. 146 | 147 | - **`loss_reduction`**: 148 | Specifies how to reduce the loss during training (e.g., "mean" or "sum"). This parameter determines how the loss is aggregated across the training batch. 149 | 150 | - **`k_neighbors`**: 151 | The number of nearest neighbors to consider in the UMAP algorithm. This parameter is crucial for determining the local structure of the data and can impact the quality of the dimensionality reduction. 152 | 153 | - **`min_desired_loss`**: 154 | An optional parameter that specifies the minimum acceptable loss. If the model's loss reaches this threshold, training will stop. This can be useful if you want to set a target performance level before halting training. 155 | 156 | - **`timeout`**: 157 | The maximum amount of time (in seconds) to allow for training. If `None`, there is no time limit. This can be useful for controlling long-running training sessions in resource-constrained environments. 158 | 159 | ### Summary of Key Parameters: 160 | - **`epochs`**: Number of epochs to train. 161 | - **`batch_size`**: Number of samples per batch. 162 | - **`learning_rate`**: Step size for gradient updates. 163 | - **`beta1`**, **`beta2`**: Adam optimizer parameters. 164 | - **`penalty`**: L2 regularization strength. 165 | - **`verbose`**: Whether to display detailed training progress. 166 | - **`patience`**: Early stopping criterion. 167 | - **`loss_reduction`**: Method to reduce loss. 168 | - **`k_neighbors`**: Number of nearest neighbors in the UMAP algorithm. 169 | - **`min_desired_loss`**: Optional target loss for stopping early. 170 | - **`timeout`**: Maximum training time. 171 | 172 | By carefully tuning these parameters, you can optimize the UMAP model to better capture the underlying structure of your high-dimensional data while balancing training time, memory usage, and generalization performance. 173 | 174 | 175 | For example: 176 | 177 | ``` 178 | let model = UMAP::>::fit(data, WgpuDevice::default()); 179 | ``` 180 | 181 | ## Examples 182 | 183 | ### Simple 184 | 185 | ```shell 186 | cargo run --example simple 187 | ``` 188 | 189 | Sample code: 190 | 191 | ```rust 192 | use fast_umap::prelude::*; 193 | use rand::Rng; 194 | 195 | fn main() { 196 | // Number of samples in the dataset 197 | let num_samples = 100; 198 | 199 | // Number of features (dimensions) for each sample 200 | let num_features = 3; 201 | 202 | // Create a random number generator for generating random values 203 | let mut rng = rand::thread_rng(); 204 | 205 | // Generate a dataset of random values with `num_samples` rows and `num_features` columns 206 | let data: Vec> = (0..num_samples * num_features) 207 | .map(|_| rng.gen::()) // Random number generation for each feature 208 | .collect::>() // Collect all random values into a vector 209 | .chunks_exact(num_features) // Chunk the vector into rows of length `num_features` 210 | .map(|chunk| chunk.to_vec()) // Convert each chunk into a Vec 211 | .collect(); // Collect the rows into a vector of vectors 212 | 213 | // Fit the UMAP model to the data and reduce the data to a lower-dimensional space (default: 2D) 214 | let umap = umap(data.clone()); 215 | 216 | // Transform the data using the trained UMAP model to reduce its dimensions 217 | let reduced_dimensions_vector = umap.transform(data.clone()); 218 | 219 | // Visualize the reduced dimensions as a vector 220 | chart_vector(reduced_dimensions_vector, None); 221 | 222 | // Optionally, you can also visualize the reduced dimensions as a tensor 223 | // let reduced_dimensions_tensor = umap.transform_to_tensor(data.clone()); 224 | // print_tensor_with_title("reduced_dimensions", &reduced_dimensions_tensor); 225 | // chart_tensor(reduced_dimensions_tensor, None); 226 | } 227 | ``` 228 | 229 | Generates this plot: 230 | 231 | ![plot](https://github.com/eugenehp/fast-umap/raw/refs/heads/master/plot.png) 232 | 233 | ### Advanced 234 | 235 | ```shell 236 | cargo run --example advanced 237 | ``` 238 | 239 | Sample code: 240 | 241 | ```rust 242 | use burn::{backend::*, module::*, prelude::*}; 243 | use fast_umap::{chart, model::*, prelude::*, train::train, utils::*}; 244 | 245 | fn main() { 246 | // Define a custom backend type using Wgpu with 32-bit floating point precision and 32-bit integer type 247 | type MyBackend = Wgpu; 248 | 249 | // Define the AutodiffBackend based on the custom MyBackend type 250 | type MyAutodiffBackend = Autodiff; 251 | 252 | // Initialize the GPU device for computation 253 | let device = burn::backend::wgpu::WgpuDevice::default(); 254 | 255 | // Set training hyperparameters 256 | let batch_size = 1; // Number of samples per batch during training 257 | let num_samples = 1000; // Total number of samples in the dataset 258 | let num_features = 100; // Number of features (dimensions) for each sample 259 | let k_neighbors = 10; // Number of nearest neighbors for the UMAP algorithm 260 | let output_size = 2; // Number of output dimensions (e.g., 2D for embeddings) 261 | let hidden_sizes = vec![100, 100, 100]; // Size of the hidden layer in the neural network 262 | let learning_rate = 0.001; // Learning rate for optimization 263 | let beta1 = 0.9; // Beta1 parameter for the Adam optimizer 264 | let beta2 = 0.999; // Beta2 parameter for the Adam optimizer 265 | let epochs = 400; // Number of training epochs 266 | let seed = 9999; // Random seed to ensure reproducibility 267 | let verbose = true; // Whether to enable the progress bar during training 268 | let patience = 10; // Number of epochs without improvement before early stopping 269 | let min_desired_loss = 0.001; // Minimum loss threshold for early stopping 270 | let timeout = 60; 271 | 272 | // let metric = Metric::EuclideanKNN; // Alternative metric for neighbors search 273 | let metric = "euclidean_knn"; // Distance metric used for the nearest neighbor search 274 | 275 | // Seed the random number generator to ensure reproducibility 276 | MyBackend::seed(seed); 277 | 278 | // Generate random test data for training 279 | let train_data = generate_test_data(num_samples, num_features); 280 | 281 | // Configure the UMAP model with the specified input size, hidden layer size, and output size 282 | let model_config = UMAPModelConfigBuilder::default() 283 | .input_size(num_features) 284 | .hidden_sizes(hidden_sizes) 285 | .output_size(output_size) 286 | .build() 287 | .unwrap(); 288 | 289 | // Initialize the UMAP model with the defined configuration and the selected device 290 | let model: UMAPModel = UMAPModel::new(&model_config, &device); 291 | 292 | // Set up the training configuration with the specified hyperparameters 293 | let config = TrainingConfig::::builder() 294 | .with_epochs(epochs) // Set the number of epochs for training 295 | .with_batch_size(batch_size) // Set the batch size for training 296 | .with_learning_rate(learning_rate) // Set the learning rate for the optimizer 297 | .with_device(device) // Specify the device (GPU) for computation 298 | .with_beta1(beta1) // Set the beta1 parameter for the Adam optimizer 299 | .with_beta2(beta2) // Set the beta2 parameter for the Adam optimizer 300 | .with_verbose(verbose) // Enable or disable the progress bar 301 | .with_patience(patience) // Set the patience for early stopping 302 | .with_metric(metric.into()) // Set the metric for nearest neighbors (e.g., Euclidean) 303 | .with_k_neighbors(k_neighbors) // Set the number of neighbors to consider for UMAP 304 | .with_min_desired_loss(min_desired_loss) // Set the minimum desired loss for early stopping 305 | .with_timeout(timeout) // set timeout in seconds 306 | .build() 307 | .expect("Failed to build TrainingConfig"); 308 | 309 | // Start training the UMAP model with the specified training data and configuration 310 | let model = train::( 311 | model, // The model to train 312 | num_samples, // Total number of training samples 313 | num_features, // Number of features per sample 314 | train_data.clone(), // The training data 315 | &config, // The training configuration 316 | ); 317 | 318 | // Validate the trained model after training 319 | let (model, _) = model.valid(); 320 | 321 | // Convert the training data into a tensor for model input 322 | let global = convert_vector_to_tensor(train_data, num_samples, num_features, &config.device); 323 | 324 | // Perform a forward pass through the model to obtain the low-dimensional (local) representation 325 | let local = model.forward(global.clone()); 326 | 327 | // Optionally, print the global and local tensors for inspection (currently commented out) 328 | // if verbose { 329 | // print_tensor_with_title("global", &global); 330 | // print_tensor_with_title("local", &local); 331 | // } 332 | 333 | // Visualize the 2D embedding (local representation) using a chart 334 | chart::chart_tensor(local, None); 335 | } 336 | ``` 337 | 338 | It also generates 2d plot, and a loss chart: 339 | 340 | ![loss](https://github.com/eugenehp/fast-umap/raw/refs/heads/master/losses.png) 341 | 342 | ## License 343 | 344 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 345 | 346 | ## Copyright 347 | 348 | 2024-2025, Eugene Hauptmann 349 | 350 | ## Thank you 351 | 352 | Inspired by original UMAP [paper](https://arxiv.org/abs/1802.03426) 353 | --------------------------------------------------------------------------------