├── .gitignore
├── LICENSE
├── README.md
├── box_convolution
├── __init__.py
├── box_convolution_function.py
├── box_convolution_module.py
└── test.py
├── examples
├── Cityscapes
│ ├── README.md
│ ├── convnet-runtime-benchmark.py
│ ├── datasets.py
│ ├── models
│ │ ├── ENet.py
│ │ └── ERFNet.py
│ └── train.py
└── mnist.py
├── setup.py
└── src
├── bind.cpp
├── box_convolution.cpp
├── box_convolution.h
├── box_convolution_cuda_backward.cu
├── box_convolution_cuda_forward.cu
├── box_convolution_cuda_misc.cu
├── box_convolution_interface.cpp
├── cuda_stubs.cpp
├── integral_image.cpp
├── integral_image.h
├── integral_image_cuda.cu
└── integral_image_interface.cpp
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__/
2 | *.pyc
3 | *.so
4 | build/
5 | *.egg-info
6 | examples/MNIST/
7 | *runs/
8 | *log.txt
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2019 Egor Burkov
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Box Convolution Layer for ConvNets
2 | ==================================
3 |
4 |
5 |
6 |
7 | Single-box-conv network (from `examples/mnist.py`) learns patterns on MNIST
8 |
9 |
10 | # What This Is
11 |
12 | This is a PyTorch implementation of the box convolution layer as introduced in the 2018 NeurIPS [paper](https://papers.nips.cc/paper/7859-deep-neural-networks-with-box-convolutions):
13 |
14 | Burkov, E., & Lempitsky, V. (2018) **Deep Neural Networks with Box Convolutions**. *Advances in Neural Information Processing Systems 31*, 6214-6224.
15 |
16 | # How to Use
17 |
18 | ## Installing
19 |
20 | ```bash
21 | python3 -m pip install git+https://github.com/shrubb/box-convolutions.git
22 | python3 -m box_convolution.test # if throws errors, please open a GitHub issue
23 | ```
24 |
25 | To uninstall:
26 |
27 | ```bash
28 | python3 -m pip uninstall box_convolution
29 | ```
30 |
31 | Tested on Ubuntu 18.04.2, Python 3.6, PyTorch 1.0.0, GCC {4.9, 5.5, 6.5, 7.3}, CUDA 9.2. Other versions (e.g. macOS or Python 2.7 or CUDA 8 or CUDA 10) should work too, but I haven't checked. If something doesn't build, please open a Github issue.
32 |
33 | Known issues (see [this chat](https://github.com/shrubb/box-convolutions/issues/2)):
34 |
35 | * CUDA 9/9.1 + GCC 6 isn't supported due to a bug in NVCC.
36 |
37 | You can specify a different compiler with `CC` environment variable:
38 |
39 | ```bash
40 | CC=g++-7 python3 -m pip install git+https://github.com/shrubb/box-convolutions.git
41 | ```
42 |
43 | ## Using
44 |
45 | ```python3
46 | import torch
47 | from box_convolution import BoxConv2d
48 |
49 | box_conv = BoxConv2d(16, 8, 240, 320)
50 | help(BoxConv2d)
51 | ```
52 |
53 | Also, there are usage examples in `examples/`.
54 |
55 |
56 | # Quick Tour of Box convolutions
57 |
58 | You may want to see our [poster](https://yadi.sk/i/LNnMrj6FwbOc9A).
59 |
60 | ### Why reinvent the old convolution?
61 |
62 | `3×3` convolutions are too small ⮕ receptive field grows too slow ⮕ ConvNets have to be very deep.
63 |
64 | This is especially undesirable in dense prediction tasks (*segmentation, depth estimation, object detection, ...*).
65 |
66 | Today people solve this by
67 |
68 | * dilated/deformable convolutions (can bring artifacts or degrade to `1×1` conv; almost always filter high-frequency);
69 | * "global" spatial pooling layers (usually too constrained, fixed size, not "fully convolutional").
70 |
71 | ### How does it work?
72 |
73 | Box convolution layer is a basic *depthwise convolution* (i.e. `Conv2d` with `groups=in_channels`) but with special kernels called *box kernels*.
74 |
75 | A box kernel is a rectangular averaging filter. That is, filter values are fixed and unit! Instead, we learn four parameters per rectangle − its size and offset:
76 |
77 | 
78 |
79 | 
80 |
81 | ### Any success stories?
82 |
83 | One example: there is an efficient semantic segmentation model [**ENet**](https://github.com/e-lab/ENet-training). It's a classical hourglass architecture stacked of dozens ResNet-like blocks (left image).
84 |
85 | Let's replace some of these blocks by our "box convolution block" (right image).
86 |
87 |
88 |
89 | First we replaced every second block with a box convolution block (*Box*ENet in the paper). The model became
90 |
91 | * more accurate,
92 | * faster,
93 | * lighter
94 | * **without dilated convolutions**.
95 |
96 | Then, we replaced **every** residual block (except the down- and up-sampling ones)! The result, *BoxOnly*ENet, is
97 |
98 | * a **ConvNet almost without** (traditional learnable weight) **convolutions**,
99 | * **2** times less operations,
100 | * **3** times less parameters,
101 | * still **more accurate** than ENet!
102 |
103 |
--------------------------------------------------------------------------------
/box_convolution/__init__.py:
--------------------------------------------------------------------------------
1 | from .box_convolution_module import BoxConv2d
2 |
--------------------------------------------------------------------------------
/box_convolution/box_convolution_function.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import box_convolution_cpp_cuda as cpp_cuda
4 |
5 | def reparametrize(
6 | x_min, x_max, y_min, y_max, reparametrization_h, reparametrization_w,
7 | inplace=False, inverse=False):
8 | """
9 | If `inverse is False`, scale module's parameters so that their range becomes
10 | approximately [-1; 1]. Otherwise, do the inverse operation.
11 |
12 | This hack is unfortunately needed for the parameters to work with variants of SGD.
13 | Without this "reparametrization", box sizes' gradients will be extremely small.
14 |
15 | If `not inplace`, returns 4 new tensors, otherwise modifies the given ones.
16 | """
17 | scalar_h = reparametrization_h if inverse else (1 / reparametrization_h)
18 | scalar_w = reparametrization_w if inverse else (1 / reparametrization_w)
19 |
20 | with torch.no_grad():
21 | if inplace:
22 | x_min *= scalar_h
23 | x_max *= scalar_h
24 | y_min *= scalar_w
25 | y_max *= scalar_w
26 | else:
27 | return x_min * scalar_h, x_max * scalar_h, y_min * scalar_w, y_max * scalar_w
28 |
29 | # TODO: rename `x_` and `y_` to `h_` and `w_`
30 | class BoxConvolutionFunction(torch.autograd.Function):
31 | @staticmethod
32 | def forward(ctx, input, x_min, x_max, y_min, y_max,
33 | reparametrization_h, reparametrization_w, normalize, exact):
34 |
35 | # store all non-tensor arguments in `ctx`
36 | ctx.normalize = normalize
37 | ctx.reparametrization_h = reparametrization_h
38 | ctx.reparametrization_w = reparametrization_w
39 | ctx.exact = exact
40 |
41 | x_min, x_max, y_min, y_max = reparametrize(
42 | x_min, x_max, y_min, y_max, reparametrization_h, reparametrization_w, inverse=True)
43 |
44 | input_integrated = cpp_cuda.integral_image(input)
45 | output = cpp_cuda.box_convolution_forward(
46 | input_integrated, x_min, x_max, y_min, y_max, normalize, exact)
47 |
48 | ctx.save_for_backward(
49 | input_integrated, x_min, x_max, y_min, y_max, output if normalize else None)
50 |
51 | return output
52 |
53 | @staticmethod
54 | def backward(ctx, grad_output):
55 | input_integrated, x_min, x_max, y_min, y_max, output = ctx.saved_variables
56 | if output is None:
57 | output = torch.empty(0) # to satisfy `box_convolution_backward`'s signature
58 |
59 | retval = cpp_cuda.box_convolution_backward(
60 | input_integrated, x_min, x_max, y_min, y_max, grad_output, output,
61 | ctx.reparametrization_h, ctx.reparametrization_w,
62 | ctx.normalize, ctx.exact, *ctx.needs_input_grad[:5])
63 |
64 | # 4 `None`s for `reparametrization_h, reparametrization_w, normalize, exact`
65 | return tuple(retval) + (None,) * 4
66 |
--------------------------------------------------------------------------------
/box_convolution/box_convolution_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 |
4 | from .box_convolution_function import BoxConvolutionFunction, reparametrize
5 | import box_convolution_cpp_cuda as cpp_cuda
6 |
7 | class BoxConv2d(torch.nn.Module):
8 | """
9 | Module that performs depthwise box convolution.
10 | Convolves each of the incoming channels with `num_filters` different,
11 | possibly normalized, box kernels.
12 |
13 | Input : `(batch_size) x (in_planes) x (h) x (w)`
14 | Output: `(batch_size) x (in_planes*num_filters) x (h) x (w)`
15 |
16 | Constructor arguments:
17 |
18 | in_planes: int
19 | Number of channels in the input image (as in Conv2d).
20 | num_filters: int
21 | Number of filters to apply per channel (as in depthwise Conv2d).
22 | max_input_h, max_input_w: int
23 | Maximum estimated height/width of future input images. This parameter does
24 | not strictly bind input images to certain sizes. However, this is used
25 | when clipping the boxes to detect if some box has become too large or has
26 | drifted too far away from the image. See `_clip_parameters()` for details.
27 | reparametrization_factor: float
28 | In module parameters, boxes are not represented directly by their
29 | relative pixel coordinates, because then the gradients will usually
30 | be too small. Rather, here they are scaled into a range that is inside
31 | [-1; 1] by `1 / (reparametrization_factor * max_input_[h/w])`.
32 | When setting up training, generate a video of boxes using `draw_boxes()`.
33 | If they move too slow, increasing this parameter might help. If they
34 | converge too fast, reduce this value.
35 | stride_h, stride_w: int
36 | Stride (as in Conv2d). Not yet implemented.
37 | normalize: bool
38 | If `False`, computes sums over boxes (traditional box filters).
39 | If `True`, computes averages over boxes (normalized box filters).
40 |
41 | Useful fields (change after construction):
42 |
43 | exact: bool
44 | If `False`, box coordinates are rounded (towards smaller box size) before
45 | the output is computed. Significantly faster, but might be harmful for
46 | convergence. Well, still, often works for some reason, so try and see.
47 | Default: `True`.
48 | """
49 | def __init__(self,
50 | in_planes, num_filters, max_input_h, max_input_w,
51 | reparametrization_factor=8, stride_h=1, stride_w=1, normalize=True):
52 |
53 | super(BoxConv2d, self).__init__()
54 | self.in_planes = in_planes
55 | self.num_filters = num_filters
56 | self.max_input_h, self.max_input_w = max_input_h, max_input_w
57 | # default reparametrization; can be changed instead of setting a separate learning rate
58 | self.reparametrization_h = max_input_h * reparametrization_factor
59 | self.reparametrization_w = max_input_w * reparametrization_factor
60 | self.stride_h, self.stride_w = stride_h, stride_w
61 | assert stride_h == 1 and stride_w == 1, 'Sorry, strides are NYI'
62 | self.normalize = normalize
63 | self.exact = True
64 |
65 | self.x_min, self.x_max, self.y_min, self.y_max = \
66 | (torch.nn.Parameter(torch.empty(in_planes, num_filters)) for _ in range(4))
67 | self.reset_parameters()
68 |
69 | def reset_parameters(self):
70 | """
71 | One of the various possible random box initializations.
72 | """
73 | with torch.no_grad():
74 | # TODO speed up
75 | # TODO use torch's random generator
76 | # TODO provide the algorithm used in all original paper's experiments?
77 | max_h, max_w = self.max_input_h, self.max_input_w
78 | min_h, min_w = 2, 2
79 | for in_plane_idx in range(self.in_planes):
80 | for filter_idx in range(self.num_filters):
81 | center_h = random.uniform(
82 | -max_h*2/4.8+1+min_h/2, max_h*2/4.8-1-min_h/2)
83 | center_w = random.uniform(
84 | -max_w*2/4.8+1+min_w/2, max_w*2/4.8-1-min_w/2)
85 | height = 2 * random.uniform(
86 | min_h/2, min((max_h*2/4.8-1)-center_h, center_h-(-max_h*2/4.8+1)))
87 | width = 2 * random.uniform(
88 | min_w/2, min((max_w*2/4.8-1)-center_w, center_w-(-max_w*2/4.8+1)))
89 |
90 | self.x_min[in_plane_idx, filter_idx] = (center_h - height/2) * 1.5
91 | self.x_max[in_plane_idx, filter_idx] = (center_h + height/2) * 1.5
92 | self.y_min[in_plane_idx, filter_idx] = (center_w - width /2) * 1.5
93 | self.y_max[in_plane_idx, filter_idx] = (center_w + width /2) * 1.5
94 |
95 | reparametrize(
96 | self.x_min, self.x_max, self.y_min, self.y_max,
97 | self.reparametrization_h, self.reparametrization_w, inplace=True)
98 |
99 | def draw_boxes(self, channels=None, resolution=(600, 600), weights=None):
100 | """
101 | Plot all rectangles corresponding to box filters. Useful for debugging.
102 | Return the resulting image, an (H x W x 3) tensor.
103 |
104 | channels: List of input channels to draw boxes for.
105 | Default: `[0, 1, ..., self.in_planes-1]` (draw all boxes).
106 | resolution: Tuple (h, w) -- returned image resolution.
107 | Default: (600, 600)
108 | weights: `len(channels) x self.num_filters` array of values in [0; 1] that define
109 | "importance" of each box (e.g. function of weights from a successive
110 | convolution). More important boxes are given a brigter color, unimportant
111 | are drawn almost transparent.
112 | Default: `numpy.ones((len(channels), self.num_filters))`.
113 | """
114 | import cv2
115 | import numpy as np
116 |
117 | if channels is None:
118 | channels = range(self.in_planes)
119 |
120 | weights_shape = (len(channels), self.num_filters)
121 | if weights is None:
122 | weights = np.ones(weights_shape)
123 | weights = weights.cpu().float().numpy().reshape(weights_shape)
124 | weights.clip(0.01, 1.0, out=weights)
125 |
126 | retval = np.zeros(resolution + (3,), dtype=np.uint8)
127 |
128 | # draw gray lines at center
129 | center = [resolution[0] // 2, resolution[1] // 2]
130 | retval[center[0], :] = 70
131 | retval[:, center[1]] = 70
132 |
133 | colors = np.array([
134 | [255, 0, 0],
135 | [ 0, 255, 0],
136 | [ 0, 0, 255],
137 | [255, 255, 255],
138 | [255, 255, 0],
139 | [255, 0, 255],
140 | [ 0, 255, 255],
141 | [ 47, 20, 255],
142 | [255, 60, 160],
143 | [ 60, 170, 255],
144 | [ 30, 105, 210],
145 | [222, 196, 176],
146 | [212, 255, 127],
147 | [250, 206, 135],
148 | [ 50, 205, 50],
149 | [ 0, 165, 255],
150 | [ 60, 20, 220],
151 | [170, 178, 32]], dtype=np.float32)
152 |
153 | x_min, x_max, y_min, y_max = (p.float() for p in self.get_actual_parameters())
154 | x_min = x_min / self.max_input_h * (resolution[0] / 2) + center[0]
155 | y_min = y_min / self.max_input_w * (resolution[1] / 2) + center[1]
156 | x_max = (x_max + 1) / self.max_input_h * (resolution[0] / 2) + center[0]
157 | y_max = (y_max + 1) / self.max_input_w * (resolution[1] / 2) + center[1]
158 |
159 | for channel_idx in channels:
160 | for filter_idx in range(self.num_filters):
161 | box_weight = weights[channel_idx, filter_idx]
162 | # heuristic for single-plane inputs
163 | color = colors[(filter_idx if len(channels) == 1 else channel_idx) % len(colors)]
164 | # take weights into account
165 | color = (color * box_weight).astype(int)
166 |
167 | param_2d_idx = channel_idx, filter_idx
168 | x_min_curr = x_min[param_2d_idx]
169 | x_max_curr = x_max[param_2d_idx]
170 | y_min_curr = y_min[param_2d_idx]
171 | y_max_curr = y_max[param_2d_idx]
172 |
173 | # if a rect has negative size, fill it
174 | box_is_invalid = x_min_curr > x_max_curr or y_min_curr > y_max_curr
175 | thickness = -1 if box_is_invalid else round(resolution[0] / 500 + 0.5)
176 |
177 | cv2.rectangle(
178 | retval, (y_min_curr, x_min_curr), (y_max_curr, x_max_curr),
179 | color.tolist(), thickness)
180 |
181 | return retval
182 |
183 | def get_actual_parameters(self):
184 | """
185 | As parameters are scaled (see `reparametrization_factor`), they don't
186 | represent actual box coordinates.
187 |
188 | Return the **real** parameters (i.e. actual relative box coordinates)
189 | as if they weren't rescaled.
190 | """
191 | return reparametrize(
192 | self.x_min, self.x_max, self.y_min, self.y_max,
193 | self.reparametrization_h, self.reparametrization_w, inplace=False, inverse=True)
194 |
195 | def _clip_parameters(self):
196 | """
197 | Internal method, do not invoke as a user.
198 |
199 | Dirty parameter fix for projected gradient descent:
200 | - If a filter's width or height is negative, reset it to the minimum allowed positive.
201 | - If the filter is >=twice higher or wider than the input image, shrink it back.
202 | """
203 | cpp_cuda.clip_parameters(
204 | self.x_min, self.x_max, self.y_min, self.y_max,
205 | self.reparametrization_h, self.reparametrization_w,
206 | self.max_input_h, self.max_input_w, self.exact)
207 |
208 | def train(self, mode=True):
209 | self.training = mode
210 | if mode is False:
211 | # TODO would be good to also precompute rounded parameters and areas
212 | self._clip_parameters()
213 |
214 | return self
215 |
216 | def forward(self, input):
217 | if self.training:
218 | self._clip_parameters()
219 |
220 | return BoxConvolutionFunction.apply(
221 | input, self.x_min, self.x_max, self.y_min, self.y_max,
222 | self.reparametrization_h, self.reparametrization_w, self.normalize, self.exact)
223 |
--------------------------------------------------------------------------------
/box_convolution/test.py:
--------------------------------------------------------------------------------
1 | import time
2 | import random
3 | import torch
4 |
5 | try:
6 | from tqdm import tqdm
7 | except ImportError:
8 | tqdm = lambda x: x
9 |
10 | def test_integral_image(device):
11 | # or use torch.cumsum
12 | def integral_image_reference(input):
13 | assert input.ndimension() >= 2
14 | h, w = input.shape[-2:]
15 | output_shape = input.shape[:-2] + (h+1, w+1)
16 | output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
17 |
18 | # zero the 0th columns and rows
19 | output.select(-2, 0).fill_(0)
20 | output.select(-1, 0).fill_(0)
21 |
22 | # accumulate rows
23 | output_no_zero_col = output.narrow(-1, 1, w)
24 | sum_rows = torch.zeros_like(input.select(-2, 0), dtype=torch.float64)
25 | for row_idx in range(h):
26 | sum_rows += input.select(-2, row_idx).double()
27 | output_no_zero_col.select(-2, row_idx+1).copy_(sum_rows)
28 |
29 | # accumulate columns
30 | sum_cols = torch.zeros_like(output.select(-1, 0), dtype=torch.float64)
31 | for col_idx in range(w):
32 | sum_cols += output.select(-1, col_idx+1).double()
33 | output.select(-1, col_idx+1).copy_(sum_cols)
34 |
35 | return output
36 |
37 | from box_convolution_cpp_cuda import integral_image
38 |
39 | # check IntegralImageFunction vs reference implementation
40 | for test_idx in tqdm(range(50)):
41 | batch_size = random.randint(1, 3)
42 | in_planes = random.randint(1, 3)
43 | stride_h, stride_w = 1, 1 # may change in the future
44 | h, w = random.randint(1+stride_h, 10), random.randint(1+stride_w, 10)
45 |
46 | input_image = torch.rand(batch_size, in_planes, h, w, requires_grad=True, device=device)
47 | grad_output = torch.rand(batch_size, in_planes, h+1, w+1) < 0.15
48 | grad_output = grad_output.to(device, input_image.dtype)
49 |
50 | reference_result = integral_image_reference(input_image)
51 | our_result = integral_image(input_image)
52 |
53 | if not our_result.allclose(reference_result):
54 | raise ValueError(
55 | 'Test %d failed at forward pass.\n\nInput:\n%s\n\n'
56 | 'Our output:\n%s\n\nReference output:\n%s\n\n'
57 | % (test_idx, input_image, our_result, reference_result))
58 |
59 | def test_box_convolution_module(device):
60 | def explicit_box_kernel(x_min, x_max, y_min, y_max, normalize):
61 | import math
62 | h_farthest = math.ceil(max(x_max, -x_min))
63 | w_farthest = math.ceil(max(y_max, -y_min))
64 |
65 | retval = torch.ones(1+2*h_farthest, 1+2*w_farthest, device=x_min.device)
66 |
67 | def segments_intersection(a_l, a_r, b_l, b_r):
68 | common_l = max(a_l, b_l)
69 | common_r = min(a_r, b_r)
70 | return max(0.0, common_r - common_l)
71 |
72 | for x, row in enumerate(retval, start=-h_farthest):
73 | # h-extent of the current row: [x; x+1]
74 | # h-extent of the box of interest: [x_min; x_max+1]
75 | # find length of their intersection and multiply the row by it
76 | row *= segments_intersection(x, x+1, x_min, x_max+1)
77 |
78 | for y, col in enumerate(retval.t(), start=-w_farthest):
79 | # same for columns
80 | col *= segments_intersection(y, y+1, y_min, y_max+1)
81 |
82 | if normalize:
83 | area = (y_max-y_min+1) * (x_max-x_min+1)
84 | retval *= 1/area
85 |
86 | return retval
87 |
88 | # reference implementation
89 | def box_convolution_reference(
90 | input, x_min, x_max, y_min, y_max,
91 | reparametrization_h, reparametrization_w, normalize, exact):
92 |
93 | assert x_min.shape == x_max.shape
94 | assert x_min.shape == y_min.shape
95 | assert x_min.shape == y_max.shape
96 |
97 | assert input.ndimension() == 4
98 | assert type(normalize) is bool
99 |
100 | x_min, x_max, y_min, y_max = \
101 | reparametrize(
102 | x_min, x_max, y_min, y_max,
103 | reparametrization_h, reparametrization_w, inverse=True)
104 |
105 | if not exact:
106 | x_min.ceil_()
107 | y_min.ceil_()
108 | x_max.floor_()
109 | y_max.floor_()
110 |
111 | in_planes, num_filters = x_min.shape
112 | assert input.shape[1] == in_planes
113 |
114 | # in_c, out_c = input channel, output channel
115 | kernels = [[explicit_box_kernel(*out_c, normalize) for out_c in zip(*in_c)] \
116 | for in_c in zip(x_min, x_max, y_min, y_max)]
117 | assert len(kernels) == in_planes
118 | assert all(len(x) == num_filters for x in kernels)
119 |
120 | def conv2d_single_channel(image, kernel):
121 | image = image.view((1,1) + image.shape)
122 | kernel = kernel.view((1,1) + kernel.shape)
123 | padding = (kernel.shape[-2] // 2, kernel.shape[-1] // 2)
124 | return torch.conv2d(image, kernel, padding=padding)[0,0]
125 |
126 | output_shape = list(input.shape)
127 | output_shape.insert(2, num_filters)
128 | output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
129 |
130 | for in_sample, out_sample in zip(input, output):
131 | for in_plane_idx, in_plane_kernels in enumerate(kernels):
132 | for filter_idx, kernel in enumerate(in_plane_kernels):
133 | filtered = conv2d_single_channel(in_sample[in_plane_idx], kernel)
134 | out_sample[in_plane_idx, filter_idx].copy_(filtered)
135 |
136 | retval_shape = list(input.shape)
137 | retval_shape[1] *= num_filters
138 | return output.reshape(retval_shape)
139 |
140 | from . import BoxConv2d
141 | from .box_convolution_function import reparametrize
142 |
143 | # same interface for our target function
144 | def box_convolution_wrapper(
145 | input, x_min, x_max, y_min, y_max,
146 | max_input_h, max_input_w, reparametrization_factor, normalize, exact):
147 |
148 | assert x_min.shape == x_max.shape
149 | assert x_min.shape == y_min.shape
150 | assert x_min.shape == y_max.shape
151 |
152 | assert input.ndimension() == 4
153 | assert type(normalize) is bool
154 |
155 | in_planes, num_filters = x_min.shape
156 | assert input.shape[1] == in_planes
157 |
158 | module = BoxConv2d(
159 | in_planes, num_filters, max_input_h, max_input_w,
160 | reparametrization_factor).to(input)
161 |
162 | del module.x_min; module.x_min = x_min
163 | del module.x_max; module.x_max = x_max
164 | del module.y_min; module.y_min = y_min
165 | del module.y_max; module.y_max = y_max
166 | module.normalize = normalize
167 | module.exact = exact
168 |
169 | params_before = module.get_actual_parameters()
170 |
171 | output = module(input)
172 |
173 | params_after = module.get_actual_parameters()
174 | param_names = 'x_min', 'x_max', 'y_min', 'y_max'
175 | for p_before, p_after, p_name in zip(params_before, params_after, param_names):
176 | if not torch.equal(p_before, p_after):
177 | raise ValueError(
178 | 'Wrong test case configuration: `_clip_parameters` '
179 | 'has changed one of the parameters.\n\n' + \
180 | 'h, w = %d, %d\n\n' % (h, w) + \
181 | 'Before:\n' + \
182 | '\n'.join('%s: %s' % (n,p) for n,p in zip(param_names, params_before)) + \
183 | '\n\nAfter:\n' + \
184 | '\n'.join('%s: %s' % (n,p) for n,p in zip(param_names, params_after)))
185 |
186 | return output
187 |
188 | for test_idx in tqdm(range(40)):
189 | batch_size = random.randint(1, 1)
190 | in_planes = random.randint(1, 1)
191 | num_filters = random.randint(1, 1)
192 | stride_h, stride_w = 1, 1 # may change in the future
193 | exact = random.random() < 0.7
194 |
195 | # if not exact, minimum box size changes from 1 to 2
196 | h = random.randint(1 + stride_h + (not exact), 10)
197 | w = random.randint(1 + stride_w + (not exact), 10)
198 | max_input_h, max_input_w = h+1, w+1
199 | reparametrization_factor = random.random() * 4.5 + 0.5
200 | reparametrization_h = max_input_h * reparametrization_factor
201 | reparametrization_w = max_input_w * reparametrization_factor
202 | gradcheck_step = 0.004
203 |
204 | input_image = torch.rand(batch_size, in_planes, h, w, device=device, requires_grad=True)
205 |
206 | # sample boxes more or less randomly (algorithm isn't practical but is OK for gradcheck)
207 | x_min, x_max, y_min, y_max = \
208 | (torch.empty(in_planes, num_filters, device=device) for _ in range(4))
209 | for plane_idx in range(in_planes):
210 | for filter_idx in range(num_filters):
211 |
212 | box_is_valid = False
213 | while not box_is_valid:
214 | x_min_curr = random.uniform(-h+1.05, h-(not exact)-1.1)
215 | y_min_curr = random.uniform(-w+1.05, w-(not exact)-1.1)
216 |
217 | # set sizes to at least 1.001 because of `_clip_parameters`'s behavior
218 | x_max_curr = random.uniform(
219 | x_min_curr + (not exact) + 3*gradcheck_step + 0.001, h-1.05)
220 | y_max_curr = random.uniform(
221 | y_min_curr + (not exact) + 3*gradcheck_step + 0.001, w-1.05)
222 |
223 | # As a function of box coordinates (x_min etc.), box convolution isn't smooth
224 | # at integer points, so the finite difference gradcheck will fail.
225 | # Therefore, let's resample the box until all coordinates are far
226 | # enough from integers.
227 | box_is_valid = True
228 | for value in x_min_curr, y_min_curr, x_max_curr, y_max_curr:
229 | if abs(value - round(value)) <= gradcheck_step * 3: # *3 for extra safety
230 | box_is_valid = False
231 |
232 | x_min[plane_idx, filter_idx] = x_min_curr
233 | y_min[plane_idx, filter_idx] = y_min_curr
234 | x_max[plane_idx, filter_idx] = x_max_curr
235 | y_max[plane_idx, filter_idx] = y_max_curr
236 |
237 | # reparametrize
238 | x_min, x_max, y_min, y_max = \
239 | reparametrize(x_min, x_max, y_min, y_max, reparametrization_h, reparametrization_w)
240 |
241 | # randomly test either sum filter or average filter
242 | normalize = random.choice((False, True))
243 |
244 | grad_output = (torch.rand(batch_size, in_planes*num_filters, h, w) < 0.15).to(input_image)
245 |
246 | # check output and grad w.r.t. input vs reference ones
247 | reference_result = box_convolution_reference(
248 | input_image, x_min, x_max, y_min, y_max,
249 | reparametrization_h, reparametrization_w, normalize, exact)
250 | reference_result.backward(grad_output)
251 | reference_grad_input = input_image.grad.clone()
252 | input_image.grad.zero_()
253 |
254 | our_result = box_convolution_wrapper(
255 | input_image, x_min, x_max, y_min, y_max,
256 | max_input_h, max_input_w, reparametrization_factor, normalize, exact)
257 | our_result.backward(grad_output)
258 | our_grad_input = input_image.grad.clone()
259 |
260 | if not our_result.allclose(reference_result, rtol=3e-5, atol=1e-5):
261 | raise ValueError(
262 | 'Test %d failed at forward pass.\n\nNormalize: %s\n\nInput:\n%s\n\n'
263 | 'Our output:\n%s\n\nReference output:\n%s\n\nMax diff: %f\n\n'
264 | % (test_idx, normalize, input_image, our_result, reference_result, \
265 | (our_result - reference_result).abs().max()))
266 |
267 | if not our_grad_input.allclose(reference_grad_input, rtol=3e-5, atol=1e-5):
268 | raise ValueError(
269 | 'Test %d failed at backward pass.\n\nNormalize: %s\n\n'
270 | 'Input:\n%s\n\nOutput:\n%s\n\ngradOutput:\n%s\n\nOur gradInput:\n%s\n\n'
271 | 'Reference gradInput:\n%s\n\nMax diff: %f\n\n'
272 | % (test_idx, normalize, input_image, our_result, \
273 | grad_output, our_grad_input, reference_grad_input, \
274 | (our_grad_input-reference_grad_input).abs().max()))
275 |
276 | # sorry, I don't want to reliably check gradients w.r.t. parameters in rounded mode
277 | if not exact:
278 | continue
279 |
280 | # convert to double and check our grads w.r.t. parameters against finite differences
281 | with torch.no_grad():
282 | input_image = input_image.double()
283 | x_min = x_min.double()
284 | x_max = x_max.double()
285 | y_min = y_min.double()
286 | y_max = y_max.double()
287 |
288 | for tensor in x_min, x_max, y_min, y_max:
289 | tensor.requires_grad_()
290 | input_image.requires_grad_(False) # already tested above
291 |
292 | try:
293 | original_parameters = reparametrize(
294 | x_min, x_max, y_min, y_max, reparametrization_h, reparametrization_w, inverse=True)
295 |
296 | torch.autograd.gradcheck(
297 | box_convolution_wrapper,
298 | (input_image, x_min, x_max, y_min, y_max, max_input_h, max_input_w, \
299 | reparametrization_factor, normalize, exact),
300 | eps=gradcheck_step / max(reparametrization_h, reparametrization_w),
301 | raise_exception=True)
302 | except Exception:
303 | print('Test %d failed at finite difference grad check w.r.t. parameters.' % test_idx)
304 | print('Normalize: %s' % normalize)
305 | print('h, w = %d, %d' % (h, w))
306 | print('x_min, x_max, y_min, y_max are:')
307 | for parameter in original_parameters:
308 | print(parameter)
309 | raise
310 |
311 | if __name__ == '__main__':
312 | seed = int(time.time())
313 | seed = 1546545757
314 | torch.manual_seed(seed)
315 | random.seed(seed)
316 | print('Random seed is %d' % seed)
317 |
318 | # TODO add dtypes too
319 | devices = ('cpu',)
320 | if torch.cuda.is_available():
321 | devices += ('cuda',)
322 |
323 | for device in devices:
324 | print('Testing for device \'%s\'' % device)
325 | for testing_function in test_integral_image, test_box_convolution_module:
326 | print('Running %s()...' % testing_function.__name__)
327 | # TODO [re]set random state etc.
328 | testing_function(device)
329 | print('OK')
330 |
--------------------------------------------------------------------------------
/examples/Cityscapes/README.md:
--------------------------------------------------------------------------------
1 | Box convolutions for semantic segmentation
2 | ======
3 |
4 | This folder has the code to reproduce our experiments (**to be uploaded soon**).
5 |
6 | Pretrained models (metrics are shown for the validation set):
7 |
8 | | Name and link | Class IoU, % | Category IoU, % |
9 | |:--------------------------------------------------------:|:------------:|:---------------:|
10 | | [BoxENet]( ) | --.- | --.- |
11 | | [BoxOnlyENet]( ) | --.- | --.- |
12 | | [BoxERFNet]( ) | --.- | --.- |
--------------------------------------------------------------------------------
/examples/Cityscapes/convnet-runtime-benchmark.py:
--------------------------------------------------------------------------------
1 | """
2 | The script I use for benchmarking.
3 |
4 | Architectures are optimized for inference as if they were deployed in
5 | production: BatchNorms and Dropouts are absorbed, "void" class is removed.
6 |
7 | I managed to fully reproduce Torch7 runtimes from the paper for ENet
8 | and ERFNet; however, for some reason smaller models (e.g. ENet^-, ERFNet^-)
9 | are slower in PyTorch 1.0.1 than in Torch7.
10 | """
11 | import torch
12 |
13 | architecture = 'ENet' # choice: ENet / BoxENet / ERFNet / BoxERFNet / ENetMinus
14 | device = 'cuda' # or 'cpu'
15 | dtype = torch.float32
16 |
17 | if device == 'cuda':
18 | assert torch.backends.cudnn.enabled
19 | torch.backends.cudnn.benchmark = True
20 | torch.backends.cudnn.deterministic = False
21 |
22 | torch.set_num_threads(1)
23 | torch.manual_seed(666)
24 |
25 | # Cityscapes 1024x512 configuration
26 | n_classes = 18
27 | input_image = torch.rand((1, 3, 512, 1024), dtype=dtype, device=device)
28 |
29 | print('Architecture:', architecture)
30 | print('Device:', input_image.device)
31 | print('Data type:', input_image.dtype)
32 |
33 | from models.ERFNet import ERFNet, BoxERFNet
34 | from models.ENet import ENet, BoxENet, BoxOnlyENet, ENetMinus
35 |
36 | model = globals()[architecture](n_classes).to(input_image)
37 |
38 | # optimize the model for inference
39 | def remove_bn_and_dropout(module):
40 | for child_name, child in module.named_children():
41 | child_type = str(type(child))
42 | if 'BatchNorm' in child_type or 'Dropout' in child_type:
43 | module.__setattr__(child_name, torch.nn.Sequential())
44 | else:
45 | remove_bn_and_dropout(child)
46 |
47 | from box_convolution import BoxConv2d
48 | def set_boxconv_to_nonexact(module):
49 | if isinstance(module, BoxConv2d):
50 | module.exact = False
51 |
52 | model.apply(set_boxconv_to_nonexact)
53 | remove_bn_and_dropout(model)
54 | model.eval()
55 |
56 | # warm up
57 | print('Output shape:', model(input_image).shape)
58 |
59 | n_runs = 10 if device == 'cpu' else 160
60 | import time
61 |
62 | with torch.no_grad():
63 | if device == 'cuda':
64 | torch.cuda.synchronize()
65 |
66 | start = time.time()
67 | for _ in range(n_runs):
68 | model(input_image)
69 |
70 | if device == 'cuda':
71 | torch.cuda.synchronize()
72 | end = time.time()
73 |
74 | time_per_frame = (end - start) / n_runs
75 | print('%.1f ms per frame / %.2f FPS' % (time_per_frame * 1000, 1 / time_per_frame))
76 |
77 |
--------------------------------------------------------------------------------
/examples/Cityscapes/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import torchvision
4 |
5 | import cv2
6 | import numpy as np
7 |
8 | class Cityscapes(torch.utils.data.Dataset):
9 | """
10 | Native PyTorch Cityscapes but with proper data augmentation.
11 | """
12 | def __init__(
13 | self, root='/media/hpc4_Raid/e_burkov/Datasets/Cityscapes/',
14 | split='train', size=(1024, 512), augmented=False):
15 |
16 | super().__init__()
17 |
18 | self.cityscapes = torchvision.datasets.Cityscapes(
19 | root=root, split=split, mode='fine', target_type='semantic')
20 |
21 | self.size = size
22 | self.n_classes = 19
23 |
24 | # precomputed mean and stddev per channel
25 | self.mean = np.float32([0.28470638394356, 0.32577008008957, 0.28766867518425]) * 255.0
26 | self.std = np.float32([0.18671783804893, 0.1899059265852, 0.18665011227131]) * 255.0
27 |
28 | # precomputed class frequencies
29 | class_probs = np.float32([
30 | 0.36869695782661,
31 | 0.060849856585264,
32 | 0.22824048995972,
33 | 0.0065539856441319,
34 | 0.0087727159261703,
35 | 0.012273414991796,
36 | 0.0020779478363693,
37 | 0.0055127013474703,
38 | 0.1592865139246,
39 | 0.011578181758523,
40 | 0.040189824998379,
41 | 0.012189572677016,
42 | 0.0013512192526832,
43 | 0.069945447146893,
44 | 0.0026745572686195,
45 | 0.0023519159294665,
46 | 0.0023290426470339,
47 | 0.00098657899070531,
48 | 0.0041390685364604,
49 | ])
50 | # add "void" class and adopt a slightly modified class weighting scheme from ENet
51 | self.class_weights = np.concatenate(([0], 1.0 / np.log(class_probs + 1.1)))
52 | self.class_weights = torch.tensor(self.class_weights, dtype=torch.float32)
53 |
54 | self.augmented = augmented
55 |
56 | @staticmethod
57 | def augment(image, labels):
58 | """
59 | image: np.uint8, H x W x 3
60 | labels: np.uint8, H x W
61 | """
62 | flip = bool(np.random.randint(2))
63 | maybe_flip_matrix = np.eye(3)
64 | if flip:
65 | maybe_flip_matrix[0,0] = -1
66 | maybe_flip_matrix[0,2] = labels.shape[1]
67 |
68 | angle = (np.random.rand() * 2 - 1) * 7.5
69 | scale_factor = (np.random.rand() * 2 - 1) * 0.12 + 1.0
70 | image_center = (labels.shape[1] / 2, labels.shape[0] / 2)
71 | rotation_matrix = np.eye(3)
72 | rotation_matrix[:2] = cv2.getRotationMatrix2D(image_center, angle, scale_factor)
73 |
74 | transformation_matrix = (maybe_flip_matrix @ rotation_matrix)[:2]
75 |
76 | image_size = (labels.shape[1], labels.shape[0])
77 |
78 | image = cv2.warpAffine(
79 | image, transformation_matrix, image_size,
80 | flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
81 | labels = cv2.warpAffine(
82 | labels, transformation_matrix, image_size,
83 | flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT)
84 |
85 | return image, labels
86 |
87 | def __len__(self):
88 | return len(self.cityscapes)
89 |
90 | def __getitem__(self, idx):
91 | image, labels = map(np.array, self.cityscapes[idx])
92 |
93 | if image.shape[:2] != self.size[::-1]:
94 | image = cv2.resize(image, self.size, interpolation=cv2.INTER_AREA)
95 | if labels.shape != self.size[::-1]:
96 | labels = cv2.resize(labels, self.size, interpolation=cv2.INTER_NEAREST)
97 |
98 | def remap_labels(labels):
99 | """
100 | Shift labels according to
101 | https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py#L61
102 | """
103 | retval = np.zeros_like(labels)
104 | class_map = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
105 | for new_class, old_class in enumerate(class_map, start=1):
106 | retval[labels == old_class] = new_class
107 | return retval
108 |
109 | # comment this line if you have already preprocessed the labels this way
110 | labels = remap_labels(labels)
111 |
112 | if self.augmented:
113 | image, labels = self.augment(image, labels)
114 |
115 | image = image.transpose((2, 0, 1)).astype(np.float32)
116 | image -= self.mean.reshape(3, 1, 1)
117 | image *= 1 / self.std.reshape(3, 1, 1)
118 |
119 | return torch.tensor(image), torch.tensor(labels, dtype=torch.long)
120 |
121 |
122 | if __name__ == '__main__':
123 | dataset = Cityscapes('/home/shrubb/Datasets/Cityscapes', augmented=True)
124 | print(len(dataset))
125 |
126 |
--------------------------------------------------------------------------------
/examples/Cityscapes/models/ENet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class ENet(nn.ModuleList):
6 | def __init__(self, n_classes=19):
7 | super().__init__([
8 | Downsampler(3, 16),
9 | Bottleneck(16, 64, 0.01, downsample=True),
10 |
11 | Bottleneck(64, 64, 0.01),
12 | Bottleneck(64, 64, 0.01),
13 | Bottleneck(64, 64, 0.01),
14 | Bottleneck(64, 64, 0.01),
15 |
16 | Bottleneck(64, 128, 0.1, downsample=True),
17 |
18 | Bottleneck(128, 128, 0.1),
19 | Bottleneck(128, 128, 0.1, dilation=2),
20 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5),
21 | Bottleneck(128, 128, 0.1, dilation=4),
22 | Bottleneck(128, 128, 0.1),
23 | Bottleneck(128, 128, 0.1, dilation=8),
24 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5),
25 | Bottleneck(128, 128, 0.1, dilation=16),
26 |
27 | Bottleneck(128, 128, 0.1),
28 | Bottleneck(128, 128, 0.1, dilation=2),
29 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5),
30 | Bottleneck(128, 128, 0.1, dilation=4),
31 | Bottleneck(128, 128, 0.1),
32 | Bottleneck(128, 128, 0.1, dilation=8),
33 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5),
34 | Bottleneck(128, 128, 0.1, dilation=16),
35 |
36 | Upsampler(128, 64),
37 |
38 | Bottleneck(64, 64, 0.1),
39 | Bottleneck(64, 64, 0.1),
40 |
41 | Upsampler(64, 16),
42 |
43 | Bottleneck(16, 16, 0.1),
44 |
45 | nn.ConvTranspose2d(16, n_classes+1, (2,2), (2,2))])
46 |
47 | def forward(self, x):
48 | max_indices_stack = []
49 |
50 | for module in self:
51 | if isinstance(module, Upsampler):
52 | x = module(x, max_indices_stack.pop())
53 | else:
54 | x = module(x)
55 |
56 | if type(x) is tuple: # then it was a downsampling bottleneck block
57 | x, max_indices = x
58 | max_indices_stack.append(max_indices)
59 |
60 | return x
61 |
62 | class BoxENet(ENet):
63 | def __init__(self, n_classes=19, max_input_h=512, max_input_w=1024):
64 | h, w = max_input_h, max_input_w # shorten names for convenience
65 | r = 0.860 # reparametrization factor
66 |
67 | nn.ModuleList.__init__(self, [
68 | Downsampler(3, 16),
69 | Bottleneck(16, 64, 0.01, downsample=True),
70 |
71 | Bottleneck(64, 64, 0.01),
72 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r),
73 | Bottleneck(64, 64, 0.01),
74 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r),
75 |
76 | Bottleneck(64, 128, 0.1, downsample=True),
77 |
78 | Bottleneck(128, 128, 0.1),
79 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r),
80 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5),
81 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r),
82 | Bottleneck(128, 128, 0.1),
83 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r),
84 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5),
85 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r),
86 |
87 | Bottleneck(128, 128, 0.1),
88 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r),
89 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5),
90 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r),
91 | Bottleneck(128, 128, 0.1),
92 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r),
93 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5),
94 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.25, reparam_factor=r),
95 |
96 | Upsampler(128, 64),
97 |
98 | Bottleneck(64, 64, 0.1),
99 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.1, reparam_factor=r),
100 |
101 | Upsampler(64, 16),
102 |
103 | BottleneckBoxConv(16, 2, h // 2, w // 2, 0.1, reparam_factor=r),
104 |
105 | nn.ConvTranspose2d(16, n_classes+1, (2,2), (2,2))])
106 |
107 | class BoxOnlyENet(ENet):
108 | def __init__(self, n_classes=19, max_input_h=512, max_input_w=1024):
109 | h, w = max_input_h, max_input_w # shorten names for convenience
110 | r = 0.510 # reparametrization factor
111 |
112 | nn.ModuleList.__init__(self, [
113 | Downsampler(3, 16),
114 | Bottleneck(16, 64, 0.01, downsample=True),
115 |
116 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r),
117 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r),
118 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r),
119 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r),
120 |
121 | Bottleneck(64, 128, 0.01, downsample=True),
122 |
123 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
124 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
125 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
126 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
127 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
128 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
129 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
130 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
131 |
132 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
133 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
134 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
135 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
136 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
137 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
138 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
139 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.4, reparam_factor=r),
140 |
141 | Upsampler(128, 64),
142 |
143 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r),
144 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.15, reparam_factor=r),
145 |
146 | Upsampler(64, 16),
147 |
148 | BottleneckBoxConv(16, 4, h // 2, w // 2, 0.05, reparam_factor=r),
149 |
150 | nn.ConvTranspose2d(16, n_classes+1, (2,2), (2,2))])
151 |
152 | class ENetMinus(ENet):
153 | def __init__(self, n_classes=19, max_input_h=512, max_input_w=1024):
154 | h, w = max_input_h, max_input_w # shorten names for convenience
155 | r = 0.860 # reparametrization factor
156 |
157 | nn.ModuleList.__init__(self, [
158 | Downsampler(3, 16),
159 | Bottleneck(16, 64, 0.01, downsample=True),
160 |
161 | Bottleneck(64, 64, 0.01),
162 | Bottleneck(64, 64, 0.01),
163 |
164 | Bottleneck(64, 128, 0.1, downsample=True),
165 |
166 | Bottleneck(128, 128, 0.1),
167 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5),
168 | Bottleneck(128, 128, 0.1),
169 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5),
170 |
171 | Bottleneck(128, 128, 0.1),
172 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5),
173 | Bottleneck(128, 128, 0.1),
174 | Bottleneck(128, 128, 0.1, asymmetric_ksize=5),
175 |
176 | Upsampler(128, 64),
177 |
178 | Bottleneck(64, 64, 0.1),
179 |
180 | Upsampler(64, 16),
181 |
182 | nn.ConvTranspose2d(16, n_classes+1, (2,2), (2,2))])
183 |
184 | class Upsampler(nn.Module):
185 | def __init__(self, in_channels, out_channels):
186 | super().__init__()
187 | bt_channels = out_channels // 4
188 |
189 | self.main_branch = nn.Sequential(
190 | nn.Conv2d(in_channels, bt_channels, (1,1), bias=False),
191 | nn.BatchNorm2d(bt_channels, 1e-3),
192 | nn.ReLU(True),
193 |
194 | nn.ConvTranspose2d(bt_channels, bt_channels, (3,3), 2, 1, 1),
195 | nn.BatchNorm2d(bt_channels, 1e-3),
196 | nn.ReLU(True),
197 |
198 | nn.Conv2d(bt_channels, out_channels, (1,1), bias=False),
199 | nn.BatchNorm2d(out_channels, 1e-3))
200 |
201 | self.skip_connection = nn.Sequential(
202 | nn.Conv2d(in_channels, out_channels, (1,1), bias=False),
203 | nn.BatchNorm2d(out_channels, 1e-3))
204 |
205 | def forward(self, x, max_indices):
206 | x_skip_connection = self.skip_connection(x)
207 | x_skip_connection = F.max_unpool2d(x_skip_connection, max_indices, (2,2))
208 |
209 | return (x_skip_connection + self.main_branch(x)).relu_()
210 |
211 | class Downsampler(nn.Module):
212 | def __init__(self, in_channels, out_channels):
213 | super().__init__()
214 | self.conv = nn.Conv2d(in_channels, out_channels-in_channels, (3,3), 2, 1, bias=False)
215 | self.bn = nn.BatchNorm2d(out_channels, 1e-3)
216 | self.prelu = nn.PReLU(out_channels)
217 |
218 | def forward(self, x):
219 | x = torch.cat([F.max_pool2d(x, (2,2)), self.conv(x)], 1)
220 | x = self.bn(x)
221 | x = self.prelu(x)
222 | return x
223 |
224 | class Bottleneck(nn.Module):
225 | def __init__(self, in_channels, out_channels, dropout_prob=0.0, downsample=False,
226 | asymmetric_ksize=None, dilation=1, use_prelu=True):
227 |
228 | super().__init__()
229 | bt_channels = out_channels // 4
230 | self.downsample = downsample
231 | self.channels_to_pad = out_channels-in_channels
232 |
233 | input_stride = 2 if downsample else 1
234 |
235 | main_branch = [
236 | nn.Conv2d(in_channels, bt_channels, input_stride, input_stride, bias=False),
237 | nn.BatchNorm2d(bt_channels, 1e-3),
238 | nn.PReLU(bt_channels) if use_prelu else nn.ReLU(True)
239 | ]
240 |
241 | if asymmetric_ksize is None:
242 | main_branch += [
243 | nn.Conv2d(bt_channels, bt_channels, (3,3), 1, dilation, dilation, bias=False)
244 | ]
245 | else:
246 | assert type(asymmetric_ksize) is int
247 | ksize, padding = asymmetric_ksize, (asymmetric_ksize-1) // 2
248 | main_branch += [
249 | nn.Conv2d(bt_channels, bt_channels, (ksize,1), 1, (padding,0), bias=False),
250 | nn.Conv2d(bt_channels, bt_channels, (1,ksize), 1, (0,padding))
251 | ]
252 |
253 | main_branch += [
254 | nn.BatchNorm2d(bt_channels, 1e-3),
255 | nn.PReLU(bt_channels) if use_prelu else nn.ReLU(True),
256 | nn.Conv2d(bt_channels, out_channels, (1,1), bias=False),
257 | nn.BatchNorm2d(out_channels, 1e-3),
258 | nn.Dropout2d(dropout_prob)
259 | ]
260 |
261 | self.main_branch = nn.Sequential(*main_branch)
262 | self.output_activation = nn.PReLU(out_channels) if use_prelu else nn.ReLU(True)
263 |
264 | def forward(self, x):
265 | if self.downsample:
266 | x_skip_connection, max_indices = F.max_pool2d(x, (2,2), return_indices=True)
267 | else:
268 | x_skip_connection = x
269 |
270 | if self.channels_to_pad > 0:
271 | x_skip_connection = F.pad(x_skip_connection, (0,0, 0,0, 0,self.channels_to_pad))
272 |
273 | x = self.output_activation(x_skip_connection + self.main_branch(x))
274 |
275 | if self.downsample:
276 | return x, max_indices
277 | else:
278 | return x
279 |
280 | from box_convolution import BoxConv2d
281 |
282 | class BottleneckBoxConv(nn.Module):
283 | def __init__(self, in_channels, num_boxes, max_input_h, max_input_w,
284 | dropout_prob=0.0, reparam_factor=1.5625):
285 |
286 | super().__init__()
287 | assert in_channels % num_boxes == 0
288 | bt_channels = in_channels // num_boxes # bottleneck channels
289 |
290 | self.main_branch = nn.Sequential(
291 | nn.Conv2d(in_channels, bt_channels, (1,1), bias=False),
292 | nn.BatchNorm2d(bt_channels),
293 | nn.ReLU(True),
294 |
295 | # BEHOLD:
296 | BoxConv2d(
297 | bt_channels, num_boxes, max_input_h, max_input_w,
298 | reparametrization_factor=reparam_factor),
299 |
300 | nn.BatchNorm2d(in_channels),
301 | nn.Dropout2d(dropout_prob))
302 |
303 | def forward(self, x):
304 | return (x + self.main_branch(x)).relu_()
305 |
306 |
--------------------------------------------------------------------------------
/examples/Cityscapes/models/ERFNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class ERFNet(nn.Sequential):
6 | def __init__(self, n_classes=19):
7 | super().__init__(
8 | Downsampler( 3, 16, 0.0 ),
9 | Downsampler(16, 64, 0.03),
10 |
11 | NonBottleneck1D(64, 0.03),
12 | NonBottleneck1D(64, 0.03),
13 | NonBottleneck1D(64, 0.03),
14 | NonBottleneck1D(64, 0.03),
15 | NonBottleneck1D(64, 0.03),
16 |
17 | Downsampler(64, 128, 0.3),
18 |
19 | NonBottleneck1D(128, 0.3, 2),
20 | NonBottleneck1D(128, 0.3, 4),
21 | NonBottleneck1D(128, 0.3, 8),
22 | NonBottleneck1D(128, 0.3, 16),
23 | NonBottleneck1D(128, 0.3, 2),
24 | NonBottleneck1D(128, 0.3, 4),
25 | NonBottleneck1D(128, 0.3, 8),
26 | NonBottleneck1D(128, 0.3, 16),
27 |
28 | Upsampler(128, 64),
29 |
30 | NonBottleneck1D(64),
31 | NonBottleneck1D(64),
32 |
33 | Upsampler(64, 16),
34 |
35 | NonBottleneck1D(16),
36 | NonBottleneck1D(16),
37 |
38 | nn.ConvTranspose2d(16, n_classes+1, (3,3), 2, 1, 1))
39 |
40 | class BoxERFNet(nn.Sequential):
41 | def __init__(self, n_classes=19, max_input_h=512, max_input_w=1024):
42 | h, w = max_input_h, max_input_w # shorten names for convenience
43 |
44 | super().__init__(
45 | Downsampler( 3, 16, 0.0 ),
46 | Downsampler(16, 64, 0.03),
47 |
48 | NonBottleneck1D(64, 0.03),
49 | BottleneckBoxConv(64, 4, h // 4, w // 4, 0.03),
50 |
51 | Downsampler(64, 128, 0.3),
52 |
53 | NonBottleneck1D(128, 0.3, 2),
54 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.3),
55 | NonBottleneck1D(128, 0.3, 4),
56 |
57 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.3),
58 |
59 | NonBottleneck1D(128, 0.3, 2),
60 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.3),
61 | NonBottleneck1D(128, 0.3, 4),
62 | BottleneckBoxConv(128, 4, h // 8, w // 8, 0.3),
63 |
64 | Upsampler(128, 64),
65 |
66 | NonBottleneck1D(64),
67 |
68 | Upsampler(64, 16),
69 |
70 | NonBottleneck1D(16),
71 |
72 | nn.ConvTranspose2d(16, n_classes+1, (3,3), 2, 1, 1))
73 |
74 | def Upsampler(in_channels, out_channels):
75 | return nn.Sequential(
76 | nn.ConvTranspose2d(in_channels, out_channels, (3,3), 2, 1, 1, bias=False),
77 | nn.BatchNorm2d(out_channels),
78 | nn.ReLU(inplace=True))
79 |
80 | class Downsampler(nn.Module):
81 | def __init__(self, in_channels, out_channels, dropout_prob=0.0):
82 | super().__init__()
83 | self.conv = nn.Conv2d(in_channels, out_channels-in_channels, (3,3), 2, 1, bias=False)
84 | self.bn = nn.BatchNorm2d(out_channels)
85 | self.dropout = nn.Dropout2d(dropout_prob)
86 |
87 | def forward(self, x):
88 | x = torch.cat([F.max_pool2d(x, (2,2)), self.conv(x)], 1)
89 | x = self.bn(x)
90 | x = self.dropout(x)
91 | x = F.relu(x, inplace=True)
92 | return x
93 |
94 | class NonBottleneck1D(nn.Module):
95 | def __init__(self, in_channels, dropout_prob=0.0, dilation=1):
96 | super().__init__()
97 | dil = dilation # shorten the name for convenience
98 |
99 | self.main_branch = nn.Sequential(
100 | nn.Conv2d(in_channels, in_channels, (3,1), 1, (1,0), bias=False),
101 | nn.ReLU(True),
102 | nn.Conv2d(in_channels, in_channels, (1,3), 1, (0,1), bias=False),
103 | nn.BatchNorm2d(in_channels),
104 | nn.ReLU(True),
105 |
106 | nn.Conv2d(in_channels, in_channels, (3,1), 1, (dil,0), (dil,dil), bias=False),
107 | nn.ReLU(True),
108 | nn.Conv2d(in_channels, in_channels, (1,3), 1, (0,dil), (dil,dil), bias=False),
109 | nn.BatchNorm2d(in_channels),
110 | nn.Dropout2d(dropout_prob))
111 |
112 | def forward(self, x):
113 | return F.relu(x + self.main_branch(x), inplace=True)
114 |
115 | from box_convolution import BoxConv2d
116 |
117 | class BottleneckBoxConv(nn.Module):
118 | def __init__(self, in_channels, num_boxes, max_input_h, max_input_w, dropout_prob=0.0):
119 | super().__init__()
120 |
121 | assert in_channels % num_boxes == 0
122 | bt_channels = in_channels // num_boxes # bottleneck channels
123 |
124 | self.main_branch = nn.Sequential(
125 | nn.Conv2d(in_channels, bt_channels, (1,1), bias=False),
126 | nn.BatchNorm2d(bt_channels),
127 |
128 | # BEHOLD:
129 | BoxConv2d(
130 | bt_channels, num_boxes, max_input_h, max_input_w,
131 | reparametrization_factor=1.5625),
132 |
133 | nn.BatchNorm2d(in_channels),
134 | nn.Dropout2d(dropout_prob))
135 |
136 | def forward(self, x):
137 | return F.relu(x + self.main_branch(x), inplace=True)
138 |
--------------------------------------------------------------------------------
/examples/Cityscapes/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim
10 | import torch.utils.data
11 |
12 | import tensorboardX
13 |
14 | from models.ERFNet import ERFNet, BoxERFNet
15 | from models.ENet import ENet, BoxENet, BoxOnlyENet
16 |
17 | model_names = ['ENet', 'BoxENet', 'BoxOnlyENet', 'ERFNet', 'BoxERFNet']
18 |
19 | parser = argparse.ArgumentParser(
20 | description='Simple Cityscapes semantic segmentation training')
21 |
22 | parser.add_argument('data',
23 | help='path to dataset')
24 | parser.add_argument('--arch', '-a', default='BoxENet', choices=model_names,
25 | help='model architecture: ' + ' | '.join(model_names) + ' (default: BoxENet)')
26 | parser.add_argument('-j', '--workers', default=4, type=int,
27 | help='number of data loading workers (default: 4)')
28 | parser.add_argument('--epochs', default=1300, type=int,
29 | help='number of total epochs to run')
30 | parser.add_argument('-b', '--batch-size', default=12, type=int,
31 | help='mini-batch size (default: 12)')
32 |
33 | # optimizer
34 | parser.add_argument('--optimizer', default='Adam', type=str,
35 | help='optimizer type (default: Adam)')
36 | parser.add_argument('--lr', '--learning-rate', default=None, type=float, # 1e-3
37 | help='initial learning rate')
38 | parser.add_argument('--lr-decay', default=0.45, type=float,
39 | help='See --decay-steps')
40 | parser.add_argument('--decay-steps', default='', type=str, # "600,780,880,960,1040,1120,1200,1260"
41 | help='Comma-separated epoch numbers at which to multiply LR by --lr-decay')
42 | parser.add_argument('--momentum', default=None, type=float, # 0.9
43 | help='momentum')
44 | parser.add_argument('--nesterov', default=None, type=bool, # True
45 | help='use Nesterov momentum?')
46 | parser.add_argument('--weight-decay', '--wd', default=None, type=float, # 2e-4
47 | help='weight decay')
48 |
49 | parser.add_argument('--run-name', default=time.ctime(time.time())[4:-8], type=str,
50 | help='path to latest checkpoint (default: none)')
51 | parser.add_argument('--resume', action='store_true',
52 | help='resume from latest checkpoint')
53 | parser.add_argument('--start-epoch', default=0, type=int,
54 | help='Manual epoch number (useful on restarts)')
55 |
56 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
57 | help='Do not train, only evaluate model on the validation set')
58 |
59 | best_classIoU = 0
60 |
61 |
62 | def main():
63 | global args, best_classIoU
64 | args = parser.parse_args()
65 |
66 | random.seed(666)
67 | torch.manual_seed(666)
68 |
69 | from datasets import Cityscapes
70 |
71 | train_dataset = Cityscapes(
72 | args.data, split='train', size=(1024, 512), augmented=True)
73 |
74 | train_loader = torch.utils.data.DataLoader(
75 | train_dataset, batch_size=args.batch_size, shuffle=True,
76 | num_workers=args.workers, pin_memory=True)
77 | train_loader.iter = iter(train_loader)
78 |
79 | val_dataset = Cityscapes(
80 | args.data, split='val', size=(1024, 512), augmented=False)
81 |
82 | val_loader = torch.utils.data.DataLoader(
83 | val_dataset, batch_size=args.batch_size, shuffle=False,
84 | num_workers=args.workers, pin_memory=True)
85 |
86 | # create model and append `log_softmax` to it, to split this part of Criterion between GPUs
87 | print('Architecture:', args.arch)
88 | class ModelWithLogSoftmax(globals()[args.arch]):
89 | def forward(self, x):
90 | heatmap_raw = super().forward(x)
91 | return torch.nn.functional.log_softmax(heatmap_raw, dim=1)
92 |
93 | if 'Box' in args.arch:
94 | model = ModelWithLogSoftmax(
95 | n_classes=19, max_input_h=512, max_input_w=1024)
96 | else:
97 | model = ModelWithLogSoftmax(n_classes=19)
98 | model.cuda()
99 |
100 | # define loss function (criterion) and optimizer
101 | criterion = nn.NLLLoss(weight=train_dataset.class_weights).cuda()
102 |
103 | optimizer_kwargs = {hyperparam: getattr(args, hyperparam) \
104 | for hyperparam in ('lr', 'weight_decay', 'momentum', 'nesterov') \
105 | if getattr(args, hyperparam) is not None}
106 | optimizer = torch.optim.__dict__[args.optimizer](model.parameters(), **optimizer_kwargs)
107 |
108 | # optionally resume from a checkpoint
109 | if args.resume:
110 | checkpoint_path = os.path.join('runs', args.run_name, 'model.pth')
111 | if os.path.isfile(checkpoint_path):
112 | print("=> loading checkpoint '{}'".format(checkpoint_path))
113 | checkpoint = torch.load(checkpoint_path)
114 | args.start_epoch = checkpoint['epoch']
115 | best_classIoU = checkpoint['best_classIoU']
116 |
117 | model.load_state_dict(checkpoint['state_dict'])
118 | optimizer.load_state_dict(checkpoint['optimizer'])
119 | print("=> loaded checkpoint '{}' (epoch {})"
120 | .format(checkpoint_path, checkpoint['epoch']))
121 | else:
122 | raise FileNotFoundError("=> no checkpoint found at '{}'".format(checkpoint_path))
123 |
124 | torch.backends.cudnn.benchmark = True
125 | model = torch.nn.DataParallel(model).cuda()
126 |
127 | board_writer = tensorboardX.SummaryWriter(os.path.join('runs', args.run_name))
128 |
129 | if args.evaluate:
130 | torch.backends.cudnn.deterministic = True
131 | val_loader.iter = iter(val_loader)
132 | validate(val_loader, model, criterion)
133 | return
134 |
135 | # warm up to save GPU memory
136 | #sample_input = torch.zeros(
137 | # (12, 3, 512, 1024), device='cuda', dtype=torch.float32, requires_grad=True)
138 | #(model(sample_input).sum() * 0).backward()
139 |
140 | for epoch in range(args.start_epoch, args.epochs):
141 | print('Epoch', epoch+1)
142 |
143 | val_loader.iter = iter(val_loader)
144 | # train for one epoch
145 | train_metrics = train(train_loader, model, criterion, optimizer, epoch, board_writer)
146 |
147 | train_loader.iter = iter(train_loader)
148 | # evaluate on validation set
149 | val_metrics = validate(val_loader, model, criterion)
150 |
151 | # record metrics to tensorboard
152 | for tra,val,name in zip(train_metrics, val_metrics, ('Class IoU', 'Category IoU', 'Loss')):
153 | board_writer.add_scalars(name, {'Train': tra, 'Test': val}, epoch+1)
154 |
155 | # remember best score and save checkpoint
156 | classIoU = val_metrics[0]
157 | is_best = classIoU > best_classIoU
158 | best_classIoU = max(classIoU, best_classIoU)
159 | save_checkpoint({
160 | 'epoch': epoch + 1,
161 | 'arch': args.arch,
162 | 'state_dict': model.module.state_dict(),
163 | 'best_classIoU': best_classIoU,
164 | 'optimizer' : optimizer.state_dict(),
165 | }, is_best, os.path.join('runs', args.run_name, 'model.pth'))
166 |
167 |
168 | def train(train_loader, model, criterion, optimizer, epoch, board_writer):
169 | batch_time = AverageMeter()
170 | data_time = AverageMeter()
171 | loss_meter = AverageMeter()
172 |
173 | n_classes = criterion.weight.numel()-1
174 | confusion_matrix = torch.zeros((n_classes, n_classes), device='cuda', dtype=torch.long)
175 |
176 | # switch to train mode
177 | model.train()
178 |
179 | total_batches = len(train_loader)
180 |
181 | end = time.time()
182 | for i, (input, target) in enumerate(train_loader.iter):
183 | if i > total_batches: break
184 |
185 | # measure data loading time
186 | data_time.update(time.time() - end)
187 | input = input.cuda(non_blocking=True)
188 | target = target.cuda(non_blocking=True)
189 | # compute output
190 | output = model(input)
191 | loss = criterion(output, target)
192 | # compute gradient and do SGD step
193 | global_iteration = epoch * total_batches + i
194 | adjust_learning_rate(optimizer, epoch, i, total_batches)
195 | optimizer.zero_grad()
196 | loss.backward()
197 | optimizer.step()
198 |
199 | # measure elapsed time
200 | batch_time.update(time.time() - end)
201 |
202 | with torch.no_grad():
203 | if i % 100 == 0:
204 | """import numpy as np
205 | image = input[0].cpu().numpy().transpose((1,2,0)).copy()
206 | image -= image.min()
207 | image /= image.max()
208 | image *= 255.
209 | image = image.astype(np.uint8)
210 | output_map = output[0].max(0)[1].cpu().numpy()
211 | import imgaug
212 | segmap_display = imgaug.SegmentationMapOnImage(output_map, image.shape[:2], 20).draw_on_image(image)
213 | segmap_display = segmap_display.transpose((2,0,1)).copy()
214 |
215 | board_writer.add_image('Example segmentation', segmap_display, global_iteration)"""
216 |
217 | # update confusion matrix to compute IoU
218 | output = output.max(1)[1].view(-1).to(confusion_matrix)
219 | target = target.view(-1).to(output)
220 |
221 | confusion_matrix_update = \
222 | torch.bincount(target*(n_classes+1) + output, minlength=(n_classes+1)**2)
223 | confusion_matrix += \
224 | confusion_matrix_update.view(n_classes+1, n_classes+1)[1:,1:].to(confusion_matrix)
225 |
226 | loss_meter.update(loss.item(), input.size(0))
227 |
228 | board_writer.add_scalar('Learning rate',
229 | next(iter(optimizer.param_groups))['lr'], global_iteration)
230 | board_writer.add_scalars('Time',
231 | {'Total batch time': batch_time.val,
232 | 'Data loading time': data_time.val}, global_iteration)
233 | board_writer.add_scalar('Online batch loss', loss_meter.val, global_iteration)
234 |
235 | end = time.time()
236 |
237 | classIoU, categIoU = compute_IoU(confusion_matrix.cpu())
238 | return classIoU, categIoU, loss_meter.avg
239 |
240 |
241 | def validate(val_loader, model, criterion):
242 | batch_time = AverageMeter()
243 | loss_meter = AverageMeter()
244 |
245 | # switch to evaluate mode
246 | model.eval()
247 |
248 | n_classes = criterion.weight.numel()-1
249 | confusion_matrix = torch.zeros((n_classes, n_classes), device='cuda', dtype=torch.long)
250 |
251 | with torch.no_grad():
252 | end = time.time()
253 | for i, (input, target) in enumerate(val_loader.iter):
254 | input = input.cuda(non_blocking=True)
255 | target = target.cuda(non_blocking=True)
256 |
257 | # compute output
258 | output = model(input)
259 | loss = criterion(output, target)
260 |
261 | # record loss
262 | loss_meter.update(loss.item(), input.size(0))
263 |
264 | # update confusion matrix to compute IoU
265 | output = output.max(1)[1].view(-1)
266 | target = target.view(-1).to(output)
267 |
268 | confusion_matrix_update = \
269 | torch.bincount(target*(n_classes+1) + output, minlength=(n_classes+1)**2)
270 | confusion_matrix += \
271 | confusion_matrix_update.view(n_classes+1, n_classes+1)[1:,1:].to(confusion_matrix)
272 |
273 | # measure elapsed time
274 | batch_time.update(time.time() - end)
275 | end = time.time()
276 |
277 | classIoU, categoryIoU = compute_IoU(confusion_matrix.cpu())
278 |
279 | print('Class IoU:', classIoU)
280 | print('Category IoU:', categoryIoU)
281 | return classIoU, categoryIoU, loss_meter.avg
282 |
283 |
284 | def save_checkpoint(state, is_best, filename='checkpoint.pth'):
285 | assert filename.endswith('.pth')
286 | torch.save(state, filename)
287 | if is_best:
288 | shutil.copyfile(filename, filename[:-4] + '_best.pth')
289 |
290 |
291 | class AverageMeter(object):
292 | """Computes and stores the average and current value"""
293 | def __init__(self):
294 | self.reset()
295 |
296 | def reset(self):
297 | self.val = 0
298 | self.avg = 0
299 | self.sum = 0
300 | self.count = 0
301 |
302 | def update(self, val, n=1):
303 | self.val = val
304 | self.sum += val * n
305 | self.count += n
306 | self.avg = self.sum / self.count
307 |
308 |
309 | def adjust_learning_rate(optimizer, epoch, epoch_iteration=None, iters_per_epoch=None):
310 | step_epochs = torch.tensor(list(map(int, args.decay_steps.split(','))))
311 | lr = args.lr * (args.lr_decay ** (epoch > step_epochs).sum().item())
312 |
313 | for param_group in optimizer.param_groups:
314 | param_group['lr'] = lr
315 |
316 |
317 | def compute_IoU(confusion_matrix):
318 | n_classes = confusion_matrix.shape[0]
319 |
320 | with torch.no_grad():
321 | class_categories = [
322 | [0, 1], # flat
323 | [2, 3, 4], # construction
324 | [5, 6, 7], # object
325 | [8, 9], # nature
326 | [10], # sky
327 | [11, 12], # human
328 | [13, 14, 15, 16, 17, 18], # vehicle
329 | ]
330 |
331 | classIoU = torch.empty(n_classes, dtype=torch.float32)
332 | for class_idx in range(n_classes):
333 | TP = confusion_matrix[class_idx, class_idx].item()
334 | FN = confusion_matrix[class_idx, :].sum().item() - TP
335 | FP = confusion_matrix[:, class_idx].sum().item() - TP
336 |
337 | classIoU[class_idx] = TP / max(TP + FP + FN, 1)
338 |
339 | categoryIoU = torch.empty(len(class_categories), dtype=torch.float32)
340 | for category_idx, category in enumerate(class_categories):
341 | TP = 0
342 | for class_idx in category:
343 | TP += confusion_matrix[class_idx, category].sum().item()
344 |
345 | FN, FP = -TP, -TP
346 | for class_idx in category:
347 | FN += confusion_matrix[class_idx, :].sum().item()
348 | FP += confusion_matrix[:, class_idx].sum().item()
349 |
350 | categoryIoU[category_idx] = TP / max(TP + FP + FN, 1)
351 |
352 | return classIoU.mean(), categoryIoU.mean()
353 |
354 |
355 | if __name__ == '__main__':
356 | main()
357 |
--------------------------------------------------------------------------------
/examples/mnist.py:
--------------------------------------------------------------------------------
1 | """
2 | This script trains a very simple box convnet on MNIST.
3 | If OpenCV `videoio` is available, also outputs an animation
4 | of boxes' evolution to 'mnist-boxes.avi'.
5 | Based on https://github.com/pytorch/examples/blob/master/mnist/main.py
6 | """
7 | from __future__ import print_function
8 | import argparse
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch.optim as optim
13 | from torchvision import datasets, transforms
14 |
15 | from box_convolution import BoxConv2d
16 |
17 | class Net(nn.Module):
18 | def __init__(self):
19 | super(Net, self).__init__()
20 | self.conv1 = BoxConv2d(1, 40, 28, 28)
21 | self.conv1_1x1 = nn.Conv2d(40, 40, 1, 1)
22 |
23 | self.fc1 = nn.Linear(7*7*40, 10)
24 |
25 | def forward(self, x):
26 | # The following line computes responses to 40 "generalized Haar filters"
27 | x = self.conv1_1x1(self.conv1(x))
28 | x = F.relu(F.max_pool2d(x, 4))
29 |
30 | x = self.fc1(x.view(-1, 7*7*40))
31 | return F.log_softmax(x, dim=1)
32 |
33 | try:
34 | import cv2
35 | box_video_resolution = (300, 300)
36 | box_video = cv2.VideoWriter(
37 | 'mnist-boxes.avi', cv2.VideoWriter_fourcc(*'MJPG'), 25, tuple(reversed(box_video_resolution)))
38 | box_video_frame_count = 0
39 | video_background = None # to be defined in `main()`, sorry for globals and messy code
40 | except ImportError:
41 | box_video = None
42 | print('Couldn\'t import OpenCV. Will not log boxes to a video file')
43 |
44 |
45 | def train(model, device, train_loader, optimizer, epoch):
46 | model.train()
47 | for batch_idx, (data, target) in enumerate(train_loader):
48 |
49 | # log boxes to a video file
50 | if box_video is not None:
51 | global box_video_frame_count
52 |
53 | # change video background
54 | if box_video_frame_count % 5 == 0:
55 | global video_background # defined at the top for beautiful box visualization
56 | sample_idx = torch.randint(len(train_loader.dataset), (1,)).item()
57 | sample_digit = train_loader.dataset[sample_idx][0]
58 | video_background = torch.nn.functional.pad(sample_digit, (14,14,14,14))
59 | video_background = torch.nn.functional.interpolate(
60 | video_background.unsqueeze(0), size=box_video_resolution, mode='nearest')[0,0]
61 | video_background = video_background.unsqueeze(-1).repeat(1, 1, 3)
62 | video_background = video_background.mul(255).round().byte().numpy()
63 |
64 | # log boxes to the video file
65 | if batch_idx % 5 == 0:
66 | box_importances = model.conv1_1x1.weight.detach().float().abs().max(0)[0].squeeze()
67 | box_importances /= box_importances.max()
68 | boxes_plot = model.conv1.draw_boxes(
69 | resolution=box_video_resolution, weights=box_importances)
70 | box_video.write(cv2.addWeighted(boxes_plot, 1.0, video_background, 0.25, 0.0))
71 | box_video_frame_count += 1
72 |
73 | data, target = data.to(device), target.to(device)
74 | optimizer.zero_grad()
75 | output = model(data)
76 | loss = F.nll_loss(output, target)
77 | loss.backward()
78 | optimizer.step()
79 | if batch_idx % 100 == 0:
80 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
81 | epoch, batch_idx * len(data), len(train_loader.dataset),
82 | 100. * batch_idx / len(train_loader), loss.item()))
83 |
84 | for g in optimizer.param_groups:
85 | g['lr'] *= 0.999
86 |
87 | def test(model, device, test_loader):
88 | model.eval()
89 | test_loss = 0
90 | correct = 0
91 | with torch.no_grad():
92 | for data, target in test_loader:
93 | data, target = data.to(device), target.to(device)
94 | output = model(data)
95 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
96 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
97 | correct += pred.eq(target.view_as(pred)).sum().item()
98 |
99 | test_loss /= len(test_loader.dataset)
100 |
101 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
102 | test_loss, correct, len(test_loader.dataset),
103 | 100. * correct / len(test_loader.dataset)))
104 |
105 | def main():
106 | # Training settings
107 | use_cuda = torch.cuda.is_available()
108 | batch_size = 64
109 | n_epochs = 10
110 |
111 | torch.manual_seed(666)
112 |
113 | device = torch.device('cuda' if use_cuda else 'cpu')
114 |
115 | mnist_train = datasets.MNIST(
116 | './', train=True, download=True, transform=transforms.ToTensor())
117 | mnist_test = datasets.MNIST(
118 | './', train=False, transform=transforms.ToTensor())
119 |
120 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
121 | train_loader = torch.utils.data.DataLoader(
122 | mnist_train, batch_size=batch_size, shuffle=True, **kwargs)
123 | test_loader = torch.utils.data.DataLoader(
124 | mnist_test, batch_size=batch_size, shuffle=True, **kwargs)
125 |
126 | model = Net().to(device)
127 | optimizer = optim.Adam(model.parameters(), lr=1e-3)
128 |
129 | for epoch in range(1, n_epochs + 1):
130 | train(model, device, train_loader, optimizer, epoch)
131 | test(model, device, test_loader)
132 |
133 | if __name__ == '__main__':
134 | main()
135 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # monkey-patch for parallel compilation
2 | # https://stackoverflow.com/questions/11013851/speeding-up-build-process-with-distutils
3 | def parallelCCompile(
4 | self, sources, output_dir=None, macros=None, include_dirs=None, debug=0,
5 | extra_preargs=None, extra_postargs=None, depends=None):
6 |
7 | # those lines are copied from distutils.ccompiler.CCompiler directly
8 | macros, objects, extra_postargs, pp_opts, build = \
9 | self._setup_compile(output_dir, macros, include_dirs, sources, depends, extra_postargs)
10 | cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
11 |
12 | # parallel code
13 | from multiprocessing import cpu_count
14 | try:
15 | n_processes = cpu_count() # number of parallel compilations
16 | except NotImplementedError:
17 | print('multiprocessing.cpu_count() failed, building on 1 core')
18 | n_processes = 1
19 |
20 | def _single_compile(obj):
21 | try: src, ext = build[obj]
22 | except KeyError: return
23 | self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
24 |
25 | import multiprocessing.pool
26 | multiprocessing.pool.ThreadPool(n_processes).map(_single_compile, objects)
27 |
28 | return objects
29 |
30 | import distutils.ccompiler
31 | distutils.ccompiler.CCompiler.compile = parallelCCompile
32 |
33 | import torch
34 |
35 | build_cuda = torch.cuda.is_available() # TODO allow cross-compiling too
36 |
37 | source_root = 'src'
38 | source_files_cpp = [
39 | 'integral_image_interface.cpp',
40 | 'integral_image.cpp',
41 | 'box_convolution_interface.cpp',
42 | 'box_convolution.cpp',
43 | 'bind.cpp'
44 | ]
45 | source_files_cuda = [
46 | 'integral_image_cuda.cu',
47 | 'box_convolution_cuda_forward.cu',
48 | 'box_convolution_cuda_backward.cu',
49 | 'box_convolution_cuda_misc.cu'
50 | ]
51 | source_files_cuda_stubs = [
52 | 'cuda_stubs.cpp'
53 | ]
54 | source_files = source_files_cpp + (source_files_cuda if build_cuda else source_files_cuda_stubs)
55 |
56 | from torch.utils.cpp_extension import CppExtension, CUDAExtension
57 | import os
58 |
59 | extra_compile_args = {'cxx': [], 'nvcc': []}
60 | if os.getenv('CC'):
61 | # temporary hack to allow choosing a different host compiler for NVCC too
62 | extra_compile_args['nvcc'] += ['-ccbin', os.getenv('CC')]
63 |
64 | cpp_cuda = (CUDAExtension if build_cuda else CppExtension)(
65 | name='box_convolution_cpp_cuda',
66 | sources=[os.path.join(source_root, file) for file in source_files],
67 | include_dirs=[source_root],
68 | extra_compile_args=extra_compile_args
69 | )
70 |
71 | from setuptools import setup
72 |
73 | setup(
74 | name='box_convolution',
75 | packages=['box_convolution'],
76 | ext_modules=[cpp_cuda],
77 | cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension},
78 | install_requires=['future', 'torch>=1.0.0a0']
79 | )
80 |
--------------------------------------------------------------------------------
/src/bind.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | at::Tensor integral_image(
4 | at::Tensor input);
5 |
6 | at::Tensor box_convolution_forward(
7 | at::Tensor input_integrated,
8 | at::Tensor x_min, at::Tensor x_max,
9 | at::Tensor y_min, at::Tensor y_max,
10 | const bool normalize, const bool exact);
11 |
12 | std::vector box_convolution_backward(
13 | at::Tensor input_integrated,
14 | at::Tensor x_min, at::Tensor x_max,
15 | at::Tensor y_min, at::Tensor y_max,
16 | at::Tensor grad_output, at::Tensor output,
17 | const float reparametrization_h, const float reparametrization_w,
18 | const bool normalize, const bool exact,
19 | const bool input_needs_grad,
20 | const bool x_min_needs_grad, const bool x_max_needs_grad,
21 | const bool y_min_needs_grad, const bool y_max_needs_grad);
22 |
23 | void clip_parameters(
24 | at::Tensor x_min, at::Tensor x_max,
25 | at::Tensor y_min, at::Tensor y_max,
26 | const double reparametrization_h, const double reparametrization_w,
27 | const double max_input_h, const double max_input_w, const bool exact);
28 |
29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30 | m.def("integral_image", &integral_image, "Integral image");
31 | m.def("box_convolution_forward" , &box_convolution_forward , "Box convolution, forward" );
32 | m.def("box_convolution_backward", &box_convolution_backward, "Box convolution, backward");
33 | m.def("clip_parameters", &clip_parameters, "Box convolution, clip parameters");
34 | }
35 |
--------------------------------------------------------------------------------
/src/box_convolution.cpp:
--------------------------------------------------------------------------------
1 | /* CPU implementations of the functions that really operate actual tensor data
2 | * on a low level. Used by those in `box_convolution_interface.cpp`. */
3 |
4 | #include
5 |
6 | using std::min;
7 | using std::max;
8 |
9 | #include "box_convolution.h" // for `enum class Parameter`
10 |
11 | namespace cpu {
12 |
13 | // Splits x_min, x_max, y_min, y_max into integer and fractional parts
14 | void splitParameters(
15 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max ,
16 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
17 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac) {
18 |
19 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_min.scalar_type(), "cpu::splitParameters", ([&] {
20 | scalar_t minInt, maxInt;
21 |
22 | for (int i = 0; i < x_min.numel(); ++i) {
23 | minInt = std::ceil(x_min.data_ptr()[i]);
24 | xMinFrac.data_ptr()[i] = minInt - x_min.data_ptr()[i];
25 | xMinInt.data_ptr()[i] = static_cast(minInt);
26 |
27 | minInt = std::ceil(y_min.data_ptr()[i]);
28 | yMinFrac.data_ptr()[i] = minInt - y_min.data_ptr()[i];
29 | yMinInt.data_ptr()[i] = static_cast(minInt);
30 |
31 | maxInt = std::floor(x_max.data_ptr()[i]);
32 | xMaxFrac.data_ptr()[i] = x_max.data_ptr()[i] - maxInt;
33 | xMaxInt.data_ptr()[i] = static_cast(maxInt) + 1;
34 |
35 | maxInt = std::floor(y_max.data_ptr()[i]);
36 | yMaxFrac.data_ptr()[i] = y_max.data_ptr()[i] - maxInt;
37 | yMaxInt.data_ptr()[i] = static_cast(maxInt) + 1;
38 | }
39 | }));
40 | }
41 |
42 | // A special parameters' split for backward pass wrt input
43 | void splitParametersUpdateGradInput(
44 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max ,
45 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
46 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac) {
47 |
48 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_min.scalar_type(), "cpu::splitParametersUpdateGradInput", ([&] {
49 | scalar_t minInt, maxInt;
50 |
51 | for (int i = 0; i < x_min.numel(); ++i) {
52 | minInt = std::ceil(-x_max.data_ptr()[i]);
53 | xMinFrac.data_ptr()[i] = minInt + x_max.data_ptr()[i];
54 | xMinInt.data_ptr()[i] = static_cast(minInt);
55 |
56 | minInt = std::ceil(-y_max.data_ptr()[i]);
57 | yMinFrac.data_ptr()[i] = minInt + y_max.data_ptr()[i];
58 | yMinInt.data_ptr()[i] = static_cast(minInt);
59 |
60 | maxInt = std::floor(-x_min.data_ptr()[i]) + 1;
61 | xMaxFrac.data_ptr()[i] = -x_min.data_ptr()[i] + 1 - maxInt;
62 | xMaxInt.data_ptr()[i] = static_cast(maxInt);
63 |
64 | maxInt = std::floor(-y_min.data_ptr()[i]) + 1;
65 | yMaxFrac.data_ptr()[i] = -y_min.data_ptr()[i] + 1 - maxInt;
66 | yMaxInt.data_ptr()[i] = static_cast(maxInt);
67 | }
68 | }));
69 | }
70 |
71 | // A special parameters' split for backward pass wrt x_min, x_max, y_min, y_max
72 | void splitParametersAccGradParameters(
73 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max ,
74 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
75 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac) {
76 |
77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_min.scalar_type(), "cpu::splitParametersAccGradParams", ([&] {
78 | scalar_t minInt, maxInt;
79 |
80 | for (int i = 0; i < x_min.numel(); ++i) {
81 | minInt = std::ceil(x_min.data_ptr()[i] - 1);
82 | xMinFrac.data_ptr()[i] = minInt - x_min.data_ptr()[i] + 1;
83 | xMinInt.data_ptr()[i] = static_cast(minInt);
84 |
85 | minInt = std::ceil(y_min.data_ptr()[i] - 1);
86 | yMinFrac.data_ptr()[i] = minInt - y_min.data_ptr()[i] + 1;
87 | yMinInt.data_ptr()[i] = static_cast(minInt);
88 |
89 | maxInt = std::floor(x_max.data_ptr()[i]);
90 | xMaxFrac.data_ptr()[i] = x_max.data_ptr()[i] - maxInt;
91 | xMaxInt.data_ptr()[i] = static_cast(maxInt);
92 |
93 | maxInt = std::floor(y_max.data_ptr()[i]);
94 | yMaxFrac.data_ptr()[i] = y_max.data_ptr()[i] - maxInt;
95 | yMaxInt.data_ptr()[i] = static_cast(maxInt);
96 | }
97 | }));
98 | }
99 |
100 | template
101 | void boxConvUpdateOutput(
102 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
103 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac,
104 | at::Tensor & area, at::Tensor & input_integrated, at::Tensor & output) {
105 |
106 | // was `const int`, but had to remove `const` to work around a bug in GCC 5
107 | int h = output.size(-2);
108 | int w = output.size(-1);
109 |
110 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(output.scalar_type(), "cpu::boxConvUpdateOutput", ([&] {
111 | auto xMinIntAcsr = xMinInt.accessor();
112 | auto xMaxIntAcsr = xMaxInt.accessor();
113 | auto yMinIntAcsr = yMinInt.accessor();
114 | auto yMaxIntAcsr = yMaxInt.accessor();
115 |
116 | auto xMinFracAcsr = xMinFrac.accessor();
117 | auto xMaxFracAcsr = xMaxFrac.accessor();
118 | auto yMinFracAcsr = yMinFrac.accessor();
119 | auto yMaxFracAcsr = yMaxFrac.accessor();
120 |
121 | auto areaAcsr = xMinFracAcsr; // because there's no default ctor :(
122 | // only initialize the accessor if `area` is defined (errors otherwise)
123 | if (normalize) {
124 | areaAcsr = area.accessor();
125 | }
126 |
127 | scalar_t *outputData = output.data_ptr();
128 |
129 | for (int batchIdx = 0; batchIdx < input_integrated.size(0); ++batchIdx) {
130 | for (int inPlaneIdx = 0; inPlaneIdx < input_integrated.size(1); ++inPlaneIdx) {
131 | auto inputIntPlane = input_integrated[batchIdx][inPlaneIdx];
132 | auto inputIntAcsr = inputIntPlane.accessor();
133 |
134 | for (int filterIdx = 0; filterIdx < xMinInt.size(1); ++filterIdx) {
135 |
136 | // TODO make a separate loop for each 2D array access?
137 | for (int x = 0; x < h; ++x) {
138 | for (int y = 0; y < w; ++y) {
139 | const int xMinCurr = xMinIntAcsr[inPlaneIdx][filterIdx];
140 | const int xMaxCurr = xMaxIntAcsr[inPlaneIdx][filterIdx];
141 | const int yMinCurr = yMinIntAcsr[inPlaneIdx][filterIdx];
142 | const int yMaxCurr = yMaxIntAcsr[inPlaneIdx][filterIdx];
143 |
144 | const scalar_t xMinCurrFrac = xMinFracAcsr[inPlaneIdx][filterIdx];
145 | const scalar_t xMaxCurrFrac = xMaxFracAcsr[inPlaneIdx][filterIdx];
146 | const scalar_t yMinCurrFrac = yMinFracAcsr[inPlaneIdx][filterIdx];
147 | const scalar_t yMaxCurrFrac = yMaxFracAcsr[inPlaneIdx][filterIdx];
148 |
149 | // Must add 1 to xMax/yMax/xMin/yMin due to OpenCV's
150 | // `integral()` behavior. Namely, I(x,0) and I(0,y) are
151 | // always 0 (so it's a C-style array sum).
152 |
153 | // However, when computing sums, we subtract values at points
154 | // like y+yMin-1 and x+xMin-1, so we also SUBTRACT 1 from xMin
155 | // and yMin, and thus finally they are not affected.
156 |
157 | const int t = max(0, min(x+xMinCurr, h));
158 | const int b = max(0, min(x+xMaxCurr, h));
159 | const int l = max(0, min(y+yMinCurr, w));
160 | const int r = max(0, min(y+yMaxCurr, w));
161 |
162 | const int bAdv = max(0, min(x+xMaxCurr+1, h));
163 | const int rAdv = max(0, min(y+yMaxCurr+1, w));
164 | const int tAdv = max(0, min(x+xMinCurr-1, h));
165 | const int lAdv = max(0, min(y+yMinCurr-1, w));
166 |
167 | scalar_t outValue;
168 |
169 | // -- main area
170 | outValue =
171 | inputIntAcsr[b][r]
172 | - inputIntAcsr[t][r]
173 | - inputIntAcsr[b][l]
174 | + inputIntAcsr[t][l];
175 |
176 | if (exact) {
177 | // -- xMax border
178 | outValue +=
179 | ( inputIntAcsr[bAdv][r]
180 | - inputIntAcsr[b ][r]
181 | - inputIntAcsr[bAdv][l]
182 | + inputIntAcsr[b ][l]) * xMaxCurrFrac;
183 |
184 | // -- yMax border
185 | outValue +=
186 | ( inputIntAcsr[b][rAdv]
187 | - inputIntAcsr[b][r ]
188 | - inputIntAcsr[t][rAdv]
189 | + inputIntAcsr[t][r ]) * yMaxCurrFrac;
190 |
191 | // -- xMin border
192 | outValue +=
193 | ( inputIntAcsr[t ][r]
194 | - inputIntAcsr[tAdv][r]
195 | - inputIntAcsr[t ][l]
196 | + inputIntAcsr[tAdv][l]) * xMinCurrFrac;
197 |
198 | // -- yMin border
199 | outValue +=
200 | ( inputIntAcsr[b][l ]
201 | - inputIntAcsr[b][lAdv]
202 | - inputIntAcsr[t][l ]
203 | + inputIntAcsr[t][lAdv]) * yMinCurrFrac;
204 |
205 | // -- corner pixels
206 | // Note: before, I used plain `input` to access corner values
207 | // with lower memory overhead. Moved to `input_integrated`
208 | // to get rid of an extra input to this function.
209 |
210 | if (not ((x+xMaxCurr >= h) | (y+yMaxCurr >= w) |
211 | (x+xMaxCurr < 0) | (y+yMaxCurr < 0))) {
212 | outValue +=
213 | xMaxCurrFrac * yMaxCurrFrac *
214 | ( inputIntAcsr[b+1][r+1]
215 | - inputIntAcsr[b ][r+1]
216 | - inputIntAcsr[b+1][r ]
217 | + inputIntAcsr[b ][r ]);
218 | }
219 |
220 | if (not ((x+xMinCurr > h) | (y+yMaxCurr >= w) |
221 | (x+xMinCurr <= 0) | (y+yMaxCurr < 0))) {
222 | outValue +=
223 | xMinCurrFrac * yMaxCurrFrac *
224 | ( inputIntAcsr[t ][r+1]
225 | - inputIntAcsr[t-1][r+1]
226 | - inputIntAcsr[t ][r ]
227 | + inputIntAcsr[t-1][r ]);
228 | }
229 |
230 | if (not ((x+xMaxCurr >= h) | (y+yMinCurr > w) |
231 | (x+xMaxCurr < 0) | (y+yMinCurr <= 0))) {
232 | outValue +=
233 | xMaxCurrFrac * yMinCurrFrac *
234 | ( inputIntAcsr[b+1][l ]
235 | - inputIntAcsr[b ][l ]
236 | - inputIntAcsr[b+1][l-1]
237 | + inputIntAcsr[b ][l-1]);
238 | }
239 |
240 | if (not ((x+xMinCurr > h) | (y+yMinCurr > w) |
241 | (x+xMinCurr <= 0) | (y+yMinCurr <= 0))) {
242 | outValue +=
243 | xMinCurrFrac * yMinCurrFrac *
244 | ( inputIntAcsr[t ][l ]
245 | - inputIntAcsr[t-1][l ]
246 | - inputIntAcsr[t ][l-1]
247 | + inputIntAcsr[t-1][l-1]);
248 | }
249 | }
250 |
251 | *(outputData++) = outValue *
252 | (normalize ?
253 | areaAcsr[inPlaneIdx][filterIdx] :
254 | static_cast(1));
255 | }
256 | }
257 | } // filterIdx
258 | } // inPlaneIdx
259 | } // batchIdx
260 | }));
261 | }
262 |
263 | // explicitly instantiate
264 | template void boxConvUpdateOutput(
265 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
266 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
267 | at::Tensor &, at::Tensor &, at::Tensor &);
268 |
269 | template void boxConvUpdateOutput(
270 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
271 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
272 | at::Tensor &, at::Tensor &, at::Tensor &);
273 |
274 | template void boxConvUpdateOutput(
275 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
276 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
277 | at::Tensor &, at::Tensor &, at::Tensor &);
278 |
279 | template void boxConvUpdateOutput(
280 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
281 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
282 | at::Tensor &, at::Tensor &, at::Tensor &);
283 |
284 | // `grad_output_integrated` size: {batchSize, nInputPlanes, numFilters, h+1, w+1}
285 | // `tmpArray` size: {batchSize, nInputPlanes, numFilters, h, w}
286 | template
287 | void boxConvUpdateGradInput(
288 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
289 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac,
290 | at::Tensor & area, at::Tensor & grad_output_integrated, at::Tensor & tmpArray) {
291 |
292 | // was `const int`, but had to remove `const` to work around a bug in GCC 5
293 | int h = tmpArray.size(-2);
294 | int w = tmpArray.size(-1);
295 |
296 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tmpArray.scalar_type(), "cpu::boxConvUpdateGradInput", ([&] {
297 |
298 | auto xMinIntAcsr = xMinInt.accessor();
299 | auto xMaxIntAcsr = xMaxInt.accessor();
300 | auto yMinIntAcsr = yMinInt.accessor();
301 | auto yMaxIntAcsr = yMaxInt.accessor();
302 |
303 | auto xMinFracAcsr = xMinFrac.accessor();
304 | auto xMaxFracAcsr = xMaxFrac.accessor();
305 | auto yMinFracAcsr = yMinFrac.accessor();
306 | auto yMaxFracAcsr = yMaxFrac.accessor();
307 |
308 | auto areaAcsr = xMinFracAcsr; // because there's no default ctor :(
309 | // only initialize the accessor if `area` is defined (errors otherwise)
310 | if (normalize) {
311 | areaAcsr = area.accessor();
312 | }
313 |
314 | scalar_t *tmpArrayData = tmpArray.data_ptr();
315 |
316 | for (int batchIdx = 0; batchIdx < grad_output_integrated.size(0); ++batchIdx) {
317 | for (int inPlaneIdx = 0; inPlaneIdx < grad_output_integrated.size(1); ++inPlaneIdx) {
318 | for (int filterIdx = 0; filterIdx < xMinInt.size(1); ++filterIdx) {
319 |
320 | const int xMinCurr = xMinIntAcsr[inPlaneIdx][filterIdx];
321 | const int xMaxCurr = xMaxIntAcsr[inPlaneIdx][filterIdx];
322 | const int yMinCurr = yMinIntAcsr[inPlaneIdx][filterIdx];
323 | const int yMaxCurr = yMaxIntAcsr[inPlaneIdx][filterIdx];
324 |
325 | const scalar_t xMinCurrFrac = xMinFracAcsr[inPlaneIdx][filterIdx];
326 | const scalar_t xMaxCurrFrac = xMaxFracAcsr[inPlaneIdx][filterIdx];
327 | const scalar_t yMinCurrFrac = yMinFracAcsr[inPlaneIdx][filterIdx];
328 | const scalar_t yMaxCurrFrac = yMaxFracAcsr[inPlaneIdx][filterIdx];
329 |
330 | auto gradOutputIntPlane =
331 | grad_output_integrated[batchIdx][inPlaneIdx][filterIdx];
332 | auto gradOutputAcsr = gradOutputIntPlane.accessor();
333 |
334 | for (int x = 0; x < h; ++x) {
335 | for (int y = 0; y < w; ++y) {
336 |
337 | const int t = max(0, min(x+xMinCurr, h));
338 | const int b = max(0, min(x+xMaxCurr, h));
339 | const int l = max(0, min(y+yMinCurr, w));
340 | const int r = max(0, min(y+yMaxCurr, w));
341 |
342 | const int tAdv = x+xMinCurr-1 < h ? max(0, min(t-1, h)) : t;
343 | const int bAdv = x+xMaxCurr >= 0 ? max(0, min(b+1, h)) : b;
344 | const int lAdv = y+yMinCurr-1 < w ? max(0, min(l-1, w)) : l;
345 | const int rAdv = y+yMaxCurr >= 0 ? max(0, min(r+1, w)) : r;
346 |
347 | scalar_t outValue;
348 |
349 | outValue =
350 | gradOutputAcsr[b][r]
351 | - gradOutputAcsr[t][r]
352 | - gradOutputAcsr[b][l]
353 | + gradOutputAcsr[t][l];
354 |
355 | if (exact) {
356 | // -- xMax border
357 | outValue +=
358 | ( gradOutputAcsr[bAdv][r]
359 | - gradOutputAcsr[b ][r]
360 | - gradOutputAcsr[bAdv][l]
361 | + gradOutputAcsr[b ][l]
362 | ) * xMaxCurrFrac;
363 |
364 | // -- yMax border
365 | outValue +=
366 | ( gradOutputAcsr[b][rAdv]
367 | - gradOutputAcsr[b][r ]
368 | - gradOutputAcsr[t][rAdv]
369 | + gradOutputAcsr[t][r ]
370 | ) * yMaxCurrFrac;
371 |
372 | // -- xMin border
373 | outValue +=
374 | ( gradOutputAcsr[t ][r]
375 | - gradOutputAcsr[tAdv][r]
376 | - gradOutputAcsr[t ][l]
377 | + gradOutputAcsr[tAdv][l]
378 | ) * xMinCurrFrac;
379 |
380 | // -- yMin border
381 | outValue +=
382 | ( gradOutputAcsr[b][l ]
383 | - gradOutputAcsr[b][lAdv]
384 | - gradOutputAcsr[t][l ]
385 | + gradOutputAcsr[t][lAdv]
386 | ) * yMinCurrFrac;
387 |
388 | // -- corner pixels
389 | outValue +=
390 | xMaxCurrFrac*yMaxCurrFrac * (
391 | (x+xMaxCurr >= h or
392 | y+yMaxCurr >= w or
393 | x+xMaxCurr < 0 or
394 | y+yMaxCurr < 0 or
395 | b == bAdv or
396 | r == rAdv) ? static_cast(0) :
397 |
398 | ( gradOutputAcsr[b+1][r+1]
399 | - gradOutputAcsr[b ][r+1]
400 | - gradOutputAcsr[b+1][r ]
401 | + gradOutputAcsr[b ][r ]));
402 |
403 | outValue +=
404 | xMinCurrFrac*yMaxCurrFrac * (
405 | (x+xMinCurr > h or
406 | y+yMaxCurr >= w or
407 | x+xMinCurr <= 0 or
408 | y+yMaxCurr < 0 or
409 | t == tAdv or
410 | r == rAdv) ? static_cast(0) :
411 |
412 | ( gradOutputAcsr[tAdv+1][r+1]
413 | - gradOutputAcsr[tAdv+1][r ]
414 | - gradOutputAcsr[tAdv ][r+1]
415 | + gradOutputAcsr[tAdv ][r ]));
416 |
417 | outValue +=
418 | xMaxCurrFrac*yMinCurrFrac * (
419 | (x+xMaxCurr >= h or
420 | y+yMinCurr > w or
421 | x+xMaxCurr < 0 or
422 | y+yMinCurr <= 0 or
423 | b == bAdv or
424 | l == lAdv) ? static_cast(0) :
425 |
426 | ( gradOutputAcsr[b+1][lAdv+1]
427 | - gradOutputAcsr[b ][lAdv+1]
428 | - gradOutputAcsr[b+1][lAdv ]
429 | + gradOutputAcsr[b ][lAdv ]));
430 |
431 | outValue +=
432 | xMinCurrFrac*yMinCurrFrac * (
433 | (x+xMinCurr > h or
434 | y+yMinCurr > w or
435 | x+xMinCurr <= 0 or
436 | y+yMinCurr <= 0 or
437 | t == tAdv or
438 | l == lAdv) ? static_cast(0) :
439 |
440 | ( gradOutputAcsr[tAdv+1][lAdv+1]
441 | - gradOutputAcsr[tAdv+1][lAdv ]
442 | - gradOutputAcsr[tAdv ][lAdv+1]
443 | + gradOutputAcsr[tAdv ][lAdv ]));
444 | }
445 |
446 | *(tmpArrayData++) = outValue *
447 | (normalize ?
448 | areaAcsr[inPlaneIdx][filterIdx] :
449 | static_cast(1));
450 | }
451 | }
452 | } // filterIdx
453 | } // inPlaneIdx
454 | } // batchIdx
455 | }));
456 | }
457 |
458 | // explicitly instantiate
459 | template void boxConvUpdateGradInput(
460 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
461 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
462 | at::Tensor &, at::Tensor &, at::Tensor &);
463 |
464 | template void boxConvUpdateGradInput(
465 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
466 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
467 | at::Tensor &, at::Tensor &, at::Tensor &);
468 |
469 | template void boxConvUpdateGradInput(
470 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
471 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
472 | at::Tensor &, at::Tensor &, at::Tensor &);
473 |
474 | template void boxConvUpdateGradInput(
475 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
476 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
477 | at::Tensor &, at::Tensor &, at::Tensor &);
478 |
479 |
480 | template
481 | void boxConvAccGradParameters(
482 | // tmpArray size: {batchSize, nInputPlanes, numFilters, h, w}
483 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
484 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac,
485 | at::Tensor & input_integrated, at::Tensor & tmpArray, Parameter parameter) {
486 |
487 | // was `const int`, but had to remove `const` to work around a bug in GCC 5
488 | int h = tmpArray.size(-2);
489 | int w = tmpArray.size(-1);
490 |
491 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tmpArray.scalar_type(), "cpu::boxConvAccGradParameters", ([&] {
492 |
493 | auto xMinIntAcsr = xMinInt.accessor();
494 | auto xMaxIntAcsr = xMaxInt.accessor();
495 | auto yMinIntAcsr = yMinInt.accessor();
496 | auto yMaxIntAcsr = yMaxInt.accessor();
497 |
498 | auto xMinFracAcsr = xMinFrac.accessor();
499 | auto xMaxFracAcsr = xMaxFrac.accessor();
500 | auto yMinFracAcsr = yMinFrac.accessor();
501 | auto yMaxFracAcsr = yMaxFrac.accessor();
502 |
503 | scalar_t *tmpArrayData = tmpArray.data_ptr();
504 |
505 | for (int batchIdx = 0; batchIdx < input_integrated.size(0); ++batchIdx) {
506 | for (int inPlaneIdx = 0; inPlaneIdx < input_integrated.size(1); ++inPlaneIdx) {
507 |
508 | auto inputIntPlane =
509 | input_integrated[batchIdx][inPlaneIdx];
510 | auto inputIntAcsr = inputIntPlane.accessor();
511 |
512 | for (int filterIdx = 0; filterIdx < xMinInt.size(1); ++filterIdx) {
513 |
514 | const int xMinInt = xMinIntAcsr[inPlaneIdx][filterIdx];
515 | const int xMaxInt = xMaxIntAcsr[inPlaneIdx][filterIdx];
516 | const int yMinInt = yMinIntAcsr[inPlaneIdx][filterIdx];
517 | const int yMaxInt = yMaxIntAcsr[inPlaneIdx][filterIdx];
518 |
519 | const scalar_t xMinFrac = xMinFracAcsr[inPlaneIdx][filterIdx];
520 | const scalar_t xMaxFrac = xMaxFracAcsr[inPlaneIdx][filterIdx];
521 | const scalar_t yMinFrac = yMinFracAcsr[inPlaneIdx][filterIdx];
522 | const scalar_t yMaxFrac = yMaxFracAcsr[inPlaneIdx][filterIdx];
523 |
524 | for (int x = 1; x <= h; ++x) {
525 | for (int y = 1; y <= w; ++y) {
526 |
527 | if (parameter == Parameter::xMin) {
528 | // Had to move 3 following lines into the
529 | // `if` to ensure loop unswitching
530 | int valid;
531 | int cornerX, cornerY;
532 |
533 | scalar_t delta = 0;
534 |
535 | if (exact) {
536 | // TODO maybe use `input` instead of `inputInt`
537 | valid =
538 | not (y+yMinInt < 1) & not (y+yMinInt > w) & not (x+xMinInt < 1);
539 | cornerX = max(0,min(h-1,x+xMinInt-1));
540 | cornerY = max(0,min(w-1,y+yMinInt-1));
541 | const scalar_t tlCorner = valid *
542 | ( inputIntAcsr[cornerX+1][cornerY+1]
543 | - inputIntAcsr[cornerX ][cornerY+1]
544 | - inputIntAcsr[cornerX+1][cornerY ]
545 | + inputIntAcsr[cornerX ][cornerY ]);
546 |
547 | valid =
548 | not (y+yMaxInt < 0) & not (y+yMaxInt >= w) & not (x+xMinInt < 1);
549 | cornerX = max(0,min(h-1,x+xMinInt-1));
550 | cornerY = max(0,min(w-1,y+yMaxInt ));
551 | const scalar_t trCorner = valid *
552 | ( inputIntAcsr[cornerX+1][cornerY+1]
553 | - inputIntAcsr[cornerX ][cornerY+1]
554 | - inputIntAcsr[cornerX+1][cornerY ]
555 | + inputIntAcsr[cornerX ][cornerY ]);
556 |
557 | delta += trCorner * yMaxFrac;
558 | delta += tlCorner * yMinFrac;
559 | } // if (exact)
560 |
561 | delta += inputIntAcsr
562 | [max(0,min(x+xMinInt , h))][max(0,min(y+yMaxInt , w))];
563 | delta -= inputIntAcsr
564 | [max(0,min(x+xMinInt-1, h))][max(0,min(y+yMaxInt , w))];
565 | delta -= inputIntAcsr
566 | [max(0,min(x+xMinInt , h))][max(0,min(y+yMinInt , w))];
567 | delta += inputIntAcsr
568 | [max(0,min(x+xMinInt-1, h))][max(0,min(y+yMinInt , w))];
569 |
570 | delta *= (x+xMinInt >= 1) & (x+xMinInt <= h);
571 |
572 | *(tmpArrayData++) = -delta;
573 | }
574 |
575 | else if (parameter == Parameter::xMax) {
576 | int valid;
577 | int cornerX, cornerY;
578 |
579 | scalar_t delta = 0;
580 |
581 | if (exact) {
582 | valid =
583 | not (y+yMinInt < 1) & not (y+yMinInt > w) & not (x+xMaxInt >= h);
584 | cornerX = max(0,min(h-1,x+xMaxInt ));
585 | cornerY = max(0,min(w-1,y+yMinInt-1));
586 | const scalar_t blCorner = valid *
587 | ( inputIntAcsr[cornerX+1][cornerY+1]
588 | - inputIntAcsr[cornerX ][cornerY+1]
589 | - inputIntAcsr[cornerX+1][cornerY ]
590 | + inputIntAcsr[cornerX ][cornerY ]);
591 |
592 | valid =
593 | not (y+yMaxInt < 0) & not (y+yMaxInt >= w) & not (x+xMaxInt >= h);
594 | cornerX = max(0,min(h-1,x+xMaxInt ));
595 | cornerY = max(0,min(w-1,y+yMaxInt ));
596 | const scalar_t brCorner = valid *
597 | ( inputIntAcsr[cornerX+1][cornerY+1]
598 | - inputIntAcsr[cornerX ][cornerY+1]
599 | - inputIntAcsr[cornerX+1][cornerY ]
600 | + inputIntAcsr[cornerX ][cornerY ]);
601 |
602 | delta += brCorner * yMaxFrac;
603 | delta += blCorner * yMinFrac;
604 | } // if (exact)
605 |
606 | delta += inputIntAcsr
607 | [max(0,min(x+xMaxInt+1, h))][max(0,min(y+yMaxInt , w))];
608 | delta -= inputIntAcsr
609 | [max(0,min(x+xMaxInt , h))][max(0,min(y+yMaxInt , w))];
610 | delta -= inputIntAcsr
611 | [max(0,min(x+xMaxInt+1, h))][max(0,min(y+yMinInt , w))];
612 | delta += inputIntAcsr
613 | [max(0,min(x+xMaxInt , h))][max(0,min(y+yMinInt , w))];
614 |
615 | delta *= (x+xMaxInt >= 0) & (x+xMaxInt < h);
616 |
617 | *(tmpArrayData++) = delta;
618 | }
619 |
620 | else if (parameter == Parameter::yMin) {
621 | int valid;
622 | int cornerX, cornerY;
623 |
624 | scalar_t delta = 0;
625 |
626 | if (exact) {
627 | valid =
628 | not (y+yMinInt < 1) & not (x+xMinInt < 1) & not (x+xMinInt > h);
629 | cornerX = max(0,min(h-1,x+xMinInt-1));
630 | cornerY = max(0,min(w-1,y+yMinInt-1));
631 | const scalar_t tlCorner = valid *
632 | ( inputIntAcsr[cornerX+1][cornerY+1]
633 | - inputIntAcsr[cornerX ][cornerY+1]
634 | - inputIntAcsr[cornerX+1][cornerY ]
635 | + inputIntAcsr[cornerX ][cornerY ]);
636 |
637 | valid =
638 | not (y+yMinInt < 1) & not (x+xMaxInt < 0) & not (x+xMaxInt >= h);
639 | cornerX = max(0,min(h-1,x+xMaxInt ));
640 | cornerY = max(0,min(w-1,y+yMinInt-1));
641 | const scalar_t blCorner = valid *
642 | ( inputIntAcsr[cornerX+1][cornerY+1]
643 | - inputIntAcsr[cornerX ][cornerY+1]
644 | - inputIntAcsr[cornerX+1][cornerY ]
645 | + inputIntAcsr[cornerX ][cornerY ]);
646 |
647 | delta += tlCorner * xMinFrac;
648 | delta += blCorner * xMaxFrac;
649 | } // if (exact)
650 |
651 | delta += inputIntAcsr
652 | [max(0,min(x+xMaxInt , h))][max(0,min(y+yMinInt , w))];
653 | delta -= inputIntAcsr
654 | [max(0,min(x+xMaxInt , h))][max(0,min(y+yMinInt-1, w))];
655 | delta -= inputIntAcsr
656 | [max(0,min(x+xMinInt , h))][max(0,min(y+yMinInt , w))];
657 | delta += inputIntAcsr
658 | [max(0,min(x+xMinInt , h))][max(0,min(y+yMinInt-1, w))];
659 |
660 | delta *= (y+yMinInt >= 1) & (y+yMinInt <= w);
661 |
662 | *(tmpArrayData++) = -delta;
663 | }
664 |
665 | else if (parameter == Parameter::yMax) {
666 | int valid;
667 | int cornerX, cornerY;
668 |
669 | scalar_t delta = 0;
670 |
671 | if (exact) {
672 | valid =
673 | not (y+yMaxInt >= w) & not (x+xMinInt < 1) & not (x+xMinInt > h);
674 | cornerX = max(0,min(h-1,x+xMinInt-1));
675 | cornerY = max(0,min(w-1,y+yMaxInt ));
676 | const scalar_t trCorner = valid *
677 | ( inputIntAcsr[cornerX+1][cornerY+1]
678 | - inputIntAcsr[cornerX ][cornerY+1]
679 | - inputIntAcsr[cornerX+1][cornerY ]
680 | + inputIntAcsr[cornerX ][cornerY ]);
681 |
682 | valid =
683 | not (y+yMaxInt >= w) & not (x+xMaxInt < 0) & not (x+xMaxInt >= h);
684 | cornerX = max(0,min(h-1,x+xMaxInt ));
685 | cornerY = max(0,min(w-1,y+yMaxInt ));
686 | const scalar_t brCorner = valid *
687 | ( inputIntAcsr[cornerX+1][cornerY+1]
688 | - inputIntAcsr[cornerX ][cornerY+1]
689 | - inputIntAcsr[cornerX+1][cornerY ]
690 | + inputIntAcsr[cornerX ][cornerY ]);
691 |
692 | delta += trCorner * xMinFrac;
693 | delta += brCorner * xMaxFrac;
694 | } // if (exact)
695 |
696 | delta += inputIntAcsr
697 | [max(0,min(x+xMaxInt , h))][max(0,min(y+yMaxInt+1, w))];
698 | delta -= inputIntAcsr
699 | [max(0,min(x+xMaxInt , h))][max(0,min(y+yMaxInt , w))];
700 | delta -= inputIntAcsr
701 | [max(0,min(x+xMinInt , h))][max(0,min(y+yMaxInt+1, w))];
702 | delta += inputIntAcsr
703 | [max(0,min(x+xMinInt , h))][max(0,min(y+yMaxInt , w))];
704 |
705 | delta *= (y+yMaxInt >= 0) & (y+yMaxInt < w);
706 |
707 | *(tmpArrayData++) = delta;
708 | }
709 | }
710 | }
711 | } // filterIdx
712 | } // inPlaneIdx
713 | } // batchIdx
714 | }));
715 | }
716 |
717 | // explicitly instantiate
718 | template void boxConvAccGradParameters(
719 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
720 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
721 | at::Tensor &, at::Tensor &, Parameter);
722 |
723 | template void boxConvAccGradParameters(
724 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
725 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
726 | at::Tensor &, at::Tensor &, Parameter);
727 |
728 |
729 | void clipParameters(
730 | at::Tensor & paramMin, at::Tensor & paramMax,
731 | const double reparametrization, const double minSize, const double maxSize) {
732 |
733 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(paramMin.scalar_type(), "cpu::clipParameters", ([&] {
734 |
735 | scalar_t *paramMinPtr = paramMin.data_ptr();
736 | scalar_t *paramMaxPtr = paramMax.data_ptr();
737 |
738 | const double inverseReparam = 1.0 / reparametrization;
739 |
740 | for (int idx = 0; idx < paramMin.numel(); ++idx) {
741 |
742 | double minValue, maxValue;
743 | const double paramMinCurrent = static_cast(paramMinPtr[idx]);
744 | const double paramMaxCurrent = static_cast(paramMaxPtr[idx]);
745 |
746 | // clamp parameters
747 | minValue = max(-(maxSize+1) * inverseReparam,
748 | min((maxSize-1) * inverseReparam, paramMinCurrent));
749 | maxValue = max(-(maxSize+1) * inverseReparam,
750 | min((maxSize-1) * inverseReparam, paramMaxCurrent));
751 |
752 | // make sure bottom/right border doesn't come before top/left
753 | if (minValue + (minSize - 0.9999) * inverseReparam > maxValue) {
754 | const scalar_t mean = 0.5 * (minValue + maxValue);
755 | minValue = mean - 0.5 * (minSize - 0.9999) * inverseReparam;
756 | maxValue = mean + 0.5 * (minSize - 0.9999) * inverseReparam;
757 | }
758 |
759 | paramMinPtr[idx] = static_cast(minValue);
760 | paramMaxPtr[idx] = static_cast(maxValue);
761 | }
762 | }));
763 | }
764 |
765 | at::Tensor computeArea(
766 | at::Tensor x_min, at::Tensor x_max, at::Tensor y_min, at::Tensor y_max,
767 | const bool exact, const bool needXDeriv, const bool needYDeriv) {
768 |
769 | // TODO: how to stop tracking operations??? `.is_variable_(false)` doesn't work
770 | auto retval = at::ones_like(x_min);
771 |
772 | if (not exact) {
773 | x_min = x_min.ceil();
774 | y_min = y_min.ceil();
775 | x_max = x_max.floor();
776 | y_max = y_max.floor();
777 | }
778 |
779 | if (needXDeriv) {
780 | auto xArea = x_max - x_min;
781 | xArea += 1;
782 | retval *= xArea;
783 | }
784 |
785 | if (needYDeriv) {
786 | auto yArea = y_max - y_min;
787 | yArea += 1;
788 | retval *= yArea;
789 | }
790 |
791 | retval.reciprocal_(); // inverse areas
792 | return retval;
793 | }
794 |
795 | } // namespace cpu
796 |
--------------------------------------------------------------------------------
/src/box_convolution.h:
--------------------------------------------------------------------------------
1 | #include // && -> and, || -> or etc.
2 |
3 | enum class Parameter {xMin, xMax, yMin, yMax};
4 |
5 | namespace cpu {
6 |
7 | void splitParameters(
8 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max ,
9 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
10 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac);
11 |
12 | void splitParametersUpdateGradInput(
13 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max ,
14 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
15 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac);
16 |
17 | void splitParametersAccGradParameters(
18 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max ,
19 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
20 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac);
21 |
22 | template
23 | void boxConvUpdateOutput(
24 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
25 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac,
26 | at::Tensor & area, at::Tensor & input_integrated, at::Tensor & output);
27 |
28 | template
29 | void boxConvUpdateGradInput(
30 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
31 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac,
32 | at::Tensor & area, at::Tensor & grad_output_integrated, at::Tensor & tmpArray);
33 |
34 | template
35 | void boxConvAccGradParameters(
36 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
37 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac,
38 | at::Tensor & input_integrated, at::Tensor & tmpArray, Parameter parameter);
39 |
40 | void clipParameters(
41 | at::Tensor & paramMin, at::Tensor & paramMax,
42 | const double reparametrization, const double minSize, const double maxSize);
43 |
44 | at::Tensor computeArea(
45 | at::Tensor x_min, at::Tensor x_max, at::Tensor y_min, at::Tensor y_max,
46 | const bool exact, const bool needXDeriv = true, const bool needYDeriv = true);
47 |
48 | }
49 |
50 | namespace gpu {
51 |
52 | void splitParameters(
53 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max ,
54 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
55 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac);
56 |
57 | void splitParametersUpdateGradInput(
58 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max ,
59 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
60 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac);
61 |
62 | void splitParametersAccGradParameters(
63 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max ,
64 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
65 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac);
66 |
67 | template
68 | void boxConvUpdateOutput(
69 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
70 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac,
71 | at::Tensor & area, at::Tensor & input_integrated, at::Tensor & output);
72 |
73 | template
74 | void boxConvUpdateGradInput(
75 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
76 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac,
77 | at::Tensor & area, at::Tensor & grad_output_integrated, at::Tensor & tmpArray);
78 |
79 | template
80 | void boxConvAccGradParameters(
81 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
82 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac,
83 | at::Tensor & input_integrated, at::Tensor & tmpArray, Parameter parameter);
84 |
85 | void clipParameters(
86 | at::Tensor & paramMin, at::Tensor & paramMax,
87 | const double reparametrization, const double minSize, const double maxSize);
88 |
89 | at::Tensor computeArea(
90 | at::Tensor x_min, at::Tensor x_max, at::Tensor y_min, at::Tensor y_max,
91 | const bool exact, const bool needXDeriv = true, const bool needYDeriv = true);
92 |
93 | }
--------------------------------------------------------------------------------
/src/box_convolution_cuda_backward.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 | #include "box_convolution.h" // for `enum class Parameter`
7 |
8 | #define BLOCK_SIZE 256
9 | #define NUM_THREADS 1024
10 |
11 | using std::min;
12 | using std::max;
13 |
14 | namespace gpu {
15 |
16 | template
17 | using CudaAcsr = const at::PackedTensorAccessor32;
18 |
19 | // TODO switch to square blocks
20 | template
21 | __global__ void boxConvUpdateGradInputKernel(
22 | CudaAcsr gradOutputInt, scalar_t * __restrict__ tmpArray,
23 | const int32_t * __restrict__ xMinInt , const int32_t * __restrict__ xMaxInt ,
24 | const int32_t * __restrict__ yMinInt , const int32_t * __restrict__ yMaxInt ,
25 | const scalar_t * __restrict__ xMinFrac, const scalar_t * __restrict__ xMaxFrac,
26 | const scalar_t * __restrict__ yMinFrac, const scalar_t * __restrict__ yMaxFrac,
27 | const scalar_t * __restrict__ area, const int nParams) {
28 |
29 | int32_t id = NUM_THREADS * blockIdx.x + threadIdx.x;
30 | tmpArray += id;
31 |
32 | const int32_t h = gradOutputInt.size(1) - 1;
33 | const int32_t w = gradOutputInt.size(2) - 1;
34 | const int32_t y = id % w; id /= w;
35 | const int32_t x = id % h; id /= h;
36 | const int32_t paramIdx = id % nParams;
37 |
38 | // `id` is now the current plane number
39 | auto gradOutputIntPlane = gradOutputInt[id];
40 |
41 | if (id < gradOutputInt.size(0)) {
42 |
43 | const int32_t xMinCurr = xMinInt[paramIdx];
44 | const int32_t xMaxCurr = xMaxInt[paramIdx];
45 | const int32_t yMinCurr = yMinInt[paramIdx];
46 | const int32_t yMaxCurr = yMaxInt[paramIdx];
47 |
48 | const int t = max(0, min(x+xMinCurr, h));
49 | const int b = max(0, min(x+xMaxCurr, h));
50 | const int l = max(0, min(y+yMinCurr, w));
51 | const int r = max(0, min(y+yMaxCurr, w));
52 |
53 | scalar_t outValue;
54 |
55 | outValue =
56 | gradOutputIntPlane[b][r]
57 | - gradOutputIntPlane[t][r]
58 | - gradOutputIntPlane[b][l]
59 | + gradOutputIntPlane[t][l];
60 |
61 | if (exact) {
62 | const scalar_t xMinCurrFrac = xMinFrac[paramIdx];
63 | const scalar_t xMaxCurrFrac = xMaxFrac[paramIdx];
64 | const scalar_t yMinCurrFrac = yMinFrac[paramIdx];
65 | const scalar_t yMaxCurrFrac = yMaxFrac[paramIdx];
66 |
67 | const int tAdv = x+xMinCurr-1 < h ? max(0, min(t-1, h)) : t;
68 | const int bAdv = x+xMaxCurr >= 0 ? max(0, min(b+1, h)) : b;
69 | const int lAdv = y+yMinCurr-1 < w ? max(0, min(l-1, w)) : l;
70 | const int rAdv = y+yMaxCurr >= 0 ? max(0, min(r+1, w)) : r;
71 |
72 | // -- xMax border
73 | outValue +=
74 | ( gradOutputIntPlane[bAdv][r]
75 | - gradOutputIntPlane[b ][r]
76 | - gradOutputIntPlane[bAdv][l]
77 | + gradOutputIntPlane[b ][l]
78 | ) * xMaxCurrFrac;
79 |
80 | // -- yMax border
81 | outValue +=
82 | ( gradOutputIntPlane[b][rAdv]
83 | - gradOutputIntPlane[b][r ]
84 | - gradOutputIntPlane[t][rAdv]
85 | + gradOutputIntPlane[t][r ]
86 | ) * yMaxCurrFrac;
87 |
88 | // -- xMin border
89 | outValue +=
90 | ( gradOutputIntPlane[t ][r]
91 | - gradOutputIntPlane[tAdv][r]
92 | - gradOutputIntPlane[t ][l]
93 | + gradOutputIntPlane[tAdv][l]
94 | ) * xMinCurrFrac;
95 |
96 | // -- yMin border
97 | outValue +=
98 | ( gradOutputIntPlane[b][l ]
99 | - gradOutputIntPlane[b][lAdv]
100 | - gradOutputIntPlane[t][l ]
101 | + gradOutputIntPlane[t][lAdv]
102 | ) * yMinCurrFrac;
103 |
104 | // -- corner pixels
105 | outValue +=
106 | xMaxCurrFrac*yMaxCurrFrac * (
107 | (x+xMaxCurr >= h or
108 | y+yMaxCurr >= w or
109 | x+xMaxCurr < 0 or
110 | y+yMaxCurr < 0 or
111 | b == bAdv or
112 | r == rAdv) ? static_cast(0) :
113 |
114 | ( gradOutputIntPlane[b+1][r+1]
115 | - gradOutputIntPlane[b ][r+1]
116 | - gradOutputIntPlane[b+1][r ]
117 | + gradOutputIntPlane[b ][r ]));
118 |
119 | outValue +=
120 | xMinCurrFrac*yMaxCurrFrac * (
121 | (x+xMinCurr > h or
122 | y+yMaxCurr >= w or
123 | x+xMinCurr <= 0 or
124 | y+yMaxCurr < 0 or
125 | t == tAdv or
126 | r == rAdv) ? static_cast(0) :
127 |
128 | ( gradOutputIntPlane[tAdv+1][r+1]
129 | - gradOutputIntPlane[tAdv+1][r ]
130 | - gradOutputIntPlane[tAdv ][r+1]
131 | + gradOutputIntPlane[tAdv ][r ]));
132 |
133 | outValue +=
134 | xMaxCurrFrac*yMinCurrFrac * (
135 | (x+xMaxCurr >= h or
136 | y+yMinCurr > w or
137 | x+xMaxCurr < 0 or
138 | y+yMinCurr <= 0 or
139 | b == bAdv or
140 | l == lAdv) ? static_cast(0) :
141 |
142 | ( gradOutputIntPlane[b+1][lAdv+1]
143 | - gradOutputIntPlane[b ][lAdv+1]
144 | - gradOutputIntPlane[b+1][lAdv ]
145 | + gradOutputIntPlane[b ][lAdv ]));
146 |
147 | outValue +=
148 | xMinCurrFrac*yMinCurrFrac * (
149 | (x+xMinCurr > h or
150 | y+yMinCurr > w or
151 | x+xMinCurr <= 0 or
152 | y+yMinCurr <= 0 or
153 | t == tAdv or
154 | l == lAdv) ? static_cast(0) :
155 |
156 | ( gradOutputIntPlane[tAdv+1][lAdv+1]
157 | - gradOutputIntPlane[tAdv+1][lAdv ]
158 | - gradOutputIntPlane[tAdv ][lAdv+1]
159 | + gradOutputIntPlane[tAdv ][lAdv ]));
160 | }
161 |
162 | *tmpArray = outValue * (normalize ? area[paramIdx] : static_cast(1));
163 | }
164 | }
165 |
166 | template
167 | void boxConvUpdateGradInput(
168 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
169 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac,
170 | at::Tensor & area, at::Tensor & grad_output_integrated, at::Tensor & tmpArray) {
171 |
172 | // TODO use square blocks as in `boxConvUpdateOutput`?
173 | const int threadsNeeded = tmpArray.numel();
174 | int numBlocks = (threadsNeeded + NUM_THREADS - 1) / NUM_THREADS;
175 |
176 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tmpArray.scalar_type(), "gpu::boxConvUpdateGradInput", ([&] {
177 | auto gradOutputIntFlattened = grad_output_integrated.view(
178 | {-1, grad_output_integrated.size(-2), grad_output_integrated.size(-1)});
179 | auto gradOutputIntAcsr =
180 | gradOutputIntFlattened.packed_accessor32();
181 |
182 | boxConvUpdateGradInputKernel
183 | <<>> (
184 | gradOutputIntAcsr, tmpArray.data_ptr(),
185 | xMinInt.data_ptr(), xMaxInt.data_ptr(),
186 | yMinInt.data_ptr(), yMaxInt.data_ptr(),
187 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(),
188 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(),
189 | normalize ? area.data_ptr() : nullptr, xMinInt.numel());
190 | THCudaCheck(cudaGetLastError());
191 | }));
192 | }
193 |
194 | // explicitly instantiate
195 | template void boxConvUpdateGradInput(
196 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
197 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
198 | at::Tensor &, at::Tensor &, at::Tensor &);
199 |
200 | template void boxConvUpdateGradInput(
201 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
202 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
203 | at::Tensor &, at::Tensor &, at::Tensor &);
204 |
205 | template void boxConvUpdateGradInput(
206 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
207 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
208 | at::Tensor &, at::Tensor &, at::Tensor &);
209 |
210 | template void boxConvUpdateGradInput(
211 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
212 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
213 | at::Tensor &, at::Tensor &, at::Tensor &);
214 |
215 |
216 | // TODO overload for exact/truncated mode
217 | // TODO accept only three pairs of parameter arrays, not four (one is always redundant)
218 | template
219 | __global__ void boxConvAccGradParametersKernel(
220 | CudaAcsr inputInt, scalar_t * __restrict__ tmpArray,
221 | const int32_t * __restrict__ xMinInt , const int32_t * __restrict__ xMaxInt ,
222 | const int32_t * __restrict__ yMinInt , const int32_t * __restrict__ yMaxInt ,
223 | const scalar_t * __restrict__ xMinFrac, const scalar_t * __restrict__ xMaxFrac,
224 | const scalar_t * __restrict__ yMinFrac, const scalar_t * __restrict__ yMaxFrac,
225 | const int nParams) {
226 |
227 | int32_t id = NUM_THREADS * blockIdx.x + threadIdx.x;
228 | tmpArray += id;
229 |
230 | const int32_t h = inputInt.size(1) - 1;
231 | const int32_t w = inputInt.size(2) - 1;
232 | const int32_t y = id % w + 1; id /= w;
233 | const int32_t x = id % h + 1; id /= h;
234 | const int32_t paramIdx = id % nParams; id /= nParams;
235 |
236 | // `id` is now the current absolute input plane number
237 | auto inputIntPlane = inputInt[id];
238 |
239 | if (id < inputInt.size(0)) {
240 |
241 | const int32_t xMinCurr = xMinInt[paramIdx];
242 | const int32_t xMaxCurr = xMaxInt[paramIdx];
243 | const int32_t yMinCurr = yMinInt[paramIdx];
244 | const int32_t yMaxCurr = yMaxInt[paramIdx];
245 |
246 | // TODO only define these if `exact == true`
247 | const scalar_t xMinCurrFrac = xMinFrac[paramIdx];
248 | const scalar_t xMaxCurrFrac = xMaxFrac[paramIdx];
249 | const scalar_t yMinCurrFrac = yMinFrac[paramIdx];
250 | const scalar_t yMaxCurrFrac = yMaxFrac[paramIdx];
251 |
252 | int valid;
253 | int cornerX, cornerY;
254 |
255 | scalar_t delta = 0;
256 |
257 | if (parameter == Parameter::xMin) {
258 | if (exact) {
259 | // TODO maybe use `input` instead of `inputInt`
260 | valid =
261 | not (y+yMinCurr < 1) & not (y+yMinCurr > w) & not (x+xMinCurr < 1);
262 | cornerX = max(0,min(h-1,x+xMinCurr-1));
263 | cornerY = max(0,min(w-1,y+yMinCurr-1));
264 | const scalar_t tlCorner = valid *
265 | ( inputIntPlane[cornerX+1][cornerY+1]
266 | - inputIntPlane[cornerX ][cornerY+1]
267 | - inputIntPlane[cornerX+1][cornerY ]
268 | + inputIntPlane[cornerX ][cornerY ]);
269 |
270 | valid =
271 | not (y+yMaxCurr < 0) & not (y+yMaxCurr >= w) & not (x+xMinCurr < 1);
272 | cornerX = max(0,min(h-1,x+xMinCurr -1));
273 | cornerY = max(0,min(w-1,y+yMaxCurr ));
274 | const scalar_t trCorner = valid *
275 | ( inputIntPlane[cornerX+1][cornerY+1]
276 | - inputIntPlane[cornerX ][cornerY+1]
277 | - inputIntPlane[cornerX+1][cornerY ]
278 | + inputIntPlane[cornerX ][cornerY ]);
279 |
280 | delta += trCorner * yMaxCurrFrac;
281 | delta += tlCorner * yMinCurrFrac;
282 | } // if (exact)
283 |
284 | delta += inputIntPlane
285 | [max(0,min(x+xMinCurr , h))][max(0,min(y+yMaxCurr , w))];
286 | delta -= inputIntPlane
287 | [max(0,min(x+xMinCurr -1, h))][max(0,min(y+yMaxCurr , w))];
288 | delta -= inputIntPlane
289 | [max(0,min(x+xMinCurr , h))][max(0,min(y+yMinCurr , w))];
290 | delta += inputIntPlane
291 | [max(0,min(x+xMinCurr -1, h))][max(0,min(y+yMinCurr , w))];
292 |
293 | delta *= (x+xMinCurr >= 1) & (x+xMinCurr <= h);
294 |
295 | *tmpArray = -delta;
296 | }
297 |
298 | else if (parameter == Parameter::xMax) {
299 | if (exact) {
300 | valid =
301 | not (y+yMinCurr < 1) & not (y+yMinCurr > w) & not (x+xMaxCurr >= h);
302 | cornerX = max(0,min(h-1,x+xMaxCurr ));
303 | cornerY = max(0,min(w-1,y+yMinCurr -1));
304 | const scalar_t blCorner = valid *
305 | ( inputIntPlane[cornerX+1][cornerY+1]
306 | - inputIntPlane[cornerX ][cornerY+1]
307 | - inputIntPlane[cornerX+1][cornerY ]
308 | + inputIntPlane[cornerX ][cornerY ]);
309 |
310 | valid =
311 | not (y+yMaxCurr < 0) & not (y+yMaxCurr >= w) & not (x+xMaxCurr >= h);
312 | cornerX = max(0,min(h-1,x+xMaxCurr ));
313 | cornerY = max(0,min(w-1,y+yMaxCurr ));
314 | const scalar_t brCorner = valid *
315 | ( inputIntPlane[cornerX+1][cornerY+1]
316 | - inputIntPlane[cornerX ][cornerY+1]
317 | - inputIntPlane[cornerX+1][cornerY ]
318 | + inputIntPlane[cornerX ][cornerY ]);
319 |
320 | delta += brCorner * yMaxCurrFrac;
321 | delta += blCorner * yMinCurrFrac;
322 | } // if (exact)
323 |
324 | delta += inputIntPlane
325 | [max(0,min(x+xMaxCurr +1, h))][max(0,min(y+yMaxCurr , w))];
326 | delta -= inputIntPlane
327 | [max(0,min(x+xMaxCurr , h))][max(0,min(y+yMaxCurr , w))];
328 | delta -= inputIntPlane
329 | [max(0,min(x+xMaxCurr +1, h))][max(0,min(y+yMinCurr , w))];
330 | delta += inputIntPlane
331 | [max(0,min(x+xMaxCurr , h))][max(0,min(y+yMinCurr , w))];
332 |
333 | delta *= (x+xMaxCurr >= 0) & (x+xMaxCurr < h);
334 |
335 | *tmpArray = delta;
336 | }
337 |
338 | else if (parameter == Parameter::yMin) {
339 | if (exact) {
340 | valid =
341 | not (y+yMinCurr < 1) & not (x+xMinCurr < 1) & not (x+xMinCurr > h);
342 | cornerX = max(0,min(h-1,x+xMinCurr -1));
343 | cornerY = max(0,min(w-1,y+yMinCurr -1));
344 | const scalar_t tlCorner = valid *
345 | ( inputIntPlane[cornerX+1][cornerY+1]
346 | - inputIntPlane[cornerX ][cornerY+1]
347 | - inputIntPlane[cornerX+1][cornerY ]
348 | + inputIntPlane[cornerX ][cornerY ]);
349 |
350 | valid =
351 | not (y+yMinCurr < 1) & not (x+xMaxCurr < 0) & not (x+xMaxCurr >= h);
352 | cornerX = max(0,min(h-1,x+xMaxCurr ));
353 | cornerY = max(0,min(w-1,y+yMinCurr -1));
354 | const scalar_t blCorner = valid *
355 | ( inputIntPlane[cornerX+1][cornerY+1]
356 | - inputIntPlane[cornerX ][cornerY+1]
357 | - inputIntPlane[cornerX+1][cornerY ]
358 | + inputIntPlane[cornerX ][cornerY ]);
359 |
360 | delta += tlCorner * xMinCurrFrac;
361 | delta += blCorner * xMaxCurrFrac;
362 | } // if (exact)
363 |
364 | delta += inputIntPlane
365 | [max(0,min(x+xMaxCurr , h))][max(0,min(y+yMinCurr , w))];
366 | delta -= inputIntPlane
367 | [max(0,min(x+xMaxCurr , h))][max(0,min(y+yMinCurr -1, w))];
368 | delta -= inputIntPlane
369 | [max(0,min(x+xMinCurr , h))][max(0,min(y+yMinCurr , w))];
370 | delta += inputIntPlane
371 | [max(0,min(x+xMinCurr , h))][max(0,min(y+yMinCurr -1, w))];
372 |
373 | delta *= (y+yMinCurr >= 1) & (y+yMinCurr <= w);
374 |
375 | *tmpArray = -delta;
376 | }
377 |
378 | else if (parameter == Parameter::yMax) {
379 | if (exact) {
380 | valid =
381 | not (y+yMaxCurr >= w) & not (x+xMinCurr < 1) & not (x+xMinCurr > h);
382 | cornerX = max(0,min(h-1,x+xMinCurr -1));
383 | cornerY = max(0,min(w-1,y+yMaxCurr ));
384 | const scalar_t trCorner = valid *
385 | ( inputIntPlane[cornerX+1][cornerY+1]
386 | - inputIntPlane[cornerX ][cornerY+1]
387 | - inputIntPlane[cornerX+1][cornerY ]
388 | + inputIntPlane[cornerX ][cornerY ]);
389 |
390 | valid =
391 | not (y+yMaxCurr >= w) & not (x+xMaxCurr < 0) & not (x+xMaxCurr >= h);
392 | cornerX = max(0,min(h-1,x+xMaxCurr ));
393 | cornerY = max(0,min(w-1,y+yMaxCurr ));
394 | const scalar_t brCorner = valid *
395 | ( inputIntPlane[cornerX+1][cornerY+1]
396 | - inputIntPlane[cornerX ][cornerY+1]
397 | - inputIntPlane[cornerX+1][cornerY ]
398 | + inputIntPlane[cornerX ][cornerY ]);
399 |
400 | delta += trCorner * xMinCurrFrac;
401 | delta += brCorner * xMaxCurrFrac;
402 | } // if (exact)
403 |
404 | delta += inputIntPlane
405 | [max(0,min(x+xMaxCurr , h))][max(0,min(y+yMaxCurr +1, w))];
406 | delta -= inputIntPlane
407 | [max(0,min(x+xMaxCurr , h))][max(0,min(y+yMaxCurr , w))];
408 | delta -= inputIntPlane
409 | [max(0,min(x+xMinCurr , h))][max(0,min(y+yMaxCurr +1, w))];
410 | delta += inputIntPlane
411 | [max(0,min(x+xMinCurr , h))][max(0,min(y+yMaxCurr , w))];
412 |
413 | delta *= (y+yMaxCurr >= 0) & (y+yMaxCurr < w);
414 |
415 | *tmpArray = delta;
416 | }
417 | }
418 | }
419 |
420 | template
421 | void boxConvAccGradParameters(
422 | // tmpArray size: {batchSize, nInputPlanes, numFilters, h, w}
423 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
424 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac,
425 | at::Tensor & input_integrated, at::Tensor & tmpArray, Parameter parameter) {
426 |
427 | // TODO switch to square blocks?
428 | const int threadsNeeded = tmpArray.numel();
429 | int numBlocks = (threadsNeeded + NUM_THREADS - 1) / NUM_THREADS;
430 |
431 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tmpArray.scalar_type(), "gpu::boxConvAccGradParameters", ([&] {
432 | auto inputIntFlattened = input_integrated.view(
433 | {-1, input_integrated.size(-2), input_integrated.size(-1)});
434 | auto inputIntAcsr =
435 | inputIntFlattened.packed_accessor32();
436 |
437 | switch (parameter) {
438 | case Parameter::xMin:
439 | boxConvAccGradParametersKernel
440 | <<>> (
441 | inputIntAcsr, tmpArray.data_ptr(),
442 | xMinInt.data_ptr(), xMaxInt.data_ptr(),
443 | yMinInt.data_ptr(), yMaxInt.data_ptr(),
444 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(),
445 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(), xMinInt.numel()); break;
446 | case Parameter::xMax:
447 | boxConvAccGradParametersKernel
448 | <<>> (
449 | inputIntAcsr, tmpArray.data_ptr(),
450 | xMinInt.data_ptr(), xMaxInt.data_ptr(),
451 | yMinInt.data_ptr(), yMaxInt.data_ptr(),
452 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(),
453 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(), xMinInt.numel()); break;
454 | case Parameter::yMin:
455 | boxConvAccGradParametersKernel
456 | <<>> (
457 | inputIntAcsr, tmpArray.data_ptr(),
458 | xMinInt.data_ptr(), xMaxInt.data_ptr(),
459 | yMinInt.data_ptr(), yMaxInt.data_ptr(),
460 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(),
461 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(), xMinInt.numel()); break;
462 | case Parameter::yMax:
463 | boxConvAccGradParametersKernel
464 | <<>> (
465 | inputIntAcsr, tmpArray.data_ptr(),
466 | xMinInt.data_ptr(), xMaxInt.data_ptr(),
467 | yMinInt.data_ptr(), yMaxInt.data_ptr(),
468 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(),
469 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(), xMinInt.numel()); break;
470 | }
471 | THCudaCheck(cudaGetLastError());
472 | }));
473 | }
474 |
475 | // explicitly instantiate
476 | template void boxConvAccGradParameters(
477 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
478 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
479 | at::Tensor &, at::Tensor &, Parameter);
480 |
481 | template void boxConvAccGradParameters(
482 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
483 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
484 | at::Tensor &, at::Tensor &, Parameter);
485 |
486 | }
487 |
--------------------------------------------------------------------------------
/src/box_convolution_cuda_forward.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 | #define BLOCK_SIZE 256
7 |
8 | using std::min;
9 | using std::max;
10 |
11 | #include "box_convolution.h" // for `enum class Parameter`
12 |
13 | namespace gpu {
14 |
15 | // TODO use constant memory when possible
16 | // namespace constant {
17 | // __constant__ float xMinFrac[1536], xMaxFrac[1536];
18 | // __constant__ float yMinFrac[1536], yMaxFrac[1536];
19 | // __constant__ int xMinInt[1536], xMaxInt[1536];
20 | // __constant__ int yMinInt[1536], yMaxInt[1536];
21 | // __constant__ float area[1536];
22 | // }
23 |
24 | template
25 | using CudaAcsr = const at::PackedTensorAccessor32;
26 |
27 | // overload for "truncated"/"rounded" mode
28 | template
29 | __global__ void boxConvUpdateOutputKernel(
30 | CudaAcsr inputInt, CudaAcsr output,
31 | const int32_t * __restrict__ xMinInt, const int32_t * __restrict__ xMaxInt,
32 | const int32_t * __restrict__ yMinInt, const int32_t * __restrict__ yMaxInt,
33 | const scalar_t * __restrict__ area) {
34 |
35 | // `output` size: `batch_size x in_planes x num_filters x h x w`
36 | const int32_t y = blockDim.y * blockIdx.y + threadIdx.y;
37 | const int32_t x = blockDim.z * blockIdx.z + threadIdx.z;
38 | const int32_t inPlaneIdx = blockIdx.x / output.size(2);
39 | const int32_t paramIdx = blockIdx.x % (output.size(1) * output.size(2));
40 | const int32_t h = output.size(3);
41 | const int32_t w = output.size(4);
42 |
43 | const auto inputIntPlane = inputInt[inPlaneIdx];
44 |
45 | if (x < h and y < w) {
46 | // Must add 1 to xMax/yMax/xMin/yMin due to OpenCV's
47 | // `integral()` behavior. Namely, I(x,0) and I(0,y) are
48 | // always 0 (so it's a C-style array sum).
49 |
50 | // However, when computing sums, we subtract values at points
51 | // like y+yMin-1 and x+xMin-1, so we also SUBTRACT 1 from xMin
52 | // and yMin, and thus finally they are not affected.
53 |
54 | const int32_t t = max(0, min(x+xMinInt[paramIdx], h));
55 | const int32_t b = max(0, min(x+xMaxInt[paramIdx], h));
56 | const int32_t l = max(0, min(y+yMinInt[paramIdx], w));
57 | const int32_t r = max(0, min(y+yMaxInt[paramIdx], w));
58 |
59 | scalar_t outValue = 0;
60 |
61 | outValue += inputIntPlane[b][r];
62 | outValue -= inputIntPlane[t][r];
63 | outValue -= inputIntPlane[b][l];
64 | outValue += inputIntPlane[t][l];
65 |
66 | // TODO error: expression must be a modifiable lvalue
67 | output.data()[(blockIdx.x * h + x) * w + y] =
68 | outValue * (normalize ? area[paramIdx] : static_cast(1));
69 | }
70 | }
71 |
72 | // overload for "exact" mode
73 | template
74 | __global__ void boxConvUpdateOutputKernel(
75 | CudaAcsr inputInt, CudaAcsr output,
76 | const int32_t * __restrict__ xMinInt, const int32_t * __restrict__ xMaxInt,
77 | const int32_t * __restrict__ yMinInt, const int32_t * __restrict__ yMaxInt,
78 | const scalar_t * __restrict__ xMinFrac, const scalar_t * __restrict__ xMaxFrac,
79 | const scalar_t * __restrict__ yMinFrac, const scalar_t * __restrict__ yMaxFrac,
80 | const scalar_t * __restrict__ area) {
81 |
82 | const int32_t y = blockDim.y * blockIdx.y + threadIdx.y;
83 | const int32_t x = blockDim.z * blockIdx.z + threadIdx.z;
84 | const int32_t inPlaneIdx = blockIdx.x / output.size(2);
85 | const int32_t paramIdx = blockIdx.x % (output.size(1) * output.size(2));
86 | const int32_t h = output.size(3);
87 | const int32_t w = output.size(4);
88 |
89 | const auto inputIntPlane = inputInt[inPlaneIdx];
90 |
91 | if (x < h and y < w) {
92 | // Must add 1 to xMax/yMax/xMin/yMin due to OpenCV's
93 | // `integral()` behavior. Namely, I(x,0) and I(0,y) are
94 | // always 0 (so it's a C-style array sum).
95 |
96 | // However, when computing sums, we subtract values at points
97 | // like y+yMin-1 and x+xMin-1, so we also SUBTRACT 1 from xMin
98 | // and yMin, and thus finally they are not affected.
99 | const int xMinCurr = xMinInt[paramIdx];
100 | const int xMaxCurr = xMaxInt[paramIdx];
101 | const int yMinCurr = yMinInt[paramIdx];
102 | const int yMaxCurr = yMaxInt[paramIdx];
103 |
104 | const scalar_t xMinCurrFrac = xMinFrac[paramIdx];
105 | const scalar_t xMaxCurrFrac = xMaxFrac[paramIdx];
106 | const scalar_t yMinCurrFrac = yMinFrac[paramIdx];
107 | const scalar_t yMaxCurrFrac = yMaxFrac[paramIdx];
108 |
109 | const int32_t t = max(0, min(x+xMinCurr, h));
110 | const int32_t b = max(0, min(x+xMaxCurr, h));
111 | const int32_t l = max(0, min(y+yMinCurr, w));
112 | const int32_t r = max(0, min(y+yMaxCurr, w));
113 |
114 | const int32_t bAdv = max(0, min(x+xMaxCurr+1, h));
115 | const int32_t rAdv = max(0, min(y+yMaxCurr+1, w));
116 | const int32_t tAdv = max(0, min(x+xMinCurr-1, h));
117 | const int32_t lAdv = max(0, min(y+yMinCurr-1, w));
118 |
119 | scalar_t outValue;
120 |
121 | // -- main area
122 | outValue =
123 | inputIntPlane[b][r]
124 | - inputIntPlane[t][r]
125 | - inputIntPlane[b][l]
126 | + inputIntPlane[t][l];
127 |
128 | // -- xMax border
129 | outValue +=
130 | ( inputIntPlane[bAdv][r]
131 | - inputIntPlane[b ][r]
132 | - inputIntPlane[bAdv][l]
133 | + inputIntPlane[b ][l]) * xMaxCurrFrac;
134 |
135 | // -- yMax border
136 | outValue +=
137 | ( inputIntPlane[b][rAdv]
138 | - inputIntPlane[b][r ]
139 | - inputIntPlane[t][rAdv]
140 | + inputIntPlane[t][r ]) * yMaxCurrFrac;
141 |
142 | // -- xMin border
143 | outValue +=
144 | ( inputIntPlane[t ][r]
145 | - inputIntPlane[tAdv][r]
146 | - inputIntPlane[t ][l]
147 | + inputIntPlane[tAdv][l]) * xMinCurrFrac;
148 |
149 | // -- yMin border
150 | outValue +=
151 | ( inputIntPlane[b][l ]
152 | - inputIntPlane[b][lAdv]
153 | - inputIntPlane[t][l ]
154 | + inputIntPlane[t][lAdv]) * yMinCurrFrac;
155 |
156 | // -- corner pixels
157 | // Note: before, I used plain `input` to access corner values
158 | // with lower memory access overhead. Moved to `input_integrated`
159 | // to get rid of an extra input to this function.
160 | if (not ((x+xMaxCurr >= h) | (y+yMaxCurr >= w) |
161 | (x+xMaxCurr < 0) | (y+yMaxCurr < 0))) {
162 | outValue +=
163 | xMaxCurrFrac * yMaxCurrFrac *
164 | ( inputIntPlane[b+1][r+1]
165 | - inputIntPlane[b ][r+1]
166 | - inputIntPlane[b+1][r ]
167 | + inputIntPlane[b ][r ]);
168 | }
169 |
170 | if (not ((x+xMinCurr > h) | (y+yMaxCurr >= w) |
171 | (x+xMinCurr <= 0) | (y+yMaxCurr < 0))) {
172 | outValue +=
173 | xMinCurrFrac * yMaxCurrFrac *
174 | ( inputIntPlane[t ][r+1]
175 | - inputIntPlane[t-1][r+1]
176 | - inputIntPlane[t ][r ]
177 | + inputIntPlane[t-1][r ]);
178 | }
179 |
180 | if (not ((x+xMaxCurr >= h) | (y+yMinCurr > w) |
181 | (x+xMaxCurr < 0) | (y+yMinCurr <= 0))) {
182 | outValue +=
183 | xMaxCurrFrac * yMinCurrFrac *
184 | ( inputIntPlane[b+1][l ]
185 | - inputIntPlane[b ][l ]
186 | - inputIntPlane[b+1][l-1]
187 | + inputIntPlane[b ][l-1]);
188 | }
189 |
190 | if (not ((x+xMinCurr > h) | (y+yMinCurr > w) |
191 | (x+xMinCurr <= 0) | (y+yMinCurr <= 0))) {
192 | outValue +=
193 | xMinCurrFrac * yMinCurrFrac *
194 | ( inputIntPlane[t ][l ]
195 | - inputIntPlane[t-1][l ]
196 | - inputIntPlane[t ][l-1]
197 | + inputIntPlane[t-1][l-1]);
198 | }
199 |
200 | // TODO error: expression must be a modifiable lvalue
201 | output.data()[(blockIdx.x * h + x) * w + y] =
202 | outValue * (normalize ? area[paramIdx] : static_cast(1));
203 | }
204 | }
205 |
206 | // TODO put split params and area into constant memory
207 | template
208 | void boxConvUpdateOutput(
209 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
210 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac,
211 | at::Tensor & area, at::Tensor & input_integrated, at::Tensor & output) {
212 |
213 | // was `const int`, but had to remove `const` to work around a bug in GCC 5
214 | int h = output.size(-2);
215 | int w = output.size(-1);
216 | const int totalOutputChannels = output.numel() / (h * w);
217 |
218 | const dim3 blockSize(1, 32, 32);
219 | const dim3 gridSize(
220 | (totalOutputChannels + blockSize.x - 1) / blockSize.x,
221 | (w + blockSize.y - 1) / blockSize.y,
222 | (h + blockSize.z - 1) / blockSize.z);
223 |
224 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(output.scalar_type(), "gpu::boxConvUpdateOutput", ([&] {
225 |
226 | auto inputIntFlattened = input_integrated.view({-1, h+1, w+1});
227 | auto inputIntAcsr =
228 | inputIntFlattened.packed_accessor32();
229 |
230 | auto outputAcsr = output.packed_accessor32();
231 |
232 | if (exact) {
233 | boxConvUpdateOutputKernel
234 | <<>> (
235 | inputIntAcsr, outputAcsr,
236 | xMinInt.data_ptr(), xMaxInt.data_ptr(),
237 | yMinInt.data_ptr(), yMaxInt.data_ptr(),
238 | xMinFrac.data_ptr(), xMaxFrac.data_ptr(),
239 | yMinFrac.data_ptr(), yMaxFrac.data_ptr(),
240 | normalize ? area.data_ptr() : nullptr);
241 | } else {
242 | boxConvUpdateOutputKernel
243 | <<>> (
244 | inputIntAcsr, outputAcsr,
245 | xMinInt.data_ptr(), xMaxInt.data_ptr(),
246 | yMinInt.data_ptr(), yMaxInt.data_ptr(),
247 | normalize ? area.data_ptr() : nullptr);
248 | }
249 | THCudaCheck(cudaGetLastError());
250 | }));
251 | }
252 |
253 | // explicitly instantiate
254 | template void boxConvUpdateOutput(
255 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
256 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
257 | at::Tensor &, at::Tensor &, at::Tensor &);
258 |
259 | template void boxConvUpdateOutput(
260 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
261 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
262 | at::Tensor &, at::Tensor &, at::Tensor &);
263 |
264 | template void boxConvUpdateOutput(
265 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
266 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
267 | at::Tensor &, at::Tensor &, at::Tensor &);
268 |
269 | template void boxConvUpdateOutput(
270 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
271 | at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &,
272 | at::Tensor &, at::Tensor &, at::Tensor &);
273 |
274 | }
275 |
--------------------------------------------------------------------------------
/src/box_convolution_cuda_misc.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 | #include "box_convolution.h" // for `enum class Parameter`
7 |
8 | #define BLOCK_SIZE 256
9 |
10 | using std::min;
11 | using std::max;
12 |
13 | namespace gpu {
14 |
15 | // TODO make sure xMin and yMin threads don't fall into same warp
16 | template
17 | __global__ void splitParametersKernel(
18 | const scalar_t * __restrict__ xMin, const scalar_t * __restrict__ xMax,
19 | const scalar_t * __restrict__ yMin, const scalar_t * __restrict__ yMax,
20 | int32_t * __restrict__ xMinInt, int32_t * __restrict__ xMaxInt,
21 | int32_t * __restrict__ yMinInt, int32_t * __restrict__ yMaxInt,
22 | scalar_t * __restrict__ xMinFrac, scalar_t * __restrict__ xMaxFrac,
23 | scalar_t * __restrict__ yMinFrac, scalar_t * __restrict__ yMaxFrac,
24 | const int nParameters) {
25 |
26 | const int idx = BLOCK_SIZE * blockIdx.x + threadIdx.x;
27 | if (idx < 2 * nParameters) {
28 | const int paramIndex = idx < nParameters ? idx : idx - nParameters;
29 |
30 | const scalar_t *param;
31 | scalar_t *fracParam;
32 | int32_t *intParam;
33 |
34 | param = idx < nParameters ? xMin : yMin;
35 | fracParam = idx < nParameters ? xMinFrac : yMinFrac;
36 | intParam = idx < nParameters ? xMinInt : yMinInt;
37 |
38 | const scalar_t minInt = std::ceil(param[paramIndex]);
39 | fracParam[paramIndex] = minInt - param[paramIndex];
40 | intParam[paramIndex] = static_cast(minInt);
41 |
42 | param = idx < nParameters ? xMax : yMax;
43 | fracParam = idx < nParameters ? xMaxFrac : yMaxFrac;
44 | intParam = idx < nParameters ? xMaxInt : yMaxInt;
45 |
46 | const scalar_t maxInt = std::floor(param[paramIndex]);
47 | fracParam[paramIndex] = param[paramIndex] - maxInt;
48 | intParam[paramIndex] = static_cast(maxInt) + 1;
49 | }
50 | }
51 |
52 | void splitParameters(
53 | at::Tensor & x_min , at::Tensor & x_max , at::Tensor & y_min , at::Tensor & y_max ,
54 | at::Tensor & xMinInt , at::Tensor & xMaxInt , at::Tensor & yMinInt , at::Tensor & yMaxInt ,
55 | at::Tensor & xMinFrac, at::Tensor & xMaxFrac, at::Tensor & yMinFrac, at::Tensor & yMaxFrac) {
56 |
57 | const int threadsNeeded = 2 * x_min.numel();
58 | const int numBlocks = (threadsNeeded + BLOCK_SIZE - 1) / BLOCK_SIZE;
59 |
60 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x_min.scalar_type(), "gpu::splitParameters", ([&] {
61 | splitParametersKernel
62 | <<>> (
63 | x_min.data_ptr(), x_max.data_ptr