├── .gitignore ├── LICENSE ├── README.md ├── SECURITY.md ├── mictorch ├── __init__.py ├── common │ ├── caffe_cuda.h │ ├── mtorch_common.h │ └── region_common.hpp ├── imresize.py ├── nms.py ├── nms │ └── nms_cpu.cpp ├── nmsfilt │ ├── nmsfilt_cpu.cpp │ ├── nmsfilt_cuda.cpp │ └── nmsfilt_cuda_kernel.cu ├── nmsfilter.py ├── simple_parser.py ├── smt │ ├── smt_cpu.cpp │ ├── smt_cuda.cpp │ └── smt_cuda_kernel.cu ├── smtpred │ ├── smtpred_cpu.cpp │ ├── smtpred_cuda.cpp │ └── smtpred_cuda_kernel.cu ├── softmaxtree.py └── softmaxtree_prediction.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # PyCharm IDE 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Object Detection Modules for PyTorch 2 | 3 | SoftmaxTree is a tree of softmax groups (i.e. the softmax axis is jagged) 4 | NMSFilter applies Non Maximal Suppression on a batch of confidence values and bouding boxes 5 | 6 | # Contributing 7 | 8 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 9 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 10 | the rights to use your contribution. For details, visit https://cla.microsoft.com. 11 | 12 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide 13 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions 14 | provided by the bot. You will only need to do this once across all repos using our CLA. 15 | 16 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 17 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 18 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 19 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /mictorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/pytorch_od/03d40e60c8c283d12cf61d90a93c70bb3ba10470/mictorch/__init__.py -------------------------------------------------------------------------------- /mictorch/common/caffe_cuda.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #ifndef CAFFE_CUDA_HPP_ 5 | #define CAFFE_CUDA_HPP_ 6 | /** 7 | * Caffe CUDA port to PyTorch 8 | **/ 9 | 10 | // CUDA: grid stride looping 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | 16 | // CUDA: use 1024 threads per block 17 | const int CUDA_NUM_THREADS = 1024; 18 | 19 | // CUDA: number of blocks for threads. 20 | inline int GET_BLOCKS(const int N) { 21 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 22 | } 23 | 24 | #endif // CAFFE_CUDA_HPP_ 25 | -------------------------------------------------------------------------------- /mictorch/common/mtorch_common.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #ifndef MTORCH_COMMON_HPP_ 5 | #define MTORCH_COMMON_HPP_ 6 | /** 7 | * Common checks for all the extensions 8 | **/ 9 | 10 | // Work-around ATen regression 11 | #ifndef AT_ASSERTM 12 | #define AT_ASSERTM AT_ASSERT 13 | #endif 14 | 15 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 16 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 17 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 18 | 19 | #define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be a CPU tensor") 20 | #define CHECK_INPUT_CPU(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) 21 | 22 | #ifdef MSVC 23 | #define RESTRICT __restrict 24 | #else 25 | #define RESTRICT __restrict__ 26 | #endif 27 | #endif // MTORCH_COMMON_HPP_ 28 | -------------------------------------------------------------------------------- /mictorch/common/region_common.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #ifndef REGION_COMMON_HPP_ 5 | #define REGION_COMMON_HPP_ 6 | 7 | template 8 | struct TBox { 9 | scalar_t x, y, w, h; 10 | }; 11 | 12 | #ifdef __CUDACC__ 13 | #define CUDA_HOSTDEV __host__ __device__ __forceinline__ 14 | #else 15 | #define CUDA_HOSTDEV 16 | #endif 17 | 18 | template 19 | CUDA_HOSTDEV scalar_t TOverlap(scalar_t x1, scalar_t w1, scalar_t x2, scalar_t w2) 20 | { 21 | auto l1 = x1 - w1 / 2; 22 | auto l2 = x2 - w2 / 2; 23 | auto left = l1 > l2 ? l1 : l2; 24 | auto r1 = x1 + w1 / 2; 25 | auto r2 = x2 + w2 / 2; 26 | auto right = r1 < r2 ? r1 : r2; 27 | return right - left; 28 | } 29 | 30 | template 31 | CUDA_HOSTDEV scalar_t TBoxIntersection(scalar_t ax, scalar_t ay, scalar_t aw, scalar_t ah, 32 | scalar_t bx, scalar_t by, scalar_t bw, scalar_t bh) { 33 | auto w = TOverlap(ax, aw, bx, bw); 34 | auto h = TOverlap(ay, ah, by, bh); 35 | if (w < 0 || h < 0) { 36 | return 0; 37 | } 38 | else { 39 | return w * h; 40 | } 41 | } 42 | 43 | template 44 | CUDA_HOSTDEV scalar_t TBoxIou(scalar_t ax, scalar_t ay, scalar_t aw, scalar_t ah, 45 | scalar_t bx, scalar_t by, scalar_t bw, scalar_t bh) { 46 | auto i = TBoxIntersection(ax, ay, aw, ah, bx, by, bw, bh); 47 | auto u = aw * ah + bw * bh - i; 48 | return i / u; 49 | } 50 | 51 | #endif // REGION_COMMON_HPP_ 52 | -------------------------------------------------------------------------------- /mictorch/imresize.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------- 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # -------------------------------------------------------------------------- 5 | 6 | import torch 7 | 8 | 9 | @torch.jit.script 10 | def _dynsize_helper(crop_height_i, crop_width_i): 11 | """The input shape could be dynamic 12 | This will be exported as .ones().nonzero() with proper params 13 | """ 14 | y = torch.arange(crop_height_i, dtype=torch.float32) 15 | x = torch.arange(crop_width_i, dtype=torch.float32) 16 | return y, x 17 | 18 | 19 | def resize_bilinear(im, 20 | resized_shape=None, output_crop_shape=None, 21 | darknet=False, edge=True, axis=2): 22 | """Bilinear interpolate 23 | :param im: Image tensor shape (1xCxHxW) 24 | :type im: torch.Tensor 25 | :param resized_shape: shape of the resized image (H_r, W_r) 26 | :param output_crop_shape: shape of the output center crop (H_c, W_c) 27 | :param darknet: if should resize darknet-style 28 | :param edge: if should use edge (like in OpenCV) 29 | :param axis: height axis (0 or 2) 30 | :return: resized image 31 | :rtype: torch.Tensor 32 | """ 33 | 34 | if resized_shape is None: 35 | assert output_crop_shape is not None, "No dimension given to resize" 36 | resized_shape = output_crop_shape 37 | 38 | input_height, input_width = im.shape[axis:axis + 2] 39 | if not isinstance(input_height, torch.Tensor): 40 | input_height, input_width = torch.tensor(input_height), torch.tensor(input_width) 41 | input_height, input_width = input_height.float(), input_width.float() 42 | 43 | assert resized_shape is not None, "No dimension given to resize" 44 | target_height, target_width = resized_shape 45 | if not isinstance(target_height, torch.Tensor): 46 | target_height, target_width = torch.tensor(target_height), torch.tensor(target_width) 47 | resized_shape_i = target_height, target_width 48 | target_height, target_width = target_height.float(), target_width.float() 49 | resized_shape = target_height, target_width 50 | 51 | top = left = None 52 | if output_crop_shape is None: 53 | crop_height_i, crop_width_i = resized_shape_i 54 | crop_height, crop_width = resized_shape 55 | top = 0 56 | left = 0 57 | else: 58 | crop_height_i, crop_width_i = output_crop_shape 59 | if not isinstance(crop_height_i, torch.Tensor): 60 | crop_height_i, crop_width_i = torch.tensor(crop_height_i), torch.tensor(crop_width_i) 61 | crop_height, crop_width = crop_height_i, crop_width_i 62 | 63 | if not crop_height.dtype.is_floating_point: 64 | crop_height, crop_width = crop_height.float(), crop_width.float() 65 | 66 | # TODO: ONNX does not like float in arange, can avoid .long() once issue #27718 is fixed in release 67 | if crop_height_i.dtype.is_floating_point: 68 | crop_height_i, crop_width_i = crop_height_i.long(), crop_width_i.long() 69 | 70 | # TODO: Use normal arange once issue #20075 is fixed in release 71 | y, x = _dynsize_helper(crop_height_i, crop_width_i) 72 | y, x = y.to(im.device), x.to(im.device) 73 | 74 | if top is None: 75 | assert left is None 76 | assert crop_height <= target_height and crop_width <= target_width, "invalid output_crop_shape" 77 | if not crop_height.dtype.is_floating_point: 78 | crop_height, crop_width = crop_height.float(), crop_width.float() 79 | # TODO: use .round() when PyTorch Issue # 25806 is fixed (round for ONNX is released) 80 | top = ((target_height - crop_height) / 2 + 0.5).floor() 81 | left = ((target_width - crop_width) / 2 + 0.5).floor() 82 | 83 | rh = target_height / input_height 84 | rw = target_width / input_width 85 | if edge: 86 | ty = (y + top + 1) / rh + 0.5 * (1 - 1.0 / rh) - 1 87 | tx = (x + left + 1) / rw + 0.5 * (1 - 1.0 / rw) - 1 88 | zero = torch.tensor(0.0, dtype=torch.float32) 89 | ty = torch.max(ty, zero) # ty[ty < 0] = 0 90 | tx = torch.max(tx, zero) # tx[tx < 0] = 0 91 | else: 92 | ty = (y + top) / rh 93 | tx = (x + left) / rw 94 | del y, x 95 | 96 | ity0 = ty.floor() 97 | if darknet: 98 | ity1 = ity0 + 1 99 | else: 100 | ity1 = ty.ceil() 101 | 102 | itx0 = tx.floor() 103 | if darknet: 104 | itx1 = itx0 + 1 105 | else: 106 | itx1 = tx.ceil() 107 | 108 | dy = ty - ity0 109 | dx = tx - itx0 110 | del ty, tx 111 | if axis == 0: 112 | dy = dy.view(-1, 1, 1) 113 | dx = dx.view(-1, 1) 114 | else: 115 | assert axis == 2, "Only 1xCxHxW and HxWxC inputs supported" 116 | dy = dy.view(-1, 1) 117 | dx = dx.view(-1) 118 | dydx = dy * dx 119 | 120 | # noinspection PyProtectedMember 121 | if torch._C._get_tracing_state(): 122 | # always do clamp when tracing 123 | ity1 = torch.min(ity1, input_height - 1) 124 | itx1 = torch.min(itx1, input_width - 1) 125 | else: 126 | # TODO: use searchsorted once avaialble 127 | # items at the end could be out of bound (if upsampling) 128 | if ity1[-1] >= input_height: 129 | ity1[ity1 >= input_height] = input_height - 1 130 | if itx1[-1] >= input_width: 131 | itx1[itx1 >= input_width] = input_width - 1 132 | 133 | iy0 = ity0.long() 134 | ix0 = itx0.long() 135 | iy1 = ity1.long() 136 | ix1 = itx1.long() 137 | del ity0, itx0, ity1, itx1 138 | 139 | if not im.dtype.is_floating_point: 140 | im = im.float() 141 | im_iy0 = im.index_select(axis, iy0) 142 | im_iy1 = im.index_select(axis, iy1) 143 | d = im_iy0.index_select(axis + 1, ix0) * (1 - dx - dy + dydx) + \ 144 | im_iy1.index_select(axis + 1, ix0) * (dy - dydx) + \ 145 | im_iy0.index_select(axis + 1, ix1) * (dx - dydx) + \ 146 | im_iy1.index_select(axis + 1, ix1) * dydx 147 | 148 | return d 149 | -------------------------------------------------------------------------------- /mictorch/nms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------- 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # -------------------------------------------------------------------------- 5 | 6 | import torch 7 | import nms_cpu 8 | 9 | nms = torch.ops.mtorch_ops.nms 10 | 11 | 12 | def register_custom_nms_op(): 13 | # experimenting custom op registration. 14 | from torch.onnx.symbolic_helper import parse_args 15 | from torch.onnx.symbolic_opset9 import view, select 16 | @parse_args('v', 'v', 'f', 'i') 17 | def symbolic_nms(g, boxes, scores, iou_threshold, max_output_boxes): 18 | # if should return all 19 | if max_output_boxes <= 0: 20 | max_output_boxes = 10000 21 | boxes = view(g, boxes, (1, -1, 4)) 22 | max_output_per_class = g.op('Constant', value_t=torch.tensor([max_output_boxes], dtype=torch.long)) 23 | iou_threshold = g.op('Constant', value_t=torch.tensor([iou_threshold], dtype=torch.float)) 24 | # center_point_box == 1 is for our center_x, centr_y, width, height format 25 | nms_out = g.op('NonMaxSuppression', 26 | boxes, view(g, scores, (1, 1, -1)), max_output_per_class, iou_threshold, 27 | center_point_box_i=1) 28 | idx = select(g, nms_out, 1, g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))) 29 | return view(g, idx, (-1,)) 30 | 31 | from torch.onnx import register_custom_op_symbolic 32 | register_custom_op_symbolic('mtorch_ops::nms', symbolic_nms, 10) 33 | -------------------------------------------------------------------------------- /mictorch/nms/nms_cpu.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include "mtorch_common.h" 10 | #include "region_common.hpp" 11 | 12 | namespace { 13 | 14 | template 15 | at::Tensor nms_kernel( 16 | const at::Tensor& bbs, 17 | const at::Tensor& conf, 18 | float thresh, 19 | int max_output_boxes) { 20 | auto w_t = bbs.select(1, 2).contiguous(); 21 | auto h_t = bbs.select(1, 3).contiguous(); 22 | 23 | auto x_t = bbs.select(1, 0).contiguous(); 24 | auto y_t = bbs.select(1, 1).contiguous(); 25 | 26 | at::Tensor areas_t = w_t * h_t; 27 | auto outer_num = bbs.size(0); 28 | at::Tensor suppressed_t = at::zeros({outer_num}, bbs.options().dtype(at::kByte).device(at::kCPU)); 29 | auto order_t = std::get<1>(conf.sort(0, /* descending=*/true)); 30 | 31 | 32 | auto bbs_data = bbs.data(); 33 | auto order_data = order_t.data(); 34 | auto areas_data = areas_t.data(); 35 | auto suppressed_data = suppressed_t.data(); 36 | auto x_data = x_t.data(); 37 | auto y_data = y_t.data(); 38 | auto w_data = w_t.data(); 39 | auto h_data = h_t.data(); 40 | 41 | int non_zero_count = 0; 42 | for (int64_t i = 0; i < outer_num; i++) { 43 | auto i_idx = order_data[i]; 44 | if (suppressed_data[i_idx]) 45 | continue; 46 | if (non_zero_count == max_output_boxes) { 47 | // suppress the rest 48 | suppressed_data[i_idx] = 1; 49 | continue; 50 | } 51 | non_zero_count++; 52 | 53 | auto ax = x_data[i_idx]; 54 | auto ay = y_data[i_idx]; 55 | auto aw = w_data[i_idx]; 56 | auto ah = h_data[i_idx]; 57 | auto area = areas_data[i_idx]; 58 | 59 | for (int64_t j = i + 1; j < outer_num; j++) { 60 | auto j_idx = order_data[j]; 61 | if (suppressed_data[j_idx]) 62 | continue; 63 | 64 | auto bx = x_data[j_idx]; 65 | auto by = y_data[j_idx]; 66 | auto bw = w_data[j_idx]; 67 | auto bh = h_data[j_idx]; 68 | auto inter = TBoxIntersection(ax, ay, aw, ah, bx, by, bw, bh); 69 | auto iou = inter / (area + areas_data[j_idx] - inter); 70 | if (iou > thresh) 71 | suppressed_data[j_idx] = 1; 72 | } 73 | } 74 | 75 | return at::nonzero(suppressed_t == 0).squeeze(1); 76 | } 77 | 78 | } // namespace 79 | 80 | // C++ interface 81 | at::Tensor nms_cpu_forward( 82 | at::Tensor bbs, at::Tensor conf, 83 | float nms_threshold, int max_output_boxes) { 84 | CHECK_INPUT_CPU(bbs); 85 | CHECK_INPUT_CPU(conf); 86 | 87 | if (bbs.numel() == 0 || nms_threshold <= 0) { 88 | return at::empty({0}, bbs.options().dtype(at::kLong).device(at::kCPU)); 89 | } 90 | 91 | AT_ASSERTM(conf.dim() == 1, "invalid conf dim"); 92 | AT_ASSERTM(bbs.dim() == 2, "invalid bbs dim"); 93 | AT_ASSERTM(bbs.size(-1) == 4, "bbs axis must have 4 corners"); 94 | AT_ASSERTM(bbs.numel() == conf.numel() * 4, "conf and bbs mismatch element count"); 95 | 96 | at::Tensor keep; 97 | AT_DISPATCH_FLOATING_TYPES(conf.scalar_type(), "nms::nms_cpu_forward", ([&] { 98 | keep = nms_kernel(bbs, conf, nms_threshold, max_output_boxes); 99 | })); 100 | 101 | return keep; 102 | } 103 | 104 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 105 | m.def("forward", &nms_cpu_forward, "NMS forward (CPU)"); 106 | } 107 | 108 | at::Tensor nms( 109 | at::Tensor bbs, at::Tensor conf, 110 | const double nms_threshold, const int64_t max_output_boxes) { 111 | return nms_cpu_forward(bbs, conf, nms_threshold, max_output_boxes); 112 | } 113 | 114 | static auto registry = torch::RegisterOperators() 115 | .op("mtorch_ops::nms(Tensor bbs, Tensor conf," 116 | "float nms_threshold, int max_output_boxes) -> Tensor", 117 | &nms); 118 | -------------------------------------------------------------------------------- /mictorch/nmsfilt/nmsfilt_cpu.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include "mtorch_common.h" 11 | #include "region_common.hpp" 12 | 13 | namespace { 14 | 15 | // sort the values in p in descending order and keep the index in result 16 | template 17 | void sort_nms_idx(const scalar_t* p, 18 | std::vector& result) { 19 | std::iota(result.begin(), result.end(), 0); 20 | std::sort(result.begin(), result.end(), 21 | [p](int i, int j) { 22 | return p[i] > p[j]; 23 | }); 24 | } 25 | template 26 | void pre_filter(int outer_num, int channels, int inner_num, int classes, int first_class, 27 | float thresh, 28 | scalar_t* RESTRICT top_conf_data) { 29 | for (int index = 0; index < outer_num * classes * inner_num; ++index) { 30 | const int s = index % inner_num; 31 | const int c = (index / inner_num) % classes + first_class; 32 | const int n = (index / inner_num) / classes; 33 | int dim = (n * channels + c) * inner_num + s; 34 | if (top_conf_data[dim] <= thresh) 35 | top_conf_data[dim] = 0; 36 | } 37 | } 38 | 39 | template 40 | void nms_filter(const scalar_t* RESTRICT bbs_data, 41 | int outer_num, int channels, int inner_num, int classes, int first_class, int max_output_boxes, 42 | float thresh, 43 | scalar_t* RESTRICT top_conf_data) { 44 | 45 | for (int index = 0; index < outer_num * classes; ++index) { 46 | int c = index % classes + first_class; 47 | int n = index / classes; 48 | 49 | const int dim = (n * channels + c) * inner_num; 50 | std::vector idx(inner_num); 51 | sort_nms_idx(top_conf_data + dim, idx); 52 | int non_zero_count = 0; 53 | 54 | // TODO: profile the performance and try vectorizing with BLAS (or at::) 55 | for (int i_idx = 0; i_idx < inner_num; ++i_idx) { 56 | int i = idx[i_idx]; 57 | if (top_conf_data[dim + i] == 0) 58 | continue; 59 | if (non_zero_count == max_output_boxes) { 60 | // zero out the rest 61 | top_conf_data[dim + i] = 0; 62 | continue; 63 | } 64 | ++non_zero_count; 65 | auto i_bb = bbs_data + (n * inner_num + i) * 4; 66 | for (int j_idx = i_idx + 1; j_idx < inner_num; ++j_idx) { 67 | int j = idx[j_idx]; 68 | if (top_conf_data[dim + j] == 0) 69 | continue; 70 | auto j_bb = bbs_data + (n * inner_num + j) * 4; 71 | scalar_t curr_iou = TBoxIou(i_bb[0], i_bb[1], i_bb[2], i_bb[3], 72 | j_bb[0], j_bb[1], j_bb[2], j_bb[3]); 73 | if (curr_iou > thresh) 74 | top_conf_data[dim + j] = 0; 75 | } 76 | } 77 | } 78 | } 79 | 80 | } // namespace 81 | 82 | // C++ interface 83 | std::vector nmsfilt_forward( 84 | at::Tensor bbs, at::Tensor conf, 85 | float nms_threshold, int classes, float pre_threshold, int first_class, int max_output_boxes) { 86 | CHECK_INPUT_CPU(bbs); 87 | CHECK_INPUT_CPU(conf); 88 | 89 | AT_ASSERTM(bbs.dim() >= 2, "invalid bbs dim"); 90 | AT_ASSERTM(conf.dim() >= 2, "invalid conf dim"); 91 | int bbs_axis = bbs.dim() - 1; 92 | AT_ASSERTM(bbs.size(bbs_axis) == 4, "bbs axis must have 4 corners"); 93 | 94 | int outer_num = bbs.size(0); 95 | AT_ASSERTM(conf.size(0) == outer_num, "conf has invalid number of batches"); 96 | int inner_num = 1; 97 | for (int i = 1; i < bbs_axis; ++i) 98 | inner_num *= bbs.size(i); 99 | AT_ASSERTM(bbs.numel() == inner_num * outer_num * 4, "bbs invalid size"); 100 | 101 | int channels = 1; 102 | if (conf.numel() != inner_num * outer_num) 103 | channels = conf.size(1); 104 | AT_ASSERTM(classes <= channels, "classes must be less than channels"); 105 | AT_ASSERTM(conf.numel(), inner_num * channels * outer_num, "conf invalid size"); 106 | 107 | if (classes <= 0) 108 | classes = channels; 109 | AT_ASSERTM(classes + first_class <= channels, "classes + first_class_ must be <= channels"); 110 | 111 | auto top_conf = conf.clone(); 112 | 113 | if (pre_threshold >= 0) { 114 | AT_DISPATCH_FLOATING_TYPES(conf.scalar_type(), "nmsfilt_forward::pre_filter", ([&] { 115 | pre_filter(outer_num, channels, inner_num, classes, first_class, 116 | pre_threshold, 117 | top_conf.data()); 118 | })); 119 | } 120 | 121 | if (nms_threshold <= 0 || inner_num == 1) 122 | return {top_conf}; 123 | 124 | AT_DISPATCH_FLOATING_TYPES(conf.scalar_type(), "nmsfilt_forward::nms_filter", ([&] { 125 | nms_filter(bbs.data(), 126 | outer_num, channels, inner_num, classes, first_class, max_output_boxes, 127 | nms_threshold, 128 | top_conf.data()); 129 | })); 130 | 131 | return {top_conf}; 132 | } 133 | 134 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 135 | m.def("forward", &nmsfilt_forward, "NMSFilter forward (CPU)"); 136 | } 137 | 138 | at::Tensor nmsfilt( 139 | at::Tensor bbs, at::Tensor conf, 140 | const double nms_threshold, const double pre_threshold, const int64_t max_output_boxes) { 141 | return nmsfilt_forward(bbs, conf, nms_threshold, 1, pre_threshold, 0, max_output_boxes)[0]; 142 | } 143 | 144 | static auto registry = torch::RegisterOperators() 145 | .op("mtorch_ops::nmsfilt(Tensor bbs, Tensor conf," 146 | "float nms_threshold, float pre_threshold, int max_output_boxes) -> Tensor", 147 | &nmsfilt); 148 | -------------------------------------------------------------------------------- /mictorch/nmsfilt/nmsfilt_cuda.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include "mtorch_common.h" 9 | 10 | // CUDA forward declarations 11 | 12 | std::vector nmsfilt_cuda_forward( 13 | at::Tensor bbs, at::Tensor conf, 14 | float nms_threshold, int classes, float pre_threshold, int first_class, int max_output_boxes, 15 | int outer_num, int channels, int inner_num); 16 | 17 | // C++ interface 18 | std::vector nmsfilt_forward( 19 | at::Tensor bbs, at::Tensor conf, 20 | float nms_threshold, int classes, float pre_threshold, int first_class, int max_output_boxes) { 21 | CHECK_INPUT(bbs); 22 | CHECK_INPUT(conf); 23 | 24 | AT_ASSERTM(bbs.dim() >= 2, "invalid bbs dim"); 25 | AT_ASSERTM(conf.dim() >= 2, "invalid conf dim"); 26 | int bbs_axis = bbs.dim() - 1; // Last axis 27 | AT_ASSERTM(bbs.size(bbs_axis) == 4, "bbs axis must have 4 corners"); 28 | 29 | int outer_num = bbs.size(0); 30 | AT_ASSERTM(conf.size(0) == outer_num, "conf has invalid number of batches"); 31 | int inner_num = 1; 32 | for (int i = 1; i < bbs_axis; ++i) 33 | inner_num *= bbs.size(i); 34 | AT_ASSERTM(bbs.numel() == inner_num * outer_num * 4, "bbs invalid size"); 35 | 36 | int channels = 1; 37 | if (conf.numel() != inner_num * outer_num) 38 | channels = conf.size(1); 39 | AT_ASSERTM(classes <= channels, "classes must be less than channels"); 40 | AT_ASSERTM(conf.numel(), inner_num * channels * outer_num, "conf invalid size"); 41 | 42 | if (classes <= 0) 43 | classes = channels; 44 | AT_ASSERTM(classes + first_class <= channels, "classes + first_class_ must be <= channels"); 45 | 46 | return nmsfilt_cuda_forward( 47 | bbs, conf, 48 | nms_threshold, classes, pre_threshold, first_class, max_output_boxes, 49 | outer_num, channels, inner_num); 50 | } 51 | 52 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 53 | m.def("forward", &nmsfilt_forward, "NMSFilter forward (CUDA)"); 54 | } 55 | -------------------------------------------------------------------------------- /mictorch/nmsfilt/nmsfilt_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include "caffe_cuda.h" 11 | #include "region_common.hpp" 12 | 13 | namespace { 14 | 15 | template 16 | __device__ void bottom_up_argmerge(const scalar_t* __restrict__ p, 17 | int left, int right, int end, 18 | const int* __restrict__ src, int* __restrict__ dst) { 19 | int i = left; 20 | int j = right; 21 | // Merge 2 already sorted lists 22 | for (int k = left; k < end; ++k) { 23 | if (i < right && (j >= end || p[src[i]] > p[src[j]])) { 24 | dst[k] = src[i]; 25 | i++; 26 | } else { 27 | dst[k] = src[j]; 28 | j++; 29 | } 30 | } 31 | } 32 | 33 | template 34 | __global__ void kernel_channel_argmergesort( 35 | int outer_num, int channels, int inner_num, int classes, int first_class, 36 | int width, int chunks, 37 | const scalar_t* __restrict__ data, 38 | int* __restrict__ src, int* __restrict__ dst) { 39 | CUDA_KERNEL_LOOP(index, outer_num * classes * chunks) { 40 | const int i = index % chunks; 41 | const int c_idx = (index / chunks) % classes; 42 | const int c = c_idx + first_class; 43 | const int n = (index / chunks) / classes; 44 | const int dim = (n * channels + c) * inner_num; 45 | const int idx_dim = (n * classes + c_idx) * inner_num; 46 | int left = i * width; 47 | int right = min(left + width / 2, inner_num); 48 | int end = min(left + width, inner_num); 49 | int* src_idx = src + idx_dim; 50 | int* dst_idx = dst + idx_dim; 51 | if (width == 2) { 52 | // Initialize the index 53 | if (right < end) 54 | src_idx[right] = left + 1; 55 | src_idx[left] = left + 0; 56 | } 57 | bottom_up_argmerge(data + dim, 58 | left, right, end, 59 | src_idx, dst_idx); 60 | } 61 | } 62 | 63 | template 64 | __global__ void kernel_pre_filter( 65 | int outer_num, int channels, int inner_num, int classes, int first_class, 66 | float thresh, 67 | scalar_t* __restrict__ top_conf_data) { 68 | CUDA_KERNEL_LOOP(index, outer_num * classes * inner_num) { 69 | const int s = index % inner_num; 70 | const int c = (index / inner_num) % classes + first_class; 71 | const int n = (index / inner_num) / classes; 72 | int dim = (n * channels + c) * inner_num + s; 73 | if (top_conf_data[dim] <= thresh) 74 | top_conf_data[dim] = 0; 75 | } 76 | } 77 | 78 | template 79 | __global__ void kernel_nms_filter( 80 | int outer_num, int channels, int inner_num, int classes, int first_class, int max_output_boxes, 81 | const int* __restrict__ idx, 82 | const scalar_t* __restrict__ bbs_data, float thresh, 83 | scalar_t* __restrict__ top_conf_data) { 84 | CUDA_KERNEL_LOOP(index, outer_num * classes) { 85 | const int c_idx = index % classes; 86 | const int c = c_idx + first_class; 87 | const int n = index / classes; 88 | const int dim = (n * channels + c) * inner_num; 89 | const int idx_dim = (n * classes + c_idx) * inner_num; 90 | const int* src_idx = idx + idx_dim; 91 | int non_zero_count = 0; 92 | for (int i_idx = 0; i_idx < inner_num; ++i_idx) { 93 | int i = src_idx[i_idx]; 94 | if (top_conf_data[dim + i] == 0) 95 | continue; 96 | if (non_zero_count == max_output_boxes) { 97 | // zero out the rest 98 | top_conf_data[dim + i] = 0; 99 | continue; 100 | } 101 | ++non_zero_count; 102 | auto i_bb = bbs_data + (n * inner_num + i) * 4; 103 | for (int j_idx = i_idx + 1; j_idx < inner_num; ++j_idx) { 104 | int j = src_idx[j_idx]; 105 | if (top_conf_data[dim + j] == 0) 106 | continue; 107 | auto j_bb = bbs_data + (n * inner_num + j) * 4; 108 | scalar_t curr_iou = TBoxIou(i_bb[0], i_bb[1], i_bb[2], i_bb[3], 109 | j_bb[0], j_bb[1], j_bb[2], j_bb[3]); 110 | if (curr_iou > thresh) 111 | top_conf_data[dim + j] = 0; 112 | } 113 | } 114 | } 115 | } 116 | 117 | 118 | } // namespace 119 | 120 | std::vector nmsfilt_cuda_forward( 121 | at::Tensor bbs, at::Tensor conf, 122 | float nms_threshold, int classes, float pre_threshold, int first_class, int max_output_boxes, 123 | int outer_num, int channels, int inner_num) { 124 | 125 | auto top_conf = conf.clone(); 126 | 127 | if (pre_threshold >= 0) { 128 | AT_DISPATCH_FLOATING_TYPES(conf.scalar_type(), "nmsfilt_cuda_forward::kernel_pre_filter", ([&] { 129 | kernel_pre_filter<<>>( 130 | outer_num, channels, inner_num, classes, first_class, 131 | pre_threshold, 132 | top_conf.data()); 133 | })); 134 | } 135 | 136 | if (nms_threshold <= 0 || inner_num == 1) 137 | return {top_conf}; 138 | 139 | // intermediate variables 140 | auto idx = at::empty({outer_num, classes, inner_num}, bbs.options().dtype(at::kInt)); 141 | int* idx_data = idx.data(); 142 | 143 | { 144 | // This memory is safe to release after sorting but we keep it in GPU memory, 145 | auto idx_swp = at::empty_like(idx); 146 | int* idx_tmp = idx_swp.data(); 147 | // Start swapped if loop runs for an odd number 148 | bool is_swapped = ((int)ceil(log2((double)inner_num))) % 2 != 0; 149 | AT_DISPATCH_FLOATING_TYPES(conf.scalar_type(), "nmsfilt_cuda_forward::kernel_channel_argmergesort", ([&] { 150 | for (int width = 2; width < inner_num * 2; width *= 2) { 151 | int chunks = (inner_num + width - 1) / width; 152 | int* src_idx = is_swapped ? idx_tmp : idx_data; 153 | int* dst_idx = is_swapped ? idx_data : idx_tmp; 154 | kernel_channel_argmergesort<<>>( 155 | outer_num, channels, inner_num, classes, first_class, 156 | width, chunks, 157 | conf.data(), 158 | src_idx, dst_idx); 159 | is_swapped = !is_swapped; 160 | } 161 | })); 162 | } 163 | 164 | AT_DISPATCH_FLOATING_TYPES(conf.scalar_type(), "nmsfilt_cuda_forward::kernel_nms_filter", ([&] { 165 | kernel_nms_filter <<>>( 166 | outer_num, channels, inner_num, classes, first_class, max_output_boxes, 167 | idx.data(), 168 | bbs.data(), nms_threshold, 169 | top_conf.data() 170 | ); 171 | })); 172 | 173 | return {top_conf}; 174 | } 175 | -------------------------------------------------------------------------------- /mictorch/nmsfilter.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------- 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # -------------------------------------------------------------------------- 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Function 9 | 10 | try: 11 | import nmsfilt_cuda 12 | except ImportError: 13 | nmsfilt_cuda = None 14 | import nmsfilt_cpu 15 | 16 | 17 | class NMSFilterFunction(Function): 18 | @staticmethod 19 | def forward(ctx, bbs, conf, 20 | nms_threshold, classes, pre_threshold, first_class, max_output_boxes): 21 | if bbs.is_cuda: 22 | nms_ = nmsfilt_cuda 23 | else: 24 | nms_ = nmsfilt_cpu 25 | top_conf = nms_.forward(bbs, conf, 26 | nms_threshold, classes, pre_threshold, first_class, max_output_boxes)[0] 27 | 28 | return top_conf 29 | 30 | @staticmethod 31 | def backward(ctx, grad_output): 32 | return tuple([None] * 6) 33 | 34 | 35 | class NMSFilter(nn.Module): 36 | """Applies Non-Maximal Suppression filter to bounding box confidence values 37 | Each class (and each batch) will be filtered independently 38 | """ 39 | def __init__(self, nms_threshold=0.45, classes=1, pre_threshold=-1.0, first_class=0, max_output_boxes=-1, 40 | return_bbs=False): 41 | """NMSFilter 42 | :param nms_threshold: NMS threshold 43 | :param classes: number of classes to filter (-1 to filter all classes independently) 44 | :param pre_threshold: amplitude threshold to apply before NMS 45 | :param first_class: The first class to start filtering "classes" 46 | :param max_output_boxes: maximum number of boxes per-class per-batch (<= 0 to ignore) 47 | :param return_bbs: if should return the input bbs as well as filtered conf 48 | """ 49 | super(NMSFilter, self).__init__() 50 | self.nms_threshold = nms_threshold 51 | self.classes = classes 52 | self.pre_threshold = pre_threshold # amplitude threshold to apply before NMS 53 | self.first_class = first_class 54 | self.max_output_boxes = max_output_boxes 55 | self.return_bbs = return_bbs 56 | 57 | assert self.first_class >= 0 58 | 59 | def forward(self, bbs, conf): 60 | """NMS filter confidences based on bounding boxes 61 | :param bbs: bounding boxes 62 | :param conf: confidences to filter 63 | """ 64 | if isinstance(bbs.shape[0], torch.Tensor): 65 | assert self.classes <= 1, "multi-class NMS tracing not supported yet" 66 | filt = torch.ops.mtorch_ops.nmsfilt( 67 | bbs, conf, 68 | self.nms_threshold, self.pre_threshold, self.max_output_boxes 69 | ) 70 | else: 71 | filt = NMSFilterFunction.apply( 72 | bbs, conf, 73 | self.nms_threshold, self.classes, self.pre_threshold, self.first_class, self.max_output_boxes 74 | ) 75 | 76 | if self.return_bbs: 77 | return bbs, filt 78 | return filt 79 | 80 | def extra_repr(self): 81 | """Extra information 82 | """ 83 | return 'nms_threshold={}, classes={}{}{}{}'.format( 84 | self.nms_threshold, self.classes, 85 | ", pre_threshold={}".format(self.pre_threshold) if self.pre_threshold > 0 else "", 86 | ", first_class={}".format(self.first_class) if self.first_class > 0 else "", 87 | ", max_output_boxes={}".format(self.max_output_boxes) if self.max_output_boxes > 0 else "", 88 | ) 89 | 90 | 91 | def register_custom_nms_op(): 92 | # experimenting custom op registration. 93 | from torch.onnx.symbolic_helper import parse_args 94 | from torch.onnx.symbolic_opset9 import view, select, index_select, scatter 95 | @parse_args('v', 'v', 'f', 'f', 'i') 96 | def symbolic_nmsfilt(g, boxes, scores, iou_threshold, score_threshold, max_output_boxes): 97 | # if should return all 98 | if max_output_boxes <= 0: 99 | max_output_boxes = 10000 100 | shape = g.op("Shape", scores) # original shape 101 | boxes = view(g, boxes, (1, -1, 4)) 102 | max_output_per_class = g.op('Constant', value_t=torch.tensor([max_output_boxes], dtype=torch.long)) 103 | iou_threshold = g.op('Constant', value_t=torch.tensor([iou_threshold], dtype=torch.float)) 104 | score_threshold = g.op('Constant', value_t=torch.tensor([score_threshold], dtype=torch.float)) 105 | # center_point_box == 1 is for our center_x, centr_y, width, height format 106 | nms_out = g.op('NonMaxSuppression', 107 | boxes, view(g, scores, (1, 1, -1)), max_output_per_class, iou_threshold, score_threshold, 108 | center_point_box_i=1) 109 | idx = view(g, select(g, nms_out, 1, g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))), (-1,)) 110 | scores = view(g, scores, (-1,)) 111 | flat_shape = g.op("Shape", scores) 112 | src = index_select(g, scores, 0, idx) 113 | src = view(g, src, (-1,)) 114 | filt = g.op("ConstantOfShape", flat_shape) 115 | filt = scatter(g, filt, 0, idx, src) 116 | return view(g, filt, shape) 117 | 118 | from torch.onnx import register_custom_op_symbolic 119 | register_custom_op_symbolic('mtorch_ops::nmsfilt', symbolic_nmsfilt, 10) 120 | 121 | 122 | # TODO: make these a proper unit test 123 | if __name__ == '__main__': 124 | n = 4 125 | a = 3 126 | c = 10 127 | net = NMSFilter(classes=c, pre_threshold=0.1) 128 | 129 | bb = torch.empty(n, a, 4) 130 | xy = bb[:, :, :2] 131 | wh = bb[:, :, 2:] 132 | xy.uniform_(0.25, 0.35) # cluster them within 0.1 133 | wh.uniform_(0.1, 0.5) 134 | 135 | prob = torch.empty(n, c, a).uniform_(0, .3) 136 | b = net(bb, prob) 137 | 138 | # now test with cuda 139 | net = net.cuda() 140 | bb = bb.cuda() 141 | prob = prob.cuda() 142 | b2 = net(bb, prob) 143 | 144 | print((b2.cpu() - b).sum()) 145 | -------------------------------------------------------------------------------- /mictorch/simple_parser.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------- 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # -------------------------------------------------------------------------- 5 | 6 | from __future__ import print_function 7 | 8 | 9 | def read_softmax_tree(tree_file): 10 | """Simple parsing of softmax tree with subgroups 11 | :param tree_file: path to the tree file, or open file object 12 | :type tree_file: str or file 13 | """ 14 | group_offsets = [] 15 | group_sizes = [] 16 | cid_groups = [] 17 | parents = [] 18 | child = [] # child group 19 | child_sizes = [] # number of child groups 20 | root_size = 0 # number of child sub-groups at root 21 | last_p = -1 22 | last_sg = -1 23 | groups = 0 24 | sub_groups = 0 25 | size = 0 26 | n = 0 27 | with open(tree_file, 'r') as f: 28 | for line in f.readlines(): 29 | tokens = [t for t in line.split(' ') if t] 30 | assert len(tokens) == 2 or len(tokens) == 3, "invalid tree: {} node: {} line: {}".format( 31 | tree_file, n, line) 32 | p = int(tokens[1]) 33 | assert n > p >= -1, "invalid parent: {} node: {} tree: {}".format(p, n, tree_file) 34 | parents.append(p) 35 | sg = -1 36 | if len(tokens) == 3: 37 | sg = int(tokens[2]) 38 | new_group = new_sub_group = False 39 | if p != last_p: 40 | last_p = p 41 | last_sg = sg 42 | new_group = True 43 | sub_groups = 0 44 | elif sg != last_sg: 45 | assert sg > last_sg, "invalid sg: {} node: {} tree: {}".format(sg, n, tree_file) 46 | last_sg = sg 47 | new_sub_group = True 48 | sub_groups += 1 49 | if new_group or new_sub_group: 50 | group_sizes.append(size) 51 | group_offsets.append(n - size) 52 | groups += 1 53 | size = 0 54 | child.append(-1) 55 | child_sizes.append(0) 56 | if p >= 0: 57 | if new_group: 58 | assert child[p] == -1, "node: {} parent discontinuity in tree: {}".format(n, tree_file) 59 | child[p] = groups # start group of child subgroup 60 | elif new_sub_group: 61 | child_sizes[p] = sub_groups 62 | else: 63 | root_size = sub_groups 64 | n += 1 65 | size += 1 66 | cid_groups.append(groups) 67 | group_sizes.append(size) 68 | group_offsets.append(n - size) 69 | 70 | assert len(cid_groups) == len(parents) == len(child) == len(child_sizes) 71 | assert len(group_offsets) == len(group_sizes) == max(cid_groups) + 1 72 | return group_offsets, group_sizes, cid_groups, parents, child, child_sizes, root_size 73 | -------------------------------------------------------------------------------- /mictorch/smt/smt_cpu.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "mtorch_common.h" 10 | 11 | 12 | // C++ interface 13 | std::vector smt_forward( 14 | at::Tensor input, 15 | at::Tensor group_offset, at::Tensor group_size, 16 | int axis) { 17 | CHECK_INPUT_CPU(input); 18 | CHECK_INPUT_CPU(group_offset); 19 | CHECK_INPUT_CPU(group_size); 20 | 21 | int outer_num = 1; 22 | for (int i = 0; i < axis; ++i) 23 | outer_num *= input.size(i); 24 | int inner_num = 1; 25 | for (int i = axis + 1; i < input.dim(); ++i) 26 | inner_num *= input.size(i); 27 | 28 | int groups = group_offset.numel(); 29 | int channels = input.size(axis); 30 | 31 | auto prob = input.clone(); 32 | 33 | // We need to subtract the per-group max to avoid numerical issues, compute the exp, 34 | // and then per-group normalize. 35 | AT_DISPATCH_FLOATING_TYPES(input.type(), "smt_cpu_forward", ([&] { 36 | const int* group_offset_data = group_offset.data(); 37 | const int* group_size_data = group_size.data(); 38 | scalar_t* data = prob.data(); 39 | for (int index = 0; index < outer_num * groups * inner_num; ++index) { 40 | int s = index % inner_num; 41 | int g = (index / inner_num) % groups; 42 | int n = (index / inner_num) / groups; 43 | auto offset = group_offset_data[g]; 44 | auto size = group_size_data[g]; 45 | scalar_t maxval = -FLT_MAX; 46 | for (int j = 0; j < size; ++j) { 47 | if (data[(n * channels + offset + j) * inner_num + s] > maxval) 48 | maxval = data[(n * channels + offset + j) * inner_num + s]; 49 | } 50 | // Subtract the max 51 | for (int j = 0; j < size; ++j) 52 | data[(n * channels + offset + j) * inner_num + s] -= maxval; 53 | } 54 | })); 55 | 56 | // exponentiate 57 | prob.exp_(); 58 | 59 | // per-group sum after exp, and divide 60 | AT_DISPATCH_FLOATING_TYPES(input.type(), "smt_cpu_forward", ([&] { 61 | const int* group_offset_data = group_offset.data(); 62 | const int* group_size_data = group_size.data(); 63 | scalar_t* data = prob.data(); 64 | for (int index = 0; index < outer_num * groups * inner_num; ++index) { 65 | int s = index % inner_num; 66 | int g = (index / inner_num) % groups; 67 | int n = (index / inner_num) / groups; 68 | auto offset = group_offset_data[g]; 69 | auto size = group_size_data[g]; 70 | scalar_t sum = 0; 71 | for (int j = 0; j < size; ++j) 72 | sum += data[(n * channels + offset + j) * inner_num + s]; 73 | // divide by sum 74 | for (int j = 0; j < size; ++j) 75 | data[(n * channels + offset + j) * inner_num + s] /= sum; 76 | } 77 | })); 78 | 79 | return {prob}; 80 | } 81 | 82 | std::vector smt_backward( 83 | at::Tensor prob, at::Tensor grad_output, 84 | at::Tensor group_offset, at::Tensor group_size, 85 | int axis) { 86 | CHECK_INPUT_CPU(prob); 87 | CHECK_INPUT_CPU(grad_output); 88 | CHECK_INPUT_CPU(group_offset); 89 | CHECK_INPUT_CPU(group_size); 90 | 91 | int outer_num = 1; 92 | for (int i = 0; i < axis; ++i) 93 | outer_num *= prob.size(i); 94 | int inner_num = 1; 95 | for (int i = axis + 1; i < prob.dim(); ++i) 96 | inner_num *= prob.size(i); 97 | 98 | int groups = group_offset.numel(); 99 | int channels = prob.size(axis); 100 | 101 | auto diff = grad_output.clone(); // bottom diff 102 | 103 | // Compute per-group inner1d(top_diff, top_data) and subtract them from the bottom diff. 104 | AT_DISPATCH_FLOATING_TYPES(prob.type(), "smt_cpu_backward", ([&] { 105 | const int* group_offset_data = group_offset.data(); 106 | const int* group_size_data = group_size.data(); 107 | const scalar_t* data_1 = grad_output.data(); 108 | const scalar_t* data_2 = prob.data(); 109 | scalar_t* out = diff.data(); 110 | for(int index = 0; index < outer_num * groups * inner_num; ++index) { 111 | int s = index % inner_num; 112 | int g = (index / inner_num) % groups; 113 | int n = (index / inner_num) / groups; 114 | auto offset = group_offset_data[g]; 115 | auto size = group_size_data[g]; 116 | scalar_t dot = 0; 117 | for (int j = 0; j < size; ++j) { 118 | dot += (data_1[(n * channels + offset + j) * inner_num + s] 119 | * data_2[(n * channels + offset + j) * inner_num + s]); 120 | } 121 | // subtract the dot 122 | for (int j = 0; j < size; ++j) 123 | out[(n * channels + offset + j) * inner_num + s] -= dot; 124 | } 125 | })); 126 | 127 | // elementwise multiplication 128 | diff.mul_(prob); 129 | 130 | return {diff}; 131 | } 132 | 133 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 134 | m.def("forward", &smt_forward, "SMT forward (CPU)"); 135 | m.def("backward", &smt_backward, "SMT backward (CPU)"); 136 | } 137 | -------------------------------------------------------------------------------- /mictorch/smt/smt_cuda.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include "mtorch_common.h" 9 | 10 | // CUDA forward declarations 11 | 12 | std::vector smt_cuda_forward( 13 | at::Tensor input, 14 | at::Tensor group_offset, at::Tensor group_size, 15 | int outer_num, int inner_num, int axis); 16 | 17 | std::vector smt_cuda_backward( 18 | at::Tensor prob, at::Tensor grad_output, 19 | at::Tensor group_offset, at::Tensor group_size, 20 | int outer_num, int inner_num, int axis); 21 | 22 | // C++ interface 23 | std::vector smt_forward( 24 | at::Tensor input, 25 | at::Tensor group_offset, at::Tensor group_size, 26 | int axis) { 27 | CHECK_INPUT(input); 28 | CHECK_INPUT(group_offset); 29 | CHECK_INPUT(group_size); 30 | 31 | int outer_num = 1; 32 | for (int i = 0; i < axis; ++i) 33 | outer_num *= input.size(i); 34 | int inner_num = 1; 35 | for (int i = axis + 1; i < input.dim(); ++i) 36 | inner_num *= input.size(i); 37 | 38 | return smt_cuda_forward(input, 39 | group_offset, group_size, 40 | outer_num, inner_num, 41 | axis); 42 | } 43 | 44 | std::vector smt_backward( 45 | at::Tensor prob, at::Tensor grad_output, 46 | at::Tensor group_offset, at::Tensor group_size, 47 | int axis) { 48 | CHECK_INPUT(prob); 49 | CHECK_INPUT(grad_output); 50 | CHECK_INPUT(group_offset); 51 | CHECK_INPUT(group_size); 52 | 53 | int outer_num = 1; 54 | for (int i = 0; i < axis; ++i) 55 | outer_num *= prob.size(i); 56 | int inner_num = 1; 57 | for (int i = axis + 1; i < prob.dim(); ++i) 58 | inner_num *= prob.size(i); 59 | 60 | return smt_cuda_backward( 61 | prob, grad_output, 62 | group_offset, group_size, 63 | outer_num, inner_num, 64 | axis); 65 | } 66 | 67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 68 | m.def("forward", &smt_forward, "SMT forward (CUDA)"); 69 | m.def("backward", &smt_backward, "SMT backward (CUDA)"); 70 | } 71 | -------------------------------------------------------------------------------- /mictorch/smt/smt_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include "caffe_cuda.h" 12 | 13 | namespace { 14 | 15 | template 16 | __global__ void kernel_subtract_max(const int num, const int channels, const int spatial_dim, const int groups, 17 | const int* __restrict__ group_offset_data, const int* __restrict__ group_size_data, scalar_t* __restrict__ data) { 18 | CUDA_KERNEL_LOOP(index, num * groups * spatial_dim) { 19 | int s = index % spatial_dim; 20 | int g = (index / spatial_dim) % groups; 21 | int n = (index / spatial_dim) / groups; 22 | auto offset = group_offset_data[g]; 23 | auto size = group_size_data[g]; 24 | scalar_t maxval = -FLT_MAX; 25 | for (int j = 0; j < size; ++j) { 26 | if (data[(n * channels + offset + j) * spatial_dim + s] > maxval) 27 | maxval = data[(n * channels + offset + j) * spatial_dim + s]; 28 | } 29 | // TODO: Use dynamic parallelism for devices with 3.5 compute capability 30 | // Subtract the max 31 | for (int j = 0; j < size; ++j) 32 | data[(n * channels + offset + j) * spatial_dim + s] -= maxval; 33 | } 34 | } 35 | 36 | template 37 | __global__ void kernel_div_sum(const int num, const int channels, const int spatial_dim, const int groups, 38 | const int* __restrict__ group_offset_data, const int* __restrict__ group_size_data, scalar_t* __restrict__ data) { 39 | CUDA_KERNEL_LOOP(index, num * groups * spatial_dim) { 40 | int s = index % spatial_dim; 41 | int g = (index / spatial_dim) % groups; 42 | int n = (index / spatial_dim) / groups; 43 | auto offset = group_offset_data[g]; 44 | auto size = group_size_data[g]; 45 | scalar_t sum = 0; 46 | for (int j = 0; j < size; ++j) 47 | sum += data[(n * channels + offset + j) * spatial_dim + s]; 48 | // TODO: Use dynamic parallelism for devices with 3.5 compute capability 49 | // divide by sum 50 | for (int j = 0; j < size; ++j) 51 | data[(n * channels + offset + j) * spatial_dim + s] /= sum; 52 | } 53 | } 54 | 55 | template 56 | __global__ void kernel_subtract_dot(const int num, const int channels, const int spatial_dim, const int groups, 57 | const int* group_offset_data, const int* group_size_data, 58 | const scalar_t* __restrict__ data_1, const scalar_t* __restrict__ data_2, scalar_t* __restrict__ out) { 59 | CUDA_KERNEL_LOOP(index, num * groups * spatial_dim) { 60 | int s = index % spatial_dim; 61 | int g = (index / spatial_dim) % groups; 62 | int n = (index / spatial_dim) / groups; 63 | auto offset = group_offset_data[g]; 64 | auto size = group_size_data[g]; 65 | scalar_t dot = 0; 66 | for (int j = 0; j < size; ++j) { 67 | dot += (data_1[(n * channels + offset + j) * spatial_dim + s] 68 | * data_2[(n * channels + offset + j) * spatial_dim + s]); 69 | } 70 | // TODO: Use dynamic parallelism for devices with 3.5 compute capability 71 | // subtract the dot 72 | for (int j = 0; j < size; ++j) 73 | out[(n * channels + offset + j) * spatial_dim + s] -= dot; 74 | } 75 | } 76 | 77 | } // namespace 78 | 79 | std::vector smt_cuda_forward( 80 | at::Tensor input, 81 | at::Tensor group_offset, at::Tensor group_size, 82 | int outer_num, int inner_num, int axis) { 83 | 84 | int groups = group_offset.numel(); 85 | int channels = input.size(axis); 86 | 87 | auto prob = input.clone(); 88 | 89 | // We need to subtract the per-group max to avoid numerical issues, compute the exp, 90 | // and then per-group normalize. 91 | AT_DISPATCH_FLOATING_TYPES(input.type(), "smt_cuda_forward", ([&] { 92 | kernel_subtract_max<<>>( 93 | outer_num, channels, inner_num, groups, 94 | group_offset.data(), group_size.data(), 95 | prob.data()); 96 | })); 97 | 98 | // exponentiate 99 | prob.exp_(); 100 | 101 | // per-group sum after exp, and divide 102 | AT_DISPATCH_FLOATING_TYPES(input.type(), "smt_cuda_forward", ([&] { 103 | kernel_div_sum<<>>( 104 | outer_num, channels, inner_num, groups, 105 | group_offset.data(), group_size.data(), 106 | prob.data()); 107 | })); 108 | 109 | return {prob}; 110 | } 111 | 112 | std::vector smt_cuda_backward( 113 | at::Tensor prob, at::Tensor grad_output, 114 | at::Tensor group_offset, at::Tensor group_size, 115 | int outer_num, int inner_num, int axis) { 116 | 117 | int groups = group_offset.numel(); 118 | int channels = prob.size(axis); 119 | 120 | auto diff = grad_output.clone(); // bottom diff 121 | 122 | // Compute per-group inner1d(top_diff, top_data) and subtract them from the bottom diff. 123 | AT_DISPATCH_FLOATING_TYPES(prob.type(), "smt_cuda_backward", ([&] { 124 | kernel_subtract_dot<<>>( 125 | outer_num, channels, inner_num, groups, 126 | group_offset.data(), group_size.data(), 127 | grad_output.data(), prob.data(), diff.data()); 128 | })); 129 | 130 | // elementwise multiplication 131 | diff.mul_(prob); 132 | 133 | return {diff}; 134 | } 135 | -------------------------------------------------------------------------------- /mictorch/smtpred/smtpred_cpu.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "mtorch_common.h" 11 | 12 | namespace { 13 | 14 | struct Pred { 15 | double parent_p; 16 | int parent_argmax; 17 | int g; 18 | }; 19 | 20 | template 21 | void predict_tree_stack(int outer_num, int channels, int inner_num, 22 | bool append_max, 23 | float threshold, 24 | const int* group_offset_data, const int* group_size_data, const int* child_data, const int* child_size_data, 25 | const scalar_t* obj_data, const scalar_t* prob_data, 26 | int max_stack_size, int n, int s, int g, 27 | scalar_t* top_data, bool output_tree_path) { 28 | std::stack preds; 29 | scalar_t obj = obj_data ? obj_data[n * inner_num + s] : 1; 30 | double root_p = output_tree_path ? obj : 1.0; 31 | // if it is output_tree_path, the score should be the obj * category_prob 32 | // in the path 33 | threshold = output_tree_path ? (threshold * obj) : threshold; 34 | preds.push({ root_p, -1, g }); 35 | const int top_channels = append_max ? (channels + 1) : channels; 36 | while (!preds.empty()) { 37 | assert(preds.size() <= max_stack_size); 38 | auto pred = preds.top(); 39 | preds.pop(); 40 | double p = pred.parent_p; 41 | int argmax = 0; 42 | { 43 | g = pred.g; 44 | scalar_t maxval = -FLT_MAX; 45 | auto offset = group_offset_data[g]; 46 | argmax = offset; 47 | auto size = group_size_data[g]; 48 | for (int j = 0; j < size; ++j) { 49 | scalar_t prob = prob_data[(n * channels + offset + j) * inner_num + s]; 50 | if (prob > maxval) { 51 | argmax = offset + j; 52 | maxval = prob; 53 | } 54 | } 55 | p *= maxval; 56 | } 57 | if (p > threshold) { 58 | if (output_tree_path) { 59 | top_data[(n * top_channels + argmax) * inner_num + s] = static_cast(p); 60 | } 61 | g = child_data[argmax]; // initial child group 62 | if (g >= 0) { 63 | // if there is any child, descend further 64 | int sg_count = child_size_data[argmax] + 1; 65 | for (int sg = 0; sg < sg_count; ++sg) 66 | preds.push({ p, argmax, g + sg }); 67 | continue; 68 | } 69 | } else { 70 | argmax = pred.parent_argmax; 71 | if (argmax < 0) 72 | continue; 73 | p = pred.parent_p; 74 | } 75 | 76 | scalar_t node_p = 0; 77 | if (!output_tree_path) { 78 | node_p = obj_data ? obj : static_cast(p); 79 | top_data[(n * top_channels + argmax) * inner_num + s] = node_p; 80 | } 81 | if (append_max) { 82 | int max_idx = (n * top_channels + channels) * inner_num + s; 83 | if (output_tree_path) { 84 | // in this case, we use the obj as the max value, which will be 85 | // used as the indicator for class-independent NMS. otherwise, the 86 | // maximum value will always be the ones in the first 87 | // child-level of the root node. 88 | top_data[max_idx] = obj; 89 | } else { 90 | if (node_p > top_data[max_idx]) { 91 | top_data[max_idx] = node_p; 92 | } 93 | } 94 | } 95 | } 96 | } 97 | 98 | } // namespace 99 | 100 | // C++ interface 101 | std::vector smtpred_forward( 102 | at::Tensor conf, at::Tensor obj, 103 | at::Tensor group_offset, at::Tensor group_size, at::Tensor child, at::Tensor child_size, 104 | float threshold, bool output_tree_path, bool append_max, 105 | int root_size, int stack_size 106 | ) { 107 | 108 | CHECK_INPUT_CPU(conf); 109 | CHECK_INPUT_CPU(obj); 110 | CHECK_INPUT_CPU(group_offset); 111 | CHECK_INPUT_CPU(group_size); 112 | CHECK_INPUT_CPU(child); 113 | CHECK_INPUT_CPU(child_size); 114 | 115 | AT_ASSERTM(conf.dim() >= 2, "invalid conf dim"); 116 | 117 | int outer_num = conf.size(0); 118 | int inner_num = 1; 119 | for (int i = 2; i < conf.dim(); ++i) 120 | inner_num *= conf.size(i); 121 | 122 | auto shape = conf.sizes().vec(); 123 | int channels = shape[1]; 124 | if (append_max) 125 | shape[1] = channels + 1; 126 | 127 | auto top = at::zeros(shape, conf.type()); 128 | 129 | root_size++; 130 | 131 | AT_DISPATCH_FLOATING_TYPES(conf.type(), "smtpred_forward", ([&] { 132 | auto group_offset_data = group_offset.data(); 133 | auto group_size_data = group_size.data(); 134 | auto child_data = child.data(); 135 | auto child_size_data = child_size.data(); 136 | scalar_t* obj_data = nullptr; 137 | if (obj.numel()) 138 | obj_data = obj.data(); 139 | auto prob_data = conf.data(); 140 | auto top_data = top.data(); 141 | for (int index = 0; index < outer_num * root_size * inner_num; ++index) { 142 | const int s = index % inner_num; 143 | const int g = (index / inner_num) % root_size; 144 | const int n = (index / inner_num) / root_size; 145 | 146 | predict_tree_stack(outer_num, channels, inner_num, 147 | append_max, 148 | threshold, 149 | group_offset_data, group_size_data, child_data, child_size_data, 150 | obj_data, prob_data, 151 | stack_size, n, s, g, 152 | top_data, 153 | output_tree_path); 154 | } 155 | })); 156 | 157 | return {top}; 158 | } 159 | 160 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 161 | m.def("forward", &smtpred_forward, "SoftmaxTreePrediction forward (CPU)"); 162 | } 163 | -------------------------------------------------------------------------------- /mictorch/smtpred/smtpred_cuda.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include "mtorch_common.h" 9 | 10 | // CUDA forward declarations 11 | 12 | std::vector smtpred_cuda_forward( 13 | at::Tensor conf, at::Tensor obj, 14 | at::Tensor group_offset, at::Tensor group_size, at::Tensor child, at::Tensor child_size, 15 | float threshold, bool output_tree_path, bool append_max, 16 | int root_size, int stack_size, 17 | int outer_num, int inner_num 18 | ); 19 | 20 | // C++ interface 21 | std::vector smtpred_forward( 22 | at::Tensor conf, at::Tensor obj, 23 | at::Tensor group_offset, at::Tensor group_size, at::Tensor child, at::Tensor child_size, 24 | float threshold, bool output_tree_path, bool append_max, 25 | int root_size, int stack_size 26 | ) { 27 | CHECK_INPUT(conf); 28 | CHECK_INPUT(obj); 29 | CHECK_INPUT(group_offset); 30 | CHECK_INPUT(group_size); 31 | CHECK_INPUT(child); 32 | CHECK_INPUT(child_size); 33 | 34 | AT_ASSERTM(conf.dim() >= 2, "invalid conf dim"); 35 | 36 | int outer_num = conf.size(0); 37 | int inner_num = 1; 38 | for (int i = 2; i < conf.dim(); ++i) 39 | inner_num *= conf.size(i); 40 | 41 | return smtpred_cuda_forward( 42 | conf, obj, 43 | group_offset, group_size, child, child_size, 44 | threshold, output_tree_path, append_max, 45 | root_size, stack_size, 46 | outer_num, inner_num); 47 | } 48 | 49 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 50 | m.def("forward", &smtpred_forward, "SoftmaxTreePrediction forward (CUDA)"); 51 | } 52 | -------------------------------------------------------------------------------- /mictorch/smtpred/smtpred_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | #include "caffe_cuda.h" 13 | #include "region_common.hpp" 14 | 15 | namespace { 16 | 17 | __device__ void stack_push(double* __restrict__ parent_p_data, int* __restrict__ parent_argmax_data, int* __restrict__ g_data, 18 | int& stack_size, 19 | double p, int argmax, int g) { 20 | parent_p_data[stack_size] = p; 21 | parent_argmax_data[stack_size] = argmax; 22 | g_data[stack_size] = g; 23 | stack_size++; 24 | } 25 | 26 | __device__ void stack_pop(const double* __restrict__ parent_p_data, const int* __restrict__ parent_argmax_data, const int* __restrict__ g_data, 27 | int& stack_size, 28 | double& p, int& argmax, int& g) { 29 | assert(stack_size > 0); 30 | stack_size--; 31 | p = parent_p_data[stack_size]; 32 | argmax = parent_argmax_data[stack_size]; 33 | g = g_data[stack_size]; 34 | } 35 | 36 | template 37 | __device__ void predict_tree_stack( 38 | int outer_num, int channels, int inner_num, 39 | bool append_max, 40 | float threshold, 41 | const int* __restrict__ group_offset_data, const int* __restrict__ group_size_data, const int* __restrict__ child_data, const int* __restrict__ child_size_data, 42 | double* __restrict__ parent_p_data, int* __restrict__ parent_argmax_data, int* __restrict__ g_data, 43 | const scalar_t* __restrict__ obj_data, const scalar_t* __restrict__ prob_data, 44 | int max_stack_size, int n, int s, int g, 45 | scalar_t* top_data, 46 | bool output_tree_path) { 47 | 48 | int stack_size = 0; 49 | const int top_channels = append_max ? (channels + 1) : channels; 50 | scalar_t obj = obj_data ? obj_data[n * inner_num + s] : 1; 51 | double root_p = output_tree_path ? obj : 1.0; 52 | threshold = output_tree_path ? (threshold * obj) : threshold; 53 | stack_push(parent_p_data, parent_argmax_data, g_data, 54 | stack_size, 55 | root_p, -1, g); 56 | while (stack_size) { 57 | assert(stack_size <= max_stack_size); 58 | double parent_p; 59 | int parent_argmax; 60 | int g; 61 | stack_pop(parent_p_data, parent_argmax_data, g_data, 62 | stack_size, 63 | parent_p, parent_argmax, g); 64 | double p = parent_p; 65 | int argmax = 0; 66 | { 67 | scalar_t maxval = -FLT_MAX; 68 | auto offset = group_offset_data[g]; 69 | argmax = offset; 70 | auto size = group_size_data[g]; 71 | for (int j = 0; j < size; ++j) { 72 | scalar_t prob = prob_data[(n * channels + offset + j) * inner_num + s]; 73 | if (prob > maxval) { 74 | argmax = offset + j; 75 | maxval = prob; 76 | } 77 | } 78 | p *= maxval; 79 | } 80 | if (p > threshold) { 81 | if (output_tree_path) { 82 | top_data[(n * top_channels + argmax) * inner_num + s] = static_cast(p); 83 | } 84 | g = child_data[argmax]; // initial child group 85 | if (g >= 0) { 86 | // if there is any child, descend further 87 | int sg_count = child_size_data[argmax] + 1; 88 | for (int sg = 0; sg < sg_count; ++sg) { 89 | stack_push(parent_p_data, parent_argmax_data, g_data, 90 | stack_size, 91 | p, argmax, g + sg); 92 | 93 | } 94 | continue; 95 | } 96 | } else { 97 | argmax = parent_argmax; 98 | if (argmax < 0) 99 | continue; 100 | p = parent_p; 101 | } 102 | 103 | scalar_t node_p = 0; 104 | if (!output_tree_path) { 105 | node_p = obj_data ? obj : static_cast(p); 106 | top_data[(n * top_channels + argmax) * inner_num + s] = node_p; 107 | } 108 | if (append_max) { 109 | int max_idx = (n * top_channels + channels) * inner_num + s; 110 | if (output_tree_path) { 111 | // in this case, we use the obj as the max value, which will be 112 | // used as the indicator for class-independent NMS. or the 113 | // maximum value will always be the ones in the root. 114 | // gradually, we might remove the support of append_max since 115 | // it is more like a legacy strategy 116 | top_data[max_idx] = obj; 117 | } else { 118 | if (node_p > top_data[max_idx]) { 119 | top_data[max_idx] = node_p; 120 | } 121 | } 122 | } 123 | } 124 | } 125 | 126 | template 127 | __global__ void kernel_smt_prediction( 128 | int outer_num, int channels, int inner_num, int root_size, 129 | bool append_max, 130 | float threshold, 131 | const int* __restrict__ group_offset_data, const int* __restrict__ group_size_data, const int* __restrict__ child_data, const int* __restrict__ child_size_data, 132 | double* __restrict__ parent_p_data, int* __restrict__ parent_argmax_data, int* __restrict__ g_data, 133 | const scalar_t* __restrict__ obj_data, const scalar_t* __restrict__ prob_data, 134 | int max_stack_size, 135 | scalar_t* __restrict__ top_data, 136 | bool output_tree_path) { 137 | CUDA_KERNEL_LOOP(index, outer_num * root_size * inner_num) { 138 | const int s = index % inner_num; 139 | const int g = (index / inner_num) % root_size; 140 | const int n = (index / inner_num) / root_size; 141 | 142 | predict_tree_stack(outer_num, channels, inner_num, 143 | append_max, 144 | threshold, 145 | group_offset_data, group_size_data, child_data, child_size_data, 146 | &parent_p_data[index * max_stack_size], &parent_argmax_data[index * max_stack_size], &g_data[index * max_stack_size], 147 | obj_data, prob_data, 148 | max_stack_size, n, s, g, 149 | top_data, 150 | output_tree_path); 151 | } 152 | } 153 | 154 | 155 | } // namespace 156 | 157 | std::vector smtpred_cuda_forward( 158 | at::Tensor conf, at::Tensor obj, 159 | at::Tensor group_offset, at::Tensor group_size, at::Tensor child, at::Tensor child_size, 160 | float threshold, bool output_tree_path, bool append_max, 161 | int root_size, int stack_size, 162 | int outer_num, int inner_num 163 | ) { 164 | 165 | root_size++; 166 | 167 | // Intermediate variables 168 | auto stack_parent_p = at::empty({outer_num, root_size, inner_num, stack_size}, at::CUDA(at::kDouble)); 169 | auto stack_parent_argmax = at::empty({outer_num, root_size, inner_num, stack_size}, at::CUDA(at::kInt)); 170 | auto stack_g = at::empty({outer_num, root_size, inner_num, stack_size}, at::CUDA(at::kInt)); 171 | 172 | auto shape = conf.sizes().vec(); 173 | int channels = shape[1]; 174 | if (append_max) 175 | shape[1] = channels + 1; 176 | 177 | auto top = at::zeros(shape, conf.type()); 178 | 179 | AT_DISPATCH_FLOATING_TYPES(conf.type(), "smtpred_cuda_forward::kernel_smt_prediction", ([&] { 180 | scalar_t* obj_data = nullptr; 181 | if (obj.numel()) 182 | obj_data = obj.data(); 183 | kernel_smt_prediction<<>>( 184 | outer_num, channels, inner_num, root_size, 185 | append_max, 186 | threshold, 187 | group_offset.data(), group_size.data(), child.data(), child_size.data(), 188 | stack_parent_p.data(), stack_parent_argmax.data(), stack_g.data(), 189 | obj_data, conf.data(), 190 | stack_size, 191 | top.data(), 192 | output_tree_path); 193 | })); 194 | 195 | return {top}; 196 | } 197 | -------------------------------------------------------------------------------- /mictorch/softmaxtree.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------- 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # -------------------------------------------------------------------------- 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Function 10 | 11 | from mictorch.simple_parser import read_softmax_tree 12 | 13 | import smt_cuda 14 | import smt_cpu 15 | 16 | 17 | class SoftmaxTreeFunction(Function): 18 | @staticmethod 19 | def forward(ctx, x, group_offsets, group_sizes, axis): 20 | assert 0 <= axis < x.dim(), "invalid axis for x of size: {}".format(x.size()) 21 | node_count = group_offsets[-1] + group_sizes[-1] 22 | assert x.size(axis) == node_count, "Channel count: {} must match tree node count: {}".format( 23 | x.size(axis), node_count 24 | ) 25 | if x.is_cuda: 26 | smt_ = smt_cuda 27 | else: 28 | smt_ = smt_cpu 29 | prob = smt_.forward(x, group_offsets, group_sizes, axis)[0] 30 | 31 | ctx.softmax_axis = axis 32 | ctx.save_for_backward(prob, group_offsets, group_sizes) 33 | return prob 34 | 35 | @staticmethod 36 | def backward(ctx, grad_output): 37 | grad_x = grad_group_offsets = grad_group_sizes = grad_axis = None 38 | if ctx.needs_input_grad[0]: 39 | axis = ctx.softmax_axis 40 | prob, group_offsets, group_sizes = ctx.saved_tensors 41 | if prob.is_cuda: 42 | smt_ = smt_cuda 43 | else: 44 | smt_ = smt_cpu 45 | grad_x = smt_.backward( 46 | prob, grad_output, 47 | group_offsets, group_sizes, 48 | axis 49 | )[0] 50 | 51 | return grad_x, grad_group_offsets, grad_group_sizes, grad_axis 52 | 53 | 54 | class SoftmaxTree(nn.Module): 55 | def __init__(self, tree, axis=1): 56 | """SoftmaxTree is multiple softmaxes with an inherent tree relation assumed between softmax groups 57 | :param tree: path to the tree file (format as in Yolo) 58 | :param axis: axis to apply softmax (jagged axis) 59 | """ 60 | super(SoftmaxTree, self).__init__() 61 | self.tree = tree # type: str 62 | self.axis = axis 63 | 64 | group_offsets, group_sizes, cid_groups, parents, _, _, _ = read_softmax_tree(self.tree) 65 | self.register_buffer('group_offsets', torch.from_numpy(np.array(group_offsets, dtype=np.int32))) 66 | self.register_buffer('group_sizes', torch.from_numpy(np.array(group_sizes, dtype=np.int32))) 67 | self.node_count = len(cid_groups) 68 | self.group_count = len(group_offsets) 69 | assert self.node_count == group_offsets[-1] + group_sizes[-1], "node count: {} last group: {}+{}".format( 70 | self.node_count, group_offsets[-1], group_sizes[-1] 71 | ) 72 | 73 | def forward(self, x): 74 | return SoftmaxTreeFunction.apply( 75 | x, 76 | self.group_offsets, self.group_sizes, self.axis 77 | ) 78 | 79 | def extra_repr(self): 80 | """Extra information 81 | """ 82 | return 'tree={}, nodes={}, groups={}{}'.format( 83 | self.tree, self.node_count, self.group_count, ", axis={}".format(self.axis) if self.axis != 1 else "" 84 | ) 85 | 86 | 87 | # TODO: make these a proper unit test 88 | if __name__ == '__main__': 89 | from StringIO import StringIO 90 | 91 | # Create a flat softmax 92 | net = SoftmaxTree(StringIO("boz -1\nbozak -1\ngoat -1\n")) 93 | 94 | # create a matrix 4x3x8 so that softmax will be applied to the second dimension 95 | a = torch.rand(4, 3, 8) 96 | b = net(a) 97 | # total should be 1.0 98 | print(b[0, :, 0].sum()) 99 | 100 | # now test with cuda 101 | net = net.cuda() 102 | a = a.cuda() 103 | b = net(a) 104 | print(b[0, :, 0].sum()) 105 | -------------------------------------------------------------------------------- /mictorch/softmaxtree_prediction.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------- 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # -------------------------------------------------------------------------- 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Function 10 | 11 | from mictorch.simple_parser import read_softmax_tree 12 | 13 | import smtpred_cuda 14 | import smtpred_cpu 15 | 16 | 17 | def _find_max_stack_size(group_offsets, group_sizes, child, child_sizes, root_size, g=-1): 18 | if g == -1: 19 | max_stack_size = 0 20 | for g in range(root_size + 1): 21 | stack_size = _find_max_stack_size(group_offsets, group_sizes, child, child_sizes, 22 | root_size, g=g) 23 | if stack_size > max_stack_size: 24 | max_stack_size = stack_size 25 | return max_stack_size 26 | 27 | max_stack_size = 1 28 | offset = group_offsets[g] 29 | size = group_sizes[g] 30 | for n in range(offset, offset + size): 31 | g = child[n] 32 | if g < 0: 33 | continue 34 | stack_size = child_sizes[n] + _find_max_stack_size(group_offsets, group_sizes, child, child_sizes, 35 | root_size, g=g) 36 | if stack_size > max_stack_size: 37 | max_stack_size = stack_size 38 | return max_stack_size 39 | 40 | 41 | class SoftmaxTreePredictionFunction(Function): 42 | @staticmethod 43 | def forward(ctx, 44 | conf, obj, 45 | group_offsets, group_sizes, child, child_sizes, 46 | threshold, output_tree_path, append_max, 47 | root_size, stack_size, 48 | ): 49 | node_count = group_offsets[-1] + group_sizes[-1] 50 | assert conf.size(1) == node_count, "Channel count: {} must match tree node count: {}".format( 51 | conf.size(1), node_count 52 | ) 53 | if conf.is_cuda: 54 | smtpred_ = smtpred_cuda 55 | else: 56 | smtpred_ = smtpred_cpu 57 | if obj is None: 58 | obj = torch.zeros(0).type_as(conf) 59 | else: 60 | # if objectness is provided 61 | assert conf.numel() / conf.size(1) == obj.numel(), "Invalid obj dimension" 62 | top_pred = smtpred_.forward( 63 | conf, obj, 64 | group_offsets, group_sizes, child, child_sizes, 65 | threshold, output_tree_path, append_max, 66 | root_size, stack_size 67 | )[0] 68 | 69 | return top_pred 70 | 71 | @staticmethod 72 | def backward(ctx, grad_output): 73 | return tuple([None] * 11) 74 | 75 | 76 | class SoftmaxTreePrediction(nn.Module): 77 | def __init__(self, tree, threshold=0.5, append_max=True, output_tree_path=False): 78 | super(SoftmaxTreePrediction, self).__init__() 79 | self.tree = tree # type: str 80 | self.threshold = threshold 81 | self.append_max = append_max 82 | self.output_tree_path = output_tree_path 83 | 84 | group_offsets, group_sizes, cid_groups, parents, child, child_sizes, self.root_size = read_softmax_tree( 85 | self.tree 86 | ) 87 | self.stack_size = _find_max_stack_size(group_offsets, group_sizes, child, child_sizes, self.root_size) 88 | # TODO: share buffers with SoftmaxTree 89 | self.register_buffer('group_offsets', torch.from_numpy(np.array(group_offsets, dtype=np.int32))) 90 | self.register_buffer('group_sizes', torch.from_numpy(np.array(group_sizes, dtype=np.int32))) 91 | self.register_buffer('child', torch.from_numpy(np.array(child, dtype=np.int32))) 92 | self.register_buffer('child_sizes', torch.from_numpy(np.array(child_sizes, dtype=np.int32))) 93 | self.node_count = len(cid_groups) 94 | self.group_count = len(group_offsets) 95 | assert self.node_count == group_offsets[-1] + group_sizes[-1], "node count: {} last group: {}+{}".format( 96 | self.node_count, group_offsets[-1], group_sizes[-1] 97 | ) 98 | 99 | def forward(self, conf, obj=None): 100 | if self.output_tree_path: 101 | assert obj is not None, "output_tree_path requires objectness bottom" 102 | return SoftmaxTreePredictionFunction.apply( 103 | conf, obj, 104 | self.group_offsets, self.group_sizes, self.child, self.child_sizes, 105 | self.threshold, self.output_tree_path, self.append_max, 106 | self.root_size, self.stack_size, 107 | ) 108 | 109 | def extra_repr(self): 110 | """Extra information 111 | """ 112 | return 'tree={}, nodes={}, groups={}{}{}{}{}'.format( 113 | self.tree, self.node_count, self.group_count, 114 | ", root_size={}".format(self.root_size) if self.root_size else "", 115 | ", stack_size={}".format(self.stack_size) if self.stack_size != 1 else "", 116 | ", append_max={}".format(self.append_max) if not self.append_max else "", 117 | ", output_tree_path={}".format(self.output_tree_path) if self.output_tree_path else "", 118 | ) 119 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # ------------------------------------------------------------------------- 4 | # Copyright (c) Microsoft Corporation. All rights reserved. 5 | # Licensed under the MIT License. 6 | # -------------------------------------------------------------------------- 7 | 8 | from __future__ import print_function 9 | import os 10 | import sys 11 | import os.path as op 12 | from setuptools import find_packages, setup 13 | from torch.utils.cpp_extension import BuildExtension, CppExtension 14 | try: 15 | from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME 16 | assert CUDA_HOME or os.getenv("FORCE_CUDA", "0") == "1", "CUDA not found" 17 | if not torch.cuda.is_available(): 18 | arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None) 19 | if not arch_list: 20 | print("No CUDA runtime found and TORCH_CUDA_ARCH_LIST not set") 21 | try: 22 | driver_version = torch._C._cuda_getDriverVersion() 23 | except Exception as e: 24 | print("torch._C._cuda_getDriverVersion() may be deprecated error: {}".format(e)) 25 | driver_version = 0 26 | if driver_version == 0: 27 | arch_list = 'Pascal;Volta;Turing' 28 | print("No driver found defaulting TORCH_CUDA_ARCH_LIST to {}".format(arch_list)) 29 | os.environ['TORCH_CUDA_ARCH_LIST'] = arch_list 30 | except (ImportError, OSError, AssertionError) as e: 31 | CUDAExtension = None 32 | print("No CUDA was detected, building without CUDA error: {}".format(e)) 33 | 34 | # change directory to this module path 35 | try: 36 | this_file = __file__ 37 | except NameError: 38 | this_file = sys.argv[0] 39 | this_file = os.path.abspath(this_file) 40 | if op.dirname(this_file): 41 | os.chdir(op.dirname(this_file)) 42 | script_dir = os.getcwd() 43 | 44 | include_dirs = [op.abspath('./mictorch/common/')] 45 | 46 | 47 | def readme(fname): 48 | """Read text out of a file in the same directory as setup.py. 49 | """ 50 | return open(op.join(script_dir, fname)).read() 51 | 52 | 53 | if CUDAExtension is None: 54 | cuda_extensions = [] 55 | else: 56 | cuda_extensions = [ 57 | CUDAExtension('smt_cuda', [ 58 | 'mictorch/smt/smt_cuda.cpp', 59 | 'mictorch/smt/smt_cuda_kernel.cu', 60 | ], include_dirs=include_dirs), 61 | CUDAExtension('smtpred_cuda', [ 62 | 'mictorch/smtpred/smtpred_cuda.cpp', 63 | 'mictorch/smtpred/smtpred_cuda_kernel.cu', 64 | ], include_dirs=include_dirs), 65 | CUDAExtension('nmsfilt_cuda', [ 66 | 'mictorch/nmsfilt/nmsfilt_cuda.cpp', 67 | 'mictorch/nmsfilt/nmsfilt_cuda_kernel.cu', 68 | ], include_dirs=include_dirs), 69 | ] 70 | 71 | setup( 72 | name="mictorch", 73 | version="0.0.1", 74 | author="ehazar", 75 | author_email="ehazar@microsoft.com", 76 | url='', 77 | description="Microsoft PyTorch object detection modules", 78 | long_description=readme('README.md'), 79 | packages=find_packages(), 80 | ext_modules=[ 81 | CppExtension('smt_cpu', [ 82 | 'mictorch/smt/smt_cpu.cpp', 83 | ], include_dirs=include_dirs), 84 | CppExtension('nmsfilt_cpu', [ 85 | 'mictorch/nmsfilt/nmsfilt_cpu.cpp', 86 | ], include_dirs=include_dirs), 87 | CppExtension('smtpred_cpu', [ 88 | 'mictorch/smtpred/smtpred_cpu.cpp', 89 | ], include_dirs=include_dirs), 90 | CppExtension('nms_cpu', [ 91 | 'mictorch/nms/nms_cpu.cpp', 92 | ], include_dirs=include_dirs), 93 | ] + cuda_extensions, 94 | cmdclass={ 95 | 'build_ext': BuildExtension 96 | }, 97 | zip_safe=False, 98 | license="MIT", 99 | classifiers=[ 100 | 'Intended Audience :: Developers', 101 | "Programming Language :: Python", 102 | 'Topic :: Software Development', 103 | ] 104 | ) 105 | --------------------------------------------------------------------------------