├── .github └── workflows │ ├── ci.yml │ └── presubmits.yml ├── .gitignore ├── LICENSE ├── NOTICE ├── README.md ├── container ├── bruteforce │ ├── bruteforce.go │ └── bruteforce_test.go ├── container.go ├── kd │ ├── kd.go │ └── kd_test.go └── kyroy │ ├── kyroy.go │ └── kyroy_test.go ├── filter └── filter.go ├── go.mod ├── go.sum ├── internal ├── knn │ ├── knn.go │ └── knn_test.go ├── node │ ├── node.go │ ├── tree │ │ ├── tree.go │ │ └── tree_test.go │ └── util │ │ └── util.go ├── perf │ ├── perf_test.go │ ├── results │ │ └── v0.5.5.txt │ └── util │ │ ├── util.go │ │ └── util_test.go └── rangesearch │ ├── rangesearch.go │ └── rangesearch_test.go ├── kd ├── kd.go └── kd_test.go ├── point ├── mock │ └── mock.go └── point.go ├── vector └── vector.go └── x ├── README.md ├── go.mod └── go.sum /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | 7 | jobs: 8 | 9 | presubmit: 10 | name: CI Tests 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | 16 | - uses: actions/setup-go@v2 17 | with: 18 | go-version: 1.19 19 | 20 | - name: Build github.com/downflux/go-kd 21 | run: go build github.com/downflux/go-kd/... 22 | 23 | - name: Build github.com/downflux/go-kd 24 | run: go vet github.com/downflux/go-kd/... 25 | 26 | - name: Test github.com/downflux/go-kd 27 | run: go test github.com/downflux/go-kd/... -run ^$ -bench . -benchmem 28 | -------------------------------------------------------------------------------- /.github/workflows/presubmits.yml: -------------------------------------------------------------------------------- 1 | name: CI Presubmits 2 | 3 | on: 4 | pull_request: 5 | branches: [ main ] 6 | push: 7 | branches: [ "*" ] 8 | 9 | jobs: 10 | 11 | presubmit: 12 | name: CI Presubmits 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | - uses: actions/setup-go@v2 19 | with: 20 | go-version: 1.19 21 | 22 | - name: Test github.com/downflux/go-kd 23 | run: go test github.com/downflux/go-kd/... 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | DownFlux 2 | Copyright 2021 DownFlux 3 | 4 | This product includes software from 5 | 6 | * the "kyroy/kdtree" library (https://github.com/kyroy/kdtree). 7 | 8 | Some parts of the framework, including but not limited to 9 | 10 | * internal/knn/... 11 | * internal/rangesearch/... 12 | 13 | are copied from, derived from, or inspired by the kyroy/kdtree library. 14 | Copyright 2018 - 2021 Dennis Kuhnert. All Rights Reserved. 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-kd 2 | 3 | Golang K-D tree implementation with duplicate coordinate support 4 | 5 | See [Wikipedia](https://en.wikipedia.org/wiki/K-d_tree) for more information. 6 | 7 | ## Testing 8 | 9 | ```bash 10 | go test github.com/downflux/go-kd/... 11 | go test github.com/downflux/go-kd/internal/perf \ 12 | -bench . \ 13 | -benchmem \ 14 | -timeout=60m \ 15 | -args -performance_test_size=large 16 | ``` 17 | 18 | ## Example 19 | 20 | ```golang 21 | package main 22 | 23 | import ( 24 | "fmt" 25 | 26 | "github.com/downflux/go-geometry/nd/hyperrectangle" 27 | "github.com/downflux/go-geometry/nd/vector" 28 | "github.com/downflux/go-kd/point" 29 | 30 | "github.com/downflux/go-kd/kd" 31 | ) 32 | 33 | // P implements the point.P interface, which needs to provide a coordinate 34 | // vector function P(). 35 | var _ point.P = &P{} 36 | 37 | type P struct { 38 | p vector.V 39 | tag string 40 | } 41 | 42 | func (p *P) P() vector.V { return p.p } 43 | func (p *P) Equal(q *P) bool { return vector.Within(p.P(), q.P()) && p.tag == q.tag } 44 | 45 | func main() { 46 | data := []*P{ 47 | &P{p: vector.V{1, 2}, tag: "A"}, 48 | &P{p: vector.V{2, 100}, tag: "B"}, 49 | } 50 | 51 | // Data is copy-constructed and may be read from outside the k-D tree. 52 | t := kd.New[*P](kd.O[*P]{ 53 | Data: data, 54 | K: 2, 55 | N: 1, 56 | }) 57 | 58 | fmt.Println("KNN search") 59 | for _, p := range kd.KNN( 60 | t, 61 | /* v = */ vector.V{0, 0}, 62 | /* k = */ 2, 63 | func(p *P) bool { return true }) { 64 | fmt.Println(p) 65 | } 66 | 67 | // Remove deletes the first data point at the given input coordinate and 68 | // matches the input check function. 69 | p, ok := t.Remove(data[0].P(), data[0].Equal) 70 | fmt.Printf("removed %v (found = %v)\n", p, ok) 71 | 72 | // RangeSearch returns all points within the k-D bounds and matches the 73 | // input filter function. 74 | fmt.Println("range search") 75 | for _, p := range kd.RangeSearch( 76 | t, 77 | *hyperrectangle.New( 78 | /* min = */ vector.V{0, 0}, 79 | /* max = */ vector.V{100, 100}, 80 | ), 81 | func(p *P) bool { return true }, 82 | ) { 83 | fmt.Println(p) 84 | } 85 | } 86 | ``` 87 | 88 | ## Performance (@v1.0.0) 89 | 90 | This k-D tree implementation was compared against a brute force method, as well 91 | as with the leading Golang k-D tree implementation 92 | (http://github.com/kyroy/kdtree). Overall, we have found that 93 | 94 | * tree construction is about 10x faster for large N. 95 | 96 | ``` 97 | BenchmarkNew/kyroy/K=16/N=1000-8 758980 ns/op 146777 B/op 98 | BenchmarkNew/Real/K=16/N=1000/LeafSize=16-8 200749 ns/op 32637 B/op 99 | 100 | BenchmarkNew/kyroy/K=16/N=1000000-8 7407144200 ns/op 184813784 B/op 101 | BenchmarkNew/Real/K=16/N=1000000/LeafSize=256-8 588456300 ns/op 12462912 B/op 102 | ``` 103 | 104 | * KNN is significantly faster; for small N, we have found our implementation is 105 | ~10x faster than the reference implementation and ~20x faster than brute 106 | force. For large N, we have found up to ~15x faster than brute force, and a 107 | staggering _~1500x_ speedup when compared to the reference implementation. 108 | 109 | ``` 110 | BenchmarkKNN/BruteForce/K=16/N=1000-8 1563019 ns/op 2220712 B/op 111 | BenchmarkKNN/kyroy/K=16/N=1000/KNN=0.05-8 791415 ns/op 21960 B/op 112 | BenchmarkKNN/Real/K=16/N=1000/LeafSize=16/KNN=0.05-8 69537 ns/op 12024 B/op 113 | 114 | BenchmarkKNN/BruteForce/K=16/N=1000000-8 5030811400 ns/op 5347687464 B/op 115 | BenchmarkKNN/kyroy/K=16/N=1000000/KNN=0.05-8 529703585200 ns/op 23755688 B/op 116 | BenchmarkKNN/Real/K=16/N=1000000/LeafSize=256/KNN=0.05-8 335845533 ns/op 6044016 B/op 117 | ``` 118 | 119 | * RangeSearch is slower for small N -- we are approximately at parity for brute 120 | force, and ~10x slower than the reference implementation. However, at large N, 121 | we are ~300x faster than brute force, and ~100x faster than the reference 122 | implementation. 123 | 124 | ``` 125 | BenchmarkRangeSearch/BruteForce/K=16/N=1000-8 154712 ns/op 25208 B/op 126 | BenchmarkRangeSearch/kyroy/K=16/N=1000/Coverage=0.05-8 13373 ns/op 496 B/op 127 | BenchmarkRangeSearch/Real/K=16/N=1000/LeafSize=16/Coverage=0.05-8 193276 ns/op 101603 B/op 128 | 129 | BenchmarkRangeSearch/BruteForce/K=16/N=1000000-8 173427000 ns/op 41678072 B/op 130 | BenchmarkRangeSearch/kyroy/K=16/N=1000000/Coverage=0.05-8 56820240 ns/op 496 B/op 131 | BenchmarkRangeSearch/Real/K=16/N=1000000/LeafSize=256/Coverage=0.05-8 530937 ns/op 212134 B/op 132 | ``` 133 | 134 | Raw data on these results may be found [here](/internal/perf/results/v0.5.5.txt). 135 | -------------------------------------------------------------------------------- /container/bruteforce/bruteforce.go: -------------------------------------------------------------------------------- 1 | package bruteforce 2 | 3 | import ( 4 | "sort" 5 | 6 | "github.com/downflux/go-geometry/nd/hyperrectangle" 7 | "github.com/downflux/go-geometry/nd/vector" 8 | "github.com/downflux/go-kd/filter" 9 | "github.com/downflux/go-kd/internal/perf/util" 10 | "github.com/downflux/go-kd/point" 11 | ) 12 | 13 | type L[T point.P] []T 14 | 15 | func New[T point.P](d []T) *L[T] { 16 | data := make([]T, len(d)) 17 | if l := copy(data, d); l != len(d) { 18 | panic("could not copy data into brute force list") 19 | } 20 | m := L[T](data) 21 | return &m 22 | } 23 | 24 | func (m *L[T]) KNN(p vector.V, k int, f filter.F[T]) []T { 25 | sort.Sort(util.L[T]{ 26 | Data: *m, 27 | P: p, 28 | }) 29 | 30 | var data []T 31 | for _, p := range *m { 32 | if f(p) { 33 | data = append(data, p) 34 | } 35 | if len(data) == k { 36 | return data 37 | } 38 | } 39 | return data 40 | } 41 | 42 | func (m *L[T]) RangeSearch(q hyperrectangle.R, f filter.F[T]) []T { 43 | var data []T 44 | for _, p := range m.Data() { 45 | if q.In(p.P()) && f(p) { 46 | data = append(data, p) 47 | } 48 | } 49 | return data 50 | } 51 | 52 | func (m *L[T]) Insert(p T) { *m = append(*m, p) } 53 | func (m *L[T]) Remove(p vector.V, f filter.F[T]) (T, bool) { 54 | var blank T 55 | for i, q := range *m { 56 | if vector.Within(p, q.P()) && f(q) { 57 | (*m)[i], (*m)[len(*m)-1] = (*m)[len(*m)-1], blank 58 | *m = (*m)[:len(*m)-1] 59 | return q, true 60 | } 61 | } 62 | return blank, false 63 | } 64 | 65 | func (m *L[T]) Data() []T { return *m } 66 | func (m *L[T]) Balance() {} 67 | -------------------------------------------------------------------------------- /container/bruteforce/bruteforce_test.go: -------------------------------------------------------------------------------- 1 | package bruteforce 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/downflux/go-geometry/nd/vector" 7 | "github.com/downflux/go-kd/container" 8 | "github.com/downflux/go-kd/point/mock" 9 | "github.com/google/go-cmp/cmp" 10 | ) 11 | 12 | var _ container.C[mock.P] = &L[mock.P]{} 13 | 14 | func TestDelete(t *testing.T) { 15 | type config struct { 16 | name string 17 | data []*mock.P 18 | vs []vector.V 19 | 20 | want []*mock.P 21 | } 22 | 23 | configs := []config{ 24 | { 25 | name: "Nil", 26 | data: nil, 27 | vs: []vector.V{ 28 | mock.U(100), 29 | }, 30 | want: []*mock.P{}, 31 | }, 32 | { 33 | name: "Simple", 34 | data: []*mock.P{ 35 | &mock.P{X: mock.U(50)}, 36 | &mock.P{X: mock.U(100)}, 37 | }, 38 | vs: []vector.V{ 39 | mock.U(100), 40 | }, 41 | want: []*mock.P{ 42 | &mock.P{X: mock.U(50)}, 43 | }, 44 | }, 45 | { 46 | name: "Degenerate", 47 | data: []*mock.P{ 48 | &mock.P{X: mock.U(100), Data: "A"}, 49 | &mock.P{X: mock.U(100), Data: "B"}, 50 | }, 51 | vs: []vector.V{ 52 | mock.U(100), 53 | }, 54 | want: []*mock.P{ 55 | &mock.P{X: mock.U(100), Data: "B"}, 56 | }, 57 | }, 58 | } 59 | 60 | for _, c := range configs { 61 | t.Run(c.name, func(t *testing.T) { 62 | l := New(c.data) 63 | for _, v := range c.vs { 64 | l.Remove(v, func(p *mock.P) bool { return vector.Within(v, p.P()) }) 65 | } 66 | 67 | got := l.Data() 68 | if diff := cmp.Diff(c.want, got); diff != "" { 69 | t.Errorf("Data() mismatch (-want +got):\n%v", diff) 70 | } 71 | 72 | }) 73 | } 74 | } 75 | 76 | func TestInsert(t *testing.T) { 77 | type config struct { 78 | name string 79 | data []*mock.P 80 | ps []*mock.P 81 | 82 | want []*mock.P 83 | } 84 | 85 | configs := []config{ 86 | { 87 | name: "Trivial", 88 | data: nil, 89 | ps: []*mock.P{ 90 | &mock.P{X: mock.U(100)}, 91 | }, 92 | want: []*mock.P{ 93 | &mock.P{X: mock.U(100)}, 94 | }, 95 | }, 96 | { 97 | name: "MultipleInsert", 98 | data: nil, 99 | ps: []*mock.P{ 100 | &mock.P{X: mock.U(101)}, 101 | &mock.P{X: mock.U(100)}, 102 | &mock.P{X: mock.U(202)}, 103 | }, 104 | want: []*mock.P{ 105 | &mock.P{X: mock.U(101)}, 106 | &mock.P{X: mock.U(100)}, 107 | &mock.P{X: mock.U(202)}, 108 | }, 109 | }, 110 | { 111 | name: "MultipleInsert/NonNil", 112 | data: []*mock.P{ 113 | &mock.P{X: mock.U(4)}, 114 | &mock.P{X: mock.U(5)}, 115 | }, 116 | ps: []*mock.P{ 117 | &mock.P{X: mock.U(101)}, 118 | &mock.P{X: mock.U(100)}, 119 | &mock.P{X: mock.U(202)}, 120 | }, 121 | want: []*mock.P{ 122 | &mock.P{X: mock.U(4)}, 123 | &mock.P{X: mock.U(5)}, 124 | &mock.P{X: mock.U(101)}, 125 | &mock.P{X: mock.U(100)}, 126 | &mock.P{X: mock.U(202)}, 127 | }, 128 | }, 129 | } 130 | 131 | for _, c := range configs { 132 | t.Run(c.name, func(t *testing.T) { 133 | l := New(c.data) 134 | for _, p := range c.ps { 135 | l.Insert(p) 136 | } 137 | 138 | got := l.Data() 139 | if diff := cmp.Diff(c.want, got); diff != "" { 140 | t.Errorf("Data() mismatch (-want +got):\n%v", diff) 141 | } 142 | }) 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /container/container.go: -------------------------------------------------------------------------------- 1 | // Package container exports the expected storage API used for querying a set of 2 | // objects in a system. This may be used to more freely move between different 3 | // implementations as the conditions of the system change, e.g. when the number 4 | // or density of agents reach some threshold. 5 | package container 6 | 7 | import ( 8 | "github.com/downflux/go-geometry/nd/hyperrectangle" 9 | "github.com/downflux/go-geometry/nd/vector" 10 | "github.com/downflux/go-kd/filter" 11 | "github.com/downflux/go-kd/point" 12 | ) 13 | 14 | type C[T point.P] interface { 15 | // KNN returns the k-nearest neighbors of the given search coordinates. 16 | // 17 | // N.B.: KNN will return at max k neighbors; in the degenerate case that 18 | // multiple data points reside at the same spacial coordinate, this 19 | // function will arbitrarily return a subset of these to fulfill the 20 | // k-neighbors constraint. 21 | KNN(p vector.V, k int, f filter.F[T]) []T 22 | 23 | // Data returns all data stored in the K-D tree. 24 | Data() []T 25 | 26 | // RangeSearch returns a set of data points in the given bounding box. 27 | // Data points are added to the returned set if they fall inside the 28 | // bounding box and passes the given filter function. 29 | RangeSearch(q hyperrectangle.R, f filter.F[T]) []T 30 | 31 | // Balance() upates the container after a set of mutations. For a k-D 32 | // tree, this is a rebalance operation. 33 | Balance() 34 | 35 | // Insert adds a new data point into the container. 36 | Insert(p T) 37 | 38 | // Remove deletes an existing data point from the container. This 39 | // function will delete an arbitrary matching point with the given 40 | // coordinates. 41 | Remove(p vector.V, f filter.F[T]) (T, bool) 42 | } 43 | -------------------------------------------------------------------------------- /container/kd/kd.go: -------------------------------------------------------------------------------- 1 | package kd 2 | 3 | import ( 4 | "github.com/downflux/go-geometry/nd/hyperrectangle" 5 | "github.com/downflux/go-geometry/nd/vector" 6 | "github.com/downflux/go-kd/filter" 7 | "github.com/downflux/go-kd/kd" 8 | "github.com/downflux/go-kd/point" 9 | ) 10 | 11 | type KD[T point.P] kd.KD[T] 12 | 13 | func (t *KD[T]) KNN(p vector.V, k int, f filter.F[T]) []T { return kd.KNN((*kd.KD[T])(t), p, k, f) } 14 | func (t *KD[T]) RangeSearch(q hyperrectangle.R, f filter.F[T]) []T { 15 | return kd.RangeSearch((*kd.KD[T])(t), q, f) 16 | } 17 | func (t *KD[T]) Data() []T { return kd.Data((*kd.KD[T])(t)) } 18 | func (t *KD[T]) Balance() { (*kd.KD[T])(t).Balance() } 19 | func (t *KD[T]) Insert(p T) { (*kd.KD[T])(t).Insert(p) } 20 | func (t *KD[T]) Remove(v vector.V, f filter.F[T]) (T, bool) { return (*kd.KD[T])(t).Remove(v, f) } 21 | -------------------------------------------------------------------------------- /container/kd/kd_test.go: -------------------------------------------------------------------------------- 1 | package kd 2 | 3 | import ( 4 | "github.com/downflux/go-kd/container" 5 | "github.com/downflux/go-kd/point/mock" 6 | ) 7 | 8 | var _ container.C[mock.P] = &KD[mock.P]{} 9 | -------------------------------------------------------------------------------- /container/kyroy/kyroy.go: -------------------------------------------------------------------------------- 1 | // Package kyroy is a wrapper around the @kyroy k-D tree implementation. This is 2 | // used for performance testing. 3 | package kyroy 4 | 5 | import ( 6 | "github.com/downflux/go-geometry/nd/hyperrectangle" 7 | "github.com/downflux/go-geometry/nd/vector" 8 | "github.com/downflux/go-kd/filter" 9 | "github.com/downflux/go-kd/point" 10 | "github.com/kyroy/kdtree" 11 | "github.com/kyroy/kdtree/points" 12 | ) 13 | 14 | type KD[T point.P] kdtree.KDTree 15 | 16 | func New[T point.P](data []T) *KD[T] { 17 | var ps []kdtree.Point 18 | for _, p := range data { 19 | ps = append(ps, points.NewPoint([]float64(p.P()), p)) 20 | } 21 | return (*KD[T])(kdtree.New(ps)) 22 | } 23 | 24 | func (t *KD[T]) KNN(p vector.V, k int, f filter.F[T]) []T { 25 | var data []T 26 | for _, p := range (*kdtree.KDTree)(t).KNN( 27 | points.NewPoint([]float64(p), nil), 28 | k, 29 | ) { 30 | if f(p.(*points.Point).Data.(T)) { 31 | data = append(data, p.(*points.Point).Data.(T)) 32 | } 33 | } 34 | return data 35 | } 36 | 37 | func (t *KD[T]) Data() []T { 38 | var data []T 39 | for _, p := range (*kdtree.KDTree)(t).Points() { 40 | data = append(data, p.(T)) 41 | } 42 | return data 43 | } 44 | 45 | func (t *KD[T]) RangeSearch(q hyperrectangle.R, f filter.F[T]) []T { 46 | var r [][2]float64 47 | for i := vector.D(0); i < q.Min().Dimension(); i++ { 48 | r = append(r, [2]float64{q.Min().X(i), q.Max().X(i)}) 49 | } 50 | 51 | var data []T 52 | for _, p := range (*kdtree.KDTree)(t).RangeSearch(r) { 53 | if f(p.(*points.Point).Data.(T)) { 54 | data = append(data, p.(*points.Point).Data.(T)) 55 | } 56 | } 57 | return data 58 | } 59 | 60 | func (t *KD[T]) Balance() { (*kdtree.KDTree)(t).Balance() } 61 | func (t *KD[T]) Insert(p T) { (*kdtree.KDTree)(t).Insert(points.NewPoint([]float64(p.P()), p)) } 62 | 63 | func (t *KD[T]) Remove(p vector.V, f filter.F[T]) (T, bool) { 64 | v := (*kdtree.KDTree)(t).Remove(points.NewPoint([]float64(p), nil)).(*points.Point).Data.(T) 65 | if !f(v) { 66 | var blank T 67 | (*kdtree.KDTree)(t).Insert(points.NewPoint([]float64(p), v)) 68 | return blank, false 69 | } 70 | return v, true 71 | } 72 | -------------------------------------------------------------------------------- /container/kyroy/kyroy_test.go: -------------------------------------------------------------------------------- 1 | package kyroy 2 | 3 | import ( 4 | "github.com/downflux/go-kd/container" 5 | "github.com/downflux/go-kd/point/mock" 6 | ) 7 | 8 | var _ container.C[mock.P] = &KD[mock.P]{} 9 | -------------------------------------------------------------------------------- /filter/filter.go: -------------------------------------------------------------------------------- 1 | package filter 2 | 3 | import ( 4 | "github.com/downflux/go-kd/point" 5 | ) 6 | 7 | type F[T point.P] func(p T) bool 8 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/downflux/go-kd 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/downflux/go-geometry v0.15.0 7 | github.com/downflux/go-pq v0.3.0 8 | github.com/google/go-cmp v0.5.9 9 | github.com/kyroy/kdtree v0.0.0-20200419114247-70830f883f1d 10 | ) 11 | 12 | require github.com/kyroy/priority-queue v0.0.0-20180327160706-6e21825e7e0c // indirect 13 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/downflux/go-geometry v0.10.2 h1:Z79Khzl6AKMSMLnM5xG75fEOL1fmIWlF14+8j+r01D0= 4 | github.com/downflux/go-geometry v0.10.2/go.mod h1:XWTzSaMiRMAxupAR+cXAsa1Q75TCSp1Shc/ydsJ0xVE= 5 | github.com/downflux/go-geometry v0.13.0 h1:MWGPpr9ZLMPh/oQLYWwHQG6LpM7KLKn8pfkxIzGFCbw= 6 | github.com/downflux/go-geometry v0.13.0/go.mod h1:ZJcto0QwYRdoIbi5G4mh5y6v2xUS+d++/cANaO1F9+8= 7 | github.com/downflux/go-geometry v0.15.0 h1:rgx7x/t2Yo9OqhjSONhFS74QQlx/UQouEb4qmmPxM2w= 8 | github.com/downflux/go-geometry v0.15.0/go.mod h1:ZJcto0QwYRdoIbi5G4mh5y6v2xUS+d++/cANaO1F9+8= 9 | github.com/downflux/go-pq v0.1.4 h1:SHFeyU+DNtx6gcsmBYPx+EekKIsmib8k7U39UGpH/7g= 10 | github.com/downflux/go-pq v0.1.4/go.mod h1:vkc6UAQ+TBoNdTwDm5akDexE1auN2kQcR8BFw3hNCiM= 11 | github.com/downflux/go-pq v0.3.0 h1:oWLx7rzsD4fv1f2kp33NUq63CJVQvXZORkcpHr6bp9g= 12 | github.com/downflux/go-pq v0.3.0/go.mod h1:vkc6UAQ+TBoNdTwDm5akDexE1auN2kQcR8BFw3hNCiM= 13 | github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= 14 | github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 15 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 16 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 17 | github.com/jupp0r/go-priority-queue v0.0.0-20160601094913-ab1073853bde h1:+5PMaaQtDUwOcJIUlmX89P0J3iwTvErTmyn5WghzXAQ= 18 | github.com/jupp0r/go-priority-queue v0.0.0-20160601094913-ab1073853bde/go.mod h1:RDgD/dfPmIwFH0qdUOjw71HjtWg56CtyLIoHL+R1wJw= 19 | github.com/kyroy/kdtree v0.0.0-20200419114247-70830f883f1d h1:1n5M/49q9H6QtNJiiVL/W5mqgT1UdlGQ7oLP+DkJ1vs= 20 | github.com/kyroy/kdtree v0.0.0-20200419114247-70830f883f1d/go.mod h1:6oJGQK7VSg3RxSQ7QspgqpCmKjIbAslgT2wBXbFJUZw= 21 | github.com/kyroy/priority-queue v0.0.0-20180327160706-6e21825e7e0c h1:1c7+XOOGQ19cXjZ1Ss/irljQxgPvb+8z+jNEprCXl20= 22 | github.com/kyroy/priority-queue v0.0.0-20180327160706-6e21825e7e0c/go.mod h1:R477L6j2/dUcE0q0aftk0kR5Xt93W7g1066AodcJhEo= 23 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 24 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 25 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 26 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= 27 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 28 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 29 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 30 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 31 | -------------------------------------------------------------------------------- /internal/knn/knn.go: -------------------------------------------------------------------------------- 1 | package knn 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/downflux/go-kd/internal/node" 7 | "github.com/downflux/go-kd/point" 8 | "github.com/downflux/go-kd/vector" 9 | "github.com/downflux/go-pq/pq" 10 | 11 | vnd "github.com/downflux/go-geometry/nd/vector" 12 | ) 13 | 14 | func path[T point.P](n node.N[T], p vnd.V) []node.N[T] { 15 | if n.Nil() { 16 | return nil 17 | } 18 | if n.Leaf() { 19 | return []node.N[T]{n} 20 | } 21 | 22 | // Note that we are bypassing the v == n.Pivot() stop condition check -- 23 | // we are always continuing to the leaf ndoe. This is necessary for 24 | // finding multiple closest neighbors, as we care about points in the 25 | // tree which do not have to coincide with the point coordinates. 26 | if vector.Comparator(n.Axis()).Less(p, n.Pivot()) { 27 | return append(path(n.L(), p), n) 28 | } 29 | return append(path(n.R(), p), n) 30 | } 31 | 32 | func KNN[T point.P](n node.N[T], p vnd.V, k int, f func(p T) bool) []T { 33 | q := pq.New[T](k, pq.PMax) 34 | knn(n, p, q, vnd.M(make([]float64, p.Dimension())), f) 35 | 36 | ps := make([]T, q.Len()) 37 | for i := q.Len() - 1; i >= 0; i-- { 38 | ps[i], _ = q.Pop() 39 | } 40 | return ps 41 | } 42 | 43 | func knn[T point.P](n node.N[T], p vnd.V, q *pq.PQ[T], buf vnd.M, f func(p T) bool) { 44 | for _, n := range path[T](n, p) { 45 | for _, datum := range n.Data() { 46 | buf.Copy(p) 47 | buf.Sub(datum.P()) 48 | 49 | if d := vnd.SquaredMagnitude(buf.V()); (!q.Full() || d < q.Priority()) && f(datum) { 50 | q.Push(datum, d) 51 | } 52 | } 53 | 54 | if !n.Leaf() { 55 | buf.Copy(p) 56 | buf.Sub(n.Pivot()) 57 | 58 | if q.Priority() > math.Pow(buf.X(n.Axis()), 2) { 59 | if vector.Comparator(n.Axis()).Less(p, n.Pivot()) { 60 | knn(n.R(), p, q, buf, f) 61 | } else { 62 | knn(n.L(), p, q, buf, f) 63 | } 64 | } 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /internal/knn/knn_test.go: -------------------------------------------------------------------------------- 1 | package knn 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/downflux/go-kd/internal/node" 7 | "github.com/downflux/go-kd/internal/node/tree" 8 | "github.com/downflux/go-kd/point" 9 | "github.com/downflux/go-kd/point/mock" 10 | "github.com/downflux/go-kd/vector" 11 | "github.com/google/go-cmp/cmp" 12 | "github.com/google/go-cmp/cmp/cmpopts" 13 | 14 | vnd "github.com/downflux/go-geometry/nd/vector" 15 | ) 16 | 17 | func TestKNN(t *testing.T) { 18 | type config[T point.P] struct { 19 | name string 20 | n node.N[T] 21 | p vnd.V 22 | k int 23 | want []T 24 | } 25 | 26 | configs := []config[mock.P]{ 27 | { 28 | name: "Trivial", 29 | n: tree.New[mock.P](tree.O[mock.P]{ 30 | Data: nil, 31 | K: 1, 32 | N: 10, 33 | }), 34 | p: mock.V(*vnd.New(100, 200)), 35 | k: 100, 36 | want: []mock.P{}, 37 | }, 38 | { 39 | name: "SmallD", 40 | n: tree.New[mock.P](tree.O[mock.P]{ 41 | Data: []mock.P{ 42 | mock.P{X: mock.U(0.1)}, 43 | mock.P{X: mock.U(0.01)}, 44 | }, 45 | K: 1, 46 | N: 10, 47 | }), 48 | p: mock.U(0), 49 | k: 1, 50 | want: []mock.P{ 51 | mock.P{X: mock.U(0.01)}, 52 | }, 53 | }, 54 | { 55 | name: "Simple", 56 | n: tree.New[mock.P](tree.O[mock.P]{ 57 | Data: []mock.P{ 58 | mock.P{X: mock.U(10)}, 59 | }, 60 | K: 1, 61 | N: 10, 62 | }), 63 | p: mock.U(-1000), 64 | k: 100, 65 | want: []mock.P{ 66 | mock.P{X: mock.U(10)}, 67 | }, 68 | }, 69 | { 70 | name: "Simple/2D", 71 | n: tree.New[mock.P](tree.O[mock.P]{ 72 | Data: []mock.P{ 73 | mock.P{X: mock.V(*vnd.New(100, 1))}, 74 | }, 75 | K: 2, 76 | N: 1, 77 | }), 78 | p: mock.V(*vnd.New(0, -100)), 79 | k: 100, 80 | want: []mock.P{ 81 | mock.P{X: mock.V(*vnd.New(100, 1))}, 82 | }, 83 | }, 84 | { 85 | name: "Simple/MultiK", 86 | n: tree.New[mock.P](tree.O[mock.P]{ 87 | Data: []mock.P{ 88 | mock.P{X: mock.U(101)}, 89 | mock.P{X: mock.U(102)}, 90 | mock.P{X: mock.U(103)}, 91 | mock.P{X: mock.U(99)}, 92 | }, 93 | K: 1, 94 | N: 1, 95 | }), 96 | p: mock.U(100), 97 | k: 2, 98 | want: []mock.P{ 99 | mock.P{X: mock.U(101)}, 100 | mock.P{X: mock.U(99)}, 101 | }, 102 | }, 103 | { 104 | name: "Simple/MultiK/Degenerate", 105 | n: tree.New[mock.P](tree.O[mock.P]{ 106 | Data: []mock.P{ 107 | mock.P{X: mock.U(99), Data: "A"}, 108 | mock.P{X: mock.U(99), Data: "B"}, 109 | mock.P{X: mock.U(99), Data: "C"}, 110 | mock.P{X: mock.U(99), Data: "D"}, 111 | }, 112 | K: 1, 113 | N: 1, 114 | }), 115 | p: mock.U(100), 116 | k: 2, 117 | want: []mock.P{ 118 | // We don't care what data we get here, as long 119 | // as it's two of the input set. The ordering 120 | // was matched manually. 121 | mock.P{X: mock.U(99), Data: "C"}, 122 | mock.P{X: mock.U(99), Data: "B"}, 123 | }, 124 | }, 125 | { 126 | name: "Simple/MultiK/2D/Degenerate", 127 | n: tree.New[mock.P](tree.O[mock.P]{ 128 | Data: []mock.P{ 129 | mock.P{X: mock.V(*vnd.New(99, 100)), Data: "A"}, 130 | mock.P{X: mock.V(*vnd.New(99, 100)), Data: "B"}, 131 | mock.P{X: mock.V(*vnd.New(99, 100)), Data: "C"}, 132 | mock.P{X: mock.V(*vnd.New(99, 100)), Data: "D"}, 133 | }, 134 | K: 2, 135 | N: 1, 136 | }), 137 | p: mock.V(*vnd.New(0, 0)), 138 | k: 2, 139 | want: []mock.P{ 140 | mock.P{X: mock.V(*vnd.New(99, 100)), Data: "C"}, 141 | mock.P{X: mock.V(*vnd.New(99, 100)), Data: "B"}, 142 | }, 143 | }, 144 | } 145 | 146 | for _, c := range configs { 147 | t.Run(c.name, func(t *testing.T) { 148 | got := KNN(c.n, c.p, c.k, func(mock.P) bool { return true }) 149 | if diff := cmp.Diff(c.want, got, cmpopts.SortSlices( 150 | func(p, q mock.P) bool { 151 | return vector.Less(p.P(), q.P()) 152 | })); diff != "" { 153 | t.Errorf("KNN mismatch (-want +got):\n%v", diff) 154 | } 155 | }) 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /internal/node/node.go: -------------------------------------------------------------------------------- 1 | package node 2 | 3 | import ( 4 | "github.com/downflux/go-geometry/nd/vector" 5 | "github.com/downflux/go-kd/point" 6 | ) 7 | 8 | type N[T point.P] interface { 9 | // L consists of points strictly less than the current pivot for the 10 | // current axis. 11 | L() N[T] 12 | 13 | // R consists of points greater than or equal to the current pivot for 14 | // the current axis. 15 | R() N[T] 16 | 17 | // Data returns the points stored in the current node -- note that this 18 | // does not include data from child nodes. 19 | Data() []T 20 | 21 | Insert(p T) 22 | Remove(v vector.V, f func(p T) bool) (T, bool) 23 | 24 | Pivot() vector.V 25 | K() vector.D 26 | Axis() vector.D 27 | 28 | Leaf() bool 29 | Nil() bool 30 | } 31 | -------------------------------------------------------------------------------- /internal/node/tree/tree.go: -------------------------------------------------------------------------------- 1 | package tree 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | 7 | "github.com/downflux/go-kd/internal/node" 8 | "github.com/downflux/go-kd/point" 9 | "github.com/downflux/go-kd/vector" 10 | 11 | vnd "github.com/downflux/go-geometry/nd/vector" 12 | ) 13 | 14 | const ( 15 | NLargeData = 128 16 | ) 17 | 18 | type O[T point.P] struct { 19 | Data []T 20 | K vnd.D 21 | 22 | // N is the nominal leaf size of a node. 23 | N int 24 | 25 | Axis vnd.D 26 | 27 | // inorder specifies if the incoming data should be shuffled first 28 | // before consuming; this is a test-only option useful for comparing 29 | // expected states. 30 | inorder bool 31 | } 32 | 33 | type N[T point.P] struct { 34 | data []T 35 | 36 | k vnd.D 37 | pivot vnd.V 38 | axis vnd.D 39 | left *N[T] 40 | right *N[T] 41 | } 42 | 43 | func validate[T point.P](o O[T]) error { 44 | if o.Axis >= o.K { 45 | return fmt.Errorf("given node dimension greater than vnd dimension: %v > %v", o.Axis, o.K) 46 | } 47 | if o.N < 1 { 48 | return fmt.Errorf("given leaf node size must be a positive integer") 49 | } 50 | return nil 51 | } 52 | 53 | // New recursively constructs a node object given the input data. 54 | // 55 | // An input of an empty dataset will result in a leaf node being returned. 56 | func New[T point.P](o O[T]) *N[T] { 57 | if err := validate(o); err != nil { 58 | panic(fmt.Sprintf("could not construct node: %v", err)) 59 | } 60 | 61 | // hoare is known to be quadratic in the worst case (i.e. when data is 62 | // in-sequence). Therefore we want to ensure this is not the case by 63 | // randomly shuffling the data ordering first. 64 | if !o.inorder && len(o.Data) > 0 { 65 | rand.Shuffle(len(o.Data), func(i, j int) { o.Data[i], o.Data[j] = o.Data[j], o.Data[i] }) 66 | } 67 | 68 | if len(o.Data) <= o.N { 69 | return &N[T]{ 70 | data: o.Data, 71 | axis: o.Axis, 72 | k: o.K, 73 | } 74 | } 75 | pivot := hoare(o.Data, 0, 0, len(o.Data), func(a vnd.V, b vnd.V) bool { return a.X(o.Axis) < b.X(o.Axis) }) 76 | 77 | node := &N[T]{ 78 | data: []T{o.Data[pivot]}, 79 | pivot: o.Data[pivot].P(), 80 | axis: o.Axis, 81 | k: o.K, 82 | } 83 | 84 | ol := O[T]{ 85 | Data: o.Data[0:pivot], 86 | Axis: (o.Axis + 1) % o.K, 87 | K: o.K, 88 | N: o.N, 89 | // There is no need to continue shuffling data in child 90 | // nodes, as we are making the assumption the root 91 | // shuffle was random enough. 92 | inorder: true, 93 | } 94 | or := O[T]{ 95 | Data: o.Data[pivot+1 : len(o.Data)], 96 | Axis: (o.Axis + 1) % o.K, 97 | K: o.K, 98 | N: o.N, 99 | inorder: true, 100 | } 101 | 102 | // Channel overhead is too high relative to the stack when the dataset 103 | // is small. Quick experimentation has determined a value which seems 104 | // "good enough" to branch on, though this figure can probably be dialed 105 | // in more as architectures change. 106 | if len(o.Data) < NLargeData { 107 | if n := New[T](ol); len(n.Data()) > 0 { 108 | node.left = n 109 | } 110 | if n := New[T](or); len(n.Data()) > 0 { 111 | node.right = n 112 | } 113 | } else { 114 | // Node construction can be concurrent since we guarantee child nodes 115 | // will never access data across the high / low boundary. Note that this 116 | // does increase the number of allocs ~3x, and is less performant for 117 | // low data sizes (i.e. for less than ~10k points). 118 | l := make(chan *N[T]) 119 | r := make(chan *N[T]) 120 | 121 | go func(ch chan<- *N[T]) { 122 | ch <- New[T](ol) 123 | close(ch) 124 | }(l) 125 | go func(ch chan<- *N[T]) { 126 | ch <- New[T](or) 127 | close(ch) 128 | }(r) 129 | 130 | // Skip adding child nodes if they do not contain data -- this prevents 131 | // extraneous leaves from being added. 132 | if n := <-l; len(n.Data()) > 0 { 133 | node.left = n 134 | } 135 | if n := <-r; len(n.Data()) > 0 { 136 | node.right = n 137 | } 138 | } 139 | return node 140 | } 141 | 142 | func (n *N[T]) Nil() bool { return n == nil } 143 | func (n *N[T]) L() node.N[T] { return n.left } 144 | func (n *N[T]) R() node.N[T] { return n.right } 145 | func (n *N[T]) Leaf() bool { return n.pivot == nil } 146 | func (n *N[T]) Axis() vnd.D { return n.axis } 147 | func (n *N[T]) Pivot() vnd.V { return n.pivot } 148 | func (n *N[T]) Data() []T { return n.data } 149 | func (n *N[T]) K() vnd.D { return n.k } 150 | 151 | func (n *N[T]) Insert(p T) { 152 | if n.Leaf() || !n.Leaf() && vnd.Within(n.Pivot(), p.P()) { 153 | n.data = append(n.data, p) 154 | return 155 | } 156 | 157 | if vector.Comparator(n.Axis()).Less(p.P(), n.Pivot()) { 158 | n.L().Insert(p) 159 | return 160 | } 161 | n.R().Insert(p) 162 | } 163 | 164 | func (n *N[T]) Remove(v vnd.V, f func(p T) bool) (T, bool) { 165 | var blank T 166 | if n.Leaf() || !n.Leaf() && vnd.Within(n.Pivot(), v) { 167 | for i, p := range n.Data() { 168 | if f(p) { 169 | n.data[i], n.data[len(n.data)-1] = n.data[len(n.data)-1], blank 170 | n.data = n.data[:len(n.data)-1] 171 | 172 | return p, true 173 | } 174 | } 175 | return blank, false 176 | } 177 | 178 | if vector.Comparator(n.Axis()).Less(v, n.Pivot()) { 179 | return n.L().Remove(v, f) 180 | } 181 | return n.R().Remove(v, f) 182 | } 183 | 184 | // hoare partitions the input data by the pivot. 185 | // 186 | // N.B.: The high index is exclusive -- that is, when partitioning an entire 187 | // array, high should be set to len(data). 188 | func hoare[T point.P](data []T, pivot int, low int, high int, less func(a vnd.V, b vnd.V) bool) int { 189 | if pivot < 0 || low < 0 || high < 0 || pivot >= len(data) || low >= len(data) || high > len(data) { 190 | return -1 191 | } 192 | 193 | // hoare partitioning requires the pivot at the beginning of the array. 194 | data[pivot], data[low] = data[low], data[pivot] 195 | 196 | // i and j are the left and right tracker indices, respectively. i is 197 | // strictly increasing, while j is strictly decreasing 198 | i := low + 1 199 | j := high - 1 200 | 201 | for i <= j { 202 | // Skip array elements which are already sorted. Note that we 203 | // are partitioning such that 204 | // 205 | // l < p <= r 206 | for ; less(data[i].P(), data[low].P()) && i < j; i++ { 207 | } 208 | for ; !less(data[j].P(), data[low].P()) && j > 0; j-- { 209 | } 210 | 211 | if i > j { 212 | break 213 | } 214 | 215 | data[i], data[j] = data[j], data[i] 216 | 217 | i++ 218 | j-- 219 | 220 | } 221 | 222 | // Since the pivot is stored at the beginning of the array, we need to 223 | // do a final swap to ensure the pivot is at the right position. 224 | data[low], data[i-1] = data[i-1], data[low] 225 | return i - 1 226 | } 227 | -------------------------------------------------------------------------------- /internal/node/tree/tree_test.go: -------------------------------------------------------------------------------- 1 | package tree 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/downflux/go-kd/internal/node" 7 | "github.com/downflux/go-kd/internal/node/util" 8 | "github.com/downflux/go-kd/point/mock" 9 | "github.com/downflux/go-kd/vector" 10 | "github.com/google/go-cmp/cmp" 11 | 12 | vnd "github.com/downflux/go-geometry/nd/vector" 13 | ) 14 | 15 | var _ node.N[mock.P] = &N[mock.P]{} 16 | 17 | func TestInsert(t *testing.T) { 18 | type config struct { 19 | name string 20 | opts O[mock.P] 21 | ps []mock.P 22 | 23 | want *N[mock.P] 24 | } 25 | 26 | configs := []config{ 27 | { 28 | name: "Nil", 29 | opts: O[mock.P]{ 30 | Data: nil, 31 | K: 1, 32 | N: 1, 33 | inorder: true, 34 | }, 35 | ps: []mock.P{ 36 | mock.P{X: mock.U(1)}, 37 | mock.P{X: mock.U(-50)}, 38 | mock.P{X: mock.U(100)}, 39 | }, 40 | want: &N[mock.P]{ 41 | k: 1, 42 | axis: 0, 43 | data: []mock.P{ 44 | mock.P{X: mock.U(1)}, 45 | mock.P{X: mock.U(-50)}, 46 | mock.P{X: mock.U(100)}, 47 | }, 48 | }, 49 | }, 50 | { 51 | name: "L", 52 | opts: O[mock.P]{ 53 | Data: []mock.P{ 54 | mock.P{X: mock.U(1)}, 55 | mock.P{X: mock.U(-50)}, 56 | }, 57 | K: 1, 58 | N: 1, 59 | inorder: true, 60 | }, 61 | ps: []mock.P{ 62 | mock.P{X: mock.U(0)}, 63 | mock.P{X: mock.U(-55)}, 64 | }, 65 | want: &N[mock.P]{ 66 | k: 1, 67 | axis: 0, 68 | pivot: mock.U(1), 69 | data: []mock.P{ 70 | mock.P{X: mock.U(1)}, 71 | }, 72 | left: &N[mock.P]{ 73 | k: 1, 74 | axis: 0, 75 | data: []mock.P{ 76 | mock.P{X: mock.U(-50)}, 77 | mock.P{X: mock.U(0)}, 78 | mock.P{X: mock.U(-55)}, 79 | }, 80 | }, 81 | }, 82 | }, 83 | { 84 | name: "R", 85 | opts: O[mock.P]{ 86 | Data: []mock.P{ 87 | mock.P{X: mock.U(1)}, 88 | mock.P{X: mock.U(50)}, 89 | }, 90 | K: 1, 91 | N: 1, 92 | inorder: true, 93 | }, 94 | ps: []mock.P{ 95 | mock.P{X: mock.U(2)}, 96 | mock.P{X: mock.U(100)}, 97 | }, 98 | want: &N[mock.P]{ 99 | k: 1, 100 | axis: 0, 101 | pivot: mock.U(1), 102 | data: []mock.P{ 103 | mock.P{X: mock.U(1)}, 104 | }, 105 | right: &N[mock.P]{ 106 | k: 1, 107 | axis: 0, 108 | data: []mock.P{ 109 | mock.P{X: mock.U(50)}, 110 | mock.P{X: mock.U(2)}, 111 | mock.P{X: mock.U(100)}, 112 | }, 113 | }, 114 | }, 115 | }, 116 | { 117 | name: "Pivot", 118 | opts: O[mock.P]{ 119 | Data: []mock.P{ 120 | mock.P{X: mock.U(1)}, 121 | mock.P{X: mock.U(50)}, 122 | }, 123 | K: 1, 124 | N: 1, 125 | inorder: true, 126 | }, 127 | ps: []mock.P{ 128 | mock.P{X: mock.U(1), Data: "B"}, 129 | mock.P{X: mock.U(1), Data: "C"}, 130 | }, 131 | want: &N[mock.P]{ 132 | k: 1, 133 | axis: 0, 134 | pivot: mock.U(1), 135 | data: []mock.P{ 136 | mock.P{X: mock.U(1)}, 137 | mock.P{X: mock.U(1), Data: "B"}, 138 | mock.P{X: mock.U(1), Data: "C"}, 139 | }, 140 | right: &N[mock.P]{ 141 | k: 1, 142 | axis: 0, 143 | data: []mock.P{ 144 | mock.P{X: mock.U(50)}, 145 | }, 146 | }, 147 | }, 148 | }, 149 | { 150 | name: "L/LargeK", 151 | opts: O[mock.P]{ 152 | Data: []mock.P{ 153 | mock.P{X: mock.V([]float64{1, 40})}, 154 | mock.P{X: mock.V([]float64{-50, 70})}, 155 | }, 156 | K: 2, 157 | N: 1, 158 | inorder: true, 159 | }, 160 | ps: []mock.P{ 161 | mock.P{X: mock.V([]float64{-55, 100})}, 162 | mock.P{X: mock.V([]float64{0, 2})}, 163 | }, 164 | want: &N[mock.P]{ 165 | k: 2, 166 | axis: 0, 167 | pivot: mock.V([]float64{1, 40}), 168 | data: []mock.P{ 169 | mock.P{X: mock.V([]float64{1, 40})}, 170 | }, 171 | left: &N[mock.P]{ 172 | k: 2, 173 | axis: 1, 174 | data: []mock.P{ 175 | mock.P{X: mock.V([]float64{-50, 70})}, 176 | mock.P{X: mock.V([]float64{-55, 100})}, 177 | mock.P{X: mock.V([]float64{0, 2})}, 178 | }, 179 | }, 180 | }, 181 | }, 182 | { 183 | name: "R/LargeK", 184 | opts: O[mock.P]{ 185 | Data: []mock.P{ 186 | mock.P{X: mock.V([]float64{-50, 70})}, 187 | mock.P{X: mock.V([]float64{1, 40})}, 188 | }, 189 | K: 2, 190 | N: 1, 191 | inorder: true, 192 | }, 193 | ps: []mock.P{ 194 | mock.P{X: mock.V([]float64{-49, 100})}, 195 | mock.P{X: mock.V([]float64{1, 100})}, 196 | mock.P{X: mock.V([]float64{49, 2})}, 197 | }, 198 | want: &N[mock.P]{ 199 | k: 2, 200 | axis: 0, 201 | pivot: mock.V([]float64{-50, 70}), 202 | data: []mock.P{ 203 | mock.P{X: mock.V([]float64{-50, 70})}, 204 | }, 205 | right: &N[mock.P]{ 206 | k: 2, 207 | axis: 1, 208 | data: []mock.P{ 209 | mock.P{X: mock.V([]float64{1, 40})}, 210 | mock.P{X: mock.V([]float64{-49, 100})}, 211 | mock.P{X: mock.V([]float64{1, 100})}, 212 | mock.P{X: mock.V([]float64{49, 2})}, 213 | }, 214 | }, 215 | }, 216 | }, 217 | } 218 | 219 | for _, c := range configs { 220 | t.Run(c.name, func(t *testing.T) { 221 | kd := New(c.opts) 222 | for _, p := range c.ps { 223 | kd.Insert(p) 224 | } 225 | 226 | if diff := cmp.Diff(c.want, kd, cmp.AllowUnexported(N[mock.P]{})); diff != "" { 227 | t.Errorf("Insert() mismatch(-want +got):\n%v", diff) 228 | } 229 | }) 230 | } 231 | } 232 | 233 | func TestRemove(t *testing.T) { 234 | type config struct { 235 | name string 236 | opts O[mock.P] 237 | ps []mock.P 238 | 239 | want *N[mock.P] 240 | } 241 | 242 | configs := []config{ 243 | { 244 | name: "Nil", 245 | opts: O[mock.P]{ 246 | Data: nil, 247 | K: 1, 248 | N: 1, 249 | inorder: true, 250 | }, 251 | ps: []mock.P{ 252 | mock.P{X: mock.U(1)}, 253 | }, 254 | want: &N[mock.P]{ 255 | k: 1, 256 | }, 257 | }, 258 | { 259 | name: "Simple", 260 | opts: O[mock.P]{ 261 | Data: []mock.P{ 262 | mock.P{X: mock.U(100)}, 263 | }, 264 | K: 1, 265 | N: 1, 266 | inorder: true, 267 | }, 268 | ps: []mock.P{ 269 | mock.P{X: mock.U(100)}, 270 | }, 271 | want: &N[mock.P]{ 272 | k: 1, 273 | axis: 0, 274 | data: []mock.P{}, 275 | }, 276 | }, 277 | { 278 | name: "L", 279 | opts: O[mock.P]{ 280 | Data: []mock.P{ 281 | mock.P{X: mock.U(100)}, 282 | mock.P{X: mock.U(-50)}, 283 | }, 284 | K: 1, 285 | N: 1, 286 | inorder: true, 287 | }, 288 | ps: []mock.P{ 289 | mock.P{X: mock.U(-50)}, 290 | }, 291 | want: &N[mock.P]{ 292 | k: 1, 293 | pivot: mock.U(100), 294 | data: []mock.P{ 295 | mock.P{X: mock.U(100)}, 296 | }, 297 | axis: 0, 298 | left: &N[mock.P]{ 299 | k: 1, 300 | axis: 0, 301 | data: []mock.P{}, 302 | }, 303 | }, 304 | }, 305 | { 306 | name: "L/LargeK", 307 | opts: O[mock.P]{ 308 | Data: []mock.P{ 309 | mock.P{X: mock.V([]float64{100, 1})}, 310 | mock.P{X: mock.V([]float64{-50, 100})}, 311 | }, 312 | K: 2, 313 | N: 1, 314 | inorder: true, 315 | }, 316 | ps: []mock.P{ 317 | mock.P{X: mock.V([]float64{-50, 100})}, 318 | }, 319 | want: &N[mock.P]{ 320 | k: 2, 321 | pivot: mock.V([]float64{100, 1}), 322 | data: []mock.P{ 323 | mock.P{X: mock.V([]float64{100, 1})}, 324 | }, 325 | axis: 0, 326 | left: &N[mock.P]{ 327 | k: 2, 328 | axis: 1, 329 | data: []mock.P{}, 330 | }, 331 | }, 332 | }, 333 | { 334 | name: "R", 335 | opts: O[mock.P]{ 336 | Data: []mock.P{ 337 | mock.P{X: mock.U(-50)}, 338 | mock.P{X: mock.U(100)}, 339 | }, 340 | K: 1, 341 | N: 1, 342 | inorder: true, 343 | }, 344 | ps: []mock.P{ 345 | mock.P{X: mock.U(100)}, 346 | }, 347 | want: &N[mock.P]{ 348 | k: 1, 349 | pivot: mock.U(-50), 350 | data: []mock.P{ 351 | mock.P{X: mock.U(-50)}, 352 | }, 353 | axis: 0, 354 | right: &N[mock.P]{ 355 | k: 1, 356 | axis: 0, 357 | data: []mock.P{}, 358 | }, 359 | }, 360 | }, 361 | { 362 | name: "R/LargeK", 363 | opts: O[mock.P]{ 364 | Data: []mock.P{ 365 | mock.P{X: mock.V([]float64{-50, 100})}, 366 | mock.P{X: mock.V([]float64{-50, 101})}, 367 | mock.P{X: mock.V([]float64{100, 500})}, 368 | }, 369 | K: 2, 370 | N: 1, 371 | inorder: true, 372 | }, 373 | ps: []mock.P{ 374 | mock.P{X: mock.V([]float64{-50, 101})}, 375 | mock.P{X: mock.V([]float64{100, 500})}, 376 | }, 377 | want: &N[mock.P]{ 378 | k: 2, 379 | pivot: mock.V([]float64{-50, 100}), 380 | data: []mock.P{ 381 | mock.P{X: mock.V([]float64{-50, 100})}, 382 | }, 383 | axis: 0, 384 | right: &N[mock.P]{ 385 | k: 2, 386 | axis: 1, 387 | pivot: mock.V([]float64{-50, 101}), 388 | data: []mock.P{}, 389 | right: &N[mock.P]{ 390 | k: 2, 391 | axis: 0, 392 | data: []mock.P{}, 393 | }, 394 | }, 395 | }, 396 | }, 397 | { 398 | name: "Pivot", 399 | opts: O[mock.P]{ 400 | Data: []mock.P{ 401 | mock.P{X: mock.U(-50)}, 402 | mock.P{X: mock.U(100)}, 403 | }, 404 | K: 1, 405 | N: 1, 406 | inorder: true, 407 | }, 408 | ps: []mock.P{ 409 | mock.P{X: mock.U(-50)}, 410 | }, 411 | want: &N[mock.P]{ 412 | k: 1, 413 | pivot: mock.U(-50), 414 | data: []mock.P{}, 415 | axis: 0, 416 | right: &N[mock.P]{ 417 | k: 1, 418 | axis: 0, 419 | data: []mock.P{ 420 | mock.P{X: mock.U(100)}, 421 | }, 422 | }, 423 | }, 424 | }, 425 | } 426 | 427 | for _, c := range configs { 428 | t.Run(c.name, func(t *testing.T) { 429 | kd := New(c.opts) 430 | for _, p := range c.ps { 431 | kd.Remove(p.P(), func(q mock.P) bool { return mock.Equal(p, q) }) 432 | } 433 | 434 | if diff := cmp.Diff(c.want, kd, cmp.AllowUnexported(N[mock.P]{})); diff != "" { 435 | t.Errorf("Remove() mismatch(-want +got):\n%v", diff) 436 | } 437 | }) 438 | } 439 | } 440 | 441 | func TestNew(t *testing.T) { 442 | type config struct { 443 | name string 444 | opts O[mock.P] 445 | 446 | want *N[mock.P] 447 | } 448 | 449 | configs := []config{ 450 | { 451 | name: "NullNode", 452 | opts: O[mock.P]{ 453 | Data: nil, 454 | K: 2, 455 | N: 1, 456 | Axis: 0, 457 | inorder: true, 458 | }, 459 | want: &N[mock.P]{ 460 | k: 2, 461 | }, 462 | }, 463 | { 464 | name: "SingleElement", 465 | opts: O[mock.P]{ 466 | Data: []mock.P{ 467 | { 468 | X: mock.U(1), 469 | Data: "foo", 470 | }, 471 | }, 472 | K: 1, 473 | N: 1, 474 | Axis: 0, 475 | inorder: true, 476 | }, 477 | want: &N[mock.P]{ 478 | data: []mock.P{ 479 | { 480 | X: mock.U(1), 481 | Data: "foo", 482 | }, 483 | }, 484 | axis: 0, 485 | k: 1, 486 | }, 487 | }, 488 | { 489 | name: "DoubleElement", 490 | opts: O[mock.P]{ 491 | Data: []mock.P{ 492 | { 493 | X: mock.U(1), 494 | Data: "bar", 495 | }, 496 | { 497 | X: mock.U(-100), 498 | Data: "foo", 499 | }, 500 | }, 501 | K: 1, 502 | N: 1, 503 | Axis: 0, 504 | inorder: true, 505 | }, 506 | want: &N[mock.P]{ 507 | data: []mock.P{ 508 | { 509 | X: mock.U(1), 510 | Data: "bar", 511 | }, 512 | }, 513 | pivot: mock.U(1), 514 | k: 1, 515 | axis: 0, 516 | left: &N[mock.P]{ 517 | data: []mock.P{ 518 | { 519 | X: mock.U(-100), 520 | Data: "foo", 521 | }, 522 | }, 523 | k: 1, 524 | axis: 0, 525 | }, 526 | }, 527 | }, 528 | { 529 | // Check that elements right of the pivot are greater 530 | // than or equal on the same axis. 531 | name: "Equal/Right", 532 | opts: O[mock.P]{ 533 | Data: []mock.P{ 534 | mock.P{ 535 | X: mock.U(100), 536 | Data: "B", 537 | }, 538 | mock.P{ 539 | X: mock.U(100), 540 | Data: "A", 541 | }, 542 | }, 543 | K: 1, 544 | N: 1, 545 | Axis: 0, 546 | inorder: true, 547 | }, 548 | want: &N[mock.P]{ 549 | data: []mock.P{ 550 | 551 | mock.P{ 552 | X: mock.U(100), 553 | Data: "B", 554 | }, 555 | }, 556 | k: 1, 557 | pivot: mock.U(100), 558 | axis: 0, 559 | right: &N[mock.P]{ 560 | data: []mock.P{ 561 | mock.P{ 562 | X: mock.U(100), 563 | Data: "A", 564 | }, 565 | }, 566 | k: 1, 567 | axis: 0, 568 | }, 569 | }, 570 | }, 571 | { 572 | name: "TripleElement/Unbalanced/BigLeaf", 573 | opts: O[mock.P]{ 574 | Data: []mock.P{ 575 | { 576 | X: mock.U(-100), 577 | Data: "foo", 578 | }, 579 | { 580 | X: mock.U(1), 581 | Data: "bar", 582 | }, 583 | { 584 | X: mock.U(0), 585 | Data: "baz", 586 | }, 587 | }, 588 | K: 1, 589 | N: 2, 590 | Axis: 0, 591 | inorder: true, 592 | }, 593 | want: &N[mock.P]{ 594 | data: []mock.P{ 595 | { 596 | X: mock.U(-100), 597 | Data: "foo", 598 | }, 599 | }, 600 | k: 1, 601 | pivot: mock.U(-100), 602 | axis: 0, 603 | right: &N[mock.P]{ 604 | data: []mock.P{ 605 | { 606 | X: mock.U(1), 607 | Data: "bar", 608 | }, 609 | { 610 | X: mock.U(0), 611 | Data: "baz", 612 | }, 613 | }, 614 | k: 1, 615 | axis: 0, 616 | }, 617 | }, 618 | }, 619 | { 620 | name: "TripleElement/Unbalanced/BigLeaf/BigK", 621 | opts: O[mock.P]{ 622 | Data: []mock.P{ 623 | { 624 | X: mock.V(*vnd.New(-100, 1)), 625 | Data: "foo", 626 | }, 627 | { 628 | X: mock.V(*vnd.New(1, 50)), 629 | Data: "bar", 630 | }, 631 | { 632 | X: mock.V(*vnd.New(0, 75)), 633 | Data: "baz", 634 | }, 635 | }, 636 | K: 2, 637 | N: 2, 638 | Axis: 0, 639 | inorder: true, 640 | }, 641 | want: &N[mock.P]{ 642 | data: []mock.P{ 643 | { 644 | X: mock.V(*vnd.New(-100, 1)), 645 | Data: "foo", 646 | }, 647 | }, 648 | k: 2, 649 | pivot: mock.V(*vnd.New(-100, 1)), 650 | axis: 0, 651 | right: &N[mock.P]{ 652 | data: []mock.P{ 653 | { 654 | X: mock.V(*vnd.New(1, 50)), 655 | Data: "bar", 656 | }, 657 | { 658 | X: mock.V(*vnd.New(0, 75)), 659 | Data: "baz", 660 | }, 661 | }, 662 | k: 2, 663 | axis: 1, 664 | }, 665 | }, 666 | }, 667 | { 668 | name: "TripleElement/Unbalanced", 669 | opts: O[mock.P]{ 670 | Data: []mock.P{ 671 | { 672 | X: mock.U(-100), 673 | Data: "foo", 674 | }, 675 | { 676 | X: mock.U(1), 677 | Data: "bar", 678 | }, 679 | { 680 | X: mock.U(0), 681 | Data: "baz", 682 | }, 683 | }, 684 | K: 1, 685 | N: 1, 686 | Axis: 0, 687 | inorder: true, 688 | }, 689 | want: &N[mock.P]{ 690 | data: []mock.P{ 691 | { 692 | X: mock.U(-100), 693 | Data: "foo", 694 | }, 695 | }, 696 | k: 1, 697 | pivot: mock.U(-100), 698 | axis: 0, 699 | right: &N[mock.P]{ 700 | data: []mock.P{ 701 | { 702 | X: mock.U(1), 703 | Data: "bar", 704 | }, 705 | }, 706 | k: 1, 707 | pivot: mock.U(1), 708 | axis: 0, 709 | left: &N[mock.P]{ 710 | data: []mock.P{ 711 | { 712 | X: mock.U(0), 713 | Data: "baz", 714 | }, 715 | }, 716 | k: 1, 717 | axis: 0, 718 | }, 719 | }, 720 | }, 721 | }, 722 | { 723 | name: "TripleElement/Unbalanced/BigK", 724 | opts: O[mock.P]{ 725 | Data: []mock.P{ 726 | { 727 | X: mock.V(*vnd.New(-100, 1)), 728 | Data: "foo", 729 | }, 730 | { 731 | X: mock.V(*vnd.New(1, 50)), 732 | Data: "bar", 733 | }, 734 | { 735 | X: mock.V(*vnd.New(0, 75)), 736 | Data: "baz", 737 | }, 738 | }, 739 | K: 2, 740 | N: 1, 741 | Axis: 0, 742 | inorder: true, 743 | }, 744 | want: &N[mock.P]{ 745 | pivot: mock.V(*vnd.New(-100, 1)), 746 | data: []mock.P{ 747 | { 748 | X: mock.V(*vnd.New(-100, 1)), 749 | Data: "foo", 750 | }, 751 | }, 752 | k: 2, 753 | axis: 0, 754 | right: &N[mock.P]{ 755 | pivot: mock.V(*vnd.New(1, 50)), 756 | data: []mock.P{ 757 | { 758 | X: mock.V(*vnd.New(1, 50)), 759 | Data: "bar", 760 | }, 761 | }, 762 | k: 2, 763 | axis: 1, 764 | right: &N[mock.P]{ 765 | data: []mock.P{ 766 | { 767 | X: mock.V(*vnd.New(0, 75)), 768 | Data: "baz", 769 | }, 770 | }, 771 | k: 2, 772 | axis: 0, 773 | }, 774 | }, 775 | }, 776 | }, 777 | } 778 | 779 | for _, c := range configs { 780 | t.Run(c.name, func(t *testing.T) { 781 | got := New[mock.P](c.opts) 782 | if diff := cmp.Diff(c.want, got, cmp.AllowUnexported(N[mock.P]{})); diff != "" { 783 | t.Errorf("New() mismatch (-want, +got):\n%v", diff) 784 | } 785 | 786 | if util.Validate[mock.P](got) != true { 787 | t.Errorf("Validate() = %v, want = %v", false, true) 788 | } 789 | }) 790 | } 791 | } 792 | 793 | func TestHoare(t *testing.T) { 794 | type result struct { 795 | data []mock.P 796 | pivot int 797 | } 798 | 799 | type config struct { 800 | name string 801 | 802 | data []mock.P 803 | pivot int 804 | low int 805 | high int 806 | less func(a vnd.V, b vnd.V) bool 807 | 808 | want result 809 | } 810 | 811 | configs := []config{ 812 | { 813 | name: "Trivial", 814 | data: []mock.P{ 815 | mock.P{ 816 | X: mock.V(*vnd.New(100, 80)), 817 | Data: "foo", 818 | }, 819 | }, 820 | pivot: 0, 821 | low: 0, 822 | high: 1, 823 | less: vector.Comparator(vnd.AXIS_X).Less, 824 | want: result{ 825 | data: []mock.P{ 826 | mock.P{ 827 | X: mock.V(*vnd.New(100, 80)), 828 | Data: "foo", 829 | }, 830 | }, 831 | pivot: 0, 832 | }, 833 | }, 834 | { 835 | name: "Simple/NoSwap", 836 | data: []mock.P{ 837 | mock.P{ 838 | X: mock.U(0), 839 | Data: "foo", 840 | }, 841 | mock.P{ 842 | X: mock.U(1), 843 | Data: "bar", 844 | }, 845 | }, 846 | pivot: 0, 847 | low: 0, 848 | high: 2, 849 | less: vector.Comparator(vnd.AXIS_X).Less, 850 | want: result{ 851 | data: []mock.P{ 852 | mock.P{ 853 | X: mock.U(0), 854 | Data: "foo", 855 | }, 856 | mock.P{ 857 | X: mock.U(1), 858 | Data: "bar", 859 | }, 860 | }, 861 | pivot: 0, 862 | }, 863 | }, 864 | { 865 | name: "Simple/Swap", 866 | data: []mock.P{ 867 | mock.P{ 868 | X: mock.U(1), 869 | Data: "bar", 870 | }, 871 | mock.P{ 872 | X: mock.U(0), 873 | Data: "foo", 874 | }, 875 | }, 876 | pivot: 0, 877 | low: 0, 878 | high: 2, 879 | less: vector.Comparator(vnd.AXIS_X).Less, 880 | want: result{ 881 | data: []mock.P{ 882 | mock.P{ 883 | X: mock.U(0), 884 | Data: "foo", 885 | }, 886 | mock.P{ 887 | X: mock.U(1), 888 | Data: "bar", 889 | }, 890 | }, 891 | pivot: 1, 892 | }, 893 | }, 894 | { 895 | name: "Pivot", 896 | data: []mock.P{ 897 | mock.P{ 898 | X: mock.U(100), 899 | Data: "2", 900 | }, 901 | mock.P{ 902 | X: mock.U(0), 903 | Data: "0", 904 | }, 905 | mock.P{ 906 | X: mock.U(50), 907 | Data: "1", 908 | }, 909 | }, 910 | pivot: 1, 911 | low: 0, 912 | high: 3, 913 | less: vector.Comparator(vnd.AXIS_X).Less, 914 | want: result{ 915 | data: []mock.P{ 916 | mock.P{ 917 | X: mock.U(0), 918 | Data: "0", 919 | }, 920 | mock.P{ 921 | X: mock.U(100), 922 | Data: "2", 923 | }, 924 | mock.P{ 925 | X: mock.U(50), 926 | Data: "1", 927 | }, 928 | }, 929 | pivot: 0, 930 | }, 931 | }, 932 | { 933 | name: "Pivot/Partial", 934 | data: []mock.P{ 935 | mock.P{ 936 | X: mock.U(100), 937 | Data: "2", 938 | }, 939 | mock.P{ 940 | X: mock.U(0), 941 | Data: "0", 942 | }, 943 | mock.P{ 944 | X: mock.U(50), 945 | Data: "1", 946 | }, 947 | }, 948 | pivot: 2, 949 | low: 1, 950 | high: 3, 951 | less: vector.Comparator(vnd.AXIS_X).Less, 952 | want: result{ 953 | data: []mock.P{ 954 | mock.P{ 955 | X: mock.U(100), 956 | Data: "2", 957 | }, 958 | mock.P{ 959 | X: mock.U(0), 960 | Data: "0", 961 | }, 962 | mock.P{ 963 | X: mock.U(50), 964 | Data: "1", 965 | }, 966 | }, 967 | pivot: 2, 968 | }, 969 | }, 970 | 971 | { 972 | name: "Pivot/Equal", 973 | data: []mock.P{ 974 | mock.P{X: mock.U(0)}, 975 | mock.P{ 976 | X: mock.U(100), 977 | Data: "B", 978 | }, 979 | mock.P{X: mock.U(50)}, 980 | mock.P{X: mock.U(150)}, 981 | mock.P{ 982 | X: mock.U(100), 983 | Data: "A", 984 | }, 985 | }, 986 | pivot: 1, 987 | low: 0, 988 | high: 5, 989 | less: vector.Comparator(vnd.AXIS_X).Less, 990 | want: result{ 991 | data: []mock.P{ 992 | mock.P{X: mock.U(50)}, 993 | mock.P{X: mock.U(0)}, 994 | mock.P{ 995 | X: mock.U(100), 996 | Data: "B", 997 | }, 998 | mock.P{X: mock.U(150)}, 999 | mock.P{ 1000 | X: mock.U(100), 1001 | Data: "A", 1002 | }, 1003 | }, 1004 | pivot: 2, 1005 | }, 1006 | }, 1007 | } 1008 | 1009 | for _, c := range configs { 1010 | t.Run(c.name, func(t *testing.T) { 1011 | if got := hoare(c.data, c.pivot, c.low, c.high, c.less); got != c.want.pivot { 1012 | t.Errorf("hoare() = %v, want = %v", got, c.want.pivot) 1013 | } 1014 | 1015 | if diff := cmp.Diff(c.want.data, c.data); diff != "" { 1016 | t.Errorf("hoare() mismatch (-want +got):\n%v", diff) 1017 | } 1018 | }) 1019 | } 1020 | } 1021 | -------------------------------------------------------------------------------- /internal/node/util/util.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "github.com/downflux/go-kd/internal/node" 5 | "github.com/downflux/go-kd/point" 6 | "github.com/downflux/go-kd/vector" 7 | ) 8 | 9 | func Map[T point.P](n node.N[T], f func(n node.N[T])) { 10 | open := []node.N[T]{n} 11 | for len(open) > 0 { 12 | var n node.N[T] 13 | n, open = open[0], open[1:] 14 | 15 | if n.Nil() { 16 | continue 17 | } 18 | 19 | if !n.L().Nil() { 20 | open = append(open, n.L()) 21 | } 22 | if !n.R().Nil() { 23 | open = append(open, n.R()) 24 | } 25 | 26 | f(n) 27 | } 28 | } 29 | 30 | func Validate[T point.P](t node.N[T]) bool { 31 | equal := true 32 | f := func(n node.N[T]) { 33 | if n.Nil() { 34 | return 35 | } 36 | 37 | if !n.L().Nil() { 38 | for _, p := range n.L().Data() { 39 | equal = equal && vector.Comparator(n.Axis()).Less(p.P(), n.Pivot()) 40 | } 41 | } 42 | if !n.R().Nil() { 43 | for _, p := range n.R().Data() { 44 | equal = equal && !vector.Comparator(n.Axis()).Less(p.P(), n.Pivot()) 45 | } 46 | } 47 | } 48 | Map[T](t, f) 49 | return equal 50 | } 51 | -------------------------------------------------------------------------------- /internal/perf/perf_test.go: -------------------------------------------------------------------------------- 1 | // Package perf runs a suite of perf tests. 2 | // 3 | // CI tests are run against a smaller set of configurations in order to fit into 4 | // computational time constraints. To run the full set of tests (which make take 5 | // up to an hour), run 6 | // 7 | // go test github.com/downflux/go-kd/internal/perf \ 8 | // -bench . -benchmem -timeout=60m \ 9 | // -args -performance_test_size=large 10 | package perf 11 | 12 | import ( 13 | "flag" 14 | "fmt" 15 | "os" 16 | "testing" 17 | "unsafe" 18 | 19 | "github.com/downflux/go-geometry/nd/hyperrectangle" 20 | "github.com/downflux/go-geometry/nd/vector" 21 | "github.com/downflux/go-kd/container" 22 | "github.com/downflux/go-kd/container/bruteforce" 23 | "github.com/downflux/go-kd/container/kyroy" 24 | "github.com/downflux/go-kd/internal/perf/util" 25 | "github.com/downflux/go-kd/kd" 26 | "github.com/downflux/go-kd/point/mock" 27 | 28 | ckd "github.com/downflux/go-kd/container/kd" 29 | ) 30 | 31 | var ( 32 | SuiteSize = util.SizeSmall 33 | ) 34 | 35 | func TestMain(m *testing.M) { 36 | flag.Var(&SuiteSize, "performance_test_size", "performance test size, one of (unit | small | large)") 37 | flag.Parse() 38 | 39 | os.Exit(m.Run()) 40 | } 41 | 42 | func BenchmarkNew(b *testing.B) { 43 | type config struct { 44 | name string 45 | k vector.D 46 | n int 47 | 48 | // kyroy implementation does not take a leaf-size parameter. 49 | kyroy bool 50 | 51 | size int 52 | } 53 | 54 | var configs []config 55 | for _, k := range SuiteSize.K() { 56 | for _, n := range SuiteSize.N() { 57 | configs = append(configs, config{ 58 | name: fmt.Sprintf("kyroy/K=%v/N=%v", k, n), 59 | k: k, 60 | n: n, 61 | kyroy: true, 62 | }) 63 | for _, size := range SuiteSize.LeafSize() { 64 | configs = append(configs, config{ 65 | name: fmt.Sprintf("Real/K=%v/N=%v/LeafSize=%v", k, n, size), 66 | k: k, 67 | n: n, 68 | size: size, 69 | }) 70 | } 71 | } 72 | } 73 | 74 | for _, c := range configs { 75 | ps := util.Generate(c.n, c.k) 76 | 77 | if c.kyroy { 78 | b.Run(c.name, func(b *testing.B) { 79 | for i := 0; i < b.N; i++ { 80 | kyroy.New[*mock.P](ps) 81 | } 82 | }) 83 | } else { 84 | b.Run(c.name, func(b *testing.B) { 85 | for i := 0; i < b.N; i++ { 86 | kd.New[*mock.P](kd.O[*mock.P]{ 87 | Data: ps, 88 | K: c.k, 89 | N: c.size, 90 | }) 91 | } 92 | }) 93 | } 94 | } 95 | } 96 | 97 | func BenchmarkKNN(b *testing.B) { 98 | type config struct { 99 | name string 100 | t container.C[*mock.P] 101 | p vector.V 102 | knn int 103 | } 104 | 105 | var configs []config 106 | for _, k := range SuiteSize.K() { 107 | for _, n := range SuiteSize.N() { 108 | ps := util.Generate(n, k) 109 | 110 | // Brute force approach sorts all data, meaning that the 111 | // KNN factor does not matter. 112 | configs = append(configs, config{ 113 | name: fmt.Sprintf("BruteForce/K=%v/N=%v", k, n), 114 | t: bruteforce.New[*mock.P](ps), 115 | p: vector.V(make([]float64, k)), 116 | knn: n, 117 | }) 118 | 119 | for _, f := range SuiteSize.F() { 120 | knn := int(float64(n) * f) 121 | 122 | // kyroy implementation does not take a 123 | // leaf-size parameter. 124 | configs = append(configs, config{ 125 | name: fmt.Sprintf("kyroy/K=%v/N=%v/KNN=%v", k, n, f), 126 | t: kyroy.New[*mock.P](ps), 127 | p: vector.V(make([]float64, k)), 128 | knn: knn, 129 | }) 130 | 131 | for _, size := range SuiteSize.LeafSize() { 132 | configs = append(configs, config{ 133 | name: fmt.Sprintf("Real/K=%v/N=%v/LeafSize=%v/KNN=%v", k, n, size, f), 134 | t: (*ckd.KD[*mock.P])(unsafe.Pointer( 135 | kd.New[*mock.P](kd.O[*mock.P]{ 136 | Data: ps, 137 | K: k, 138 | N: size, 139 | }), 140 | )), 141 | p: vector.V(make([]float64, k)), 142 | knn: knn, 143 | }) 144 | } 145 | 146 | } 147 | } 148 | } 149 | 150 | for _, c := range configs { 151 | b.Run(c.name, func(b *testing.B) { 152 | for i := 0; i < b.N; i++ { 153 | c.t.KNN(c.p, c.knn, util.TrivialFilter) 154 | } 155 | }) 156 | } 157 | } 158 | 159 | func BenchmarkRangeSearch(b *testing.B) { 160 | type config struct { 161 | name string 162 | t container.C[*mock.P] 163 | q hyperrectangle.R 164 | } 165 | 166 | var configs []config 167 | for _, k := range SuiteSize.K() { 168 | for _, n := range SuiteSize.N() { 169 | ps := util.Generate(n, k) 170 | 171 | // Brute force approach sorts all data, meaning that the 172 | // query range factor does not matter. 173 | configs = append(configs, config{ 174 | name: fmt.Sprintf("BruteForce/K=%v/N=%v", k, n), 175 | t: bruteforce.New[*mock.P](ps), 176 | q: util.RH(k, 1), 177 | }) 178 | 179 | for _, f := range SuiteSize.F() { 180 | q := util.RH(k, f) 181 | 182 | // kyroy implementation does not take a 183 | // leaf-size parameter. 184 | configs = append(configs, config{ 185 | name: fmt.Sprintf("kyroy/K=%v/N=%v/Coverage=%v", k, n, f), 186 | t: kyroy.New[*mock.P](ps), 187 | q: q, 188 | }) 189 | 190 | for _, size := range SuiteSize.LeafSize() { 191 | configs = append(configs, config{ 192 | name: fmt.Sprintf("Real/K=%v/N=%v/LeafSize=%v/Coverage=%v", k, n, size, f), 193 | t: (*ckd.KD[*mock.P])(unsafe.Pointer( 194 | kd.New[*mock.P](kd.O[*mock.P]{ 195 | Data: ps, 196 | K: k, 197 | N: size, 198 | }), 199 | )), 200 | q: q, 201 | }) 202 | } 203 | } 204 | } 205 | } 206 | 207 | for _, c := range configs { 208 | b.Run(c.name, func(b *testing.B) { 209 | for i := 0; i < b.N; i++ { 210 | c.t.RangeSearch(c.q, util.TrivialFilter) 211 | } 212 | }) 213 | } 214 | } 215 | -------------------------------------------------------------------------------- /internal/perf/results/v0.5.5.txt: -------------------------------------------------------------------------------- 1 | goos: linux 2 | goarch: amd64 3 | pkg: github.com/downflux/go-kd/internal/perf 4 | cpu: Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz 5 | BenchmarkNew/kyroy/K=16/N=1000-8 1528 758980 ns/op 146777 B/op 2524 allocs/op 6 | BenchmarkNew/Real/K=16/N=1000/LeafSize=1-8 3805 276313 ns/op 126098 B/op 2089 allocs/op 7 | BenchmarkNew/Real/K=16/N=1000/LeafSize=16-8 6034 200749 ns/op 32637 B/op 420 allocs/op 8 | BenchmarkNew/Real/K=16/N=1000/LeafSize=256-8 10000 113851 ns/op 12155 B/op 63 allocs/op 9 | BenchmarkNew/kyroy/K=16/N=10000-8 79 15089373 ns/op 1674236 B/op 25924 allocs/op 10 | BenchmarkNew/Real/K=16/N=10000/LeafSize=1-8 514 2201218 ns/op 1263945 B/op 20928 allocs/op 11 | BenchmarkNew/Real/K=16/N=10000/LeafSize=16-8 751 1599132 ns/op 330730 B/op 4264 allocs/op 12 | BenchmarkNew/Real/K=16/N=10000/LeafSize=256-8 886 1273534 ns/op 125601 B/op 692 allocs/op 13 | BenchmarkNew/kyroy/K=16/N=1000000-8 1 7407144200 ns/op 184813784 B/op 2524327 allocs/op 14 | BenchmarkNew/Real/K=16/N=1000000/LeafSize=1-8 2 735249000 ns/op 127022260 B/op 2096135 allocs/op 15 | BenchmarkNew/Real/K=16/N=1000000/LeafSize=16-8 2 559409550 ns/op 33078812 B/op 428590 allocs/op 16 | BenchmarkNew/Real/K=16/N=1000000/LeafSize=256-8 2 588456300 ns/op 12462912 B/op 70330 allocs/op 17 | BenchmarkKNN/BruteForce/K=16/N=1000-8 956 1563019 ns/op 2220712 B/op 17165 allocs/op 18 | BenchmarkKNN/kyroy/K=16/N=1000/KNN=0.05-8 1501 791415 ns/op 21960 B/op 1116 allocs/op 19 | BenchmarkKNN/Real/K=16/N=1000/LeafSize=1/KNN=0.05-8 6880 176106 ns/op 37984 B/op 972 allocs/op 20 | BenchmarkKNN/Real/K=16/N=1000/LeafSize=16/KNN=0.05-8 17564 69537 ns/op 12024 B/op 330 allocs/op 21 | BenchmarkKNN/Real/K=16/N=1000/LeafSize=256/KNN=0.05-8 22638 53922 ns/op 6880 B/op 209 allocs/op 22 | BenchmarkKNN/kyroy/K=16/N=1000/KNN=0.1-8 996 1194847 ns/op 27880 B/op 1242 allocs/op 23 | BenchmarkKNN/Real/K=16/N=1000/LeafSize=1/KNN=0.1-8 6176 196038 ns/op 44184 B/op 1102 allocs/op 24 | BenchmarkKNN/Real/K=16/N=1000/LeafSize=16/KNN=0.1-8 10000 101893 ns/op 17896 B/op 489 allocs/op 25 | BenchmarkKNN/Real/K=16/N=1000/LeafSize=256/KNN=0.1-8 16645 70664 ns/op 10784 B/op 295 allocs/op 26 | BenchmarkKNN/BruteForce/K=16/N=10000-8 74 25007432 ns/op 30633256 B/op 236548 allocs/op 27 | BenchmarkKNN/kyroy/K=16/N=10000/KNN=0.05-8 37 30799189 ns/op 223040 B/op 10906 allocs/op 28 | BenchmarkKNN/Real/K=16/N=10000/LeafSize=1/KNN=0.05-8 654 2057458 ns/op 373568 B/op 9747 allocs/op 29 | BenchmarkKNN/Real/K=16/N=10000/LeafSize=16/KNN=0.05-8 1303 889883 ns/op 118112 B/op 3294 allocs/op 30 | BenchmarkKNN/Real/K=16/N=10000/LeafSize=256/KNN=0.05-8 1663 679360 ns/op 58024 B/op 1741 allocs/op 31 | BenchmarkKNN/kyroy/K=16/N=10000/KNN=0.1-8 13 91103708 ns/op 297008 B/op 12232 allocs/op 32 | BenchmarkKNN/Real/K=16/N=10000/LeafSize=1/KNN=0.1-8 562 2202105 ns/op 413840 B/op 10845 allocs/op 33 | BenchmarkKNN/Real/K=16/N=10000/LeafSize=16/KNN=0.1-8 961 1215787 ns/op 165600 B/op 4681 allocs/op 34 | BenchmarkKNN/Real/K=16/N=10000/LeafSize=256/KNN=0.1-8 1220 984166 ns/op 100272 B/op 2923 allocs/op 35 | BenchmarkKNN/BruteForce/K=16/N=1000000-8 1 5030811400 ns/op 5347687464 B/op 41453237 allocs/op 36 | BenchmarkKNN/kyroy/K=16/N=1000000/KNN=0.05-8 1 529703585200 ns/op 23755688 B/op 1107742 allocs/op 37 | BenchmarkKNN/Real/K=16/N=1000000/LeafSize=1/KNN=0.05-8 3 464044100 ns/op 36143720 B/op 1001542 allocs/op 38 | BenchmarkKNN/Real/K=16/N=1000000/LeafSize=16/KNN=0.05-8 3 347817233 ns/op 11420744 B/op 333388 allocs/op 39 | BenchmarkKNN/Real/K=16/N=1000000/LeafSize=256/KNN=0.05-8 3 335845533 ns/op 6044016 B/op 190971 allocs/op 40 | BenchmarkKNN/kyroy/K=16/N=1000000/KNN=0.1-8 1 1694060569900 ns/op 31972504 B/op 1237806 allocs/op 41 | BenchmarkKNN/Real/K=16/N=1000000/LeafSize=1/KNN=0.1-8 3 501073000 ns/op 40388328 B/op 1130901 allocs/op 42 | BenchmarkKNN/Real/K=16/N=1000000/LeafSize=16/KNN=0.1-8 3 394814333 ns/op 16062312 B/op 473830 allocs/op 43 | BenchmarkKNN/Real/K=16/N=1000000/LeafSize=256/KNN=0.1-8 3 365633867 ns/op 10085976 B/op 304736 allocs/op 44 | BenchmarkRangeSearch/BruteForce/K=16/N=1000-8 7825 154712 ns/op 25208 B/op 12 allocs/op 45 | BenchmarkRangeSearch/kyroy/K=16/N=1000/Coverage=0.05-8 89456 13373 ns/op 496 B/op 5 allocs/op 46 | BenchmarkRangeSearch/Real/K=16/N=1000/LeafSize=1/Coverage=0.05-8 5394 314928 ns/op 207113 B/op 1978 allocs/op 47 | BenchmarkRangeSearch/Real/K=16/N=1000/LeafSize=16/Coverage=0.05-8 7376 193276 ns/op 101603 B/op 970 allocs/op 48 | BenchmarkRangeSearch/Real/K=16/N=1000/LeafSize=256/Coverage=0.05-8 15967 75247 ns/op 21216 B/op 202 allocs/op 49 | BenchmarkRangeSearch/kyroy/K=16/N=1000/Coverage=0.1-8 58239 20985 ns/op 496 B/op 5 allocs/op 50 | BenchmarkRangeSearch/Real/K=16/N=1000/LeafSize=1/Coverage=0.1-8 5154 288420 ns/op 179478 B/op 1714 allocs/op 51 | BenchmarkRangeSearch/Real/K=16/N=1000/LeafSize=16/Coverage=0.1-8 6628 237190 ns/op 121699 B/op 1162 allocs/op 52 | BenchmarkRangeSearch/Real/K=16/N=1000/LeafSize=256/Coverage=0.1-8 16291 69358 ns/op 21216 B/op 202 allocs/op 53 | BenchmarkRangeSearch/BruteForce/K=16/N=10000-8 774 1594897 ns/op 357624 B/op 19 allocs/op 54 | BenchmarkRangeSearch/kyroy/K=16/N=10000/Coverage=0.05-8 5510 205729 ns/op 496 B/op 5 allocs/op 55 | BenchmarkRangeSearch/Real/K=16/N=10000/LeafSize=1/Coverage=0.05-8 4323 332339 ns/op 202086 B/op 1930 allocs/op 56 | BenchmarkRangeSearch/Real/K=16/N=10000/LeafSize=16/Coverage=0.05-8 4491 336055 ns/op 141795 B/op 1354 allocs/op 57 | BenchmarkRangeSearch/Real/K=16/N=10000/LeafSize=256/Coverage=0.05-8 6256 187946 ns/op 46337 B/op 442 allocs/op 58 | BenchmarkRangeSearch/kyroy/K=16/N=10000/Coverage=0.1-8 3904 288862 ns/op 496 B/op 5 allocs/op 59 | BenchmarkRangeSearch/Real/K=16/N=10000/LeafSize=1/Coverage=0.1-8 643 2387566 ns/op 2355150 B/op 22500 allocs/op 60 | BenchmarkRangeSearch/Real/K=16/N=10000/LeafSize=16/Coverage=0.1-8 2816 614839 ns/op 523648 B/op 5002 allocs/op 61 | BenchmarkRangeSearch/Real/K=16/N=10000/LeafSize=256/Coverage=0.1-8 5074 258066 ns/op 101605 B/op 970 allocs/op 62 | BenchmarkRangeSearch/BruteForce/K=16/N=1000000-8 7 173427000 ns/op 41678072 B/op 38 allocs/op 63 | BenchmarkRangeSearch/kyroy/K=16/N=1000000/Coverage=0.05-8 20 56820240 ns/op 496 B/op 5 allocs/op 64 | BenchmarkRangeSearch/Real/K=16/N=1000000/LeafSize=1/Coverage=0.05-8 266 5463061 ns/op 5008653 B/op 47853 allocs/op 65 | BenchmarkRangeSearch/Real/K=16/N=1000000/LeafSize=16/Coverage=0.05-8 698 2587562 ns/op 2242039 B/op 21420 allocs/op 66 | BenchmarkRangeSearch/Real/K=16/N=1000000/LeafSize=256/Coverage=0.05-8 2593 530937 ns/op 212134 B/op 2026 allocs/op 67 | BenchmarkRangeSearch/kyroy/K=16/N=1000000/Coverage=0.1-8 15 76181887 ns/op 496 B/op 5 allocs/op 68 | BenchmarkRangeSearch/Real/K=16/N=1000000/LeafSize=1/Coverage=0.1-8 82 18895179 ns/op 17748509 B/op 169579 allocs/op 69 | BenchmarkRangeSearch/Real/K=16/N=1000000/LeafSize=16/Coverage=0.1-8 150 6825001 ns/op 6734713 B/op 64344 allocs/op 70 | BenchmarkRangeSearch/Real/K=16/N=1000000/LeafSize=256/Coverage=0.1-8 298 4691212 ns/op 2920521 B/op 27902 allocs/op 71 | PASS 72 | ok github.com/downflux/go-kd/internal/perf 2466.847s 73 | -------------------------------------------------------------------------------- /internal/perf/util/util.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/rand" 7 | "runtime" 8 | "sort" 9 | 10 | "github.com/downflux/go-geometry/nd/hyperrectangle" 11 | "github.com/downflux/go-geometry/nd/vector" 12 | "github.com/downflux/go-kd/point" 13 | "github.com/downflux/go-kd/point/mock" 14 | "github.com/google/go-cmp/cmp" 15 | ) 16 | 17 | type PerfTestSize int 18 | 19 | const ( 20 | SizeUnknown PerfTestSize = iota 21 | SizeUnit 22 | SizeSmall 23 | SizeLarge 24 | ) 25 | 26 | func (s *PerfTestSize) String() string { 27 | return map[PerfTestSize]string{ 28 | SizeUnit: "unit", 29 | SizeSmall: "small", 30 | SizeLarge: "large", 31 | }[*s] 32 | } 33 | 34 | func (s *PerfTestSize) Set(v string) error { 35 | size, ok := map[string]PerfTestSize{ 36 | "unit": SizeUnit, 37 | "small": SizeSmall, 38 | "large": SizeLarge, 39 | }[v] 40 | if !ok { 41 | return fmt.Errorf("invalid test size value: %v", v) 42 | } 43 | *s = size 44 | return nil 45 | } 46 | 47 | func (s PerfTestSize) F() []float64 { 48 | return map[PerfTestSize][]float64{ 49 | SizeUnit: []float64{0.05}, 50 | SizeSmall: []float64{0.05}, 51 | SizeLarge: []float64{0.05, 0.1}, 52 | }[s] 53 | } 54 | 55 | func (s PerfTestSize) LeafSize() []int { 56 | return map[PerfTestSize][]int{ 57 | SizeUnit: []int{1, 16}, 58 | SizeSmall: []int{1, 32, 512}, 59 | SizeLarge: []int{1, 16, 256}, 60 | }[s] 61 | } 62 | 63 | func (s PerfTestSize) N() []int { 64 | return map[PerfTestSize][]int{ 65 | SizeUnit: []int{1e3}, 66 | SizeSmall: []int{1e3, 1e4}, 67 | SizeLarge: []int{1e3, 1e4, 1e6}, 68 | }[s] 69 | } 70 | 71 | func (s PerfTestSize) K() []vector.D { 72 | return map[PerfTestSize][]vector.D{ 73 | SizeUnit: []vector.D{2}, 74 | SizeSmall: []vector.D{2, 16}, 75 | 76 | // Large tests phyically cannot store enough point data in 77 | // memory with high-dimensional data. 78 | SizeLarge: []vector.D{16}, 79 | }[s] 80 | } 81 | 82 | func TrivialFilter(p *mock.P) bool { return true } 83 | 84 | // Transformer sorts a list of points. 85 | func Transformer(p vector.V) cmp.Option { 86 | return cmp.Transformer("Sort", func(in []*mock.P) []*mock.P { 87 | out := append([]*mock.P(nil), in...) 88 | sort.Sort(L[*mock.P]{ 89 | Data: out, 90 | P: p, 91 | }) 92 | return out 93 | }) 94 | } 95 | 96 | type L[T point.P] struct { 97 | Data []T 98 | P vector.V 99 | } 100 | 101 | func (l L[T]) Len() int { return len(l.Data) } 102 | func (l L[T]) Swap(i, j int) { l.Data[i], l.Data[j] = l.Data[j], l.Data[i] } 103 | 104 | func (l L[T]) Less(i, j int) bool { 105 | return vector.SquaredMagnitude( 106 | vector.Sub(l.Data[i].P(), l.P), 107 | ) < vector.SquaredMagnitude( 108 | vector.Sub(l.Data[j].P(), l.P), 109 | ) 110 | } 111 | 112 | func RH(k vector.D, f float64) hyperrectangle.R { 113 | min := make([]float64, k) 114 | max := make([]float64, k) 115 | for i := vector.D(0); i < k; i++ { 116 | min[i] = -100 * math.Sqrt(f) 117 | max[i] = 100 * math.Sqrt(f) 118 | } 119 | return *hyperrectangle.New(vector.V(min), vector.V(max)) 120 | } 121 | 122 | func RV(k vector.D, min float64, max float64) vector.V { 123 | var xs []float64 124 | for i := 0; i < int(k); i++ { 125 | xs = append(xs, rand.Float64()*(max-min)+min) 126 | } 127 | return vector.V(xs) 128 | } 129 | 130 | func Generate(n int, k vector.D) []*mock.P { 131 | // Generating large number of points in tests will mess with data 132 | // collection figures. We should ignore these allocs. 133 | runtime.MemProfileRate = 0 134 | defer func() { runtime.MemProfileRate = 512 * 1024 }() 135 | 136 | var ps []*mock.P 137 | for i := 0; i < n; i++ { 138 | ps = append(ps, &mock.P{ 139 | X: RV(k, -100, 100), 140 | }) 141 | } 142 | 143 | return ps 144 | } 145 | -------------------------------------------------------------------------------- /internal/perf/util/util_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "flag" 5 | ) 6 | 7 | var ( 8 | s = SizeSmall 9 | _ flag.Value = &s 10 | ) 11 | -------------------------------------------------------------------------------- /internal/rangesearch/rangesearch.go: -------------------------------------------------------------------------------- 1 | package rangesearch 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/downflux/go-geometry/nd/hyperrectangle" 7 | "github.com/downflux/go-geometry/nd/vector" 8 | "github.com/downflux/go-kd/internal/node" 9 | "github.com/downflux/go-kd/point" 10 | ) 11 | 12 | func RangeSearch[T point.P](n node.N[T], q hyperrectangle.R, f func(p T) bool) []T { 13 | if n.Nil() { 14 | return nil 15 | } 16 | 17 | min := make([]float64, n.K()) 18 | max := make([]float64, n.K()) 19 | 20 | for i := vector.D(0); i < n.K(); i++ { 21 | min[i] = math.Inf(-1) 22 | max[i] = math.Inf(0) 23 | } 24 | 25 | return rangesearch(n, q, *hyperrectangle.New(vector.V(min), vector.V(max)), f) 26 | } 27 | 28 | func rangesearch[T point.P](n node.N[T], q hyperrectangle.R, bound hyperrectangle.R, f func(p T) bool) []T { 29 | if n.Nil() || hyperrectangle.Disjoint(q, bound) { 30 | return nil 31 | } 32 | 33 | var data []T 34 | for _, p := range n.Data() { 35 | if q.In(p.P()) && f(p) { 36 | data = append(data, p) 37 | } 38 | } 39 | 40 | if n.Leaf() { 41 | return data 42 | } 43 | 44 | l := make(chan []T) 45 | r := make(chan []T) 46 | 47 | go func(ch chan<- []T) { 48 | max := make([]float64, n.K()) 49 | copy(max, bound.Max()) 50 | max[n.Axis()] = n.Pivot().X(n.Axis()) 51 | 52 | bound := *hyperrectangle.New(bound.Min(), max) 53 | ch <- rangesearch(n.L(), q, bound, f) 54 | close(ch) 55 | }(l) 56 | go func(ch chan<- []T) { 57 | min := make([]float64, n.K()) 58 | copy(min, bound.Min()) 59 | min[n.Axis()] = n.Pivot().X(n.Axis()) 60 | 61 | bound := *hyperrectangle.New(min, bound.Max()) 62 | ch <- rangesearch(n.R(), q, bound, f) 63 | close(ch) 64 | }(r) 65 | 66 | data = append(data, <-l...) 67 | data = append(data, <-r...) 68 | 69 | return data 70 | 71 | } 72 | -------------------------------------------------------------------------------- /internal/rangesearch/rangesearch_test.go: -------------------------------------------------------------------------------- 1 | package rangesearch 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/downflux/go-geometry/nd/hyperrectangle" 7 | "github.com/downflux/go-geometry/nd/vector" 8 | "github.com/downflux/go-kd/internal/node/tree" 9 | "github.com/downflux/go-kd/point/mock" 10 | "github.com/google/go-cmp/cmp" 11 | ) 12 | 13 | func TestRangeSearch(t *testing.T) { 14 | type config struct { 15 | name string 16 | data []*mock.P 17 | k vector.D 18 | n int 19 | q hyperrectangle.R 20 | want []*mock.P 21 | } 22 | 23 | configs := []config{ 24 | { 25 | name: "Nil", 26 | data: nil, 27 | k: 1, 28 | n: 1, 29 | q: *hyperrectangle.New(mock.U(1), mock.U(2)), 30 | want: nil, 31 | }, 32 | { 33 | name: "Simple", 34 | data: []*mock.P{ 35 | &mock.P{X: mock.U(1.5)}, 36 | }, 37 | k: 1, 38 | n: 1, 39 | q: *hyperrectangle.New(mock.U(1), mock.U(2)), 40 | want: []*mock.P{ 41 | &mock.P{X: mock.U(1.5)}, 42 | }, 43 | }, 44 | { 45 | name: "LR", 46 | data: []*mock.P{ 47 | &mock.P{X: mock.U(1.5)}, 48 | &mock.P{X: mock.U(1)}, 49 | &mock.P{X: mock.U(2)}, 50 | }, 51 | k: 1, 52 | n: 1, 53 | q: *hyperrectangle.New(mock.U(1), mock.U(2)), 54 | want: []*mock.P{ 55 | &mock.P{X: mock.U(1.5)}, 56 | &mock.P{X: mock.U(1)}, 57 | &mock.P{X: mock.U(2)}, 58 | }, 59 | }, 60 | { 61 | name: "Partial", 62 | data: []*mock.P{ 63 | &mock.P{X: mock.U(1.5)}, 64 | &mock.P{X: mock.U(1)}, 65 | &mock.P{X: mock.U(2)}, 66 | }, 67 | k: 1, 68 | n: 1, 69 | q: *hyperrectangle.New(mock.U(1.9), mock.U(2.1)), 70 | want: []*mock.P{ 71 | &mock.P{X: mock.U(2)}, 72 | }, 73 | }, 74 | } 75 | 76 | for _, c := range configs { 77 | t.Run(c.name, func(t *testing.T) { 78 | got := RangeSearch[*mock.P]( 79 | tree.New[*mock.P](tree.O[*mock.P]{ 80 | Data: c.data, 81 | K: c.k, 82 | N: c.n, 83 | Axis: vector.AXIS_X, 84 | }), 85 | c.q, 86 | func(*mock.P) bool { return true }, 87 | ) 88 | if diff := cmp.Diff(c.want, got); diff != "" { 89 | t.Errorf("RangeSearch() mismatch (-want +got):\n%v", diff) 90 | } 91 | }) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /kd/kd.go: -------------------------------------------------------------------------------- 1 | // Package kd implements a k-D tree with arbitrary data packing and duplicate 2 | // data coordinate support. 3 | // 4 | // k-D trees are generally a cacheing layer representation of the local state -- 5 | // we do not expect to be making frequent mutations to this tree once 6 | // constructed. 7 | // 8 | // Read operations on this k-D tree may be done in parallel. Mutations on the 9 | // k-D tree must be done serially. 10 | // 11 | // N.B.: Mutating the data point positions must be accompanied by mutating the 12 | // k-D tree. For large numbers of points, and for a large number of queries, the 13 | // time taken to build the tree will be offset by the speedup of subsequent 14 | // reads. 15 | package kd 16 | 17 | import ( 18 | "github.com/downflux/go-geometry/nd/hyperrectangle" 19 | "github.com/downflux/go-geometry/nd/vector" 20 | "github.com/downflux/go-kd/filter" 21 | "github.com/downflux/go-kd/internal/knn" 22 | "github.com/downflux/go-kd/internal/node" 23 | "github.com/downflux/go-kd/internal/node/tree" 24 | "github.com/downflux/go-kd/internal/rangesearch" 25 | "github.com/downflux/go-kd/point" 26 | ) 27 | 28 | type O[T point.P] struct { 29 | Data []T 30 | K vector.D 31 | 32 | // N is the nominal leaf size of the k-D tree. Leaf nodes are checked 33 | // via bruteforce methods. 34 | // 35 | // Note that individual nodes (including non-leaf nodes) may contain 36 | // elements that exceed this size constraint after inserts and removes. 37 | // 38 | // Leaf size will significantly impact performance -- users should 39 | // tailor this value to their specific use-case. We recommend setting 40 | // this value to 16 and up as the size of the data set increases. 41 | N int 42 | } 43 | 44 | type KD[T point.P] struct { 45 | k vector.D 46 | n int 47 | 48 | root node.N[T] 49 | } 50 | 51 | func New[T point.P](o O[T]) *KD[T] { 52 | data := make([]T, len(o.Data)) 53 | if l := copy(data, o.Data); l != len(o.Data) { 54 | panic("could not copy data into k-D tree") 55 | } 56 | if o.K < 1 { 57 | panic("k-D tree must contain points with non-zero length vectors") 58 | } 59 | if o.N < 1 { 60 | panic("k-D tree minimum leaf node size must be positive") 61 | } 62 | 63 | t := &KD[T]{ 64 | k: o.K, 65 | n: o.N, 66 | root: tree.New[T](tree.O[T]{ 67 | Data: data, 68 | Axis: vector.AXIS_X, 69 | K: o.K, 70 | N: o.N, 71 | }), 72 | } 73 | 74 | return t 75 | } 76 | 77 | // Balance reconstructs the k-D tree. 78 | // 79 | // This k-D tree implementation does not support concurrent mutations. 80 | func (t *KD[T]) Balance() { 81 | t.root = tree.New[T](tree.O[T]{ 82 | Data: Data(t), 83 | Axis: vector.AXIS_X, 84 | K: t.k, 85 | N: t.n, 86 | }) 87 | } 88 | 89 | // Insert adds a new point into the k-D tree. 90 | // 91 | // Insert is not a balanced operation -- after many mutations, the tree should 92 | // be explicitly reconstructed. 93 | // 94 | // This k-D tree implementation does not support concurrent mutations. 95 | func (t *KD[T]) Insert(p T) { t.root.Insert(p) } 96 | 97 | // Remove pops a point from the k-D tree which lies at the input vector v and 98 | // matches the filter. Note that if multiple points match both the location 99 | // vector and the filter, an arbitrary one will be removed. This function will 100 | // pop at most one element from the k-D tree. 101 | // 102 | // Remove is not a balanced operation -- after many mutations, the tree should 103 | // be explicitly reconstructed. 104 | // 105 | // This k-D tree implementation does not support concurrent mutations. 106 | // 107 | // If there is no matching point, the returned bool will be false. 108 | func (t *KD[T]) Remove(v vector.V, f filter.F[T]) (T, bool) { return t.root.Remove(v, f) } 109 | 110 | // KNN returns the k nearest neighbors to the input vector p and matches the 111 | // filter function. 112 | // 113 | // This k-D tree implementation supports concurrent read operations. 114 | func KNN[T point.P](t *KD[T], p vector.V, k int, f filter.F[T]) []T { 115 | return knn.KNN(t.root, p, k, f) 116 | } 117 | 118 | // RangeSearch returns all points which are found in the given bounds and 119 | // matches the filter function. 120 | // 121 | // This k-D tree implementation supports concurrent read operations. 122 | func RangeSearch[T point.P](t *KD[T], q hyperrectangle.R, f filter.F[T]) []T { 123 | return rangesearch.RangeSearch(t.root, q, f) 124 | } 125 | 126 | // Data returns all points in the k-D tree. 127 | // 128 | // This k-D tree implementation supports concurrent read operations. 129 | func Data[T point.P](t *KD[T]) []T { 130 | if t.root.Nil() { 131 | return nil 132 | } 133 | var data []T 134 | 135 | var n node.N[T] 136 | open := []node.N[T]{t.root} 137 | for len(open) > 0 { 138 | n, open = open[0], open[1:] 139 | 140 | data = append(data, n.Data()...) 141 | if !n.L().Nil() { 142 | open = append(open, n.L()) 143 | } 144 | if !n.R().Nil() { 145 | open = append(open, n.R()) 146 | } 147 | } 148 | 149 | return data 150 | } 151 | -------------------------------------------------------------------------------- /kd/kd_test.go: -------------------------------------------------------------------------------- 1 | package kd 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/downflux/go-geometry/nd/hyperrectangle" 8 | "github.com/downflux/go-kd/container/bruteforce" 9 | "github.com/downflux/go-kd/internal/node/util" 10 | "github.com/downflux/go-kd/point/mock" 11 | "github.com/downflux/go-kd/vector" 12 | "github.com/google/go-cmp/cmp" 13 | "github.com/google/go-cmp/cmp/cmpopts" 14 | 15 | vnd "github.com/downflux/go-geometry/nd/vector" 16 | putil "github.com/downflux/go-kd/internal/perf/util" 17 | ) 18 | 19 | func TestNew(t *testing.T) { 20 | type config struct { 21 | name string 22 | k vnd.D 23 | n int 24 | 25 | size int 26 | } 27 | 28 | var configs []config 29 | for _, k := range putil.PerfTestSize(putil.SizeUnit).K() { 30 | for _, n := range putil.PerfTestSize(putil.SizeUnit).N() { 31 | for _, size := range putil.PerfTestSize(putil.SizeUnit).LeafSize() { 32 | configs = append(configs, config{ 33 | name: fmt.Sprintf("K=%v/N=%v/LeafSize=%v", k, n, size), 34 | k: k, 35 | n: n, 36 | size: size, 37 | }) 38 | } 39 | } 40 | } 41 | 42 | for _, c := range configs { 43 | ps := putil.Generate(c.n, c.k) 44 | t.Run(c.name, func(t *testing.T) { 45 | tree := New[*mock.P](O[*mock.P]{ 46 | Data: ps, 47 | K: c.k, 48 | N: c.size, 49 | }) 50 | if !util.Validate(tree.root) { 51 | t.Errorf("validate() = %v, want = %v", false, true) 52 | } 53 | }) 54 | } 55 | } 56 | 57 | func TestData(t *testing.T) { 58 | type config struct { 59 | name string 60 | data []*mock.P 61 | k vnd.D 62 | want []*mock.P 63 | } 64 | 65 | configs := []config{ 66 | { 67 | name: "Nil", 68 | data: nil, 69 | want: nil, 70 | k: 1, 71 | }, 72 | { 73 | name: "Simple", 74 | data: []*mock.P{ 75 | &mock.P{X: mock.U(1)}, 76 | }, 77 | want: []*mock.P{ 78 | &mock.P{X: mock.U(1)}, 79 | }, 80 | k: 1, 81 | }, 82 | { 83 | name: "LR", 84 | data: []*mock.P{ 85 | &mock.P{X: mock.U(1)}, 86 | &mock.P{X: mock.U(0)}, 87 | &mock.P{X: mock.U(2)}, 88 | }, 89 | want: []*mock.P{ 90 | &mock.P{X: mock.U(1)}, 91 | &mock.P{X: mock.U(0)}, 92 | &mock.P{X: mock.U(2)}, 93 | }, 94 | k: 1, 95 | }, 96 | } 97 | 98 | for _, c := range configs { 99 | t.Run(c.name, func(t *testing.T) { 100 | kd := New(O[*mock.P]{ 101 | Data: c.data, 102 | K: c.k, 103 | N: 1, 104 | }) 105 | got := Data(kd) 106 | if diff := cmp.Diff(c.want, got, cmpopts.SortSlices( 107 | func(p, q *mock.P) bool { 108 | return vector.Less(p.P(), q.P()) 109 | })); diff != "" { 110 | t.Errorf("KNN mismatch (-want +got):\n%v", diff) 111 | } 112 | }) 113 | } 114 | } 115 | func TestKNN(t *testing.T) { 116 | type config struct { 117 | name string 118 | k vnd.D 119 | n int 120 | size int 121 | 122 | knn int 123 | } 124 | 125 | var configs []config 126 | for _, k := range putil.PerfTestSize(putil.SizeUnit).K() { 127 | for _, n := range putil.PerfTestSize(putil.SizeUnit).N() { 128 | for _, size := range putil.PerfTestSize(putil.SizeUnit).LeafSize() { 129 | for _, f := range putil.PerfTestSize(putil.SizeUnit).F() { 130 | configs = append(configs, config{ 131 | name: fmt.Sprintf("K=%v/N=%v/LeafSize=%v/KNN=%v", k, n, size, f), 132 | k: k, 133 | n: n, 134 | knn: int(float64(n) * f), 135 | size: size, 136 | }) 137 | } 138 | } 139 | } 140 | } 141 | 142 | for _, c := range configs { 143 | ps := putil.Generate(c.n, c.k) 144 | t.Run(c.name, func(t *testing.T) { 145 | p := vnd.V(make([]float64, c.k)) 146 | 147 | got := KNN( 148 | New[*mock.P](O[*mock.P]{ 149 | Data: ps, 150 | K: c.k, 151 | N: c.size, 152 | }), 153 | p, 154 | c.knn, 155 | putil.TrivialFilter, 156 | ) 157 | want := bruteforce.New[*mock.P](ps).KNN(p, c.knn, putil.TrivialFilter) 158 | if diff := cmp.Diff(want, got); diff != "" { 159 | t.Errorf("KNN mismatch (-want +got):\n%v", diff) 160 | } 161 | }) 162 | } 163 | } 164 | 165 | func TestRangeSearch(t *testing.T) { 166 | type config struct { 167 | name string 168 | k vnd.D 169 | n int 170 | size int 171 | q hyperrectangle.R 172 | } 173 | 174 | var configs []config 175 | for _, k := range putil.PerfTestSize(putil.SizeUnit).K() { 176 | for _, n := range putil.PerfTestSize(putil.SizeUnit).N() { 177 | for _, size := range putil.PerfTestSize(putil.SizeUnit).LeafSize() { 178 | for _, f := range putil.PerfTestSize(putil.SizeUnit).F() { 179 | configs = append(configs, config{ 180 | name: fmt.Sprintf("K=%v/N=%v/LeafSize=%v/Coverage=%v", k, n, size, f), 181 | k: k, 182 | n: n, 183 | size: size, 184 | q: putil.RH(k, f), 185 | }) 186 | } 187 | } 188 | } 189 | } 190 | 191 | for _, c := range configs { 192 | ps := putil.Generate(c.n, c.k) 193 | t.Run(c.name, func(t *testing.T) { 194 | got := RangeSearch( 195 | New[*mock.P](O[*mock.P]{ 196 | Data: ps, 197 | K: c.k, 198 | N: c.size, 199 | }), 200 | c.q, 201 | putil.TrivialFilter, 202 | ) 203 | want := bruteforce.New[*mock.P](ps).RangeSearch(c.q, putil.TrivialFilter) 204 | 205 | if diff := cmp.Diff(want, got, putil.Transformer(vnd.V(make([]float64, c.k)))); diff != "" { 206 | t.Errorf("RangeSearch mismatch (-want +got):\n%v", diff) 207 | } 208 | }) 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /point/mock/mock.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "github.com/downflux/go-geometry/nd/vector" 5 | "github.com/downflux/go-kd/point" 6 | ) 7 | 8 | var _ point.P = P{} 9 | 10 | type P struct { 11 | X vector.V 12 | Data string 13 | } 14 | 15 | func (p P) P() vector.V { return p.X } 16 | 17 | func Equal(a P, b P) bool { 18 | return a.Data == b.Data && vector.Within(a.P(), b.P()) 19 | } 20 | 21 | func U(x float64) vector.V { return vector.V([]float64{x}) } 22 | func V(v []float64) vector.V { return vector.V(v) } 23 | -------------------------------------------------------------------------------- /point/point.go: -------------------------------------------------------------------------------- 1 | package point 2 | 3 | import ( 4 | "github.com/downflux/go-geometry/nd/vector" 5 | ) 6 | 7 | type P interface { 8 | P() vector.V 9 | } 10 | -------------------------------------------------------------------------------- /vector/vector.go: -------------------------------------------------------------------------------- 1 | package vector 2 | 3 | import ( 4 | "github.com/downflux/go-geometry/nd/vector" 5 | ) 6 | 7 | type Comparator vector.D 8 | 9 | func (c Comparator) Less(v vector.V, u vector.V) bool { return v.X(vector.D(c)) < u.X(vector.D(c)) } 10 | 11 | // Less returns the lexicographical ordering between two vectors. 12 | func Less(v vector.V, u vector.V) bool { 13 | if v.Dimension() != u.Dimension() { 14 | panic("mismatching vector dimensions") 15 | } 16 | for i := vector.D(0); i < v.Dimension(); i++ { 17 | if !Comparator(i).Less(v, u) { 18 | return false 19 | } 20 | } 21 | return true 22 | } 23 | -------------------------------------------------------------------------------- /x/README.md: -------------------------------------------------------------------------------- 1 | # go-kd/x 2 | 3 | Modules here are experimental and not yet ready for release. Features may be 4 | broken. 5 | -------------------------------------------------------------------------------- /x/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/downflux/go-kd/x 2 | 3 | go 1.18 4 | -------------------------------------------------------------------------------- /x/go.sum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/downflux/go-kd/1a161faebc58c0f2f1888820dd4d0e4501862e79/x/go.sum --------------------------------------------------------------------------------