├── .github └── workflows │ └── go.yml ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── filter.go ├── geom.go ├── geom_test.go ├── go.mod ├── rtree.go └── rtree_test.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | build: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | 16 | - name: Set up Go 17 | uses: actions/setup-go@v2 18 | with: 19 | go-version: 1.17 20 | 21 | - name: Build 22 | run: go build -v ./... 23 | 24 | - name: Test 25 | run: go test -v ./... 26 | 27 | - name: Run golangci-lint 28 | uses: golangci/golangci-lint-action@v3.7.0 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Daniel Connelly 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | rtreego 2 | ======= 3 | 4 | A library for efficiently storing and querying spatial data 5 | in the Go programming language. 6 | 7 | [![CI](https://github.com/dhconnelly/rtreego/actions/workflows/go.yml/badge.svg)](https://github.com/dhconnelly/rtreego/actions/workflows/go.yml) 8 | [![Go Report Card](https://goreportcard.com/badge/github.com/dhconnelly/rtreego)](https://goreportcard.com/report/github.com/dhconnelly/rtreego) 9 | [![GoDoc](https://godoc.org/github.com/dhconnelly/rtreego?status.svg)](https://godoc.org/github.com/dhconnelly/rtreego) 10 | 11 | About 12 | ----- 13 | 14 | The R-tree is a popular data structure for efficiently storing and 15 | querying spatial objects; one common use is implementing geospatial 16 | indexes in database management systems. Both bounding-box queries 17 | and k-nearest-neighbor queries are supported. 18 | 19 | R-trees are balanced, so maximum tree height is guaranteed to be 20 | logarithmic in the number of entries; however, good worst-case 21 | performance is not guaranteed. Instead, a number of rebalancing 22 | heuristics are applied that perform well in practice. For more 23 | details please refer to the references. 24 | 25 | This implementation handles the general N-dimensional case; for a more 26 | efficient implementation for the 3-dimensional case, see [Patrick 27 | Higgins' fork](https://github.com/patrick-higgins/rtreego). 28 | 29 | Getting Started 30 | --------------- 31 | 32 | Get the source code from [GitHub](https://github.com/dhconnelly/rtreego) or, 33 | with Go 1 installed, run `go get github.com/dhconnelly/rtreego`. 34 | 35 | Make sure you `import github.com/dhconnelly/rtreego` in your Go source files. 36 | 37 | Documentation 38 | ------------- 39 | 40 | ### Storing, updating, and deleting objects 41 | 42 | To create a new tree, specify the number of spatial dimensions and the minimum 43 | and maximum branching factor: 44 | ```Go 45 | rt := rtreego.NewTree(2, 25, 50) 46 | ``` 47 | You can also bulk-load the tree when creating it by passing the objects as 48 | a parameter. 49 | ```Go 50 | rt := rtreego.NewTree(2, 25, 50, objects...) 51 | ``` 52 | Any type that implements the `Spatial` interface can be stored in the tree: 53 | ```Go 54 | type Spatial interface { 55 | Bounds() *Rect 56 | } 57 | ``` 58 | `Rect`s are data structures for representing spatial objects, while `Point`s 59 | represent spatial locations. Creating `Point`s is easy--they're just slices 60 | of `float64`s: 61 | ```Go 62 | p1 := rtreego.Point{0.4, 0.5} 63 | p2 := rtreego.Point{6.2, -3.4} 64 | ``` 65 | To create a `Rect`, specify a location and the lengths of the sides: 66 | ```Go 67 | r1, _ := rtreego.NewRect(p1, []float64{1, 2}) 68 | r2, _ := rtreego.NewRect(p2, []float64{1.7, 2.7}) 69 | ``` 70 | To demonstrate, let's create and store some test data. 71 | ```Go 72 | type Thing struct { 73 | where *Rect 74 | name string 75 | } 76 | 77 | func (t *Thing) Bounds() *Rect { 78 | return t.where 79 | } 80 | 81 | rt.Insert(&Thing{r1, "foo"}) 82 | rt.Insert(&Thing{r2, "bar"}) 83 | 84 | size := rt.Size() // returns 2 85 | ``` 86 | We can insert and delete objects from the tree in any order. 87 | ```Go 88 | rt.Delete(thing2) 89 | // do some stuff... 90 | rt.Insert(anotherThing) 91 | ``` 92 | Note that ```Delete``` function does the equality comparison by comparing the 93 | memory addresses of the objects. If you do not have a pointer to the original 94 | object anymore, you can define a custom comparator. 95 | ```Go 96 | type Comparator func(obj1, obj2 Spatial) (equal bool) 97 | ``` 98 | You can use a custom comparator with ```DeleteWithComparator``` function. 99 | ```Go 100 | cmp := func(obj1, obj2 Spatial) bool { 101 | sp1 := obj1.(*IDRect) 102 | sp2 := obj2.(*IDRect) 103 | 104 | return sp1.ID == sp2.ID 105 | } 106 | 107 | rt.DeleteWithComparator(obj, cmp) 108 | ``` 109 | If you want to store points instead of rectangles, you can easily convert a 110 | point into a rectangle using the `ToRect` method: 111 | ```Go 112 | var tol = 0.01 113 | 114 | type Somewhere struct { 115 | location rtreego.Point 116 | name string 117 | wormhole chan int 118 | } 119 | 120 | func (s *Somewhere) Bounds() *Rect { 121 | // define the bounds of s to be a rectangle centered at s.location 122 | // with side lengths 2 * tol: 123 | return s.location.ToRect(tol) 124 | } 125 | 126 | rt.Insert(&Somewhere{rtreego.Point{0, 0}, "Someplace", nil}) 127 | ``` 128 | If you want to update the location of an object, you must delete it, update it, 129 | and re-insert. Just modifying the object so that the `*Rect` returned by 130 | `Location()` changes, without deleting and re-inserting the object, will 131 | corrupt the tree. 132 | 133 | ### Queries 134 | 135 | Bounding-box and k-nearest-neighbors queries are supported. 136 | 137 | Bounding-box queries require a search `*Rect`. This function will return all 138 | objects which has a non-zero intersection volume with the input search rectangle. 139 | ```Go 140 | bb, _ := rtreego.NewRect(rtreego.Point{1.7, -3.4}, []float64{3.2, 1.9}) 141 | 142 | // Get a slice of the objects in rt that intersect bb: 143 | results := rt.SearchIntersect(bb) 144 | ``` 145 | ### Filters 146 | 147 | You can filter out values during searches by implementing Filter functions. 148 | ```Go 149 | type Filter func(results []Spatial, object Spatial) (refuse, abort bool) 150 | ``` 151 | A filter for limiting results by result count is included in the package for 152 | backwards compatibility. 153 | ```Go 154 | // maximum of three results will be returned 155 | tree.SearchIntersect(bb, LimitFilter(3)) 156 | ``` 157 | Nearest-neighbor queries find the objects in a tree closest to a specified 158 | query point. 159 | ```Go 160 | q := rtreego.Point{6.5, -2.47} 161 | k := 5 162 | 163 | // Get a slice of the k objects in rt closest to q: 164 | results = rt.NearestNeighbors(k, q) 165 | ``` 166 | ### More information 167 | 168 | See [GoDoc](http://godoc.org/github.com/dhconnelly/rtreego) for full API 169 | documentation. 170 | 171 | References 172 | ---------- 173 | 174 | - A. Guttman. R-trees: A Dynamic Index Structure for Spatial Searching. 175 | Proceedings of ACM SIGMOD, pages 47-57, 1984. 176 | http://www.cs.jhu.edu/~misha/ReadingSeminar/Papers/Guttman84.pdf 177 | 178 | - N. Beckmann, H .P. Kriegel, R. Schneider and B. Seeger. The R*-tree: An 179 | Efficient and Robust Access Method for Points and Rectangles. Proceedings 180 | of ACM SIGMOD, pages 323-331, May 1990. 181 | http://infolab.usc.edu/csci587/Fall2011/papers/p322-beckmann.pdf 182 | 183 | - N. Roussopoulos, S. Kelley and F. Vincent. Nearest Neighbor Queries. ACM 184 | SIGMOD, pages 71-79, 1995. 185 | http://www.postgis.org/support/nearestneighbor.pdf 186 | 187 | Author 188 | ------ 189 | 190 | Written by [Daniel Connelly](http://dhconnelly.com) (). 191 | 192 | License 193 | ------- 194 | 195 | rtreego is released under a BSD-style license, described in the `LICENSE` 196 | file. 197 | -------------------------------------------------------------------------------- /filter.go: -------------------------------------------------------------------------------- 1 | package rtreego 2 | 3 | // Filter is an interface for filtering leaves during search. The parameters 4 | // should be treated as read-only. If refuse is true, the current entry will 5 | // not be added to the result set. If abort is true, the search is aborted and 6 | // the current result set will be returned. 7 | type Filter func(results []Spatial, object Spatial) (refuse, abort bool) 8 | 9 | // ApplyFilters applies the given filters and returns whether the entry is 10 | // refused and/or the search should be aborted. If a filter refuses an entry, 11 | // the following filters are not applied for the entry. If a filter aborts, the 12 | // search terminates without further applying any filter. 13 | func applyFilters(results []Spatial, object Spatial, filters []Filter) (bool, bool) { 14 | for _, filter := range filters { 15 | refuse, abort := filter(results, object) 16 | if refuse || abort { 17 | return refuse, abort 18 | } 19 | } 20 | return false, false 21 | } 22 | 23 | // LimitFilter checks if the results have reached the limit size and aborts if so. 24 | func LimitFilter(limit int) Filter { 25 | return func(results []Spatial, object Spatial) (refuse, abort bool) { 26 | if len(results) >= limit { 27 | return true, true 28 | } 29 | 30 | return false, false 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /geom.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012 Daniel Connelly. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package rtreego 6 | 7 | import ( 8 | "fmt" 9 | "math" 10 | "strings" 11 | ) 12 | 13 | // DimError represents a failure due to mismatched dimensions. 14 | type DimError struct { 15 | Expected int 16 | Actual int 17 | } 18 | 19 | func (err DimError) Error() string { 20 | return "rtreego: dimension mismatch" 21 | } 22 | 23 | // DistError is an improper distance measurement. It implements the error 24 | // and is generated when a distance-related assertion fails. 25 | type DistError float64 26 | 27 | func (err DistError) Error() string { 28 | return "rtreego: improper distance" 29 | } 30 | 31 | // Point represents a point in n-dimensional Euclidean space. 32 | type Point []float64 33 | 34 | func (p Point) Copy() Point { 35 | result := make(Point, len(p)) 36 | copy(result, p) 37 | return result 38 | } 39 | 40 | // Dist computes the Euclidean distance between two points p and q. 41 | func (p Point) dist(q Point) float64 { 42 | if len(p) != len(q) { 43 | panic(DimError{len(p), len(q)}) 44 | } 45 | sum := 0.0 46 | for i := range p { 47 | dx := p[i] - q[i] 48 | sum += dx * dx 49 | } 50 | return math.Sqrt(sum) 51 | } 52 | 53 | // minDist computes the square of the distance from a point to a rectangle. 54 | // If the point is contained in the rectangle then the distance is zero. 55 | // 56 | // Implemented per Definition 2 of "Nearest Neighbor Queries" by 57 | // N. Roussopoulos, S. Kelley and F. Vincent, ACM SIGMOD, pages 71-79, 1995. 58 | func (p Point) minDist(r Rect) float64 { 59 | if len(p) != len(r.p) { 60 | panic(DimError{len(p), len(r.p)}) 61 | } 62 | 63 | sum := 0.0 64 | for i, pi := range p { 65 | if pi < r.p[i] { 66 | d := pi - r.p[i] 67 | sum += d * d 68 | } else if pi > r.q[i] { 69 | d := pi - r.q[i] 70 | sum += d * d 71 | } else { 72 | sum += 0 73 | } 74 | } 75 | return sum 76 | } 77 | 78 | // minMaxDist computes the minimum of the maximum distances from p to points 79 | // on r. If r is the bounding box of some geometric objects, then there is 80 | // at least one object contained in r within minMaxDist(p, r) of p. 81 | // 82 | // Implemented per Definition 4 of "Nearest Neighbor Queries" by 83 | // N. Roussopoulos, S. Kelley and F. Vincent, ACM SIGMOD, pages 71-79, 1995. 84 | func (p Point) minMaxDist(r Rect) float64 { 85 | if len(p) != len(r.p) { 86 | panic(DimError{len(p), len(r.p)}) 87 | } 88 | 89 | // by definition, MinMaxDist(p, r) = 90 | // min{1<=k<=n}(|pk - rmk|^2 + sum{1<=i<=n, i != k}(|pi - rMi|^2)) 91 | // where rmk and rMk are defined as follows: 92 | 93 | rm := func(k int) float64 { 94 | if p[k] <= (r.p[k]+r.q[k])/2 { 95 | return r.p[k] 96 | } 97 | return r.q[k] 98 | } 99 | 100 | rM := func(k int) float64 { 101 | if p[k] >= (r.p[k]+r.q[k])/2 { 102 | return r.p[k] 103 | } 104 | return r.q[k] 105 | } 106 | 107 | // This formula can be computed in linear time by precomputing 108 | // S = sum{1<=i<=n}(|pi - rMi|^2). 109 | 110 | S := 0.0 111 | for i := range p { 112 | d := p[i] - rM(i) 113 | S += d * d 114 | } 115 | 116 | // Compute MinMaxDist using the precomputed S. 117 | min := math.MaxFloat64 118 | for k := range p { 119 | d1 := p[k] - rM(k) 120 | d2 := p[k] - rm(k) 121 | d := S - d1*d1 + d2*d2 122 | if d < min { 123 | min = d 124 | } 125 | } 126 | 127 | return min 128 | } 129 | 130 | // Rect represents a subset of n-dimensional Euclidean space of the form 131 | // [a1, b1] x [a2, b2] x ... x [an, bn], where ai < bi for all 1 <= i <= n. 132 | type Rect struct { 133 | p, q Point // Enforced by NewRect: p[i] <= q[i] for all i. 134 | } 135 | 136 | // PointCoord returns the coordinate of the point of the rectangle at i 137 | func (r Rect) PointCoord(i int) float64 { 138 | return r.p[i] 139 | } 140 | 141 | // LengthsCoord returns the coordinate of the lengths of the rectangle at i 142 | func (r Rect) LengthsCoord(i int) float64 { 143 | return r.q[i] - r.p[i] 144 | } 145 | 146 | // Equal returns true if the two rectangles are equal 147 | func (r Rect) Equal(other Rect) bool { 148 | for i, e := range r.p { 149 | if e != other.p[i] { 150 | return false 151 | } 152 | } 153 | for i, e := range r.q { 154 | if e != other.q[i] { 155 | return false 156 | } 157 | } 158 | return true 159 | } 160 | 161 | func (r Rect) String() string { 162 | s := make([]string, len(r.p)) 163 | for i, a := range r.p { 164 | b := r.q[i] 165 | s[i] = fmt.Sprintf("[%.2f, %.2f]", a, b) 166 | } 167 | return strings.Join(s, "x") 168 | } 169 | 170 | // NewRect constructs and returns a pointer to a Rect given a corner point and 171 | // the lengths of each dimension. The point p should be the most-negative point 172 | // on the rectangle (in every dimension) and every length should be positive. 173 | func NewRect(p Point, lengths []float64) (r Rect, err error) { 174 | r.p = p 175 | if len(p) != len(lengths) { 176 | err = &DimError{len(p), len(lengths)} 177 | return 178 | } 179 | r.q = make([]float64, len(p)) 180 | for i := range p { 181 | if lengths[i] <= 0 { 182 | err = DistError(lengths[i]) 183 | return 184 | } 185 | r.q[i] = p[i] + lengths[i] 186 | } 187 | return 188 | } 189 | 190 | // NewRectFromPoints constructs and returns a pointer to a Rect given a corner points. 191 | func NewRectFromPoints(minPoint, maxPoint Point) (r Rect, err error) { 192 | if len(minPoint) != len(maxPoint) { 193 | err = &DimError{len(minPoint), len(maxPoint)} 194 | return 195 | } 196 | 197 | // check that min and max point coordinates require swapping 198 | copied := false 199 | for i, p := range minPoint { 200 | if minPoint[i] > maxPoint[i] { 201 | if !copied { 202 | minPoint = minPoint.Copy() 203 | maxPoint = maxPoint.Copy() 204 | copied = true 205 | } 206 | minPoint[i] = maxPoint[i] 207 | maxPoint[i] = p 208 | } 209 | } 210 | 211 | r = Rect{p: minPoint, q: maxPoint} 212 | return 213 | } 214 | 215 | // Size computes the measure of a rectangle (the product of its side lengths). 216 | func (r Rect) Size() float64 { 217 | size := 1.0 218 | for i, a := range r.p { 219 | b := r.q[i] 220 | size *= b - a 221 | } 222 | return size 223 | } 224 | 225 | // margin computes the sum of the edge lengths of a rectangle. 226 | func (r Rect) margin() float64 { 227 | // The number of edges in an n-dimensional rectangle is n * 2^(n-1) 228 | // (http://en.wikipedia.org/wiki/Hypercube_graph). Thus the number 229 | // of edges of length (ai - bi), where the rectangle is determined 230 | // by p = (a1, a2, ..., an) and q = (b1, b2, ..., bn), is 2^(n-1). 231 | // 232 | // The margin of the rectangle, then, is given by the formula 233 | // 2^(n-1) * [(b1 - a1) + (b2 - a2) + ... + (bn - an)]. 234 | dim := len(r.p) 235 | sum := 0.0 236 | for i, a := range r.p { 237 | b := r.q[i] 238 | sum += b - a 239 | } 240 | return math.Pow(2, float64(dim-1)) * sum 241 | } 242 | 243 | // containsPoint tests whether p is located inside or on the boundary of r. 244 | func (r Rect) containsPoint(p Point) bool { 245 | if len(p) != len(r.p) { 246 | panic(DimError{len(r.p), len(p)}) 247 | } 248 | 249 | for i, a := range p { 250 | // p is contained in (or on) r if and only if p <= a <= q for 251 | // every dimension. 252 | if a < r.p[i] || a > r.q[i] { 253 | return false 254 | } 255 | } 256 | 257 | return true 258 | } 259 | 260 | // containsRect tests whether r2 is is located inside r1. 261 | func (r Rect) containsRect(r2 Rect) bool { 262 | if len(r.p) != len(r2.p) { 263 | panic(DimError{len(r.p), len(r2.p)}) 264 | } 265 | 266 | for i, a1 := range r.p { 267 | b1, a2, b2 := r.q[i], r2.p[i], r2.q[i] 268 | // enforced by constructor: a1 <= b1 and a2 <= b2. 269 | // so containment holds if and only if a1 <= a2 <= b2 <= b1 270 | // for every dimension. 271 | if a1 > a2 || b2 > b1 { 272 | return false 273 | } 274 | } 275 | 276 | return true 277 | } 278 | 279 | // intersect computes the intersection of two rectangles. If no intersection 280 | // exists, the intersection is nil. 281 | func intersect(r1, r2 Rect) bool { 282 | dim := len(r1.p) 283 | if len(r2.p) != dim { 284 | panic(DimError{dim, len(r2.p)}) 285 | } 286 | 287 | // There are four cases of overlap: 288 | // 289 | // 1. a1------------b1 290 | // a2------------b2 291 | // p--------q 292 | // 293 | // 2. a1------------b1 294 | // a2------------b2 295 | // p--------q 296 | // 297 | // 3. a1-----------------b1 298 | // a2-------b2 299 | // p--------q 300 | // 301 | // 4. a1-------b1 302 | // a2-----------------b2 303 | // p--------q 304 | // 305 | // Thus there are only two cases of non-overlap: 306 | // 307 | // 1. a1------b1 308 | // a2------b2 309 | // 310 | // 2. a1------b1 311 | // a2------b2 312 | // 313 | // Enforced by constructor: a1 <= b1 and a2 <= b2. So we can just 314 | // check the endpoints. 315 | 316 | for i := range r1.p { 317 | a1, b1, a2, b2 := r1.p[i], r1.q[i], r2.p[i], r2.q[i] 318 | if b2 <= a1 || b1 <= a2 { 319 | return false 320 | } 321 | } 322 | return true 323 | } 324 | 325 | // ToRect constructs a rectangle containing p with side lengths 2*tol. 326 | func (p Point) ToRect(tol float64) Rect { 327 | dim := len(p) 328 | a, b := make([]float64, dim), make([]float64, dim) 329 | for i := range p { 330 | a[i] = p[i] - tol 331 | b[i] = p[i] + tol 332 | } 333 | return Rect{a, b} 334 | } 335 | 336 | // boundingBox constructs the smallest rectangle containing both r1 and r2. 337 | func boundingBox(r1, r2 Rect) (bb Rect) { 338 | dim := len(r1.p) 339 | bb.p = make([]float64, dim) 340 | bb.q = make([]float64, dim) 341 | if len(r2.p) != dim { 342 | panic(DimError{dim, len(r2.p)}) 343 | } 344 | for i := 0; i < dim; i++ { 345 | if r1.p[i] <= r2.p[i] { 346 | bb.p[i] = r1.p[i] 347 | } else { 348 | bb.p[i] = r2.p[i] 349 | } 350 | if r1.q[i] <= r2.q[i] { 351 | bb.q[i] = r2.q[i] 352 | } else { 353 | bb.q[i] = r1.q[i] 354 | } 355 | } 356 | return 357 | } 358 | -------------------------------------------------------------------------------- /geom_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012 Daniel Connelly. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package rtreego 6 | 7 | import ( 8 | "math" 9 | "testing" 10 | ) 11 | 12 | const EPS = 0.000000001 13 | 14 | func TestDist(t *testing.T) { 15 | p := Point{1, 2, 3} 16 | q := Point{4, 5, 6} 17 | dist := math.Sqrt(27) 18 | if d := p.dist(q); d != dist { 19 | t.Errorf("dist(%v, %v) = %v; expected %v", p, q, d, dist) 20 | } 21 | } 22 | 23 | func TestNewRect(t *testing.T) { 24 | p := Point{1.0, -2.5, 3.0} 25 | q := Point{3.5, 5.5, 4.5} 26 | lengths := []float64{2.5, 8.0, 1.5} 27 | 28 | rect, err := NewRect(p, lengths) 29 | if err != nil { 30 | t.Errorf("Error on NewRect(%v, %v): %v", p, lengths, err) 31 | } 32 | if d := p.dist(rect.p); d > EPS { 33 | t.Errorf("Expected p == rect.p") 34 | } 35 | if d := q.dist(rect.q); d > EPS { 36 | t.Errorf("Expected q == rect.q") 37 | } 38 | } 39 | 40 | func TestNewRectFromPoints(t *testing.T) { 41 | p := Point{1.0, -2.5, 3.0} 42 | q := Point{3.5, 5.5, 4.5} 43 | 44 | rect, err := NewRectFromPoints(p, q) 45 | if err != nil { 46 | t.Errorf("Error on NewRect(%v, %v): %v", p, q, err) 47 | } 48 | if d := p.dist(rect.p); d > EPS { 49 | t.Errorf("Expected p == rect.p") 50 | } 51 | if d := q.dist(rect.q); d > EPS { 52 | t.Errorf("Expected q == rect.q") 53 | } 54 | } 55 | 56 | func TestNewRectFromPointsWithSwapPoints(t *testing.T) { 57 | p := Point{1.0, -2.5, 3.0} 58 | q := Point{3.5, 5.5, 4.5} 59 | 60 | rect, err := NewRectFromPoints(q, p) 61 | if err != nil { 62 | t.Errorf("Error on NewRect(%v, %v): %v", q, p, err) 63 | } 64 | 65 | if d := p.dist(rect.p); d > EPS { 66 | t.Errorf("Expected p == rect.") 67 | } 68 | if d := q.dist(rect.q); d > EPS { 69 | t.Errorf("Expected q == rect.q") 70 | } 71 | } 72 | 73 | func TestNewRectDimMismatch(t *testing.T) { 74 | p := Point{-7.0, 10.0} 75 | lengths := []float64{2.5, 8.0, 1.5} 76 | _, err := NewRect(p, lengths) 77 | if _, ok := err.(*DimError); !ok { 78 | t.Errorf("Expected DimError on NewRect(%v, %v)", p, lengths) 79 | } 80 | } 81 | 82 | func TestNewRectDistError(t *testing.T) { 83 | p := Point{1.0, -2.5, 3.0} 84 | lengths := []float64{2.5, -8.0, 1.5} 85 | _, err := NewRect(p, lengths) 86 | if _, ok := err.(DistError); !ok { 87 | t.Errorf("Expected distError on NewRect(%v, %v)", p, lengths) 88 | } 89 | } 90 | 91 | func TestRectPointCoord(t *testing.T) { 92 | p := Point{1.0, -2.5} 93 | lengths := []float64{2.5, 8.0} 94 | rect, _ := NewRect(p, lengths) 95 | 96 | f := rect.PointCoord(0) 97 | if f != 1.0 { 98 | t.Errorf("Expected %v.PointCoord(0) == 1.0, got %v", rect, f) 99 | } 100 | f = rect.PointCoord(1) 101 | if f != -2.5 { 102 | t.Errorf("Expected %v.PointCoord(1) == -2.5, got %v", rect, f) 103 | } 104 | } 105 | 106 | func TestRectLengthsCoord(t *testing.T) { 107 | p := Point{1.0, -2.5} 108 | lengths := []float64{2.5, 8.0} 109 | rect, _ := NewRect(p, lengths) 110 | 111 | f := rect.LengthsCoord(0) 112 | if f != 2.5 { 113 | t.Errorf("Expected %v.LengthsCoord(0) == 2.5, got %v", rect, f) 114 | } 115 | f = rect.LengthsCoord(1) 116 | if f != 8.0 { 117 | t.Errorf("Expected %v.LengthsCoord(1) == 8.0, got %v", rect, f) 118 | } 119 | } 120 | 121 | func TestRectEqual(t *testing.T) { 122 | p := Point{1.0, -2.5, 3.0} 123 | lengths := []float64{2.5, 8.0, 1.5} 124 | a, _ := NewRect(p, lengths) 125 | b, _ := NewRect(p, lengths) 126 | c, _ := NewRect(Point{0.0, -2.5, 3.0}, lengths) 127 | if !a.Equal(b) { 128 | t.Errorf("Expected %v.Equal(%v) to return true", a, b) 129 | } 130 | if a.Equal(c) { 131 | t.Errorf("Expected %v.Equal(%v) to return false", a, c) 132 | } 133 | } 134 | 135 | func TestRectSize(t *testing.T) { 136 | p := Point{1.0, -2.5, 3.0} 137 | lengths := []float64{2.5, 8.0, 1.5} 138 | rect, _ := NewRect(p, lengths) 139 | size := lengths[0] * lengths[1] * lengths[2] 140 | actual := rect.Size() 141 | if size != actual { 142 | t.Errorf("Expected %v.Size() == %v, got %v", rect, size, actual) 143 | } 144 | } 145 | 146 | func TestRectMargin(t *testing.T) { 147 | p := Point{1.0, -2.5, 3.0} 148 | lengths := []float64{2.5, 8.0, 1.5} 149 | rect, _ := NewRect(p, lengths) 150 | size := 4*2.5 + 4*8.0 + 4*1.5 151 | actual := rect.margin() 152 | if size != actual { 153 | t.Errorf("Expected %v.margin() == %v, got %v", rect, size, actual) 154 | } 155 | } 156 | 157 | func TestContainsPoint(t *testing.T) { 158 | p := Point{3.7, -2.4, 0.0} 159 | lengths := []float64{6.2, 1.1, 4.9} 160 | rect, _ := NewRect(p, lengths) 161 | 162 | q := Point{4.5, -1.7, 4.8} 163 | if yes := rect.containsPoint(q); !yes { 164 | t.Errorf("Expected %v contains %v", rect, q) 165 | } 166 | } 167 | 168 | func TestDoesNotContainPoint(t *testing.T) { 169 | p := Point{3.7, -2.4, 0.0} 170 | lengths := []float64{6.2, 1.1, 4.9} 171 | rect, _ := NewRect(p, lengths) 172 | 173 | q := Point{4.5, -1.7, -3.2} 174 | if yes := rect.containsPoint(q); yes { 175 | t.Errorf("Expected %v doesn't contain %v", rect, q) 176 | } 177 | } 178 | 179 | func TestContainsRect(t *testing.T) { 180 | p := Point{3.7, -2.4, 0.0} 181 | lengths1 := []float64{6.2, 1.1, 4.9} 182 | rect1, _ := NewRect(p, lengths1) 183 | 184 | q := Point{4.1, -1.9, 1.0} 185 | lengths2 := []float64{3.2, 0.6, 3.7} 186 | rect2, _ := NewRect(q, lengths2) 187 | 188 | if yes := rect1.containsRect(rect2); !yes { 189 | t.Errorf("Expected %v.containsRect(%v", rect1, rect2) 190 | } 191 | } 192 | 193 | func TestDoesNotContainRectOverlaps(t *testing.T) { 194 | p := Point{3.7, -2.4, 0.0} 195 | lengths1 := []float64{6.2, 1.1, 4.9} 196 | rect1, _ := NewRect(p, lengths1) 197 | 198 | q := Point{4.1, -1.9, 1.0} 199 | lengths2 := []float64{3.2, 1.4, 3.7} 200 | rect2, _ := NewRect(q, lengths2) 201 | 202 | if yes := rect1.containsRect(rect2); yes { 203 | t.Errorf("Expected %v doesn't contain %v", rect1, rect2) 204 | } 205 | } 206 | 207 | func TestDoesNotContainRectDisjoint(t *testing.T) { 208 | p := Point{3.7, -2.4, 0.0} 209 | lengths1 := []float64{6.2, 1.1, 4.9} 210 | rect1, _ := NewRect(p, lengths1) 211 | 212 | q := Point{1.2, -19.6, -4.0} 213 | lengths2 := []float64{2.2, 5.9, 0.5} 214 | rect2, _ := NewRect(q, lengths2) 215 | 216 | if yes := rect1.containsRect(rect2); yes { 217 | t.Errorf("Expected %v doesn't contain %v", rect1, rect2) 218 | } 219 | } 220 | 221 | func TestNoIntersection(t *testing.T) { 222 | p := Point{1, 2, 3} 223 | lengths1 := []float64{1, 1, 1} 224 | rect1, _ := NewRect(p, lengths1) 225 | 226 | q := Point{-1, -2, -3} 227 | lengths2 := []float64{2.5, 3, 6.5} 228 | rect2, _ := NewRect(q, lengths2) 229 | 230 | // rect1 and rect2 fail to overlap in just one dimension (second) 231 | 232 | if intersect(rect1, rect2) { 233 | t.Errorf("Expected intersect(%v, %v) == false", rect1, rect2) 234 | } 235 | } 236 | 237 | func TestNoIntersectionJustTouches(t *testing.T) { 238 | p := Point{1, 2, 3} 239 | lengths1 := []float64{1, 1, 1} 240 | rect1, _ := NewRect(p, lengths1) 241 | 242 | q := Point{-1, -2, -3} 243 | lengths2 := []float64{2.5, 4, 6.5} 244 | rect2, _ := NewRect(q, lengths2) 245 | 246 | // rect1 and rect2 fail to overlap in just one dimension (second) 247 | 248 | if intersect(rect1, rect2) { 249 | t.Errorf("Expected intersect(%v, %v) == false", rect1, rect2) 250 | } 251 | } 252 | 253 | func TestContainmentIntersection(t *testing.T) { 254 | p := Point{1, 2, 3} 255 | lengths1 := []float64{1, 1, 1} 256 | rect1, _ := NewRect(p, lengths1) 257 | 258 | q := Point{1, 2.2, 3.3} 259 | lengths2 := []float64{0.5, 0.5, 0.5} 260 | rect2, _ := NewRect(q, lengths2) 261 | 262 | r := Point{1, 2.2, 3.3} 263 | s := Point{1.5, 2.7, 3.8} 264 | 265 | if !intersect(rect1, rect2) { 266 | t.Errorf("intersect(%v, %v) != %v, %v", rect1, rect2, r, s) 267 | } 268 | } 269 | 270 | func TestOverlapIntersection(t *testing.T) { 271 | p := Point{1, 2, 3} 272 | lengths1 := []float64{1, 2.5, 1} 273 | rect1, _ := NewRect(p, lengths1) 274 | 275 | q := Point{1, 4, -3} 276 | lengths2 := []float64{3, 2, 6.5} 277 | rect2, _ := NewRect(q, lengths2) 278 | 279 | r := Point{1, 4, 3} 280 | s := Point{2, 4.5, 3.5} 281 | 282 | if !intersect(rect1, rect2) { 283 | t.Errorf("intersect(%v, %v) != %v, %v", rect1, rect2, r, s) 284 | } 285 | } 286 | 287 | func TestToRect(t *testing.T) { 288 | x := Point{3.7, -2.4, 0.0} 289 | tol := 0.05 290 | rect := x.ToRect(tol) 291 | 292 | p := Point{3.65, -2.45, -0.05} 293 | q := Point{3.75, -2.35, 0.05} 294 | d1 := p.dist(rect.p) 295 | d2 := q.dist(rect.q) 296 | if d1 > EPS || d2 > EPS { 297 | t.Errorf("Expected %v.ToRect(%v) == %v, %v, got %v", x, tol, p, q, rect) 298 | } 299 | } 300 | 301 | func TestBoundingBox(t *testing.T) { 302 | p := Point{3.7, -2.4, 0.0} 303 | lengths1 := []float64{1, 15, 3} 304 | rect1, _ := NewRect(p, lengths1) 305 | 306 | q := Point{-6.5, 4.7, 2.5} 307 | lengths2 := []float64{4, 5, 6} 308 | rect2, _ := NewRect(q, lengths2) 309 | 310 | r := Point{-6.5, -2.4, 0.0} 311 | s := Point{4.7, 12.6, 8.5} 312 | 313 | bb := boundingBox(rect1, rect2) 314 | d1 := r.dist(bb.p) 315 | d2 := s.dist(bb.q) 316 | if d1 > EPS || d2 > EPS { 317 | t.Errorf("boundingBox(%v, %v) != %v, %v, got %v", rect1, rect2, r, s, bb) 318 | } 319 | } 320 | 321 | func TestBoundingBoxContains(t *testing.T) { 322 | p := Point{3.7, -2.4, 0.0} 323 | lengths1 := []float64{1, 15, 3} 324 | rect1, _ := NewRect(p, lengths1) 325 | 326 | q := Point{4.0, 0.0, 1.5} 327 | lengths2 := []float64{0.56, 6.222222, 0.946} 328 | rect2, _ := NewRect(q, lengths2) 329 | 330 | bb := boundingBox(rect1, rect2) 331 | d1 := rect1.p.dist(bb.p) 332 | d2 := rect1.q.dist(bb.q) 333 | if d1 > EPS || d2 > EPS { 334 | t.Errorf("boundingBox(%v, %v) != %v, got %v", rect1, rect2, rect1, bb) 335 | } 336 | } 337 | 338 | func TestMinDistZero(t *testing.T) { 339 | p := Point{1, 2, 3} 340 | r := p.ToRect(1) 341 | if d := p.minDist(r); d > EPS { 342 | t.Errorf("Expected %v.minDist(%v) == 0, got %v", p, r, d) 343 | } 344 | } 345 | 346 | func TestMinDistPositive(t *testing.T) { 347 | p := Point{1, 2, 3} 348 | r := Rect{Point{-1, -4, 7}, Point{2, -2, 9}} 349 | expected := float64((-2-2)*(-2-2) + (7-3)*(7-3)) 350 | if d := p.minDist(r); math.Abs(d-expected) > EPS { 351 | t.Errorf("Expected %v.minDist(%v) == %v, got %v", p, r, expected, d) 352 | } 353 | } 354 | 355 | func TestMinMaxdist(t *testing.T) { 356 | p := Point{-3, -2, -1} 357 | r := Rect{Point{0, 0, 0}, Point{1, 2, 3}} 358 | 359 | // furthest points from p on the faces closest to p in each dimension 360 | q1 := Point{0, 2, 3} 361 | q2 := Point{1, 0, 3} 362 | q3 := Point{1, 2, 0} 363 | 364 | // find the closest distance from p to one of these furthest points 365 | d1 := p.dist(q1) 366 | d2 := p.dist(q2) 367 | d3 := p.dist(q3) 368 | expected := math.Min(d1*d1, math.Min(d2*d2, d3*d3)) 369 | 370 | if d := p.minMaxDist(r); math.Abs(d-expected) > EPS { 371 | t.Errorf("Expected %v.minMaxDist(%v) == %v, got %v", p, r, expected, d) 372 | } 373 | } 374 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/dhconnelly/rtreego 2 | 3 | go 1.13 4 | -------------------------------------------------------------------------------- /rtree.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012 Daniel Connelly. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package rtreego is a library for efficiently storing and querying spatial data. 6 | package rtreego 7 | 8 | import ( 9 | "fmt" 10 | "math" 11 | "sort" 12 | ) 13 | 14 | // Comparator compares two spatials and returns whether they are equal. 15 | type Comparator func(obj1, obj2 Spatial) (equal bool) 16 | 17 | func defaultComparator(obj1, obj2 Spatial) bool { 18 | return obj1 == obj2 19 | } 20 | 21 | // Rtree represents an R-tree, a balanced search tree for storing and querying 22 | // spatial objects. Dim specifies the number of spatial dimensions and 23 | // MinChildren/MaxChildren specify the minimum/maximum branching factors. 24 | type Rtree struct { 25 | Dim int 26 | MinChildren int 27 | MaxChildren int 28 | root *node 29 | size int 30 | height int 31 | 32 | // deleted is a temporary buffer to avoid memory allocations in Delete. 33 | // It is just an optimization and not part of the data structure. 34 | deleted []*node 35 | 36 | // FloatingPointTolerance is the tolerance to guard against floating point rounding errors during minMaxDist calculations. 37 | FloatingPointTolerance float64 38 | } 39 | 40 | // NewTree returns an Rtree. If the number of objects given on initialization 41 | // is larger than max, the Rtree will be initialized using the Overlap 42 | // Minimizing Top-down bulk-loading algorithm. 43 | func NewTree(dim, min, max int, objs ...Spatial) *Rtree { 44 | rt := &Rtree{ 45 | Dim: dim, 46 | MinChildren: min, 47 | MaxChildren: max, 48 | height: 1, 49 | FloatingPointTolerance: 1e-6, 50 | root: &node{ 51 | entries: []entry{}, 52 | leaf: true, 53 | level: 1, 54 | }, 55 | } 56 | 57 | if len(objs) <= rt.MaxChildren { 58 | for _, obj := range objs { 59 | rt.Insert(obj) 60 | } 61 | } else { 62 | rt.bulkLoad(objs) 63 | } 64 | 65 | return rt 66 | } 67 | 68 | // Size returns the number of objects currently stored in tree. 69 | func (tree *Rtree) Size() int { 70 | return tree.size 71 | } 72 | 73 | func (tree *Rtree) String() string { 74 | return "foo" 75 | } 76 | 77 | // Depth returns the maximum depth of tree. 78 | func (tree *Rtree) Depth() int { 79 | return tree.height 80 | } 81 | 82 | type dimSorter struct { 83 | dim int 84 | objs []entry 85 | } 86 | 87 | func (s *dimSorter) Len() int { 88 | return len(s.objs) 89 | } 90 | 91 | func (s *dimSorter) Swap(i, j int) { 92 | s.objs[i], s.objs[j] = s.objs[j], s.objs[i] 93 | } 94 | 95 | func (s *dimSorter) Less(i, j int) bool { 96 | return s.objs[i].bb.p[s.dim] < s.objs[j].bb.p[s.dim] 97 | } 98 | 99 | // walkPartitions splits objs into slices of maximum k elements and 100 | // iterates over these partitions. 101 | func walkPartitions(k int, objs []entry, iter func(parts []entry)) { 102 | n := (len(objs) + k - 1) / k // ceil(len(objs) / k) 103 | 104 | for i := 1; i < n; i++ { 105 | iter(objs[(i-1)*k : i*k]) 106 | } 107 | iter(objs[(n-1)*k:]) 108 | } 109 | 110 | func sortByDim(dim int, objs []entry) { 111 | sort.Sort(&dimSorter{dim, objs}) 112 | } 113 | 114 | // bulkLoad bulk loads the Rtree using OMT algorithm. bulkLoad contains special 115 | // handling for the root node. 116 | func (tree *Rtree) bulkLoad(objs []Spatial) { 117 | n := len(objs) 118 | 119 | // create entries for all the objects 120 | entries := make([]entry, n) 121 | for i := range objs { 122 | entries[i] = entry{ 123 | bb: objs[i].Bounds(), 124 | obj: objs[i], 125 | } 126 | } 127 | 128 | // following equations are defined in the paper describing OMT 129 | var ( 130 | N = float64(n) 131 | M = float64(tree.MaxChildren) 132 | ) 133 | // Eq1: height of the tree 134 | // use log2 instead of log due to rounding errors with log, 135 | // eg, math.Log(9) / math.Log(3) > 2 136 | h := math.Ceil(math.Log2(N) / math.Log2(M)) 137 | 138 | // Eq2: size of subtrees at the root 139 | nsub := math.Pow(M, h-1) 140 | 141 | // Inner Eq3: number of subtrees at the root 142 | s := math.Ceil(N / nsub) 143 | 144 | // Eq3: number of slices 145 | S := math.Floor(math.Sqrt(s)) 146 | 147 | // sort all entries by first dimension 148 | sortByDim(0, entries) 149 | 150 | tree.height = int(h) 151 | tree.size = n 152 | tree.root = tree.omt(int(h), int(S), entries, int(s)) 153 | } 154 | 155 | // omt is the recursive part of the Overlap Minimizing Top-loading bulk- 156 | // load approach. Returns the root node of a subtree. 157 | func (tree *Rtree) omt(level, nSlices int, objs []entry, m int) *node { 158 | // if number of objects is less than or equal than max children per leaf, 159 | // we need to create a leaf node 160 | if len(objs) <= m { 161 | // as long as the recursion is not at the leaf, call it again 162 | if level > 1 { 163 | child := tree.omt(level-1, nSlices, objs, m) 164 | n := &node{ 165 | level: level, 166 | entries: []entry{{ 167 | bb: child.computeBoundingBox(), 168 | child: child, 169 | }}, 170 | } 171 | child.parent = n 172 | return n 173 | } 174 | entries := make([]entry, len(objs)) 175 | copy(entries, objs) 176 | return &node{ 177 | leaf: true, 178 | entries: entries, 179 | level: level, 180 | } 181 | } 182 | 183 | n := &node{ 184 | level: level, 185 | entries: make([]entry, 0, m), 186 | } 187 | 188 | // maximum node size given at most M nodes at this level 189 | k := (len(objs) + m - 1) / m // = ceil(N / M) 190 | 191 | // In the root level, split objs in nSlices. In all other levels, 192 | // we use a single slice. 193 | vertSize := len(objs) 194 | if nSlices > 1 { 195 | vertSize = nSlices * k 196 | } 197 | 198 | // create sub trees 199 | walkPartitions(vertSize, objs, func(vert []entry) { 200 | // sort vertical slice by a different dimension on every level 201 | sortByDim((tree.height-level+1)%tree.Dim, vert) 202 | 203 | // split slice into groups of size k 204 | walkPartitions(k, vert, func(part []entry) { 205 | child := tree.omt(level-1, 1, part, tree.MaxChildren) 206 | child.parent = n 207 | 208 | n.entries = append(n.entries, entry{ 209 | bb: child.computeBoundingBox(), 210 | child: child, 211 | }) 212 | }) 213 | }) 214 | return n 215 | } 216 | 217 | // node represents a tree node of an Rtree. 218 | type node struct { 219 | parent *node 220 | entries []entry 221 | level int // node depth in the Rtree 222 | leaf bool 223 | } 224 | 225 | func (n *node) String() string { 226 | return fmt.Sprintf("node{leaf: %v, entries: %v}", n.leaf, n.entries) 227 | } 228 | 229 | // entry represents a spatial index record stored in a tree node. 230 | type entry struct { 231 | bb Rect // bounding-box of all children of this entry 232 | child *node 233 | obj Spatial 234 | } 235 | 236 | func (e entry) String() string { 237 | if e.child != nil { 238 | return fmt.Sprintf("entry{bb: %v, child: %v}", e.bb, e.child) 239 | } 240 | return fmt.Sprintf("entry{bb: %v, obj: %v}", e.bb, e.obj) 241 | } 242 | 243 | // Spatial is an interface for objects that can be stored in an Rtree and queried. 244 | type Spatial interface { 245 | Bounds() Rect 246 | } 247 | 248 | // Insertion 249 | 250 | // Insert inserts a spatial object into the tree. If insertion 251 | // causes a leaf node to overflow, the tree is rebalanced automatically. 252 | // 253 | // Implemented per Section 3.2 of "R-trees: A Dynamic Index Structure for 254 | // Spatial Searching" by A. Guttman, Proceedings of ACM SIGMOD, p. 47-57, 1984. 255 | func (tree *Rtree) Insert(obj Spatial) { 256 | e := entry{obj.Bounds(), nil, obj} 257 | tree.insert(e, 1) 258 | tree.size++ 259 | } 260 | 261 | // insert adds the specified entry to the tree at the specified level. 262 | func (tree *Rtree) insert(e entry, level int) { 263 | leaf := tree.chooseNode(tree.root, e, level) 264 | leaf.entries = append(leaf.entries, e) 265 | 266 | // update parent pointer if necessary 267 | if e.child != nil { 268 | e.child.parent = leaf 269 | } 270 | 271 | // split leaf if overflows 272 | var split *node 273 | if len(leaf.entries) > tree.MaxChildren { 274 | leaf, split = leaf.split(tree.MinChildren) 275 | } 276 | root, splitRoot := tree.adjustTree(leaf, split) 277 | if splitRoot != nil { 278 | oldRoot := root 279 | tree.height++ 280 | tree.root = &node{ 281 | parent: nil, 282 | level: tree.height, 283 | entries: []entry{ 284 | {bb: oldRoot.computeBoundingBox(), child: oldRoot}, 285 | {bb: splitRoot.computeBoundingBox(), child: splitRoot}, 286 | }, 287 | } 288 | oldRoot.parent = tree.root 289 | splitRoot.parent = tree.root 290 | } 291 | } 292 | 293 | // chooseNode finds the node at the specified level to which e should be added. 294 | func (tree *Rtree) chooseNode(n *node, e entry, level int) *node { 295 | if n.leaf || n.level == level { 296 | return n 297 | } 298 | 299 | // find the entry whose bb needs least enlargement to include obj 300 | diff := math.MaxFloat64 301 | var chosen entry 302 | for _, en := range n.entries { 303 | bb := boundingBox(en.bb, e.bb) 304 | d := bb.Size() - en.bb.Size() 305 | if d < diff || (d == diff && en.bb.Size() < chosen.bb.Size()) { 306 | diff = d 307 | chosen = en 308 | } 309 | } 310 | 311 | return tree.chooseNode(chosen.child, e, level) 312 | } 313 | 314 | // adjustTree splits overflowing nodes and propagates the changes upwards. 315 | func (tree *Rtree) adjustTree(n, nn *node) (*node, *node) { 316 | // Let the caller handle root adjustments. 317 | if n == tree.root { 318 | return n, nn 319 | } 320 | 321 | // Re-size the bounding box of n to account for lower-level changes. 322 | en := n.getEntry() 323 | prevBox := en.bb 324 | en.bb = n.computeBoundingBox() 325 | 326 | // If nn is nil, then we're just propagating changes upwards. 327 | if nn == nil { 328 | // Optimize for the case where nothing is changed 329 | // to avoid computeBoundingBox which is expensive. 330 | if en.bb.Equal(prevBox) { 331 | return tree.root, nil 332 | } 333 | return tree.adjustTree(n.parent, nil) 334 | } 335 | 336 | // Otherwise, these are two nodes resulting from a split. 337 | // n was reused as the "left" node, but we need to add nn to n.parent. 338 | enn := entry{nn.computeBoundingBox(), nn, nil} 339 | n.parent.entries = append(n.parent.entries, enn) 340 | 341 | // If the new entry overflows the parent, split the parent and propagate. 342 | if len(n.parent.entries) > tree.MaxChildren { 343 | return tree.adjustTree(n.parent.split(tree.MinChildren)) 344 | } 345 | 346 | // Otherwise keep propagating changes upwards. 347 | return tree.adjustTree(n.parent, nil) 348 | } 349 | 350 | // getEntry returns a pointer to the entry for the node n from n's parent. 351 | func (n *node) getEntry() *entry { 352 | var e *entry 353 | for i := range n.parent.entries { 354 | if n.parent.entries[i].child == n { 355 | e = &n.parent.entries[i] 356 | break 357 | } 358 | } 359 | return e 360 | } 361 | 362 | // computeBoundingBox finds the MBR of the children of n. 363 | func (n *node) computeBoundingBox() (bb Rect) { 364 | if len(n.entries) == 1 { 365 | bb = n.entries[0].bb 366 | return 367 | } 368 | 369 | bb = boundingBox(n.entries[0].bb, n.entries[1].bb) 370 | for _, e := range n.entries[2:] { 371 | bb = boundingBox(bb, e.bb) 372 | } 373 | return 374 | } 375 | 376 | // split splits a node into two groups while attempting to minimize the 377 | // bounding-box area of the resulting groups. 378 | func (n *node) split(minGroupSize int) (left, right *node) { 379 | // find the initial split 380 | l, r := n.pickSeeds() 381 | leftSeed, rightSeed := n.entries[l], n.entries[r] 382 | 383 | // get the entries to be divided between left and right 384 | remaining := append(n.entries[:l], n.entries[l+1:r]...) 385 | remaining = append(remaining, n.entries[r+1:]...) 386 | 387 | // setup the new split nodes, but re-use n as the left node 388 | left = n 389 | left.entries = []entry{leftSeed} 390 | right = &node{ 391 | parent: n.parent, 392 | leaf: n.leaf, 393 | level: n.level, 394 | entries: []entry{rightSeed}, 395 | } 396 | 397 | // TODO 398 | if rightSeed.child != nil { 399 | rightSeed.child.parent = right 400 | } 401 | if leftSeed.child != nil { 402 | leftSeed.child.parent = left 403 | } 404 | 405 | // distribute all of n's old entries into left and right. 406 | for len(remaining) > 0 { 407 | next := pickNext(left, right, remaining) 408 | e := remaining[next] 409 | 410 | if len(remaining)+len(left.entries) <= minGroupSize { 411 | assign(e, left) 412 | } else if len(remaining)+len(right.entries) <= minGroupSize { 413 | assign(e, right) 414 | } else { 415 | assignGroup(e, left, right) 416 | } 417 | 418 | remaining = append(remaining[:next], remaining[next+1:]...) 419 | } 420 | 421 | return 422 | } 423 | 424 | // getAllBoundingBoxes traverses tree populating slice of bounding boxes of non-leaf nodes. 425 | func (n *node) getAllBoundingBoxes() []Rect { 426 | var rects []Rect 427 | if n.leaf { 428 | return rects 429 | } 430 | for _, e := range n.entries { 431 | if e.child == nil { 432 | return rects 433 | } 434 | rectsInter := append(e.child.getAllBoundingBoxes(), e.bb) 435 | rects = append(rects, rectsInter...) 436 | } 437 | return rects 438 | } 439 | 440 | func assign(e entry, group *node) { 441 | if e.child != nil { 442 | e.child.parent = group 443 | } 444 | group.entries = append(group.entries, e) 445 | } 446 | 447 | // assignGroup chooses one of two groups to which a node should be added. 448 | func assignGroup(e entry, left, right *node) { 449 | leftBB := left.computeBoundingBox() 450 | rightBB := right.computeBoundingBox() 451 | leftEnlarged := boundingBox(leftBB, e.bb) 452 | rightEnlarged := boundingBox(rightBB, e.bb) 453 | 454 | // first, choose the group that needs the least enlargement 455 | leftDiff := leftEnlarged.Size() - leftBB.Size() 456 | rightDiff := rightEnlarged.Size() - rightBB.Size() 457 | if diff := leftDiff - rightDiff; diff < 0 { 458 | assign(e, left) 459 | return 460 | } else if diff > 0 { 461 | assign(e, right) 462 | return 463 | } 464 | 465 | // next, choose the group that has smaller area 466 | if diff := leftBB.Size() - rightBB.Size(); diff < 0 { 467 | assign(e, left) 468 | return 469 | } else if diff > 0 { 470 | assign(e, right) 471 | return 472 | } 473 | 474 | // next, choose the group with fewer entries 475 | if diff := len(left.entries) - len(right.entries); diff <= 0 { 476 | assign(e, left) 477 | return 478 | } 479 | assign(e, right) 480 | } 481 | 482 | // pickSeeds chooses two child entries of n to start a split. 483 | func (n *node) pickSeeds() (int, int) { 484 | left, right := 0, 1 485 | maxWastedSpace := -1.0 486 | for i, e1 := range n.entries { 487 | for j, e2 := range n.entries[i+1:] { 488 | d := boundingBox(e1.bb, e2.bb).Size() - e1.bb.Size() - e2.bb.Size() 489 | if d > maxWastedSpace { 490 | maxWastedSpace = d 491 | left, right = i, j+i+1 492 | } 493 | } 494 | } 495 | return left, right 496 | } 497 | 498 | // pickNext chooses an entry to be added to an entry group. 499 | func pickNext(left, right *node, entries []entry) (next int) { 500 | maxDiff := -1.0 501 | leftBB := left.computeBoundingBox() 502 | rightBB := right.computeBoundingBox() 503 | for i, e := range entries { 504 | d1 := boundingBox(leftBB, e.bb).Size() - leftBB.Size() 505 | d2 := boundingBox(rightBB, e.bb).Size() - rightBB.Size() 506 | d := math.Abs(d1 - d2) 507 | if d > maxDiff { 508 | maxDiff = d 509 | next = i 510 | } 511 | } 512 | return 513 | } 514 | 515 | // Deletion 516 | 517 | // Delete removes an object from the tree. If the object is not found, returns 518 | // false, otherwise returns true. Uses the default comparator when checking 519 | // equality. 520 | // 521 | // Implemented per Section 3.3 of "R-trees: A Dynamic Index Structure for 522 | // Spatial Searching" by A. Guttman, Proceedings of ACM SIGMOD, p. 47-57, 1984. 523 | func (tree *Rtree) Delete(obj Spatial) bool { 524 | return tree.DeleteWithComparator(obj, defaultComparator) 525 | } 526 | 527 | // DeleteWithComparator removes an object from the tree using a custom 528 | // comparator for evaluating equalness. This is useful when you want to remove 529 | // an object from a tree but don't have a pointer to the original object 530 | // anymore. 531 | func (tree *Rtree) DeleteWithComparator(obj Spatial, cmp Comparator) bool { 532 | n := tree.findLeaf(tree.root, obj, cmp) 533 | if n == nil { 534 | return false 535 | } 536 | 537 | ind := -1 538 | for i, e := range n.entries { 539 | if cmp(e.obj, obj) { 540 | ind = i 541 | } 542 | } 543 | if ind < 0 { 544 | return false 545 | } 546 | 547 | n.entries = append(n.entries[:ind], n.entries[ind+1:]...) 548 | 549 | tree.condenseTree(n) 550 | tree.size-- 551 | 552 | /* 553 | when the tree is deep, and deleting nodes, will cause the issue. 554 | the tree could be like this: one obj but 3 levels depth. 555 | { 556 | "size": 1, 557 | "depth": 3, 558 | "root": { 559 | "entries": [ 560 | { 561 | "bb": "[1.00, 2.00]x[1.00, 2.00]", 562 | "child": { 563 | "entries": [ 564 | { 565 | "bb": "[1.00, 2.00]x[1.00, 2.00]", 566 | "child": { 567 | "leaf": true, 568 | "entries": [ 569 | { 570 | "bb": "[1.00, 2.00]x[1.00, 2.00]" 571 | } 572 | ] 573 | } 574 | } 575 | ] 576 | } 577 | } 578 | ] 579 | } 580 | } 581 | so we need to merge the root in loop, instead of once. 582 | */ 583 | for !tree.root.leaf && len(tree.root.entries) == 1 { 584 | tree.root = tree.root.entries[0].child 585 | } 586 | 587 | tree.height = tree.root.level 588 | 589 | return true 590 | } 591 | 592 | // findLeaf finds the leaf node containing obj. 593 | func (tree *Rtree) findLeaf(n *node, obj Spatial, cmp Comparator) *node { 594 | if n.leaf { 595 | return n 596 | } 597 | // if not leaf, search all candidate subtrees 598 | for _, e := range n.entries { 599 | if e.bb.containsRect(obj.Bounds()) { 600 | leaf := tree.findLeaf(e.child, obj, cmp) 601 | if leaf == nil { 602 | continue 603 | } 604 | // check if the leaf actually contains the object 605 | for _, leafEntry := range leaf.entries { 606 | if cmp(leafEntry.obj, obj) { 607 | return leaf 608 | } 609 | } 610 | } 611 | } 612 | return nil 613 | } 614 | 615 | // condenseTree deletes underflowing nodes and propagates the changes upwards. 616 | func (tree *Rtree) condenseTree(n *node) { 617 | // reset the deleted buffer 618 | tree.deleted = tree.deleted[:0] 619 | 620 | for n != tree.root { 621 | if len(n.entries) < tree.MinChildren { 622 | // find n and delete it by swapping the last entry into its place 623 | idx := -1 624 | for i, e := range n.parent.entries { 625 | if e.child == n { 626 | idx = i 627 | break 628 | } 629 | } 630 | if idx == -1 { 631 | panic(fmt.Errorf("Failed to remove entry from parent")) 632 | } 633 | l := len(n.parent.entries) 634 | n.parent.entries[idx] = n.parent.entries[l-1] 635 | n.parent.entries = n.parent.entries[:l-1] 636 | 637 | // only add n to deleted if it still has children 638 | if len(n.entries) > 0 { 639 | tree.deleted = append(tree.deleted, n) 640 | } 641 | } else { 642 | // just a child entry deletion, no underflow 643 | en := n.getEntry() 644 | prevBox := en.bb 645 | en.bb = n.computeBoundingBox() 646 | 647 | if en.bb.Equal(prevBox) { 648 | // Optimize for the case where nothing is changed 649 | // to avoid computeBoundingBox which is expensive. 650 | break 651 | } 652 | } 653 | n = n.parent 654 | } 655 | 656 | for i := len(tree.deleted) - 1; i >= 0; i-- { 657 | n := tree.deleted[i] 658 | // reinsert entry so that it will remain at the same level as before 659 | e := entry{n.computeBoundingBox(), n, nil} 660 | tree.insert(e, n.level+1) 661 | } 662 | } 663 | 664 | // Searching 665 | 666 | // SearchIntersect returns all objects that intersect the specified rectangle. 667 | // Implemented per Section 3.1 of "R-trees: A Dynamic Index Structure for 668 | // Spatial Searching" by A. Guttman, Proceedings of ACM SIGMOD, p. 47-57, 1984. 669 | func (tree *Rtree) SearchIntersect(bb Rect, filters ...Filter) []Spatial { 670 | return tree.searchIntersect([]Spatial{}, tree.root, bb, filters) 671 | } 672 | 673 | // SearchIntersectWithLimit is similar to SearchIntersect, but returns 674 | // immediately when the first k results are found. A negative k behaves exactly 675 | // like SearchIntersect and returns all the results. 676 | // 677 | // Kept for backwards compatibility, please use SearchIntersect with a 678 | // LimitFilter. 679 | func (tree *Rtree) SearchIntersectWithLimit(k int, bb Rect) []Spatial { 680 | // backwards compatibility, previous implementation didn't limit results if 681 | // k was negative. 682 | if k < 0 { 683 | return tree.SearchIntersect(bb) 684 | } 685 | return tree.SearchIntersect(bb, LimitFilter(k)) 686 | } 687 | 688 | func (tree *Rtree) searchIntersect(results []Spatial, n *node, bb Rect, filters []Filter) []Spatial { 689 | for _, e := range n.entries { 690 | if !intersect(e.bb, bb) { 691 | continue 692 | } 693 | 694 | if !n.leaf { 695 | results = tree.searchIntersect(results, e.child, bb, filters) 696 | continue 697 | } 698 | 699 | refuse, abort := applyFilters(results, e.obj, filters) 700 | if !refuse { 701 | results = append(results, e.obj) 702 | } 703 | 704 | if abort { 705 | break 706 | } 707 | } 708 | return results 709 | } 710 | 711 | // NearestNeighbor returns the closest object to the specified point. 712 | // Implemented per "Nearest Neighbor Queries" by Roussopoulos et al 713 | func (tree *Rtree) NearestNeighbor(p Point) Spatial { 714 | obj, _ := tree.nearestNeighbor(p, tree.root, math.MaxFloat64, nil) 715 | return obj 716 | } 717 | 718 | // GetAllBoundingBoxes returning slice of bounding boxes by traversing tree. Slice 719 | // includes bounding boxes from all non-leaf nodes. 720 | func (tree *Rtree) GetAllBoundingBoxes() []Rect { 721 | var rects []Rect 722 | if tree.root != nil { 723 | rects = tree.root.getAllBoundingBoxes() 724 | } 725 | return rects 726 | } 727 | 728 | // utilities for sorting slices of entries 729 | 730 | type entrySlice struct { 731 | entries []entry 732 | dists []float64 733 | } 734 | 735 | func (s entrySlice) Len() int { return len(s.entries) } 736 | 737 | func (s entrySlice) Swap(i, j int) { 738 | s.entries[i], s.entries[j] = s.entries[j], s.entries[i] 739 | s.dists[i], s.dists[j] = s.dists[j], s.dists[i] 740 | } 741 | 742 | func (s entrySlice) Less(i, j int) bool { 743 | return s.dists[i] < s.dists[j] 744 | } 745 | 746 | func sortEntries(p Point, entries []entry) ([]entry, []float64) { 747 | sorted := make([]entry, len(entries)) 748 | dists := make([]float64, len(entries)) 749 | return sortPreallocEntries(p, entries, sorted, dists) 750 | } 751 | 752 | func sortPreallocEntries(p Point, entries, sorted []entry, dists []float64) ([]entry, []float64) { 753 | // use preallocated slices 754 | sorted = sorted[:len(entries)] 755 | dists = dists[:len(entries)] 756 | 757 | for i := 0; i < len(entries); i++ { 758 | sorted[i] = entries[i] 759 | dists[i] = p.minDist(entries[i].bb) 760 | } 761 | sort.Sort(entrySlice{sorted, dists}) 762 | return sorted, dists 763 | } 764 | 765 | func pruneEntriesMinDist(d float64, entries []entry, minDists []float64) []entry { 766 | var i int 767 | for ; i < len(entries); i++ { 768 | if minDists[i] > d { 769 | break 770 | } 771 | } 772 | return entries[:i] 773 | } 774 | 775 | func (tree *Rtree) nearestNeighbor(p Point, n *node, d float64, nearest Spatial) (Spatial, float64) { 776 | if n.leaf { 777 | for _, e := range n.entries { 778 | dist := math.Sqrt(p.minDist(e.bb)) 779 | if dist < d { 780 | d = dist 781 | nearest = e.obj 782 | } 783 | } 784 | } else { 785 | // Search only through entries with minDist <= minMinMaxDist, 786 | // where minDist is the distance between a point and a rectangle, 787 | // and minMaxDist is the smallest value among the maximum distance across all axes. 788 | // 789 | // Entries with minDist > minMinMaxDist are guaranteed to be farther away than some other entry. 790 | // 791 | // For more details, please consult 792 | // N. Roussopoulos, S. Kelley and F. Vincent, ACM SIGMOD, pages 71-79, 1995. 793 | minMinMaxDist := math.MaxFloat64 794 | for _, e := range n.entries { 795 | minMaxDist := p.minMaxDist(e.bb) 796 | if minMaxDist < minMinMaxDist { 797 | minMinMaxDist = minMaxDist 798 | } 799 | } 800 | 801 | for _, e := range n.entries { 802 | minDist := p.minDist(e.bb) 803 | // Add a bit of tolerance to guard against floating point rounding errors. 804 | if minDist > minMinMaxDist+tree.FloatingPointTolerance { 805 | continue 806 | } 807 | 808 | subNearest, dist := tree.nearestNeighbor(p, e.child, d, nearest) 809 | if dist < d { 810 | d = dist 811 | nearest = subNearest 812 | } 813 | } 814 | } 815 | 816 | return nearest, d 817 | } 818 | 819 | // NearestNeighbors gets the closest Spatials to the Point. 820 | func (tree *Rtree) NearestNeighbors(k int, p Point, filters ...Filter) []Spatial { 821 | // preallocate the buffers for sortings the branches. At each level of the 822 | // tree, we slide the buffer by the number of entries in the node. 823 | maxBufSize := tree.MaxChildren * tree.Depth() 824 | branches := make([]entry, maxBufSize) 825 | branchDists := make([]float64, maxBufSize) 826 | 827 | // allocate the buffers for the results 828 | dists := make([]float64, 0, k) 829 | objs := make([]Spatial, 0, k) 830 | 831 | objs, _, _ = tree.nearestNeighbors(k, p, tree.root, dists, objs, filters, branches, branchDists) 832 | return objs 833 | } 834 | 835 | // insert obj into nearest and return the first k elements in increasing order. 836 | func insertNearest(k int, dists []float64, nearest []Spatial, dist float64, obj Spatial, filters []Filter) ([]float64, []Spatial, bool) { 837 | i := sort.SearchFloat64s(dists, dist) 838 | for i < len(nearest) && dist >= dists[i] { 839 | i++ 840 | } 841 | if i >= k { 842 | return dists, nearest, false 843 | } 844 | 845 | if refuse, abort := applyFilters(nearest, obj, filters); refuse || abort { 846 | return dists, nearest, abort 847 | } 848 | 849 | // no resize since cap = k 850 | if len(nearest) < k { 851 | dists = append(dists, 0) 852 | nearest = append(nearest, nil) 853 | } 854 | 855 | left, right := dists[:i], dists[i:len(dists)-1] 856 | copy(dists, left) 857 | copy(dists[i+1:], right) 858 | dists[i] = dist 859 | 860 | leftObjs, rightObjs := nearest[:i], nearest[i:len(nearest)-1] 861 | copy(nearest, leftObjs) 862 | copy(nearest[i+1:], rightObjs) 863 | nearest[i] = obj 864 | 865 | return dists, nearest, false 866 | } 867 | 868 | func (tree *Rtree) nearestNeighbors(k int, p Point, n *node, dists []float64, nearest []Spatial, filters []Filter, b []entry, bd []float64) ([]Spatial, []float64, bool) { 869 | var abort bool 870 | if n.leaf { 871 | for _, e := range n.entries { 872 | dist := p.minDist(e.bb) 873 | dists, nearest, abort = insertNearest(k, dists, nearest, dist, e.obj, filters) 874 | if abort { 875 | break 876 | } 877 | } 878 | } else { 879 | branches, branchDists := sortPreallocEntries(p, n.entries, b, bd) 880 | // only prune if buffer has k elements 881 | if l := len(dists); l >= k { 882 | branches = pruneEntriesMinDist(dists[l-1], branches, branchDists) 883 | } 884 | for _, e := range branches { 885 | nearest, dists, abort = tree.nearestNeighbors(k, p, e.child, dists, nearest, filters, b[len(n.entries):], bd[len(n.entries):]) 886 | if abort { 887 | break 888 | } 889 | } 890 | } 891 | return nearest, dists, abort 892 | } 893 | -------------------------------------------------------------------------------- /rtree_test.go: -------------------------------------------------------------------------------- 1 | package rtreego 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "math/rand" 7 | "sort" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | type testCase struct { 13 | name string 14 | build func() *Rtree 15 | } 16 | 17 | func tests(dim, min, max int, objs ...Spatial) []*testCase { 18 | return []*testCase{ 19 | { 20 | "dynamically built", 21 | func() *Rtree { 22 | rt := NewTree(dim, min, max) 23 | for _, thing := range objs { 24 | rt.Insert(thing) 25 | } 26 | return rt 27 | }, 28 | }, 29 | { 30 | "bulk-loaded", 31 | func() *Rtree { 32 | return NewTree(dim, min, max, objs...) 33 | }, 34 | }, 35 | } 36 | } 37 | 38 | func (r Rect) Bounds() Rect { 39 | return r 40 | } 41 | 42 | func rectEq(a, b Rect) bool { 43 | if len(a.p) != len(b.p) { 44 | return false 45 | } 46 | for i := 0; i < len(a.p); i++ { 47 | if a.p[i] != b.p[i] { 48 | return false 49 | } 50 | } 51 | 52 | if len(a.q) != len(b.q) { 53 | return false 54 | } 55 | for i := 0; i < len(a.q); i++ { 56 | if a.q[i] != b.q[i] { 57 | return false 58 | } 59 | } 60 | 61 | return true 62 | } 63 | 64 | func entryEq(a, b entry) bool { 65 | if !rectEq(a.bb, b.bb) { 66 | return false 67 | } 68 | if a.child != b.child { 69 | return false 70 | } 71 | if a.obj != b.obj { 72 | return false 73 | } 74 | return true 75 | } 76 | 77 | func mustRect(p Point, widths []float64) Rect { 78 | r, err := NewRect(p, widths) 79 | if err != nil { 80 | panic(err) 81 | } 82 | return r 83 | } 84 | 85 | func printNode(n *node, level int) { 86 | padding := strings.Repeat("\t", level) 87 | fmt.Printf("%sNode: %p\n", padding, n) 88 | fmt.Printf("%sParent: %p\n", padding, n.parent) 89 | fmt.Printf("%sLevel: %d\n", padding, n.level) 90 | fmt.Printf("%sLeaf: %t\n%sEntries:\n", padding, n.leaf, padding) 91 | for _, e := range n.entries { 92 | printEntry(e, level+1) 93 | } 94 | } 95 | 96 | func printEntry(e entry, level int) { 97 | padding := strings.Repeat("\t", level) 98 | fmt.Printf("%sBB: %v\n", padding, e.bb) 99 | if e.child != nil { 100 | printNode(e.child, level) 101 | } else { 102 | fmt.Printf("%sObject: %v\n", padding, e.obj) 103 | } 104 | fmt.Println() 105 | } 106 | 107 | func items(n *node) chan Spatial { 108 | ch := make(chan Spatial) 109 | go func() { 110 | for _, e := range n.entries { 111 | if n.leaf { 112 | ch <- e.obj 113 | } else { 114 | for obj := range items(e.child) { 115 | ch <- obj 116 | } 117 | } 118 | } 119 | close(ch) 120 | }() 121 | return ch 122 | } 123 | 124 | func validate(n *node, height, max int) error { 125 | if n.level != height { 126 | return fmt.Errorf("level %d != height %d", n.level, height) 127 | } 128 | if len(n.entries) > max { 129 | return fmt.Errorf("node with too many entries at level %d/%d (actual: %d max: %d)", n.level, height, len(n.entries), max) 130 | } 131 | if n.leaf { 132 | if n.level != 1 { 133 | return fmt.Errorf("leaf node at level %d", n.level) 134 | } 135 | return nil 136 | } 137 | for _, e := range n.entries { 138 | if e.child.level != n.level-1 { 139 | return fmt.Errorf("failed to preserve level order") 140 | } 141 | if e.child.parent != n { 142 | return fmt.Errorf("failed to update parent pointer") 143 | } 144 | if err := validate(e.child, height-1, max); err != nil { 145 | return err 146 | } 147 | } 148 | return nil 149 | } 150 | 151 | func verify(t *testing.T, rt *Rtree) { 152 | if rt.height != rt.root.level { 153 | t.Errorf("invalid tree: height %d differs root level %d", rt.height, rt.root.level) 154 | } 155 | 156 | if err := validate(rt.root, rt.height, rt.MaxChildren); err != nil { 157 | printNode(rt.root, 0) 158 | t.Errorf("invalid tree: %v", err) 159 | } 160 | } 161 | 162 | var chooseLeafNodeTests = []struct { 163 | bb0, bb1, bb2 Rect // leaf bounding boxes 164 | exp int // expected chosen leaf 165 | desc string 166 | level int 167 | }{ 168 | { 169 | mustRect(Point{1, 1, 1}, []float64{1, 1, 1}), 170 | mustRect(Point{-1, -1, -1}, []float64{0.5, 0.5, 0.5}), 171 | mustRect(Point{3, 4, -5}, []float64{2, 0.9, 8}), 172 | 1, 173 | "clear winner", 174 | 1, 175 | }, 176 | { 177 | mustRect(Point{-1, -1.5, -1}, []float64{0.5, 2.5025, 0.5}), 178 | mustRect(Point{0.5, 1, 0.5}, []float64{0.5, 0.815, 0.5}), 179 | mustRect(Point{3, 4, -5}, []float64{2, 0.9, 8}), 180 | 1, 181 | "leaves tie", 182 | 1, 183 | }, 184 | { 185 | mustRect(Point{-1, -1.5, -1}, []float64{0.5, 2.5025, 0.5}), 186 | mustRect(Point{0.5, 1, 0.5}, []float64{0.5, 0.815, 0.5}), 187 | mustRect(Point{-1, -2, -3}, []float64{2, 4, 6}), 188 | 2, 189 | "leaf contains obj", 190 | 1, 191 | }, 192 | } 193 | 194 | func TestChooseLeafNodeEmpty(t *testing.T) { 195 | rt := NewTree(3, 5, 10) 196 | obj := Point{0, 0, 0}.ToRect(0.5) 197 | e := entry{obj, nil, obj} 198 | if leaf := rt.chooseNode(rt.root, e, 1); leaf != rt.root { 199 | t.Errorf("expected chooseLeaf of empty tree to return root") 200 | } 201 | } 202 | 203 | func TestChooseLeafNode(t *testing.T) { 204 | for _, test := range chooseLeafNodeTests { 205 | rt := Rtree{} 206 | rt.root = &node{} 207 | 208 | leaf0 := &node{rt.root, []entry{}, 1, true} 209 | entry0 := entry{test.bb0, leaf0, nil} 210 | 211 | leaf1 := &node{rt.root, []entry{}, 1, true} 212 | entry1 := entry{test.bb1, leaf1, nil} 213 | 214 | leaf2 := &node{rt.root, []entry{}, 1, true} 215 | entry2 := entry{test.bb2, leaf2, nil} 216 | 217 | rt.root.entries = []entry{entry0, entry1, entry2} 218 | 219 | obj := Point{0, 0, 0}.ToRect(0.5) 220 | e := entry{obj, nil, obj} 221 | 222 | expected := rt.root.entries[test.exp].child 223 | if leaf := rt.chooseNode(rt.root, e, 1); leaf != expected { 224 | t.Errorf("%s: expected %d", test.desc, test.exp) 225 | } 226 | } 227 | } 228 | 229 | func TestPickSeeds(t *testing.T) { 230 | entry1 := entry{bb: mustRect(Point{1, 1}, []float64{1, 1})} 231 | entry2 := entry{bb: mustRect(Point{1, -1}, []float64{2, 1})} 232 | entry3 := entry{bb: mustRect(Point{-1, -1}, []float64{1, 2})} 233 | n := node{entries: []entry{entry1, entry2, entry3}} 234 | left, right := n.pickSeeds() 235 | if !entryEq(n.entries[left], entry1) || !entryEq(n.entries[right], entry3) { 236 | t.Errorf("expected entries %d, %d", 1, 3) 237 | } 238 | } 239 | 240 | func TestPickNext(t *testing.T) { 241 | leftEntry := entry{bb: mustRect(Point{1, 1}, []float64{1, 1})} 242 | left := &node{entries: []entry{leftEntry}} 243 | 244 | rightEntry := entry{bb: mustRect(Point{-1, -1}, []float64{1, 2})} 245 | right := &node{entries: []entry{rightEntry}} 246 | 247 | entry1 := entry{bb: mustRect(Point{0, 0}, []float64{1, 1})} 248 | entry2 := entry{bb: mustRect(Point{-2, -2}, []float64{1, 1})} 249 | entry3 := entry{bb: mustRect(Point{1, 2}, []float64{1, 1})} 250 | entries := []entry{entry1, entry2, entry3} 251 | 252 | chosen := pickNext(left, right, entries) 253 | if !entryEq(entries[chosen], entry2) { 254 | t.Errorf("expected entry %d", 3) 255 | } 256 | } 257 | 258 | func TestSplit(t *testing.T) { 259 | entry1 := entry{bb: mustRect(Point{-3, -1}, []float64{2, 1})} 260 | entry2 := entry{bb: mustRect(Point{1, 2}, []float64{1, 1})} 261 | entry3 := entry{bb: mustRect(Point{-1, 0}, []float64{1, 1})} 262 | entry4 := entry{bb: mustRect(Point{-3, -3}, []float64{1, 1})} 263 | entry5 := entry{bb: mustRect(Point{1, -1}, []float64{2, 2})} 264 | entries := []entry{entry1, entry2, entry3, entry4, entry5} 265 | n := &node{entries: entries} 266 | 267 | l, r := n.split(0) // left=entry2, right=entry4 268 | expLeft := mustRect(Point{1, -1}, []float64{2, 4}) 269 | expRight := mustRect(Point{-3, -3}, []float64{3, 4}) 270 | 271 | lbb := l.computeBoundingBox() 272 | rbb := r.computeBoundingBox() 273 | if lbb.p.dist(expLeft.p) >= EPS || lbb.q.dist(expLeft.q) >= EPS { 274 | t.Errorf("expected left.bb = %s, got %s", expLeft, lbb) 275 | } 276 | if rbb.p.dist(expRight.p) >= EPS || rbb.q.dist(expRight.q) >= EPS { 277 | t.Errorf("expected right.bb = %s, got %s", expRight, rbb) 278 | } 279 | } 280 | 281 | func TestSplitUnderflow(t *testing.T) { 282 | entry1 := entry{bb: mustRect(Point{0, 0}, []float64{1, 1})} 283 | entry2 := entry{bb: mustRect(Point{0, 1}, []float64{1, 1})} 284 | entry3 := entry{bb: mustRect(Point{0, 2}, []float64{1, 1})} 285 | entry4 := entry{bb: mustRect(Point{0, 3}, []float64{1, 1})} 286 | entry5 := entry{bb: mustRect(Point{-50, -50}, []float64{1, 1})} 287 | entries := []entry{entry1, entry2, entry3, entry4, entry5} 288 | n := &node{entries: entries} 289 | 290 | l, r := n.split(2) 291 | 292 | if len(l.entries) != 3 || len(r.entries) != 2 { 293 | t.Errorf("expected underflow assignment for right group") 294 | } 295 | } 296 | 297 | func TestAssignGroupLeastEnlargement(t *testing.T) { 298 | r00 := entry{bb: mustRect(Point{0, 0}, []float64{1, 1})} 299 | r01 := entry{bb: mustRect(Point{0, 1}, []float64{1, 1})} 300 | r10 := entry{bb: mustRect(Point{1, 0}, []float64{1, 1})} 301 | r11 := entry{bb: mustRect(Point{1, 1}, []float64{1, 1})} 302 | r02 := entry{bb: mustRect(Point{0, 2}, []float64{1, 1})} 303 | 304 | group1 := &node{entries: []entry{r00, r01}} 305 | group2 := &node{entries: []entry{r10, r11}} 306 | 307 | assignGroup(r02, group1, group2) 308 | if len(group1.entries) != 3 || len(group2.entries) != 2 { 309 | t.Errorf("expected r02 added to group 1") 310 | } 311 | } 312 | 313 | func TestAssignGroupSmallerArea(t *testing.T) { 314 | r00 := entry{bb: mustRect(Point{0, 0}, []float64{1, 1})} 315 | r01 := entry{bb: mustRect(Point{0, 1}, []float64{1, 1})} 316 | r12 := entry{bb: mustRect(Point{1, 2}, []float64{1, 1})} 317 | r02 := entry{bb: mustRect(Point{0, 2}, []float64{1, 1})} 318 | 319 | group1 := &node{entries: []entry{r00, r01}} 320 | group2 := &node{entries: []entry{r12}} 321 | 322 | assignGroup(r02, group1, group2) 323 | if len(group2.entries) != 2 || len(group1.entries) != 2 { 324 | t.Errorf("expected r02 added to group 2") 325 | } 326 | } 327 | 328 | func TestAssignGroupFewerEntries(t *testing.T) { 329 | r0001 := entry{bb: mustRect(Point{0, 0}, []float64{1, 2})} 330 | r12 := entry{bb: mustRect(Point{1, 2}, []float64{1, 1})} 331 | r22 := entry{bb: mustRect(Point{2, 2}, []float64{1, 1})} 332 | r02 := entry{bb: mustRect(Point{0, 2}, []float64{1, 1})} 333 | 334 | group1 := &node{entries: []entry{r0001}} 335 | group2 := &node{entries: []entry{r12, r22}} 336 | 337 | assignGroup(r02, group1, group2) 338 | if len(group2.entries) != 2 || len(group1.entries) != 2 { 339 | t.Errorf("expected r02 added to group 2") 340 | } 341 | } 342 | 343 | func TestAdjustTreeNoPreviousSplit(t *testing.T) { 344 | rt := Rtree{root: &node{}} 345 | 346 | r00 := entry{bb: mustRect(Point{0, 0}, []float64{1, 1})} 347 | r01 := entry{bb: mustRect(Point{0, 1}, []float64{1, 1})} 348 | r10 := entry{bb: mustRect(Point{1, 0}, []float64{1, 1})} 349 | entries := []entry{r00, r01, r10} 350 | n := node{rt.root, entries, 1, false} 351 | rt.root.entries = []entry{{bb: Point{0, 0}.ToRect(0), child: &n}} 352 | 353 | rt.adjustTree(&n, nil) 354 | 355 | e := rt.root.entries[0] 356 | p, q := Point{0, 0}, Point{2, 2} 357 | if p.dist(e.bb.p) >= EPS || q.dist(e.bb.q) >= EPS { 358 | t.Errorf("Expected adjustTree to fit %v,%v,%v", r00.bb, r01.bb, r10.bb) 359 | } 360 | } 361 | 362 | func TestAdjustTreeNoSplit(t *testing.T) { 363 | rt := NewTree(2, 3, 3) 364 | 365 | r00 := entry{bb: mustRect(Point{0, 0}, []float64{1, 1})} 366 | r01 := entry{bb: mustRect(Point{0, 1}, []float64{1, 1})} 367 | left := node{rt.root, []entry{r00, r01}, 1, false} 368 | leftEntry := entry{bb: Point{0, 0}.ToRect(0), child: &left} 369 | 370 | r10 := entry{bb: mustRect(Point{1, 0}, []float64{1, 1})} 371 | r11 := entry{bb: mustRect(Point{1, 1}, []float64{1, 1})} 372 | right := node{rt.root, []entry{r10, r11}, 1, false} 373 | 374 | rt.root.entries = []entry{leftEntry} 375 | retl, retr := rt.adjustTree(&left, &right) 376 | 377 | if retl != rt.root || retr != nil { 378 | t.Errorf("Expected adjustTree didn't split the root") 379 | } 380 | 381 | entries := rt.root.entries 382 | if entries[0].child != &left || entries[1].child != &right { 383 | t.Errorf("Expected adjustTree keeps left and adds n in parent") 384 | } 385 | 386 | lbb, rbb := entries[0].bb, entries[1].bb 387 | if lbb.p.dist(Point{0, 0}) >= EPS || lbb.q.dist(Point{1, 2}) >= EPS { 388 | t.Errorf("Expected adjustTree to adjust left bb") 389 | } 390 | if rbb.p.dist(Point{1, 0}) >= EPS || rbb.q.dist(Point{2, 2}) >= EPS { 391 | t.Errorf("Expected adjustTree to adjust right bb") 392 | } 393 | } 394 | 395 | func TestAdjustTreeSplitParent(t *testing.T) { 396 | rt := NewTree(2, 1, 1) 397 | 398 | r00 := entry{bb: mustRect(Point{0, 0}, []float64{1, 1})} 399 | r01 := entry{bb: mustRect(Point{0, 1}, []float64{1, 1})} 400 | left := node{rt.root, []entry{r00, r01}, 1, false} 401 | leftEntry := entry{bb: Point{0, 0}.ToRect(0), child: &left} 402 | 403 | r10 := entry{bb: mustRect(Point{1, 0}, []float64{1, 1})} 404 | r11 := entry{bb: mustRect(Point{1, 1}, []float64{1, 1})} 405 | right := node{rt.root, []entry{r10, r11}, 1, false} 406 | 407 | rt.root.entries = []entry{leftEntry} 408 | retl, retr := rt.adjustTree(&left, &right) 409 | 410 | if len(retl.entries) != 1 || len(retr.entries) != 1 { 411 | t.Errorf("Expected adjustTree distributed the entries") 412 | } 413 | 414 | lbb, rbb := retl.entries[0].bb, retr.entries[0].bb 415 | if lbb.p.dist(Point{0, 0}) >= EPS || lbb.q.dist(Point{1, 2}) >= EPS { 416 | t.Errorf("Expected left split got left entry") 417 | } 418 | if rbb.p.dist(Point{1, 0}) >= EPS || rbb.q.dist(Point{2, 2}) >= EPS { 419 | t.Errorf("Expected right split got right entry") 420 | } 421 | } 422 | 423 | func TestInsertRepeated(t *testing.T) { 424 | var things []Spatial 425 | for i := 0; i < 10; i++ { 426 | things = append(things, mustRect(Point{0, 0}, []float64{2, 1})) 427 | } 428 | 429 | for _, tc := range tests(2, 3, 5, things...) { 430 | t.Run(tc.name, func(t *testing.T) { 431 | rt := tc.build() 432 | rt.Insert(mustRect(Point{0, 0}, []float64{2, 1})) 433 | }) 434 | } 435 | } 436 | 437 | func TestInsertNoSplit(t *testing.T) { 438 | rt := NewTree(2, 3, 3) 439 | thing := mustRect(Point{0, 0}, []float64{2, 1}) 440 | rt.Insert(thing) 441 | 442 | if rt.Size() != 1 { 443 | t.Errorf("Insert failed to increase tree size") 444 | } 445 | 446 | if len(rt.root.entries) != 1 || !rectEq(rt.root.entries[0].obj.(Rect), thing) { 447 | t.Errorf("Insert failed to insert thing into root entries") 448 | } 449 | } 450 | 451 | func TestInsertSplitRoot(t *testing.T) { 452 | rt := NewTree(2, 3, 3) 453 | things := []Rect{ 454 | mustRect(Point{0, 0}, []float64{2, 1}), 455 | mustRect(Point{3, 1}, []float64{1, 2}), 456 | mustRect(Point{1, 2}, []float64{2, 2}), 457 | mustRect(Point{8, 6}, []float64{1, 1}), 458 | mustRect(Point{10, 3}, []float64{1, 2}), 459 | mustRect(Point{11, 7}, []float64{1, 1}), 460 | } 461 | for _, thing := range things { 462 | rt.Insert(thing) 463 | } 464 | 465 | if rt.Size() != 6 { 466 | t.Errorf("Insert failed to insert") 467 | } 468 | 469 | if len(rt.root.entries) != 2 { 470 | t.Errorf("Insert failed to split") 471 | } 472 | 473 | left, right := rt.root.entries[0].child, rt.root.entries[1].child 474 | if len(left.entries) != 3 || len(right.entries) != 3 { 475 | t.Errorf("Insert failed to split evenly") 476 | } 477 | } 478 | 479 | func TestInsertSplit(t *testing.T) { 480 | rt := NewTree(2, 3, 3) 481 | things := []Rect{ 482 | mustRect(Point{0, 0}, []float64{2, 1}), 483 | mustRect(Point{3, 1}, []float64{1, 2}), 484 | mustRect(Point{1, 2}, []float64{2, 2}), 485 | mustRect(Point{8, 6}, []float64{1, 1}), 486 | mustRect(Point{10, 3}, []float64{1, 2}), 487 | mustRect(Point{11, 7}, []float64{1, 1}), 488 | mustRect(Point{10, 10}, []float64{2, 2}), 489 | } 490 | for _, thing := range things { 491 | rt.Insert(thing) 492 | } 493 | 494 | if rt.Size() != 7 { 495 | t.Errorf("Insert failed to insert") 496 | } 497 | 498 | if len(rt.root.entries) != 3 { 499 | t.Errorf("Insert failed to split") 500 | } 501 | 502 | a, b, c := rt.root.entries[0], rt.root.entries[1], rt.root.entries[2] 503 | if len(a.child.entries) != 3 || 504 | len(b.child.entries) != 3 || 505 | len(c.child.entries) != 1 { 506 | t.Errorf("Insert failed to split evenly") 507 | } 508 | } 509 | 510 | func TestInsertSplitSecondLevel(t *testing.T) { 511 | rt := NewTree(2, 3, 3) 512 | things := []Rect{ 513 | mustRect(Point{0, 0}, []float64{2, 1}), 514 | mustRect(Point{3, 1}, []float64{1, 2}), 515 | mustRect(Point{1, 2}, []float64{2, 2}), 516 | mustRect(Point{8, 6}, []float64{1, 1}), 517 | mustRect(Point{10, 3}, []float64{1, 2}), 518 | mustRect(Point{11, 7}, []float64{1, 1}), 519 | mustRect(Point{0, 6}, []float64{1, 2}), 520 | mustRect(Point{1, 6}, []float64{1, 2}), 521 | mustRect(Point{0, 8}, []float64{1, 2}), 522 | mustRect(Point{1, 8}, []float64{1, 2}), 523 | } 524 | for _, thing := range things { 525 | rt.Insert(thing) 526 | } 527 | 528 | if rt.Size() != 10 { 529 | t.Errorf("Insert failed to insert") 530 | } 531 | 532 | // should split root 533 | if len(rt.root.entries) != 2 { 534 | t.Errorf("Insert failed to split the root") 535 | } 536 | 537 | // split level + entries level + objs level 538 | if rt.Depth() != 3 { 539 | t.Errorf("Insert failed to adjust properly") 540 | } 541 | 542 | var checkParents func(n *node) 543 | checkParents = func(n *node) { 544 | if n.leaf { 545 | return 546 | } 547 | for _, e := range n.entries { 548 | if e.child.parent != n { 549 | t.Errorf("Insert failed to update parent pointers") 550 | } 551 | checkParents(e.child) 552 | } 553 | } 554 | checkParents(rt.root) 555 | } 556 | 557 | func TestBulkLoadingValidity(t *testing.T) { 558 | var things []Spatial 559 | for i := float64(0); i < float64(100); i++ { 560 | things = append(things, mustRect(Point{i, i}, []float64{1, 1})) 561 | } 562 | 563 | testCases := []struct { 564 | count int 565 | max int 566 | }{ 567 | { 568 | count: 5, 569 | max: 2, 570 | }, 571 | { 572 | count: 33, 573 | max: 5, 574 | }, 575 | { 576 | count: 34, 577 | max: 7, 578 | }, 579 | } 580 | 581 | for _, tc := range testCases { 582 | t.Run(fmt.Sprintf("count=%d-max=%d", tc.count, tc.max), func(t *testing.T) { 583 | rt := NewTree(2, 1, tc.max, things[:tc.count]...) 584 | verify(t, rt) 585 | }) 586 | } 587 | } 588 | 589 | func TestFindLeaf(t *testing.T) { 590 | rt := NewTree(2, 3, 3) 591 | rects := []Rect{ 592 | mustRect(Point{0, 0}, []float64{2, 1}), 593 | mustRect(Point{3, 1}, []float64{1, 2}), 594 | mustRect(Point{1, 2}, []float64{2, 2}), 595 | mustRect(Point{8, 6}, []float64{1, 1}), 596 | mustRect(Point{10, 3}, []float64{1, 2}), 597 | mustRect(Point{11, 7}, []float64{1, 1}), 598 | mustRect(Point{0, 6}, []float64{1, 2}), 599 | mustRect(Point{1, 6}, []float64{1, 2}), 600 | mustRect(Point{0, 8}, []float64{1, 2}), 601 | mustRect(Point{1, 8}, []float64{1, 2}), 602 | } 603 | things := []Spatial{} 604 | for i := range rects { 605 | things = append(things, &rects[i]) 606 | } 607 | 608 | for _, thing := range things { 609 | rt.Insert(thing) 610 | } 611 | verify(t, rt) 612 | for _, thing := range things { 613 | leaf := rt.findLeaf(rt.root, thing, defaultComparator) 614 | if leaf == nil { 615 | printNode(rt.root, 0) 616 | t.Fatalf("Unable to find leaf containing an entry after insertion!") 617 | } 618 | var found *Rect 619 | for _, other := range leaf.entries { 620 | if other.obj == thing { 621 | found = other.obj.(*Rect) 622 | break 623 | } 624 | } 625 | if found == nil { 626 | printNode(rt.root, 0) 627 | printNode(leaf, 0) 628 | t.Errorf("Entry %v not found in leaf node %v!", thing, leaf) 629 | } 630 | } 631 | } 632 | 633 | func TestFindLeafDoesNotExist(t *testing.T) { 634 | rt := NewTree(2, 3, 3) 635 | things := []Rect{ 636 | mustRect(Point{0, 0}, []float64{2, 1}), 637 | mustRect(Point{3, 1}, []float64{1, 2}), 638 | mustRect(Point{1, 2}, []float64{2, 2}), 639 | mustRect(Point{8, 6}, []float64{1, 1}), 640 | mustRect(Point{10, 3}, []float64{1, 2}), 641 | mustRect(Point{11, 7}, []float64{1, 1}), 642 | mustRect(Point{0, 6}, []float64{1, 2}), 643 | mustRect(Point{1, 6}, []float64{1, 2}), 644 | mustRect(Point{0, 8}, []float64{1, 2}), 645 | mustRect(Point{1, 8}, []float64{1, 2}), 646 | } 647 | for _, thing := range things { 648 | rt.Insert(thing) 649 | } 650 | 651 | obj := mustRect(Point{99, 99}, []float64{99, 99}) 652 | leaf := rt.findLeaf(rt.root, obj, defaultComparator) 653 | if leaf != nil { 654 | t.Errorf("findLeaf failed to return nil for non-existent object") 655 | } 656 | } 657 | 658 | func TestCondenseTreeEliminate(t *testing.T) { 659 | rt := NewTree(2, 3, 3) 660 | things := []Rect{ 661 | mustRect(Point{0, 0}, []float64{2, 1}), 662 | mustRect(Point{3, 1}, []float64{1, 2}), 663 | mustRect(Point{1, 2}, []float64{2, 2}), 664 | mustRect(Point{8, 6}, []float64{1, 1}), 665 | mustRect(Point{10, 3}, []float64{1, 2}), 666 | mustRect(Point{11, 7}, []float64{1, 1}), 667 | mustRect(Point{0, 6}, []float64{1, 2}), 668 | mustRect(Point{1, 6}, []float64{1, 2}), 669 | mustRect(Point{0, 8}, []float64{1, 2}), 670 | mustRect(Point{1, 8}, []float64{1, 2}), 671 | } 672 | for _, thing := range things { 673 | rt.Insert(thing) 674 | } 675 | 676 | // delete entry 2 from parent entries 677 | parent := rt.root.entries[0].child.entries[1].child 678 | parent.entries = append(parent.entries[:2], parent.entries[3:]...) 679 | rt.condenseTree(parent) 680 | 681 | retrieved := []Spatial{} 682 | for obj := range items(rt.root) { 683 | retrieved = append(retrieved, obj) 684 | } 685 | 686 | if len(retrieved) != len(things)-1 { 687 | t.Errorf("condenseTree failed to reinsert upstream elements") 688 | } 689 | 690 | verify(t, rt) 691 | } 692 | 693 | func TestChooseNodeNonLeaf(t *testing.T) { 694 | rt := NewTree(2, 3, 3) 695 | things := []Rect{ 696 | mustRect(Point{0, 0}, []float64{2, 1}), 697 | mustRect(Point{3, 1}, []float64{1, 2}), 698 | mustRect(Point{1, 2}, []float64{2, 2}), 699 | mustRect(Point{8, 6}, []float64{1, 1}), 700 | mustRect(Point{10, 3}, []float64{1, 2}), 701 | mustRect(Point{11, 7}, []float64{1, 1}), 702 | mustRect(Point{0, 6}, []float64{1, 2}), 703 | mustRect(Point{1, 6}, []float64{1, 2}), 704 | mustRect(Point{0, 8}, []float64{1, 2}), 705 | mustRect(Point{1, 8}, []float64{1, 2}), 706 | } 707 | for _, thing := range things { 708 | rt.Insert(thing) 709 | } 710 | 711 | obj := mustRect(Point{0, 10}, []float64{1, 2}) 712 | e := entry{obj, nil, obj} 713 | n := rt.chooseNode(rt.root, e, 2) 714 | if n.level != 2 { 715 | t.Errorf("chooseNode failed to stop at desired level") 716 | } 717 | } 718 | 719 | func TestInsertNonLeaf(t *testing.T) { 720 | rt := NewTree(2, 3, 3) 721 | things := []Rect{ 722 | mustRect(Point{0, 0}, []float64{2, 1}), 723 | mustRect(Point{3, 1}, []float64{1, 2}), 724 | mustRect(Point{1, 2}, []float64{2, 2}), 725 | mustRect(Point{8, 6}, []float64{1, 1}), 726 | mustRect(Point{10, 3}, []float64{1, 2}), 727 | mustRect(Point{11, 7}, []float64{1, 1}), 728 | mustRect(Point{0, 6}, []float64{1, 2}), 729 | mustRect(Point{1, 6}, []float64{1, 2}), 730 | mustRect(Point{0, 8}, []float64{1, 2}), 731 | mustRect(Point{1, 8}, []float64{1, 2}), 732 | } 733 | for _, thing := range things { 734 | rt.Insert(thing) 735 | } 736 | 737 | obj := mustRect(Point{99, 99}, []float64{99, 99}) 738 | e := entry{obj, nil, obj} 739 | rt.insert(e, 2) 740 | 741 | expected := rt.root.entries[1].child 742 | if !rectEq(expected.entries[1].obj.(Rect), obj) { 743 | t.Errorf("insert failed to insert entry at correct level") 744 | } 745 | } 746 | 747 | func TestDeleteFlatten(t *testing.T) { 748 | rects := []Rect{ 749 | mustRect(Point{0, 0}, []float64{2, 1}), 750 | mustRect(Point{3, 1}, []float64{1, 2}), 751 | } 752 | things := []Spatial{} 753 | for i := range rects { 754 | things = append(things, &rects[i]) 755 | } 756 | 757 | for _, tc := range tests(2, 3, 3, things...) { 758 | t.Run(tc.name, func(t *testing.T) { 759 | rt := tc.build() 760 | // make sure flattening didn't nuke the tree 761 | rt.Delete(things[0]) 762 | verify(t, rt) 763 | }) 764 | } 765 | } 766 | 767 | func TestDelete(t *testing.T) { 768 | rects := []Rect{ 769 | mustRect(Point{0, 0}, []float64{2, 1}), 770 | mustRect(Point{3, 1}, []float64{1, 2}), 771 | mustRect(Point{1, 2}, []float64{2, 2}), 772 | mustRect(Point{8, 6}, []float64{1, 1}), 773 | mustRect(Point{10, 3}, []float64{1, 2}), 774 | mustRect(Point{11, 7}, []float64{1, 1}), 775 | mustRect(Point{0, 6}, []float64{1, 2}), 776 | mustRect(Point{1, 6}, []float64{1, 2}), 777 | mustRect(Point{0, 8}, []float64{1, 2}), 778 | mustRect(Point{1, 8}, []float64{1, 2}), 779 | } 780 | things := []Spatial{} 781 | for i := range rects { 782 | things = append(things, &rects[i]) 783 | } 784 | 785 | for _, tc := range tests(2, 3, 3, things...) { 786 | t.Run(tc.name, func(t *testing.T) { 787 | rt := tc.build() 788 | 789 | verify(t, rt) 790 | 791 | things2 := []Spatial{} 792 | for len(things) > 0 { 793 | i := rand.Int() % len(things) 794 | things2 = append(things2, things[i]) 795 | things = append(things[:i], things[i+1:]...) 796 | } 797 | 798 | for i, thing := range things2 { 799 | ok := rt.Delete(thing) 800 | if !ok { 801 | t.Errorf("Thing %v was not found in tree during deletion", thing) 802 | return 803 | } 804 | 805 | if rt.Size() != len(things2)-i-1 { 806 | t.Errorf("Delete failed to remove %v", thing) 807 | return 808 | } 809 | verify(t, rt) 810 | } 811 | }) 812 | } 813 | } 814 | 815 | func TestDeleteWithDepthChange(t *testing.T) { 816 | rt := NewTree(2, 3, 3) 817 | rects := []Rect{ 818 | mustRect(Point{0, 0}, []float64{2, 1}), 819 | mustRect(Point{3, 1}, []float64{1, 2}), 820 | mustRect(Point{1, 2}, []float64{2, 2}), 821 | mustRect(Point{8, 6}, []float64{1, 1}), 822 | } 823 | things := []Spatial{} 824 | for i := range rects { 825 | things = append(things, &rects[i]) 826 | } 827 | 828 | for _, thing := range things { 829 | rt.Insert(thing) 830 | } 831 | 832 | // delete last item and condense nodes 833 | rt.Delete(things[3]) 834 | 835 | // rt.height should be 1 otherwise insert increases height to 3 836 | rt.Insert(things[3]) 837 | 838 | // and verify would fail 839 | verify(t, rt) 840 | } 841 | 842 | func TestDeleteWithComparator(t *testing.T) { 843 | type IDRect struct { 844 | ID string 845 | Rect 846 | } 847 | 848 | things := []Spatial{ 849 | &IDRect{"1", mustRect(Point{0, 0}, []float64{2, 1})}, 850 | &IDRect{"2", mustRect(Point{3, 1}, []float64{1, 2})}, 851 | &IDRect{"3", mustRect(Point{1, 2}, []float64{2, 2})}, 852 | &IDRect{"4", mustRect(Point{8, 6}, []float64{1, 1})}, 853 | &IDRect{"5", mustRect(Point{10, 3}, []float64{1, 2})}, 854 | &IDRect{"6", mustRect(Point{11, 7}, []float64{1, 1})}, 855 | &IDRect{"7", mustRect(Point{0, 6}, []float64{1, 2})}, 856 | &IDRect{"8", mustRect(Point{1, 6}, []float64{1, 2})}, 857 | &IDRect{"9", mustRect(Point{0, 8}, []float64{1, 2})}, 858 | &IDRect{"10", mustRect(Point{1, 8}, []float64{1, 2})}, 859 | } 860 | 861 | for _, tc := range tests(2, 3, 3, things...) { 862 | t.Run(tc.name, func(t *testing.T) { 863 | rt := tc.build() 864 | 865 | verify(t, rt) 866 | 867 | cmp := func(obj1, obj2 Spatial) bool { 868 | idr1 := obj1.(*IDRect) 869 | idr2 := obj2.(*IDRect) 870 | return idr1.ID == idr2.ID 871 | } 872 | 873 | things2 := []*IDRect{} 874 | for len(things) > 0 { 875 | i := rand.Int() % len(things) 876 | // make a deep copy 877 | copy := &IDRect{things[i].(*IDRect).ID, things[i].(*IDRect).Rect} 878 | things2 = append(things2, copy) 879 | 880 | if !cmp(things[i], copy) { 881 | log.Fatalf("expected copy to be equal to the original, original: %v, copy: %v", things[i], copy) 882 | } 883 | 884 | things = append(things[:i], things[i+1:]...) 885 | } 886 | 887 | for i, thing := range things2 { 888 | ok := rt.DeleteWithComparator(thing, cmp) 889 | if !ok { 890 | t.Errorf("Thing %v was not found in tree during deletion", thing) 891 | return 892 | } 893 | 894 | if rt.Size() != len(things2)-i-1 { 895 | t.Errorf("Delete failed to remove %v", thing) 896 | return 897 | } 898 | verify(t, rt) 899 | } 900 | }) 901 | } 902 | } 903 | 904 | func TestDeleteThenInsert(t *testing.T) { 905 | tol := 1e-3 906 | rects := []Rect{ 907 | mustRect(Point{3, 1}, []float64{tol, tol}), 908 | mustRect(Point{1, 2}, []float64{tol, tol}), 909 | mustRect(Point{2, 6}, []float64{tol, tol}), 910 | mustRect(Point{3, 6}, []float64{tol, tol}), 911 | mustRect(Point{2, 8}, []float64{tol, tol}), 912 | } 913 | things := []Spatial{} 914 | for i := range rects { 915 | things = append(things, &rects[i]) 916 | } 917 | 918 | rt := NewTree(2, 2, 2, things...) 919 | 920 | if ok := rt.Delete(things[3]); !ok { 921 | t.Fatalf("%#v", things[3]) 922 | } 923 | rt.Insert(things[3]) 924 | 925 | // Deleting and then inserting things[3] should not affect things[4]. 926 | if ok := rt.Delete(things[4]); !ok { 927 | t.Fatalf("%#v", things[4]) 928 | } 929 | } 930 | 931 | func TestSearchIntersect(t *testing.T) { 932 | rects := []Rect{ 933 | mustRect(Point{0, 0}, []float64{2, 1}), 934 | mustRect(Point{3, 1}, []float64{1, 2}), 935 | mustRect(Point{1, 2}, []float64{2, 2}), 936 | mustRect(Point{8, 6}, []float64{1, 1}), 937 | mustRect(Point{10, 3}, []float64{1, 2}), 938 | mustRect(Point{11, 7}, []float64{1, 1}), 939 | mustRect(Point{2, 6}, []float64{1, 2}), 940 | mustRect(Point{3, 6}, []float64{1, 2}), 941 | mustRect(Point{2, 8}, []float64{1, 2}), 942 | mustRect(Point{3, 8}, []float64{1, 2}), 943 | } 944 | things := []Spatial{} 945 | for i := range rects { 946 | things = append(things, &rects[i]) 947 | } 948 | 949 | for _, tc := range tests(2, 3, 3, things...) { 950 | t.Run(tc.name, func(t *testing.T) { 951 | rt := tc.build() 952 | 953 | p := Point{2, 1.5} 954 | bb := mustRect(p, []float64{10, 5.5}) 955 | q := rt.SearchIntersect(bb) 956 | 957 | var expected []Spatial 958 | for _, i := range []int{1, 2, 3, 4, 6, 7} { 959 | expected = append(expected, things[i]) 960 | } 961 | 962 | ensureDisorderedSubset(t, q, expected) 963 | }) 964 | } 965 | 966 | } 967 | 968 | func TestSearchIntersectWithLimit(t *testing.T) { 969 | rects := []Rect{ 970 | mustRect(Point{0, 0}, []float64{2, 1}), 971 | mustRect(Point{3, 1}, []float64{1, 2}), 972 | mustRect(Point{1, 2}, []float64{2, 2}), 973 | mustRect(Point{8, 6}, []float64{1, 1}), 974 | mustRect(Point{10, 3}, []float64{1, 2}), 975 | mustRect(Point{11, 7}, []float64{1, 1}), 976 | mustRect(Point{2, 6}, []float64{1, 2}), 977 | mustRect(Point{3, 6}, []float64{1, 2}), 978 | mustRect(Point{2, 8}, []float64{1, 2}), 979 | mustRect(Point{3, 8}, []float64{1, 2}), 980 | } 981 | things := []Spatial{} 982 | for i := range rects { 983 | things = append(things, &rects[i]) 984 | } 985 | 986 | for _, tc := range tests(2, 3, 3, things...) { 987 | t.Run(tc.name, func(t *testing.T) { 988 | rt := tc.build() 989 | 990 | bb := mustRect(Point{2, 1.5}, []float64{10, 5.5}) 991 | 992 | // expected contains all the intersecting things 993 | var expected []Spatial 994 | for _, i := range []int{1, 2, 6, 7, 3, 4} { 995 | expected = append(expected, things[i]) 996 | } 997 | 998 | // Loop through all possible limits k of SearchIntersectWithLimit, 999 | // and test that the results are as expected. 1000 | for k := -1; k <= len(things); k++ { 1001 | q := rt.SearchIntersectWithLimit(k, bb) 1002 | 1003 | if k == -1 { 1004 | ensureDisorderedSubset(t, q, expected) 1005 | if len(q) != len(expected) { 1006 | t.Fatalf("length of actual (%v) was different from expected (%v)", len(q), len(expected)) 1007 | } 1008 | } else if k == 0 { 1009 | if len(q) != 0 { 1010 | t.Fatalf("length of actual (%v) was different from expected (%v)", len(q), len(expected)) 1011 | } 1012 | } else if k <= len(expected) { 1013 | ensureDisorderedSubset(t, q, expected) 1014 | if len(q) != k { 1015 | t.Fatalf("length of actual (%v) was different from expected (%v)", len(q), len(expected)) 1016 | } 1017 | } else { 1018 | ensureDisorderedSubset(t, q, expected) 1019 | if len(q) != len(expected) { 1020 | t.Fatalf("length of actual (%v) was different from expected (%v)", len(q), len(expected)) 1021 | } 1022 | } 1023 | } 1024 | }) 1025 | } 1026 | } 1027 | 1028 | func TestSearchIntersectWithTestFilter(t *testing.T) { 1029 | rects := []Rect{ 1030 | mustRect(Point{0, 0}, []float64{2, 1}), 1031 | mustRect(Point{3, 1}, []float64{1, 2}), 1032 | mustRect(Point{1, 2}, []float64{2, 2}), 1033 | mustRect(Point{8, 6}, []float64{1, 1}), 1034 | mustRect(Point{10, 3}, []float64{1, 2}), 1035 | mustRect(Point{11, 7}, []float64{1, 1}), 1036 | mustRect(Point{2, 6}, []float64{1, 2}), 1037 | mustRect(Point{3, 6}, []float64{1, 2}), 1038 | mustRect(Point{2, 8}, []float64{1, 2}), 1039 | mustRect(Point{3, 8}, []float64{1, 2}), 1040 | } 1041 | things := []Spatial{} 1042 | for i := range rects { 1043 | things = append(things, &rects[i]) 1044 | } 1045 | 1046 | for _, tc := range tests(2, 3, 3, things...) { 1047 | t.Run(tc.name, func(t *testing.T) { 1048 | rt := tc.build() 1049 | 1050 | bb := mustRect(Point{2, 1.5}, []float64{10, 5.5}) 1051 | 1052 | // intersecting indexes are 1, 2, 6, 7, 3, 4 1053 | // rects which we do not filter out 1054 | var expected []Spatial 1055 | for _, i := range []int{1, 6, 4} { 1056 | expected = append(expected, things[i]) 1057 | } 1058 | 1059 | // this test filter will only pick the objects that are in expected 1060 | objects := rt.SearchIntersect(bb, func(results []Spatial, object Spatial) (bool, bool) { 1061 | for _, exp := range expected { 1062 | if exp == object { 1063 | return false, false 1064 | } 1065 | } 1066 | return true, false 1067 | }) 1068 | 1069 | ensureDisorderedSubset(t, objects, expected) 1070 | }) 1071 | } 1072 | } 1073 | 1074 | func TestSearchIntersectNoResults(t *testing.T) { 1075 | things := []Spatial{ 1076 | mustRect(Point{0, 0}, []float64{2, 1}), 1077 | mustRect(Point{3, 1}, []float64{1, 2}), 1078 | mustRect(Point{1, 2}, []float64{2, 2}), 1079 | mustRect(Point{8, 6}, []float64{1, 1}), 1080 | mustRect(Point{10, 3}, []float64{1, 2}), 1081 | mustRect(Point{11, 7}, []float64{1, 1}), 1082 | mustRect(Point{2, 6}, []float64{1, 2}), 1083 | mustRect(Point{3, 6}, []float64{1, 2}), 1084 | mustRect(Point{2, 8}, []float64{1, 2}), 1085 | mustRect(Point{3, 8}, []float64{1, 2}), 1086 | } 1087 | 1088 | for _, tc := range tests(2, 3, 3, things...) { 1089 | t.Run(tc.name, func(t *testing.T) { 1090 | rt := tc.build() 1091 | 1092 | bb := mustRect(Point{99, 99}, []float64{10, 5.5}) 1093 | q := rt.SearchIntersect(bb) 1094 | if len(q) != 0 { 1095 | t.Errorf("SearchIntersect failed to return nil slice on failing query") 1096 | } 1097 | }) 1098 | } 1099 | } 1100 | 1101 | func TestSortEntries(t *testing.T) { 1102 | objs := []Rect{ 1103 | mustRect(Point{1, 1}, []float64{1, 1}), 1104 | mustRect(Point{2, 2}, []float64{1, 1}), 1105 | mustRect(Point{3, 3}, []float64{1, 1})} 1106 | entries := []entry{ 1107 | {objs[2], nil, &objs[2]}, 1108 | {objs[1], nil, &objs[1]}, 1109 | {objs[0], nil, &objs[0]}, 1110 | } 1111 | sorted, dists := sortEntries(Point{0, 0}, entries) 1112 | if !entryEq(sorted[0], entries[2]) || !entryEq(sorted[1], entries[1]) || !entryEq(sorted[2], entries[0]) { 1113 | t.Errorf("sortEntries failed") 1114 | } 1115 | if dists[0] != 2 || dists[1] != 8 || dists[2] != 18 { 1116 | t.Errorf("sortEntries failed to calculate proper distances") 1117 | } 1118 | } 1119 | 1120 | func TestNearestNeighbor(t *testing.T) { 1121 | rects := []Rect{ 1122 | mustRect(Point{1, 1}, []float64{1, 1}), 1123 | mustRect(Point{1, 3}, []float64{1, 1}), 1124 | mustRect(Point{3, 2}, []float64{1, 1}), 1125 | mustRect(Point{-7, -7}, []float64{1, 1}), 1126 | mustRect(Point{7, 7}, []float64{1, 1}), 1127 | mustRect(Point{10, 2}, []float64{1, 1}), 1128 | } 1129 | things := []Spatial{} 1130 | for i := range rects { 1131 | things = append(things, &rects[i]) 1132 | } 1133 | 1134 | for _, tc := range tests(2, 3, 3, things...) { 1135 | t.Run(tc.name, func(t *testing.T) { 1136 | rt := tc.build() 1137 | 1138 | obj1 := rt.NearestNeighbor(Point{0.5, 0.5}) 1139 | obj2 := rt.NearestNeighbor(Point{1.5, 4.5}) 1140 | obj3 := rt.NearestNeighbor(Point{5, 2.5}) 1141 | obj4 := rt.NearestNeighbor(Point{3.5, 2.5}) 1142 | 1143 | if obj1 != things[0] || obj2 != things[1] || obj3 != things[2] || obj4 != things[2] { 1144 | t.Errorf("NearestNeighbor failed") 1145 | } 1146 | }) 1147 | } 1148 | } 1149 | 1150 | func TestComputeBoundingBox(t *testing.T) { 1151 | rect1, _ := NewRect(Point{0, 0}, []float64{1, 1}) 1152 | rect2, _ := NewRect(Point{0, 1}, []float64{1, 1}) 1153 | rect3, _ := NewRect(Point{1, 0}, []float64{1, 1}) 1154 | n := &node{} 1155 | n.entries = append(n.entries, entry{bb: rect1}) 1156 | n.entries = append(n.entries, entry{bb: rect2}) 1157 | n.entries = append(n.entries, entry{bb: rect3}) 1158 | 1159 | exp, _ := NewRect(Point{0, 0}, []float64{2, 2}) 1160 | bb := n.computeBoundingBox() 1161 | d1 := bb.p.dist(exp.p) 1162 | d2 := bb.q.dist(exp.q) 1163 | if d1 > EPS || d2 > EPS { 1164 | t.Errorf("boundingBoxN(%v, %v, %v) != %v, got %v", rect1, rect2, rect3, exp, bb) 1165 | } 1166 | } 1167 | 1168 | func TestGetAllBoundingBoxes(t *testing.T) { 1169 | rt1 := NewTree(2, 3, 3) 1170 | rt2 := NewTree(2, 2, 4) 1171 | rt3 := NewTree(2, 4, 8) 1172 | things := []Rect{ 1173 | mustRect(Point{0, 0}, []float64{2, 1}), 1174 | mustRect(Point{3, 1}, []float64{1, 2}), 1175 | mustRect(Point{1, 2}, []float64{2, 2}), 1176 | mustRect(Point{8, 6}, []float64{1, 1}), 1177 | mustRect(Point{10, 3}, []float64{1, 2}), 1178 | mustRect(Point{11, 7}, []float64{1, 1}), 1179 | mustRect(Point{10, 10}, []float64{2, 2}), 1180 | mustRect(Point{2, 3}, []float64{0.5, 1}), 1181 | mustRect(Point{3, 5}, []float64{1.5, 2}), 1182 | mustRect(Point{7, 14}, []float64{2.5, 2}), 1183 | mustRect(Point{15, 6}, []float64{1, 1}), 1184 | mustRect(Point{4, 3}, []float64{1, 2}), 1185 | mustRect(Point{1, 7}, []float64{1, 1}), 1186 | mustRect(Point{10, 5}, []float64{2, 2}), 1187 | } 1188 | for _, thing := range things { 1189 | rt1.Insert(thing) 1190 | } 1191 | for _, thing := range things { 1192 | rt2.Insert(thing) 1193 | } 1194 | for _, thing := range things { 1195 | rt3.Insert(thing) 1196 | } 1197 | 1198 | if rt1.Size() != 14 { 1199 | t.Errorf("Insert failed to insert") 1200 | } 1201 | if rt2.Size() != 14 { 1202 | t.Errorf("Insert failed to insert") 1203 | } 1204 | if rt3.Size() != 14 { 1205 | t.Errorf("Insert failed to insert") 1206 | } 1207 | 1208 | rtbb1 := rt1.GetAllBoundingBoxes() 1209 | rtbb2 := rt2.GetAllBoundingBoxes() 1210 | rtbb3 := rt3.GetAllBoundingBoxes() 1211 | 1212 | if len(rtbb1) != 13 { 1213 | t.Errorf("Failed bounding box traversal expected 13, got %d", len(rtbb1)) 1214 | } 1215 | if len(rtbb2) != 7 { 1216 | t.Errorf("Failed bounding box traversal expected 7, got %d", len(rtbb2)) 1217 | } 1218 | if len(rtbb3) != 2 { 1219 | t.Errorf("Failed bounding box traversal expected 2, got %d", len(rtbb3)) 1220 | } 1221 | } 1222 | 1223 | type byMinDist struct { 1224 | r []Spatial 1225 | p Point 1226 | } 1227 | 1228 | func (r byMinDist) Less(i, j int) bool { 1229 | return r.p.minDist(r.r[i].Bounds()) < r.p.minDist(r.r[j].Bounds()) 1230 | } 1231 | 1232 | func (r byMinDist) Len() int { 1233 | return len(r.r) 1234 | } 1235 | 1236 | func (r byMinDist) Swap(i, j int) { 1237 | r.r[i], r.r[j] = r.r[j], r.r[i] 1238 | } 1239 | 1240 | func TestNearestNeighborsAll(t *testing.T) { 1241 | rects := []Rect{ 1242 | mustRect(Point{1, 1}, []float64{1, 1}), 1243 | mustRect(Point{-7, -7}, []float64{1, 1}), 1244 | mustRect(Point{1, 3}, []float64{1, 1}), 1245 | mustRect(Point{7, 7}, []float64{1, 1}), 1246 | mustRect(Point{10, 2}, []float64{1, 1}), 1247 | mustRect(Point{3, 3}, []float64{1, 1}), 1248 | } 1249 | things := []Spatial{} 1250 | for i := range rects { 1251 | things = append(things, &rects[i]) 1252 | } 1253 | 1254 | for _, tc := range tests(2, 3, 3, things...) { 1255 | t.Run(tc.name, func(t *testing.T) { 1256 | rt := tc.build() 1257 | 1258 | verify(t, rt) 1259 | 1260 | p := Point{0.5, 0.5} 1261 | sort.Sort(byMinDist{things, p}) 1262 | 1263 | objs := rt.NearestNeighbors(len(things), p) 1264 | for i := range things { 1265 | if objs[i] != things[i] { 1266 | t.Errorf("NearestNeighbors failed at index %d: %v != %v", i, objs[i], things[i]) 1267 | } 1268 | } 1269 | 1270 | objs = rt.NearestNeighbors(len(things)+2, p) 1271 | if len(objs) > len(things) { 1272 | t.Errorf("NearestNeighbors failed: too many elements") 1273 | } 1274 | if len(objs) < len(things) { 1275 | t.Errorf("NearestNeighbors failed: not enough elements") 1276 | } 1277 | 1278 | }) 1279 | } 1280 | } 1281 | 1282 | func TestNearestNeighborsFilters(t *testing.T) { 1283 | rects := []Rect{ 1284 | mustRect(Point{1, 1}, []float64{1, 1}), 1285 | mustRect(Point{-7, -7}, []float64{1, 1}), 1286 | mustRect(Point{1, 3}, []float64{1, 1}), 1287 | mustRect(Point{7, 7}, []float64{1, 1}), 1288 | mustRect(Point{10, 2}, []float64{1, 1}), 1289 | mustRect(Point{3, 3}, []float64{1, 1}), 1290 | } 1291 | things := []Spatial{} 1292 | for i := range rects { 1293 | things = append(things, &rects[i]) 1294 | } 1295 | 1296 | expected := []Spatial{things[0], things[2], things[3]} 1297 | 1298 | for _, tc := range tests(2, 3, 3, things...) { 1299 | t.Run(tc.name, func(t *testing.T) { 1300 | rt := tc.build() 1301 | 1302 | p := Point{0.5, 0.5} 1303 | sort.Sort(byMinDist{expected, p}) 1304 | 1305 | objs := rt.NearestNeighbors(len(things), p, func(r []Spatial, obj Spatial) (bool, bool) { 1306 | for _, ex := range expected { 1307 | if ex == obj { 1308 | return false, false 1309 | } 1310 | } 1311 | 1312 | return true, false 1313 | }) 1314 | 1315 | ensureOrderedSubset(t, objs, expected) 1316 | }) 1317 | } 1318 | } 1319 | 1320 | func TestNearestNeighborsHalf(t *testing.T) { 1321 | rects := []Rect{ 1322 | mustRect(Point{1, 1}, []float64{1, 1}), 1323 | mustRect(Point{-7, -7}, []float64{1, 1}), 1324 | mustRect(Point{1, 3}, []float64{1, 1}), 1325 | mustRect(Point{7, 7}, []float64{1, 1}), 1326 | mustRect(Point{10, 2}, []float64{1, 1}), 1327 | mustRect(Point{3, 3}, []float64{1, 1}), 1328 | } 1329 | things := []Spatial{} 1330 | for i := range rects { 1331 | things = append(things, &rects[i]) 1332 | } 1333 | 1334 | p := Point{0.5, 0.5} 1335 | sort.Sort(byMinDist{things, p}) 1336 | 1337 | for _, tc := range tests(2, 3, 3, things...) { 1338 | t.Run(tc.name, func(t *testing.T) { 1339 | rt := tc.build() 1340 | 1341 | objs := rt.NearestNeighbors(3, p) 1342 | for i := range objs { 1343 | if objs[i] != things[i] { 1344 | t.Errorf("NearestNeighbors failed at index %d: %v != %v", i, objs[i], things[i]) 1345 | } 1346 | } 1347 | 1348 | objs = rt.NearestNeighbors(len(things)+2, p) 1349 | if len(objs) > len(things) { 1350 | t.Errorf("NearestNeighbors failed: too many elements") 1351 | } 1352 | }) 1353 | } 1354 | } 1355 | 1356 | func TestMinMaxDistFloatingPointRoundingError(t *testing.T) { 1357 | rects := []Rect{ 1358 | Point{1134900, 15600}.ToRect(0), 1359 | Point{1134900, 25600}.ToRect(0), 1360 | Point{1134900, 22805}.ToRect(0), 1361 | Point{1134900, 29116}.ToRect(0), 1362 | } 1363 | things := make([]Spatial, 0, len(rects)) 1364 | for i := range rects { 1365 | things = append(things, &rects[i]) 1366 | } 1367 | rt := NewTree(2, 1, 2, things...) 1368 | n := rt.NearestNeighbor(Point{1134851.8, 25570.8}) 1369 | if n != things[1] { 1370 | t.Fatalf("wrong neighbor, expected %v, got %v", things[1], n) 1371 | } 1372 | } 1373 | 1374 | func TestInsertThenDeleteAllInDifferentOrder(t *testing.T) { 1375 | rects := []Rect{ 1376 | mustRect(Point{1, 1}, []float64{1, 1}), 1377 | mustRect(Point{2, 2}, []float64{1, 1}), 1378 | mustRect(Point{3, 3}, []float64{1, 1}), 1379 | mustRect(Point{4, 4}, []float64{1, 1}), 1380 | mustRect(Point{5, 5}, []float64{1, 1}), 1381 | } 1382 | things := []Spatial{} 1383 | for i := range rects { 1384 | things = append(things, &rects[i]) 1385 | } 1386 | 1387 | deleteOrders := [][]int{ 1388 | {0, 1, 2, 3, 4}, 1389 | // in this case, the last delete will cause the issue: no thing but 2 levels depth. 1390 | // {"size":0,"depth":2,"root":{"entries":[]}} 1391 | {1, 2, 3, 4, 0}, 1392 | } 1393 | for _, order := range deleteOrders { 1394 | rt := NewTree(2, 2, 2) 1395 | for _, thing := range things { 1396 | rt.Insert(thing) 1397 | } 1398 | if rt.Size() != 5 { 1399 | t.Errorf("Insert failed to insert") 1400 | } 1401 | 1402 | for _, idx := range order { 1403 | rt.Delete(things[idx]) 1404 | } 1405 | if rt.Size() != 0 { 1406 | t.Errorf("Delete failed to delete, got size: %d, expected size: 0", rt.Size()) 1407 | } 1408 | if rt.Depth() != 1 { 1409 | t.Errorf("Delete failed to delete, got depth: %d, expected depth: 1", rt.Depth()) 1410 | } 1411 | } 1412 | } 1413 | 1414 | func ensureOrderedSubset(t *testing.T, actual []Spatial, expected []Spatial) { 1415 | for i := range actual { 1416 | if len(expected)-1 < i || actual[i] != expected[i] { 1417 | t.Fatalf("actual is not an ordered subset of expected") 1418 | } 1419 | } 1420 | } 1421 | 1422 | func ensureDisorderedSubset(t *testing.T, actual []Spatial, expected []Spatial) { 1423 | for _, obj := range actual { 1424 | if !contains(obj, expected) { 1425 | t.Fatalf("actual contained an object that was not expected: %+v", obj) 1426 | } 1427 | } 1428 | } 1429 | 1430 | func contains(obj Spatial, slice []Spatial) bool { 1431 | for _, s := range slice { 1432 | if s == obj { 1433 | return true 1434 | } 1435 | } 1436 | 1437 | return false 1438 | } 1439 | --------------------------------------------------------------------------------