├── .gitignore
├── .gitmodules
├── LICENSE
├── Readme.md
├── corenet.py
├── data.py
├── data
├── 000010_10_left.png
├── 000010_10_right.png
├── frame_0019.flo
├── frame_0019.png
├── frame_0020.png
├── frankfurt_val.png
├── im0.png
├── im1.png
├── output
│ ├── flow
│ │ └── .keep
│ ├── semantic
│ │ ├── .keep
│ │ └── sem_pred.png
│ └── stereo
│ │ └── .keep
├── params
│ ├── flow
│ │ ├── BP+MS (H)
│ │ │ ├── affinity_best.cpt
│ │ │ ├── crf0_lvl0_best.cpt
│ │ │ ├── crf0_lvl1_best.cpt
│ │ │ ├── crf0_lvl2_best.cpt
│ │ │ ├── matching_lvl0_best.cpt
│ │ │ ├── matching_lvl1_best.cpt
│ │ │ ├── matching_lvl2_best.cpt
│ │ │ └── unary_best.cpt
│ │ └── BP+MS+Ref (H)
│ │ │ ├── affinity_best.cpt
│ │ │ ├── crf0_lvl0_best.cpt
│ │ │ ├── crf0_lvl1_best.cpt
│ │ │ ├── crf0_lvl2_best.cpt
│ │ │ ├── matching_lvl0_best.cpt
│ │ │ ├── matching_lvl1_best.cpt
│ │ │ ├── matching_lvl2_best.cpt
│ │ │ ├── refinement_best.cpt
│ │ │ └── unary_best.cpt
│ ├── semantic
│ │ ├── global_model.cpt
│ │ └── pixel_model.cpt
│ └── stereo
│ │ ├── BP+MS (H)
│ │ ├── affinity_best.cpt
│ │ ├── crf0_lvl0_best.cpt
│ │ ├── crf0_lvl1_best.cpt
│ │ ├── crf0_lvl2_best.cpt
│ │ ├── matching_lvl0_best.cpt
│ │ ├── matching_lvl1_best.cpt
│ │ ├── matching_lvl2_best.cpt
│ │ └── unary_best.cpt
│ │ ├── BP+MS (NLL)
│ │ ├── affinity_best.cpt
│ │ ├── crf0_lvl0_best.cpt
│ │ ├── crf0_lvl1_best.cpt
│ │ ├── crf0_lvl2_best.cpt
│ │ ├── matching_lvl0_best.cpt
│ │ ├── matching_lvl1_best.cpt
│ │ ├── matching_lvl2_best.cpt
│ │ └── unary_best.cpt
│ │ ├── BP+MS+Ref (H)
│ │ ├── affinity_best.cpt
│ │ ├── crf0_lvl0_best.cpt
│ │ ├── crf0_lvl1_best.cpt
│ │ ├── crf0_lvl2_best.cpt
│ │ ├── matching_lvl0_best.cpt
│ │ ├── matching_lvl1_best.cpt
│ │ ├── matching_lvl2_best.cpt
│ │ ├── refinement_best.cpt
│ │ └── unary_best.cpt
│ │ ├── Kitti
│ │ ├── affinity_best.cpt
│ │ ├── crf0_lvl0_best.cpt
│ │ ├── crf0_lvl1_best.cpt
│ │ ├── crf0_lvl2_best.cpt
│ │ ├── matching_lvl0_best.cpt
│ │ ├── matching_lvl1_best.cpt
│ │ ├── matching_lvl2_best.cpt
│ │ └── unary_best.cpt
│ │ ├── MB
│ │ ├── affinity_best.cpt
│ │ ├── crf0_lvl0_best.cpt
│ │ ├── crf0_lvl1_best.cpt
│ │ ├── crf0_lvl2_best.cpt
│ │ ├── matching_lvl0_best.cpt
│ │ ├── matching_lvl1_best.cpt
│ │ ├── matching_lvl2_best.cpt
│ │ └── unary_best.cpt
│ │ └── WTA (NLL)
│ │ ├── matching_best.cpt
│ │ └── unary_best.cpt
├── sf_0006_left.png
└── sf_0006_right.png
├── flow.py
├── flow_matching.py
├── flow_tools.py
├── github_imgs
├── flow_example.png
├── kitti.png
├── mb.png
├── sem_input.png
├── sem_pred.png
├── sf.png
└── teaser.gif
├── main_flow.py
├── main_semantic.py
├── main_stereo.py
├── matching.py
├── networks.py
├── ops
├── flow_mp_sad
│ ├── flow_mp_sad.py
│ └── src
│ │ ├── flow_mp_sad.cpp
│ │ ├── flow_mp_sad_kernel.cu
│ │ └── flow_mp_sad_kernel.cuh
├── include
│ ├── error_util.h
│ └── tensor.h
├── lbp_semantic_pw
│ ├── bp_op_cuda.py
│ ├── inference_op.py
│ ├── message_passing_op_cuda.py
│ ├── setup.py
│ └── src
│ │ ├── lbp.cpp
│ │ ├── lbp_min_sum_kernel.cu
│ │ ├── lbp_min_sum_kernel.cuh
│ │ └── util.cuh
├── lbp_semantic_pw_pixel
│ ├── bp_op_cuda.py
│ ├── inference_op.py
│ ├── message_passing_op_pw_pixel.py
│ ├── setup.py
│ └── src
│ │ ├── lbp.cpp
│ │ ├── lbp_min_sum_kernel.cu
│ │ ├── lbp_min_sum_kernel.cuh
│ │ └── util.cuh
├── lbp_stereo
│ ├── bp_op_cuda.py
│ ├── inference_op.py
│ ├── message_passing_op_cuda.py
│ ├── setup.py
│ └── src
│ │ ├── lbp.cpp
│ │ ├── lbp_min_sum_kernel.cu
│ │ ├── lbp_min_sum_kernel.cuh
│ │ └── util.cuh
├── sad
│ ├── src
│ │ ├── stereo_sad.cpp
│ │ ├── stereo_sad_kernel.cu
│ │ └── stereo_sad_kernel.cuh
│ └── stereo_sad.py
└── setup.py
├── run_flow.sh
├── run_semantic_global.sh
├── run_semantic_pixel.sh
├── run_stereo_kitti.sh
├── run_stereo_mb.sh
├── run_stereo_sf.sh
├── semantic_segmentation.py
└── stereo.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.fls
2 | *.fdb_latexmk
3 | *.aux
4 | *.log
5 | *.out
6 | *.pdf
7 | *.DS_Store
8 | .idea
9 | .vscode
10 | *.gz
11 | images
12 | build
13 | __pycache__
14 | *egg*
15 | .ipynb_checkpoints
16 | *.pyc
17 | *.zip
18 | *.tar
19 | *.nav
20 | *.snm
21 | *.toc
22 | *.brf
23 | *.blg
24 | *.bbl
25 | output
26 |
27 |
28 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "dependencies/ESPNet"]
2 | path = dependencies/ESPNet
3 | url = https://github.com/sacmehta/ESPNet.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020
4 | Patrick Knoebelreiter and Christian Sormann
5 | Institute for Computer Graphics and Vision, Graz University of Technology
6 | https://www.tugraz.at/institute/icg/research/team-pock/
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | # BP-Layers
2 | This repository contains the implementation for our publication "Belief Propagation Reloaded: Learning BP-Layers for Labeling Problems". If you use this implementation please cite the following publication:
3 |
4 | ~~~
5 | @InProceedings{Knobelreiter_2020_CVPR,
6 | author = {Knöbelreiter, Patrick and Sormann, Christian and Shekhovtsov, Alexander and Fraundorfer, Friedrich and Pock, Thomas},
7 | title = {Belief Propagation Reloaded: Learning BP-Layers for Labeling Problems},
8 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
9 | month = {June},
10 | year = {2020}
11 | }
12 | ~~~
13 |
14 |

