├── docs └── images │ └── clustered_results.png ├── scripts └── install_dependencies.bash ├── include └── dbscan │ ├── dbscan.hpp │ ├── Utility.hpp │ ├── DBSCAN.hpp │ ├── KDTree.hpp │ └── impl │ ├── DBSCANImpl.ipp │ └── KDTreeImpl.ipp ├── CPPLINT.cfg ├── cmake ├── googletest-download.cmake └── googletest.cmake ├── cuda ├── src │ ├── CudaUtils.cu │ ├── Types.cuh │ └── DBSCAN.cu ├── include │ └── dbscan │ │ ├── CudaUtils.cuh │ │ └── DBSCAN.cuh └── CMakeLists.txt ├── tests ├── CMakeLists.txt ├── TestDBSCAN.cpp └── TestKDTree.cpp ├── Makefile ├── dockerfiles ├── ubuntu2004.dockerfile └── ubuntu2004_gpu.dockerfile ├── CMakeLists.txt ├── .gitignore ├── .github └── workflows │ └── build.yml ├── examples ├── TestDBSCAN.cpp ├── CMakeLists.txt ├── DBSCANPointCloudGPU.cpp ├── AppUtility.hpp └── DBSCANPointCloud.cpp ├── .clang-format └── README.md /docs/images/clustered_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmba15/generic_dbscan/HEAD/docs/images/clustered_results.png -------------------------------------------------------------------------------- /scripts/install_dependencies.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | sudo -l 4 | 5 | sudo apt-get install -y --no-install-recommends libpcl-dev 6 | -------------------------------------------------------------------------------- /include/dbscan/dbscan.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file dbscan.hpp 3 | * 4 | * @author btran 5 | * 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | #include 13 | -------------------------------------------------------------------------------- /CPPLINT.cfg: -------------------------------------------------------------------------------- 1 | set noparent 2 | 3 | filter=-legal/copyright 4 | filter=-whitespace/braces 5 | filter=-build/header_guard 6 | filter=+build/pragma_once 7 | filter=-runtime/references 8 | filter=-build/include_order 9 | filter=-build/include_subdir 10 | linelength=120 11 | -------------------------------------------------------------------------------- /cmake/googletest-download.cmake: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.8) 2 | 3 | project(googletest-download NONE) 4 | 5 | include(ExternalProject) 6 | 7 | ExternalProject_Add( 8 | googletest 9 | SOURCE_DIR "@GOOGLETEST_DOWNLOAD_ROOT@/googletest-src" 10 | BINARY_DIR "@GOOGLETEST_DOWNLOAD_ROOT@/googletest-build" 11 | GIT_REPOSITORY 12 | https://github.com/google/googletest.git 13 | GIT_TAG 14 | release-1.8.0 15 | CONFIGURE_COMMAND "" 16 | BUILD_COMMAND "" 17 | INSTALL_COMMAND "" 18 | TEST_COMMAND "" 19 | ) 20 | -------------------------------------------------------------------------------- /cuda/src/CudaUtils.cu: -------------------------------------------------------------------------------- 1 | /** 2 | * @file CudaUtils.cu 3 | * 4 | * @author btran 5 | * 6 | */ 7 | 8 | #include 9 | 10 | namespace cuda 11 | { 12 | namespace utils 13 | { 14 | namespace 15 | { 16 | __global__ void warmUpGPUKernel() 17 | { 18 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 19 | ++idx; 20 | } 21 | } // namespace 22 | 23 | cudaError_t warmUpGPU() 24 | { 25 | warmUpGPUKernel<<<1, 1>>>(); 26 | cudaDeviceSynchronize(); 27 | return cudaGetLastError(); 28 | } 29 | } // namespace utils 30 | } // namespace cuda 31 | -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if (NOT TARGET gtest) 2 | include(googletest) 3 | __fetch_googletest( 4 | ${PROJECT_SOURCE_DIR}/cmake 5 | ${PROJECT_BINARY_DIR}/${PROJECT_NAME}_googletest 6 | ) 7 | endif() 8 | 9 | add_executable( 10 | ${PROJECT_NAME}_unit_tests 11 | TestKDTree.cpp 12 | TestDBSCAN.cpp 13 | ) 14 | 15 | target_link_libraries( 16 | ${PROJECT_NAME}_unit_tests 17 | PUBLIC 18 | ${PROJECT_NAME} 19 | PRIVATE 20 | gtest_main 21 | ) 22 | 23 | add_test( 24 | NAME 25 | ${PROJECT_NAME}_unit_tests 26 | COMMAND 27 | $ 28 | ) 29 | -------------------------------------------------------------------------------- /cuda/include/dbscan/CudaUtils.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * @file CudaUtils.cuh 3 | * 4 | * @author btran 5 | * 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | namespace cuda 16 | { 17 | namespace utils 18 | { 19 | cudaError_t warmUpGPU(); 20 | } // namespace utils 21 | } // namespace cuda 22 | 23 | inline void HandleError(cudaError_t err, const char* file, int line) 24 | { 25 | if (err != cudaSuccess) { 26 | printf("%s in %s at line %d\n", cudaGetErrorString(err), file, line); 27 | exit(EXIT_FAILURE); 28 | } 29 | } 30 | 31 | #define HANDLE_ERROR(err) (HandleError(err, __FILE__, __LINE__)) 32 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | UTEST=OFF 2 | BUILD_EXAMPLES=OFF 3 | GPU=OFF 4 | CMAKE_ARGS:=$(CMAKE_ARGS) 5 | 6 | default: 7 | @mkdir -p build 8 | @cd build && cmake .. -DBUILDING_TEST=$(UTEST) -DBUILD_EXAMPLES=$(BUILD_EXAMPLES) -DUSE_GPU=$(GPU) -DCMAKE_BUILD_TYPE=Release $(CMAKE_ARGS) && make 9 | 10 | debug: 11 | @mkdir -p build 12 | @cd build && cmake .. -DBUILD_EXAMPLES=$(BUILD_EXAMPLES) -DCMAKE_BUILD_TYPE=Debug $(CMAKE_ARGS) && make 13 | 14 | apps: 15 | @make default BUILD_EXAMPLES=ON 16 | 17 | gpu_apps: 18 | @make apps GPU=ON 19 | 20 | debug_apps: 21 | @make debug BUILD_EXAMPLES=ON 22 | 23 | unittest: 24 | @make default UTEST=ON 25 | @cd build && make test 26 | 27 | clean: 28 | @rm -rf build* 29 | -------------------------------------------------------------------------------- /dockerfiles/ubuntu2004.dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | 5 | WORKDIR /build 6 | COPY ./scripts/install_dependencies.bash . 7 | 8 | RUN apt-get update && \ 9 | apt-get install -y --no-install-recommends \ 10 | sudo \ 11 | gnupg2 \ 12 | lsb-release \ 13 | build-essential \ 14 | software-properties-common \ 15 | cmake \ 16 | git \ 17 | tmux && \ 18 | chmod +x ./install_dependencies.bash && \ 19 | ./install_dependencies.bash && \ 20 | rm -rf /build && \ 21 | apt-get clean && \ 22 | rm -rf /var/lib/apt/lists/* 23 | 24 | WORKDIR /workspace 25 | 26 | ENTRYPOINT ["/bin/bash"] 27 | -------------------------------------------------------------------------------- /dockerfiles/ubuntu2004_gpu.dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.4.0-devel-ubuntu20.04 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | 5 | WORKDIR /build 6 | COPY ./scripts/install_dependencies.bash . 7 | 8 | RUN apt-get update && \ 9 | apt-get install -y --no-install-recommends \ 10 | sudo \ 11 | gnupg2 \ 12 | lsb-release \ 13 | build-essential \ 14 | software-properties-common \ 15 | cmake \ 16 | git \ 17 | tmux && \ 18 | chmod +x ./install_dependencies.bash && \ 19 | ./install_dependencies.bash && \ 20 | rm -rf /build && \ 21 | apt-get clean && \ 22 | rm -rf /var/lib/apt/lists/* 23 | 24 | WORKDIR /workspace 25 | 26 | ENTRYPOINT ["/bin/bash"] 27 | -------------------------------------------------------------------------------- /cuda/include/dbscan/DBSCAN.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * @file DBSCAN.cuh 3 | * 4 | * @author btran 5 | * 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | 12 | namespace clustering 13 | { 14 | namespace cuda 15 | { 16 | template class DBSCAN 17 | { 18 | public: 19 | struct Param { 20 | int pointDimension; 21 | double eps; 22 | int minPoints; 23 | }; 24 | 25 | explicit DBSCAN(const Param& param); 26 | ~DBSCAN(); 27 | 28 | std::vector> run(const PointType*, int numPoints) const; 29 | 30 | private: 31 | Param m_param; 32 | 33 | mutable PointType* m_dPoints; 34 | mutable int m_allocatedSize; 35 | }; 36 | } // namespace cuda 37 | } // namespace clustering 38 | -------------------------------------------------------------------------------- /cmake/googletest.cmake: -------------------------------------------------------------------------------- 1 | function(__fetch_googletest download_module_path download_root) 2 | set(GOOGLETEST_DOWNLOAD_ROOT ${download_root}) 3 | configure_file( 4 | ${download_module_path}/googletest-download.cmake 5 | ${download_root}/CMakeLists.txt 6 | @ONLY 7 | ) 8 | unset(GOOGLETEST_DOWNLOAD_ROOT) 9 | 10 | execute_process( 11 | COMMAND 12 | "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . 13 | WORKING_DIRECTORY 14 | ${download_root} 15 | ) 16 | execute_process( 17 | COMMAND 18 | "${CMAKE_COMMAND}" --build . 19 | WORKING_DIRECTORY 20 | ${download_root} 21 | ) 22 | 23 | add_subdirectory( 24 | ${download_root}/googletest-src 25 | ${download_root}/googletest-build 26 | ) 27 | endfunction() 28 | -------------------------------------------------------------------------------- /cuda/src/Types.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * @file Types.cuh 3 | * 4 | * @author btran 5 | * 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | 12 | namespace clustering 13 | { 14 | namespace cuda 15 | { 16 | enum class NodeType : int { CORE, NOISE }; 17 | 18 | struct Node { 19 | __host__ __device__ Node() 20 | : type(NodeType::NOISE) 21 | , numNeighbors(0) 22 | , visited(false) 23 | { 24 | } 25 | 26 | NodeType type; 27 | int numNeighbors; 28 | char visited; 29 | }; 30 | 31 | struct Graph { 32 | thrust::device_vector nodes; 33 | thrust::device_vector neighborStartIndices; 34 | thrust::device_vector adjList; 35 | }; 36 | } // namespace cuda 37 | } // namespace clustering 38 | -------------------------------------------------------------------------------- /cuda/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | project(dbscan_cuda) 4 | 5 | set (CMAKE_CXX_STANDARD 17) 6 | 7 | find_package(PCL REQUIRED) 8 | 9 | add_library(dbscan_cuda 10 | SHARED 11 | ${CMAKE_CURRENT_SOURCE_DIR}/src/DBSCAN.cu 12 | ${CMAKE_CURRENT_SOURCE_DIR}/src/CudaUtils.cu 13 | ) 14 | 15 | target_include_directories(dbscan_cuda 16 | SYSTEM PUBLIC 17 | $ 18 | $ 19 | ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} 20 | PRIVATE 21 | ${PCL_INCLUDE_DIRS} 22 | ) 23 | 24 | target_link_libraries(dbscan_cuda 25 | PRIVATE 26 | ${PCL_LIBRARIES} 27 | ) 28 | 29 | target_compile_options(dbscan_cuda 30 | PRIVATE 31 | $<$:-expt-extended-lambda -Xcompiler -fPIC -Xcudafe --diag_suppress=esa_on_defaulted_function_ignored> 32 | ) 33 | -------------------------------------------------------------------------------- /include/dbscan/Utility.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file Utility.hpp 3 | * 4 | * @author btran 5 | * 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | namespace clustering 15 | { 16 | template double at(const POINT_TYPE& p, const int axis) 17 | { 18 | if (axis >= POINT_DIMENSION) { 19 | throw std::runtime_error("axis out of range"); 20 | } 21 | 22 | return p[axis]; 23 | } 24 | 25 | /** 26 | * @brief l2 norm between two points 27 | */ 28 | template 29 | double distance(const POINT_TYPE& p1, const POINT_TYPE& p2, 30 | const std::function& valueAtFunc = 31 | clustering::at) 32 | { 33 | double result = 0.0; 34 | for (int i = 0; i < POINT_DIMENSION; ++i) { 35 | result += std::pow(valueAtFunc(p1, i) - valueAtFunc(p2, i), 2); 36 | } 37 | return std::sqrt(result); 38 | } 39 | } // namespace clustering 40 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | project(dbscan VERSION 0.0.1) 4 | 5 | set (CMAKE_CXX_STANDARD 17) 6 | 7 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) 8 | 9 | # ------------------------------------------------ 10 | # cpu library 11 | # ------------------------------------------------ 12 | add_library(dbscan 13 | INTERFACE 14 | ) 15 | 16 | target_include_directories(dbscan 17 | INTERFACE 18 | $ 19 | $ 20 | ) 21 | 22 | target_compile_options(dbscan 23 | INTERFACE 24 | $<$:-O0 -g -Wall -Werror> 25 | $<$:-O3> 26 | ) 27 | 28 | # ------------------------------------------------ 29 | # gpu library 30 | # ------------------------------------------------ 31 | if(USE_GPU) 32 | include(CheckLanguage) 33 | check_language(CUDA) 34 | if (CMAKE_CUDA_COMPILER) 35 | enable_language(CUDA) 36 | add_subdirectory(cuda) 37 | else() 38 | message(STATUS "CUDA not found") 39 | return() 40 | endif() 41 | endif() 42 | 43 | if(BUILD_EXAMPLES) 44 | add_subdirectory(examples) 45 | endif(BUILD_EXAMPLES) 46 | 47 | if(BUILDING_TEST) 48 | enable_testing() 49 | add_subdirectory(tests) 50 | endif() 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled source # 2 | ################### 3 | *.com 4 | *.class 5 | *.dll 6 | *.exe 7 | *.o 8 | *.so 9 | *.pyc 10 | .ipynb_checkpoints 11 | *~ 12 | *# 13 | build* 14 | 15 | # Packages # 16 | ################### 17 | # it's better to unpack these files and commit the raw source 18 | # git has its own built in compression methods 19 | *.7z 20 | *.dmg 21 | *.gz 22 | *.iso 23 | *.jar 24 | *.rar 25 | *.tar 26 | *.zip 27 | 28 | # Logs and databases # 29 | ###################### 30 | *.log 31 | *.sql 32 | *.sqlite 33 | 34 | # OS generated files # 35 | ###################### 36 | .DS_Store 37 | .DS_Store? 38 | ._* 39 | .Spotlight-V100 40 | .Trashes 41 | ehthumbs.db 42 | Thumbs.db 43 | 44 | # Images 45 | ###################### 46 | *.jpg 47 | *.gif 48 | *.png 49 | *.svg 50 | *.ico 51 | 52 | # Video 53 | ###################### 54 | *.wmv 55 | *.mpg 56 | *.mpeg 57 | *.mp4 58 | *.mov 59 | *.flv 60 | *.avi 61 | *.ogv 62 | *.ogg 63 | *.webm 64 | 65 | # Audio 66 | ###################### 67 | *.wav 68 | *.mp3 69 | *.wma 70 | 71 | # Fonts 72 | ###################### 73 | Fonts 74 | *.eot 75 | *.ttf 76 | *.woff 77 | 78 | # Format 79 | ###################### 80 | CPPLINT.cfg 81 | .clang-format 82 | 83 | # Gtags 84 | ###################### 85 | GPATH 86 | GRTAGS 87 | GSYMS 88 | GTAGS 89 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | 9 | jobs: 10 | test-production: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - name: Setup python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: '3.x' 20 | 21 | - name: Apply cpplint 22 | run: | 23 | pip install cpplint 24 | find ./ -iname *.hpp -o -iname *.cpp | xargs cpplint 25 | - name: Apply clang-format 26 | run: | 27 | sudo apt-get update 28 | sudo apt-get install -y --no-install-recommends clang-format 29 | find ./ -iname *.hpp -o -iname *.cpp | xargs clang-format -n --verbose --Werror 30 | - name: Setup environment 31 | run: | 32 | sudo apt-get update 33 | sudo apt-get install -y --no-install-recommends software-properties-common build-essential cmake git 34 | chmod +x scripts/install_dependencies.bash && bash scripts/install_dependencies.bash 35 | - name: Run tests 36 | run: | 37 | make unittest 38 | 39 | - name: Run apps 40 | run: | 41 | make apps -j`nproc` 42 | ./build/examples/test_pointcloud_clustering ./data/street_no_ground.pcd 0.7 3 0 43 | -------------------------------------------------------------------------------- /examples/TestDBSCAN.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file TestDBSCAN.cpp 3 | * 4 | * @author btran 5 | * 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | namespace 14 | { 15 | class RandDouble 16 | { 17 | public: 18 | RandDouble(const double low, const double high) 19 | : m_genFunc(std::bind(std::uniform_real_distribution<>(low, high), std::default_random_engine())) 20 | { 21 | } 22 | 23 | double operator()() const 24 | { 25 | return m_genFunc(); 26 | } 27 | 28 | private: 29 | std::function m_genFunc; 30 | }; 31 | 32 | class Point3D : public std::array 33 | { 34 | public: 35 | Point3D(const double x, const double y, const double z) 36 | { 37 | (*this)[0] = x; 38 | (*this)[1] = y; 39 | (*this)[2] = y; 40 | } 41 | }; 42 | 43 | using DBSCAN = clustering::DBSCAN<3, Point3D>; 44 | } // namespace 45 | 46 | int main(int argc, char* argv[]) 47 | { 48 | const double eps = 10; 49 | const int minPoints = 3; 50 | 51 | const int numPoints = 24000; 52 | ::DBSCAN::VecPointType points; 53 | points.reserve(numPoints); 54 | { 55 | auto doubleGenerator = ::RandDouble(0.0, 2000.0); 56 | for (int i = 0; i < numPoints; ++i) { 57 | points.emplace_back(doubleGenerator(), doubleGenerator(), doubleGenerator()); 58 | } 59 | } 60 | ::DBSCAN dbscan(eps, minPoints); 61 | dbscan.estimateClusterIndices(points); 62 | 63 | return EXIT_SUCCESS; 64 | } 65 | -------------------------------------------------------------------------------- /examples/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | # --------------------------------------------------------------- 4 | # test simple dbscan example 5 | # --------------------------------------------------------------- 6 | 7 | add_executable(test_simple_dbscan 8 | ${CMAKE_CURRENT_LIST_DIR}/TestDBSCAN.cpp 9 | ) 10 | 11 | target_link_libraries(test_simple_dbscan 12 | PUBLIC 13 | ${PROJECT_NAME} 14 | ) 15 | 16 | # --------------------------------------------------------------- 17 | # test point cloud clustering with dbscan 18 | # --------------------------------------------------------------- 19 | 20 | find_package(PCL REQUIRED) 21 | 22 | add_executable(test_pointcloud_clustering 23 | ${CMAKE_CURRENT_LIST_DIR}/DBSCANPointCloud.cpp 24 | ) 25 | 26 | target_link_libraries(test_pointcloud_clustering 27 | PUBLIC 28 | ${PROJECT_NAME} 29 | ${PCL_LIBRARIES} 30 | ) 31 | 32 | target_include_directories(test_pointcloud_clustering 33 | SYSTEM PUBLIC 34 | ${PCL_INCLUDE_DIRS} 35 | ) 36 | 37 | # --------------------------------------------------------------- 38 | # test point cloud clustering with gpu dbscan 39 | # --------------------------------------------------------------- 40 | 41 | if(NOT CMAKE_CUDA_COMPILER OR NOT USE_GPU) 42 | return() 43 | endif() 44 | 45 | add_executable(test_pointcloud_clustering_gpu 46 | ${CMAKE_CURRENT_LIST_DIR}/DBSCANPointCloudGPU.cpp 47 | ) 48 | 49 | target_link_libraries(test_pointcloud_clustering_gpu 50 | PUBLIC 51 | dbscan_cuda 52 | ${PCL_LIBRARIES} 53 | ) 54 | 55 | target_include_directories(test_pointcloud_clustering_gpu 56 | SYSTEM PUBLIC 57 | ${PCL_INCLUDE_DIRS} 58 | ) 59 | -------------------------------------------------------------------------------- /tests/TestDBSCAN.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file TestDBSCAN.cpp 3 | * 4 | * @author btran 5 | * 6 | * Copyright (c) organization 7 | * 8 | */ 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | #include 16 | 17 | namespace 18 | { 19 | class RandDouble 20 | { 21 | public: 22 | RandDouble(const double low, const double high) 23 | : m_genFunc(std::bind(std::uniform_real_distribution<>(low, high), std::default_random_engine())) 24 | { 25 | } 26 | 27 | double operator()() const 28 | { 29 | return m_genFunc(); 30 | } 31 | 32 | private: 33 | std::function m_genFunc; 34 | }; 35 | 36 | class Point3D : public std::array 37 | { 38 | public: 39 | Point3D(const double x, const double y, const double z) 40 | { 41 | (*this)[0] = x; 42 | (*this)[1] = y; 43 | (*this)[2] = y; 44 | } 45 | }; 46 | 47 | using DBSCAN = clustering::DBSCAN<3, Point3D>; 48 | } // namespace 49 | 50 | TEST(TestDBSCAN, TestInitialization) 51 | { 52 | const double eps = 0.2; 53 | const int minPoints = 3; 54 | 55 | ASSERT_NO_THROW(::DBSCAN dbscan(eps, minPoints)); 56 | } 57 | 58 | TEST(TestDBSCAN, TestClustering) 59 | { 60 | const double eps = 5; 61 | const int minPoints = 3; 62 | 63 | const int numPoints = 1000; 64 | ::DBSCAN::VecPointType points; 65 | points.reserve(numPoints); 66 | { 67 | auto doubleGenerator = ::RandDouble(0.0, 200.0); 68 | for (int i = 0; i < numPoints; ++i) { 69 | points.emplace_back(doubleGenerator(), doubleGenerator(), doubleGenerator()); 70 | } 71 | } 72 | 73 | ::DBSCAN dbscan(eps, minPoints); 74 | ASSERT_NO_THROW(dbscan.estimateClusterIndices(points)); 75 | ASSERT_NO_THROW(dbscan.estimateClusters(points)); 76 | } 77 | -------------------------------------------------------------------------------- /include/dbscan/DBSCAN.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file DBSCAN.hpp 3 | * 4 | * @author btran 5 | * 6 | * Copyright (c) organization 7 | * 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | #include 19 | 20 | namespace clustering 21 | { 22 | template > class DBSCAN 23 | { 24 | public: 25 | using PointType = POINT_TYPE; 26 | using VecPointType = VEC_POINT_TYPE; 27 | 28 | using Ptr = std::shared_ptr; 29 | 30 | using ValueAtFunctionType = const std::function; 31 | using DistanceFunctionType = std::function; 32 | 33 | public: 34 | DBSCAN(const double eps, const int minPoints, 35 | const DistanceFunctionType& distFunc = clustering::distance, 36 | const ValueAtFunctionType& valueAtFunc = clustering::at); 37 | 38 | ~DBSCAN() = default; 39 | 40 | std::vector> estimateClusterIndices(const VecPointType& points); 41 | std::vector estimateClusters(const VecPointType& points); 42 | 43 | protected: 44 | void expandCluster(const VecPointType& points, std::vector& curCluster, const std::vector& neighbors, 45 | std::vector& visited, std::vector& isNoise) const; 46 | 47 | private: 48 | double m_eps; 49 | int m_minPoints; 50 | 51 | const ValueAtFunctionType m_valueAtFunc; 52 | const DistanceFunctionType m_distFunc; 53 | 54 | typename KDTree::Ptr m_kdtree; 55 | }; 56 | } // namespace clustering 57 | #include "impl/DBSCANImpl.ipp" 58 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: LLVM 4 | AccessModifierOffset: -3 5 | AlignAfterOpenBracket: true 6 | AlignEscapedNewlinesLeft: false 7 | AlignOperands: true 8 | AlignTrailingComments: true 9 | AllowAllParametersOfDeclarationOnNextLine: true 10 | AllowShortBlocksOnASingleLine: false 11 | AllowShortCaseLabelsOnASingleLine: false 12 | AllowShortIfStatementsOnASingleLine: false 13 | AllowShortLoopsOnASingleLine: false 14 | AllowShortFunctionsOnASingleLine: None 15 | AlwaysBreakAfterDefinitionReturnType: false 16 | AlwaysBreakTemplateDeclarations: false 17 | AlwaysBreakBeforeMultilineStrings: false 18 | BreakBeforeBinaryOperators: None 19 | BreakBeforeTernaryOperators: true 20 | BreakConstructorInitializersBeforeComma: true 21 | BinPackParameters: true 22 | BinPackArguments: true 23 | ColumnLimit: 120 24 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 25 | ConstructorInitializerIndentWidth: 4 26 | DerivePointerAlignment: false 27 | ExperimentalAutoDetectBinPacking: false 28 | IndentCaseLabels: true 29 | IndentWrappedFunctionNames: false 30 | IndentFunctionDeclarationAfterType: false 31 | MaxEmptyLinesToKeep: 1 32 | KeepEmptyLinesAtTheStartOfBlocks: false 33 | NamespaceIndentation: None 34 | ObjCBlockIndentWidth: 2 35 | ObjCSpaceAfterProperty: false 36 | ObjCSpaceBeforeProtocolList: true 37 | PenaltyBreakBeforeFirstCallParameter: 19 38 | PenaltyBreakComment: 300 39 | PenaltyBreakString: 1000 40 | PenaltyBreakFirstLessLess: 120 41 | PenaltyExcessCharacter: 1000000 42 | PenaltyReturnTypeOnItsOwnLine: 60 43 | PointerAlignment: Left 44 | SpacesBeforeTrailingComments: 2 45 | Cpp11BracedListStyle: true 46 | Standard: Cpp11 47 | IndentWidth: 4 48 | TabWidth: 8 49 | UseTab: Never 50 | BreakBeforeBraces: Linux # the position of the braces 51 | SpacesInParentheses: false 52 | SpacesInSquareBrackets: false 53 | SpacesInAngles: false 54 | SpaceInEmptyParentheses: false 55 | SpacesInCStyleCastParentheses: false 56 | SpaceAfterCStyleCast: false 57 | SpacesInContainerLiterals: true 58 | SpaceBeforeAssignmentOperators: true 59 | ContinuationIndentWidth: 4 60 | CommentPragmas: '.*' 61 | ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] 62 | SpaceBeforeParens: ControlStatements 63 | DisableFormat: false 64 | ... 65 | -------------------------------------------------------------------------------- /examples/DBSCANPointCloudGPU.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file DBSCANPointCloudGPU.cpp 3 | * 4 | * @author btran 5 | * 6 | */ 7 | 8 | #include 9 | 10 | #include 11 | 12 | #include "AppUtility.hpp" 13 | #include 14 | #include 15 | 16 | namespace 17 | { 18 | using PointCloudType = pcl::PointXYZ; 19 | using PointCloud = pcl::PointCloud; 20 | using PointCloudPtr = PointCloud::Ptr; 21 | using DBSCAN = clustering::cuda::DBSCAN; 22 | 23 | auto timer = pcl::StopWatch(); 24 | int NUM_TEST = 10; 25 | } // namespace 26 | 27 | int main(int argc, char* argv[]) 28 | { 29 | if (argc < 4) { 30 | std::cerr << "Usage: [app] [path/to/pcl/file] [eps] [min/points] [to/visualize/0:no/1:yes]\n"; 31 | return EXIT_FAILURE; 32 | } 33 | 34 | const std::string pclFilePath = argv[1]; 35 | const double eps = std::atof(argv[2]); 36 | const int minPoints = std::atoi(argv[3]); 37 | bool toVisualize = argc == 5 ? std::atoi(argv[4]) : true; 38 | 39 | PointCloudPtr inCloud(new PointCloud); 40 | if (pcl::io::loadPCDFile(pclFilePath, *inCloud) == -1) { 41 | std::cerr << "Failed to load pcl file\n"; 42 | return EXIT_FAILURE; 43 | } 44 | std::cout << "number of points: " << inCloud->size() << "\n"; 45 | 46 | cuda::utils::warmUpGPU(); 47 | 48 | DBSCAN::Param param{.pointDimension = 3, .eps = eps, .minPoints = minPoints}; 49 | DBSCAN dbscanHandler(param); 50 | std::vector> clusterIndices; 51 | 52 | timer.reset(); 53 | for (int i = 0; i < NUM_TEST; ++i) { 54 | clusterIndices = dbscanHandler.run(inCloud->points.data(), inCloud->size()); 55 | } 56 | 57 | std::cout << "number of clusters: " << clusterIndices.size() << "\n"; 58 | std::cout << "processing time (gpu): " << timer.getTime() / NUM_TEST << "[ms]\n"; 59 | 60 | if (!toVisualize) { 61 | return EXIT_SUCCESS; 62 | } 63 | 64 | std::vector clusters = ::toClusters(clusterIndices, inCloud); 65 | 66 | auto [colors, colorHandlers] = ::initPclColorHandlers(clusters); 67 | 68 | std::size_t countElem = 0; 69 | auto pclViewer = initializeViewer(); 70 | for (const auto& cluster : clusters) { 71 | pclViewer->addPointCloud(cluster, colorHandlers[countElem], std::to_string(countElem)); 72 | pclViewer->setPointCloudRenderingProperties(pcl::visualization::PCL_VISUALIZER_POINT_SIZE, 3, 73 | std::to_string(countElem)); 74 | countElem++; 75 | } 76 | 77 | while (!pclViewer->wasStopped()) { 78 | pclViewer->spinOnce(); 79 | } 80 | 81 | return EXIT_SUCCESS; 82 | } 83 | -------------------------------------------------------------------------------- /include/dbscan/KDTree.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file KDTree.hpp 3 | * 4 | * @author btran 5 | * 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | #include 21 | 22 | namespace clustering 23 | { 24 | template > class KDTree 25 | { 26 | public: 27 | using PointType = POINT_TYPE; 28 | using VecPointType = VEC_POINT_TYPE; 29 | 30 | using Ptr = std::shared_ptr; 31 | 32 | using ValueAtFunctionType = const std::function; 33 | using DistanceFunctionType = std::function; 34 | 35 | public: 36 | explicit KDTree(const VecPointType& points, 37 | const DistanceFunctionType& distFunc = clustering::distance, 38 | const ValueAtFunctionType& valueAtFunc = clustering::at); 39 | ~KDTree(); 40 | 41 | int nearestSearch(const PointType& queryPoint) const; 42 | 43 | std::vector radiusSearch(const PointType& queryPoint, const double radius) const; 44 | 45 | std::vector knnSearch(const PointType& queryPoint, const int k) const; 46 | 47 | auto valueAtFunc() const 48 | { 49 | return m_valueAtFunc; 50 | } 51 | 52 | auto distFunc() const 53 | { 54 | return m_distFunc; 55 | } 56 | 57 | private: 58 | void buildKDTree(); 59 | void clear(); 60 | 61 | // DistanceIndex->(distance from query point to current point, index of the current point) 62 | using DistanceIndex = std::pair; 63 | 64 | // priority queue from max heap based on distance 65 | using KNNPriorityQueue = std::priority_queue, std::less>; 66 | 67 | struct KDNode; 68 | KDNode* insertRec(std::vector& pointIndices, const int depth); 69 | void clearRec(KDNode* node); 70 | void nearestSearchRec(const PointType& queryPoint, const KDNode* node, int& nearestPointIdx, 71 | double& nearestDist) const; 72 | void radiusSearchRec(const PointType& queryPoint, const KDNode* node, std::vector& indices, 73 | const double radius) const; 74 | void knnSearchRec(const PointType& queryPoint, const KDNode* node, KNNPriorityQueue& knnMaxHeap, 75 | double& maxDistance, const int k) const; 76 | 77 | private: 78 | KDNode* m_root; 79 | 80 | const VecPointType& m_points; 81 | 82 | ValueAtFunctionType m_valueAtFunc; 83 | DistanceFunctionType m_distFunc; 84 | }; 85 | } // namespace clustering 86 | 87 | #include "impl/KDTreeImpl.ipp" 88 | -------------------------------------------------------------------------------- /examples/AppUtility.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file AppUtility.hpp 3 | * 4 | * @author btran 5 | * 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace 16 | { 17 | inline pcl::PointIndices::Ptr toPclPointIndices(const std::vector& indices) 18 | { 19 | pcl::PointIndices::Ptr pclIndices(new pcl::PointIndices); 20 | pclIndices->indices = indices; 21 | 22 | return pclIndices; 23 | } 24 | 25 | template 26 | inline std::vector::Ptr> 27 | toClusters(const std::vector>& clusterIndices, 28 | const typename pcl::PointCloud::Ptr& inPcl) 29 | { 30 | std::vector::Ptr> clusters; 31 | clusters.reserve(clusterIndices.size()); 32 | 33 | for (const auto& curIndices : clusterIndices) { 34 | typename pcl::PointCloud::Ptr curPcl(new pcl::PointCloud); 35 | pcl::copyPointCloud(*inPcl, *toPclPointIndices(curIndices), *curPcl); 36 | clusters.emplace_back(curPcl); 37 | } 38 | 39 | return clusters; 40 | } 41 | 42 | inline std::vector> generateColorCharts(const uint16_t numSources, 43 | const std::uint16_t seed = 2020) 44 | { 45 | std::srand(seed); 46 | 47 | std::vector> colors; 48 | colors.reserve(numSources); 49 | 50 | for (std::uint16_t i = 0; i < numSources; ++i) { 51 | colors.emplace_back(std::array{static_cast(std::rand() % 256), 52 | static_cast(std::rand() % 256), 53 | static_cast(std::rand() % 256)}); 54 | } 55 | 56 | return colors; 57 | } 58 | 59 | template 60 | inline std::pair>, 61 | std::vector>> 62 | initPclColorHandlers(const std::vector::Ptr>& inPcls) 63 | { 64 | using PointCloudType = POINT_CLOUD_TYPE; 65 | using PointCloud = pcl::PointCloud; 66 | 67 | std::vector> pclColorHandlers; 68 | pclColorHandlers.reserve(inPcls.size()); 69 | auto colorCharts = ::generateColorCharts(inPcls.size()); 70 | 71 | for (int i = 0; i < inPcls.size(); ++i) { 72 | const auto& curColor = colorCharts[i]; 73 | pclColorHandlers.emplace_back(pcl::visualization::PointCloudColorHandlerCustom( 74 | inPcls[i], curColor[0], curColor[1], curColor[2])); 75 | } 76 | 77 | return std::make_pair(colorCharts, pclColorHandlers); 78 | } 79 | 80 | pcl::visualization::PCLVisualizer::Ptr initializeViewer() 81 | { 82 | pcl::visualization::PCLVisualizer::Ptr viewer(new pcl::visualization::PCLVisualizer("3D Viewer")); 83 | pcl::PointXYZ o(0.1, 0, 0); 84 | viewer->addSphere(o, 0.1, "sphere", 0); 85 | viewer->setBackgroundColor(0.05, 0.05, 0.05, 0); 86 | viewer->addCoordinateSystem(0.5); 87 | viewer->setCameraPosition(-26, 0, 3, 10, -1, 0.5, 0, 0, 1); 88 | 89 | return viewer; 90 | } 91 | } // namespace 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![.github/workflows/build.yml](https://github.com/xmba15/generic_dbscan/actions/workflows/build.yml/badge.svg)](https://github.com/xmba15/generic_dbscan/actions/workflows/build.yml) 2 | 3 | # 📝 c++ generic DBSCAN # 4 | *** 5 | 6 | - c++ generic dbscan library on CPU & GPU. header-only codes for cpu dbscan. 7 | - cpu dbscan uses kd-tree for radius search. 8 | - gpu dbscan or [G-DBSCAN](https://reader.elsevier.com/reader/sd/pii/S1877050913003438?token=E033ECA84C7B1F55917A637A014A22704BAE63308219758D0D58A96E341DB7C8FF05EE97208D877494577F20ED0226D8&originRegion=us-east-1&originCreation=20211003151554) uses BFS on nodes of point. 9 | 10 | ## :tada: TODO ## 11 | *** 12 | 13 | - [x] Implement generic kd-tree 14 | - [x] Implement generic dbscan 15 | - [x] Create unittest & examples 16 | - [x] GPU DBSCAN 17 | 18 | ## 🎛 Dependencies ## 19 | *** 20 | 21 | - base dependencies 22 | 23 | ```bash 24 | sudo apt-get install \ 25 | libpcl-dev \ 26 | ``` 27 | 28 | *tested with pcl 1.10* 29 | 30 | - gpu dbscan 31 | 32 | [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit) (tested on cuda 11) 33 | 34 | *lower pcl version (that uses older eigen version) might not be compatible with cuda* 35 | 36 | ## 🔨 How to Build ## 37 | *** 38 | 39 | ```bash 40 | # build lib 41 | make default -j`nproc` 42 | 43 | # build examples 44 | 45 | # build only cpu dbscan 46 | make apps -j`nproc` 47 | 48 | # build both cpu and gpu dbscan 49 | make gpu_apps -j`nproc` 50 | ``` 51 | 52 | ## :running: How to Run ## 53 | *** 54 | - *This library provides an example with clustering point cloud from [livox horizon lidar](https://www.livoxtech.com/horizon)* 55 | 56 | ```bash 57 | # after make apps or make gpu_apps 58 | 59 | # cpu 60 | ./build/examples/test_pointcloud_clustering [path/to/pcl/file] [eps] [min/points] 61 | 62 | # gpu 63 | 64 | ./build/examples/test_pointcloud_clustering_gpu [path/to/pcl/file] [eps] [min/points] 65 | 66 | # eps and min points are parameters of dbscan algorithm 67 | 68 | # for example 69 | ./build/examples/test_pointcloud_clustering ./data/street_no_ground.pcd 0.7 3 1 70 | 71 | # or 72 | ./build/examples/test_pointcloud_clustering_gpu ./data/street_no_ground.pcd 0.7 3 1 73 | 74 | # change final parameters to 0 to disable visualization 75 | ``` 76 | 77 | - processing time (average of 10 tests): 78 | 79 | ```bash 80 | # number of points: 11619 81 | # processing time (cpu): 80[ms] 82 | # processing time (gpu): 38[ms] 83 | 84 | # difference in speed will get more obvious with point cloud of more points 85 | ``` 86 | 87 | - Here is the sample result: 88 | 89 | ![clustered_results](./docs/images/clustered_results.png) 90 | 91 | ### :whale: How to Run with Docker ### 92 | 93 | - cpu 94 | 95 | ```bash 96 | # build 97 | docker build -f ./dockerfiles/ubuntu2004_gpu.dockerfile -t dbscan . 98 | 99 | # run 100 | docker run -it --rm -v `pwd`:/workspace dbscan 101 | ``` 102 | 103 | - gpu: change [the cuda version here](https://github.com/xmba15/generic_dbscan/blob/master/dockerfiles/ubuntu2004_gpu.dockerfile#L1) to match your local cuda version before build. 104 | 105 | ```bash 106 | # build 107 | docker build -f ./dockerfiles/ubuntu2004_gpu.dockerfile -t dbscan_gpu . 108 | 109 | # run 110 | docker run -it --rm --gpus all -v `pwd`:/workspace dbscan_gpu 111 | ``` 112 | 113 | ## :gem: References ## 114 | *** 115 | 116 | - [kdtree](https://en.wikipedia.org/wiki/K-d_tree) 117 | - [dbscan](https://en.wikipedia.org/wiki/DBSCAN) 118 | - [G-DBSCAN: A GPU Accelerated Algorithm for Density-based Clustering](https://www.sciencedirect.com/science/article/pii/S1877050913003438) 119 | - [SLIC-DBSCAN-CUDA](https://github.com/ca1773130n/SLIC-DBSCAN-CUDA) 120 | -------------------------------------------------------------------------------- /examples/DBSCANPointCloud.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file DBSCANPointCloud.cpp 3 | * 4 | * @author btran 5 | * 6 | */ 7 | 8 | #include 9 | 10 | #include 11 | 12 | #include "AppUtility.hpp" 13 | 14 | namespace 15 | { 16 | template double at(const POINT_CLOUD_TYPE& p, const int axis) 17 | { 18 | switch (axis) { 19 | case 0: { 20 | return p.x; 21 | break; 22 | } 23 | case 1: { 24 | return p.y; 25 | break; 26 | } 27 | case 2: { 28 | return p.z; 29 | break; 30 | } 31 | default: { 32 | throw std::runtime_error("axis out of range"); 33 | } 34 | } 35 | } 36 | 37 | template 38 | double 39 | distance(const POINT_CLOUD_TYPE& p1, const POINT_CLOUD_TYPE& p2, 40 | const std::function& valueAtFunc = ::at) 41 | { 42 | double result = 0.0; 43 | for (int i = 0; i < 3; ++i) { 44 | result += std::pow(valueAtFunc(p1, i) - valueAtFunc(p2, i), 2); 45 | } 46 | return std::sqrt(result); 47 | } 48 | 49 | using PointCloudType = pcl::PointXYZ; 50 | using PointCloud = pcl::PointCloud; 51 | using PointCloudPtr = PointCloud::Ptr; 52 | 53 | auto timer = pcl::StopWatch(); 54 | int NUM_TEST = 10; 55 | } // namespace 56 | 57 | int main(int argc, char* argv[]) 58 | { 59 | if (argc < 4) { 60 | std::cerr << "Usage: [app] [path/to/pcl/file] [eps] [min/points] [to/visualize/0:no/1:yes]\n"; 61 | return EXIT_FAILURE; 62 | } 63 | 64 | const std::string pclFilePath = argv[1]; 65 | const double eps = std::atof(argv[2]); 66 | const int minPoints = std::atoi(argv[3]); 67 | bool toVisualize = argc == 5 ? std::atoi(argv[4]) : true; 68 | 69 | PointCloudPtr inputPcl(new PointCloud); 70 | if (pcl::io::loadPCDFile(pclFilePath, *inputPcl) == -1) { 71 | std::cerr << "Failed to load pcl file\n"; 72 | return EXIT_FAILURE; 73 | } 74 | std::cout << "number of points: " << inputPcl->size() << "\n"; 75 | 76 | using DBSCAN = clustering::DBSCAN<3, PointCloudType, decltype(inputPcl->points)>; 77 | DBSCAN::Ptr dbscan = std::make_shared(eps, minPoints, ::distance, ::at); 78 | 79 | std::vector> clusterIndices; 80 | timer.reset(); 81 | for (int i = 0; i < NUM_TEST; ++i) { 82 | clusterIndices = dbscan->estimateClusterIndices(inputPcl->points); 83 | } 84 | 85 | std::cout << "number of clusters: " << clusterIndices.size() << "\n"; 86 | std::cout << "processing time (cpu): " << timer.getTime() / NUM_TEST << "[ms]\n"; 87 | 88 | if (!toVisualize) { 89 | return EXIT_SUCCESS; 90 | } 91 | 92 | std::vector clusters = ::toClusters(clusterIndices, inputPcl); 93 | 94 | auto [colors, colorHandlers] = ::initPclColorHandlers(clusters); 95 | 96 | std::size_t countElem = 0; 97 | auto pclViewer = initializeViewer(); 98 | for (const auto& cluster : clusters) { 99 | pclViewer->addPointCloud(cluster, colorHandlers[countElem], std::to_string(countElem)); 100 | pclViewer->setPointCloudRenderingProperties(pcl::visualization::PCL_VISUALIZER_POINT_SIZE, 3, 101 | std::to_string(countElem)); 102 | countElem++; 103 | } 104 | 105 | while (!pclViewer->wasStopped()) { 106 | pclViewer->spinOnce(); 107 | } 108 | 109 | return EXIT_SUCCESS; 110 | } 111 | -------------------------------------------------------------------------------- /include/dbscan/impl/DBSCANImpl.ipp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file DBSCANImpl.ipp 3 | * 4 | * @author btran 5 | * 6 | * Copyright (c) organization 7 | * 8 | */ 9 | 10 | namespace clustering 11 | { 12 | template 13 | DBSCAN::DBSCAN(const double eps, const int minPoints, 14 | const DistanceFunctionType& distFunc, 15 | const ValueAtFunctionType& valueAtFunc) 16 | : m_valueAtFunc(valueAtFunc) 17 | , m_distFunc(distFunc) 18 | , m_eps(eps) 19 | , m_minPoints(minPoints) 20 | , m_kdtree(nullptr) 21 | { 22 | } 23 | 24 | template 25 | std::vector> 26 | DBSCAN::estimateClusterIndices(const VecPointType& points) 27 | { 28 | m_kdtree = std::make_shared>(points, m_distFunc, m_valueAtFunc); 29 | 30 | std::vector> clusterIndices; 31 | std::vector visited(points.size(), false); 32 | std::vector isNoise(points.size(), false); 33 | 34 | for (std::size_t i = 0; i < points.size(); ++i) { 35 | if (visited[i]) { 36 | continue; 37 | } 38 | 39 | visited[i] = true; 40 | const PointType& curPoint = points[i]; 41 | const std::vector neighbors = m_kdtree->radiusSearch(curPoint, m_eps); 42 | 43 | if (neighbors.size() < m_minPoints) { 44 | isNoise[i] = true; 45 | } else { 46 | clusterIndices.emplace_back(std::vector{static_cast(i)}); 47 | this->expandCluster(points, clusterIndices.back(), neighbors, visited, isNoise); 48 | } 49 | } 50 | 51 | return clusterIndices; 52 | } 53 | 54 | template 55 | void DBSCAN::expandCluster(const VecPointType& points, 56 | std::vector& curCluster, 57 | const std::vector& neighbors, 58 | std::vector& visited, 59 | std::vector& isNoise) const 60 | { 61 | std::deque neighborDeque(neighbors.begin(), neighbors.end()); 62 | while (!neighborDeque.empty()) { 63 | int curIdx = neighborDeque.front(); 64 | neighborDeque.pop_front(); 65 | 66 | if (isNoise[curIdx]) { 67 | curCluster.emplace_back(curIdx); 68 | continue; 69 | } 70 | 71 | if (!visited[curIdx]) { 72 | visited[curIdx] = true; 73 | curCluster.emplace_back(curIdx); 74 | 75 | const PointType& curPoint = points[curIdx]; 76 | std::vector curNeighbors = m_kdtree->radiusSearch(curPoint, m_eps); 77 | 78 | if (curNeighbors.size() < m_minPoints) { 79 | continue; 80 | } 81 | 82 | std::copy(curNeighbors.begin(), curNeighbors.end(), std::back_inserter(neighborDeque)); 83 | } 84 | } 85 | } 86 | 87 | template 88 | std::vector::VecPointType> 89 | DBSCAN::estimateClusters(const VecPointType& points) 90 | { 91 | std::vector> clusterIndices = this->estimateClusterIndices(points); 92 | std::vector clusters; 93 | clusters.reserve(clusterIndices.size()); 94 | 95 | for (const auto& curIndices : clusterIndices) { 96 | VecPointType curCluster; 97 | curCluster.reserve(curIndices.size()); 98 | for (const int pointIdx : curIndices) { 99 | curCluster.emplace_back(points[pointIdx]); 100 | } 101 | clusters.emplace_back(curCluster); 102 | } 103 | 104 | return clusters; 105 | } 106 | } // namespace clustering 107 | -------------------------------------------------------------------------------- /tests/TestKDTree.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file TestKDTree.cpp 3 | * 4 | * @author btran 5 | * 6 | * Copyright (c) organization 7 | * 8 | */ 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | #include 16 | 17 | namespace 18 | { 19 | class RandDouble 20 | { 21 | public: 22 | RandDouble(const double low, const double high) 23 | : m_genFunc(std::bind(std::uniform_real_distribution<>(low, high), std::default_random_engine())) 24 | { 25 | } 26 | 27 | double operator()() const 28 | { 29 | return m_genFunc(); 30 | } 31 | 32 | private: 33 | std::function m_genFunc; 34 | }; 35 | 36 | class Point3D : public std::array 37 | { 38 | public: 39 | Point3D(const double x, const double y, const double z) 40 | { 41 | (*this)[0] = x; 42 | (*this)[1] = y; 43 | (*this)[2] = y; 44 | } 45 | }; 46 | 47 | using KDTree = clustering::KDTree<3, Point3D>; 48 | } // namespace 49 | 50 | TEST(TestKDTree, TestInitialization) 51 | { 52 | const int numPoints = 1000; 53 | KDTree::VecPointType points; 54 | points.reserve(numPoints); 55 | { 56 | auto doubleGenerator = ::RandDouble(0.0, 200.0); 57 | for (int i = 0; i < numPoints; ++i) { 58 | points.emplace_back(doubleGenerator(), doubleGenerator(), doubleGenerator()); 59 | } 60 | } 61 | 62 | { 63 | ASSERT_NO_THROW(::KDTree kdTree(points);); 64 | } 65 | 66 | { 67 | ::KDTree::ValueAtFunctionType valueAtFunc = [](const ::KDTree::PointType& p, const int axis) { 68 | if (axis >= 3) { 69 | throw std::runtime_error("axis out of range\n"); 70 | } 71 | 72 | return p[axis]; 73 | }; 74 | 75 | const auto distFunc = [](const ::KDTree::PointType& p1, const ::KDTree::PointType& p2, 76 | const ::KDTree::ValueAtFunctionType& valueAtFunc) { 77 | double result = 0.0; 78 | for (int i = 0; i < 3; ++i) { 79 | result += std::fabs(valueAtFunc(p1, i) - valueAtFunc(p2, i)); 80 | } 81 | return std::sqrt(result); 82 | }; 83 | 84 | ASSERT_NO_THROW(::KDTree kdTree(points, distFunc, valueAtFunc);); 85 | } 86 | } 87 | 88 | TEST(TestKDTree, TestNearestSearch) 89 | { 90 | { 91 | const int numPoints = 10000; 92 | KDTree::VecPointType points; 93 | points.reserve(numPoints); 94 | { 95 | auto doubleGenerator = ::RandDouble(0.0, 200.0); 96 | for (int i = 0; i < numPoints; ++i) { 97 | points.emplace_back(doubleGenerator(), doubleGenerator(), doubleGenerator()); 98 | } 99 | } 100 | 101 | ::KDTree kdTree(points); 102 | 103 | KDTree::PointType queryPoint(50.0, 40.0, 40.6); 104 | 105 | int expectedNearestIdx = kdTree.nearestSearch(queryPoint); 106 | 107 | int bruteForcedNearestIdx = -1; 108 | double bruteForcedMinDist = std::numeric_limits::max(); 109 | for (int i = 0; i < numPoints; ++i) { 110 | const auto& curPoint = points[i]; 111 | const double curDist = kdTree.distFunc()(queryPoint, curPoint, kdTree.valueAtFunc()); 112 | if (curDist < bruteForcedMinDist) { 113 | bruteForcedNearestIdx = i; 114 | bruteForcedMinDist = curDist; 115 | } 116 | } 117 | 118 | EXPECT_EQ(expectedNearestIdx, bruteForcedNearestIdx); 119 | } 120 | } 121 | 122 | TEST(TestKDTree, TestRadiusSearch) 123 | { 124 | { 125 | const int numPoints = 10000; 126 | KDTree::VecPointType points; 127 | points.reserve(numPoints); 128 | { 129 | auto doubleGenerator = ::RandDouble(0.0, 200.0); 130 | for (int i = 0; i < numPoints; ++i) { 131 | points.emplace_back(doubleGenerator(), doubleGenerator(), doubleGenerator()); 132 | } 133 | } 134 | 135 | ::KDTree kdTree(points); 136 | 137 | KDTree::PointType queryPoint(50.0, 40.0, 40.6); 138 | const double radius = 10; 139 | 140 | std::vector expectedIndices = kdTree.radiusSearch(queryPoint, radius); 141 | 142 | std::vector bruteForcedIndices; 143 | for (int i = 0; i < numPoints; ++i) { 144 | const auto& curPoint = points[i]; 145 | const double curDist = kdTree.distFunc()(queryPoint, curPoint, kdTree.valueAtFunc()); 146 | if (curDist < radius) { 147 | bruteForcedIndices.emplace_back(i); 148 | } 149 | } 150 | 151 | EXPECT_EQ(expectedIndices.size(), bruteForcedIndices.size()); 152 | 153 | std::sort(expectedIndices.begin(), expectedIndices.end()); 154 | std::sort(bruteForcedIndices.begin(), bruteForcedIndices.end()); 155 | std::vector idxIndices(expectedIndices.size()); 156 | std::iota(idxIndices.begin(), idxIndices.end(), 0); 157 | EXPECT_TRUE(std::all_of(idxIndices.begin(), idxIndices.end(), [&](const int idxIdx) { 158 | return expectedIndices[idxIdx] == bruteForcedIndices[idxIdx]; 159 | })); 160 | } 161 | } 162 | 163 | TEST(TestKDTree, TestKNNSearch) 164 | { 165 | { 166 | const int numPoints = 10000; 167 | KDTree::VecPointType points; 168 | points.reserve(numPoints); 169 | { 170 | auto doubleGenerator = ::RandDouble(0.0, 200.0); 171 | for (int i = 0; i < numPoints; ++i) { 172 | points.emplace_back(doubleGenerator(), doubleGenerator(), doubleGenerator()); 173 | } 174 | } 175 | 176 | ::KDTree kdTree(points); 177 | 178 | KDTree::PointType queryPoint(50.0, 40.0, 40.6); 179 | const int k = 10; 180 | 181 | std::vector expectedIndices = kdTree.knnSearch(queryPoint, k); 182 | 183 | std::vector allIndices(numPoints); 184 | std::iota(allIndices.begin(), allIndices.end(), 0); 185 | std::sort(allIndices.begin(), allIndices.end(), [&](const int idx1, const int idx2) { 186 | return kdTree.distFunc()(queryPoint, points[idx1], kdTree.valueAtFunc()) < 187 | kdTree.distFunc()(queryPoint, points[idx2], kdTree.valueAtFunc()); 188 | }); 189 | 190 | std::vector bruteForcedIndices(allIndices.begin(), allIndices.begin() + k); 191 | 192 | EXPECT_EQ(expectedIndices.size(), bruteForcedIndices.size()); 193 | 194 | std::sort(expectedIndices.begin(), expectedIndices.end()); 195 | std::sort(bruteForcedIndices.begin(), bruteForcedIndices.end()); 196 | std::vector idxIndices(expectedIndices.size()); 197 | std::iota(idxIndices.begin(), idxIndices.end(), 0); 198 | EXPECT_TRUE(std::all_of(idxIndices.begin(), idxIndices.end(), [&](const int idxIdx) { 199 | return expectedIndices[idxIdx] == bruteForcedIndices[idxIdx]; 200 | })); 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /cuda/src/DBSCAN.cu: -------------------------------------------------------------------------------- 1 | /** 2 | * @file DBSCAN.cu 3 | * 4 | * @author btran 5 | * 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include "Types.cuh" 17 | 18 | namespace clustering 19 | { 20 | namespace cuda 21 | { 22 | namespace 23 | { 24 | cudaDeviceProp dProperties; 25 | 26 | template 27 | __global__ void makeGraphStep1Kernel(const PointType* __restrict__ points, Node* nodes, int* nodeDegs, int numPoints, 28 | float eps, int minPoints) 29 | { 30 | int tid = threadIdx.x + blockIdx.x * blockDim.x; 31 | 32 | while (tid < numPoints) { 33 | Node* node = &nodes[tid]; 34 | const PointType* point = &points[tid]; 35 | 36 | for (int i = 0; i < numPoints; ++i) { 37 | if (i == tid) { 38 | continue; 39 | } 40 | 41 | const PointType* curPoint = &points[i]; 42 | float diffx = point->x - curPoint->x; 43 | float diffy = point->y - curPoint->y; 44 | float diffz = point->z - curPoint->z; 45 | 46 | float sqrDist = diffx * diffx + diffy * diffy + diffz * diffz; 47 | 48 | if (sqrDist < eps * eps) { 49 | node->numNeighbors++; 50 | } 51 | } 52 | 53 | if (node->numNeighbors >= minPoints) { 54 | node->type = NodeType::CORE; 55 | } 56 | 57 | nodeDegs[tid] = node->numNeighbors; 58 | 59 | tid += blockDim.x * gridDim.x; 60 | } 61 | } 62 | 63 | template 64 | __global__ void makeGraphStep2Kernel(const PointType* __restrict__ points, const int* __restrict__ neighborStartIndices, 65 | int* adjList, int numPoints, float eps) 66 | { 67 | int tid = threadIdx.x + blockIdx.x * blockDim.x; 68 | 69 | while (tid < numPoints) { 70 | const PointType* point = &points[tid]; 71 | 72 | int startIdx = neighborStartIndices[tid]; 73 | 74 | int countNeighbors = 0; 75 | for (int i = 0; i < numPoints; ++i) { 76 | if (i == tid) { 77 | continue; 78 | } 79 | 80 | const PointType* curPoint = &points[i]; 81 | float diffx = point->x - curPoint->x; 82 | float diffy = point->y - curPoint->y; 83 | float diffz = point->z - curPoint->z; 84 | 85 | float sqrDist = diffx * diffx + diffy * diffy + diffz * diffz; 86 | 87 | if (sqrDist < eps * eps) { 88 | adjList[startIdx + countNeighbors++] = i; 89 | } 90 | } 91 | 92 | tid += blockDim.x * gridDim.x; 93 | } 94 | } 95 | 96 | __global__ void BFSKernel(const Node* __restrict__ nodes, const int* __restrict__ adjList, 97 | const int* __restrict__ neighborStartIndices, char* Fa, char* Xa, int numPoints) 98 | { 99 | int tid = threadIdx.x + blockIdx.x * blockDim.x; 100 | 101 | while (tid < numPoints) { 102 | if (Fa[tid]) { 103 | Fa[tid] = false; 104 | Xa[tid] = true; 105 | 106 | int startIdx = neighborStartIndices[tid]; 107 | 108 | for (int i = 0; i < nodes[tid].numNeighbors; ++i) { 109 | int nIdx = adjList[startIdx + i]; 110 | Fa[nIdx] = 1 - Xa[nIdx]; 111 | } 112 | } 113 | 114 | tid += blockDim.x * gridDim.x; 115 | } 116 | } 117 | 118 | void BFS(Node* hNodes, Node* dNodes, const int* adjList, const int* neighborStartIndices, int v, int numPoints, 119 | std::vector& curCluster) 120 | { 121 | thrust::device_vector dXa(numPoints, false); 122 | thrust::device_vector dFa(numPoints, false); 123 | dFa[v] = true; 124 | 125 | int numThreads = 256; 126 | int numBlocks = std::min(dProperties.maxGridSize[0], (numPoints + numThreads - 1) / numThreads); 127 | 128 | int countFa = 1; 129 | while (countFa > 0) { 130 | BFSKernel<<>>(dNodes, adjList, neighborStartIndices, 131 | thrust::raw_pointer_cast(dFa.data()), thrust::raw_pointer_cast(dXa.data()), 132 | numPoints); 133 | countFa = thrust::count(thrust::device, dFa.begin(), dFa.end(), true); 134 | } 135 | 136 | thrust::host_vector hXa = dXa; 137 | 138 | for (int i = 0; i < numPoints; ++i) { 139 | if (hXa[i]) { 140 | hNodes[i].visited = true; 141 | curCluster.emplace_back(i); 142 | } 143 | } 144 | } 145 | 146 | template Graph makeGraph(const PointType* points, int numPoints, float eps, int minPoints) 147 | { 148 | Graph graph; 149 | 150 | graph.nodes.resize(numPoints); 151 | graph.neighborStartIndices.resize(numPoints); 152 | 153 | thrust::device_vector dNodeDegs(numPoints); 154 | 155 | int numThreads = 256; 156 | int numBlocks = std::min(dProperties.maxGridSize[0], (numPoints + numThreads - 1) / numThreads); 157 | 158 | makeGraphStep1Kernel<<>>(points, thrust::raw_pointer_cast(graph.nodes.data()), 159 | thrust::raw_pointer_cast(dNodeDegs.data()), numPoints, 160 | eps, minPoints); 161 | 162 | thrust::exclusive_scan(dNodeDegs.begin(), dNodeDegs.end(), graph.neighborStartIndices.begin()); 163 | 164 | int totalEdges = dNodeDegs.back() + graph.neighborStartIndices.back(); 165 | graph.adjList.resize(totalEdges); 166 | 167 | makeGraphStep2Kernel 168 | <<>>(points, thrust::raw_pointer_cast(graph.neighborStartIndices.data()), 169 | thrust::raw_pointer_cast(graph.adjList.data()), numPoints, eps); 170 | 171 | return graph; 172 | } 173 | } // namespace 174 | 175 | template 176 | DBSCAN::DBSCAN(const Param& param) 177 | : m_param(param) 178 | , m_dPoints(nullptr) 179 | , m_allocatedSize(-1) 180 | { 181 | HANDLE_ERROR(cudaGetDeviceProperties(&dProperties, 0)); 182 | } 183 | 184 | template DBSCAN::~DBSCAN() 185 | { 186 | if (m_dPoints) { 187 | HANDLE_ERROR(cudaFree(m_dPoints)); 188 | m_dPoints = nullptr; 189 | } 190 | } 191 | 192 | template 193 | std::vector> DBSCAN::run(const PointType* points, int numPoints) const 194 | { 195 | if (numPoints <= 0) { 196 | throw std::runtime_error("number of points must be more than 0"); 197 | } 198 | 199 | if (m_allocatedSize < numPoints) { 200 | m_allocatedSize = numPoints; 201 | if (m_dPoints) { 202 | HANDLE_ERROR(cudaFree(m_dPoints)); 203 | } 204 | HANDLE_ERROR(cudaMalloc((void**)&m_dPoints, numPoints * sizeof(PointType))); 205 | } 206 | HANDLE_ERROR(cudaMemcpy(m_dPoints, points, numPoints * sizeof(PointType), cudaMemcpyHostToDevice)); 207 | 208 | auto graph = makeGraph(m_dPoints, numPoints, m_param.eps, m_param.minPoints); 209 | 210 | thrust::host_vector hNodes = graph.nodes; 211 | 212 | std::vector> clusterIndices; 213 | for (int i = 0; i < numPoints; ++i) { 214 | auto& curHNode = hNodes[i]; 215 | if (curHNode.visited || curHNode.type != NodeType::CORE) { 216 | continue; 217 | } 218 | 219 | std::vector curCluster; 220 | curCluster.emplace_back(i); 221 | curHNode.visited = true; 222 | 223 | BFS(hNodes.data(), thrust::raw_pointer_cast(graph.nodes.data()), thrust::raw_pointer_cast(graph.adjList.data()), 224 | thrust::raw_pointer_cast(graph.neighborStartIndices.data()), i, numPoints, curCluster); 225 | 226 | clusterIndices.emplace_back(std::move(curCluster)); 227 | } 228 | 229 | return clusterIndices; 230 | } 231 | 232 | #undef INSTANTIATE_TEMPLATE 233 | #define INSTANTIATE_TEMPLATE(DATA_TYPE) template class DBSCAN; 234 | 235 | INSTANTIATE_TEMPLATE(pcl::PointXYZ); 236 | INSTANTIATE_TEMPLATE(pcl::PointXYZI); 237 | INSTANTIATE_TEMPLATE(pcl::PointXYZRGB); 238 | INSTANTIATE_TEMPLATE(pcl::PointNormal); 239 | INSTANTIATE_TEMPLATE(pcl::PointXYZRGBNormal); 240 | INSTANTIATE_TEMPLATE(pcl::PointXYZINormal); 241 | 242 | #undef INSTANTIATE_TEMPLATE 243 | } // namespace cuda 244 | } // namespace clustering 245 | -------------------------------------------------------------------------------- /include/dbscan/impl/KDTreeImpl.ipp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file KDTreeImpl.ipp 3 | * 4 | * @author btran 5 | * 6 | * Copyright (c) organization 7 | * 8 | */ 9 | 10 | namespace clustering 11 | { 12 | template 13 | KDTree::KDTree(const VecPointType& points, 14 | const DistanceFunctionType& distFunc, 15 | const ValueAtFunctionType& valueAtFunc) 16 | : m_points(points) 17 | , m_valueAtFunc(valueAtFunc) 18 | , m_distFunc(distFunc) 19 | { 20 | this->buildKDTree(); 21 | } 22 | 23 | template 24 | KDTree::~KDTree() 25 | { 26 | this->clear(); 27 | } 28 | 29 | template 30 | void KDTree::buildKDTree() 31 | { 32 | std::vector pointIndices(m_points.size()); 33 | std::iota(pointIndices.begin(), pointIndices.end(), 0); 34 | m_root = this->insertRec(pointIndices, 0); 35 | } 36 | 37 | template 38 | void KDTree::clear() 39 | { 40 | this->clearRec(m_root); 41 | m_root = nullptr; 42 | } 43 | 44 | template 45 | int KDTree::nearestSearch(const PointType& queryPoint) const 46 | { 47 | int nearestPointIdx = -1; 48 | 49 | if (!m_root) { 50 | return nearestPointIdx; 51 | } 52 | 53 | double nearestDist = std::numeric_limits::max(); 54 | this->nearestSearchRec(queryPoint, m_root, nearestPointIdx, nearestDist); 55 | 56 | return nearestPointIdx; 57 | } 58 | 59 | template 60 | std::vector KDTree::radiusSearch(const PointType& queryPoint, 61 | const double radius) const 62 | { 63 | std::vector indices; 64 | 65 | if (!m_root) { 66 | return indices; 67 | } 68 | 69 | this->radiusSearchRec(queryPoint, m_root, indices, radius); 70 | return indices; 71 | } 72 | 73 | template 74 | std::vector KDTree::knnSearch(const PointType& queryPoint, 75 | const int k) const 76 | { 77 | std::vector indices; 78 | 79 | if (!m_root || k < 0) { 80 | return indices; 81 | } 82 | 83 | if (k > m_points.size()) { 84 | indices.resize(m_points.size()); 85 | std::iota(indices.begin(), indices.end(), 0); 86 | std::sort(indices.begin(), indices.end(), [&queryPoint, this](const int idx1, const int idx2) { 87 | return m_distFunc(queryPoint, m_points[idx1], m_valueAtFunc) < 88 | m_distFunc(queryPoint, m_points[idx2], m_valueAtFunc); 89 | }); 90 | 91 | return indices; 92 | } 93 | 94 | double maxDistance = std::numeric_limits::max(); 95 | KNNPriorityQueue knnMaxHeap; 96 | this->knnSearchRec(queryPoint, m_root, knnMaxHeap, maxDistance, k); 97 | 98 | indices.reserve(k); 99 | while (!knnMaxHeap.empty()) { 100 | indices.emplace_back(knnMaxHeap.top().second); 101 | knnMaxHeap.pop(); 102 | } 103 | 104 | return indices; 105 | } 106 | 107 | template 108 | typename KDTree::KDNode* 109 | KDTree::insertRec(std::vector& pointIndices, const int depth) 110 | { 111 | if (pointIndices.empty()) { 112 | return nullptr; 113 | } 114 | 115 | const int medianIdx = pointIndices.size() / 2; 116 | const int splitType = depth % POINT_DIMENSION; 117 | std::nth_element(pointIndices.begin(), pointIndices.begin() + medianIdx, pointIndices.end(), 118 | [&](const int lhs, const int rhs) { 119 | return (m_valueAtFunc(m_points[lhs], splitType) < m_valueAtFunc(m_points[rhs], splitType)); 120 | }); 121 | 122 | KDNode* curNode = new KDNode(); 123 | curNode->pointIdx = pointIndices[medianIdx]; 124 | curNode->splitType = splitType; 125 | 126 | std::vector leftIndices(pointIndices.begin(), pointIndices.begin() + medianIdx); 127 | std::vector rightIndices(pointIndices.begin() + medianIdx + 1, pointIndices.end()); 128 | curNode->child[0] = this->insertRec(leftIndices, depth + 1); 129 | curNode->child[1] = this->insertRec(rightIndices, depth + 1); 130 | 131 | return curNode; 132 | } 133 | 134 | template 135 | void KDTree::clearRec(KDNode* node) 136 | { 137 | if (!node) { 138 | return; 139 | } 140 | 141 | if (node->child[0]) { 142 | this->clearRec(node->child[0]); 143 | } 144 | 145 | if (node->child[1]) { 146 | this->clearRec(node->child[1]); 147 | } 148 | 149 | delete node; 150 | } 151 | 152 | template 153 | void KDTree::nearestSearchRec(const PointType& queryPoint, 154 | const KDNode* node, int& nearestPointIdx, 155 | double& nearestDist) const 156 | { 157 | if (!node) { 158 | return; 159 | } 160 | 161 | const PointType& curPoint = m_points[node->pointIdx]; 162 | const double curDist = m_distFunc(queryPoint, curPoint, m_valueAtFunc); 163 | if (curDist < nearestDist) { 164 | nearestPointIdx = node->pointIdx; 165 | nearestDist = curDist; 166 | } 167 | 168 | const int splitType = node->splitType; 169 | const int searchDirection = m_valueAtFunc(queryPoint, splitType) < m_valueAtFunc(curPoint, splitType) ? 0 : 1; 170 | this->nearestSearchRec(queryPoint, node->child[searchDirection], nearestPointIdx, nearestDist); 171 | 172 | const double distToTheRemainingPlane = 173 | std::fabs(m_valueAtFunc(queryPoint, splitType) - m_valueAtFunc(curPoint, splitType)); 174 | if (distToTheRemainingPlane < nearestDist) { 175 | this->nearestSearchRec(queryPoint, node->child[!searchDirection], nearestPointIdx, nearestDist); 176 | } 177 | } 178 | 179 | template 180 | void KDTree::radiusSearchRec(const PointType& queryPoint, 181 | const KDNode* node, std::vector& indices, 182 | const double radius) const 183 | { 184 | if (!node) { 185 | return; 186 | } 187 | 188 | const PointType& curPoint = m_points[node->pointIdx]; 189 | const double curDist = m_distFunc(queryPoint, curPoint, m_valueAtFunc); 190 | if (curDist < radius) { 191 | indices.emplace_back(node->pointIdx); 192 | } 193 | 194 | const int splitType = node->splitType; 195 | const int searchDirection = m_valueAtFunc(queryPoint, splitType) < m_valueAtFunc(curPoint, splitType) ? 0 : 1; 196 | this->radiusSearchRec(queryPoint, node->child[searchDirection], indices, radius); 197 | 198 | const double distToTheRemainingPlane = 199 | std::fabs(m_valueAtFunc(queryPoint, splitType) - m_valueAtFunc(curPoint, splitType)); 200 | if (distToTheRemainingPlane < radius) { 201 | this->radiusSearchRec(queryPoint, node->child[!searchDirection], indices, radius); 202 | } 203 | } 204 | 205 | template 206 | void KDTree::knnSearchRec(const PointType& queryPoint, const KDNode* node, 207 | KNNPriorityQueue& knnMaxHeap, 208 | double& maxDistance, const int k) const 209 | { 210 | if (!node) { 211 | return; 212 | } 213 | 214 | if (knnMaxHeap.size() == k) { 215 | maxDistance = knnMaxHeap.top().first; 216 | } 217 | 218 | const PointType& curPoint = m_points[node->pointIdx]; 219 | const double curDist = m_distFunc(queryPoint, curPoint, m_valueAtFunc); 220 | 221 | if (curDist < maxDistance) { 222 | while (knnMaxHeap.size() >= k) { 223 | knnMaxHeap.pop(); 224 | } 225 | 226 | knnMaxHeap.emplace(std::make_pair(curDist, node->pointIdx)); 227 | } 228 | 229 | const int splitType = node->splitType; 230 | const int searchDirection = m_valueAtFunc(queryPoint, splitType) < m_valueAtFunc(curPoint, splitType) ? 0 : 1; 231 | this->knnSearchRec(queryPoint, node->child[searchDirection], knnMaxHeap, maxDistance, k); 232 | 233 | const double distToTheRemainingPlane = 234 | std::fabs(m_valueAtFunc(queryPoint, splitType) - m_valueAtFunc(curPoint, splitType)); 235 | if (knnMaxHeap.size() < k || distToTheRemainingPlane < maxDistance) { 236 | this->knnSearchRec(queryPoint, node->child[!searchDirection], knnMaxHeap, maxDistance, k); 237 | } 238 | } 239 | 240 | //------------------------------------------------------------------------------ 241 | // kdnode data structure 242 | //------------------------------------------------------------------------------ 243 | 244 | template 245 | struct KDTree::KDNode { 246 | KDNode(); 247 | 248 | int pointIdx; // index of the point this node is pointing to 249 | int splitType; 250 | KDNode* child[2]; // left , right nodes 251 | }; 252 | 253 | template 254 | KDTree::KDNode::KDNode() 255 | : pointIdx(-1) 256 | , splitType(-1) 257 | { 258 | child[0] = nullptr; 259 | child[1] = nullptr; 260 | } 261 | } // namespace clustering 262 | --------------------------------------------------------------------------------