├── examples ├── interactive_demo │ ├── Screenshot.png │ ├── Cargo.toml │ ├── run_demo.sh │ ├── README.md │ └── src │ │ └── main.rs └── demo │ ├── Cargo.toml │ ├── run_demo.sh │ ├── README.md │ └── src │ └── main.rs ├── src ├── geometry │ ├── mod.rs │ ├── axis.rs │ ├── point.rs │ ├── helper.rs │ └── distance.rs ├── error.rs ├── map │ ├── nn.rs │ ├── table.rs │ └── mod.rs ├── macros.rs ├── test_utilities.rs ├── models │ ├── mod.rs │ ├── trainer.rs │ ├── stats.rs │ └── linear.rs ├── lib.rs └── hasher │ └── mod.rs ├── .github └── workflows │ └── rust.yml ├── .gitignore ├── Cargo.toml ├── LICENSE-MIT ├── README.md ├── benches └── benchmarks.rs └── LICENSE-APACHE /examples/interactive_demo/Screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackson211/LearnedSPatialHashMap/HEAD/examples/interactive_demo/Screenshot.png -------------------------------------------------------------------------------- /src/geometry/mod.rs: -------------------------------------------------------------------------------- 1 | mod axis; 2 | pub mod distance; 3 | pub mod helper; 4 | mod point; 5 | 6 | pub use axis::*; 7 | pub use helper::*; 8 | pub use point::*; 9 | -------------------------------------------------------------------------------- /src/geometry/axis.rs: -------------------------------------------------------------------------------- 1 | /// Axis defines axis direction in Cartesian coordinate system 2 | #[derive(Debug, Clone)] 3 | pub enum Axis { 4 | X, 5 | Y, 6 | Z, 7 | } 8 | -------------------------------------------------------------------------------- /examples/interactive_demo/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lsph_interactive_demo" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | lsph = { path = "../.." } 8 | eframe = "0.24" 9 | egui = "0.24" 10 | rand = "0.9" 11 | serde = { version = "1.0", features = ["derive"] } 12 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | /// The kinds of errors that can occur when calculating a linear regression. 2 | #[derive(Copy, Clone, Debug, PartialEq)] 3 | pub enum Error { 4 | /// The slope is too steep to represent, approaching infinity. 5 | SteepSlope, 6 | 7 | /// Different input lenses 8 | DiffLen, 9 | 10 | /// Input was empty 11 | EmptyVal, 12 | } 13 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Build 20 | run: cargo build --verbose 21 | - name: Run tests 22 | run: cargo test --verbose 23 | -------------------------------------------------------------------------------- /examples/demo/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lsph_demo" 3 | version = "0.1.0" 4 | edition = "2021" 5 | description = "LSPH demonstration with real-world Melbourne geographic data" 6 | authors = ["LSPH Contributors"] 7 | 8 | [dependencies] 9 | lsph = { path = "../.." } 10 | csv = "1.3" 11 | serde = { version = "1.0", features = ["derive"] } 12 | rand = "0.9" 13 | colored = "2.1" 14 | clap = { version = "4.4", features = ["derive"] } 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Rust build artifacts 2 | /target/ 3 | **/target/ 4 | Cargo.lock 5 | 6 | # IDE and editor files 7 | .vscode/ 8 | .idea/ 9 | *.swp 10 | *.swo 11 | *~ 12 | 13 | # OS generated files 14 | .DS_Store 15 | .DS_Store? 16 | ._* 17 | .Spotlight-V100 18 | .Trashes 19 | ehthumbs.db 20 | Thumbs.db 21 | 22 | # Temporary files 23 | *.tmp 24 | *.temp 25 | *.log 26 | 27 | # Backup files 28 | *.bak 29 | *.backup 30 | 31 | # Coverage reports 32 | tarpaulin-report.html 33 | lcov.info 34 | 35 | # Benchmark results 36 | /criterion/ 37 | 38 | # Documentation build 39 | /book/ 40 | /mdbook/ 41 | 42 | # Local configuration 43 | .env 44 | .env.local -------------------------------------------------------------------------------- /src/geometry/point.rs: -------------------------------------------------------------------------------- 1 | use num_traits::float::Float; 2 | 3 | /// Point struct contains id, x and y 4 | #[derive(Debug, Clone, Copy, Eq, PartialEq)] 5 | pub struct Point { 6 | pub(crate) x: T, 7 | pub(crate) y: T, 8 | } 9 | 10 | impl Default for Point 11 | where 12 | T: Float, 13 | { 14 | fn default() -> Self { 15 | Point { 16 | x: T::zero(), 17 | y: T::zero(), 18 | } 19 | } 20 | } 21 | 22 | impl Point 23 | where 24 | T: Float, 25 | { 26 | pub fn new(x: T, y: T) -> Self { 27 | Point { x, y } 28 | } 29 | 30 | pub fn x(&self) -> T { 31 | self.x 32 | } 33 | 34 | pub fn y(&self) -> T { 35 | self.y 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lsph" 3 | version = "0.1.9" 4 | authors = ["Haozhan Shi "] 5 | repository = "https://github.com/jackson211/lsph" 6 | homepage = "https://github.com/jackson211/lsph" 7 | documentation = "https://docs.rs/lsph" 8 | edition = "2021" 9 | description = "Learned Spatial HashMap" 10 | license = "MIT OR Apache-2.0" 11 | readme = "README.md" 12 | keywords = ["hash", "hashmap", "spatial", "index"] 13 | categories = ["data-structures"] 14 | 15 | [dependencies] 16 | num-traits = "0.2.19" 17 | smallvec = "1.15.1" 18 | 19 | [dev-dependencies] 20 | rand = "0.9.2" 21 | criterion = { version = "0.7.0", features = ["html_reports"] } 22 | 23 | [[bench]] 24 | name = "benchmarks" 25 | harness = false 26 | -------------------------------------------------------------------------------- /src/map/nn.rs: -------------------------------------------------------------------------------- 1 | use crate::geometry::Point; 2 | use num_traits::float::Float; 3 | use std::cmp::Ordering; 4 | 5 | /// State for store nearest neighbors distances and points in min_heap 6 | #[derive(Copy, Clone, PartialEq)] 7 | pub struct NearestNeighborState 8 | where 9 | F: Float, 10 | { 11 | pub distance: F, 12 | pub point: Point, 13 | } 14 | 15 | impl Eq for NearestNeighborState {} 16 | 17 | impl PartialOrd for NearestNeighborState 18 | where 19 | F: Float, 20 | { 21 | fn partial_cmp(&self, other: &Self) -> Option { 22 | Some(self.cmp(other)) 23 | } 24 | } 25 | 26 | impl Ord for NearestNeighborState 27 | where 28 | F: Float, 29 | { 30 | fn cmp(&self, other: &Self) -> Ordering { 31 | // We flip the ordering on distance, so the queue becomes a min-heap 32 | other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/macros.rs: -------------------------------------------------------------------------------- 1 | extern crate num_traits; 2 | // Hepler macro for assert Float values within a delta range, panic if the 3 | // difference between two numbers is exceeds the given threshold. 4 | #[macro_export] 5 | macro_rules! assert_delta { 6 | ($x:expr, $y:expr, $delta:expr) => { 7 | assert!( 8 | ($x - $y).abs() <= num_traits::NumCast::from($delta).unwrap(), 9 | "{} is not within {} of {}", 10 | stringify!($x), 11 | $delta, 12 | stringify!($y) 13 | ); 14 | }; 15 | } 16 | 17 | #[macro_export] 18 | macro_rules! assert_eq_len { 19 | ($a:expr, $b:expr) => { 20 | if $a.len() != $b.len() { 21 | return Err(Error::DiffLen); 22 | } 23 | }; 24 | } 25 | 26 | #[macro_export] 27 | macro_rules! assert_empty { 28 | ($a:expr) => { 29 | if $a.is_empty() { 30 | return Err(Error::EmptyVal); 31 | } 32 | }; 33 | } 34 | -------------------------------------------------------------------------------- /src/test_utilities.rs: -------------------------------------------------------------------------------- 1 | use crate::geometry::*; 2 | use rand::{rngs::SmallRng, Rng, SeedableRng}; 3 | 4 | pub type Seed = [u8; 32]; 5 | 6 | pub const SEED_1: &Seed = b"wPYxAkIiHcEmSBAxQFoXFrpYToCe1B71"; 7 | pub const SEED_2: &Seed = b"4KbTVjPT4DXSwWAsQM5dkWWywPKZRfCX"; 8 | 9 | pub fn create_random_points(num_points: usize, seed: &[u8; 32]) -> Vec<(f64, f64)> { 10 | let mut result = Vec::with_capacity(num_points); 11 | let mut rng = SmallRng::from_seed(*seed); 12 | for _ in 0..num_points { 13 | result.push((rng.random(), rng.random())); 14 | } 15 | result 16 | } 17 | 18 | pub fn create_random_point_type_points(num_points: usize, seed: &[u8; 32]) -> Vec> { 19 | let result = create_random_points(num_points, seed); 20 | 21 | // result.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); 22 | result 23 | .into_iter() 24 | .map(|(x, y)| Point { x, y }) 25 | .collect::>() 26 | } 27 | -------------------------------------------------------------------------------- /src/geometry/helper.rs: -------------------------------------------------------------------------------- 1 | use crate::geometry::Point; 2 | use num_traits::float::Float; 3 | 4 | /// Extract all the x values from a Vec> 5 | pub fn extract_x(ps: &[Point]) -> Vec { 6 | ps.iter().map(|p| p.x).collect() 7 | } 8 | 9 | /// Extract all the y values from a Vec> 10 | pub fn extract_y(ps: &[Point]) -> Vec { 11 | ps.iter().map(|p| p.y).collect() 12 | } 13 | 14 | /// Sort a Vec> based on the x values 15 | pub fn sort_by_x(ps: &mut [Point]) { 16 | ps.sort_by(|a, b| a.x.partial_cmp(&b.x).unwrap()); 17 | } 18 | 19 | /// Sort a Vec> based on the y values 20 | pub fn sort_by_y(ps: &mut [Point]) { 21 | ps.sort_by(|a, b| a.y.partial_cmp(&b.y).unwrap()); 22 | } 23 | 24 | /// Convert a Vec of [F; 2] to a Vec> 25 | pub fn convert_to_points(ps: &[[F; 2]]) -> Option>> { 26 | Some(ps.iter().map(|p| Point::new(p[0], p[1])).collect()) 27 | } 28 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 The LSPH developers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/interactive_demo/run_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # LSPH Interactive Demo Runner 4 | # This script builds and runs the interactive demo 5 | 6 | echo "🗺️ Building LSPH Interactive Demo..." 7 | echo "======================================" 8 | 9 | # Build the demo 10 | if cargo build --release; then 11 | echo "✅ Build successful!" 12 | echo "" 13 | echo "🚀 Starting LSPH Interactive Demo..." 14 | echo "=====================================" 15 | echo "" 16 | echo "📋 Demo Features:" 17 | echo " • Manual Point Addition - Click to add points" 18 | echo " • Random Generation - Auto-generate test data" 19 | echo " • Nearest Neighbor Search - Find closest points" 20 | echo " • Range Query - Search within radius" 21 | echo "" 22 | echo "💡 Tips:" 23 | echo " • Start with 'Random Generation' mode" 24 | echo " • Generate 100 points for best experience" 25 | echo " • Try different search modes by clicking on canvas" 26 | echo " • Enable grid for better spatial reference" 27 | echo "" 28 | echo "Press Ctrl+C to exit the demo" 29 | echo "" 30 | 31 | # Run the demo 32 | cargo run --release 33 | else 34 | echo "❌ Build failed! Please check the error messages above." 35 | exit 1 36 | fi -------------------------------------------------------------------------------- /src/models/mod.rs: -------------------------------------------------------------------------------- 1 | mod linear; 2 | mod stats; 3 | mod trainer; 4 | 5 | pub use linear::*; 6 | pub use stats::*; 7 | pub use trainer::*; 8 | 9 | use crate::error::Error; 10 | use core::fmt::Debug; 11 | use num_traits::float::Float; 12 | 13 | /// Model representation, provides common functionalities for model training 14 | pub trait Model { 15 | /// Associated type for float number representation 16 | type F; 17 | /// Prints the name of the model 18 | fn name(&self) -> String; 19 | /// Fit two slices of training data into the model 20 | fn fit(&mut self, xs: &[Self::F], ys: &[Self::F]) -> Result<(), Error>; 21 | /// Fit one slice of training data in tuple format into the model 22 | fn fit_tuple(&mut self, xys: &[(Self::F, Self::F)]) -> Result<(), Error>; 23 | /// Takes one value and returns the predictions of the model 24 | fn predict(&self, x: Self::F) -> Self::F; 25 | /// Takes slice of value and returns the batch predictions from the model 26 | fn batch_predict(&self, xs: &[Self::F]) -> Vec; 27 | /// Evaluate the predictions results from a pair of test sets 28 | fn evaluate(&self, x_test: &[Self::F], y_test: &[Self::F]) -> Self::F; 29 | /// Unpredict provides the ability of reversing the predict operation 30 | /// For a given target value, return the estimate input value 31 | fn unpredict(&self, y: Self::F) -> Self::F; 32 | } 33 | 34 | impl Debug for (dyn Model + 'static) 35 | where 36 | F: Float, 37 | { 38 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 39 | write!(f, "Model {{{}}}", self.name()) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /examples/demo/run_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # LSPH Geographic Data Demo Runner 4 | # This script builds and runs the geographic data demonstration 5 | 6 | echo "🗺️ LSPH Geographic Data Demo" 7 | echo "==============================" 8 | echo "" 9 | 10 | # Check if melbourne.csv exists 11 | if [ ! -f "melbourne.csv" ]; then 12 | echo "⚠️ Warning: melbourne.csv not found in current directory" 13 | echo " The demo will attempt to load the file, but may fail." 14 | echo " Make sure the CSV file is in the demo directory." 15 | echo "" 16 | fi 17 | 18 | # Build the demo 19 | echo "🔨 Building demo application..." 20 | if cargo build --release; then 21 | echo "✅ Build successful!" 22 | echo "" 23 | else 24 | echo "❌ Build failed! Please check the error messages above." 25 | exit 1 26 | fi 27 | 28 | # Display usage options 29 | echo "🚀 Starting LSPH Demo" 30 | echo "====================" 31 | echo "" 32 | echo "Available options:" 33 | echo " 1. Full demo (default)" 34 | echo " 2. Interactive mode only" 35 | echo " 3. Custom parameters" 36 | echo "" 37 | read -p "Choose option (1-3) or press Enter for default: " choice 38 | 39 | case $choice in 40 | 2) 41 | echo "🎮 Starting interactive mode..." 42 | cargo run --release -- --skip-demo 43 | ;; 44 | 3) 45 | echo "📝 Custom parameters:" 46 | read -p "Number of queries (default 100): " queries 47 | read -p "Run interactive mode after? (y/n): " interactive 48 | 49 | args="" 50 | if [ ! -z "$queries" ]; then 51 | args="$args --queries $queries" 52 | fi 53 | if [ "$interactive" = "y" ] || [ "$interactive" = "Y" ]; then 54 | args="$args --interactive" 55 | fi 56 | 57 | echo "🚀 Running with custom parameters..." 58 | cargo run --release -- $args 59 | ;; 60 | *) 61 | echo "🚀 Running full demo..." 62 | cargo run --release 63 | ;; 64 | esac 65 | 66 | echo "" 67 | echo "🎉 Demo completed!" 68 | echo "Thank you for exploring LSPH capabilities." -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! This crate implements the Learned SPatial HashMap(LSPH), a high performance spatial index uses HashMap with learned model. 2 | //! 3 | //! The original paper of LSPH can be found [here]. 4 | //! 5 | //! [here]: https://minerva-access.unimelb.edu.au/items/beb5c0ee-2a8d-5bd2-b349-1190a335ef1a 6 | //! 7 | //! The LSPH uses a learned model such as a linear regression model as the hash 8 | //! function to predict the index in a hashmap. As a result, the learned model 9 | //! is more fitted to the data that stored in the hashmap, and reduces the 10 | //! chance of hashing collisions. Moreover, if the learned model is monotonic 11 | //! function(e.g. linear regression), the hash indexes are increasing as the 12 | //! input data increases. This property can be used to create a sorted order 13 | //! of buckets in a hashmap, which allow us to do range searches in a hashmap. 14 | //! 15 | //! The LSPH supports: 16 | //! - Point Query 17 | //! - Rectange Query 18 | //! - Radius Range Query 19 | //! - Nearest Neighbor Query 20 | //! 21 | //! Example: 22 | //! ``` 23 | //! use lsph::{LearnedHashMap, LinearModel}; 24 | //! let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 25 | //! let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 26 | //! 27 | //! assert_eq!(map.get(&[1., 1.]).is_some(), true); 28 | //! assert_eq!(map.get(&[3., 1.]).is_none(), true); 29 | //! assert_eq!(map.range_search(&[0., 0.], &[3., 3.]).is_some(), true); 30 | //! assert_eq!(map.radius_range(&[2., 1.], 1.).is_some(), true); 31 | //! assert_eq!(map.nearest_neighbor(&[2., 1.]).is_some(), true); 32 | //! 33 | //! ``` 34 | //! # License 35 | //! 36 | //! Licensed under either of 37 | //! 38 | //! - Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or ) 39 | //! - MIT license ([LICENSE-MIT](LICENSE-MIT) or ) 40 | //! 41 | //! at your option. 42 | 43 | #[macro_use] 44 | mod macros; 45 | mod error; 46 | pub mod geometry; 47 | pub mod hasher; 48 | pub mod map; 49 | pub mod models; 50 | #[cfg(test)] 51 | pub mod test_utilities; 52 | 53 | pub use geometry::*; 54 | pub use hasher::*; 55 | pub use map::*; 56 | pub use models::*; 57 | -------------------------------------------------------------------------------- /src/geometry/distance.rs: -------------------------------------------------------------------------------- 1 | use crate::geometry::Point; 2 | use core::marker::PhantomData; 3 | use num_traits::float::Float; 4 | 5 | /// Distance trait for measuring the distance between two points 6 | pub trait Distance { 7 | type F; 8 | /// Distance between two points in tuple format 9 | fn distance(a: &[Self::F; 2], b: &[Self::F; 2]) -> Self::F; 10 | /// Distance between two points in points format 11 | fn distance_point(a: &Point, b: &Point) -> Self::F; 12 | } 13 | 14 | /// Euclidean Distance 15 | pub struct Euclidean { 16 | _marker: PhantomData, 17 | } 18 | 19 | impl Distance for Euclidean 20 | where 21 | F: Float, 22 | { 23 | type F = F; 24 | fn distance(a: &[F; 2], b: &[F; 2]) -> F { 25 | F::sqrt((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2)) 26 | } 27 | 28 | fn distance_point(a: &Point, b: &Point) -> Self::F { 29 | Self::distance(&[a.x, a.y], &[b.x, b.y]) 30 | } 31 | } 32 | 33 | /// Manhattan Distance 34 | pub struct Manhattan { 35 | _marker: PhantomData, 36 | } 37 | 38 | impl Distance for Manhattan 39 | where 40 | F: Float, 41 | { 42 | type F = F; 43 | fn distance(a: &[F; 2], b: &[F; 2]) -> F { 44 | (a[0] - b[0]).abs() + (a[1] - b[1]).abs() 45 | } 46 | 47 | fn distance_point(a: &Point, b: &Point) -> Self::F { 48 | Self::distance(&[a.x, a.y], &[b.x, b.y]) 49 | } 50 | } 51 | 52 | #[cfg(test)] 53 | mod tests { 54 | use super::*; 55 | 56 | #[test] 57 | fn test_euclidean_f32() { 58 | let a = Point:: { x: 0., y: 0. }; 59 | let b = Point:: { x: 1., y: 1. }; 60 | let d = Euclidean::distance_point(&a, &b); 61 | assert_delta!(d, std::f32::consts::SQRT_2, 0.00001); 62 | } 63 | 64 | #[test] 65 | fn test_euclidean_f64() { 66 | let a = Point:: { x: 0., y: 0. }; 67 | let b = Point:: { x: 1., y: 1. }; 68 | let d = Euclidean::distance_point(&a, &b); 69 | assert_delta!(d, std::f64::consts::SQRT_2, 0.00001); 70 | } 71 | 72 | #[test] 73 | fn test_manhattan_f32() { 74 | let a = Point:: { x: 0., y: 0. }; 75 | let b = Point:: { x: 1., y: 1. }; 76 | let d = Manhattan::distance_point(&a, &b); 77 | assert_delta!(d, 2., 0.00001); 78 | } 79 | 80 | #[test] 81 | fn test_manhattan_f64() { 82 | let a = Point:: { x: 0., y: 0. }; 83 | let b = Point:: { x: 1., y: 1. }; 84 | let d = Manhattan::distance_point(&a, &b); 85 | assert_delta!(d, 2., 0.00001); 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/map/table.rs: -------------------------------------------------------------------------------- 1 | use core::ops::{Deref, DerefMut}; 2 | use smallvec::SmallVec; 3 | 4 | /// Bucket is the lower unit in the HashMap to store the points 5 | #[derive(Debug, Clone)] 6 | pub(crate) struct Bucket { 7 | entry: SmallVec<[V; 6]>, 8 | } 9 | 10 | impl Bucket { 11 | /// Returns a default Bucket with value type. 12 | #[inline] 13 | pub fn new() -> Self { 14 | Self { 15 | entry: SmallVec::new(), 16 | } 17 | } 18 | 19 | /// Removes an element from the Bucket and returns it. 20 | /// The removed element is replaced by the last element of the Bucket. 21 | #[inline] 22 | pub fn swap_remove(&mut self, index: usize) -> V { 23 | self.entry.swap_remove(index) 24 | } 25 | } 26 | 27 | impl Deref for Bucket { 28 | type Target = SmallVec<[V; 6]>; 29 | fn deref(&self) -> &Self::Target { 30 | &self.entry 31 | } 32 | } 33 | 34 | impl DerefMut for Bucket { 35 | fn deref_mut(&mut self) -> &mut Self::Target { 36 | &mut self.entry 37 | } 38 | } 39 | 40 | /// Table containing a Vec of Bucket to store the values 41 | #[derive(Debug, Clone)] 42 | pub(crate) struct Table { 43 | buckets: Vec>, 44 | } 45 | 46 | impl Table { 47 | /// Returns a default Table with empty Vec. 48 | #[inline] 49 | pub fn new() -> Self { 50 | Self { 51 | buckets: Vec::new(), 52 | } 53 | } 54 | 55 | /// Returns a default Table with Vec that with the given capacity. 56 | /// 57 | /// # Arguments 58 | /// * `capacity` - A capacity size for the Table 59 | #[inline] 60 | pub fn with_capacity(capacity: usize) -> Self { 61 | Self { 62 | buckets: Vec::with_capacity(capacity), 63 | } 64 | } 65 | 66 | /// Returns the capacity of the Table. 67 | #[inline] 68 | pub fn capacity(&self) -> usize { 69 | self.buckets.capacity() 70 | } 71 | 72 | /// Returns the Bucket with given hash value. 73 | /// 74 | /// # Arguments 75 | /// * `hash` - A hash value for indexing the bucket in the table 76 | #[inline] 77 | pub fn bucket(&self, hash: u64) -> usize { 78 | hash as usize % self.buckets.len() 79 | } 80 | } 81 | impl Table 82 | where 83 | V: PartialEq, 84 | { 85 | /// Remove entry with given hash value and entry. 86 | /// 87 | /// # Arguments 88 | /// * `hash` - A hash value for indexing the bucket in the table 89 | /// * `entry` - Entry to remove 90 | #[inline] 91 | pub fn remove_entry(&mut self, hash: u64, entry: V) -> Option { 92 | let index = self.bucket(hash); 93 | let bucket = &mut self.buckets[index]; 94 | let i = bucket.iter().position(|ek| ek == &entry)?; 95 | Some(bucket.swap_remove(i)) 96 | } 97 | } 98 | 99 | impl Deref for Table { 100 | type Target = Vec>; 101 | fn deref(&self) -> &Self::Target { 102 | &self.buckets 103 | } 104 | } 105 | 106 | impl DerefMut for Table { 107 | fn deref_mut(&mut self) -> &mut Self::Target { 108 | &mut self.buckets 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [LSPH](https://crates.io/crates/lsph) - Learned SPatial HashMap (LSPH) 2 | 3 | **fast 2d point query powered by hashmap and statistic model** 4 | 5 | ![Github Workflow](https://github.com/jackson211/lsph/actions/workflows/rust.yml/badge.svg) 6 | [![crates.io version](https://img.shields.io/crates/v/lsph)](https://crates.io/crates/lsph) 7 | [![dos.io](https://img.shields.io/docsrs/lsph)](https://docs.rs/lsph) 8 | [![dependency status](https://deps.rs/repo/github/jackson211/lsph/status.svg)](https://deps.rs/repo/github/jackson211/lsph) 9 | 10 | The original paper of LSPH can be found [here]. 11 | 12 | [here]: https://minerva-access.unimelb.edu.au/items/beb5c0ee-2a8d-5bd2-b349-1190a335ef1a 13 | 14 | The LSPH uses a learned model such as a linear regression model as the hash function to predict the index in a hashmap. As a result, the learned model is more fitted to the data that stored in the hashmap, and reduces the 15 | chance of hashing collisions. Moreover, if the learned model is monotonic function(e.g. linear regression), the hash indexes are increasing as the input data increases. This property can be used to create a sorted order 16 | of buckets in a hashmap, which allow us to do range searches in a hashmap. 17 | 18 | The LSPH supports: 19 | 20 | - Point Query 21 | - Rectange Query 22 | - Radius Range Query 23 | - Nearest Neighbor Query 24 | 25 | ## Example: 26 | 27 | ```rust 28 | use lsph::{LearnedHashMap, LinearModel}; 29 | let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 30 | let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 31 | 32 | assert_eq!(map.get(&[1., 1.]).is_some(), true); 33 | assert_eq!(map.get(&[3., 1.]).is_none(), true); 34 | assert_eq!(map.range_search(&[0., 0.], &[3., 3.]).is_some(), true); 35 | assert_eq!(map.radius_range(&[2., 1.], 1.).is_some(), true); 36 | assert_eq!(map.nearest_neighbor(&[2., 1.]).is_some(), true); 37 | ``` 38 | 39 | ## Running Demos 40 | 41 | LSPH includes two comprehensive demo applications to showcase its capabilities: 42 | 43 | ### Geographic Data Demo 44 | 45 | A command-line demo using real Melbourne geographic data (6,361 points): 46 | 47 | ```bash 48 | cd examples/demo 49 | cargo run --release 50 | ``` 51 | 52 | Features: 53 | 54 | - Real-world geographic data processing 55 | - Performance benchmarking and analysis 56 | - Interactive nearest neighbor queries 57 | - Range query demonstrations 58 | - Memory usage and throughput metrics 59 | 60 | ### Interactive GUI Demo 61 | 62 | A graphical demonstration with visual spatial operations: 63 | 64 | ```bash 65 | cd examples/interactive_demo 66 | cargo run --release 67 | ``` 68 | 69 | ![LSPH Interactive Demo](examples/interactive_demo/Screenshot.png) 70 | 71 | _Interactive demo showing nearest neighbor search with responsive UI and visual feedback_ 72 | 73 | Features: 74 | 75 | - Visual point addition and management 76 | - Interactive nearest neighbor search 77 | - Range query visualization 78 | - Real-time performance metrics 79 | 80 | ## To Run Benchmark: 81 | 82 | ```bash 83 | cargo bench 84 | ``` 85 | 86 | # License 87 | 88 | Licensed under either of 89 | 90 | - Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) 91 | - MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) 92 | 93 | at your option. 94 | -------------------------------------------------------------------------------- /src/hasher/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::models::Model; 2 | use num_traits::cast::{AsPrimitive, FromPrimitive}; 3 | use num_traits::float::Float; 4 | 5 | /// LearnedHasher takes a model and produces hash from the model 6 | #[derive(Debug, Clone)] 7 | pub struct LearnedHasher { 8 | state: u64, 9 | pub model: M, 10 | sort_by_x: bool, 11 | } 12 | 13 | impl Default for LearnedHasher 14 | where 15 | F: Float, 16 | M: Model + Default, 17 | { 18 | #[inline] 19 | fn default() -> Self { 20 | Self { 21 | state: 0, 22 | model: Default::default(), 23 | sort_by_x: true, 24 | } 25 | } 26 | } 27 | impl LearnedHasher 28 | where 29 | F: Float, 30 | M: Model, 31 | { 32 | /// Returns a default LearnedHasher with Model and Float type. 33 | /// 34 | /// # Arguments 35 | /// * `model` - A model that implements Model trait 36 | /// 37 | /// # Examples 38 | /// 39 | /// ``` 40 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 41 | /// let hasher = LearnedHasher::with_model(LinearModel::::new()); 42 | /// ``` 43 | #[inline] 44 | pub fn with_model(model: M) -> Self { 45 | Self { 46 | state: 0, 47 | model, 48 | sort_by_x: true, 49 | } 50 | } 51 | 52 | /// Returns a current Hasher state. 53 | #[inline] 54 | fn finish(&self) -> u64 { 55 | self.state 56 | } 57 | 58 | /// Returns the sorted index base on parameter self.sort_by_x. 59 | #[inline] 60 | pub fn sort_by_x(&self) -> bool { 61 | self.sort_by_x 62 | } 63 | 64 | /// Sets self.sort_by_x to a given boolean value. 65 | #[inline] 66 | pub fn set_sort_by_x(&mut self, x: bool) { 67 | self.sort_by_x = x; 68 | } 69 | } 70 | 71 | impl LearnedHasher 72 | where 73 | F: Float, 74 | M: Model + Default, 75 | { 76 | /// Returns a default LearnedHasher. 77 | #[inline] 78 | pub fn new() -> Self { 79 | Self::default() 80 | } 81 | } 82 | 83 | impl LearnedHasher 84 | where 85 | F: Float + AsPrimitive, 86 | M: Model, 87 | { 88 | /// Writes a data into self.data by inferencing the input data into the trained model. 89 | #[inline] 90 | fn write(&mut self, data: &F) { 91 | self.state = self.model.predict(*data).floor().as_(); 92 | } 93 | } 94 | 95 | impl LearnedHasher 96 | where 97 | F: Float + FromPrimitive, 98 | M: Model + Default, 99 | { 100 | /// Unwrite takes a hash value, and unpredict the hash value to estimate the approximate input 101 | /// data from the target data(e.g. in linear regression, to get the x value for given y). 102 | /// 103 | /// # Arguments 104 | /// * `hash` - An usize hash value 105 | #[inline] 106 | fn unwrite(&mut self, hash: u64) -> F { 107 | let hash = FromPrimitive::from_u64(hash).unwrap(); 108 | self.model.unpredict(hash) 109 | } 110 | } 111 | /// Make hash value from a given hasher, returns a u64 hash value. 112 | /// 113 | /// # Arguments 114 | /// * `hasher` - A LearnedHasher type 115 | #[inline] 116 | pub fn make_hash(hasher: &mut LearnedHasher, p: &F) -> u64 117 | where 118 | F: Float + FromPrimitive + AsPrimitive, 119 | M: Model + Default, 120 | { 121 | hasher.write(p); 122 | hasher.finish() 123 | } 124 | 125 | /// Make hash value from a given hasher, and 2 item array with float data. 126 | /// 127 | /// # Arguments 128 | /// * `hasher` - A LearnedHasher type 129 | /// * `p` - Point data 130 | #[inline] 131 | pub fn make_hash_point(hasher: &mut LearnedHasher, p: &[F; 2]) -> u64 132 | where 133 | F: Float + FromPrimitive + AsPrimitive, 134 | M: Model + Default, 135 | { 136 | if hasher.sort_by_x { 137 | make_hash(hasher, &p[0]) 138 | } else { 139 | make_hash(hasher, &p[1]) 140 | } 141 | } 142 | 143 | /// Unmake hash value from a given hasher, and a u64 hash value. 144 | /// Reverse the hash function, which it takes a hash and returns float 145 | /// 146 | /// # Arguments 147 | /// * `hasher` - A LearnedHasher type 148 | /// * `p` - Point data 149 | #[inline] 150 | pub fn unhash(hasher: &mut LearnedHasher, hash: u64) -> F 151 | where 152 | F: Float + FromPrimitive + AsPrimitive, 153 | M: Model + Default, 154 | { 155 | hasher.unwrite(hash) 156 | } 157 | 158 | #[cfg(test)] 159 | mod tests { 160 | use super::LearnedHasher; 161 | use crate::models::LinearModel; 162 | 163 | #[test] 164 | fn hasher_with_empty_model() { 165 | let mut hasher: LearnedHasher> = LearnedHasher::new(); 166 | hasher.write(&10f64); 167 | assert_eq!(0u64, hasher.finish()); 168 | } 169 | 170 | #[test] 171 | fn unhash() { 172 | let mut hasher: LearnedHasher> = LearnedHasher::with_model(LinearModel { 173 | coefficient: 3., 174 | intercept: 2., 175 | }); 176 | hasher.write(&10.5); 177 | assert_eq!(33u64, hasher.finish()); 178 | assert_delta!(10.33f64, hasher.unwrite(33u64), 0.01); 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /benches/benchmarks.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | extern crate criterion; 3 | 4 | use criterion::Criterion; 5 | use lsph::geometry::Point; 6 | use lsph::{map::LearnedHashMap, models::LinearModel}; 7 | use rand::{rngs::SmallRng, Rng, SeedableRng}; 8 | 9 | const SEED_1: &[u8; 32] = b"Gv0aHMtHkBGsUXNspGU9fLRuCWkZWHZx"; 10 | const SEED_2: &[u8; 32] = b"km7DO4GeaFZfTcDXVpnO7ZJlgUY7hZiS"; 11 | 12 | const DEFAULT_BENCHMARK_TREE_SIZE: usize = 2000; 13 | 14 | fn bulk_load_baseline(c: &mut Criterion) { 15 | c.bench_function("Bulk load baseline", move |b| { 16 | b.iter(|| { 17 | let mut points: Vec<_> = 18 | create_random_point_type_points(DEFAULT_BENCHMARK_TREE_SIZE, SEED_1); 19 | let mut map = LearnedHashMap::, f64>::new(); 20 | map.batch_insert(&mut points).unwrap(); 21 | }); 22 | }) 23 | .bench_function("Bulk load baseline with f32", move |b| { 24 | b.iter(|| { 25 | let mut points: Vec<_> = 26 | create_random_point_type_points_f32(DEFAULT_BENCHMARK_TREE_SIZE, SEED_1); 27 | let mut map = LearnedHashMap::, f32>::new(); 28 | map.batch_insert(&mut points).unwrap(); 29 | }); 30 | }); 31 | } 32 | 33 | fn locate_successful(c: &mut Criterion) { 34 | let mut points: Vec<_> = create_random_point_type_points(100_000, SEED_1); 35 | let mut points_f32: Vec<_> = create_random_point_type_points_f32(100_000, SEED_1); 36 | let query_point = [points[500].x(), points[500].y()]; 37 | let query_point_f32 = [points_f32[500].x(), points_f32[500].y()]; 38 | 39 | let mut map = LearnedHashMap::, f64>::new(); 40 | let mut map_f32 = LearnedHashMap::, f32>::new(); 41 | map.batch_insert(&mut points).unwrap(); 42 | map_f32.batch_insert(&mut points_f32).unwrap(); 43 | c.bench_function("locate_at_point (successful)", move |b| { 44 | b.iter(|| map.get(&query_point).is_some()) 45 | }) 46 | .bench_function("locate_at_point_f32 (successful)", move |b| { 47 | b.iter(|| map_f32.get(&query_point_f32).is_some()) 48 | }); 49 | } 50 | 51 | fn locate_unsuccessful(c: &mut Criterion) { 52 | let mut points: Vec<_> = create_random_point_type_points(100_000, SEED_1); 53 | let mut points_f32: Vec<_> = create_random_point_type_points_f32(100_000, SEED_1); 54 | let query_point: [f64; 2] = [0.7, 0.7]; 55 | let query_point_f32: [f32; 2] = [0.7, 0.7]; 56 | 57 | let mut map = LearnedHashMap::, f64>::new(); 58 | let mut map_f32 = LearnedHashMap::, f32>::new(); 59 | map.batch_insert(&mut points).unwrap(); 60 | map_f32.batch_insert(&mut points_f32).unwrap(); 61 | c.bench_function("locate_at_point (unsuccessful)", move |b| { 62 | b.iter(|| map.get(&query_point).is_none()) 63 | }) 64 | .bench_function("locate_at_point_f32 (unsuccessful)", move |b| { 65 | b.iter(|| map_f32.get(&query_point_f32).is_none()) 66 | }); 67 | } 68 | 69 | fn nearest_neighbor(c: &mut Criterion) { 70 | const SIZE: usize = 100_000; 71 | let mut points: Vec<_> = create_random_point_type_points(SIZE, SEED_1); 72 | let query_points = create_random_points(100, SEED_2); 73 | 74 | let mut map = LearnedHashMap::, f64>::new(); 75 | map.batch_insert(&mut points).unwrap(); 76 | 77 | c.bench_function("nearest_neigbor", move |b| { 78 | b.iter(|| { 79 | for query_point in &query_points { 80 | map.nearest_neighbor(query_point).unwrap(); 81 | } 82 | }); 83 | }); 84 | } 85 | 86 | fn radius_range(c: &mut Criterion) { 87 | const SIZE: usize = 100_000; 88 | let mut points: Vec<_> = create_random_point_type_points(SIZE, SEED_1); 89 | let query_points = create_random_points(100, SEED_2); 90 | 91 | let mut map = LearnedHashMap::, f64>::new(); 92 | map.batch_insert(&mut points).unwrap(); 93 | 94 | let radiuses = vec![0.01, 0.1, 0.2]; 95 | for radius in radiuses { 96 | let title = format!("radius_range_{radius}"); 97 | c.bench_function(title.as_str(), |b| { 98 | b.iter(|| { 99 | for query_point in &query_points { 100 | map.radius_range(query_point, radius).unwrap(); 101 | } 102 | }); 103 | }); 104 | } 105 | } 106 | 107 | criterion_group!( 108 | benches, 109 | bulk_load_baseline, 110 | locate_successful, 111 | locate_unsuccessful, 112 | radius_range, 113 | nearest_neighbor, 114 | ); 115 | criterion_main!(benches); 116 | 117 | fn create_random_points(num_points: usize, seed: &[u8; 32]) -> Vec<[f64; 2]> { 118 | let mut result = Vec::with_capacity(num_points); 119 | let mut rng = SmallRng::from_seed(*seed); 120 | for _ in 0..num_points { 121 | result.push([rng.random(), rng.random()]); 122 | } 123 | result 124 | } 125 | 126 | fn create_random_point_type_points(num_points: usize, seed: &[u8; 32]) -> Vec> { 127 | let result = create_random_points(num_points, seed); 128 | 129 | // result.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); 130 | result 131 | .into_iter() 132 | .map(|[x, y]| Point::new(x, y)) 133 | .collect::>() 134 | } 135 | 136 | fn create_random_point_type_points_f32(num_points: usize, seed: &[u8; 32]) -> Vec> { 137 | let result = create_random_points(num_points, seed); 138 | 139 | // result.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); 140 | result 141 | .into_iter() 142 | .map(|[x, y]| Point::new(x as f32, y as f32)) 143 | .collect::>() 144 | } 145 | -------------------------------------------------------------------------------- /examples/interactive_demo/README.md: -------------------------------------------------------------------------------- 1 | # LSPH Interactive Demo 2 | 3 | An interactive graphical demonstration of the Learned Spatial HashMap (LSPH) capabilities. 4 | 5 | ## Features 6 | 7 | This demo showcases the core functionality of LSPH through an intuitive graphical interface: 8 | 9 | ### 🎯 Demo Modes 10 | 11 | 1. **Manual Point Addition** 12 | - Add points by clicking on the canvas or entering coordinates 13 | - Specify custom values for each point 14 | - Visual feedback with color-coded points 15 | 16 | 2. **Random Generation** 17 | - Generate random points automatically 18 | - Adjustable generation speed 19 | - Batch generation (10 or 100 points) 20 | - Auto-generation mode for continuous demonstration 21 | 22 | 3. **Nearest Neighbor Search** 23 | - Click anywhere to find the nearest point 24 | - Visual highlighting of the nearest neighbor 25 | - Real-time search performance metrics 26 | - Interactive query point positioning 27 | 28 | 4. **Range Query** 29 | - Define a search center and radius 30 | - Visual circle showing the search area 31 | - Highlight all points within the specified range 32 | - Adjustable search radius with slider 33 | 34 | ### 🎨 Visualization Features 35 | 36 | - **Interactive Canvas**: Click to add points or perform searches 37 | - **Color-coded Points**: Points are colored based on their values 38 | - **Grid Overlay**: Optional grid for better spatial reference 39 | - **Adjustable Point Size**: Customize point visualization 40 | - **Real-time Highlighting**: Visual feedback for search results 41 | - **Performance Metrics**: Display search times and point counts 42 | 43 | ### 📊 Statistics 44 | 45 | - Total number of points in the spatial map 46 | - Search operation timing (in milliseconds) 47 | - Real-time updates during interactions 48 | 49 | ## Running the Demo 50 | 51 | ### Prerequisites 52 | 53 | Make sure you have Rust installed on your system. If not, install it from [rustup.rs](https://rustup.rs/). 54 | 55 | ### Installation and Execution 56 | 57 | 1. Navigate to the demo directory: 58 | ```bash 59 | cd /Users/jackson/Documents/code/project/lsph/examples/interactive_demo 60 | ``` 61 | 62 | 2. Run the demo: 63 | ```bash 64 | cargo run 65 | ``` 66 | 67 | The application will open in a new window with the interactive interface. 68 | 69 | ## How to Use 70 | 71 | ### Getting Started 72 | 73 | 1. **Start with Random Generation**: Select "Random Generation" mode and click "Generate 100 Points" to populate the canvas 74 | 2. **Try Nearest Neighbor**: Switch to "Nearest Neighbor Search" mode and click anywhere on the canvas 75 | 3. **Explore Range Queries**: Use "Range Query" mode to find all points within a specified radius 76 | 4. **Manual Addition**: Add specific points using "Manual Point Addition" mode 77 | 78 | ### Tips for Best Experience 79 | 80 | - **Use the grid**: Enable "Show Grid" for better spatial reference 81 | - **Adjust point size**: Increase point size for better visibility with many points 82 | - **Try different modes**: Each mode demonstrates different LSPH capabilities 83 | - **Watch performance**: Notice how search times remain fast even with many points 84 | - **Interactive exploration**: Click around the canvas to see real-time search results 85 | 86 | ## Technical Details 87 | 88 | ### Dependencies 89 | 90 | - **LSPH**: The core spatial hashmap library 91 | - **eframe/egui**: Modern immediate-mode GUI framework for Rust 92 | - **rand**: Random number generation for demo data 93 | - **serde**: Serialization support (future feature) 94 | 95 | ### Architecture 96 | 97 | The demo is built using the egui immediate-mode GUI framework, providing: 98 | - Cross-platform compatibility (Windows, macOS, Linux) 99 | - Smooth 60fps rendering 100 | - Responsive user interface 101 | - Real-time visualization updates 102 | 103 | ### Performance Characteristics 104 | 105 | The demo demonstrates LSPH's key performance benefits: 106 | - **Fast Insertions**: Add thousands of points quickly 107 | - **Efficient Searches**: Nearest neighbor and range queries in sub-millisecond time 108 | - **Memory Efficiency**: Compact spatial indexing 109 | - **Scalability**: Performance remains consistent as data grows 110 | 111 | ## Educational Value 112 | 113 | This demo is designed to help users understand: 114 | 115 | 1. **Spatial Data Structures**: How points are organized in 2D space 116 | 2. **Search Algorithms**: Visual representation of nearest neighbor and range queries 117 | 3. **Performance Benefits**: Real-time metrics showing LSPH's efficiency 118 | 4. **Interactive Learning**: Hands-on exploration of spatial algorithms 119 | 120 | ## Extending the Demo 121 | 122 | The demo can be extended with additional features: 123 | 124 | - **3D Visualization**: Extend to 3D point clouds 125 | - **Data Import/Export**: Load real-world datasets 126 | - **Algorithm Comparison**: Side-by-side comparison with other spatial structures 127 | - **Advanced Queries**: k-nearest neighbors, polygon queries 128 | - **Animation**: Animated insertions and searches 129 | - **Benchmarking**: Built-in performance testing suite 130 | 131 | ## Troubleshooting 132 | 133 | ### Common Issues 134 | 135 | 1. **Compilation Errors**: Ensure you have the latest Rust toolchain 136 | 2. **Missing Dependencies**: Run `cargo update` to fetch latest dependencies 137 | 3. **Performance Issues**: Try reducing the number of points or disabling auto-generation 138 | 139 | ### System Requirements 140 | 141 | - **Rust**: 1.70 or later 142 | - **Graphics**: OpenGL 3.0+ support 143 | - **Memory**: 512MB RAM minimum 144 | - **Platform**: Windows 10+, macOS 10.15+, or Linux with X11/Wayland 145 | 146 | ## License 147 | 148 | This demo is part of the LSPH project and is licensed under MIT OR Apache-2.0. -------------------------------------------------------------------------------- /src/models/trainer.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | error::Error, 3 | geometry::{helper::*, Axis, Point}, 4 | models::{variance, Model}, 5 | }; 6 | use core::iter::Sum; 7 | use num_traits::{cast::FromPrimitive, Float}; 8 | 9 | /// Preprocessing and prepare data for model training 10 | /// 11 | #[derive(Debug, Clone)] 12 | pub struct Trainer { 13 | train_x: Vec, 14 | train_y: Vec, 15 | axis: Axis, 16 | } 17 | 18 | impl Default for Trainer { 19 | fn default() -> Self { 20 | Self { 21 | train_x: Vec::::new(), 22 | train_y: Vec::::new(), 23 | axis: Axis::X, 24 | } 25 | } 26 | } 27 | 28 | impl Trainer 29 | where 30 | F: Float + Sized, 31 | { 32 | pub fn new() -> Self { 33 | Self::default() 34 | } 35 | 36 | pub fn train_x(&self) -> &Vec { 37 | &self.train_x 38 | } 39 | 40 | pub fn train_y(&self) -> &Vec { 41 | &self.train_y 42 | } 43 | 44 | pub fn axis(&self) -> &Axis { 45 | &self.axis 46 | } 47 | 48 | pub fn set_train_x(&mut self, xs: Vec) { 49 | self.train_x = xs 50 | } 51 | 52 | pub fn set_train_y(&mut self, ys: Vec) { 53 | self.train_y = ys 54 | } 55 | 56 | pub fn set_axis(&mut self, axis: Axis) { 57 | self.axis = axis 58 | } 59 | 60 | /// Training with provided model 61 | /// 62 | /// Returns trained/fitted model on success, otherwise returns an error 63 | pub fn train<'a, M: Model + 'a>(&self, model: &'a mut M) -> Result<(), Error> { 64 | model.fit(&self.train_x, &self.train_y)?; 65 | Ok(()) 66 | } 67 | } 68 | 69 | impl Trainer 70 | where 71 | F: Float + Sum + FromPrimitive, 72 | { 73 | /// Initialize Trainer with two Vec 74 | /// 75 | /// Returns prepared Trainer Ok((Trainer)) on success, otherwise returns an error 76 | pub fn with_data(xs: Vec, ys: Vec) -> Result<(Self, Vec>), Error> { 77 | assert_empty!(xs); 78 | assert_eq_len!(xs, ys); 79 | 80 | let mut trainer = Trainer::new(); 81 | let data = trainer.preprocess(xs, ys)?; 82 | Ok((trainer, data)) 83 | } 84 | 85 | /// Preprocess two Vec that satisfy Trainer's requirements 86 | /// 87 | /// Returns sorted Ok(Vec>) on success, otherwise returns an error 88 | pub fn preprocess(&mut self, xs: Vec, ys: Vec) -> Result>, Error> { 89 | assert_empty!(xs); 90 | assert_eq_len!(xs, ys); 91 | 92 | let mut ps: Vec> = xs 93 | .iter() 94 | .zip(ys.iter()) 95 | .map(|(&x, &y)| Point { x, y }) 96 | .collect(); 97 | 98 | // set train_x to data with larger variance 99 | if variance(&xs) > variance(&ys) { 100 | // sort along x 101 | sort_by_x(&mut ps); 102 | self.set_axis(Axis::X); 103 | self.set_train_x(extract_x(&ps)); 104 | } else { 105 | // sort along y 106 | sort_by_y(&mut ps); 107 | self.set_axis(Axis::Y); 108 | self.set_train_x(extract_y(&ps)); 109 | }; 110 | 111 | let train_y: Vec = (0..ps.len()).map(|id| F::from_usize(id).unwrap()).collect(); 112 | self.set_train_y(train_y); 113 | Ok(ps) 114 | } 115 | 116 | /// Preprocess with Vec> that satisfy Trainer's requirements 117 | /// 118 | /// Returns prepared Trainer Ok((Trainer)) on success, otherwise returns an error 119 | pub fn with_points(ps: &mut [Point]) -> Result { 120 | let px: Vec = extract_x(ps); 121 | let py: Vec = extract_y(ps); 122 | assert_eq_len!(px, py); 123 | let x_variance = variance(&px); 124 | let y_variance = variance(&py); 125 | // set train_x to data with larger variance 126 | let (axis, train_x) = if x_variance > y_variance { 127 | sort_by_x(ps); 128 | (Axis::X, extract_x(ps)) 129 | } else { 130 | sort_by_y(ps); 131 | (Axis::Y, extract_y(ps)) 132 | }; 133 | let train_y: Vec = (0..ps.len()).map(|id| F::from_usize(id).unwrap()).collect(); 134 | Ok(Self { 135 | train_x, 136 | train_y, 137 | axis, 138 | }) 139 | } 140 | } 141 | 142 | #[cfg(test)] 143 | mod tests { 144 | use super::*; 145 | 146 | #[test] 147 | fn sort_by() { 148 | let mut data: Vec> = vec![ 149 | Point { x: 1., y: 1. }, 150 | Point { x: 3., y: 1. }, 151 | Point { x: 2., y: 1. }, 152 | Point { x: 3., y: 2. }, 153 | Point { x: 5., y: 1. }, 154 | ]; 155 | let data_sort_by_x: Vec> = vec![ 156 | Point { x: 1., y: 1. }, 157 | Point { x: 2., y: 1. }, 158 | Point { x: 3., y: 1. }, 159 | Point { x: 3., y: 2. }, 160 | Point { x: 5., y: 1. }, 161 | ]; 162 | sort_by_x(&mut data); 163 | 164 | assert_eq!(data_sort_by_x, data); 165 | } 166 | 167 | #[test] 168 | fn train() { 169 | let mut data: Vec> = vec![ 170 | Point { x: 1., y: 1. }, 171 | Point { x: 3., y: 1. }, 172 | Point { x: 2., y: 1. }, 173 | Point { x: 3., y: 2. }, 174 | Point { x: 5., y: 1. }, 175 | ]; 176 | let trainer = Trainer::with_points(&mut data).unwrap(); 177 | let test_x = vec![1., 2., 3., 3., 5.]; 178 | let test_y = vec![0., 1., 2., 3., 4.]; 179 | 180 | assert_eq!(&test_x, trainer.train_x()); 181 | assert_eq!(&test_y, trainer.train_y()); 182 | } 183 | } 184 | -------------------------------------------------------------------------------- /src/models/stats.rs: -------------------------------------------------------------------------------- 1 | use core::iter::Sum; 2 | use num_traits::{cast::FromPrimitive, float::Float}; 3 | 4 | /// Calculates mean of a slice of values 5 | pub fn mean(values: &[F]) -> F 6 | where 7 | F: Float + Sum, 8 | { 9 | if values.is_empty() { 10 | return F::zero(); 11 | } 12 | let sum: F = values.iter().cloned().map(Into::into).sum(); 13 | sum / F::from(values.len()).unwrap() 14 | } 15 | 16 | /// Calculates variance of a slice of values 17 | pub fn variance(values: &[F]) -> F 18 | where 19 | F: Float + Sum, 20 | { 21 | if values.is_empty() { 22 | return F::zero(); 23 | } 24 | let mean = mean(values); 25 | 26 | let diff_sum: F = values 27 | .iter() 28 | .cloned() 29 | .map(|x| (x - mean).powf(F::from(2.0).unwrap())) 30 | .sum(); 31 | diff_sum / F::from(values.len()).unwrap() 32 | } 33 | 34 | /// Calculates covariance over two slices of values 35 | pub fn covariance(x_values: &[F], y_values: &[F]) -> F 36 | where 37 | F: Float + Sum, 38 | { 39 | if x_values.len() != y_values.len() { 40 | panic!("x_values and y_values must be of equal length."); 41 | } 42 | let length: usize = x_values.len(); 43 | if length == 0usize { 44 | return F::zero(); 45 | } 46 | let mean_x = mean(x_values); 47 | let mean_y = mean(y_values); 48 | 49 | x_values 50 | .iter() 51 | .zip(y_values.iter()) 52 | .fold(F::zero(), |covariance, (&x, &y)| { 53 | covariance + (x - mean_x) * (y - mean_y) 54 | }) 55 | / F::from(length).unwrap() 56 | } 57 | 58 | /// Calculates mean squared error between actual and predict values 59 | pub fn mean_squared_error(actual: &[F], predict: &[F]) -> F 60 | where 61 | F: Float + FromPrimitive, 62 | { 63 | if actual.len() != predict.len() { 64 | panic!("actual and predict must be of equal length."); 65 | } 66 | 67 | actual 68 | .iter() 69 | .zip(predict.iter()) 70 | .fold(F::from_f64(0.0).unwrap(), |sum, (&x, &y)| { 71 | sum + (x - y).powf(F::from_f64(2.0).unwrap()) 72 | }) 73 | / F::from_usize(actual.len()).unwrap() 74 | } 75 | 76 | /// Calculates root mean squared error between actual and predict values 77 | pub fn root_mean_squared_error(actual: &[F], predict: &[F]) -> F 78 | where 79 | F: Float + FromPrimitive, 80 | { 81 | mean_squared_error::(actual, predict).sqrt() 82 | } 83 | 84 | #[cfg(test)] 85 | mod tests { 86 | use super::*; 87 | 88 | #[test] 89 | fn mean_empty_vec() { 90 | let values: Vec = vec![]; 91 | assert_delta!(0., mean(&values), 0.00001); 92 | } 93 | 94 | #[test] 95 | fn mean_empty_vec_f32() { 96 | let values: Vec = vec![]; 97 | assert_delta!(0., mean(&values), 0.00001); 98 | } 99 | 100 | #[test] 101 | fn mean_1_to_5() { 102 | let values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 103 | assert_delta!(3., mean(&values), 0.00001); 104 | } 105 | 106 | #[test] 107 | fn mean_1_to_5_f32() { 108 | let values: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 109 | assert_delta!(3., mean(&values), 0.00001); 110 | } 111 | 112 | #[test] 113 | fn variance_empty_vec() { 114 | let values: Vec = vec![]; 115 | assert_delta!(0., variance(&values), 0.00001); 116 | } 117 | 118 | #[test] 119 | fn variance_empty_vec_f32() { 120 | let values: Vec = vec![]; 121 | assert_delta!(0., variance(&values), 0.00001); 122 | } 123 | 124 | #[test] 125 | fn variance_1_to_5() { 126 | let values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 127 | assert_delta!(2., variance(&values), 0.00001); 128 | } 129 | 130 | #[test] 131 | fn variance_1_to_5_f32() { 132 | let values: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 133 | assert_delta!(2., variance(&values), 0.00001); 134 | } 135 | 136 | #[test] 137 | fn covariance_empty_vec() { 138 | let x_values: Vec = vec![]; 139 | let y_values: Vec = vec![]; 140 | assert_delta!(0., covariance(&x_values, &y_values), 0.00001); 141 | } 142 | 143 | #[test] 144 | fn covariance_empty_vec_f32() { 145 | let x_values: Vec = vec![]; 146 | let y_values: Vec = vec![]; 147 | assert_delta!(0., covariance(&x_values, &y_values), 0.00001); 148 | } 149 | 150 | #[test] 151 | fn covariance_1_to_5() { 152 | let x_values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 153 | let y_values = vec![1.0, 3.0, 2.0, 3.0, 5.0]; 154 | assert_delta!(1.6, covariance(&x_values, &y_values), 0.00001); 155 | } 156 | 157 | #[test] 158 | fn covariance_1_to_5_f32() { 159 | let x_values: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 160 | let y_values: Vec = vec![1.0, 3.0, 2.0, 3.0, 5.0]; 161 | assert_delta!(1.6, covariance(&x_values, &y_values), 0.00001); 162 | } 163 | 164 | #[test] 165 | fn negative_covariance_1_to_5() { 166 | let x_values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 167 | let y_values = vec![0.5, 4.0, 1.0, -5.0, 4.0]; 168 | assert_delta!(-0.4, covariance(&x_values, &y_values), 0.00001); 169 | } 170 | 171 | #[test] 172 | fn negative_covariance_1_to_5_f32() { 173 | let x_values: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 174 | let y_values: Vec = vec![0.5, 4.0, 1.0, -5.0, 4.0]; 175 | assert_delta!(-0.4, covariance(&x_values, &y_values), 0.00001); 176 | } 177 | 178 | #[test] 179 | fn mean_squared_error_test() { 180 | let actual = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 181 | let predict = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 182 | assert_delta!(0., mean_squared_error(&actual, &predict), 0.00001); 183 | } 184 | 185 | #[test] 186 | fn mean_squared_error_test_f32() { 187 | let actual: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 188 | let predict: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 189 | assert_delta!(0., mean_squared_error(&actual, &predict), 0.00001); 190 | } 191 | 192 | #[test] 193 | fn mean_squared_error_test_2() { 194 | let actual = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 195 | let predict = vec![1.0, 1.6, 3.0, 4.0, 5.0]; 196 | assert_delta!(0.032, mean_squared_error(&actual, &predict), 0.00001); 197 | } 198 | 199 | #[test] 200 | fn mean_squared_error_test_2_f32() { 201 | let actual: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 202 | let predict: Vec = vec![1.0, 1.6, 3.0, 4.0, 5.0]; 203 | assert_delta!(0.032, mean_squared_error(&actual, &predict), 0.00001); 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /examples/demo/README.md: -------------------------------------------------------------------------------- 1 | # LSPH Geographic Data Demo 2 | 3 | A comprehensive demonstration of the Learned Spatial HashMap (LSPH) capabilities using real-world geographic data from Melbourne, Australia. 4 | 5 | ## Overview 6 | 7 | This demo showcases LSPH's spatial indexing and querying capabilities through: 8 | 9 | - **Real-world Data**: 6,361 geographic points from Melbourne 10 | - **Performance Benchmarking**: Detailed timing and throughput analysis 11 | - **Interactive Exploration**: Command-line interface for custom queries 12 | - **Comprehensive Metrics**: Memory usage, query performance, and scalability analysis 13 | 14 | ## Features 15 | 16 | ### 🗺️ Data Processing 17 | - Loads CSV geographic data with latitude, longitude, and zone information 18 | - Displays data distribution and zone statistics 19 | - Robust error handling for malformed data 20 | 21 | ### 🏗️ Spatial Indexing 22 | - Builds LSPH spatial index from geographic points 23 | - Tracks insertion performance and memory usage 24 | - Provides detailed indexing statistics 25 | 26 | ### 🎯 Nearest Neighbor Search 27 | - Performs configurable number of nearest neighbor queries 28 | - Measures query performance in microseconds 29 | - Calculates real-world distances using Haversine formula 30 | - Reports queries per second and success rates 31 | 32 | ### 🌐 Range Queries 33 | - Tests multiple search radii (0.001° to 0.02°) 34 | - Converts degrees to approximate meters for intuitive understanding 35 | - Analyzes average results per query and performance metrics 36 | 37 | ### 📈 Performance Analysis 38 | - Comprehensive timing analysis (min, max, average) 39 | - Memory usage estimation 40 | - Throughput calculations (points/second, queries/second) 41 | - Detailed performance summaries 42 | 43 | ### 🎮 Interactive Mode 44 | - Real-time nearest neighbor queries 45 | - User-friendly coordinate input format 46 | - Instant distance and timing feedback 47 | - Graceful error handling 48 | 49 | ## Installation 50 | 51 | ### Prerequisites 52 | 53 | - Rust 1.70 or later 54 | - Cargo package manager 55 | 56 | ### Building 57 | 58 | ```bash 59 | cd /Users/jackson/Documents/code/project/lsph/examples/demo 60 | cargo build --release 61 | ``` 62 | 63 | ## Usage 64 | 65 | ### Basic Demo 66 | 67 | Run the complete demonstration with default settings: 68 | 69 | ```bash 70 | cargo run --release 71 | ``` 72 | 73 | This will: 74 | 1. Load the Melbourne dataset (6,361 points) 75 | 2. Build the spatial index 76 | 3. Perform 100 nearest neighbor queries 77 | 4. Execute 25 range queries with different radii 78 | 5. Display comprehensive performance summary 79 | 80 | ### Command Line Options 81 | 82 | ```bash 83 | # Custom data file 84 | cargo run --release -- --data custom_data.csv 85 | 86 | # Specify number of test queries 87 | cargo run --release -- --queries 500 88 | 89 | # Run in interactive mode after demo 90 | cargo run --release -- --interactive 91 | 92 | # Skip automated demo, go straight to interactive 93 | cargo run --release -- --skip-demo 94 | 95 | # Combine options 96 | cargo run --release -- --queries 1000 --interactive 97 | ``` 98 | 99 | ### Interactive Mode 100 | 101 | In interactive mode, enter coordinates in the format `lat,lng`: 102 | 103 | ``` 104 | 🔍 Query: -37.8136,144.9631 105 | ✅ Nearest point: (-37.81360, 144.96310) 106 | 📏 Distance: 0.00m | ⏱️ Query time: 12.34μs 107 | 108 | 🔍 Query: quit 109 | 👋 Goodbye! 110 | ``` 111 | 112 | ## Data Format 113 | 114 | The demo expects CSV data with three columns (no headers): 115 | 116 | ```csv 117 | latitude,longitude,zone 118 | -37.82341,144.98905,31 119 | -37.82962,144.98793,31 120 | -37.83119,144.98961,31 121 | ``` 122 | 123 | - **Latitude**: Decimal degrees (negative for Southern Hemisphere) 124 | - **Longitude**: Decimal degrees (positive for Eastern Hemisphere) 125 | - **Zone**: Integer category/zone identifier 126 | 127 | ## Performance Characteristics 128 | 129 | ### Expected Performance (Melbourne Dataset) 130 | 131 | | Metric | Typical Value | 132 | |--------|---------------| 133 | | **Data Loading** | ~5-10ms for 6,361 points | 134 | | **Index Building** | ~15-30ms | 135 | | **Nearest Neighbor** | ~10-50μs per query | 136 | | **Range Queries** | ~20-100μs per query | 137 | | **Memory Usage** | ~0.5-1.0 MB | 138 | | **Throughput** | 20,000-100,000 queries/sec | 139 | 140 | *Performance varies based on hardware and data distribution* 141 | 142 | ### Scalability 143 | 144 | LSPH demonstrates excellent scalability characteristics: 145 | 146 | - **Sub-linear query time** with dataset size 147 | - **Efficient memory usage** (~100-200 bytes per point) 148 | - **Consistent performance** across different query patterns 149 | - **Fast index construction** (200,000+ points/second) 150 | 151 | ## Architecture 152 | 153 | ### Core Components 154 | 155 | ```rust 156 | // Main demo application 157 | struct LSPHDemo { 158 | spatial_map: LearnedHashMap, f64>, 159 | points: Vec, 160 | stats: PerformanceStats, 161 | } 162 | 163 | // Geographic point representation 164 | struct GeoPoint { 165 | latitude: f64, 166 | longitude: f64, 167 | zone: u32, 168 | } 169 | 170 | // Performance tracking 171 | struct PerformanceStats { 172 | data_loading_time: Duration, 173 | index_building_time: Duration, 174 | nearest_neighbor_times: Vec, 175 | range_query_times: Vec, 176 | memory_usage_estimate: usize, 177 | } 178 | ``` 179 | 180 | ### Key Features 181 | 182 | 1. **Error Handling**: Comprehensive error handling for file I/O, parsing, and user input 183 | 2. **Performance Monitoring**: Detailed timing and memory usage tracking 184 | 3. **User Experience**: Colorized output, progress indicators, and clear formatting 185 | 4. **Flexibility**: Configurable parameters and multiple operation modes 186 | 5. **Documentation**: Extensive inline documentation and help text 187 | 188 | ## Example Output 189 | 190 | ``` 191 | 🗺️ LSPH Geographic Data Demo 192 | Learned Spatial HashMap Performance Demonstration 193 | ============================================================ 194 | 195 | 🗺️ Loading Melbourne Geographic Data 196 | ================================================== 197 | ✅ Loaded 6361 points in 8.45ms 198 | 199 | 📊 Zone Distribution: 200 | Zone 31: 6361 points (100.0%) 201 | 202 | 🏗️ Building Spatial Index 203 | ================================================== 204 | ✅ Built spatial index in 23.12ms 205 | 📈 Successful insertions: 6361 206 | 💾 Estimated memory usage: 0.89 MB 207 | 208 | 🎯 Nearest Neighbor Search Demo 209 | ================================================== 210 | 🔍 Query 1: (-37.85234, 144.92156) → Nearest: (-37.85240, 144.92160) | Distance: 8.45m | Time: 15.23μs 211 | 🔍 Query 2: (-37.78945, 145.02341) → Nearest: (-37.78950, 145.02345) | Distance: 6.12m | Time: 12.67μs 212 | ... 213 | 214 | 📊 Nearest Neighbor Results: 215 | Successful queries: 100/100 216 | Average query time: 18.45μs 217 | Queries per second: 54,200 218 | 219 | 📈 Performance Summary 220 | ================================================== 221 | 🗂️ Data Processing: 222 | Total points processed: 6361 223 | Data loading time: 8.45ms 224 | Index building time: 23.12ms 225 | Points per second (indexing): 275,234 226 | 227 | 🎯 Nearest Neighbor Performance: 228 | Average time: 18.45μs 229 | Min time: 8.23μs 230 | Max time: 45.67μs 231 | Queries per second: 54,200 232 | 233 | 💾 Memory Usage: 234 | Estimated total: 0.89 MB 235 | Per point: 143.2 bytes 236 | 237 | 🎉 Demo completed successfully! 238 | Thank you for exploring LSPH capabilities. 239 | ``` 240 | 241 | ## Troubleshooting 242 | 243 | ### Common Issues 244 | 245 | 1. **File Not Found** 246 | ``` 247 | Error: No such file or directory (os error 2) 248 | ``` 249 | - Ensure `melbourne.csv` is in the demo directory 250 | - Use `--data` flag to specify custom file path 251 | 252 | 2. **Parse Errors** 253 | ``` 254 | ⚠️ Error parsing line 42: invalid float literal 255 | ``` 256 | - Check CSV format (no headers, three numeric columns) 257 | - Verify decimal separator (use `.` not `,`) 258 | 259 | 3. **Memory Issues** 260 | - For very large datasets, monitor system memory 261 | - Consider reducing query count with `--queries` flag 262 | 263 | ### Performance Tips 264 | 265 | 1. **Use Release Mode**: Always run with `--release` for accurate performance 266 | 2. **Warm-up Queries**: First few queries may be slower due to CPU caching 267 | 3. **Dataset Size**: Performance scales well, but very large datasets (>1M points) may require more memory 268 | 4. **Query Patterns**: Random queries provide good average-case performance 269 | 270 | ## Contributing 271 | 272 | To extend or modify the demo: 273 | 274 | 1. **Add New Metrics**: Extend `PerformanceStats` struct 275 | 2. **Custom Data Sources**: Modify `load_data()` method 276 | 3. **Additional Query Types**: Add new demo methods 277 | 4. **Visualization**: Consider adding graphical output 278 | 279 | ## License 280 | 281 | This demo is part of the LSPH project and is licensed under MIT OR Apache-2.0. 282 | 283 | ## Related 284 | 285 | - [LSPH Main Documentation](../../README.md) 286 | - [Interactive GUI Demo](../interactive_demo/README.md) 287 | - [Performance Benchmarks](../../benches/) 288 | - [LSPH Paper and Research](../../docs/) -------------------------------------------------------------------------------- /src/models/linear.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | error::*, 3 | models::{stats::root_mean_squared_error, Model}, 4 | }; 5 | 6 | use core::fmt::Debug; 7 | use core::iter::Sum; 8 | use num_traits::{cast::FromPrimitive, float::Float}; 9 | 10 | /// Simple linear regression from tuples. 11 | /// 12 | /// Calculates the simple linear regression from array of tuples, and their means. 13 | /// 14 | /// # Arguments 15 | /// 16 | /// * `xys` - An array of tuples of training data that contains Xs and Ys. 17 | /// 18 | /// * `x_mean` - The mean of Xs training data. 19 | /// 20 | /// * `y_mean` - The mean of Ys target values. 21 | /// 22 | /// Returns `Ok(slope, intercept)` or Err(Error). 23 | /// 24 | /// # Errors 25 | /// 26 | /// Returns an error if 27 | /// 28 | /// * `xs` and `ys` differ in length 29 | /// * `xs` or `ys` are empty 30 | /// * the slope is too steep to represent, approaching infinity 31 | /// * the number of elements cannot be represented as an `F` 32 | fn slr(xys: I, x_mean: F, y_mean: F) -> Result<(F, F), Error> 33 | where 34 | I: Iterator, 35 | F: Float + Debug, 36 | { 37 | // compute the covariance of x and y as well as the variance of x 38 | let (sq_diff_sum, cov_diff_sum) = xys.fold((F::zero(), F::zero()), |(v, c), (x, y)| { 39 | let diff = x - x_mean; 40 | let sq_diff = diff * diff; 41 | let cov_diff = diff * (y - y_mean); 42 | (v + sq_diff, c + cov_diff) 43 | }); 44 | let slope = cov_diff_sum / sq_diff_sum; 45 | if slope.is_nan() { 46 | return Err(Error::SteepSlope); 47 | } 48 | let intercept = y_mean - slope * x_mean; 49 | Ok((slope, intercept)) 50 | } 51 | 52 | /// Two-pass simple linear regression from slices. 53 | /// 54 | /// Calculates the linear regression from two slices, one for x- and one for y-values, by 55 | /// calculating the mean and then calling `lin_reg`. 56 | /// 57 | /// # Arguments 58 | /// 59 | /// * `xs` - An array of tuples of training data. 60 | /// 61 | /// * `ys` - An array of tuples of targeting data. 62 | /// 63 | /// Returns `Ok(slope, intercept)` of the regression line. 64 | /// 65 | /// # Errors 66 | /// 67 | /// Returns an error if 68 | /// 69 | /// * `xs` and `ys` differ in length 70 | /// * `xs` or `ys` are empty 71 | /// * the slope is too steep to represent, approaching infinity 72 | /// * the number of elements cannot be represented as an `F` 73 | fn linear_regression(xs: &[X], ys: &[Y]) -> Result<(F, F), Error> 74 | where 75 | X: Clone + Into, 76 | Y: Clone + Into, 77 | F: Float + Sum + Debug, 78 | { 79 | assert_empty!(xs); 80 | assert_empty!(ys); 81 | assert_eq_len!(xs, ys); 82 | 83 | let n = F::from(xs.len()).ok_or(Error::EmptyVal)?; 84 | 85 | // compute the mean of x and y 86 | let x_sum: F = xs.iter().cloned().map(Into::into).sum(); 87 | let x_mean = x_sum / n; 88 | let y_sum: F = ys.iter().cloned().map(Into::into).sum(); 89 | let y_mean = y_sum / n; 90 | 91 | let data = xs 92 | .iter() 93 | .zip(ys.iter()) 94 | .map(|(x, y)| (x.clone().into(), y.clone().into())); 95 | 96 | slr(data, x_mean, y_mean) 97 | } 98 | 99 | /// Two-pass linear regression from tuples. 100 | /// 101 | /// Calculates the linear regression from a slice of tuple values by first calculating the mean 102 | /// before calling `lin_reg`. 103 | /// 104 | /// Returns `Ok(slope, intercept)` of the regression line. 105 | /// 106 | /// # Errors 107 | /// 108 | /// Returns an error if 109 | /// 110 | /// * `xys` is empty 111 | /// * the slope is too steep to represent, approaching infinity 112 | /// * the number of elements cannot be represented as an `F` 113 | fn linear_regression_tuple(xys: &[(X, Y)]) -> Result<(F, F), Error> 114 | where 115 | X: Clone + Into, 116 | Y: Clone + Into, 117 | F: Float + Debug, 118 | { 119 | assert_empty!(xys); 120 | 121 | // We're handrolling the mean computation here, because our generic implementation can't handle tuples. 122 | // If we ran the generic impl on each tuple field, that would be very cache inefficient 123 | let n = F::from(xys.len()).ok_or(Error::EmptyVal)?; 124 | let (x_sum, y_sum) = xys 125 | .iter() 126 | .cloned() 127 | .fold((F::zero(), F::zero()), |(sx, sy), (x, y)| { 128 | (sx + x.into(), sy + y.into()) 129 | }); 130 | let x_mean = x_sum / n; 131 | let y_mean = y_sum / n; 132 | 133 | slr( 134 | xys.iter() 135 | .map(|(x, y)| (x.clone().into(), y.clone().into())), 136 | x_mean, 137 | y_mean, 138 | ) 139 | } 140 | 141 | /// Linear regression model 142 | #[derive(Copy, Clone, Debug, Default)] 143 | pub struct LinearModel { 144 | pub coefficient: F, 145 | pub intercept: F, 146 | } 147 | 148 | impl LinearModel 149 | where 150 | F: Float, 151 | { 152 | pub fn new() -> LinearModel { 153 | LinearModel { 154 | coefficient: F::zero(), 155 | intercept: F::zero(), 156 | } 157 | } 158 | } 159 | 160 | impl Model for LinearModel 161 | where 162 | F: Float + FromPrimitive + Sum + Debug + Sized, 163 | { 164 | type F = F; 165 | 166 | fn name(&self) -> String { 167 | String::from("linear") 168 | } 169 | 170 | fn fit(&mut self, xs: &[F], ys: &[F]) -> Result<(), Error> { 171 | let (coefficient, intercept): (F, F) = linear_regression(xs, ys).unwrap(); 172 | self.coefficient = coefficient; 173 | self.intercept = intercept; 174 | Ok(()) 175 | } 176 | fn fit_tuple(&mut self, xys: &[(F, F)]) -> Result<(), Error> { 177 | let (coefficient, intercept): (F, F) = linear_regression_tuple(xys).unwrap(); 178 | self.coefficient = coefficient; 179 | self.intercept = intercept; 180 | Ok(()) 181 | } 182 | 183 | fn predict(&self, x: F) -> F { 184 | x * self.coefficient + self.intercept 185 | } 186 | fn batch_predict(&self, xs: &[F]) -> Vec { 187 | (0..xs.len()).map(|i| self.predict(xs[i])).collect() 188 | } 189 | 190 | fn evaluate(&self, x_test: &[F], y_test: &[F]) -> F { 191 | let y_predicted = self.batch_predict(x_test); 192 | root_mean_squared_error(y_test, &y_predicted) 193 | } 194 | 195 | fn unpredict(&self, y: F) -> F { 196 | (y - self.intercept) / self.coefficient 197 | } 198 | } 199 | 200 | #[cfg(test)] 201 | mod tests { 202 | use super::*; 203 | 204 | #[test] 205 | #[should_panic] 206 | fn should_panic_for_empty_vecs() { 207 | let x_values: Vec = vec![]; 208 | let y_values: Vec = vec![]; 209 | let mut model = LinearModel::new(); 210 | model.fit(&x_values, &y_values).unwrap(); 211 | 212 | assert_delta!(0.8f64, model.coefficient, 0.00001); 213 | assert_delta!(0.4, model.intercept, 0.00001); 214 | } 215 | 216 | #[test] 217 | fn fit_coefficients() { 218 | let x_values: Vec = vec![1., 2., 3., 4., 5.]; 219 | let y_values: Vec = vec![1., 3., 2., 3., 5.]; 220 | let mut model = LinearModel::new(); 221 | model.fit(&x_values, &y_values).unwrap(); 222 | 223 | assert_delta!(0.8f64, model.coefficient, 0.00001); 224 | assert_delta!(0.4f64, model.intercept, 0.00001); 225 | } 226 | 227 | #[test] 228 | fn fit_coefficients_f32() { 229 | let x_values: Vec = vec![1., 2., 3., 4., 5.]; 230 | let y_values: Vec = vec![1., 3., 2., 3., 5.]; 231 | let mut model = LinearModel::new(); 232 | model.fit(&x_values, &y_values).unwrap(); 233 | 234 | assert_delta!(0.8, model.coefficient, 0.00001); 235 | assert_delta!(0.4, model.intercept, 0.00001); 236 | } 237 | 238 | #[test] 239 | fn predict() { 240 | let x_values = vec![1f64, 2f64, 3f64, 4f64, 5f64]; 241 | let y_values = vec![1f64, 3f64, 2f64, 3f64, 5f64]; 242 | let mut model = LinearModel::new(); 243 | model.fit(&x_values, &y_values).unwrap(); 244 | 245 | assert_delta!(1.2f64, model.predict(1f64), 0.00001); 246 | assert_delta!(2f64, model.predict(2f64), 0.00001); 247 | assert_delta!(2.8f64, model.predict(3f64), 0.00001); 248 | assert_delta!(3.6f64, model.predict(4f64), 0.00001); 249 | assert_delta!(4.4f64, model.predict(5f64), 0.00001); 250 | } 251 | 252 | #[test] 253 | fn predict_list() { 254 | let x_values = vec![1f64, 2f64, 3f64, 4f64, 5f64]; 255 | let y_values = vec![1f64, 3f64, 2f64, 3f64, 5f64]; 256 | let mut model = LinearModel::new(); 257 | model.fit(&x_values, &y_values).unwrap(); 258 | 259 | let predictions = model.batch_predict(&x_values); 260 | 261 | assert_delta!(1.2f64, predictions[0], 0.00001); 262 | assert_delta!(2f64, predictions[1], 0.00001); 263 | assert_delta!(2.8f64, predictions[2], 0.00001); 264 | assert_delta!(3.6f64, predictions[3], 0.00001); 265 | assert_delta!(4.4f64, predictions[4], 0.00001); 266 | } 267 | 268 | #[test] 269 | fn evaluate() { 270 | let x_values = vec![1f64, 2f64, 3f64, 4f64, 5f64]; 271 | let y_values = vec![1f64, 3f64, 2f64, 3f64, 5f64]; 272 | let mut model = LinearModel::new(); 273 | model.fit(&x_values.clone(), &y_values.clone()).unwrap(); 274 | 275 | let error = model.evaluate(&x_values, &y_values); 276 | assert_delta!(0.69282f64, error, 0.00001); 277 | } 278 | } 279 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Haozhan Shi 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /examples/demo/src/main.rs: -------------------------------------------------------------------------------- 1 | //! # LSPH Demo Application 2 | //! 3 | //! This demo showcases the capabilities of the Learned Spatial HashMap (LSPH) 4 | //! using real-world geographic data from Melbourne, Australia. 5 | //! 6 | //! Features demonstrated: 7 | //! - Loading and processing geographic CSV data 8 | //! - Spatial indexing with LSPH 9 | //! - Nearest neighbor searches 10 | //! - Range queries 11 | //! - Performance benchmarking 12 | //! - Interactive command-line interface 13 | 14 | use clap::{Arg, Command}; 15 | use colored::*; 16 | use csv::ReaderBuilder; 17 | use lsph::{ 18 | geometry::Point, 19 | map::LearnedHashMap, 20 | models::LinearModel, 21 | }; 22 | use rand::Rng; 23 | use serde::Deserialize; 24 | use std::{ 25 | collections::HashMap, 26 | error::Error, 27 | fs::File, 28 | time::{Duration, Instant}, 29 | }; 30 | 31 | /// Represents a geographic point from the Melbourne dataset 32 | #[derive(Debug, Deserialize, Clone)] 33 | struct GeoPoint { 34 | #[serde(rename = "0")] 35 | latitude: f64, 36 | #[serde(rename = "1")] 37 | longitude: f64, 38 | #[serde(rename = "2")] 39 | zone: u32, 40 | } 41 | 42 | /// Statistics for performance analysis 43 | #[derive(Debug, Default)] 44 | struct PerformanceStats { 45 | data_loading_time: Duration, 46 | index_building_time: Duration, 47 | total_points: usize, 48 | nearest_neighbor_times: Vec, 49 | range_query_times: Vec, 50 | memory_usage_estimate: usize, 51 | } 52 | 53 | /// Main demo application 54 | struct LSPHDemo { 55 | spatial_map: LearnedHashMap, f64>, 56 | points: Vec, 57 | stats: PerformanceStats, 58 | } 59 | 60 | impl LSPHDemo { 61 | /// Create a new demo instance 62 | fn new() -> Self { 63 | Self { 64 | spatial_map: LearnedHashMap::new(), 65 | points: Vec::new(), 66 | stats: PerformanceStats::default(), 67 | } 68 | } 69 | 70 | /// Load geographic data from CSV file 71 | fn load_data(&mut self, file_path: &str) -> Result<(), Box> { 72 | println!( 73 | "{}\n{}", 74 | "🗺️ Loading Melbourne Geographic Data".bright_blue().bold(), 75 | "=".repeat(50).bright_blue() 76 | ); 77 | 78 | let start_time = Instant::now(); 79 | let file = File::open(file_path)?; 80 | let mut reader = ReaderBuilder::new() 81 | .has_headers(false) 82 | .from_reader(file); 83 | 84 | let mut loaded_points = Vec::new(); 85 | let mut zone_counts: HashMap = HashMap::new(); 86 | 87 | for (index, result) in reader.deserialize().enumerate() { 88 | match result { 89 | Ok(point) => { 90 | let geo_point: GeoPoint = point; 91 | *zone_counts.entry(geo_point.zone).or_insert(0) += 1; 92 | loaded_points.push(geo_point); 93 | } 94 | Err(e) => { 95 | eprintln!( 96 | "{} Error parsing line {}: {}", 97 | "⚠️".yellow(), 98 | index + 1, 99 | e 100 | ); 101 | } 102 | } 103 | } 104 | 105 | self.points = loaded_points; 106 | self.stats.data_loading_time = start_time.elapsed(); 107 | self.stats.total_points = self.points.len(); 108 | 109 | println!( 110 | "✅ Loaded {} points in {:.2}ms", 111 | self.stats.total_points.to_string().bright_green(), 112 | self.stats.data_loading_time.as_secs_f64() * 1000.0 113 | ); 114 | 115 | // Display zone distribution 116 | println!("\n📊 Zone Distribution:"); 117 | let mut zones: Vec<_> = zone_counts.into_iter().collect(); 118 | zones.sort_by_key(|&(zone, _)| zone); 119 | for (zone, count) in zones.iter().take(10) { 120 | let percentage = (*count as f64 / self.stats.total_points as f64) * 100.0; 121 | println!( 122 | " Zone {}: {} points ({:.1}%)", 123 | zone.to_string().cyan(), 124 | count.to_string().bright_white(), 125 | percentage 126 | ); 127 | } 128 | if zones.len() > 10 { 129 | println!(" ... and {} more zones", zones.len() - 10); 130 | } 131 | 132 | Ok(()) 133 | } 134 | 135 | /// Build the spatial index 136 | fn build_index(&mut self) -> Result<(), Box> { 137 | println!( 138 | "\n{}\n{}", 139 | "🏗️ Building Spatial Index".bright_blue().bold(), 140 | "=".repeat(50).bright_blue() 141 | ); 142 | 143 | let start_time = Instant::now(); 144 | let mut successful_insertions = 0; 145 | 146 | for geo_point in &self.points { 147 | let point = Point::new(geo_point.latitude, geo_point.longitude); 148 | 149 | match self.spatial_map.insert(point) { 150 | Some(_existing) => { 151 | // Point already existed, this is fine 152 | successful_insertions += 1; 153 | } 154 | None => { 155 | // New point inserted successfully 156 | successful_insertions += 1; 157 | } 158 | } 159 | } 160 | 161 | self.stats.index_building_time = start_time.elapsed(); 162 | self.stats.memory_usage_estimate = self.estimate_memory_usage(); 163 | 164 | println!( 165 | "✅ Built spatial index in {:.2}ms", 166 | self.stats.index_building_time.as_secs_f64() * 1000.0 167 | ); 168 | println!( 169 | "📈 Successful insertions: {}", 170 | successful_insertions.to_string().bright_green() 171 | ); 172 | println!( 173 | "💾 Estimated memory usage: {:.2} MB", 174 | self.stats.memory_usage_estimate as f64 / 1_048_576.0 175 | ); 176 | 177 | Ok(()) 178 | } 179 | 180 | /// Estimate memory usage of the spatial map 181 | fn estimate_memory_usage(&self) -> usize { 182 | // Rough estimation based on point count and structure overhead 183 | let point_size = std::mem::size_of::>(); 184 | let base_overhead = 1024; // Base structure overhead 185 | let per_point_overhead = 64; // Hash table and indexing overhead per point 186 | 187 | base_overhead + (self.stats.total_points * (point_size + per_point_overhead)) 188 | } 189 | 190 | /// Perform nearest neighbor search demonstrations 191 | fn demo_nearest_neighbor(&mut self, num_queries: usize) { 192 | println!( 193 | "\n{}\n{}", 194 | "🎯 Nearest Neighbor Search Demo".bright_blue().bold(), 195 | "=".repeat(50).bright_blue() 196 | ); 197 | 198 | let mut rng = rand::rng(); 199 | let mut successful_queries = 0; 200 | 201 | for i in 0..num_queries { 202 | // Generate random query point within Melbourne bounds 203 | let query_lat = rng.random_range(-37.9..=-37.7); 204 | let query_lng = rng.random_range(144.8..=145.1); 205 | let query_point = [query_lat, query_lng]; 206 | 207 | let start_time = Instant::now(); 208 | let result = self.spatial_map.nearest_neighbor(&query_point); 209 | let query_time = start_time.elapsed(); 210 | 211 | self.stats.nearest_neighbor_times.push(query_time); 212 | 213 | match result { 214 | Some(nearest) => { 215 | successful_queries += 1; 216 | if i < 5 { 217 | // Show details for first few queries 218 | let distance = self.calculate_distance( 219 | query_lat, 220 | query_lng, 221 | nearest.x(), 222 | nearest.y(), 223 | ); 224 | println!( 225 | "🔍 Query {}: ({:.5}, {:.5}) → Nearest: ({:.5}, {:.5}) | Distance: {:.2}m | Time: {:.2}μs", 226 | (i + 1).to_string().cyan(), 227 | query_lat, 228 | query_lng, 229 | nearest.x(), 230 | nearest.y(), 231 | distance, 232 | query_time.as_nanos() as f64 / 1000.0 233 | ); 234 | } 235 | } 236 | None => { 237 | if i < 5 { 238 | println!( 239 | "❌ Query {}: ({:.5}, {:.5}) → No result found", 240 | (i + 1).to_string().red(), 241 | query_lat, 242 | query_lng 243 | ); 244 | } 245 | } 246 | } 247 | } 248 | 249 | let avg_time = self.stats.nearest_neighbor_times.iter().sum::().as_nanos() 250 | / self.stats.nearest_neighbor_times.len() as u128; 251 | 252 | println!( 253 | "\n📊 Nearest Neighbor Results:" 254 | ); 255 | println!( 256 | " Successful queries: {}/{}", 257 | successful_queries.to_string().bright_green(), 258 | num_queries 259 | ); 260 | println!( 261 | " Average query time: {:.2}μs", 262 | avg_time as f64 / 1000.0 263 | ); 264 | println!( 265 | " Queries per second: {:.0}", 266 | 1_000_000.0 / (avg_time as f64 / 1000.0) 267 | ); 268 | } 269 | 270 | /// Perform range query demonstrations 271 | fn demo_range_queries(&mut self, num_queries: usize) { 272 | println!( 273 | "\n{}\n{}", 274 | "🌐 Range Query Demo".bright_blue().bold(), 275 | "=".repeat(50).bright_blue() 276 | ); 277 | 278 | let mut rng = rand::rng(); 279 | let radii = [0.001, 0.005, 0.01, 0.02]; // Different search radii in degrees 280 | 281 | for &radius in &radii { 282 | println!( 283 | "\n🔍 Testing radius: {:.3}° (~{:.0}m)", 284 | radius, 285 | radius * 111_000.0 // Rough conversion to meters 286 | ); 287 | 288 | let mut total_results = 0; 289 | let mut query_times = Vec::new(); 290 | 291 | for i in 0..num_queries { 292 | let query_lat = rng.random_range(-37.9..=-37.7); 293 | let query_lng = rng.random_range(144.8..=145.1); 294 | let query_point = [query_lat, query_lng]; 295 | 296 | let start_time = Instant::now(); 297 | let results = self.spatial_map.radius_range(&query_point, radius); 298 | let query_time = start_time.elapsed(); 299 | 300 | query_times.push(query_time); 301 | 302 | match results { 303 | Some(points) => { 304 | total_results += points.len(); 305 | if i == 0 { 306 | // Show details for first query 307 | println!( 308 | " Sample query: ({:.5}, {:.5}) → {} points found in {:.2}μs", 309 | query_lat, 310 | query_lng, 311 | points.len().to_string().bright_green(), 312 | query_time.as_nanos() as f64 / 1000.0 313 | ); 314 | } 315 | } 316 | None => { 317 | if i == 0 { 318 | println!( 319 | " Sample query: ({:.5}, {:.5}) → No results", 320 | query_lat, query_lng 321 | ); 322 | } 323 | } 324 | } 325 | } 326 | 327 | let avg_time = query_times.iter().sum::().as_nanos() / query_times.len() as u128; 328 | let avg_results = total_results as f64 / num_queries as f64; 329 | 330 | println!( 331 | " Average results per query: {:.1}", 332 | avg_results 333 | ); 334 | println!( 335 | " Average query time: {:.2}μs", 336 | avg_time as f64 / 1000.0 337 | ); 338 | 339 | self.stats.range_query_times.extend(query_times); 340 | } 341 | } 342 | 343 | /// Calculate approximate distance between two geographic points 344 | fn calculate_distance(&self, lat1: f64, lng1: f64, lat2: f64, lng2: f64) -> f64 { 345 | let dlat = (lat2 - lat1).to_radians(); 346 | let dlng = (lng2 - lng1).to_radians(); 347 | let a = (dlat / 2.0).sin().powi(2) 348 | + lat1.to_radians().cos() * lat2.to_radians().cos() * (dlng / 2.0).sin().powi(2); 349 | let c = 2.0 * a.sqrt().atan2((1.0 - a).sqrt()); 350 | 6371000.0 * c // Earth radius in meters 351 | } 352 | 353 | /// Display comprehensive performance summary 354 | fn display_performance_summary(&self) { 355 | println!( 356 | "\n{}\n{}", 357 | "📈 Performance Summary".bright_blue().bold(), 358 | "=".repeat(50).bright_blue() 359 | ); 360 | 361 | println!("🗂️ Data Processing:"); 362 | println!( 363 | " Total points processed: {}", 364 | self.stats.total_points.to_string().bright_green() 365 | ); 366 | println!( 367 | " Data loading time: {:.2}ms", 368 | self.stats.data_loading_time.as_secs_f64() * 1000.0 369 | ); 370 | println!( 371 | " Index building time: {:.2}ms", 372 | self.stats.index_building_time.as_secs_f64() * 1000.0 373 | ); 374 | println!( 375 | " Points per second (indexing): {:.0}", 376 | self.stats.total_points as f64 / self.stats.index_building_time.as_secs_f64() 377 | ); 378 | 379 | if !self.stats.nearest_neighbor_times.is_empty() { 380 | let nn_avg = self.stats.nearest_neighbor_times.iter().sum::().as_nanos() 381 | / self.stats.nearest_neighbor_times.len() as u128; 382 | let nn_min = self.stats.nearest_neighbor_times.iter().min().unwrap().as_nanos(); 383 | let nn_max = self.stats.nearest_neighbor_times.iter().max().unwrap().as_nanos(); 384 | 385 | println!("\n🎯 Nearest Neighbor Performance:"); 386 | println!( 387 | " Average time: {:.2}μs", 388 | nn_avg as f64 / 1000.0 389 | ); 390 | println!( 391 | " Min time: {:.2}μs", 392 | nn_min as f64 / 1000.0 393 | ); 394 | println!( 395 | " Max time: {:.2}μs", 396 | nn_max as f64 / 1000.0 397 | ); 398 | println!( 399 | " Queries per second: {:.0}", 400 | 1_000_000.0 / (nn_avg as f64 / 1000.0) 401 | ); 402 | } 403 | 404 | if !self.stats.range_query_times.is_empty() { 405 | let rq_avg = self.stats.range_query_times.iter().sum::().as_nanos() 406 | / self.stats.range_query_times.len() as u128; 407 | 408 | println!("\n🌐 Range Query Performance:"); 409 | println!( 410 | " Average time: {:.2}μs", 411 | rq_avg as f64 / 1000.0 412 | ); 413 | println!( 414 | " Queries per second: {:.0}", 415 | 1_000_000.0 / (rq_avg as f64 / 1000.0) 416 | ); 417 | } 418 | 419 | println!("\n💾 Memory Usage:"); 420 | println!( 421 | " Estimated total: {:.2} MB", 422 | self.stats.memory_usage_estimate as f64 / 1_048_576.0 423 | ); 424 | println!( 425 | " Per point: {:.1} bytes", 426 | self.stats.memory_usage_estimate as f64 / self.stats.total_points as f64 427 | ); 428 | } 429 | 430 | /// Run interactive mode 431 | fn run_interactive(&mut self) { 432 | println!( 433 | "\n{}\n{}", 434 | "🎮 Interactive Mode".bright_blue().bold(), 435 | "=".repeat(50).bright_blue() 436 | ); 437 | println!("Enter coordinates to find nearest neighbors (format: lat,lng) or 'quit' to exit:"); 438 | 439 | loop { 440 | print!("🔍 Query: "); 441 | use std::io::{self, Write}; 442 | io::stdout().flush().unwrap(); 443 | 444 | let mut input = String::new(); 445 | match io::stdin().read_line(&mut input) { 446 | Ok(_) => { 447 | let input = input.trim(); 448 | if input.eq_ignore_ascii_case("quit") || input.eq_ignore_ascii_case("exit") { 449 | break; 450 | } 451 | 452 | let coords: Vec<&str> = input.split(',').collect(); 453 | if coords.len() != 2 { 454 | println!("❌ Invalid format. Use: lat,lng (e.g., -37.8136,144.9631)"); 455 | continue; 456 | } 457 | 458 | match (coords[0].trim().parse::(), coords[1].trim().parse::()) { 459 | (Ok(lat), Ok(lng)) => { 460 | let query_point = [lat, lng]; 461 | let start_time = Instant::now(); 462 | 463 | match self.spatial_map.nearest_neighbor(&query_point) { 464 | Some(nearest) => { 465 | let query_time = start_time.elapsed(); 466 | let distance = self.calculate_distance(lat, lng, nearest.x(), nearest.y()); 467 | 468 | println!( 469 | "✅ Nearest point: ({:.5}, {:.5})", 470 | nearest.x(), nearest.y() 471 | ); 472 | println!( 473 | "📏 Distance: {:.2}m | ⏱️ Query time: {:.2}μs", 474 | distance, 475 | query_time.as_nanos() as f64 / 1000.0 476 | ); 477 | } 478 | None => { 479 | println!("❌ No nearest neighbor found"); 480 | } 481 | } 482 | } 483 | _ => { 484 | println!("❌ Invalid coordinates. Use decimal format (e.g., -37.8136,144.9631)"); 485 | } 486 | } 487 | } 488 | Err(e) => { 489 | eprintln!("Error reading input: {}", e); 490 | break; 491 | } 492 | } 493 | } 494 | 495 | println!("👋 Goodbye!"); 496 | } 497 | } 498 | 499 | fn main() -> Result<(), Box> { 500 | let matches = Command::new("LSPH Demo") 501 | .version("0.1.0") 502 | .author("LSPH Contributors") 503 | .about("Demonstrates LSPH capabilities with Melbourne geographic data") 504 | .arg( 505 | Arg::new("data") 506 | .short('d') 507 | .long("data") 508 | .value_name("FILE") 509 | .help("Path to the CSV data file") 510 | .default_value("melbourne.csv") 511 | ) 512 | .arg( 513 | Arg::new("queries") 514 | .short('q') 515 | .long("queries") 516 | .value_name("NUMBER") 517 | .help("Number of test queries to perform") 518 | .default_value("100") 519 | ) 520 | .arg( 521 | Arg::new("interactive") 522 | .short('i') 523 | .long("interactive") 524 | .help("Run in interactive mode") 525 | .action(clap::ArgAction::SetTrue) 526 | ) 527 | .arg( 528 | Arg::new("skip-demo") 529 | .long("skip-demo") 530 | .help("Skip automated demo and go straight to interactive mode") 531 | .action(clap::ArgAction::SetTrue) 532 | ) 533 | .get_matches(); 534 | 535 | let data_file = matches.get_one::("data").unwrap(); 536 | let num_queries: usize = matches.get_one::("queries").unwrap().parse()?; 537 | let interactive_mode = matches.get_flag("interactive"); 538 | let skip_demo = matches.get_flag("skip-demo"); 539 | 540 | println!( 541 | "{}\n{}\n{}", 542 | "🗺️ LSPH Geographic Data Demo".bright_blue().bold(), 543 | "Learned Spatial HashMap Performance Demonstration".bright_white(), 544 | "=".repeat(60).bright_blue() 545 | ); 546 | 547 | let mut demo = LSPHDemo::new(); 548 | 549 | // Load data 550 | demo.load_data(data_file)?; 551 | 552 | // Build spatial index 553 | demo.build_index()?; 554 | 555 | if !skip_demo { 556 | // Run performance demonstrations 557 | demo.demo_nearest_neighbor(num_queries); 558 | demo.demo_range_queries(num_queries / 4); // Fewer range queries as they're more expensive 559 | 560 | // Display comprehensive summary 561 | demo.display_performance_summary(); 562 | } 563 | 564 | // Run interactive mode if requested 565 | if interactive_mode || skip_demo { 566 | demo.run_interactive(); 567 | } 568 | 569 | println!( 570 | "\n{}\n{}", 571 | "🎉 Demo completed successfully!".bright_green().bold(), 572 | "Thank you for exploring LSPH capabilities.".bright_white() 573 | ); 574 | 575 | Ok(()) 576 | } 577 | -------------------------------------------------------------------------------- /examples/interactive_demo/src/main.rs: -------------------------------------------------------------------------------- 1 | use eframe::egui; 2 | use lsph::{ 3 | geometry::Point, 4 | map::LearnedHashMap, 5 | models::LinearModel, 6 | }; 7 | use rand::Rng; 8 | use std::collections::HashMap; 9 | 10 | #[derive(Default)] 11 | struct LSPHDemo { 12 | // Core LSPH data structure 13 | spatial_map: LearnedHashMap, f64>, 14 | 15 | // UI state 16 | points: Vec>, 17 | point_colors: HashMap, 18 | 19 | // Input fields 20 | input_x: String, 21 | input_y: String, 22 | input_value: String, 23 | 24 | // Search parameters 25 | search_x: String, 26 | search_y: String, 27 | search_radius: f32, 28 | 29 | // Visualization settings 30 | point_size: f32, 31 | show_grid: bool, 32 | 33 | // Demo modes 34 | demo_mode: DemoMode, 35 | auto_generate: bool, 36 | generation_speed: f32, 37 | 38 | // Search results 39 | nearest_neighbor: Option>, 40 | range_results: Vec>, 41 | 42 | // Statistics 43 | total_points: usize, 44 | last_search_time: Option, 45 | } 46 | 47 | #[derive(Default, PartialEq)] 48 | enum DemoMode { 49 | #[default] 50 | Manual, 51 | RandomGeneration, 52 | NearestNeighbor, 53 | RangeQuery, 54 | } 55 | 56 | impl LSPHDemo { 57 | fn new(_cc: &eframe::CreationContext<'_>) -> Self { 58 | Self { 59 | spatial_map: LearnedHashMap::new(), 60 | point_size: 4.0, 61 | show_grid: true, 62 | search_radius: 50.0, 63 | generation_speed: 10.0, 64 | input_x: "0.5".to_string(), 65 | input_y: "0.5".to_string(), 66 | input_value: "1.0".to_string(), 67 | search_x: "0.5".to_string(), 68 | search_y: "0.5".to_string(), 69 | ..Default::default() 70 | } 71 | } 72 | 73 | fn add_point(&mut self, x: f64, y: f64, value: f64) { 74 | let point = Point::new(x, y); 75 | 76 | // Add to LSPH 77 | let _existing = self.spatial_map.insert(point); 78 | 79 | // Add to visualization 80 | let index = self.points.len(); 81 | self.points.push(point); 82 | 83 | // Assign a color based on value 84 | let color = self.value_to_color(value); 85 | self.point_colors.insert(index, color); 86 | 87 | self.total_points += 1; 88 | } 89 | 90 | fn value_to_color(&self, value: f64) -> egui::Color32 { 91 | let normalized = (value.abs() % 10.0) / 10.0; 92 | let hue = normalized * 360.0; 93 | egui::Color32::from_rgb( 94 | (hue.sin() * 127.0 + 128.0) as u8, 95 | ((hue + 120.0).to_radians().sin() * 127.0 + 128.0) as u8, 96 | ((hue + 240.0).to_radians().sin() * 127.0 + 128.0) as u8, 97 | ) 98 | } 99 | 100 | fn generate_random_points(&mut self, count: usize) { 101 | let mut rng = rand::rng(); 102 | for _ in 0..count { 103 | let x = rng.random_range(0.0..1.0); 104 | let y = rng.random_range(0.0..1.0); 105 | let value = rng.random_range(-10.0..10.0); 106 | self.add_point(x, y, value); 107 | } 108 | } 109 | 110 | fn find_nearest_neighbor(&mut self, x: f64, y: f64) { 111 | let query_point = [x, y]; 112 | let start = std::time::Instant::now(); 113 | 114 | match self.spatial_map.nearest_neighbor(&query_point) { 115 | Some(point) => { 116 | self.nearest_neighbor = Some(point); 117 | } 118 | None => { 119 | self.nearest_neighbor = None; 120 | } 121 | } 122 | 123 | self.last_search_time = Some(start.elapsed()); 124 | } 125 | 126 | fn range_query(&mut self, x: f64, y: f64, radius: f64) { 127 | let query_point = [x, y]; 128 | let start = std::time::Instant::now(); 129 | 130 | match self.spatial_map.radius_range(&query_point, radius) { 131 | Some(results) => { 132 | self.range_results = results; 133 | } 134 | None => { 135 | self.range_results.clear(); 136 | } 137 | } 138 | 139 | self.last_search_time = Some(start.elapsed()); 140 | } 141 | 142 | fn clear_all(&mut self) { 143 | self.spatial_map = LearnedHashMap::new(); 144 | self.points.clear(); 145 | self.point_colors.clear(); 146 | self.nearest_neighbor = None; 147 | self.range_results.clear(); 148 | self.total_points = 0; 149 | self.last_search_time = None; 150 | } 151 | 152 | fn canvas_to_world(&self, canvas_pos: egui::Pos2, canvas_rect: egui::Rect) -> (f64, f64) { 153 | let x = (canvas_pos.x - canvas_rect.min.x) / canvas_rect.width(); 154 | let y = 1.0 - (canvas_pos.y - canvas_rect.min.y) / canvas_rect.height(); 155 | (x as f64, y as f64) 156 | } 157 | 158 | fn world_to_canvas(&self, x: f64, y: f64, canvas_rect: egui::Rect) -> egui::Pos2 { 159 | let canvas_x = canvas_rect.min.x + (x as f32) * canvas_rect.width(); 160 | let canvas_y = canvas_rect.min.y + (1.0 - y as f32) * canvas_rect.height(); 161 | egui::Pos2::new(canvas_x, canvas_y) 162 | } 163 | } 164 | 165 | impl eframe::App for LSPHDemo { 166 | fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) { 167 | // Auto-generation in random mode 168 | if self.auto_generate && self.demo_mode == DemoMode::RandomGeneration { 169 | if ctx.input(|i| i.time) as f32 % (1.0 / self.generation_speed) < 0.016 { 170 | self.generate_random_points(1); 171 | } 172 | } 173 | 174 | egui::CentralPanel::default().show(ctx, |ui| { 175 | ui.heading("🗺️ LSPH Interactive Demo"); 176 | ui.separator(); 177 | 178 | ui.horizontal(|ui| { 179 | // Left panel - Controls 180 | ui.vertical(|ui| { 181 | // Responsive control panel width (25-35% of available width) 182 | let available_width = ui.available_size().x; 183 | let control_width = (available_width * 0.3).clamp(250.0, 400.0); 184 | ui.set_width(control_width); 185 | 186 | // Demo mode selection 187 | ui.group(|ui| { 188 | ui.label("Demo Mode:"); 189 | ui.radio_value(&mut self.demo_mode, DemoMode::Manual, "Manual Point Addition"); 190 | ui.radio_value(&mut self.demo_mode, DemoMode::RandomGeneration, "Random Generation"); 191 | ui.radio_value(&mut self.demo_mode, DemoMode::NearestNeighbor, "Nearest Neighbor Search"); 192 | ui.radio_value(&mut self.demo_mode, DemoMode::RangeQuery, "Range Query"); 193 | }); 194 | 195 | ui.separator(); 196 | 197 | match self.demo_mode { 198 | DemoMode::Manual => { 199 | ui.group(|ui| { 200 | ui.label("Add Point:"); 201 | ui.horizontal(|ui| { 202 | ui.label("X:"); 203 | ui.text_edit_singleline(&mut self.input_x); 204 | }); 205 | ui.horizontal(|ui| { 206 | ui.label("Y:"); 207 | ui.text_edit_singleline(&mut self.input_y); 208 | }); 209 | ui.horizontal(|ui| { 210 | ui.label("Value:"); 211 | ui.text_edit_singleline(&mut self.input_value); 212 | }); 213 | 214 | if ui.button("Add Point").clicked() { 215 | if let (Ok(x), Ok(y), Ok(value)) = ( 216 | self.input_x.parse::(), 217 | self.input_y.parse::(), 218 | self.input_value.parse::(), 219 | ) { 220 | if x >= 0.0 && x <= 1.0 && y >= 0.0 && y <= 1.0 { 221 | self.add_point(x, y, value); 222 | } 223 | } 224 | } 225 | }); 226 | } 227 | 228 | DemoMode::RandomGeneration => { 229 | ui.group(|ui| { 230 | ui.label("Random Generation:"); 231 | ui.checkbox(&mut self.auto_generate, "Auto Generate"); 232 | ui.horizontal(|ui| { 233 | ui.label("Speed:"); 234 | ui.add(egui::Slider::new(&mut self.generation_speed, 0.1..=20.0)); 235 | }); 236 | 237 | if ui.button("Generate 10 Points").clicked() { 238 | self.generate_random_points(10); 239 | } 240 | if ui.button("Generate 100 Points").clicked() { 241 | self.generate_random_points(100); 242 | } 243 | }); 244 | } 245 | 246 | DemoMode::NearestNeighbor => { 247 | ui.group(|ui| { 248 | ui.label("Nearest Neighbor Search:"); 249 | ui.horizontal(|ui| { 250 | ui.label("Query X:"); 251 | ui.text_edit_singleline(&mut self.search_x); 252 | }); 253 | ui.horizontal(|ui| { 254 | ui.label("Query Y:"); 255 | ui.text_edit_singleline(&mut self.search_y); 256 | }); 257 | 258 | if ui.button("Find Nearest").clicked() { 259 | if let (Ok(x), Ok(y)) = ( 260 | self.search_x.parse::(), 261 | self.search_y.parse::(), 262 | ) { 263 | if x >= 0.0 && x <= 1.0 && y >= 0.0 && y <= 1.0 { 264 | self.find_nearest_neighbor(x, y); 265 | } 266 | } 267 | } 268 | 269 | if let Some(nn) = &self.nearest_neighbor { 270 | ui.label(format!("Nearest: ({:.3}, {:.3})", nn.x(), nn.y())); 271 | } 272 | }); 273 | } 274 | 275 | DemoMode::RangeQuery => { 276 | ui.group(|ui| { 277 | ui.label("Range Query:"); 278 | ui.horizontal(|ui| { 279 | ui.label("Center X:"); 280 | ui.text_edit_singleline(&mut self.search_x); 281 | }); 282 | ui.horizontal(|ui| { 283 | ui.label("Center Y:"); 284 | ui.text_edit_singleline(&mut self.search_y); 285 | }); 286 | ui.horizontal(|ui| { 287 | ui.label("Radius:"); 288 | ui.add(egui::Slider::new(&mut self.search_radius, 0.01..=0.5)); 289 | }); 290 | 291 | if ui.button("Search Range").clicked() { 292 | if let (Ok(x), Ok(y)) = ( 293 | self.search_x.parse::(), 294 | self.search_y.parse::(), 295 | ) { 296 | if x >= 0.0 && x <= 1.0 && y >= 0.0 && y <= 1.0 { 297 | self.range_query(x, y, self.search_radius as f64); 298 | } 299 | } 300 | } 301 | 302 | ui.label(format!("Found: {} points", self.range_results.len())); 303 | }); 304 | } 305 | } 306 | 307 | ui.separator(); 308 | 309 | // Visualization settings 310 | ui.group(|ui| { 311 | ui.label("Visualization:"); 312 | ui.checkbox(&mut self.show_grid, "Show Grid"); 313 | ui.horizontal(|ui| { 314 | ui.label("Point Size:"); 315 | ui.add(egui::Slider::new(&mut self.point_size, 1.0..=10.0)); 316 | }); 317 | }); 318 | 319 | ui.separator(); 320 | 321 | // Statistics 322 | ui.group(|ui| { 323 | ui.label("Statistics:"); 324 | ui.label(format!("Total Points: {}", self.total_points)); 325 | if let Some(time) = self.last_search_time { 326 | ui.label(format!("Last Search: {:.2}ms", time.as_secs_f64() * 1000.0)); 327 | } 328 | }); 329 | 330 | ui.separator(); 331 | 332 | if ui.button("Clear All").clicked() { 333 | self.clear_all(); 334 | } 335 | }); 336 | 337 | ui.separator(); 338 | 339 | // Right panel - Canvas 340 | ui.vertical(|ui| { 341 | // Calculate responsive canvas size based on available space 342 | let available_size = ui.available_size(); 343 | let canvas_width = available_size.x - 20.0; // Leave some margin 344 | let canvas_height = available_size.y - 40.0; // Leave space for label 345 | 346 | // Maintain square aspect ratio for better spatial visualization 347 | let canvas_size = canvas_width.min(canvas_height).max(200.0); // Minimum 200px 348 | let canvas_vec = egui::Vec2::splat(canvas_size); 349 | 350 | let (response, painter) = ui.allocate_painter(canvas_vec, egui::Sense::click()); 351 | let canvas_rect = response.rect; 352 | 353 | // Handle canvas clicks 354 | if response.clicked() { 355 | if let Some(click_pos) = response.interact_pointer_pos() { 356 | let (world_x, world_y) = self.canvas_to_world(click_pos, canvas_rect); 357 | 358 | match self.demo_mode { 359 | DemoMode::Manual => { 360 | self.add_point(world_x, world_y, 1.0); 361 | } 362 | DemoMode::NearestNeighbor => { 363 | self.find_nearest_neighbor(world_x, world_y); 364 | self.search_x = format!("{:.3}", world_x); 365 | self.search_y = format!("{:.3}", world_y); 366 | } 367 | DemoMode::RangeQuery => { 368 | self.range_query(world_x, world_y, self.search_radius as f64); 369 | self.search_x = format!("{:.3}", world_x); 370 | self.search_y = format!("{:.3}", world_y); 371 | } 372 | _ => {} 373 | } 374 | } 375 | } 376 | 377 | // Draw background 378 | painter.rect_filled(canvas_rect, 0.0, egui::Color32::WHITE); 379 | painter.rect_stroke(canvas_rect, 0.0, egui::Stroke::new(1.0, egui::Color32::BLACK)); 380 | 381 | // Draw grid 382 | if self.show_grid { 383 | let grid_color = egui::Color32::from_gray(230); 384 | for i in 1..10 { 385 | let x = canvas_rect.min.x + (i as f32 / 10.0) * canvas_rect.width(); 386 | let y = canvas_rect.min.y + (i as f32 / 10.0) * canvas_rect.height(); 387 | 388 | painter.line_segment( 389 | [egui::Pos2::new(x, canvas_rect.min.y), egui::Pos2::new(x, canvas_rect.max.y)], 390 | egui::Stroke::new(0.5, grid_color), 391 | ); 392 | painter.line_segment( 393 | [egui::Pos2::new(canvas_rect.min.x, y), egui::Pos2::new(canvas_rect.max.x, y)], 394 | egui::Stroke::new(0.5, grid_color), 395 | ); 396 | } 397 | } 398 | 399 | // Draw points with responsive sizing 400 | let scale_factor = (canvas_rect.width() / 400.0).clamp(0.5, 2.0); // Scale relative to 400px baseline 401 | let scaled_point_size = self.point_size * scale_factor; 402 | 403 | for (i, point) in self.points.iter().enumerate() { 404 | let canvas_pos = self.world_to_canvas(point.x(), point.y(), canvas_rect); 405 | let color = self.point_colors.get(&i).copied().unwrap_or(egui::Color32::BLUE); 406 | painter.circle_filled(canvas_pos, scaled_point_size, color); 407 | } 408 | 409 | // Draw search query point with responsive sizing 410 | if self.demo_mode == DemoMode::NearestNeighbor || self.demo_mode == DemoMode::RangeQuery { 411 | if let (Ok(x), Ok(y)) = (self.search_x.parse::(), self.search_y.parse::()) { 412 | if x >= 0.0 && x <= 1.0 && y >= 0.0 && y <= 1.0 { 413 | let query_pos = self.world_to_canvas(x, y, canvas_rect); 414 | let scaled_query_size = 8.0 * scale_factor; 415 | let scaled_stroke_width = 2.0 * scale_factor.sqrt(); 416 | painter.circle_stroke(query_pos, scaled_query_size, egui::Stroke::new(scaled_stroke_width, egui::Color32::RED)); 417 | 418 | // Draw range circle for range queries 419 | if self.demo_mode == DemoMode::RangeQuery { 420 | // Scale radius proportionally to canvas size 421 | let radius_pixels = self.search_radius * canvas_rect.width().min(canvas_rect.height()); 422 | painter.circle_stroke( 423 | query_pos, 424 | radius_pixels, 425 | egui::Stroke::new(1.0, egui::Color32::from_rgba_unmultiplied(255, 0, 0, 100)), 426 | ); 427 | } 428 | } 429 | } 430 | } 431 | 432 | // Highlight nearest neighbor with responsive sizing 433 | if let Some(nn) = &self.nearest_neighbor { 434 | let nn_pos = self.world_to_canvas(nn.x(), nn.y(), canvas_rect); 435 | let highlight_size = scaled_point_size + 3.0 * scale_factor; 436 | let highlight_stroke = 2.0 * scale_factor.sqrt(); 437 | painter.circle_stroke(nn_pos, highlight_size, egui::Stroke::new(highlight_stroke, egui::Color32::GREEN)); 438 | } 439 | 440 | // Highlight range query results with responsive sizing 441 | for result in &self.range_results { 442 | let result_pos = self.world_to_canvas(result.x(), result.y(), canvas_rect); 443 | let result_highlight_size = scaled_point_size + 2.0 * scale_factor; 444 | let result_stroke = 1.5 * scale_factor.sqrt(); 445 | painter.circle_stroke(result_pos, result_highlight_size, egui::Stroke::new(result_stroke, egui::Color32::YELLOW)); 446 | } 447 | 448 | ui.label("💡 Click on the canvas to interact!"); 449 | }); 450 | }); 451 | }); 452 | 453 | // Request repaint for animations 454 | if self.auto_generate { 455 | ctx.request_repaint(); 456 | } 457 | } 458 | } 459 | 460 | fn main() -> Result<(), eframe::Error> { 461 | let options = eframe::NativeOptions { 462 | viewport: egui::ViewportBuilder::default() 463 | .with_inner_size([1000.0, 700.0]) // Larger default size 464 | .with_min_inner_size([600.0, 400.0]) // Minimum window size 465 | .with_resizable(true) // Allow resizing 466 | .with_title("LSPH Interactive Demo"), 467 | ..Default::default() 468 | }; 469 | 470 | eframe::run_native( 471 | "LSPH Demo", 472 | options, 473 | Box::new(|cc| Box::new(LSPHDemo::new(cc))), 474 | ) 475 | } -------------------------------------------------------------------------------- /src/map/mod.rs: -------------------------------------------------------------------------------- 1 | mod nn; 2 | mod table; 3 | 4 | use crate::{ 5 | error::*, 6 | geometry::{distance::*, Point}, 7 | hasher::*, 8 | map::{nn::*, table::*}, 9 | models::Model, 10 | }; 11 | use core::{fmt::Debug, iter::Sum, mem}; 12 | use num_traits::{ 13 | cast::{AsPrimitive, FromPrimitive}, 14 | float::Float, 15 | }; 16 | use std::collections::BinaryHeap; 17 | 18 | /// Initial bucket size is set to 1 19 | const INITIAL_NBUCKETS: usize = 1; 20 | 21 | /// LearnedHashMap takes a model instead of an hasher for hashing indexes in the table. 22 | /// 23 | /// Default Model for the LearndedHashMap is Linear regression. 24 | /// In order to build a ordered HashMap, we need to make sure that the model is **monotonic**. 25 | #[derive(Debug, Clone)] 26 | pub struct LearnedHashMap { 27 | hasher: LearnedHasher, 28 | table: Table>, 29 | items: usize, 30 | } 31 | 32 | /// Default for the LearndedHashMap. 33 | impl Default for LearnedHashMap 34 | where 35 | F: Float, 36 | M: Model + Default, 37 | { 38 | fn default() -> Self { 39 | Self { 40 | hasher: LearnedHasher::::new(), 41 | table: Table::new(), 42 | items: 0, 43 | } 44 | } 45 | } 46 | 47 | impl LearnedHashMap 48 | where 49 | F: Float + Default + AsPrimitive + FromPrimitive + Debug + Sum, 50 | M: Model + Default + Clone, 51 | { 52 | /// Returns a default LearnedHashMap with Model and Float type. 53 | /// 54 | /// # Examples 55 | /// 56 | /// ``` 57 | /// use lsph::{LearnedHashMap, LinearModel}; 58 | /// let map = LearnedHashMap::, f64>::new(); 59 | /// ``` 60 | #[inline] 61 | pub fn new() -> Self { 62 | Self::default() 63 | } 64 | 65 | /// Returns a default LearnedHashMap with Model and Float type. 66 | /// 67 | /// # Arguments 68 | /// * `hasher` - A LearnedHasher with model 69 | /// 70 | /// # Examples 71 | /// 72 | /// ``` 73 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 74 | /// let map = LearnedHashMap::, f64>::with_hasher(LearnedHasher::new()); 75 | /// ``` 76 | #[inline] 77 | pub fn with_hasher(hasher: LearnedHasher) -> Self { 78 | Self { 79 | hasher, 80 | table: Table::new(), 81 | items: 0, 82 | } 83 | } 84 | 85 | /// Returns a default LearnedHashMap with Model and Float type. 86 | /// 87 | /// # Arguments 88 | /// * `capacity` - A predefined capacity size for the LearnedHashMap 89 | /// 90 | /// # Examples 91 | /// 92 | /// ``` 93 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 94 | /// let map = LearnedHashMap::, f64>::with_capacity(10usize); 95 | /// ``` 96 | #[inline] 97 | pub fn with_capacity(capacity: usize) -> Self { 98 | Self { 99 | hasher: Default::default(), 100 | table: Table::with_capacity(capacity), 101 | items: 0, 102 | } 103 | } 104 | 105 | /// Returns a default LearnedHashMap with Model and Float type 106 | /// 107 | /// # Arguments 108 | /// * `data` - A Vec<[F; 2]> of 2d points for the map 109 | /// 110 | /// # Examples 111 | /// 112 | /// ``` 113 | /// use lsph::{LearnedHashMap, LinearModel}; 114 | /// let data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 115 | /// let map = LearnedHashMap::, f64>::with_data(&data); 116 | /// ``` 117 | #[inline] 118 | pub fn with_data(data: &[[F; 2]]) -> Result<(Self, Vec>), Error> { 119 | use crate::helper::convert_to_points; 120 | let mut map = LearnedHashMap::with_capacity(data.len()); 121 | let mut ps = convert_to_points(data).unwrap(); 122 | match map.batch_insert(&mut ps) { 123 | Ok(()) => Ok((map, ps)), 124 | Err(err) => Err(err), 125 | } 126 | } 127 | 128 | /// Returns Option> with given point data. 129 | /// 130 | /// # Arguments 131 | /// * `p` - A array slice containing two points for querying 132 | /// 133 | /// # Examples 134 | /// 135 | /// ``` 136 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 137 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 138 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 139 | /// 140 | /// assert_eq!(map.get(&[1., 1.]).is_some(), true); 141 | /// ``` 142 | #[inline] 143 | pub fn get(&mut self, p: &[F; 2]) -> Option<&Point> { 144 | let hash = make_hash_point(&mut self.hasher, p) as usize; 145 | if hash > self.table.capacity() { 146 | return None; 147 | } 148 | self.find_by_hash(hash, p) 149 | } 150 | 151 | /// Returns Option> by hash index, if it exists in the map. 152 | /// 153 | /// # Arguments 154 | /// * `hash` - An usize hash value 155 | /// * `p` - A array slice containing two points for querying 156 | /// 157 | /// # Examples 158 | /// 159 | /// ``` 160 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 161 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 162 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 163 | /// 164 | /// assert_eq!(map.find_by_hash(0, &[1., 1.]).is_some(), true); 165 | /// assert_eq!(map.find_by_hash(1, &[1., 1.]).is_none(), true); 166 | /// ``` 167 | pub fn find_by_hash(&self, hash: usize, p: &[F; 2]) -> Option<&Point> { 168 | self.table[hash] 169 | .iter() 170 | .find(|&ep| ep.x == p[0] && ep.y == p[1]) 171 | } 172 | 173 | /// Returns bool. 174 | /// 175 | /// # Arguments 176 | /// * `p` - A array slice containing two points for querying 177 | /// 178 | /// # Examples 179 | /// 180 | /// ``` 181 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 182 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 183 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 184 | /// 185 | /// assert_eq!(map.contains_points(&[1., 1.]), true); 186 | /// assert_eq!(map.contains_points(&[0., 1.]), false); 187 | /// ``` 188 | #[inline] 189 | pub fn contains_points(&mut self, p: &[F; 2]) -> bool { 190 | self.get(p).is_some() 191 | } 192 | 193 | /// Returns Option> if the map contains a point and successful remove it from the map. 194 | /// 195 | /// # Arguments 196 | /// * `p` - A Point data 197 | /// 198 | /// # Examples 199 | /// 200 | /// ``` 201 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 202 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 203 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 204 | /// 205 | /// let p = points[0]; 206 | /// assert_eq!(map.remove(&p).unwrap(), p); 207 | /// ``` 208 | #[inline] 209 | pub fn remove(&mut self, p: &Point) -> Option> { 210 | let hash = make_hash_point(&mut self.hasher, &[p.x, p.y]); 211 | self.items -= 1; 212 | self.table.remove_entry(hash, *p) 213 | } 214 | 215 | /// Returns usize length. 216 | /// 217 | /// # Examples 218 | /// 219 | /// ``` 220 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 221 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 222 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 223 | /// 224 | /// assert_eq!(map.len(), 4); 225 | /// ``` 226 | #[inline] 227 | pub fn len(&self) -> usize { 228 | self.table.len() 229 | } 230 | 231 | /// Returns usize number of items. 232 | /// 233 | /// # Examples 234 | /// 235 | /// ``` 236 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 237 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 238 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 239 | /// 240 | /// assert_eq!(map.items(), 4); 241 | /// ``` 242 | #[inline] 243 | pub fn items(&self) -> usize { 244 | self.items 245 | } 246 | 247 | /// Returns bool if the map is empty. 248 | /// 249 | /// # Examples 250 | /// 251 | /// ``` 252 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 253 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 254 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 255 | /// 256 | /// assert_eq!(map.is_empty(), false); 257 | /// ``` 258 | #[inline] 259 | pub fn is_empty(&self) -> bool { 260 | self.items == 0 261 | } 262 | 263 | /// Resize the map if needed, it will initialize the map to the INITIAL_NBUCKETS, otherwise it will double the capacity if table is not empty. 264 | fn resize(&mut self) { 265 | let target_size = match self.table.len() { 266 | 0 => INITIAL_NBUCKETS, 267 | n => 2 * n, 268 | }; 269 | self.resize_with_capacity(target_size); 270 | } 271 | 272 | /// Resize the map if needed, it will resize the map to desired capacity. 273 | #[inline] 274 | fn resize_with_capacity(&mut self, target_size: usize) { 275 | let mut new_table = Table::with_capacity(target_size); 276 | new_table.extend((0..target_size).map(|_| Bucket::new())); 277 | 278 | for p in self.table.iter_mut().flat_map(|bucket| bucket.drain(..)) { 279 | let hash = make_hash_point(&mut self.hasher, &[p.x, p.y]) as usize; 280 | new_table[hash].push(p); 281 | } 282 | 283 | self.table = new_table; 284 | } 285 | 286 | /// Rehash the map. 287 | #[inline] 288 | fn rehash(&mut self) -> Result<(), Error> { 289 | let mut old_data = Vec::with_capacity(self.items()); 290 | for p in self.table.iter_mut().flat_map(|bucket| bucket.drain(..)) { 291 | old_data.push(p); 292 | } 293 | self.batch_insert(&mut old_data) 294 | } 295 | 296 | /// Inner function for insert a single point into the map 297 | #[inline] 298 | fn insert_inner(&mut self, p: Point) -> Option> { 299 | // Resize if the table is empty or 3/4 size of the table is full 300 | if self.table.is_empty() || self.items() > 3 * self.table.len() / 4 { 301 | self.resize(); 302 | } 303 | let hash = make_hash_point::(&mut self.hasher, &[p.x, p.y]); 304 | self.insert_with_axis(p, hash) 305 | } 306 | 307 | /// Sequencial insert a point into the map. 308 | /// 309 | /// # Arguments 310 | /// * `p` - A Point with float number 311 | /// 312 | /// # Examples 313 | /// 314 | /// ``` 315 | /// use lsph::{LearnedHashMap, LinearModel, Point}; 316 | /// let a: Point = Point::new(0., 1.); 317 | /// let b: Point = Point::new(1., 0.); 318 | 319 | /// let mut map = LearnedHashMap::, f64>::new(); 320 | /// map.insert(a); 321 | /// map.insert(b); 322 | 323 | /// assert_eq!(map.items(), 2); 324 | /// assert_eq!(map.get(&[0., 1.]).unwrap(), &a); 325 | /// assert_eq!(map.get(&[1., 0.]).unwrap(), &b); 326 | /// ``` 327 | pub fn insert(&mut self, p: Point) -> Option> { 328 | // Resize if the table is empty or 3/4 size of the table is full 329 | if self.table.is_empty() || self.items() > 3 * self.table.len() / 4 { 330 | self.resize(); 331 | } 332 | 333 | let hash = make_hash_point::(&mut self.hasher, &[p.x, p.y]); 334 | // resize if hash index is larger or equal to the table capacity 335 | if hash >= self.table.capacity() as u64 { 336 | self.resize_with_capacity(hash as usize * 2); 337 | self.insert_with_axis(p, hash); 338 | match self.rehash() { 339 | Ok(_) => None, 340 | Err(_err) => { 341 | None 342 | } 343 | } 344 | } else { 345 | self.insert_with_axis(p, hash) 346 | } 347 | } 348 | 349 | /// Insert a point into the map along the given axis. 350 | /// 351 | /// # Arguments 352 | /// * `p_value` - A float number represent the key of a 2d point 353 | #[inline] 354 | fn insert_with_axis(&mut self, p: Point, hash: u64) -> Option> { 355 | let mut insert_index = 0; 356 | let bucket_index = self.table.bucket(hash); 357 | let bucket = &mut self.table[bucket_index]; 358 | if self.hasher.sort_by_x() { 359 | // Get index from the hasher 360 | for ep in bucket.iter_mut() { 361 | if ep == &mut p.clone() { 362 | return Some(mem::replace(ep, p)); 363 | } 364 | if ep.y < p.y() { 365 | insert_index += 1; 366 | } 367 | } 368 | } else { 369 | for ep in bucket.iter_mut() { 370 | if ep == &mut p.clone() { 371 | return Some(mem::replace(ep, p)); 372 | } 373 | if ep.x < p.x() { 374 | insert_index += 1; 375 | } 376 | } 377 | } 378 | bucket.insert(insert_index, p); 379 | self.items += 1; 380 | None 381 | } 382 | 383 | /// Fit the input data into the model of the hasher. Returns Error if error occurred during 384 | /// model fitting. 385 | /// 386 | /// # Arguments 387 | /// 388 | /// * `xs` - A list of tuple of floating number 389 | /// * `ys` - A list of tuple of floating number 390 | #[inline] 391 | pub fn model_fit(&mut self, xs: &[F], ys: &[F]) -> Result<(), Error> { 392 | self.hasher.model.fit(xs, ys) 393 | } 394 | 395 | /// Fit the input data into the model of the hasher. Returns Error if error occurred during 396 | /// model fitting. 397 | /// 398 | /// # Arguments 399 | /// * `data` - A list of tuple of floating number 400 | #[inline] 401 | pub fn model_fit_tuple(&mut self, data: &[(F, F)]) -> Result<(), Error> { 402 | self.hasher.model.fit_tuple(data) 403 | } 404 | 405 | /// Inner function for batch insert 406 | #[inline] 407 | fn batch_insert_inner(&mut self, ps: &[Point]) { 408 | // Allocate table capacity before insert 409 | let n = ps.len(); 410 | self.resize_with_capacity(n); 411 | for p in ps.iter() { 412 | self.insert_inner(*p); 413 | } 414 | } 415 | 416 | /// Batch insert a batch of 2d data into the map. 417 | /// 418 | /// # Arguments 419 | /// * `ps` - A list of point number 420 | /// 421 | /// # Examples 422 | /// 423 | /// ``` 424 | /// use lsph::{LearnedHashMap, LinearModel}; 425 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 426 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 427 | /// 428 | /// assert_eq!(map.get(&[1., 1.]).is_some(), true); 429 | /// ``` 430 | #[inline] 431 | pub fn batch_insert(&mut self, ps: &mut [Point]) -> Result<(), Error> { 432 | // Select suitable axis for training 433 | use crate::geometry::Axis; 434 | use crate::models::Trainer; 435 | 436 | // Loading data into trainer 437 | if let Ok(trainer) = Trainer::with_points(ps) { 438 | trainer.train(&mut self.hasher.model).unwrap(); 439 | let axis = trainer.axis(); 440 | match axis { 441 | Axis::X => self.hasher.set_sort_by_x(true), 442 | _ => self.hasher.set_sort_by_x(false), 443 | }; 444 | 445 | // Fit the data into model 446 | self.model_fit(trainer.train_x(), trainer.train_y()) 447 | .unwrap(); 448 | // Batch insert into the map 449 | self.batch_insert_inner(ps); 450 | } 451 | Ok(()) 452 | } 453 | 454 | /// Range search finds all points for a given 2d range. 455 | /// Returns all the points within the given range. 456 | /// ```text 457 | /// | top right 458 | /// | .-----------* 459 | /// | | . . | 460 | /// | | . . . | 461 | /// | | . | 462 | /// bottom left *-----------. 463 | /// | 464 | /// | | | 465 | /// |________v___________v________ 466 | /// left right 467 | /// hash hash 468 | /// ``` 469 | /// # Arguments 470 | /// 471 | /// * `bottom_left` - A tuple containing a pair of points that represent the bottom left of the 472 | /// range. 473 | /// 474 | /// * `top_right` - A tuple containing a pair of points that represent the top right of the 475 | /// range. 476 | #[inline] 477 | pub fn range_search( 478 | &mut self, 479 | bottom_left: &[F; 2], 480 | top_right: &[F; 2], 481 | ) -> Option>> { 482 | let mut right_hash = make_hash_point(&mut self.hasher, top_right) as usize; 483 | if right_hash > self.table.capacity() { 484 | right_hash = self.table.capacity() - 1; 485 | } 486 | let left_hash = make_hash_point(&mut self.hasher, bottom_left) as usize; 487 | if left_hash > self.table.capacity() || left_hash > right_hash { 488 | return None; 489 | } 490 | let mut result: Vec> = Vec::new(); 491 | for i in left_hash..=right_hash { 492 | let bucket = &self.table[i]; 493 | for item in bucket.iter() { 494 | if item.x >= bottom_left[0] 495 | && item.x <= top_right[0] 496 | && item.y >= bottom_left[1] 497 | && item.y <= top_right[1] 498 | { 499 | result.push(*item); 500 | } 501 | } 502 | } 503 | if result.is_empty() { 504 | return None; 505 | } 506 | Some(result) 507 | } 508 | 509 | /// Returns Option>> if points are found in the map with given range 510 | /// 511 | /// # Arguments 512 | /// * `query_point` - A Point data for querying 513 | /// * `radius` - A radius value 514 | /// 515 | /// # Examples 516 | /// 517 | /// ``` 518 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 519 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 520 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 521 | /// assert_eq!(map.range_search(&[0., 0.], &[3., 3.]).is_some(), true); 522 | /// ``` 523 | #[inline] 524 | pub fn radius_range(&mut self, query_point: &[F; 2], radius: F) -> Option>> { 525 | self.range_search( 526 | &[query_point[0] - radius, query_point[1] - radius], 527 | &[query_point[0] + radius, query_point[1] + radius], 528 | ) 529 | } 530 | 531 | /// Find the local minimum distance between query points and cadidates neighbors, then store 532 | /// the cadidates neighbors in the min_heap. 533 | /// 534 | /// 535 | /// # Arguments 536 | /// * `heap` - mutable borrow of an BinaryHeap 537 | /// * `local_hash` - A hash index of local bucket 538 | /// * `query_point` - A Point data 539 | /// * `min_d` - minimum distance 540 | /// * `nearest_neighbor` - mutable borrow of an point data, which is the nearest neighbor at 541 | /// search index bucket 542 | #[inline] 543 | fn local_min_heap( 544 | &self, 545 | heap: &mut BinaryHeap>, 546 | local_hash: u64, 547 | query_point: &[F; 2], 548 | min_d: &mut F, 549 | nearest_neighbor: &mut Point, 550 | ) { 551 | let bucket = &self.table[local_hash as usize]; 552 | if !bucket.is_empty() { 553 | for p in bucket.iter() { 554 | let d = Euclidean::distance(query_point, &[p.x, p.y]); 555 | heap.push(NearestNeighborState { 556 | distance: d, 557 | point: *p, 558 | }); 559 | } 560 | } 561 | if let Some(v) = heap.pop() { 562 | let local_min_d = v.distance; 563 | // Update the nearest neighbour and minimum distance 564 | if &local_min_d < min_d { 565 | *nearest_neighbor = v.point; 566 | *min_d = local_min_d; 567 | } 568 | } 569 | } 570 | 571 | /// Calculates the horizontal distance between query_point and bucket at index with given hash. 572 | /// 573 | /// # Arguments 574 | /// * `hash` - A hash index of the bucket 575 | /// * `query_point` - A Point data 576 | #[inline] 577 | fn horizontal_distance(&mut self, query_point: &[F; 2], hash: u64) -> F { 578 | let x = unhash(&mut self.hasher, hash); 579 | match self.hasher.sort_by_x() { 580 | true => Euclidean::distance(&[query_point[0], F::zero()], &[x, F::zero()]), 581 | false => Euclidean::distance(&[query_point[1], F::zero()], &[x, F::zero()]), 582 | } 583 | } 584 | 585 | /// Nearest neighbor search for the closest point for given query point 586 | /// Returns the closest point 587 | ///```text 588 | /// | 589 | /// | . 590 | /// | . | 591 | /// | |. | * . <- nearest neighbor 592 | /// | || | | .| 593 | /// | expand <--------> expand 594 | /// | left | right 595 | /// | | 596 | /// |_______________v_____________ 597 | /// query 598 | /// point 599 | ///``` 600 | /// # Arguments 601 | /// 602 | /// * `query_point` - A tuple containing a pair of points for querying 603 | /// 604 | /// # Examples 605 | /// 606 | /// ``` 607 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 608 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 609 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 610 | /// assert_eq!(map.nearest_neighbor(&[2., 1.]).is_some(), true); 611 | /// ``` 612 | #[inline] 613 | pub fn nearest_neighbor(&mut self, query_point: &[F; 2]) -> Option> { 614 | let mut hash = make_hash_point(&mut self.hasher, query_point); 615 | let max_capacity = self.table.capacity() as u64; 616 | 617 | // if hash out of max bound, still search right most bucket 618 | if hash > max_capacity { 619 | hash = max_capacity - 1; 620 | } 621 | 622 | let mut heap = BinaryHeap::new(); 623 | let mut min_d = F::max_value(); 624 | let mut nearest_neighbor = Point::default(); 625 | 626 | // Searching at current hash index 627 | self.local_min_heap( 628 | &mut heap, 629 | hash, 630 | query_point, 631 | &mut min_d, 632 | &mut nearest_neighbor, 633 | ); 634 | 635 | // Measure left horizontal distance from current bucket to left hash bucket 636 | // left hash must >= 0 637 | let mut left_hash = hash.saturating_sub(1); 638 | // Unhash the left_hash, then calculate the vertical distance between 639 | // left hash point and query point 640 | let mut left_hash_d = self.horizontal_distance(query_point, left_hash); 641 | 642 | // Iterate over left 643 | while left_hash_d < min_d { 644 | self.local_min_heap( 645 | &mut heap, 646 | left_hash, 647 | query_point, 648 | &mut min_d, 649 | &mut nearest_neighbor, 650 | ); 651 | 652 | // break before update 653 | if left_hash == 0 { 654 | break; 655 | } 656 | 657 | // Update next right side bucket distance 658 | left_hash = left_hash.saturating_sub(1); 659 | left_hash_d = self.horizontal_distance(query_point, left_hash); 660 | } 661 | 662 | // Measure right vertical distance from current bucket to right hash bucket 663 | let mut right_hash = hash + 1; 664 | // Unhash the right_hash, then calculate the vertical distance between 665 | // right hash point and query point 666 | let mut right_hash_d = self.horizontal_distance(query_point, right_hash); 667 | 668 | // Iterate over right 669 | while right_hash_d < min_d { 670 | self.local_min_heap( 671 | &mut heap, 672 | right_hash, 673 | query_point, 674 | &mut min_d, 675 | &mut nearest_neighbor, 676 | ); 677 | 678 | // Move to next right bucket 679 | right_hash += 1; 680 | 681 | // break after update 682 | if right_hash == self.table.capacity() as u64 { 683 | break; 684 | } 685 | // Update next right side bucket distance 686 | right_hash_d = self.horizontal_distance(query_point, right_hash); 687 | } 688 | 689 | Some(nearest_neighbor) 690 | } 691 | } 692 | 693 | pub struct Iter<'a, M, F> 694 | where 695 | F: Float, 696 | M: Model + Default + Clone, 697 | { 698 | map: &'a LearnedHashMap, 699 | bucket: usize, 700 | at: usize, 701 | } 702 | 703 | impl<'a, M, F> Iterator for Iter<'a, M, F> 704 | where 705 | F: Float, 706 | M: Model + Default + Clone, 707 | { 708 | type Item = &'a Point; 709 | fn next(&mut self) -> Option { 710 | loop { 711 | match self.map.table.get(self.bucket) { 712 | Some(bucket) => { 713 | match bucket.get(self.at) { 714 | Some(p) => { 715 | // move along self.at and self.bucket 716 | self.at += 1; 717 | break Some(p); 718 | } 719 | None => { 720 | self.bucket += 1; 721 | self.at = 0; 722 | continue; 723 | } 724 | } 725 | } 726 | None => break None, 727 | } 728 | } 729 | } 730 | } 731 | 732 | impl<'a, M, F> IntoIterator for &'a LearnedHashMap 733 | where 734 | F: Float, 735 | M: Model + Default + Clone, 736 | { 737 | type Item = &'a Point; 738 | type IntoIter = Iter<'a, M, F>; 739 | fn into_iter(self) -> Self::IntoIter { 740 | Iter { 741 | map: self, 742 | bucket: 0, 743 | at: 0, 744 | } 745 | } 746 | } 747 | 748 | pub struct IntoIter 749 | where 750 | F: Float, 751 | M: Model + Default + Clone, 752 | { 753 | map: LearnedHashMap, 754 | bucket: usize, 755 | } 756 | 757 | impl Iterator for IntoIter 758 | where 759 | F: Float, 760 | M: Model + Default + Clone, 761 | { 762 | type Item = Point; 763 | fn next(&mut self) -> Option { 764 | loop { 765 | match self.map.table.get_mut(self.bucket) { 766 | Some(bucket) => match bucket.pop() { 767 | Some(x) => break Some(x), 768 | None => { 769 | self.bucket += 1; 770 | continue; 771 | } 772 | }, 773 | None => break None, 774 | } 775 | } 776 | } 777 | } 778 | 779 | impl IntoIterator for LearnedHashMap 780 | where 781 | F: Float, 782 | M: Model + Default + Clone, 783 | { 784 | type Item = Point; 785 | type IntoIter = IntoIter; 786 | fn into_iter(self) -> Self::IntoIter { 787 | IntoIter { 788 | map: self, 789 | bucket: 0, 790 | } 791 | } 792 | } 793 | 794 | #[cfg(test)] 795 | mod tests { 796 | use super::*; 797 | use crate::geometry::Point; 798 | use crate::models::LinearModel; 799 | use crate::test_utilities::*; 800 | 801 | #[test] 802 | fn insert() { 803 | let a: Point = Point::new(0., 1.); 804 | let b: Point = Point::new(1., 0.); 805 | 806 | let mut map = LearnedHashMap::, f64>::new(); 807 | map.insert(a); 808 | map.insert(b); 809 | 810 | assert_eq!(map.items(), 2); 811 | assert_eq!(map.get(&[0., 1.]).unwrap(), &a); 812 | assert_eq!(map.get(&[1., 0.]).unwrap(), &b); 813 | } 814 | 815 | #[test] 816 | fn insert_repeated() { 817 | let mut map = LearnedHashMap::, f64>::new(); 818 | let a: Point = Point::new(0., 1.); 819 | let b: Point = Point::new(1., 0.); 820 | let res = map.insert(a); 821 | assert_eq!(map.items(), 1); 822 | assert_eq!(res, None); 823 | 824 | let res = map.insert(b); 825 | assert_eq!(map.items(), 2); 826 | assert_eq!(res, None); 827 | } 828 | 829 | #[test] 830 | fn with_data() { 831 | let data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 832 | let (mut map, _points) = LearnedHashMap::, f64>::with_data(&data).unwrap(); 833 | assert!(map.get(&[1., 1.]).is_some()); 834 | } 835 | 836 | #[test] 837 | fn fit_batch_insert() { 838 | let mut data: Vec> = vec![ 839 | Point::new(1., 1.), 840 | Point::new(3., 1.), 841 | Point::new(2., 1.), 842 | Point::new(3., 2.), 843 | Point::new(5., 1.), 844 | ]; 845 | let mut map = LearnedHashMap::, f64>::new(); 846 | map.batch_insert(&mut data).unwrap(); 847 | 848 | assert_delta!(1.02272, map.hasher.model.coefficient, 0.00001); 849 | assert_delta!(-0.86363, map.hasher.model.intercept, 0.00001); 850 | assert_eq!(Some(&Point::new(1., 1.)), map.get(&[1., 1.])); 851 | assert_eq!(Some(&Point::new(3., 1.,)), map.get(&[3., 1.])); 852 | assert_eq!(Some(&Point::new(5., 1.)), map.get(&[5., 1.])); 853 | 854 | assert_eq!(None, map.get(&[5., 2.])); 855 | assert_eq!(None, map.get(&[2., 2.])); 856 | assert_eq!(None, map.get(&[50., 10.])); 857 | assert_eq!(None, map.get(&[500., 100.])); 858 | } 859 | 860 | #[test] 861 | fn insert_after_batch_insert() { 862 | let mut data: Vec> = vec![ 863 | Point::new(1., 1.), 864 | Point::new(3., 1.), 865 | Point::new(2., 1.), 866 | Point::new(3., 2.), 867 | Point::new(5., 1.), 868 | ]; 869 | let mut map = LearnedHashMap::, f64>::new(); 870 | map.batch_insert(&mut data).unwrap(); 871 | 872 | let a: Point = Point::new(10., 10.); 873 | map.insert(a); 874 | assert_eq!(Some(&a), map.get(&[10., 10.])); 875 | 876 | let b: Point = Point::new(100., 100.); 877 | map.insert(b); 878 | assert_eq!(Some(&b), map.get(&[100., 100.])); 879 | assert_eq!(None, map.get(&[100., 101.])); 880 | } 881 | 882 | #[test] 883 | fn range_search() { 884 | let mut data: Vec> = vec![ 885 | Point::new(1., 1.), 886 | Point::new(2., 2.), 887 | Point::new(3., 3.), 888 | Point::new(4., 4.), 889 | Point::new(5., 5.), 890 | ]; 891 | let mut map = LearnedHashMap::, f64>::new(); 892 | map.batch_insert(&mut data).unwrap(); 893 | 894 | 895 | let found: Vec> = 896 | vec![Point::new(1., 1.), Point::new(2., 2.), Point::new(3., 3.)]; 897 | 898 | assert_eq!(Some(found), map.range_search(&[1., 1.], &[3.5, 3.])); 899 | 900 | let found: Vec> = vec![Point::new(1., 1.)]; 901 | 902 | assert_eq!(Some(found), map.range_search(&[1., 1.], &[3., 1.])); 903 | assert_eq!(None, map.range_search(&[4., 2.], &[5., 3.])); 904 | } 905 | 906 | #[test] 907 | fn test_nearest_neighbor() { 908 | let points = create_random_point_type_points(1000, SEED_1); 909 | let mut map = LearnedHashMap::, f64>::new(); 910 | map.batch_insert(&mut points.clone()).unwrap(); 911 | 912 | let sample_points = create_random_point_type_points(100, SEED_2); 913 | for sample_point in sample_points.iter() { 914 | let mut nearest = None; 915 | let mut closest_dist = f64::INFINITY; 916 | for point in &points { 917 | let new_dist = Euclidean::distance_point(point, sample_point); 918 | if new_dist < closest_dist { 919 | closest_dist = new_dist; 920 | nearest = Some(point); 921 | } 922 | } 923 | let map_nearest = map 924 | .nearest_neighbor(&[sample_point.x, sample_point.y]) 925 | .unwrap(); 926 | assert_eq!(nearest.unwrap(), &map_nearest); 927 | } 928 | } 929 | } 930 | --------------------------------------------------------------------------------