15 |
16 | ## Repository Structure
17 |
18 | The repository is structured as follows:
19 | - the base directory contains scripts for running inference and python implementations of the networks
20 | - 'data' includes sample images for stereo/semantic/flow inference
21 | - 'ops' contains custom PyTorch-Ops which need to be installed before running the respective stereo/semantic implementation (note that this is also taken care of by the run_*.sh scripts)
22 |
23 | For your convenience, the required libraries are added as submodules. To clone them issue the command
24 |
25 | ~~~
26 | git submodule update --init --recursive
27 | ~~~
28 |
29 | ## Dependencies
30 |
31 | * Cuda 10.2
32 | * pytorch >= 1.3
33 | * argparse
34 | * imageio (with libpfm installed)*
35 | * numpy
36 |
37 | The stereo results are saved as pfm images. If your imageio does not have libpfm installed automatically, execute the following command in a python:
38 |
39 | ~~~
40 | imageio.plugins.freeimage.download()
41 | ~~~
42 |
43 | In order to display pfm files we highly recommend the tool provided by the Middlebury stereo benchmark. You can find it here.
44 |
45 |
46 | ## Running the implementation
47 |
48 | After installing all of the required dependencies above you need to install the provided modules to you python environment. This can be done with
49 |
50 | ~~~
51 | cd ops
52 | python setup.py install
53 | ~~~
54 |
55 | This will install the SAD matching kernels for stereo and Optical flow. The BP-Layer is installed automatically upon execution of the provided shell scripts. The following sections show how to use them.
56 |
57 |
58 | ### Stereo
59 | * run_stereo_sf.sh: The models trained on the Scene-Flow Dataset
60 | * run_stereo_mb.sh: The model used for evaluation on the Middlebury dataset
61 | * run_stereo_kitti: The model used for evaluation on the Kitti dataset
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 | ### Flow
74 | * run_flow.sh
75 |
76 |
77 | 
78 |
79 | ### Semantic Segmentation
80 | * run_semantic_global.sh: Our semantic segmentation model with *global* pairwise weights
81 | * run_semantic_pixel: Our semantic segmentation model with *pixel-wise* pairwise weights
82 |
83 |
84 |
85 |
86 | ~~~
87 | sh run_semantic_pixel.sh
88 | ~~~
89 |
90 | Should yield this result:
91 |
92 |
93 | 
94 |
95 | Inside these scripts you can also specify different images to be used for inference. The correct PyTorch-Ops are also automatically installed by these scripts before running the inference.
--------------------------------------------------------------------------------
/corenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import numpy as np
6 |
7 | class CvConfidence(nn.Module):
8 | def __init__(self, device):
9 | super(CvConfidence, self).__init__()
10 | self._device = device
11 |
12 | def forward(self, prob_volume, disps):
13 | N, _, H, W = prob_volume.shape
14 |
15 | # generate coordinates
16 | n_coords = torch.arange(N, device=prob_volume.device, dtype=torch.long)
17 | n_coords = n_coords.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
18 |
19 | x_coords = torch.arange(W, device=prob_volume.device, dtype=torch.long)
20 | y_coords = torch.arange(H, device=prob_volume.device, dtype=torch.long)
21 | y_coords = y_coords.unsqueeze(-1)
22 | # xl, yl = torch.meshgrid((x_coords, y_coords)) # with torch >= 0.4.1
23 |
24 | nl = n_coords.repeat((1, 1, H, W)).long()
25 | xl = x_coords.repeat((N, 1, H, 1)).long()
26 | yl = y_coords.repeat((N, 1, 1, W)).long()
27 |
28 | cv_confidence = prob_volume[nl, torch.round(disps).long(), yl, xl]
29 | return cv_confidence
30 |
31 | class LrDistance(nn.Module):
32 | def __init__(self, device):
33 | super(LrDistance, self).__init__()
34 | self._device = device
35 |
36 | def forward(self, disps_lr, disps_rl):
37 | dispsLR = disps_lr.float()
38 | dispsRL = disps_rl.float()
39 |
40 | S, _, M, N = dispsLR.shape
41 |
42 | # generate coordinates
43 | x_coords = torch.arange(N, device=self._device, dtype=torch.float)
44 | y_coords = torch.arange(M, device=self._device, dtype=torch.float)
45 |
46 | xl = x_coords.repeat((S, M, 1)).unsqueeze(1)
47 | yl = y_coords.repeat((S, N, 1)).transpose(2, 1).unsqueeze(1)
48 |
49 | xr = xl - dispsLR
50 |
51 | # normalize coordinates for sampling between [-1, 1]
52 | xr_normed = (2 * xr / (N - 1)) - 1.0
53 | yl_normed = (2 * yl / (M - 1)) - 1.0
54 |
55 | # coords must have sahpe N x OH x OW x 2
56 | sample_coords = torch.stack((xr_normed[:, 0], yl_normed[:, 0]), dim=-1)
57 | dispsRL_warped = nn.functional.grid_sample(dispsRL, sample_coords)
58 |
59 | lr_distance = torch.abs(dispsLR + dispsRL_warped)
60 | lr_distance[(xr >= N) | (xr < 0)] = 100.0
61 |
62 | # import matplotlib.pyplot as plt
63 | # plt.ion()
64 | # plt.figure(), plt.imshow(disps_lr[0,0].detach().cpu()), plt.title('LR')
65 | # plt.figure(), plt.imshow(disps_rl[0,0].detach().cpu()), plt.title('RL')
66 | # plt.figure(), plt.imshow(dispsRL_warped[0,0].detach().cpu()), plt.title('RL WARPED')
67 | # plt.figure(), plt.imshow(lr_distance[0,0].detach().cpu()), plt.title('DIST')
68 |
69 | # import pdb
70 | # pdb.set_trace()
71 |
72 | return lr_distance
73 |
74 | class LrCheck(nn.Module):
75 | def __init__(self, device, eps):
76 | super(LrCheck, self).__init__()
77 | self._device = device
78 | self._eps = eps
79 |
80 | def forward(self, lr_dists):
81 | lr_mask = torch.ones_like(lr_dists)
82 | lr_mask[lr_dists > self._eps] = 0.0
83 |
84 | zero = torch.tensor([0], dtype=torch.float).to(self._device)
85 | confidence = torch.max(self._eps - lr_dists, zero) / self._eps
86 |
87 | # import matplotlib.pyplot as plt
88 | # plt.ion()
89 | # plt.imshow(confidence.detach()[0,0]), plt.title('conf')
90 | # plt.figure(), plt.imshow(lr_mask.detach()[0,0]), plt.title('mask')
91 |
92 | # import pdb
93 | # pdb.set_trace()
94 |
95 | return lr_mask, confidence
96 |
97 | class TemperatureSoftmax(nn.Module):
98 | def __init__(self, dim, init_temp=1.0):
99 | nn.Module.__init__(self)
100 | self.T = nn.Parameter(torch.ones(1).float().cuda() * init_temp, requires_grad=True)
101 | self.dim = dim
102 |
103 | def forward(self, x):
104 | return F.softmax(x / self.T, dim=self.dim)
105 |
106 | class TemperatureSoftmin(nn.Module):
107 | def __init__(self, dim, init_temp=1.0):
108 | nn.Module.__init__(self)
109 | self.T = nn.Parameter(torch.ones(1).float().cuda() * init_temp, requires_grad=True)
110 | self.dim = dim
111 |
112 | def forward(self, x):
113 | return F.softmin(x / self.T, dim=self.dim)
114 |
115 | class Pad(nn.Module):
116 | def __init__(self, divisor, extra_pad=(0, 0)):
117 | nn.Module.__init__(self)
118 | self.divisor = divisor
119 | self.extra_pad_h = extra_pad[0] # pad at top and bottom with specified value
120 | self.extra_pad_w = extra_pad[1] # pad at left and right with specified value
121 |
122 | self.l = 0
123 | self.r = 0
124 | self.t = 0
125 | self.b = 0
126 |
127 | def pad(self, x):
128 | N, C, H, W = x.shape
129 |
130 | w_add = 0
131 | while W % self.divisor != 0:
132 | W += 1
133 | w_add += 1
134 |
135 | h_add = 0
136 | while H % self.divisor != 0:
137 | H += 1
138 | h_add += 1
139 |
140 | # additionally pad kitti imgs
141 | self.l = self.extra_pad_w + np.ceil(w_add / 2.0).astype('int')
142 | self.r = self.extra_pad_w + np.floor(w_add / 2.0).astype('int')
143 | self.t = self.extra_pad_h + np.ceil(h_add / 2.0).astype('int')
144 | self.b = self.extra_pad_h + np.floor(h_add / 2.0).astype('int')
145 |
146 | padded = F.pad(x, (self.l, self.r, self.t, self.b), mode='reflect')
147 | return padded
148 |
149 | def forward(self, x):
150 | return self.pad(x)
151 |
152 | class Unpad(nn.Module):
153 | def __init__(self):
154 | nn.Module.__init__(self)
155 |
156 | def unpad_NCHW(self, x, l, r, t, b):
157 | x = x[:, :, t:, l:]
158 | if b > 0:
159 | x = x[:, :, :-b, :]
160 | if r > 0:
161 | x = x[:, :, :, :-r]
162 | return x.contiguous()
163 |
164 | def unpad_NHWC(self, x, l, r, t, b):
165 | x = x[:, t:, l:, :]
166 | if b > 0:
167 | x = x[:, :-b, :, :]
168 | if r > 0:
169 | x = x[:, :, :-r, :]
170 | return x.contiguous()
171 |
172 |
173 | def unpad_flow_N2HWC(self, x, l, r, t, b):
174 | x = x[:, :, t:, l:, :]
175 | if b > 0:
176 | x = x[:, :, :-b, :, :]
177 | if r > 0:
178 | x = x[:, :, :, :-r, :]
179 | return x.contiguous()
180 |
181 | def unpad_flow_N2CHW(self, x, l, r, t, b):
182 | x = x[:, :, :, t:, l:]
183 | if b > 0:
184 | x = x[:, :, :, :-b, :]
185 | if r > 0:
186 | x = x[:, :, :, :, :-r]
187 | return x.contiguous()
188 |
189 | def forward(self, x, l, r, t, b, NCHW=True):
190 | if NCHW:
191 | if len(x.shape) == 4:
192 | return self.unpad_NCHW(x, l, r, t, b)
193 | else:
194 | return self.unpad_flow_N2CHW(x, l, r, t, b)
195 | else:
196 | if len(x.shape) == 4:
197 | return self.unpad_NHWC(x, l, r, t, b)
198 | else:
199 | return self.unpad_flow_N2HWC(x, l, r, t, b)
200 |
201 | class PadUnpad(Pad, Unpad):
202 | def __init__(self, net, divisor=1, extra_pad=(0, 0)):
203 | Pad.__init__(self, divisor, extra_pad)
204 | Unpad.__init__(self)
205 |
206 | self.net = net
207 |
208 | def forward(self, ipt):
209 | out = self.net.forward(self.pad(ipt))
210 | res = []
211 | for o in out:
212 | res.append(self.unpad_NCHW(o, self.l, self.r, self.t, self.b))
213 | return res
214 |
215 | class PaddedConv2d(nn.Module):
216 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=False):
217 | super(PaddedConv2d, self).__init__()
218 |
219 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
220 | stride=stride, padding=0, dilation=dilation, bias=bias)
221 |
222 | self.dilation = dilation
223 |
224 | def forward(self, x):
225 | pd = (self.conv.kernel_size[0] // 2) * self.dilation
226 | x = F.pad(x, (pd, pd, pd, pd), mode='reflect')
227 | x = self.conv(x)
228 | return x
229 |
230 | class ResidualBlock(nn.Module):
231 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0,
232 | activation='ReLU', transposed=False, with_bn=True, leaky_alpha=1e-2):
233 | super(ResidualBlock, self).__init__()
234 |
235 | self.convbnact1 = ConvBatchNormAct(in_channels, out_channels, kernel_size, stride, dilation,
236 | padding, activation, transposed, with_bn)
237 |
238 | bias = False # because the parameter is already contained in group-norm!!
239 | self.conv2 = PaddedConv2d(out_channels, out_channels, kernel_size, stride, dilation, bias)
240 | num_groups = out_channels
241 | self.bn2 = nn.GroupNorm(num_groups, out_channels, affine=True)
242 |
243 | if activation.lower() == 'relu':
244 | self.act2 = nn.ReLU(inplace=True)
245 | elif activation.lower() == 'leakyrelu':
246 | self.act2 = nn.LeakyReLU(negative_slope=leaky_alpha, inplace=True)
247 | elif activation.lower() == 'elu':
248 | self.act2 = nn.ELU(inplace=True)
249 | else:
250 | raise NotImplementedError("Activation " + activation + " is currently not implemented!")
251 |
252 | def forward(self, x):
253 | residual = x
254 | x = self.convbnact1(x)
255 |
256 | x = self.conv2(x)
257 | x = self.bn2(x)
258 |
259 | x += residual
260 | x = self.act2(x)
261 | return x
262 |
263 |
264 |
265 | class ConvBatchNormAct(nn.Module):
266 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0,
267 | activation='ReLU', transposed=False, with_bn=True):
268 | super(ConvBatchNormAct, self).__init__()
269 |
270 | if with_bn:
271 | bias = False
272 | else:
273 | bias = True
274 |
275 | if transposed:
276 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size,
277 | stride=stride, padding=padding, dilation=dilation)
278 | else:
279 | self.conv = PaddedConv2d(in_channels, out_channels, kernel_size, stride=stride,
280 | dilation=dilation, bias=bias)
281 |
282 | self.bn = None
283 | if with_bn:
284 | #self.bn = nn.BatchNorm2d(out_channels, affine=True, track_running_stats=False)
285 | num_groups = out_channels
286 | self.bn = nn.GroupNorm(num_groups, out_channels, affine=True)
287 |
288 |
289 | if activation.lower() == 'relu':
290 | self.act = nn.ReLU(inplace=True)
291 | elif activation.lower() == 'leakyrelu':
292 | self.act = nn.LeakyReLU(inplace=True)
293 | elif activation.lower() == 'elu':
294 | self.act = nn.ELU(inplace=True)
295 | else:
296 | raise NotImplementedError("Activation " + activation + " is currently not implemented!")
297 |
298 | def forward(self, x):
299 | x = self.conv(x)
300 | if self.bn:
301 | x = self.bn(x)
302 | x = self.act(x)
303 | return x
304 |
305 |
306 | class BaseNet(nn.Module):
307 | def __init__(self):
308 | super(BaseNet, self).__init__()
309 |
310 | self._divisor = -1
311 | self._out_channels = -1
312 |
313 | @property
314 | def divisor(self):
315 | return self._divisor
316 |
317 | @property
318 | def out_channels(self):
319 | return self._out_channels
320 |
321 | @property
322 | def num_output_levels(self):
323 | return self._num_output_levels
324 |
325 | def forward(self, *input):
326 | raise NotImplementedError
327 |
328 |
329 | class StereoUnaryUnetDyn(BaseNet):
330 | def __init__(self, multi_level_output=False, activation='relu',
331 | with_bn=True, with_upconv=False, with_output_bn=True):
332 | super(StereoUnaryUnetDyn, self).__init__()
333 |
334 | self.multi_level_output = multi_level_output
335 | self.with_upconv = with_upconv
336 | self.with_output_bn = with_output_bn
337 | self._divisor = 8
338 | self._out_channels = [64]
339 | if multi_level_output:
340 | self._out_channels = [64, 64, 64]
341 |
342 | self._num_output_levels = 1
343 | if multi_level_output:
344 | self._num_output_levels = 3
345 |
346 | self._out_channels = [32]
347 | if multi_level_output:
348 | self._out_channels = [32, 32, 32]
349 |
350 | self.conv1 = ConvBatchNormAct(3, 16, 3, padding=1, activation=activation, with_bn=with_bn)
351 | self.conv2 = ConvBatchNormAct(16, 16, 3, padding=1, activation=activation, with_bn=with_bn)
352 | self.pool1 = nn.MaxPool2d(2)
353 |
354 | self.conv3 = ConvBatchNormAct(16, 32, 3, padding=1, activation=activation, with_bn=with_bn)
355 | self.conv4 = ConvBatchNormAct(32, 32, 3, padding=1, activation=activation, with_bn=with_bn)
356 | self.pool2 = nn.MaxPool2d(2)
357 |
358 | self.conv5 = ConvBatchNormAct(32, 64, 3, padding=1, activation=activation, with_bn=with_bn)
359 | self.conv6 = ConvBatchNormAct(64, 64, 3, padding=1, activation=activation, with_bn=with_bn)
360 |
361 | if self.with_upconv:
362 | self.upconv7 = ConvBatchNormAct(64, 32, 3, stride=2, padding=0, activation=activation,
363 | transposed=True, with_bn=with_bn)
364 | self.conv8 = ConvBatchNormAct(64, 32, 3, padding=1, activation=activation,
365 | with_bn=with_bn)
366 | else:
367 | self.conv8 = ConvBatchNormAct(96, 32, 3, padding=1, activation=activation,
368 | with_bn=with_bn)
369 | self.conv9 = ConvBatchNormAct(32, 32, 3, padding=1, activation=activation, with_bn=with_bn)
370 |
371 | if self.with_upconv:
372 | self.upconv10 = ConvBatchNormAct(32, 16, 3, stride=2, padding=0, activation=activation,
373 | transposed=True, with_bn=with_bn)
374 | self.conv11 = ConvBatchNormAct(32, 32, 3, padding=1, activation=activation,
375 | with_bn=with_bn)
376 | else:
377 | self.conv11 = ConvBatchNormAct(48, 32, 3, padding=1, activation=activation,
378 | with_bn=with_bn)
379 |
380 | self.conv12 = PaddedConv2d(32, 32, 3)
381 | self.bn12 = None
382 | if with_bn and with_output_bn:
383 | # self.bn12 = nn.BatchNorm2d(32, affine=True, track_running_stats=False)
384 | self.bn12 = nn.GroupNorm(32, 32, affine=True)
385 |
386 | if self.multi_level_output:
387 | self.conv_lvl1 = PaddedConv2d(32, 32, 3)
388 | self.conv_lvl2 = PaddedConv2d(64, 32, 3)
389 |
390 | def forward(self, x_in):
391 | x = x_in
392 | x = self.conv1(x)
393 | lvl0 = self.conv2(x)
394 | x = self.pool1(lvl0)
395 |
396 | x = self.conv3(x)
397 | lvl1 = self.conv4(x)
398 | x = self.pool2(lvl1)
399 |
400 | x = self.conv5(x)
401 | x = self.conv6(x)
402 |
403 | if self.multi_level_output:
404 | lvl2_out = self.conv_lvl2(x)
405 |
406 | if self.with_upconv:
407 | x = self.upconv7(x)[:, :, 1:, 1:]
408 | else:
409 | x = F.interpolate(x, lvl1.shape[2:], mode='bilinear')
410 | x = torch.cat([lvl1, x], dim=1)
411 |
412 | x = self.conv8(x)
413 | x = self.conv9(x)
414 |
415 | if self.multi_level_output:
416 | lvl1_out = self.conv_lvl1(x)
417 |
418 | if self.with_upconv:
419 | x = self.upconv10(x)[:, :, :-1, :-1]
420 | else:
421 | x = F.interpolate(x, lvl0.shape[2:], mode='bilinear')
422 | x = torch.cat([lvl0, x], dim=1)
423 |
424 | x = self.conv11(x)
425 | x = self.conv12(x)
426 | if self.bn12:
427 | x = self.bn12(x)
428 |
429 | if self.multi_level_output:
430 | return x, lvl1_out, lvl2_out
431 |
432 | return [x]
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from imageio import imread
3 | from skimage.color import rgb2lab
4 | from skimage.transform import rescale
5 | import numpy as np
6 |
7 |
8 | def get_lvl_name(lvl):
9 | return '_lvl' + str(lvl)
10 |
11 | def load_sample(left_path, right_path):
12 | # load left/right image
13 | i0 = imread(left_path)
14 | if i0.shape[2] == 4:
15 | i0 = i0[:, :, :3] # remove alpha channel
16 |
17 | i1 = imread(right_path)
18 | if i1.shape[2] == 4:
19 | i1 = i1[:, :, :3] # remove alpha channel
20 |
21 | # construct sample
22 | sample = {'i0': i0, 'i1': i1}
23 |
24 | # apply transforms
25 | sample['i0_lvl0'] = rgb2lab(sample['i0']).astype('float32')
26 | sample['i1_lvl0'] = rgb2lab(sample['i1']).astype('float32')
27 | for lvl in range(2):
28 | scale = 1.0 / 2**(lvl+1)
29 | sample['i0_lvl' + str(lvl + 1)] = rgb2lab(rescale(sample['i0'], scale, order=1, anti_aliasing=True,
30 | mode='reflect', multichannel=True)).astype('float32')
31 | sample['i1_lvl' + str(lvl + 1)] = rgb2lab(rescale(sample['i1'], scale, order=1, anti_aliasing=True,
32 | mode='reflect', multichannel=True)).astype('float32')
33 |
34 | for key in sample.keys():
35 | sample[key] = torch.from_numpy(sample[key].transpose(2, 0, 1)).unsqueeze(0)
36 |
37 | # construct image pyramid
38 | I0_pyramid = []
39 | I1_pyramid = []
40 | for lvl in range(3):
41 | I0_pyramid.append(sample['i0' + get_lvl_name(lvl)])
42 | I1_pyramid.append(sample['i1' + get_lvl_name(lvl)])
43 |
44 | return I0_pyramid, I1_pyramid
45 |
46 |
47 | def readFlow(name):
48 | f = open(name, 'rb')
49 |
50 | header = f.read(4)
51 | if header.decode("utf-8") != 'PIEH':
52 | raise Exception('Flow file header does not contain PIEH')
53 |
54 | width = np.fromfile(f, np.int32, 1).squeeze()
55 | height = np.fromfile(f, np.int32, 1).squeeze()
56 |
57 | flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
58 |
59 | return flow.astype(np.float32)
--------------------------------------------------------------------------------
/data/000010_10_left.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/000010_10_left.png
--------------------------------------------------------------------------------
/data/000010_10_right.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/000010_10_right.png
--------------------------------------------------------------------------------
/data/frame_0019.flo:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/frame_0019.flo
--------------------------------------------------------------------------------
/data/frame_0019.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/frame_0019.png
--------------------------------------------------------------------------------
/data/frame_0020.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/frame_0020.png
--------------------------------------------------------------------------------
/data/frankfurt_val.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/frankfurt_val.png
--------------------------------------------------------------------------------
/data/im0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/im0.png
--------------------------------------------------------------------------------
/data/im1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/im1.png
--------------------------------------------------------------------------------
/data/output/flow/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/output/flow/.keep
--------------------------------------------------------------------------------
/data/output/semantic/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/output/semantic/.keep
--------------------------------------------------------------------------------
/data/output/semantic/sem_pred.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/output/semantic/sem_pred.png
--------------------------------------------------------------------------------
/data/output/stereo/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/output/stereo/.keep
--------------------------------------------------------------------------------
/data/params/flow/BP+MS (H)/affinity_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/affinity_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS (H)/crf0_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/crf0_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS (H)/crf0_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/crf0_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS (H)/crf0_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/crf0_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS (H)/matching_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/matching_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS (H)/matching_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/matching_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS (H)/matching_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/matching_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS (H)/unary_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/unary_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS+Ref (H)/affinity_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/affinity_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS+Ref (H)/crf0_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/crf0_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS+Ref (H)/crf0_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/crf0_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS+Ref (H)/crf0_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/crf0_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS+Ref (H)/matching_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/matching_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS+Ref (H)/matching_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/matching_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS+Ref (H)/matching_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/matching_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS+Ref (H)/refinement_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/refinement_best.cpt
--------------------------------------------------------------------------------
/data/params/flow/BP+MS+Ref (H)/unary_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/unary_best.cpt
--------------------------------------------------------------------------------
/data/params/semantic/global_model.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/semantic/global_model.cpt
--------------------------------------------------------------------------------
/data/params/semantic/pixel_model.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/semantic/pixel_model.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (H)/affinity_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/affinity_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (H)/crf0_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/crf0_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (H)/crf0_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/crf0_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (H)/crf0_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/crf0_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (H)/matching_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/matching_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (H)/matching_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/matching_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (H)/matching_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/matching_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (H)/unary_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/unary_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (NLL)/affinity_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/affinity_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (NLL)/crf0_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/crf0_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (NLL)/crf0_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/crf0_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (NLL)/crf0_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/crf0_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (NLL)/matching_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/matching_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (NLL)/matching_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/matching_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (NLL)/matching_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/matching_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS (NLL)/unary_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/unary_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS+Ref (H)/affinity_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/affinity_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS+Ref (H)/crf0_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/crf0_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS+Ref (H)/crf0_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/crf0_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS+Ref (H)/crf0_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/crf0_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS+Ref (H)/matching_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/matching_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS+Ref (H)/matching_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/matching_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS+Ref (H)/matching_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/matching_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS+Ref (H)/refinement_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/refinement_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/BP+MS+Ref (H)/unary_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/unary_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/Kitti/affinity_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/affinity_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/Kitti/crf0_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/crf0_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/Kitti/crf0_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/crf0_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/Kitti/crf0_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/crf0_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/Kitti/matching_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/matching_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/Kitti/matching_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/matching_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/Kitti/matching_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/matching_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/Kitti/unary_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/unary_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/MB/affinity_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/affinity_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/MB/crf0_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/crf0_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/MB/crf0_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/crf0_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/MB/crf0_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/crf0_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/MB/matching_lvl0_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/matching_lvl0_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/MB/matching_lvl1_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/matching_lvl1_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/MB/matching_lvl2_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/matching_lvl2_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/MB/unary_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/unary_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/WTA (NLL)/matching_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/WTA (NLL)/matching_best.cpt
--------------------------------------------------------------------------------
/data/params/stereo/WTA (NLL)/unary_best.cpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/WTA (NLL)/unary_best.cpt
--------------------------------------------------------------------------------
/data/sf_0006_left.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/sf_0006_left.png
--------------------------------------------------------------------------------
/data/sf_0006_right.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/sf_0006_right.png
--------------------------------------------------------------------------------
/flow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from networks import FeatureNet, AffinityNet, RefinementNet
6 | from flow_matching import FlowMatchingSad
7 | from ops.lbp_stereo.bp_op_cuda import BP
8 |
9 | from corenet import CvConfidence#, LrCheck, LrDistance,
10 | from corenet import Pad, Unpad
11 |
12 | class FlowMethod(nn.Module):
13 | def __init__(self, device, args):
14 | nn.Module.__init__(self)
15 | self.args = args
16 |
17 | self._feature_net = None
18 | self._matching = []
19 | self._affinity_net = None
20 | self._refinement_net = None
21 | self._crf = [] # list of bp layers
22 |
23 | self.cv_conf = CvConfidence(device).to(device)
24 |
25 | self._sws = args.sws
26 |
27 | self.pad = None
28 | self.unpad = None
29 |
30 | self.device = device
31 |
32 | def forward(self, I0_pyramid, I1_pyramid, beliefs_in=None, sws=None):
33 |
34 | # necessary for evaluation
35 | if self.pad is None:
36 | self.pad = Pad(self.feature_net.net.divisor, self.args.pad)
37 | if self.unpad is None:
38 | self.unpad = Unpad()
39 |
40 | res_dict = {'flow0': None}
41 |
42 | I0_in = I0_pyramid[self.args.input_level_offset].to(self.device)
43 | I1_in = I1_pyramid[self.args.input_level_offset].to(self.device)
44 |
45 | # pad input for multi-scale (for evaluation)
46 | I0_in = self.pad.forward(I0_in)
47 | I1_in = self.pad.forward(I1_in)
48 |
49 | f0_pyramid = self.extract_features(I0_in)
50 | f1_pyramid = self.extract_features(I1_in)
51 |
52 | # if sws is not None:
53 | # self.matching.sws = sws
54 |
55 | # multi-scale-matching
56 | prob_vol_pyramid = self.match(f0_pyramid, f1_pyramid)
57 |
58 | uflow0_pyramid = []
59 | for pv0 in prob_vol_pyramid:
60 | uflow0_pyramid.append(torch.argmax(pv0, dim=-1))
61 | res_dict['flow0'] = uflow0_pyramid
62 |
63 | if not self._crf:
64 | return res_dict
65 |
66 | affinity_pyramid = None
67 | if self.affinity_net:
68 | affinity_pyramid = self.extract_affinities(I0_in)
69 | for lvl in range(len(affinity_pyramid)):
70 | _, _, h, w = affinity_pyramid[lvl].shape
71 | affinity_pyramid[lvl] = affinity_pyramid[lvl].view((-1, 2, 5, h, w))
72 | affinity_pyramid[lvl] = affinity_pyramid[lvl].unsqueeze(0)
73 |
74 | output_flow_pyramid = []
75 | beliefs_pyramid = None
76 |
77 | crf_flow_pyramid = []
78 | beliefs_pyramid = []
79 | beliefs_in = None
80 | for lvl in reversed(range(len(prob_vol_pyramid))):
81 | pv_lvl = prob_vol_pyramid[lvl]
82 | m = self.matching[lvl]
83 |
84 | affinity = None
85 | if affinity_pyramid is not None:
86 | affinity = affinity_pyramid[lvl]
87 | crf = self.crf[lvl]
88 |
89 | # add probably an if condition whether do add multi-scale to crf
90 | if beliefs_in is not None:
91 | N,_,H,W,K = beliefs_in.shape
92 | size = (2*H, 2*W, 2*K-1)
93 | beliefs_in_u = F.interpolate(beliefs_in[:, 0].unsqueeze(1), size=size, mode='trilinear')[:, 0]
94 | beliefs_in_v = F.interpolate(beliefs_in[:, 1].unsqueeze(1), size=size, mode='trilinear')[:, 0]
95 | beliefs_in = torch.cat((beliefs_in_u.unsqueeze(1), beliefs_in_v.unsqueeze(1)), dim=1).contiguous()
96 | pv_lvl = pv_lvl + beliefs_in / 2.0
97 |
98 | flow_lvl, beliefs_lvl, affinities_lvl, offsets_lvl = self.optimize_crf(crf, pv_lvl, None, affinity, None)
99 |
100 | if lvl == 0: # TODO FOR EVAL!!!!
101 | beliefs_lvl = self.unpad(beliefs_lvl, self.pad.l, self.pad.r, self.pad.t,
102 | self.pad.b, NCHW=False)
103 | flow_lvl = self.unpad(flow_lvl, self.pad.l, self.pad.r, self.pad.t,
104 | self.pad.b)
105 |
106 | beliefs_pyramid.append(beliefs_lvl)
107 | crf_flow_pyramid.append(flow_lvl - m.sws // 2)
108 |
109 | beliefs_in = beliefs_pyramid[-1]
110 |
111 | # beliefs are from low res to high res
112 | beliefs_pyramid.reverse()
113 | crf_flow_pyramid.reverse()
114 | output_flow_pyramid = crf_flow_pyramid
115 | res_dict['flow0'] = crf_flow_pyramid
116 |
117 | if self.refinement_net:
118 | # crf
119 | cv_conf_u = self.cv_conf.forward(beliefs_pyramid[0][:,0].permute(0, 3, 1, 2),
120 | crf_flow_pyramid[0][:,0:1] + m.sws // 2)
121 | cv_conf_v = self.cv_conf.forward(beliefs_pyramid[0][:,1].permute(0, 3, 1, 2),
122 | crf_flow_pyramid[0][:,1:2] + m.sws // 2)
123 |
124 | conf_all = torch.cat((cv_conf_u, cv_conf_v), dim=1)
125 | refined_flow_pyramid, _ = self.refine_disps(I0_pyramid,
126 | crf_flow_pyramid[0],
127 | confidence=conf_all,
128 | I1=I1_pyramid)
129 | refined_flow_pyramid.reverse()
130 | output_flow_pyramid = refined_flow_pyramid
131 |
132 | res_dict['flow0'] = output_flow_pyramid
133 |
134 | return res_dict
135 |
136 | def extract_features(self, ipt):
137 | if self.feature_net:
138 | return self.feature_net.forward(ipt)
139 | return None
140 |
141 | def compute_guidance(self, ipt):
142 | if self.guidance_net:
143 | return self.guidance_net.forward(ipt)
144 | return None
145 |
146 | def extract_edges(self, ipt):
147 | if self.edge_net:
148 | return self.edge_net.forward(ipt)
149 | return None
150 |
151 | def extract_affinities(self, ipt):
152 | if self.affinity_net:
153 | return self.affinity_net.forward(ipt)
154 | return None
155 |
156 | def extract_offsets(self, ipt):
157 | if self.offset_net:
158 | return self.offset_net.forward(ipt)
159 | return None
160 |
161 | def match(self, f0, f1):
162 | prob_vols = []
163 | if self.matching:
164 | for matching, f0s, f1s in zip(self.matching, f0, f1):
165 | prob_vols.append(matching.forward(f0s, f1s))
166 | return prob_vols
167 | return None
168 |
169 | def optimize_crf(self, crf_layer, prob_vol, weights, affinities, offsets):
170 | if crf_layer:
171 | # iterate over all bp "layers"
172 | for idx, crf in enumerate(crf_layer):
173 | #TODO take care of BP layer idx in adjust functions
174 | prob_vol = prob_vol.contiguous()
175 | weights_input = crf.adjust_input_weights(weights, idx)
176 | affinities_shift = crf.adjust_input_affinities(affinities[:,idx])
177 | offsets_shift = crf.adjust_input_offsets(offsets)
178 |
179 | disps, prob_vol, messages = crf.forward(prob_vol, weights_input, affinities_shift, offsets_shift)
180 |
181 | return disps, prob_vol, affinities_shift, offsets_shift
182 | return None
183 |
184 | def refine_disps(self, I0, d0, confidence=None, I1=None):
185 | if self.refinement_net:
186 | refined, steps = self.refinement_net.forward(I0, d0, confidence, I1)
187 | return refined, steps
188 | return None
189 |
190 | def feature_net_params(self, requires_grad=None):
191 | if self.feature_net:
192 | return self.feature_net.parameter_list(requires_grad)
193 | return []
194 |
195 | def matching_params(self, requires_grad=None):
196 | params = []
197 | if self.matching:
198 | for m in self.matching:
199 | params += m.parameter_list(requires_grad)
200 | return params
201 |
202 | def affinity_net_params(self, requires_grad=None):
203 | if self.affinity_net:
204 | return self.affinity_net.parameter_list(requires_grad)
205 | return []
206 |
207 | def crf_params(self, requires_grad=None):
208 | crf_params = []
209 | if self.crf:
210 | for crf_layer in self.crf:
211 | for crf in crf_layer:
212 | crf_params += crf.parameter_list(requires_grad)
213 | return crf_params
214 |
215 | def refinement_net_params(self, requires_grad=None):
216 | if self.refinement_net:
217 | return self.refinement_net.parameter_list(requires_grad)
218 | return []
219 |
220 | @property
221 | def feature_net(self):
222 | return self._feature_net
223 |
224 | @property
225 | def affinity_net(self):
226 | return self._affinity_net
227 |
228 | @property
229 | def offset_net(self):
230 | return self._offset_net
231 |
232 | @property
233 | def crf(self):
234 | if self._crf == []:
235 | return None
236 | return self._crf
237 |
238 | @property
239 | def refinement_net(self):
240 | return self._refinement_net
241 |
242 | @property
243 | def matching(self):
244 | return self._matching
245 |
246 | @property
247 | def min_disp(self):
248 | return self._min_disp
249 |
250 | @property
251 | def max_disp(self):
252 | return self._max_disp
253 |
254 |
255 | ####################################################################################################
256 | # Block Match
257 | ####################################################################################################
258 | class BlockMatchFlow(FlowMethod):
259 | def __init__(self, device, args):
260 | FlowMethod.__init__(self, device, args)
261 | self._feature_net = FeatureNet(device, args)
262 |
263 | self._matching = []
264 | for matching_lvl in range(self._feature_net.net.num_output_levels):
265 | sws = ((self._sws) // 2**matching_lvl)
266 | self._matching.append(FlowMatchingSad(device, args, sws, lvl=matching_lvl))
267 | if args.matching != 'sad':
268 | print('WARNING: Use SAD matching for flow, but', args.matching, 'was chosen.')
269 |
270 |
271 | ####################################################################################################
272 | # Min-Sum LBP
273 | ####################################################################################################
274 | class MinSumFlow(BlockMatchFlow):
275 | def __init__(self, device, args):
276 | BlockMatchFlow.__init__(self, device, args)
277 |
278 | self.max_iter = args.max_iter
279 | num_labels = self._sws + 1
280 |
281 | self._affinity_net = AffinityNet(device, args)
282 |
283 | for lvl in range(self._feature_net.net.num_output_levels):
284 | self._crf.append([BP(device, args, self.max_iter, num_labels, 3,
285 | mode_inference = args.bp_inference,
286 | mode_message_passing='min-sum', layer_idx=idx, level=lvl)
287 | for idx in range(args.num_bp_layers)])
288 |
289 | class RefinedMinSumFlow(MinSumFlow):
290 | def __init__(self, device, args):
291 | super(RefinedMinSumFlow, self).__init__(device, args)
292 |
293 | self._refinement_net = RefinementNet(device, args, in_channels=7, out_channels=2, with_output_relu=False)
294 |
--------------------------------------------------------------------------------
/flow_matching.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import torch
3 |
4 | from networks import SubNetwork
5 |
6 | from corenet import TemperatureSoftmax
7 | from ops.flow_mp_sad.flow_mp_sad import FlowMpSadFunction
8 |
9 | class FlowMatching(SubNetwork):
10 | def __init__(self, device, args, sws, in_channels=None, lvl=0):
11 | super(FlowMatching, self).__init__(args, device)
12 | self.device = device
13 | self.sws = sws
14 | self.in_channels = in_channels
15 | self.level = lvl
16 |
17 | self._softmax = TemperatureSoftmax(dim=3, init_temp=1.0)
18 |
19 | def compute_score_volume(self, f0, f1):
20 | raise NotImplementedError
21 |
22 | def forward(self, f0, f1):
23 | score_vol_u, score_vol_v = self.compute_score_volume(f0, f1)
24 | prob_vol_u = self._softmax.forward(score_vol_u)
25 | prob_vol_v = self._softmax.forward(score_vol_v)
26 | prob_vol_uv = torch.cat((prob_vol_u.unsqueeze(1), prob_vol_v.unsqueeze(1)), dim=1)
27 | return prob_vol_uv.contiguous()
28 |
29 | @staticmethod
30 | def argmin_to_disp(argmin, min_disp):
31 | res = argmin + min_disp
32 | return res
33 |
34 | def save_checkpoint(self, epoch, iteration):
35 | if 'u' in self.args.train_params:
36 | torch.save(self.state_dict(),
37 | osp.join(self.args.train_dir, 'matching_lvl' + str(self.level) +
38 | '_checkpoint_' + str(epoch) + '_' + str(iteration).zfill(6) + '.cpt'))
39 |
40 |
41 | class FlowMatchingSad(FlowMatching):
42 | def __init__(self, device, args, sws, lvl=0):
43 | super(FlowMatchingSad, self).__init__(device, args, sws, lvl=lvl)
44 |
45 | if args.checkpoint_matching: # not-empty check
46 | lvl = min(self.level, len(args.checkpoint_matching) - 1)
47 | self.load_parameters(args.checkpoint_matching[lvl], device)
48 | self.to(device)
49 |
50 | def compute_score_volume(self, f0, f1):
51 | cv_u, cv_v, amin_u, amin_v = FlowMpSadFunction.apply(f0, f1, self.sws)
52 | return -cv_u, -cv_v
53 |
--------------------------------------------------------------------------------
/flow_tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage.transform import resize
3 | from skimage.color import rgb2gray
4 | from scipy.ndimage import map_coordinates
5 | from scipy.ndimage.filters import correlate
6 |
7 | def makeColorWheel():
8 | RY = 15
9 | YG = 6
10 | GC = 4
11 | CB = 11
12 | BM = 13
13 | MR = 6
14 |
15 | size = RY + YG + GC + CB + BM + MR
16 | colorwheel = np.zeros((3, size))
17 |
18 | col = 0
19 | # RY
20 | colorwheel[0, col:col+RY] = 255
21 | colorwheel[1, col:col+RY] = np.floor(255 * np.arange(RY)/RY)
22 | col += RY
23 |
24 | # YG
25 | colorwheel[0, col:col+YG] = 255 - np.floor(255 * np.arange(YG)/YG)
26 | colorwheel[1, col:col+YG] = 255
27 | col += YG
28 |
29 | # GC
30 | colorwheel[1, col:col+GC] = 255
31 | colorwheel[2, col:col+GC] = np.floor(255 * np.arange(GC)/GC)
32 | col += GC
33 |
34 | # CB
35 | colorwheel[1, col:col+CB] = 255 - np.floor(255 * np.arange(CB)/CB)
36 | colorwheel[2, col:col+CB] = 255
37 | col += CB
38 |
39 | # BM
40 | colorwheel[0, col:col+BM] = np.floor(255 * np.arange(BM)/BM)
41 | colorwheel[2, col:col+BM] = 255
42 | col += BM
43 |
44 | # MR
45 | colorwheel[0, col:col+MR] = 255
46 | colorwheel[2, col:col+MR] = 255 - np.floor(255 * np.arange(MR)/MR)
47 |
48 | return colorwheel.astype('uint8')
49 |
50 | def computeNormalizedFlow(u, v, u_ref=None, v_ref=None, verbose=False):
51 | # copy to not overwrite the inputs
52 | u = u.copy()
53 | v = v.copy()
54 |
55 | eps = 1e-15
56 | UNKNOWN_FLOW_THRES = 1e9
57 | # UNKNOWN_FLOW = 1e10
58 |
59 | maxu = -999
60 | maxv = -999
61 | minu = 999
62 | minv = 999
63 | maxrad = -1
64 |
65 | # fix unknown flow
66 | idxUnknown = np.logical_or(np.abs(u) > UNKNOWN_FLOW_THRES, np.abs(v) > UNKNOWN_FLOW_THRES)
67 | u[idxUnknown] = 0
68 | v[idxUnknown] = 0
69 |
70 | maxu = np.maximum(maxu, np.max(u))
71 | minu = np.minimum(minu, np.min(u))
72 |
73 | maxv = np.maximum(maxv, np.max(v))
74 | minv = np.minimum(minv, np.min(v))
75 |
76 | if u_ref is not None and v_ref is not None:
77 | rad = np.sqrt(u_ref**2 + v_ref**2)
78 | else:
79 | rad = np.sqrt(u**2 + v**2)
80 | maxrad = np.maximum(maxrad, np.max(rad))
81 |
82 | if verbose:
83 | print("max flow: ", maxrad, " flow range: u = ", minu, "..", maxu, "v = ", minv, "..", maxv)
84 |
85 | u = u / (maxrad + eps)
86 | v = v / (maxrad + eps)
87 |
88 | return u, v
89 |
90 | def computeFlowImg(u, v, u_ref=None, v_ref=None):
91 | # do not overwrite input flow!
92 | u = u.copy()
93 | v = v.copy()
94 |
95 | u, v = computeNormalizedFlow(u, v, u_ref, v_ref)
96 |
97 | nanIdx = np.logical_or(np.isnan(u), np.isnan(v))
98 | u[nanIdx] = 0
99 | v[nanIdx] = 0
100 |
101 | cw = makeColorWheel().T
102 |
103 | M, N = u.shape
104 | img = np.zeros((M, N, 3)).astype('uint8')
105 |
106 | mag = np.sqrt(u**2 + v**2)
107 |
108 | phi = np.arctan2(-v, -u) / np.pi # [-1, 1]
109 | phi_idx = (phi + 1.0) / 2.0 * (cw.shape[0] - 1)
110 | f_phi_idx = np.floor(phi_idx).astype('int')
111 |
112 | c_phi_idx = f_phi_idx + 1
113 | c_phi_idx[c_phi_idx == cw.shape[0]] = 0
114 |
115 | floor = phi_idx - f_phi_idx
116 |
117 | for i in range(cw.shape[1]):
118 | tmp = cw[:, i]
119 |
120 | # linear blend between colors
121 | col0 = tmp[f_phi_idx] / 255.0 # from colorwheel take specified values in phi_idx
122 | col1 = tmp[c_phi_idx] / 255.0
123 | col = (1.0 - floor)*col0 + floor * col1
124 |
125 | # increase saturation for small magnitude
126 | sat_idx = mag <= 1
127 | col[sat_idx] = 1 - mag[sat_idx] * (1 - col[sat_idx])
128 |
129 | col[np.logical_not(sat_idx)] = col[np.logical_not(sat_idx)] * 0.75
130 |
131 | img[:, :, i] = (np.floor(255.0*col*(1-nanIdx))).astype('uint8')
132 | return img
133 |
134 |
--------------------------------------------------------------------------------
/github_imgs/flow_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/flow_example.png
--------------------------------------------------------------------------------
/github_imgs/kitti.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/kitti.png
--------------------------------------------------------------------------------
/github_imgs/mb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/mb.png
--------------------------------------------------------------------------------
/github_imgs/sem_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/sem_input.png
--------------------------------------------------------------------------------
/github_imgs/sem_pred.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/sem_pred.png
--------------------------------------------------------------------------------
/github_imgs/sf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/sf.png
--------------------------------------------------------------------------------
/github_imgs/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/teaser.gif
--------------------------------------------------------------------------------
/main_flow.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from flow import MinSumFlow, BlockMatchFlow, RefinedMinSumFlow
5 | import data
6 | from flow_tools import computeFlowImg
7 |
8 | import imageio
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--im0', action='store', default="data/sf_0006_left.png", required=False, type=str)
14 | parser.add_argument('--im1', action='store', default="data/sf_0006_right.png", required=False, type=str)
15 | parser.add_argument('--sws', action='store', default=108, type=int)
16 | parser.add_argument('--multi-level-output', action='store_true', default=False)
17 | parser.add_argument('--activation', action='store', choices=['relu', 'leakyrelu', 'elu'], default='leakyrelu')
18 | parser.add_argument('--with-bn', action='store_true', default=False)
19 | parser.add_argument('--with-upconv', action='store_true', default=False)
20 | parser.add_argument('--with-output-bn', action='store_true', default=False)
21 | parser.add_argument('--pad', action='store', default=(0, 0), nargs=2, type=int,
22 | help='extra padding of in height and in width on every side')
23 |
24 | parser.add_argument('--model', action='store', default='bp+ms+ref+h',
25 | choices=['wta', 'bp+ms', 'bp+ms+h', 'bp+ms+ref+h'])
26 | parser.add_argument('--checkpoint-unary', action='store', default=None, type=str)
27 | parser.add_argument('--checkpoint-matching', action='store', default=[], nargs='+', type=str)
28 | parser.add_argument('--checkpoint-affinity', action='store', default=None, type=str)
29 | parser.add_argument('--checkpoint-crf', action='append', default=[], type=str, nargs='+')
30 | parser.add_argument('--checkpoint-refinement', action='store', default=None, type=str)
31 |
32 | parser.add_argument('--lbp-min-disp', action='store_true', default=False)
33 | parser.add_argument('--max-iter', action='store', default=1, type=int)
34 | parser.add_argument('--num-bp-layers', action='store', default=1, type=int)
35 | parser.add_argument('--bp-inference', action='store', default='sub-exp',
36 | choices=['wta', 'expectation', 'sub-exp'], type=str)
37 |
38 | parser.add_argument('--matching', action='store', choices=['corr', 'sad', 'conv3d'],
39 | default='sad', type=str)
40 |
41 | parser.add_argument('--input-level-offset', action='store', default=1, type=int,
42 | help='1 means that level 1 is the input resolution')
43 | parser.add_argument('--output-level-offset', action='store', default=1, type=int,
44 | help="0 means that level 0 (=full res) is the output resolution")
45 | args = parser.parse_args()
46 |
47 | I0_pyramid, I1_pyramid = data.load_sample(args.im0, args.im1)
48 |
49 | args.multi_level_output = True
50 |
51 | device = 'cuda:0'
52 | with torch.no_grad():
53 | if args.model == 'wta':
54 | model = BlockMatchFlow(device, args)
55 | elif args.model == 'bp+ms':
56 | model = MinSumFlow(device, args)
57 | elif args.model == 'bp+ms+h':
58 | model = MinSumFlow(device, args)
59 | elif args.model == 'bp+ms+ref+h':
60 | model = RefinedMinSumFlow(device, args)
61 |
62 | res_dict = model.to(device).forward(I0_pyramid, I1_pyramid, sws=args.sws)
63 |
64 |
65 | flow = res_dict['flow0'][0].squeeze().float().detach().cpu().numpy()
66 |
67 | gt_flow = data.readFlow("data/frame_0019.flo") # only used for normalizing the flow output!
68 | flow_img = computeFlowImg(flow[0], flow[1], gt_flow[0], gt_flow[1])
69 |
70 | imageio.imwrite("data/output/flow/" + args.model + ".png", flow_img)
--------------------------------------------------------------------------------
/main_semantic.py:
--------------------------------------------------------------------------------
1 | from semantic_segmentation import SemanticNet
2 |
3 | import argparse
4 | import os
5 | import yaml
6 | import cv2
7 | import numpy as np
8 | import torch
9 | import matplotlib.pyplot as plt
10 | from imageio import imsave
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--img', action='store', required=True, type=str)
14 | parser.add_argument('--scale-imgs', action='store', default=0.5, type=float)
15 | parser.add_argument('--checkpoint-semantic', action='store', default=None, type=str)
16 | parser.add_argument('--checkpoint-esp-net', action='store', default=None, type=str)
17 | parser.add_argument('--num-labels', action='store', default=20, type=int)
18 | parser.add_argument('--activation', action='store', choices=['relu', 'leakyrelu', 'elu'], default='leakyrelu')
19 | parser.add_argument('--num-bp-layers', action='store', default=1, type=int)
20 | parser.add_argument('--with-bn', action='store_true', default=True)
21 | parser.add_argument('--with-upconv', action='store_true', default=False)
22 | parser.add_argument('--with-output-bn', action='store_true', default=False)
23 | parser.add_argument('--pad', action='store', default=(0, 0), nargs=2, type=int, help='extra padding of in height and in width on every side')
24 | parser.add_argument('--pairwise-type', action='store', choices=["global", "pixel"], default="global")
25 | parser.add_argument('--multi-level-features', action='store_true', default=False)
26 | parser.add_argument('--with-esp', action='store_true', default=False)
27 | parser.add_argument('--with-edges', action='store_true', default=False)
28 | parser.add_argument('--multi-level-output', action='store_true', default=False)
29 |
30 |
31 | args = parser.parse_args()
32 |
33 | cuda_device = 'cuda:0'
34 |
35 | semantic_model = SemanticNet(cuda_device, args)
36 |
37 | # read input image
38 | test_img_path = args.img
39 |
40 | test_img = cv2.imread(test_img_path).astype(np.float32)
41 |
42 | # normalize input and convert to torch
43 | mean = np.array([72.39231, 82.908936, 73.1584])
44 | std = np.array([45.31922, 46.152893, 44.914833])
45 |
46 | test_img -= mean[np.newaxis, np.newaxis, :]
47 | test_img /= std[np.newaxis, np.newaxis, :]
48 |
49 | height, width, _ = test_img.shape
50 | scaled_height = int(height * args.scale_imgs)
51 | scaled_width = int(width * args.scale_imgs)
52 |
53 | test_img = cv2.resize(test_img, (scaled_width, scaled_height))
54 |
55 | test_img_torch = torch.from_numpy(test_img).to(device=cuda_device).permute(2, 0, 1).unsqueeze(0) / 255.0
56 |
57 | # run model
58 | sem_pred, _, _ = semantic_model.forward(test_img_torch)
59 |
60 | # visualize/save result
61 |
62 | ########## cityscapes visualization from ESP Net: https://github.com/sacmehta/ESPNet ###########################
63 | label_colors = [[128, 64, 128],
64 | [244, 35, 232],
65 | [70, 70, 70],
66 | [102, 102, 156],
67 | [190, 153, 153],
68 | [153, 153, 153],
69 | [250, 170, 30],
70 | [220, 220, 0],
71 | [107, 142, 35],
72 | [152, 251, 152],
73 | [70, 130, 180],
74 | [220, 20, 60],
75 | [255, 0, 0],
76 | [0, 0, 142],
77 | [0, 0, 70],
78 | [0, 60, 100],
79 | [0, 80, 100],
80 | [0, 0, 230],
81 | [119, 11, 32],
82 | [0, 0, 0]]
83 |
84 | sem_pred_np = sem_pred.squeeze().byte().cpu().data.numpy()
85 | sem_pred_np_color = np.zeros((sem_pred_np.shape[0], sem_pred_np.shape[1], 3), dtype=np.uint8)
86 |
87 | for label in range(len(label_colors)):
88 | sem_pred_np_color[sem_pred_np == label] = label_colors[label]
89 |
90 | imsave("data/output/semantic/sem_pred.png", sem_pred_np_color)
91 |
92 | # plt.figure()
93 | # plt.imshow(sem_pred_np_color)
94 | # plt.show()
95 |
96 |
--------------------------------------------------------------------------------
/main_stereo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from stereo import MinSumStereo, BlockMatchStereo, RefinedMinSumStereo
5 | import data
6 |
7 | import imageio
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--im0', action='store', required=True, type=str)
13 | parser.add_argument('--im1', action='store', required=True, type=str)
14 | parser.add_argument('--min-disp', action='store', default=0, type=int)
15 | parser.add_argument('--max-disp', action='store', default=127, type=int)
16 | parser.add_argument('--stride-in', action='store', default=1, type=int)
17 | parser.add_argument('--stride-out', action='store', default=1, type=int)
18 | parser.add_argument('--multi-level-output', action='store_true', default=False)
19 | parser.add_argument('--activation', action='store', choices=['relu', 'leakyrelu', 'elu'], default='leakyrelu')
20 | parser.add_argument('--with-bn', action='store_true', default=False)
21 | parser.add_argument('--with-upconv', action='store_true', default=False)
22 | parser.add_argument('--with-output-bn', action='store_true', default=False)
23 | parser.add_argument('--pad', action='store', default=(0, 0), nargs=2, type=int,
24 | help='extra padding of in height and in width on every side')
25 |
26 | parser.add_argument('--model', action='store', default='bp+ms+h',
27 | choices=['wta', 'bp+ms', 'bp+ms+h', 'bp+ms+ref+h'])
28 | parser.add_argument('--checkpoint-unary', action='store', default=None, type=str)
29 | parser.add_argument('--checkpoint-matching', action='store', default=[], nargs='+', type=str)
30 | parser.add_argument('--checkpoint-affinity', action='store', default=None, type=str)
31 | parser.add_argument('--checkpoint-crf', action='append', default=[], type=str, nargs='+')
32 | parser.add_argument('--checkpoint-refinement', action='store', default=None, type=str)
33 |
34 | parser.add_argument('--lbp-min-disp', action='store_true', default=False)
35 | parser.add_argument('--max-iter', action='store', default=1, type=int)
36 | parser.add_argument('--num-bp-layers', action='store', default=1, type=int)
37 | parser.add_argument('--bp-inference', action='store', default='sub-exp',
38 | choices=['wta', 'expectation', 'sub-exp'], type=str)
39 |
40 | parser.add_argument('--matching', action='store', choices=['corr', 'sad', 'conv3d'],
41 | default='sad', type=str)
42 |
43 | parser.add_argument('--input-level-offset', action='store', default=1, type=int,
44 | help='1 means that level 1 is the input resolution')
45 | parser.add_argument('--output-level-offset', action='store', default=1, type=int,
46 | help="0 means that level 0 (=full res) is the output resolution")
47 | args = parser.parse_args()
48 |
49 | I0_pyramid, I1_pyramid = data.load_sample(args.im0, args.im1)
50 |
51 | device = 'cuda:0'
52 | with torch.no_grad():
53 | if args.model == 'wta':
54 | model = BlockMatchStereo(device, args)
55 | elif args.model == 'bp+ms':
56 | model = MinSumStereo(device, args)
57 | elif args.model == 'bp+ms+h':
58 | model = MinSumStereo(device, args)
59 | elif args.model == 'bp+ms+ref+h':
60 | model = RefinedMinSumStereo(device, args)
61 |
62 | max_disp = None # use original max-disp
63 | res_dict = model.to(device).forward(I0_pyramid, I1_pyramid, max_disp=args.max_disp, step=1)
64 |
65 |
66 | imageio.imwrite("data/output/stereo/" + args.model + ".pfm",
67 | np.flipud(res_dict['disps0'][0].squeeze().float().detach().cpu().numpy()))
--------------------------------------------------------------------------------
/matching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from networks import SubNetwork
3 | from corenet import TemperatureSoftmax
4 |
5 | from ops.sad.stereo_sad import StereoMatchingSadFunction
6 |
7 | class StereoMatching(SubNetwork):
8 | def __init__(self, device, args, min_disp, max_disp, in_channels=None, lvl=0, step=1.0):
9 | super(StereoMatching, self).__init__(args, device)
10 | self.device = device
11 | self.min_disp = min_disp
12 | self.max_disp = max_disp
13 | self.step = step
14 | self.in_channels = in_channels
15 | self.level = lvl
16 |
17 | self._softmax = TemperatureSoftmax(dim=3, init_temp=1.0)
18 |
19 | def compute_score_volume(self, f0, f1):
20 | raise NotImplementedError
21 |
22 | def forward(self, f0, f1):
23 | score_vol = self.compute_score_volume(f0, f1)
24 | prob_vol = self._softmax.forward(score_vol)
25 | return prob_vol.contiguous()
26 |
27 | @staticmethod
28 | def argmin_to_disp(argmin, min_disp):
29 | res = argmin + min_disp
30 | return res
31 |
32 | def save_checkpoint(self, epoch, iteration):
33 | pass
34 |
35 | class StereoMatchingSad(StereoMatching):
36 | def __init__(self, device, args, min_disp, max_disp, lvl=0, step=1.0):
37 | super(StereoMatchingSad, self).__init__(device, args, min_disp, max_disp, lvl=lvl, step=step)
38 |
39 | self.load_parameters(args.checkpoint_matching[self.level], device)
40 | self.to(device)
41 |
42 | def compute_score_volume(self, f0, f1):
43 | cost_vol = StereoMatchingSadFunction.apply(f0, f1, self.min_disp, self.max_disp, self.step)
44 | return -cost_vol
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import os.path as osp
6 |
7 | from corenet import StereoUnaryUnetDyn, PadUnpad, PaddedConv2d, ResidualBlock
8 |
9 | class SubNetwork(nn.Module):
10 | def __init__(self, args, device):
11 | super(SubNetwork, self).__init__()
12 | self.net = None
13 | self.pad_unpad = None
14 | self.args = args
15 | self.device = device
16 |
17 | def forward(self, ipt):
18 | return self.pad_unpad.forward(ipt)
19 |
20 | def freeze_parameters(self):
21 | self.eval()
22 | for p in self.parameters():
23 | p.requires_grad = False
24 |
25 | def parameter_list(self, requires_grad=None):
26 | def check(var):
27 | if requires_grad is None:
28 | return True
29 | elif requires_grad is True and var is True:
30 | return True
31 | elif requires_grad is False and var is False:
32 | return True
33 | else: return False
34 |
35 | params = []
36 | for p in self.parameters():
37 | if check(p.requires_grad):
38 | params.append(p)
39 | return params
40 |
41 | def load_parameters(self, path, device=None):
42 | if path is not None:
43 | if not osp.exists(path):
44 | raise FileNotFoundError("Specified unary checkpoint file does not exist!" + path)
45 |
46 | if device is not None:
47 | checkpoint = torch.load(path, map_location=device)
48 | else:
49 | checkpoint = torch.load(path)
50 |
51 | checkpoint = self.hook_adjust_checkpoint(checkpoint)
52 |
53 | self.load_state_dict(checkpoint)
54 | print('Successfully loaded checkpoint %s' % path)
55 |
56 | def hook_adjust_checkpoint(self, checkpoint):
57 | return checkpoint
58 |
59 | def save_checkpoint(self, epoch, iteration):
60 | raise NotImplementedError
61 |
62 | @property
63 | def divisor(self):
64 | return self.net.divisor
65 |
66 | @property
67 | def pad_input(self):
68 | return self._pad_input
69 |
70 | class FeatureNet(SubNetwork):
71 | def __init__(self, device, args):
72 | super(FeatureNet, self).__init__(args, device)
73 |
74 | self.net = StereoUnaryUnetDyn(args.multi_level_output,
75 | args.activation, args.with_bn, args.with_upconv, args.with_output_bn)
76 |
77 | # provided pad-unpad
78 | self.pad_unpad = PadUnpad(self.net, self.net.divisor, tuple(args.pad))
79 |
80 | # load params
81 | self.load_parameters(args.checkpoint_unary, device)
82 |
83 |
84 | class AffinityNet(SubNetwork):
85 | def __init__(self, device, args):
86 | super(AffinityNet, self).__init__(args, device)
87 |
88 | self.net = StereoUnaryUnetDyn(args.multi_level_output,
89 | args.activation, args.with_bn, args.with_upconv, args.with_output_bn)
90 |
91 |
92 | out_channel_factor = args.num_bp_layers
93 | out_channels = out_channel_factor * 2 * 5 # 2 directions, 5 values # L2-, L2+, L1-, L1+, L3
94 | self.conv_out = [PaddedConv2d(ic, out_channels, 3, bias=True).to(device) for ic in \
95 | self.net.out_channels]
96 | for lvl, conv in enumerate(self.conv_out):
97 | self.add_module('conv_out_lvl' + str(lvl), conv)
98 |
99 | # provide pad-unpad
100 | self.pad_unpad = PadUnpad(self.net, self.net.divisor, tuple(args.pad))
101 | self.to(device)
102 |
103 | # load params
104 | self.load_parameters(args.checkpoint_affinity, device)
105 |
106 | def forward(self, ipt):
107 | features = super(AffinityNet, self).forward(ipt)
108 | affinities = [torch.abs(conv(fi)) for conv, fi in zip(self.conv_out, features)] # outshape = N x 2 * 5 x H x W
109 |
110 | return affinities
111 |
112 | class RefinementNet(SubNetwork):
113 | def __init__(self, device, args, in_channels=5, out_channels=1, with_output_relu=True):
114 | super(RefinementNet, self).__init__(args, device)
115 |
116 | self.net = []
117 | self.with_output_relu = with_output_relu
118 |
119 | for lvl in range(self.args.input_level_offset, self.args.output_level_offset - 1, -1):
120 | net = nn.Sequential(
121 | PaddedConv2d(in_channels, 32, 3, bias=True),
122 | ResidualBlock(32, 32, 3, dilation=1),
123 | ResidualBlock(32, 32, 3, dilation=2),
124 | ResidualBlock(32, 32, 3, dilation=4),
125 | ResidualBlock(32, 32, 3, dilation=8),
126 | ResidualBlock(32, 32, 3, dilation=1),
127 | ResidualBlock(32, 32, 3, dilation=1),
128 | PaddedConv2d(32, out_channels, 3, bias=True),
129 | )
130 | self.net.append(net)
131 | self.add_module('sn_' + str(lvl), net)
132 |
133 | self.to(device)
134 | self.relu = nn.ReLU(inplace=True)
135 |
136 | # load params
137 | self.load_parameters(args.checkpoint_refinement, device)
138 |
139 |
140 | def forward(self, I0_pyramid, d0, confidence, I1_pyramid):
141 | d0_lvl = d0.clone()
142 | refined_pyramid = []
143 | residuum_pyramid = []
144 | for ref_lvl in range(self.args.input_level_offset, self.args.output_level_offset -1, -1):
145 | I0_lvl = I0_pyramid[ref_lvl].to(self.device)
146 | I0_lvl = I0_lvl / I0_lvl.var(dim=(2,3), keepdim=True)
147 |
148 | # adapt input size
149 | scale_factor = I0_lvl.shape[2] / d0_lvl.shape[2]
150 | if abs(scale_factor - round(scale_factor)) > 0:
151 | print('WARNING: something weird is going on, got a fractional scale-factor in ref', scale_factor)
152 | d0_up = F.interpolate(d0_lvl.float(), size=I0_lvl.shape[2:], mode='nearest') * scale_factor
153 | conf_up = F.interpolate(confidence.float(), size=I0_lvl.shape[2:], mode='nearest')
154 |
155 | # compute input tensor
156 | ipt = torch.cat((I0_lvl, d0_up, conf_up), dim=1)
157 | residuum = self.net[ref_lvl - self.args.output_level_offset].forward(ipt)
158 |
159 | d0_lvl = d0_up.float() + residuum # flow
160 | if self.with_output_relu:
161 | d0_lvl = self.relu(d0_lvl) # stereo
162 | refined_pyramid.append(d0_lvl)
163 | residuum_pyramid.append(residuum)
164 |
165 | return refined_pyramid, [residuum_pyramid]
166 |
167 |
168 | class EdgeNet(SubNetwork):
169 | def __init__(self, device, args):
170 | super(EdgeNet, self).__init__(args, device)
171 |
172 | self.net = StereoUnaryUnetDyn(args.multi_level_output, args.activation, args.with_bn, args.with_upconv, args.with_output_bn)
173 |
174 | out_channels = args.num_bp_layers * 2
175 | self.conv_out = [PaddedConv2d(ic, out_channels, 3, bias=True).to(device) for ic in \
176 | self.net.out_channels]
177 | for lvl, conv in enumerate(self.conv_out):
178 | self.add_module('conv_out_lvl' + str(lvl), conv)
179 |
180 | # provide pad-unpad
181 | self.pad_unpad = PadUnpad(self.net, self.net.divisor, tuple(args.pad))
182 | self.to(device)
183 |
184 | def forward(self, ipt):
185 | features = super(EdgeNet, self).forward(ipt)
186 | edge_weights = [torch.abs(conv(fi)) for conv, fi in zip(self.conv_out, features)]
187 | return edge_weights
188 |
189 |
190 | class PWNet(SubNetwork):
191 | def __init__(self, device, args):
192 | super(PWNet, self).__init__(args, device)
193 |
194 | self.net = StereoUnaryUnetDyn(args.multi_level_output,
195 | args.activation, args.with_bn, args.with_upconv, args.with_output_bn)
196 |
197 | out_channels = args.num_bp_layers * 2 * args.num_labels * args.num_labels # 2 directions, 5 values # L2-, L2+, L1-, L1+, L3
198 |
199 | self.conv_out = [PaddedConv2d(ic, out_channels, 3, bias=True).to(device) for ic in \
200 | self.net.out_channels]
201 | for lvl, conv in enumerate(self.conv_out):
202 | self.add_module('conv_out_lvl' + str(lvl), conv)
203 |
204 | # provide pad-unpad
205 | self.pad_unpad = PadUnpad(self.net, self.net.divisor, tuple(args.pad))
206 | self.to(device)
207 |
208 | def forward(self, ipt):
209 | features = super(PWNet, self).forward(ipt)
210 |
211 | pairwise_costs = [torch.abs(conv(fi)) for conv, fi in zip(self.conv_out, features)] # outshape = N x 2 * 5 x H x W
212 |
213 | return pairwise_costs
--------------------------------------------------------------------------------
/ops/flow_mp_sad/flow_mp_sad.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import pytorch_cuda_flow_mp_sad_op
4 |
5 | class FlowMpSadFunction(torch.autograd.Function):
6 | @staticmethod
7 | def forward(ctx, f0, f1, sws):
8 | # shape = N x H x W x K => 2x 3D cost volume for u (=x dir) and v (=y dir)
9 | offset_u = offset_v = 0
10 | block_u = block_v = 0
11 | if sws <= 108:
12 | #if False:
13 | cv_u, cv_v, u_star, v_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, sws, offset_u, offset_v, block_u, block_v)
14 | elif sws == 2*108:
15 | #else:
16 | #print('Hello -> split-up flow SAD')
17 | s = sws // 2 # new sub-search-window-size
18 | cv00_u, cv00_v, u00_star, v00_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, s, -s//2, -s//2, 0, 0)#x0y0
19 | cv01_u, cv01_v, u01_star, v01_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, s, -s//2, s//2, 0, 1)#x0y1
20 | cv10_u, cv10_v, u10_star, v10_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, s, s//2, -s//2, 1, 0)#x1y0
21 | cv11_u, cv11_v, u11_star, v11_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, s, s//2, s//2, 1, 1)#x1y1
22 |
23 | # ref
24 | #cv_u_ref, cv_v_ref, u_star_ref, v_star_ref = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, sws, 0,0,0,0)
25 |
26 | # merge sub-volums to one volume
27 | u0 = torch.cat((cv00_u, cv10_u[:,:,:,1:]), dim=-1)
28 | u1 = torch.cat((cv01_u, cv11_u[:,:,:,1:]), dim=-1)
29 | au0 = torch.cat((u00_star, u10_star[:,:,:,1:]), dim=-1)
30 | au1 = torch.cat((u01_star, u11_star[:,:,:,1:]), dim=-1)
31 |
32 | idx = u1 < u0
33 | u0[idx] = u1[idx] # overwrite better values
34 | au0[idx] = au1[idx]
35 |
36 |
37 | v0 = torch.cat((cv00_v, cv01_v[:,:,:,1:]), dim=-1)
38 | v1 = torch.cat((cv10_v, cv11_v[:,:,:,1:]), dim=-1)
39 | av0 = torch.cat((v00_star, v01_star[:,:,:,1:]), dim=-1)
40 | av1 = torch.cat((v10_star, v11_star[:,:,:,1:]), dim=-1)
41 |
42 | idx = v1 < v0
43 | v0[idx] = v1[idx] # overwrite better valves
44 | av0[idx] = av1[idx]
45 |
46 | cv_u = u0
47 | u_star = au0
48 | cv_v = v0
49 | v_star = av0
50 | elif sws == 432: # sintel all!!
51 | # global offset
52 | go_u = 0
53 | go_v = 0
54 |
55 | # read cost-volume block u
56 | # index with row and col
57 | cv_bu = [[], []]
58 | cv_bv = [[], []]
59 | bu_star = [[], []]
60 | bv_star = [[], []]
61 |
62 | s = 108 # search-window size of block
63 | # iterate over 4 x 4 grid
64 | for idx_v, grid_v in enumerate([-2, -1, 1, -2]): # grid-position
65 | for idx_u, grid_u in enumerate([-2, -1, 1, -2]):
66 | ro_u = (grid_u * 2 - 1) * s // 2 # relative offset in x
67 | ro_v = (grid_v * 2 - 1) * s // 2 # relative offset in y
68 |
69 | cv_uu, cv_vv, uu_star, vv_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, s,
70 | ro_u + go_u, ro_v + go_v, idx_u, idx_v)
71 | cv_bu[idx_u].append(cv_uu)
72 | cv_bv[idx_u].append(cv_vv)
73 | bu_star[idx_u].append(uu_star)
74 | bv_star[idx_v].append(vv_star)
75 |
76 | # stitch complete cv_u
77 | # read: u0 for all v, u_star for all v
78 | u0_v = [cv_bu[0][idx_v] for idx_v in range(4)]
79 | u_star_v = [bu_star[0][idx_v] for idx_v in range(4)]
80 | for idx_u in range(1, len(cv_bu[0])):
81 | for idx_v in range(4):
82 | # increment u, keepp v
83 | u0_v[idx_v] = torch.cat((u0_v[idx_v], cv_bu[idx_u][idx_v][:,:,:,1:]), dim=-1)
84 | u_star_v[idx_v] = torch.cat((u_star_v[idx_v], bu_star[idx_u][idx_v][:,:,:,1:]), dim=-1)
85 |
86 |
87 |
88 |
89 | #elif sws == 372: # kitti
90 | # ATTENTION: FORWARD WOULD WORK NICELY LIKE THAT, BUT BACKWARD WOULD NEED THE ORIGINAL SEARCH WINDOW!!
91 | # PROBABLY NOT A GOOD IDEA TO CHANGE THIS NOW ...
92 | # # 2 x 6 grid with size 372 x 124 with block-size 62 and offset (+18, +36)
93 |
94 | # # global offset
95 | # go = np.array([18, 36])
96 |
97 | # # read cost-volume block u
98 | # # index with row and col
99 | # cv_bu = [[], []]
100 | # cv_bv = [[], []]
101 | # bu_star = [[], []]
102 | # bv_star = [[], []]
103 |
104 | # # iterate over 2 x 6 grid
105 | # for grid_v in range(-1, 2):
106 | # for grid_u in range(-3, 4):
107 | # cv_uu, cv_vv, uu_star, vv_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, s, -s//2, -s//2, 0, 0)#x0y0
108 | else:
109 | raise ValueError("Unsupported sws: ", sws, "only allowd: <= 108 or 216")
110 |
111 | ctx.save_for_backward(f0, f1, torch.tensor(sws), u_star, v_star)
112 | return cv_u, cv_v, u_star, v_star
113 |
114 | @staticmethod
115 | def backward(ctx, in_grad_u, in_grad_v, u_star_grad, v_star_grad):
116 | # u_star_grad and v_star_grad are just zeros.
117 | f0, f1, sws, u_star, v_star = ctx.saved_tensors
118 | df0, df1 = pytorch_cuda_flow_mp_sad_op.backward(f0, f1, int(sws), in_grad_u, in_grad_v, u_star, v_star)
119 | return df0, df1, None, None
120 |
--------------------------------------------------------------------------------
/ops/flow_mp_sad/src/flow_mp_sad.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "flow_mp_sad_kernel.cuh"
3 |
4 | // C++ interface
5 | // AT_ASSERTM in pytorch 1.0
6 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
7 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
8 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x);
9 |
10 | std::vector flow_mp_sad_forward(at::Tensor f0, at::Tensor f1, int sws, int offset_u, int offset_v,
11 | int blockIdx_u, int blockIdx_v)
12 | {
13 | CHECK_INPUT(f0)
14 | CHECK_INPUT(f1)
15 | return cuda::flow_mp_sad_forward(f0, f1, sws, offset_u, offset_v, blockIdx_u, blockIdx_v);
16 | }
17 |
18 | std::vector flow_mp_sad_backward(at::Tensor f0, at::Tensor f1, int sws,
19 | at::Tensor in_grad_u, at::Tensor in_grad_v,
20 | at::Tensor u_star, at::Tensor v_star)
21 | {
22 | CHECK_INPUT(f0)
23 | CHECK_INPUT(f1)
24 | CHECK_INPUT(in_grad_u)
25 | CHECK_INPUT(in_grad_v)
26 | CHECK_INPUT(u_star)
27 | CHECK_INPUT(v_star)
28 |
29 | return cuda::flow_mp_sad_backward(f0, f1, sws, in_grad_u, in_grad_v, u_star, v_star);
30 | }
31 |
32 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
33 | {
34 | m.def("forward", &flow_mp_sad_forward, "Flow correlation Matching forward (CUDA)");
35 | m.def("backward", &flow_mp_sad_backward, "Flow correlation Matching backward (CUDA)");
36 | }
--------------------------------------------------------------------------------
/ops/flow_mp_sad/src/flow_mp_sad_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "flow_mp_sad_kernel.cuh"
4 | #include "tensor.h"
5 | #include "error_util.h"
6 |
7 | // ============================================================================
8 | // CUDA KERNELS
9 | // ============================================================================
10 | __global__ void flow_mp_sad_cuda_forward_kernel(
11 | KernelData f0,
12 | KernelData f1,
13 | int sws,
14 | KernelData cv_u,
15 | KernelData cv_v,
16 | KernelData u_star,
17 | KernelData v_star,
18 | int offset_u,
19 | int offset_v,
20 | int blockIdx_u, // necessary for argmin computation
21 | int blockIdx_v // same here
22 | )
23 | {
24 | // parallelize over u, loop over v
25 | //const int x = blockIdx.x * blockDim.x + threadIdx.x;
26 | const int y = blockIdx.y * blockDim.y + threadIdx.y;
27 | // const int u_idx = blockIdx.z * blockDim.z + threadIdx.z;
28 |
29 | const int u_idx = blockIdx.x * blockDim.x + threadIdx.x;
30 | const int x = blockIdx.z * blockDim.z + threadIdx.z;
31 |
32 | // shared memory for search-window matching costs
33 | extern __shared__ float sdata[];
34 |
35 | // global defines
36 | unsigned short K = cv_u.size3;
37 | short sws_half = sws / 2;
38 | short u = u_idx - sws_half;
39 |
40 | //short sm_offset = blockDim.x * K * K * threadIdx.y + K * K * threadIdx.x;
41 | short sm_offset = blockDim.z * K * K * threadIdx.y + K * K * threadIdx.z;
42 |
43 | // check inside image for reference pixel
44 | int n = 0;
45 | if(x >= f0.size3 || y >= f0.size2 || u_idx >= K)
46 | return;
47 |
48 | // initialize all sws with constant value (initialize all v displacements for given u_idx)
49 | for(short v_idx = 0; v_idx < K; ++v_idx)
50 | {
51 | sdata[sm_offset + K * v_idx + u_idx] = 40.0;
52 | }
53 | __syncthreads();
54 |
55 | // skip outside pixels
56 | // if(x + u < 0 || x + u >= f0.size3)
57 | // return;
58 |
59 | // I cannot return outside pixels directly, because I need all the treads for the min-computation
60 | // later!!
61 | if(x + offset_u + u >= 0 && x + offset_u + u < f0.size3) // check match inside
62 | {
63 | for(short v = -sws_half; v <= sws_half; ++v)
64 | {
65 | short v_idx = v + sws_half;
66 |
67 | // skip outside pixels (match-pixel)
68 | if(y + offset_v + v < 0 || y + offset_v + v >= f0.size2)
69 | continue;
70 |
71 | float sad = 0.0f;
72 | for(int c = 0; c < f0.size1; ++c)
73 | {
74 | sad += fabs(f0(n, c, y, x) - f1(n, c, y + offset_v + v, x + offset_u + u));
75 | }
76 |
77 | // save result to shared mem
78 | sdata[sm_offset + K * v_idx + u_idx] = sad;
79 | //cv_all(n, y, x, v_idx, u_idx) = sad;
80 | }
81 | }
82 | __syncthreads(); // all u-threads must be ready here!
83 |
84 | // compute min-projection in shared memory
85 | // Note: u_idx is parallelized within the kernel!
86 | float min_v = 9999999.0;
87 | short argmin_v = 0;
88 | for(unsigned short v_idx = 0; v_idx < K; ++v_idx)
89 | {
90 | if(sdata[sm_offset + K * v_idx + u_idx] < min_v)
91 | {
92 | min_v = sdata[sm_offset + K * v_idx + u_idx];
93 | argmin_v = v_idx;
94 | }
95 | }
96 |
97 | // update min only if the current block has a better min
98 | // if(min_v < cv_u(n, y, x, u_idx)) // for inplace variant which I do not have yet
99 | //{
100 | cv_u(n, y, x, u_idx) = min_v;
101 | u_star(n, y, x, u_idx) = argmin_v + blockIdx_v * sws; // sws = K - 1 => default overlap
102 | //}
103 |
104 | // compute min-projection in shared memory
105 | // here I swap rules and use the u_idx as v_idx for easier parallelization
106 | float min_u = 9999999.0;
107 | short v_idx = u_idx;
108 | short argmin_u = 0;
109 | for(unsigned short u_idx = 0; u_idx < K; ++u_idx)
110 | {
111 | if(sdata[sm_offset + K * v_idx + u_idx] < min_u)
112 | {
113 | min_u = sdata[sm_offset + K * v_idx + u_idx];
114 | argmin_u = u_idx;
115 | }
116 | }
117 |
118 | // update min only if the current block has a better min
119 | //if(min_u < cv_v(n, y, x, v_idx)) // for inplace variant which I do not have yet
120 | //{
121 | cv_v(n, y, x, v_idx) = min_u;
122 | v_star(n, y, x, v_idx) = argmin_u + blockIdx_u * sws; // sws = K - 1 => default overlap
123 | //}
124 | }
125 |
126 | __global__ void flow_mp_sad_cuda_backward_kernel(
127 | KernelData f0,
128 | KernelData f1,
129 | int sws,
130 | KernelData in_grad_u,
131 | KernelData in_grad_v,
132 | KernelData u_star,
133 | KernelData v_star,
134 | KernelData df0,
135 | KernelData df1
136 | )
137 | {
138 | const unsigned int x = blockIdx.x * blockDim.x + threadIdx.x;
139 | const unsigned int y = blockIdx.y * blockDim.y + threadIdx.y;
140 | const unsigned int c = blockIdx.z * blockDim.z + threadIdx.z;
141 |
142 | float eps = 1e-15;
143 |
144 | // check inside image
145 | int n = 0;
146 | if(x >= f0.size3 || y >= f0.size2 || c >= f0.size1)
147 | return;
148 |
149 | int sws_half = sws / 2;
150 |
151 | float grad_f0 = 0.0f;
152 | float grad_f1 = 0.0f;
153 | for(short u = -sws_half; u <= sws_half; ++u)
154 | {
155 | short u_idx = u + sws_half;
156 | short v_idx = u_star(n, y, x, u_idx);
157 | short v = v_idx - sws_half;
158 |
159 | // skip outside pixels
160 | if(x + u >= 0 && x + u < f0.size3 && y + v >= 0 && y + v < f0.size2)
161 | {
162 | float diff = f0(n, c, y, x) - f1(n, c, y + v, x + u);
163 | if(fabsf(diff) > eps) // gradient is zero if diff is zero!
164 | {
165 | float update = diff / fabsf(diff) * in_grad_u(n, y, x, u_idx);
166 | // local update for df0
167 | grad_f0 += update;
168 |
169 | // global update for df1 (multiple vars can point to one address!)
170 | atomicAdd(&df1(n, c, y + v, x + u), -update);
171 | }
172 | }
173 |
174 | }
175 |
176 | for(short v = -sws_half; v <= sws_half; ++v)
177 | {
178 | short v_idx = v + sws_half;
179 | short u_idx = v_star(n, y, x, v_idx);
180 | short u = u_idx - sws_half;
181 |
182 | // copied from above, only change is that here in_grad_v is used
183 | if(x + u >= 0 && x + u < f0.size3 && y + v >= 0 && y + v < f0.size2)
184 | {
185 | float diff = f0(n, c, y, x) - f1(n, c, y + v, x + u);
186 | if(fabsf(diff) > eps) // gradient is zero if diff is zero!
187 | {
188 | float update = diff / fabsf(diff) * in_grad_v(n, y, x, v_idx);
189 | // local update for df0
190 | grad_f0 += update;
191 |
192 | // global update for df1 (multiple vars can point to one address!)
193 | atomicAdd(&df1(n, c, y + v, x + u), -update);
194 | }
195 | }
196 | }
197 |
198 | df0(n, c, y, x) = grad_f0;
199 |
200 | }
201 |
202 |
203 | // ============================================================================
204 | // CPP KERNEL CALLS
205 | // ============================================================================
206 | namespace cuda
207 | {
208 | std::vector flow_mp_sad_forward(at::Tensor f0, at::Tensor f1, int sws, int offset_u, int offset_v,
209 | int blockIdx_u, int blockIdx_v)
210 | {
211 | int N = f0.size(0);
212 | int C = f0.size(1);
213 | int H = f0.size(2);
214 | int W = f0.size(3);
215 | int K = sws + 1;
216 |
217 | auto cv_u = at::ones({N, H, W, K}, f0.options()) * 40;
218 | auto cv_v = at::ones({N, H, W, K}, f0.options()) * 40;
219 |
220 | auto u_star = at::zeros({N, H, W, K}, f0.options());
221 | auto v_star = at::zeros({N, H, W, K}, f0.options());
222 |
223 | //auto cv_all = at::ones({N, H, W, K, K}, f0.options()) * 40;
224 |
225 | if(K > 128)
226 | std::cout << "Error: Maximal search window size is " << K << " which is larger than max allowed 128!!" << std::endl;
227 |
228 | // parallelise over H x W x K
229 | // all K need to be in one block in order to have access to the same shared memory!
230 | // K needs to be the first, because last idx must be < 64.
231 | const dim3 blockSize(K, 1, 1);
232 | const dim3 numBlocks(std::ceil(K / static_cast(blockSize.x)),
233 | std::ceil(H / static_cast(blockSize.y)),
234 | std::ceil(W / static_cast(blockSize.z)));
235 |
236 | const int threadsPerBlock = blockSize.x * blockSize.y * blockSize.z;
237 |
238 | // std::cout << "N=" << N << " C=" << C << " H=" << H << " W=" << W << " K=" << K << std::endl;
239 | // std::cout << "threadsPerBlock=" << threadsPerBlock << std::endl;
240 | // std::cout << "numBlocks.x=" << numBlocks.x << " .y=" << numBlocks.y << " .z=" << numBlocks.z << std::endl;
241 | // std::cout << "mem-use=" << threadsPerBlock*K*sizeof(float) << "bytes" << std::endl;
242 |
243 | //CudaTimer cut;
244 | //cut.start();
245 | flow_mp_sad_cuda_forward_kernel<<>>(f0, f1, sws, cv_u, cv_v, u_star, v_star, offset_u, offset_v, blockIdx_u, blockIdx_v);
246 | cudaSafeCall(cudaGetLastError());
247 | // cudaDeviceSynchronize();
248 | //std::cout << "SAD forward time " << cut.elapsed() << std::endl;
249 | std::vector res;
250 | //cost_vols.push_back(cv_all);
251 | res.push_back(cv_u);
252 | res.push_back(cv_v);
253 | res.push_back(u_star);
254 | res.push_back(v_star);
255 | return res;
256 | }
257 |
258 | std::vector flow_mp_sad_backward(at::Tensor f0, at::Tensor f1,
259 | int sws, at::Tensor in_grad_u, at::Tensor in_grad_v,
260 | at::Tensor u_star, at::Tensor v_star)
261 | {
262 | int N = f0.size(0);
263 | int C = f0.size(1);
264 | int H = f0.size(2);
265 | int W = f0.size(3);
266 | int K = sws + 1;
267 |
268 | auto df0 = at::zeros_like(f0);
269 | auto df1 = at::zeros_like(f1);
270 |
271 | // parallelise over H x W x D
272 | const dim3 blockSize(8, 8, 4);
273 | const dim3 numBlocks(std::ceil(W / static_cast(blockSize.x)),
274 | std::ceil(H / static_cast(blockSize.y)),
275 | std::ceil(C / static_cast(blockSize.z)));
276 |
277 | //CudaTimer cut;
278 | //cut.start();
279 | flow_mp_sad_cuda_backward_kernel<<>>(f0, f1, sws, in_grad_u, in_grad_v,
280 | u_star, v_star, df0, df1);
281 | cudaSafeCall(cudaGetLastError());
282 | // cudaDeviceSynchronize();
283 |
284 | //std::cout << "SAD backward time " << cut.elapsed() << std::endl;
285 |
286 | std::vector gradients;
287 | gradients.push_back(df0);
288 | gradients.push_back(df1);
289 |
290 | return gradients;
291 | }
292 | }
--------------------------------------------------------------------------------
/ops/flow_mp_sad/src/flow_mp_sad_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | namespace cuda
5 | {
6 | std::vector flow_mp_sad_forward(at::Tensor f0, at::Tensor f1, int sws, int offset_u, int offset_v,
7 | int blockIdx_u, int blockIdx_v);
8 | std::vector flow_mp_sad_backward(at::Tensor f0, at::Tensor f1, int sws,
9 | at::Tensor in_grad_u, at::Tensor in_grad_v,
10 | at::Tensor u_star, at::Tensor v_star);
11 | }
--------------------------------------------------------------------------------
/ops/include/error_util.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | //#include "cudnn.h"
5 | #include "stdio.h"
6 | #include "stdlib.h"
7 |
8 | #define CUDA_ERROR_CHECK
9 |
10 | #define cudaSafeCall( err ) __cnnCudaSafeCall( err, __FILE__, __LINE__ )
11 | #define cudnnSafeCall( err ) __cnnCudnnSafeCall( err, __FILE__, __LINE__ )
12 |
13 | inline void __cnnCudaSafeCall( cudaError_t err, const char *file, const int line )
14 | {
15 | #ifdef CUDA_ERROR_CHECK
16 | if ( cudaSuccess != err )
17 | {
18 | fprintf( stderr, "cudaSafeCall() failed at %s:%i : %s\n", file, line, cudaGetErrorString( err ) );
19 | exit( -1 );
20 | }
21 | #endif
22 |
23 | return;
24 | }
25 |
26 | // inline void __cnnCudnnSafeCall( cudnnStatus_t err, const char *file, const int line)
27 | // {
28 | // #ifdef CUDA_ERROR_CHECK
29 | // if(err != CUDNN_STATUS_SUCCESS)
30 | // {
31 | // fprintf( stderr, "cudaSafeCall() failed at %s:%i : %s\n", file, line, cudnnGetErrorString( err ) );
32 | // exit( -1 );
33 | // }
34 | // #endif
35 |
36 | // return;
37 | // }
38 |
--------------------------------------------------------------------------------
/ops/include/tensor.h:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 | /** \brief Struct pointer TensorKernelData that can be used in CUDA kernels.
7 | *
8 | * This struct provides the device data pointer as well as important class
9 | * properties.
10 | */
11 | struct KernelData5
12 | {
13 | /** Pointer to device buffer. */
14 | float* data_;
15 |
16 | unsigned int stride0;
17 | unsigned int stride1;
18 | unsigned int stride2;
19 | unsigned int stride3;
20 | unsigned int stride4;
21 |
22 | // size of dimensions 0 - 4
23 | unsigned short size0;
24 | unsigned short size1;
25 | unsigned short size2;
26 | unsigned short size3;
27 | unsigned short size4;
28 |
29 |
30 | /** Access the image via the () operator
31 | * @param x0 Position in the first dimension.
32 | * @param x1 Position in the second dimension.
33 | * @param x2 Position in the third dimension.
34 | * @param x3 Position in the forth dimension.
35 | * @param x4 Position in the fifth dimension.
36 | * @return value at position (x0, x1, x2, x3, x4).
37 | */
38 | __device__ float& operator()(short x0, short x1, short x2, short x3, short x4)
39 | {
40 | return data_[x0 * stride0 + x1 * stride1 + x2 * stride2 + x3 *stride3 + x4];
41 | }
42 |
43 | /** Get position / coordinates for a linear index.
44 | * @param[in] linearIdx Linear index.
45 | * @param[out] dim0 Position in the first dimension.
46 | * @param[out] dim1 Position in the second dimension.
47 | * @param[out] dim2 Position in the third dimension.
48 | * @param[out] dim3 Position in the forth dimension.
49 | */
50 | __device__ void coords(unsigned int linearIdx, short *x0, short *x1, short *x2, short *x3, short *x4)
51 | {
52 | // modulo is slow
53 | // *dim0 = linearIdx / stride0;
54 | // *dim1 = (linearIdx % stride0) / stride1;
55 | // *dim2 = ((linearIdx % stride0) % stride1) / stride2;
56 | // *dim3 = ((linearIdx % stride0) % stride1) % stride2;
57 | *x0 = linearIdx / stride0;
58 | *x1 = (linearIdx - *x0 * stride0) / stride1;
59 | *x2 = (linearIdx - (*x0 * stride0 + *x1 * stride1)) / stride2;
60 | *x3 = linearIdx - (*x0 * stride0 + *x1 * stride1 + *x2 * stride2);
61 | *x4 = linearIdx - (*x0 * stride0 + *x1 * stride1 + *x2 * stride2 + *x3 * stride3);
62 | }
63 |
64 | /** Constructor */
65 | __host__ KernelData5(const at::Tensor &tensor) :
66 | data_(tensor.data()),
67 | size0(tensor.size(0)),
68 | size1(tensor.size(1)),
69 | size2(tensor.size(2)),
70 | size3(tensor.size(3)),
71 | size4(tensor.size(4)),
72 | stride0(tensor.stride(0)),
73 | stride1(tensor.stride(1)),
74 | stride2(tensor.stride(2)),
75 | stride3(tensor.stride(3)),
76 | stride4(tensor.stride(4))
77 | {
78 | }
79 | };
80 |
81 | /** \brief Struct pointer TensorKernelData that can be used in CUDA kernels.
82 | *
83 | * This struct provides the device data pointer as well as important class
84 | * properties.
85 | */
86 | struct KernelData
87 | {
88 | /** Pointer to device buffer. */
89 | float* data_;
90 |
91 | unsigned int stride0;
92 | unsigned int stride1;
93 | unsigned int stride2;
94 | unsigned int stride3;
95 |
96 | // size of dimensions 0 - 3
97 | unsigned short size0;
98 | unsigned short size1;
99 | unsigned short size2;
100 | unsigned short size3;
101 |
102 |
103 | /** Access the image via the () operator
104 | * @param x0 Position in the first dimension.
105 | * @param x1 Position in the second dimension.
106 | * @param x2 Position in the third dimension.
107 | * @param x3 Position in the forth dimension.
108 | * @return value at position (x0, x1, x2, x3).
109 | */
110 | __device__ float& operator()(short x0, short x1, short x2, short x3)
111 | {
112 | return data_[x0 * stride0 + x1 * stride1 + x2 * stride2 + x3];
113 | }
114 |
115 | /** Get position / coordinates for a linear index.
116 | * @param[in] linearIdx Linear index.
117 | * @param[out] dim0 Position in the first dimension.
118 | * @param[out] dim1 Position in the second dimension.
119 | * @param[out] dim2 Position in the third dimension.
120 | * @param[out] dim3 Position in the forth dimension.
121 | */
122 | __device__ void coords(unsigned int linearIdx, short *x0, short *x1, short *x2, short *x3)
123 | {
124 | // modulo is slow
125 | // *dim0 = linearIdx / stride0;
126 | // *dim1 = (linearIdx % stride0) / stride1;
127 | // *dim2 = ((linearIdx % stride0) % stride1) / stride2;
128 | // *dim3 = ((linearIdx % stride0) % stride1) % stride2;
129 | *x0 = linearIdx / stride0;
130 | *x1 = (linearIdx - *x0 * stride0) / stride1;
131 | *x2 = (linearIdx - (*x0 * stride0 + *x1 * stride1)) / stride2;
132 | *x3 = linearIdx - (*x0 * stride0 + *x1 * stride1 + *x2 * stride2);
133 | }
134 |
135 | /** Constructor */
136 | __host__ KernelData(const at::Tensor &tensor) :
137 | data_(tensor.data()),
138 | size0(tensor.size(0)),
139 | size1(tensor.size(1)),
140 | size2(tensor.size(2)),
141 | size3(tensor.size(3)),
142 | stride0(tensor.stride(0)),
143 | stride1(tensor.stride(1)),
144 | stride2(tensor.stride(2)),
145 | stride3(tensor.stride(3))
146 | {
147 | // std::cout << "size of size " << tensor.sizes().size() << std::endl;
148 | // std::cout << "s0 " << tensor.size(0) << std::endl;
149 | // std::cout << "s1 " << tensor.size(1) << std::endl;
150 | // std::cout << "s2 " << tensor.size(2) << std::endl;
151 | // std::cout << "s3 " << tensor.size(3) << std::endl;
152 | //std::cout << "s0 " << tensor.size(4) << std::endl;
153 | }
154 | };
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw/bp_op_cuda.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 |
7 | from networks import SubNetwork
8 |
9 | from ops.lbp_semantic_pw.inference_op import Inference
10 | from ops.lbp_semantic_pw.message_passing_op_cuda import MessagePassing
11 |
12 | class BP(SubNetwork):
13 | @staticmethod
14 | def construct_pw_energy_weights(K, L1, L2):
15 | '''K = num labels'''
16 | fij = np.ones((K, K)) * L2
17 | for i in range(-1, 2):
18 | if i == 0:
19 | fij -= np.diag(L2 * np.ones(K - np.abs(i)), i)
20 | else:
21 | fij -= np.diag((L2 - L1) * np.ones(K - np.abs(i)), i)
22 |
23 | return fij
24 |
25 | @staticmethod
26 | def construct_pw_prob_weights(K, L1, L2):
27 | '''K = num labels'''
28 | fij = np.zeros((K, K))
29 |
30 | for i in range(-1, 2):
31 | if i == 0:
32 | fij += np.diag(L1 * np.ones(K - np.abs(i)), i)
33 | else:
34 | fij += np.diag(L2 * np.ones(K - np.abs(i)), i)
35 |
36 | fij[fij == 0] = 1.0 - L1 - L2 #(1.0 - L1 - 2 * L2) / (K - 3.0) #1.0 - L1 - L2
37 |
38 | return fij
39 |
40 | # modes = wta / expectation
41 | def __init__(self, device, args, max_iter, num_labels, delta, mode_inference='expectation', mode_message_passing='min-sum', layer_idx=0, single_pw=False):
42 | super(BP, self).__init__(args, device)
43 | self.device = device
44 | self.max_iter = max_iter
45 | self.layer_idx = layer_idx
46 | self.delta = delta
47 | self.single_pw = single_pw
48 |
49 | if mode_inference != 'wta' and mode_inference != 'expectation' and mode_message_passing != 'min-sum' and mode_inference != 'norm' and mode_inference != 'raw':
50 | raise ValueError("Unknown inference/message passing mode " + mode_inference + " " + mode_message_passing)
51 |
52 | self.message_passing = MessagePassing(self.device, self.max_iter, num_labels, self.delta, mode_message_passing)
53 | self.inference = Inference(self.device, mode_inference, mode_passing=mode_message_passing)
54 |
55 | def forward(self, prob_vol, edge_weights, jump):
56 |
57 | N, H, W, K = prob_vol.shape
58 | messages = torch.zeros((N, 4, H, W, K), requires_grad=True, device=self.device,
59 | dtype=torch.float)
60 |
61 | # compute messages
62 | beliefs = self.message_passing.forward(prob_vol, edge_weights, messages, jump, self.single_pw)
63 |
64 | # + wta/expectation
65 | result = self.inference.forward(beliefs)
66 |
67 | return result.permute(0,3,1,2), beliefs
68 |
69 | def project_jumpcosts(self):
70 | self.message_passing.projectL1L2()
71 |
72 | def save_checkpoint(self, epoch, iteration):
73 | if 'c' in self.args.train_params:
74 | torch.save(self.state_dict(),
75 | osp.join(self.args.train_dir, 'crf' + str(self.layer_idx) + '_checkpoint_' +
76 | str(epoch) + '_' + str(iteration).zfill(6) + '.cpt'))
77 |
78 | def adjust_input_weights(self, weights, idx):
79 | if weights is not None:
80 | weights_idx = weights[:, idx * 2 : (idx + 1) * 2, :, :]
81 |
82 | # wx_L = np.zeros_like(wx)
83 | # wy_D = np.zeros_like(wy)
84 | # wx_L[:, 1:] = wx[:, :-1]
85 | # wy_D[1:, :] = wy[:-1, :]
86 |
87 | weights_input = torch.zeros((weights_idx.shape[0], 4, weights_idx.shape[2], weights_idx.shape[3])).cuda()
88 |
89 | weights_input[:, 0] = weights[:, 0]
90 | # wx RL
91 | weights_input[:, 1, :, 1:] = weights[:, 0, :, :-1]
92 | # wy UD
93 | weights_input[:, 2] = weights[:, 1]
94 | # wy DU
95 | weights_input[:, 3, 1:, :] = weights[:, 1, :-1, :]
96 |
97 | weights_input = weights_input.contiguous()
98 |
99 | else:
100 | weights_input = None
101 |
102 | return weights_input
103 |
104 | @property
105 | def L1(self):
106 | return self.message_passing.L1
107 |
108 | @L1.setter
109 | def L1(self, value):
110 | self.message_passing.L1.data = torch.tensor(value, device=self.device, dtype=torch.float)
111 |
112 | @property
113 | def L2(self):
114 | return self.message_passing.L2
115 |
116 | @L2.setter
117 | def L2(self, value):
118 | self.message_passing.L2.data = torch.tensor(value, device=self.device, dtype=torch.float)
119 |
120 | @property
121 | def rescaleT(self):
122 | return self.message_passing.rescaleT
123 |
124 | @rescaleT.setter
125 | def rescaleT(self, value):
126 | self.message_passing.rescaleT = value
127 |
128 |
129 |
130 |
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw/inference_op.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 | import pytorch_cuda_lbp_op as lbp
7 |
8 | class Inference(nn.Module):
9 | # modes = wta / expectation
10 | def __init__(self, device, mode='wta', mode_passing='min-sum'):
11 | super(Inference, self).__init__()
12 | self.device = device
13 | if mode != 'wta' and mode != 'expectation' and mode != 'norm' and mode != 'raw':
14 | raise ValueError("Unknown inference mode " + mode)
15 | self.mode = mode
16 | self.mode_passing = mode_passing
17 |
18 | def forward(self, beliefs):
19 | if self.mode == "wta" and self.mode_passing == "min-sum":
20 | res = torch.argmax(beliefs, dim=3, keepdim=True)
21 | if self.mode == "wta":
22 | res = torch.argmax(beliefs, dim=3, keepdim=True)
23 | elif self.mode == "expectation":
24 | beliefs_normal = beliefs / beliefs.sum(dim=3, keepdim=True)
25 |
26 | if torch.isnan(beliefs_normal).sum() > 0:
27 | print("Beliefs normalized contains " + str(torch.isnan(beliefs_normal).sum()) + " NaNs ;(")
28 |
29 | labels = np.arange(beliefs.shape[3])[np.newaxis, np.newaxis, np.newaxis, :]
30 | labels_tensor = torch.tensor(labels.astype('float32'), device=self.device)
31 | res = (beliefs_normal * labels_tensor).sum(dim=3, keepdim=True)
32 | elif self.mode == "norm":
33 | beliefs_normal = beliefs / beliefs.sum(dim=3, keepdim=True)
34 | res = beliefs_normal
35 | elif self.mode == "raw":
36 | #print("using raw inference...")
37 | res = beliefs
38 |
39 | return res
40 |
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw/message_passing_op_cuda.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import pytorch_cuda_lbp_op as lbp
7 |
8 | from corenet import TemperatureSoftmin
9 |
10 | def construct_pw_tensor(L1, L2, K):
11 | pw_tensor = torch.ones((K, K), dtype=torch.float) * L2
12 |
13 | for i in range(-1, 2):
14 | if i == 0:
15 | pw_tensor -= torch.diag(L2 * torch.ones(K - torch.abs(torch.tensor(i))), i)
16 | else:
17 | pw_tensor -= torch.diag((L2 - L1) * torch.ones(K - torch.abs(torch.tensor(i))), i)
18 |
19 | pw_tensor = pw_tensor.unsqueeze(0).unsqueeze(0)
20 |
21 | return pw_tensor
22 |
23 | class LBPMinSumFunction(torch.autograd.Function):
24 | @staticmethod
25 | def forward(ctx, cost, L1, L2, edge, messages, delta, jump, single_pw):
26 |
27 | if not single_pw.item():
28 | #print("NOT SINGLE!")
29 | jump_in = torch.zeros((jump.shape[0], jump.shape[1] + 2, jump.shape[2], jump.shape[3])).cuda()
30 | jump_in[0, 0] = jump[0,0]
31 | jump_in[0, 2] = jump[0,1]
32 | jump_in[0, 1] = jump[0,0].permute((1,0))
33 | jump_in[0, 3] = jump[0,1].permute((1,0))
34 | else:
35 | #print("SINGLE!")
36 | jump_in = torch.zeros((jump.shape[0], 4, jump.shape[2], jump.shape[3])).cuda()
37 | jump_in[0, 0] = jump[0,0]
38 | jump_in[0, 2] = jump[0,0]
39 | jump_in[0, 1] = jump[0,0]
40 | jump_in[0, 3] = jump[0,0]
41 |
42 | messages, messages_argmin, message_scale = lbp.forward_minsum(cost, jump_in, edge, messages, delta)
43 |
44 | ctx.save_for_backward(cost, jump_in, edge, messages, messages_argmin, message_scale, single_pw)
45 | return messages
46 |
47 | @staticmethod
48 | # @profile
49 | def backward(ctx, in_grad):
50 | cost, jump, edge, messages, messages_argmin, message_scale, single_pw = ctx.saved_tensors
51 |
52 | grad_cost, grad_jump, grad_edge, grad_message = lbp.backward_minsum(cost, jump.contiguous(), edge, in_grad.contiguous(), messages, messages_argmin, message_scale)
53 |
54 | L1_grad = None
55 | L2_grad = None
56 |
57 | if not single_pw.item():
58 | #print("NOT SINGLE!")
59 | grad_jump_out = torch.zeros((grad_jump.shape[0], grad_jump.shape[1] - 2, grad_jump.shape[2], grad_jump.shape[3])).cuda()
60 | grad_jump_out[0,0] += grad_jump[0,0]
61 | grad_jump_out[0,1] += grad_jump[0,2]
62 | grad_jump_out[0,0] += grad_jump[0,1].permute((1,0))
63 | grad_jump_out[0,1] += grad_jump[0,3].permute((1,0))
64 | else:
65 | #print("SINGLE!")
66 | grad_jump_out = torch.zeros((grad_jump.shape[0], 1, grad_jump.shape[2], grad_jump.shape[3])).cuda()
67 | grad_jump_out[0,0] += grad_jump[0,0]
68 | grad_jump_out[0,0] += grad_jump[0,1]
69 | grad_jump_out[0,0] += grad_jump[0,2]
70 | grad_jump_out[0,0] += grad_jump[0,3]
71 |
72 | return grad_cost, L1_grad, L2_grad, grad_edge, grad_message, None, grad_jump_out, None
73 |
74 |
75 | class MessagePassing(nn.Module):
76 | def __init__(self, device, max_iter, num_labels, delta, mode='min-sum'):
77 | super(MessagePassing, self).__init__()
78 | self.device = device
79 | self.max_iter = max_iter
80 |
81 | if mode != 'min-sum':
82 | raise ValueError("Unknown message parsing mode " + mode)
83 | self.mode = mode
84 |
85 | L1 = torch.tensor(0.1, device=device)
86 | L2 = torch.tensor(2.5, device=device)
87 | self.L1 = nn.Parameter(L1, requires_grad=True)
88 | self.L2 = nn.Parameter(L2, requires_grad=True)
89 |
90 | self.softmin = TemperatureSoftmin(dim=3, init_temp=1.0)
91 |
92 | self.delta = delta
93 | self.rescaleT = None
94 |
95 | def projectL1L2(self):
96 | self.L2.data = torch.max(self.L1.data, self.L2.data)
97 |
98 | def forward(self, prob_vol, edge_weights, messages, jump, single_pw=False):
99 | N, H, W, C = prob_vol.shape
100 | if edge_weights is None:
101 | edge_weights = torch.ones((N, 4, H, W))
102 |
103 | if self.mode == 'min-sum':
104 | # convert to cost-input
105 | cost = -prob_vol
106 |
107 | # perform message-passing iterations
108 | for it in range(self.max_iter):
109 | messages = LBPMinSumFunction.apply(cost, self.L1, self.L2, edge_weights, messages, self.delta, jump, torch.tensor(single_pw))
110 |
111 | # compute beliefs
112 | beliefs = messages.sum(dim=1) + cost
113 |
114 | # normalize output
115 | beliefs = self.softmin.forward(beliefs)
116 |
117 | else:
118 | raise NotImplementedError("message parsing mode " + self.mode + " is currently not implemented!")
119 |
120 | return beliefs
121 |
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | setup(
6 | name='pytorch-lbp-op',
7 | version='0.2',
8 | author="Patrick Knöbelreiter",
9 | author_email="knoebelreiter@icg.tugraz.at",
10 | packages=["src"],
11 | include_dirs=['../include/'],
12 | ext_modules=[
13 | CUDAExtension('pytorch_cuda_lbp_op', [
14 | 'src/lbp.cpp',
15 | 'src/lbp_min_sum_kernel.cu',
16 | ]),
17 | ],
18 | cmdclass={
19 | 'build_ext': BuildExtension
20 | },
21 | install_requires=[
22 | "numpy >= 1.15",
23 | "torch >= 0.4.1",
24 | "matplotlib >= 3.0.0",
25 | "scikit-image >= 0.14.1",
26 | "numba >= 0.42"
27 | ],
28 | )
29 |
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw/src/lbp.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "lbp_min_sum_kernel.cuh"
5 |
6 | // C++ interface
7 | // AT_ASSERTM in pytorch 1.0
8 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x);
11 |
12 |
13 | // ================================================================================================
14 | // MIN-SUM LBP
15 | // ================================================================================================
16 | std::vector lbp_forward_min_sum(at::Tensor cost,
17 | at::Tensor jump,
18 | at::Tensor edge,
19 | at::Tensor messages, unsigned short delta)
20 | {
21 | CHECK_INPUT(cost)
22 | CHECK_INPUT(jump)
23 | CHECK_INPUT(edge)
24 | CHECK_INPUT(messages)
25 |
26 | return cuda::lbp_forward_min_sum(cost, jump, edge, messages, delta);
27 | }
28 |
29 | std::vector lbp_backward_min_sum(at::Tensor cost,
30 | at::Tensor jump,
31 | at::Tensor edge,
32 | at::Tensor in_grad,
33 | at::Tensor messages,
34 | at::Tensor messages_argmin,
35 | at::Tensor message_scale)
36 | {
37 | CHECK_INPUT(cost)
38 | CHECK_INPUT(jump)
39 | CHECK_INPUT(edge)
40 | CHECK_INPUT(in_grad)
41 | CHECK_INPUT(messages)
42 | CHECK_INPUT(messages_argmin)
43 | CHECK_INPUT(message_scale)
44 |
45 | return cuda::lbp_backward_min_sum(cost, jump, edge, in_grad, messages, messages_argmin, message_scale);
46 | }
47 |
48 | // ================================================================================================
49 | // Pytorch Interfaces
50 | // ================================================================================================
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
52 | {
53 | m.def("forward_minsum", &lbp_forward_min_sum, "LBP forward (CUDA)");
54 | m.def("backward_minsum", &lbp_backward_min_sum, "LBP backward (CUDA)");
55 | }
56 |
57 |
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw/src/lbp_min_sum_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | namespace cuda
6 | {
7 | std::vector lbp_forward_min_sum(at::Tensor cost,
8 | at::Tensor jump,
9 | at::Tensor edge,
10 | at::Tensor messages, unsigned short delta);
11 |
12 | std::vector lbp_backward_min_sum(at::Tensor cost,
13 | at::Tensor jump,
14 | at::Tensor edge,
15 | at::Tensor in_grad,
16 | at::Tensor messages,
17 | at::Tensor messages_argmin,
18 | at::Tensor message_scale);
19 |
20 | }
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw/src/util.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | #include "../../include/tensor.h"
6 |
7 | enum DIRECTION {LEFT, RIGHT, UP, DOWN};
8 |
9 |
10 | extern __device__ __forceinline__ float getDerivativeValue(KernelData prev_values, int direction, float current_value, float max_float)
11 | {
12 | float ret_val = 0.0;
13 | if(prev_values(0, direction, 0, 0) != max_float)
14 | {
15 | ret_val = prev_values(0, direction, 0, 0);
16 | }
17 | else
18 | {
19 | ret_val = current_value;
20 | }
21 |
22 | return ret_val;
23 | }
24 |
25 |
26 | extern __device__ __forceinline__ float computeCrossGradient(KernelData prev_values, float max_float)
27 | {
28 |
29 | bool is_cross = false;
30 | if(prev_values(0,0,0,0) != max_float ||
31 | prev_values(0,1,0,0) != max_float ||
32 | prev_values(0,2,0,0) != max_float ||
33 | prev_values(0,3,0,0) != max_float)
34 | {
35 | is_cross = true;
36 | }
37 |
38 | return is_cross;
39 | }
40 |
41 | extern __device__ __forceinline__ float getGradientAcc(KernelData gradient_accumulation, int direction, int n, int y, int x, int c, int grad_acc_idx)
42 | {
43 |
44 | float ret_val = 0.0;
45 | if(direction == UP || direction == DOWN)
46 | {
47 | ret_val = gradient_accumulation(n, x, grad_acc_idx, c);
48 | }
49 | if(direction == LEFT || direction == RIGHT)
50 | {
51 | ret_val = gradient_accumulation(n, y, grad_acc_idx, c);
52 | }
53 |
54 | return ret_val;
55 |
56 | }
57 |
58 | extern __device__ __forceinline__ void updateGradientAcc(KernelData gradient_accumulation, float value, int direction, int n, int y, int x, int c, int grad_acc_idx)
59 | {
60 |
61 | if(direction == UP || direction == DOWN)
62 | {
63 | //gradient_accumulation(n, x, grad_acc_idx, c) = value;
64 | atomicAdd(&gradient_accumulation(n, x, grad_acc_idx, c), value);
65 | }
66 | if(direction == LEFT || direction == RIGHT)
67 | {
68 | //gradient_accumulation(n, y, grad_acc_idx, c) = value;
69 | atomicAdd(&gradient_accumulation(n, y, grad_acc_idx, c), value);
70 | }
71 |
72 | }
73 |
74 | extern __device__ __forceinline__ void setGradientAcc(KernelData gradient_accumulation, float value, int direction, int n, int y, int x, int c, int grad_acc_idx)
75 | {
76 |
77 | if(direction == UP || direction == DOWN)
78 | {
79 | gradient_accumulation(n, x, grad_acc_idx, c) = value;
80 | }
81 | if(direction == LEFT || direction == RIGHT)
82 | {
83 | gradient_accumulation(n, y, grad_acc_idx, c) = value;
84 | }
85 |
86 | }
87 |
88 | extern __device__ __forceinline__ float getEdgeWeight(KernelData edges, int n, int y, int x, int direction)
89 | {
90 |
91 | float w = 1.0;
92 | if(direction == UP || direction == DOWN)
93 | {
94 | w = edges(n, 1, y, x);
95 | }
96 | else
97 | {
98 | w = edges(n, 0, y, x);
99 | }
100 |
101 | return w;
102 | }
103 |
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw_pixel/bp_op_cuda.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 |
7 | from networks import SubNetwork
8 |
9 | from ops.lbp_semantic_pw_pixel.inference_op import Inference
10 | from ops.lbp_semantic_pw_pixel.message_passing_op_pw_pixel import MessagePassing
11 |
12 | class BP(SubNetwork):
13 | @staticmethod
14 | def construct_pw_energy_weights(K, L1, L2):
15 | '''K = num labels'''
16 | fij = np.ones((K, K)) * L2
17 | for i in range(-1, 2):
18 | if i == 0:
19 | fij -= np.diag(L2 * np.ones(K - np.abs(i)), i)
20 | else:
21 | fij -= np.diag((L2 - L1) * np.ones(K - np.abs(i)), i)
22 |
23 | return fij
24 |
25 | @staticmethod
26 | def construct_pw_prob_weights(K, L1, L2):
27 | '''K = num labels'''
28 | fij = np.zeros((K, K))
29 |
30 | for i in range(-1, 2):
31 | if i == 0:
32 | fij += np.diag(L1 * np.ones(K - np.abs(i)), i)
33 | else:
34 | fij += np.diag(L2 * np.ones(K - np.abs(i)), i)
35 |
36 | fij[fij == 0] = 1.0 - L1 - L2 #(1.0 - L1 - 2 * L2) / (K - 3.0) #1.0 - L1 - L2
37 |
38 | return fij
39 |
40 | # modes = wta / expectation
41 | def __init__(self, device, args, max_iter, num_labels, delta, mode_inference='expectation', mode_message_passing='min-sum', layer_idx=0):
42 | super(BP, self).__init__(args, device)
43 | self.device = device
44 | self.max_iter = max_iter
45 | self.layer_idx = layer_idx
46 | self.delta = delta
47 |
48 | print("init pixel wise bp...")
49 |
50 | if mode_inference != 'wta' and mode_inference != 'expectation' and mode_message_passing != 'min-sum' and mode_inference != 'norm' and mode_inference != 'raw':
51 | raise ValueError("Unknown inference/message passing mode " + mode_inference + " " + mode_message_passing)
52 |
53 | self.message_passing = MessagePassing(self.device, self.max_iter, num_labels, self.delta, mode_message_passing)
54 | self.inference = Inference(self.device, mode_inference, mode_passing=mode_message_passing)
55 |
56 |
57 | def forward(self, prob_vol, edge_weights, jump):
58 |
59 | N, H, W, K = prob_vol.shape
60 | messages = torch.zeros((N, 4, H, W, K), requires_grad=False, device=self.device,
61 | dtype=torch.float)
62 |
63 | # compute messages
64 | beliefs = self.message_passing.forward(prob_vol, edge_weights, messages, jump)
65 |
66 | # + wta/expectation
67 | result = self.inference.forward(beliefs)
68 |
69 | return result.permute(0,3,1,2), beliefs
70 |
71 | def project_jumpcosts(self):
72 | self.message_passing.projectL1L2()
73 |
74 | def save_checkpoint(self, epoch, iteration):
75 | if 'c' in self.args.train_params:
76 | torch.save(self.state_dict(),
77 | osp.join(self.args.train_dir, 'crf' + str(self.layer_idx) + '_checkpoint_' +
78 | str(epoch) + '_' + str(iteration).zfill(6) + '.cpt'))
79 |
80 | def adjust_input_weights(self, weights, idx):
81 | if weights is not None:
82 | weights_idx = weights[:, idx * 2 : (idx + 1) * 2, :, :]
83 |
84 | # wx_L = np.zeros_like(wx)
85 | # wy_D = np.zeros_like(wy)
86 | # wx_L[:, 1:] = wx[:, :-1]
87 | # wy_D[1:, :] = wy[:-1, :]
88 |
89 | weights_input = torch.zeros((weights_idx.shape[0], 4, weights_idx.shape[2], weights_idx.shape[3])).cuda()
90 |
91 | weights_input[:, 0] = weights[:, 0]
92 | # wx RL
93 | weights_input[:, 1, :, 1:] = weights[:, 0, :, :-1]
94 | # wy UD
95 | weights_input[:, 2] = weights[:, 1]
96 | # wy DU
97 | weights_input[:, 3, 1:, :] = weights[:, 1, :-1, :]
98 |
99 | weights_input = weights_input.contiguous()
100 |
101 | else:
102 | weights_input = None
103 |
104 | return weights_input
105 |
106 | @property
107 | def L1(self):
108 | return self.message_passing.L1
109 |
110 | @L1.setter
111 | def L1(self, value):
112 | self.message_passing.L1.data = torch.tensor(value, device=self.device, dtype=torch.float)
113 |
114 | @property
115 | def L2(self):
116 | return self.message_passing.L2
117 |
118 | @L2.setter
119 | def L2(self, value):
120 | self.message_passing.L2.data = torch.tensor(value, device=self.device, dtype=torch.float)
121 |
122 | @property
123 | def rescaleT(self):
124 | return self.message_passing.rescaleT
125 |
126 | @rescaleT.setter
127 | def rescaleT(self, value):
128 | self.message_passing.rescaleT = value
129 |
130 |
131 |
132 |
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw_pixel/inference_op.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 | import pytorch_cuda_lbp_op as lbp
7 |
8 | class Inference(nn.Module):
9 | # modes = wta / expectation
10 | def __init__(self, device, mode='wta', mode_passing='min-sum'):
11 | super(Inference, self).__init__()
12 | self.device = device
13 | if mode != 'wta' and mode != 'expectation' and mode != 'norm' and mode != 'raw':
14 | raise ValueError("Unknown inference mode " + mode)
15 | self.mode = mode
16 | self.mode_passing = mode_passing
17 |
18 | def forward(self, beliefs):
19 | if self.mode == "wta" and self.mode_passing == "min-sum":
20 | res = torch.argmax(beliefs, dim=3, keepdim=True)
21 | if self.mode == "wta":
22 | res = torch.argmax(beliefs, dim=3, keepdim=True)
23 | elif self.mode == "expectation":
24 | beliefs_normal = beliefs / beliefs.sum(dim=3, keepdim=True)
25 |
26 | if torch.isnan(beliefs_normal).sum() > 0:
27 | print("Beliefs normalized contains " + str(torch.isnan(beliefs_normal).sum()) + " NaNs ;(")
28 |
29 | labels = np.arange(beliefs.shape[3])[np.newaxis, np.newaxis, np.newaxis, :]
30 | labels_tensor = torch.tensor(labels.astype('float32'), device=self.device)
31 | res = (beliefs_normal * labels_tensor).sum(dim=3, keepdim=True)
32 | elif self.mode == "norm":
33 | beliefs_normal = beliefs / beliefs.sum(dim=3, keepdim=True)
34 | res = beliefs_normal
35 | elif self.mode == "raw":
36 | #print("using raw inference...")
37 | res = beliefs
38 |
39 | return res
40 |
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw_pixel/message_passing_op_pw_pixel.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import pytorch_cuda_lbp_op as lbp
7 |
8 | from corenet import TemperatureSoftmin
9 |
10 | def construct_pw_tensor(L1, L2, K):
11 | pw_tensor = torch.ones((K, K), dtype=torch.float) * L2
12 |
13 | for i in range(-1, 2):
14 | if i == 0:
15 | pw_tensor -= torch.diag(L2 * torch.ones(K - torch.abs(torch.tensor(i))), i)
16 | else:
17 | pw_tensor -= torch.diag((L2 - L1) * torch.ones(K - torch.abs(torch.tensor(i))), i)
18 |
19 | pw_tensor = pw_tensor.unsqueeze(0).unsqueeze(0)
20 |
21 | return pw_tensor
22 |
23 | class LBPMinSumFunction(torch.autograd.Function):
24 | @staticmethod
25 | def forward(ctx, cost, L1, L2, edge, messages, delta, jump):
26 |
27 | pw_net_jump_shift = torch.zeros((jump.shape[0] + 2, jump.shape[1], jump.shape[2], jump.shape[3], jump.shape[4])).cuda()
28 |
29 | # LR
30 | pw_net_jump_shift[0] = jump[0]
31 | # RL
32 | pw_net_jump_shift[1, :, 1:] = jump[0, :, :-1].permute((0, 1, 3, 2))
33 | # UD
34 | pw_net_jump_shift[2] = jump[1]
35 | # DU
36 | pw_net_jump_shift[3, 1:] = jump[1, :-1].permute((0, 1, 3, 2))
37 |
38 | pw_net_jump_shift = pw_net_jump_shift.contiguous()
39 |
40 | messages, messages_argmin, message_scale = lbp.forward_minsum(cost, pw_net_jump_shift, edge, messages, delta)
41 |
42 | ctx.save_for_backward(cost, edge, messages, messages_argmin, message_scale)
43 | return messages
44 |
45 | @staticmethod
46 | # @profile
47 | def backward(ctx, in_grad):
48 | cost, edge, messages, messages_argmin, message_scale = ctx.saved_tensors
49 |
50 | grad_cost, grad_jump, grad_edge, grad_message = lbp.backward_minsum(cost, edge, in_grad.contiguous(), messages, messages_argmin, message_scale)
51 |
52 | L1_grad = None
53 | L2_grad = None
54 |
55 | grad_jump[0, :, :-1] += grad_jump[1, :, 1:].permute((0, 1, 3, 2))
56 | grad_jump[1] = 0
57 | grad_jump[1] += grad_jump[2]
58 | grad_jump[1, :-1] += grad_jump[3, 1:].permute((0, 1, 3, 2))
59 |
60 | return grad_cost, L1_grad, L2_grad, grad_edge, grad_message, None, grad_jump[:2]
61 |
62 | class MessagePassing(nn.Module):
63 | def __init__(self, device, max_iter, num_labels, delta, mode='min-sum'):
64 | super(MessagePassing, self).__init__()
65 |
66 | self.device = device
67 | self.max_iter = max_iter
68 |
69 | if mode != 'min-sum':
70 | raise ValueError("Unknown message parsing mode " + mode)
71 | self.mode = mode
72 |
73 | L1 = torch.tensor(0.1, device=device)
74 | L2 = torch.tensor(2.5, device=device)
75 | self.L1 = nn.Parameter(L1, requires_grad=True)
76 | self.L2 = nn.Parameter(L2, requires_grad=True)
77 |
78 | self.softmin = TemperatureSoftmin(dim=3, init_temp=1.0)
79 |
80 | self.delta = delta
81 | self.rescaleT = None
82 |
83 | def projectL1L2(self):
84 | self.L2.data = torch.max(self.L1.data, self.L2.data)
85 |
86 | def forward(self, prob_vol, edge_weights, messages, jump):
87 |
88 | N, H, W, C = prob_vol.shape
89 | if edge_weights is None:
90 | edge_weights = torch.ones((N, 4, H, W))
91 |
92 | if self.mode == 'min-sum':
93 | # convert to cost-input
94 | cost = -prob_vol
95 |
96 | # perform message-passing iterations
97 | for it in range(self.max_iter):
98 | messages = LBPMinSumFunction.apply(cost, self.L1, self.L2, edge_weights, messages, self.delta, jump)
99 |
100 | # compute beliefs
101 | beliefs = messages.sum(dim=1) + cost
102 |
103 | # normalize output
104 | beliefs = self.softmin.forward(beliefs)
105 |
106 | else:
107 | raise NotImplementedError("message parsing mode " + self.mode + " is currently not implemented!")
108 |
109 | return beliefs
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw_pixel/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | setup(
6 | name='pytorch-lbp-op',
7 | version='0.2',
8 | author="Patrick Knöbelreiter",
9 | author_email="knoebelreiter@icg.tugraz.at",
10 | packages=["src"],
11 | include_dirs=[],
12 | ext_modules=[
13 | CUDAExtension('pytorch_cuda_lbp_op', [
14 | 'src/lbp.cpp',
15 | 'src/lbp_min_sum_kernel.cu',
16 | ]),
17 | ],
18 | cmdclass={
19 | 'build_ext': BuildExtension
20 | },
21 | install_requires=[
22 | "numpy >= 1.15",
23 | "torch >= 0.4.1",
24 | "matplotlib >= 3.0.0",
25 | "scikit-image >= 0.14.1",
26 | "numba >= 0.42"
27 | ],
28 | )
29 |
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw_pixel/src/lbp.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "lbp_min_sum_kernel.cuh"
5 |
6 | // C++ interface
7 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x);
10 |
11 |
12 | // ================================================================================================
13 | // MIN-SUM LBP
14 | // ================================================================================================
15 | std::vector lbp_forward_min_sum(at::Tensor cost,
16 | at::Tensor jump,
17 | at::Tensor edge,
18 | at::Tensor messages, unsigned short delta)
19 | {
20 | CHECK_INPUT(cost)
21 | CHECK_INPUT(jump)
22 | CHECK_INPUT(edge)
23 | CHECK_INPUT(messages)
24 |
25 | return cuda::lbp_forward_min_sum(cost, jump, edge, messages, delta);
26 | }
27 |
28 | std::vector lbp_backward_min_sum(at::Tensor cost,
29 | at::Tensor edge,
30 | at::Tensor in_grad,
31 | at::Tensor messages,
32 | at::Tensor messages_argmin,
33 | at::Tensor message_scale)
34 | {
35 | CHECK_INPUT(cost)
36 | //CHECK_INPUT(jump)
37 | CHECK_INPUT(edge)
38 | CHECK_INPUT(in_grad)
39 | CHECK_INPUT(messages)
40 | CHECK_INPUT(messages_argmin)
41 | CHECK_INPUT(message_scale)
42 |
43 | return cuda::lbp_backward_min_sum(cost, edge, in_grad, messages, messages_argmin, message_scale);
44 | }
45 |
46 | // ================================================================================================
47 | // Pytorch Interfaces
48 | // ================================================================================================
49 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
50 | {
51 | m.def("forward_minsum", &lbp_forward_min_sum, "LBP forward (CUDA)");
52 | m.def("backward_minsum", &lbp_backward_min_sum, "LBP backward (CUDA)");
53 | }
54 |
55 |
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw_pixel/src/lbp_min_sum_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | namespace cuda
6 | {
7 | std::vector lbp_forward_min_sum(at::Tensor cost,
8 | at::Tensor jump,
9 | at::Tensor edge,
10 | at::Tensor messages, unsigned short delta);
11 |
12 | std::vector lbp_backward_min_sum(at::Tensor cost,
13 | at::Tensor edge,
14 | at::Tensor in_grad,
15 | at::Tensor messages,
16 | at::Tensor messages_argmin,
17 | at::Tensor message_scale);
18 |
19 | }
--------------------------------------------------------------------------------
/ops/lbp_semantic_pw_pixel/src/util.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | #include "../../include/tensor.h"
6 |
7 | enum DIRECTION {LEFT, RIGHT, UP, DOWN};
8 |
9 |
10 | extern __device__ __forceinline__ float getDerivativeValue(KernelData prev_values, int direction, float current_value, float max_float)
11 | {
12 | float ret_val = 0.0;
13 | if(prev_values(0, direction, 0, 0) != max_float)
14 | {
15 | ret_val = prev_values(0, direction, 0, 0);
16 | }
17 | else
18 | {
19 | ret_val = current_value;
20 | }
21 |
22 | return ret_val;
23 | }
24 |
25 |
26 | extern __device__ __forceinline__ float computeCrossGradient(KernelData prev_values, float max_float)
27 | {
28 |
29 | bool is_cross = false;
30 | if(prev_values(0,0,0,0) != max_float ||
31 | prev_values(0,1,0,0) != max_float ||
32 | prev_values(0,2,0,0) != max_float ||
33 | prev_values(0,3,0,0) != max_float)
34 | {
35 | is_cross = true;
36 | }
37 |
38 | return is_cross;
39 | }
40 |
41 | extern __device__ __forceinline__ float getGradientAcc(KernelData gradient_accumulation, int direction, int n, int y, int x, int c, int grad_acc_idx)
42 | {
43 |
44 | float ret_val = 0.0;
45 | if(direction == UP || direction == DOWN)
46 | {
47 | ret_val = gradient_accumulation(n, x, grad_acc_idx, c);
48 | }
49 | if(direction == LEFT || direction == RIGHT)
50 | {
51 | ret_val = gradient_accumulation(n, y, grad_acc_idx, c);
52 | }
53 |
54 | return ret_val;
55 |
56 | }
57 |
58 | extern __device__ __forceinline__ void updateGradientAcc(KernelData gradient_accumulation, float value, int direction, int n, int y, int x, int c, int grad_acc_idx)
59 | {
60 |
61 | if(direction == UP || direction == DOWN)
62 | {
63 | //gradient_accumulation(n, x, grad_acc_idx, c) = value;
64 | atomicAdd(&gradient_accumulation(n, x, grad_acc_idx, c), value);
65 | }
66 | if(direction == LEFT || direction == RIGHT)
67 | {
68 | //gradient_accumulation(n, y, grad_acc_idx, c) = value;
69 | atomicAdd(&gradient_accumulation(n, y, grad_acc_idx, c), value);
70 | }
71 |
72 | }
73 |
74 | extern __device__ __forceinline__ void setGradientAcc(KernelData gradient_accumulation, float value, int direction, int n, int y, int x, int c, int grad_acc_idx)
75 | {
76 |
77 | if(direction == UP || direction == DOWN)
78 | {
79 | gradient_accumulation(n, x, grad_acc_idx, c) = value;
80 | }
81 | if(direction == LEFT || direction == RIGHT)
82 | {
83 | gradient_accumulation(n, y, grad_acc_idx, c) = value;
84 | }
85 |
86 | }
87 |
88 | extern __device__ __forceinline__ float getEdgeWeight(KernelData edges, int n, int y, int x, int direction)
89 | {
90 |
91 | float w = 1.0;
92 | if(direction == UP || direction == DOWN)
93 | {
94 | w = edges(n, 1, y, x);
95 | }
96 | else
97 | {
98 | w = edges(n, 0, y, x);
99 | }
100 |
101 | return w;
102 | }
103 |
--------------------------------------------------------------------------------
/ops/lbp_stereo/bp_op_cuda.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 |
7 | from networks import SubNetwork
8 |
9 | import pytorch_cuda_lbp_op as lbp
10 | from ops.lbp_stereo.inference_op import Inference
11 | from ops.lbp_stereo.message_passing_op_cuda import MessagePassing
12 |
13 | import numba
14 |
15 | class BP(SubNetwork):
16 |
17 | @staticmethod
18 | def get_linear_idx(pos, L_vec_size):
19 | vec_idx = L_vec_size // 2 + pos
20 | vec_idx = np.maximum(1, vec_idx)
21 | vec_idx = np.minimum(L_vec_size - 1, vec_idx)
22 |
23 | return vec_idx
24 |
25 | @staticmethod
26 | def f_func(t, s, L_vec):
27 |
28 | o = L_vec[0]
29 |
30 | input_pos = t - s + o
31 |
32 | lower_pos = int(np.floor(input_pos))
33 | upper_pos = int(np.ceil(input_pos))
34 |
35 | if lower_pos == upper_pos:
36 | vec_idx = BP.get_linear_idx(lower_pos, L_vec.shape[0])
37 | #print(L_vec[vec_idx])
38 | return L_vec[vec_idx]
39 |
40 | lower_vec_idx = BP.get_linear_idx(lower_pos, L_vec.shape[0])
41 | upper_vec_idx = BP.get_linear_idx(upper_pos, L_vec.shape[0])
42 |
43 | lower_val = L_vec[lower_vec_idx]
44 | upper_val = L_vec[upper_vec_idx]
45 |
46 | weight_upper = input_pos - lower_pos
47 | weight_lower = upper_pos - input_pos
48 |
49 | interp_val = weight_lower * lower_val + weight_upper * upper_val
50 |
51 | return interp_val
52 |
53 | @staticmethod
54 | def construct_pw_energy_weights(K, L1, L2):
55 | '''K = num labels'''
56 | fij = np.ones((K, K)) * L2
57 | for i in range(-1, 2):
58 | if i == 0:
59 | fij -= np.diag(L2 * np.ones(K - np.abs(i)), i)
60 | else:
61 | fij -= np.diag((L2 - L1) * np.ones(K - np.abs(i)), i)
62 |
63 | return fij
64 |
65 | @staticmethod
66 | def construct_pw_prob_weights(K, L1, L2):
67 | '''K = num labels'''
68 | fij = np.zeros((K, K))
69 |
70 | for i in range(-1, 2):
71 | if i == 0:
72 | fij += np.diag(L1 * np.ones(K - np.abs(i)), i)
73 | else:
74 | fij += np.diag(L2 * np.ones(K - np.abs(i)), i)
75 |
76 | fij[fij == 0] = 1.0 - L1 - L2 #(1.0 - L1 - 2 * L2) / (K - 3.0) #1.0 - L1 - L2
77 |
78 | return fij
79 |
80 | # modes = wta / expectation
81 | def __init__(self, device, args, max_iter, num_labels, delta, mode_inference='expectation', mode_message_passing='min-sum', layer_idx=0, level=0):
82 | super(BP, self).__init__(args, device)
83 | self.device = device
84 | self.max_iter = max_iter
85 | self.layer_idx = layer_idx
86 | self.level = level
87 | self.delta = delta
88 |
89 | if mode_inference != 'wta' and mode_inference != 'expectation' and mode_message_passing != 'min-sum' and mode_inference != 'norm' and mode_inference != 'raw':
90 | raise ValueError("Unknown inference/message passing mode " + mode_inference + " " + mode_message_passing)
91 |
92 | self.message_passing = MessagePassing(self.device, self.max_iter, num_labels, delta, mode_message_passing)
93 | self.inference = Inference(self.device, mode_inference, mode_passing=mode_message_passing)
94 |
95 | if args.checkpoint_crf and args and 'checkpoint_crf' in args.__dict__.keys() and args.checkpoint_crf[0] is not None:
96 | self.load_parameters(args.checkpoint_crf[self.level][self.layer_idx], device)
97 |
98 | def forward(self, prob_vol, edge_weights, affinities = None, offsets = None):
99 |
100 | if len(prob_vol.shape) == 5:
101 | N, _, H, W, K = prob_vol.shape
102 |
103 | # u-flow
104 | messages_u = torch.zeros((N, 4, H, W, K), requires_grad=True, device=self.device, dtype=torch.float)
105 | beliefs_u, messages_u = self.message_passing.forward(prob_vol[:,0].contiguous(), edge_weights, affinities, offsets, messages_u)
106 | result_u = self.inference.forward(beliefs_u)
107 |
108 | # v-flow
109 | messages_v = torch.zeros((N, 4, H, W, K), requires_grad=True, device=self.device, dtype=torch.float)
110 | beliefs_v, messages_v = self.message_passing.forward(prob_vol[:,1].contiguous(), edge_weights, affinities, offsets, messages_u)
111 | result_v = self.inference.forward(beliefs_v)
112 |
113 | flow = torch.cat((result_u, result_v), dim=-1).permute(0,3,1,2)
114 | beliefs = torch.cat((beliefs_u.unsqueeze(1), beliefs_v.unsqueeze(1)), dim=1)
115 | messages = torch.cat((messages_u.unsqueeze(1), messages_v.unsqueeze(1)), dim=1)
116 |
117 | return flow, beliefs, messages
118 | else: # 4
119 | N, H, W, K = prob_vol.shape
120 |
121 | # disps
122 | messages = torch.zeros((N, 4, H, W, K), requires_grad=True, device=self.device, dtype=torch.float)
123 | beliefs, messages = self.message_passing.forward(prob_vol.contiguous(), edge_weights, affinities, offsets, messages)
124 | result = self.inference.forward(beliefs)
125 |
126 | disps = result.permute(0,3,1,2)
127 | return disps, beliefs, messages
128 |
129 |
130 | def adjust_input_weights(self, weights, idx):
131 | if weights is not None:
132 | weights_idx = weights[:, idx * 2 : (idx + 1) * 2, :, :]
133 |
134 | # wx_L = np.zeros_like(wx)
135 | # wy_D = np.zeros_like(wy)
136 | # wx_L[:, 1:] = wx[:, :-1]
137 | # wy_D[1:, :] = wy[:-1, :]
138 |
139 | weights_input = torch.zeros((weights_idx.shape[0], 4, weights_idx.shape[2], weights_idx.shape[3])).cuda()
140 |
141 | weights_input[:, 0] = weights[:, 0]
142 | # wx RL
143 | weights_input[:, 1, :, 1:] = weights[:, 0, :, :-1]
144 | # wy UD
145 | weights_input[:, 2] = weights[:, 1]
146 | # wy DU
147 | weights_input[:, 3, 1:, :] = weights[:, 1, :-1, :]
148 |
149 | weights_input = weights_input.contiguous()
150 |
151 | else:
152 | weights_input = None
153 |
154 | return weights_input
155 |
156 | def adjust_input_affinities(self, affinities):
157 | # create affinities for 4 directions
158 | if affinities is not None:
159 | # outshape = N x 2, 5 x H x W
160 | # ensure ordering constraint # L2-, L2+, L1-, L1+, L3
161 | # L3 >= L2 <= L1
162 | affinities_new = affinities.clone()
163 | affinities_new[:, :, 0] = torch.max(affinities[:, :, 0], affinities[:, :, 2])
164 | affinities_new[:, :, 1] = torch.max(affinities[:, :, 1], affinities[:, :, 3])
165 | affinities_new[:, :, 4] = torch.max(affinities[:, :, 4],
166 | torch.max(affinities_new[:, :, 0].clone(),
167 | affinities_new[:, :, 1].clone()))
168 |
169 | affinities = affinities_new
170 |
171 |
172 | # shifted affinities
173 | affinities_shift = torch.zeros((affinities.shape[0], affinities.shape[1] + 2, affinities.shape[2], affinities.shape[3], affinities.shape[4])).cuda()
174 |
175 | # ax LR
176 | affinities_shift[:, 0] = affinities[:, 0]
177 | # ax RL
178 | affinities_shift[:, 1, :, :, 1:] = affinities[:, 0, :, :, :-1]
179 | # ay UD
180 | affinities_shift[:, 2] = affinities[:, 1]
181 | # ay DU
182 | affinities_shift[:, 3, :, 1:, :] = affinities[:, 1, :, :-1, :]
183 |
184 | affinities_shift = affinities_shift.contiguous()
185 | else:
186 | affinities_shift = None
187 |
188 | return affinities_shift
189 |
190 | def adjust_input_offsets(self, offsets):
191 | # create offsets for 4 directions
192 | if offsets is not None:
193 | # shifted offsets
194 | offsets_shift = torch.zeros((offsets.shape[0], offsets.shape[1] + 2, offsets.shape[2], offsets.shape[3])).cuda()
195 |
196 | # ox LR
197 | offsets_shift[:, 0] = offsets[:, 0]
198 | # ox RL
199 | offsets_shift[:, 1, :, 1:] = -offsets[:, 0, :, :-1]
200 | # oy UD
201 | offsets_shift[:, 2] = offsets[:, 1]
202 | # oy DU
203 | offsets_shift[:, 3, 1:, :] = -offsets[:, 1, :-1, :]
204 |
205 | offsets_shift = offsets_shift.contiguous()
206 | else:
207 | offsets_shift = None
208 |
209 | return offsets_shift
210 |
211 |
212 | def project_jumpcosts(self):
213 | self.message_passing.projectL1L2()
214 |
215 | def save_checkpoint(self, epoch, iteration):
216 | if 'c' in self.args.train_params:
217 | torch.save(self.state_dict(),
218 | osp.join(self.args.train_dir, 'crf' + str(self.layer_idx) + '_lvl' + str(self.level) + '_checkpoint_' +
219 | str(epoch) + '_' + str(iteration).zfill(6) + '.cpt'))
220 |
221 | def hook_adjust_checkpoint(self, checkpoint):
222 | # allow to continue a training where the temperature parameter after the BP did not exist
223 | if 'message_passing.softmin.T' not in checkpoint.keys():
224 | print('Info: Adjust loaded BP-checkpoint -> Use temperature T=1 (=fully backward compatible)')
225 | checkpoint['message_passing.softmin.T'] = self.message_passing.softmin.T
226 |
227 | if 'message_passing.L1' in checkpoint.keys() and 'message_passing.L2' in checkpoint.keys():
228 | print('Info: Adjust loaded BP-checkpoint -> Use setter for L1 L2')
229 | #L1 = nn.Parameter(checkpoint['message_passing.L1'])
230 | #L2 = nn.Parameter(checkpoint['message_passing.L2'])
231 | #self.setL1L2(L1, L2)
232 | checkpoint['message_passing._L1'] = checkpoint['message_passing.L1']
233 | checkpoint['message_passing._L2'] = checkpoint['message_passing.L2']
234 | checkpoint.pop('message_passing.L1')
235 | checkpoint.pop('message_passing.L2')
236 |
237 | return checkpoint
238 |
239 | @property
240 | def L1(self):
241 | return self.message_passing.L1
242 |
243 | @property
244 | def L2(self):
245 | return self.message_passing.L2
246 |
247 | def setL1L2(self, value_L1, value_L2):
248 | self.message_passing.setL1L2(value_L1, value_L2)
249 |
250 | @property
251 | def rescaleT(self):
252 | return self.message_passing.rescaleT
253 |
254 | @rescaleT.setter
255 | def rescaleT(self, value):
256 | self.message_passing.rescaleT = value
257 |
258 |
259 |
260 |
--------------------------------------------------------------------------------
/ops/lbp_stereo/inference_op.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 | import pytorch_cuda_lbp_op as lbp
7 |
8 | class Inference(nn.Module):
9 | # modes = wta / expectation
10 | def __init__(self, device, mode='wta', mode_passing='min-sum'):
11 | super(Inference, self).__init__()
12 | self.device = device
13 | if mode != 'wta' \
14 | and mode != 'expectation' \
15 | and mode != 'norm' \
16 | and mode != 'raw' \
17 | and mode != 'sub-exp':
18 | raise ValueError("Unknown inference mode " + mode)
19 | self.mode = mode
20 | self.mode_passing = mode_passing
21 |
22 | def forward(self, beliefs):
23 | if self.mode == "wta" and self.mode_passing == "min-sum":
24 | res = torch.argmax(beliefs, dim=3, keepdim=True).float()
25 |
26 | if self.mode == "wta":
27 | res = torch.argmax(beliefs, dim=3, keepdim=True)
28 | elif self.mode == "expectation":
29 | beliefs_normal = beliefs / beliefs.sum(dim=3, keepdim=True)
30 |
31 | if torch.isnan(beliefs_normal).sum() > 0:
32 | print("Beliefs normalized contains " + str(torch.isnan(beliefs_normal).sum()) + \
33 | " NaNs ;(")
34 |
35 | labels = np.arange(beliefs.shape[3])[np.newaxis, np.newaxis, np.newaxis, :]
36 | labels_tensor = torch.tensor(labels.astype('float32'), device=self.device)
37 | res = (beliefs_normal * labels_tensor).sum(dim=3, keepdim=True)
38 |
39 | elif self.mode == "norm":
40 | beliefs_normal = beliefs / beliefs.sum(dim=3, keepdim=True)
41 | res = beliefs_normal
42 |
43 | elif self.mode == "raw":
44 | #print("using raw inference...")
45 | res = beliefs
46 |
47 | elif self.mode == 'sub-exp':
48 | res = self.compute_sub_expectation(beliefs)
49 |
50 | return res
51 |
52 | @staticmethod
53 | def compute_sub_expectation(beliefs, support=3):
54 | N, H, W, K = beliefs.shape
55 | device = beliefs.device
56 |
57 | disp = beliefs.argmax(dim=-1).unsqueeze(-1)
58 |
59 | # generate coordinates
60 | n_coords = torch.arange(N, device=device, dtype=torch.long)
61 | n_coords = n_coords.view(-1, 1, 1, 1)
62 |
63 | x_coords = torch.arange(W, device=device, dtype=torch.long).view(1, 1, -1, 1)
64 | y_coords = torch.arange(H, device=device, dtype=torch.long).view(1, -1, 1, 1)
65 |
66 | # nl = n_coords.expand((N, H, W, K)).long()
67 | # xl = x_coords.expand((N, H, W, K)).long()
68 | # yl = y_coords.expand((N, H, W, K)).long()
69 |
70 | #disp_multiple_hot = torch.zeros((N, H, W, K), device=device, dtype=torch.float)
71 | #torch.cuda.empty_cache()
72 | #for offset in range(-support, support + 1):
73 | # disp_offs = torch.min(torch.max(disp + offset, torch.tensor(0, device=device)),
74 | # torch.tensor(K - 1, device=device))
75 | # print(offset)
76 | # disp_multiple_hot[nl, yl, xl, disp_offs] = 1
77 | # torch.cuda.empty_cache()
78 |
79 | # disps_range = torch.arange(K, device=device, dtype=torch.float).view(1, 1, 1, -1)
80 |
81 | # beliefs_max = beliefs * disp_multiple_hot
82 | # beliefs_max_normalized = beliefs_max / beliefs_max.sum(dim=-1, keepdim=True)
83 | # disp_subpix = torch.sum(beliefs_max_normalized * disps_range, dim=-1, keepdim=True)
84 |
85 | # reduces GPU memory requirement significantly
86 | ws = 2 * support + 1
87 | nl = n_coords.expand((N, H, W, ws)).long()
88 | xl = x_coords.expand((N, H, W, ws)).long()
89 | yl = y_coords.expand((N, H, W, ws)).long()
90 |
91 | disp_windows = torch.arange(-support, support + 1).cuda() + disp
92 | disp_windows = torch.min(torch.max(disp_windows, torch.tensor(0, device=device)), torch.tensor(K - 1, device=device))
93 | beliefs_windows = beliefs[nl, yl, xl, disp_windows]
94 | beliefs_windows_normalized = beliefs_windows / beliefs_windows.sum(dim=-1, keepdim=True)
95 | disp_subpix = torch.sum(beliefs_windows_normalized * disp_windows, dim=-1, keepdim=True)
96 |
97 | return disp_subpix
98 |
--------------------------------------------------------------------------------
/ops/lbp_stereo/message_passing_op_cuda.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from corenet import TemperatureSoftmin
6 |
7 | import pytorch_cuda_lbp_op as lbp
8 |
9 | def construct_pw_tensor(L1, L2, K):
10 | pw_tensor = torch.ones((K, K), dtype=torch.float) * L2
11 |
12 | for i in range(-1, 2):
13 | if i == 0:
14 | pw_tensor -= torch.diag(L2 * torch.ones(K - torch.abs(torch.tensor(i))), i)
15 | else:
16 | pw_tensor -= torch.diag((L2 - L1) * torch.ones(K - torch.abs(torch.tensor(i))), i)
17 |
18 | pw_tensor = pw_tensor.unsqueeze(0).unsqueeze(0)
19 |
20 | return pw_tensor
21 |
22 | class LBPMinSumFunction(torch.autograd.Function):
23 | @staticmethod
24 | def forward(ctx, cost, edge, messages, delta, affinities, offset):
25 |
26 | #print("Min sum forward OP......")
27 |
28 | #jump = construct_pw_tensor(L1, L2, cost.shape[3]).cuda()
29 |
30 | # affinities shape = N x 2*2 x 5 x H x W dim1 = 4 => already shifted
31 | # affinties dim1 = LR, RL, UD, DU
32 | # affinities dim2 = # L2-, L2+, L1-, L1+, L3
33 | # affiniteis_input shape = N x 2*2 x 5+3 x H x W
34 | # affinities_input dim2 = # offset, L3, L2-, L1-, L0, L1+, L2+, L3
35 | affinities_input = torch.zeros((affinities.shape[0], affinities.shape[1], affinities.shape[2] + 3, affinities.shape[3], affinities.shape[4])).cuda()
36 |
37 | affinities_input[:, :, 0, :, :] = offset
38 |
39 | # L3 is symmetric => assign all directions
40 | affinities_input[:, :, 1, :, :] = affinities[:, :, -1, :, :]
41 | affinities_input[:, :, -1, :, :] = affinities[:, :, -1, :, :]
42 |
43 | # L0
44 | affinities_input[:, :, int(affinities_input.shape[2] / 2), :, :] = 0.0
45 |
46 | # everything else
47 | in_count = 0 # affinities counter
48 | for i in range(2, int(affinities_input.shape[2] / 2)): # i = affinities_input counter
49 | # copy original ordering for LR and UD (first negative, then positive)
50 | affinities_input[:, 0::2, i, :, :] = affinities[:, 0::2, in_count, :, :]
51 | affinities_input[:, 0::2, -i, :, :] = affinities[:, 0::2, in_count + 1, :, :]
52 |
53 | # copy reverse ordering for RL and DU (first positive, then negative)
54 | affinities_input[:, 1::2, i, :, :] = affinities[:, 1::2, in_count + 1, :, :]
55 | affinities_input[:, 1::2, -i, :, :] = affinities[:, 1::2, in_count, :, :]
56 | in_count += 2
57 |
58 | #print(affinities_input[0, 0, :, 0, 0])
59 |
60 | affinities_input = affinities_input.contiguous()
61 | torch.cuda.empty_cache()
62 | messages, messages_argmin, message_scale = lbp.forward_minsum(cost, affinities_input, edge, messages, delta)
63 |
64 | ctx.save_for_backward(cost, affinities_input, edge, messages, messages_argmin, message_scale)
65 | return messages
66 |
67 | @staticmethod
68 | # @profile
69 | def backward(ctx, in_grad):
70 | cost, affinities_input, edge, messages, messages_argmin, message_scale = ctx.saved_tensors
71 |
72 | grad_cost, grad_affinities_input, grad_edge, grad_message = lbp.backward_minsum(cost, affinities_input, edge, in_grad.contiguous(), messages, messages_argmin, message_scale)
73 |
74 | #re-compute affinities grad for all learned params
75 | grad_affinities_out = torch.zeros((affinities_input.shape[0], affinities_input.shape[1], affinities_input.shape[2] - 3, affinities_input.shape[3], affinities_input.shape[4])).cuda()
76 |
77 | # sum up grad L3
78 | grad_affinities_out[:, :, -1, :, :] += grad_affinities_input[:, :, 1, :, :]
79 | grad_affinities_out[:, :, -1, :, :] += grad_affinities_input[:, :, -1, :, :]
80 |
81 | in_count = 0
82 | for i in range(2, int(affinities_input.shape[2] / 2)):
83 | grad_affinities_out[:, 0::2, in_count, :, :] = grad_affinities_input[:, 0::2, i, :, :]
84 | grad_affinities_out[:, 0::2, in_count + 1, :, :] = grad_affinities_input[:, 0::2, -i, :, :]
85 |
86 | grad_affinities_out[:, 1::2, in_count + 1, :, :] = grad_affinities_input[:, 1::2, i, :, :]
87 | grad_affinities_out[:, 1::2, in_count, :, :] = grad_affinities_input[:, 1::2, -i, :, :]
88 | in_count += 2
89 |
90 | #offset grad
91 | grad_offset = grad_affinities_input[:, :, 0, :, :]
92 |
93 | return grad_cost, grad_edge, grad_message, None, grad_affinities_out, grad_offset
94 |
95 | class MessagePassing(nn.Module):
96 |
97 | @property
98 | def L1(self):
99 | return self._L1
100 |
101 | @property
102 | def L2(self):
103 | return self._L2
104 |
105 | def setL1L2(self, value_L1, value_L2):
106 |
107 | if value_L1 > 0 and value_L2 > 0 and value_L1 <= value_L2:
108 | self._L1 = value_L1
109 | self._L2 = value_L2
110 | elif value_L1 < 0 or value_L2 < 0:
111 | raise ValueError("L1 or L2 is < 0!")
112 | elif value_L1 > value_L2:
113 | raise ValueError("L1 must be smaller than or equal L2!")
114 |
115 | def __init__(self, device, max_iter, num_labels, delta, mode='min-sum'):
116 | super(MessagePassing, self).__init__()
117 | self.device = device
118 | self.max_iter = max_iter
119 |
120 | if mode != 'min-sum':
121 | raise ValueError("Unknown message parsing mode " + mode)
122 | self.mode = mode
123 |
124 | L1 = torch.tensor(0.1, device=device)
125 | L2 = torch.tensor(2.5, device=device)
126 | self._L1 = nn.Parameter(L1, requires_grad=True)
127 | self._L2 = nn.Parameter(L2, requires_grad=True)
128 |
129 | self.softmin = TemperatureSoftmin(dim=3, init_temp=1.0)
130 |
131 | self.delta = delta
132 |
133 | self.rescaleT = None
134 |
135 | def projectL1L2(self):
136 | self.L2.data = torch.max(self.L1.data, self.L2.data)
137 |
138 | def forward(self, prob_vol, edge_weights, affinities, offset, messages):
139 | N, H, W, C = prob_vol.shape
140 |
141 | NUM_DIR = 4
142 |
143 | if edge_weights is None:
144 | edge_weights = torch.ones((N, NUM_DIR, H, W)).cuda()
145 |
146 | if affinities is None:
147 | if (self._L1 is not None) and (self._L2 is not None):
148 | # parameters are expected as follows L1-left L1-right L2
149 | affinities = torch.zeros((N, NUM_DIR, 3, H, W), dtype=torch.float).cuda()
150 | affinities[:, :, :2, :, :] = self._L1
151 | affinities[:, :, 2, :, :] = self._L2
152 |
153 | else:
154 | if (self._L1 is not None and self._L2 is None) or (self._L1 is None and self._L2 is not None):
155 | raise ValueError("L1 or L2 is None and affinities are not set!")
156 |
157 | if offset is None:
158 | offset = torch.zeros((N, NUM_DIR, H, W))
159 |
160 | if self.mode == 'min-sum':
161 | # convert to cost-input
162 | cost = -prob_vol
163 |
164 | # perform message-passing iterations
165 | for it in range(self.max_iter):
166 | torch.cuda.empty_cache()
167 | messages = LBPMinSumFunction.apply(cost, edge_weights, messages, self.delta, affinities, offset)
168 |
169 | # compute beliefs
170 | beliefs = messages.sum(dim=1) + cost
171 |
172 | # normalize output
173 | beliefs = self.softmin.forward(beliefs)
174 |
175 | else:
176 | raise NotImplementedError("message parsing mode " + self.mode + " is currently not implemented!")
177 |
178 | return beliefs, messages
--------------------------------------------------------------------------------
/ops/lbp_stereo/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | setup(
6 | name='pytorch-lbp-op',
7 | version='0.2',
8 | author="Patrick Knöbelreiter",
9 | author_email="knoebelreiter@icg.tugraz.at",
10 | packages=["src"],
11 | include_dirs=[],
12 | ext_modules=[
13 | CUDAExtension('pytorch_cuda_lbp_op', [
14 | 'src/lbp.cpp',
15 | 'src/lbp_min_sum_kernel.cu',
16 | ]),
17 | ],
18 | cmdclass={
19 | 'build_ext': BuildExtension
20 | },
21 | install_requires=[
22 | "numpy >= 1.15",
23 | "torch >= 0.4.1",
24 | "matplotlib >= 3.0.0",
25 | "scikit-image >= 0.14.1",
26 | "numba >= 0.42"
27 | ],
28 | )
29 |
--------------------------------------------------------------------------------
/ops/lbp_stereo/src/lbp.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "lbp_min_sum_kernel.cuh"
5 |
6 | // C++ interface
7 | // AT_ASSERTM in pytorch 1.0
8 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x);
11 |
12 |
13 | // ================================================================================================
14 | // MIN-SUM LBP
15 | // ================================================================================================
16 | std::vector lbp_forward_min_sum(at::Tensor cost,
17 | at::Tensor jump,
18 | at::Tensor edge,
19 | at::Tensor messages, unsigned short delta)
20 | {
21 | CHECK_INPUT(cost)
22 | CHECK_INPUT(jump)
23 | CHECK_INPUT(edge)
24 | CHECK_INPUT(messages)
25 |
26 | return cuda::lbp_forward_min_sum(cost, jump, edge, messages, delta);
27 | }
28 |
29 | std::vector lbp_backward_min_sum(at::Tensor cost,
30 | at::Tensor jump,
31 | at::Tensor edge,
32 | at::Tensor in_grad,
33 | at::Tensor messages,
34 | at::Tensor messages_argmin,
35 | at::Tensor message_scale)
36 | {
37 | CHECK_INPUT(cost)
38 | CHECK_INPUT(jump)
39 | CHECK_INPUT(edge)
40 | CHECK_INPUT(in_grad)
41 | CHECK_INPUT(messages)
42 | CHECK_INPUT(messages_argmin)
43 | CHECK_INPUT(message_scale)
44 |
45 | return cuda::lbp_backward_min_sum(cost, jump, edge, in_grad, messages, messages_argmin, message_scale);
46 | }
47 |
48 | // ================================================================================================
49 | // Pytorch Interfaces
50 | // ================================================================================================
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
52 | {
53 | m.def("forward_minsum", &lbp_forward_min_sum, "LBP forward (CUDA)");
54 | m.def("backward_minsum", &lbp_backward_min_sum, "LBP backward (CUDA)");
55 | }
56 |
57 |
--------------------------------------------------------------------------------
/ops/lbp_stereo/src/lbp_min_sum_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | namespace cuda
6 | {
7 | std::vector lbp_forward_min_sum(at::Tensor cost,
8 | at::Tensor jump,
9 | at::Tensor edge,
10 | at::Tensor messages, unsigned short delta);
11 |
12 | std::vector lbp_backward_min_sum(at::Tensor cost,
13 | at::Tensor jump,
14 | at::Tensor edge,
15 | at::Tensor in_grad,
16 | at::Tensor messages,
17 | at::Tensor messages_argmin,
18 | at::Tensor message_scale);
19 |
20 | }
--------------------------------------------------------------------------------
/ops/lbp_stereo/src/util.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | #include "../../include/tensor.h"
6 |
7 | enum DIRECTION {LEFT, RIGHT, UP, DOWN};
8 |
9 | extern __device__ __forceinline__ float getDerivativeValue(KernelData prev_values, int direction, float current_value, float max_float)
10 | {
11 | float ret_val = 0.0;
12 | if(prev_values(0, direction, 0, 0) != max_float)
13 | {
14 | ret_val = prev_values(0, direction, 0, 0);
15 | }
16 | else
17 | {
18 | ret_val = current_value;
19 | }
20 |
21 | return ret_val;
22 | }
23 |
24 |
25 | extern __device__ __forceinline__ float computeCrossGradient(KernelData prev_values, float max_float)
26 | {
27 |
28 | bool is_cross = false;
29 | if(prev_values(0,0,0,0) != max_float ||
30 | prev_values(0,1,0,0) != max_float ||
31 | prev_values(0,2,0,0) != max_float ||
32 | prev_values(0,3,0,0) != max_float)
33 | {
34 | is_cross = true;
35 | }
36 |
37 | return is_cross;
38 | }
39 |
40 | extern __device__ __forceinline__ float getGradientAcc(KernelData gradient_accumulation, int direction, int n, int y, int x, int c, int grad_acc_idx)
41 | {
42 |
43 | float ret_val = 0.0;
44 | if(direction == UP || direction == DOWN)
45 | {
46 | ret_val = gradient_accumulation(n, x, grad_acc_idx, c);
47 | }
48 | if(direction == LEFT || direction == RIGHT)
49 | {
50 | ret_val = gradient_accumulation(n, y, grad_acc_idx, c);
51 | }
52 |
53 | return ret_val;
54 |
55 | }
56 |
57 | extern __device__ __forceinline__ void updateGradientAcc(KernelData gradient_accumulation, float value, int direction, int n, int y, int x, int c, int grad_acc_idx)
58 | {
59 |
60 | if(direction == UP || direction == DOWN)
61 | {
62 | //gradient_accumulation(n, x, grad_acc_idx, c) = value;
63 | atomicAdd(&gradient_accumulation(n, x, grad_acc_idx, c), value);
64 | }
65 | if(direction == LEFT || direction == RIGHT)
66 | {
67 | //gradient_accumulation(n, y, grad_acc_idx, c) = value;
68 | atomicAdd(&gradient_accumulation(n, y, grad_acc_idx, c), value);
69 | }
70 |
71 | }
72 |
73 | extern __device__ __forceinline__ void setGradientAcc(KernelData gradient_accumulation, float value, int direction, int n, int y, int x, int c, int grad_acc_idx)
74 | {
75 |
76 | if(direction == UP || direction == DOWN)
77 | {
78 | gradient_accumulation(n, x, grad_acc_idx, c) = value;
79 | }
80 | if(direction == LEFT || direction == RIGHT)
81 | {
82 | gradient_accumulation(n, y, grad_acc_idx, c) = value;
83 | }
84 |
85 | }
86 |
87 | extern __device__ __forceinline__ float getEdgeWeight(KernelData edges, int n, int y, int x, int direction)
88 | {
89 |
90 | float w = 1.0;
91 | if(direction == UP || direction == DOWN)
92 | {
93 | w = edges(n, 1, y, x);
94 | }
95 | else
96 | {
97 | w = edges(n, 0, y, x);
98 | }
99 |
100 | return w;
101 | }
102 |
103 | extern __device__ __forceinline__ int getLinearIdx(int pos, int L_vec_size)
104 | {
105 | int vec_idx = L_vec_size / 2 + pos;
106 | vec_idx = max(1, vec_idx);
107 | vec_idx = min(L_vec_size - 1, vec_idx);
108 | return vec_idx;
109 | }
110 |
111 | extern __device__ __forceinline__ float getJumpCost(int t, int s, KernelData5 jump_cost, int n, int direction, int y, int x)
112 | {
113 | int num_L = jump_cost.size2;
114 |
115 | float input_pos = t - s;
116 |
117 | int vec_idx = getLinearIdx(input_pos, num_L);
118 | return jump_cost(n, direction, vec_idx, y, x);
119 | }
120 |
121 | extern __device__ __forceinline__ void addGradientJump(int t, int s, KernelData5 jump_cost, int n, int direction, int y, int x, KernelData5 gradient_pairwise, int grad_xy_idx, float additive_hor, float additive_up, float additive_down)
122 | {
123 | int num_L = jump_cost.size2;
124 |
125 | float input_pos = t - s;
126 |
127 | int vec_idx = getLinearIdx(input_pos, num_L);
128 |
129 | atomicAdd(&gradient_pairwise(n, grad_xy_idx, vec_idx, y, x), additive_hor);
130 | atomicAdd(&gradient_pairwise(n, grad_xy_idx, vec_idx, y, x), additive_up);
131 | atomicAdd(&gradient_pairwise(n, grad_xy_idx, vec_idx, y, x), additive_down);
132 | }
133 |
--------------------------------------------------------------------------------
/ops/sad/src/stereo_sad.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "stereo_sad_kernel.cuh"
3 |
4 | // C++ interface
5 | // AT_ASSERTM in pytorch 1.0
6 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
7 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
8 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x);
9 |
10 | at::Tensor stereo_sad_forward(at::Tensor f0, at::Tensor f1, int min_disp, int max_disp, float step)
11 | {
12 | CHECK_INPUT(f0)
13 | CHECK_INPUT(f1)
14 | return cuda::stereo_sad_forward(f0, f1, min_disp, max_disp, step);
15 | }
16 |
17 | std::vector stereo_sad_backward(at::Tensor f0, at::Tensor f1, int min_disp, int max_disp,
18 | at::Tensor in_grad)
19 | {
20 | CHECK_INPUT(f0)
21 | CHECK_INPUT(f1)
22 | CHECK_INPUT(in_grad)
23 | return cuda::stereo_sad_backward(f0, f1, min_disp, max_disp, in_grad);
24 | }
25 |
26 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
27 | {
28 | m.def("forward", &stereo_sad_forward, "SAD Matching forward (CUDA)");
29 | m.def("backward", &stereo_sad_backward, "SAD Matching backward (CUDA)");
30 | }
--------------------------------------------------------------------------------
/ops/sad/src/stereo_sad_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "stereo_sad_kernel.cuh"
4 | #include "tensor.h"
5 | #include "error_util.h"
6 |
7 | // get y for position x
8 | __device__ float dLinearInterpolation1D(float x, float x0, float x1, float y0, float y1)
9 | {
10 | return y0 + (y1 - y0) * (x - x0) / (x1 - x0);
11 | }
12 |
13 | // ============================================================================
14 | // CUDA KERNELS
15 | // ============================================================================
16 | __global__ void stereo_sad_cuda_forward_kernel(
17 | KernelData f0,
18 | KernelData f1,
19 | int min_disp,
20 | int max_disp,
21 | float step,
22 | KernelData output
23 | )
24 | {
25 | const unsigned int x = blockIdx.x * blockDim.x + threadIdx.x;
26 | const unsigned int y = blockIdx.y * blockDim.y + threadIdx.y;
27 | const unsigned int d_idx = blockIdx.z * blockDim.z + threadIdx.z;
28 |
29 | // check inside image
30 | int n = 0;
31 | if(x >= f0.size3 || y >= f0.size2 || d_idx >= output.size3)
32 | return;
33 |
34 | //
35 | int d = (d_idx * step) + min_disp;
36 |
37 | // skip outside pixels
38 | if(x - d < 0 || x - d >= f0.size3)
39 | return;
40 |
41 | float sad = 0.0f;
42 | for(int c = 0; c < f0.size1; ++c)
43 | {
44 | float f1_c = 0.0;
45 | if(step == 1.0)
46 | {
47 | f1_c = f1(n, c, y, x - d);
48 | }
49 | else
50 | {
51 | int floor_x = (int) floorf(x - d);
52 | int ceil_x = floor_x + 1;
53 | float x_pos = x - d;
54 | f1_c = dLinearInterpolation1D(x_pos, floor_x, ceil_x, f1(n, c, y, floor_x), f1(n, c, y, ceil_x));
55 | }
56 |
57 | sad += fabs(f0(n, c, y, x) - f1_c);
58 | }
59 |
60 | // write result back to global memory
61 | output(n, y, x, d_idx) = sad;
62 | }
63 |
64 | __global__ void stereo_sad_cuda_backward_kernel(
65 | KernelData f0,
66 | KernelData f1,
67 | int min_disp,
68 | int max_disp,
69 | KernelData in_grad,
70 | KernelData df0,
71 | KernelData df1
72 | )
73 | {
74 | const unsigned int x = blockIdx.x * blockDim.x + threadIdx.x;
75 | const unsigned int y = blockIdx.y * blockDim.y + threadIdx.y;
76 | const unsigned int c = blockIdx.z * blockDim.z + threadIdx.z;
77 |
78 | float eps = 1e-15;
79 |
80 | // check inside image
81 | int n = 0;
82 | if(x >= f0.size3 || y >= f0.size2 || c >= f0.size1)
83 | return;
84 |
85 | float grad_f0 = 0.0f;
86 | float grad_f1 = 0.0f;
87 | for(int d = min_disp; d <= max_disp; ++d)
88 | {
89 | int idx = d - min_disp;
90 | // skip outside pixels
91 | if(x - d >= 0 && x - d < f0.size3)
92 | {
93 | float diff = f0(n, c, y, x) - f1(n, c, y, x - d);
94 | if(fabsf(diff) > eps) // gradient is zero if diff is zero!
95 | grad_f0 += (diff / fabsf(diff)) * in_grad(n, y, x, idx);
96 | }
97 |
98 | if(x + d >= 0 && x + d < f0.size3)
99 | {
100 | float diff1 = f0(n, c, y, x + d) - f1(n, c, y, x);
101 | if(fabsf(diff1) > eps)
102 | grad_f1 -= (diff1 / fabsf(diff1)) * in_grad(n, y, x + d, idx);
103 | }
104 | }
105 |
106 | df0(n, c, y, x) = grad_f0;
107 | df1(n, c, y, x) = grad_f1;
108 | }
109 |
110 |
111 | // ============================================================================
112 | // CPP KERNEL CALLS
113 | // ============================================================================
114 | namespace cuda
115 | {
116 | at::Tensor stereo_sad_forward(at::Tensor f0, at::Tensor f1, int min_disp, int max_disp, float step)
117 | {
118 | int N = f0.size(0);
119 | int C = f0.size(1);
120 | int H = f0.size(2);
121 | int W = f0.size(3);
122 | int D = (max_disp - min_disp + 1) / step;
123 |
124 |
125 | auto cost_vol = at::ones({N, H, W, D}, f0.options()) * 40;
126 |
127 | // parallelise over H x W x D
128 | const dim3 blockSize(8, 8, 4);
129 | const dim3 numBlocks(std::ceil(W / static_cast(blockSize.x)),
130 | std::ceil(H / static_cast(blockSize.y)),
131 | std::ceil(D / static_cast(blockSize.z)));
132 |
133 | stereo_sad_cuda_forward_kernel<<>>(f0, f1, min_disp, max_disp, step, cost_vol);
134 | cudaSafeCall(cudaGetLastError());
135 | return cost_vol;
136 | }
137 |
138 | std::vector stereo_sad_backward(at::Tensor f0, at::Tensor f1,
139 | int min_disp, int max_disp,
140 | at::Tensor in_grad)
141 | {
142 | int N = f0.size(0);
143 | int C = f0.size(1);
144 | int H = f0.size(2);
145 | int W = f0.size(3);
146 | int D = max_disp - min_disp + 1;
147 |
148 | auto df0 = at::zeros_like(f0);
149 | auto df1 = at::zeros_like(f1);
150 |
151 | // parallelise over H x W x D
152 | const dim3 blockSize(8, 8, 4);
153 | const dim3 numBlocks(std::ceil(W / static_cast(blockSize.x)),
154 | std::ceil(H / static_cast(blockSize.y)),
155 | std::ceil(D / static_cast(blockSize.z)));
156 |
157 | stereo_sad_cuda_backward_kernel<<>>(f0, f1, min_disp, max_disp, in_grad,
158 | df0, df1);
159 | cudaSafeCall(cudaGetLastError());
160 |
161 | std::vector gradients;
162 | gradients.push_back(df0);
163 | gradients.push_back(df1);
164 |
165 | return gradients;
166 | }
167 | }
--------------------------------------------------------------------------------
/ops/sad/src/stereo_sad_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | namespace cuda
5 | {
6 | at::Tensor stereo_sad_forward(at::Tensor f0, at::Tensor f1, int min_disp, int max_disp, float step);
7 | std::vector stereo_sad_backward(at::Tensor f0, at::Tensor f1, int min_disp,
8 | int max_disp, at::Tensor in_grad);
9 | }
--------------------------------------------------------------------------------
/ops/sad/stereo_sad.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import pytorch_cuda_stereo_sad_op
4 |
5 | class StereoMatchingSadFunction(torch.autograd.Function):
6 | @staticmethod
7 | def forward(ctx, f0, f1, min_disp, max_disp, step=1.0):
8 | ctx.save_for_backward(f0, f1, torch.tensor(min_disp), torch.tensor(max_disp), torch.tensor(step))
9 | res = pytorch_cuda_stereo_sad_op.forward(f0, f1, min_disp, max_disp, step)
10 | return res
11 |
12 | @staticmethod
13 | def backward(ctx, in_grad):
14 | f0, f1, min_disp, max_disp, step = ctx.saved_tensors
15 | if step != 1.0:
16 | raise ValueError("Error: Backward for step != 1 is not implemented!")
17 | df0, df1 = pytorch_cuda_stereo_sad_op.backward(f0, f1, int(min_disp), int(max_disp),
18 | in_grad)
19 | return df0, df1, None, None
20 |
--------------------------------------------------------------------------------
/ops/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | setup(
6 | name='matching',
7 | version='1.0',
8 | author="Patrick Knöbelreiter",
9 | author_email="knoebelreiter@icg.tugraz.at",
10 | packages=["sad"],
11 | include_dirs=['include/' ],
12 | ext_modules=[
13 | CUDAExtension('pytorch_cuda_stereo_sad_op', [
14 | 'sad/src/stereo_sad.cpp',
15 | 'sad/src/stereo_sad_kernel.cu',
16 | ]),
17 | CUDAExtension('pytorch_cuda_flow_mp_sad_op', [
18 | 'flow_mp_sad/src/flow_mp_sad.cpp',
19 | 'flow_mp_sad/src/flow_mp_sad_kernel.cu',
20 | ])
21 | ],
22 | cmdclass={
23 | 'build_ext': BuildExtension
24 | },
25 | install_requires=[
26 | "numpy >= 1.15",
27 | "torch >= 0.4.1",
28 | "scikit-image >= 0.14.1",
29 | ],
30 | )
31 |
--------------------------------------------------------------------------------
/run_flow.sh:
--------------------------------------------------------------------------------
1 | #/bin/sh
2 | IM0="data/frame_0019.png"
3 | IM1="data/frame_0020.png"
4 |
5 | cd ops/lbp_stereo
6 | python setup.py install
7 | cd ../../
8 |
9 | python main_flow.py --model bp+ms+h --checkpoint-unary "data/params/flow/BP+MS (H)/unary_best.cpt" --checkpoint-matching "data/params/flow/BP+MS (H)/matching_lvl0_best.cpt" "data/params/flow/BP+MS (H)/matching_lvl1_best.cpt" "data/params/flow/BP+MS (H)/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/flow/BP+MS (H)/affinity_best.cpt" --checkpoint-crf "data/params/flow/BP+MS (H)/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/flow/BP+MS (H)/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/flow/BP+MS (H)/crf0_lvl2_best.cpt" --multi-level-output --bp-inference sub-exp --with-bn --im0 $IM0 --im1 $IM1
10 |
11 | python main_flow.py --model bp+ms+ref+h --checkpoint-unary "data/params/flow/BP+MS+Ref (H)/unary_best.cpt" --checkpoint-matching "data/params/flow/BP+MS+Ref (H)/matching_lvl0_best.cpt" "data/params/flow/BP+MS+Ref (H)/matching_lvl1_best.cpt" "data/params/flow/BP+MS+Ref (H)/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/flow/BP+MS+Ref (H)/affinity_best.cpt" --checkpoint-crf "data/params/flow/BP+MS+Ref (H)/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/flow/BP+MS+Ref (H)/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/flow/BP+MS+Ref (H)/crf0_lvl2_best.cpt" --checkpoint-refinement "data/params/flow/BP+MS+Ref (H)/refinement_best.cpt" --with-bn --bp-inference sub-exp --output-level-offset 0 --multi-level-output --im0 $IM0 --im1 $IM1
--------------------------------------------------------------------------------
/run_semantic_global.sh:
--------------------------------------------------------------------------------
1 | #/bin/sh
2 | IMG="data/frankfurt_val.png"
3 | cd ops/lbp_semantic_pw &&
4 | python setup.py install &&
5 | cd ../../ &&
6 | python main_semantic.py --img=$IMG --checkpoint-semantic data/params/semantic/global_model.cpt --checkpoint-esp-net dependencies/ESPNet/pretrained --pairwise-type global --with-edges
--------------------------------------------------------------------------------
/run_semantic_pixel.sh:
--------------------------------------------------------------------------------
1 | #/bin/sh
2 | IMG="data/frankfurt_val.png"
3 | cd ops/lbp_semantic_pw_pixel &&
4 | python setup.py install &&
5 | cd ../../ &&
6 |
7 | python main_semantic.py --img=$IMG --checkpoint-semantic data/params/semantic/pixel_model.cpt --checkpoint-esp-net dependencies/ESPNet/pretrained --pairwise-type pixel
--------------------------------------------------------------------------------
/run_stereo_kitti.sh:
--------------------------------------------------------------------------------
1 | #/bin/sh
2 | cd ops/lbp_stereo
3 | python setup.py install
4 | cd ../../
5 |
6 | python main_stereo.py --model bp+ms+h --checkpoint-unary "data/params/stereo/BP+MS (H)/unary_best.cpt" --checkpoint-matching "data/params/stereo/Kitti/matching_lvl0_best.cpt" "data/params/stereo/Kitti/matching_lvl1_best.cpt" "data/params/stereo/Kitti/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/stereo/Kitti/affinity_best.cpt" --checkpoint-crf "data/params/stereo/Kitti/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/stereo/Kitti/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/stereo/Kitti/crf0_lvl2_best.cpt" --with-bn --bp-inference sub-exp --input-level-offset 0 --output-level-offset 0 --multi-level-output --im0 "data/000010_10_left.png" --im1 "data/000010_10_right.png"
--------------------------------------------------------------------------------
/run_stereo_mb.sh:
--------------------------------------------------------------------------------
1 | #/bin/sh
2 | cd ops/lbp_stereo
3 | python setup.py install
4 | cd ../../
5 |
6 | python main_stereo.py --model bp+ms+h --checkpoint-unary "data/params/stereo/MB/unary_best.cpt" --checkpoint-matching "data/params/stereo/MB/matching_lvl0_best.cpt" "data/params/stereo/MB/matching_lvl1_best.cpt" "data/params/stereo/MB/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/stereo/MB/affinity_best.cpt" --checkpoint-crf "data/params/stereo/MB/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/stereo/MB/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/stereo/MB/crf0_lvl2_best.cpt" --multi-level-output --bp-inference sub-exp --with-bn --input-level-offset 1 --output-level-offset 1 --im0 "data/im0.png" --im1 "data/im1.png"
--------------------------------------------------------------------------------
/run_stereo_sf.sh:
--------------------------------------------------------------------------------
1 | #/bin/sh
2 | IM0="data/sf_0006_left.png"
3 | IM1="data/sf_0006_right.png"
4 |
5 | cd ops/lbp_stereo
6 | python setup.py install
7 | cd ../../
8 |
9 | python main_stereo.py --model wta --checkpoint-unary "data/params/stereo/WTA (NLL)/unary_best.cpt" --checkpoint-matching "data/params/stereo/WTA (NLL)/matching_best.cpt" --with-bn --with-output-bn --im0 $IM0 --im1 $IM1
10 |
11 | python main_stereo.py --model bp+ms --checkpoint-unary "data/params/stereo/BP+MS (NLL)/unary_best.cpt" --checkpoint-matching "data/params/stereo/BP+MS (NLL)/matching_lvl0_best.cpt" "data/params/stereo/BP+MS (NLL)/matching_lvl1_best.cpt" "data/params/stereo/BP+MS (NLL)/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/stereo/BP+MS (NLL)/affinity_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS (H)/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS (H)/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS (H)/crf0_lvl2_best.cpt" --with-bn --bp-inference wta --multi-level-output --im0 $IM0 --im1 $IM1
12 |
13 | python main_stereo.py --model bp+ms+h --checkpoint-unary "data/params/stereo/BP+MS (H)/unary_best.cpt" --checkpoint-matching "data/params/stereo/BP+MS (H)/matching_lvl0_best.cpt" "data/params/stereo/BP+MS (H)/matching_lvl1_best.cpt" "data/params/stereo/BP+MS (H)/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/stereo/BP+MS (H)/affinity_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS (H)/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS (H)/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS (H)/crf0_lvl2_best.cpt" --multi-level-output --bp-inference sub-exp --with-bn --im0 $IM0 --im1 $IM1
14 |
15 | python main_stereo.py --model bp+ms+ref+h --checkpoint-unary "data/params/stereo/BP+MS+Ref (H)/unary_best.cpt" --checkpoint-matching "data/params/stereo/BP+MS+Ref (H)/matching_lvl0_best.cpt" "data/params/stereo/BP+MS+Ref (H)/matching_lvl1_best.cpt" "data/params/stereo/BP+MS+Ref (H)/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/stereo/BP+MS+Ref (H)/affinity_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS+Ref (H)/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS+Ref (H)/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS+Ref (H)/crf0_lvl2_best.cpt" --checkpoint-refinement "data/params/stereo/BP+MS+Ref (H)/refinement_best.cpt" --with-bn --bp-inference sub-exp --output-level-offset 0 --multi-level-output --im0 $IM0 --im1 $IM1
--------------------------------------------------------------------------------
/semantic_segmentation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import os.path as osp
5 | from networks import SubNetwork
6 | from networks import PWNet, EdgeNet
7 |
8 | from ops.lbp_semantic_pw.bp_op_cuda import BP as BP_PW
9 | from ops.lbp_semantic_pw_pixel.bp_op_cuda import BP as BP_PW_PIXEL
10 |
11 | from dependencies.ESPNet.test.Model import ESPNet
12 |
13 | import os
14 |
15 | from corenet import TemperatureSoftmax
16 |
17 | from time import time
18 |
19 | class SemanticNet(SubNetwork):
20 | def __init__(self, device, args):
21 | super(SemanticNet, self).__init__(args, device)
22 |
23 | self.num_labels = args.num_labels
24 |
25 | self.forward_timings = []
26 |
27 | self._esp_net = ESPNet(self.num_labels, 2, 8, None)
28 |
29 | self._softmax_input = TemperatureSoftmax(dim=3, init_temp=1.0)
30 |
31 | if (self.args.pairwise_type == "global") and args.with_edges:
32 | self._edge_net = EdgeNet(device, args)
33 |
34 | if self.args.checkpoint_esp_net:
35 | esp_weight_file = os.path.join(self.args.checkpoint_esp_net, "decoder/espnet_p_" + str(2) + "_q_" + str(8) + ".pth")
36 | self._esp_net.load_state_dict(torch.load(esp_weight_file))
37 |
38 | max_iter = 1
39 |
40 | if args.pairwise_type == "global":
41 | self._crf = [BP_PW(device, args, max_iter, self.num_labels, self.num_labels,
42 | mode_inference = 'wta',
43 | mode_message_passing='min-sum', layer_idx=0)]
44 |
45 | self.pw_mat_init = torch.zeros((1, 2, self.num_labels, self.num_labels), dtype=torch.float).cuda()
46 | self._pw_mat = nn.Parameter(self.pw_mat_init, requires_grad=True)
47 | self._pw_net = None
48 |
49 | elif args.pairwise_type == "pixel":
50 | self._crf = [BP_PW_PIXEL(device, args, max_iter, self.num_labels, self.num_labels,
51 | mode_inference = 'wta',
52 | mode_message_passing='min-sum', layer_idx=0)]
53 |
54 | self._pw_net = PWNet(device, args)
55 | self._pw_mat = None
56 |
57 | if args.checkpoint_semantic is not None:
58 | print("Loading semantic checkpoint!")
59 | self.load_state_dict(torch.load(args.checkpoint_semantic))
60 |
61 | if not self.args.with_esp:
62 | self._esp_net.eval()
63 |
64 | self._esp_net.to(device)
65 |
66 | def forward(self, ipt):
67 |
68 | t0_fwd = time()
69 |
70 | res_esp = self._esp_net.forward(ipt)
71 |
72 | res_esp = res_esp.permute((0, 2, 3, 1))
73 | res_esp = res_esp.contiguous()
74 |
75 | res_esp = self._softmax_input(res_esp)
76 |
77 | esp_only_res = res_esp[0].max(2)[1].unsqueeze(0).unsqueeze(0)
78 |
79 | N, H, W, C = res_esp.shape
80 |
81 | if (self.args.pairwise_type == "global") and self.args.with_edges:
82 | weights = self.extract_edges(ipt)[0]
83 | else:
84 | weights = torch.ones((N, 2, H, W), dtype=torch.float).cuda()
85 |
86 | res, beliefs = self.optimize_crf(ipt, res_esp, weights, None, None)
87 |
88 | torch.cuda.synchronize()
89 | self.forward_timings.append(time() - t0_fwd)
90 |
91 | torch.cuda.empty_cache()
92 |
93 | return res, beliefs, esp_only_res
94 |
95 | def esp_net(self):
96 | return self._esp_net
97 |
98 | def esp_net_params(self, requires_grad=None):
99 | if self._esp_net:
100 | return self._esp_net.parameter_list(requires_grad)
101 | return []
102 |
103 | def sem_params(self, requires_grad=None):
104 | return self.parameter_list(requires_grad)
105 |
106 | def pw_mat(self):
107 | return self._pw_mat
108 |
109 | def edge_net(self):
110 | return self._edge_net
111 |
112 | def edge_net_params(self, requires_grad=None):
113 | if self._edge_net:
114 | return self._edge_net.parameter_list(requires_grad)
115 | return []
116 |
117 | def extract_edges(self, ipt):
118 | if self.edge_net:
119 | return self._edge_net.forward(ipt)
120 | return None
121 |
122 | def optimize_crf(self, ipt, prob_vol, weights, affinities, offsets):
123 | # iterate over all bp "layers"
124 | for idx, crf in enumerate(self._crf):
125 |
126 | prob_vol = prob_vol.contiguous()
127 |
128 | weights_input = crf.adjust_input_weights(weights, idx)
129 |
130 | if self._pw_net is not None:
131 |
132 | pw_net_jump = self._pw_net.forward(ipt)[0]
133 |
134 | _, _, H, W = pw_net_jump.shape
135 |
136 | pw_net_jump = pw_net_jump.view((2, self.num_labels, self.num_labels, H, W))
137 | pw_net_jump = pw_net_jump.permute((0, 3, 4, 1, 2)).contiguous()
138 |
139 | disps, prob_vol = crf.forward(prob_vol, weights_input, pw_net_jump)
140 | elif self._pw_mat is not None:
141 | disps, prob_vol = crf.forward(prob_vol, weights_input, self._pw_mat)
142 |
143 | return disps, prob_vol
144 |
--------------------------------------------------------------------------------
/stereo.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from networks import FeatureNet, AffinityNet, RefinementNet
7 | from matching import StereoMatchingSad
8 | from ops.lbp_stereo.bp_op_cuda import BP
9 | from ops.lbp_stereo.inference_op import Inference
10 |
11 | from corenet import LrCheck, LrDistance, CvConfidence
12 | from corenet import PadUnpad, Pad, Unpad
13 |
14 | import numpy as np
15 |
16 | class StereoMethod(nn.Module):
17 | def __init__(self, device, args):
18 | nn.Module.__init__(self)
19 | self.args = args
20 |
21 | self._feature_net = None
22 | self._matching = []
23 | self._affinity_net = None
24 | self._refinement_net = None
25 | self._crf = [] # list of bp layers
26 |
27 | max_dist = 3.0
28 | self.lr_check = LrCheck(device, max_dist).to(device)
29 | self.cv_conf = CvConfidence(device).to(device)
30 | self.lr_dist = LrDistance(device).to(device)
31 |
32 | self._min_disp = args.min_disp
33 | self._max_disp = args.max_disp
34 |
35 | self.pad = None
36 | self.unpad = None
37 |
38 | self.device = device
39 |
40 | self.logger = logging.getLogger("StereoMethod")
41 |
42 |
43 | def forward(self, I0_pyramid, I1_pyramid, offsets_orig=None, edges_orig=None, beliefs_in=None,
44 | min_disp=None, max_disp=None, step=None):
45 |
46 | # necessary for evaluation
47 | if self.pad is None:
48 | self.pad = Pad(self.feature_net.net.divisor, self.args.pad)
49 | if self.unpad is None:
50 | self.unpad = Unpad()
51 |
52 | res_dict = {'disps0': None}
53 |
54 | I0_in = I0_pyramid[self.args.input_level_offset].to(self.device)
55 | I1_in = I1_pyramid[self.args.input_level_offset].to(self.device)
56 |
57 | # pad input for multi-scale (for evaluation)
58 | I0_in = self.pad.forward(I0_in).cuda()
59 | I1_in = self.pad.forward(I1_in).cuda()
60 |
61 | f0_pyramid = self.extract_features(I0_in)
62 | f1_pyramid = self.extract_features(I1_in)
63 |
64 | if max_disp is not None:
65 | for matching_lvl, m in enumerate(self.matching):
66 | m.max_disp = ((max_disp + 1) // 2**matching_lvl) - 1
67 | if step is not None:
68 | for matching_lvl, m in enumerate(self.matching):
69 | m.step = step
70 |
71 | # multi-scale-matching
72 | prob_vol_pyramid = self.match(f0_pyramid, f1_pyramid)
73 |
74 | udisp0_pyramid = []
75 | for pv0 in prob_vol_pyramid:
76 | udisp0_pyramid.append(torch.argmax(pv0, dim=-1, keepdim=True).permute(0, 3, 1, 2))
77 | res_dict['disps0'] = udisp0_pyramid
78 |
79 | if self.args.model == 'wta':
80 | return res_dict
81 |
82 | affinity_pyramid = None
83 | if self.affinity_net:
84 | affinity_pyramid = self.extract_affinities(I0_in)
85 | for lvl in range(len(affinity_pyramid)):
86 | _, _, h, w = affinity_pyramid[lvl].shape
87 | affinity_pyramid[lvl] = affinity_pyramid[lvl].view((2, 5, h, w))
88 | affinity_pyramid[lvl] = affinity_pyramid[lvl].unsqueeze(0)
89 |
90 | output_disps_pyramid = []
91 | beliefs_pyramid = None
92 | crf_disps_pyramid = []
93 | beliefs_pyramid = []
94 | beliefs_in = None
95 | for lvl in reversed(range(len(prob_vol_pyramid))):
96 | pv_lvl = prob_vol_pyramid[lvl]
97 | m = self.matching[lvl]
98 |
99 | affinity = None
100 | if affinity_pyramid is not None:
101 | affinity = affinity_pyramid[lvl]
102 | crf = self.crf[lvl]
103 |
104 | # add probably an if condition whether do add multi-scale to crf
105 | if beliefs_in is not None:
106 | beliefs_in = F.interpolate(beliefs_in.unsqueeze(1), scale_factor=2.0, mode='trilinear')[:, 0]
107 |
108 | if beliefs_in.requires_grad:
109 | # print('requires grad')
110 | pv_lvl = pv_lvl / pv_lvl.sum(dim=-1, keepdim=True)
111 | else:
112 | # print('no grad-> inplace')
113 | pv_lvl += beliefs_in / 2.0 # in-place saves memory
114 | del beliefs_in
115 |
116 | torch.cuda.empty_cache()
117 | disps_lvl, beliefs_lvl, affinities_lvl, _ = self.optimize_crf(crf, pv_lvl, None, affinity)
118 | del affinities_lvl
119 |
120 | if lvl == 0:
121 | beliefs_lvl = self.unpad(beliefs_lvl, self.pad.l, self.pad.r, self.pad.t,
122 | self.pad.b, NCHW=False)
123 | disps_lvl = self.unpad(disps_lvl, self.pad.l, self.pad.r, self.pad.t,
124 | self.pad.b)
125 |
126 | beliefs_pyramid.append(beliefs_lvl)
127 | crf_disps_pyramid.append(disps_lvl + m.min_disp)
128 | beliefs_in = beliefs_pyramid[-1]
129 |
130 | # beliefs are from low res to high res
131 | beliefs_pyramid.reverse()
132 | crf_disps_pyramid.reverse()
133 | res_dict['disps0'] = crf_disps_pyramid
134 |
135 | if self.refinement_net:
136 | # crf
137 | cv_conf = self.cv_conf.forward(beliefs_pyramid[0].permute(0, 3, 1, 2),
138 | crf_disps_pyramid[0])
139 |
140 | conf_all = cv_conf
141 | refined_disps_pyramid, refinement_steps = self.refine_disps(I0_pyramid,
142 | crf_disps_pyramid[0],
143 | confidence=conf_all,
144 | I1=I1_pyramid)
145 | if refinement_steps is not None:
146 | refinement_steps.reverse()
147 | refined_disps_pyramid.reverse()
148 | output_disps_pyramid = refined_disps_pyramid
149 |
150 | res_dict['disps0'] = output_disps_pyramid
151 |
152 | return res_dict
153 |
154 | def extract_features(self, ipt):
155 | if self.feature_net:
156 | return self.feature_net.forward(ipt)
157 | return None
158 |
159 | def extract_affinities(self, ipt):
160 | if self.affinity_net:
161 | return self.affinity_net.forward(ipt)
162 | return None
163 |
164 | def match(self, f0, f1, lr=False):
165 | prob_vols = []
166 | if self.matching:
167 | for matching, f0s, f1s in zip(self.matching, f0, f1):
168 | if lr:
169 | f0s = torch.flip(f0s, dims=(3,)).contiguous()
170 | f1s = torch.flip(f1s, dims=(3,)).contiguous()
171 | prob_vol_s = matching.forward(f1s, f0s)
172 | prob_vol_s = torch.flip(prob_vol_s, dims=(2,))
173 | else:
174 | prob_vol_s = matching.forward(f0s, f1s)
175 |
176 | prob_vols.append(prob_vol_s)
177 | return prob_vols
178 | return None
179 |
180 | def optimize_crf(self, crf_layer, prob_vol, weights, affinities):
181 | if crf_layer:
182 | offsets = None
183 | # iterate over all bp "layers"
184 | for idx, crf in enumerate(crf_layer):
185 | prob_vol = prob_vol.contiguous()
186 | weights_input = crf.adjust_input_weights(weights, idx)
187 | affinities_shift = crf.adjust_input_affinities(affinities)
188 | offsets_shift = crf.adjust_input_offsets(offsets)
189 |
190 | if not prob_vol.requires_grad:
191 | torch.cuda.empty_cache()
192 |
193 | disps, prob_vol, messages = crf.forward(prob_vol, weights_input, affinities_shift, offsets_shift)
194 | del messages # never used again
195 |
196 | return disps, prob_vol, affinities_shift, offsets_shift
197 | return None
198 |
199 | def refine_disps(self, I0, d0, confidence=None, I1=None):
200 | if self.refinement_net:
201 | refined, steps = self.refinement_net.forward(I0, d0, confidence, I1)
202 | return refined, steps
203 | return None
204 |
205 | def feature_net_params(self, requires_grad=None):
206 | if self.feature_net:
207 | return self.feature_net.parameter_list(requires_grad)
208 | return []
209 |
210 | def matching_params(self, requires_grad=None):
211 | params = []
212 | if self.matching:
213 | for m in self.matching:
214 | params += m.parameter_list(requires_grad)
215 | return params
216 |
217 | def affinity_net_params(self, requires_grad=None):
218 | if self.affinity_net:
219 | return self.affinity_net.parameter_list(requires_grad)
220 | return []
221 |
222 | def crf_params(self, requires_grad=None):
223 | crf_params = []
224 | if self.crf:
225 | for crf_layer in self.crf:
226 | for crf in crf_layer:
227 | crf_params += crf.parameter_list(requires_grad)
228 | return crf_params
229 |
230 | def refinement_net_params(self, requires_grad=None):
231 | if self.refinement_net:
232 | return self.refinement_net.parameter_list(requires_grad)
233 | return []
234 |
235 | @property
236 | def feature_net(self):
237 | return self._feature_net
238 |
239 | @property
240 | def affinity_net(self):
241 | return self._affinity_net
242 |
243 | @property
244 | def crf(self):
245 | if self._crf == []:
246 | return None
247 | return self._crf
248 |
249 | @property
250 | def refinement_net(self):
251 | return self._refinement_net
252 |
253 | @property
254 | def matching(self):
255 | return self._matching
256 |
257 | @property
258 | def min_disp(self):
259 | return self._min_disp
260 |
261 | @property
262 | def max_disp(self):
263 | return self._max_disp
264 |
265 | @property
266 | def gc_net(self):
267 | return self._gc_net
268 |
269 |
270 | ####################################################################################################
271 | # Block Match
272 | ####################################################################################################
273 | class BlockMatchStereo(StereoMethod):
274 | def __init__(self, device, args):
275 | StereoMethod.__init__(self, device, args)
276 | self._feature_net = FeatureNet(device, args)
277 |
278 | self._matching = []
279 | for matching_lvl in range(self._feature_net.net.num_output_levels):
280 | if self.args.lbp_min_disp:
281 | min_disp = self.min_disp # original
282 | else:
283 | min_disp = self.min_disp // 2**matching_lvl
284 | max_disp = ((self.max_disp + 1) // 2**matching_lvl) - 1
285 | self.logger.info("Construct Matching Level %d with min-disp=%d and max-disp=%d" %(matching_lvl, min_disp, max_disp))
286 |
287 | self._matching.append(StereoMatchingSad(device, args, min_disp, max_disp,
288 | lvl=matching_lvl))
289 |
290 |
291 | ####################################################################################################
292 | # Min-Sum LBP
293 | ####################################################################################################
294 | class MinSumStereo(BlockMatchStereo):
295 | def __init__(self, device, args):
296 | BlockMatchStereo.__init__(self, device, args)
297 |
298 | self.max_iter = args.max_iter
299 | num_labels = self.max_disp - self.min_disp + 1
300 |
301 | self._affinity_net = AffinityNet(device, args)
302 |
303 | for lvl in range(self._feature_net.net.num_output_levels):
304 | self._crf.append([BP(device, args, self.max_iter, num_labels, 3,
305 | mode_inference = args.bp_inference,
306 | mode_message_passing='min-sum', layer_idx=idx, level=lvl)
307 | for idx in range(args.num_bp_layers)])
308 |
309 |
310 | class RefinedMinSumStereo(MinSumStereo):
311 | def __init__(self, device, args):
312 | super(RefinedMinSumStereo, self).__init__(device, args)
313 |
314 | self._refinement_net = RefinementNet(device, args)
--------------------------------------------------------------------------------