├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── hnms ├── __init__.py ├── extension │ ├── include │ │ └── hnms.h │ └── src │ │ ├── cpu │ │ └── hnms.cpp │ │ ├── cuda │ │ └── hnms.cu │ │ └── hnms_module.cpp └── multi_hnms.py ├── setup.py └── test ├── example.py └── run.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hashing-based Non-Maximum Suppression 2 | 3 | ## Installation 4 | ``` 5 | git clone https://github.com/microsoft/hnms.git 6 | python setup.py install 7 | ``` 8 | The code has been tested with ubuntu16.4, python 3.6, cuda 10.1, pytorch 1.4 (1.5 as 9 | well). 10 | 11 | ## Usage 12 | ``` 13 | import torch 14 | from hnms import MultiHNMS 15 | 16 | hnms = MultiHNMS(num=1, alpha=0.7) 17 | 18 | # center x, center y, width, height 19 | xywh = [[10, 20, 10, 20], [10, 20, 10, 20], [30, 6, 4, 5]] 20 | conf = [0.9, 0.8, 0.9] 21 | xywh = torch.tensor(xywh).float() 22 | conf = torch.tensor(conf) 23 | keep = hnms(xywh, conf) 24 | print(keep) 25 | ``` 26 | 27 | ## Reference 28 | ``` 29 | @article{DBLP:journals/corr/abs-2005-11426, 30 | author = {Jianfeng Wang and 31 | Xi Yin and 32 | Lijuan Wang and 33 | Lei Zhang}, 34 | title = {Hashing-based Non-Maximum Suppression for Crowded Object Detection}, 35 | journal = {CoRR}, 36 | volume = {abs/2005.11426}, 37 | year = {2020}, 38 | url = {https://arxiv.org/abs/2005.11426}, 39 | archivePrefix = {arXiv}, 40 | eprint = {2005.11426}, 41 | } 42 | ``` 43 | 44 | 45 | # Contributing 46 | 47 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 48 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 49 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 50 | 51 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 52 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 53 | provided by the bot. You will only need to do this once across all repos using our CLA. 54 | 55 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 56 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 57 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 58 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /hnms/__init__.py: -------------------------------------------------------------------------------- 1 | from .multi_hnms import MultiHNMS, HNMS 2 | -------------------------------------------------------------------------------- /hnms/extension/include/hnms.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #ifdef WITH_CUDA 5 | at::Tensor hnms_cuda(const at::Tensor& dets, 6 | const at::Tensor& scores, 7 | float w0, 8 | float h0, 9 | float alpha, 10 | float bx, 11 | float by 12 | ); 13 | #endif 14 | 15 | 16 | at::Tensor hnms_cpu(const at::Tensor& dets, 17 | const at::Tensor& scores, 18 | float w0, 19 | float h0, 20 | float alpha, 21 | float bx, 22 | float by 23 | ); 24 | 25 | 26 | at::Tensor hnms(const at::Tensor& dets, 27 | const at::Tensor& scores, 28 | float w0, 29 | float h0, 30 | float alpha, 31 | float bx, 32 | float by 33 | ); 34 | -------------------------------------------------------------------------------- /hnms/extension/src/cpu/hnms.cpp: -------------------------------------------------------------------------------- 1 | #include "hnms.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | at::Tensor hash_rects(const at::Tensor& dets, 10 | float w0, 11 | float h0, 12 | float alpha, 13 | float bx, 14 | float by) { 15 | auto log_w0 = log(w0); 16 | auto log_h0 = log(h0); 17 | auto log_alpha = log(alpha); 18 | 19 | // map the rects to the code 20 | auto x = dets.select(1, 0).contiguous(); 21 | auto y = dets.select(1, 1).contiguous(); 22 | auto w = dets.select(1, 2).contiguous(); 23 | auto h = dets.select(1, 3).contiguous(); 24 | auto alpha_ratio = (1. - alpha) / (1. + alpha); 25 | auto w0_alpha = w0 * alpha_ratio; 26 | auto h0_alpha = h0 * alpha_ratio; 27 | 28 | auto i = at::round((log_w0 - at::log(w)) / log_alpha); 29 | auto j = at::round((log_h0 - at::log(h)) / log_alpha); 30 | 31 | auto di = w0_alpha / at::pow(alpha, i); 32 | auto dj = h0_alpha / at::pow(alpha, j); 33 | 34 | at::Tensor qx, qy; 35 | qx = at::round(x / di - bx); 36 | qy = at::round(y / dj - by); 37 | auto result = at::stack({qx, qy, i, j}, 1); 38 | return at::_cast_Long(result).contiguous(); 39 | } 40 | 41 | typedef long TCode; 42 | TCode get_code(const long* p_code) { 43 | return p_code[0] + p_code[1] * 10000 + 44 | p_code[2] * 100000000 + p_code[3] * 1000000000000; 45 | } 46 | 47 | at::Tensor get_best_score_each_code( 48 | at::Tensor codes, 49 | const at::Tensor& scores) { 50 | std::map code_to_idx; 51 | 52 | auto p_code = codes.data(); 53 | auto p_score = scores.data(); 54 | 55 | auto ndets = codes.size(0); 56 | for (auto i = 0; i < ndets; i++) { 57 | auto code = get_code(p_code); 58 | if (code_to_idx.count(code) == 0) { 59 | code_to_idx[code] = i; 60 | } else { 61 | auto &pre_idx = code_to_idx[code]; 62 | if (p_score[pre_idx] < p_score[i]) { 63 | pre_idx = i; 64 | } 65 | } 66 | p_code += 4; 67 | } 68 | 69 | at::Tensor result = at::ones({long(code_to_idx.size())}, 70 | scores.options().dtype(at::kLong).device(at::kCPU)); 71 | auto p = result.data(); 72 | int idx = 0; 73 | for (auto i = code_to_idx.begin(); i != code_to_idx.end(); i++) { 74 | p[idx++] = i->second; 75 | } 76 | 77 | return result; 78 | } 79 | 80 | at::Tensor hnms_cpu(const at::Tensor& dets, 81 | const at::Tensor& scores, 82 | float w0, 83 | float h0, 84 | float alpha, 85 | float bx, 86 | float by 87 | ) { 88 | AT_ASSERTM(!dets.is_cuda(), "dets must be a CPU tensor"); 89 | AT_ASSERTM(!scores.is_cuda(), "scores must be a CPU tensor"); 90 | if (dets.numel() == 0) { 91 | return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); 92 | } 93 | 94 | auto codes = hash_rects(dets, w0, h0, alpha, bx, by); 95 | 96 | auto result = get_best_score_each_code(codes, scores); 97 | 98 | return result; 99 | } 100 | 101 | at::Tensor hnms(const at::Tensor& dets, 102 | const at::Tensor& scores, 103 | float w0, 104 | float h0, 105 | float alpha, 106 | float bx, 107 | float by 108 | ) { 109 | if (dets.is_cuda()) { 110 | #ifdef WITH_CUDA 111 | if (dets.numel() == 0) 112 | return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); 113 | return hnms_cuda(dets, scores, w0, h0, alpha, bx, by); 114 | #else 115 | AT_ERROR("Not compiled with GPU support"); 116 | #endif 117 | } 118 | 119 | return hnms_cpu(dets, scores, w0, h0, alpha, bx, by); 120 | } 121 | 122 | -------------------------------------------------------------------------------- /hnms/extension/src/cuda/hnms.cu: -------------------------------------------------------------------------------- 1 | #ifdef WITH_CUDA 2 | #include "hnms.h" 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | using namespace std::chrono; 15 | 16 | 17 | #define CONF_TO_INT_MULT 1000000 18 | #define CONF_TO_INT_ADD 100000 19 | #define CONF_TO_INT(x) (long long)((x) * CONF_TO_INT_MULT) + CONF_TO_INT_ADD 20 | 21 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 22 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 23 | i += blockDim.x * gridDim.x) 24 | const int CUDA_NUM_THREADS = 512; 25 | 26 | int const threadsPerBlock = sizeof(unsigned long long) * 8; 27 | 28 | inline int GET_BLOCKS(const int N) { 29 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 30 | } 31 | 32 | 33 | template 34 | __global__ void hnms_max_conf_kernel(long long nthreads, 35 | T* box_confs, 36 | int64_t* cell_indices, 37 | int64_t* cell_max_confs) { 38 | CUDA_1D_KERNEL_LOOP(i, nthreads) { 39 | unsigned long long conf = CONF_TO_INT(box_confs[i]); 40 | unsigned long long cell = cell_indices[i]; 41 | unsigned long long * cell_max = (unsigned long long*)(cell_max_confs + cell); 42 | // long long type is not supported for atomiMax 43 | atomicMax(cell_max, conf); 44 | } 45 | } 46 | 47 | template 48 | __global__ void hnms_max_idx_kernel(long long nthreads, 49 | T* box_confs, 50 | int64_t* cell_indices, 51 | int64_t* cell_max_confs) { 52 | CUDA_1D_KERNEL_LOOP(i, nthreads) { 53 | unsigned long long conf = CONF_TO_INT(box_confs[i]); 54 | auto cell = cell_indices[i]; 55 | unsigned long long* cell_max = (unsigned long long*)(cell_max_confs + cell); 56 | // no implementation to take long long, but unsigned long long 57 | atomicCAS(cell_max, conf, (unsigned long long)i); 58 | } 59 | } 60 | 61 | template 62 | __global__ void hash_rects_kernel(int64_t nthreads, 63 | T* dets, 64 | T w0, T h0, T alpha, 65 | T bx, T by, 66 | T alpha_ratio, 67 | int64_t* out) { 68 | CUDA_1D_KERNEL_LOOP(idx_box, nthreads) { 69 | auto log_w0 = log(w0); 70 | auto log_h0 = log(h0); 71 | auto log_alpha = log(alpha); 72 | 73 | auto curr_det = dets + idx_box * 4; 74 | auto x = curr_det[0]; 75 | auto y = curr_det[1]; 76 | auto w = curr_det[2]; 77 | auto h = curr_det[3]; 78 | auto w0_alpha = w0 * alpha_ratio; 79 | auto h0_alpha = h0 * alpha_ratio; 80 | 81 | auto i = round((log_w0 - log(w)) / log_alpha); 82 | auto j = round((log_h0 - log(h)) / log_alpha); 83 | auto di = w0_alpha / pow(alpha, i); 84 | auto dj = h0_alpha / pow(alpha, j); 85 | 86 | int64_t qx, qy; 87 | qx = round(x / di - bx); 88 | qy = round(y / dj - by); 89 | auto curr_out = out + 4 * idx_box; 90 | curr_out[0] = qx; 91 | curr_out[1] = qy; 92 | curr_out[2] = i; 93 | curr_out[3] = j; 94 | } 95 | } 96 | 97 | at::Tensor hash_rects_cuda(const at::Tensor& dets, 98 | float w0, 99 | float h0, 100 | float alpha, 101 | float bx, 102 | float by) { 103 | auto num_box = dets.size(0); 104 | auto alpha_ratio = (1. - alpha) / (1. + alpha); 105 | 106 | auto result = at::zeros({long(num_box), 4}, 107 | dets.options().dtype(at::kLong)); 108 | 109 | AT_DISPATCH_FLOATING_TYPES(dets.type(), "HASH_RECTS", [&] { 110 | hash_rects_kernel<<>>(num_box, 111 | dets.data(), 112 | (scalar_t)w0, (scalar_t)h0, (scalar_t)alpha, 113 | (scalar_t)bx, (scalar_t)by, 114 | alpha_ratio, 115 | result.data()); 116 | }); 117 | return result; 118 | } 119 | 120 | __global__ void map_code(int num_box, 121 | int64_t* codes, 122 | int64_t* codes_as_one) { 123 | CUDA_1D_KERNEL_LOOP(idx_box, num_box) { 124 | auto curr_code = codes + 4 * idx_box; 125 | auto curr_mapped = codes_as_one + idx_box; 126 | *curr_mapped = curr_code[0] + 127 | curr_code[1] * 10000 + 128 | curr_code[2] * 100000000 + 129 | curr_code[3] * 1000000000000; 130 | } 131 | } 132 | 133 | at::Tensor get_best_idx_each_code( 134 | at::Tensor codes, 135 | const at::Tensor& scores) { 136 | auto num_box = codes.size(0); 137 | auto codes_as_one = at::zeros({long(num_box)}, 138 | codes.options().dtype(at::kLong)); 139 | map_code<<>>(num_box, 140 | codes.data(), 141 | codes_as_one.data()); 142 | THCudaCheck(cudaGetLastError()); 143 | 144 | auto unique_result = at::unique_dim(codes_as_one, 0, // dim 145 | false, true); 146 | 147 | at::Tensor reverse_index = std::get<1>(unique_result); 148 | auto count = std::get<0>(unique_result).size(0); 149 | 150 | auto result = at::zeros({long(count)}, 151 | codes.options().dtype(at::kLong)); 152 | 153 | // get the maximum confidence score for each code with the atomic operation 154 | // of atomicMax. 155 | AT_DISPATCH_FLOATING_TYPES(scores.type(), "HNMS_MAX_IDX_KERNEL", [&] { 156 | hnms_max_conf_kernel<<>>( 157 | num_box, 158 | scores.data(), 159 | reverse_index.data(), 160 | result.data()); 161 | }); 162 | THCudaCheck(cudaGetLastError()); 163 | 164 | AT_DISPATCH_FLOATING_TYPES(scores.type(), "HNMS_MAX_IDX_KERNEL", [&] { 165 | hnms_max_idx_kernel<<>>( 166 | num_box, 167 | scores.data(), 168 | reverse_index.data_ptr(), 169 | result.data()); 170 | // NULL, 171 | }); 172 | return result; 173 | } 174 | 175 | at::Tensor hnms_cuda(const at::Tensor& dets, 176 | const at::Tensor& scores, 177 | float w0, 178 | float h0, 179 | float alpha, 180 | float bx, 181 | float by 182 | ) { 183 | AT_ASSERTM(dets.type().is_cuda(), "dets must be a CUDA tensor"); 184 | AT_ASSERTM(scores.type().is_cuda(), "scores must be a CUDA tensor"); 185 | AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores"); 186 | if (dets.numel() == 0) { 187 | return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); 188 | } 189 | 190 | auto codes = hash_rects_cuda(dets, w0, h0, alpha, bx, by); 191 | auto result = get_best_idx_each_code(codes, scores); 192 | return result; 193 | } 194 | 195 | #endif 196 | -------------------------------------------------------------------------------- /hnms/extension/src/hnms_module.cpp: -------------------------------------------------------------------------------- 1 | #include "hnms.h" 2 | 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 5 | m.def("hnms", &hnms, "HNMS"); 6 | } 7 | 8 | -------------------------------------------------------------------------------- /hnms/multi_hnms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | from hnms import _c as hnms_c 5 | 6 | 7 | class MultiHNMS(nn.ModuleList): 8 | def __init__(self, num, alpha): 9 | all_hash_rect = [] 10 | for i in range(num): 11 | curr_w0 = math.exp(1. * i / num * (-math.log(alpha))) 12 | curr_h0 = math.exp(1. * i / num * (-math.log(alpha))) 13 | bx = 1. * i / num 14 | by = 1. * i / num 15 | 16 | hr = HNMS(alpha=alpha, 17 | w0=curr_w0, 18 | h0=curr_h0, 19 | bx=bx, 20 | by=by) 21 | all_hash_rect.append(hr) 22 | super(MultiHNMS, self).__init__(all_hash_rect) 23 | 24 | def forward(self, rects, conf): 25 | for i, hr in enumerate(self): 26 | if i == 0: 27 | curr_keep = hr(rects, conf) 28 | keep = curr_keep 29 | else: 30 | curr_keep = hr(rects[keep], conf[keep]) 31 | keep = keep[curr_keep] 32 | return keep 33 | 34 | class HNMS(nn.Module): 35 | def __init__(self, alpha, w0=1., h0=1., bx=0.5, by=0.5): 36 | super().__init__() 37 | self.w0 = float(w0) 38 | self.h0 = float(h0) 39 | self.alpha = alpha 40 | self.bx = bx 41 | self.by = by 42 | 43 | def __call__(self, rects, conf): 44 | result = hnms_c.hnms(rects, conf, 45 | self.w0, self.h0, 46 | self.alpha, 47 | self.bx, self.by) 48 | return result 49 | 50 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import os.path as op 4 | from setuptools import setup 5 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 6 | 7 | 8 | ext_folder = 'hnms/extension/' 9 | 10 | include_dirs = [op.join(ext_folder, 'include')] 11 | 12 | include_dirs = [op.abspath(i) for i in include_dirs] 13 | 14 | define_macros = [] 15 | 16 | extension = CppExtension 17 | 18 | if torch.cuda.is_available() or os.getenv("FORCE_CUDA", "0") == "1": 19 | extension = CUDAExtension 20 | define_macros += [("WITH_CUDA", None)] 21 | 22 | 23 | setup( 24 | name='hnms', 25 | ext_modules=[ 26 | extension( 27 | 'hnms._c', 28 | [ 29 | op.join(ext_folder, 'src/hnms_module.cpp'), 30 | op.join(ext_folder, 'src/cuda/hnms.cu'), 31 | op.join(ext_folder, 'src/cpu/hnms.cpp'), 32 | ], 33 | include_dirs=include_dirs, 34 | define_macros=define_macros, 35 | ), 36 | ], 37 | cmdclass={ 38 | 'build_ext': BuildExtension 39 | }, 40 | ) 41 | -------------------------------------------------------------------------------- /test/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from hnms import MultiHNMS 3 | 4 | hnms = MultiHNMS(num=1, alpha=0.7) 5 | 6 | rects = [[10, 20, 10, 20], [10, 20, 10, 20], [30, 6, 4, 5]] 7 | conf = [0.9, 0.8, 0.9] 8 | rects = torch.tensor(rects).float() 9 | conf = torch.tensor(conf) 10 | keep = hnms(rects, conf) 11 | print(keep) 12 | 13 | -------------------------------------------------------------------------------- /test/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from hnms import MultiHNMS 3 | 4 | hnms = MultiHNMS(num=1, alpha=0.7) 5 | 6 | xywh = [[10, 20, 10, 20], [10, 20, 10, 20], [30, 6, 4, 5]] 7 | conf = [0.9, 0.8, 0.9] 8 | xywh = torch.tensor(xywh).float() 9 | conf = torch.tensor(conf) 10 | keep = hnms(xywh, conf) 11 | print(keep) 12 | 13 | --------------------------------------------------------------------------------