├── .gitignore ├── README.md ├── knn_cuda ├── __init__.py └── csrc │ └── cuda │ ├── knn.cpp │ └── knn.cu ├── makefile ├── ninja ├── requirements.txt ├── setup.py ├── test_knn_cuda.log └── tests └── test_knn_cuda.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | _ext 3 | __pycache__ 4 | dist 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KNN_CUDA 2 | 3 | + ref: [kNN-CUDA](https://github.com/vincentfpgarcia/kNN-CUDA) 4 | + ref: [pytorch knn cuda](https://github.com/chrischoy/pytorch_knn_cuda) 5 | + author: [sli@mail.bnu.edu.cn](sli@mail.bnu.edu.cn) 6 | 7 | 8 | #### Modifications 9 | + Aten support 10 | + pytorch v1.0+ support 11 | + pytorch c++ extention 12 | 13 | #### Performance 14 | 15 | + dim = 5 16 | + k = 100 17 | + ref = 224 18 | + query = 224 19 | + Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz 20 | + NVIDIA GeForce 940MX 21 | 22 | | Loop | sklearn | CUDA | Memory | 23 | | :---: | :---: | :---: | :---: | 24 | | 100 | 2.34 ms | 0.06 ms | 652/1024 | 25 | | 1000 | 2.30 ms | 1.40 ms | 652/1024 | 26 | 27 | 28 | #### Install 29 | 30 | 31 | + from source 32 | 33 | ```bash 34 | git clone https://github.com/unlimblue/KNN_CUDA.git 35 | cd KNN_CUDA 36 | make && make install 37 | ``` 38 | 39 | + from wheel 40 | 41 | ```bash 42 | pip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl 43 | ``` 44 | And then, make sure [`ninja`](https://ninja-build.org/) has been installed: 45 | 1. see [https://pytorch.org/tutorials/advanced/cpp_extension.html](https://pytorch.org/tutorials/advanced/cpp_extension.html) 46 | 2. **or just**: 47 | ```bash 48 | wget -P /usr/bin https://github.com/unlimblue/KNN_CUDA/raw/master/ninja 49 | ``` 50 | 51 | + for windows 52 | 53 | You should use branch `windows`: 54 | 55 | ```bash 56 | git clone --branch windows https://github.com/unlimblue/KNN_CUDA.git 57 | cd C:\\PATH_TO_KNN_CUDA 58 | make 59 | make install 60 | ``` 61 | 62 | #### Usage 63 | 64 | ```python 65 | import torch 66 | 67 | # Make sure your CUDA is available. 68 | assert torch.cuda.is_available() 69 | 70 | from knn_cuda import KNN 71 | """ 72 | if transpose_mode is True, 73 | ref is Tensor [bs x nr x dim] 74 | query is Tensor [bs x nq x dim] 75 | 76 | return 77 | dist is Tensor [bs x nq x k] 78 | indx is Tensor [bs x nq x k] 79 | else 80 | ref is Tensor [bs x dim x nr] 81 | query is Tensor [bs x dim x nq] 82 | 83 | return 84 | dist is Tensor [bs x k x nq] 85 | indx is Tensor [bs x k x nq] 86 | """ 87 | 88 | knn = KNN(k=10, transpose_mode=True) 89 | 90 | ref = torch.rand(32, 1000, 5).cuda() 91 | query = torch.rand(32, 50, 5).cuda() 92 | 93 | dist, indx = knn(ref, query) # 32 x 50 x 10 94 | ``` 95 | -------------------------------------------------------------------------------- /knn_cuda/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.cpp_extension import load 5 | 6 | 7 | __version__ = "0.2" 8 | 9 | 10 | def load_cpp_ext(ext_name): 11 | root_dir = os.path.join(os.path.split(__file__)[0]) 12 | ext_csrc = os.path.join(root_dir, "csrc") 13 | ext_path = os.path.join(ext_csrc, "_ext", ext_name) 14 | os.makedirs(ext_path, exist_ok=True) 15 | assert torch.cuda.is_available(), "torch.cuda.is_available() is False." 16 | ext_sources = [ 17 | os.path.join(ext_csrc, "cuda", "{}.cpp".format(ext_name)), 18 | os.path.join(ext_csrc, "cuda", "{}.cu".format(ext_name)) 19 | ] 20 | extra_cuda_cflags = [ 21 | "-DCUDA_HAS_FP16=1", 22 | "-D__CUDA_NO_HALF_OPERATORS__", 23 | "-D__CUDA_NO_HALF_CONVERSIONS__", 24 | "-D__CUDA_NO_HALF2_OPERATORS__", 25 | ] 26 | ext = load( 27 | name=ext_name, 28 | sources=ext_sources, 29 | extra_cflags=["-O2"], 30 | build_directory=ext_path, 31 | extra_cuda_cflags=extra_cuda_cflags, 32 | verbose=False, 33 | with_cuda=True 34 | ) 35 | return ext 36 | 37 | 38 | _knn = load_cpp_ext("knn") 39 | 40 | 41 | def knn(ref, query, k): 42 | d, i = _knn.knn(ref, query, k) 43 | i -= 1 44 | return d, i 45 | 46 | 47 | def _T(t, mode=False): 48 | if mode: 49 | return t.transpose(0, 1).contiguous() 50 | else: 51 | return t 52 | 53 | 54 | class KNN(nn.Module): 55 | 56 | def __init__(self, k, transpose_mode=False): 57 | super(KNN, self).__init__() 58 | self.k = k 59 | self._t = transpose_mode 60 | 61 | def forward(self, ref, query): 62 | assert ref.size(0) == query.size(0), "ref.shape={} != query.shape={}".format(ref.shape, query.shape) 63 | with torch.no_grad(): 64 | batch_size = ref.size(0) 65 | D, I = [], [] 66 | for bi in range(batch_size): 67 | r, q = _T(ref[bi], self._t), _T(query[bi], self._t) 68 | d, i = knn(r.float(), q.float(), self.k) 69 | d, i = _T(d, self._t), _T(i, self._t) 70 | D.append(d) 71 | I.append(i) 72 | D = torch.stack(D, dim=0) 73 | I = torch.stack(I, dim=0) 74 | return D, I 75 | 76 | -------------------------------------------------------------------------------- /knn_cuda/csrc/cuda/knn.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 6 | #define CHECK_TYPE(x, t) AT_ASSERTM(x.dtype() == t, #x " must be " #t) 7 | #define CHECK_CUDA(x) AT_ASSERTM(x.device().type() == at::Device::Type::CUDA, #x " must be on CUDA") 8 | #define CHECK_INPUT(x, t) CHECK_CONTIGUOUS(x); CHECK_TYPE(x, t); CHECK_CUDA(x) 9 | 10 | 11 | void knn_device( 12 | float* ref_dev, 13 | int ref_nb, 14 | float* query_dev, 15 | int query_nb, 16 | int dim, 17 | int k, 18 | float* dist_dev, 19 | long* ind_dev, 20 | cudaStream_t stream 21 | ); 22 | 23 | std::vector knn( 24 | at::Tensor & ref, 25 | at::Tensor & query, 26 | const int k 27 | ){ 28 | 29 | CHECK_INPUT(ref, at::kFloat); 30 | CHECK_INPUT(query, at::kFloat); 31 | int dim = ref.size(0); 32 | int ref_nb = ref.size(1); 33 | int query_nb = query.size(1); 34 | float * ref_dev = ref.data(); 35 | float * query_dev = query.data(); 36 | auto dist = at::empty({ref_nb, query_nb}, query.options().dtype(at::kFloat)); 37 | auto ind = at::empty({k, query_nb}, query.options().dtype(at::kLong)); 38 | float * dist_dev = dist.data(); 39 | long * ind_dev = ind.data(); 40 | 41 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 42 | 43 | knn_device( 44 | ref_dev, 45 | ref_nb, 46 | query_dev, 47 | query_nb, 48 | dim, 49 | k, 50 | dist_dev, 51 | ind_dev, 52 | stream 53 | ); 54 | 55 | return {dist.slice(0, 0, k), ind}; 56 | } 57 | 58 | 59 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 60 | m.def("knn", &knn, "KNN cuda version"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /knn_cuda/csrc/cuda/knn.cu: -------------------------------------------------------------------------------- 1 | /** Modifed version of knn-CUDA from https://github.com/vincentfpgarcia/kNN-CUDA 2 | * The modifications are 3 | * removed texture memory usage 4 | * removed split query KNN computation 5 | * added feature extraction with bilinear interpolation 6 | * 7 | * Last modified by Christopher B. Choy 12/23/2016 8 | */ 9 | 10 | // Includes 11 | #include 12 | #include "cuda.h" 13 | 14 | // Constants used by the program 15 | #define BLOCK_DIM 16 16 | #define DEBUG 0 17 | 18 | /** 19 | * Computes the distance between two matrix A (reference points) and 20 | * B (query points) containing respectively wA and wB points. 21 | * 22 | * @param A pointer on the matrix A 23 | * @param wA width of the matrix A = number of points in A 24 | * @param B pointer on the matrix B 25 | * @param wB width of the matrix B = number of points in B 26 | * @param dim dimension of points = height of matrices A and B 27 | * @param AB pointer on the matrix containing the wA*wB distances computed 28 | */ 29 | __global__ void cuComputeDistanceGlobal( float* A, int wA, 30 | float* B, int wB, int dim, float* AB){ 31 | 32 | // Declaration of the shared memory arrays As and Bs used to store the sub-matrix of A and B 33 | __shared__ float shared_A[BLOCK_DIM][BLOCK_DIM]; 34 | __shared__ float shared_B[BLOCK_DIM][BLOCK_DIM]; 35 | 36 | // Sub-matrix of A (begin, step, end) and Sub-matrix of B (begin, step) 37 | __shared__ int begin_A; 38 | __shared__ int begin_B; 39 | __shared__ int step_A; 40 | __shared__ int step_B; 41 | __shared__ int end_A; 42 | 43 | // Thread index 44 | int tx = threadIdx.x; 45 | int ty = threadIdx.y; 46 | 47 | // Other variables 48 | float tmp; 49 | float ssd = 0; 50 | 51 | // Loop parameters 52 | begin_A = BLOCK_DIM * blockIdx.y; 53 | begin_B = BLOCK_DIM * blockIdx.x; 54 | step_A = BLOCK_DIM * wA; 55 | step_B = BLOCK_DIM * wB; 56 | end_A = begin_A + (dim-1) * wA; 57 | 58 | // Conditions 59 | int cond0 = (begin_A + tx < wA); // used to write in shared memory 60 | int cond1 = (begin_B + tx < wB); // used to write in shared memory & to computations and to write in output matrix 61 | int cond2 = (begin_A + ty < wA); // used to computations and to write in output matrix 62 | 63 | // Loop over all the sub-matrices of A and B required to compute the block sub-matrix 64 | for (int a = begin_A, b = begin_B; a <= end_A; a += step_A, b += step_B) { 65 | // Load the matrices from device memory to shared memory; each thread loads one element of each matrix 66 | if (a/wA + ty < dim){ 67 | shared_A[ty][tx] = (cond0)? A[a + wA * ty + tx] : 0; 68 | shared_B[ty][tx] = (cond1)? B[b + wB * ty + tx] : 0; 69 | } 70 | else{ 71 | shared_A[ty][tx] = 0; 72 | shared_B[ty][tx] = 0; 73 | } 74 | 75 | // Synchronize to make sure the matrices are loaded 76 | __syncthreads(); 77 | 78 | // Compute the difference between the two matrixes; each thread computes one element of the block sub-matrix 79 | if (cond2 && cond1){ 80 | for (int k = 0; k < BLOCK_DIM; ++k){ 81 | tmp = shared_A[k][ty] - shared_B[k][tx]; 82 | ssd += tmp*tmp; 83 | } 84 | } 85 | 86 | // Synchronize to make sure that the preceding computation is done before loading two new sub-matrices of A and B in the next iteration 87 | __syncthreads(); 88 | } 89 | 90 | // Write the block sub-matrix to device memory; each thread writes one element 91 | if (cond2 && cond1) 92 | AB[(begin_A + ty) * wB + begin_B + tx] = ssd; 93 | } 94 | 95 | 96 | /** 97 | * Gathers k-th smallest distances for each column of the distance matrix in the top. 98 | * 99 | * @param dist distance matrix 100 | * @param ind index matrix 101 | * @param width width of the distance matrix and of the index matrix 102 | * @param height height of the distance matrix and of the index matrix 103 | * @param k number of neighbors to consider 104 | */ 105 | __global__ void cuInsertionSort(float *dist, long *ind, int width, int height, int k){ 106 | 107 | // Variables 108 | int l, i, j; 109 | float *p_dist; 110 | long *p_ind; 111 | float curr_dist, max_dist; 112 | long curr_row, max_row; 113 | unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x; 114 | if (xIndexcurr_dist){ 129 | i=a; 130 | break; 131 | } 132 | } 133 | for (j=l; j>i; j--){ 134 | p_dist[j*width] = p_dist[(j-1)*width]; 135 | p_ind[j*width] = p_ind[(j-1)*width]; 136 | } 137 | p_dist[i*width] = curr_dist; 138 | p_ind[i*width] = l + 1; 139 | } else { 140 | p_ind[l*width] = l + 1; 141 | } 142 | max_dist = p_dist[curr_row]; 143 | } 144 | 145 | // Part 2 : insert element in the k-th first lines 146 | max_row = (k-1)*width; 147 | for (l=k; lcurr_dist){ 153 | i=a; 154 | break; 155 | } 156 | } 157 | for (j=k-1; j>i; j--){ 158 | p_dist[j*width] = p_dist[(j-1)*width]; 159 | p_ind[j*width] = p_ind[(j-1)*width]; 160 | } 161 | p_dist[i*width] = curr_dist; 162 | p_ind[i*width] = l + 1; 163 | max_dist = p_dist[max_row]; 164 | } 165 | } 166 | } 167 | } 168 | 169 | 170 | /** 171 | * Computes the square root of the first line (width-th first element) 172 | * of the distance matrix. 173 | * 174 | * @param dist distance matrix 175 | * @param width width of the distance matrix 176 | * @param k number of neighbors to consider 177 | */ 178 | __global__ void cuParallelSqrt(float *dist, int width, int k){ 179 | unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x; 180 | unsigned int yIndex = blockIdx.y * blockDim.y + threadIdx.y; 181 | if (xIndex>>(ref_dev, ref_nb, 252 | query_dev, query_nb, dim, dist_dev); 253 | 254 | #if DEBUG 255 | printf("Pre insertionSort\n"); 256 | debug(dist_dev, ind_dev, query_nb, k); 257 | #endif 258 | 259 | // Kernel 2: Sort each column 260 | cuInsertionSort<<>>(dist_dev, ind_dev, query_nb, ref_nb, k); 261 | 262 | #if DEBUG 263 | printf("Post insertionSort\n"); 264 | debug(dist_dev, ind_dev, query_nb, k); 265 | #endif 266 | 267 | // Kernel 3: Compute square root of k first elements 268 | cuParallelSqrt<<>>(dist_dev, query_nb, k); 269 | } 270 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | .PHONY : build reqs install clean 2 | NINJA := $(shell command -v ninja 2> /dev/null) 3 | 4 | 5 | build : reqs 6 | python3 setup.py bdist_wheel 7 | 8 | reqs : 9 | ifndef NINJA 10 | sudo cp ./ninja /usr/bin 11 | endif 12 | pip3 install -r requirements.txt 13 | 14 | install : 15 | pip3 install --upgrade dist/*.whl 16 | 17 | clean : 18 | -rm -rf build dist/* *.egg-info 19 | -------------------------------------------------------------------------------- /ninja: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unlimblue/KNN_CUDA/619617b51ce2df785e276164cd8cac0234eb8e8c/ninja -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit_learn 3 | torch>=1.1.0 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup, find_packages 3 | from knn_cuda import __version__ 4 | 5 | 6 | with open('requirements.txt') as f: 7 | required = f.read().splitlines() 8 | 9 | setup( 10 | name='KNN_CUDA', 11 | version=__version__, 12 | description='pytorch version knn support cuda.', 13 | author='Shuaipeng Li', 14 | author_email='sli@mail.bnu.edu.cn', 15 | packages=find_packages(), 16 | package_data={ 17 | 'knn_cuda': ["csrc/cuda/knn.cu", "csrc/cuda/knn.cpp"] 18 | }, 19 | install_requires=required 20 | ) 21 | 22 | -------------------------------------------------------------------------------- /test_knn_cuda.log: -------------------------------------------------------------------------------- 1 | ============================= test session starts ============================== 2 | platform linux -- Python 3.6.6, pytest-4.0.1, py-1.7.0, pluggy-0.8.0 -- /home/shuli/Env/py36/Python/virtual/env/bin/python 3 | cachedir: .pytest_cache 4 | benchmark: 3.1.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) 5 | rootdir: /home/shuli/Repo/KNN, inifile: 6 | plugins: benchmark-3.1.1 7 | collecting ... collected 11 items 8 | 9 | tests/test_knn_cuda.py::TestKNNCuda::test_knn_cuda_performance PASSED [ 9%] 10 | tests/test_knn_cuda.py::TestKNNCuda::test_knn_cuda_400_5_1000 PASSED [ 18%] 11 | tests/test_knn_cuda.py::TestKNNCuda::test_knn_cuda_400_5_100 PASSED [ 27%] 12 | tests/test_knn_cuda.py::TestKNNCuda::test_knn_cuda_400_5_10 PASSED [ 36%] 13 | tests/test_knn_cuda.py::TestKNNCuda::test_knn_cuda_400_5_1001 PASSED [ 45%] 14 | tests/test_knn_cuda.py::TestKNNCuda::test_knn_cuda_400_5_101 PASSED [ 54%] 15 | tests/test_knn_cuda.py::TestKNNCuda::test_knn_cuda_400_5_11 PASSED [ 63%] 16 | tests/test_knn_cuda.py::TestKNNCuda::test_knn_cuda_400_5_300000_50 PASSED [ 72%] 17 | tests/test_knn_cuda.py::TestKNNCuda::test_knn_cuda_400_5_300001_50 PASSED [ 81%] 18 | tests/test_knn_cuda.py::TestKNNCuda::test_knn_cuda_400_5_10000 PASSED [ 90%] 19 | tests/test_knn_cuda.py::TestKNNCuda::test_knn_cuda_400_5_10001 PASSED [100%] 20 | 21 | 22 | Computing stats ... Computing stats ... group 1/1 Computing stats ... group 1/1: min Computing stats ... group 1/1: min (1/1) Computing stats ... group 1/1: min (1/1) Computing stats ... group 1/1: max Computing stats ... group 1/1: max (1/1) Computing stats ... group 1/1: max (1/1) Computing stats ... group 1/1: mean Computing stats ... group 1/1: mean (1/1) Computing stats ... group 1/1: mean (1/1) Computing stats ... group 1/1: median Computing stats ... group 1/1: median (1/1) Computing stats ... group 1/1: median (1/1) Computing stats ... group 1/1: iqr Computing stats ... group 1/1: iqr (1/1) Computing stats ... group 1/1: iqr (1/1) Computing stats ... group 1/1: stddev Computing stats ... group 1/1: stddev (1/1) Computing stats ... group 1/1: stddev (1/1) Computing stats ... group 1/1: ops Computing stats ... group 1/1: ops (1/1) Computing stats ... group 1/1: ops (1/1) Computing stats ... group 1/1: ops: outliers Computing stats ... group 1/1: ops: outliers (1/1) Computing stats ... group 1/1: ops: rounds Computing stats ... group 1/1: ops: rounds (1/1) Computing stats ... group 1/1: ops: iterations Computing stats ... group 1/1: ops: iterations (1/1) -------------------------------------------------- benchmark: 1 tests ------------------------------------------------- 23 | Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations 24 | ----------------------------------------------------------------------------------------------------------------------- 25 | test_knn_cuda_performance 10.8398 10.8665 10.8496 0.0108 10.8486 0.0153 1;0 92.1694 5 1 26 | ----------------------------------------------------------------------------------------------------------------------- 27 | 28 | Legend: 29 | Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile. 30 | OPS: Operations Per Second, computed as 1 / Mean 31 | ========================= 11 passed in 115.03 seconds ========================== 32 | -------------------------------------------------------------------------------- /tests/test_knn_cuda.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.neighbors import KDTree 4 | from knn_cuda import KNN 5 | 6 | 7 | def t2n(t): 8 | return t.detach().cpu().numpy() 9 | 10 | 11 | def run_kdtree(ref, query, k): 12 | bs = ref.shape[0] 13 | D, I = [], [] 14 | for j in range(bs): 15 | tree = KDTree(ref[j], leaf_size=100) 16 | d, i = tree.query(query[j], k=k) 17 | D.append(d) 18 | I.append(i) 19 | D = np.stack(D) 20 | I = np.stack(I) 21 | return D, I 22 | 23 | 24 | def run_knnCuda(ref, query, k): 25 | ref = torch.from_numpy(ref).float().cuda() 26 | query = torch.from_numpy(query).float().cuda() 27 | knn = KNN(k, transpose_mode=True) 28 | d, i = knn(ref, query) 29 | return t2n(d), t2n(i) 30 | 31 | 32 | def compare(k, dim, n1, n2=-1): 33 | if n2 < 0: 34 | n2 = n1 35 | for _ in range(5): 36 | ref = np.random.random((2, n1, dim)) 37 | query = np.random.random((2, n2, dim)) 38 | 39 | kd_dist, kd_idices = run_kdtree(ref, query, k) 40 | kn_dist, kn_idices = run_knnCuda(ref, query, k) 41 | 42 | # diff = (kd_idices - kn_idices) != 0 43 | # print(kd_dist[diff]) 44 | # print(kn_dist[diff]) 45 | 46 | np.testing.assert_almost_equal(kd_dist, kn_dist, decimal=3) 47 | # np.testing.assert_array_equal(kd_idices, kn_idices) 48 | 49 | 50 | class TestKNNCuda: 51 | 52 | def test_knn_cuda_performance(self, benchmark): 53 | dim = 5 54 | k = 100 55 | ref = np.random.random((1, 224, dim)) 56 | query = np.random.random((1, 224, dim)) 57 | benchmark(run_knnCuda, ref, query, k) 58 | 59 | def test_knn_cuda_400_5_1000(self): 60 | compare(400, 5, 1000) 61 | 62 | def test_knn_cuda_400_5_100(self): 63 | compare(10, 5, 100) 64 | 65 | def test_knn_cuda_400_5_10(self): 66 | compare(2, 5, 10) 67 | 68 | def test_knn_cuda_400_5_1001(self): 69 | compare(400, 5, 1001) 70 | 71 | def test_knn_cuda_400_5_101(self): 72 | compare(10, 5, 101) 73 | 74 | def test_knn_cuda_400_5_11(self): 75 | compare(2, 5, 11) 76 | 77 | def test_knn_cuda_400_5_300000_50(self): 78 | compare(400, 5, 30000, 50) 79 | 80 | def test_knn_cuda_400_5_300001_50(self): 81 | compare(400, 5, 30001, 50) 82 | 83 | def test_knn_cuda_400_5_10000(self): 84 | compare(400, 5, 10000) 85 | 86 | def test_knn_cuda_400_5_10001(self): 87 | compare(400, 5, 10001) 88 | --------------------------------------------------------------------------------