├── .github └── workflows │ └── rust.yml ├── .gitignore ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── benches └── benchmarks.rs └── src ├── error.rs ├── geometry ├── axis.rs ├── distance.rs ├── helper.rs ├── mod.rs └── point.rs ├── hasher └── mod.rs ├── lib.rs ├── macros.rs ├── map ├── mod.rs ├── nn.rs └── table.rs ├── models ├── linear.rs ├── mod.rs ├── stats.rs └── trainer.rs └── test_utilities.rs /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Build 20 | run: cargo build --verbose 21 | - name: Run tests 22 | run: cargo test --verbose 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | .DS_Store 4 | .vscode -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lsph" 3 | version = "0.1.8" 4 | authors = ["Haozhan Shi "] 5 | repository = "https://github.com/jackson211/lsph" 6 | homepage = "https://github.com/jackson211/lsph" 7 | documentation = "https://docs.rs/lsph" 8 | edition = "2021" 9 | description = "Learned Spatial HashMap" 10 | license = "MIT OR Apache-2.0" 11 | readme = "README.md" 12 | keywords = ["hash", "hashmap", "spatial", "index"] 13 | categories = ["data-structures"] 14 | 15 | [dependencies] 16 | num-traits = "0.2.18" 17 | smallvec = "1.13.2" 18 | 19 | [dev-dependencies] 20 | rand = "0.8.5" 21 | rand_hc = "0.3.2" 22 | criterion = { version = "0.5.1", features = ["html_reports"] } 23 | 24 | [[bench]] 25 | name = "benchmarks" 26 | harness = false 27 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Haozhan Shi 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 The LSPH developers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [LSPH](https://crates.io/crates/lsph) - Learned SPatial HashMap 2 | 3 | **fast 2d point query powered by hashmap and statistic model** 4 | 5 | ![Github Workflow](https://github.com/jackson211/lsph/actions/workflows/rust.yml/badge.svg) 6 | [![crates.io version](https://img.shields.io/crates/v/lsph)](https://crates.io/crates/lsph) 7 | [![dos.io](https://img.shields.io/docsrs/lsph)](https://docs.rs/lsph) 8 | [![dependency status](https://deps.rs/repo/github/jackson211/lsph/status.svg)](https://deps.rs/repo/github/jackson211/lsph) 9 | 10 | The original paper of LSPH can be found [here]. 11 | 12 | [here]: https://minerva-access.unimelb.edu.au/items/beb5c0ee-2a8d-5bd2-b349-1190a335ef1a 13 | 14 | The LSPH uses a learned model such as a linear regression model as the hash function to predict the index in a hashmap. As a result, the learned model is more fitted to the data that stored in the hashmap, and reduces the 15 | chance of hashing collisions. Moreover, if the learned model is monotonic function(e.g. linear regression), the hash indexes are increasing as the input data increases. This property can be used to create a sorted order 16 | of buckets in a hashmap, which allow us to do range searches in a hashmap. 17 | 18 | The LSPH supports: 19 | 20 | - Point Query 21 | - Rectange Query 22 | - Radius Range Query 23 | - Nearest Neighbor Query 24 | 25 | ## Example: 26 | 27 | ```rust 28 | use lsph::{LearnedHashMap, LinearModel}; 29 | let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 30 | let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 31 | 32 | assert_eq!(map.get(&[1., 1.]).is_some(), true); 33 | assert_eq!(map.get(&[3., 1.]).is_none(), true); 34 | assert_eq!(map.range_search(&[0., 0.], &[3., 3.]).is_some(), true); 35 | assert_eq!(map.radius_range(&[2., 1.], 1.).is_some(), true); 36 | assert_eq!(map.nearest_neighbor(&[2., 1.]).is_some(), true); 37 | ``` 38 | 39 | ## To Run Benchmark: 40 | 41 | ```bash 42 | cargo bench 43 | ``` 44 | 45 | # License 46 | 47 | Licensed under either of 48 | 49 | - Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) 50 | - MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) 51 | 52 | at your option. 53 | -------------------------------------------------------------------------------- /benches/benchmarks.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | extern crate criterion; 3 | 4 | use criterion::Criterion; 5 | use lsph::geometry::Point; 6 | use lsph::{map::LearnedHashMap, models::LinearModel}; 7 | use rand::{Rng, SeedableRng}; 8 | use rand_hc::Hc128Rng; 9 | 10 | const SEED_1: &[u8; 32] = b"Gv0aHMtHkBGsUXNspGU9fLRuCWkZWHZx"; 11 | const SEED_2: &[u8; 32] = b"km7DO4GeaFZfTcDXVpnO7ZJlgUY7hZiS"; 12 | 13 | const DEFAULT_BENCHMARK_TREE_SIZE: usize = 2000; 14 | 15 | fn bulk_load_baseline(c: &mut Criterion) { 16 | c.bench_function("Bulk load baseline", move |b| { 17 | let mut points: Vec<_> = 18 | create_random_point_type_points(DEFAULT_BENCHMARK_TREE_SIZE, SEED_1); 19 | let mut map = LearnedHashMap::, f64>::new(); 20 | 21 | b.iter(|| { 22 | map.batch_insert(&mut points).unwrap(); 23 | }); 24 | }) 25 | .bench_function("Bulk load baseline with f32", move |b| { 26 | let mut points: Vec<_> = 27 | create_random_point_type_points_f32(DEFAULT_BENCHMARK_TREE_SIZE, SEED_1); 28 | let mut map = LearnedHashMap::, f32>::new(); 29 | 30 | b.iter(|| { 31 | map.batch_insert(&mut points).unwrap(); 32 | }); 33 | }); 34 | } 35 | 36 | fn locate_successful(c: &mut Criterion) { 37 | let mut points: Vec<_> = create_random_point_type_points(100_000, SEED_1); 38 | let mut points_f32: Vec<_> = create_random_point_type_points_f32(100_000, SEED_1); 39 | let query_point = [points[500].x(), points[500].y()]; 40 | let query_point_f32 = [points_f32[500].x(), points_f32[500].y()]; 41 | 42 | let mut map = LearnedHashMap::, f64>::new(); 43 | let mut map_f32 = LearnedHashMap::, f32>::new(); 44 | map.batch_insert(&mut points).unwrap(); 45 | map_f32.batch_insert(&mut points_f32).unwrap(); 46 | c.bench_function("locate_at_point (successful)", move |b| { 47 | b.iter(|| map.get(&query_point).is_some()) 48 | }) 49 | .bench_function("locate_at_point_f32 (successful)", move |b| { 50 | b.iter(|| map_f32.get(&query_point_f32).is_some()) 51 | }); 52 | } 53 | 54 | fn locate_unsuccessful(c: &mut Criterion) { 55 | let mut points: Vec<_> = create_random_point_type_points(100_000, SEED_1); 56 | let mut points_f32: Vec<_> = create_random_point_type_points_f32(100_000, SEED_1); 57 | let query_point: [f64; 2] = [0.7, 0.7]; 58 | let query_point_f32: [f32; 2] = [0.7, 0.7]; 59 | 60 | let mut map = LearnedHashMap::, f64>::new(); 61 | let mut map_f32 = LearnedHashMap::, f32>::new(); 62 | map.batch_insert(&mut points).unwrap(); 63 | map_f32.batch_insert(&mut points_f32).unwrap(); 64 | c.bench_function("locate_at_point (unsuccessful)", move |b| { 65 | b.iter(|| map.get(&query_point).is_none()) 66 | }) 67 | .bench_function("locate_at_point_f32 (unsuccessful)", move |b| { 68 | b.iter(|| map_f32.get(&query_point_f32).is_none()) 69 | }); 70 | } 71 | 72 | fn nearest_neighbor(c: &mut Criterion) { 73 | const SIZE: usize = 100_000; 74 | let mut points: Vec<_> = create_random_point_type_points(SIZE, SEED_1); 75 | let query_points = create_random_points(100, SEED_2); 76 | 77 | let mut map = LearnedHashMap::, f64>::new(); 78 | map.batch_insert(&mut points).unwrap(); 79 | 80 | c.bench_function("nearest_neigbor", move |b| { 81 | b.iter(|| { 82 | for query_point in &query_points { 83 | map.nearest_neighbor(&query_point).unwrap(); 84 | } 85 | }); 86 | }); 87 | } 88 | 89 | fn radius_range(c: &mut Criterion) { 90 | const SIZE: usize = 100_000; 91 | let mut points: Vec<_> = create_random_point_type_points(SIZE, SEED_1); 92 | let query_points = create_random_points(100, SEED_2); 93 | 94 | let mut map = LearnedHashMap::, f64>::new(); 95 | map.batch_insert(&mut points).unwrap(); 96 | 97 | let radiuses = vec![0.01, 0.1, 0.2]; 98 | for radius in radiuses { 99 | let title = format!("radius_range_{radius}"); 100 | c.bench_function(title.as_str(), |b| { 101 | b.iter(|| { 102 | for query_point in &query_points { 103 | map.radius_range(&query_point, radius).unwrap(); 104 | } 105 | }); 106 | }); 107 | } 108 | } 109 | 110 | criterion_group!( 111 | benches, 112 | bulk_load_baseline, 113 | locate_successful, 114 | locate_unsuccessful, 115 | radius_range, 116 | nearest_neighbor, 117 | ); 118 | criterion_main!(benches); 119 | 120 | fn create_random_points(num_points: usize, seed: &[u8; 32]) -> Vec<[f64; 2]> { 121 | let mut result = Vec::with_capacity(num_points); 122 | let mut rng = Hc128Rng::from_seed(*seed); 123 | for _ in 0..num_points { 124 | result.push([rng.gen(), rng.gen()]); 125 | } 126 | result 127 | } 128 | 129 | fn create_random_point_type_points(num_points: usize, seed: &[u8; 32]) -> Vec> { 130 | let result = create_random_points(num_points, seed); 131 | 132 | // result.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); 133 | result 134 | .into_iter() 135 | .map(|[x, y]| Point::new(x, y)) 136 | .collect::>() 137 | } 138 | 139 | fn create_random_point_type_points_f32(num_points: usize, seed: &[u8; 32]) -> Vec> { 140 | let result = create_random_points(num_points, seed); 141 | 142 | // result.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); 143 | result 144 | .into_iter() 145 | .map(|[x, y]| Point::new(x as f32, y as f32)) 146 | .collect::>() 147 | } 148 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | /// The kinds of errors that can occur when calculating a linear regression. 2 | #[derive(Copy, Clone, Debug, PartialEq)] 3 | pub enum Error { 4 | /// The slope is too steep to represent, approaching infinity. 5 | SteepSlope, 6 | 7 | /// Different input lenses 8 | DiffLen, 9 | 10 | /// Input was empty 11 | EmptyVal, 12 | } 13 | -------------------------------------------------------------------------------- /src/geometry/axis.rs: -------------------------------------------------------------------------------- 1 | /// Axis defines axis direction in Cartesian coordinate system 2 | #[derive(Debug, Clone)] 3 | pub enum Axis { 4 | X, 5 | Y, 6 | Z, 7 | } 8 | -------------------------------------------------------------------------------- /src/geometry/distance.rs: -------------------------------------------------------------------------------- 1 | use crate::geometry::Point; 2 | use core::marker::PhantomData; 3 | use num_traits::float::Float; 4 | 5 | /// Distance trait for measuring the distance between two points 6 | pub trait Distance { 7 | type F; 8 | /// Distance between two points in tuple format 9 | fn distance(a: &[Self::F; 2], b: &[Self::F; 2]) -> Self::F; 10 | /// Distance between two points in points format 11 | fn distance_point(a: &Point, b: &Point) -> Self::F; 12 | } 13 | 14 | /// Euclidean Distance 15 | pub struct Euclidean { 16 | _marker: PhantomData, 17 | } 18 | 19 | impl Distance for Euclidean 20 | where 21 | F: Float, 22 | { 23 | type F = F; 24 | fn distance(a: &[F; 2], b: &[F; 2]) -> F { 25 | F::sqrt((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2)) 26 | } 27 | 28 | fn distance_point(a: &Point, b: &Point) -> Self::F { 29 | Self::distance(&[a.x, a.y], &[b.x, b.y]) 30 | } 31 | } 32 | 33 | /// Manhattan Distance 34 | pub struct Manhattan { 35 | _marker: PhantomData, 36 | } 37 | 38 | impl Distance for Manhattan 39 | where 40 | F: Float, 41 | { 42 | type F = F; 43 | fn distance(a: &[F; 2], b: &[F; 2]) -> F { 44 | (a[0] - b[0]).abs() + (a[1] - b[1]).abs() 45 | } 46 | 47 | fn distance_point(a: &Point, b: &Point) -> Self::F { 48 | Self::distance(&[a.x, a.y], &[b.x, b.y]) 49 | } 50 | } 51 | 52 | #[cfg(test)] 53 | mod tests { 54 | use super::*; 55 | 56 | #[test] 57 | fn test_euclidean_f32() { 58 | let a = Point:: { x: 0., y: 0. }; 59 | let b = Point:: { x: 1., y: 1. }; 60 | let d = Euclidean::distance_point(&a, &b); 61 | assert_delta!(d, 1.4142135, 0.00001); 62 | } 63 | 64 | #[test] 65 | fn test_euclidean_f64() { 66 | let a = Point:: { x: 0., y: 0. }; 67 | let b = Point:: { x: 1., y: 1. }; 68 | let d = Euclidean::distance_point(&a, &b); 69 | assert_delta!(d, 1.4142135, 0.00001); 70 | } 71 | 72 | #[test] 73 | fn test_manhattan_f32() { 74 | let a = Point:: { x: 0., y: 0. }; 75 | let b = Point:: { x: 1., y: 1. }; 76 | let d = Manhattan::distance_point(&a, &b); 77 | assert_delta!(d, 2., 0.00001); 78 | } 79 | 80 | #[test] 81 | fn test_manhattan_f64() { 82 | let a = Point:: { x: 0., y: 0. }; 83 | let b = Point:: { x: 1., y: 1. }; 84 | let d = Manhattan::distance_point(&a, &b); 85 | assert_delta!(d, 2., 0.00001); 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/geometry/helper.rs: -------------------------------------------------------------------------------- 1 | use crate::geometry::Point; 2 | use num_traits::float::Float; 3 | 4 | /// Extract all the x values from a Vec> 5 | pub fn extract_x(ps: &[Point]) -> Vec { 6 | ps.iter().map(|p| p.x).collect() 7 | } 8 | 9 | /// Extract all the y values from a Vec> 10 | pub fn extract_y(ps: &[Point]) -> Vec { 11 | ps.iter().map(|p| p.y).collect() 12 | } 13 | 14 | /// Sort a Vec> based on the x values 15 | pub fn sort_by_x(ps: &mut [Point]) { 16 | ps.sort_by(|a, b| a.x.partial_cmp(&b.x).unwrap()); 17 | } 18 | 19 | /// Sort a Vec> based on the y values 20 | pub fn sort_by_y(ps: &mut [Point]) { 21 | ps.sort_by(|a, b| a.y.partial_cmp(&b.y).unwrap()); 22 | } 23 | 24 | /// Convert a Vec of [F; 2] to a Vec> 25 | pub fn convert_to_points(ps: &[[F; 2]]) -> Option>> { 26 | Some(ps.iter().map(|p| Point::new(p[0], p[1])).collect()) 27 | } 28 | -------------------------------------------------------------------------------- /src/geometry/mod.rs: -------------------------------------------------------------------------------- 1 | mod axis; 2 | pub mod distance; 3 | pub mod helper; 4 | mod point; 5 | 6 | pub use axis::*; 7 | pub use helper::*; 8 | pub use point::*; 9 | -------------------------------------------------------------------------------- /src/geometry/point.rs: -------------------------------------------------------------------------------- 1 | use num_traits::float::Float; 2 | 3 | /// Point struct contains id, x and y 4 | #[derive(Debug, Clone, Copy, Eq, PartialEq)] 5 | pub struct Point { 6 | pub(crate) x: T, 7 | pub(crate) y: T, 8 | } 9 | 10 | impl Default for Point 11 | where 12 | T: Float, 13 | { 14 | fn default() -> Self { 15 | Point { 16 | x: T::zero(), 17 | y: T::zero(), 18 | } 19 | } 20 | } 21 | 22 | impl Point 23 | where 24 | T: Float, 25 | { 26 | pub fn new(x: T, y: T) -> Self { 27 | Point { x, y } 28 | } 29 | 30 | pub fn x(&self) -> T { 31 | self.x 32 | } 33 | 34 | pub fn y(&self) -> T { 35 | self.y 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/hasher/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::models::Model; 2 | use num_traits::cast::{AsPrimitive, FromPrimitive}; 3 | use num_traits::float::Float; 4 | 5 | /// LearnedHasher takes a model and produces hash from the model 6 | #[derive(Debug, Clone)] 7 | pub struct LearnedHasher { 8 | state: u64, 9 | pub model: M, 10 | sort_by_x: bool, 11 | } 12 | 13 | impl Default for LearnedHasher 14 | where 15 | F: Float, 16 | M: Model + Default, 17 | { 18 | #[inline] 19 | fn default() -> Self { 20 | Self { 21 | state: 0, 22 | model: Default::default(), 23 | sort_by_x: true, 24 | } 25 | } 26 | } 27 | impl LearnedHasher 28 | where 29 | F: Float, 30 | M: Model, 31 | { 32 | /// Returns a default LearnedHasher with Model and Float type. 33 | /// 34 | /// # Arguments 35 | /// * `model` - A model that implements Model trait 36 | /// 37 | /// # Examples 38 | /// 39 | /// ``` 40 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 41 | /// let hasher = LearnedHasher::with_model(LinearModel::::new()); 42 | /// ``` 43 | #[inline] 44 | pub fn with_model(model: M) -> Self { 45 | Self { 46 | state: 0, 47 | model, 48 | sort_by_x: true, 49 | } 50 | } 51 | 52 | /// Returns a current Hasher state. 53 | #[inline] 54 | fn finish(&self) -> u64 { 55 | self.state 56 | } 57 | 58 | /// Returns the sorted index base on parameter self.sort_by_x. 59 | #[inline] 60 | pub fn sort_by_x(&self) -> bool { 61 | self.sort_by_x 62 | } 63 | 64 | /// Sets self.sort_by_x to a given boolean value. 65 | #[inline] 66 | pub fn set_sort_by_x(&mut self, x: bool) { 67 | self.sort_by_x = x; 68 | } 69 | } 70 | 71 | impl LearnedHasher 72 | where 73 | F: Float, 74 | M: Model + Default, 75 | { 76 | /// Returns a default LearnedHasher. 77 | #[inline] 78 | pub fn new() -> Self { 79 | Self::default() 80 | } 81 | } 82 | 83 | impl LearnedHasher 84 | where 85 | F: Float + AsPrimitive, 86 | M: Model, 87 | { 88 | /// Writes a data into self.data by inferencing the input data into the trained model. 89 | #[inline] 90 | fn write(&mut self, data: &F) { 91 | self.state = self.model.predict(*data).floor().as_(); 92 | } 93 | } 94 | 95 | impl LearnedHasher 96 | where 97 | F: Float + FromPrimitive, 98 | M: Model + Default, 99 | { 100 | /// Unwrite takes a hash value, and unpredict the hash value to estimate the approximate input 101 | /// data from the target data(e.g. in linear regression, to get the x value for given y). 102 | /// 103 | /// # Arguments 104 | /// * `hash` - An usize hash value 105 | #[inline] 106 | fn unwrite(&mut self, hash: u64) -> F { 107 | let hash = FromPrimitive::from_u64(hash).unwrap(); 108 | self.model.unpredict(hash) 109 | } 110 | } 111 | /// Make hash value from a given hasher, returns a u64 hash value. 112 | /// 113 | /// # Arguments 114 | /// * `hasher` - A LearnedHasher type 115 | #[inline] 116 | pub fn make_hash(hasher: &mut LearnedHasher, p: &F) -> u64 117 | where 118 | F: Float + FromPrimitive + AsPrimitive, 119 | M: Model + Default, 120 | { 121 | hasher.write(p); 122 | hasher.finish() 123 | } 124 | 125 | /// Make hash value from a given hasher, and 2 item array with float data. 126 | /// 127 | /// # Arguments 128 | /// * `hasher` - A LearnedHasher type 129 | /// * `p` - Point data 130 | #[inline] 131 | pub fn make_hash_point(hasher: &mut LearnedHasher, p: &[F; 2]) -> u64 132 | where 133 | F: Float + FromPrimitive + AsPrimitive, 134 | M: Model + Default, 135 | { 136 | if hasher.sort_by_x { 137 | make_hash(hasher, &p[0]) 138 | } else { 139 | make_hash(hasher, &p[1]) 140 | } 141 | } 142 | 143 | /// Unmake hash value from a given hasher, and a u64 hash value. 144 | /// Reverse the hash function, which it takes a hash and returns float 145 | /// 146 | /// # Arguments 147 | /// * `hasher` - A LearnedHasher type 148 | /// * `p` - Point data 149 | #[inline] 150 | pub fn unhash(hasher: &mut LearnedHasher, hash: u64) -> F 151 | where 152 | F: Float + FromPrimitive + AsPrimitive, 153 | M: Model + Default, 154 | { 155 | hasher.unwrite(hash) 156 | } 157 | 158 | #[cfg(test)] 159 | mod tests { 160 | use super::LearnedHasher; 161 | use crate::models::LinearModel; 162 | 163 | #[test] 164 | fn hasher_with_empty_model() { 165 | let mut hasher: LearnedHasher> = LearnedHasher::new(); 166 | hasher.write(&10f64); 167 | assert_eq!(0u64, hasher.finish()); 168 | } 169 | 170 | #[test] 171 | fn unhash() { 172 | let mut hasher: LearnedHasher> = LearnedHasher::with_model(LinearModel { 173 | coefficient: 3., 174 | intercept: 2., 175 | }); 176 | hasher.write(&10.5); 177 | assert_eq!(33u64, hasher.finish()); 178 | assert_delta!(10.33f64, hasher.unwrite(33u64), 0.01); 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! This crate implements the Learned SPatial HashMap(LSPH), a high performance spatial index uses HashMap with learned model. 2 | //! 3 | //! The original paper of LSPH can be found [here]. 4 | //! 5 | //! [here]: https://minerva-access.unimelb.edu.au/items/beb5c0ee-2a8d-5bd2-b349-1190a335ef1a 6 | //! 7 | //! The LSPH uses a learned model such as a linear regression model as the hash 8 | //! function to predict the index in a hashmap. As a result, the learned model 9 | //! is more fitted to the data that stored in the hashmap, and reduces the 10 | //! chance of hashing collisions. Moreover, if the learned model is monotonic 11 | //! function(e.g. linear regression), the hash indexes are increasing as the 12 | //! input data increases. This property can be used to create a sorted order 13 | //! of buckets in a hashmap, which allow us to do range searches in a hashmap. 14 | //! 15 | //! The LSPH supports: 16 | //! - Point Query 17 | //! - Rectange Query 18 | //! - Radius Range Query 19 | //! - Nearest Neighbor Query 20 | //! 21 | //! Example: 22 | //! ``` 23 | //! use lsph::{LearnedHashMap, LinearModel}; 24 | //! let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 25 | //! let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 26 | //! 27 | //! assert_eq!(map.get(&[1., 1.]).is_some(), true); 28 | //! assert_eq!(map.get(&[3., 1.]).is_none(), true); 29 | //! assert_eq!(map.range_search(&[0., 0.], &[3., 3.]).is_some(), true); 30 | //! assert_eq!(map.radius_range(&[2., 1.], 1.).is_some(), true); 31 | //! assert_eq!(map.nearest_neighbor(&[2., 1.]).is_some(), true); 32 | //! 33 | //! ``` 34 | //! # License 35 | //! 36 | //! Licensed under either of 37 | //! 38 | //! - Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or ) 39 | //! - MIT license ([LICENSE-MIT](LICENSE-MIT) or ) 40 | //! 41 | //! at your option. 42 | 43 | #[macro_use] 44 | mod macros; 45 | mod error; 46 | pub mod geometry; 47 | pub mod hasher; 48 | pub mod map; 49 | pub mod models; 50 | #[cfg(test)] 51 | pub mod test_utilities; 52 | 53 | pub use geometry::*; 54 | pub use hasher::*; 55 | pub use map::*; 56 | pub use models::*; 57 | -------------------------------------------------------------------------------- /src/macros.rs: -------------------------------------------------------------------------------- 1 | extern crate num_traits; 2 | // Hepler macro for assert Float values within a delta range, panic if the 3 | // difference between two numbers is exceeds the given threshold. 4 | #[macro_export] 5 | macro_rules! assert_delta { 6 | ($x:expr, $y:expr, $delta:expr) => { 7 | assert!( 8 | ($x - $y).abs() <= num_traits::NumCast::from($delta).unwrap(), 9 | "{} is not within {} of {}", 10 | stringify!($x), 11 | $delta, 12 | stringify!($y) 13 | ); 14 | }; 15 | } 16 | 17 | #[macro_export] 18 | macro_rules! assert_eq_len { 19 | ($a:expr, $b:expr) => { 20 | if $a.len() != $b.len() { 21 | return Err(Error::DiffLen); 22 | } 23 | }; 24 | } 25 | 26 | #[macro_export] 27 | macro_rules! assert_empty { 28 | ($a:expr) => { 29 | if $a.is_empty() { 30 | return Err(Error::EmptyVal); 31 | } 32 | }; 33 | } 34 | -------------------------------------------------------------------------------- /src/map/mod.rs: -------------------------------------------------------------------------------- 1 | mod nn; 2 | mod table; 3 | 4 | use crate::{ 5 | error::*, 6 | geometry::{distance::*, Point}, 7 | hasher::*, 8 | map::{nn::*, table::*}, 9 | models::Model, 10 | }; 11 | use core::{fmt::Debug, iter::Sum, mem}; 12 | use num_traits::{ 13 | cast::{AsPrimitive, FromPrimitive}, 14 | float::Float, 15 | }; 16 | use std::collections::BinaryHeap; 17 | 18 | /// Initial bucket size is set to 1 19 | const INITIAL_NBUCKETS: usize = 1; 20 | 21 | /// LearnedHashMap takes a model instead of an hasher for hashing indexes in the table. 22 | /// 23 | /// Default Model for the LearndedHashMap is Linear regression. 24 | /// In order to build a ordered HashMap, we need to make sure that the model is **monotonic**. 25 | #[derive(Debug, Clone)] 26 | pub struct LearnedHashMap { 27 | hasher: LearnedHasher, 28 | table: Table>, 29 | items: usize, 30 | } 31 | 32 | /// Default for the LearndedHashMap. 33 | impl Default for LearnedHashMap 34 | where 35 | F: Float, 36 | M: Model + Default, 37 | { 38 | fn default() -> Self { 39 | Self { 40 | hasher: LearnedHasher::::new(), 41 | table: Table::new(), 42 | items: 0, 43 | } 44 | } 45 | } 46 | 47 | impl LearnedHashMap 48 | where 49 | F: Float + Default + AsPrimitive + FromPrimitive + Debug + Sum, 50 | M: Model + Default + Clone, 51 | { 52 | /// Returns a default LearnedHashMap with Model and Float type. 53 | /// 54 | /// # Examples 55 | /// 56 | /// ``` 57 | /// use lsph::{LearnedHashMap, LinearModel}; 58 | /// let map = LearnedHashMap::, f64>::new(); 59 | /// ``` 60 | #[inline] 61 | pub fn new() -> Self { 62 | Self::default() 63 | } 64 | 65 | /// Returns a default LearnedHashMap with Model and Float type. 66 | /// 67 | /// # Arguments 68 | /// * `hasher` - A LearnedHasher with model 69 | /// 70 | /// # Examples 71 | /// 72 | /// ``` 73 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 74 | /// let map = LearnedHashMap::, f64>::with_hasher(LearnedHasher::new()); 75 | /// ``` 76 | #[inline] 77 | pub fn with_hasher(hasher: LearnedHasher) -> Self { 78 | Self { 79 | hasher, 80 | table: Table::new(), 81 | items: 0, 82 | } 83 | } 84 | 85 | /// Returns a default LearnedHashMap with Model and Float type. 86 | /// 87 | /// # Arguments 88 | /// * `capacity` - A predefined capacity size for the LearnedHashMap 89 | /// 90 | /// # Examples 91 | /// 92 | /// ``` 93 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 94 | /// let map = LearnedHashMap::, f64>::with_capacity(10usize); 95 | /// ``` 96 | #[inline] 97 | pub fn with_capacity(capacity: usize) -> Self { 98 | Self { 99 | hasher: Default::default(), 100 | table: Table::with_capacity(capacity), 101 | items: 0, 102 | } 103 | } 104 | 105 | /// Returns a default LearnedHashMap with Model and Float type 106 | /// 107 | /// # Arguments 108 | /// * `data` - A Vec<[F; 2]> of 2d points for the map 109 | /// 110 | /// # Examples 111 | /// 112 | /// ``` 113 | /// use lsph::{LearnedHashMap, LinearModel}; 114 | /// let data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 115 | /// let map = LearnedHashMap::, f64>::with_data(&data); 116 | /// ``` 117 | #[inline] 118 | pub fn with_data(data: &[[F; 2]]) -> Result<(Self, Vec>), Error> { 119 | use crate::helper::convert_to_points; 120 | let mut map = LearnedHashMap::with_capacity(data.len()); 121 | let mut ps = convert_to_points(data).unwrap(); 122 | match map.batch_insert(&mut ps) { 123 | Ok(()) => Ok((map, ps)), 124 | Err(err) => Err(err), 125 | } 126 | } 127 | 128 | /// Returns Option> with given point data. 129 | /// 130 | /// # Arguments 131 | /// * `p` - A array slice containing two points for querying 132 | /// 133 | /// # Examples 134 | /// 135 | /// ``` 136 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 137 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 138 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 139 | /// 140 | /// assert_eq!(map.get(&[1., 1.]).is_some(), true); 141 | /// ``` 142 | #[inline] 143 | pub fn get(&mut self, p: &[F; 2]) -> Option<&Point> { 144 | let hash = make_hash_point(&mut self.hasher, p) as usize; 145 | if hash > self.table.capacity() { 146 | return None; 147 | } 148 | self.find_by_hash(hash, p) 149 | } 150 | 151 | /// Returns Option> by hash index, if it exists in the map. 152 | /// 153 | /// # Arguments 154 | /// * `hash` - An usize hash value 155 | /// * `p` - A array slice containing two points for querying 156 | /// 157 | /// # Examples 158 | /// 159 | /// ``` 160 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 161 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 162 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 163 | /// 164 | /// assert_eq!(map.find_by_hash(0, &[1., 1.]).is_some(), true); 165 | /// assert_eq!(map.find_by_hash(1, &[1., 1.]).is_none(), true); 166 | /// ``` 167 | pub fn find_by_hash(&self, hash: usize, p: &[F; 2]) -> Option<&Point> { 168 | self.table[hash] 169 | .iter() 170 | .find(|&ep| ep.x == p[0] && ep.y == p[1]) 171 | } 172 | 173 | /// Returns bool. 174 | /// 175 | /// # Arguments 176 | /// * `p` - A array slice containing two points for querying 177 | /// 178 | /// # Examples 179 | /// 180 | /// ``` 181 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 182 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 183 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 184 | /// 185 | /// assert_eq!(map.contains_points(&[1., 1.]), true); 186 | /// assert_eq!(map.contains_points(&[0., 1.]), false); 187 | /// ``` 188 | #[inline] 189 | pub fn contains_points(&mut self, p: &[F; 2]) -> bool { 190 | self.get(p).is_some() 191 | } 192 | 193 | /// Returns Option> if the map contains a point and successful remove it from the map. 194 | /// 195 | /// # Arguments 196 | /// * `p` - A Point data 197 | /// 198 | /// # Examples 199 | /// 200 | /// ``` 201 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 202 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 203 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 204 | /// 205 | /// let p = points[0]; 206 | /// assert_eq!(map.remove(&p).unwrap(), p); 207 | /// ``` 208 | #[inline] 209 | pub fn remove(&mut self, p: &Point) -> Option> { 210 | let hash = make_hash_point(&mut self.hasher, &[p.x, p.y]); 211 | self.items -= 1; 212 | self.table.remove_entry(hash, *p) 213 | } 214 | 215 | /// Returns usize length. 216 | /// 217 | /// # Examples 218 | /// 219 | /// ``` 220 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 221 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 222 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 223 | /// 224 | /// assert_eq!(map.len(), 4); 225 | /// ``` 226 | #[inline] 227 | pub fn len(&self) -> usize { 228 | self.table.len() 229 | } 230 | 231 | /// Returns usize number of items. 232 | /// 233 | /// # Examples 234 | /// 235 | /// ``` 236 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 237 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 238 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 239 | /// 240 | /// assert_eq!(map.items(), 4); 241 | /// ``` 242 | #[inline] 243 | pub fn items(&self) -> usize { 244 | self.items 245 | } 246 | 247 | /// Returns bool if the map is empty. 248 | /// 249 | /// # Examples 250 | /// 251 | /// ``` 252 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 253 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 254 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 255 | /// 256 | /// assert_eq!(map.is_empty(), false); 257 | /// ``` 258 | #[inline] 259 | pub fn is_empty(&self) -> bool { 260 | self.items == 0 261 | } 262 | 263 | /// Resize the map if needed, it will initialize the map to the INITIAL_NBUCKETS, otherwise it will double the capacity if table is not empty. 264 | fn resize(&mut self) { 265 | let target_size = match self.table.len() { 266 | 0 => INITIAL_NBUCKETS, 267 | n => 2 * n, 268 | }; 269 | self.resize_with_capacity(target_size); 270 | } 271 | 272 | /// Resize the map if needed, it will resize the map to desired capacity. 273 | #[inline] 274 | fn resize_with_capacity(&mut self, target_size: usize) { 275 | let mut new_table = Table::with_capacity(target_size); 276 | new_table.extend((0..target_size).map(|_| Bucket::new())); 277 | 278 | for p in self.table.iter_mut().flat_map(|bucket| bucket.drain(..)) { 279 | let hash = make_hash_point(&mut self.hasher, &[p.x, p.y]) as usize; 280 | new_table[hash].push(p); 281 | } 282 | 283 | self.table = new_table; 284 | } 285 | 286 | /// Rehash the map. 287 | #[inline] 288 | fn rehash(&mut self) -> Result<(), Error> { 289 | let mut old_data = Vec::with_capacity(self.items()); 290 | for p in self.table.iter_mut().flat_map(|bucket| bucket.drain(..)) { 291 | old_data.push(p); 292 | } 293 | self.batch_insert(&mut old_data) 294 | } 295 | 296 | /// Inner function for insert a single point into the map 297 | #[inline] 298 | fn insert_inner(&mut self, p: Point) -> Option> { 299 | // Resize if the table is empty or 3/4 size of the table is full 300 | if self.table.is_empty() || self.items() > 3 * self.table.len() / 4 { 301 | self.resize(); 302 | } 303 | let hash = make_hash_point::(&mut self.hasher, &[p.x, p.y]); 304 | self.insert_with_axis(p, hash) 305 | } 306 | 307 | /// Sequencial insert a point into the map. 308 | /// 309 | /// # Arguments 310 | /// * `p` - A Point with float number 311 | /// 312 | /// # Examples 313 | /// 314 | /// ``` 315 | /// use lsph::{LearnedHashMap, LinearModel, Point}; 316 | /// let a: Point = Point::new(0., 1.); 317 | /// let b: Point = Point::new(1., 0.); 318 | 319 | /// let mut map = LearnedHashMap::, f64>::new(); 320 | /// map.insert(a); 321 | /// map.insert(b); 322 | 323 | /// assert_eq!(map.items(), 2); 324 | /// assert_eq!(map.get(&[0., 1.]).unwrap(), &a); 325 | /// assert_eq!(map.get(&[1., 0.]).unwrap(), &b); 326 | /// ``` 327 | pub fn insert(&mut self, p: Point) -> Option> { 328 | // Resize if the table is empty or 3/4 size of the table is full 329 | if self.table.is_empty() || self.items() > 3 * self.table.len() / 4 { 330 | self.resize(); 331 | } 332 | 333 | let hash = make_hash_point::(&mut self.hasher, &[p.x, p.y]); 334 | // resize if hash index is larger or equal to the table capacity 335 | if hash >= self.table.capacity() as u64 { 336 | self.resize_with_capacity(hash as usize * 2); 337 | self.insert_with_axis(p, hash); 338 | match self.rehash() { 339 | Ok(_) => None, 340 | Err(err) => { 341 | eprintln!("{:?}", err); 342 | None 343 | } 344 | } 345 | } else { 346 | self.insert_with_axis(p, hash) 347 | } 348 | } 349 | 350 | /// Insert a point into the map along the given axis. 351 | /// 352 | /// # Arguments 353 | /// * `p_value` - A float number represent the key of a 2d point 354 | #[inline] 355 | fn insert_with_axis(&mut self, p: Point, hash: u64) -> Option> { 356 | let mut insert_index = 0; 357 | let bucket_index = self.table.bucket(hash); 358 | let bucket = &mut self.table[bucket_index]; 359 | if self.hasher.sort_by_x() { 360 | // Get index from the hasher 361 | for ep in bucket.iter_mut() { 362 | if ep == &mut p.clone() { 363 | return Some(mem::replace(ep, p)); 364 | } 365 | if ep.y < p.y() { 366 | insert_index += 1; 367 | } 368 | } 369 | } else { 370 | for ep in bucket.iter_mut() { 371 | if ep == &mut p.clone() { 372 | return Some(mem::replace(ep, p)); 373 | } 374 | if ep.x < p.x() { 375 | insert_index += 1; 376 | } 377 | } 378 | } 379 | bucket.insert(insert_index, p); 380 | self.items += 1; 381 | None 382 | } 383 | 384 | /// Fit the input data into the model of the hasher. Returns Error if error occurred during 385 | /// model fitting. 386 | /// 387 | /// # Arguments 388 | /// 389 | /// * `xs` - A list of tuple of floating number 390 | /// * `ys` - A list of tuple of floating number 391 | #[inline] 392 | pub fn model_fit(&mut self, xs: &[F], ys: &[F]) -> Result<(), Error> { 393 | self.hasher.model.fit(xs, ys) 394 | } 395 | 396 | /// Fit the input data into the model of the hasher. Returns Error if error occurred during 397 | /// model fitting. 398 | /// 399 | /// # Arguments 400 | /// * `data` - A list of tuple of floating number 401 | #[inline] 402 | pub fn model_fit_tuple(&mut self, data: &[(F, F)]) -> Result<(), Error> { 403 | self.hasher.model.fit_tuple(data) 404 | } 405 | 406 | /// Inner function for batch insert 407 | #[inline] 408 | fn batch_insert_inner(&mut self, ps: &[Point]) { 409 | // Allocate table capacity before insert 410 | let n = ps.len(); 411 | self.resize_with_capacity(n); 412 | for p in ps.iter() { 413 | self.insert_inner(*p); 414 | } 415 | } 416 | 417 | /// Batch insert a batch of 2d data into the map. 418 | /// 419 | /// # Arguments 420 | /// * `ps` - A list of point number 421 | /// 422 | /// # Examples 423 | /// 424 | /// ``` 425 | /// use lsph::{LearnedHashMap, LinearModel}; 426 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 427 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 428 | /// 429 | /// assert_eq!(map.get(&[1., 1.]).is_some(), true); 430 | /// ``` 431 | #[inline] 432 | pub fn batch_insert(&mut self, ps: &mut [Point]) -> Result<(), Error> { 433 | // Select suitable axis for training 434 | use crate::geometry::Axis; 435 | use crate::models::Trainer; 436 | 437 | // Loading data into trainer 438 | if let Ok(trainer) = Trainer::with_points(ps) { 439 | trainer.train(&mut self.hasher.model).unwrap(); 440 | let axis = trainer.axis(); 441 | match axis { 442 | Axis::X => self.hasher.set_sort_by_x(true), 443 | _ => self.hasher.set_sort_by_x(false), 444 | }; 445 | 446 | // Fit the data into model 447 | self.model_fit(trainer.train_x(), trainer.train_y()) 448 | .unwrap(); 449 | // Batch insert into the map 450 | self.batch_insert_inner(ps); 451 | } 452 | Ok(()) 453 | } 454 | 455 | /// Range search finds all points for a given 2d range 456 | /// Returns all the points within the given range 457 | /// ```text 458 | /// | top right 459 | /// | .-----------* 460 | /// | | . . | 461 | /// | | . . . | 462 | /// | | . | 463 | /// bottom left *-----------. 464 | /// | 465 | /// | | | 466 | /// |________v___________v________ 467 | /// left right 468 | /// hash hash 469 | /// ``` 470 | /// # Arguments 471 | /// 472 | /// * `bottom_left` - A tuple containing a pair of points that represent the bottom left of the 473 | /// range. 474 | /// 475 | /// * `top_right` - A tuple containing a pair of points that represent the top right of the 476 | /// range. 477 | #[inline] 478 | pub fn range_search( 479 | &mut self, 480 | bottom_left: &[F; 2], 481 | top_right: &[F; 2], 482 | ) -> Option>> { 483 | let mut right_hash = make_hash_point(&mut self.hasher, top_right) as usize; 484 | if right_hash > self.table.capacity() { 485 | right_hash = self.table.capacity() as usize - 1; 486 | } 487 | let left_hash = make_hash_point(&mut self.hasher, bottom_left) as usize; 488 | if left_hash > self.table.capacity() || left_hash > right_hash { 489 | return None; 490 | } 491 | let mut result: Vec> = Vec::new(); 492 | for i in left_hash..=right_hash { 493 | let bucket = &self.table[i]; 494 | for item in bucket.iter() { 495 | if item.x >= bottom_left[0] 496 | && item.x <= top_right[0] 497 | && item.y >= bottom_left[1] 498 | && item.y <= top_right[1] 499 | { 500 | result.push(*item); 501 | } 502 | } 503 | } 504 | if result.is_empty() { 505 | return None; 506 | } 507 | Some(result) 508 | } 509 | 510 | /// Returns Option>> if points are found in the map with given range 511 | /// 512 | /// # Arguments 513 | /// * `query_point` - A Point data for querying 514 | /// * `radius` - A radius value 515 | /// 516 | /// # Examples 517 | /// 518 | /// ``` 519 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 520 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 521 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 522 | /// assert_eq!(map.range_search(&[0., 0.], &[3., 3.]).is_some(), true); 523 | /// ``` 524 | #[inline] 525 | pub fn radius_range(&mut self, query_point: &[F; 2], radius: F) -> Option>> { 526 | self.range_search( 527 | &[query_point[0] - radius, query_point[1] - radius], 528 | &[query_point[0] + radius, query_point[1] + radius], 529 | ) 530 | } 531 | 532 | /// Find the local minimum distance between query points and cadidates neighbors, then store 533 | /// the cadidates neighbors in the min_heap. 534 | /// 535 | /// 536 | /// # Arguments 537 | /// * `heap` - mutable borrow of an BinaryHeap 538 | /// * `local_hash` - A hash index of local bucket 539 | /// * `query_point` - A Point data 540 | /// * `min_d` - minimum distance 541 | /// * `nearest_neighbor` - mutable borrow of an point data, which is the nearest neighbor at 542 | /// search index bucket 543 | #[inline] 544 | fn local_min_heap( 545 | &self, 546 | heap: &mut BinaryHeap>, 547 | local_hash: u64, 548 | query_point: &[F; 2], 549 | min_d: &mut F, 550 | nearest_neighbor: &mut Point, 551 | ) { 552 | let bucket = &self.table[local_hash as usize]; 553 | if !bucket.is_empty() { 554 | for p in bucket.iter() { 555 | let d = Euclidean::distance(query_point, &[p.x, p.y]); 556 | heap.push(NearestNeighborState { 557 | distance: d, 558 | point: *p, 559 | }); 560 | } 561 | } 562 | match heap.pop() { 563 | Some(v) => { 564 | let local_min_d = v.distance; 565 | // Update the nearest neighbour and minimum distance 566 | if &local_min_d < min_d { 567 | *nearest_neighbor = v.point; 568 | *min_d = local_min_d; 569 | } 570 | } 571 | None => (), 572 | } 573 | } 574 | 575 | /// Calculates the horizontal distance between query_point and bucket at index with given hash. 576 | /// 577 | /// # Arguments 578 | /// * `hash` - A hash index of the bucket 579 | /// * `query_point` - A Point data 580 | #[inline] 581 | fn horizontal_distance(&mut self, query_point: &[F; 2], hash: u64) -> F { 582 | let x = unhash(&mut self.hasher, hash); 583 | match self.hasher.sort_by_x() { 584 | true => Euclidean::distance(&[query_point[0], F::zero()], &[x, F::zero()]), 585 | false => Euclidean::distance(&[query_point[1], F::zero()], &[x, F::zero()]), 586 | } 587 | } 588 | 589 | /// Nearest neighbor search for the closest point for given query point 590 | /// Returns the closest point 591 | ///```text 592 | /// | 593 | /// | . 594 | /// | . | 595 | /// | |. | * . <- nearest neighbor 596 | /// | || | | .| 597 | /// | expand <--------> expand 598 | /// | left | right 599 | /// | | 600 | /// |_______________v_____________ 601 | /// query 602 | /// point 603 | ///``` 604 | /// # Arguments 605 | /// 606 | /// * `query_point` - A tuple containing a pair of points for querying 607 | /// 608 | /// # Examples 609 | /// 610 | /// ``` 611 | /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher}; 612 | /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 613 | /// let (mut map, points) = LearnedHashMap::, f64>::with_data(&point_data).unwrap(); 614 | /// assert_eq!(map.nearest_neighbor(&[2., 1.]).is_some(), true); 615 | /// ``` 616 | #[inline] 617 | pub fn nearest_neighbor(&mut self, query_point: &[F; 2]) -> Option> { 618 | let mut hash = make_hash_point(&mut self.hasher, query_point); 619 | let max_capacity = self.table.capacity() as u64; 620 | 621 | // if hash out of max bound, still search right most bucket 622 | if hash > max_capacity { 623 | hash = max_capacity - 1; 624 | } 625 | 626 | let mut heap = BinaryHeap::new(); 627 | let mut min_d = F::max_value(); 628 | let mut nearest_neighbor = Point::default(); 629 | 630 | // Searching at current hash index 631 | self.local_min_heap( 632 | &mut heap, 633 | hash, 634 | query_point, 635 | &mut min_d, 636 | &mut nearest_neighbor, 637 | ); 638 | 639 | // Measure left horizontal distance from current bucket to left hash bucket 640 | // left hash must >= 0 641 | let mut left_hash = hash.saturating_sub(1); 642 | // Unhash the left_hash, then calculate the vertical distance between 643 | // left hash point and query point 644 | let mut left_hash_d = self.horizontal_distance(query_point, left_hash); 645 | 646 | // Iterate over left 647 | while left_hash_d < min_d { 648 | self.local_min_heap( 649 | &mut heap, 650 | left_hash, 651 | query_point, 652 | &mut min_d, 653 | &mut nearest_neighbor, 654 | ); 655 | 656 | // break before update 657 | if left_hash == 0 { 658 | break; 659 | } 660 | 661 | // Update next right side bucket distance 662 | left_hash = left_hash.saturating_sub(1); 663 | left_hash_d = self.horizontal_distance(query_point, left_hash); 664 | } 665 | 666 | // Measure right vertical distance from current bucket to right hash bucket 667 | let mut right_hash = hash + 1; 668 | // Unhash the right_hash, then calculate the vertical distance between 669 | // right hash point and query point 670 | let mut right_hash_d = self.horizontal_distance(query_point, right_hash); 671 | 672 | // Iterate over right 673 | while right_hash_d < min_d { 674 | self.local_min_heap( 675 | &mut heap, 676 | right_hash, 677 | query_point, 678 | &mut min_d, 679 | &mut nearest_neighbor, 680 | ); 681 | 682 | // Move to next right bucket 683 | right_hash += 1; 684 | 685 | // break after update 686 | if right_hash == self.table.capacity() as u64 { 687 | break; 688 | } 689 | // Update next right side bucket distance 690 | right_hash_d = self.horizontal_distance(query_point, right_hash); 691 | } 692 | 693 | Some(nearest_neighbor) 694 | } 695 | } 696 | 697 | pub struct Iter<'a, M, F> 698 | where 699 | F: Float, 700 | M: Model + Default + Clone, 701 | { 702 | map: &'a LearnedHashMap, 703 | bucket: usize, 704 | at: usize, 705 | } 706 | 707 | impl<'a, M, F> Iterator for Iter<'a, M, F> 708 | where 709 | F: Float, 710 | M: Model + Default + Clone, 711 | { 712 | type Item = &'a Point; 713 | fn next(&mut self) -> Option { 714 | loop { 715 | match self.map.table.get(self.bucket) { 716 | Some(bucket) => { 717 | match bucket.get(self.at) { 718 | Some(p) => { 719 | // move along self.at and self.bucket 720 | self.at += 1; 721 | break Some(p); 722 | } 723 | None => { 724 | self.bucket += 1; 725 | self.at = 0; 726 | continue; 727 | } 728 | } 729 | } 730 | None => break None, 731 | } 732 | } 733 | } 734 | } 735 | 736 | impl<'a, M, F> IntoIterator for &'a LearnedHashMap 737 | where 738 | F: Float, 739 | M: Model + Default + Clone, 740 | { 741 | type Item = &'a Point; 742 | type IntoIter = Iter<'a, M, F>; 743 | fn into_iter(self) -> Self::IntoIter { 744 | Iter { 745 | map: self, 746 | bucket: 0, 747 | at: 0, 748 | } 749 | } 750 | } 751 | 752 | pub struct IntoIter 753 | where 754 | F: Float, 755 | M: Model + Default + Clone, 756 | { 757 | map: LearnedHashMap, 758 | bucket: usize, 759 | } 760 | 761 | impl Iterator for IntoIter 762 | where 763 | F: Float, 764 | M: Model + Default + Clone, 765 | { 766 | type Item = Point; 767 | fn next(&mut self) -> Option { 768 | loop { 769 | match self.map.table.get_mut(self.bucket) { 770 | Some(bucket) => match bucket.pop() { 771 | Some(x) => break Some(x), 772 | None => { 773 | self.bucket += 1; 774 | continue; 775 | } 776 | }, 777 | None => break None, 778 | } 779 | } 780 | } 781 | } 782 | 783 | impl IntoIterator for LearnedHashMap 784 | where 785 | F: Float, 786 | M: Model + Default + Clone, 787 | { 788 | type Item = Point; 789 | type IntoIter = IntoIter; 790 | fn into_iter(self) -> Self::IntoIter { 791 | IntoIter { 792 | map: self, 793 | bucket: 0, 794 | } 795 | } 796 | } 797 | 798 | #[cfg(test)] 799 | mod tests { 800 | use super::*; 801 | use crate::geometry::Point; 802 | use crate::models::LinearModel; 803 | use crate::test_utilities::*; 804 | 805 | #[test] 806 | fn insert() { 807 | let a: Point = Point::new(0., 1.); 808 | let b: Point = Point::new(1., 0.); 809 | 810 | let mut map = LearnedHashMap::, f64>::new(); 811 | map.insert(a); 812 | map.insert(b); 813 | 814 | assert_eq!(map.items(), 2); 815 | assert_eq!(map.get(&[0., 1.]).unwrap(), &a); 816 | assert_eq!(map.get(&[1., 0.]).unwrap(), &b); 817 | } 818 | 819 | #[test] 820 | fn insert_repeated() { 821 | let mut map = LearnedHashMap::, f64>::new(); 822 | let a: Point = Point::new(0., 1.); 823 | let b: Point = Point::new(1., 0.); 824 | let res = map.insert(a); 825 | assert_eq!(map.items(), 1); 826 | assert_eq!(res, None); 827 | 828 | let res = map.insert(b); 829 | assert_eq!(map.items(), 2); 830 | assert_eq!(res, None); 831 | } 832 | 833 | #[test] 834 | fn with_data() { 835 | let data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]]; 836 | let (mut map, _points) = LearnedHashMap::, f64>::with_data(&data).unwrap(); 837 | assert_eq!(map.get(&[1., 1.]).is_some(), true); 838 | } 839 | 840 | #[test] 841 | fn fit_batch_insert() { 842 | let mut data: Vec> = vec![ 843 | Point::new(1., 1.), 844 | Point::new(3., 1.), 845 | Point::new(2., 1.), 846 | Point::new(3., 2.), 847 | Point::new(5., 1.), 848 | ]; 849 | let mut map = LearnedHashMap::, f64>::new(); 850 | map.batch_insert(&mut data).unwrap(); 851 | dbg!(&map); 852 | 853 | assert_delta!(1.02272, map.hasher.model.coefficient, 0.00001); 854 | assert_delta!(-0.86363, map.hasher.model.intercept, 0.00001); 855 | assert_eq!(Some(&Point::new(1., 1.)), map.get(&[1., 1.])); 856 | assert_eq!(Some(&Point::new(3., 1.,)), map.get(&[3., 1.])); 857 | assert_eq!(Some(&Point::new(5., 1.)), map.get(&[5., 1.])); 858 | 859 | assert_eq!(None, map.get(&[5., 2.])); 860 | assert_eq!(None, map.get(&[2., 2.])); 861 | assert_eq!(None, map.get(&[50., 10.])); 862 | assert_eq!(None, map.get(&[500., 100.])); 863 | } 864 | 865 | #[test] 866 | fn insert_after_batch_insert() { 867 | let mut data: Vec> = vec![ 868 | Point::new(1., 1.), 869 | Point::new(3., 1.), 870 | Point::new(2., 1.), 871 | Point::new(3., 2.), 872 | Point::new(5., 1.), 873 | ]; 874 | let mut map = LearnedHashMap::, f64>::new(); 875 | map.batch_insert(&mut data).unwrap(); 876 | dbg!(&map); 877 | 878 | let a: Point = Point::new(10., 10.); 879 | map.insert(a.clone()); 880 | assert_eq!(Some(&a), map.get(&[10., 10.])); 881 | 882 | let b: Point = Point::new(100., 100.); 883 | map.insert(b.clone()); 884 | assert_eq!(Some(&b), map.get(&[100., 100.])); 885 | assert_eq!(None, map.get(&[100., 101.])); 886 | } 887 | 888 | #[test] 889 | fn range_search() { 890 | let mut data: Vec> = vec![ 891 | Point::new(1., 1.), 892 | Point::new(2., 2.), 893 | Point::new(3., 3.), 894 | Point::new(4., 4.), 895 | Point::new(5., 5.), 896 | ]; 897 | let mut map = LearnedHashMap::, f64>::new(); 898 | map.batch_insert(&mut data).unwrap(); 899 | // dbg!(&map); 900 | 901 | let found: Vec> = 902 | vec![Point::new(1., 1.), Point::new(2., 2.), Point::new(3., 3.)]; 903 | 904 | assert_eq!(Some(found), map.range_search(&[1., 1.], &[3.5, 3.])); 905 | 906 | let found: Vec> = vec![Point::new(1., 1.)]; 907 | 908 | assert_eq!(Some(found), map.range_search(&[1., 1.], &[3., 1.])); 909 | assert_eq!(None, map.range_search(&[4., 2.], &[5., 3.])); 910 | } 911 | 912 | #[test] 913 | fn test_nearest_neighbor() { 914 | let points = create_random_point_type_points(1000, SEED_1); 915 | let mut map = LearnedHashMap::, f64>::new(); 916 | map.batch_insert(&mut points.clone()).unwrap(); 917 | 918 | let sample_points = create_random_point_type_points(100, SEED_2); 919 | let mut i = 0; 920 | for sample_point in &sample_points { 921 | let mut nearest = None; 922 | let mut closest_dist = ::core::f64::INFINITY; 923 | for point in &points { 924 | let new_dist = Euclidean::distance_point(&point, &sample_point); 925 | if new_dist < closest_dist { 926 | closest_dist = new_dist; 927 | nearest = Some(point); 928 | } 929 | } 930 | let map_nearest = map 931 | .nearest_neighbor(&[sample_point.x, sample_point.y]) 932 | .unwrap(); 933 | assert_eq!(nearest.unwrap(), &map_nearest); 934 | i = i + 1; 935 | } 936 | } 937 | } 938 | -------------------------------------------------------------------------------- /src/map/nn.rs: -------------------------------------------------------------------------------- 1 | use crate::geometry::Point; 2 | use num_traits::float::Float; 3 | use std::cmp::Ordering; 4 | 5 | /// State for store nearest neighbors distances and points in min_heap 6 | #[derive(Copy, Clone, PartialEq)] 7 | pub struct NearestNeighborState 8 | where 9 | F: Float, 10 | { 11 | pub distance: F, 12 | pub point: Point, 13 | } 14 | 15 | impl Eq for NearestNeighborState {} 16 | 17 | impl PartialOrd for NearestNeighborState 18 | where 19 | F: Float, 20 | { 21 | fn partial_cmp(&self, other: &Self) -> Option { 22 | // We flip the ordering on distance, so the queue becomes a min-heap 23 | other.distance.partial_cmp(&self.distance) 24 | } 25 | } 26 | 27 | impl Ord for NearestNeighborState 28 | where 29 | F: Float, 30 | { 31 | fn cmp(&self, other: &Self) -> Ordering { 32 | self.partial_cmp(other).unwrap() 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/map/table.rs: -------------------------------------------------------------------------------- 1 | use core::ops::{Deref, DerefMut}; 2 | use smallvec::SmallVec; 3 | 4 | /// Bucket is the lower unit in the HashMap to store the points 5 | #[derive(Debug, Clone)] 6 | pub(crate) struct Bucket { 7 | entry: SmallVec<[V; 6]>, 8 | } 9 | 10 | impl Bucket { 11 | /// Returns a default Bucket with value type. 12 | #[inline] 13 | pub fn new() -> Self { 14 | Self { 15 | entry: SmallVec::new(), 16 | } 17 | } 18 | 19 | /// Removes an element from the Bucket and returns it. 20 | /// The removed element is replaced by the last element of the Bucket. 21 | #[inline] 22 | pub fn swap_remove(&mut self, index: usize) -> V { 23 | self.entry.swap_remove(index) 24 | } 25 | } 26 | 27 | impl Deref for Bucket { 28 | type Target = SmallVec<[V; 6]>; 29 | fn deref(&self) -> &Self::Target { 30 | &self.entry 31 | } 32 | } 33 | 34 | impl DerefMut for Bucket { 35 | fn deref_mut(&mut self) -> &mut Self::Target { 36 | &mut self.entry 37 | } 38 | } 39 | 40 | /// Table containing a Vec of Bucket to store the values 41 | #[derive(Debug, Clone)] 42 | pub(crate) struct Table { 43 | buckets: Vec>, 44 | } 45 | 46 | impl Table { 47 | /// Returns a default Table with empty Vec. 48 | #[inline] 49 | pub fn new() -> Self { 50 | Self { 51 | buckets: Vec::new(), 52 | } 53 | } 54 | 55 | /// Returns a default Table with Vec that with the given capacity. 56 | /// 57 | /// # Arguments 58 | /// * `capacity` - A capacity size for the Table 59 | #[inline] 60 | pub fn with_capacity(capacity: usize) -> Self { 61 | Self { 62 | buckets: Vec::with_capacity(capacity), 63 | } 64 | } 65 | 66 | /// Returns the capacity of the Table. 67 | #[inline] 68 | pub fn capacity(&self) -> usize { 69 | self.buckets.capacity() 70 | } 71 | 72 | /// Returns the Bucket with given hash value. 73 | /// 74 | /// # Arguments 75 | /// * `hash` - A hash value for indexing the bucket in the table 76 | #[inline] 77 | pub fn bucket(&self, hash: u64) -> usize { 78 | hash as usize % self.buckets.len() 79 | } 80 | } 81 | impl Table 82 | where 83 | V: PartialEq, 84 | { 85 | /// Remove entry with given hash value and entry. 86 | /// 87 | /// # Arguments 88 | /// * `hash` - A hash value for indexing the bucket in the table 89 | /// * `entry` - Entry to remove 90 | #[inline] 91 | pub fn remove_entry(&mut self, hash: u64, entry: V) -> Option { 92 | let index = self.bucket(hash); 93 | let bucket = &mut self.buckets[index]; 94 | let i = bucket.iter().position(|ek| ek == &entry)?; 95 | Some(bucket.swap_remove(i)) 96 | } 97 | } 98 | 99 | impl Deref for Table { 100 | type Target = Vec>; 101 | fn deref(&self) -> &Self::Target { 102 | &self.buckets 103 | } 104 | } 105 | 106 | impl DerefMut for Table { 107 | fn deref_mut(&mut self) -> &mut Self::Target { 108 | &mut self.buckets 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /src/models/linear.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | error::*, 3 | models::{stats::root_mean_squared_error, Model}, 4 | }; 5 | 6 | use core::fmt::Debug; 7 | use core::iter::Sum; 8 | use num_traits::{cast::FromPrimitive, float::Float}; 9 | 10 | /// Simple linear regression from tuples. 11 | /// 12 | /// Calculates the simple linear regression from array of tuples, and their means. 13 | /// 14 | /// # Arguments 15 | /// 16 | /// * `xys` - An array of tuples of training data that contains Xs and Ys. 17 | /// 18 | /// * `x_mean` - The mean of Xs training data. 19 | /// 20 | /// * `y_mean` - The mean of Ys target values. 21 | /// 22 | /// Returns `Ok(slope, intercept)` or Err(Error). 23 | /// 24 | /// # Errors 25 | /// 26 | /// Returns an error if 27 | /// 28 | /// * `xs` and `ys` differ in length 29 | /// * `xs` or `ys` are empty 30 | /// * the slope is too steep to represent, approaching infinity 31 | /// * the number of elements cannot be represented as an `F` 32 | fn slr(xys: I, x_mean: F, y_mean: F) -> Result<(F, F), Error> 33 | where 34 | I: Iterator, 35 | F: Float + Debug, 36 | { 37 | // compute the covariance of x and y as well as the variance of x 38 | let (sq_diff_sum, cov_diff_sum) = xys.fold((F::zero(), F::zero()), |(v, c), (x, y)| { 39 | let diff = x - x_mean; 40 | let sq_diff = diff * diff; 41 | let cov_diff = diff * (y - y_mean); 42 | (v + sq_diff, c + cov_diff) 43 | }); 44 | let slope = cov_diff_sum / sq_diff_sum; 45 | if slope.is_nan() { 46 | return Err(Error::SteepSlope); 47 | } 48 | let intercept = y_mean - slope * x_mean; 49 | Ok((slope, intercept)) 50 | } 51 | 52 | /// Two-pass simple linear regression from slices. 53 | /// 54 | /// Calculates the linear regression from two slices, one for x- and one for y-values, by 55 | /// calculating the mean and then calling `lin_reg`. 56 | /// 57 | /// # Arguments 58 | /// 59 | /// * `xs` - An array of tuples of training data. 60 | /// 61 | /// * `ys` - An array of tuples of targeting data. 62 | /// 63 | /// Returns `Ok(slope, intercept)` of the regression line. 64 | /// 65 | /// # Errors 66 | /// 67 | /// Returns an error if 68 | /// 69 | /// * `xs` and `ys` differ in length 70 | /// * `xs` or `ys` are empty 71 | /// * the slope is too steep to represent, approaching infinity 72 | /// * the number of elements cannot be represented as an `F` 73 | fn linear_regression(xs: &[X], ys: &[Y]) -> Result<(F, F), Error> 74 | where 75 | X: Clone + Into, 76 | Y: Clone + Into, 77 | F: Float + Sum + Debug, 78 | { 79 | assert_empty!(xs); 80 | assert_empty!(ys); 81 | assert_eq_len!(xs, ys); 82 | 83 | let n = F::from(xs.len()).ok_or(Error::EmptyVal)?; 84 | 85 | // compute the mean of x and y 86 | let x_sum: F = xs.iter().cloned().map(Into::into).sum(); 87 | let x_mean = x_sum / n; 88 | let y_sum: F = ys.iter().cloned().map(Into::into).sum(); 89 | let y_mean = y_sum / n; 90 | 91 | let data = xs 92 | .iter() 93 | .zip(ys.iter()) 94 | .map(|(x, y)| (x.clone().into(), y.clone().into())); 95 | 96 | slr(data, x_mean, y_mean) 97 | } 98 | 99 | /// Two-pass linear regression from tuples. 100 | /// 101 | /// Calculates the linear regression from a slice of tuple values by first calculating the mean 102 | /// before calling `lin_reg`. 103 | /// 104 | /// Returns `Ok(slope, intercept)` of the regression line. 105 | /// 106 | /// # Errors 107 | /// 108 | /// Returns an error if 109 | /// 110 | /// * `xys` is empty 111 | /// * the slope is too steep to represent, approaching infinity 112 | /// * the number of elements cannot be represented as an `F` 113 | fn linear_regression_tuple(xys: &[(X, Y)]) -> Result<(F, F), Error> 114 | where 115 | X: Clone + Into, 116 | Y: Clone + Into, 117 | F: Float + Debug, 118 | { 119 | assert_empty!(xys); 120 | 121 | // We're handrolling the mean computation here, because our generic implementation can't handle tuples. 122 | // If we ran the generic impl on each tuple field, that would be very cache inefficient 123 | let n = F::from(xys.len()).ok_or(Error::EmptyVal)?; 124 | let (x_sum, y_sum) = xys 125 | .iter() 126 | .cloned() 127 | .fold((F::zero(), F::zero()), |(sx, sy), (x, y)| { 128 | (sx + x.into(), sy + y.into()) 129 | }); 130 | let x_mean = x_sum / n; 131 | let y_mean = y_sum / n; 132 | 133 | slr( 134 | xys.iter() 135 | .map(|(x, y)| (x.clone().into(), y.clone().into())), 136 | x_mean, 137 | y_mean, 138 | ) 139 | } 140 | 141 | /// Linear regression model 142 | #[derive(Copy, Clone, Debug, Default)] 143 | pub struct LinearModel { 144 | pub coefficient: F, 145 | pub intercept: F, 146 | } 147 | 148 | impl LinearModel 149 | where 150 | F: Float, 151 | { 152 | pub fn new() -> LinearModel { 153 | LinearModel { 154 | coefficient: F::zero(), 155 | intercept: F::zero(), 156 | } 157 | } 158 | } 159 | 160 | impl Model for LinearModel 161 | where 162 | F: Float + FromPrimitive + Sum + Debug + Sized, 163 | { 164 | type F = F; 165 | 166 | fn name(&self) -> String { 167 | String::from("linear") 168 | } 169 | 170 | fn fit(&mut self, xs: &[F], ys: &[F]) -> Result<(), Error> { 171 | let (coefficient, intercept): (F, F) = linear_regression(xs, ys).unwrap(); 172 | self.coefficient = coefficient; 173 | self.intercept = intercept; 174 | Ok(()) 175 | } 176 | fn fit_tuple(&mut self, xys: &[(F, F)]) -> Result<(), Error> { 177 | let (coefficient, intercept): (F, F) = linear_regression_tuple(xys).unwrap(); 178 | self.coefficient = coefficient; 179 | self.intercept = intercept; 180 | Ok(()) 181 | } 182 | 183 | fn predict(&self, x: F) -> F { 184 | x * self.coefficient + self.intercept 185 | } 186 | fn batch_predict(&self, xs: &[F]) -> Vec { 187 | (0..xs.len()).map(|i| self.predict(xs[i])).collect() 188 | } 189 | 190 | fn evaluate(&self, x_test: &[F], y_test: &[F]) -> F { 191 | let y_predicted = self.batch_predict(x_test); 192 | root_mean_squared_error(y_test, &y_predicted) 193 | } 194 | 195 | fn unpredict(&self, y: F) -> F { 196 | (y - self.intercept) / self.coefficient 197 | } 198 | } 199 | 200 | #[cfg(test)] 201 | mod tests { 202 | use super::*; 203 | 204 | #[test] 205 | #[should_panic] 206 | fn should_panic_for_empty_vecs() { 207 | let x_values: Vec = vec![]; 208 | let y_values: Vec = vec![]; 209 | let mut model = LinearModel::new(); 210 | model.fit(&x_values, &y_values).unwrap(); 211 | 212 | assert_delta!(0.8f64, model.coefficient, 0.00001); 213 | assert_delta!(0.4, model.intercept, 0.00001); 214 | } 215 | 216 | #[test] 217 | fn fit_coefficients() { 218 | let x_values: Vec = vec![1., 2., 3., 4., 5.]; 219 | let y_values: Vec = vec![1., 3., 2., 3., 5.]; 220 | let mut model = LinearModel::new(); 221 | model.fit(&x_values, &y_values).unwrap(); 222 | 223 | assert_delta!(0.8f64, model.coefficient, 0.00001); 224 | assert_delta!(0.4f64, model.intercept, 0.00001); 225 | } 226 | 227 | #[test] 228 | fn fit_coefficients_f32() { 229 | let x_values: Vec = vec![1., 2., 3., 4., 5.]; 230 | let y_values: Vec = vec![1., 3., 2., 3., 5.]; 231 | let mut model = LinearModel::new(); 232 | model.fit(&x_values, &y_values).unwrap(); 233 | 234 | assert_delta!(0.8, model.coefficient, 0.00001); 235 | assert_delta!(0.4, model.intercept, 0.00001); 236 | } 237 | 238 | #[test] 239 | fn predict() { 240 | let x_values = vec![1f64, 2f64, 3f64, 4f64, 5f64]; 241 | let y_values = vec![1f64, 3f64, 2f64, 3f64, 5f64]; 242 | let mut model = LinearModel::new(); 243 | model.fit(&x_values, &y_values).unwrap(); 244 | 245 | assert_delta!(1.2f64, model.predict(1f64), 0.00001); 246 | assert_delta!(2f64, model.predict(2f64), 0.00001); 247 | assert_delta!(2.8f64, model.predict(3f64), 0.00001); 248 | assert_delta!(3.6f64, model.predict(4f64), 0.00001); 249 | assert_delta!(4.4f64, model.predict(5f64), 0.00001); 250 | } 251 | 252 | #[test] 253 | fn predict_list() { 254 | let x_values = vec![1f64, 2f64, 3f64, 4f64, 5f64]; 255 | let y_values = vec![1f64, 3f64, 2f64, 3f64, 5f64]; 256 | let mut model = LinearModel::new(); 257 | model.fit(&x_values, &y_values).unwrap(); 258 | 259 | let predictions = model.batch_predict(&x_values); 260 | 261 | assert_delta!(1.2f64, predictions[0], 0.00001); 262 | assert_delta!(2f64, predictions[1], 0.00001); 263 | assert_delta!(2.8f64, predictions[2], 0.00001); 264 | assert_delta!(3.6f64, predictions[3], 0.00001); 265 | assert_delta!(4.4f64, predictions[4], 0.00001); 266 | } 267 | 268 | #[test] 269 | fn evaluate() { 270 | let x_values = vec![1f64, 2f64, 3f64, 4f64, 5f64]; 271 | let y_values = vec![1f64, 3f64, 2f64, 3f64, 5f64]; 272 | let mut model = LinearModel::new(); 273 | model.fit(&x_values.clone(), &y_values.clone()).unwrap(); 274 | 275 | let error = model.evaluate(&x_values, &y_values); 276 | assert_delta!(0.69282f64, error, 0.00001); 277 | } 278 | } 279 | -------------------------------------------------------------------------------- /src/models/mod.rs: -------------------------------------------------------------------------------- 1 | mod linear; 2 | mod stats; 3 | mod trainer; 4 | 5 | pub use linear::*; 6 | pub use stats::*; 7 | pub use trainer::*; 8 | 9 | use crate::error::Error; 10 | use core::fmt::Debug; 11 | use num_traits::float::Float; 12 | 13 | /// Model representation, provides common functionalities for model training 14 | pub trait Model { 15 | /// Associated type for float number representation 16 | type F; 17 | /// Prints the name of the model 18 | fn name(&self) -> String; 19 | /// Fit two slices of training data into the model 20 | fn fit(&mut self, xs: &[Self::F], ys: &[Self::F]) -> Result<(), Error>; 21 | /// Fit one slice of training data in tuple format into the model 22 | fn fit_tuple(&mut self, xys: &[(Self::F, Self::F)]) -> Result<(), Error>; 23 | /// Takes one value and returns the predictions of the model 24 | fn predict(&self, x: Self::F) -> Self::F; 25 | /// Takes slice of value and returns the batch predictions from the model 26 | fn batch_predict(&self, xs: &[Self::F]) -> Vec; 27 | /// Evaluate the predictions results from a pair of test sets 28 | fn evaluate(&self, x_test: &[Self::F], y_test: &[Self::F]) -> Self::F; 29 | /// Unpredict provides the ability of reversing the predict operation 30 | /// For a given target value, return the estimate input value 31 | fn unpredict(&self, y: Self::F) -> Self::F; 32 | } 33 | 34 | impl Debug for (dyn Model + 'static) 35 | where 36 | F: Float, 37 | { 38 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 39 | write!(f, "Model {{{}}}", self.name()) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/models/stats.rs: -------------------------------------------------------------------------------- 1 | use core::iter::Sum; 2 | use num_traits::{cast::FromPrimitive, float::Float}; 3 | 4 | /// Calculates mean of a slice of values 5 | pub fn mean(values: &[F]) -> F 6 | where 7 | F: Float + Sum, 8 | { 9 | if values.is_empty() { 10 | return F::zero(); 11 | } 12 | let sum: F = values.iter().cloned().map(Into::into).sum(); 13 | sum / F::from(values.len()).unwrap() 14 | } 15 | 16 | /// Calculates variance of a slice of values 17 | pub fn variance(values: &[F]) -> F 18 | where 19 | F: Float + Sum, 20 | { 21 | if values.is_empty() { 22 | return F::zero(); 23 | } 24 | let mean = mean(values); 25 | 26 | let diff_sum: F = values 27 | .iter() 28 | .cloned() 29 | .map(|x| (x - mean).powf(F::from(2.0).unwrap())) 30 | .sum(); 31 | diff_sum / F::from(values.len()).unwrap() 32 | } 33 | 34 | /// Calculates covariance over two slices of values 35 | pub fn covariance(x_values: &[F], y_values: &[F]) -> F 36 | where 37 | F: Float + Sum, 38 | { 39 | if x_values.len() != y_values.len() { 40 | panic!("x_values and y_values must be of equal length."); 41 | } 42 | let length: usize = x_values.len(); 43 | if length == 0usize { 44 | return F::zero(); 45 | } 46 | let mean_x = mean(x_values); 47 | let mean_y = mean(y_values); 48 | 49 | x_values 50 | .iter() 51 | .zip(y_values.iter()) 52 | .fold(F::zero(), |covariance, (&x, &y)| { 53 | covariance + (x - mean_x) * (y - mean_y) 54 | }) 55 | / F::from(length).unwrap() 56 | } 57 | 58 | /// Calculates mean squared error between actual and predict values 59 | pub fn mean_squared_error(actual: &[F], predict: &[F]) -> F 60 | where 61 | F: Float + FromPrimitive, 62 | { 63 | if actual.len() != predict.len() { 64 | panic!("actual and predict must be of equal length."); 65 | } 66 | 67 | actual 68 | .iter() 69 | .zip(predict.iter()) 70 | .fold(F::from_f64(0.0).unwrap(), |sum, (&x, &y)| { 71 | sum + (x - y).powf(F::from_f64(2.0).unwrap()) 72 | }) 73 | / F::from_usize(actual.len()).unwrap() 74 | } 75 | 76 | /// Calculates root mean squared error between actual and predict values 77 | pub fn root_mean_squared_error(actual: &[F], predict: &[F]) -> F 78 | where 79 | F: Float + FromPrimitive, 80 | { 81 | mean_squared_error::(actual, predict).sqrt() 82 | } 83 | 84 | #[cfg(test)] 85 | mod tests { 86 | use super::*; 87 | 88 | #[test] 89 | fn mean_empty_vec() { 90 | let values: Vec = vec![]; 91 | assert_delta!(0., mean(&values), 0.00001); 92 | } 93 | 94 | #[test] 95 | fn mean_empty_vec_f32() { 96 | let values: Vec = vec![]; 97 | assert_delta!(0., mean(&values), 0.00001); 98 | } 99 | 100 | #[test] 101 | fn mean_1_to_5() { 102 | let values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 103 | assert_delta!(3., mean(&values), 0.00001); 104 | } 105 | 106 | #[test] 107 | fn mean_1_to_5_f32() { 108 | let values: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 109 | assert_delta!(3., mean(&values), 0.00001); 110 | } 111 | 112 | #[test] 113 | fn variance_empty_vec() { 114 | let values: Vec = vec![]; 115 | assert_delta!(0., variance(&values), 0.00001); 116 | } 117 | 118 | #[test] 119 | fn variance_empty_vec_f32() { 120 | let values: Vec = vec![]; 121 | assert_delta!(0., variance(&values), 0.00001); 122 | } 123 | 124 | #[test] 125 | fn variance_1_to_5() { 126 | let values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 127 | assert_delta!(2., variance(&values), 0.00001); 128 | } 129 | 130 | #[test] 131 | fn variance_1_to_5_f32() { 132 | let values: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 133 | assert_delta!(2., variance(&values), 0.00001); 134 | } 135 | 136 | #[test] 137 | fn covariance_empty_vec() { 138 | let x_values: Vec = vec![]; 139 | let y_values: Vec = vec![]; 140 | assert_delta!(0., covariance(&x_values, &y_values), 0.00001); 141 | } 142 | 143 | #[test] 144 | fn covariance_empty_vec_f32() { 145 | let x_values: Vec = vec![]; 146 | let y_values: Vec = vec![]; 147 | assert_delta!(0., covariance(&x_values, &y_values), 0.00001); 148 | } 149 | 150 | #[test] 151 | fn covariance_1_to_5() { 152 | let x_values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 153 | let y_values = vec![1.0, 3.0, 2.0, 3.0, 5.0]; 154 | assert_delta!(1.6, covariance(&x_values, &y_values), 0.00001); 155 | } 156 | 157 | #[test] 158 | fn covariance_1_to_5_f32() { 159 | let x_values: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 160 | let y_values: Vec = vec![1.0, 3.0, 2.0, 3.0, 5.0]; 161 | assert_delta!(1.6, covariance(&x_values, &y_values), 0.00001); 162 | } 163 | 164 | #[test] 165 | fn negative_covariance_1_to_5() { 166 | let x_values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 167 | let y_values = vec![0.5, 4.0, 1.0, -5.0, 4.0]; 168 | assert_delta!(-0.4, covariance(&x_values, &y_values), 0.00001); 169 | } 170 | 171 | #[test] 172 | fn negative_covariance_1_to_5_f32() { 173 | let x_values: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 174 | let y_values: Vec = vec![0.5, 4.0, 1.0, -5.0, 4.0]; 175 | assert_delta!(-0.4, covariance(&x_values, &y_values), 0.00001); 176 | } 177 | 178 | #[test] 179 | fn mean_squared_error_test() { 180 | let actual = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 181 | let predict = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 182 | assert_delta!(0., mean_squared_error(&actual, &predict), 0.00001); 183 | } 184 | 185 | #[test] 186 | fn mean_squared_error_test_f32() { 187 | let actual: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 188 | let predict: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 189 | assert_delta!(0., mean_squared_error(&actual, &predict), 0.00001); 190 | } 191 | 192 | #[test] 193 | fn mean_squared_error_test_2() { 194 | let actual = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 195 | let predict = vec![1.0, 1.6, 3.0, 4.0, 5.0]; 196 | assert_delta!(0.032, mean_squared_error(&actual, &predict), 0.00001); 197 | } 198 | 199 | #[test] 200 | fn mean_squared_error_test_2_f32() { 201 | let actual: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; 202 | let predict: Vec = vec![1.0, 1.6, 3.0, 4.0, 5.0]; 203 | assert_delta!(0.032, mean_squared_error(&actual, &predict), 0.00001); 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /src/models/trainer.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | error::Error, 3 | geometry::{helper::*, Axis, Point}, 4 | models::{variance, Model}, 5 | }; 6 | use core::iter::Sum; 7 | use num_traits::{cast::FromPrimitive, Float}; 8 | 9 | /// Preprocessing and prepare data for model training 10 | /// 11 | #[derive(Debug, Clone)] 12 | pub struct Trainer { 13 | train_x: Vec, 14 | train_y: Vec, 15 | axis: Axis, 16 | } 17 | 18 | impl Default for Trainer { 19 | fn default() -> Self { 20 | Self { 21 | train_x: Vec::::new(), 22 | train_y: Vec::::new(), 23 | axis: Axis::X, 24 | } 25 | } 26 | } 27 | 28 | impl Trainer 29 | where 30 | F: Float + Sized, 31 | { 32 | pub fn new() -> Self { 33 | Self::default() 34 | } 35 | 36 | pub fn train_x(&self) -> &Vec { 37 | &self.train_x 38 | } 39 | 40 | pub fn train_y(&self) -> &Vec { 41 | &self.train_y 42 | } 43 | 44 | pub fn axis(&self) -> &Axis { 45 | &self.axis 46 | } 47 | 48 | pub fn set_train_x(&mut self, xs: Vec) { 49 | self.train_x = xs 50 | } 51 | 52 | pub fn set_train_y(&mut self, ys: Vec) { 53 | self.train_y = ys 54 | } 55 | 56 | pub fn set_axis(&mut self, axis: Axis) { 57 | self.axis = axis 58 | } 59 | 60 | /// Training with provided model 61 | /// 62 | /// Returns trained/fitted model on success, otherwise returns an error 63 | pub fn train<'a, M: Model + 'a>(&self, model: &'a mut M) -> Result<(), Error> { 64 | model.fit(&self.train_x, &self.train_y)?; 65 | Ok(()) 66 | } 67 | } 68 | 69 | impl Trainer 70 | where 71 | F: Float + Sum + FromPrimitive, 72 | { 73 | /// Initialize Trainer with two Vec 74 | /// 75 | /// Returns prepared Trainer Ok((Trainer)) on success, otherwise returns an error 76 | pub fn with_data(xs: Vec, ys: Vec) -> Result<(Self, Vec>), Error> { 77 | assert_empty!(xs); 78 | assert_eq_len!(xs, ys); 79 | 80 | let mut trainer = Trainer::new(); 81 | let data = trainer.preprocess(xs, ys)?; 82 | Ok((trainer, data)) 83 | } 84 | 85 | /// Preprocess two Vec that satisfy Trainer's requirements 86 | /// 87 | /// Returns sorted Ok(Vec>) on success, otherwise returns an error 88 | pub fn preprocess(&mut self, xs: Vec, ys: Vec) -> Result>, Error> { 89 | assert_empty!(xs); 90 | assert_eq_len!(xs, ys); 91 | 92 | let mut ps: Vec> = xs 93 | .iter() 94 | .zip(ys.iter()) 95 | .map(|(&x, &y)| Point { x, y }) 96 | .collect(); 97 | 98 | // set train_x to data with larger variance 99 | if variance(&xs) > variance(&ys) { 100 | // sort along x 101 | sort_by_x(&mut ps); 102 | self.set_axis(Axis::X); 103 | self.set_train_x(extract_x(&ps)); 104 | } else { 105 | // sort along y 106 | sort_by_y(&mut ps); 107 | self.set_axis(Axis::Y); 108 | self.set_train_x(extract_y(&ps)); 109 | }; 110 | 111 | let train_y: Vec = (0..ps.len()).map(|id| F::from_usize(id).unwrap()).collect(); 112 | self.set_train_y(train_y); 113 | Ok(ps) 114 | } 115 | 116 | /// Preprocess with Vec> that satisfy Trainer's requirements 117 | /// 118 | /// Returns prepared Trainer Ok((Trainer)) on success, otherwise returns an error 119 | pub fn with_points(ps: &mut [Point]) -> Result { 120 | let px: Vec = extract_x(ps); 121 | let py: Vec = extract_y(ps); 122 | assert_eq_len!(px, py); 123 | let x_variance = variance(&px); 124 | let y_variance = variance(&py); 125 | // set train_x to data with larger variance 126 | let (axis, train_x) = if x_variance > y_variance { 127 | sort_by_x(ps); 128 | (Axis::X, extract_x(ps)) 129 | } else { 130 | sort_by_y(ps); 131 | (Axis::Y, extract_y(ps)) 132 | }; 133 | let train_y: Vec = (0..ps.len()).map(|id| F::from_usize(id).unwrap()).collect(); 134 | Ok(Self { 135 | train_x, 136 | train_y, 137 | axis, 138 | }) 139 | } 140 | } 141 | 142 | #[cfg(test)] 143 | mod tests { 144 | use super::*; 145 | 146 | #[test] 147 | fn sort_by() { 148 | let mut data: Vec> = vec![ 149 | Point { x: 1., y: 1. }, 150 | Point { x: 3., y: 1. }, 151 | Point { x: 2., y: 1. }, 152 | Point { x: 3., y: 2. }, 153 | Point { x: 5., y: 1. }, 154 | ]; 155 | let data_sort_by_x: Vec> = vec![ 156 | Point { x: 1., y: 1. }, 157 | Point { x: 2., y: 1. }, 158 | Point { x: 3., y: 1. }, 159 | Point { x: 3., y: 2. }, 160 | Point { x: 5., y: 1. }, 161 | ]; 162 | sort_by_x(&mut data); 163 | 164 | assert_eq!(data_sort_by_x, data); 165 | } 166 | 167 | #[test] 168 | fn train() { 169 | let mut data: Vec> = vec![ 170 | Point { x: 1., y: 1. }, 171 | Point { x: 3., y: 1. }, 172 | Point { x: 2., y: 1. }, 173 | Point { x: 3., y: 2. }, 174 | Point { x: 5., y: 1. }, 175 | ]; 176 | let trainer = Trainer::with_points(&mut data).unwrap(); 177 | let test_x = vec![1., 2., 3., 3., 5.]; 178 | let test_y = vec![0., 1., 2., 3., 4.]; 179 | 180 | assert_eq!(&test_x, trainer.train_x()); 181 | assert_eq!(&test_y, trainer.train_y()); 182 | } 183 | } 184 | -------------------------------------------------------------------------------- /src/test_utilities.rs: -------------------------------------------------------------------------------- 1 | use crate::geometry::*; 2 | use rand::{Rng, SeedableRng}; 3 | use rand_hc::Hc128Rng; 4 | 5 | pub type Seed = [u8; 32]; 6 | 7 | pub const SEED_1: &Seed = b"wPYxAkIiHcEmSBAxQFoXFrpYToCe1B71"; 8 | pub const SEED_2: &Seed = b"4KbTVjPT4DXSwWAsQM5dkWWywPKZRfCX"; 9 | 10 | pub fn create_random_points(num_points: usize, seed: &[u8; 32]) -> Vec<(f64, f64)> { 11 | let mut result = Vec::with_capacity(num_points); 12 | let mut rng = Hc128Rng::from_seed(*seed); 13 | for _ in 0..num_points { 14 | result.push((rng.gen(), rng.gen())); 15 | } 16 | result 17 | } 18 | 19 | pub fn create_random_point_type_points(num_points: usize, seed: &[u8; 32]) -> Vec> { 20 | let result = create_random_points(num_points, seed); 21 | 22 | // result.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); 23 | result 24 | .into_iter() 25 | .map(|(x, y)| Point { x, y }) 26 | .collect::>() 27 | } 28 | --------------------------------------------------------------------------------