├── README.md ├── LICENSE ├── src ├── kdtree.hpp ├── _kdtree_base.hpp ├── _tuple.hpp ├── _wtspace.hpp ├── _l2space.hpp ├── _so3rlspace.hpp ├── _spaces.hpp ├── _kdtree_median.hpp ├── _compoundspace.hpp ├── _kdtree_midpoint.hpp ├── _so3altspace.hpp └── _so3space.hpp └── test ├── Makefile ├── spaces_test.cpp ├── state_sampler.hpp ├── benchmark.cpp ├── test.hpp └── kdtree_test.cpp /README.md: -------------------------------------------------------------------------------- 1 | # kdtree 2 | Exact nearest neighbor searching for various Euclidean, SO(3), SE(3) and weighted combinations thereof. 3 | 4 | This is an implementation of: 5 | 6 | Jeffrey Ichnowski, Ron Alterovitz, “Fast Nearest Neighbor Search in SE(3) for Sampling-Based Motion Planning,” Proc. Algorithmic Foundations of Robotics (WAFR), August 2014 7 | 8 | # Warning: 9 | This project is no longer maintained. It has been migrated to: https://github.com/UNC-Robotics/nigh. 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Jeff Ichnowski 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /src/kdtree.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Jeffrey Ichnowski 2 | // All rights reserved. 3 | // 4 | // BSD 3 Clause 5 | // 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions 8 | // are met: 9 | // 1. Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // 2. Redistributions in binary form must reproduce the above copyright 12 | // notice, this list of conditions and the following disclaimer in the 13 | // documentation and/or other materials provided with the distribution. 14 | // 3. Neither the name of the copyright holder nor the names of its 15 | // contributors may be used to endorse or promote products derived 16 | // from this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 21 | // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 22 | // COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 23 | // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 26 | // HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 27 | // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 29 | // OF THE POSSIBILITY OF SUCH DAMAGE. 30 | 31 | #pragma once 32 | #ifndef UNC_ROBOTICS_KDTREE_KDTREE_HPP 33 | #define UNC_ROBOTICS_KDTREE_KDTREE_HPP 34 | 35 | #include "_spaces.hpp" 36 | #include "_kdtree_base.hpp" 37 | #include "_kdtree_midpoint.hpp" 38 | #include "_kdtree_median.hpp" 39 | 40 | #endif // UNC_ROBOTICS_KDTREE_KDTREE_BASE_HPP 41 | -------------------------------------------------------------------------------- /test/Makefile: -------------------------------------------------------------------------------- 1 | TARGET = build 2 | SHELL = bash 3 | testsrc := $(wildcard *_test.cpp) 4 | testexe := $(patsubst %.cpp,$(TARGET)/%,$(testsrc)) 5 | 6 | CXXFLAGS += -std=c++14 -I../src -Wall -pedantic -Wno-ignored-attributes -O3 -g -stdlib=libc++ 7 | PKG_CONFIG ?= pkg-config 8 | 9 | # Check if pkg-config is available, if so, use it to check for dependencies 10 | ifneq ($(shell command -v $(PKG_CONFIG) 2>/dev/null),) 11 | 12 | ifeq ($(shell $(PKG_CONFIG) --exists eigen3 && echo 1),1) 13 | CXXFLAGS += `pkg-config --cflags eigen3` 14 | else 15 | MISSING += eigen3 16 | endif 17 | 18 | # else if pkg-config is not available, test for dependencies directly 19 | else # no pkg_config, manually test 20 | 21 | ifeq ($(shell echo $$'\#include \nint main(){}' | $(CXX) $(CXXFLAGS) -c -o/dev/null -xc++ - 2>/dev/null || echo 0),0) 22 | MISSING += eigen3 23 | endif 24 | 25 | endif 26 | #end of dependencies checking 27 | 28 | # check results of dependency checks 29 | ifeq ($(MISSING),) 30 | all: run_tests 31 | else 32 | all: 33 | @echo "MISSING DEPENDENCIES" 34 | @echo 35 | @echo "Your system is missing the following dependencies, please install the following" 36 | @echo "dependencies and make again" 37 | @echo 38 | @echo " $(MISSING)" 39 | endif 40 | 41 | 42 | # remove default rule 43 | %: %.cpp 44 | 45 | # Compile directly to executable (no intermediate .o files) 46 | $(TARGET)/%: %.cpp $(TARGET)/%.d 47 | @echo "Compiling" $@ 48 | @mkdir -p $(@D) 49 | @$(CXX) -MMD -MF $(patsubst %.cpp,$(TARGET)/%.dtmp,$<) $(CXXFLAGS) -o $@ $< 50 | @mv $(patsubst %.cpp,$(TARGET)/%.dtmp,$<) $(patsubst %.cpp,$(TARGET)/%.d,$<) 51 | @touch $@ 52 | 53 | # Run each test. If successful, touch the .success file to mark its success 54 | $(TARGET)/%.success: $(TARGET)/% 55 | @echo "RUNNING TEST " $(patsubst $(TARGET)/%.success,%,$@) 56 | @set -o pipefail ; $< | tee $<.log && touch $@ 57 | 58 | %.d: ; 59 | .PRECIOUS: %.d 60 | 61 | .PHONY: all run_tests compile_tests clean benchmark 62 | 63 | run_tests: $(patsubst %.cpp,$(TARGET)/%.success,$(testsrc)) 64 | 65 | compile_tests: $(testexe) 66 | 67 | benchmark: $(TARGET)/benchmark 68 | $(TARGET)/benchmark | gnuplot 69 | 70 | clean: 71 | $(RM) -r $(TARGET) 72 | 73 | # Include generated dependencies 74 | -include $(patsubst %.cpp,$(TARGET)/%.d,$(testsrc)) 75 | -include $(TARGET)/benchmark.d 76 | -------------------------------------------------------------------------------- /test/spaces_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../src/_spaces.hpp" 3 | #include "test.hpp" 4 | 5 | TEST_CASE(L2Distance) { 6 | using namespace unc::robotics::kdtree; 7 | 8 | typedef L2Space Space; 9 | typedef Space::State State; 10 | 11 | Space space; 12 | State a(1.2, -3.1); 13 | State b(5.1, 6.7); 14 | 15 | EXPECT(space.distance(a, b)) == std::sqrt( 16 | std::pow(5.1 - 1.2, 2) + 17 | std::pow(6.7 + 3.1, 2)); 18 | } 19 | 20 | TEST_CASE(SO3Distance) { 21 | using namespace unc::robotics::kdtree; 22 | 23 | typedef SO3Space Space; 24 | typedef Space::State State; 25 | 26 | Space space; 27 | 28 | State a(1, 0, 0, 0); 29 | State b(0, 1, 0, 0); 30 | 31 | EXPECT(space.distance(a, b)) == M_PI_2; 32 | 33 | State c(std::sin(M_PI/6), std::cos(M_PI/6), 0, 0); 34 | 35 | EXPECT(std::abs(space.distance(a, c) - M_PI/3)) < 1e-13; 36 | EXPECT(std::abs(space.distance(b, c) - M_PI/6)) < 1e-13; 37 | } 38 | 39 | TEST_CASE(RatioWeightedDistance) { 40 | using namespace unc::robotics::kdtree; 41 | 42 | typedef RatioWeightedSpace, std::ratio<17, 3>> Space; 43 | typedef Space::State State; 44 | 45 | Space space; 46 | State a(1.2, -3.1); 47 | State b(5.1, 6.7); 48 | 49 | EXPECT(space.distance(a, b)) == std::sqrt( 50 | std::pow(5.1 - 1.2, 2) + 51 | std::pow(6.7 + 3.1, 2)) * 17 / 3; 52 | } 53 | 54 | TEST_CASE(SE3Distance) { 55 | using namespace unc::robotics::kdtree; 56 | 57 | typedef SE3Space Space; 58 | typedef Space::State State; 59 | 60 | Space space; 61 | // ( 62 | // (RatioWeightedSpace>(SO3Space())), 63 | // (RatioWeightedSpace>(L2Space()))); 64 | 65 | // State a(SO3Space::State(1, 0, 0, 0), 66 | // L2Space::State(-1.2, 3.4, 5.6)); 67 | // State b(SO3Space::State(0, 1, 0, 0), 68 | // L2Space::State(9.8, -7.6, 5.4)); 69 | 70 | State a({1, 0, 0, 0}, {-1.2, 3.4, 5.6}); 71 | State b({0, 1, 0, 0}, {9.8, -7.6, 5.4}); 72 | 73 | EXPECT(std::abs(space.distance(a, b) - (std::sqrt( 74 | std::pow(9.8 + 1.2, 2) + 75 | std::pow(3.4 + 7.6, 2) + 76 | std::pow(5.6 - 5.4, 2))*3 + M_PI_2*5))) < 1e-9 ; 77 | } 78 | 79 | TEST_CASE(MixedScalarCompound) { 80 | using namespace unc::robotics::kdtree; 81 | 82 | typedef L2Space FloatSpace; 83 | typedef L2Space DoubleSpace; 84 | 85 | EXPECT((std::is_same::Distance>::value)) == true; 86 | EXPECT((std::is_same::Distance>::value)) == true; 87 | EXPECT((std::is_same::Distance>::value)) == true; 88 | } 89 | -------------------------------------------------------------------------------- /test/state_sampler.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | template 4 | struct StateSampler {}; 5 | 6 | template 7 | struct StateSampler> { 8 | typedef unc::robotics::kdtree::BoundedL2Space<_Scalar, _dimensions> Space; 9 | typedef typename Space::State State; 10 | 11 | template 12 | static State randomState(_RNG& rng, const Space& space) { 13 | State q; 14 | for (int i=0 ; i dist(space.bounds(i, 0), space.bounds(i, 1)); 16 | q[i] = dist(rng); 17 | } 18 | return q; 19 | } 20 | }; 21 | 22 | template 23 | struct StateSampler> { 24 | typedef unc::robotics::kdtree::L2Space<_Scalar, _dimensions> Space; 25 | typedef typename Space::State State; 26 | 27 | template 28 | static State randomState(_RNG& rng, const Space& space) { 29 | State q; 30 | std::uniform_real_distribution<_Scalar> dist(-50, 50); 31 | for (int i=0 ; i 39 | struct StateSampler> { 40 | typedef unc::robotics::kdtree::SO3Space<_Scalar> Space; 41 | template 42 | static typename Space::State 43 | randomState(_RNG& rng, const _Space&) { 44 | typename Space::State q; 45 | std::uniform_real_distribution<_Scalar> dist01(0, 1); 46 | std::uniform_real_distribution<_Scalar> dist2pi(0, 2*M_PI); 47 | _Scalar a = dist01(rng); 48 | _Scalar b = dist2pi(rng); 49 | _Scalar c = dist2pi(rng); 50 | 51 | return Eigen::Quaternion<_Scalar>( 52 | std::sqrt(1-a)*std::sin(b), 53 | std::sqrt(1-a)*std::cos(b), 54 | std::sqrt(a)*std::sin(c), 55 | std::sqrt(a)*std::cos(c)); 56 | } 57 | }; 58 | 59 | template 60 | struct StateSampler> 61 | : StateSampler> 62 | {}; 63 | 64 | template 65 | struct StateSampler> 66 | : StateSampler> 67 | {}; 68 | 69 | template 70 | struct StateSampler> { 71 | typedef unc::robotics::kdtree::RatioWeightedSpace<_Space, _Ratio> Space; 72 | 73 | template 74 | static typename Space::State 75 | randomState(_RNG& rng, const Space& space) { 76 | return StateSampler<_Space>::randomState(rng, space); 77 | } 78 | }; 79 | 80 | template 81 | struct StateSampler> { 82 | typedef unc::robotics::kdtree::WeightedSpace<_Space> Space; 83 | 84 | template 85 | static typename Space::State 86 | randomState(_RNG& rng, const Space& space) { 87 | return StateSampler<_Space>::randomState(rng, space); 88 | } 89 | }; 90 | 91 | template 92 | struct StateSampler> { 93 | typedef unc::robotics::kdtree::CompoundSpace<_Spaces...> Space; 94 | 95 | template 96 | static typename Space::State 97 | compoundRandomState(_RNG& rng, const Space& space, std::index_sequence) 98 | { 99 | return typename Space::State( 100 | StateSampler>::type> 101 | ::randomState(rng, std::get(space))...); 102 | } 103 | 104 | template 105 | static typename Space::State 106 | randomState(_RNG& rng, const Space& space) { 107 | return compoundRandomState(rng, space, std::make_index_sequence{}); 108 | } 109 | }; 110 | -------------------------------------------------------------------------------- /src/_kdtree_base.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Jeffrey Ichnowski 2 | // All rights reserved. 3 | // 4 | // BSD 3 Clause 5 | // 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions 8 | // are met: 9 | // 1. Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // 2. Redistributions in binary form must reproduce the above copyright 12 | // notice, this list of conditions and the following disclaimer in the 13 | // documentation and/or other materials provided with the distribution. 14 | // 3. Neither the name of the copyright holder nor the names of its 15 | // contributors may be used to endorse or promote products derived 16 | // from this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 21 | // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 22 | // COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 23 | // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 26 | // HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 27 | // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 29 | // OF THE POSSIBILITY OF SUCH DAMAGE. 30 | #pragma once 31 | #ifndef UNC_ROBOTICS_KDTREE_KDTREE_BASE_HPP 32 | #define UNC_ROBOTICS_KDTREE_KDTREE_BASE_HPP 33 | 34 | #include "_spaces.hpp" 35 | 36 | namespace unc { namespace robotics { namespace kdtree { 37 | 38 | struct MidpointSplit {}; 39 | struct MedianSplit {}; 40 | 41 | struct DynamicBuild {}; 42 | struct StaticBuild {}; 43 | 44 | struct SingleThread {}; 45 | struct MultiThread {}; 46 | 47 | template < 48 | typename _T, 49 | typename _Space, 50 | typename _GetKey, 51 | typename _SplitStrategy, 52 | typename _Construction = DynamicBuild, 53 | typename _Locking = SingleThread, 54 | typename _Allocator = std::allocator<_T>> 55 | struct KDTree; 56 | 57 | namespace detail { 58 | template 59 | struct MidpointAddTraversal; 60 | template 61 | struct MidpointNearestTraversal; 62 | template 63 | struct MedianAccum; 64 | template 65 | struct MedianNearestTraversal; 66 | 67 | struct CompareSecond { 68 | template 69 | constexpr bool operator() (const std::pair<_First,_Second>& a, const std::pair<_First,_Second>& b) const { 70 | return a.second < b.second; 71 | } 72 | }; 73 | 74 | // helper to enable builtin wrapper for long types. The problem this 75 | // resolves is that `int` and `long` may be the same type on some 76 | // systems and not on others. 77 | template 78 | struct enable_builtin_long { 79 | static constexpr bool value = std::is_integral::value 80 | && sizeof(T)==sizeof(long) && sizeof(T) != sizeof(int); 81 | }; 82 | 83 | // helper to enable builtin wrapper for long long types. The problem 84 | // this resolves is that `long` and `long long` may be the same type 85 | // on some systems and not on others. 86 | template 87 | struct enable_builtin_long_long { 88 | static constexpr bool value = std::is_integral::value 89 | && sizeof(T)==sizeof(long long) && sizeof(T) != sizeof(long); 90 | }; 91 | 92 | // clz returns the number of leading 0-bits in argument, starting with 93 | // the most significant bit. If x is 0, the result is undefined. The 94 | // builtins make use of processor instructions, and are defined for 95 | // unsigned, unsigned long, and unsigned long long types. 96 | constexpr int clz(unsigned x) { return __builtin_clz(x); } 97 | 98 | template 99 | constexpr typename std::enable_if::value, int>::type 100 | clz(T x) { return __builtin_clzl(x); } 101 | 102 | template 103 | constexpr typename std::enable_if::value, int>::type 104 | clz(T x) { return __builtin_clzll(x); } 105 | 106 | template 107 | constexpr int log2(UInt x) { return sizeof(x)*8 - 1 - clz(x); } 108 | 109 | 110 | template 111 | struct AllocatorDestructor { 112 | typedef std::allocator_traits<_Allocator> AllocatorTraits; 113 | typedef typename AllocatorTraits::pointer pointer; 114 | typedef typename AllocatorTraits::size_type size_type; 115 | 116 | _Allocator& allocator_; 117 | size_type count_; 118 | 119 | inline AllocatorDestructor(_Allocator& allocator, size_type count) 120 | : allocator_(allocator), 121 | count_(count) 122 | { 123 | } 124 | 125 | inline void operator() (pointer p) { 126 | AllocatorTraits::deallocate(allocator_, p, count_); 127 | } 128 | }; 129 | 130 | }}}} 131 | 132 | #include "_l2space.hpp" 133 | #include "_so3space.hpp" 134 | #include "_so3altspace.hpp" 135 | #include "_so3rlspace.hpp" 136 | #include "_wtspace.hpp" 137 | #include "_compoundspace.hpp" 138 | 139 | #endif // UNC_ROBOTICS_KDTREE_KDTREE_BASE_HPP 140 | -------------------------------------------------------------------------------- /src/_tuple.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Jeffrey Ichnowski 2 | // All rights reserved. 3 | // 4 | // BSD 3 Clause 5 | // 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions 8 | // are met: 9 | // 1. Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // 2. Redistributions in binary form must reproduce the above copyright 12 | // notice, this list of conditions and the following disclaimer in the 13 | // documentation and/or other materials provided with the distribution. 14 | // 3. Neither the name of the copyright holder nor the names of its 15 | // contributors may be used to endorse or promote products derived 16 | // from this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 21 | // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 22 | // COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 23 | // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 26 | // HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 27 | // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 29 | // OF THE POSSIBILITY OF SUCH DAMAGE. 30 | #pragma once 31 | #ifndef UNC_ROBOTICS_KDTREE_TUPLE_HPP 32 | #define UNC_ROBOTICS_KDTREE_TUPLE_HPP 33 | 34 | // helper methods for dealing with tuples used by CompoundSpace. 35 | 36 | namespace unc { 37 | namespace robotics { 38 | namespace kdtree { 39 | namespace detail { 40 | 41 | // computes the result type of a compound sum 42 | // e.g. float + double => double 43 | template 44 | struct SumResultType { typedef _T type; }; 45 | template 46 | struct SumResultType<_T, _U, _Rest...> { typedef typename SumResultType::type type; }; 47 | 48 | 49 | template 50 | constexpr decltype(auto) reduceArgs(_Fn&& fn, _T a) { return a; } 51 | template 52 | constexpr decltype(auto) reduceArgs(_Fn&& fn, _First&& a, _Second&& b, _Rest&& ... args) { 53 | return reduceArgs(std::forward<_Fn>(fn), std::forward<_Fn>(fn)(a, b), args...); 54 | } 55 | template 56 | constexpr decltype(auto) reduce_impl(_Fn&& fn, _Tuple&& args, std::index_sequence) { 57 | return reduceArgs(std::forward<_Fn>(fn), std::get(std::forward<_Tuple>(args))...); 58 | } 59 | template 60 | constexpr decltype(auto) reduce(_Fn&& fn, _Tuple&& args) { 61 | return reduce_impl( 62 | std::forward<_Fn>(fn), 63 | std::forward<_Tuple>(args), 64 | std::make_index_sequence::type>::value>{}); 65 | } 66 | template 67 | constexpr decltype(auto) sum(_First&& a, _Rest&& ... args) { 68 | return reduce(std::plus::type>(), a, args...); 69 | } 70 | template 71 | constexpr decltype(auto) sum(std::tuple<_T...>&& args) { 72 | return reduce(std::plus::type>(), args); 73 | } 74 | template 75 | constexpr decltype(auto) map_impl(_Fn&& fn, _Tuple&& args, std::index_sequence) { 76 | return std::make_tuple(std::forward<_Fn>(fn)(std::get(std::forward<_Tuple>(args)))...); 77 | } 78 | template 79 | constexpr decltype(auto) map(_Fn&& fn, const std::tuple<_T...>& args) { 80 | return map_impl(std::forward<_Fn>(fn), args, std::make_index_sequence{}); 81 | } 82 | template 83 | constexpr decltype(auto) apply_impl(_Fn&& fn, _Tuple&& args, std::index_sequence) { 84 | return std::forward<_Fn>(fn)(std::get(std::forward<_Tuple>(args))...); 85 | } 86 | template 87 | constexpr decltype(auto) apply(_Fn&& fn, _Tuple&& t) { 88 | return apply_impl( 89 | std::forward<_Fn>(fn), 90 | std::forward<_Tuple>(t), 91 | std::make_index_sequence::value>{}); 92 | } 93 | template 94 | constexpr decltype(auto) slice_impl(_Tuple&& tuple, std::index_sequence) { 95 | return std::make_tuple(std::get(std::get(std::forward<_Tuple>(tuple)))...); 96 | } 97 | 98 | // Returns a tuple containing the get of each element of the argument 99 | template 100 | constexpr decltype(auto) slice(_Tuple&& tuple) { 101 | return slice_impl( 102 | std::forward<_Tuple>(tuple), 103 | std::make_index_sequence::value>{}); 104 | } 105 | template 106 | constexpr decltype(auto) zip_impl(_Fn&& fn, _Args&& args, std::index_sequence) { 107 | return std::make_tuple( 108 | apply(std::forward<_Fn>(fn), slice(std::forward<_Args>(args)))...); 109 | } 110 | template 111 | constexpr decltype(auto) zip(_Fn&& fn, _First&& first, _Rest&& ... rest) { 112 | return zip_impl( 113 | fn, 114 | std::make_tuple(std::forward<_First>(first), std::forward<_Rest>(rest)...), 115 | std::make_index_sequence::type>::value>{}); 116 | } 117 | 118 | 119 | } 120 | } 121 | } 122 | } 123 | 124 | 125 | #endif // UNC_ROBOTICS_KDTREE_TUPLE_HPP 126 | -------------------------------------------------------------------------------- /test/benchmark.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "../src/kdtree.hpp" 5 | #include "state_sampler.hpp" 6 | 7 | template 8 | struct Node { 9 | _State key_; 10 | 11 | Node(const _State& key) 12 | : key_(key) 13 | { 14 | } 15 | }; 16 | 17 | struct GetKey { 18 | template 19 | constexpr const _State& operator() (const Node<_State>& state) const { 20 | return state.key_; 21 | } 22 | }; 23 | 24 | template 25 | void printStats(std::size_t size, Duration elapsed, unsigned queryCount, double overhead) { 26 | auto elapsedNanos = std::chrono::duration(elapsed).count() 27 | - static_cast(queryCount * overhead); 28 | 29 | std::cout << size << "\t" 30 | << elapsedNanos << "\t" 31 | << queryCount << "\t" 32 | << elapsedNanos / (queryCount * 1e3) << std::endl; 33 | } 34 | 35 | 36 | template 37 | void benchmark( 38 | const std::string& name, 39 | const _Space& space, 40 | std::size_t N, std::size_t Q, std::size_t k, 41 | _StepDuration stepDuration, 42 | double stepsPerExp, 43 | double overhead, 44 | const _SplitStrategy&, 45 | const _Locking&) 46 | { 47 | using namespace unc::robotics::kdtree; 48 | 49 | typedef typename _Space::Distance Distance; 50 | typedef typename _Space::State Key; 51 | 52 | constexpr std::size_t nTrees = 16; 53 | 54 | std::vector, _Space, GetKey, _SplitStrategy, DynamicBuild, _Locking>> trees; 55 | trees.reserve(nTrees); 56 | for (std::size_t i=0 ; i> nodes; 61 | nodes.reserve(N); 62 | 63 | typedef std::pair, Distance> NodeDist; 64 | std::vector> nearest; 65 | nearest.reserve(k+1); 66 | 67 | std::vector queries; 68 | queries.reserve(Q); 69 | for (std::size_t j=0 ; j::randomState(rng, space)); 71 | 72 | typedef std::chrono::high_resolution_clock Clock; 73 | 74 | Clock::duration elapsed{}; 75 | unsigned queryCount = 0; 76 | unsigned treeNo = 0; 77 | 78 | std::size_t nextStat = 100; 79 | unsigned statCounter = static_cast(std::log(nextStat) / std::log(10) * stepsPerExp); 80 | Clock::duration timePerSize = Clock::duration(stepDuration) / (nextStat - trees[0].size()); 81 | 82 | for (std::size_t i=1 ; i<=N ; ++i) { 83 | nodes.emplace_back(StateSampler<_Space>::randomState(rng, space)); 84 | for (std::size_t j=0 ; j= nextStat) { 96 | printStats(trees[0].size(), elapsed, queryCount, overhead); 97 | elapsed = Clock::duration::zero(); 98 | queryCount = 0; 99 | do { 100 | nextStat = static_cast(std::pow(10, ++statCounter/stepsPerExp) + 0.5); 101 | } while (nextStat <= i); 102 | timePerSize = Clock::duration(stepDuration) / (nextStat - trees[0].size()); 103 | } 104 | } 105 | 106 | if (queryCount > 0) 107 | printStats(trees[0].size(), elapsed, queryCount, overhead); 108 | } 109 | 110 | template 111 | double measureNow(_Duration checkTime) { 112 | typedef std::chrono::high_resolution_clock Clock; 113 | Clock::duration clockCheckTime(checkTime); 114 | Clock::duration elapsed; 115 | unsigned count = 0; 116 | Clock::time_point start = Clock::now(); 117 | do { 118 | ++count; 119 | } while ((elapsed = Clock::now() - start) < clockCheckTime); 120 | 121 | double elapsedNanos = std::chrono::duration(elapsed).count(); 122 | 123 | std::cout << "# Clock overhead: " << elapsedNanos / 1e6 << " ms over " 124 | << count << " calls = " << elapsedNanos / count << " ns/call" 125 | << std::endl; 126 | 127 | return elapsedNanos / count; 128 | } 129 | 130 | int main(int argc, char *argv[]) { 131 | using namespace unc::robotics::kdtree; 132 | using namespace std::literals::chrono_literals; 133 | 134 | std::size_t N = 100000; 135 | std::size_t Q = 10000; 136 | std::size_t k = 20; 137 | 138 | auto stepTime = 25ms; 139 | double overhead = measureNow(1s); 140 | double steps = 250; 141 | 142 | std::time_t tm = std::time(nullptr); 143 | std::cout << 144 | "set title '" << std::put_time(std::localtime(&tm), "%c") << "'\n" 145 | "set logscale x\n" 146 | "set key top left\n" 147 | // "plot '-' u 1:4 w lines title 'L2(3) Midpoint', " 148 | // " '-' u 1:4 w lines title 'L2(3) Median'\n"; 149 | "plot '-' u 1:4 w lines title 'SO(3) Midpoint', " 150 | " '-' u 1:4 w lines title 'SO(3) Median'\n"; 151 | 152 | Eigen::Array bounds; 153 | bounds.col(0) = -1.0; 154 | bounds.col(1) = 1.0; 155 | 156 | // benchmark("SO(3)", BoundedL2Space(bounds), N, Q, k, stepTime, steps, overhead, MidpointSplit{}, SingleThread{}); 157 | // std::cout << "e" << std::endl; 158 | // benchmark("SO(3)", BoundedL2Space(bounds), N, Q, k, stepTime, steps, overhead, MedianSplit{}, SingleThread{}); 159 | // std::cout << "e" << std::endl; 160 | 161 | benchmark("SO(3)", SO3Space(), N, Q, k, stepTime, steps, overhead, MidpointSplit{}, SingleThread{}); 162 | std::cout << "e" << std::endl; 163 | benchmark("SO(3)", SO3Space(), N, Q, k, stepTime, steps, overhead, MedianSplit{}, SingleThread{}); 164 | std::cout << "e" << std::endl; 165 | 166 | return 0; 167 | } 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /src/_wtspace.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Jeffrey Ichnowski 2 | // All rights reserved. 3 | // 4 | // BSD 3 Clause 5 | // 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions 8 | // are met: 9 | // 1. Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // 2. Redistributions in binary form must reproduce the above copyright 12 | // notice, this list of conditions and the following disclaimer in the 13 | // documentation and/or other materials provided with the distribution. 14 | // 3. Neither the name of the copyright holder nor the names of its 15 | // contributors may be used to endorse or promote products derived 16 | // from this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 21 | // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 22 | // COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 23 | // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 26 | // HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 27 | // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 29 | // OF THE POSSIBILITY OF SUCH DAMAGE. 30 | #pragma once 31 | #ifndef UNC_ROBOTICS_KDTREE_WTSPACE_HPP 32 | #define UNC_ROBOTICS_KDTREE_WTSPACE_HPP 33 | 34 | namespace unc { namespace robotics { namespace kdtree { namespace detail { 35 | 36 | template 37 | struct MidpointAddTraversal<_Node, RatioWeightedSpace<_Space, _Ratio>> 38 | : MidpointAddTraversal<_Node, _Space> 39 | { 40 | typedef RatioWeightedSpace<_Space, _Ratio> Space; 41 | typedef typename Space::State State; 42 | typedef typename Space::Distance Distance; 43 | 44 | // inherit constructor 45 | using MidpointAddTraversal<_Node, _Space>::MidpointAddTraversal; 46 | 47 | template 48 | constexpr Distance keyDistance(const _State& q) const { 49 | return MidpointAddTraversal<_Node, _Space>::keyDistance(q) * _Ratio::num / _Ratio::den; 50 | } 51 | 52 | constexpr Distance maxAxis(unsigned *axis) const { 53 | return MidpointAddTraversal<_Node, _Space>::maxAxis(axis) * _Ratio::num / _Ratio::den; 54 | } 55 | }; 56 | 57 | template 58 | struct MidpointAddTraversal<_Node, WeightedSpace<_Space>> 59 | : MidpointAddTraversal<_Node, _Space> 60 | { 61 | typedef WeightedSpace<_Space> Space; 62 | typedef typename Space::State State; 63 | typedef typename Space::Distance Distance; 64 | 65 | Distance weight_; 66 | 67 | MidpointAddTraversal(const Space& space, const State& key) 68 | : MidpointAddTraversal<_Node, _Space>(space, key), 69 | weight_(space.weight()) 70 | { 71 | } 72 | 73 | template 74 | constexpr Distance keyDistance(const _State& q) const { 75 | return MidpointAddTraversal<_Node, _Space>::keyDistance(q) * weight_; 76 | } 77 | 78 | constexpr Distance maxAxis(unsigned *axis) const { 79 | return MidpointAddTraversal<_Node, _Space>::maxAxis(axis) * weight_; 80 | } 81 | }; 82 | 83 | template 84 | struct MidpointNearestTraversal<_Node, RatioWeightedSpace<_Space, _Ratio>> 85 | : MidpointNearestTraversal<_Node, _Space> 86 | { 87 | typedef RatioWeightedSpace<_Space, _Ratio> Space; 88 | typedef typename Space::State State; 89 | typedef typename Space::Distance Distance; 90 | 91 | // inherit constructor 92 | using MidpointNearestTraversal<_Node, _Space>::MidpointNearestTraversal; 93 | 94 | // TODO: keyDistance and maxAxis implementations are duplicated 95 | // with MidpointAddTraversal. Would be nice to merge them. 96 | template 97 | constexpr Distance keyDistance(const _State& q) const { 98 | return MidpointNearestTraversal<_Node, _Space>::keyDistance(q) * _Ratio::num / _Ratio::den; 99 | } 100 | 101 | constexpr Distance maxAxis(unsigned *axis) const { 102 | return MidpointNearestTraversal<_Node, _Space>::maxAxis(axis) * _Ratio::num / _Ratio::den; 103 | } 104 | 105 | constexpr Distance distToRegion() const { 106 | return MidpointNearestTraversal<_Node, _Space>::distToRegion() * _Ratio::num / _Ratio::den; 107 | } 108 | 109 | // template 110 | // inline void traverse(_Nearest& nearest, const _Node* n, unsigned axis) { 111 | // MidpointNearestTraversal<_Node, _Space>::traverse(nearest, n, axis); 112 | // } 113 | }; 114 | 115 | template 116 | struct MidpointNearestTraversal<_Node, WeightedSpace<_Space>> 117 | : MidpointNearestTraversal<_Node, _Space> 118 | { 119 | typedef WeightedSpace<_Space> Space; 120 | typedef typename Space::State State; 121 | typedef typename Space::Distance Distance; 122 | 123 | Distance weight_; 124 | 125 | MidpointNearestTraversal(const Space& space, const State& key) 126 | : MidpointNearestTraversal<_Node, _Space>(space, key), 127 | weight_(space.weight()) 128 | { 129 | } 130 | 131 | // TODO: keyDistance and maxAxis implementations are duplicated 132 | // with MidpointAddTraversal. Would be nice to merge them. 133 | template 134 | constexpr Distance keyDistance(const _State& q) const { 135 | return MidpointNearestTraversal<_Node, _Space>::keyDistance(q) * weight_; 136 | } 137 | 138 | constexpr Distance maxAxis(unsigned *axis) const { 139 | return MidpointNearestTraversal<_Node, _Space>::maxAxis(axis) * weight_; 140 | } 141 | 142 | constexpr Distance distToRegion() const { 143 | return MidpointNearestTraversal<_Node, _Space>::distToRegion() * weight_; 144 | } 145 | 146 | // template 147 | // void traverse(_Nearest& nearest, const _Node* n, unsigned axis) { 148 | // MidpointNearestTraversal<_Node, _Space>::traverse(nearest, n, axis); 149 | // } 150 | }; 151 | 152 | template 153 | struct MedianAccum> 154 | : MedianAccum<_Space> 155 | { 156 | typedef MedianAccum<_Space> Base; 157 | typedef typename _Space::Distance Distance; 158 | 159 | using Base::Base; 160 | 161 | constexpr Distance maxAxis(unsigned *axis) const { 162 | return Base::maxAxis(axis) * _Ratio::num / _Ratio::den; 163 | } 164 | }; 165 | 166 | template 167 | struct MedianNearestTraversal> 168 | : MedianNearestTraversal<_Space> 169 | { 170 | typedef MedianNearestTraversal<_Space> Base; 171 | typedef typename _Space::State Key; 172 | typedef typename _Space::Distance Distance; 173 | 174 | using Base::Base; 175 | 176 | constexpr Distance distToRegion() const { 177 | return Base::distToRegion() * _Ratio::num / _Ratio::den; 178 | } 179 | 180 | template 181 | constexpr Distance keyDistance(const _Key& key) const { 182 | return Base::keyDistance(key) * _Ratio::num / _Ratio::den; 183 | } 184 | }; 185 | 186 | template 187 | struct MedianAccum> 188 | : MedianAccum<_Space> 189 | { 190 | typedef MedianAccum<_Space> Base; 191 | typedef typename _Space::Distance Distance; 192 | 193 | Distance weight_; 194 | 195 | MedianAccum(const WeightedSpace<_Space>& space) 196 | : Base(space), 197 | weight_(space.weight()) 198 | { 199 | } 200 | 201 | constexpr Distance maxAxis(unsigned *axis) const { 202 | return Base::maxAxis(axis) * weight_; 203 | } 204 | }; 205 | 206 | template 207 | struct MedianNearestTraversal> 208 | : MedianNearestTraversal<_Space> 209 | { 210 | typedef MedianNearestTraversal<_Space> Base; 211 | typedef typename _Space::State Key; 212 | typedef typename _Space::Distance Distance; 213 | 214 | Distance weight_; 215 | 216 | MedianNearestTraversal( 217 | const WeightedSpace<_Space>& space, 218 | const Key& key) 219 | : Base(space, key), 220 | weight_(space.weight()) 221 | { 222 | } 223 | 224 | constexpr Distance distToRegion() const { 225 | return Base::distToRegion() * weight_; 226 | } 227 | 228 | template 229 | constexpr Distance keyDistance(const _Key& key) const { 230 | return Base::keyDistance(key) * weight_; 231 | } 232 | 233 | using Base::traverse; 234 | }; 235 | 236 | 237 | 238 | }}}} 239 | 240 | #endif // UNC_ROBOTICS_KDTREE_WTSPACE_HPP 241 | -------------------------------------------------------------------------------- /test/test.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace test { 11 | 12 | static std::atomic_uintmax_t g_assertionCount(0); 13 | 14 | template 15 | struct PrintWrap { 16 | const _T& value_; 17 | 18 | PrintWrap(const _T& v) : value_(v) {} 19 | 20 | template 21 | inline friend std::basic_ostream<_Char, _Traits>& 22 | operator << (std::basic_ostream<_Char, _Traits>& os, const PrintWrap& w) { 23 | return os << w.value_; 24 | } 25 | }; 26 | 27 | template <> 28 | struct PrintWrap { 29 | PrintWrap(std::nullptr_t v) {} 30 | 31 | template 32 | inline friend std::basic_ostream<_Char, _Traits>& 33 | operator << (std::basic_ostream<_Char, _Traits>& os, const PrintWrap& w) { 34 | return os << "nullptr"; 35 | } 36 | }; 37 | 38 | template <> 39 | struct PrintWrap { 40 | bool value_; 41 | PrintWrap(bool v) : value_(v) {} 42 | template 43 | inline friend std::basic_ostream<_Char, _Traits>& 44 | operator << (std::basic_ostream<_Char, _Traits>& os, const PrintWrap& w) { 45 | // could also use std::boolalpha 46 | return os << (w.value_ ? "true" : "false"); 47 | } 48 | }; 49 | 50 | // helper to print elements of a tuple 51 | template 52 | struct PrintTupleElements { 53 | typedef std::tuple<_Types...> Tuple; 54 | 55 | template 56 | static void apply(std::basic_ostream<_Char, _Traits>& os, const Tuple& t) { 57 | if (_index > 0) os << ", "; 58 | os << PrintWrap::type>(std::get<_index>(t)); 59 | PrintTupleElements<_index+1, _Types...>::apply(os, t); 60 | } 61 | }; 62 | 63 | // helper to print elements of a tuple, base case. 64 | template 65 | struct PrintTupleElements { 66 | typedef std::tuple<_Types...> Tuple; 67 | 68 | template 69 | static void apply(std::basic_ostream<_Char, _Traits>& os, const Tuple& t) {} 70 | }; 71 | 72 | // tuples get printed as [<0>, <1>, ...], where is the element at 73 | // index . 74 | template 75 | struct PrintWrap> { 76 | typedef std::tuple<_Types...> Tuple; 77 | 78 | const Tuple& value_; 79 | PrintWrap(const Tuple& v) : value_(v) {} 80 | 81 | template 82 | inline friend std::basic_ostream<_Char, _Traits>& 83 | operator << (std::basic_ostream<_Char, _Traits>& os, const PrintWrap& w) { 84 | os << "["; 85 | PrintTupleElements<0, _Types...>::apply(os, w.value_); 86 | return os << "]"; 87 | } 88 | }; 89 | 90 | 91 | template 92 | PrintWrap<_T> printWrap(const _T& t) { return PrintWrap<_T>(t); } 93 | 94 | template 95 | struct Expectation { 96 | _T value_; 97 | 98 | const char *expr_; 99 | const char *file_; 100 | int line_; 101 | 102 | mutable bool checked_; 103 | 104 | Expectation(const _T& value, const char *expr, const char *file, int line) 105 | : value_(value), expr_(expr), file_(file), line_(line), checked_(false) 106 | { 107 | } 108 | 109 | Expectation(_T&& value, const char *expr, const char *file, int line) 110 | : value_(std::move(value)), expr_(expr), file_(file), line_(line), checked_(false) 111 | { 112 | } 113 | 114 | ~Expectation() { assert(checked_); } 115 | 116 | template 117 | void fail(const char *op, const _E& expect) const { 118 | std::ostringstream str; 119 | str << "Expected " << expr_ << op << printWrap(expect) 120 | << ", got " << printWrap(value_) << " at " << file_ << ':' << line_; 121 | throw std::runtime_error(str.str()); 122 | } 123 | 124 | #define DEFINE_OP(_op_) \ 125 | template \ 126 | void operator _op_ (const _E& expect) const { \ 127 | assert(!checked_); \ 128 | checked_ = true; \ 129 | ++g_assertionCount; \ 130 | if (!(value_ _op_ expect)) \ 131 | fail(" " #_op_ " ", expect); \ 132 | } 133 | 134 | DEFINE_OP(==) 135 | DEFINE_OP(!=) 136 | DEFINE_OP(<) 137 | DEFINE_OP(>) 138 | DEFINE_OP(<=) 139 | DEFINE_OP(>=) 140 | #undef DEFINE_OP 141 | }; 142 | 143 | class TestCase; 144 | 145 | std::vector g_testCases; 146 | 147 | class TestCase { 148 | std::string name_; 149 | 150 | public: 151 | TestCase(const std::string& name) : name_(name) { 152 | g_testCases.push_back(this); 153 | } 154 | 155 | const std::string& name() const { 156 | return name_; 157 | } 158 | 159 | virtual void testImpl() = 0; 160 | 161 | template 162 | static void failed(const std::string& name, std::uintmax_t nAsserts, const _Reason& reason) { 163 | std::ostringstream msg; 164 | msg.imbue(std::locale("")); 165 | msg << name << " \33[31;1mfailed ⚠\33[0m after " 166 | << nAsserts << " assertion" << (nAsserts == 1 ? "" : "s") << "\n\t" 167 | << reason << "\n"; 168 | std::cout << msg.str() << std::flush; 169 | } 170 | 171 | bool run() { 172 | auto assertionsBefore = g_assertionCount.load(); 173 | try { 174 | auto start = std::chrono::high_resolution_clock::now(); 175 | testImpl(); 176 | auto elapsed = std::chrono::high_resolution_clock::now() - start; 177 | auto nAsserts = g_assertionCount.load() - assertionsBefore; 178 | std::ostringstream msg; 179 | msg.imbue(std::locale("")); 180 | msg << name_ << " \33[32mpassed ✓\33[0m (" 181 | << (nAsserts ? "" : "\33[31m") 182 | << nAsserts << " assertion" << (nAsserts == 1 ? "" : "s") 183 | << (nAsserts ? "" : "\33[0m") 184 | << ", " << std::chrono::duration(elapsed).count() 185 | << " ms)\n"; 186 | std::cout << msg.str() << std::flush; 187 | return true; 188 | } catch (const std::runtime_error& e) { 189 | failed(name_, g_assertionCount.load() - assertionsBefore - 1, e.what()); 190 | return false; 191 | } 192 | } 193 | }; 194 | 195 | } // namespace test 196 | 197 | #define EXPECT(expr) (::test::Expectation::type>(expr, #expr, __FILE__, __LINE__)) 198 | 199 | #define TEST_CASE(name) \ 200 | struct test_case_ ## name : public ::test::TestCase { \ 201 | test_case_ ## name () : TestCase(#name) {} \ 202 | void testImpl(); \ 203 | }; \ 204 | test_case_ ## name test_instance_ ## name; \ 205 | void test_case_ ## name :: testImpl() 206 | 207 | 208 | // template 209 | // bool runTest(const std::string& name, _Fn fn) { 210 | // auto assertionsBefore = g_assertionCount.load(); 211 | // try { 212 | // fn(); 213 | // auto nAsserts = g_assertionCount.load() - assertionsBefore; 214 | // std::cout << name << " \33[32mpassed\33[0m (" 215 | // << nAsserts << " assertion" << (nAsserts == 1 ? "" : "s") 216 | // << ")" << std::endl; 217 | // return true; 218 | // } catch (const std::runtime_error& e) { 219 | // std::cerr << name << " \33[31;1mfailed.\33[0m\n\t" << e.what() << std::endl; 220 | // return false; 221 | // } 222 | // } 223 | 224 | int main(int argc, char* argv[]) { 225 | using namespace test; 226 | 227 | int testCases; 228 | int passed = 0; 229 | if (argc > 1) { 230 | testCases = argc - 1; 231 | for (int i=1 ; iname() == name; }); 236 | 237 | if (it == g_testCases.end()) { 238 | TestCase::failed(name, 0, "no test found with matching name"); 239 | } else { 240 | passed += (*it)->run(); 241 | } 242 | } 243 | } else { 244 | testCases = test::g_testCases.size(); 245 | for (test::TestCase *test : test::g_testCases) { 246 | passed += test->run(); 247 | } 248 | } 249 | 250 | std::cout << passed << " of " 251 | << testCases << " test" << (testCases == 1?"":"s") << " passed." 252 | << std::endl; 253 | return testCases == passed ? 0 : 1; 254 | } 255 | 256 | -------------------------------------------------------------------------------- /src/_l2space.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Jeffrey Ichnowski 2 | // All rights reserved. 3 | // 4 | // BSD 3 Clause 5 | // 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions 8 | // are met: 9 | // 1. Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // 2. Redistributions in binary form must reproduce the above copyright 12 | // notice, this list of conditions and the following disclaimer in the 13 | // documentation and/or other materials provided with the distribution. 14 | // 3. Neither the name of the copyright holder nor the names of its 15 | // contributors may be used to endorse or promote products derived 16 | // from this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 21 | // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 22 | // COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 23 | // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 26 | // HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 27 | // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 29 | // OF THE POSSIBILITY OF SUCH DAMAGE. 30 | #pragma once 31 | #ifndef UNC_ROBOTICS_KDTREE_L2SPACE_HPP 32 | #define UNC_ROBOTICS_KDTREE_L2SPACE_HPP 33 | 34 | namespace unc { namespace robotics { namespace kdtree { namespace detail { 35 | 36 | template 37 | struct MidpointBoundedL2TraversalBase { 38 | typedef BoundedL2Space<_Scalar, _dimensions> Space; 39 | typedef typename Space::State Key; 40 | 41 | const Key key_; 42 | Eigen::Array<_Scalar, _dimensions, 2> bounds_; 43 | 44 | template 45 | inline MidpointBoundedL2TraversalBase(const Space& space, const Eigen::MatrixBase<_Derived>& key) 46 | : key_(key), bounds_(space.bounds()) 47 | { 48 | } 49 | 50 | constexpr unsigned dimensions() const { 51 | return _dimensions; 52 | } 53 | }; 54 | 55 | template 56 | struct MidpointBoundedL2TraversalBase<_Scalar, Eigen::Dynamic> { 57 | typedef BoundedL2Space<_Scalar, Eigen::Dynamic> Space; 58 | typedef typename Space::State Key; 59 | 60 | const Key& key_; 61 | Eigen::Array<_Scalar, Eigen::Dynamic, 2> bounds_; 62 | unsigned dimensions_; 63 | 64 | template 65 | inline MidpointBoundedL2TraversalBase(const Space& space, const Eigen::MatrixBase<_Derived>& key) 66 | : key_(key), 67 | bounds_(space.bounds()), 68 | dimensions_(space.dimensions()) 69 | { 70 | } 71 | 72 | constexpr unsigned dimensions() const { 73 | return dimensions_; 74 | } 75 | }; 76 | 77 | 78 | template 79 | struct MidpointAddTraversal<_Node, BoundedL2Space<_Scalar, _dimensions>> 80 | : MidpointBoundedL2TraversalBase<_Scalar, _dimensions> 81 | { 82 | typedef BoundedL2Space<_Scalar, _dimensions> Space; 83 | typedef typename Space::State Key; 84 | 85 | using MidpointBoundedL2TraversalBase<_Scalar, _dimensions>::bounds_; 86 | using MidpointBoundedL2TraversalBase<_Scalar, _dimensions>::MidpointBoundedL2TraversalBase; 87 | 88 | constexpr _Scalar maxAxis(unsigned *axis) const { 89 | return (this->bounds_.col(1) - this->bounds_.col(0)).maxCoeff(axis); 90 | } 91 | 92 | template 93 | void addImpl(_Adder& adder, unsigned axis, _Node* p, _Node *n) { 94 | _Scalar split = (bounds_(axis, 0) + bounds_(axis, 1)) * 0.5; 95 | int childNo = (split - this->key_[axis]) < 0; 96 | _Node* c = _Adder::child(p, childNo); 97 | while (c == nullptr) 98 | if (_Adder::update(p, childNo, c, n)) 99 | return; 100 | 101 | bounds_(axis, 1-childNo) = split; 102 | adder(c, n); 103 | } 104 | }; 105 | 106 | template 107 | struct MidpointNearestTraversal<_Node, BoundedL2Space<_Scalar, _dimensions>> 108 | : MidpointBoundedL2TraversalBase<_Scalar, _dimensions> 109 | { 110 | typedef BoundedL2Space<_Scalar, _dimensions> Space; 111 | typedef typename Space::State Key; 112 | typedef typename Space::Distance Distance; 113 | 114 | using MidpointBoundedL2TraversalBase<_Scalar, _dimensions>::key_; 115 | using MidpointBoundedL2TraversalBase<_Scalar, _dimensions>::bounds_; 116 | 117 | _Scalar distToRegionCache_ = 0; 118 | _Scalar distToRegionSum_ = 0; 119 | Eigen::Array<_Scalar, _dimensions, 1> regionDeltas_; 120 | 121 | template 122 | MidpointNearestTraversal(const Space& space, const Eigen::MatrixBase<_Derived>& key) 123 | : MidpointBoundedL2TraversalBase<_Scalar, _dimensions>(space, key), 124 | regionDeltas_(space.dimensions(), 1) 125 | { 126 | regionDeltas_.setZero(); 127 | } 128 | 129 | template 130 | constexpr _Scalar keyDistance(const Eigen::MatrixBase<_Derived>& q) const { 131 | return (this->key_ - q).norm(); 132 | } 133 | 134 | constexpr Distance distToRegion() const { 135 | return distToRegionCache_; // std::sqrt(regionDeltas_.sum()); 136 | } 137 | 138 | template 139 | inline void traverse(_Nearest& nearest, const _Node* n, unsigned axis) { 140 | _Scalar split = (bounds_(axis, 0) + bounds_(axis, 1)) * 0.5; 141 | _Scalar delta = (split - key_[axis]); 142 | int childNo = delta < 0; 143 | 144 | if (const _Node* c = _Nearest::child(n, childNo)) { 145 | std::swap(bounds_(axis, 1-childNo), split); 146 | nearest(c); 147 | std::swap(bounds_(axis, 1-childNo), split); 148 | } 149 | 150 | nearest.update(n); 151 | 152 | if (const _Node* c = _Nearest::child(n, 1-childNo)) { 153 | Distance oldDelta = regionDeltas_[axis]; 154 | Distance oldSum = distToRegionSum_; 155 | Distance oldDist = distToRegionCache_; 156 | delta *= delta; 157 | regionDeltas_[axis] = delta; 158 | distToRegionSum_ = distToRegionSum_ - oldDelta + delta; 159 | distToRegionCache_ = std::sqrt(distToRegionSum_); 160 | if (nearest.shouldTraverse()) { 161 | std::swap(bounds_(axis, childNo), split); 162 | nearest(c); 163 | std::swap(bounds_(axis, childNo), split); 164 | } 165 | regionDeltas_[axis] = oldDelta; 166 | distToRegionSum_ = oldSum; 167 | distToRegionCache_ = oldDist; 168 | } 169 | } 170 | }; 171 | 172 | template 173 | struct MedianAccum> { 174 | typedef _Scalar Scalar; 175 | typedef L2Space Space; 176 | 177 | Eigen::Array<_Scalar, _dimensions, 1> min_; 178 | Eigen::Array<_Scalar, _dimensions, 1> max_; 179 | 180 | inline MedianAccum(const Space& space) 181 | : min_(space.dimensions()), 182 | max_(space.dimensions()) 183 | { 184 | } 185 | 186 | constexpr unsigned dimensions() const { 187 | return min_.rows(); 188 | } 189 | 190 | template 191 | void init(const Eigen::MatrixBase<_Derived>& q) { 192 | min_ = q; 193 | max_ = q; 194 | } 195 | 196 | template 197 | void accum(const Eigen::MatrixBase<_Derived>& q) { 198 | min_ = min_.min(q.array()); 199 | max_ = max_.max(q.array()); 200 | } 201 | 202 | constexpr Scalar maxAxis(unsigned *axis) const { 203 | return (max_ - min_).maxCoeff(axis); 204 | } 205 | 206 | template 207 | void partition(_Builder& builder, unsigned axis, _Iter begin, _Iter end, const _GetKey& getKey) { 208 | _Iter mid = begin + (std::distance(begin, end)-1)/2; 209 | std::nth_element(begin, mid, end, [&] (auto& a, auto& b) { 210 | return getKey(a)[axis] < getKey(b)[axis]; 211 | }); 212 | std::iter_swap(begin, mid); 213 | 214 | _Builder::setSplit(*begin, getKey(*begin)[axis]); 215 | // begin->split_ = getKey(*begin)[axis]; 216 | 217 | builder(++begin, ++mid); 218 | builder(mid, end); 219 | } 220 | }; 221 | 222 | template 223 | struct MedianAccum> 224 | : MedianAccum> 225 | { 226 | using MedianAccum>::MedianAccum; 227 | }; 228 | 229 | template 230 | struct MedianNearestTraversal> { 231 | typedef L2Space<_Scalar, _dimensions> Space; 232 | typedef typename Space::State Key; 233 | typedef typename Space::Distance Distance; 234 | 235 | const Key key_; 236 | 237 | Eigen::Array<_Scalar, _dimensions, 1> regionDeltas_; 238 | 239 | template 240 | MedianNearestTraversal(const Space& space, const Eigen::MatrixBase<_Derived>& key) 241 | : key_(key), 242 | regionDeltas_(space.dimensions()) 243 | { 244 | regionDeltas_.setZero(); 245 | } 246 | 247 | constexpr unsigned dimensions() const { 248 | return regionDeltas_.rows(); 249 | } 250 | 251 | Distance distToRegion() const { 252 | return std::sqrt(regionDeltas_.sum()); 253 | } 254 | 255 | template 256 | Distance keyDistance(const Eigen::MatrixBase<_Derived>& q) const { 257 | return (key_ - q).norm(); 258 | } 259 | 260 | template 261 | void traverse(_Nearest& nearest, unsigned axis, _Iter begin, _Iter end) { 262 | const auto& n = *begin++; 263 | std::array<_Iter, 3> iters{{begin, begin + std::distance(begin, end)/2, end}}; 264 | Distance delta = _Nearest::split(n) - key_[axis]; 265 | int childNo = delta < 0; 266 | nearest(iters[childNo], iters[childNo+1]); 267 | nearest.update(n); 268 | delta *= delta; 269 | std::swap(regionDeltas_[axis], delta); 270 | if (nearest.shouldTraverse()) 271 | nearest(iters[1-childNo], iters[2-childNo]); 272 | regionDeltas_[axis] = delta; 273 | } 274 | }; 275 | 276 | template 277 | struct MedianNearestTraversal> 278 | : MedianNearestTraversal> 279 | { 280 | using MedianNearestTraversal>::MedianNearestTraversal; 281 | }; 282 | 283 | }}}} 284 | 285 | #endif // UNC_ROBOTICS_KDTREE_L2SPACE_HPP 286 | -------------------------------------------------------------------------------- /src/_so3rlspace.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Jeffrey Ichnowski 2 | // All rights reserved. 3 | // 4 | // BSD 3 Clause 5 | // 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions 8 | // are met: 9 | // 1. Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // 2. Redistributions in binary form must reproduce the above copyright 12 | // notice, this list of conditions and the following disclaimer in the 13 | // documentation and/or other materials provided with the distribution. 14 | // 3. Neither the name of the copyright holder nor the names of its 15 | // contributors may be used to endorse or promote products derived 16 | // from this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 21 | // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 22 | // COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 23 | // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 26 | // HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 27 | // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 29 | // OF THE POSSIBILITY OF SUCH DAMAGE. 30 | #pragma once 31 | #ifndef UNC_ROBOTICS_KDTREE_SO3RLSPACE_HPP 32 | #define UNC_ROBOTICS_KDTREE_SO3RLSPACE_HPP 33 | 34 | namespace unc { namespace robotics { namespace kdtree { namespace detail { 35 | 36 | template 37 | struct MidpointSO3RLTraversalBase { 38 | typedef SO3RLSpace<_Scalar> Space; 39 | typedef typename Space::State Key; 40 | 41 | Key key_; 42 | Eigen::Array<_Scalar, 4, 2> bounds_; 43 | 44 | MidpointSO3RLTraversalBase(const Space& space, const Key& key) 45 | : key_(key) 46 | { 47 | bounds_.col(0) = -1; 48 | bounds_.col(1) = 1; 49 | } 50 | 51 | constexpr unsigned dimensions() const { 52 | return 4; 53 | } 54 | }; 55 | 56 | template 57 | struct MidpointAddTraversal<_Node, SO3RLSpace<_Scalar>> 58 | : MidpointSO3RLTraversalBase<_Scalar> 59 | { 60 | typedef _Scalar Scalar; 61 | typedef SO3RLSpace Space; 62 | typedef typename Space::State Key; 63 | 64 | using MidpointSO3RLTraversalBase<_Scalar>::bounds_; 65 | using MidpointSO3RLTraversalBase<_Scalar>::MidpointSO3RLTraversalBase; 66 | 67 | constexpr _Scalar maxAxis(unsigned *axis) const { 68 | return (bounds_.col(1) - bounds_.col(0)).maxCoeff(axis) * M_PI_2; 69 | } 70 | 71 | template 72 | void addImpl(_Adder& adder, unsigned axis, _Node* p, _Node *n) { 73 | _Scalar split = (bounds_(axis, 0) + bounds_(axis, 1)) * 0.5; 74 | int childNo = split < this->key_.coeffs()[axis]; 75 | _Node *c = _Adder::child(p, childNo); 76 | while (c == nullptr) 77 | if (_Adder::update(p, childNo, c, n)) 78 | return; 79 | bounds_(axis, 1-childNo) = split; 80 | adder(c, n); 81 | } 82 | }; 83 | 84 | template 85 | _Scalar distSide2(_Scalar min, _Scalar pt, _Scalar max) { 86 | _Scalar d; 87 | if (pt < min) { 88 | d = min - pt; 89 | } else if (pt < max) { 90 | return 0; 91 | } else { 92 | d = max - pt; 93 | } 94 | return d*d; 95 | } 96 | 97 | template 98 | inline auto distPtRect(const Eigen::DenseBase<_Derived1>& min, 99 | const Eigen::DenseBase<_Derived2>& max, 100 | const Eigen::DenseBase<_Derived3>& q) 101 | { 102 | return distSide2(min[0], q[0], max[0]) 103 | + distSide2(min[1], q[1], max[1]) 104 | + distSide2(min[2], q[2], max[2]) 105 | + distSide2(min[3], q[3], max[3]); 106 | } 107 | 108 | 109 | template 110 | auto so3RLdistPointRect( 111 | const Eigen::ArrayBase<_Min>& min, 112 | const Eigen::ArrayBase<_Max>& max, 113 | const Eigen::QuaternionBase<_Split>& split) 114 | { 115 | const auto& pt = split.coeffs().array(); 116 | 117 | // -2 1 118 | // -3 => 1 119 | // 3 => 2 120 | // 0 => 0 121 | // 122 | // (-2 - -3).max(-3 - 1).max(0) = 1 123 | // (-2 - 3).max( 3 - 1).max(0) = 2 124 | // (-2 - 0).max( 0 - 1).max(0) = 0 125 | 126 | // std::cout << (min - pt).max(pt - max).max(0).transpose() << std::endl 127 | // << (min + pt).max(-pt - max).max(0).transpose() << std::endl; 128 | 129 | auto r = std::min( 130 | (min - pt).max(pt - max).max(0).matrix().squaredNorm(), 131 | (min + pt).max(-pt -max).max(0).matrix().squaredNorm()); 132 | 133 | // auto c = std::min(distPtRect(min, max, pt), 134 | // distPtRect(min, max, -pt)); 135 | 136 | 137 | // if (std::abs(r - c) > 1e-5) { 138 | // std::cout << std::abs(r - c) << std::endl; 139 | // abort(); 140 | // } 141 | 142 | return r; 143 | } 144 | 145 | template 146 | auto so3RLdistPointRect( 147 | const Eigen::ArrayBase<_Bounds>& bounds, 148 | const Eigen::QuaternionBase<_Split>& split) 149 | { 150 | return so3RLdistPointRect(bounds.col(0), bounds.col(1), split); 151 | } 152 | 153 | // TODO: use this for bounds update instead of std::swap or similar. 154 | template 155 | struct PushVal { 156 | _Scalar& var_; 157 | _Scalar prev_; 158 | 159 | PushVal(_Scalar& v, _Scalar n) : var_(v), prev_(v) { var_ = n; } 160 | ~PushVal() { var_ = prev_; } 161 | }; 162 | 163 | template 164 | struct MidpointNearestTraversal<_Node, SO3RLSpace<_Scalar>> 165 | : MidpointSO3RLTraversalBase<_Scalar> 166 | { 167 | typedef SO3RLSpace<_Scalar> Space; 168 | typedef typename Space::State Key; 169 | typedef typename Space::Distance Distance; 170 | 171 | Distance distToRegionCache_ = 0; 172 | 173 | using MidpointSO3RLTraversalBase<_Scalar>::key_; 174 | using MidpointSO3RLTraversalBase<_Scalar>::bounds_; 175 | 176 | template 177 | MidpointNearestTraversal(const Space& space, const Eigen::QuaternionBase<_Derived>& key) 178 | : MidpointSO3RLTraversalBase<_Scalar>(space, key) 179 | { 180 | } 181 | 182 | template 183 | constexpr Distance keyDistance(const Eigen::QuaternionBase<_Derived>& q) const { 184 | _Scalar dot = std::abs(key_.coeffs().matrix().dot(q.coeffs().matrix())); 185 | return dot < 0 ? M_PI_2 : dot > 1 ? 0 : std::acos(dot); 186 | } 187 | 188 | constexpr Distance distToRegion() const { 189 | return distToRegionCache_; 190 | } 191 | 192 | template 193 | inline void traverse(_Nearest& nearest, const _Node* n, unsigned axis) { 194 | _Scalar split = (bounds_(axis, 0) + bounds_(axis, 1)) * 0.5; 195 | 196 | std::swap(bounds_(axis, 1), split); 197 | Distance d0 = so3RLdistPointRect(bounds_, key_); 198 | std::swap(bounds_(axis, 1), split); 199 | 200 | std::swap(bounds_(axis, 0), split); 201 | Distance d1 = so3RLdistPointRect(bounds_, key_); 202 | std::swap(bounds_(axis, 0), split); 203 | 204 | int childNo = d0 > d1; 205 | 206 | if (const _Node* c = _Nearest::child(n, childNo)) { 207 | std::swap(bounds_(axis, 1-childNo), split); 208 | nearest(c); 209 | std::swap(bounds_(axis, 1-childNo), split); 210 | } 211 | 212 | nearest.update(n); 213 | 214 | if (const _Node* c = _Nearest::child(n, 1-childNo)) { 215 | Distance oldDist = distToRegionCache_; 216 | distToRegionCache_ = oldDist + std::abs(d1 - d0); 217 | if (nearest.shouldTraverse()) { 218 | std::swap(bounds_(axis, childNo), split); 219 | nearest(c); 220 | std::swap(bounds_(axis, childNo), split); 221 | } 222 | distToRegionCache_ = oldDist; 223 | } 224 | } 225 | }; 226 | 227 | template 228 | struct MedianAccum> { 229 | typedef _Scalar Scalar; 230 | typedef SO3RLSpace Space; 231 | 232 | Eigen::Array min_; 233 | Eigen::Array max_; 234 | 235 | inline MedianAccum(const Space& space) { 236 | min_ = -1; 237 | max_ = 1; 238 | } 239 | 240 | constexpr unsigned dimensions() const { 241 | return 4; 242 | } 243 | 244 | template 245 | void init(const Eigen::QuaternionBase<_Derived>& q) { 246 | min_ = q.coeffs(); 247 | max_ = q.coeffs(); 248 | } 249 | 250 | template 251 | void accum(const Eigen::QuaternionBase<_Derived>& q) { 252 | min_ = min_.min(q.coeffs().array()); 253 | max_ = max_.max(q.coeffs().array()); 254 | } 255 | 256 | constexpr Scalar maxAxis(unsigned *axis) const { 257 | return (max_ - min_).maxCoeff(axis); 258 | } 259 | 260 | template 261 | void partition(_Builder& builder, unsigned axis, _Iter begin, _Iter end, const _GetKey& getKey) { 262 | _Iter mid = begin + (std::distance(begin, end) - 1)/2; 263 | std::nth_element(begin, mid, end, [&] (auto& a, auto& b) { 264 | return getKey(a).coeffs()[axis] < getKey(b).coeffs()[axis]; 265 | }); 266 | std::iter_swap(begin, mid); 267 | _Builder::setSplit(*begin, getKey(*begin).coeffs()[axis]); 268 | builder(++begin, ++mid); 269 | builder(mid, end); 270 | } 271 | }; 272 | 273 | template 274 | struct MedianNearestTraversal> { 275 | typedef SO3RLSpace<_Scalar> Space; 276 | typedef typename Space::State Key; 277 | typedef typename Space::Distance Distance; 278 | 279 | const Key key_; 280 | Eigen::Array<_Scalar, 4, 2> bounds_; 281 | Distance distToRegionCache_; 282 | 283 | template 284 | MedianNearestTraversal(const Space& space, const Eigen::QuaternionBase<_Derived>& key) 285 | : key_(key), distToRegionCache_(0) 286 | { 287 | bounds_.col(0) = -1; 288 | bounds_.col(1) = 1; 289 | } 290 | 291 | constexpr unsigned dimensions() const { 292 | return 4; 293 | } 294 | 295 | Distance distToRegion() const { 296 | return distToRegionCache_; 297 | } 298 | 299 | template 300 | Distance keyDistance(const Eigen::QuaternionBase<_Derived>& q) const { 301 | _Scalar dot = std::abs(key_.coeffs().matrix().dot(q.coeffs().matrix())); 302 | return dot < 0 ? M_PI_2 : dot > 1 ? 0 : std::acos(dot); 303 | } 304 | 305 | template 306 | void traverse(_Nearest& nearest, unsigned axis, _Iter begin, _Iter end) { 307 | const auto& n = *begin++; 308 | Distance split = _Nearest::split(n); 309 | std::array<_Iter, 3> iters{{begin, begin + std::distance(begin,end)/2, end}}; 310 | 311 | 312 | std::swap(bounds_(axis, 1), split); 313 | Distance d0 = so3RLdistPointRect(bounds_, key_); 314 | std::swap(bounds_(axis, 1), split); 315 | 316 | std::swap(bounds_(axis, 0), split); 317 | Distance d1 = so3RLdistPointRect(bounds_, key_); 318 | std::swap(bounds_(axis, 0), split); 319 | 320 | int childNo = d0 > d1; 321 | 322 | std::swap(bounds_(axis, 1-childNo), split); 323 | nearest(iters[childNo], iters[childNo+1]); 324 | std::swap(bounds_(axis, 1-childNo), split); 325 | 326 | nearest.update(n); 327 | 328 | Distance oldDist = distToRegionCache_; 329 | distToRegionCache_ += std::abs(d1 - d0); 330 | if (nearest.shouldTraverse()) { 331 | std::swap(bounds_(axis, childNo), split); 332 | nearest(iters[1-childNo], iters[2-childNo]); 333 | std::swap(bounds_(axis, childNo), split); 334 | } 335 | distToRegionCache_ = oldDist; 336 | } 337 | 338 | }; 339 | 340 | }}}} 341 | 342 | #endif // UNC_ROBOTICS_KDTREE_SO3RLSPACE_HPP 343 | -------------------------------------------------------------------------------- /src/_spaces.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Jeffrey Ichnowski 2 | // All rights reserved. 3 | // 4 | // BSD 3 Clause 5 | // 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions 8 | // are met: 9 | // 1. Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // 2. Redistributions in binary form must reproduce the above copyright 12 | // notice, this list of conditions and the following disclaimer in the 13 | // documentation and/or other materials provided with the distribution. 14 | // 3. Neither the name of the copyright holder nor the names of its 15 | // contributors may be used to endorse or promote products derived 16 | // from this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 21 | // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 22 | // COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 23 | // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 26 | // HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 27 | // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 29 | // OF THE POSSIBILITY OF SUCH DAMAGE. 30 | #pragma once 31 | #ifndef UNC_ROBOTICS_KDTREE_SPACES_HPP 32 | #define UNC_ROBOTICS_KDTREE_SPACES_HPP 33 | 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | #include "_tuple.hpp" 41 | 42 | namespace unc { 43 | namespace robotics { 44 | namespace kdtree { 45 | 46 | namespace detail { 47 | template 48 | class L2SpaceBase { 49 | public: 50 | typedef _Scalar Distance; 51 | typedef Eigen::Matrix<_Scalar, _dimensions, 1> State; 52 | 53 | template 54 | bool isValid(const Eigen::MatrixBase<_Derived>& q) const { 55 | return q.rows() == _dimensions && q.cols() == 1 && q.allFinite(); 56 | } 57 | 58 | template 59 | constexpr Distance distance( 60 | const Eigen::MatrixBase<_DerivedA>& a, 61 | const Eigen::MatrixBase<_DerivedB>& b) const 62 | { 63 | return (a - b).norm(); 64 | } 65 | 66 | template 67 | constexpr State interpolate( 68 | const Eigen::MatrixBase<_DerivedA>& from, 69 | const Eigen::MatrixBase<_DerivedB>& to, 70 | Distance t) const 71 | { 72 | return from + (to - from) * t; 73 | } 74 | }; 75 | } 76 | 77 | template 78 | class L2Space : public detail::L2SpaceBase<_Scalar, _dimensions> { 79 | public: 80 | L2Space(unsigned dimensions = _dimensions) { 81 | assert(dimensions == _dimensions); 82 | } 83 | 84 | constexpr unsigned dimensions() const { return _dimensions; } 85 | 86 | template 87 | bool isValid(const Eigen::MatrixBase<_Derived>& q) const { 88 | return q.rows() == _dimensions && detail::L2SpaceBase<_Scalar, _dimensions>::isValid(q); 89 | } 90 | }; 91 | 92 | template 93 | class L2Space<_Scalar, Eigen::Dynamic> : public detail::L2SpaceBase<_Scalar, Eigen::Dynamic> { 94 | unsigned dimensions_; 95 | 96 | public: 97 | L2Space(unsigned dimensions) 98 | : dimensions_(dimensions) 99 | { 100 | } 101 | 102 | constexpr unsigned dimensions() const { 103 | return dimensions_; 104 | } 105 | 106 | template 107 | bool isValid(const Eigen::MatrixBase<_Derived>& q) const { 108 | return q.rows() == dimensions_ && detail::L2SpaceBase<_Scalar, Eigen::Dynamic>::isValid(q); 109 | } 110 | }; 111 | 112 | template 113 | class BoundedL2Space : public L2Space<_Scalar, _dimensions> { 114 | Eigen::Array<_Scalar, _dimensions, 2> bounds_; 115 | 116 | typedef typename Eigen::Array<_Scalar, _dimensions, 2>::Index Index; 117 | 118 | void checkBounds() { 119 | assert((bounds_.col(0) < bounds_.col(1)).all()); 120 | assert((bounds_.col(1) - bounds_.col(0)).allFinite()); 121 | } 122 | 123 | public: 124 | using typename L2Space<_Scalar, _dimensions>::State; 125 | 126 | template 127 | BoundedL2Space(const Eigen::DenseBase<_Derived>& bounds) 128 | : L2Space<_Scalar, _dimensions>(bounds.rows()), 129 | bounds_(bounds) 130 | { 131 | checkBounds(); 132 | } 133 | 134 | template 135 | BoundedL2Space( 136 | const Eigen::DenseBase<_DerivedMin>& min, 137 | const Eigen::DenseBase<_DerivedMax>& max) 138 | : L2Space<_Scalar, _dimensions>(min.rows()) 139 | { 140 | bounds_.col(0) = min; 141 | bounds_.col(1) = max; 142 | } 143 | 144 | template 145 | bool isValid(const Eigen::MatrixBase<_Derived>& q) const { 146 | return L2Space<_Scalar, _dimensions>::isValid(q) 147 | && (bounds_.col(0) <= q).all() 148 | && (bounds_.col(1) >= q).all(); 149 | } 150 | 151 | const Eigen::Array<_Scalar, _dimensions, 2>& bounds() const { 152 | return bounds_; 153 | } 154 | 155 | _Scalar bounds(Index dim, Index j) const { 156 | return bounds_(dim, j); 157 | } 158 | }; 159 | 160 | template 161 | class SO3Space { 162 | public: 163 | typedef _Scalar Distance; 164 | typedef Eigen::Quaternion<_Scalar> State; 165 | 166 | constexpr unsigned dimensions() const { return 3; } 167 | 168 | template 169 | bool isValid(const Eigen::QuaternionBase<_Derived>& q) const { 170 | return std::abs(1 - q.coeffs().squaredNorm()) <= 1e-5; 171 | } 172 | 173 | template 174 | inline Distance distance( 175 | const Eigen::QuaternionBase<_DerivedA>& a, 176 | const Eigen::QuaternionBase<_DerivedB>& b) const 177 | { 178 | Distance dot = std::abs(a.coeffs().matrix().dot(b.coeffs().matrix())); 179 | return dot < 0 ? M_PI_2 : dot > 1 ? 0 : std::acos(dot); 180 | } 181 | 182 | template 183 | constexpr State interpolate( 184 | const Eigen::QuaternionBase<_DerivedA>& from, 185 | const Eigen::QuaternionBase<_DerivedB>& to, 186 | Distance t) const 187 | { 188 | Distance dq = from.coeffs().matrix().dot(to.coeffs().matrix()); 189 | if (std::abs(dq) >= 1) 190 | return from; 191 | 192 | Distance theta = std::acos(std::abs(dq)); 193 | Distance d = 1 / std::sin(theta); 194 | Distance s0 = std::sin((1 - t) * theta); 195 | Distance s1 = std::sin(t * theta); 196 | 197 | if (dq < 0) 198 | s1 = -s1; 199 | 200 | return State(d * (from.coeffs() * s0 + to.coeffs() * s1)); 201 | } 202 | }; 203 | 204 | template 205 | class SO3AltSpace { 206 | public: 207 | typedef _Scalar Distance; 208 | typedef Eigen::Quaternion<_Scalar> State; 209 | 210 | constexpr unsigned dimensions() const { return 3; } 211 | 212 | template 213 | bool isValid(const Eigen::QuaternionBase<_Derived>& q) const { 214 | return std::abs(1 - q.coeffs().squaredNorm()) <= 1e-5; 215 | } 216 | 217 | template 218 | inline Distance distance( 219 | const Eigen::QuaternionBase<_DerivedA>& a, 220 | const Eigen::QuaternionBase<_DerivedB>& b) const 221 | { 222 | Distance dot = std::abs(a.coeffs().matrix().dot(b.coeffs().matrix())); 223 | return dot < 0 ? M_PI_2 : dot > 1 ? 0 : std::acos(dot); 224 | } 225 | }; 226 | 227 | template 228 | class SO3RLSpace { 229 | public: 230 | typedef _Scalar Distance; 231 | typedef Eigen::Quaternion<_Scalar> State; 232 | 233 | constexpr unsigned dimensions() const { return 4; } 234 | 235 | template 236 | bool isValid(const Eigen::QuaternionBase<_Derived>& q) const { 237 | return std::abs(1 - q.coeffs().squaredNorm()) <= 1e-5; 238 | } 239 | 240 | template 241 | inline Distance distance( 242 | const Eigen::QuaternionBase<_DerivedA>& a, 243 | const Eigen::QuaternionBase<_DerivedB>& b) const 244 | { 245 | Distance dot = std::abs(a.coeffs().matrix().dot(b.coeffs().matrix())); 246 | return dot < 0 ? M_PI_2 : dot > 1 ? 0 : std::acos(dot); 247 | } 248 | }; 249 | 250 | template > 251 | class RatioWeightedSpace : public _Space { 252 | public: 253 | typedef _Ratio Ratio; 254 | 255 | static constexpr std::intmax_t num = Ratio::num; 256 | static constexpr std::intmax_t den = Ratio::den; 257 | 258 | // inherit constructor 259 | using _Space::_Space; 260 | 261 | // RatioWeightedSpace() {} 262 | 263 | RatioWeightedSpace(const _Space& space) 264 | : _Space(space) 265 | { 266 | } 267 | 268 | RatioWeightedSpace(_Space&& space) 269 | : _Space(std::forward<_Space>(space)) 270 | { 271 | } 272 | 273 | template 274 | inline typename _Space::Distance distance(const _A& a, const _B& b) const { 275 | return _Space::distance(a, b) * num / den; 276 | } 277 | }; 278 | 279 | template 280 | auto makeRatioWeightedSpace(_Space&& space) { 281 | return RatioWeightedSpace<_Space, std::ratio>(std::forward<_Space>(space)); 282 | } 283 | 284 | template 285 | class WeightedSpace : public _Space { 286 | public: 287 | using typename _Space::Distance; 288 | 289 | private: 290 | Distance weight_; 291 | 292 | public: 293 | WeightedSpace(Distance weight, const _Space& space = _Space()) 294 | : _Space(space), weight_(weight) 295 | { 296 | } 297 | 298 | WeightedSpace(Distance weight, _Space&& space) 299 | : _Space(std::forward<_Space>(space)), weight_(weight) 300 | { 301 | } 302 | 303 | template 304 | WeightedSpace(Distance weight, _Args&& ... args) 305 | : _Space(std::forward<_Args>(args)...), 306 | weight_(weight) 307 | { 308 | } 309 | 310 | constexpr Distance weight() const { 311 | return weight_; 312 | } 313 | 314 | template 315 | constexpr Distance distance(const _A& a, const _B& b) const { 316 | return _Space::distance(a, b) * weight_; 317 | } 318 | }; 319 | 320 | template 321 | class CompoundSpace { 322 | typedef std::tuple<_Spaces...> Spaces; 323 | 324 | Spaces spaces_; 325 | 326 | static_assert(sizeof...(_Spaces) > 1, "compound space must have two or more subspaces"); 327 | 328 | public: 329 | typedef std::tuple State; 330 | typedef typename detail::SumResultType::type Distance; 331 | 332 | explicit CompoundSpace() { 333 | } 334 | 335 | explicit CompoundSpace(const _Spaces& ... args) 336 | : spaces_(args...) 337 | { 338 | } 339 | 340 | template 341 | explicit CompoundSpace(_Args&& ... args) 342 | : spaces_(std::forward<_Args>(args)...) 343 | { 344 | } 345 | 346 | template 347 | constexpr typename std::tuple_element::type& get() { 348 | return std::get(spaces_); 349 | } 350 | 351 | template 352 | constexpr typename std::tuple_element::type const& get() const { 353 | return std::get(spaces_); 354 | } 355 | 356 | inline constexpr unsigned dimensions() const { 357 | using namespace detail; 358 | return sum(map([](const auto& space) { return space.dimensions(); }, spaces_)); 359 | } 360 | 361 | template 362 | bool isValid(_State&& q) const { 363 | using namespace detail; 364 | // TODO: return reduce(std::logical_and(), zip([](auto&& subs, auto&& subq) { return subs.isValid(subq); }, spaces_, q)); 365 | assert(false); 366 | return false; 367 | } 368 | 369 | template 370 | inline Distance distance(_StateA&& a, _StateB&& b) const { 371 | using namespace detail; 372 | return sum(zip([](auto&& subs, auto&& suba, auto&& subb) { 373 | return subs.distance(suba, subb); 374 | }, spaces_, a, b)); 375 | } 376 | 377 | private: 378 | template 379 | constexpr State interpolate( 380 | const _StateA& from, 381 | const _StateB& to, 382 | Distance t, 383 | std::index_sequence) const 384 | { 385 | return State(std::get(spaces_).interpolate(std::get(from), std::get(to), t)...); 386 | } 387 | 388 | public: 389 | template 390 | constexpr State interpolate( 391 | const _StateA& from, 392 | const _StateB& to, 393 | Distance t) const 394 | { 395 | return interpolate(from, to, t, std::make_index_sequence{}); 396 | } 397 | 398 | }; 399 | 400 | template 401 | constexpr auto makeCompoundSpace(_Spaces&&... args) { 402 | return CompoundSpace::type...>(std::forward<_Spaces>(args)...); 403 | } 404 | 405 | template 406 | constexpr auto makeWeightedSpace( 407 | typename _Space::Distance weight, 408 | _Space&& space = _Space()) 409 | { 410 | return WeightedSpace<_Space>(weight, std::forward<_Space>(space)); 411 | } 412 | 413 | template 414 | using SE3Space = CompoundSpace< 415 | RatioWeightedSpace, std::ratio<_qWeight>>, 416 | RatioWeightedSpace, std::ratio<_tWeight>>>; 417 | 418 | template 419 | using BoundedSE3Space = CompoundSpace< 420 | RatioWeightedSpace, std::ratio<_qWeight>>, 421 | RatioWeightedSpace, std::ratio<_tWeight>>>; 422 | 423 | } 424 | } 425 | } 426 | 427 | namespace std { 428 | 429 | template 430 | class tuple_element> { 431 | public: 432 | typedef typename std::tuple_element>::type type; 433 | }; 434 | 435 | template 436 | constexpr typename std::tuple_element>::type& 437 | get(unc::robotics::kdtree::CompoundSpace<_Spaces...>& space) { 438 | return space.template get(); 439 | } 440 | 441 | template 442 | constexpr typename std::tuple_element>::type const& 443 | get(const unc::robotics::kdtree::CompoundSpace<_Spaces...>& space) { 444 | return space.template get(); 445 | } 446 | 447 | 448 | } 449 | 450 | #endif // UNC_ROBOTICS_KDTREE_SPACES_HPP 451 | -------------------------------------------------------------------------------- /test/kdtree_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../src/kdtree.hpp" 3 | #include "test.hpp" 4 | #include "state_sampler.hpp" 5 | 6 | template 7 | struct TestNode { 8 | _State key_; 9 | int name_; 10 | 11 | EIGEN_MAKE_ALIGNED_OPERATOR_NEW 12 | 13 | TestNode(const _State& key, int name) 14 | : key_(key), name_(name) 15 | { 16 | } 17 | 18 | template 19 | friend std::basic_ostream<_Char, _Traits>& operator << ( 20 | std::basic_ostream<_Char, _Traits>& os, const TestNode& node) 21 | { 22 | return os << "Node{" << node.name_ << "}"; 23 | } 24 | }; 25 | 26 | struct TestNodeKey { 27 | template 28 | constexpr const _State& operator() (const TestNode<_State>& state) const { 29 | return state.key_; 30 | } 31 | }; 32 | 33 | template 34 | unc::robotics::kdtree::BoundedL2Space<_Scalar, _dimensions> 35 | createBoundedL2Space(int dim = _dimensions) { 36 | using namespace unc::robotics::kdtree; 37 | 38 | Eigen::Array<_Scalar, _dimensions, 2> bounds(dim, 2); 39 | bounds.col(0) = -1; 40 | bounds.col(1) = 1; 41 | 42 | return BoundedL2Space<_Scalar, _dimensions>(bounds); 43 | } 44 | 45 | template 46 | void testAdd(const Space& space, _Split&&) { 47 | using namespace unc::robotics::kdtree; 48 | 49 | typedef typename Space::Distance Distance; 50 | typedef typename Space::State Key; 51 | 52 | constexpr std::size_t N = 1000; 53 | 54 | KDTree, Space, TestNodeKey, _Split> tree(space); 55 | 56 | std::vector, Distance>> nearest; 57 | std::mt19937_64 rng; 58 | std::vector> nodes; 59 | for (std::size_t i=0 ; i::randomState(rng, space), i); 61 | EXPECT(tree.size()) == i; 62 | tree.add(nodes.back()); 63 | 64 | tree.nearest(nodes.back().key_); 65 | tree.nearest(nearest, nodes.back().key_, 1); 66 | } 67 | EXPECT(tree.size()) == N; 68 | } 69 | 70 | template 71 | void testKNN(const Space& space, _Split&&, std::size_t N = 10000, std::size_t Q = 5000, std::size_t k = 20) { 72 | using namespace unc::robotics::kdtree; 73 | 74 | typedef typename Space::Distance Distance; 75 | typedef typename Space::State Key; 76 | 77 | KDTree, Space, TestNodeKey, MidpointSplit> tree(space); 78 | 79 | std::mt19937_64 rng; 80 | std::vector> nodes; 81 | nodes.reserve(N); 82 | for (std::size_t i=0 ; i::randomState(rng, space), i); 84 | tree.add(nodes.back()); 85 | } 86 | 87 | std::vector, Distance>> nearest; 88 | nearest.reserve(k); 89 | for (std::size_t i=0 ; i::randomState(rng, space); 91 | tree.nearest(nearest, q, k); 92 | 93 | EXPECT(nearest.size()) == k; 94 | 95 | std::partial_sort(nodes.begin(), nodes.begin() + k, nodes.end(), [&q, &space] (auto& a, auto& b) { 96 | return space.distance(a.key_, q) < space.distance(b.key_, q); 97 | }); 98 | 99 | for (std::size_t j=0 ; j 108 | std::pair benchmark( 109 | const std::string& name, 110 | const Space& space, 111 | std::size_t N, std::size_t k, 112 | _Duration duration) 113 | { 114 | using namespace unc::robotics::kdtree; 115 | 116 | typedef typename Space::Distance Distance; 117 | typedef typename Space::State Key; 118 | 119 | KDTree, Space, TestNodeKey, MidpointSplit> tree(space); 120 | 121 | std::mt19937_64 rng; 122 | std::vector> nodes; 123 | nodes.reserve(N); 124 | for (std::size_t i=0 ; i::randomState(rng, space), i); 126 | tree.add(nodes.back()); 127 | } 128 | 129 | typedef std::pair, Distance> DistNode; 130 | std::vector> nearest; 131 | nearest.reserve(k); 132 | constexpr std::size_t batchSize = 100; 133 | typedef std::chrono::high_resolution_clock Clock; 134 | Clock::duration maxElapsed = duration; 135 | Clock::duration elapsed; 136 | std::size_t count = 0; 137 | Clock::time_point start = Clock::now(); 138 | do { 139 | for (std::size_t i=0 ; i::randomState(rng, space); 141 | tree.nearest(nearest, q, k); 142 | } 143 | count += batchSize; 144 | } while ((elapsed = Clock::now() - start) < maxElapsed); 145 | 146 | double seconds = std::chrono::duration(elapsed).count(); 147 | std::cout << name << ": " << seconds*1e6/count << " us/op" << std::endl; 148 | 149 | return std::make_pair(seconds, count); 150 | } 151 | 152 | template 153 | void testStaticBuildAndQuery(const Space& space, std::size_t N = 10000) { 154 | using namespace unc::robotics::kdtree; 155 | typedef typename Space::State Key; 156 | typedef typename Space::Distance Distance; 157 | 158 | std::size_t Q = 1000; 159 | std::size_t k = 20; 160 | 161 | std::mt19937_64 rng; 162 | std::vector> nodes; 163 | KDTree, Space, TestNodeKey, MedianSplit, StaticBuild> tree(space); 164 | 165 | for (std::size_t i=0 ; i::randomState(rng, space), i); 167 | 168 | tree.build(nodes); 169 | 170 | for (std::size_t i=0 ; i* n = tree.nearest(q, &dist); 174 | // close to zero, but not always exactly 0 due to numerical issues 175 | EXPECT(dist) == space.distance(q, q); 176 | EXPECT(n) != nullptr; 177 | EXPECT(n->name_) == nodes[i].name_; 178 | } 179 | 180 | 181 | std::vector, Distance>> nearest; 182 | nearest.reserve(k); 183 | for (std::size_t i=0 ; i::randomState(rng, space); 185 | tree.nearest(nearest, q, k); 186 | 187 | EXPECT(nearest.size()) == k; 188 | 189 | std::partial_sort( 190 | nodes.begin(), nodes.begin() + k, nodes.end(), 191 | [&q, &space] (auto& a, auto& b) { 192 | return space.distance(q, a.key_) < space.distance(q, b.key_); 193 | }); 194 | 195 | for (std::size_t j=0 ; j auto createL2_2Space() { return unc::robotics::kdtree::L2Space<_Scalar, 2>(); } 230 | template auto createL2_3Space() { return unc::robotics::kdtree::L2Space<_Scalar, 3>(); } 231 | template auto createL2_6Space() { return unc::robotics::kdtree::L2Space<_Scalar, 6>(); } 232 | template auto createBoundedL2_2Space() { return createBoundedL2Space<_Scalar, 2>(); } 233 | template auto createBoundedL2_3Space() { return createBoundedL2Space<_Scalar, 3>(); } 234 | template auto createBoundedL2_6Space() { return createBoundedL2Space<_Scalar, 6>(); } 235 | template auto createSO3Space() { return unc::robotics::kdtree::SO3Space<_Scalar>(); } 236 | template auto createSO3AltSpace() { return unc::robotics::kdtree::SO3AltSpace<_Scalar>(); } 237 | template auto createSO3RLSpace() { return unc::robotics::kdtree::SO3RLSpace<_Scalar>(); } 238 | 239 | template 240 | auto createBoundedSE3_1to1Space() { 241 | using namespace unc::robotics::kdtree; 242 | return makeCompoundSpace( 243 | SO3Space<_Scalar>(), 244 | createBoundedL2Space<_Scalar, 3>()); 245 | } 246 | 247 | template 248 | auto createRatioWeightedBoundedSE3Space() { 249 | using namespace unc::robotics::kdtree; 250 | return makeCompoundSpace( 251 | makeRatioWeightedSpace<_q>(SO3Space<_Scalar>()), 252 | makeRatioWeightedSpace<_t>(createBoundedL2Space<_Scalar, 3>())); 253 | } 254 | 255 | template 256 | auto createBoundedSE3_5to17Space() { 257 | return createRatioWeightedBoundedSE3Space<_Scalar, 5, 17>(); 258 | } 259 | 260 | template 261 | auto createBoundedSE3_31416to10000Space() { 262 | return createRatioWeightedBoundedSE3Space<_Scalar, 31416, 10000>(); 263 | } 264 | 265 | template 266 | auto createBoundedSE3_PISpace() { 267 | using namespace unc::robotics::kdtree; 268 | return makeCompoundSpace( 269 | SO3Space<_Scalar>(), 270 | makeWeightedSpace(M_PI, createBoundedL2Space<_Scalar, 3>())); 271 | } 272 | 273 | template 274 | auto createThreeSE3Space() { 275 | return makeCompoundSpace( 276 | createBoundedSE3_1to1Space<_Scalar>(), 277 | createBoundedSE3_1to1Space<_Scalar>(), 278 | createBoundedSE3_1to1Space<_Scalar>()); 279 | } 280 | 281 | // TEST_CASE(benchmark) { 282 | // using namespace unc::robotics::kdtree; 283 | // using namespace std::literals::chrono_literals; 284 | 285 | // std::size_t N = 1000000; 286 | // std::size_t k = 20; 287 | // auto timeLimit = 5s; 288 | 289 | // // benchmark("R^3 l2", createBoundedL2Space(), N, k, timeLimit); 290 | // // benchmark("R^6 l2", createBoundedL2Space(), N, k, timeLimit); 291 | // benchmark("SO(3)F", SO3Space(), N, k, timeLimit); 292 | // benchmark("SO(3)A", SO3AltSpace(), N, k, timeLimit); 293 | // benchmark("SO(3)R", SO3RLSpace(), N, k, timeLimit); 294 | // benchmark("SE(3)F", createBoundedSE3_1to1Space(), N, k, timeLimit); 295 | // } 296 | 297 | #define SCALAR_TESTS(name, split, space) \ 298 | TEST_CASE(name##_##split##_##space##_float) { \ 299 | test##name(create##space##Space(), ::unc::robotics::kdtree:: split##Split{}); \ 300 | } \ 301 | TEST_CASE(name##_##split##_##space##_double) { \ 302 | test##name(create##space##Space(), ::unc::robotics::kdtree:: split##Split{}); \ 303 | } \ 304 | TEST_CASE(name##_##split##_##space##_long_double) { \ 305 | test##name(create##space##Space(), ::unc::robotics::kdtree:: split##Split{}); \ 306 | } 307 | 308 | #define SPLIT_TESTS(name, space) \ 309 | SCALAR_TESTS(name, Median, space) \ 310 | SCALAR_TESTS(name, Midpoint, space) 311 | 312 | #define SPACE_TESTS(name) \ 313 | SPLIT_TESTS(name, SO3) \ 314 | SPLIT_TESTS(name, SO3Alt) \ 315 | SPLIT_TESTS(name, SO3RL) 316 | // SPLIT_TESTS(name, BoundedL2_2) \ 317 | // SPLIT_TESTS(name, BoundedL2_3) \ 318 | // SPLIT_TESTS(name, BoundedL2_6) \ 319 | // SPLIT_TESTS(name, BoundedSE3_1to1) \ 320 | // SPLIT_TESTS(name, BoundedSE3_5to17) \ 321 | // SPLIT_TESTS(name, BoundedSE3_31416to10000) \ 322 | // SPLIT_TESTS(name, BoundedSE3_PI) \ 323 | // SPLIT_TESTS(name, ThreeSE3) 324 | 325 | 326 | // TEST_CASE(Add_BoundedL2) { 327 | // testAdd(createBoundedL2Space()); 328 | // } 329 | // TEST_CASE(KNN_BoundedL2) { 330 | // testKNN(createBoundedL2Space(), 10000, 1000, 20); 331 | // } 332 | 333 | // TEST_CASE(Add_SO3Space) { 334 | // testAdd(unc::robotics::kdtree::SO3Space()); 335 | // } 336 | // TEST_CASE(KNN_SO3Space) { 337 | // testKNN(unc::robotics::kdtree::SO3Space(), 10000, 1000, 20); 338 | // } 339 | 340 | // TEST_CASE(Add_RatioWeightedSpace) { 341 | // using namespace unc::robotics::kdtree; 342 | // testAdd(makeRatioWeightedSpace<17,5>(SO3Space())); 343 | // } 344 | 345 | // TEST_CASE(KNN_RatioWeightedSpace) { 346 | // using namespace unc::robotics::kdtree; 347 | // testKNN(makeRatioWeightedSpace<17,5>(SO3Space()), 10000, 1000, 20); 348 | // } 349 | 350 | // TEST_CASE(Add_WeightedSpace) { 351 | // using namespace unc::robotics::kdtree; 352 | // testAdd(WeightedSpace>(3.21)); 353 | // } 354 | 355 | // TEST_CASE(KNN_WeightedSpace) { 356 | // using namespace unc::robotics::kdtree; 357 | // testKNN(WeightedSpace>(3.21), 10000, 1000, 20); 358 | // } 359 | 360 | // TEST_CASE(Add_CompoundSpace_SE3_1to1) { 361 | // testAdd(createBoundedSE3Space()); 362 | // } 363 | 364 | // TEST_CASE(KNN_CompoundSpace_SE3_1to1) { 365 | // testKNN(createBoundedSE3Space(), 10000, 1000, 20); 366 | // } 367 | 368 | // TEST_CASE(Add_CompoundSpace_SE3_5to17) { 369 | // testAdd(createRatioWeightedBoundedSE3Space()); 370 | // } 371 | 372 | // TEST_CASE(KNN_CompoundSpace_SE3_5to17) { 373 | // testKNN(createRatioWeightedBoundedSE3Space(), 10000, 1000, 20); 374 | // } 375 | 376 | TEST_CASE(StaticBuildAndQuery_L2) { 377 | using namespace unc::robotics::kdtree; 378 | testStaticBuildAndQuery(L2Space()); 379 | } 380 | 381 | TEST_CASE(StaticBuildAndQuery_BoundedL2) { 382 | testStaticBuildAndQuery(createBoundedL2Space()); 383 | } 384 | 385 | TEST_CASE(StaticBuildAndQuery_SO3) { 386 | testStaticBuildAndQuery(unc::robotics::kdtree::SO3Space()); 387 | } 388 | 389 | TEST_CASE(StaticBuildAndQuery_SE3) { 390 | using namespace unc::robotics::kdtree; 391 | testStaticBuildAndQuery( 392 | CompoundSpace, L2Space>()); 393 | } 394 | 395 | TEST_CASE(StaticBuildAndQuery_SE3_5to17) { 396 | using namespace unc::robotics::kdtree; 397 | testStaticBuildAndQuery( 398 | CompoundSpace< 399 | RatioWeightedSpace, std::ratio<5>>, 400 | RatioWeightedSpace, std::ratio<17>>>()); 401 | } 402 | 403 | TEST_CASE(StaticBuildAndQuery_SE3_PI) { 404 | using namespace unc::robotics::kdtree; 405 | testStaticBuildAndQuery( 406 | CompoundSpace< 407 | SO3Space, 408 | WeightedSpace>>( 409 | SO3Space(), 410 | WeightedSpace>( 411 | M_PI, L2Space()))); 412 | } 413 | 414 | // TODO: two SE3 spaces (compound of compounds) 415 | 416 | SPACE_TESTS(Add) 417 | SPACE_TESTS(KNN) 418 | -------------------------------------------------------------------------------- /src/_kdtree_median.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Jeffrey Ichnowski 2 | // All rights reserved. 3 | // 4 | // BSD 3 Clause 5 | // 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions 8 | // are met: 9 | // 1. Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // 2. Redistributions in binary form must reproduce the above copyright 12 | // notice, this list of conditions and the following disclaimer in the 13 | // documentation and/or other materials provided with the distribution. 14 | // 3. Neither the name of the copyright holder nor the names of its 15 | // contributors may be used to endorse or promote products derived 16 | // from this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 21 | // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 22 | // COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 23 | // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 26 | // HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 27 | // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 29 | // OF THE POSSIBILITY OF SUCH DAMAGE. 30 | #pragma once 31 | #ifndef UNC_ROBOTICS_KDTREE_KDTREE_MEDIAN_HPP 32 | #define UNC_ROBOTICS_KDTREE_KDTREE_MEDIAN_HPP 33 | 34 | namespace unc { namespace robotics { namespace kdtree { 35 | 36 | template 37 | struct MedianSplitNodeMember { 38 | union { 39 | _Distance split_; 40 | // TODO: consider making offset_'s type a template parameter 41 | std::ptrdiff_t offset_; 42 | }; 43 | unsigned axis_; 44 | }; 45 | 46 | namespace detail { 47 | 48 | template 49 | struct MedianSplitNode { 50 | _T data_; 51 | MedianSplitNodeMember, _Distance> hook_; 52 | 53 | // TODO: delete? 54 | // MedianSplitNode(const MedianSplitNode&) = delete; 55 | // MedianSplitNode(MedianSplitNode&&) = delete; 56 | 57 | MedianSplitNode(const _T& value) : data_(value) {} 58 | template 59 | MedianSplitNode(_Args&& ... args) : _T(std::forward<_Args>(args)...) {} 60 | }; 61 | 62 | template 63 | struct MedianSplitNodeKey : _GetKey { 64 | inline MedianSplitNodeKey(const _GetKey& getKey) : _GetKey(getKey) {} 65 | 66 | constexpr decltype(auto) operator() (const MedianSplitNode<_T, _Distance>& node) const { 67 | return _GetKey::operator()(node.data_); 68 | } 69 | constexpr decltype(auto) operator() (MedianSplitNode<_T, _Distance>& node) { 70 | return _GetKey::operator()(node.data_); 71 | } 72 | }; 73 | 74 | template _Node::* _member> 76 | struct MedianBuilder { 77 | typedef _Space Space; 78 | typedef _Node Node; 79 | typedef typename Space::Distance Distance; 80 | typedef typename Space::State Key; 81 | typedef MedianSplitNodeMember Member; 82 | 83 | MedianAccum<_Space> accum_; 84 | _GetKey getKey_; 85 | 86 | MedianBuilder(const _Space& space, const _GetKey& getKey) 87 | : accum_(space), 88 | getKey_(getKey) 89 | { 90 | } 91 | 92 | static void setSplit(Node& node, Distance split) { 93 | (node.*_member).split_ = split; 94 | } 95 | 96 | static void setOffset(Node& node, std::ptrdiff_t offset) { 97 | (node.*_member).offset_ = offset; 98 | } 99 | 100 | template 101 | void operator() (_Iter begin, _Iter end) { 102 | if (begin == end) 103 | return; 104 | 105 | _Iter it = begin; 106 | accum_.init(getKey_(*it)); 107 | while (++it != end) 108 | accum_.accum(getKey_(*it)); 109 | 110 | unsigned axis; 111 | accum_.maxAxis(&axis); 112 | accum_.partition(*this, axis, begin, end, getKey_); 113 | ((*begin).*_member).axis_ = axis; 114 | } 115 | }; 116 | 117 | template 118 | struct MedianNearest { 119 | typedef _Tree Tree; 120 | 121 | typedef typename Tree::Space Space; 122 | typedef typename Space::Distance Distance; 123 | typedef typename Space::State Key; 124 | typedef typename Tree::Node Node; 125 | 126 | const Tree& tree_; 127 | MedianNearestTraversal traversal_; 128 | Distance dist_; 129 | 130 | template 131 | MedianNearest(const Tree& tree, const _Key& key, Distance dist) 132 | : tree_(tree), traversal_(tree.space(), key), dist_(dist) 133 | { 134 | } 135 | 136 | constexpr bool shouldTraverse() const { 137 | return traversal_.distToRegion() <= dist_; 138 | } 139 | 140 | static constexpr Distance split(const Node& n) { 141 | return (n.*_Tree::member_).split_; 142 | } 143 | 144 | static constexpr std::ptrdiff_t offset(const Node& n) { 145 | return (n.*_Tree::member_).offset_; 146 | } 147 | 148 | template 149 | inline void operator() (_Iter begin, _Iter end) { 150 | assert(begin <= end); 151 | if (begin != end) 152 | traversal_.traverse(*this, ((*begin).*_Tree::member_).axis_, begin, end); 153 | } 154 | 155 | template 156 | void updateX(const Node& n) { update(n); } 157 | 158 | template 159 | void update(const Node& n) { 160 | Distance d = traversal_.keyDistance(tree_.getNodeKey_(n)); 161 | if (d <= dist_) 162 | static_cast<_Derived*>(this)->update(d, n); 163 | } 164 | }; 165 | 166 | template 167 | struct MedianNearest1 168 | : MedianNearest, _Tree> 169 | { 170 | typedef MedianNearest, _Tree> Base; 171 | typedef typename _Tree::Space Space; 172 | typedef typename Space::Distance Distance; 173 | typedef typename Space::State Key; 174 | typedef typename Base::Node Node; 175 | 176 | using MedianNearest, _Tree>::MedianNearest; 177 | using Base::dist_; 178 | 179 | const Node* nearest_ = nullptr; 180 | 181 | // MedianNearest1(const _Tree& tree, const Key& key) 182 | // : Base(tree, key) 183 | // { 184 | // std::cout << "nearest = " << nearest_ << std::endl; 185 | // } 186 | 187 | inline void update(Distance d, const Node& n) { 188 | dist_ = d; 189 | nearest_ = &n; 190 | // std::cout << d << std::endl; 191 | // std::cout << nearest_ << std::endl; 192 | } 193 | }; 194 | 195 | template 196 | struct MedianNearestK 197 | : MedianNearest, _Tree> 198 | { 199 | typedef _Tree Tree; 200 | typedef _Value Value; 201 | typedef _NodeValueFn NodeValueFn; 202 | typedef MedianNearest, Tree> Base; 203 | typedef typename _Tree::Space Space; 204 | typedef typename Space::State Key; 205 | typedef typename Space::Distance Distance; 206 | typedef typename Base::Node Node; 207 | 208 | using Base::dist_; 209 | 210 | std::size_t k_; 211 | std::vector, _ResultAllocator>& nearest_; 212 | NodeValueFn nodeValueFn_; 213 | 214 | template 215 | MedianNearestK( 216 | const Tree& tree, 217 | const _Key& key, 218 | Distance dist, 219 | std::size_t k, 220 | std::vector, _ResultAllocator>& nearest, 221 | NodeValueFn&& nodeValueFn) 222 | : Base(tree, key, dist), k_(k), nearest_(nearest), nodeValueFn_(nodeValueFn) 223 | { 224 | } 225 | 226 | inline void update(Distance d, const Node& n) { 227 | if (nearest_.size() == k_) { 228 | std::pop_heap(nearest_.begin(), nearest_.end(), CompareSecond()); 229 | nearest_.pop_back(); 230 | } 231 | 232 | nearest_.emplace_back(nodeValueFn_(n), d); 233 | std::push_heap(nearest_.begin(), nearest_.end(), CompareSecond()); 234 | 235 | if (nearest_.size() == k_) 236 | dist_ = nearest_[0].second; 237 | } 238 | }; 239 | 240 | 241 | 242 | } // namespace detail 243 | 244 | // Static Median-Split Tree (lock-free operation is not supported) 245 | template < 246 | typename _T, 247 | typename _Space, 248 | typename _GetKey, 249 | typename _Allocator> 250 | struct KDTree<_T, _Space, _GetKey, MedianSplit, StaticBuild, SingleThread, _Allocator> { 251 | typedef _Space Space; 252 | 253 | private: 254 | typedef typename Space::Distance Distance; 255 | typedef typename Space::State Key; 256 | typedef _GetKey GetKey; 257 | 258 | typedef detail::MedianSplitNode<_T, Distance> Node; 259 | typedef detail::MedianSplitNodeKey<_T, Distance, _GetKey> GetNodeKey; 260 | 261 | typedef std::allocator_traits<_Allocator> AllocatorTraits; 262 | typedef typename AllocatorTraits::template rebind_alloc NodeAllocator; 263 | 264 | static constexpr auto member_ = &Node::hook_; 265 | 266 | Space space_; 267 | GetNodeKey getNodeKey_; 268 | 269 | std::vector nodes_; 270 | 271 | template friend struct detail::MedianNearest; 272 | 273 | public: 274 | KDTree(const Space& space, const _GetKey& getKey = _GetKey(), const _Allocator& alloc = _Allocator()) 275 | : space_(space), 276 | getNodeKey_(getKey), 277 | nodes_(NodeAllocator(alloc)) 278 | { 279 | } 280 | 281 | template 282 | KDTree(const Space& space, const _GetKey& getKey, _Iter begin, _Iter end) 283 | : KDTree(space, getKey) 284 | { 285 | build(begin, end); 286 | } 287 | 288 | template 289 | KDTree(const Space& space, const _GetKey& getKey, const _Container& container) 290 | : KDTree(space, getKey) 291 | { 292 | build(container); 293 | } 294 | 295 | template 296 | void build(_Iter begin, _Iter end) { 297 | nodes_.clear(); 298 | nodes_.reserve(std::distance(begin, end)); 299 | std::transform(begin, end, std::back_inserter(nodes_), [&](auto& v) { return Node(v); }); 300 | 301 | detail::MedianBuilder builder(space_, getNodeKey_); 302 | builder(nodes_.begin(), nodes_.end()); 303 | } 304 | 305 | template 306 | void build(const _Container& container) { 307 | build(container.begin(), container.end()); 308 | } 309 | 310 | constexpr const Space& space() const { 311 | return space_; 312 | } 313 | 314 | constexpr std::size_t size() const { 315 | return nodes_.size(); 316 | } 317 | 318 | // TODO: non-const version returning non-const result? 319 | const _T* nearest(const Key& key, Distance *distOut = nullptr) const { 320 | if (nodes_.size() == 0) 321 | return nullptr; 322 | 323 | detail::MedianNearest1 nearest( 324 | *this, key, std::numeric_limits::infinity()); 325 | nearest(nodes_.begin(), nodes_.end()); 326 | if (distOut) 327 | *distOut = nearest.dist_; 328 | 329 | return &(nearest.nearest_->data_); 330 | } 331 | 332 | template 333 | void nearest( 334 | std::vector, _ResultAllocator>& result, 335 | const Key& key, 336 | std::size_t k, 337 | Distance maxRadius, 338 | _NodeValueFn&& nodeValueFn) const 339 | { 340 | result.clear(); 341 | if (k == 0) 342 | return; 343 | 344 | detail::MedianNearestK nearest( 345 | *this, key, maxRadius, k, result, std::forward<_NodeValueFn>(nodeValueFn)); 346 | 347 | nearest(nodes_.begin(), nodes_.end()); 348 | std::sort_heap(result.begin(), result.end(), detail::CompareSecond()); 349 | } 350 | 351 | template 352 | void nearest( 353 | std::vector, _ResultAllocator>& result, 354 | const Key& key, 355 | std::size_t k, 356 | Distance maxRadius = std::numeric_limits::infinity()) const 357 | { 358 | nearest(result, key, k, maxRadius, [] (const Node& n) -> const auto& { return n.data_; }); 359 | } 360 | }; 361 | 362 | // Dynamically balanced tree 363 | template < 364 | typename _T, 365 | typename _Space, 366 | typename _GetKey, 367 | typename _Allocator> 368 | struct KDTree<_T, _Space, _GetKey, MedianSplit, DynamicBuild, SingleThread, _Allocator> { 369 | typedef _Space Space; 370 | 371 | private: 372 | // must be a power of 2 373 | static constexpr std::size_t minStaticTreeSize_ = 2; 374 | 375 | typedef typename Space::Distance Distance; 376 | typedef typename Space::State Key; 377 | typedef _GetKey GetKey; 378 | 379 | typedef detail::MedianSplitNode<_T, Distance> Node; 380 | typedef detail::MedianSplitNodeKey<_T, Distance, _GetKey> GetNodeKey; 381 | 382 | typedef std::vector Nodes; 383 | typedef typename Nodes::iterator Iter; 384 | typedef typename Nodes::const_iterator ConstIter; 385 | 386 | static constexpr auto member_ = &Node::hook_; 387 | 388 | Space space_; 389 | Nodes nodes_; 390 | detail::MedianBuilder builder_; 391 | 392 | template friend struct detail::MedianNearest; 393 | 394 | template 395 | inline void scanTrees(_Nearest& nearest) const { 396 | ConstIter it = nodes_.begin(); 397 | for (std::size_t remaining = size() ; remaining >= minStaticTreeSize_ ; ) { 398 | std::size_t treeSize = 1 << detail::log2(remaining); 399 | // std::cout << "scan " << remaining << " -> " << treeSize << std::endl; 400 | nearest(it, it + treeSize); 401 | it += treeSize; 402 | remaining &= ~treeSize; 403 | } 404 | 405 | for ( ; it != nodes_.end() ; ++it) 406 | nearest.updateX(*it); 407 | } 408 | 409 | // TODO: change to reference 410 | constexpr decltype(auto) getNodeKey_(const Node& node) const { 411 | return builder_.getKey_(node); 412 | } 413 | 414 | public: 415 | KDTree(const Space& space, const _GetKey& getKey = _GetKey()) 416 | : space_(space), 417 | builder_(space, getKey) 418 | { 419 | } 420 | 421 | constexpr const Space& space() const { 422 | return space_; 423 | } 424 | 425 | constexpr std::size_t size() const { 426 | return nodes_.size(); 427 | } 428 | 429 | constexpr bool empty() const { 430 | return nodes_.empty(); 431 | } 432 | 433 | void add(const _T& value) { 434 | nodes_.emplace_back(value); 435 | std::size_t s = nodes_.size(); 436 | 437 | std::size_t newTreeSize = ((s^(s-1)) + 1) >> 1; 438 | 439 | if (newTreeSize >= minStaticTreeSize_) 440 | builder_(nodes_.end() - newTreeSize, nodes_.end()); 441 | } 442 | 443 | template 444 | const _T* nearest(const _Key& key, Distance *distOut = nullptr) const { 445 | if (empty()) 446 | return nullptr; 447 | 448 | detail::MedianNearest1 nearest( 449 | *this, key, std::numeric_limits::infinity()); 450 | scanTrees(nearest); 451 | if (distOut) 452 | *distOut = nearest.dist_; 453 | return &(nearest.nearest_->data_); 454 | } 455 | 456 | template 457 | void nearest( 458 | std::vector, _ResultAllocator>& result, 459 | const _Key& key, 460 | std::size_t k, 461 | Distance maxRadius, 462 | _NodeValueFn&& nodeValueFn) const 463 | { 464 | result.clear(); 465 | if (k == 0) 466 | return; 467 | 468 | detail::MedianNearestK nearest( 469 | *this, key, maxRadius, k, result, std::forward<_NodeValueFn>(nodeValueFn)); 470 | scanTrees(nearest); 471 | std::sort_heap(result.begin(), result.end(), detail::CompareSecond()); 472 | } 473 | 474 | template 475 | void nearest( 476 | std::vector, _ResultAllocator>& result, 477 | const _Key& key, 478 | std::size_t k, 479 | Distance maxRadius = std::numeric_limits::infinity()) const 480 | { 481 | nearest(result, key, k, maxRadius, [] (const Node& n) -> const auto& { return n.data_; }); 482 | } 483 | 484 | template 485 | void visitAll(_Fn&& f) const { 486 | for (auto& n : nodes_) 487 | f(n.data_); 488 | } 489 | }; 490 | 491 | 492 | }}} 493 | 494 | #endif // UNC_ROBOTICS_KDTREE_KDTREE_MEDIAN_HPP 495 | -------------------------------------------------------------------------------- /src/_compoundspace.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Jeffrey Ichnowski 2 | // All rights reserved. 3 | // 4 | // BSD 3 Clause 5 | // 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions 8 | // are met: 9 | // 1. Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // 2. Redistributions in binary form must reproduce the above copyright 12 | // notice, this list of conditions and the following disclaimer in the 13 | // documentation and/or other materials provided with the distribution. 14 | // 3. Neither the name of the copyright holder nor the names of its 15 | // contributors may be used to endorse or promote products derived 16 | // from this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 21 | // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 22 | // COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 23 | // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 26 | // HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 27 | // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 29 | // OF THE POSSIBILITY OF SUCH DAMAGE. 30 | #pragma once 31 | #ifndef UNC_ROBOTICS_KDTREE_COMPOUNDSPACE_HPP 32 | #define UNC_ROBOTICS_KDTREE_COMPOUNDSPACE_HPP 33 | 34 | namespace unc { namespace robotics { namespace kdtree { namespace detail { 35 | 36 | template 37 | struct MidpointCompoundHelper { 38 | typedef CompoundSpace<_Spaces...> Space; 39 | typedef typename Space::State Key; 40 | typedef typename Space::Distance Distance; 41 | 42 | typedef MidpointCompoundHelper<_Node, _index+1, _Spaces...> Next; 43 | 44 | template 45 | static constexpr unsigned dimensions(_Traversals& traversals, unsigned sum) { 46 | return Next::dimensions(traversals, sum + std::get<_index>(traversals).dimensions()); 47 | } 48 | 49 | template 50 | static inline Distance maxAxis(_Traversals& traversals, unsigned dimBefore, Distance bestDist, unsigned *bestAxis) { 51 | unsigned axis; 52 | typename Space::Distance d = std::get<_index>(traversals).maxAxis(&axis); 53 | if (d > bestDist) { 54 | *bestAxis = dimBefore + axis; 55 | bestDist = d; 56 | } 57 | return Next::maxAxis(traversals, dimBefore + std::get<_index>(traversals).dimensions(), bestDist, bestAxis); 58 | } 59 | 60 | template 61 | static inline void addImpl(_Traversals& traversals, _Adder& adder, unsigned dimBefore, unsigned axis, _Node* p, _Node* n) { 62 | unsigned dimAfter = dimBefore + std::get<_index>(traversals).dimensions(); 63 | if (axis < dimAfter) { 64 | std::get<_index>(traversals).addImpl(adder, axis - dimBefore, p, n); 65 | } else { 66 | Next::addImpl(traversals, adder, dimAfter, axis, p, n); 67 | } 68 | } 69 | 70 | template 71 | static inline Distance keyDistance(const _Traversals& traversals, const _State& q, Distance sum) { 72 | return Next::keyDistance(traversals, q, sum + std::get<_index>(traversals).keyDistance(std::get<_index>(q))); 73 | } 74 | 75 | template 76 | static inline Distance distToRegion(const _Traversals& traversals, Distance sum) { 77 | return Next::distToRegion(traversals, sum + std::get<_index>(traversals).distToRegion()); 78 | } 79 | 80 | template 81 | static inline void traverse(_Traversals& traversals, _Nearest& nearest, const _Node* n, unsigned dimBefore, unsigned axis) { 82 | unsigned dimAfter = dimBefore + std::get<_index>(traversals).dimensions(); 83 | if (axis < dimAfter) { 84 | std::get<_index>(traversals).traverse(nearest, n, axis - dimBefore); 85 | } else { 86 | Next::traverse(traversals, nearest, n, dimAfter, axis); 87 | } 88 | } 89 | }; 90 | 91 | 92 | template 93 | struct MidpointCompoundHelper<_Node, sizeof...(_Spaces)-1, _Spaces...> { 94 | typedef CompoundSpace<_Spaces...> Space; 95 | typedef typename Space::State Key; 96 | typedef typename Space::Distance Distance; 97 | static constexpr int _index = sizeof...(_Spaces)-1; 98 | 99 | template 100 | static constexpr unsigned dimensions(_Traversals& traversals, unsigned sum) { 101 | return sum + std::get<_index>(traversals).dimensions(); 102 | } 103 | 104 | template 105 | static inline Distance maxAxis(_Traversals& traversals, unsigned dimBefore, Distance bestDist, unsigned *bestAxis) { 106 | unsigned axis; 107 | typename Space::Distance d = std::get<_index>(traversals).maxAxis(&axis); 108 | if (d > bestDist) { 109 | *bestAxis = dimBefore + axis; 110 | bestDist = d; 111 | } 112 | return bestDist; 113 | } 114 | 115 | template 116 | static inline void addImpl(_Traversals& traversals, _Adder& adder, unsigned dimBefore, unsigned axis, _Node* p, _Node* n) { 117 | std::get<_index>(traversals).addImpl(adder, axis - dimBefore, p, n); 118 | } 119 | 120 | template 121 | static inline Distance keyDistance(_Traversals& traversals, const _State& q, Distance sum) { 122 | return sum + std::get<_index>(traversals).keyDistance(std::get<_index>(q)); 123 | } 124 | 125 | template 126 | static inline Distance distToRegion(_Traversals& traversals, Distance sum) { 127 | return sum + std::get<_index>(traversals).distToRegion(); 128 | } 129 | 130 | template 131 | static inline void traverse(_Traversals& traversals, _Nearest& nearest, const _Node* n, unsigned dimBefore, unsigned axis) { 132 | unsigned dimAfter = dimBefore + std::get<_index>(traversals).dimensions(); 133 | assert(axis < dimAfter); 134 | std::get<_index>(traversals).traverse(nearest, n, axis - dimBefore); 135 | } 136 | }; 137 | 138 | 139 | // Alternate base case for MidpointCompoundHelper, 140 | // 141 | // template 142 | // struct MidpointCompoundHelper<_Node, sizeof...(_Spaces), _Spaces...> { 143 | // typedef CompoundSpace<_Spaces...> Space; 144 | // typedef typename Space::State Key; 145 | // typedef typename Space::Distance Distance; 146 | 147 | // template 148 | // static inline Distance maxAxis(_Traversals& traversals, unsigned dimBefore, Distance bestDist, unsigned *bestAxis) { 149 | // return bestDist; 150 | // } 151 | 152 | // template 153 | // static inline void addImpl(_Traversals& traversals, _Adder& adder, unsigned dimBefore, unsigned axis, _Node* p, _Node* n) { 154 | // assert(false); // should not happen 155 | // } 156 | 157 | // template 158 | // static inline Distance keyDistance(_Traversals& traversals, const _State& q, Distance sum) { 159 | // return sum; 160 | // } 161 | 162 | // template 163 | // static inline Distance distToRegion(_Traversals& traversals, Distance sum) { 164 | // return sum; 165 | // } 166 | 167 | // template 168 | // static inline void traverse(_Traversals& traversals, _Nearest& nearest, const _Node* n, unsigned dimBefore, unsigned axis) { 169 | // assert(false); // should not happen 170 | // } 171 | // }; 172 | 173 | template 174 | struct MidpointAddTraversal<_Node, CompoundSpace<_Spaces...>> { 175 | typedef CompoundSpace<_Spaces...> Space; 176 | typedef typename Space::State Key; 177 | typedef typename Space::Distance Distance; 178 | typedef typename std::tuple<_Spaces...> Tuple; 179 | 180 | const Space& space_; 181 | std::tuple...> traversals_; 182 | 183 | template 184 | MidpointAddTraversal(const Space& space, const _Key& key, std::index_sequence) 185 | : space_(space), 186 | traversals_(MidpointAddTraversal<_Node, typename std::tuple_element::type>( 187 | std::get(space), std::get(key))...) 188 | { 189 | } 190 | 191 | template 192 | MidpointAddTraversal(const Space& space, const _Key& key) 193 | : MidpointAddTraversal(space, key, std::make_index_sequence{}) 194 | { 195 | } 196 | 197 | constexpr unsigned dimensions() const { 198 | return MidpointCompoundHelper<_Node, 0, _Spaces...>::dimensions(traversals_, 0); 199 | } 200 | 201 | inline Distance maxAxis(unsigned* axis) { 202 | Distance d = std::get<0>(traversals_).maxAxis(axis); 203 | return MidpointCompoundHelper<_Node, 1, _Spaces...>::maxAxis( 204 | traversals_, std::get<0>(space_).dimensions(), d, axis); 205 | } 206 | 207 | template 208 | void addImpl(_Adder& adder, unsigned axis, _Node* p, _Node* n) { 209 | MidpointCompoundHelper<_Node, 0, _Spaces...>::addImpl( 210 | traversals_, adder, 0, axis, p, n); 211 | } 212 | }; 213 | 214 | template 215 | struct MidpointNearestTraversal<_Node, CompoundSpace<_Spaces...>> { 216 | typedef CompoundSpace<_Spaces...> Space; 217 | typedef typename Space::State Key; 218 | typedef typename Space::Distance Distance; 219 | typedef typename std::tuple<_Spaces...> Tuple; 220 | 221 | const Space& space_; 222 | std::tuple...> traversals_; 223 | 224 | template 225 | MidpointNearestTraversal( 226 | const Space& space, const _Key& key, std::index_sequence) 227 | : space_(space), 228 | traversals_(MidpointNearestTraversal<_Node, typename std::tuple_element::type>( 229 | std::get(space), std::get(key))...) 230 | { 231 | } 232 | 233 | template 234 | MidpointNearestTraversal(const Space& space, const _Key& key) 235 | : MidpointNearestTraversal(space, key, std::make_index_sequence{}) 236 | { 237 | } 238 | 239 | constexpr unsigned dimensions() const { 240 | return MidpointCompoundHelper<_Node, 0, _Spaces...>::dimensions(traversals_, 0); 241 | } 242 | 243 | template 244 | Distance keyDistance(const _State& q) const { 245 | return MidpointCompoundHelper<_Node, 1, _Spaces...>::keyDistance( 246 | traversals_, q, std::get<0>(traversals_).keyDistance(std::get<0>(q))); 247 | } 248 | 249 | inline Distance distToRegion() const { 250 | return MidpointCompoundHelper<_Node, 1, _Spaces...>::distToRegion( 251 | traversals_, std::get<0>(traversals_).distToRegion()); 252 | } 253 | 254 | template 255 | void traverse(_Nearest& nearest, const _Node* n, unsigned axis) { 256 | MidpointCompoundHelper<_Node, 0, _Spaces...>::traverse( 257 | traversals_, nearest, n, 0, axis); 258 | } 259 | }; 260 | 261 | template 262 | struct CompoundMedianHelper { 263 | typedef CompoundSpace<_Spaces...> Space; 264 | typedef typename Space::State Key; 265 | typedef typename Space::Distance Distance; 266 | typedef std::tuple...> Accums; 267 | typedef std::tuple...> Traversals; 268 | typedef CompoundMedianHelper<_index+1, _Spaces...> Next; 269 | 270 | template 271 | static unsigned dimensions(_Accums& accums, unsigned sum) { 272 | return Next::dimensions(accums, sum + std::get<_index>(accums).dimensions()); 273 | } 274 | 275 | template 276 | static void init(Accums& accums, const _Key& q) { 277 | std::get<_index>(accums).init(std::get<_index>(q)); 278 | return Next::init(accums, q); 279 | } 280 | 281 | template 282 | static void accum(Accums& accums, const _Key& q) { 283 | std::get<_index>(accums).accum(std::get<_index>(q)); 284 | return Next::accum(accums, q); 285 | } 286 | 287 | static Distance maxAxis(Accums& accums, unsigned dimBefore, Distance dist, unsigned *axis) { 288 | unsigned a; 289 | Distance d = std::get<_index>(accums).maxAxis(&a); 290 | if (d > dist) { 291 | dist = d; 292 | *axis = a + dimBefore; 293 | } 294 | return Next::maxAxis(accums, dimBefore + std::get<_index>(accums).dimensions(), dist, axis); 295 | } 296 | 297 | template 298 | static void partition( 299 | Accums& accums, _Builder& builder, unsigned axis, 300 | _Iter begin, _Iter end, 301 | const _GetKey& getKey) 302 | { 303 | unsigned dim = std::get<_index>(accums).dimensions(); 304 | if (axis < dim) { 305 | std::get<_index>(accums).partition( 306 | builder, axis, begin, end, 307 | [&] (auto& t) { return std::get<_index>(getKey(t)); }); // TODO: -> const auto& needed? 308 | } else { 309 | Next::partition(accums, builder, axis - dim, begin, end, getKey); 310 | } 311 | } 312 | 313 | static Distance distToRegion(const Traversals& traversals, Distance sum) { 314 | return Next::distToRegion(traversals, sum + std::get<_index>(traversals).distToRegion()); 315 | } 316 | 317 | template 318 | static Distance keyDistance(const Traversals& traversals, const _Key& q, Distance sum) { 319 | return Next::keyDistance( 320 | traversals, q, 321 | sum + std::get<_index>(traversals).keyDistance(std::get<_index>(q))); 322 | } 323 | 324 | template 325 | static void traverse(Traversals& traversals, _Nearest& nearest, unsigned axis, _Iter begin, _Iter end) { 326 | unsigned dim = std::get<_index>(traversals).dimensions(); 327 | if (axis < dim) { 328 | std::get<_index>(traversals).traverse(nearest, axis, begin, end); 329 | } else { 330 | Next::traverse(traversals, nearest, axis - dim, begin, end); 331 | } 332 | } 333 | }; 334 | 335 | template 336 | struct CompoundMedianHelper { 337 | typedef CompoundSpace<_Spaces...> Space; 338 | typedef typename Space::State Key; 339 | typedef typename Space::Distance Distance; 340 | typedef std::tuple...> Accums; 341 | typedef std::tuple...> Traversals; 342 | static constexpr int _index = sizeof...(_Spaces)-1; 343 | 344 | template 345 | static unsigned dimensions(_Accums& accums, unsigned sum) { 346 | return sum + std::get<_index>(accums).dimensions(); 347 | } 348 | 349 | template 350 | static void init(Accums& accums, const _Key& q) { 351 | std::get<_index>(accums).init(std::get<_index>(q)); 352 | } 353 | 354 | template 355 | static void accum(Accums& accums, const _Key& q) { 356 | std::get<_index>(accums).accum(std::get<_index>(q)); 357 | } 358 | 359 | static Distance maxAxis(Accums& accums, unsigned dimBefore, Distance dist, unsigned *axis) { 360 | unsigned a; 361 | Distance d = std::get<_index>(accums).maxAxis(&a); 362 | if (d > dist) { 363 | dist = d; 364 | *axis = a + dimBefore; 365 | } 366 | return dist; 367 | } 368 | 369 | template 370 | static void partition( 371 | Accums& accums, _Builder& builder, unsigned axis, 372 | _Iter begin, _Iter end, 373 | const _GetKey& getKey) 374 | { 375 | std::get<_index>(accums).partition( 376 | builder, axis, begin, end, 377 | [&] (auto& t) { return std::get<_index>(getKey(t)); }); // TODO: -> const auto& needed? 378 | } 379 | 380 | static Distance distToRegion(const Traversals& traversals, Distance sum) { 381 | return sum + std::get<_index>(traversals).distToRegion(); 382 | } 383 | 384 | template 385 | static Distance keyDistance(const Traversals& traversals, const _Key& q, Distance sum) { 386 | return sum + std::get<_index>(traversals).keyDistance(std::get<_index>(q)); 387 | } 388 | 389 | template 390 | static void traverse(Traversals& traversals, _Nearest& nearest, unsigned axis, _Iter begin, _Iter end) { 391 | std::get<_index>(traversals).traverse(nearest, axis, begin, end); 392 | } 393 | 394 | }; 395 | 396 | template 397 | struct MedianAccum> { 398 | typedef CompoundSpace<_Spaces...> Space; 399 | typedef typename Space::State Key; 400 | typedef typename Space::Distance Distance; 401 | 402 | std::tuple...> accums_; 403 | 404 | template 405 | MedianAccum(const Space& space, std::index_sequence) 406 | : accums_(std::get(space)...) 407 | { 408 | } 409 | 410 | MedianAccum(const Space& space) 411 | : MedianAccum(space, std::make_index_sequence{}) 412 | { 413 | } 414 | 415 | constexpr unsigned dimensions() const { 416 | return CompoundMedianHelper<0, _Spaces...>::dimensions(accums_, 0); 417 | } 418 | 419 | template 420 | void init(const _Key& q) { 421 | CompoundMedianHelper<0, _Spaces...>::init(accums_, q); 422 | } 423 | 424 | template 425 | void accum(const _Key& q) { 426 | CompoundMedianHelper<0, _Spaces...>::accum(accums_, q); 427 | } 428 | 429 | Distance maxAxis(unsigned *axis) { 430 | Distance d = std::get<0>(accums_).maxAxis(axis); 431 | return CompoundMedianHelper<1, _Spaces...>::maxAxis( 432 | accums_, std::get<0>(accums_).dimensions(), d, axis); 433 | } 434 | 435 | template 436 | void partition(_Builder& builder, unsigned axis, _Iter begin, _Iter end, const _GetKey& getKey) { 437 | CompoundMedianHelper<0, _Spaces...>::partition( 438 | accums_, builder, axis, begin, end, getKey); 439 | } 440 | }; 441 | 442 | template 443 | struct MedianNearestTraversal> { 444 | typedef CompoundSpace<_Spaces...> Space; 445 | typedef typename Space::State Key; 446 | typedef typename Space::Distance Distance; 447 | 448 | typedef std::tuple<_Spaces...> Tuple; 449 | 450 | std::tuple...> traversals_; 451 | 452 | template 453 | MedianNearestTraversal(const Space& space, const _Key& key, std::index_sequence) 454 | : traversals_( 455 | MedianNearestTraversal::type>( 456 | std::get(space), 457 | std::get(key))...) 458 | { 459 | } 460 | 461 | template 462 | MedianNearestTraversal(const Space& space, const _Key& key) 463 | : MedianNearestTraversal(space, key, std::make_index_sequence{}) 464 | { 465 | } 466 | 467 | constexpr unsigned dimensions() const { 468 | return CompoundMedianHelper<0, _Spaces...>::dimensions(traversals_, 0); 469 | } 470 | 471 | Distance distToRegion() const { 472 | return CompoundMedianHelper<1, _Spaces...>::distToRegion( 473 | traversals_, std::get<0>(traversals_).distToRegion()); 474 | } 475 | 476 | template 477 | Distance keyDistance(const _State& q) const { 478 | return CompoundMedianHelper<1, _Spaces...>::keyDistance( 479 | traversals_, q, std::get<0>(traversals_).keyDistance(std::get<0>(q))); 480 | } 481 | 482 | template 483 | void traverse(_Nearest& nearest, unsigned axis, _Iter begin, _Iter end) { 484 | CompoundMedianHelper<0, _Spaces...>::traverse( 485 | traversals_, nearest, axis, begin, end); 486 | } 487 | }; 488 | 489 | }}}} 490 | 491 | #endif // UNC_ROBOTICS_KDTREE_COMPOUNDSPACE_HPP 492 | -------------------------------------------------------------------------------- /src/_kdtree_midpoint.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Jeffrey Ichnowski 2 | // All rights reserved. 3 | // 4 | // BSD 3 Clause 5 | // 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions 8 | // are met: 9 | // 1. Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // 2. Redistributions in binary form must reproduce the above copyright 12 | // notice, this list of conditions and the following disclaimer in the 13 | // documentation and/or other materials provided with the distribution. 14 | // 3. Neither the name of the copyright holder nor the names of its 15 | // contributors may be used to endorse or promote products derived 16 | // from this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 21 | // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 22 | // COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 23 | // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 26 | // HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 27 | // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 29 | // OF THE POSSIBILITY OF SUCH DAMAGE. 30 | #pragma once 31 | #ifndef UNC_ROBOTICS_KDTREE_KDTREE_MIDPOINT_HPP 32 | #define UNC_ROBOTICS_KDTREE_KDTREE_MIDPOINT_HPP 33 | 34 | #include "_spaces.hpp" 35 | 36 | #include 37 | #include 38 | #include 39 | 40 | namespace unc { namespace robotics { namespace kdtree { 41 | 42 | // When using the intrusive version of the KDTree, the caller must 43 | // provide a member node. 44 | template 45 | struct MidpointSplitNodeMember; 46 | 47 | template 48 | struct MidpointSplitNodeMember<_Node, SingleThread> { 49 | typedef SingleThread Locking; 50 | 51 | std::array<_Node*, 2> children_{}; 52 | 53 | MidpointSplitNodeMember(const MidpointSplitNodeMember&) = delete; 54 | 55 | constexpr _Node* child(int no) { return children_[no]; } 56 | constexpr const _Node* child(int no) const { return children_[no]; } 57 | inline constexpr bool hasChild() const { return children_[0] != children_[1]; } 58 | inline bool update(int no, _Node*, _Node* n) { 59 | children_[no] = n; 60 | return true; 61 | } 62 | }; 63 | 64 | template 65 | struct MidpointSplitNodeMember<_Node, MultiThread> { 66 | typedef MultiThread Locking; 67 | 68 | std::array, 2> children_{}; 69 | 70 | MidpointSplitNodeMember(const MidpointSplitNodeMember&) = delete; 71 | 72 | constexpr _Node* child(int no) { return children_[no].load(std::memory_order_acquire); } 73 | constexpr const _Node* child(int no) const { 74 | return children_[no].load(std::memory_order_relaxed); 75 | } 76 | inline constexpr bool hasChild() const { 77 | return children_[0].load(std::memory_order_relaxed) != children_[1].load(std::memory_order_relaxed); 78 | } 79 | inline bool update(int no, _Node*& c, _Node* n) { 80 | return children_[no].compare_exchange_weak( 81 | c, n, std::memory_order_release, std::memory_order_relaxed); 82 | } 83 | }; 84 | 85 | namespace detail { 86 | 87 | // MidpointSplitNode is used in the default non-intrusive KDTree 88 | // implementation with MidpointSplits. It extends the value type and 89 | // adds the required intrusive KDTree child members. 90 | template 91 | struct MidpointSplitNode { 92 | _T value_; 93 | 94 | MidpointSplitNodeMember, _Locking> children_{}; 95 | 96 | MidpointSplitNode(const MidpointSplitNode&) = delete; 97 | MidpointSplitNode(MidpointSplitNode&&) = delete; 98 | 99 | MidpointSplitNode(const _T& value) : value_(value) {} 100 | template 101 | MidpointSplitNode(_Args&& ... args) : value_(std::forward<_Args>(args)...) {} 102 | }; 103 | 104 | // This class is usually not require, and _GetKey could be used 105 | // directly instead. However, there is a chance the caller could 106 | // provide a _GetKey that would unexpectectedly handle the derived 107 | // class of _T that we use in the default implementation. 108 | template 109 | struct MidpointSplitNodeKey : _GetKey { 110 | inline MidpointSplitNodeKey(const _GetKey& getKey) : _GetKey(getKey) {} 111 | 112 | constexpr decltype(auto) operator() (const MidpointSplitNode<_T, _Locking>& node) const { 113 | return _GetKey::operator()(static_cast(node.value_)); 114 | } 115 | 116 | constexpr decltype(auto) operator() (MidpointSplitNode<_T, _Locking>& node) const { 117 | return _GetKey::operator()(static_cast<_T&>(node.value_)); 118 | } 119 | }; 120 | 121 | template 122 | struct MidpointSplitRoot; 123 | 124 | template 125 | struct MidpointSplitRoot<_Node, SingleThread, _Allocator> : _Allocator { 126 | _Node *root_ = nullptr; 127 | std::size_t size_ = 0; 128 | 129 | MidpointSplitRoot(const _Allocator& alloc) : _Allocator(alloc) {} 130 | 131 | constexpr const _Node* get() const { 132 | return root_; 133 | } 134 | 135 | constexpr _Node* get() { 136 | return root_; 137 | } 138 | 139 | inline _Node* update(_Node *node) { 140 | _Node *root = root_; 141 | if (root_ == nullptr) 142 | root_ = node; 143 | return root; 144 | } 145 | }; 146 | 147 | template 148 | struct MidpointSplitRoot<_Node, MultiThread, _Allocator> : _Allocator { 149 | std::atomic<_Node*> root_{}; 150 | std::atomic size_{}; 151 | 152 | MidpointSplitRoot(const _Allocator& alloc) : _Allocator(alloc) {} 153 | 154 | constexpr const _Node* get() const { 155 | return root_.load(std::memory_order_relaxed); 156 | } 157 | 158 | constexpr _Node* get() { 159 | return root_.load(std::memory_order_acquire); 160 | } 161 | 162 | inline _Node* update(_Node *node) { 163 | _Node *root = root_.load(std::memory_order_acquire); 164 | while (root == nullptr) 165 | if (root_.compare_exchange_weak(root, node, std::memory_order_release, std::memory_order_relaxed)) 166 | return nullptr; 167 | return root; 168 | } 169 | }; 170 | 171 | template 172 | struct MidpointAxisCache; 173 | 174 | template <> 175 | struct MidpointAxisCache { 176 | unsigned axis_; 177 | MidpointAxisCache* next_; 178 | 179 | MidpointAxisCache(unsigned axis) : axis_(axis), next_(nullptr) {} 180 | 181 | constexpr MidpointAxisCache* next() { return next_; } 182 | constexpr const MidpointAxisCache* next() const { return next_; } 183 | 184 | template 185 | inline MidpointAxisCache* next(unsigned axis, _Allocator& allocator) { 186 | // return next_ = new MidpointAxisCache(axis); 187 | typedef std::allocator_traits<_Allocator> Traits; 188 | MidpointAxisCache *n = Traits::allocate(allocator, 1); 189 | Traits::construct(allocator, n, axis); 190 | return next_ = n; 191 | } 192 | }; 193 | 194 | template <> 195 | struct MidpointAxisCache { 196 | unsigned axis_; 197 | std::atomic next_{}; 198 | 199 | MidpointAxisCache(unsigned axis) : axis_(axis) {} 200 | 201 | constexpr MidpointAxisCache* next() { return next_.load(std::memory_order_acquire); } 202 | constexpr const MidpointAxisCache* next() const { return next_.load(std::memory_order_acquire); } 203 | 204 | template 205 | MidpointAxisCache* next(unsigned axis, _Allocator& allocator) { 206 | typedef std::allocator_traits<_Allocator> Traits; 207 | // MidpointAxisCache* next = new MidpointAxisCache(axis); 208 | MidpointAxisCache* next = Traits::allocate(allocator, 1); 209 | Traits::construct(allocator, next, axis); 210 | MidpointAxisCache* prev = nullptr; 211 | if (next_.compare_exchange_strong(prev, next)) 212 | return next; 213 | 214 | // other thread beat this thread to the update. 215 | assert(prev->axis_ == axis); 216 | // delete next; 217 | Traits::destroy(allocator, next); 218 | Traits::deallocate(allocator, next, 1); 219 | return prev; 220 | } 221 | }; 222 | 223 | template < 224 | typename _Node, 225 | typename _Space, 226 | typename _GetKey, 227 | typename _Locking, 228 | MidpointSplitNodeMember<_Node, _Locking> _Node::* _member, 229 | typename _Allocator = std::allocator<_Node>> 230 | struct KDTreeMidpointSplitIntrusiveImpl 231 | { 232 | typedef _Space Space; 233 | typedef _Node Node; 234 | typedef typename Space::Distance Distance; 235 | typedef typename Space::State Key; 236 | typedef MidpointSplitNodeMember Member; 237 | typedef MidpointAxisCache<_Locking> AxisCache; 238 | 239 | typedef std::allocator_traits<_Allocator> AllocatorTraits; 240 | typedef typename AllocatorTraits::template rebind_alloc AxisCacheAllocator; 241 | 242 | Space space_; 243 | _GetKey getKey_; 244 | MidpointSplitRoot root_; 245 | AxisCache axisCache_; 246 | 247 | struct Adder : AxisCacheAllocator { 248 | MidpointAddTraversal traversal_; 249 | MidpointAxisCache<_Locking>* axisCache_; 250 | 251 | template 252 | Adder(KDTreeMidpointSplitIntrusiveImpl& tree, const _Key& key) 253 | : AxisCacheAllocator(tree.root_), // root is an allocator 254 | traversal_(tree.space_, key), 255 | axisCache_(&tree.axisCache_) 256 | { 257 | } 258 | 259 | static constexpr _Node* child(_Node *p, int childNo) { 260 | return (p->*_member).child(childNo); 261 | } 262 | 263 | static constexpr bool update(Node *p, int childNo, Node*& c, Node* n) { 264 | return (p->*_member).update(childNo, c, n); 265 | } 266 | 267 | void operator() (Node* p, Node* n) { 268 | MidpointAxisCache<_Locking>* nextAxis = axisCache_->next(); 269 | if (nextAxis == nullptr) { 270 | unsigned axis; 271 | traversal_.maxAxis(&axis); 272 | nextAxis = axisCache_->next(axis, *this); // *this is an allocator 273 | } 274 | axisCache_ = nextAxis; 275 | traversal_.addImpl(*this, axisCache_->axis_, p, n); 276 | } 277 | }; 278 | 279 | template 280 | struct Nearest { 281 | MidpointNearestTraversal<_Node, _Space> traversal_; 282 | const KDTreeMidpointSplitIntrusiveImpl& tree_; 283 | Distance dist_; 284 | const MidpointAxisCache<_Locking>* axisCache_; 285 | 286 | template 287 | Nearest( 288 | const KDTreeMidpointSplitIntrusiveImpl& tree, 289 | const _Key& key, 290 | Distance dist = std::numeric_limits::infinity()) 291 | : traversal_(tree.space_, key), 292 | tree_(tree), 293 | dist_(dist), 294 | axisCache_(&tree.axisCache_) 295 | { 296 | } 297 | 298 | constexpr bool shouldTraverse() const { 299 | return traversal_.distToRegion() <= dist_; 300 | } 301 | 302 | static constexpr const _Node* child(const _Node* n, int no) { 303 | return (n->*_member).child(no); 304 | } 305 | 306 | inline void update(const _Node* n) { 307 | Distance d = traversal_.keyDistance(tree_.getKey_(*n)); 308 | if (d <= dist_) { 309 | static_cast<_Derived*>(this)->update(d, n); 310 | } 311 | } 312 | 313 | inline void operator() (const _Node* n) { 314 | if ((n->*_member).hasChild()) { 315 | const MidpointAxisCache<_Locking> *oldCache = axisCache_; 316 | axisCache_ = axisCache_->next(); 317 | traversal_.traverse(*this, n, axisCache_->axis_); 318 | axisCache_ = oldCache; 319 | } else { 320 | update(n); 321 | } 322 | } 323 | }; 324 | 325 | struct Nearest1 : Nearest { 326 | const _Node *nearest_ = nullptr; 327 | 328 | using Nearest::Nearest; 329 | using Nearest::dist_; 330 | 331 | inline void update(Distance d, const _Node* n) { 332 | dist_ = d; 333 | nearest_ = n; 334 | } 335 | }; 336 | 337 | template 338 | struct NearestK : Nearest> { 339 | std::vector, _ResultAllocator>& nearest_; 340 | std::size_t k_; 341 | _NodeValueFn nodeValueFn_; 342 | 343 | using Nearest::dist_; 344 | 345 | template 346 | NearestK( 347 | const KDTreeMidpointSplitIntrusiveImpl& tree, 348 | std::vector, _ResultAllocator>& result, 349 | const _Key& key, 350 | std::size_t k, 351 | Distance dist, 352 | const _NodeValueFn& nodeValueFn) 353 | : Nearest(tree, key, dist), 354 | nearest_(result), 355 | k_(k), 356 | nodeValueFn_(nodeValueFn) 357 | { 358 | } 359 | 360 | void update(Distance d, const _Node* n) { 361 | #if 0 362 | // R^3 l2: 4.76677 us/op 363 | // R^6 l2: 18.6769 us/op 364 | // SO(3)F: 37.8118 us/op 365 | // SO(3)A: 27.7197 us/op 366 | // SO(3)R: 105.142 us/op 367 | // SE(3)F: 161.597 us/op 368 | 369 | 370 | if (nearest_.size() == k_) { 371 | std::pop_heap(nearest_.begin(), nearest_.end(), CompareSecond()); 372 | nearest_.pop_back(); 373 | } 374 | 375 | nearest_.emplace_back(nodeValueFn_(n), d); 376 | std::push_heap(nearest_.begin(), nearest_.end(), CompareSecond()); 377 | 378 | if (nearest_.size() == k_) 379 | dist_ = nearest_[0].second; 380 | #else 381 | // R^3 l2: 4.21419 us/op 382 | // R^6 l2: 16.1458 us/op 383 | // SO(3)F: 33.538 us/op 384 | // SO(3)A: 24.9379 us/op 385 | // SO(3)R: 95.3695 us/op 386 | // SE(3)F: 145.62 us/op 387 | 388 | if (nearest_.size() < k_) { 389 | // until we've reached k_ elements, we just collect 390 | // elements in nearest_. once k is reached, we 391 | // maintain a heap. 392 | nearest_.emplace_back(nodeValueFn_(n), d); 393 | if (nearest_.size() < k_) 394 | return; 395 | std::make_heap(nearest_.begin(), nearest_.end(), CompareSecond()); 396 | } else { 397 | // slightly hackery, pop_heap() operates by first 398 | // swapping the first and last element, then by moving 399 | // the first element into position. By first placing 400 | // the new element at the end, it will be swapped into 401 | // the top, then put in the correct position. The old 402 | // top will placed at the end, and we can then remove 403 | // it. This slightly shortens the pop_heap/push_heap 404 | // alternative at the expense of requiring nearest_ to 405 | // have a minimum capacity of k_ + 1. 406 | nearest_.emplace_back(nodeValueFn_(n), d); 407 | std::pop_heap(nearest_.begin(), nearest_.end(), CompareSecond()); 408 | nearest_.pop_back(); 409 | 410 | // std::pop_heap(nearest_.begin(), nearest_.end(), CompareSecond()); 411 | // nearest_.back() = std::make_pair(nodeValueFn_(n), d); 412 | // std::push_heap(nearest_.begin(), nearest_.end(), CompareSecond()); 413 | } 414 | 415 | dist_ = nearest_[0].second; 416 | #endif 417 | } 418 | }; 419 | 420 | 421 | KDTreeMidpointSplitIntrusiveImpl( 422 | const Space& space, 423 | const _GetKey& getKey, 424 | const _Allocator& alloc = _Allocator()) 425 | : space_(space), 426 | getKey_(getKey), 427 | root_(alloc), 428 | axisCache_(~0) 429 | { 430 | } 431 | 432 | ~KDTreeMidpointSplitIntrusiveImpl() { 433 | typedef std::allocator_traits Traits; 434 | AxisCacheAllocator alloc(root_); 435 | 436 | for (AxisCache *n, *c = axisCache_.next() ; c ; c = n) { 437 | n = c->next(); 438 | Traits::destroy(alloc, c); 439 | Traits::deallocate(alloc, c, 1); 440 | } 441 | } 442 | 443 | constexpr _Allocator& allocator() { return root_; } 444 | 445 | constexpr std::size_t size() const { 446 | return root_.size_; 447 | } 448 | 449 | template 450 | void clear(const _Destroy& destroy) { 451 | clear(root_.get(), destroy); 452 | } 453 | 454 | template 455 | void clear(Node *n, const _Destroy& destroy) { 456 | if (n) { 457 | clear((n->*_member).child(0), destroy); 458 | Node *c = (n->*_member).child(1); 459 | destroy(n); 460 | clear(c, destroy); // tail recursion 461 | } 462 | } 463 | 464 | void add(Node* node) { 465 | if (Node* root = root_.update(node)) { 466 | Adder adder(*this, getKey_(*node)); 467 | adder(root, node); 468 | } 469 | ++root_.size_; 470 | } 471 | 472 | // TODO: support non-const return on non-const call? 473 | template 474 | const _Node* nearest(const _Key& key, Distance* distOut = nullptr) const { 475 | if (const Node* root = root_.get()) { 476 | Nearest1 nearest(*this, key); 477 | nearest(root); 478 | if (distOut) 479 | *distOut = nearest.dist_; 480 | return nearest.nearest_; 481 | } 482 | return nullptr; 483 | } 484 | 485 | template 486 | void nearest( 487 | std::vector, _ResultAllocator>& result, 488 | const _Key& key, 489 | std::size_t k, 490 | Distance maxDist, 491 | _NodeValueFn&& nodeValue) const 492 | { 493 | result.clear(); 494 | if (k == 0) 495 | return; 496 | 497 | // std::cout << "nearest = " << &result << std::endl; 498 | // result.size(); 499 | if (const Node* root = root_.get()) { 500 | NearestK<_Value, _ResultAllocator, _NodeValueFn> nearest( 501 | *this, result, key, k, maxDist, std::forward<_NodeValueFn>(nodeValue)); 502 | nearest(root); 503 | if (result.size() < k) { 504 | std::sort(result.begin(), result.end(), CompareSecond()); 505 | } else { 506 | std::sort_heap(result.begin(), result.end(), CompareSecond()); 507 | } 508 | } 509 | } 510 | 511 | template 512 | void visitAll(const Node *n, _Fn&& f) const { 513 | if (n) { 514 | f(n); 515 | visitAll((n->*_member).child(0), f); 516 | visitAll((n->*_member).child(1), f); 517 | } 518 | } 519 | 520 | template 521 | void visitAll(_Fn&& f) const { 522 | visitAll(root_.get(), f); 523 | } 524 | }; 525 | 526 | } // namespace detail 527 | 528 | template < 529 | typename _T, 530 | typename _Space, 531 | typename _GetKey, 532 | typename _Locking, 533 | typename _Allocator> 534 | struct KDTree<_T, _Space, _GetKey, MidpointSplit, DynamicBuild, _Locking, _Allocator> 535 | : private detail::KDTreeMidpointSplitIntrusiveImpl< 536 | detail::MidpointSplitNode<_T, _Locking>, 537 | _Space, 538 | detail::MidpointSplitNodeKey<_T, _Locking, _GetKey>, 539 | _Locking, 540 | &detail::MidpointSplitNode<_T, _Locking>::children_, 541 | typename std::allocator_traits<_Allocator>::template rebind_alloc< 542 | detail::MidpointSplitNode<_T, _Locking>>> 543 | { 544 | typedef detail::KDTreeMidpointSplitIntrusiveImpl< 545 | detail::MidpointSplitNode<_T, _Locking>, 546 | _Space, 547 | detail::MidpointSplitNodeKey<_T, _Locking, _GetKey>, 548 | _Locking, 549 | &detail::MidpointSplitNode<_T, _Locking>::children_, 550 | typename std::allocator_traits<_Allocator>::template rebind_alloc< 551 | detail::MidpointSplitNode<_T, _Locking>>> Base; 552 | 553 | typedef _Space Space; 554 | typedef typename Space::Distance Distance; 555 | typedef typename Space::State Key; 556 | typedef detail::MidpointSplitNode<_T, _Locking> Node; 557 | typedef std::allocator_traits<_Allocator> AllocatorTraits; 558 | typedef typename AllocatorTraits::template rebind_alloc NodeAllocator; 559 | typedef std::allocator_traits NodeAllocatorTraits; 560 | 561 | public: 562 | KDTree( 563 | const Space& space, 564 | const _GetKey& getKey = _GetKey(), 565 | const _Allocator& alloc = _Allocator()) 566 | : Base( 567 | space, 568 | detail::MidpointSplitNodeKey<_T, _Locking, _GetKey>(getKey), 569 | NodeAllocator(alloc)) // TODO: allocator for axiscache 570 | // nodes_(NodeAllocator(alloc)) 571 | { 572 | } 573 | 574 | ~KDTree() { 575 | clear(); 576 | } 577 | 578 | void clear() { 579 | Base::clear([&] (Node *n) { 580 | NodeAllocatorTraits::destroy(Base::allocator(), n); 581 | NodeAllocatorTraits::deallocate(Base::allocator(), n, 1); 582 | }); 583 | } 584 | 585 | void add(const _T& arg) { 586 | // Base::add(new Node(arg)); 587 | typedef detail::AllocatorDestructor Destruct; 588 | NodeAllocator& na = Base::allocator(); 589 | std::unique_ptr hold(NodeAllocatorTraits::allocate(na, 1), Destruct(na, 1)); 590 | NodeAllocatorTraits::construct(na, hold.get(), arg); 591 | Base::add(hold.get()); 592 | hold.release(); 593 | } 594 | 595 | using Base::size; 596 | 597 | template 598 | void emplace(_Args&& ... args) { 599 | // Base::add(new Node(std::forward<_Args>(args)...)); 600 | 601 | typedef detail::AllocatorDestructor Destruct; 602 | NodeAllocator& na = Base::allocator(); 603 | std::unique_ptr hold(NodeAllocatorTraits::allocate(na, 1), Destruct(na, 1)); 604 | NodeAllocatorTraits::construct(na, hold.get(), std::forward<_Args>(args)...); 605 | Base::add(hold.get()); 606 | hold.release(); 607 | } 608 | 609 | template 610 | const _T* nearest(const _Key& key, Distance* distOut = nullptr) const { 611 | const Node* n = Base::nearest(key, distOut); 612 | return n == nullptr ? nullptr : &n->value_; 613 | } 614 | 615 | template 616 | void nearest( 617 | std::vector, _ResultAllocator>& result, 618 | const _Key& key, 619 | std::size_t k, 620 | Distance maxDist = std::numeric_limits::infinity()) const 621 | { 622 | Base::nearest(result, key, k, maxDist, [](const Node* n) -> const auto& { return n->value_; }); 623 | } 624 | 625 | template 626 | void visitAll(_Fn&& f) const { 627 | Base::visitAll([&] (const Node* n) { f(n->value_); }); 628 | } 629 | }; 630 | 631 | 632 | }}} 633 | 634 | #endif // UNC_ROBOTICS_KDTREE_KDTREE_MIDPOINT_HPP 635 | 636 | -------------------------------------------------------------------------------- /src/_so3altspace.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Jeffrey Ichnowski 2 | // All rights reserved. 3 | // 4 | // BSD 3 Clause 5 | // 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions 8 | // are met: 9 | // 1. Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // 2. Redistributions in binary form must reproduce the above copyright 12 | // notice, this list of conditions and the following disclaimer in the 13 | // documentation and/or other materials provided with the distribution. 14 | // 3. Neither the name of the copyright holder nor the names of its 15 | // contributors may be used to endorse or promote products derived 16 | // from this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 21 | // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 22 | // COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 23 | // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 26 | // HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 27 | // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 29 | // OF THE POSSIBILITY OF SUCH DAMAGE. 30 | #pragma once 31 | #ifndef UNC_ROBOTICS_KDTREE_SO3MINSPACE_HPP 32 | #define UNC_ROBOTICS_KDTREE_SO3MINSPACE_HPP 33 | 34 | // The SO(3) space with an alternative traversal strategy. This one 35 | // is considerbly simpler than the full SO(3) traversal strategy as it 36 | // only requires std::max and std::asin on each midpoint. This does 37 | // not sacrifice accuracy, but it does not cull out branch traversal 38 | // as tightly as the full traversal--so given an equal overhead 39 | // between the two this should be slower. However this method has a 40 | // significantly lower overhead, and can thus sometimes be faster. 41 | 42 | namespace unc { namespace robotics { namespace kdtree { namespace detail { 43 | 44 | template 45 | struct MidpointSO3MinTraversalBase { 46 | typedef SO3AltSpace<_Scalar> Space; 47 | typedef typename Space::State Key; 48 | 49 | Eigen::Matrix<_Scalar, 4, 1> key_; 50 | std::array, 2> soBounds_; 51 | unsigned soDepth_; 52 | unsigned keyVol_; 53 | 54 | MidpointSO3MinTraversalBase(const Space& space, const Key& key) 55 | : soDepth_(2), 56 | keyVol_(so3VolumeIndex(key)) 57 | { 58 | key_ = rotateCoeffs(key.coeffs(), keyVol_ + 1); 59 | if (key_[3] < 0) 60 | key_ = -key_; 61 | 62 | soBounds_[0] = M_SQRT1_2; 63 | soBounds_[1].colwise() = Eigen::Array<_Scalar, 2, 1>(-M_SQRT1_2, M_SQRT1_2); 64 | } 65 | 66 | constexpr unsigned dimensions() const { 67 | return 3; 68 | } 69 | 70 | constexpr _Scalar maxAxis(unsigned *axis) const { 71 | *axis = soDepth_ % 3; 72 | return M_PI / (1 << (soDepth_ / 3)); 73 | } 74 | }; 75 | 76 | template 77 | struct MidpointAddTraversal<_Node, SO3AltSpace<_Scalar>> 78 | : MidpointSO3MinTraversalBase<_Scalar> 79 | { 80 | typedef _Scalar Scalar; 81 | typedef SO3AltSpace Space; 82 | typedef typename Space::State Key; 83 | 84 | using MidpointSO3MinTraversalBase<_Scalar>::soDepth_; 85 | using MidpointSO3MinTraversalBase<_Scalar>::keyVol_; 86 | using MidpointSO3MinTraversalBase<_Scalar>::key_; 87 | 88 | MidpointAddTraversal(const Space& space, const Key& key) 89 | : MidpointSO3MinTraversalBase<_Scalar>(space, key) 90 | { 91 | } 92 | 93 | template 94 | void addImpl(_Adder& adder, unsigned axis, _Node* p, _Node *n) { 95 | int childNo; 96 | _Node *c; 97 | 98 | if (soDepth_ < 3) { 99 | c = _Adder::child(p, childNo = keyVol_ & 1); 100 | while (c == nullptr) 101 | if (_Adder::update(p, childNo, c, n)) 102 | return; 103 | 104 | // if ((c = p->children_[childNo = keyVol_ & 1]) == nullptr) { 105 | // p->children_[childNo] = n; 106 | // return; 107 | // } 108 | p = c; 109 | 110 | c = _Adder::child(p, childNo = keyVol_ >> 1); 111 | while (c == nullptr) 112 | if (_Adder::update(p, childNo, c, n)) 113 | return; 114 | 115 | // if ((c = p->children_[childNo = keyVol_ >> 1]) == nullptr) { 116 | // p->children_[childNo] = n; 117 | // return; 118 | // } 119 | 120 | ++soDepth_; 121 | adder(c, n); 122 | } else { 123 | Eigen::Matrix mp = (this->soBounds_[0].col(axis) + this->soBounds_[1].col(axis)) 124 | .matrix().normalized(); 125 | 126 | // assert(inSoBounds(keyVol_, 0, soBounds_, key_)); 127 | // assert(inSoBounds(keyVol_, 1, soBounds_, key_)); 128 | // assert(inSoBounds(keyVol_, 2, soBounds_, key_)); 129 | 130 | Scalar dot = mp[0]*key_[3] + mp[1]*key_[axis]; 131 | // if ((c = p->children_[childNo = (dot > 0)]) == nullptr) { 132 | // p->children_[childNo] = n; 133 | // return; 134 | // } 135 | c = _Adder::child(p, childNo = (dot > 0)); 136 | while (c == nullptr) 137 | if (_Adder::update(p, childNo, c, n)) 138 | return; 139 | 140 | this->soBounds_[1-childNo].col(axis) = mp; 141 | ++soDepth_; 142 | adder(c, n); 143 | } 144 | } 145 | }; 146 | 147 | template 148 | struct MidpointNearestTraversal<_Node, SO3AltSpace<_Scalar>> 149 | : MidpointSO3MinTraversalBase<_Scalar> 150 | { 151 | typedef _Scalar Scalar; 152 | typedef SO3AltSpace<_Scalar> Space; 153 | typedef typename Space::State Key; 154 | typedef typename Space::Distance Distance; 155 | 156 | using MidpointSO3MinTraversalBase<_Scalar>::soBounds_; 157 | using MidpointSO3MinTraversalBase<_Scalar>::soDepth_; 158 | using MidpointSO3MinTraversalBase<_Scalar>::keyVol_; 159 | using MidpointSO3MinTraversalBase<_Scalar>::key_; 160 | 161 | Key origKey_; 162 | Distance distToRegionCache_ = 0; 163 | 164 | MidpointNearestTraversal(const Space& space, const Key& key) 165 | : MidpointSO3MinTraversalBase<_Scalar>(space, key), 166 | origKey_(key) 167 | { 168 | } 169 | 170 | template 171 | inline _Scalar keyDistance(const Eigen::QuaternionBase<_Derived>& q) const { 172 | _Scalar dot = std::abs(origKey_.coeffs().matrix().dot(q.coeffs().matrix())); 173 | return dot < 0 ? M_PI_2 : dot > 1 ? 0 : std::acos(dot); 174 | } 175 | 176 | constexpr Distance distToRegion() const { 177 | return distToRegionCache_; 178 | } 179 | 180 | template 181 | inline Distance dotBounds(int b, unsigned axis, const Eigen::DenseBase<_Derived>& q) { 182 | // assert(b == 0 || b == 1); 183 | // assert(0 <= axis && axis < 3); 184 | 185 | return soBounds_[b](0, axis)*q[3] 186 | + soBounds_[b](1, axis)*q[axis]; 187 | } 188 | 189 | Distance initialBounds() { 190 | Distance d = 0; 191 | for (int a=0 ; a<3 ; ++a) { 192 | Distance d0 = dotBounds(0, a, key_); 193 | Distance d1 = dotBounds(1, a, key_); 194 | if (d0 < 0 || d1 > 0) 195 | d = std::max(d, std::min(std::abs(d0), std::abs(d1))); 196 | } 197 | return std::asin(d); 198 | } 199 | 200 | template 201 | inline void traverse(_Nearest& nearest, const _Node* n, unsigned axis) { 202 | if (soDepth_ < 3) { 203 | ++soDepth_; 204 | if (const _Node *c = _Nearest::child(n, keyVol_ & 1)) { 205 | // std::cout << c->value_.name_ << " " << soDepth_ << ".5" << std::endl; 206 | if (const _Node *g = _Nearest::child(c, keyVol_ >> 1)) { 207 | // assert(std::abs(origKey_.coeffs()[keyVol_]) == key_[3]); 208 | nearest(g); 209 | } 210 | // TODO: can we gain so efficiency by exploring the 211 | // nearest of the remaining 3 volumes first? 212 | nearest.update(c); 213 | if (const _Node *g = _Nearest::child(c, 1 - (keyVol_ >> 1))) { 214 | key_ = rotateCoeffs(origKey_.coeffs(), (keyVol_ ^ 2) + 1); 215 | if (key_[3] < 0) 216 | key_ = -key_; 217 | // assert(std::abs(origKey_.coeffs()[keyVol_ ^ 2]) == key_[3]); 218 | distToRegionCache_ = initialBounds(); 219 | if (nearest.shouldTraverse()) 220 | nearest(g); 221 | } 222 | } 223 | nearest.update(n); 224 | if (const _Node *c = _Nearest::child(n, 1 - (keyVol_ & 1))) { 225 | // std::cout << c->value_.name_ << " " << soDepth_ << ".5" << std::endl; 226 | if (const _Node *g = _Nearest::child(c, keyVol_ >> 1)) { 227 | key_ = rotateCoeffs(origKey_.coeffs(), (keyVol_ ^ 1) + 1); 228 | if (key_[3] < 0) 229 | key_ = -key_; 230 | // assert(std::abs(origKey_.coeffs()[keyVol_ ^ 1]) == key_[3]); 231 | distToRegionCache_ = initialBounds(); 232 | if (nearest.shouldTraverse()) 233 | nearest(g); 234 | } 235 | nearest.update(c); 236 | if (const _Node *g = _Nearest::child(c, 1 - (keyVol_ >> 1))) { 237 | key_ = rotateCoeffs(origKey_.coeffs(), (keyVol_ ^ 3) + 1); 238 | if (key_[3] < 0) 239 | key_ = -key_; 240 | // assert(std::abs(origKey_.coeffs()[keyVol_ ^ 3]) == key_[3]); 241 | distToRegionCache_ = initialBounds(); 242 | if (nearest.shouldTraverse()) 243 | nearest(g); 244 | } 245 | } 246 | // setting vol_ to keyVol_ is only needed when part of a compound space 247 | // if (key_[vol_ = keyVol_] < 0) 248 | // key_ = -key_; 249 | distToRegionCache_ = 0; 250 | key_ = rotateCoeffs(origKey_.coeffs(), keyVol_ + 1); 251 | if (key_[3] < 0) 252 | key_ = -key_; 253 | --soDepth_; 254 | // assert(distToRegion() == 0); 255 | // assert(soDepth_ == 2); 256 | } else { 257 | Eigen::Matrix mp = (soBounds_[0].col(axis) + soBounds_[1].col(axis)) 258 | .matrix().normalized(); 259 | Scalar dot = mp[0]*key_[3] 260 | + mp[1]*key_[axis]; 261 | ++soDepth_; 262 | int childNo = (dot > 0); 263 | if (const _Node *c = _Nearest::child(n, childNo)) { 264 | Eigen::Matrix tmp = soBounds_[1-childNo].col(axis); 265 | soBounds_[1-childNo].col(axis) = mp; 266 | // #ifdef KD_PEDANTIC 267 | // Scalar soBoundsDistNow = soBoundsDist(); 268 | // if (soBoundsDistNow + rvBoundsDistCache_ <= dist_) { 269 | // std::swap(soBoundsDistNow, soBoundsDistCache_); 270 | // #endif 271 | nearest(c); 272 | // #ifdef KD_PEDANTIC 273 | // soBoundsDistCache_ = soBoundsDistNow; 274 | // } 275 | // #endif 276 | soBounds_[1-childNo].col(axis) = tmp; 277 | } 278 | nearest.update(n); 279 | if (const _Node *c = _Nearest::child(n, 1-childNo)) { 280 | Eigen::Matrix tmp = soBounds_[childNo].col(axis); 281 | soBounds_[childNo].col(axis) = mp; 282 | Scalar distToSplit = std::asin(std::abs(dot)); 283 | Scalar oldDistToRegion = distToRegionCache_; 284 | distToRegionCache_ = std::max(oldDistToRegion, distToSplit); 285 | // distToRegionCache_ = computeDistToRegion(); 286 | if (nearest.shouldTraverse()) 287 | nearest(c); 288 | distToRegionCache_ = oldDistToRegion; 289 | soBounds_[childNo].col(axis) = tmp; 290 | } 291 | --soDepth_; 292 | } 293 | } 294 | }; 295 | 296 | template 297 | struct MedianAccum> { 298 | typedef _Scalar Scalar; 299 | typedef SO3AltSpace<_Scalar> Space; 300 | 301 | Eigen::Array min_; 302 | Eigen::Array max_; 303 | 304 | int vol_ = -1; 305 | 306 | MedianAccum(const Space& space) {} 307 | 308 | constexpr unsigned dimensions() const { 309 | return 3; 310 | } 311 | 312 | template 313 | void init(const Eigen::QuaternionBase<_Derived>& q) { 314 | if (vol_ < 0) return; 315 | for (unsigned axis = 0 ; axis<3 ; ++axis) 316 | min_.col(axis) = max_.col(axis) = projectToAxis(q, vol_, axis); 317 | } 318 | 319 | template 320 | void accum(const Eigen::QuaternionBase<_Derived>& q) { 321 | if (vol_ < 0) return; 322 | for (unsigned axis = 0 ; axis<3 ; ++axis) { 323 | Eigen::Matrix split = projectToAxis(q, vol_, axis); 324 | if (split[0] < min_(0, axis)) 325 | min_.col(axis) = split; 326 | if (split[0] > max_(0, axis)) 327 | max_.col(axis) = split; 328 | } 329 | } 330 | 331 | constexpr Scalar maxAxis(unsigned *axis) const { 332 | if (vol_ < 0) { 333 | *axis = 0; 334 | return M_PI; 335 | } else { 336 | // Compute: 337 | // (x_min * x_max) + (w_min * w_max) for wach axis 338 | // 339 | // This is the dot product between the min and max 340 | // boundaries. By finding the minimum we find the maximum 341 | // acos distance. 342 | 343 | return (min_ * max_).colwise().sum().minCoeff(axis); 344 | } 345 | } 346 | 347 | template 348 | void partition(_Builder& builder, unsigned axis, _Iter begin, _Iter end, const _GetKey& getKey) { 349 | if (vol_ < 0) { 350 | if (std::distance(begin, end) < 4) { 351 | for (_Iter it = begin ; it != end ; ++it) 352 | _Builder::setOffset(*it, 0); 353 | return; 354 | } 355 | 356 | // radix sort into 4 partitions, one for each volume 357 | Eigen::Array counts; 358 | counts.setZero(); 359 | 360 | for (_Iter it = begin ; it != end ; ++it) 361 | counts[so3VolumeIndex(getKey(*it))]++; 362 | 363 | std::array<_Iter, 4> its; 364 | std::array<_Iter, 3> stops; 365 | its[0] = begin; 366 | for (int i=0 ; i<3 ; ++i) 367 | its[i+1] = stops[i] = its[i] + counts[i]; 368 | assert(its[3]+counts[3] == end); 369 | for (int i=0 ; i<3 ; ++i) 370 | for (int v ; its[i] != stops[i] ; ++(its[v])) 371 | if ((v = so3VolumeIndex(getKey(*its[i]))) != i) 372 | std::iter_swap(its[i], its[v]); 373 | 374 | // after sorting, organize the range s.t. the first 3 375 | // elements are roots of a tree of 4 volumes. This makes 376 | // use of the offset_ member of the union to determine 377 | // where the subtrees split. 378 | 379 | // [begin q0 end) 380 | // begin [q0 .. q2) [q2 .. end) 381 | // begin q0 (q0 .. q1) [q1 .. q2) q2 (q2 .. q3) [q3 .. end) 382 | 383 | // select the volume with the most elements to be the root 384 | // this will help balance the subtrees out. 385 | 386 | for (int i=0, v ; i<3 ; ++i) { 387 | counts.maxCoeff(&v); 388 | 389 | for (int j=0 ; j aProj = projectToAxis(getKey(a), vol_, axis); 418 | Eigen::Matrix bProj = projectToAxis(getKey(b), vol_, axis); 419 | return aProj[0] < bProj[0]; 420 | }); 421 | std::iter_swap(begin, mid); 422 | Eigen::Matrix split = projectToAxis(getKey(*begin), vol_, axis); 423 | 424 | // split[0] may be positive or negative, whereas split[1] 425 | // is always non-negative. Given that, split.norm() == 1, 426 | // we only need to store split[0] and can recomput 427 | // split[1] from it when necessary. 428 | _Builder::setSplit(*begin, split[0]); 429 | 430 | ++mid; 431 | 432 | builder(begin+1, mid); 433 | builder(mid, end); 434 | } 435 | } 436 | }; 437 | 438 | template 439 | struct MedianNearestTraversal> { 440 | typedef _Scalar Scalar; 441 | typedef SO3AltSpace Space; 442 | typedef typename Space::State Key; 443 | typedef typename Space::Distance Distance; 444 | 445 | Key key_; 446 | int keyVol_; 447 | int vol_ = -1; 448 | Distance distToRegionCache_; 449 | 450 | std::array, 2> soBounds_; 451 | 452 | MedianNearestTraversal(const Space& space, const Key& key) 453 | : key_(key), 454 | keyVol_(so3VolumeIndex(key)) 455 | { 456 | soBounds_[0] = M_SQRT1_2; 457 | soBounds_[1].colwise() = Eigen::Array(-M_SQRT1_2, M_SQRT1_2); 458 | } 459 | 460 | constexpr unsigned dimensions() const { 461 | return 3; 462 | } 463 | 464 | template 465 | Distance keyDistance(const Eigen::QuaternionBase<_Derived>& q) const { 466 | Distance dot = std::abs(key_.coeffs().matrix().dot(q.coeffs().matrix())); 467 | return dot < 0 ? M_PI_2 : dot > 1 ? 0 : std::acos(dot); 468 | } 469 | 470 | constexpr Distance distToRegion() const { 471 | return distToRegionCache_; 472 | } 473 | 474 | template 475 | inline Scalar dotBounds(int b, unsigned axis, const Eigen::DenseBase<_Derived>& q) { 476 | // assert(b == 0 || b == 1); 477 | // assert(0 <= axis && axis < 3); 478 | assert(q[vol_] >= 0); 479 | return soBounds_[b](0, axis)*q[vol_] 480 | + soBounds_[b](1, axis)*q[(vol_ + axis + 1)%4]; 481 | } 482 | 483 | 484 | template 485 | void traverse(_Nearest& nearest, unsigned axis, _Iter begin, _Iter end) { 486 | if (vol_ < 0) { 487 | if (std::distance(begin, end) < 4) { 488 | for (_Iter it = begin ; it != end ; ++it) 489 | nearest.update(*it); 490 | return; 491 | } 492 | 493 | std::array<_Iter, 5> iters{{ 494 | begin + 3, 495 | begin + _Nearest::offset(begin[0]), 496 | begin + _Nearest::offset(begin[1]), 497 | begin + _Nearest::offset(begin[2]), 498 | end 499 | }}; 500 | 501 | // std::cout << "--- " << std::distance(begin, end) << " @ " << &*begin << std::endl; 502 | // for (int i=0 ; i<3 ; ++i) 503 | // std::cout << _Nearest::offset(begin[i]) << std::endl; 504 | 505 | for (int i=0 ; i<4 ; ++i) 506 | assert(std::distance(iters[i], iters[i+1]) >= 0); 507 | 508 | for (int v=0 ; v<4 ; ++v) { 509 | if (key_.coeffs()[vol_ = (keyVol_ + v)%4] < 0) 510 | key_.coeffs() = -key_.coeffs(); 511 | // TODO: add back 512 | // if (v != 0) 513 | // distToRegionCache_ = computeDistToRegion(); 514 | if (v) { 515 | Distance d = 0; 516 | for (unsigned a=0 ; a<3 ; ++a) { 517 | Distance d0 = dotBounds(0, a, key_.coeffs()); 518 | Distance d1 = dotBounds(1, a, key_.coeffs()); 519 | if (d0 < 0 || d1 > 0) 520 | d = std::max(d, std::min(-d0, d1)); // std::abs(d0), std::abs(d1))); 521 | } 522 | distToRegionCache_ = std::asin(d); 523 | } 524 | 525 | if (nearest.shouldTraverse()) { 526 | // std::cout << "q" << v << ": " << std::distance(iters[vol_], iters[vol_+1]) << std::endl; 527 | nearest(iters[vol_], iters[vol_+1]); 528 | } 529 | } 530 | vol_ = -1; 531 | distToRegionCache_ = 0; 532 | 533 | for (int i=0 ; i<3 ; ++i) 534 | nearest.update(begin[i]); 535 | 536 | } else { 537 | const auto& n = *begin++; 538 | 539 | _Iter mid = begin + std::distance(begin, end)/2; 540 | // std::cout << std::distance(begin, end) << " " << std::distance(begin, mid) << std::endl; 541 | assert(std::distance(begin, mid) >= 0); 542 | assert(std::distance(mid, end) >= 0); 543 | Distance q0 = key_.coeffs()[vol_]; 544 | Distance qa = key_.coeffs()[(vol_ + axis + 1)%4]; 545 | 546 | Eigen::Matrix split; 547 | split[0] = _Nearest::split(n); 548 | split[1] = std::sqrt(1 - split[0]*split[0]); 549 | 550 | Distance dot = split[0] * q0 + split[1] * qa; 551 | int childNo = (dot > 0); 552 | 553 | Eigen::Matrix tmp = soBounds_[1-childNo].col(axis); 554 | soBounds_[1-childNo].col(axis) = split; 555 | if (nearest.shouldTraverse()) { 556 | if (childNo) { 557 | nearest(begin, mid); 558 | } else { 559 | nearest(mid, end); 560 | } 561 | } 562 | soBounds_[1-childNo].col(axis) = tmp; 563 | 564 | tmp = soBounds_[childNo].col(axis); 565 | soBounds_[childNo].col(axis) = split; 566 | Scalar prevDistToRegion = distToRegionCache_; 567 | Scalar distToSplit = std::asin(std::abs(dot)); 568 | distToRegionCache_ = std::max(prevDistToRegion, distToSplit); 569 | if (nearest.shouldTraverse()) { 570 | if (childNo) { 571 | nearest(mid, end); 572 | } else { 573 | nearest(begin, mid); 574 | } 575 | } 576 | soBounds_[childNo].col(axis) = tmp; 577 | distToRegionCache_ = prevDistToRegion; 578 | 579 | nearest.update(n); 580 | } 581 | } 582 | }; 583 | 584 | }}}} 585 | 586 | #endif // UNC_ROBOTICS_KDTREE_SO3MINSPACE_HPP 587 | -------------------------------------------------------------------------------- /src/_so3space.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2017 Jeffrey Ichnowski 2 | // All rights reserved. 3 | // 4 | // BSD 3 Clause 5 | // 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions 8 | // are met: 9 | // 1. Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // 2. Redistributions in binary form must reproduce the above copyright 12 | // notice, this list of conditions and the following disclaimer in the 13 | // documentation and/or other materials provided with the distribution. 14 | // 3. Neither the name of the copyright holder nor the names of its 15 | // contributors may be used to endorse or promote products derived 16 | // from this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 21 | // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 22 | // COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 23 | // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 26 | // HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 27 | // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 29 | // OF THE POSSIBILITY OF SUCH DAMAGE. 30 | #pragma once 31 | #ifndef UNC_ROBOTICS_KDTREE_SO3SPACE_HPP 32 | #define UNC_ROBOTICS_KDTREE_SO3SPACE_HPP 33 | 34 | namespace unc { namespace robotics { namespace kdtree { namespace detail { 35 | 36 | template 37 | unsigned so3VolumeIndex(const Eigen::MatrixBase<_Derived>& q) { 38 | unsigned index; 39 | q.array().abs().maxCoeff(&index); 40 | return index; 41 | } 42 | 43 | template 44 | unsigned so3VolumeIndex(const Eigen::QuaternionBase<_Scalar>& q) { 45 | return so3VolumeIndex(q.coeffs()); 46 | } 47 | 48 | template 49 | Eigen::Matrix rotateCoeffs(const Eigen::DenseBase<_Derived>& m, unsigned shift) { 50 | // 0: 0 1 2 3 51 | // 1: 1 2 3 0 52 | // 2: 2 3 0 1 53 | // 3: 3 0 1 2 54 | 55 | return Eigen::Matrix( 56 | m[shift%4], 57 | m[(shift+1)%4], 58 | m[(shift+2)%4], 59 | m[(shift+3)%4]); 60 | } 61 | 62 | template 63 | class SO3Region { 64 | typedef _Scalar Scalar; 65 | 66 | std::array, 2> bounds_; 67 | 68 | public: 69 | SO3Region() { 70 | bounds_[0] = M_SQRT1_2; 71 | bounds_[1].colwise() = Eigen::Array<_Scalar, 2, 1>(-M_SQRT1_2, M_SQRT1_2); 72 | } 73 | 74 | Eigen::Matrix<_Scalar, 2, 1> midPoint(unsigned axis) { 75 | return (bounds_[0].col(axis) + bounds_[1].col(axis)).matrix().normalized(); 76 | } 77 | 78 | auto operator() (int which, unsigned axis) { 79 | return bounds_[which].col(axis); 80 | } 81 | 82 | template 83 | inline _Scalar dotBounds(int which, unsigned axis, const Eigen::DenseBase<_Derived>& q) { 84 | return bounds_[which](0, axis)*q[3] 85 | + bounds_[which](1, axis)*q[axis]; 86 | } 87 | 88 | template 89 | inline _Scalar computeDistToRegion(const Eigen::MatrixBase<_Derived>& q) { 90 | int edgesToCheck = 0; 91 | 92 | // check faces 93 | for (int a0 = 0 ; a0 < 3 ; ++a0) { 94 | Eigen::Matrix dot(dotBounds(0, a0, q), dotBounds(1, a0, q)); 95 | int b0 = dot[0] >= 0; 96 | if (b0 && dot[1] <= 0) 97 | continue; // in bounds 98 | 99 | Eigen::Matrix p0 = q; 100 | p0[3] -= bounds_[b0](0, a0) * dot[b0]; 101 | p0[a0] -= bounds_[b0](1, a0) * dot[b0]; 102 | 103 | int a1 = (a0+1)%3; 104 | if (dotBounds(1, a1, p0) > 0 || dotBounds(0, a1, p0) < 0) { 105 | edgesToCheck |= 1 << (a0+a1); 106 | continue; // not on face with this axis 107 | } 108 | int a2 = (a0+2)%3; 109 | if (dotBounds(1, a2, p0) > 0 || dotBounds(0, a2, p0) < 0) { 110 | edgesToCheck |= 1 << (a0+a2); 111 | continue; // not on face with this axis 112 | } 113 | // the projected point is on this face, the distance to 114 | // the projected point is the closest point in the bounded 115 | // region to the query key. Use asin of the dot product 116 | // to the bounding face for the distance, instead of the 117 | // acos of the dot product to p, since p0 is not 118 | // normalized for efficiency. 119 | return std::asin(std::abs(dot[b0])); 120 | } 121 | 122 | // if the query point is within all bounds of all 3 axes, then it is within the region. 123 | if (edgesToCheck == 0) 124 | return 0; 125 | 126 | // int cornerChecked = 0; 127 | int cornersToCheck = 0; 128 | Eigen::Matrix T; 129 | T.row(0) = bounds_[0].row(0) / bounds_[0].row(1); 130 | T.row(1) = bounds_[1].row(0) / bounds_[1].row(1); 131 | 132 | // check edges 133 | // ++, +-, --, -+ for 01, 12, 20 134 | Scalar dotMax = 0; 135 | for (int a0 = 0 ; a0 < 3 ; ++a0) { 136 | int a1 = (a0 + 1)%3; 137 | int a2 = (a0 + 2)%3; 138 | 139 | if ((edgesToCheck & (1 << (a0+a1))) == 0) 140 | continue; 141 | 142 | for (int edge = 0 ; edge < 4 ; ++edge) { 143 | int b0 = edge & 1; 144 | int b1 = edge >> 1; 145 | 146 | Eigen::Matrix p1; 147 | Scalar t0 = T(b0, a0); // bounds_[b0](0, a0) / bounds_[b0](1, a0); 148 | Scalar t1 = T(b1, a1); // bounds_[b1](0, a1) / bounds_[b1](1, a1); 149 | Scalar r = q[3] - t0*q[a0] - t1*q[a1]; 150 | Scalar s = t0*t0 + t1*t1 + 1; 151 | 152 | // bounds check only requires p1[3] and p1[a2], and 153 | // p1[3] must be non-negative. If in bounds, then 154 | // [a0] and [a1] are required to compute the distance 155 | // to the edge. 156 | p1[3] = r; 157 | // p1[a0] = -t0*r; 158 | // p1[a1] = -t1*r; 159 | p1[a2] = q[a2] * s; 160 | 161 | int b2; 162 | if ((b2 = dotBounds(0, a2, p1) >= 0) && dotBounds(1, a2, p1) <= 0) { 163 | // projection onto edge is in bounds of a2, this 164 | // point will be closer than the corners. 165 | p1[a0] = -t0*r; 166 | p1[a1] = -t1*r; 167 | dotMax = std::max(dotMax, std::abs(p1.dot(q)) / p1.norm()); 168 | continue; 169 | } 170 | if (r < 0) b2 = 1-b2; 171 | 172 | int cornerCode = 1 << ((b0 << a0) | (b1 << a1) | (b2 << a2)); 173 | cornersToCheck |= cornerCode; 174 | 175 | // if (cornerChecked & cornerCode) 176 | // continue; 177 | // cornerChecked |= cornerCode; 178 | // // edge is not in bounds, use the distance to the corner 179 | // Eigen::Matrix p2; 180 | // Scalar aw = bounds_[b0](0, a0); 181 | // Scalar ax = bounds_[b0](1, a0); 182 | // Scalar bw = bounds_[b1](0, a1); 183 | // Scalar by = bounds_[b1](1, a1); 184 | // Scalar cw = bounds_[b2](0, a2); 185 | // Scalar cz = bounds_[b2](1, a2); 186 | 187 | // p2[a0] = aw*by*cz; 188 | // p2[a1] = ax*bw*cz; 189 | // p2[a2] = ax*by*cw; 190 | // p2[ 3] = -ax*by*cz; 191 | 192 | // // // p2 should be on both bounds 193 | // // assert(std::abs(dotBounds(b0, a0, p2)) < 1e-7); 194 | // // assert(std::abs(dotBounds(b1, a1, p2)) < 1e-7); 195 | // // assert(std::abs(dotBounds(b2, a2, p2)) < 1e-7); 196 | 197 | // dotMax = std::max(dotMax, std::abs(q.dot(p2)) / p2.norm()); 198 | } 199 | } 200 | 201 | for (int i=0 ; i<8 ; ++i) { 202 | if ((cornersToCheck & (1 << i)) == 0) 203 | continue; 204 | 205 | int b0 = i&1; 206 | int b1 = (i>>1)&1; 207 | int b2 = i>>2; 208 | 209 | Eigen::Matrix p2; 210 | Scalar aw = bounds_[b0](0, 0); 211 | Scalar ax = bounds_[b0](1, 0); 212 | Scalar bw = bounds_[b1](0, 1); 213 | Scalar by = bounds_[b1](1, 1); 214 | Scalar cw = bounds_[b2](0, 2); 215 | Scalar cz = bounds_[b2](1, 2); 216 | 217 | p2[0] = aw*by*cz; 218 | p2[1] = ax*bw*cz; 219 | p2[2] = ax*by*cw; 220 | p2[3] = -ax*by*cz; 221 | 222 | // // p2 should be on both bounds 223 | // assert(std::abs(dotBounds(b0, a0, p2)) < 1e-7); 224 | // assert(std::abs(dotBounds(b1, a1, p2)) < 1e-7); 225 | // assert(std::abs(dotBounds(b2, a2, p2)) < 1e-7); 226 | 227 | dotMax = std::max(dotMax, std::abs(q.dot(p2)) / p2.norm()); 228 | } 229 | 230 | return std::acos(dotMax); 231 | } 232 | 233 | }; 234 | 235 | template 236 | struct MidpointSO3TraversalBase { 237 | typedef SO3Space<_Scalar> Space; 238 | typedef typename Space::State Key; 239 | 240 | Eigen::Matrix<_Scalar, 4, 1> key_; 241 | SO3Region<_Scalar> soBounds_; 242 | unsigned soDepth_; 243 | unsigned keyVol_; 244 | 245 | MidpointSO3TraversalBase(const Space& space, const Key& key) 246 | : soDepth_(2), 247 | keyVol_(so3VolumeIndex(key)) 248 | { 249 | key_ = rotateCoeffs(key.coeffs(), keyVol_ + 1); 250 | if (key_[3] < 0) 251 | key_ = -key_; 252 | } 253 | 254 | constexpr unsigned dimensions() const { 255 | return 3; 256 | } 257 | 258 | constexpr _Scalar maxAxis(unsigned *axis) const { 259 | *axis = soDepth_ % 3; 260 | return M_PI / (1 << (soDepth_ / 3)); 261 | } 262 | }; 263 | 264 | template 265 | struct MidpointAddTraversal<_Node, SO3Space<_Scalar>> 266 | : MidpointSO3TraversalBase<_Scalar> 267 | { 268 | typedef _Scalar Scalar; 269 | typedef SO3Space Space; 270 | typedef typename Space::State Key; 271 | 272 | using MidpointSO3TraversalBase<_Scalar>::soDepth_; 273 | using MidpointSO3TraversalBase<_Scalar>::keyVol_; 274 | using MidpointSO3TraversalBase<_Scalar>::key_; 275 | 276 | MidpointAddTraversal(const Space& space, const Key& key) 277 | : MidpointSO3TraversalBase<_Scalar>(space, key) 278 | { 279 | } 280 | 281 | template 282 | void addImpl(_Adder& adder, unsigned axis, _Node* p, _Node *n) { 283 | int childNo; 284 | _Node *c; 285 | 286 | if (soDepth_ < 3) { 287 | c = _Adder::child(p, childNo = keyVol_ & 1); 288 | while (c == nullptr) 289 | if (_Adder::update(p, childNo, c, n)) 290 | return; 291 | 292 | // if ((c = p->children_[childNo = keyVol_ & 1]) == nullptr) { 293 | // p->children_[childNo] = n; 294 | // return; 295 | // } 296 | p = c; 297 | 298 | c = _Adder::child(p, childNo = keyVol_ >> 1); 299 | while (c == nullptr) 300 | if (_Adder::update(p, childNo, c, n)) 301 | return; 302 | 303 | // if ((c = p->children_[childNo = keyVol_ >> 1]) == nullptr) { 304 | // p->children_[childNo] = n; 305 | // return; 306 | // } 307 | 308 | ++soDepth_; 309 | adder(c, n); 310 | } else { 311 | Eigen::Matrix mp = this->soBounds_.midPoint(axis); 312 | 313 | // assert(inSoBounds(keyVol_, 0, soBounds_, key_)); 314 | // assert(inSoBounds(keyVol_, 1, soBounds_, key_)); 315 | // assert(inSoBounds(keyVol_, 2, soBounds_, key_)); 316 | 317 | Scalar dot = mp[0]*key_[3] + mp[1]*key_[axis]; 318 | // if ((c = p->children_[childNo = (dot > 0)]) == nullptr) { 319 | // p->children_[childNo] = n; 320 | // return; 321 | // } 322 | c = _Adder::child(p, childNo = (dot > 0)); 323 | while (c == nullptr) 324 | if (_Adder::update(p, childNo, c, n)) 325 | return; 326 | 327 | this->soBounds_(1-childNo, axis) = mp; 328 | ++soDepth_; 329 | adder(c, n); 330 | } 331 | } 332 | }; 333 | 334 | template 335 | struct MidpointNearestTraversal<_Node, SO3Space<_Scalar>> 336 | : MidpointSO3TraversalBase<_Scalar> 337 | { 338 | typedef _Scalar Scalar; 339 | typedef SO3Space<_Scalar> Space; 340 | typedef typename Space::State Key; 341 | typedef typename Space::Distance Distance; 342 | 343 | using MidpointSO3TraversalBase<_Scalar>::soBounds_; 344 | using MidpointSO3TraversalBase<_Scalar>::soDepth_; 345 | using MidpointSO3TraversalBase<_Scalar>::keyVol_; 346 | using MidpointSO3TraversalBase<_Scalar>::key_; 347 | 348 | Key origKey_; 349 | Distance distToRegionCache_ = 0; 350 | 351 | MidpointNearestTraversal(const Space& space, const Key& key) 352 | : MidpointSO3TraversalBase<_Scalar>(space, key), 353 | origKey_(key) 354 | { 355 | } 356 | 357 | template 358 | inline _Scalar keyDistance(const Eigen::QuaternionBase<_Derived>& q) const { 359 | _Scalar dot = std::abs(origKey_.coeffs().matrix().dot(q.coeffs().matrix())); 360 | return dot < 0 ? M_PI_2 : dot > 1 ? 0 : std::acos(dot); 361 | } 362 | 363 | inline Distance distToRegion() const { 364 | return distToRegionCache_; 365 | } 366 | 367 | template 368 | inline void traverse(_Nearest& nearest, const _Node* n, unsigned axis) { 369 | if (soDepth_ < 3) { 370 | ++soDepth_; 371 | // 27.6 before change 372 | 373 | // 0 -> SO(3)F: 33.538 us/op 374 | // 0 -> SO(3)F: 33.3035 us/op 375 | // 1 -> SO(3)F: 33.6403 us/op 376 | // 1 -> SO(3)F: 33.2282 us/op 377 | 378 | // 0 -> SO(3)F: 44.1291 us/op 44.537 us/op 379 | // 1.0 -> 44.6697 us/op 50.3254 us/op 380 | // 1.1 -> 44.9184 us/op 45.0832 us/op 44.7288 us/op 381 | #if 1 382 | std::array roots{}; 383 | 384 | nearest.update(n); 385 | 386 | if (const _Node *c = _Nearest::child(n, 0)) { 387 | nearest.update(c); 388 | roots[0] = _Nearest::child(c, 0); 389 | roots[2] = _Nearest::child(c, 1); 390 | } 391 | if (const _Node *c = _Nearest::child(n, 1)) { 392 | nearest.update(c); 393 | roots[1] = _Nearest::child(c, 0); 394 | roots[3] = _Nearest::child(c, 1); 395 | } 396 | 397 | if (const _Node *c = roots[keyVol_]) { 398 | nearest(c); 399 | } 400 | 401 | #if 1 402 | std::array, 3> volDists; 403 | for (unsigned i=1 ; i<4 ; ++i) { 404 | unsigned vol = volDists[i-1].first = (keyVol_ + i)%4; 405 | if (roots[vol]) { 406 | if ((key_ = rotateCoeffs(origKey_.coeffs(), vol + 1))[3] < 0) 407 | key_ = -key_; 408 | volDists[i-1].second = soBounds_.computeDistToRegion(key_); 409 | } else { 410 | volDists[i-1].second = std::numeric_limits::infinity(); 411 | } 412 | } 413 | 414 | std::sort(volDists.begin(), volDists.end(), CompareSecond()); 415 | for (int i=0 ; i<3 && 416 | (distToRegionCache_ = volDists[i].second) < std::numeric_limits::infinity() ; ++i) { 417 | unsigned vol = volDists[i].first; 418 | if (nearest.shouldTraverse()) { 419 | if ((key_ = rotateCoeffs(origKey_.coeffs(), vol + 1))[3] < 0) 420 | key_ = -key_; 421 | nearest(roots[vol]); 422 | } 423 | } 424 | #else 425 | for (unsigned i=1 ; i<4 ; ++i) { 426 | unsigned vol = (keyVol_ + i)%4; 427 | if (const _Node *c = roots[vol]) { 428 | key_ = rotateCoeffs(origKey_.coeffs(), vol + 1); 429 | if (key_[3] < 0) 430 | key_ = -key_; 431 | distToRegionCache_ = computeDistToRegion(); 432 | if (nearest.shouldTraverse()) 433 | nearest(c); 434 | } 435 | } 436 | #endif 437 | 438 | #else 439 | if (const _Node *c = _Nearest::child(n, keyVol_ & 1)) { 440 | // std::cout << c->value_.name_ << " " << soDepth_ << ".5" << std::endl; 441 | if (const _Node *g = _Nearest::child(c, keyVol_ >> 1)) { 442 | // assert(std::abs(origKey_.coeffs()[keyVol_]) == key_[3]); 443 | nearest(g); 444 | } 445 | // TODO: can we gain so efficiency by exploring the 446 | // nearest of the remaining 3 volumes first? 447 | nearest.update(c); 448 | if (const _Node *g = _Nearest::child(c, 1 - (keyVol_ >> 1))) { 449 | key_ = rotateCoeffs(origKey_.coeffs(), (keyVol_ ^ 2) + 1); 450 | if (key_[3] < 0) 451 | key_ = -key_; 452 | // assert(std::abs(origKey_.coeffs()[keyVol_ ^ 2]) == key_[3]); 453 | distToRegionCache_ = computeDistToRegion(); 454 | if (nearest.shouldTraverse()) 455 | nearest(g); 456 | } 457 | } 458 | nearest.update(n); 459 | if (const _Node *c = _Nearest::child(n, 1 - (keyVol_ & 1))) { 460 | // std::cout << c->value_.name_ << " " << soDepth_ << ".5" << std::endl; 461 | if (const _Node *g = _Nearest::child(c, keyVol_ >> 1)) { 462 | key_ = rotateCoeffs(origKey_.coeffs(), (keyVol_ ^ 1) + 1); 463 | if (key_[3] < 0) 464 | key_ = -key_; 465 | // assert(std::abs(origKey_.coeffs()[keyVol_ ^ 1]) == key_[3]); 466 | distToRegionCache_ = computeDistToRegion(); 467 | if (nearest.shouldTraverse()) 468 | nearest(g); 469 | } 470 | nearest.update(c); 471 | if (const _Node *g = _Nearest::child(c, 1 - (keyVol_ >> 1))) { 472 | key_ = rotateCoeffs(origKey_.coeffs(), (keyVol_ ^ 3) + 1); 473 | if (key_[3] < 0) 474 | key_ = -key_; 475 | // assert(std::abs(origKey_.coeffs()[keyVol_ ^ 3]) == key_[3]); 476 | distToRegionCache_ = computeDistToRegion(); 477 | if (nearest.shouldTraverse()) 478 | nearest(g); 479 | } 480 | } 481 | #endif 482 | 483 | // setting vol_ to keyVol_ is only needed when part of a compound space 484 | // if (key_[vol_ = keyVol_] < 0) 485 | // key_ = -key_; 486 | distToRegionCache_ = 0; 487 | key_ = rotateCoeffs(origKey_.coeffs(), keyVol_ + 1); 488 | if (key_[3] < 0) 489 | key_ = -key_; 490 | --soDepth_; 491 | // assert(distToRegion() == 0); 492 | // assert(soDepth_ == 2); 493 | } else { 494 | Eigen::Matrix mp = soBounds_.midPoint(axis); 495 | Scalar dot = mp[0]*key_[3] 496 | + mp[1]*key_[axis]; 497 | ++soDepth_; 498 | int childNo = (dot > 0); 499 | if (const _Node *c = _Nearest::child(n, childNo)) { 500 | Eigen::Matrix tmp = soBounds_(1-childNo, axis); 501 | soBounds_(1-childNo, axis) = mp; 502 | // #ifdef KD_PEDANTIC 503 | // Scalar soBoundsDistNow = soBoundsDist(); 504 | // if (soBoundsDistNow + rvBoundsDistCache_ <= dist_) { 505 | // std::swap(soBoundsDistNow, soBoundsDistCache_); 506 | // #endif 507 | nearest(c); 508 | // #ifdef KD_PEDANTIC 509 | // soBoundsDistCache_ = soBoundsDistNow; 510 | // } 511 | // #endif 512 | soBounds_(1-childNo, axis) = tmp; 513 | } 514 | nearest.update(n); 515 | if (const _Node *c = _Nearest::child(n, 1-childNo)) { 516 | Eigen::Matrix tmp = soBounds_(childNo, axis); 517 | soBounds_(childNo, axis) = mp; 518 | Scalar oldDistToRegion = distToRegionCache_; 519 | distToRegionCache_ = soBounds_.computeDistToRegion(key_); 520 | if (nearest.shouldTraverse()) 521 | nearest(c); 522 | distToRegionCache_ = oldDistToRegion; 523 | soBounds_(childNo, axis) = tmp; 524 | } 525 | --soDepth_; 526 | } 527 | } 528 | }; 529 | 530 | template 531 | Eigen::Matrix projectToAxis( 532 | const Eigen::QuaternionBase& q, int vol, int axis) 533 | { 534 | typedef typename Derived::Scalar Scalar; 535 | 536 | Eigen::Matrix vec(-q.coeffs()[(vol + 1 + axis)%4], q.coeffs()[vol]); 537 | Scalar norm = 1 / vec.norm(); 538 | if (vec[1] < 0) norm = -norm; 539 | return vec*norm; 540 | } 541 | 542 | 543 | template 544 | struct MedianAccum> { 545 | typedef _Scalar Scalar; 546 | typedef SO3Space<_Scalar> Space; 547 | 548 | Eigen::Array min_; 549 | Eigen::Array max_; 550 | 551 | int vol_ = -1; 552 | 553 | MedianAccum(const Space& space) {} 554 | 555 | constexpr unsigned dimensions() const { 556 | return 3; 557 | } 558 | 559 | template 560 | void init(const Eigen::QuaternionBase<_Derived>& q) { 561 | if (vol_ < 0) return; 562 | for (unsigned axis = 0 ; axis<3 ; ++axis) 563 | min_.col(axis) = max_.col(axis) = projectToAxis(q, vol_, axis); 564 | } 565 | 566 | template 567 | void accum(const Eigen::QuaternionBase<_Derived>& q) { 568 | if (vol_ < 0) return; 569 | for (unsigned axis = 0 ; axis<3 ; ++axis) { 570 | Eigen::Matrix split = projectToAxis(q, vol_, axis); 571 | if (split[0] < min_(0, axis)) 572 | min_.col(axis) = split; 573 | if (split[0] > max_(0, axis)) 574 | max_.col(axis) = split; 575 | } 576 | } 577 | 578 | constexpr Scalar maxAxis(unsigned *axis) const { 579 | if (vol_ < 0) { 580 | *axis = 0; 581 | return M_PI; 582 | } else { 583 | // Compute: 584 | // (x_min * x_max) + (w_min * w_max) for wach axis 585 | // 586 | // This is the dot product between the min and max 587 | // boundaries. By finding the minimum we find the maximum 588 | // acos distance. 589 | 590 | return (min_ * max_).colwise().sum().minCoeff(axis); 591 | } 592 | } 593 | 594 | template 595 | void partition(_Builder& builder, unsigned axis, _Iter begin, _Iter end, const _GetKey& getKey) { 596 | if (vol_ < 0) { 597 | if (std::distance(begin, end) < 4) { 598 | for (_Iter it = begin ; it != end ; ++it) 599 | _Builder::setOffset(*it, 0); 600 | return; 601 | } 602 | 603 | // radix sort into 4 partitions, one for each volume 604 | Eigen::Array counts; 605 | counts.setZero(); 606 | 607 | for (_Iter it = begin ; it != end ; ++it) 608 | counts[so3VolumeIndex(getKey(*it))]++; 609 | 610 | std::array<_Iter, 4> its; 611 | std::array<_Iter, 3> stops; 612 | its[0] = begin; 613 | for (int i=0 ; i<3 ; ++i) 614 | its[i+1] = stops[i] = its[i] + counts[i]; 615 | assert(its[3]+counts[3] == end); 616 | for (int i=0 ; i<3 ; ++i) 617 | for (int v ; its[i] != stops[i] ; ++(its[v])) 618 | if ((v = so3VolumeIndex(getKey(*its[i]))) != i) 619 | std::iter_swap(its[i], its[v]); 620 | 621 | // after sorting, organize the range s.t. the first 3 622 | // elements are roots of a tree of 4 volumes. This makes 623 | // use of the offset_ member of the union to determine 624 | // where the subtrees split. 625 | 626 | // [begin q0 end) 627 | // begin [q0 .. q2) [q2 .. end) 628 | // begin q0 (q0 .. q1) [q1 .. q2) q2 (q2 .. q3) [q3 .. end) 629 | 630 | // select the volume with the most elements to be the root 631 | // this will help balance the subtrees out. 632 | 633 | for (int i=0, v ; i<3 ; ++i) { 634 | counts.maxCoeff(&v); 635 | 636 | for (int j=0 ; j aProj = projectToAxis(getKey(a), vol_, axis); 664 | Eigen::Matrix bProj = projectToAxis(getKey(b), vol_, axis); 665 | return aProj[0] > bProj[0]; 666 | }); 667 | std::iter_swap(begin, mid); 668 | Eigen::Matrix split = projectToAxis(getKey(*begin), vol_, axis); 669 | 670 | // split[0] may be positive or negative, whereas split[1] 671 | // is always non-negative. Given that, split.norm() == 1, 672 | // we only need to store split[0] and can recomput 673 | // split[1] from it when necessary. 674 | _Builder::setSplit(*begin, split[0]); 675 | 676 | ++mid; 677 | 678 | builder(begin+1, mid); 679 | builder(mid, end); 680 | } 681 | } 682 | }; 683 | 684 | template 685 | struct MedianNearestTraversal> { 686 | typedef _Scalar Scalar; 687 | typedef SO3Space Space; 688 | typedef typename Space::State Key; 689 | typedef typename Space::Distance Distance; 690 | 691 | Key origKey_; 692 | Eigen::Matrix<_Scalar, 4, 1> key_; 693 | int keyVol_; 694 | Distance distToRegionCache_{0}; 695 | bool atRoot_{true}; 696 | SO3Region<_Scalar> soBounds_; 697 | 698 | MedianNearestTraversal(const Space& space, const Key& key) 699 | : origKey_(key), 700 | keyVol_(so3VolumeIndex(key)) 701 | { 702 | } 703 | 704 | constexpr unsigned dimensions() const { 705 | return 3; 706 | } 707 | 708 | template 709 | Distance keyDistance(const Eigen::QuaternionBase<_Derived>& q) const { 710 | Distance dot = std::abs(origKey_.coeffs().matrix().dot(q.coeffs().matrix())); 711 | return dot < 0 ? M_PI_2 : dot > 1 ? 0 : std::acos(dot); 712 | } 713 | 714 | constexpr Distance distToRegion() const { 715 | return distToRegionCache_; 716 | } 717 | 718 | template 719 | void traverse(_Nearest& nearest, unsigned axis, _Iter begin, _Iter end) { 720 | if (atRoot_) { 721 | if (std::distance(begin, end) < 4) { 722 | for (_Iter it = begin ; it != end ; ++it) 723 | nearest.update(*it); 724 | return; 725 | } 726 | 727 | atRoot_ = false; 728 | std::array<_Iter, 5> iters{{ 729 | begin + 3, 730 | begin + _Nearest::offset(begin[0]), 731 | begin + _Nearest::offset(begin[1]), 732 | begin + _Nearest::offset(begin[2]), 733 | end 734 | }}; 735 | 736 | // std::cout << "--- " << std::distance(begin, end) << " @ " << &*begin << std::endl; 737 | // for (int i=0 ; i<3 ; ++i) 738 | // std::cout << _Nearest::offset(begin[i]) << std::endl; 739 | 740 | if ((key_ = rotateCoeffs(origKey_.coeffs(), keyVol_ + 1))[3] < 0) 741 | key_ = -key_; 742 | 743 | nearest(iters[keyVol_], iters[keyVol_+1]); 744 | 745 | for (unsigned v=1 ; v<4 ; ++v) { 746 | unsigned vol = (keyVol_ + v)%4; 747 | if ((key_ = rotateCoeffs(origKey_.coeffs(), vol + 1))[3] < 0) 748 | key_ = -key_; 749 | distToRegionCache_ = soBounds_.computeDistToRegion(key_); 750 | if (nearest.shouldTraverse()) { 751 | // std::cout << "q" << v << ": " << std::distance(iters[vol_], iters[vol_+1]) << std::endl; 752 | nearest(iters[vol], iters[vol+1]); 753 | } 754 | } 755 | distToRegionCache_ = 0; 756 | 757 | for (int i=0 ; i<3 ; ++i) 758 | nearest.update(begin[i]); 759 | 760 | atRoot_ = true; 761 | } else { 762 | const auto& n = *begin++; 763 | 764 | _Iter mid = begin + std::distance(begin, end)/2; 765 | // std::cout << std::distance(begin, end) << " " << std::distance(begin, mid) << std::endl; 766 | assert(std::distance(begin, mid) >= 0); 767 | assert(std::distance(mid, end) >= 0); 768 | 769 | std::array<_Iter,3> iters{{ begin, mid, end }}; 770 | Eigen::Matrix split; 771 | split[0] = _Nearest::split(n); 772 | split[1] = std::sqrt(1 - split[0]*split[0]); 773 | 774 | Distance dot = split[0]*key_[3] + split[1]*key_[axis]; 775 | int childNo = (dot > 0); 776 | 777 | nearest.update(n); 778 | 779 | if (iters[childNo] != iters[childNo+1]) { 780 | Eigen::Matrix tmp = soBounds_(1-childNo, axis); 781 | soBounds_(1-childNo, axis) = split; 782 | nearest(iters[childNo], iters[childNo+1]); // 0 -> (0, 1), 1 -> (1, 2) 783 | soBounds_(1-childNo, axis) = tmp; 784 | } 785 | 786 | if (iters[1-childNo] != iters[2-childNo]) { 787 | Eigen::Matrix tmp = soBounds_(childNo, axis); 788 | soBounds_(childNo, axis) = split; 789 | Scalar prevDistToRegion = distToRegionCache_; 790 | distToRegionCache_ = soBounds_.computeDistToRegion(key_); 791 | if (nearest.shouldTraverse()) 792 | nearest(iters[1-childNo], iters[2-childNo]); // 0 -> (1, 2), 1 -> (0, 1) 793 | soBounds_(childNo, axis) = tmp; 794 | distToRegionCache_ = prevDistToRegion; 795 | } 796 | } 797 | } 798 | }; 799 | 800 | }}}} 801 | 802 | #endif // UNC_ROBOTICS_KDTREE_SO3SPACE_HPP 803 | --------------------------------------------------------------------------------