├── simple_knn └── .gitkeep ├── .gitignore ├── spatial.h ├── ext.cpp ├── simple_knn.h ├── README.md ├── setup.py ├── spatial.cu └── simple_knn.cu /simple_knn/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | simple_knn.egg-info/ 3 | dist/ 4 | simple_knn/*.so -------------------------------------------------------------------------------- /spatial.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | 14 | torch::Tensor distCUDA2(const torch::Tensor& points); 15 | -------------------------------------------------------------------------------- /ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "spatial.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("distCUDA2", &distCUDA2); 17 | } 18 | -------------------------------------------------------------------------------- /simple_knn.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef SIMPLEKNN_H_INCLUDED 13 | #define SIMPLEKNN_H_INCLUDED 14 | 15 | class SimpleKNN 16 | { 17 | public: 18 | static void knn(int P, float3* points, float* meanDists); 19 | }; 20 | 21 | #endif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | simple-knn 2 | --- 3 | 4 | Description: It compute the **average distance to the nearest neighbors** for a set of 3D points. 5 | 6 | Install: 7 | ```bash 8 | pip install git+https://github.com/camenduru/simple-knn 9 | 10 | # or 11 | git clone https://github.com/camenduru/simple-knn && cd simple-knn 12 | pip install . 13 | ``` 14 | 15 | Usage: 16 | ```python 17 | from simple_knn._C import distCUDA2 18 | 19 | # shape: [N, 3] 20 | demopc = torch.from_numpy(np.load("/path")).float().cuda().contiguous() 21 | 22 | # shape: [N] 23 | mean_distances = distCUDA2(demopc) 24 | ``` -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | 16 | cxx_compiler_flags = [] 17 | 18 | if os.name == 'nt': 19 | cxx_compiler_flags.append("/wd4624") 20 | 21 | setup( 22 | name="simple_knn", 23 | ext_modules=[ 24 | CUDAExtension( 25 | name="simple_knn._C", 26 | sources=[ 27 | "spatial.cu", 28 | "simple_knn.cu", 29 | "ext.cpp"], 30 | extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags}) 31 | ], 32 | cmdclass={ 33 | 'build_ext': BuildExtension 34 | }, 35 | version='1.0.0' 36 | ) 37 | -------------------------------------------------------------------------------- /spatial.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "spatial.h" 13 | #include "simple_knn.h" 14 | 15 | #include // Include the CUDA runtime header for cudaSetDevice() 16 | 17 | torch::Tensor 18 | distCUDA2(const torch::Tensor& points) 19 | { 20 | const int P = points.size(0); 21 | 22 | // Determine which device the tensor is on 23 | auto device = points.device(); 24 | int device_index = device.index(); // Get the index of the device 25 | 26 | // Set the current CUDA device to the device where 'points' is located 27 | cudaSetDevice(device_index); 28 | 29 | auto float_opts = points.options().dtype(torch::kFloat32); 30 | torch::Tensor means = torch::full({P}, 0.0, float_opts); 31 | 32 | SimpleKNN::knn(P, (float3*)points.contiguous().data_ptr(), means.contiguous().data_ptr()); 33 | 34 | return means; 35 | } 36 | -------------------------------------------------------------------------------- /simple_knn.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #define BOX_SIZE 1024 13 | 14 | #include "cuda_runtime.h" 15 | #include "device_launch_parameters.h" 16 | #include "simple_knn.h" 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #define __CUDACC__ 25 | #include 26 | #include 27 | 28 | namespace cg = cooperative_groups; 29 | 30 | struct CustomMin 31 | { 32 | __device__ __forceinline__ 33 | float3 operator()(const float3& a, const float3& b) const { 34 | return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) }; 35 | } 36 | }; 37 | 38 | struct CustomMax 39 | { 40 | __device__ __forceinline__ 41 | float3 operator()(const float3& a, const float3& b) const { 42 | return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) }; 43 | } 44 | }; 45 | 46 | __host__ __device__ uint32_t prepMorton(uint32_t x) 47 | { 48 | x = (x | (x << 16)) & 0x030000FF; 49 | x = (x | (x << 8)) & 0x0300F00F; 50 | x = (x | (x << 4)) & 0x030C30C3; 51 | x = (x | (x << 2)) & 0x09249249; 52 | return x; 53 | } 54 | 55 | __host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx) 56 | { 57 | uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1)); 58 | uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1)); 59 | uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1)); 60 | 61 | return x | (y << 1) | (z << 2); 62 | } 63 | 64 | __global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes) 65 | { 66 | auto idx = cg::this_grid().thread_rank(); 67 | if (idx >= P) 68 | return; 69 | 70 | codes[idx] = coord2Morton(points[idx], minn, maxx); 71 | } 72 | 73 | struct MinMax 74 | { 75 | float3 minn; 76 | float3 maxx; 77 | }; 78 | 79 | __global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes) 80 | { 81 | auto idx = cg::this_grid().thread_rank(); 82 | 83 | MinMax me; 84 | if (idx < P) 85 | { 86 | me.minn = points[indices[idx]]; 87 | me.maxx = points[indices[idx]]; 88 | } 89 | else 90 | { 91 | me.minn = { FLT_MAX, FLT_MAX, FLT_MAX }; 92 | me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX }; 93 | } 94 | 95 | __shared__ MinMax redResult[BOX_SIZE]; 96 | 97 | for (int off = BOX_SIZE / 2; off >= 1; off /= 2) 98 | { 99 | if (threadIdx.x < 2 * off) 100 | redResult[threadIdx.x] = me; 101 | __syncthreads(); 102 | 103 | if (threadIdx.x < off) 104 | { 105 | MinMax other = redResult[threadIdx.x + off]; 106 | me.minn.x = min(me.minn.x, other.minn.x); 107 | me.minn.y = min(me.minn.y, other.minn.y); 108 | me.minn.z = min(me.minn.z, other.minn.z); 109 | me.maxx.x = max(me.maxx.x, other.maxx.x); 110 | me.maxx.y = max(me.maxx.y, other.maxx.y); 111 | me.maxx.z = max(me.maxx.z, other.maxx.z); 112 | } 113 | __syncthreads(); 114 | } 115 | 116 | if (threadIdx.x == 0) 117 | boxes[blockIdx.x] = me; 118 | } 119 | 120 | __device__ __host__ float distBoxPoint(const MinMax& box, const float3& p) 121 | { 122 | float3 diff = { 0, 0, 0 }; 123 | if (p.x < box.minn.x || p.x > box.maxx.x) 124 | diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x)); 125 | if (p.y < box.minn.y || p.y > box.maxx.y) 126 | diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y)); 127 | if (p.z < box.minn.z || p.z > box.maxx.z) 128 | diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z)); 129 | return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z; 130 | } 131 | 132 | template 133 | __device__ void updateKBest(const float3& ref, const float3& point, float* knn) 134 | { 135 | float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z }; 136 | float dist = d.x * d.x + d.y * d.y + d.z * d.z; 137 | for (int j = 0; j < K; j++) 138 | { 139 | if (knn[j] > dist) 140 | { 141 | float t = knn[j]; 142 | knn[j] = dist; 143 | dist = t; 144 | } 145 | } 146 | } 147 | 148 | __global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists) 149 | { 150 | int idx = cg::this_grid().thread_rank(); 151 | if (idx >= P) 152 | return; 153 | 154 | float3 point = points[indices[idx]]; 155 | float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX }; 156 | 157 | for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++) 158 | { 159 | if (i == idx) 160 | continue; 161 | updateKBest<3>(point, points[indices[i]], best); 162 | } 163 | 164 | float reject = best[2]; 165 | best[0] = FLT_MAX; 166 | best[1] = FLT_MAX; 167 | best[2] = FLT_MAX; 168 | 169 | for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++) 170 | { 171 | MinMax box = boxes[b]; 172 | float dist = distBoxPoint(box, point); 173 | if (dist > reject || dist > best[2]) 174 | continue; 175 | 176 | for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++) 177 | { 178 | if (i == idx) 179 | continue; 180 | updateKBest<3>(point, points[indices[i]], best); 181 | } 182 | } 183 | dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f; 184 | } 185 | 186 | void SimpleKNN::knn(int P, float3* points, float* meanDists) 187 | { 188 | float3* result; 189 | cudaMalloc(&result, sizeof(float3)); 190 | size_t temp_storage_bytes; 191 | 192 | float3 init = { 0, 0, 0 }, minn, maxx; 193 | 194 | cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init); 195 | thrust::device_vector temp_storage(temp_storage_bytes); 196 | 197 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init); 198 | cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost); 199 | 200 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init); 201 | cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost); 202 | 203 | thrust::device_vector morton(P); 204 | thrust::device_vector morton_sorted(P); 205 | coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get()); 206 | 207 | thrust::device_vector indices(P); 208 | thrust::sequence(indices.begin(), indices.end()); 209 | thrust::device_vector indices_sorted(P); 210 | 211 | cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 212 | temp_storage.resize(temp_storage_bytes); 213 | 214 | cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 215 | 216 | uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE; 217 | thrust::device_vector boxes(num_boxes); 218 | boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get()); 219 | boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists); 220 | 221 | cudaFree(result); 222 | } --------------------------------------------------------------------------------