├── .ameba.yml ├── .editorconfig ├── .github ├── dependabot.yml └── workflows │ ├── crystal.yml │ └── documentation.yml ├── .gitignore ├── LICENSE ├── README.md ├── benchmark └── benchmark.cr ├── samples ├── geo_location.cr └── haversine_distance.cr ├── shard.yml ├── spec ├── kd_tree_spec.cr └── spec_helper.cr └── src ├── kd_tree.cr └── kd_tree └── version.cr /.ameba.yml: -------------------------------------------------------------------------------- 1 | Lint/NotNil: 2 | Excluded: 3 | - benchmark/benchmark.cr 4 | Enabled: true 5 | 6 | Naming/BlockParameterName: 7 | AllowNamesEndingInNumbers: true 8 | Enabled: false 9 | 10 | Lint/DebugCalls: 11 | Excluded: 12 | - samples/* 13 | Enabled: true 14 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*.cr] 4 | charset = utf-8 5 | end_of_line = lf 6 | insert_final_newline = true 7 | indent_style = space 8 | indent_size = 2 9 | trim_trailing_whitespace = true 10 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | 9 | # Maintain dependencies for GitHub Actions 10 | - package-ecosystem: "github-actions" 11 | directory: "/" 12 | schedule: 13 | interval: "weekly" 14 | -------------------------------------------------------------------------------- /.github/workflows/crystal.yml: -------------------------------------------------------------------------------- 1 | name: Crystal CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | schedule: 9 | - cron: '0 0 * * *' 10 | 11 | jobs: 12 | check_format: 13 | runs-on: ubuntu-latest 14 | container: 15 | image: crystallang/crystal 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Install dependencies 19 | run: shards install --ignore-crystal-version 20 | - name: Check format 21 | run: crystal tool format --check 22 | check_ameba: 23 | runs-on: ubuntu-latest 24 | container: 25 | image: crystallang/crystal 26 | steps: 27 | - uses: actions/checkout@v4 28 | - name: Install dependencies 29 | run: shards install --ignore-crystal-version 30 | - name: Check ameba 31 | run: ./bin/ameba 32 | test_latest: 33 | runs-on: ubuntu-latest 34 | container: 35 | image: crystallang/crystal 36 | steps: 37 | - uses: actions/checkout@v4 38 | - name: Install dependencies 39 | run: shards install --ignore-crystal-version 40 | - name: Run tests 41 | run: crystal spec 42 | test_nightly: 43 | runs-on: ubuntu-latest 44 | container: 45 | image: crystallang/crystal:nightly 46 | steps: 47 | - uses: actions/checkout@v4 48 | - name: Install dependencies 49 | run: shards install --ignore-crystal-version 50 | - name: Run tests 51 | run: crystal spec 52 | -------------------------------------------------------------------------------- /.github/workflows/documentation.yml: -------------------------------------------------------------------------------- 1 | name: website 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | 7 | jobs: 8 | publish: 9 | runs-on: ubuntu-latest 10 | container: 11 | image: crystallang/crystal 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Install dependencies 15 | run: shards install --ignore-crystal-version 16 | - name: Generate documentation 17 | run: crystal docs 18 | - 19 | name: Deploy to GitHub Pages 20 | if: success() 21 | uses: crazy-max/ghaction-github-pages@v4 22 | with: 23 | target_branch: gh-pages 24 | build_dir: docs 25 | env: 26 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /docs/ 2 | /lib/ 3 | /bin/ 4 | /.shards/ 5 | *.dwarf 6 | 7 | # Libraries don't need dependency lock 8 | # Dependencies will be locked in application that uses them 9 | /shard.lock 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018-2024 Anton Maminov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kd::Tree 2 | 3 | [![Crystal CI](https://github.com/geocrystal/kd_tree/actions/workflows/crystal.yml/badge.svg)](https://github.com/geocrystal/kd_tree/actions/workflows/crystal.yml) 4 | [![GitHub release](https://img.shields.io/github/release/geocrystal/kd_tree.svg)](https://github.com/geocrystal/kd_tree/releases) 5 | [![Docs](https://img.shields.io/badge/docs-available-brightgreen.svg)](https://geocrystal.github.io/kd_tree/) 6 | [![License](https://img.shields.io/github/license/geocrystal/kd_tree.svg)](https://github.com/geocrystal/kd_tree/blob/master/LICENSE) 7 | 8 | Crystal implementation of "K-Dimensional Tree" and "N-Nearest Neighbors" 9 | based on . 10 | 11 | ## Installation 12 | 13 | Add this to your application's `shard.yml`: 14 | 15 | ```yaml 16 | dependencies: 17 | kd_tree: 18 | github: geocrystal/kd_tree 19 | ``` 20 | 21 | ## Usage 22 | 23 | ```crystal 24 | require "kd_tree" 25 | ``` 26 | 27 | For example, construct a new tree where each point is represented as a two-dimensional array in the form [x, y], where x and y are numbers (such as Int32, Float64, etc). 28 | 29 | ```crystal 30 | kd = Kd::Tree(Array(Int32)).new(points) 31 | ``` 32 | 33 | Find the nearest point to `[x, y]`. Returns an array with one point: 34 | 35 | ```crystal 36 | kd.nearest([x, y]) 37 | ``` 38 | 39 | Find the nearest `k` points to `[x, y]`. Returns an array of points: 40 | 41 | ```crystal 42 | kd.nearest([x, y], k) 43 | ``` 44 | 45 | ## Example 46 | 47 | ```crystal 48 | require "kd_tree" 49 | 50 | points = [ 51 | [2.0, 3.0], 52 | [5.0, 4.0], 53 | [4.0, 7.0], 54 | [7.0, 2.0], 55 | [8.0, 1.0], 56 | [9.0, 6.0], 57 | ] 58 | 59 | kd = Kd::Tree(Array(Float64)).new(points) 60 | 61 | kd.nearest([1.0, 1.0]) 62 | # => [[2.0, 3.0]]) 63 | 64 | kd_tree.nearest([1.0, 1.0], 2) 65 | # => [[2.0, 3.0], [5.0, 4.0]]) 66 | ``` 67 | 68 | ### Complex objects 69 | 70 | `Kd::Tree(T)` can accept any object that responds to `#size` and `#[](i : Int)` methods. 71 | 72 | ```crystal 73 | class GeoLocation 74 | property name : String 75 | property longitude : Float64 76 | property latitude : Float64 77 | getter size = 2 # Assuming all GeoLocation objects are 2-dimensional 78 | 79 | def initialize(@name : String, @longitude : Float64, @latitude : Float64) 80 | end 81 | 82 | # Define an indexer to allow easy access by index for longitude and latitude 83 | def [](index : Int32) : Float64 84 | case index 85 | when 0 then @longitude 86 | when 1 then @latitude 87 | else raise "Index out of bounds" 88 | end 89 | end 90 | end 91 | 92 | # Create an array of GeoLocation points 93 | points = [ 94 | GeoLocation.new("New York", -73.935242, 40.730610), 95 | GeoLocation.new("Los Angeles", -118.243683, 34.052235), 96 | GeoLocation.new("London", -0.127647, 51.507322), 97 | GeoLocation.new("Tokyo", 139.691711, 35.689487), 98 | ] 99 | 100 | # Initialize the KD-tree with these points 101 | kd_tree = Kd::Tree(GeoLocation).new(points) 102 | 103 | # Find the nearest point to London 104 | target = GeoLocation.new("Near London", -0.125740, 51.508530) 105 | nearest_point = kd_tree.nearest(target, 1) 106 | puts "Nearest to London: #{nearest_point.first.name} (longitude #{nearest_point.first.longitude}, latitude #{nearest_point.first.latitude})" 107 | # Nearest to London: London (longitude -0.127647, latitude 51.507322) 108 | ``` 109 | 110 | ### Distance 111 | 112 | For distance calculations, the squared Euclidean distance is used. However, you can easily monkey-patch the `Kd::Tree#distance` method to implement another algorithm, such as the Haversine formula, to calculate distances between two points given their latitudes and longitudes. 113 | 114 | ```crystal 115 | require "haversine" 116 | 117 | module Kd 118 | class Tree(T) 119 | private def distance(m : T, n : T) 120 | # Calling `Haversine.distance` with 2 pairs of latitude/longitude coordinates. 121 | # Returns a distance in meters. 122 | Haversine.distance({m.latitude, m.longitude}, {n.latitude, n.longitude}).to_meters 123 | end 124 | end 125 | end 126 | 127 | points = [ 128 | GeoLocation.new("New York", -73.935242, 40.730610), 129 | GeoLocation.new("Los Angeles", -118.243683, 34.052235), 130 | GeoLocation.new("London", -0.127647, 51.507322), 131 | GeoLocation.new("Tokyo", 139.691711, 35.689487), 132 | ] 133 | 134 | kd_tree = Kd::Tree(GeoLocation).new(points) 135 | 136 | # Find the nearest point to London 137 | target = GeoLocation.new("Near London", -0.125740, 51.508530) 138 | nearest_point = kd_tree.nearest(target, 1) 139 | puts "Nearest to London: #{nearest_point.first.name} (longitude #{nearest_point.first.longitude}, latitude #{nearest_point.first.latitude})" 140 | # Nearest to London: London (longitude -0.127647, latitude 51.507322) 141 | ``` 142 | 143 | ## Performance 144 | 145 | Using a tree with 1 million points `[x, y] of Float64` on my Apple M1 Pro (10) @ 3.23 GHz: 146 | 147 | `crystal run benchmark/benchmark.cr --release` 148 | 149 | ```console 150 | Benchmarking KD-Tree with 1 million points 151 | user system total real 152 | build(init) 1.840140 0.021103 1.861243 ( 1.872732) 153 | nearest point 1 0.004484 0.000002 0.004486 ( 0.004490) 154 | nearest point 5 0.007391 0.000010 0.007401 ( 0.007479) 155 | nearest point 10 0.011406 0.000090 0.011496 ( 0.011679) 156 | nearest point 50 0.034097 0.000819 0.034916 ( 0.035175) 157 | nearest point 100 0.133828 0.003721 0.137549 ( 0.156548) 158 | nearest point 255 0.220200 0.000631 0.220831 ( 0.223081) 159 | nearest point 999 0.731941 0.000441 0.732382 ( 0.737236) 160 | ``` 161 | 162 | ## Contributing 163 | 164 | 1. Fork it () 165 | 2. Create your feature branch (`git checkout -b my-new-feature`) 166 | 3. Commit your changes (`git commit -am 'Add some feature'`) 167 | 4. Push to the branch (`git push origin my-new-feature`) 168 | 5. Create a new Pull Request 169 | 170 | ## Contributors 171 | 172 | - [mamantoha](https://github.com/mamantoha) Anton Maminov - creator, maintainer 173 | -------------------------------------------------------------------------------- /benchmark/benchmark.cr: -------------------------------------------------------------------------------- 1 | require "benchmark" 2 | require "../src/kd_tree" 3 | 4 | # Generate 1 million random points 5 | points = Array.new(1_000_000) { [rand * 100.0, rand * 100.0] } 6 | 7 | puts "Benchmarking KD-Tree with 1 million points" 8 | 9 | Benchmark.bm do |x| 10 | tree = nil 11 | 12 | x.report("build(init)") { 13 | tree = Kd::Tree(Array(Float64)).new(points) 14 | } 15 | 16 | [1, 5, 10, 50, 100, 255, 999].each do |n| 17 | x.report("nearest point #{n.to_s.rjust(3, ' ')}") do 18 | 1000.times do 19 | test_point = [rand * 100.0, rand * 100.0] 20 | 21 | tree.not_nil!.nearest(test_point, n) 22 | end 23 | end 24 | end 25 | end 26 | -------------------------------------------------------------------------------- /samples/geo_location.cr: -------------------------------------------------------------------------------- 1 | require "../src/kd_tree" 2 | 3 | class GeoLocation 4 | property name : String 5 | property longitude : Float64 6 | property latitude : Float64 7 | getter size = 2 # Assuming all GeoLocation objects are 2-dimensional 8 | 9 | def initialize(@name : String, @longitude : Float64, @latitude : Float64) 10 | end 11 | 12 | # Define an indexer to allow easy access by index for longitude and latitude 13 | def [](index : Int32) : Float64 14 | case index 15 | when 0 then @longitude 16 | when 1 then @latitude 17 | else raise "Index out of bounds" 18 | end 19 | end 20 | end 21 | 22 | # Example Usage: 23 | # Create an array of GeoLocation points 24 | points = [ 25 | GeoLocation.new("New York", -73.935242, 40.730610), 26 | GeoLocation.new("Los Angeles", -118.243683, 34.052235), 27 | GeoLocation.new("London", -0.127647, 51.507322), 28 | GeoLocation.new("Tokyo", 139.691711, 35.689487), 29 | ] 30 | 31 | # Initialize the KD-tree with these points 32 | kd_tree = Kd::Tree(GeoLocation).new(points) 33 | 34 | # Find the nearest point to London 35 | target = GeoLocation.new("Near London", -0.125740, 51.508530) 36 | nearest_point = kd_tree.nearest(target, 1) 37 | puts "Nearest to London: #{nearest_point.first.name} (longitude #{nearest_point.first.longitude}, latitude #{nearest_point.first.latitude})" 38 | -------------------------------------------------------------------------------- /samples/haversine_distance.cr: -------------------------------------------------------------------------------- 1 | require "haversine" 2 | require "../src/kd_tree" 3 | 4 | class GeoLocation 5 | property name : String 6 | property longitude : Float64 7 | property latitude : Float64 8 | getter size = 2 # Assuming all GeoLocation objects are 2-dimensional 9 | 10 | def initialize(@name : String, @longitude : Float64, @latitude : Float64) 11 | end 12 | 13 | # Define an indexer to allow easy access by index for longitude and latitude 14 | def [](index : Int32) : Float64 15 | case index 16 | when 0 then @longitude 17 | when 1 then @latitude 18 | else raise "Index out of bounds" 19 | end 20 | end 21 | end 22 | 23 | module Kd 24 | class Tree(T) 25 | private def distance(m : T, n : T) 26 | # Calling `Haversine.distance` with 2 pairs of latitude/longitude coordinates. 27 | # Returns a distance in meters. 28 | Haversine.distance({m.latitude, m.longitude}, {n.latitude, n.longitude}).to_meters 29 | end 30 | end 31 | end 32 | 33 | # Example Usage: 34 | # Create an array of GeoLocation points 35 | points = [ 36 | GeoLocation.new("New York", -73.935242, 40.730610), 37 | GeoLocation.new("Los Angeles", -118.243683, 34.052235), 38 | GeoLocation.new("London", -0.127647, 51.507322), 39 | GeoLocation.new("Paris", 2.349014, 48.864716), 40 | GeoLocation.new("Tokyo", 139.691711, 35.689487), 41 | ] 42 | 43 | # Initialize the KD-tree with these points 44 | kd_tree = Kd::Tree(GeoLocation).new(points) 45 | 46 | # Find the nearest point to London 47 | target = GeoLocation.new("Near London", -0.125740, 51.508530) 48 | nearest_point = kd_tree.nearest(target, 3) 49 | puts "First: #{nearest_point[0].name} (longitude #{nearest_point[0].longitude}, latitude #{nearest_point[0].latitude})" 50 | puts "Second: #{nearest_point[1].name} (longitude #{nearest_point[1].longitude}, latitude #{nearest_point[1].latitude})" 51 | puts "Third: #{nearest_point[2].name} (longitude #{nearest_point[2].longitude}, latitude #{nearest_point[2].latitude})" 52 | -------------------------------------------------------------------------------- /shard.yml: -------------------------------------------------------------------------------- 1 | name: kd_tree 2 | version: 0.6.0 3 | 4 | description: | 5 | Crystal implementation of "K-Dimensional Tree" and "N-Nearest Neighbors" 6 | 7 | authors: 8 | - Anton Maminov 9 | 10 | dependencies: 11 | priority-queue: 12 | github: spider-gazelle/priority-queue 13 | version: ">= 1.1.0" 14 | 15 | development_dependencies: 16 | ameba: 17 | github: crystal-ameba/ameba 18 | haversine: 19 | github: geocrystal/haversine 20 | 21 | crystal: ">= 1.0.0" 22 | 23 | license: MIT 24 | -------------------------------------------------------------------------------- /spec/kd_tree_spec.cr: -------------------------------------------------------------------------------- 1 | require "./spec_helper" 2 | 3 | class GeoLocation 4 | property name : String 5 | property longitude : Float64 6 | property latitude : Float64 7 | 8 | def initialize(@name : String, @longitude : Float64, @latitude : Float64) 9 | end 10 | 11 | # Define an indexer to allow easy access by index for longitude and latitude 12 | def [](index : Int32) : Float64 13 | case index 14 | when 0 then @longitude 15 | when 1 then @latitude 16 | else raise "Index out of bounds" 17 | end 18 | end 19 | 20 | # Assuming all GeoLocation objects are 2-dimensional 21 | def size 22 | 2 23 | end 24 | end 25 | 26 | describe Kd::Tree do 27 | describe "#initialize" do 28 | it "with Float64" do 29 | points = [[2.0, 3.0], [5.0, 4.0]] 30 | 31 | kd_tree = Kd::Tree(Array(Float64)).new(points) 32 | kd_tree.should be_a(Kd::Tree(Array(Float64))) 33 | end 34 | 35 | it "with Int32" do 36 | points = [[2, 3], [5, 4]] 37 | 38 | kd_tree = Kd::Tree(Array(Int32)).new(points) 39 | kd_tree.should be_a(Kd::Tree(Array(Int32))) 40 | end 41 | end 42 | 43 | describe "two-dimensional array" do 44 | describe "with Int32" do 45 | points = [ 46 | [2, 3], 47 | [5, 4], 48 | [4, 7], 49 | [7, 2], 50 | [8, 1], 51 | [9, 6], 52 | ] 53 | kd_tree = Kd::Tree(Array(Int32)).new(points) 54 | 55 | it "#nearest one" do 56 | res = kd_tree.nearest([1, 1]) 57 | res.should eq([[2, 3]]) 58 | end 59 | end 60 | 61 | describe "with negative" do 62 | points = [ 63 | [-1, -1], 64 | [0, 0], 65 | [5, 4], 66 | [4, 7], 67 | [7, 2], 68 | [8, 1], 69 | [9, 6], 70 | ] 71 | kd_tree = Kd::Tree(Array(Int32)).new(points) 72 | 73 | it "#nearest one" do 74 | res = kd_tree.nearest([-2, -2]) 75 | res.should eq([[-1, -1]]) 76 | end 77 | end 78 | 79 | points = [ 80 | [2.0, 3.0], 81 | [5.0, 4.0], 82 | [4.0, 7.0], 83 | [7.0, 2.0], 84 | [8.0, 1.0], 85 | [9.0, 6.0], 86 | ] 87 | kd_tree = Kd::Tree(Array(Float64)).new(points) 88 | 89 | it "have root" do 90 | kd_tree.root.should_not eq(nil) 91 | end 92 | 93 | it "#nearest one" do 94 | res = kd_tree.nearest([1.0, 1.0]) 95 | res.should eq([[2.0, 3.0]]) 96 | end 97 | 98 | it "#nearest many" do 99 | res = kd_tree.nearest([1.0, 1.0], 2) 100 | res.should eq([[2.0, 3.0], [5.0, 4.0]]) 101 | end 102 | 103 | it "#nearest too many" do 104 | res = kd_tree.nearest([1.0, 1.0], 100) 105 | res.size.should eq(points.size) 106 | end 107 | end 108 | 109 | describe "tree-dimensional array" do 110 | points = [ 111 | [2.0, 3.0, 0.0], 112 | [5.0, 4.0, 0.0], 113 | [4.0, 7.0, 0.0], 114 | [7.0, 2.0, 0.0], 115 | [8.0, 1.0, 0.0], 116 | [9.0, 6.0, 0.1], 117 | ] 118 | kd_tree = Kd::Tree(Array(Float64)).new(points) 119 | 120 | it "#nearest one" do 121 | res = kd_tree.nearest([1.0, 1.0, 0.0]) 122 | res.should eq([[2.0, 3.0, 0.0]]) 123 | end 124 | 125 | it "#nearest many" do 126 | res = kd_tree.nearest([1.0, 1.0, 0.0], 2) 127 | res.should eq([[2.0, 3.0, 0.0], [5.0, 4.0, 0.0]]) 128 | end 129 | end 130 | 131 | describe "#nearest" do 132 | # https://github.com/geocrystal/kd_tree/issues/2 133 | it "should equal naive implementation" do 134 | ndim = 2 135 | k = 3 136 | distance = ->(m : Array(Float64), n : Array(Float64)) do 137 | m.each_with_index.reduce(0) do |sum, (coord, index)| 138 | sum += (coord - n[index]) ** 2 139 | sum 140 | end 141 | end 142 | 143 | 10.times do 144 | points = Array.new(10) do 145 | Array.new(ndim) do 146 | rand(-10.0..10.0) 147 | end 148 | end 149 | kd_tree = Kd::Tree(Array(Float64)).new(points) 150 | target = Array.new(ndim) do 151 | rand(-11.0..11.0) 152 | end 153 | res = kd_tree.nearest(target, k) 154 | sorted = points.sort_by do |p| 155 | distance.call(p, target) 156 | end.reverse! 157 | (res - sorted[-k..]).should eq [] of Float64 158 | end 159 | end 160 | end 161 | 162 | describe "complex object" do 163 | # Create an array of GeoLocation points 164 | points = [ 165 | GeoLocation.new("New York", -73.935242, 40.730610), 166 | GeoLocation.new("Los Angeles", -118.243683, 34.052235), 167 | GeoLocation.new("London", -0.127647, 51.507322), 168 | GeoLocation.new("Tokyo", 139.691711, 35.689487), 169 | ] 170 | 171 | kd_tree = Kd::Tree(GeoLocation).new(points) 172 | 173 | # Find the nearest point to London 174 | target = GeoLocation.new("Near London", -0.125740, 51.508530) 175 | nearest_point = kd_tree.nearest(target, 1) 176 | nearest_point.first.name.should eq("London") 177 | end 178 | end 179 | -------------------------------------------------------------------------------- /spec/spec_helper.cr: -------------------------------------------------------------------------------- 1 | require "spec" 2 | require "../src/kd_tree" 3 | -------------------------------------------------------------------------------- /src/kd_tree.cr: -------------------------------------------------------------------------------- 1 | require "priority-queue" 2 | require "./kd_tree/*" 3 | 4 | module Kd 5 | # A generic KD-tree implementation where `T` is the type of the points. 6 | class Tree(T) 7 | # Represents a node in the KD-tree. Each node stores a pivot point, 8 | # the axis it splits, and references to its left and right children. 9 | class Node(T) 10 | getter pivot : T, split : Int32, left : Node(T)?, right : Node(T)? 11 | 12 | def initialize(@pivot : T, @split : Int32, @left : self?, @right : self?) 13 | end 14 | end 15 | 16 | getter root : Node(T)? # The root node of the KD-tree 17 | @k : Int32 # Dimensionality of the points 18 | 19 | # Constructor for the KD-tree. Takes an array of points of type T and builds the tree. 20 | def initialize(points : Array(T)) 21 | @k = points.first.size # Assumes all points have the same dimension 22 | @root = build_tree(points, 0) 23 | end 24 | 25 | # Recursive method to build the KD-tree from a given list of points. 26 | private def build_tree(points : Array(T), depth : Int32) : Node(T)? 27 | return if points.empty? 28 | 29 | axis = depth % @k # Determine the axis to split on based on the current depth 30 | points.sort_by!(&.[axis]) # Sort points by the current axis 31 | median = points.size // 2 # Find the median index 32 | 33 | right_subtree = build_tree(points[median + 1..], depth + 1) 34 | left_subtree = build_tree(points[...median], depth + 1) 35 | 36 | # Create a new Node with the median point as pivot, and recursively build the left and right subtrees. 37 | Node(T).new( 38 | points[median], 39 | axis, 40 | left_subtree, 41 | right_subtree, 42 | ) 43 | end 44 | 45 | # Method to find the nearest 'n' points to a given target point. Returns an array of these points. 46 | def nearest(target : T, n : Int32 = 1) : Array(T) 47 | return [] of T if n < 1 48 | 49 | best_nodes = Priority::Queue(Node(T)).new # Initialize a priority queue to store the best nodes found 50 | 51 | find_n_nearest(@root, target, 0, best_nodes, n) # Recursively find the nearest nodes 52 | 53 | best_nodes.map(&.value.pivot) # Extract the pivot points from the nodes and return them 54 | end 55 | 56 | # Recursive method to find the nearest nodes to a target point. 57 | private def find_n_nearest( 58 | node : Node(T)?, 59 | target : T, 60 | depth : Int32, 61 | best_nodes : Priority::Queue(Node(T)), 62 | n : Int32 63 | ) 64 | return unless node 65 | 66 | axis = depth % @k # Determine the axis to compare based on depth 67 | 68 | # Determine which child node to search next, prioritizing the side closer to the target 69 | next_node = target[axis] < node.pivot[axis] ? node.left : node.right 70 | other_node = target[axis] < node.pivot[axis] ? node.right : node.left 71 | 72 | # Recursively search the more likely side first 73 | find_n_nearest(next_node, target, depth + 1, best_nodes, n) 74 | 75 | # Calculate the distance from the target to the current node's pivot and add to the queue 76 | best_nodes.push(distance(target, node.pivot), node) 77 | 78 | # Ensure that only the 'n' closest nodes are kept in the queue 79 | best_nodes.pop if best_nodes.size > n 80 | 81 | # Check if the other side might contain closer points and potentially search there too 82 | if other_node && (best_nodes.size < n || (target[axis] - node.pivot[axis]).abs ** 2 < distance(target, best_nodes.last.value.pivot)) 83 | find_n_nearest(other_node, target, depth + 1, best_nodes, n) 84 | end 85 | end 86 | 87 | # Calculate squared Euclidean distance between two points of type T. 88 | private def distance(m : T, n : T) : Float64 89 | @k.times.sum { |i| (m[i] - n[i]) ** 2 }.to_f 90 | end 91 | end 92 | end 93 | -------------------------------------------------------------------------------- /src/kd_tree/version.cr: -------------------------------------------------------------------------------- 1 | module Kd 2 | VERSION = {{ `shards version #{__DIR__}`.chomp.stringify }} 3 | end 4 | --------------------------------------------------------------------------------