├── .gitignore ├── LICENSE ├── README.md ├── box_convolution ├── __init__.py ├── box_convolution_function.py ├── box_convolution_module.py └── test.py ├── examples ├── Cityscapes │ ├── README.md │ ├── convnet-runtime-benchmark.py │ ├── datasets.py │ ├── models │ │ ├── ENet.py │ │ └── ERFNet.py │ └── train.py └── mnist.py ├── setup.py └── src ├── bind.cpp ├── box_convolution.cpp ├── box_convolution.h ├── box_convolution_cuda_backward.cu ├── box_convolution_cuda_forward.cu ├── box_convolution_cuda_misc.cu ├── box_convolution_interface.cpp ├── cuda_stubs.cpp ├── integral_image.cpp ├── integral_image.h ├── integral_image_cuda.cu └── integral_image_interface.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__/ 2 | *.pyc 3 | *.so 4 | build/ 5 | *.egg-info 6 | examples/MNIST/ 7 | *runs/ 8 | *log.txt 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 Egor Burkov 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Box Convolution Layer for ConvNets 2 | ================================== 3 | 4 |

5 | 6 |
7 | Single-box-conv network (from `examples/mnist.py`) learns patterns on MNIST 8 |

9 | 10 | # What This Is 11 | 12 | This is a PyTorch implementation of the box convolution layer as introduced in the 2018 NeurIPS [paper](https://papers.nips.cc/paper/7859-deep-neural-networks-with-box-convolutions): 13 | 14 | Burkov, E., & Lempitsky, V. (2018) **Deep Neural Networks with Box Convolutions**. *Advances in Neural Information Processing Systems 31*, 6214-6224. 15 | 16 | # How to Use 17 | 18 | ## Installing 19 | 20 | ```bash 21 | python3 -m pip install git+https://github.com/shrubb/box-convolutions.git 22 | python3 -m box_convolution.test # if throws errors, please open a GitHub issue 23 | ``` 24 | 25 | To uninstall: 26 | 27 | ```bash 28 | python3 -m pip uninstall box_convolution 29 | ``` 30 | 31 | Tested on Ubuntu 18.04.2, Python 3.6, PyTorch 1.0.0, GCC {4.9, 5.5, 6.5, 7.3}, CUDA 9.2. Other versions (e.g. macOS or Python 2.7 or CUDA 8 or CUDA 10) should work too, but I haven't checked. If something doesn't build, please open a Github issue. 32 | 33 | Known issues (see [this chat](https://github.com/shrubb/box-convolutions/issues/2)): 34 | 35 | * CUDA 9/9.1 + GCC 6 isn't supported due to a bug in NVCC. 36 | 37 | You can specify a different compiler with `CC` environment variable: 38 | 39 | ```bash 40 | CC=g++-7 python3 -m pip install git+https://github.com/shrubb/box-convolutions.git 41 | ``` 42 | 43 | ## Using 44 | 45 | ```python3 46 | import torch 47 | from box_convolution import BoxConv2d 48 | 49 | box_conv = BoxConv2d(16, 8, 240, 320) 50 | help(BoxConv2d) 51 | ``` 52 | 53 | Also, there are usage examples in `examples/`. 54 | 55 | 56 | # Quick Tour of Box convolutions 57 | 58 | You may want to see our [poster](https://yadi.sk/i/LNnMrj6FwbOc9A). 59 | 60 | ### Why reinvent the old convolution? 61 | 62 | `3×3` convolutions are too small ⮕ receptive field grows too slow ⮕ ConvNets have to be very deep. 63 | 64 | This is especially undesirable in dense prediction tasks (*segmentation, depth estimation, object detection, ...*). 65 | 66 | Today people solve this by 67 | 68 | * dilated/deformable convolutions (can bring artifacts or degrade to `1×1` conv; almost always filter high-frequency); 69 | * "global" spatial pooling layers (usually too constrained, fixed size, not "fully convolutional"). 70 | 71 | ### How does it work? 72 | 73 | Box convolution layer is a basic *depthwise convolution* (i.e. `Conv2d` with `groups=in_channels`) but with special kernels called *box kernels*. 74 | 75 | A box kernel is a rectangular averaging filter. That is, filter values are fixed and unit! Instead, we learn four parameters per rectangle − its size and offset: 76 | 77 | ![image](https://user-images.githubusercontent.com/9570420/41361143-f6db467a-6f36-11e8-9dfc-086a79256bfc.png) 78 | 79 | ![image](https://user-images.githubusercontent.com/9570420/40393137-f371e1ea-5e26-11e8-868a-79ea3f6847f1.png) 80 | 81 | ### Any success stories? 82 | 83 | One example: there is an efficient semantic segmentation model [**ENet**](https://github.com/e-lab/ENet-training). It's a classical hourglass architecture stacked of dozens ResNet-like blocks (left image). 84 | 85 | Let's replace some of these blocks by our "box convolution block" (right image). 86 | 87 | 88 | 89 | First we replaced every second block with a box convolution block (*Box*ENet in the paper). The model became 90 | 91 | * more accurate, 92 | * faster, 93 | * lighter 94 | * **without dilated convolutions**. 95 | 96 | Then, we replaced **every** residual block (except the down- and up-sampling ones)! The result, *BoxOnly*ENet, is 97 | 98 | * a **ConvNet almost without** (traditional learnable weight) **convolutions**, 99 | * **2** times less operations, 100 | * **3** times less parameters, 101 | * still **more accurate** than ENet! 102 | 103 | -------------------------------------------------------------------------------- /box_convolution/__init__.py: -------------------------------------------------------------------------------- 1 | from .box_convolution_module import BoxConv2d 2 | -------------------------------------------------------------------------------- /box_convolution/box_convolution_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import box_convolution_cpp_cuda as cpp_cuda 4 | 5 | def reparametrize( 6 | x_min, x_max, y_min, y_max, reparametrization_h, reparametrization_w, 7 | inplace=False, inverse=False): 8 | """ 9 | If `inverse is False`, scale module's parameters so that their range becomes 10 | approximately [-1; 1]. Otherwise, do the inverse operation. 11 | 12 | This hack is unfortunately needed for the parameters to work with variants of SGD. 13 | Without this "reparametrization", box sizes' gradients will be extremely small. 14 | 15 | If `not inplace`, returns 4 new tensors, otherwise modifies the given ones. 16 | """ 17 | scalar_h = reparametrization_h if inverse else (1 / reparametrization_h) 18 | scalar_w = reparametrization_w if inverse else (1 / reparametrization_w) 19 | 20 | with torch.no_grad(): 21 | if inplace: 22 | x_min *= scalar_h 23 | x_max *= scalar_h 24 | y_min *= scalar_w 25 | y_max *= scalar_w 26 | else: 27 | return x_min * scalar_h, x_max * scalar_h, y_min * scalar_w, y_max * scalar_w 28 | 29 | # TODO: rename `x_` and `y_` to `h_` and `w_` 30 | class BoxConvolutionFunction(torch.autograd.Function): 31 | @staticmethod 32 | def forward(ctx, input, x_min, x_max, y_min, y_max, 33 | reparametrization_h, reparametrization_w, normalize, exact): 34 | 35 | # store all non-tensor arguments in `ctx` 36 | ctx.normalize = normalize 37 | ctx.reparametrization_h = reparametrization_h 38 | ctx.reparametrization_w = reparametrization_w 39 | ctx.exact = exact 40 | 41 | x_min, x_max, y_min, y_max = reparametrize( 42 | x_min, x_max, y_min, y_max, reparametrization_h, reparametrization_w, inverse=True) 43 | 44 | input_integrated = cpp_cuda.integral_image(input) 45 | output = cpp_cuda.box_convolution_forward( 46 | input_integrated, x_min, x_max, y_min, y_max, normalize, exact) 47 | 48 | ctx.save_for_backward( 49 | input_integrated, x_min, x_max, y_min, y_max, output if normalize else None) 50 | 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input_integrated, x_min, x_max, y_min, y_max, output = ctx.saved_variables 56 | if output is None: 57 | output = torch.empty(0) # to satisfy `box_convolution_backward`'s signature 58 | 59 | retval = cpp_cuda.box_convolution_backward( 60 | input_integrated, x_min, x_max, y_min, y_max, grad_output, output, 61 | ctx.reparametrization_h, ctx.reparametrization_w, 62 | ctx.normalize, ctx.exact, *ctx.needs_input_grad[:5]) 63 | 64 | # 4 `None`s for `reparametrization_h, reparametrization_w, normalize, exact` 65 | return tuple(retval) + (None,) * 4 66 | -------------------------------------------------------------------------------- /box_convolution/box_convolution_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | from .box_convolution_function import BoxConvolutionFunction, reparametrize 5 | import box_convolution_cpp_cuda as cpp_cuda 6 | 7 | class BoxConv2d(torch.nn.Module): 8 | """ 9 | Module that performs depthwise box convolution. 10 | Convolves each of the incoming channels with `num_filters` different, 11 | possibly normalized, box kernels. 12 | 13 | Input : `(batch_size) x (in_planes) x (h) x (w)` 14 | Output: `(batch_size) x (in_planes*num_filters) x (h) x (w)` 15 | 16 | Constructor arguments: 17 | 18 | in_planes: int 19 | Number of channels in the input image (as in Conv2d). 20 | num_filters: int 21 | Number of filters to apply per channel (as in depthwise Conv2d). 22 | max_input_h, max_input_w: int 23 | Maximum estimated height/width of future input images. This parameter does 24 | not strictly bind input images to certain sizes. However, this is used 25 | when clipping the boxes to detect if some box has become too large or has 26 | drifted too far away from the image. See `_clip_parameters()` for details. 27 | reparametrization_factor: float 28 | In module parameters, boxes are not represented directly by their 29 | relative pixel coordinates, because then the gradients will usually 30 | be too small. Rather, here they are scaled into a range that is inside 31 | [-1; 1] by `1 / (reparametrization_factor * max_input_[h/w])`. 32 | When setting up training, generate a video of boxes using `draw_boxes()`. 33 | If they move too slow, increasing this parameter might help. If they 34 | converge too fast, reduce this value. 35 | stride_h, stride_w: int 36 | Stride (as in Conv2d). Not yet implemented. 37 | normalize: bool 38 | If `False`, computes sums over boxes (traditional box filters). 39 | If `True`, computes averages over boxes (normalized box filters). 40 | 41 | Useful fields (change after construction): 42 | 43 | exact: bool 44 | If `False`, box coordinates are rounded (towards smaller box size) before 45 | the output is computed. Significantly faster, but might be harmful for 46 | convergence. Well, still, often works for some reason, so try and see. 47 | Default: `True`. 48 | """ 49 | def __init__(self, 50 | in_planes, num_filters, max_input_h, max_input_w, 51 | reparametrization_factor=8, stride_h=1, stride_w=1, normalize=True): 52 | 53 | super(BoxConv2d, self).__init__() 54 | self.in_planes = in_planes 55 | self.num_filters = num_filters 56 | self.max_input_h, self.max_input_w = max_input_h, max_input_w 57 | # default reparametrization; can be changed instead of setting a separate learning rate 58 | self.reparametrization_h = max_input_h * reparametrization_factor 59 | self.reparametrization_w = max_input_w * reparametrization_factor 60 | self.stride_h, self.stride_w = stride_h, stride_w 61 | assert stride_h == 1 and stride_w == 1, 'Sorry, strides are NYI' 62 | self.normalize = normalize 63 | self.exact = True 64 | 65 | self.x_min, self.x_max, self.y_min, self.y_max = \ 66 | (torch.nn.Parameter(torch.empty(in_planes, num_filters)) for _ in range(4)) 67 | self.reset_parameters() 68 | 69 | def reset_parameters(self): 70 | """ 71 | One of the various possible random box initializations. 72 | """ 73 | with torch.no_grad(): 74 | # TODO speed up 75 | # TODO use torch's random generator 76 | # TODO provide the algorithm used in all original paper's experiments? 77 | max_h, max_w = self.max_input_h, self.max_input_w 78 | min_h, min_w = 2, 2 79 | for in_plane_idx in range(self.in_planes): 80 | for filter_idx in range(self.num_filters): 81 | center_h = random.uniform( 82 | -max_h*2/4.8+1+min_h/2, max_h*2/4.8-1-min_h/2) 83 | center_w = random.uniform( 84 | -max_w*2/4.8+1+min_w/2, max_w*2/4.8-1-min_w/2) 85 | height = 2 * random.uniform( 86 | min_h/2, min((max_h*2/4.8-1)-center_h, center_h-(-max_h*2/4.8+1))) 87 | width = 2 * random.uniform( 88 | min_w/2, min((max_w*2/4.8-1)-center_w, center_w-(-max_w*2/4.8+1))) 89 | 90 | self.x_min[in_plane_idx, filter_idx] = (center_h - height/2) * 1.5 91 | self.x_max[in_plane_idx, filter_idx] = (center_h + height/2) * 1.5 92 | self.y_min[in_plane_idx, filter_idx] = (center_w - width /2) * 1.5 93 | self.y_max[in_plane_idx, filter_idx] = (center_w + width /2) * 1.5 94 | 95 | reparametrize( 96 | self.x_min, self.x_max, self.y_min, self.y_max, 97 | self.reparametrization_h, self.reparametrization_w, inplace=True) 98 | 99 | def draw_boxes(self, channels=None, resolution=(600, 600), weights=None): 100 | """ 101 | Plot all rectangles corresponding to box filters. Useful for debugging. 102 | Return the resulting image, an (H x W x 3) tensor. 103 | 104 | channels: List of input channels to draw boxes for. 105 | Default: `[0, 1, ..., self.in_planes-1]` (draw all boxes). 106 | resolution: Tuple (h, w) -- returned image resolution. 107 | Default: (600, 600) 108 | weights: `len(channels) x self.num_filters` array of values in [0; 1] that define 109 | "importance" of each box (e.g. function of weights from a successive 110 | convolution). More important boxes are given a brigter color, unimportant 111 | are drawn almost transparent. 112 | Default: `numpy.ones((len(channels), self.num_filters))`. 113 | """ 114 | import cv2 115 | import numpy as np 116 | 117 | if channels is None: 118 | channels = range(self.in_planes) 119 | 120 | weights_shape = (len(channels), self.num_filters) 121 | if weights is None: 122 | weights = np.ones(weights_shape) 123 | weights = weights.cpu().float().numpy().reshape(weights_shape) 124 | weights.clip(0.01, 1.0, out=weights) 125 | 126 | retval = np.zeros(resolution + (3,), dtype=np.uint8) 127 | 128 | # draw gray lines at center 129 | center = [resolution[0] // 2, resolution[1] // 2] 130 | retval[center[0], :] = 70 131 | retval[:, center[1]] = 70 132 | 133 | colors = np.array([ 134 | [255, 0, 0], 135 | [ 0, 255, 0], 136 | [ 0, 0, 255], 137 | [255, 255, 255], 138 | [255, 255, 0], 139 | [255, 0, 255], 140 | [ 0, 255, 255], 141 | [ 47, 20, 255], 142 | [255, 60, 160], 143 | [ 60, 170, 255], 144 | [ 30, 105, 210], 145 | [222, 196, 176], 146 | [212, 255, 127], 147 | [250, 206, 135], 148 | [ 50, 205, 50], 149 | [ 0, 165, 255], 150 | [ 60, 20, 220], 151 | [170, 178, 32]], dtype=np.float32) 152 | 153 | x_min, x_max, y_min, y_max = (p.float() for p in self.get_actual_parameters()) 154 | x_min = x_min / self.max_input_h * (resolution[0] / 2) + center[0] 155 | y_min = y_min / self.max_input_w * (resolution[1] / 2) + center[1] 156 | x_max = (x_max + 1) / self.max_input_h * (resolution[0] / 2) + center[0] 157 | y_max = (y_max + 1) / self.max_input_w * (resolution[1] / 2) + center[1] 158 | 159 | for channel_idx in channels: 160 | for filter_idx in range(self.num_filters): 161 | box_weight = weights[channel_idx, filter_idx] 162 | # heuristic for single-plane inputs 163 | color = colors[(filter_idx if len(channels) == 1 else channel_idx) % len(colors)] 164 | # take weights into account 165 | color = (color * box_weight).astype(int) 166 | 167 | param_2d_idx = channel_idx, filter_idx 168 | x_min_curr = x_min[param_2d_idx] 169 | x_max_curr = x_max[param_2d_idx] 170 | y_min_curr = y_min[param_2d_idx] 171 | y_max_curr = y_max[param_2d_idx] 172 | 173 | # if a rect has negative size, fill it 174 | box_is_invalid = x_min_curr > x_max_curr or y_min_curr > y_max_curr 175 | thickness = -1 if box_is_invalid else round(resolution[0] / 500 + 0.5) 176 | 177 | cv2.rectangle( 178 | retval, (y_min_curr, x_min_curr), (y_max_curr, x_max_curr), 179 | color.tolist(), thickness) 180 | 181 | return retval 182 | 183 | def get_actual_parameters(self): 184 | """ 185 | As parameters are scaled (see `reparametrization_factor`), they don't 186 | represent actual box coordinates. 187 | 188 | Return the **real** parameters (i.e. actual relative box coordinates) 189 | as if they weren't rescaled. 190 | """ 191 | return reparametrize( 192 | self.x_min, self.x_max, self.y_min, self.y_max, 193 | self.reparametrization_h, self.reparametrization_w, inplace=False, inverse=True) 194 | 195 | def _clip_parameters(self): 196 | """ 197 | Internal method, do not invoke as a user. 198 | 199 | Dirty parameter fix for projected gradient descent: 200 | - If a filter's width or height is negative, reset it to the minimum allowed positive. 201 | - If the filter is >=twice higher or wider than the input image, shrink it back. 202 | """ 203 | cpp_cuda.clip_parameters( 204 | self.x_min, self.x_max, self.y_min, self.y_max, 205 | self.reparametrization_h, self.reparametrization_w, 206 | self.max_input_h, self.max_input_w, self.exact) 207 | 208 | def train(self, mode=True): 209 | self.training = mode 210 | if mode is False: 211 | # TODO would be good to also precompute rounded parameters and areas 212 | self._clip_parameters() 213 | 214 | return self 215 | 216 | def forward(self, input): 217 | if self.training: 218 | self._clip_parameters() 219 | 220 | return BoxConvolutionFunction.apply( 221 | input, self.x_min, self.x_max, self.y_min, self.y_max, 222 | self.reparametrization_h, self.reparametrization_w, self.normalize, self.exact) 223 | -------------------------------------------------------------------------------- /box_convolution/test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random 3 | import torch 4 | 5 | try: 6 | from tqdm import tqdm 7 | except ImportError: 8 | tqdm = lambda x: x 9 | 10 | def test_integral_image(device): 11 | # or use torch.cumsum 12 | def integral_image_reference(input): 13 | assert input.ndimension() >= 2 14 | h, w = input.shape[-2:] 15 | output_shape = input.shape[:-2] + (h+1, w+1) 16 | output = torch.empty(output_shape, dtype=input.dtype, device=input.device) 17 | 18 | # zero the 0th columns and rows 19 | output.select(-2, 0).fill_(0) 20 | output.select(-1, 0).fill_(0) 21 | 22 | # accumulate rows 23 | output_no_zero_col = output.narrow(-1, 1, w) 24 | sum_rows = torch.zeros_like(input.select(-2, 0), dtype=torch.float64) 25 | for row_idx in range(h): 26 | sum_rows += input.select(-2, row_idx).double() 27 | output_no_zero_col.select(-2, row_idx+1).copy_(sum_rows) 28 | 29 | # accumulate columns 30 | sum_cols = torch.zeros_like(output.select(-1, 0), dtype=torch.float64) 31 | for col_idx in range(w): 32 | sum_cols += output.select(-1, col_idx+1).double() 33 | output.select(-1, col_idx+1).copy_(sum_cols) 34 | 35 | return output 36 | 37 | from box_convolution_cpp_cuda import integral_image 38 | 39 | # check IntegralImageFunction vs reference implementation 40 | for test_idx in tqdm(range(50)): 41 | batch_size = random.randint(1, 3) 42 | in_planes = random.randint(1, 3) 43 | stride_h, stride_w = 1, 1 # may change in the future 44 | h, w = random.randint(1+stride_h, 10), random.randint(1+stride_w, 10) 45 | 46 | input_image = torch.rand(batch_size, in_planes, h, w, requires_grad=True, device=device) 47 | grad_output = torch.rand(batch_size, in_planes, h+1, w+1) < 0.15 48 | grad_output = grad_output.to(device, input_image.dtype) 49 | 50 | reference_result = integral_image_reference(input_image) 51 | our_result = integral_image(input_image) 52 | 53 | if not our_result.allclose(reference_result): 54 | raise ValueError( 55 | 'Test %d failed at forward pass.\n\nInput:\n%s\n\n' 56 | 'Our output:\n%s\n\nReference output:\n%s\n\n' 57 | % (test_idx, input_image, our_result, reference_result)) 58 | 59 | def test_box_convolution_module(device): 60 | def explicit_box_kernel(x_min, x_max, y_min, y_max, normalize): 61 | import math 62 | h_farthest = math.ceil(max(x_max, -x_min)) 63 | w_farthest = math.ceil(max(y_max, -y_min)) 64 | 65 | retval = torch.ones(1+2*h_farthest, 1+2*w_farthest, device=x_min.device) 66 | 67 | def segments_intersection(a_l, a_r, b_l, b_r): 68 | common_l = max(a_l, b_l) 69 | common_r = min(a_r, b_r) 70 | return max(0.0, common_r - common_l) 71 | 72 | for x, row in enumerate(retval, start=-h_farthest): 73 | # h-extent of the current row: [x; x+1] 74 | # h-extent of the box of interest: [x_min; x_max+1] 75 | # find length of their intersection and multiply the row by it 76 | row *= segments_intersection(x, x+1, x_min, x_max+1) 77 | 78 | for y, col in enumerate(retval.t(), start=-w_farthest): 79 | # same for columns 80 | col *= segments_intersection(y, y+1, y_min, y_max+1) 81 | 82 | if normalize: 83 | area = (y_max-y_min+1) * (x_max-x_min+1) 84 | retval *= 1/area 85 | 86 | return retval 87 | 88 | # reference implementation 89 | def box_convolution_reference( 90 | input, x_min, x_max, y_min, y_max, 91 | reparametrization_h, reparametrization_w, normalize, exact): 92 | 93 | assert x_min.shape == x_max.shape 94 | assert x_min.shape == y_min.shape 95 | assert x_min.shape == y_max.shape 96 | 97 | assert input.ndimension() == 4 98 | assert type(normalize) is bool 99 | 100 | x_min, x_max, y_min, y_max = \ 101 | reparametrize( 102 | x_min, x_max, y_min, y_max, 103 | reparametrization_h, reparametrization_w, inverse=True) 104 | 105 | if not exact: 106 | x_min.ceil_() 107 | y_min.ceil_() 108 | x_max.floor_() 109 | y_max.floor_() 110 | 111 | in_planes, num_filters = x_min.shape 112 | assert input.shape[1] == in_planes 113 | 114 | # in_c, out_c = input channel, output channel 115 | kernels = [[explicit_box_kernel(*out_c, normalize) for out_c in zip(*in_c)] \ 116 | for in_c in zip(x_min, x_max, y_min, y_max)] 117 | assert len(kernels) == in_planes 118 | assert all(len(x) == num_filters for x in kernels) 119 | 120 | def conv2d_single_channel(image, kernel): 121 | image = image.view((1,1) + image.shape) 122 | kernel = kernel.view((1,1) + kernel.shape) 123 | padding = (kernel.shape[-2] // 2, kernel.shape[-1] // 2) 124 | return torch.conv2d(image, kernel, padding=padding)[0,0] 125 | 126 | output_shape = list(input.shape) 127 | output_shape.insert(2, num_filters) 128 | output = torch.empty(output_shape, dtype=input.dtype, device=input.device) 129 | 130 | for in_sample, out_sample in zip(input, output): 131 | for in_plane_idx, in_plane_kernels in enumerate(kernels): 132 | for filter_idx, kernel in enumerate(in_plane_kernels): 133 | filtered = conv2d_single_channel(in_sample[in_plane_idx], kernel) 134 | out_sample[in_plane_idx, filter_idx].copy_(filtered) 135 | 136 | retval_shape = list(input.shape) 137 | retval_shape[1] *= num_filters 138 | return output.reshape(retval_shape) 139 | 140 | from . import BoxConv2d 141 | from .box_convolution_function import reparametrize 142 | 143 | # same interface for our target function 144 | def box_convolution_wrapper( 145 | input, x_min, x_max, y_min, y_max, 146 | max_input_h, max_input_w, reparametrization_factor, normalize, exact): 147 | 148 | assert x_min.shape == x_max.shape 149 | assert x_min.shape == y_min.shape 150 | assert x_min.shape == y_max.shape 151 | 152 | assert input.ndimension() == 4 153 | assert type(normalize) is bool 154 | 155 | in_planes, num_filters = x_min.shape 156 | assert input.shape[1] == in_planes 157 | 158 | module = BoxConv2d( 159 | in_planes, num_filters, max_input_h, max_input_w, 160 | reparametrization_factor).to(input) 161 | 162 | del module.x_min; module.x_min = x_min 163 | del module.x_max; module.x_max = x_max 164 | del module.y_min; module.y_min = y_min 165 | del module.y_max; module.y_max = y_max 166 | module.normalize = normalize 167 | module.exact = exact 168 | 169 | params_before = module.get_actual_parameters() 170 | 171 | output = module(input) 172 | 173 | params_after = module.get_actual_parameters() 174 | param_names = 'x_min', 'x_max', 'y_min', 'y_max' 175 | for p_before, p_after, p_name in zip(params_before, params_after, param_names): 176 | if not torch.equal(p_before, p_after): 177 | raise ValueError( 178 | 'Wrong test case configuration: `_clip_parameters` ' 179 | 'has changed one of the parameters.\n\n' + \ 180 | 'h, w = %d, %d\n\n' % (h, w) + \ 181 | 'Before:\n' + \ 182 | '\n'.join('%s: %s' % (n,p) for n,p in zip(param_names, params_before)) + \ 183 | '\n\nAfter:\n' + \ 184 | '\n'.join('%s: %s' % (n,p) for n,p in zip(param_names, params_after))) 185 | 186 | return output 187 | 188 | for test_idx in tqdm(range(40)): 189 | batch_size = random.randint(1, 1) 190 | in_planes = random.randint(1, 1) 191 | num_filters = random.randint(1, 1) 192 | stride_h, stride_w = 1, 1 # may change in the future 193 | exact = random.random() < 0.7 194 | 195 | # if not exact, minimum box size changes from 1 to 2 196 | h = random.randint(1 + stride_h + (not exact), 10) 197 | w = random.randint(1 + stride_w + (not exact), 10) 198 | max_input_h, max_input_w = h+1, w+1 199 | reparametrization_factor = random.random() * 4.5 + 0.5 200 | reparametrization_h = max_input_h * reparametrization_factor 201 | reparametrization_w = max_input_w * reparametrization_factor 202 | gradcheck_step = 0.004 203 | 204 | input_image = torch.rand(batch_size, in_planes, h, w, device=device, requires_grad=True) 205 | 206 | # sample boxes more or less randomly (algorithm isn't practical but is OK for gradcheck) 207 | x_min, x_max, y_min, y_max = \ 208 | (torch.empty(in_planes, num_filters, device=device) for _ in range(4)) 209 | for plane_idx in range(in_planes): 210 | for filter_idx in range(num_filters): 211 | 212 | box_is_valid = False 213 | while not box_is_valid: 214 | x_min_curr = random.uniform(-h+1.05, h-(not exact)-1.1) 215 | y_min_curr = random.uniform(-w+1.05, w-(not exact)-1.1) 216 | 217 | # set sizes to at least 1.001 because of `_clip_parameters`'s behavior 218 | x_max_curr = random.uniform( 219 | x_min_curr + (not exact) + 3*gradcheck_step + 0.001, h-1.05) 220 | y_max_curr = random.uniform( 221 | y_min_curr + (not exact) + 3*gradcheck_step + 0.001, w-1.05) 222 | 223 | # As a function of box coordinates (x_min etc.), box convolution isn't smooth 224 | # at integer points, so the finite difference gradcheck will fail. 225 | # Therefore, let's resample the box until all coordinates are far 226 | # enough from integers. 227 | box_is_valid = True 228 | for value in x_min_curr, y_min_curr, x_max_curr, y_max_curr: 229 | if abs(value - round(value)) <= gradcheck_step * 3: # *3 for extra safety 230 | box_is_valid = False 231 | 232 | x_min[plane_idx, filter_idx] = x_min_curr 233 | y_min[plane_idx, filter_idx] = y_min_curr 234 | x_max[plane_idx, filter_idx] = x_max_curr 235 | y_max[plane_idx, filter_idx] = y_max_curr 236 | 237 | # reparametrize 238 | x_min, x_max, y_min, y_max = \ 239 | reparametrize(x_min, x_max, y_min, y_max, reparametrization_h, reparametrization_w) 240 | 241 | # randomly test either sum filter or average filter 242 | normalize = random.choice((False, True)) 243 | 244 | grad_output = (torch.rand(batch_size, in_planes*num_filters, h, w) < 0.15).to(input_image) 245 | 246 | # check output and grad w.r.t. input vs reference ones 247 | reference_result = box_convolution_reference( 248 | input_image, x_min, x_max, y_min, y_max, 249 | reparametrization_h, reparametrization_w, normalize, exact) 250 | reference_result.backward(grad_output) 251 | reference_grad_input = input_image.grad.clone() 252 | input_image.grad.zero_() 253 | 254 | our_result = box_convolution_wrapper( 255 | input_image, x_min, x_max, y_min, y_max, 256 | max_input_h, max_input_w, reparametrization_factor, normalize, exact) 257 | our_result.backward(grad_output) 258 | our_grad_input = input_image.grad.clone() 259 | 260 | if not our_result.allclose(reference_result, rtol=3e-5, atol=1e-5): 261 | raise ValueError( 262 | 'Test %d failed at forward pass.\n\nNormalize: %s\n\nInput:\n%s\n\n' 263 | 'Our output:\n%s\n\nReference output:\n%s\n\nMax diff: %f\n\n' 264 | % (test_idx, normalize, input_image, our_result, reference_result, \ 265 | (our_result - reference_result).abs().max())) 266 | 267 | if not our_grad_input.allclose(reference_grad_input, rtol=3e-5, atol=1e-5): 268 | raise ValueError( 269 | 'Test %d failed at backward pass.\n\nNormalize: %s\n\n' 270 | 'Input:\n%s\n\nOutput:\n%s\n\ngradOutput:\n%s\n\nOur gradInput:\n%s\n\n' 271 | 'Reference gradInput:\n%s\n\nMax diff: %f\n\n' 272 | % (test_idx, normalize, input_image, our_result, \ 273 | grad_output, our_grad_input, reference_grad_input, \ 274 | (our_grad_input-reference_grad_input).abs().max())) 275 | 276 | # sorry, I don't want to reliably check gradients w.r.t. parameters in rounded mode 277 | if not exact: 278 | continue 279 | 280 | # convert to double and check our grads w.r.t. parameters against finite differences 281 | with torch.no_grad(): 282 | input_image = input_image.double() 283 | x_min = x_min.double() 284 | x_max = x_max.double() 285 | y_min = y_min.double() 286 | y_max = y_max.double() 287 | 288 | for tensor in x_min, x_max, y_min, y_max: 289 | tensor.requires_grad_() 290 | input_image.requires_grad_(False) # already tested above 291 | 292 | try: 293 | original_parameters = reparametrize( 294 | x_min, x_max, y_min, y_max, reparametrization_h, reparametrization_w, inverse=True) 295 | 296 | torch.autograd.gradcheck( 297 | box_convolution_wrapper, 298 | (input_image, x_min, x_max, y_min, y_max, max_input_h, max_input_w, \ 299 | reparametrization_factor, normalize, exact), 300 | eps=gradcheck_step / max(reparametrization_h, reparametrization_w), 301 | raise_exception=True) 302 | except Exception: 303 | print('Test %d failed at finite difference grad check w.r.t. parameters.' % test_idx) 304 | print('Normalize: %s' % normalize) 305 | print('h, w = %d, %d' % (h, w)) 306 | print('x_min, x_max, y_min, y_max are:') 307 | for parameter in original_parameters: 308 | print(parameter) 309 | raise 310 | 311 | if __name__ == '__main__': 312 | seed = int(time.time()) 313 | seed = 1546545757 314 | torch.manual_seed(seed) 315 | random.seed(seed) 316 | print('Random seed is %d' % seed) 317 | 318 | # TODO add dtypes too 319 | devices = ('cpu',) 320 | if torch.cuda.is_available(): 321 | devices += ('cuda',) 322 | 323 | for device in devices: 324 | print('Testing for device \'%s\'' % device) 325 | for testing_function in test_integral_image, test_box_convolution_module: 326 | print('Running %s()...' % testing_function.__name__) 327 | # TODO [re]set random state etc. 328 | testing_function(device) 329 | print('OK') 330 | -------------------------------------------------------------------------------- /examples/Cityscapes/README.md: -------------------------------------------------------------------------------- 1 | Box convolutions for semantic segmentation 2 | ====== 3 | 4 | This folder has the code to reproduce our experiments (**to be uploaded soon**). 5 | 6 | Pretrained models (metrics are shown for the validation set): 7 | 8 | | Name and link | Class IoU, % | Category IoU, % | 9 | |:--------------------------------------------------------:|:------------:|:---------------:| 10 | | [BoxENet]( ) | --.- | --.- | 11 | | [BoxOnlyENet]( ) | --.- | --.- | 12 | | [BoxERFNet]( ) | --.- | --.- | -------------------------------------------------------------------------------- /examples/Cityscapes/convnet-runtime-benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | The script I use for benchmarking. 3 | 4 | Architectures are optimized for inference as if they were deployed in 5 | production: BatchNorms and Dropouts are absorbed, "void" class is removed. 6 | 7 | I managed to fully reproduce Torch7 runtimes from the paper for ENet 8 | and ERFNet; however, for some reason smaller models (e.g. ENet^-, ERFNet^-) 9 | are slower in PyTorch 1.0.1 than in Torch7. 10 | """ 11 | import torch 12 | 13 | architecture = 'ENet' # choice: ENet / BoxENet / ERFNet / BoxERFNet / ENetMinus 14 | device = 'cuda' # or 'cpu' 15 | dtype = torch.float32 16 | 17 | if device == 'cuda': 18 | assert torch.backends.cudnn.enabled 19 | torch.backends.cudnn.benchmark = True 20 | torch.backends.cudnn.deterministic = False 21 | 22 | torch.set_num_threads(1) 23 | torch.manual_seed(666) 24 | 25 | # Cityscapes 1024x512 configuration 26 | n_classes = 18 27 | input_image = torch.rand((1, 3, 512, 1024), dtype=dtype, device=device) 28 | 29 | print('Architecture:', architecture) 30 | print('Device:', input_image.device) 31 | print('Data type:', input_image.dtype) 32 | 33 | from models.ERFNet import ERFNet, BoxERFNet 34 | from models.ENet import ENet, BoxENet, BoxOnlyENet, ENetMinus 35 | 36 | model = globals()[architecture](n_classes).to(input_image) 37 | 38 | # optimize the model for inference 39 | def remove_bn_and_dropout(module): 40 | for child_name, child in module.named_children(): 41 | child_type = str(type(child)) 42 | if 'BatchNorm' in child_type or 'Dropout' in child_type: 43 | module.__setattr__(child_name, torch.nn.Sequential()) 44 | else: 45 | remove_bn_and_dropout(child) 46 | 47 | from box_convolution import BoxConv2d 48 | def set_boxconv_to_nonexact(module): 49 | if isinstance(module, BoxConv2d): 50 | module.exact = False 51 | 52 | model.apply(set_boxconv_to_nonexact) 53 | remove_bn_and_dropout(model) 54 | model.eval() 55 | 56 | # warm up 57 | print('Output shape:', model(input_image).shape) 58 | 59 | n_runs = 10 if device == 'cpu' else 160 60 | import time 61 | 62 | with torch.no_grad(): 63 | if device == 'cuda': 64 | torch.cuda.synchronize() 65 | 66 | start = time.time() 67 | for _ in range(n_runs): 68 | model(input_image) 69 | 70 | if device == 'cuda': 71 | torch.cuda.synchronize() 72 | end = time.time() 73 | 74 | time_per_frame = (end - start) / n_runs 75 | print('%.1f ms per frame / %.2f FPS' % (time_per_frame * 1000, 1 / time_per_frame)) 76 | 77 | -------------------------------------------------------------------------------- /examples/Cityscapes/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torchvision 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | class Cityscapes(torch.utils.data.Dataset): 9 | """ 10 | Native PyTorch Cityscapes but with proper data augmentation. 11 | """ 12 | def __init__( 13 | self, root='/media/hpc4_Raid/e_burkov/Datasets/Cityscapes/', 14 | split='train', size=(1024, 512), augmented=False): 15 | 16 | super().__init__() 17 | 18 | self.cityscapes = torchvision.datasets.Cityscapes( 19 | root=root, split=split, mode='fine', target_type='semantic') 20 | 21 | self.size = size 22 | self.n_classes = 19 23 | 24 | # precomputed mean and stddev per channel 25 | self.mean = np.float32([0.28470638394356, 0.32577008008957, 0.28766867518425]) * 255.0 26 | self.std = np.float32([0.18671783804893, 0.1899059265852, 0.18665011227131]) * 255.0 27 | 28 | # precomputed class frequencies 29 | class_probs = np.float32([ 30 | 0.36869695782661, 31 | 0.060849856585264, 32 | 0.22824048995972, 33 | 0.0065539856441319, 34 | 0.0087727159261703, 35 | 0.012273414991796, 36 | 0.0020779478363693, 37 | 0.0055127013474703, 38 | 0.1592865139246, 39 | 0.011578181758523, 40 | 0.040189824998379, 41 | 0.012189572677016, 42 | 0.0013512192526832, 43 | 0.069945447146893, 44 | 0.0026745572686195, 45 | 0.0023519159294665, 46 | 0.0023290426470339, 47 | 0.00098657899070531, 48 | 0.0041390685364604, 49 | ]) 50 | # add "void" class and adopt a slightly modified class weighting scheme from ENet 51 | self.class_weights = np.concatenate(([0], 1.0 / np.log(class_probs + 1.1))) 52 | self.class_weights = torch.tensor(self.class_weights, dtype=torch.float32) 53 | 54 | self.augmented = augmented 55 | 56 | @staticmethod 57 | def augment(image, labels): 58 | """ 59 | image: np.uint8, H x W x 3 60 | labels: np.uint8, H x W 61 | """ 62 | flip = bool(np.random.randint(2)) 63 | maybe_flip_matrix = np.eye(3) 64 | if flip: 65 | maybe_flip_matrix[0,0] = -1 66 | maybe_flip_matrix[0,2] = labels.shape[1] 67 | 68 | angle = (np.random.rand() * 2 - 1) * 7.5 69 | scale_factor = (np.random.rand() * 2 - 1) * 0.12 + 1.0 70 | image_center = (labels.shape[1] / 2, labels.shape[0] / 2) 71 | rotation_matrix = np.eye(3) 72 | rotation_matrix[:2] = cv2.getRotationMatrix2D(image_center, angle, scale_factor) 73 | 74 | transformation_matrix = (maybe_flip_matrix @ rotation_matrix)[:2] 75 | 76 | image_size = (labels.shape[1], labels.shape[0]) 77 | 78 | image = cv2.warpAffine( 79 | image, transformation_matrix, image_size, 80 | flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE) 81 | labels = cv2.warpAffine( 82 | labels, transformation_matrix, image_size, 83 | flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT) 84 | 85 | return image, labels 86 | 87 | def __len__(self): 88 | return len(self.cityscapes) 89 | 90 | def __getitem__(self, idx): 91 | image, labels = map(np.array, self.cityscapes[idx]) 92 | 93 | if image.shape[:2] != self.size[::-1]: 94 | image = cv2.resize(image, self.size, interpolation=cv2.INTER_AREA) 95 | if labels.shape != self.size[::-1]: 96 | labels = cv2.resize(labels, self.size, interpolation=cv2.INTER_NEAREST) 97 | 98 | def remap_labels(labels): 99 | """ 100 | Shift labels according to 101 | https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py#L61 102 | """ 103 | retval = np.zeros_like(labels) 104 | class_map = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] 105 | for new_class, old_class in enumerate(class_map, start=1): 106 | retval[labels == old_class] = new_class 107 | return retval 108 | 109 | # comment this line if you have already preprocessed the labels this way 110 | labels = remap_labels(labels) 111 | 112 | if self.augmented: 113 | image, labels = self.augment(image, labels) 114 | 115 | image = image.transpose((2, 0, 1)).astype(np.float32) 116 | image -= self.mean.reshape(3, 1, 1) 117 | image *= 1 / self.std.reshape(3, 1, 1) 118 | 119 | return torch.tensor(image), torch.tensor(labels, dtype=torch.long) 120 | 121 | 122 | if __name__ == '__main__': 123 | dataset = Cityscapes('/home/shrubb/Datasets/Cityscapes', augmented=True) 124 | print(len(dataset)) 125 | 126 | -------------------------------------------------------------------------------- /examples/Cityscapes/models/ENet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ENet(nn.ModuleList): 6 | def __init__(self, n_classes=19): 7 | super().__init__([ 8 | Downsampler(3, 16), 9 | Bottleneck(16, 64, 0.01, downsample=True), 10 | 11 | Bottleneck(64, 64, 0.01), 12 | Bottleneck(64, 64, 0.01), 13 | Bottleneck(64, 64, 0.01), 14 | Bottleneck(64, 64, 0.01), 15 | 16 | Bottleneck(64, 128, 0.1, downsample=True), 17 | 18 | Bottleneck(128, 128, 0.1), 19 | Bottleneck(128, 128, 0.1, dilation=2), 20 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5), 21 | Bottleneck(128, 128, 0.1, dilation=4), 22 | Bottleneck(128, 128, 0.1), 23 | Bottleneck(128, 128, 0.1, dilation=8), 24 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5), 25 | Bottleneck(128, 128, 0.1, dilation=16), 26 | 27 | Bottleneck(128, 128, 0.1), 28 | Bottleneck(128, 128, 0.1, dilation=2), 29 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5), 30 | Bottleneck(128, 128, 0.1, dilation=4), 31 | Bottleneck(128, 128, 0.1), 32 | Bottleneck(128, 128, 0.1, dilation=8), 33 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5), 34 | Bottleneck(128, 128, 0.1, dilation=16), 35 | 36 | Upsampler(128, 64), 37 | 38 | Bottleneck(64, 64, 0.1), 39 | Bottleneck(64, 64, 0.1), 40 | 41 | Upsampler(64, 16), 42 | 43 | Bottleneck(16, 16, 0.1), 44 | 45 | nn.ConvTranspose2d(16, n_classes+1, (2,2), (2,2))]) 46 | 47 | def forward(self, x): 48 | max_indices_stack = [] 49 | 50 | for module in self: 51 | if isinstance(module, Upsampler): 52 | x = module(x, max_indices_stack.pop()) 53 | else: 54 | x = module(x) 55 | 56 | if type(x) is tuple: # then it was a downsampling bottleneck block 57 | x, max_indices = x 58 | max_indices_stack.append(max_indices) 59 | 60 | return x 61 | 62 | class BoxENet(ENet): 63 | def __init__(self, n_classes=19, max_input_h=512, max_input_w=1024): 64 | h, w = max_input_h, max_input_w # shorten names for convenience 65 | r = 0.860 # reparametrization factor 66 | 67 | nn.ModuleList.__init__(self, [ 68 | Downsampler(3, 16), 69 | Bottleneck(16, 64, 0.01, downsample=True), 70 | 71 | Bottleneck(64, 64, 0.01), 72 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r), 73 | Bottleneck(64, 64, 0.01), 74 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r), 75 | 76 | Bottleneck(64, 128, 0.1, downsample=True), 77 | 78 | Bottleneck(128, 128, 0.1), 79 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r), 80 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5), 81 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r), 82 | Bottleneck(128, 128, 0.1), 83 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r), 84 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5), 85 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r), 86 | 87 | Bottleneck(128, 128, 0.1), 88 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r), 89 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5), 90 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r), 91 | Bottleneck(128, 128, 0.1), 92 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r), 93 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5), 94 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r), 95 | 96 | Upsampler(128, 64), 97 | 98 | Bottleneck(64, 64, 0.1), 99 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.1, reparam_factor=r), 100 | 101 | Upsampler(64, 16), 102 | 103 | BottleneckBoxConv(16, 2, h // 2, w // 2, 0.1, reparam_factor=r), 104 | 105 | nn.ConvTranspose2d(16, n_classes+1, (2,2), (2,2))]) 106 | 107 | class BoxOnlyENet(ENet): 108 | def __init__(self, n_classes=19, max_input_h=512, max_input_w=1024): 109 | h, w = max_input_h, max_input_w # shorten names for convenience 110 | r = 0.510 # reparametrization factor 111 | 112 | nn.ModuleList.__init__(self, [ 113 | Downsampler(3, 16), 114 | Bottleneck(16, 64, 0.01, downsample=True), 115 | 116 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r), 117 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r), 118 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r), 119 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r), 120 | 121 | Bottleneck(64, 128, 0.01, downsample=True), 122 | 123 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 124 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 125 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 126 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 127 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 128 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 129 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 130 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 131 | 132 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 133 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 134 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 135 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 136 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 137 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 138 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 139 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r), 140 | 141 | Upsampler(128, 64), 142 | 143 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r), 144 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r), 145 | 146 | Upsampler(64, 16), 147 | 148 | BottleneckBoxConv(16, 4, h // 2, w // 2, 0.05, reparam_factor=r), 149 | 150 | nn.ConvTranspose2d(16, n_classes+1, (2,2), (2,2))]) 151 | 152 | class ENetMinus(ENet): 153 | def __init__(self, n_classes=19, max_input_h=512, max_input_w=1024): 154 | h, w = max_input_h, max_input_w # shorten names for convenience 155 | r = 0.860 # reparametrization factor 156 | 157 | nn.ModuleList.__init__(self, [ 158 | Downsampler(3, 16), 159 | Bottleneck(16, 64, 0.01, downsample=True), 160 | 161 | Bottleneck(64, 64, 0.01), 162 | Bottleneck(64, 64, 0.01), 163 | 164 | Bottleneck(64, 128, 0.1, downsample=True), 165 | 166 | Bottleneck(128, 128, 0.1), 167 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5), 168 | Bottleneck(128, 128, 0.1), 169 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5), 170 | 171 | Bottleneck(128, 128, 0.1), 172 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5), 173 | Bottleneck(128, 128, 0.1), 174 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5), 175 | 176 | Upsampler(128, 64), 177 | 178 | Bottleneck(64, 64, 0.1), 179 | 180 | Upsampler(64, 16), 181 | 182 | nn.ConvTranspose2d(16, n_classes+1, (2,2), (2,2))]) 183 | 184 | class Upsampler(nn.Module): 185 | def __init__(self, in_channels, out_channels): 186 | super().__init__() 187 | bt_channels = out_channels // 4 188 | 189 | self.main_branch = nn.Sequential( 190 | nn.Conv2d(in_channels, bt_channels, (1,1), bias=False), 191 | nn.BatchNorm2d(bt_channels, 1e-3), 192 | nn.ReLU(True), 193 | 194 | nn.ConvTranspose2d(bt_channels, bt_channels, (3,3), 2, 1, 1), 195 | nn.BatchNorm2d(bt_channels, 1e-3), 196 | nn.ReLU(True), 197 | 198 | nn.Conv2d(bt_channels, out_channels, (1,1), bias=False), 199 | nn.BatchNorm2d(out_channels, 1e-3)) 200 | 201 | self.skip_connection = nn.Sequential( 202 | nn.Conv2d(in_channels, out_channels, (1,1), bias=False), 203 | nn.BatchNorm2d(out_channels, 1e-3)) 204 | 205 | def forward(self, x, max_indices): 206 | x_skip_connection = self.skip_connection(x) 207 | x_skip_connection = F.max_unpool2d(x_skip_connection, max_indices, (2,2)) 208 | 209 | return (x_skip_connection + self.main_branch(x)).relu_() 210 | 211 | class Downsampler(nn.Module): 212 | def __init__(self, in_channels, out_channels): 213 | super().__init__() 214 | self.conv = nn.Conv2d(in_channels, out_channels-in_channels, (3,3), 2, 1, bias=False) 215 | self.bn = nn.BatchNorm2d(out_channels, 1e-3) 216 | self.prelu = nn.PReLU(out_channels) 217 | 218 | def forward(self, x): 219 | x = torch.cat([F.max_pool2d(x, (2,2)), self.conv(x)], 1) 220 | x = self.bn(x) 221 | x = self.prelu(x) 222 | return x 223 | 224 | class Bottleneck(nn.Module): 225 | def __init__(self, in_channels, out_channels, dropout_prob=0.0, downsample=False, 226 | asymmetric_ksize=None, dilation=1, use_prelu=True): 227 | 228 | super().__init__() 229 | bt_channels = out_channels // 4 230 | self.downsample = downsample 231 | self.channels_to_pad = out_channels-in_channels 232 | 233 | input_stride = 2 if downsample else 1 234 | 235 | main_branch = [ 236 | nn.Conv2d(in_channels, bt_channels, input_stride, input_stride, bias=False), 237 | nn.BatchNorm2d(bt_channels, 1e-3), 238 | nn.PReLU(bt_channels) if use_prelu else nn.ReLU(True) 239 | ] 240 | 241 | if asymmetric_ksize is None: 242 | main_branch += [ 243 | nn.Conv2d(bt_channels, bt_channels, (3,3), 1, dilation, dilation, bias=False) 244 | ] 245 | else: 246 | assert type(asymmetric_ksize) is int 247 | ksize, padding = asymmetric_ksize, (asymmetric_ksize-1) // 2 248 | main_branch += [ 249 | nn.Conv2d(bt_channels, bt_channels, (ksize,1), 1, (padding,0), bias=False), 250 | nn.Conv2d(bt_channels, bt_channels, (1,ksize), 1, (0,padding)) 251 | ] 252 | 253 | main_branch += [ 254 | nn.BatchNorm2d(bt_channels, 1e-3), 255 | nn.PReLU(bt_channels) if use_prelu else nn.ReLU(True), 256 | nn.Conv2d(bt_channels, out_channels, (1,1), bias=False), 257 | nn.BatchNorm2d(out_channels, 1e-3), 258 | nn.Dropout2d(dropout_prob) 259 | ] 260 | 261 | self.main_branch = nn.Sequential(*main_branch) 262 | self.output_activation = nn.PReLU(out_channels) if use_prelu else nn.ReLU(True) 263 | 264 | def forward(self, x): 265 | if self.downsample: 266 | x_skip_connection, max_indices = F.max_pool2d(x, (2,2), return_indices=True) 267 | else: 268 | x_skip_connection = x 269 | 270 | if self.channels_to_pad > 0: 271 | x_skip_connection = F.pad(x_skip_connection, (0,0, 0,0, 0,self.channels_to_pad)) 272 | 273 | x = self.output_activation(x_skip_connection + self.main_branch(x)) 274 | 275 | if self.downsample: 276 | return x, max_indices 277 | else: 278 | return x 279 | 280 | from box_convolution import BoxConv2d 281 | 282 | class BottleneckBoxConv(nn.Module): 283 | def __init__(self, in_channels, num_boxes, max_input_h, max_input_w, 284 | dropout_prob=0.0, reparam_factor=1.5625): 285 | 286 | super().__init__() 287 | assert in_channels % num_boxes == 0 288 | bt_channels = in_channels // num_boxes # bottleneck channels 289 | 290 | self.main_branch = nn.Sequential( 291 | nn.Conv2d(in_channels, bt_channels, (1,1), bias=False), 292 | nn.BatchNorm2d(bt_channels), 293 | nn.ReLU(True), 294 | 295 | # BEHOLD: 296 | BoxConv2d( 297 | bt_channels, num_boxes, max_input_h, max_input_w, 298 | reparametrization_factor=reparam_factor), 299 | 300 | nn.BatchNorm2d(in_channels), 301 | nn.Dropout2d(dropout_prob)) 302 | 303 | def forward(self, x): 304 | return (x + self.main_branch(x)).relu_() 305 | 306 | -------------------------------------------------------------------------------- /examples/Cityscapes/models/ERFNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ERFNet(nn.Sequential): 6 | def __init__(self, n_classes=19): 7 | super().__init__( 8 | Downsampler( 3, 16, 0.0 ), 9 | Downsampler(16, 64, 0.03), 10 | 11 | NonBottleneck1D(64, 0.03), 12 | NonBottleneck1D(64, 0.03), 13 | NonBottleneck1D(64, 0.03), 14 | NonBottleneck1D(64, 0.03), 15 | NonBottleneck1D(64, 0.03), 16 | 17 | Downsampler(64, 128, 0.3), 18 | 19 | NonBottleneck1D(128, 0.3, 2), 20 | NonBottleneck1D(128, 0.3, 4), 21 | NonBottleneck1D(128, 0.3, 8), 22 | NonBottleneck1D(128, 0.3, 16), 23 | NonBottleneck1D(128, 0.3, 2), 24 | NonBottleneck1D(128, 0.3, 4), 25 | NonBottleneck1D(128, 0.3, 8), 26 | NonBottleneck1D(128, 0.3, 16), 27 | 28 | Upsampler(128, 64), 29 | 30 | NonBottleneck1D(64), 31 | NonBottleneck1D(64), 32 | 33 | Upsampler(64, 16), 34 | 35 | NonBottleneck1D(16), 36 | NonBottleneck1D(16), 37 | 38 | nn.ConvTranspose2d(16, n_classes+1, (3,3), 2, 1, 1)) 39 | 40 | class BoxERFNet(nn.Sequential): 41 | def __init__(self, n_classes=19, max_input_h=512, max_input_w=1024): 42 | h, w = max_input_h, max_input_w # shorten names for convenience 43 | 44 | super().__init__( 45 | Downsampler( 3, 16, 0.0 ), 46 | Downsampler(16, 64, 0.03), 47 | 48 | NonBottleneck1D(64, 0.03), 49 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.03), 50 | 51 | Downsampler(64, 128, 0.3), 52 | 53 | NonBottleneck1D(128, 0.3, 2), 54 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.3), 55 | NonBottleneck1D(128, 0.3, 4), 56 | 57 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.3), 58 | 59 | NonBottleneck1D(128, 0.3, 2), 60 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.3), 61 | NonBottleneck1D(128, 0.3, 4), 62 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.3), 63 | 64 | Upsampler(128, 64), 65 | 66 | NonBottleneck1D(64), 67 | 68 | Upsampler(64, 16), 69 | 70 | NonBottleneck1D(16), 71 | 72 | nn.ConvTranspose2d(16, n_classes+1, (3,3), 2, 1, 1)) 73 | 74 | def Upsampler(in_channels, out_channels): 75 | return nn.Sequential( 76 | nn.ConvTranspose2d(in_channels, out_channels, (3,3), 2, 1, 1, bias=False), 77 | nn.BatchNorm2d(out_channels), 78 | nn.ReLU(inplace=True)) 79 | 80 | class Downsampler(nn.Module): 81 | def __init__(self, in_channels, out_channels, dropout_prob=0.0): 82 | super().__init__() 83 | self.conv = nn.Conv2d(in_channels, out_channels-in_channels, (3,3), 2, 1, bias=False) 84 | self.bn = nn.BatchNorm2d(out_channels) 85 | self.dropout = nn.Dropout2d(dropout_prob) 86 | 87 | def forward(self, x): 88 | x = torch.cat([F.max_pool2d(x, (2,2)), self.conv(x)], 1) 89 | x = self.bn(x) 90 | x = self.dropout(x) 91 | x = F.relu(x, inplace=True) 92 | return x 93 | 94 | class NonBottleneck1D(nn.Module): 95 | def __init__(self, in_channels, dropout_prob=0.0, dilation=1): 96 | super().__init__() 97 | dil = dilation # shorten the name for convenience 98 | 99 | self.main_branch = nn.Sequential( 100 | nn.Conv2d(in_channels, in_channels, (3,1), 1, (1,0), bias=False), 101 | nn.ReLU(True), 102 | nn.Conv2d(in_channels, in_channels, (1,3), 1, (0,1), bias=False), 103 | nn.BatchNorm2d(in_channels), 104 | nn.ReLU(True), 105 | 106 | nn.Conv2d(in_channels, in_channels, (3,1), 1, (dil,0), (dil,dil), bias=False), 107 | nn.ReLU(True), 108 | nn.Conv2d(in_channels, in_channels, (1,3), 1, (0,dil), (dil,dil), bias=False), 109 | nn.BatchNorm2d(in_channels), 110 | nn.Dropout2d(dropout_prob)) 111 | 112 | def forward(self, x): 113 | return F.relu(x + self.main_branch(x), inplace=True) 114 | 115 | from box_convolution import BoxConv2d 116 | 117 | class BottleneckBoxConv(nn.Module): 118 | def __init__(self, in_channels, num_boxes, max_input_h, max_input_w, dropout_prob=0.0): 119 | super().__init__() 120 | 121 | assert in_channels % num_boxes == 0 122 | bt_channels = in_channels // num_boxes # bottleneck channels 123 | 124 | self.main_branch = nn.Sequential( 125 | nn.Conv2d(in_channels, bt_channels, (1,1), bias=False), 126 | nn.BatchNorm2d(bt_channels), 127 | 128 | # BEHOLD: 129 | BoxConv2d( 130 | bt_channels, num_boxes, max_input_h, max_input_w, 131 | reparametrization_factor=1.5625), 132 | 133 | nn.BatchNorm2d(in_channels), 134 | nn.Dropout2d(dropout_prob)) 135 | 136 | def forward(self, x): 137 | return F.relu(x + self.main_branch(x), inplace=True) 138 | -------------------------------------------------------------------------------- /examples/Cityscapes/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim 10 | import torch.utils.data 11 | 12 | import tensorboardX 13 | 14 | from models.ERFNet import ERFNet, BoxERFNet 15 | from models.ENet import ENet, BoxENet, BoxOnlyENet 16 | 17 | model_names = ['ENet', 'BoxENet', 'BoxOnlyENet', 'ERFNet', 'BoxERFNet'] 18 | 19 | parser = argparse.ArgumentParser( 20 | description='Simple Cityscapes semantic segmentation training') 21 | 22 | parser.add_argument('data', 23 | help='path to dataset') 24 | parser.add_argument('--arch', '-a', default='BoxENet', choices=model_names, 25 | help='model architecture: ' + ' | '.join(model_names) + ' (default: BoxENet)') 26 | parser.add_argument('-j', '--workers', default=4, type=int, 27 | help='number of data loading workers (default: 4)') 28 | parser.add_argument('--epochs', default=1300, type=int, 29 | help='number of total epochs to run') 30 | parser.add_argument('-b', '--batch-size', default=12, type=int, 31 | help='mini-batch size (default: 12)') 32 | 33 | # optimizer 34 | parser.add_argument('--optimizer', default='Adam', type=str, 35 | help='optimizer type (default: Adam)') 36 | parser.add_argument('--lr', '--learning-rate', default=None, type=float, # 1e-3 37 | help='initial learning rate') 38 | parser.add_argument('--lr-decay', default=0.45, type=float, 39 | help='See --decay-steps') 40 | parser.add_argument('--decay-steps', default='', type=str, # "600,780,880,960,1040,1120,1200,1260" 41 | help='Comma-separated epoch numbers at which to multiply LR by --lr-decay') 42 | parser.add_argument('--momentum', default=None, type=float, # 0.9 43 | help='momentum') 44 | parser.add_argument('--nesterov', default=None, type=bool, # True 45 | help='use Nesterov momentum?') 46 | parser.add_argument('--weight-decay', '--wd', default=None, type=float, # 2e-4 47 | help='weight decay') 48 | 49 | parser.add_argument('--run-name', default=time.ctime(time.time())[4:-8], type=str, 50 | help='path to latest checkpoint (default: none)') 51 | parser.add_argument('--resume', action='store_true', 52 | help='resume from latest checkpoint') 53 | parser.add_argument('--start-epoch', default=0, type=int, 54 | help='Manual epoch number (useful on restarts)') 55 | 56 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 57 | help='Do not train, only evaluate model on the validation set') 58 | 59 | best_classIoU = 0 60 | 61 | 62 | def main(): 63 | global args, best_classIoU 64 | args = parser.parse_args() 65 | 66 | random.seed(666) 67 | torch.manual_seed(666) 68 | 69 | from datasets import Cityscapes 70 | 71 | train_dataset = Cityscapes( 72 | args.data, split='train', size=(1024, 512), augmented=True) 73 | 74 | train_loader = torch.utils.data.DataLoader( 75 | train_dataset, batch_size=args.batch_size, shuffle=True, 76 | num_workers=args.workers, pin_memory=True) 77 | train_loader.iter = iter(train_loader) 78 | 79 | val_dataset = Cityscapes( 80 | args.data, split='val', size=(1024, 512), augmented=False) 81 | 82 | val_loader = torch.utils.data.DataLoader( 83 | val_dataset, batch_size=args.batch_size, shuffle=False, 84 | num_workers=args.workers, pin_memory=True) 85 | 86 | # create model and append `log_softmax` to it, to split this part of Criterion between GPUs 87 | print('Architecture:', args.arch) 88 | class ModelWithLogSoftmax(globals()[args.arch]): 89 | def forward(self, x): 90 | heatmap_raw = super().forward(x) 91 | return torch.nn.functional.log_softmax(heatmap_raw, dim=1) 92 | 93 | if 'Box' in args.arch: 94 | model = ModelWithLogSoftmax( 95 | n_classes=19, max_input_h=512, max_input_w=1024) 96 | else: 97 | model = ModelWithLogSoftmax(n_classes=19) 98 | model.cuda() 99 | 100 | # define loss function (criterion) and optimizer 101 | criterion = nn.NLLLoss(weight=train_dataset.class_weights).cuda() 102 | 103 | optimizer_kwargs = {hyperparam: getattr(args, hyperparam) \ 104 | for hyperparam in ('lr', 'weight_decay', 'momentum', 'nesterov') \ 105 | if getattr(args, hyperparam) is not None} 106 | optimizer = torch.optim.__dict__[args.optimizer](model.parameters(), **optimizer_kwargs) 107 | 108 | # optionally resume from a checkpoint 109 | if args.resume: 110 | checkpoint_path = os.path.join('runs', args.run_name, 'model.pth') 111 | if os.path.isfile(checkpoint_path): 112 | print("=> loading checkpoint '{}'".format(checkpoint_path)) 113 | checkpoint = torch.load(checkpoint_path) 114 | args.start_epoch = checkpoint['epoch'] 115 | best_classIoU = checkpoint['best_classIoU'] 116 | 117 | model.load_state_dict(checkpoint['state_dict']) 118 | optimizer.load_state_dict(checkpoint['optimizer']) 119 | print("=> loaded checkpoint '{}' (epoch {})" 120 | .format(checkpoint_path, checkpoint['epoch'])) 121 | else: 122 | raise FileNotFoundError("=> no checkpoint found at '{}'".format(checkpoint_path)) 123 | 124 | torch.backends.cudnn.benchmark = True 125 | model = torch.nn.DataParallel(model).cuda() 126 | 127 | board_writer = tensorboardX.SummaryWriter(os.path.join('runs', args.run_name)) 128 | 129 | if args.evaluate: 130 | torch.backends.cudnn.deterministic = True 131 | val_loader.iter = iter(val_loader) 132 | validate(val_loader, model, criterion) 133 | return 134 | 135 | # warm up to save GPU memory 136 | #sample_input = torch.zeros( 137 | # (12, 3, 512, 1024), device='cuda', dtype=torch.float32, requires_grad=True) 138 | #(model(sample_input).sum() * 0).backward() 139 | 140 | for epoch in range(args.start_epoch, args.epochs): 141 | print('Epoch', epoch+1) 142 | 143 | val_loader.iter = iter(val_loader) 144 | # train for one epoch 145 | train_metrics = train(train_loader, model, criterion, optimizer, epoch, board_writer) 146 | 147 | train_loader.iter = iter(train_loader) 148 | # evaluate on validation set 149 | val_metrics = validate(val_loader, model, criterion) 150 | 151 | # record metrics to tensorboard 152 | for tra,val,name in zip(train_metrics, val_metrics, ('Class IoU', 'Category IoU', 'Loss')): 153 | board_writer.add_scalars(name, {'Train': tra, 'Test': val}, epoch+1) 154 | 155 | # remember best score and save checkpoint 156 | classIoU = val_metrics[0] 157 | is_best = classIoU > best_classIoU 158 | best_classIoU = max(classIoU, best_classIoU) 159 | save_checkpoint({ 160 | 'epoch': epoch + 1, 161 | 'arch': args.arch, 162 | 'state_dict': model.module.state_dict(), 163 | 'best_classIoU': best_classIoU, 164 | 'optimizer' : optimizer.state_dict(), 165 | }, is_best, os.path.join('runs', args.run_name, 'model.pth')) 166 | 167 | 168 | def train(train_loader, model, criterion, optimizer, epoch, board_writer): 169 | batch_time = AverageMeter() 170 | data_time = AverageMeter() 171 | loss_meter = AverageMeter() 172 | 173 | n_classes = criterion.weight.numel()-1 174 | confusion_matrix = torch.zeros((n_classes, n_classes), device='cuda', dtype=torch.long) 175 | 176 | # switch to train mode 177 | model.train() 178 | 179 | total_batches = len(train_loader) 180 | 181 | end = time.time() 182 | for i, (input, target) in enumerate(train_loader.iter): 183 | if i > total_batches: break 184 | 185 | # measure data loading time 186 | data_time.update(time.time() - end) 187 | input = input.cuda(non_blocking=True) 188 | target = target.cuda(non_blocking=True) 189 | # compute output 190 | output = model(input) 191 | loss = criterion(output, target) 192 | # compute gradient and do SGD step 193 | global_iteration = epoch * total_batches + i 194 | adjust_learning_rate(optimizer, epoch, i, total_batches) 195 | optimizer.zero_grad() 196 | loss.backward() 197 | optimizer.step() 198 | 199 | # measure elapsed time 200 | batch_time.update(time.time() - end) 201 | 202 | with torch.no_grad(): 203 | if i % 100 == 0: 204 | """import numpy as np 205 | image = input[0].cpu().numpy().transpose((1,2,0)).copy() 206 | image -= image.min() 207 | image /= image.max() 208 | image *= 255. 209 | image = image.astype(np.uint8) 210 | output_map = output[0].max(0)[1].cpu().numpy() 211 | import imgaug 212 | segmap_display = imgaug.SegmentationMapOnImage(output_map, image.shape[:2], 20).draw_on_image(image) 213 | segmap_display = segmap_display.transpose((2,0,1)).copy() 214 | 215 | board_writer.add_image('Example segmentation', segmap_display, global_iteration)""" 216 | 217 | # update confusion matrix to compute IoU 218 | output = output.max(1)[1].view(-1).to(confusion_matrix) 219 | target = target.view(-1).to(output) 220 | 221 | confusion_matrix_update = \ 222 | torch.bincount(target*(n_classes+1) + output, minlength=(n_classes+1)**2) 223 | confusion_matrix += \ 224 | confusion_matrix_update.view(n_classes+1, n_classes+1)[1:,1:].to(confusion_matrix) 225 | 226 | loss_meter.update(loss.item(), input.size(0)) 227 | 228 | board_writer.add_scalar('Learning rate', 229 | next(iter(optimizer.param_groups))['lr'], global_iteration) 230 | board_writer.add_scalars('Time', 231 | {'Total batch time': batch_time.val, 232 | 'Data loading time': data_time.val}, global_iteration) 233 | board_writer.add_scalar('Online batch loss', loss_meter.val, global_iteration) 234 | 235 | end = time.time() 236 | 237 | classIoU, categIoU = compute_IoU(confusion_matrix.cpu()) 238 | return classIoU, categIoU, loss_meter.avg 239 | 240 | 241 | def validate(val_loader, model, criterion): 242 | batch_time = AverageMeter() 243 | loss_meter = AverageMeter() 244 | 245 | # switch to evaluate mode 246 | model.eval() 247 | 248 | n_classes = criterion.weight.numel()-1 249 | confusion_matrix = torch.zeros((n_classes, n_classes), device='cuda', dtype=torch.long) 250 | 251 | with torch.no_grad(): 252 | end = time.time() 253 | for i, (input, target) in enumerate(val_loader.iter): 254 | input = input.cuda(non_blocking=True) 255 | target = target.cuda(non_blocking=True) 256 | 257 | # compute output 258 | output = model(input) 259 | loss = criterion(output, target) 260 | 261 | # record loss 262 | loss_meter.update(loss.item(), input.size(0)) 263 | 264 | # update confusion matrix to compute IoU 265 | output = output.max(1)[1].view(-1) 266 | target = target.view(-1).to(output) 267 | 268 | confusion_matrix_update = \ 269 | torch.bincount(target*(n_classes+1) + output, minlength=(n_classes+1)**2) 270 | confusion_matrix += \ 271 | confusion_matrix_update.view(n_classes+1, n_classes+1)[1:,1:].to(confusion_matrix) 272 | 273 | # measure elapsed time 274 | batch_time.update(time.time() - end) 275 | end = time.time() 276 | 277 | classIoU, categoryIoU = compute_IoU(confusion_matrix.cpu()) 278 | 279 | print('Class IoU:', classIoU) 280 | print('Category IoU:', categoryIoU) 281 | return classIoU, categoryIoU, loss_meter.avg 282 | 283 | 284 | def save_checkpoint(state, is_best, filename='checkpoint.pth'): 285 | assert filename.endswith('.pth') 286 | torch.save(state, filename) 287 | if is_best: 288 | shutil.copyfile(filename, filename[:-4] + '_best.pth') 289 | 290 | 291 | class AverageMeter(object): 292 | """Computes and stores the average and current value""" 293 | def __init__(self): 294 | self.reset() 295 | 296 | def reset(self): 297 | self.val = 0 298 | self.avg = 0 299 | self.sum = 0 300 | self.count = 0 301 | 302 | def update(self, val, n=1): 303 | self.val = val 304 | self.sum += val * n 305 | self.count += n 306 | self.avg = self.sum / self.count 307 | 308 | 309 | def adjust_learning_rate(optimizer, epoch, epoch_iteration=None, iters_per_epoch=None): 310 | step_epochs = torch.tensor(list(map(int, args.decay_steps.split(',')))) 311 | lr = args.lr * (args.lr_decay ** (epoch > step_epochs).sum().item()) 312 | 313 | for param_group in optimizer.param_groups: 314 | param_group['lr'] = lr 315 | 316 | 317 | def compute_IoU(confusion_matrix): 318 | n_classes = confusion_matrix.shape[0] 319 | 320 | with torch.no_grad(): 321 | class_categories = [ 322 | [0, 1], # flat 323 | [2, 3, 4], # construction 324 | [5, 6, 7], # object 325 | [8, 9], # nature 326 | [10], # sky 327 | [11, 12], # human 328 | [13, 14, 15, 16, 17, 18], # vehicle 329 | ] 330 | 331 | classIoU = torch.empty(n_classes, dtype=torch.float32) 332 | for class_idx in range(n_classes): 333 | TP = confusion_matrix[class_idx, class_idx].item() 334 | FN = confusion_matrix[class_idx, :].sum().item() - TP 335 | FP = confusion_matrix[:, class_idx].sum().item() - TP 336 | 337 | classIoU[class_idx] = TP / max(TP + FP + FN, 1) 338 | 339 | categoryIoU = torch.empty(len(class_categories), dtype=torch.float32) 340 | for category_idx, category in enumerate(class_categories): 341 | TP = 0 342 | for class_idx in category: 343 | TP += confusion_matrix[class_idx, category].sum().item() 344 | 345 | FN, FP = -TP, -TP 346 | for class_idx in category: 347 | FN += confusion_matrix[class_idx, :].sum().item() 348 | FP += confusion_matrix[:, class_idx].sum().item() 349 | 350 | categoryIoU[category_idx] = TP / max(TP + FP + FN, 1) 351 | 352 | return classIoU.mean(), categoryIoU.mean() 353 | 354 | 355 | if __name__ == '__main__': 356 | main() 357 | -------------------------------------------------------------------------------- /examples/mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script trains a very simple box convnet on MNIST. 3 | If OpenCV `videoio` is available, also outputs an animation 4 | of boxes' evolution to 'mnist-boxes.avi'. 5 | Based on https://github.com/pytorch/examples/blob/master/mnist/main.py 6 | """ 7 | from __future__ import print_function 8 | import argparse 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from torchvision import datasets, transforms 14 | 15 | from box_convolution import BoxConv2d 16 | 17 | class Net(nn.Module): 18 | def __init__(self): 19 | super(Net, self).__init__() 20 | self.conv1 = BoxConv2d(1, 40, 28, 28) 21 | self.conv1_1x1 = nn.Conv2d(40, 40, 1, 1) 22 | 23 | self.fc1 = nn.Linear(7*7*40, 10) 24 | 25 | def forward(self, x): 26 | # The following line computes responses to 40 "generalized Haar filters" 27 | x = self.conv1_1x1(self.conv1(x)) 28 | x = F.relu(F.max_pool2d(x, 4)) 29 | 30 | x = self.fc1(x.view(-1, 7*7*40)) 31 | return F.log_softmax(x, dim=1) 32 | 33 | try: 34 | import cv2 35 | box_video_resolution = (300, 300) 36 | box_video = cv2.VideoWriter( 37 | 'mnist-boxes.avi', cv2.VideoWriter_fourcc(*'MJPG'), 25, tuple(reversed(box_video_resolution))) 38 | box_video_frame_count = 0 39 | video_background = None # to be defined in `main()`, sorry for globals and messy code 40 | except ImportError: 41 | box_video = None 42 | print('Couldn\'t import OpenCV. Will not log boxes to a video file') 43 | 44 | 45 | def train(model, device, train_loader, optimizer, epoch): 46 | model.train() 47 | for batch_idx, (data, target) in enumerate(train_loader): 48 | 49 | # log boxes to a video file 50 | if box_video is not None: 51 | global box_video_frame_count 52 | 53 | # change video background 54 | if box_video_frame_count % 5 == 0: 55 | global video_background # defined at the top for beautiful box visualization 56 | sample_idx = torch.randint(len(train_loader.dataset), (1,)).item() 57 | sample_digit = train_loader.dataset[sample_idx][0] 58 | video_background = torch.nn.functional.pad(sample_digit, (14,14,14,14)) 59 | video_background = torch.nn.functional.interpolate( 60 | video_background.unsqueeze(0), size=box_video_resolution, mode='nearest')[0,0] 61 | video_background = video_background.unsqueeze(-1).repeat(1, 1, 3) 62 | video_background = video_background.mul(255).round().byte().numpy() 63 | 64 | # log boxes to the video file 65 | if batch_idx % 5 == 0: 66 | box_importances = model.conv1_1x1.weight.detach().float().abs().max(0)[0].squeeze() 67 | box_importances /= box_importances.max() 68 | boxes_plot = model.conv1.draw_boxes( 69 | resolution=box_video_resolution, weights=box_importances) 70 | box_video.write(cv2.addWeighted(boxes_plot, 1.0, video_background, 0.25, 0.0)) 71 | box_video_frame_count += 1 72 | 73 | data, target = data.to(device), target.to(device) 74 | optimizer.zero_grad() 75 | output = model(data) 76 | loss = F.nll_loss(output, target) 77 | loss.backward() 78 | optimizer.step() 79 | if batch_idx % 100 == 0: 80 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 81 | epoch, batch_idx * len(data), len(train_loader.dataset), 82 | 100. * batch_idx / len(train_loader), loss.item())) 83 | 84 | for g in optimizer.param_groups: 85 | g['lr'] *= 0.999 86 | 87 | def test(model, device, test_loader): 88 | model.eval() 89 | test_loss = 0 90 | correct = 0 91 | with torch.no_grad(): 92 | for data, target in test_loader: 93 | data, target = data.to(device), target.to(device) 94 | output = model(data) 95 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 96 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 97 | correct += pred.eq(target.view_as(pred)).sum().item() 98 | 99 | test_loss /= len(test_loader.dataset) 100 | 101 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 102 | test_loss, correct, len(test_loader.dataset), 103 | 100. * correct / len(test_loader.dataset))) 104 | 105 | def main(): 106 | # Training settings 107 | use_cuda = torch.cuda.is_available() 108 | batch_size = 64 109 | n_epochs = 10 110 | 111 | torch.manual_seed(666) 112 | 113 | device = torch.device('cuda' if use_cuda else 'cpu') 114 | 115 | mnist_train = datasets.MNIST( 116 | './', train=True, download=True, transform=transforms.ToTensor()) 117 | mnist_test = datasets.MNIST( 118 | './', train=False, transform=transforms.ToTensor()) 119 | 120 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 121 | train_loader = torch.utils.data.DataLoader( 122 | mnist_train, batch_size=batch_size, shuffle=True, **kwargs) 123 | test_loader = torch.utils.data.DataLoader( 124 | mnist_test, batch_size=batch_size, shuffle=True, **kwargs) 125 | 126 | model = Net().to(device) 127 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 128 | 129 | for epoch in range(1, n_epochs + 1): 130 | train(model, device, train_loader, optimizer, epoch) 131 | test(model, device, test_loader) 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # monkey-patch for parallel compilation 2 | # https://stackoverflow.com/questions/11013851/speeding-up-build-process-with-distutils 3 | def parallelCCompile( 4 | self, sources, output_dir=None, macros=None, include_dirs=None, debug=0, 5 | extra_preargs=None, extra_postargs=None, depends=None): 6 | 7 | # those lines are copied from distutils.ccompiler.CCompiler directly 8 | macros, objects, extra_postargs, pp_opts, build = \ 9 | self._setup_compile(output_dir, macros, include_dirs, sources, depends, extra_postargs) 10 | cc_args = self._get_cc_args(pp_opts, debug, extra_preargs) 11 | 12 | # parallel code 13 | from multiprocessing import cpu_count 14 | try: 15 | n_processes = cpu_count() # number of parallel compilations 16 | except NotImplementedError: 17 | print('multiprocessing.cpu_count() failed, building on 1 core') 18 | n_processes = 1 19 | 20 | def _single_compile(obj): 21 | try: src, ext = build[obj] 22 | except KeyError: return 23 | self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts) 24 | 25 | import multiprocessing.pool 26 | multiprocessing.pool.ThreadPool(n_processes).map(_single_compile, objects) 27 | 28 | return objects 29 | 30 | import distutils.ccompiler 31 | distutils.ccompiler.CCompiler.compile = parallelCCompile 32 | 33 | import torch 34 | 35 | build_cuda = torch.cuda.is_available() # TODO allow cross-compiling too 36 | 37 | source_root = 'src' 38 | source_files_cpp = [ 39 | 'integral_image_interface.cpp', 40 | 'integral_image.cpp', 41 | 'box_convolution_interface.cpp', 42 | 'box_convolution.cpp', 43 | 'bind.cpp' 44 | ] 45 | source_files_cuda = [ 46 | 'integral_image_cuda.cu', 47 | 'box_convolution_cuda_forward.cu', 48 | 'box_convolution_cuda_backward.cu', 49 | 'box_convolution_cuda_misc.cu' 50 | ] 51 | source_files_cuda_stubs = [ 52 | 'cuda_stubs.cpp' 53 | ] 54 | source_files = source_files_cpp + (source_files_cuda if build_cuda else source_files_cuda_stubs) 55 | 56 | from torch.utils.cpp_extension import CppExtension, CUDAExtension 57 | import os 58 | 59 | extra_compile_args = {'cxx': [], 'nvcc': []} 60 | if os.getenv('CC'): 61 | # temporary hack to allow choosing a different host compiler for NVCC too 62 | extra_compile_args['nvcc'] += ['-ccbin', os.getenv('CC')] 63 | 64 | cpp_cuda = (CUDAExtension if build_cuda else CppExtension)( 65 | name='box_convolution_cpp_cuda', 66 | sources=[os.path.join(source_root, file) for file in source_files], 67 | include_dirs=[source_root], 68 | extra_compile_args=extra_compile_args 69 | ) 70 | 71 | from setuptools import setup 72 | 73 | setup( 74 | name='box_convolution', 75 | packages=['box_convolution'], 76 | ext_modules=[cpp_cuda], 77 | cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension}, 78 | install_requires=['future', 'torch>=1.0.0a0'] 79 | ) 80 | -------------------------------------------------------------------------------- /src/bind.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | at::Tensor integral_image( 4 | at::Tensor input); 5 | 6 | at::Tensor box_convolution_forward( 7 | at::Tensor input_integrated, 8 | at::Tensor x_min, at::Tensor x_max, 9 | at::Tensor y_min, at::Tensor y_max, 10 | const bool normalize, const bool exact); 11 | 12 | std::vector box_convolution_backward( 13 | at::Tensor input_integrated, 14 | at::Tensor x_min, at::Tensor x_max, 15 | at::Tensor y_min, at::Tensor y_max, 16 | at::Tensor grad_output, at::Tensor output, 17 | const float reparametrization_h, const float reparametrization_w, 18 | const bool normalize, const bool exact, 19 | const bool input_needs_grad, 20 | const bool x_min_needs_grad, const bool x_max_needs_grad, 21 | const bool y_min_needs_grad, const bool y_max_needs_grad); 22 | 23 | void clip_parameters( 24 | at::Tensor x_min, at::Tensor x_max, 25 | at::Tensor y_min, at::Tensor y_max, 26 | const double reparametrization_h, const double reparametrization_w, 27 | const double max_input_h, const double max_input_w, const bool exact); 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("integral_image", &integral_image, "Integral image"); 31 | m.def("box_convolution_forward" , &box_convolution_forward , "Box convolution, forward" ); 32 | m.def("box_convolution_backward", &box_convolution_backward, "Box convolution, backward"); 33 | m.def("clip_parameters", &clip_parameters, "Box convolution, clip parameters"); 34 | } 35 | -------------------------------------------------------------------------------- /src/box_convolution.cpp: -------------------------------------------------------------------------------- 1 | /* CPU implementations of the functions that really operate actual tensor data 2 | * on a low level. Used by those in `box_convolution_interface.cpp`. */ 3 | 4 | #include 5 | 6 | using std::min; 7 | using std::max; 8 | 9 | #include "box_convolution.h" // for `enum class Parameter` 10 | 11 | namespace cpu { 12 | 13 | // Splits x_min, x_max, y_min, y_max into integer and fractional parts 14 | void splitParameters( 15 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 16 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 17 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac) { 18 | 19 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_min.scalar_type(), "cpu::splitParameters", ([&] { 20 | scalar_t minInt, maxInt; 21 | 22 | for (int i = 0; i < x_min.numel(); ++i) { 23 | minInt = std::ceil(x_min.data_ptr()[i]); 24 | xMinFrac.data_ptr()[i] = minInt - x_min.data_ptr()[i]; 25 | xMinInt.data_ptr()[i] = static_cast(minInt); 26 | 27 | minInt = std::ceil(y_min.data_ptr()[i]); 28 | yMinFrac.data_ptr()[i] = minInt - y_min.data_ptr()[i]; 29 | yMinInt.data_ptr()[i] = static_cast(minInt); 30 | 31 | maxInt = std::floor(x_max.data_ptr()[i]); 32 | xMaxFrac.data_ptr()[i] = x_max.data_ptr()[i] - maxInt; 33 | xMaxInt.data_ptr()[i] = static_cast(maxInt) + 1; 34 | 35 | maxInt = std::floor(y_max.data_ptr()[i]); 36 | yMaxFrac.data_ptr()[i] = y_max.data_ptr()[i] - maxInt; 37 | yMaxInt.data_ptr()[i] = static_cast(maxInt) + 1; 38 | } 39 | })); 40 | } 41 | 42 | // A special parameters' split for backward pass wrt input 43 | void splitParametersUpdateGradInput( 44 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 45 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 46 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac) { 47 | 48 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_min.scalar_type(), "cpu::splitParametersUpdateGradInput", ([&] { 49 | scalar_t minInt, maxInt; 50 | 51 | for (int i = 0; i < x_min.numel(); ++i) { 52 | minInt = std::ceil(-x_max.data_ptr()[i]); 53 | xMinFrac.data_ptr()[i] = minInt + x_max.data_ptr()[i]; 54 | xMinInt.data_ptr()[i] = static_cast(minInt); 55 | 56 | minInt = std::ceil(-y_max.data_ptr()[i]); 57 | yMinFrac.data_ptr()[i] = minInt + y_max.data_ptr()[i]; 58 | yMinInt.data_ptr()[i] = static_cast(minInt); 59 | 60 | maxInt = std::floor(-x_min.data_ptr()[i]) + 1; 61 | xMaxFrac.data_ptr()[i] = -x_min.data_ptr()[i] + 1 - maxInt; 62 | xMaxInt.data_ptr()[i] = static_cast(maxInt); 63 | 64 | maxInt = std::floor(-y_min.data_ptr()[i]) + 1; 65 | yMaxFrac.data_ptr()[i] = -y_min.data_ptr()[i] + 1 - maxInt; 66 | yMaxInt.data_ptr()[i] = static_cast(maxInt); 67 | } 68 | })); 69 | } 70 | 71 | // A special parameters' split for backward pass wrt x_min, x_max, y_min, y_max 72 | void splitParametersAccGradParameters( 73 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 74 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 75 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac) { 76 | 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_min.scalar_type(), "cpu::splitParametersAccGradParams", ([&] { 78 | scalar_t minInt, maxInt; 79 | 80 | for (int i = 0; i < x_min.numel(); ++i) { 81 | minInt = std::ceil(x_min.data_ptr()[i] - 1); 82 | xMinFrac.data_ptr()[i] = minInt - x_min.data_ptr()[i] + 1; 83 | xMinInt.data_ptr()[i] = static_cast(minInt); 84 | 85 | minInt = std::ceil(y_min.data_ptr()[i] - 1); 86 | yMinFrac.data_ptr()[i] = minInt - y_min.data_ptr()[i] + 1; 87 | yMinInt.data_ptr()[i] = static_cast(minInt); 88 | 89 | maxInt = std::floor(x_max.data_ptr()[i]); 90 | xMaxFrac.data_ptr()[i] = x_max.data_ptr()[i] - maxInt; 91 | xMaxInt.data_ptr()[i] = static_cast(maxInt); 92 | 93 | maxInt = std::floor(y_max.data_ptr()[i]); 94 | yMaxFrac.data_ptr()[i] = y_max.data_ptr()[i] - maxInt; 95 | yMaxInt.data_ptr()[i] = static_cast(maxInt); 96 | } 97 | })); 98 | } 99 | 100 | template 101 | void boxConvUpdateOutput( 102 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 103 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 104 | at::Tensor & area, at::Tensor & input_integrated, at::Tensor & output) { 105 | 106 | // was `const int`, but had to remove `const` to work around a bug in GCC 5 107 | int h = output.size(-2); 108 | int w = output.size(-1); 109 | 110 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(output.scalar_type(), "cpu::boxConvUpdateOutput", ([&] { 111 | auto xMinIntAcsr = xMinInt.accessor(); 112 | auto xMaxIntAcsr = xMaxInt.accessor(); 113 | auto yMinIntAcsr = yMinInt.accessor(); 114 | auto yMaxIntAcsr = yMaxInt.accessor(); 115 | 116 | auto xMinFracAcsr = xMinFrac.accessor(); 117 | auto xMaxFracAcsr = xMaxFrac.accessor(); 118 | auto yMinFracAcsr = yMinFrac.accessor(); 119 | auto yMaxFracAcsr = yMaxFrac.accessor(); 120 | 121 | auto areaAcsr = xMinFracAcsr; // because there's no default ctor :( 122 | // only initialize the accessor if `area` is defined (errors otherwise) 123 | if (normalize) { 124 | areaAcsr = area.accessor(); 125 | } 126 | 127 | scalar_t *outputData = output.data_ptr(); 128 | 129 | for (int batchIdx = 0; batchIdx < input_integrated.size(0); ++batchIdx) { 130 | for (int inPlaneIdx = 0; inPlaneIdx < input_integrated.size(1); ++inPlaneIdx) { 131 | auto inputIntPlane = input_integrated[batchIdx][inPlaneIdx]; 132 | auto inputIntAcsr = inputIntPlane.accessor(); 133 | 134 | for (int filterIdx = 0; filterIdx < xMinInt.size(1); ++filterIdx) { 135 | 136 | // TODO make a separate loop for each 2D array access? 137 | for (int x = 0; x < h; ++x) { 138 | for (int y = 0; y < w; ++y) { 139 | const int xMinCurr = xMinIntAcsr[inPlaneIdx][filterIdx]; 140 | const int xMaxCurr = xMaxIntAcsr[inPlaneIdx][filterIdx]; 141 | const int yMinCurr = yMinIntAcsr[inPlaneIdx][filterIdx]; 142 | const int yMaxCurr = yMaxIntAcsr[inPlaneIdx][filterIdx]; 143 | 144 | const scalar_t xMinCurrFrac = xMinFracAcsr[inPlaneIdx][filterIdx]; 145 | const scalar_t xMaxCurrFrac = xMaxFracAcsr[inPlaneIdx][filterIdx]; 146 | const scalar_t yMinCurrFrac = yMinFracAcsr[inPlaneIdx][filterIdx]; 147 | const scalar_t yMaxCurrFrac = yMaxFracAcsr[inPlaneIdx][filterIdx]; 148 | 149 | // Must add 1 to xMax/yMax/xMin/yMin due to OpenCV's 150 | // `integral()` behavior. Namely, I(x,0) and I(0,y) are 151 | // always 0 (so it's a C-style array sum). 152 | 153 | // However, when computing sums, we subtract values at points 154 | // like y+yMin-1 and x+xMin-1, so we also SUBTRACT 1 from xMin 155 | // and yMin, and thus finally they are not affected. 156 | 157 | const int t = max(0, min(x+xMinCurr, h)); 158 | const int b = max(0, min(x+xMaxCurr, h)); 159 | const int l = max(0, min(y+yMinCurr, w)); 160 | const int r = max(0, min(y+yMaxCurr, w)); 161 | 162 | const int bAdv = max(0, min(x+xMaxCurr+1, h)); 163 | const int rAdv = max(0, min(y+yMaxCurr+1, w)); 164 | const int tAdv = max(0, min(x+xMinCurr-1, h)); 165 | const int lAdv = max(0, min(y+yMinCurr-1, w)); 166 | 167 | scalar_t outValue; 168 | 169 | // -- main area 170 | outValue = 171 | inputIntAcsr[b][r] 172 | - inputIntAcsr[t][r] 173 | - inputIntAcsr[b][l] 174 | + inputIntAcsr[t][l]; 175 | 176 | if (exact) { 177 | // -- xMax border 178 | outValue += 179 | ( inputIntAcsr[bAdv][r] 180 | - inputIntAcsr[b ][r] 181 | - inputIntAcsr[bAdv][l] 182 | + inputIntAcsr[b ][l]) * xMaxCurrFrac; 183 | 184 | // -- yMax border 185 | outValue += 186 | ( inputIntAcsr[b][rAdv] 187 | - inputIntAcsr[b][r ] 188 | - inputIntAcsr[t][rAdv] 189 | + inputIntAcsr[t][r ]) * yMaxCurrFrac; 190 | 191 | // -- xMin border 192 | outValue += 193 | ( inputIntAcsr[t ][r] 194 | - inputIntAcsr[tAdv][r] 195 | - inputIntAcsr[t ][l] 196 | + inputIntAcsr[tAdv][l]) * xMinCurrFrac; 197 | 198 | // -- yMin border 199 | outValue += 200 | ( inputIntAcsr[b][l ] 201 | - inputIntAcsr[b][lAdv] 202 | - inputIntAcsr[t][l ] 203 | + inputIntAcsr[t][lAdv]) * yMinCurrFrac; 204 | 205 | // -- corner pixels 206 | // Note: before, I used plain `input` to access corner values 207 | // with lower memory overhead. Moved to `input_integrated` 208 | // to get rid of an extra input to this function. 209 | 210 | if (not ((x+xMaxCurr >= h) | (y+yMaxCurr >= w) | 211 | (x+xMaxCurr < 0) | (y+yMaxCurr < 0))) { 212 | outValue += 213 | xMaxCurrFrac * yMaxCurrFrac * 214 | ( inputIntAcsr[b+1][r+1] 215 | - inputIntAcsr[b ][r+1] 216 | - inputIntAcsr[b+1][r ] 217 | + inputIntAcsr[b ][r ]); 218 | } 219 | 220 | if (not ((x+xMinCurr > h) | (y+yMaxCurr >= w) | 221 | (x+xMinCurr <= 0) | (y+yMaxCurr < 0))) { 222 | outValue += 223 | xMinCurrFrac * yMaxCurrFrac * 224 | ( inputIntAcsr[t ][r+1] 225 | - inputIntAcsr[t-1][r+1] 226 | - inputIntAcsr[t ][r ] 227 | + inputIntAcsr[t-1][r ]); 228 | } 229 | 230 | if (not ((x+xMaxCurr >= h) | (y+yMinCurr > w) | 231 | (x+xMaxCurr < 0) | (y+yMinCurr <= 0))) { 232 | outValue += 233 | xMaxCurrFrac * yMinCurrFrac * 234 | ( inputIntAcsr[b+1][l ] 235 | - inputIntAcsr[b ][l ] 236 | - inputIntAcsr[b+1][l-1] 237 | + inputIntAcsr[b ][l-1]); 238 | } 239 | 240 | if (not ((x+xMinCurr > h) | (y+yMinCurr > w) | 241 | (x+xMinCurr <= 0) | (y+yMinCurr <= 0))) { 242 | outValue += 243 | xMinCurrFrac * yMinCurrFrac * 244 | ( inputIntAcsr[t ][l ] 245 | - inputIntAcsr[t-1][l ] 246 | - inputIntAcsr[t ][l-1] 247 | + inputIntAcsr[t-1][l-1]); 248 | } 249 | } 250 | 251 | *(outputData++) = outValue * 252 | (normalize ? 253 | areaAcsr[inPlaneIdx][filterIdx] : 254 | static_cast(1)); 255 | } 256 | } 257 | } // filterIdx 258 | } // inPlaneIdx 259 | } // batchIdx 260 | })); 261 | } 262 | 263 | // explicitly instantiate 264 | template void boxConvUpdateOutput( 265 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 266 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 267 | at::Tensor &, at::Tensor &, at::Tensor &); 268 | 269 | template void boxConvUpdateOutput( 270 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 271 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 272 | at::Tensor &, at::Tensor &, at::Tensor &); 273 | 274 | template void boxConvUpdateOutput( 275 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 276 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 277 | at::Tensor &, at::Tensor &, at::Tensor &); 278 | 279 | template void boxConvUpdateOutput( 280 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 281 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 282 | at::Tensor &, at::Tensor &, at::Tensor &); 283 | 284 | // `grad_output_integrated` size: {batchSize, nInputPlanes, numFilters, h+1, w+1} 285 | // `tmpArray` size: {batchSize, nInputPlanes, numFilters, h, w} 286 | template 287 | void boxConvUpdateGradInput( 288 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 289 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 290 | at::Tensor & area, at::Tensor & grad_output_integrated, at::Tensor & tmpArray) { 291 | 292 | // was `const int`, but had to remove `const` to work around a bug in GCC 5 293 | int h = tmpArray.size(-2); 294 | int w = tmpArray.size(-1); 295 | 296 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tmpArray.scalar_type(), "cpu::boxConvUpdateGradInput", ([&] { 297 | 298 | auto xMinIntAcsr = xMinInt.accessor(); 299 | auto xMaxIntAcsr = xMaxInt.accessor(); 300 | auto yMinIntAcsr = yMinInt.accessor(); 301 | auto yMaxIntAcsr = yMaxInt.accessor(); 302 | 303 | auto xMinFracAcsr = xMinFrac.accessor(); 304 | auto xMaxFracAcsr = xMaxFrac.accessor(); 305 | auto yMinFracAcsr = yMinFrac.accessor(); 306 | auto yMaxFracAcsr = yMaxFrac.accessor(); 307 | 308 | auto areaAcsr = xMinFracAcsr; // because there's no default ctor :( 309 | // only initialize the accessor if `area` is defined (errors otherwise) 310 | if (normalize) { 311 | areaAcsr = area.accessor(); 312 | } 313 | 314 | scalar_t *tmpArrayData = tmpArray.data_ptr(); 315 | 316 | for (int batchIdx = 0; batchIdx < grad_output_integrated.size(0); ++batchIdx) { 317 | for (int inPlaneIdx = 0; inPlaneIdx < grad_output_integrated.size(1); ++inPlaneIdx) { 318 | for (int filterIdx = 0; filterIdx < xMinInt.size(1); ++filterIdx) { 319 | 320 | const int xMinCurr = xMinIntAcsr[inPlaneIdx][filterIdx]; 321 | const int xMaxCurr = xMaxIntAcsr[inPlaneIdx][filterIdx]; 322 | const int yMinCurr = yMinIntAcsr[inPlaneIdx][filterIdx]; 323 | const int yMaxCurr = yMaxIntAcsr[inPlaneIdx][filterIdx]; 324 | 325 | const scalar_t xMinCurrFrac = xMinFracAcsr[inPlaneIdx][filterIdx]; 326 | const scalar_t xMaxCurrFrac = xMaxFracAcsr[inPlaneIdx][filterIdx]; 327 | const scalar_t yMinCurrFrac = yMinFracAcsr[inPlaneIdx][filterIdx]; 328 | const scalar_t yMaxCurrFrac = yMaxFracAcsr[inPlaneIdx][filterIdx]; 329 | 330 | auto gradOutputIntPlane = 331 | grad_output_integrated[batchIdx][inPlaneIdx][filterIdx]; 332 | auto gradOutputAcsr = gradOutputIntPlane.accessor(); 333 | 334 | for (int x = 0; x < h; ++x) { 335 | for (int y = 0; y < w; ++y) { 336 | 337 | const int t = max(0, min(x+xMinCurr, h)); 338 | const int b = max(0, min(x+xMaxCurr, h)); 339 | const int l = max(0, min(y+yMinCurr, w)); 340 | const int r = max(0, min(y+yMaxCurr, w)); 341 | 342 | const int tAdv = x+xMinCurr-1 < h ? max(0, min(t-1, h)) : t; 343 | const int bAdv = x+xMaxCurr >= 0 ? max(0, min(b+1, h)) : b; 344 | const int lAdv = y+yMinCurr-1 < w ? max(0, min(l-1, w)) : l; 345 | const int rAdv = y+yMaxCurr >= 0 ? max(0, min(r+1, w)) : r; 346 | 347 | scalar_t outValue; 348 | 349 | outValue = 350 | gradOutputAcsr[b][r] 351 | - gradOutputAcsr[t][r] 352 | - gradOutputAcsr[b][l] 353 | + gradOutputAcsr[t][l]; 354 | 355 | if (exact) { 356 | // -- xMax border 357 | outValue += 358 | ( gradOutputAcsr[bAdv][r] 359 | - gradOutputAcsr[b ][r] 360 | - gradOutputAcsr[bAdv][l] 361 | + gradOutputAcsr[b ][l] 362 | ) * xMaxCurrFrac; 363 | 364 | // -- yMax border 365 | outValue += 366 | ( gradOutputAcsr[b][rAdv] 367 | - gradOutputAcsr[b][r ] 368 | - gradOutputAcsr[t][rAdv] 369 | + gradOutputAcsr[t][r ] 370 | ) * yMaxCurrFrac; 371 | 372 | // -- xMin border 373 | outValue += 374 | ( gradOutputAcsr[t ][r] 375 | - gradOutputAcsr[tAdv][r] 376 | - gradOutputAcsr[t ][l] 377 | + gradOutputAcsr[tAdv][l] 378 | ) * xMinCurrFrac; 379 | 380 | // -- yMin border 381 | outValue += 382 | ( gradOutputAcsr[b][l ] 383 | - gradOutputAcsr[b][lAdv] 384 | - gradOutputAcsr[t][l ] 385 | + gradOutputAcsr[t][lAdv] 386 | ) * yMinCurrFrac; 387 | 388 | // -- corner pixels 389 | outValue += 390 | xMaxCurrFrac*yMaxCurrFrac * ( 391 | (x+xMaxCurr >= h or 392 | y+yMaxCurr >= w or 393 | x+xMaxCurr < 0 or 394 | y+yMaxCurr < 0 or 395 | b == bAdv or 396 | r == rAdv) ? static_cast(0) : 397 | 398 | ( gradOutputAcsr[b+1][r+1] 399 | - gradOutputAcsr[b ][r+1] 400 | - gradOutputAcsr[b+1][r ] 401 | + gradOutputAcsr[b ][r ])); 402 | 403 | outValue += 404 | xMinCurrFrac*yMaxCurrFrac * ( 405 | (x+xMinCurr > h or 406 | y+yMaxCurr >= w or 407 | x+xMinCurr <= 0 or 408 | y+yMaxCurr < 0 or 409 | t == tAdv or 410 | r == rAdv) ? static_cast(0) : 411 | 412 | ( gradOutputAcsr[tAdv+1][r+1] 413 | - gradOutputAcsr[tAdv+1][r ] 414 | - gradOutputAcsr[tAdv ][r+1] 415 | + gradOutputAcsr[tAdv ][r ])); 416 | 417 | outValue += 418 | xMaxCurrFrac*yMinCurrFrac * ( 419 | (x+xMaxCurr >= h or 420 | y+yMinCurr > w or 421 | x+xMaxCurr < 0 or 422 | y+yMinCurr <= 0 or 423 | b == bAdv or 424 | l == lAdv) ? static_cast(0) : 425 | 426 | ( gradOutputAcsr[b+1][lAdv+1] 427 | - gradOutputAcsr[b ][lAdv+1] 428 | - gradOutputAcsr[b+1][lAdv ] 429 | + gradOutputAcsr[b ][lAdv ])); 430 | 431 | outValue += 432 | xMinCurrFrac*yMinCurrFrac * ( 433 | (x+xMinCurr > h or 434 | y+yMinCurr > w or 435 | x+xMinCurr <= 0 or 436 | y+yMinCurr <= 0 or 437 | t == tAdv or 438 | l == lAdv) ? static_cast(0) : 439 | 440 | ( gradOutputAcsr[tAdv+1][lAdv+1] 441 | - gradOutputAcsr[tAdv+1][lAdv ] 442 | - gradOutputAcsr[tAdv ][lAdv+1] 443 | + gradOutputAcsr[tAdv ][lAdv ])); 444 | } 445 | 446 | *(tmpArrayData++) = outValue * 447 | (normalize ? 448 | areaAcsr[inPlaneIdx][filterIdx] : 449 | static_cast(1)); 450 | } 451 | } 452 | } // filterIdx 453 | } // inPlaneIdx 454 | } // batchIdx 455 | })); 456 | } 457 | 458 | // explicitly instantiate 459 | template void boxConvUpdateGradInput( 460 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 461 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 462 | at::Tensor &, at::Tensor &, at::Tensor &); 463 | 464 | template void boxConvUpdateGradInput( 465 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 466 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 467 | at::Tensor &, at::Tensor &, at::Tensor &); 468 | 469 | template void boxConvUpdateGradInput( 470 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 471 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 472 | at::Tensor &, at::Tensor &, at::Tensor &); 473 | 474 | template void boxConvUpdateGradInput( 475 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 476 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 477 | at::Tensor &, at::Tensor &, at::Tensor &); 478 | 479 | 480 | template 481 | void boxConvAccGradParameters( 482 | // tmpArray size: {batchSize, nInputPlanes, numFilters, h, w} 483 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 484 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 485 | at::Tensor & input_integrated, at::Tensor & tmpArray, Parameter parameter) { 486 | 487 | // was `const int`, but had to remove `const` to work around a bug in GCC 5 488 | int h = tmpArray.size(-2); 489 | int w = tmpArray.size(-1); 490 | 491 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tmpArray.scalar_type(), "cpu::boxConvAccGradParameters", ([&] { 492 | 493 | auto xMinIntAcsr = xMinInt.accessor(); 494 | auto xMaxIntAcsr = xMaxInt.accessor(); 495 | auto yMinIntAcsr = yMinInt.accessor(); 496 | auto yMaxIntAcsr = yMaxInt.accessor(); 497 | 498 | auto xMinFracAcsr = xMinFrac.accessor(); 499 | auto xMaxFracAcsr = xMaxFrac.accessor(); 500 | auto yMinFracAcsr = yMinFrac.accessor(); 501 | auto yMaxFracAcsr = yMaxFrac.accessor(); 502 | 503 | scalar_t *tmpArrayData = tmpArray.data_ptr(); 504 | 505 | for (int batchIdx = 0; batchIdx < input_integrated.size(0); ++batchIdx) { 506 | for (int inPlaneIdx = 0; inPlaneIdx < input_integrated.size(1); ++inPlaneIdx) { 507 | 508 | auto inputIntPlane = 509 | input_integrated[batchIdx][inPlaneIdx]; 510 | auto inputIntAcsr = inputIntPlane.accessor(); 511 | 512 | for (int filterIdx = 0; filterIdx < xMinInt.size(1); ++filterIdx) { 513 | 514 | const int xMinInt = xMinIntAcsr[inPlaneIdx][filterIdx]; 515 | const int xMaxInt = xMaxIntAcsr[inPlaneIdx][filterIdx]; 516 | const int yMinInt = yMinIntAcsr[inPlaneIdx][filterIdx]; 517 | const int yMaxInt = yMaxIntAcsr[inPlaneIdx][filterIdx]; 518 | 519 | const scalar_t xMinFrac = xMinFracAcsr[inPlaneIdx][filterIdx]; 520 | const scalar_t xMaxFrac = xMaxFracAcsr[inPlaneIdx][filterIdx]; 521 | const scalar_t yMinFrac = yMinFracAcsr[inPlaneIdx][filterIdx]; 522 | const scalar_t yMaxFrac = yMaxFracAcsr[inPlaneIdx][filterIdx]; 523 | 524 | for (int x = 1; x <= h; ++x) { 525 | for (int y = 1; y <= w; ++y) { 526 | 527 | if (parameter == Parameter::xMin) { 528 | // Had to move 3 following lines into the 529 | // `if` to ensure loop unswitching 530 | int valid; 531 | int cornerX, cornerY; 532 | 533 | scalar_t delta = 0; 534 | 535 | if (exact) { 536 | // TODO maybe use `input` instead of `inputInt` 537 | valid = 538 | not (y+yMinInt < 1) & not (y+yMinInt > w) & not (x+xMinInt < 1); 539 | cornerX = max(0,min(h-1,x+xMinInt-1)); 540 | cornerY = max(0,min(w-1,y+yMinInt-1)); 541 | const scalar_t tlCorner = valid * 542 | ( inputIntAcsr[cornerX+1][cornerY+1] 543 | - inputIntAcsr[cornerX ][cornerY+1] 544 | - inputIntAcsr[cornerX+1][cornerY ] 545 | + inputIntAcsr[cornerX ][cornerY ]); 546 | 547 | valid = 548 | not (y+yMaxInt < 0) & not (y+yMaxInt >= w) & not (x+xMinInt < 1); 549 | cornerX = max(0,min(h-1,x+xMinInt-1)); 550 | cornerY = max(0,min(w-1,y+yMaxInt )); 551 | const scalar_t trCorner = valid * 552 | ( inputIntAcsr[cornerX+1][cornerY+1] 553 | - inputIntAcsr[cornerX ][cornerY+1] 554 | - inputIntAcsr[cornerX+1][cornerY ] 555 | + inputIntAcsr[cornerX ][cornerY ]); 556 | 557 | delta += trCorner * yMaxFrac; 558 | delta += tlCorner * yMinFrac; 559 | } // if (exact) 560 | 561 | delta += inputIntAcsr 562 | [max(0,min(x+xMinInt , h))][max(0,min(y+yMaxInt , w))]; 563 | delta -= inputIntAcsr 564 | [max(0,min(x+xMinInt-1, h))][max(0,min(y+yMaxInt , w))]; 565 | delta -= inputIntAcsr 566 | [max(0,min(x+xMinInt , h))][max(0,min(y+yMinInt , w))]; 567 | delta += inputIntAcsr 568 | [max(0,min(x+xMinInt-1, h))][max(0,min(y+yMinInt , w))]; 569 | 570 | delta *= (x+xMinInt >= 1) & (x+xMinInt <= h); 571 | 572 | *(tmpArrayData++) = -delta; 573 | } 574 | 575 | else if (parameter == Parameter::xMax) { 576 | int valid; 577 | int cornerX, cornerY; 578 | 579 | scalar_t delta = 0; 580 | 581 | if (exact) { 582 | valid = 583 | not (y+yMinInt < 1) & not (y+yMinInt > w) & not (x+xMaxInt >= h); 584 | cornerX = max(0,min(h-1,x+xMaxInt )); 585 | cornerY = max(0,min(w-1,y+yMinInt-1)); 586 | const scalar_t blCorner = valid * 587 | ( inputIntAcsr[cornerX+1][cornerY+1] 588 | - inputIntAcsr[cornerX ][cornerY+1] 589 | - inputIntAcsr[cornerX+1][cornerY ] 590 | + inputIntAcsr[cornerX ][cornerY ]); 591 | 592 | valid = 593 | not (y+yMaxInt < 0) & not (y+yMaxInt >= w) & not (x+xMaxInt >= h); 594 | cornerX = max(0,min(h-1,x+xMaxInt )); 595 | cornerY = max(0,min(w-1,y+yMaxInt )); 596 | const scalar_t brCorner = valid * 597 | ( inputIntAcsr[cornerX+1][cornerY+1] 598 | - inputIntAcsr[cornerX ][cornerY+1] 599 | - inputIntAcsr[cornerX+1][cornerY ] 600 | + inputIntAcsr[cornerX ][cornerY ]); 601 | 602 | delta += brCorner * yMaxFrac; 603 | delta += blCorner * yMinFrac; 604 | } // if (exact) 605 | 606 | delta += inputIntAcsr 607 | [max(0,min(x+xMaxInt+1, h))][max(0,min(y+yMaxInt , w))]; 608 | delta -= inputIntAcsr 609 | [max(0,min(x+xMaxInt , h))][max(0,min(y+yMaxInt , w))]; 610 | delta -= inputIntAcsr 611 | [max(0,min(x+xMaxInt+1, h))][max(0,min(y+yMinInt , w))]; 612 | delta += inputIntAcsr 613 | [max(0,min(x+xMaxInt , h))][max(0,min(y+yMinInt , w))]; 614 | 615 | delta *= (x+xMaxInt >= 0) & (x+xMaxInt < h); 616 | 617 | *(tmpArrayData++) = delta; 618 | } 619 | 620 | else if (parameter == Parameter::yMin) { 621 | int valid; 622 | int cornerX, cornerY; 623 | 624 | scalar_t delta = 0; 625 | 626 | if (exact) { 627 | valid = 628 | not (y+yMinInt < 1) & not (x+xMinInt < 1) & not (x+xMinInt > h); 629 | cornerX = max(0,min(h-1,x+xMinInt-1)); 630 | cornerY = max(0,min(w-1,y+yMinInt-1)); 631 | const scalar_t tlCorner = valid * 632 | ( inputIntAcsr[cornerX+1][cornerY+1] 633 | - inputIntAcsr[cornerX ][cornerY+1] 634 | - inputIntAcsr[cornerX+1][cornerY ] 635 | + inputIntAcsr[cornerX ][cornerY ]); 636 | 637 | valid = 638 | not (y+yMinInt < 1) & not (x+xMaxInt < 0) & not (x+xMaxInt >= h); 639 | cornerX = max(0,min(h-1,x+xMaxInt )); 640 | cornerY = max(0,min(w-1,y+yMinInt-1)); 641 | const scalar_t blCorner = valid * 642 | ( inputIntAcsr[cornerX+1][cornerY+1] 643 | - inputIntAcsr[cornerX ][cornerY+1] 644 | - inputIntAcsr[cornerX+1][cornerY ] 645 | + inputIntAcsr[cornerX ][cornerY ]); 646 | 647 | delta += tlCorner * xMinFrac; 648 | delta += blCorner * xMaxFrac; 649 | } // if (exact) 650 | 651 | delta += inputIntAcsr 652 | [max(0,min(x+xMaxInt , h))][max(0,min(y+yMinInt , w))]; 653 | delta -= inputIntAcsr 654 | [max(0,min(x+xMaxInt , h))][max(0,min(y+yMinInt-1, w))]; 655 | delta -= inputIntAcsr 656 | [max(0,min(x+xMinInt , h))][max(0,min(y+yMinInt , w))]; 657 | delta += inputIntAcsr 658 | [max(0,min(x+xMinInt , h))][max(0,min(y+yMinInt-1, w))]; 659 | 660 | delta *= (y+yMinInt >= 1) & (y+yMinInt <= w); 661 | 662 | *(tmpArrayData++) = -delta; 663 | } 664 | 665 | else if (parameter == Parameter::yMax) { 666 | int valid; 667 | int cornerX, cornerY; 668 | 669 | scalar_t delta = 0; 670 | 671 | if (exact) { 672 | valid = 673 | not (y+yMaxInt >= w) & not (x+xMinInt < 1) & not (x+xMinInt > h); 674 | cornerX = max(0,min(h-1,x+xMinInt-1)); 675 | cornerY = max(0,min(w-1,y+yMaxInt )); 676 | const scalar_t trCorner = valid * 677 | ( inputIntAcsr[cornerX+1][cornerY+1] 678 | - inputIntAcsr[cornerX ][cornerY+1] 679 | - inputIntAcsr[cornerX+1][cornerY ] 680 | + inputIntAcsr[cornerX ][cornerY ]); 681 | 682 | valid = 683 | not (y+yMaxInt >= w) & not (x+xMaxInt < 0) & not (x+xMaxInt >= h); 684 | cornerX = max(0,min(h-1,x+xMaxInt )); 685 | cornerY = max(0,min(w-1,y+yMaxInt )); 686 | const scalar_t brCorner = valid * 687 | ( inputIntAcsr[cornerX+1][cornerY+1] 688 | - inputIntAcsr[cornerX ][cornerY+1] 689 | - inputIntAcsr[cornerX+1][cornerY ] 690 | + inputIntAcsr[cornerX ][cornerY ]); 691 | 692 | delta += trCorner * xMinFrac; 693 | delta += brCorner * xMaxFrac; 694 | } // if (exact) 695 | 696 | delta += inputIntAcsr 697 | [max(0,min(x+xMaxInt , h))][max(0,min(y+yMaxInt+1, w))]; 698 | delta -= inputIntAcsr 699 | [max(0,min(x+xMaxInt , h))][max(0,min(y+yMaxInt , w))]; 700 | delta -= inputIntAcsr 701 | [max(0,min(x+xMinInt , h))][max(0,min(y+yMaxInt+1, w))]; 702 | delta += inputIntAcsr 703 | [max(0,min(x+xMinInt , h))][max(0,min(y+yMaxInt , w))]; 704 | 705 | delta *= (y+yMaxInt >= 0) & (y+yMaxInt < w); 706 | 707 | *(tmpArrayData++) = delta; 708 | } 709 | } 710 | } 711 | } // filterIdx 712 | } // inPlaneIdx 713 | } // batchIdx 714 | })); 715 | } 716 | 717 | // explicitly instantiate 718 | template void boxConvAccGradParameters( 719 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 720 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 721 | at::Tensor &, at::Tensor &, Parameter); 722 | 723 | template void boxConvAccGradParameters( 724 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 725 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 726 | at::Tensor &, at::Tensor &, Parameter); 727 | 728 | 729 | void clipParameters( 730 | at::Tensor & paramMin, at::Tensor & paramMax, 731 | const double reparametrization, const double minSize, const double maxSize) { 732 | 733 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(paramMin.scalar_type(), "cpu::clipParameters", ([&] { 734 | 735 | scalar_t *paramMinPtr = paramMin.data_ptr(); 736 | scalar_t *paramMaxPtr = paramMax.data_ptr(); 737 | 738 | const double inverseReparam = 1.0 / reparametrization; 739 | 740 | for (int idx = 0; idx < paramMin.numel(); ++idx) { 741 | 742 | double minValue, maxValue; 743 | const double paramMinCurrent = static_cast(paramMinPtr[idx]); 744 | const double paramMaxCurrent = static_cast(paramMaxPtr[idx]); 745 | 746 | // clamp parameters 747 | minValue = max(-(maxSize+1) * inverseReparam, 748 | min((maxSize-1) * inverseReparam, paramMinCurrent)); 749 | maxValue = max(-(maxSize+1) * inverseReparam, 750 | min((maxSize-1) * inverseReparam, paramMaxCurrent)); 751 | 752 | // make sure bottom/right border doesn't come before top/left 753 | if (minValue + (minSize - 0.9999) * inverseReparam > maxValue) { 754 | const scalar_t mean = 0.5 * (minValue + maxValue); 755 | minValue = mean - 0.5 * (minSize - 0.9999) * inverseReparam; 756 | maxValue = mean + 0.5 * (minSize - 0.9999) * inverseReparam; 757 | } 758 | 759 | paramMinPtr[idx] = static_cast(minValue); 760 | paramMaxPtr[idx] = static_cast(maxValue); 761 | } 762 | })); 763 | } 764 | 765 | at::Tensor computeArea( 766 | at::Tensor x_min, at::Tensor x_max, at::Tensor y_min, at::Tensor y_max, 767 | const bool exact, const bool needXDeriv, const bool needYDeriv) { 768 | 769 | // TODO: how to stop tracking operations??? `.is_variable_(false)` doesn't work 770 | auto retval = at::ones_like(x_min); 771 | 772 | if (not exact) { 773 | x_min = x_min.ceil(); 774 | y_min = y_min.ceil(); 775 | x_max = x_max.floor(); 776 | y_max = y_max.floor(); 777 | } 778 | 779 | if (needXDeriv) { 780 | auto xArea = x_max - x_min; 781 | xArea += 1; 782 | retval *= xArea; 783 | } 784 | 785 | if (needYDeriv) { 786 | auto yArea = y_max - y_min; 787 | yArea += 1; 788 | retval *= yArea; 789 | } 790 | 791 | retval.reciprocal_(); // inverse areas 792 | return retval; 793 | } 794 | 795 | } // namespace cpu 796 | -------------------------------------------------------------------------------- /src/box_convolution.h: -------------------------------------------------------------------------------- 1 | #include // && -> and, || -> or etc. 2 | 3 | enum class Parameter {xMin, xMax, yMin, yMax}; 4 | 5 | namespace cpu { 6 | 7 | void splitParameters( 8 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 9 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 10 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac); 11 | 12 | void splitParametersUpdateGradInput( 13 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 14 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 15 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac); 16 | 17 | void splitParametersAccGradParameters( 18 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 19 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 20 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac); 21 | 22 | template 23 | void boxConvUpdateOutput( 24 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 25 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 26 | at::Tensor & area, at::Tensor & input_integrated, at::Tensor & output); 27 | 28 | template 29 | void boxConvUpdateGradInput( 30 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 31 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 32 | at::Tensor & area, at::Tensor & grad_output_integrated, at::Tensor & tmpArray); 33 | 34 | template 35 | void boxConvAccGradParameters( 36 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 37 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 38 | at::Tensor & input_integrated, at::Tensor & tmpArray, Parameter parameter); 39 | 40 | void clipParameters( 41 | at::Tensor & paramMin, at::Tensor & paramMax, 42 | const double reparametrization, const double minSize, const double maxSize); 43 | 44 | at::Tensor computeArea( 45 | at::Tensor x_min, at::Tensor x_max, at::Tensor y_min, at::Tensor y_max, 46 | const bool exact, const bool needXDeriv = true, const bool needYDeriv = true); 47 | 48 | } 49 | 50 | namespace gpu { 51 | 52 | void splitParameters( 53 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 54 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 55 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac); 56 | 57 | void splitParametersUpdateGradInput( 58 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 59 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 60 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac); 61 | 62 | void splitParametersAccGradParameters( 63 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 64 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 65 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac); 66 | 67 | template 68 | void boxConvUpdateOutput( 69 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 70 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 71 | at::Tensor & area, at::Tensor & input_integrated, at::Tensor & output); 72 | 73 | template 74 | void boxConvUpdateGradInput( 75 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 76 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 77 | at::Tensor & area, at::Tensor & grad_output_integrated, at::Tensor & tmpArray); 78 | 79 | template 80 | void boxConvAccGradParameters( 81 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 82 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 83 | at::Tensor & input_integrated, at::Tensor & tmpArray, Parameter parameter); 84 | 85 | void clipParameters( 86 | at::Tensor & paramMin, at::Tensor & paramMax, 87 | const double reparametrization, const double minSize, const double maxSize); 88 | 89 | at::Tensor computeArea( 90 | at::Tensor x_min, at::Tensor x_max, at::Tensor y_min, at::Tensor y_max, 91 | const bool exact, const bool needXDeriv = true, const bool needYDeriv = true); 92 | 93 | } -------------------------------------------------------------------------------- /src/box_convolution_cuda_backward.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include "box_convolution.h" // for `enum class Parameter` 7 | 8 | #define BLOCK_SIZE 256 9 | #define NUM_THREADS 1024 10 | 11 | using std::min; 12 | using std::max; 13 | 14 | namespace gpu { 15 | 16 | template 17 | using CudaAcsr = const at::PackedTensorAccessor32; 18 | 19 | // TODO switch to square blocks 20 | template 21 | __global__ void boxConvUpdateGradInputKernel( 22 | CudaAcsr gradOutputInt, scalar_t * __restrict__ tmpArray, 23 | const int32_t * __restrict__ xMinInt , const int32_t * __restrict__ xMaxInt , 24 | const int32_t * __restrict__ yMinInt , const int32_t * __restrict__ yMaxInt , 25 | const scalar_t * __restrict__ xMinFrac, const scalar_t * __restrict__ xMaxFrac, 26 | const scalar_t * __restrict__ yMinFrac, const scalar_t * __restrict__ yMaxFrac, 27 | const scalar_t * __restrict__ area, const int nParams) { 28 | 29 | int32_t id = NUM_THREADS * blockIdx.x + threadIdx.x; 30 | tmpArray += id; 31 | 32 | const int32_t h = gradOutputInt.size(1) - 1; 33 | const int32_t w = gradOutputInt.size(2) - 1; 34 | const int32_t y = id % w; id /= w; 35 | const int32_t x = id % h; id /= h; 36 | const int32_t paramIdx = id % nParams; 37 | 38 | // `id` is now the current plane number 39 | auto gradOutputIntPlane = gradOutputInt[id]; 40 | 41 | if (id < gradOutputInt.size(0)) { 42 | 43 | const int32_t xMinCurr = xMinInt[paramIdx]; 44 | const int32_t xMaxCurr = xMaxInt[paramIdx]; 45 | const int32_t yMinCurr = yMinInt[paramIdx]; 46 | const int32_t yMaxCurr = yMaxInt[paramIdx]; 47 | 48 | const int t = max(0, min(x+xMinCurr, h)); 49 | const int b = max(0, min(x+xMaxCurr, h)); 50 | const int l = max(0, min(y+yMinCurr, w)); 51 | const int r = max(0, min(y+yMaxCurr, w)); 52 | 53 | scalar_t outValue; 54 | 55 | outValue = 56 | gradOutputIntPlane[b][r] 57 | - gradOutputIntPlane[t][r] 58 | - gradOutputIntPlane[b][l] 59 | + gradOutputIntPlane[t][l]; 60 | 61 | if (exact) { 62 | const scalar_t xMinCurrFrac = xMinFrac[paramIdx]; 63 | const scalar_t xMaxCurrFrac = xMaxFrac[paramIdx]; 64 | const scalar_t yMinCurrFrac = yMinFrac[paramIdx]; 65 | const scalar_t yMaxCurrFrac = yMaxFrac[paramIdx]; 66 | 67 | const int tAdv = x+xMinCurr-1 < h ? max(0, min(t-1, h)) : t; 68 | const int bAdv = x+xMaxCurr >= 0 ? max(0, min(b+1, h)) : b; 69 | const int lAdv = y+yMinCurr-1 < w ? max(0, min(l-1, w)) : l; 70 | const int rAdv = y+yMaxCurr >= 0 ? max(0, min(r+1, w)) : r; 71 | 72 | // -- xMax border 73 | outValue += 74 | ( gradOutputIntPlane[bAdv][r] 75 | - gradOutputIntPlane[b ][r] 76 | - gradOutputIntPlane[bAdv][l] 77 | + gradOutputIntPlane[b ][l] 78 | ) * xMaxCurrFrac; 79 | 80 | // -- yMax border 81 | outValue += 82 | ( gradOutputIntPlane[b][rAdv] 83 | - gradOutputIntPlane[b][r ] 84 | - gradOutputIntPlane[t][rAdv] 85 | + gradOutputIntPlane[t][r ] 86 | ) * yMaxCurrFrac; 87 | 88 | // -- xMin border 89 | outValue += 90 | ( gradOutputIntPlane[t ][r] 91 | - gradOutputIntPlane[tAdv][r] 92 | - gradOutputIntPlane[t ][l] 93 | + gradOutputIntPlane[tAdv][l] 94 | ) * xMinCurrFrac; 95 | 96 | // -- yMin border 97 | outValue += 98 | ( gradOutputIntPlane[b][l ] 99 | - gradOutputIntPlane[b][lAdv] 100 | - gradOutputIntPlane[t][l ] 101 | + gradOutputIntPlane[t][lAdv] 102 | ) * yMinCurrFrac; 103 | 104 | // -- corner pixels 105 | outValue += 106 | xMaxCurrFrac*yMaxCurrFrac * ( 107 | (x+xMaxCurr >= h or 108 | y+yMaxCurr >= w or 109 | x+xMaxCurr < 0 or 110 | y+yMaxCurr < 0 or 111 | b == bAdv or 112 | r == rAdv) ? static_cast(0) : 113 | 114 | ( gradOutputIntPlane[b+1][r+1] 115 | - gradOutputIntPlane[b ][r+1] 116 | - gradOutputIntPlane[b+1][r ] 117 | + gradOutputIntPlane[b ][r ])); 118 | 119 | outValue += 120 | xMinCurrFrac*yMaxCurrFrac * ( 121 | (x+xMinCurr > h or 122 | y+yMaxCurr >= w or 123 | x+xMinCurr <= 0 or 124 | y+yMaxCurr < 0 or 125 | t == tAdv or 126 | r == rAdv) ? static_cast(0) : 127 | 128 | ( gradOutputIntPlane[tAdv+1][r+1] 129 | - gradOutputIntPlane[tAdv+1][r ] 130 | - gradOutputIntPlane[tAdv ][r+1] 131 | + gradOutputIntPlane[tAdv ][r ])); 132 | 133 | outValue += 134 | xMaxCurrFrac*yMinCurrFrac * ( 135 | (x+xMaxCurr >= h or 136 | y+yMinCurr > w or 137 | x+xMaxCurr < 0 or 138 | y+yMinCurr <= 0 or 139 | b == bAdv or 140 | l == lAdv) ? static_cast(0) : 141 | 142 | ( gradOutputIntPlane[b+1][lAdv+1] 143 | - gradOutputIntPlane[b ][lAdv+1] 144 | - gradOutputIntPlane[b+1][lAdv ] 145 | + gradOutputIntPlane[b ][lAdv ])); 146 | 147 | outValue += 148 | xMinCurrFrac*yMinCurrFrac * ( 149 | (x+xMinCurr > h or 150 | y+yMinCurr > w or 151 | x+xMinCurr <= 0 or 152 | y+yMinCurr <= 0 or 153 | t == tAdv or 154 | l == lAdv) ? static_cast(0) : 155 | 156 | ( gradOutputIntPlane[tAdv+1][lAdv+1] 157 | - gradOutputIntPlane[tAdv+1][lAdv ] 158 | - gradOutputIntPlane[tAdv ][lAdv+1] 159 | + gradOutputIntPlane[tAdv ][lAdv ])); 160 | } 161 | 162 | *tmpArray = outValue * (normalize ? area[paramIdx] : static_cast(1)); 163 | } 164 | } 165 | 166 | template 167 | void boxConvUpdateGradInput( 168 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 169 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 170 | at::Tensor & area, at::Tensor & grad_output_integrated, at::Tensor & tmpArray) { 171 | 172 | // TODO use square blocks as in `boxConvUpdateOutput`? 173 | const int threadsNeeded = tmpArray.numel(); 174 | int numBlocks = (threadsNeeded + NUM_THREADS - 1) / NUM_THREADS; 175 | 176 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tmpArray.scalar_type(), "gpu::boxConvUpdateGradInput", ([&] { 177 | auto gradOutputIntFlattened = grad_output_integrated.view( 178 | {-1, grad_output_integrated.size(-2), grad_output_integrated.size(-1)}); 179 | auto gradOutputIntAcsr = 180 | gradOutputIntFlattened.packed_accessor32(); 181 | 182 | boxConvUpdateGradInputKernel 183 | <<>> ( 184 | gradOutputIntAcsr, tmpArray.data_ptr(), 185 | xMinInt.data_ptr(), xMaxInt.data_ptr(), 186 | yMinInt.data_ptr(), yMaxInt.data_ptr(), 187 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(), 188 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(), 189 | normalize ? area.data_ptr() : nullptr, xMinInt.numel()); 190 | THCudaCheck(cudaGetLastError()); 191 | })); 192 | } 193 | 194 | // explicitly instantiate 195 | template void boxConvUpdateGradInput( 196 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 197 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 198 | at::Tensor &, at::Tensor &, at::Tensor &); 199 | 200 | template void boxConvUpdateGradInput( 201 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 202 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 203 | at::Tensor &, at::Tensor &, at::Tensor &); 204 | 205 | template void boxConvUpdateGradInput( 206 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 207 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 208 | at::Tensor &, at::Tensor &, at::Tensor &); 209 | 210 | template void boxConvUpdateGradInput( 211 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 212 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 213 | at::Tensor &, at::Tensor &, at::Tensor &); 214 | 215 | 216 | // TODO overload for exact/truncated mode 217 | // TODO accept only three pairs of parameter arrays, not four (one is always redundant) 218 | template 219 | __global__ void boxConvAccGradParametersKernel( 220 | CudaAcsr inputInt, scalar_t * __restrict__ tmpArray, 221 | const int32_t * __restrict__ xMinInt , const int32_t * __restrict__ xMaxInt , 222 | const int32_t * __restrict__ yMinInt , const int32_t * __restrict__ yMaxInt , 223 | const scalar_t * __restrict__ xMinFrac, const scalar_t * __restrict__ xMaxFrac, 224 | const scalar_t * __restrict__ yMinFrac, const scalar_t * __restrict__ yMaxFrac, 225 | const int nParams) { 226 | 227 | int32_t id = NUM_THREADS * blockIdx.x + threadIdx.x; 228 | tmpArray += id; 229 | 230 | const int32_t h = inputInt.size(1) - 1; 231 | const int32_t w = inputInt.size(2) - 1; 232 | const int32_t y = id % w + 1; id /= w; 233 | const int32_t x = id % h + 1; id /= h; 234 | const int32_t paramIdx = id % nParams; id /= nParams; 235 | 236 | // `id` is now the current absolute input plane number 237 | auto inputIntPlane = inputInt[id]; 238 | 239 | if (id < inputInt.size(0)) { 240 | 241 | const int32_t xMinCurr = xMinInt[paramIdx]; 242 | const int32_t xMaxCurr = xMaxInt[paramIdx]; 243 | const int32_t yMinCurr = yMinInt[paramIdx]; 244 | const int32_t yMaxCurr = yMaxInt[paramIdx]; 245 | 246 | // TODO only define these if `exact == true` 247 | const scalar_t xMinCurrFrac = xMinFrac[paramIdx]; 248 | const scalar_t xMaxCurrFrac = xMaxFrac[paramIdx]; 249 | const scalar_t yMinCurrFrac = yMinFrac[paramIdx]; 250 | const scalar_t yMaxCurrFrac = yMaxFrac[paramIdx]; 251 | 252 | int valid; 253 | int cornerX, cornerY; 254 | 255 | scalar_t delta = 0; 256 | 257 | if (parameter == Parameter::xMin) { 258 | if (exact) { 259 | // TODO maybe use `input` instead of `inputInt` 260 | valid = 261 | not (y+yMinCurr < 1) & not (y+yMinCurr > w) & not (x+xMinCurr < 1); 262 | cornerX = max(0,min(h-1,x+xMinCurr-1)); 263 | cornerY = max(0,min(w-1,y+yMinCurr-1)); 264 | const scalar_t tlCorner = valid * 265 | ( inputIntPlane[cornerX+1][cornerY+1] 266 | - inputIntPlane[cornerX ][cornerY+1] 267 | - inputIntPlane[cornerX+1][cornerY ] 268 | + inputIntPlane[cornerX ][cornerY ]); 269 | 270 | valid = 271 | not (y+yMaxCurr < 0) & not (y+yMaxCurr >= w) & not (x+xMinCurr < 1); 272 | cornerX = max(0,min(h-1,x+xMinCurr -1)); 273 | cornerY = max(0,min(w-1,y+yMaxCurr )); 274 | const scalar_t trCorner = valid * 275 | ( inputIntPlane[cornerX+1][cornerY+1] 276 | - inputIntPlane[cornerX ][cornerY+1] 277 | - inputIntPlane[cornerX+1][cornerY ] 278 | + inputIntPlane[cornerX ][cornerY ]); 279 | 280 | delta += trCorner * yMaxCurrFrac; 281 | delta += tlCorner * yMinCurrFrac; 282 | } // if (exact) 283 | 284 | delta += inputIntPlane 285 | [max(0,min(x+xMinCurr , h))][max(0,min(y+yMaxCurr , w))]; 286 | delta -= inputIntPlane 287 | [max(0,min(x+xMinCurr -1, h))][max(0,min(y+yMaxCurr , w))]; 288 | delta -= inputIntPlane 289 | [max(0,min(x+xMinCurr , h))][max(0,min(y+yMinCurr , w))]; 290 | delta += inputIntPlane 291 | [max(0,min(x+xMinCurr -1, h))][max(0,min(y+yMinCurr , w))]; 292 | 293 | delta *= (x+xMinCurr >= 1) & (x+xMinCurr <= h); 294 | 295 | *tmpArray = -delta; 296 | } 297 | 298 | else if (parameter == Parameter::xMax) { 299 | if (exact) { 300 | valid = 301 | not (y+yMinCurr < 1) & not (y+yMinCurr > w) & not (x+xMaxCurr >= h); 302 | cornerX = max(0,min(h-1,x+xMaxCurr )); 303 | cornerY = max(0,min(w-1,y+yMinCurr -1)); 304 | const scalar_t blCorner = valid * 305 | ( inputIntPlane[cornerX+1][cornerY+1] 306 | - inputIntPlane[cornerX ][cornerY+1] 307 | - inputIntPlane[cornerX+1][cornerY ] 308 | + inputIntPlane[cornerX ][cornerY ]); 309 | 310 | valid = 311 | not (y+yMaxCurr < 0) & not (y+yMaxCurr >= w) & not (x+xMaxCurr >= h); 312 | cornerX = max(0,min(h-1,x+xMaxCurr )); 313 | cornerY = max(0,min(w-1,y+yMaxCurr )); 314 | const scalar_t brCorner = valid * 315 | ( inputIntPlane[cornerX+1][cornerY+1] 316 | - inputIntPlane[cornerX ][cornerY+1] 317 | - inputIntPlane[cornerX+1][cornerY ] 318 | + inputIntPlane[cornerX ][cornerY ]); 319 | 320 | delta += brCorner * yMaxCurrFrac; 321 | delta += blCorner * yMinCurrFrac; 322 | } // if (exact) 323 | 324 | delta += inputIntPlane 325 | [max(0,min(x+xMaxCurr +1, h))][max(0,min(y+yMaxCurr , w))]; 326 | delta -= inputIntPlane 327 | [max(0,min(x+xMaxCurr , h))][max(0,min(y+yMaxCurr , w))]; 328 | delta -= inputIntPlane 329 | [max(0,min(x+xMaxCurr +1, h))][max(0,min(y+yMinCurr , w))]; 330 | delta += inputIntPlane 331 | [max(0,min(x+xMaxCurr , h))][max(0,min(y+yMinCurr , w))]; 332 | 333 | delta *= (x+xMaxCurr >= 0) & (x+xMaxCurr < h); 334 | 335 | *tmpArray = delta; 336 | } 337 | 338 | else if (parameter == Parameter::yMin) { 339 | if (exact) { 340 | valid = 341 | not (y+yMinCurr < 1) & not (x+xMinCurr < 1) & not (x+xMinCurr > h); 342 | cornerX = max(0,min(h-1,x+xMinCurr -1)); 343 | cornerY = max(0,min(w-1,y+yMinCurr -1)); 344 | const scalar_t tlCorner = valid * 345 | ( inputIntPlane[cornerX+1][cornerY+1] 346 | - inputIntPlane[cornerX ][cornerY+1] 347 | - inputIntPlane[cornerX+1][cornerY ] 348 | + inputIntPlane[cornerX ][cornerY ]); 349 | 350 | valid = 351 | not (y+yMinCurr < 1) & not (x+xMaxCurr < 0) & not (x+xMaxCurr >= h); 352 | cornerX = max(0,min(h-1,x+xMaxCurr )); 353 | cornerY = max(0,min(w-1,y+yMinCurr -1)); 354 | const scalar_t blCorner = valid * 355 | ( inputIntPlane[cornerX+1][cornerY+1] 356 | - inputIntPlane[cornerX ][cornerY+1] 357 | - inputIntPlane[cornerX+1][cornerY ] 358 | + inputIntPlane[cornerX ][cornerY ]); 359 | 360 | delta += tlCorner * xMinCurrFrac; 361 | delta += blCorner * xMaxCurrFrac; 362 | } // if (exact) 363 | 364 | delta += inputIntPlane 365 | [max(0,min(x+xMaxCurr , h))][max(0,min(y+yMinCurr , w))]; 366 | delta -= inputIntPlane 367 | [max(0,min(x+xMaxCurr , h))][max(0,min(y+yMinCurr -1, w))]; 368 | delta -= inputIntPlane 369 | [max(0,min(x+xMinCurr , h))][max(0,min(y+yMinCurr , w))]; 370 | delta += inputIntPlane 371 | [max(0,min(x+xMinCurr , h))][max(0,min(y+yMinCurr -1, w))]; 372 | 373 | delta *= (y+yMinCurr >= 1) & (y+yMinCurr <= w); 374 | 375 | *tmpArray = -delta; 376 | } 377 | 378 | else if (parameter == Parameter::yMax) { 379 | if (exact) { 380 | valid = 381 | not (y+yMaxCurr >= w) & not (x+xMinCurr < 1) & not (x+xMinCurr > h); 382 | cornerX = max(0,min(h-1,x+xMinCurr -1)); 383 | cornerY = max(0,min(w-1,y+yMaxCurr )); 384 | const scalar_t trCorner = valid * 385 | ( inputIntPlane[cornerX+1][cornerY+1] 386 | - inputIntPlane[cornerX ][cornerY+1] 387 | - inputIntPlane[cornerX+1][cornerY ] 388 | + inputIntPlane[cornerX ][cornerY ]); 389 | 390 | valid = 391 | not (y+yMaxCurr >= w) & not (x+xMaxCurr < 0) & not (x+xMaxCurr >= h); 392 | cornerX = max(0,min(h-1,x+xMaxCurr )); 393 | cornerY = max(0,min(w-1,y+yMaxCurr )); 394 | const scalar_t brCorner = valid * 395 | ( inputIntPlane[cornerX+1][cornerY+1] 396 | - inputIntPlane[cornerX ][cornerY+1] 397 | - inputIntPlane[cornerX+1][cornerY ] 398 | + inputIntPlane[cornerX ][cornerY ]); 399 | 400 | delta += trCorner * xMinCurrFrac; 401 | delta += brCorner * xMaxCurrFrac; 402 | } // if (exact) 403 | 404 | delta += inputIntPlane 405 | [max(0,min(x+xMaxCurr , h))][max(0,min(y+yMaxCurr +1, w))]; 406 | delta -= inputIntPlane 407 | [max(0,min(x+xMaxCurr , h))][max(0,min(y+yMaxCurr , w))]; 408 | delta -= inputIntPlane 409 | [max(0,min(x+xMinCurr , h))][max(0,min(y+yMaxCurr +1, w))]; 410 | delta += inputIntPlane 411 | [max(0,min(x+xMinCurr , h))][max(0,min(y+yMaxCurr , w))]; 412 | 413 | delta *= (y+yMaxCurr >= 0) & (y+yMaxCurr < w); 414 | 415 | *tmpArray = delta; 416 | } 417 | } 418 | } 419 | 420 | template 421 | void boxConvAccGradParameters( 422 | // tmpArray size: {batchSize, nInputPlanes, numFilters, h, w} 423 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 424 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 425 | at::Tensor & input_integrated, at::Tensor & tmpArray, Parameter parameter) { 426 | 427 | // TODO switch to square blocks? 428 | const int threadsNeeded = tmpArray.numel(); 429 | int numBlocks = (threadsNeeded + NUM_THREADS - 1) / NUM_THREADS; 430 | 431 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tmpArray.scalar_type(), "gpu::boxConvAccGradParameters", ([&] { 432 | auto inputIntFlattened = input_integrated.view( 433 | {-1, input_integrated.size(-2), input_integrated.size(-1)}); 434 | auto inputIntAcsr = 435 | inputIntFlattened.packed_accessor32(); 436 | 437 | switch (parameter) { 438 | case Parameter::xMin: 439 | boxConvAccGradParametersKernel 440 | <<>> ( 441 | inputIntAcsr, tmpArray.data_ptr(), 442 | xMinInt.data_ptr(), xMaxInt.data_ptr(), 443 | yMinInt.data_ptr(), yMaxInt.data_ptr(), 444 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(), 445 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(), xMinInt.numel()); break; 446 | case Parameter::xMax: 447 | boxConvAccGradParametersKernel 448 | <<>> ( 449 | inputIntAcsr, tmpArray.data_ptr(), 450 | xMinInt.data_ptr(), xMaxInt.data_ptr(), 451 | yMinInt.data_ptr(), yMaxInt.data_ptr(), 452 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(), 453 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(), xMinInt.numel()); break; 454 | case Parameter::yMin: 455 | boxConvAccGradParametersKernel 456 | <<>> ( 457 | inputIntAcsr, tmpArray.data_ptr(), 458 | xMinInt.data_ptr(), xMaxInt.data_ptr(), 459 | yMinInt.data_ptr(), yMaxInt.data_ptr(), 460 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(), 461 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(), xMinInt.numel()); break; 462 | case Parameter::yMax: 463 | boxConvAccGradParametersKernel 464 | <<>> ( 465 | inputIntAcsr, tmpArray.data_ptr(), 466 | xMinInt.data_ptr(), xMaxInt.data_ptr(), 467 | yMinInt.data_ptr(), yMaxInt.data_ptr(), 468 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(), 469 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(), xMinInt.numel()); break; 470 | } 471 | THCudaCheck(cudaGetLastError()); 472 | })); 473 | } 474 | 475 | // explicitly instantiate 476 | template void boxConvAccGradParameters( 477 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 478 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 479 | at::Tensor &, at::Tensor &, Parameter); 480 | 481 | template void boxConvAccGradParameters( 482 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 483 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 484 | at::Tensor &, at::Tensor &, Parameter); 485 | 486 | } 487 | -------------------------------------------------------------------------------- /src/box_convolution_cuda_forward.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #define BLOCK_SIZE 256 7 | 8 | using std::min; 9 | using std::max; 10 | 11 | #include "box_convolution.h" // for `enum class Parameter` 12 | 13 | namespace gpu { 14 | 15 | // TODO use constant memory when possible 16 | // namespace constant { 17 | // __constant__ float xMinFrac[1536], xMaxFrac[1536]; 18 | // __constant__ float yMinFrac[1536], yMaxFrac[1536]; 19 | // __constant__ int xMinInt[1536], xMaxInt[1536]; 20 | // __constant__ int yMinInt[1536], yMaxInt[1536]; 21 | // __constant__ float area[1536]; 22 | // } 23 | 24 | template 25 | using CudaAcsr = const at::PackedTensorAccessor32; 26 | 27 | // overload for "truncated"/"rounded" mode 28 | template 29 | __global__ void boxConvUpdateOutputKernel( 30 | CudaAcsr inputInt, CudaAcsr output, 31 | const int32_t * __restrict__ xMinInt, const int32_t * __restrict__ xMaxInt, 32 | const int32_t * __restrict__ yMinInt, const int32_t * __restrict__ yMaxInt, 33 | const scalar_t * __restrict__ area) { 34 | 35 | // `output` size: `batch_size x in_planes x num_filters x h x w` 36 | const int32_t y = blockDim.y * blockIdx.y + threadIdx.y; 37 | const int32_t x = blockDim.z * blockIdx.z + threadIdx.z; 38 | const int32_t inPlaneIdx = blockIdx.x / output.size(2); 39 | const int32_t paramIdx = blockIdx.x % (output.size(1) * output.size(2)); 40 | const int32_t h = output.size(3); 41 | const int32_t w = output.size(4); 42 | 43 | const auto inputIntPlane = inputInt[inPlaneIdx]; 44 | 45 | if (x < h and y < w) { 46 | // Must add 1 to xMax/yMax/xMin/yMin due to OpenCV's 47 | // `integral()` behavior. Namely, I(x,0) and I(0,y) are 48 | // always 0 (so it's a C-style array sum). 49 | 50 | // However, when computing sums, we subtract values at points 51 | // like y+yMin-1 and x+xMin-1, so we also SUBTRACT 1 from xMin 52 | // and yMin, and thus finally they are not affected. 53 | 54 | const int32_t t = max(0, min(x+xMinInt[paramIdx], h)); 55 | const int32_t b = max(0, min(x+xMaxInt[paramIdx], h)); 56 | const int32_t l = max(0, min(y+yMinInt[paramIdx], w)); 57 | const int32_t r = max(0, min(y+yMaxInt[paramIdx], w)); 58 | 59 | scalar_t outValue = 0; 60 | 61 | outValue += inputIntPlane[b][r]; 62 | outValue -= inputIntPlane[t][r]; 63 | outValue -= inputIntPlane[b][l]; 64 | outValue += inputIntPlane[t][l]; 65 | 66 | // TODO error: expression must be a modifiable lvalue 67 | output.data()[(blockIdx.x * h + x) * w + y] = 68 | outValue * (normalize ? area[paramIdx] : static_cast(1)); 69 | } 70 | } 71 | 72 | // overload for "exact" mode 73 | template 74 | __global__ void boxConvUpdateOutputKernel( 75 | CudaAcsr inputInt, CudaAcsr output, 76 | const int32_t * __restrict__ xMinInt, const int32_t * __restrict__ xMaxInt, 77 | const int32_t * __restrict__ yMinInt, const int32_t * __restrict__ yMaxInt, 78 | const scalar_t * __restrict__ xMinFrac, const scalar_t * __restrict__ xMaxFrac, 79 | const scalar_t * __restrict__ yMinFrac, const scalar_t * __restrict__ yMaxFrac, 80 | const scalar_t * __restrict__ area) { 81 | 82 | const int32_t y = blockDim.y * blockIdx.y + threadIdx.y; 83 | const int32_t x = blockDim.z * blockIdx.z + threadIdx.z; 84 | const int32_t inPlaneIdx = blockIdx.x / output.size(2); 85 | const int32_t paramIdx = blockIdx.x % (output.size(1) * output.size(2)); 86 | const int32_t h = output.size(3); 87 | const int32_t w = output.size(4); 88 | 89 | const auto inputIntPlane = inputInt[inPlaneIdx]; 90 | 91 | if (x < h and y < w) { 92 | // Must add 1 to xMax/yMax/xMin/yMin due to OpenCV's 93 | // `integral()` behavior. Namely, I(x,0) and I(0,y) are 94 | // always 0 (so it's a C-style array sum). 95 | 96 | // However, when computing sums, we subtract values at points 97 | // like y+yMin-1 and x+xMin-1, so we also SUBTRACT 1 from xMin 98 | // and yMin, and thus finally they are not affected. 99 | const int xMinCurr = xMinInt[paramIdx]; 100 | const int xMaxCurr = xMaxInt[paramIdx]; 101 | const int yMinCurr = yMinInt[paramIdx]; 102 | const int yMaxCurr = yMaxInt[paramIdx]; 103 | 104 | const scalar_t xMinCurrFrac = xMinFrac[paramIdx]; 105 | const scalar_t xMaxCurrFrac = xMaxFrac[paramIdx]; 106 | const scalar_t yMinCurrFrac = yMinFrac[paramIdx]; 107 | const scalar_t yMaxCurrFrac = yMaxFrac[paramIdx]; 108 | 109 | const int32_t t = max(0, min(x+xMinCurr, h)); 110 | const int32_t b = max(0, min(x+xMaxCurr, h)); 111 | const int32_t l = max(0, min(y+yMinCurr, w)); 112 | const int32_t r = max(0, min(y+yMaxCurr, w)); 113 | 114 | const int32_t bAdv = max(0, min(x+xMaxCurr+1, h)); 115 | const int32_t rAdv = max(0, min(y+yMaxCurr+1, w)); 116 | const int32_t tAdv = max(0, min(x+xMinCurr-1, h)); 117 | const int32_t lAdv = max(0, min(y+yMinCurr-1, w)); 118 | 119 | scalar_t outValue; 120 | 121 | // -- main area 122 | outValue = 123 | inputIntPlane[b][r] 124 | - inputIntPlane[t][r] 125 | - inputIntPlane[b][l] 126 | + inputIntPlane[t][l]; 127 | 128 | // -- xMax border 129 | outValue += 130 | ( inputIntPlane[bAdv][r] 131 | - inputIntPlane[b ][r] 132 | - inputIntPlane[bAdv][l] 133 | + inputIntPlane[b ][l]) * xMaxCurrFrac; 134 | 135 | // -- yMax border 136 | outValue += 137 | ( inputIntPlane[b][rAdv] 138 | - inputIntPlane[b][r ] 139 | - inputIntPlane[t][rAdv] 140 | + inputIntPlane[t][r ]) * yMaxCurrFrac; 141 | 142 | // -- xMin border 143 | outValue += 144 | ( inputIntPlane[t ][r] 145 | - inputIntPlane[tAdv][r] 146 | - inputIntPlane[t ][l] 147 | + inputIntPlane[tAdv][l]) * xMinCurrFrac; 148 | 149 | // -- yMin border 150 | outValue += 151 | ( inputIntPlane[b][l ] 152 | - inputIntPlane[b][lAdv] 153 | - inputIntPlane[t][l ] 154 | + inputIntPlane[t][lAdv]) * yMinCurrFrac; 155 | 156 | // -- corner pixels 157 | // Note: before, I used plain `input` to access corner values 158 | // with lower memory access overhead. Moved to `input_integrated` 159 | // to get rid of an extra input to this function. 160 | if (not ((x+xMaxCurr >= h) | (y+yMaxCurr >= w) | 161 | (x+xMaxCurr < 0) | (y+yMaxCurr < 0))) { 162 | outValue += 163 | xMaxCurrFrac * yMaxCurrFrac * 164 | ( inputIntPlane[b+1][r+1] 165 | - inputIntPlane[b ][r+1] 166 | - inputIntPlane[b+1][r ] 167 | + inputIntPlane[b ][r ]); 168 | } 169 | 170 | if (not ((x+xMinCurr > h) | (y+yMaxCurr >= w) | 171 | (x+xMinCurr <= 0) | (y+yMaxCurr < 0))) { 172 | outValue += 173 | xMinCurrFrac * yMaxCurrFrac * 174 | ( inputIntPlane[t ][r+1] 175 | - inputIntPlane[t-1][r+1] 176 | - inputIntPlane[t ][r ] 177 | + inputIntPlane[t-1][r ]); 178 | } 179 | 180 | if (not ((x+xMaxCurr >= h) | (y+yMinCurr > w) | 181 | (x+xMaxCurr < 0) | (y+yMinCurr <= 0))) { 182 | outValue += 183 | xMaxCurrFrac * yMinCurrFrac * 184 | ( inputIntPlane[b+1][l ] 185 | - inputIntPlane[b ][l ] 186 | - inputIntPlane[b+1][l-1] 187 | + inputIntPlane[b ][l-1]); 188 | } 189 | 190 | if (not ((x+xMinCurr > h) | (y+yMinCurr > w) | 191 | (x+xMinCurr <= 0) | (y+yMinCurr <= 0))) { 192 | outValue += 193 | xMinCurrFrac * yMinCurrFrac * 194 | ( inputIntPlane[t ][l ] 195 | - inputIntPlane[t-1][l ] 196 | - inputIntPlane[t ][l-1] 197 | + inputIntPlane[t-1][l-1]); 198 | } 199 | 200 | // TODO error: expression must be a modifiable lvalue 201 | output.data()[(blockIdx.x * h + x) * w + y] = 202 | outValue * (normalize ? area[paramIdx] : static_cast(1)); 203 | } 204 | } 205 | 206 | // TODO put split params and area into constant memory 207 | template 208 | void boxConvUpdateOutput( 209 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 210 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 211 | at::Tensor & area, at::Tensor & input_integrated, at::Tensor & output) { 212 | 213 | // was `const int`, but had to remove `const` to work around a bug in GCC 5 214 | int h = output.size(-2); 215 | int w = output.size(-1); 216 | const int totalOutputChannels = output.numel() / (h * w); 217 | 218 | const dim3 blockSize(1, 32, 32); 219 | const dim3 gridSize( 220 | (totalOutputChannels + blockSize.x - 1) / blockSize.x, 221 | (w + blockSize.y - 1) / blockSize.y, 222 | (h + blockSize.z - 1) / blockSize.z); 223 | 224 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(output.scalar_type(), "gpu::boxConvUpdateOutput", ([&] { 225 | 226 | auto inputIntFlattened = input_integrated.view({-1, h+1, w+1}); 227 | auto inputIntAcsr = 228 | inputIntFlattened.packed_accessor32(); 229 | 230 | auto outputAcsr = output.packed_accessor32(); 231 | 232 | if (exact) { 233 | boxConvUpdateOutputKernel 234 | <<>> ( 235 | inputIntAcsr, outputAcsr, 236 | xMinInt.data_ptr(), xMaxInt.data_ptr(), 237 | yMinInt.data_ptr(), yMaxInt.data_ptr(), 238 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(), 239 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(), 240 | normalize ? area.data_ptr() : nullptr); 241 | } else { 242 | boxConvUpdateOutputKernel 243 | <<>> ( 244 | inputIntAcsr, outputAcsr, 245 | xMinInt.data_ptr(), xMaxInt.data_ptr(), 246 | yMinInt.data_ptr(), yMaxInt.data_ptr(), 247 | normalize ? area.data_ptr() : nullptr); 248 | } 249 | THCudaCheck(cudaGetLastError()); 250 | })); 251 | } 252 | 253 | // explicitly instantiate 254 | template void boxConvUpdateOutput( 255 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 256 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 257 | at::Tensor &, at::Tensor &, at::Tensor &); 258 | 259 | template void boxConvUpdateOutput( 260 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 261 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 262 | at::Tensor &, at::Tensor &, at::Tensor &); 263 | 264 | template void boxConvUpdateOutput( 265 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 266 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 267 | at::Tensor &, at::Tensor &, at::Tensor &); 268 | 269 | template void boxConvUpdateOutput( 270 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 271 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 272 | at::Tensor &, at::Tensor &, at::Tensor &); 273 | 274 | } 275 | -------------------------------------------------------------------------------- /src/box_convolution_cuda_misc.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include "box_convolution.h" // for `enum class Parameter` 7 | 8 | #define BLOCK_SIZE 256 9 | 10 | using std::min; 11 | using std::max; 12 | 13 | namespace gpu { 14 | 15 | // TODO make sure xMin and yMin threads don't fall into same warp 16 | template 17 | __global__ void splitParametersKernel( 18 | const scalar_t * __restrict__ xMin, const scalar_t * __restrict__ xMax, 19 | const scalar_t * __restrict__ yMin, const scalar_t * __restrict__ yMax, 20 | int32_t * __restrict__ xMinInt, int32_t * __restrict__ xMaxInt, 21 | int32_t * __restrict__ yMinInt, int32_t * __restrict__ yMaxInt, 22 | scalar_t * __restrict__ xMinFrac, scalar_t * __restrict__ xMaxFrac, 23 | scalar_t * __restrict__ yMinFrac, scalar_t * __restrict__ yMaxFrac, 24 | const int nParameters) { 25 | 26 | const int idx = BLOCK_SIZE * blockIdx.x + threadIdx.x; 27 | if (idx < 2 * nParameters) { 28 | const int paramIndex = idx < nParameters ? idx : idx - nParameters; 29 | 30 | const scalar_t *param; 31 | scalar_t *fracParam; 32 | int32_t *intParam; 33 | 34 | param = idx < nParameters ? xMin : yMin; 35 | fracParam = idx < nParameters ? xMinFrac : yMinFrac; 36 | intParam = idx < nParameters ? xMinInt : yMinInt; 37 | 38 | const scalar_t minInt = std::ceil(param[paramIndex]); 39 | fracParam[paramIndex] = minInt - param[paramIndex]; 40 | intParam[paramIndex] = static_cast(minInt); 41 | 42 | param = idx < nParameters ? xMax : yMax; 43 | fracParam = idx < nParameters ? xMaxFrac : yMaxFrac; 44 | intParam = idx < nParameters ? xMaxInt : yMaxInt; 45 | 46 | const scalar_t maxInt = std::floor(param[paramIndex]); 47 | fracParam[paramIndex] = param[paramIndex] - maxInt; 48 | intParam[paramIndex] = static_cast(maxInt) + 1; 49 | } 50 | } 51 | 52 | void splitParameters( 53 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 54 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 55 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac) { 56 | 57 | const int threadsNeeded = 2 * x_min.numel(); 58 | const int numBlocks = (threadsNeeded + BLOCK_SIZE - 1) / BLOCK_SIZE; 59 | 60 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_min.scalar_type(), "gpu::splitParameters", ([&] { 61 | splitParametersKernel 62 | <<>> ( 63 | x_min.data_ptr(), x_max.data_ptr(), 64 | y_min.data_ptr(), y_max.data_ptr(), 65 | xMinInt.data_ptr(), xMaxInt.data_ptr(), 66 | yMinInt.data_ptr(), yMaxInt.data_ptr(), 67 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(), 68 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(), 69 | x_min.numel()); 70 | THCudaCheck(cudaGetLastError()); 71 | })); 72 | } 73 | 74 | template 75 | __global__ void splitParametersUpdateGradInputKernel( 76 | const scalar_t * __restrict__ xMin, const scalar_t * __restrict__ xMax, 77 | const scalar_t * __restrict__ yMin, const scalar_t * __restrict__ yMax, 78 | int32_t * __restrict__ xMinInt, int32_t * __restrict__ xMaxInt, 79 | int32_t * __restrict__ yMinInt, int32_t * __restrict__ yMaxInt, 80 | scalar_t * __restrict__ xMinFrac, scalar_t * __restrict__ xMaxFrac, 81 | scalar_t * __restrict__ yMinFrac, scalar_t * __restrict__ yMaxFrac, 82 | const int nParameters) { 83 | 84 | const int idx = BLOCK_SIZE * blockIdx.x + threadIdx.x; 85 | if (idx < 2 * nParameters) { 86 | const int paramIndex = idx < nParameters ? idx : idx - nParameters; 87 | 88 | const scalar_t *param; 89 | scalar_t *fracParam; 90 | int32_t *intParam; 91 | 92 | param = idx < nParameters ? xMax : yMax; // note: min/max swapped 93 | fracParam = idx < nParameters ? xMinFrac : yMinFrac; 94 | intParam = idx < nParameters ? xMinInt : yMinInt; 95 | 96 | const scalar_t minInt = std::ceil(-param[paramIndex]); 97 | fracParam[paramIndex] = minInt + param[paramIndex]; 98 | intParam[paramIndex] = static_cast(minInt); 99 | 100 | param = idx < nParameters ? xMin : yMin; // note: min/max swapped 101 | fracParam = idx < nParameters ? xMaxFrac : yMaxFrac; 102 | intParam = idx < nParameters ? xMaxInt : yMaxInt; 103 | 104 | const scalar_t maxInt = std::floor(-param[paramIndex]) + 1; 105 | fracParam[paramIndex] = -param[paramIndex] + 1 - maxInt; 106 | intParam[paramIndex] = static_cast(maxInt); 107 | } 108 | } 109 | 110 | void splitParametersUpdateGradInput( 111 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 112 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 113 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac) { 114 | 115 | const int threadsNeeded = 2 * x_min.numel(); 116 | const int numBlocks = (threadsNeeded + BLOCK_SIZE - 1) / BLOCK_SIZE; 117 | 118 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_min.scalar_type(), "gpu::splitParametersUpdateGradInput", ([&] { 119 | splitParametersUpdateGradInputKernel 120 | <<>> ( 121 | x_min.data_ptr(), x_max.data_ptr(), 122 | y_min.data_ptr(), y_max.data_ptr(), 123 | xMinInt.data_ptr(), xMaxInt.data_ptr(), 124 | yMinInt.data_ptr(), yMaxInt.data_ptr(), 125 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(), 126 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(), 127 | x_min.numel()); 128 | THCudaCheck(cudaGetLastError()); 129 | })); 130 | } 131 | 132 | template 133 | __global__ void splitParametersAccGradParametersKernel( 134 | const scalar_t * __restrict__ xMin, const scalar_t * __restrict__ xMax, 135 | const scalar_t * __restrict__ yMin, const scalar_t * __restrict__ yMax, 136 | int32_t * __restrict__ xMinInt, int32_t * __restrict__ xMaxInt, 137 | int32_t * __restrict__ yMinInt, int32_t * __restrict__ yMaxInt, 138 | scalar_t * __restrict__ xMinFrac, scalar_t * __restrict__ xMaxFrac, 139 | scalar_t * __restrict__ yMinFrac, scalar_t * __restrict__ yMaxFrac, 140 | const int nParameters) { 141 | 142 | const int idx = BLOCK_SIZE * blockIdx.x + threadIdx.x; 143 | if (idx < 2 * nParameters) { 144 | const int paramIndex = idx < nParameters ? idx : idx - nParameters; 145 | 146 | const scalar_t *param; 147 | scalar_t *fracParam; 148 | int32_t *intParam; 149 | 150 | param = idx < nParameters ? xMin : yMin; 151 | fracParam = idx < nParameters ? xMinFrac : yMinFrac; 152 | intParam = idx < nParameters ? xMinInt : yMinInt; 153 | 154 | const scalar_t minInt = std::ceil(param[paramIndex] - 1); 155 | fracParam[paramIndex] = minInt - param[paramIndex] + 1; 156 | intParam[paramIndex] = static_cast(minInt); 157 | 158 | param = idx < nParameters ? xMax : yMax; 159 | fracParam = idx < nParameters ? xMaxFrac : yMaxFrac; 160 | intParam = idx < nParameters ? xMaxInt : yMaxInt; 161 | 162 | const scalar_t maxInt = std::floor(param[paramIndex]); 163 | fracParam[paramIndex] = param[paramIndex] - maxInt; 164 | intParam[paramIndex] = static_cast(maxInt); 165 | } 166 | } 167 | 168 | void splitParametersAccGradParameters( 169 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 170 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 171 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac) { 172 | 173 | const int threadsNeeded = 2 * x_min.numel(); 174 | const int numBlocks = (threadsNeeded + BLOCK_SIZE - 1) / BLOCK_SIZE; 175 | 176 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_min.scalar_type(), "gpu::splitParametersAccGradParams", ([&] { 177 | splitParametersAccGradParametersKernel 178 | <<>> ( 179 | x_min.data_ptr(), x_max.data_ptr(), 180 | y_min.data_ptr(), y_max.data_ptr(), 181 | xMinInt.data_ptr(), xMaxInt.data_ptr(), 182 | yMinInt.data_ptr(), yMaxInt.data_ptr(), 183 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(), 184 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(), 185 | x_min.numel()); 186 | THCudaCheck(cudaGetLastError()); 187 | })); 188 | } 189 | 190 | template 191 | __global__ void clipParametersKernel( 192 | scalar_t * __restrict__ paramMin, scalar_t * __restrict__ paramMax, 193 | const double inverseReparam, const double minSize, const double maxSize, const int nElements) { 194 | 195 | const int idx = BLOCK_SIZE * blockIdx.x + threadIdx.x; 196 | 197 | if (idx < nElements) { 198 | double minValue, maxValue; 199 | const double paramMinCurrent = static_cast(paramMin[idx]); 200 | const double paramMaxCurrent = static_cast(paramMax[idx]); 201 | 202 | // clamp parameters 203 | minValue = max(-(maxSize+1) * inverseReparam, 204 | min((maxSize-1) * inverseReparam, paramMinCurrent)); 205 | maxValue = max(-(maxSize+1) * inverseReparam, 206 | min((maxSize-1) * inverseReparam, paramMaxCurrent)); 207 | 208 | // make sure bottom/right border doesn't come before top/left 209 | if (minValue + (minSize - 0.9999) * inverseReparam > maxValue) { 210 | const scalar_t mean = 0.5 * (minValue + maxValue); 211 | minValue = mean - 0.5 * (minSize - 0.9999) * inverseReparam; 212 | maxValue = mean + 0.5 * (minSize - 0.9999) * inverseReparam; 213 | } 214 | 215 | paramMin[idx] = static_cast(minValue); 216 | paramMax[idx] = static_cast(maxValue); 217 | } 218 | } 219 | 220 | void clipParameters( 221 | at::Tensor & paramMin, at::Tensor & paramMax, 222 | const double reparametrization, const double minSize, const double maxSize) { 223 | 224 | const int threadsNeeded = paramMin.numel(); 225 | const int numBlocks = (threadsNeeded + BLOCK_SIZE - 1) / BLOCK_SIZE; 226 | 227 | const double inverseReparam = 1.0 / reparametrization; 228 | 229 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(paramMin.scalar_type(), "gpu::clipParameters", ([&] { 230 | clipParametersKernel 231 | <<>> ( 232 | paramMin.data_ptr(), paramMax.data_ptr(), 233 | inverseReparam, minSize, maxSize, paramMin.numel()); 234 | THCudaCheck(cudaGetLastError()); 235 | })); 236 | } 237 | 238 | template 239 | __global__ void computeAreaKernel( 240 | scalar_t * __restrict__ x_min, scalar_t * __restrict__ x_max, 241 | scalar_t * __restrict__ y_min, scalar_t * __restrict__ y_max, 242 | scalar_t * __restrict__ retval, const int nElements) { 243 | 244 | const int idx = BLOCK_SIZE * blockIdx.x + threadIdx.x; 245 | 246 | if (idx < nElements) { 247 | const scalar_t area = 248 | (needXDeriv ? x_max[idx]-x_min[idx]+1 : static_cast(1)) * 249 | (needYDeriv ? y_max[idx]-y_min[idx]+1 : static_cast(1)); 250 | retval[idx] = 1 / area; 251 | } 252 | } 253 | 254 | at::Tensor computeArea( 255 | at::Tensor x_min, at::Tensor x_max, at::Tensor y_min, at::Tensor y_max, 256 | const bool exact, const bool needXDeriv, const bool needYDeriv) { 257 | 258 | // TODO: how to stop tracking operations??? `.is_variable_(false)` doesn't work 259 | auto retval = at::empty_like(x_min); 260 | 261 | if (not exact) { 262 | x_min = x_min.ceil(); 263 | y_min = y_min.ceil(); 264 | x_max = x_max.floor(); 265 | y_max = y_max.floor(); 266 | } 267 | 268 | const int threadsNeeded = x_min.numel(); 269 | const int numBlocks = (threadsNeeded + BLOCK_SIZE - 1) / BLOCK_SIZE; 270 | 271 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_min.scalar_type(), "gpu::computeArea", ([&] { 272 | if (needXDeriv) { 273 | if (needYDeriv) { 274 | computeAreaKernel 275 | <<>> ( 276 | x_min.data_ptr(), x_max.data_ptr(), 277 | y_min.data_ptr(), y_max.data_ptr(), 278 | retval.data_ptr(), x_min.numel()); 279 | } else { 280 | computeAreaKernel 281 | <<>> ( 282 | x_min.data_ptr(), x_max.data_ptr(), 283 | y_min.data_ptr(), y_max.data_ptr(), 284 | retval.data_ptr(), x_min.numel()); 285 | } 286 | } else { 287 | if (needYDeriv) { 288 | computeAreaKernel 289 | <<>> ( 290 | x_min.data_ptr(), x_max.data_ptr(), 291 | y_min.data_ptr(), y_max.data_ptr(), 292 | retval.data_ptr(), x_min.numel()); 293 | } else { 294 | THError("computeArea called with needXDeriv == needYDeriv == false"); 295 | } 296 | } 297 | })); 298 | THCudaCheck(cudaGetLastError()); 299 | 300 | return retval; 301 | } 302 | 303 | } 304 | -------------------------------------------------------------------------------- /src/box_convolution_interface.cpp: -------------------------------------------------------------------------------- 1 | /* Functions actually called from Python. Registered in torch module in `bind.cpp` */ 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "box_convolution.h" 8 | 9 | at::Tensor integral_image(at::Tensor input); 10 | 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | 13 | at::Tensor box_convolution_forward( 14 | at::Tensor input_integrated, 15 | at::Tensor x_min, at::Tensor x_max, 16 | at::Tensor y_min, at::Tensor y_max, 17 | const bool normalize, const bool exact) { 18 | 19 | TORCH_CHECK(input_integrated.device() == x_min.device(), 20 | "BoxConv2d: input and parameters are on different devices"); 21 | 22 | input_integrated = input_integrated.contiguous(); // TODO support noncontiguous too 23 | TORCH_CHECK(input_integrated.dim() == 4, "BoxConv2d: input must have 4 dimensions"); 24 | TORCH_CHECK( 25 | x_min.dim() == 2 and x_max.dim() == 2 and y_min.dim() == 2 and y_max.dim() == 2, 26 | "BoxConv2d: all parameters must have 2 dimensions"); 27 | TORCH_CHECK( 28 | x_min.size(0) == x_max.size(0) and x_min.size(0) == y_min.size(0) and 29 | x_min.size(0) == y_max.size(0) and x_min.size(0) == input_integrated.size(1), 30 | "BoxConv2d: all parameters must have as many rows as there are input channels"); 31 | TORCH_CHECK( 32 | x_min.size(1) == x_max.size(1) and x_min.size(1) == y_min.size(1) and 33 | x_min.size(1) == y_max.size(1), 34 | "BoxConv2d: all parameters must have equal number of columns"); 35 | 36 | // Split x_min, x_max, y_min, y_max into integer and fractional parts 37 | auto intOptions = x_min.options().dtype(at::ScalarType::Int); 38 | auto xMinInt = at::empty(x_min.sizes(), intOptions); 39 | auto xMaxInt = at::empty(x_min.sizes(), intOptions); 40 | auto yMinInt = at::empty(x_min.sizes(), intOptions); 41 | auto yMaxInt = at::empty(x_min.sizes(), intOptions); 42 | 43 | auto fracOptions = x_min.options(); 44 | auto xMinFrac = at::empty(x_min.sizes(), fracOptions); 45 | auto xMaxFrac = at::empty(x_min.sizes(), fracOptions); 46 | auto yMinFrac = at::empty(x_min.sizes(), fracOptions); 47 | auto yMaxFrac = at::empty(x_min.sizes(), fracOptions); 48 | 49 | // inverse box areas for normalization 50 | at::Tensor area; 51 | 52 | if (x_min.is_cuda()) { 53 | gpu::splitParameters( 54 | x_min , x_max , y_min , y_max , 55 | xMinInt , xMaxInt , yMinInt , yMaxInt , 56 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac); 57 | 58 | if (normalize) { 59 | area = gpu::computeArea(x_min, x_max, y_min, y_max, exact); 60 | } 61 | } else { 62 | cpu::splitParameters( 63 | x_min , x_max , y_min , y_max , 64 | xMinInt , xMaxInt , yMinInt , yMaxInt , 65 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac); 66 | 67 | if (normalize) { 68 | area = cpu::computeArea(x_min, x_max, y_min, y_max, exact); 69 | } 70 | } 71 | 72 | const int batchSize = input_integrated.size(0); 73 | const int nInputPlanes = input_integrated.size(1); 74 | const int numFilters = x_min.size(1); 75 | const int h = input_integrated.size(2) - 1; 76 | const int w = input_integrated.size(3) - 1; 77 | 78 | // Output will be 1 pixel smaller and have `num_filters` channels per each input channel 79 | auto output = at::empty( 80 | {batchSize, nInputPlanes, numFilters, h, w}, input_integrated.options()); 81 | 82 | // Actually fill `output` 83 | if (input_integrated.is_cuda()) { 84 | // TODO what is the common practice of avoiding such `if`s? 85 | if (normalize) { 86 | if (exact) { 87 | gpu::boxConvUpdateOutput( 88 | xMinInt , xMaxInt , yMinInt , yMaxInt , 89 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 90 | area, input_integrated, output); 91 | } else { 92 | gpu::boxConvUpdateOutput( 93 | xMinInt , xMaxInt , yMinInt , yMaxInt , 94 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 95 | area, input_integrated, output); 96 | } 97 | } else { 98 | if (exact) { 99 | gpu::boxConvUpdateOutput( 100 | xMinInt , xMaxInt , yMinInt , yMaxInt , 101 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 102 | area, input_integrated, output); 103 | } else { 104 | gpu::boxConvUpdateOutput( 105 | xMinInt , xMaxInt , yMinInt , yMaxInt , 106 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 107 | area, input_integrated, output); 108 | } 109 | } 110 | } else { 111 | if (normalize) { 112 | if (exact) { 113 | cpu::boxConvUpdateOutput( 114 | xMinInt , xMaxInt , yMinInt , yMaxInt , 115 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 116 | area, input_integrated, output); 117 | } else { 118 | cpu::boxConvUpdateOutput( 119 | xMinInt , xMaxInt , yMinInt , yMaxInt , 120 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 121 | area, input_integrated, output); 122 | } 123 | } else { 124 | if (exact) { 125 | cpu::boxConvUpdateOutput( 126 | xMinInt , xMaxInt , yMinInt , yMaxInt , 127 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 128 | area, input_integrated, output); 129 | } else { 130 | cpu::boxConvUpdateOutput( 131 | xMinInt , xMaxInt , yMinInt , yMaxInt , 132 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 133 | area, input_integrated, output); 134 | } 135 | } 136 | } 137 | 138 | return output.reshape({batchSize, nInputPlanes * numFilters, h, w}); 139 | } 140 | 141 | std::vector box_convolution_backward( 142 | at::Tensor input_integrated, 143 | at::Tensor x_min, at::Tensor x_max, 144 | at::Tensor y_min, at::Tensor y_max, 145 | at::Tensor grad_output, at::Tensor output, 146 | const float reparametrization_h, const float reparametrization_w, 147 | const bool normalize, const bool exact, 148 | const bool input_needs_grad, 149 | const bool x_min_needs_grad, const bool x_max_needs_grad, 150 | const bool y_min_needs_grad, const bool y_max_needs_grad) { 151 | 152 | grad_output = grad_output.contiguous(); // TODO support noncontiguous too 153 | TORCH_CHECK(grad_output.dim() == 4, "grad_output for box_convolution must have 4 dimensions") 154 | TORCH_CHECK( 155 | grad_output.size(0) == input_integrated.size(0) and 156 | grad_output.size(1) == input_integrated.size(1) * x_min.size(1) and 157 | grad_output.size(2) == input_integrated.size(2) - 1 and 158 | grad_output.size(2) == input_integrated.size(2) - 1, 159 | "box_convolution: sizes of grad_output and input_integrated don't match"); 160 | TORCH_CHECK( 161 | x_min.dim() == 2 and x_max.dim() == 2 and y_min.dim() == 2 and y_max.dim() == 2, 162 | "all box conv parameters must have 2 dimensions"); 163 | TORCH_CHECK( 164 | x_min.size(0) == x_max.size(0) and x_min.size(0) == y_min.size(0) and 165 | x_min.size(0) == y_max.size(0) and x_min.size(0) == input_integrated.size(1), 166 | "all box conv parameters must have as many rows as there are input channels"); 167 | TORCH_CHECK( 168 | x_min.size(1) == x_max.size(1) and x_min.size(1) == y_min.size(1) and 169 | x_min.size(1) == y_max.size(1), 170 | "all box conv parameters must have equal number of columns"); 171 | 172 | const int batchSize = input_integrated.size(0); 173 | const int nInputPlanes = input_integrated.size(1); 174 | const int numFilters = x_min.size(1); 175 | const int h = input_integrated.size(2) - 1; 176 | const int w = input_integrated.size(3) - 1; 177 | 178 | grad_output = grad_output.reshape({batchSize, nInputPlanes, numFilters, h, w}); 179 | 180 | // Return value 181 | // TODO change `nullTensor` to Python `None` 182 | at::Tensor nullTensor = at::empty({0}, at::TensorOptions()); 183 | at::Tensor gradInput = nullTensor; 184 | 185 | // Allocate memory for splitting x_min, x_max, y_min, y_max into integer and fractional parts 186 | auto intOptions = x_min.options().dtype(at::ScalarType::Int); 187 | auto xMinInt = at::empty(x_min.sizes(), intOptions); 188 | auto xMaxInt = at::empty(x_min.sizes(), intOptions); 189 | auto yMinInt = at::empty(x_min.sizes(), intOptions); 190 | auto yMaxInt = at::empty(x_min.sizes(), intOptions); 191 | 192 | auto fracOptions = x_min.options(); 193 | auto xMinFrac = at::empty(x_min.sizes(), fracOptions); 194 | auto xMaxFrac = at::empty(x_min.sizes(), fracOptions); 195 | auto yMinFrac = at::empty(x_min.sizes(), fracOptions); 196 | auto yMaxFrac = at::empty(x_min.sizes(), fracOptions); 197 | 198 | if (input_needs_grad) { 199 | at::Tensor grad_output_integrated = integral_image(grad_output); 200 | at::Tensor tmpArray = at::empty( 201 | {batchSize, nInputPlanes, numFilters, h, w}, grad_output.options()); 202 | CHECK_CONTIGUOUS(tmpArray); 203 | 204 | at::Tensor area; // box area for normalization 205 | 206 | if (grad_output_integrated.is_cuda()) { 207 | gpu::splitParametersUpdateGradInput( 208 | x_min, x_max, y_min, y_max, 209 | xMinInt, xMaxInt, yMinInt, yMaxInt, 210 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac); 211 | 212 | if (normalize) { 213 | area = gpu::computeArea(x_min, x_max, y_min, y_max, exact); 214 | 215 | if (exact) { 216 | gpu::boxConvUpdateGradInput( 217 | xMinInt , xMaxInt , yMinInt , yMaxInt , 218 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 219 | area, grad_output_integrated, tmpArray); 220 | } else { 221 | gpu::boxConvUpdateGradInput( 222 | xMinInt , xMaxInt , yMinInt , yMaxInt , 223 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 224 | area, grad_output_integrated, tmpArray); 225 | } 226 | } else { 227 | if (exact) { 228 | gpu::boxConvUpdateGradInput( 229 | xMinInt , xMaxInt , yMinInt , yMaxInt , 230 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 231 | area, grad_output_integrated, tmpArray); 232 | } else { 233 | gpu::boxConvUpdateGradInput( 234 | xMinInt , xMaxInt , yMinInt , yMaxInt , 235 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 236 | area, grad_output_integrated, tmpArray); 237 | } 238 | } 239 | } else { 240 | cpu::splitParametersUpdateGradInput( 241 | x_min, x_max, y_min, y_max, 242 | xMinInt, xMaxInt, yMinInt, yMaxInt, 243 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac); 244 | 245 | if (normalize) { 246 | area = cpu::computeArea(x_min, x_max, y_min, y_max, exact); 247 | 248 | if (exact) { 249 | cpu::boxConvUpdateGradInput( 250 | xMinInt , xMaxInt , yMinInt , yMaxInt , 251 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 252 | area, grad_output_integrated, tmpArray); 253 | } else { 254 | cpu::boxConvUpdateGradInput( 255 | xMinInt , xMaxInt , yMinInt , yMaxInt , 256 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 257 | area, grad_output_integrated, tmpArray); 258 | } 259 | } else { 260 | if (exact) { 261 | cpu::boxConvUpdateGradInput( 262 | xMinInt , xMaxInt , yMinInt , yMaxInt , 263 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 264 | area, grad_output_integrated, tmpArray); 265 | } else { 266 | cpu::boxConvUpdateGradInput( 267 | xMinInt , xMaxInt , yMinInt , yMaxInt , 268 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 269 | area, grad_output_integrated, tmpArray); 270 | } 271 | } 272 | } 273 | 274 | gradInput = tmpArray.sum(2); 275 | } // if (input_needs_grad) 276 | 277 | bool paramNeedsGrad[4] = {x_min_needs_grad, x_max_needs_grad, y_min_needs_grad, y_max_needs_grad}; 278 | at::Tensor gradParam[4] = {nullTensor, nullTensor, nullTensor, nullTensor}; 279 | 280 | at::Tensor tmpArray; 281 | at::Tensor area; // box area for normalization 282 | 283 | bool someParamNeedsGrad = false; 284 | for (bool needsGrad : paramNeedsGrad) { 285 | someParamNeedsGrad |= needsGrad; 286 | } 287 | 288 | if (someParamNeedsGrad) { 289 | tmpArray = at::empty({batchSize, nInputPlanes, numFilters, h, w}, x_min.options()); 290 | 291 | if (x_min.is_cuda()) { 292 | gpu::splitParametersAccGradParameters( 293 | x_min , x_max , y_min , y_max , 294 | xMinInt , xMaxInt , yMinInt , yMaxInt , 295 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac); 296 | 297 | if (normalize) { 298 | area = gpu::computeArea(x_min, x_max, y_min, y_max, exact); 299 | } 300 | } else { 301 | cpu::splitParametersAccGradParameters( 302 | x_min , x_max , y_min , y_max , 303 | xMinInt , xMaxInt , yMinInt , yMaxInt , 304 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac); 305 | 306 | if (normalize) { 307 | area = cpu::computeArea(x_min, x_max, y_min, y_max, exact); 308 | } 309 | } 310 | 311 | input_integrated = input_integrated.contiguous(); // TODO support noncontiguous too 312 | 313 | for (int paramIdx = 0; paramIdx < 4; ++paramIdx) { 314 | if (paramNeedsGrad[paramIdx]) { 315 | const Parameter paramId = static_cast(paramIdx); 316 | 317 | if (input_integrated.is_cuda()) { 318 | if (exact) { 319 | gpu::boxConvAccGradParameters( 320 | xMinInt , xMaxInt , yMinInt , yMaxInt , 321 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 322 | input_integrated, tmpArray, paramId); 323 | } else { 324 | gpu::boxConvAccGradParameters( 325 | xMinInt , xMaxInt , yMinInt , yMaxInt , 326 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 327 | input_integrated, tmpArray, paramId); 328 | } 329 | } else { 330 | if (exact) { 331 | cpu::boxConvAccGradParameters( 332 | xMinInt , xMaxInt , yMinInt , yMaxInt , 333 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 334 | input_integrated, tmpArray, paramId); 335 | } else { 336 | cpu::boxConvAccGradParameters( 337 | xMinInt , xMaxInt , yMinInt , yMaxInt , 338 | xMinFrac, xMaxFrac, yMinFrac, yMaxFrac, 339 | input_integrated, tmpArray, paramId); 340 | } 341 | } 342 | 343 | tmpArray.mul_(grad_output); 344 | 345 | gradParam[paramIdx] = 346 | tmpArray.reshape({batchSize, nInputPlanes, numFilters, h*w}).sum(c10::IntArrayRef({0, 3})); 347 | 348 | if (normalize) { 349 | gradParam[paramIdx].mul_(area); 350 | } 351 | } 352 | } 353 | 354 | if (normalize) { // add the second summand 355 | output = output.reshape({batchSize, nInputPlanes, numFilters, h, w}); 356 | 357 | tmpArray = grad_output.mul(output); 358 | tmpArray = tmpArray.reshape({batchSize, nInputPlanes, numFilters, h*w}).sum(c10::IntArrayRef({0, 3})); 359 | 360 | for (int paramIdx = 0; paramIdx < 4; ++paramIdx) { 361 | if (paramNeedsGrad[paramIdx]) { 362 | const Parameter paramId = static_cast(paramIdx); 363 | 364 | // multiply by area derivative and divide by area 365 | const bool needXDeriv = paramId == Parameter::xMin or paramId == Parameter::xMax; 366 | const bool needYDeriv = not needXDeriv; 367 | 368 | if (x_min.is_cuda()) { 369 | area = gpu::computeArea( 370 | x_min, x_max, y_min, y_max, exact, needXDeriv, needYDeriv); 371 | } else { 372 | area = cpu::computeArea( 373 | x_min, x_max, y_min, y_max, exact, needXDeriv, needYDeriv); 374 | } 375 | 376 | const bool minus = paramId == Parameter::xMax or paramId == Parameter::yMax; 377 | gradParam[paramIdx].addcmul_(tmpArray, area, minus ? -1.0 : 1.0); 378 | } 379 | } 380 | } 381 | 382 | // account for reparametrization 383 | for (int paramIdx = 0; paramIdx < 4; ++paramIdx) { 384 | if (paramNeedsGrad[paramIdx]) { 385 | const Parameter paramId = static_cast(paramIdx); 386 | const double scale = paramId == Parameter::xMin or paramId == Parameter::xMax 387 | ? reparametrization_h : reparametrization_w; 388 | 389 | gradParam[paramIdx].mul_(scale); 390 | } 391 | } 392 | } // if (someParamNeedsGrad) 393 | 394 | return {gradInput, gradParam[0], gradParam[1], gradParam[2], gradParam[3]}; 395 | } 396 | 397 | void clip_parameters( 398 | at::Tensor x_min, at::Tensor x_max, 399 | at::Tensor y_min, at::Tensor y_max, 400 | const double reparametrization_h, const double reparametrization_w, 401 | const double max_input_h, const double max_input_w, const bool exact) { 402 | 403 | CHECK_CONTIGUOUS(x_min); // and assume other parameter tensors have same layout 404 | 405 | const float minWidth = exact ? 1.001f : 2.001f; 406 | const float minHeight = exact ? 1.001f : 2.001f; 407 | 408 | if (x_min.is_cuda()) { 409 | gpu::clipParameters(x_min, x_max, reparametrization_h, minHeight, max_input_h); 410 | gpu::clipParameters(y_min, y_max, reparametrization_w, minWidth , max_input_w); 411 | } else { 412 | cpu::clipParameters(x_min, x_max, reparametrization_h, minHeight, max_input_h); 413 | cpu::clipParameters(y_min, y_max, reparametrization_w, minWidth , max_input_w); 414 | } 415 | } 416 | -------------------------------------------------------------------------------- /src/cuda_stubs.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "box_convolution.h" // for `enum class Parameter` 4 | 5 | #define STUB_ERROR TORCH_CHECK(false, "box_convolution was compiled withoud CUDA support because " \ 6 | "torch.cuda.is_available() was False when you ran setup.py.") 7 | 8 | namespace gpu { 9 | 10 | void integral_image(at::Tensor & input, at::Tensor & output) 11 | { STUB_ERROR; } 12 | 13 | void splitParameters( 14 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 15 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 16 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac) 17 | { STUB_ERROR; } 18 | 19 | void splitParametersUpdateGradInput( 20 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 21 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 22 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac) 23 | { STUB_ERROR; } 24 | 25 | void splitParametersAccGradParameters( 26 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max , 27 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 28 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac) 29 | { STUB_ERROR; } 30 | 31 | template 32 | void boxConvUpdateOutput( 33 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 34 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 35 | at::Tensor & area, at::Tensor & input_integrated, at::Tensor & output) 36 | { STUB_ERROR; } 37 | 38 | // explicitly instantiate 39 | template void boxConvUpdateOutput( 40 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 41 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 42 | at::Tensor &, at::Tensor &, at::Tensor &); 43 | 44 | template void boxConvUpdateOutput( 45 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 46 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 47 | at::Tensor &, at::Tensor &, at::Tensor &); 48 | 49 | template void boxConvUpdateOutput( 50 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 51 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 52 | at::Tensor &, at::Tensor &, at::Tensor &); 53 | 54 | template void boxConvUpdateOutput( 55 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 56 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 57 | at::Tensor &, at::Tensor &, at::Tensor &); 58 | 59 | template 60 | void boxConvUpdateGradInput( 61 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 62 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 63 | at::Tensor & area, at::Tensor & grad_output_integrated, at::Tensor & tmpArray) 64 | { STUB_ERROR; } 65 | 66 | // explicitly instantiate 67 | template void boxConvUpdateGradInput( 68 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 69 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 70 | at::Tensor &, at::Tensor &, at::Tensor &); 71 | 72 | template void boxConvUpdateGradInput( 73 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 74 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 75 | at::Tensor &, at::Tensor &, at::Tensor &); 76 | 77 | template void boxConvUpdateGradInput( 78 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 79 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 80 | at::Tensor &, at::Tensor &, at::Tensor &); 81 | 82 | template void boxConvUpdateGradInput( 83 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 84 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 85 | at::Tensor &, at::Tensor &, at::Tensor &); 86 | 87 | template 88 | void boxConvAccGradParameters( 89 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt , 90 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac, 91 | at::Tensor & input_integrated, at::Tensor & tmpArray, Parameter parameter) 92 | { STUB_ERROR; } 93 | 94 | // explicitly instantiate 95 | template void boxConvAccGradParameters( 96 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 97 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 98 | at::Tensor &, at::Tensor &, Parameter); 99 | 100 | template void boxConvAccGradParameters( 101 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 102 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, 103 | at::Tensor &, at::Tensor &, Parameter); 104 | 105 | void clipParameters( 106 | at::Tensor & paramMin, at::Tensor & paramMax, 107 | const double reparametrization, const double minSize, const double maxSize) 108 | { STUB_ERROR; } 109 | 110 | at::Tensor computeArea( 111 | at::Tensor x_min, at::Tensor x_max, at::Tensor y_min, at::Tensor y_max, 112 | const bool exact, const bool needXDeriv, const bool needYDeriv) 113 | { STUB_ERROR; } 114 | 115 | } -------------------------------------------------------------------------------- /src/integral_image.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "integral_image.h" 5 | 6 | namespace cpu { 7 | 8 | void integral_image(at::Tensor & input, at::Tensor & output) { 9 | 10 | const int h = input.size(-2); 11 | const int w = input.size(-1); 12 | const int nChannels = input.numel() / (h * w); 13 | 14 | AT_DISPATCH_ALL_TYPES(input.scalar_type(), "integral_image_forward_cpu", ([&] { 15 | using accscalar_t = at::acc_type; 16 | 17 | scalar_t *inputPtr = input.data_ptr(); 18 | scalar_t *outputPtr = output.data_ptr(); 19 | 20 | for (int c = 0; c < nChannels; ++c) { 21 | // Fill the 0-th row 22 | std::memset(outputPtr, 0, (w+1)*sizeof(scalar_t)); 23 | 24 | // Fill the rest 25 | for (int row = 0; row < h; ++row) { 26 | outputPtr[(row+1)*(w+1)] = 0.0; 27 | 28 | accscalar_t sum = 0.0; 29 | for (int col = 0; col < w; ++col) { 30 | sum += inputPtr[row*w + col]; 31 | outputPtr[(row+1)*(w+1) + (col+1)] = sum + outputPtr[row*(w+1) + (col+1)]; 32 | } 33 | } 34 | 35 | inputPtr += h*w; 36 | outputPtr += (h+1)*(w+1); 37 | } 38 | })); 39 | } 40 | 41 | } // namespace cpu 42 | -------------------------------------------------------------------------------- /src/integral_image.h: -------------------------------------------------------------------------------- 1 | #include // && -> and, || -> or etc. 2 | 3 | namespace cpu { 4 | 5 | void integral_image(at::Tensor & input, at::Tensor & output); 6 | 7 | } 8 | 9 | namespace gpu { 10 | 11 | void integral_image(at::Tensor & input, at::Tensor & output); 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/integral_image_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include "integral_image.h" 11 | 12 | #include 13 | 14 | #define BLOCK_SIZE 256 15 | 16 | namespace gpu { 17 | 18 | template 19 | __global__ void accumulateColsKernel( 20 | const scalar_t * __restrict__ input, scalar_t * __restrict__ output, 21 | const int channels, const int h, const int w); 22 | 23 | template 24 | __global__ void accumulateColsInplaceTransposedKernel( 25 | scalar_t * __restrict__ input, const int channels, const int h, const int w); 26 | 27 | // contiguous out-of-place transpose 28 | template 29 | void transpose(at::Tensor & input, at::Tensor & output) { 30 | 31 | TORCH_CHECK(input.dim() == 2); 32 | TORCH_CHECK(input.numel() == output.numel()); 33 | 34 | if (std::is_same()) { 35 | cublasHandle_t cublasHandle = at::cuda::getCurrentCUDABlasHandle(); 36 | cudaStream_t currentStream = at::cuda::getCurrentCUDAStream(); 37 | cublasSetStream(cublasHandle, currentStream); 38 | const float ONE = 1.0, ZERO = 0.0; 39 | 40 | THCublasCheck(cublasSgeam( 41 | cublasHandle, 42 | CUBLAS_OP_T, CUBLAS_OP_N, input.size(0), input.size(1), 43 | &ONE, input.data_ptr(), input.size(1), 44 | &ZERO, output.data_ptr(), input.size(0), 45 | output.data_ptr(), input.size(0))); 46 | 47 | } else if (std::is_same()) { 48 | cublasHandle_t cublasHandle = at::cuda::getCurrentCUDABlasHandle(); 49 | cudaStream_t currentStream = at::cuda::getCurrentCUDAStream(); 50 | cublasSetStream(cublasHandle, currentStream); 51 | const double ONE = 1.0, ZERO = 0.0; 52 | 53 | THCublasCheck(cublasDgeam( 54 | cublasHandle, 55 | CUBLAS_OP_T, CUBLAS_OP_N, input.size(0), input.size(1), 56 | &ONE, input.data_ptr(), input.size(1), 57 | &ZERO, output.data_ptr(), input.size(0), 58 | output.data_ptr(), input.size(0))); 59 | 60 | } else { 61 | // TODO improve 62 | output.view({input.size(1), input.size(0)}).copy_(input.t()); 63 | 64 | } 65 | } 66 | 67 | void integral_image(at::Tensor & input, at::Tensor & output) { 68 | 69 | const int h = input.size(-2); 70 | const int w = input.size(-1); 71 | const int channels = input.numel() / (h * w); 72 | 73 | auto inputView = input.view({channels, h, w}); 74 | auto outputView = output.view({channels, h+1, w+1}); 75 | auto tmpBuffer = at::empty_like(output); 76 | 77 | cudaStream_t currentStream = at::cuda::getCurrentCUDAStream(); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "integral_image_forward_gpu", ([&] { 80 | using accscalar_t = at::acc_type; 81 | 82 | // input : (channels) x (h) x (w), contiguous 83 | // output: (channels) x (h+1) x (w+1), contiguous 84 | // tmpBuffer : at least (channels) * (h+1) * (w+1) 85 | int blockSize1D, gridSize1D; 86 | 87 | // Compute prefix sums of columns, `input` -> `output` 88 | // (channels) x (h) x (w) ==> (channels) x (h+1) x (w+1) 89 | // Note: output[:,:,0] remains uninitialized 90 | const int totalCols = channels * w; 91 | blockSize1D = BLOCK_SIZE; 92 | gridSize1D = (totalCols + blockSize1D - 1) / blockSize1D; 93 | accumulateColsKernel 94 | <<>> 95 | (inputView.data_ptr(), outputView.data_ptr(), channels, h, w); 96 | THCudaCheck(cudaGetLastError()); 97 | 98 | // transpose, `output` -> `tmpBuffer` 99 | // (channels) x (h+1) x (w+1) ==> (w+1) x (channels) x (h+1) 100 | auto output2Dim = output.view({channels * (h+1), w+1}); 101 | transpose(output2Dim, tmpBuffer); 102 | 103 | // Compute prefix sums of columns (former rows), `tmpBuffer` -> `tmpBuffer` 104 | // (w+1) x (channels) x (h+1) ==> (w+1) x (channels) x (h+1) 105 | const int totalRows = channels * h; // actually, number of cols in (w+1) x (channels * (h+1)) image 106 | blockSize1D = BLOCK_SIZE; 107 | gridSize1D = (totalRows + blockSize1D - 1) / blockSize1D; 108 | accumulateColsInplaceTransposedKernel 109 | <<>> 110 | (tmpBuffer.data_ptr(), channels, h, w); 111 | THCudaCheck(cudaGetLastError()); 112 | 113 | // transpose, `tmpBuffer` -> `output` 114 | // (w+1) x (channels) x (h+1) ==> (channels) x (h+1) x (w+1) 115 | tmpBuffer = tmpBuffer.reshape({w+1, channels * (h+1)}); 116 | transpose(tmpBuffer, output); 117 | })); // AT_DISPATCH_FLOATING_TYPES_AND_HALF 118 | } 119 | 120 | template 121 | __global__ void accumulateColsKernel( 122 | const scalar_t * __restrict__ input, scalar_t * __restrict__ output, 123 | const int channels, const int h, const int w) { 124 | // input : (channels * h) x (w) 125 | // output: (channels * (h+1)) x (w+1) -- first column remains untouched 126 | 127 | // global column index (of total `channels * w` columns in this image): 128 | const int globalColIdx = BLOCK_SIZE * blockIdx.x + threadIdx.x; 129 | 130 | if (globalColIdx < channels * w) { 131 | const int channelIdx = globalColIdx / w; 132 | const int colIdx = globalColIdx - channelIdx * w; 133 | 134 | // jump to the channel of interest: 135 | int inputPos = channelIdx * h * w + colIdx; 136 | // (let local columns be 1-indexed: 0-th output column is always zero) 137 | int outputPos = channelIdx * (h+1) * (w+1) + colIdx + 1; 138 | 139 | output[outputPos] = 0; // 0-th element of every column is always zero 140 | accscalar_t sum = 0; 141 | for (int i = 1; i <= h; ++i) { 142 | sum += static_cast(input[inputPos + (i-1) * w]); 143 | output[outputPos + i * (w+1)] = static_cast(sum); 144 | } 145 | } 146 | } 147 | 148 | template 149 | __global__ void accumulateColsInplaceTransposedKernel( 150 | scalar_t * __restrict__ input, const int channels, const int h, const int w) { 151 | // in-place. 152 | // input: (w+1) x (channels * (h+1)) 153 | 154 | // global column index (of total `channels * w` columns in this image): 155 | const int globalColIdx = BLOCK_SIZE * blockIdx.x + threadIdx.x; 156 | 157 | if (globalColIdx < channels * h) { 158 | const int channelIdx = globalColIdx / h; 159 | // add `channelIdx + 1` to account for one extra column in each horizontally stacked image 160 | const int colIdx = globalColIdx + channelIdx + 1; 161 | 162 | // need to zero the (0,0) corner of the output separately >:( 163 | input[channelIdx * (h+1)] = 0; 164 | 165 | input[colIdx] = 0; // first element of every column is always zero 166 | accscalar_t sum = 0; 167 | for (int i = 1; i <= w; ++i) { 168 | scalar_t *currentElement = &input[i * channels * (h+1) + colIdx]; 169 | sum += static_cast(*currentElement); 170 | *currentElement = static_cast(sum); 171 | } 172 | } 173 | } 174 | 175 | } // namespace gpu -------------------------------------------------------------------------------- /src/integral_image_interface.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | 6 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 7 | 8 | at::Tensor integral_image(at::Tensor input) { 9 | TORCH_CHECK(input.dim() >= 2, "integral image input must have >=2 dimensions") 10 | input = input.contiguous(); 11 | 12 | // The result will have extra row and column for zeros 13 | // TODO maybe it's faster to eliminate them 14 | std::vector outputSize(input.sizes().begin(), input.sizes().end()); 15 | ++outputSize[input.dim() - 2]; 16 | ++outputSize[input.dim() - 1]; 17 | 18 | auto output = at::empty(outputSize, input.options()); 19 | 20 | if (input.is_cuda()) { 21 | gpu::integral_image(input, output); 22 | } else { 23 | cpu::integral_image(input, output); 24 | } 25 | 26 | return output; 27 | } 28 | --------------------------------------------------------------------------------