├── .gitignore ├── .gitmodules ├── LICENSE ├── Readme.md ├── corenet.py ├── data.py ├── data ├── 000010_10_left.png ├── 000010_10_right.png ├── frame_0019.flo ├── frame_0019.png ├── frame_0020.png ├── frankfurt_val.png ├── im0.png ├── im1.png ├── output │ ├── flow │ │ └── .keep │ ├── semantic │ │ ├── .keep │ │ └── sem_pred.png │ └── stereo │ │ └── .keep ├── params │ ├── flow │ │ ├── BP+MS (H) │ │ │ ├── affinity_best.cpt │ │ │ ├── crf0_lvl0_best.cpt │ │ │ ├── crf0_lvl1_best.cpt │ │ │ ├── crf0_lvl2_best.cpt │ │ │ ├── matching_lvl0_best.cpt │ │ │ ├── matching_lvl1_best.cpt │ │ │ ├── matching_lvl2_best.cpt │ │ │ └── unary_best.cpt │ │ └── BP+MS+Ref (H) │ │ │ ├── affinity_best.cpt │ │ │ ├── crf0_lvl0_best.cpt │ │ │ ├── crf0_lvl1_best.cpt │ │ │ ├── crf0_lvl2_best.cpt │ │ │ ├── matching_lvl0_best.cpt │ │ │ ├── matching_lvl1_best.cpt │ │ │ ├── matching_lvl2_best.cpt │ │ │ ├── refinement_best.cpt │ │ │ └── unary_best.cpt │ ├── semantic │ │ ├── global_model.cpt │ │ └── pixel_model.cpt │ └── stereo │ │ ├── BP+MS (H) │ │ ├── affinity_best.cpt │ │ ├── crf0_lvl0_best.cpt │ │ ├── crf0_lvl1_best.cpt │ │ ├── crf0_lvl2_best.cpt │ │ ├── matching_lvl0_best.cpt │ │ ├── matching_lvl1_best.cpt │ │ ├── matching_lvl2_best.cpt │ │ └── unary_best.cpt │ │ ├── BP+MS (NLL) │ │ ├── affinity_best.cpt │ │ ├── crf0_lvl0_best.cpt │ │ ├── crf0_lvl1_best.cpt │ │ ├── crf0_lvl2_best.cpt │ │ ├── matching_lvl0_best.cpt │ │ ├── matching_lvl1_best.cpt │ │ ├── matching_lvl2_best.cpt │ │ └── unary_best.cpt │ │ ├── BP+MS+Ref (H) │ │ ├── affinity_best.cpt │ │ ├── crf0_lvl0_best.cpt │ │ ├── crf0_lvl1_best.cpt │ │ ├── crf0_lvl2_best.cpt │ │ ├── matching_lvl0_best.cpt │ │ ├── matching_lvl1_best.cpt │ │ ├── matching_lvl2_best.cpt │ │ ├── refinement_best.cpt │ │ └── unary_best.cpt │ │ ├── Kitti │ │ ├── affinity_best.cpt │ │ ├── crf0_lvl0_best.cpt │ │ ├── crf0_lvl1_best.cpt │ │ ├── crf0_lvl2_best.cpt │ │ ├── matching_lvl0_best.cpt │ │ ├── matching_lvl1_best.cpt │ │ ├── matching_lvl2_best.cpt │ │ └── unary_best.cpt │ │ ├── MB │ │ ├── affinity_best.cpt │ │ ├── crf0_lvl0_best.cpt │ │ ├── crf0_lvl1_best.cpt │ │ ├── crf0_lvl2_best.cpt │ │ ├── matching_lvl0_best.cpt │ │ ├── matching_lvl1_best.cpt │ │ ├── matching_lvl2_best.cpt │ │ └── unary_best.cpt │ │ └── WTA (NLL) │ │ ├── matching_best.cpt │ │ └── unary_best.cpt ├── sf_0006_left.png └── sf_0006_right.png ├── flow.py ├── flow_matching.py ├── flow_tools.py ├── github_imgs ├── flow_example.png ├── kitti.png ├── mb.png ├── sem_input.png ├── sem_pred.png ├── sf.png └── teaser.gif ├── main_flow.py ├── main_semantic.py ├── main_stereo.py ├── matching.py ├── networks.py ├── ops ├── flow_mp_sad │ ├── flow_mp_sad.py │ └── src │ │ ├── flow_mp_sad.cpp │ │ ├── flow_mp_sad_kernel.cu │ │ └── flow_mp_sad_kernel.cuh ├── include │ ├── error_util.h │ └── tensor.h ├── lbp_semantic_pw │ ├── bp_op_cuda.py │ ├── inference_op.py │ ├── message_passing_op_cuda.py │ ├── setup.py │ └── src │ │ ├── lbp.cpp │ │ ├── lbp_min_sum_kernel.cu │ │ ├── lbp_min_sum_kernel.cuh │ │ └── util.cuh ├── lbp_semantic_pw_pixel │ ├── bp_op_cuda.py │ ├── inference_op.py │ ├── message_passing_op_pw_pixel.py │ ├── setup.py │ └── src │ │ ├── lbp.cpp │ │ ├── lbp_min_sum_kernel.cu │ │ ├── lbp_min_sum_kernel.cuh │ │ └── util.cuh ├── lbp_stereo │ ├── bp_op_cuda.py │ ├── inference_op.py │ ├── message_passing_op_cuda.py │ ├── setup.py │ └── src │ │ ├── lbp.cpp │ │ ├── lbp_min_sum_kernel.cu │ │ ├── lbp_min_sum_kernel.cuh │ │ └── util.cuh ├── sad │ ├── src │ │ ├── stereo_sad.cpp │ │ ├── stereo_sad_kernel.cu │ │ └── stereo_sad_kernel.cuh │ └── stereo_sad.py └── setup.py ├── run_flow.sh ├── run_semantic_global.sh ├── run_semantic_pixel.sh ├── run_stereo_kitti.sh ├── run_stereo_mb.sh ├── run_stereo_sf.sh ├── semantic_segmentation.py └── stereo.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.fls 2 | *.fdb_latexmk 3 | *.aux 4 | *.log 5 | *.out 6 | *.pdf 7 | *.DS_Store 8 | .idea 9 | .vscode 10 | *.gz 11 | images 12 | build 13 | __pycache__ 14 | *egg* 15 | .ipynb_checkpoints 16 | *.pyc 17 | *.zip 18 | *.tar 19 | *.nav 20 | *.snm 21 | *.toc 22 | *.brf 23 | *.blg 24 | *.bbl 25 | output 26 | 27 | 28 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "dependencies/ESPNet"] 2 | path = dependencies/ESPNet 3 | url = https://github.com/sacmehta/ESPNet.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 4 | Patrick Knoebelreiter and Christian Sormann 5 | Institute for Computer Graphics and Vision, Graz University of Technology 6 | https://www.tugraz.at/institute/icg/research/team-pock/ 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # BP-Layers 2 | This repository contains the implementation for our publication "Belief Propagation Reloaded: Learning BP-Layers for Labeling Problems". If you use this implementation please cite the following publication: 3 | 4 | ~~~ 5 | @InProceedings{Knobelreiter_2020_CVPR, 6 | author = {Knöbelreiter, Patrick and Sormann, Christian and Shekhovtsov, Alexander and Fraundorfer, Friedrich and Pock, Thomas}, 7 | title = {Belief Propagation Reloaded: Learning BP-Layers for Labeling Problems}, 8 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 9 | month = {June}, 10 | year = {2020} 11 | } 12 | ~~~ 13 | 14 |

15 | 16 | ## Repository Structure 17 | 18 | The repository is structured as follows: 19 | - the base directory contains scripts for running inference and python implementations of the networks 20 | - 'data' includes sample images for stereo/semantic/flow inference 21 | - 'ops' contains custom PyTorch-Ops which need to be installed before running the respective stereo/semantic implementation (note that this is also taken care of by the run_*.sh scripts) 22 | 23 | For your convenience, the required libraries are added as submodules. To clone them issue the command 24 | 25 | ~~~ 26 | git submodule update --init --recursive 27 | ~~~ 28 | 29 | ## Dependencies 30 | 31 | * Cuda 10.2 32 | * pytorch >= 1.3 33 | * argparse 34 | * imageio (with libpfm installed)* 35 | * numpy 36 | 37 | The stereo results are saved as pfm images. If your imageio does not have libpfm installed automatically, execute the following command in a python: 38 | 39 | ~~~ 40 | imageio.plugins.freeimage.download() 41 | ~~~ 42 | 43 | In order to display pfm files we highly recommend the tool provided by the Middlebury stereo benchmark. You can find it here. 44 | 45 | 46 | ## Running the implementation 47 | 48 | After installing all of the required dependencies above you need to install the provided modules to you python environment. This can be done with 49 | 50 | ~~~ 51 | cd ops 52 | python setup.py install 53 | ~~~ 54 | 55 | This will install the SAD matching kernels for stereo and Optical flow. The BP-Layer is installed automatically upon execution of the provided shell scripts. The following sections show how to use them. 56 | 57 | 58 | ### Stereo 59 | * run_stereo_sf.sh: The models trained on the Scene-Flow Dataset 60 | * run_stereo_mb.sh: The model used for evaluation on the Middlebury dataset 61 | * run_stereo_kitti: The model used for evaluation on the Kitti dataset 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | ### Flow 74 | * run_flow.sh 75 | 76 |

77 |

78 | 79 | ### Semantic Segmentation 80 | * run_semantic_global.sh: Our semantic segmentation model with *global* pairwise weights 81 | * run_semantic_pixel: Our semantic segmentation model with *pixel-wise* pairwise weights 82 | 83 | 84 | 85 | 86 | ~~~ 87 | sh run_semantic_pixel.sh 88 | ~~~ 89 | 90 | Should yield this result: 91 | 92 |

93 |

94 | 95 | Inside these scripts you can also specify different images to be used for inference. The correct PyTorch-Ops are also automatically installed by these scripts before running the inference. -------------------------------------------------------------------------------- /corenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | class CvConfidence(nn.Module): 8 | def __init__(self, device): 9 | super(CvConfidence, self).__init__() 10 | self._device = device 11 | 12 | def forward(self, prob_volume, disps): 13 | N, _, H, W = prob_volume.shape 14 | 15 | # generate coordinates 16 | n_coords = torch.arange(N, device=prob_volume.device, dtype=torch.long) 17 | n_coords = n_coords.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 18 | 19 | x_coords = torch.arange(W, device=prob_volume.device, dtype=torch.long) 20 | y_coords = torch.arange(H, device=prob_volume.device, dtype=torch.long) 21 | y_coords = y_coords.unsqueeze(-1) 22 | # xl, yl = torch.meshgrid((x_coords, y_coords)) # with torch >= 0.4.1 23 | 24 | nl = n_coords.repeat((1, 1, H, W)).long() 25 | xl = x_coords.repeat((N, 1, H, 1)).long() 26 | yl = y_coords.repeat((N, 1, 1, W)).long() 27 | 28 | cv_confidence = prob_volume[nl, torch.round(disps).long(), yl, xl] 29 | return cv_confidence 30 | 31 | class LrDistance(nn.Module): 32 | def __init__(self, device): 33 | super(LrDistance, self).__init__() 34 | self._device = device 35 | 36 | def forward(self, disps_lr, disps_rl): 37 | dispsLR = disps_lr.float() 38 | dispsRL = disps_rl.float() 39 | 40 | S, _, M, N = dispsLR.shape 41 | 42 | # generate coordinates 43 | x_coords = torch.arange(N, device=self._device, dtype=torch.float) 44 | y_coords = torch.arange(M, device=self._device, dtype=torch.float) 45 | 46 | xl = x_coords.repeat((S, M, 1)).unsqueeze(1) 47 | yl = y_coords.repeat((S, N, 1)).transpose(2, 1).unsqueeze(1) 48 | 49 | xr = xl - dispsLR 50 | 51 | # normalize coordinates for sampling between [-1, 1] 52 | xr_normed = (2 * xr / (N - 1)) - 1.0 53 | yl_normed = (2 * yl / (M - 1)) - 1.0 54 | 55 | # coords must have sahpe N x OH x OW x 2 56 | sample_coords = torch.stack((xr_normed[:, 0], yl_normed[:, 0]), dim=-1) 57 | dispsRL_warped = nn.functional.grid_sample(dispsRL, sample_coords) 58 | 59 | lr_distance = torch.abs(dispsLR + dispsRL_warped) 60 | lr_distance[(xr >= N) | (xr < 0)] = 100.0 61 | 62 | # import matplotlib.pyplot as plt 63 | # plt.ion() 64 | # plt.figure(), plt.imshow(disps_lr[0,0].detach().cpu()), plt.title('LR') 65 | # plt.figure(), plt.imshow(disps_rl[0,0].detach().cpu()), plt.title('RL') 66 | # plt.figure(), plt.imshow(dispsRL_warped[0,0].detach().cpu()), plt.title('RL WARPED') 67 | # plt.figure(), plt.imshow(lr_distance[0,0].detach().cpu()), plt.title('DIST') 68 | 69 | # import pdb 70 | # pdb.set_trace() 71 | 72 | return lr_distance 73 | 74 | class LrCheck(nn.Module): 75 | def __init__(self, device, eps): 76 | super(LrCheck, self).__init__() 77 | self._device = device 78 | self._eps = eps 79 | 80 | def forward(self, lr_dists): 81 | lr_mask = torch.ones_like(lr_dists) 82 | lr_mask[lr_dists > self._eps] = 0.0 83 | 84 | zero = torch.tensor([0], dtype=torch.float).to(self._device) 85 | confidence = torch.max(self._eps - lr_dists, zero) / self._eps 86 | 87 | # import matplotlib.pyplot as plt 88 | # plt.ion() 89 | # plt.imshow(confidence.detach()[0,0]), plt.title('conf') 90 | # plt.figure(), plt.imshow(lr_mask.detach()[0,0]), plt.title('mask') 91 | 92 | # import pdb 93 | # pdb.set_trace() 94 | 95 | return lr_mask, confidence 96 | 97 | class TemperatureSoftmax(nn.Module): 98 | def __init__(self, dim, init_temp=1.0): 99 | nn.Module.__init__(self) 100 | self.T = nn.Parameter(torch.ones(1).float().cuda() * init_temp, requires_grad=True) 101 | self.dim = dim 102 | 103 | def forward(self, x): 104 | return F.softmax(x / self.T, dim=self.dim) 105 | 106 | class TemperatureSoftmin(nn.Module): 107 | def __init__(self, dim, init_temp=1.0): 108 | nn.Module.__init__(self) 109 | self.T = nn.Parameter(torch.ones(1).float().cuda() * init_temp, requires_grad=True) 110 | self.dim = dim 111 | 112 | def forward(self, x): 113 | return F.softmin(x / self.T, dim=self.dim) 114 | 115 | class Pad(nn.Module): 116 | def __init__(self, divisor, extra_pad=(0, 0)): 117 | nn.Module.__init__(self) 118 | self.divisor = divisor 119 | self.extra_pad_h = extra_pad[0] # pad at top and bottom with specified value 120 | self.extra_pad_w = extra_pad[1] # pad at left and right with specified value 121 | 122 | self.l = 0 123 | self.r = 0 124 | self.t = 0 125 | self.b = 0 126 | 127 | def pad(self, x): 128 | N, C, H, W = x.shape 129 | 130 | w_add = 0 131 | while W % self.divisor != 0: 132 | W += 1 133 | w_add += 1 134 | 135 | h_add = 0 136 | while H % self.divisor != 0: 137 | H += 1 138 | h_add += 1 139 | 140 | # additionally pad kitti imgs 141 | self.l = self.extra_pad_w + np.ceil(w_add / 2.0).astype('int') 142 | self.r = self.extra_pad_w + np.floor(w_add / 2.0).astype('int') 143 | self.t = self.extra_pad_h + np.ceil(h_add / 2.0).astype('int') 144 | self.b = self.extra_pad_h + np.floor(h_add / 2.0).astype('int') 145 | 146 | padded = F.pad(x, (self.l, self.r, self.t, self.b), mode='reflect') 147 | return padded 148 | 149 | def forward(self, x): 150 | return self.pad(x) 151 | 152 | class Unpad(nn.Module): 153 | def __init__(self): 154 | nn.Module.__init__(self) 155 | 156 | def unpad_NCHW(self, x, l, r, t, b): 157 | x = x[:, :, t:, l:] 158 | if b > 0: 159 | x = x[:, :, :-b, :] 160 | if r > 0: 161 | x = x[:, :, :, :-r] 162 | return x.contiguous() 163 | 164 | def unpad_NHWC(self, x, l, r, t, b): 165 | x = x[:, t:, l:, :] 166 | if b > 0: 167 | x = x[:, :-b, :, :] 168 | if r > 0: 169 | x = x[:, :, :-r, :] 170 | return x.contiguous() 171 | 172 | 173 | def unpad_flow_N2HWC(self, x, l, r, t, b): 174 | x = x[:, :, t:, l:, :] 175 | if b > 0: 176 | x = x[:, :, :-b, :, :] 177 | if r > 0: 178 | x = x[:, :, :, :-r, :] 179 | return x.contiguous() 180 | 181 | def unpad_flow_N2CHW(self, x, l, r, t, b): 182 | x = x[:, :, :, t:, l:] 183 | if b > 0: 184 | x = x[:, :, :, :-b, :] 185 | if r > 0: 186 | x = x[:, :, :, :, :-r] 187 | return x.contiguous() 188 | 189 | def forward(self, x, l, r, t, b, NCHW=True): 190 | if NCHW: 191 | if len(x.shape) == 4: 192 | return self.unpad_NCHW(x, l, r, t, b) 193 | else: 194 | return self.unpad_flow_N2CHW(x, l, r, t, b) 195 | else: 196 | if len(x.shape) == 4: 197 | return self.unpad_NHWC(x, l, r, t, b) 198 | else: 199 | return self.unpad_flow_N2HWC(x, l, r, t, b) 200 | 201 | class PadUnpad(Pad, Unpad): 202 | def __init__(self, net, divisor=1, extra_pad=(0, 0)): 203 | Pad.__init__(self, divisor, extra_pad) 204 | Unpad.__init__(self) 205 | 206 | self.net = net 207 | 208 | def forward(self, ipt): 209 | out = self.net.forward(self.pad(ipt)) 210 | res = [] 211 | for o in out: 212 | res.append(self.unpad_NCHW(o, self.l, self.r, self.t, self.b)) 213 | return res 214 | 215 | class PaddedConv2d(nn.Module): 216 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=False): 217 | super(PaddedConv2d, self).__init__() 218 | 219 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 220 | stride=stride, padding=0, dilation=dilation, bias=bias) 221 | 222 | self.dilation = dilation 223 | 224 | def forward(self, x): 225 | pd = (self.conv.kernel_size[0] // 2) * self.dilation 226 | x = F.pad(x, (pd, pd, pd, pd), mode='reflect') 227 | x = self.conv(x) 228 | return x 229 | 230 | class ResidualBlock(nn.Module): 231 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, 232 | activation='ReLU', transposed=False, with_bn=True, leaky_alpha=1e-2): 233 | super(ResidualBlock, self).__init__() 234 | 235 | self.convbnact1 = ConvBatchNormAct(in_channels, out_channels, kernel_size, stride, dilation, 236 | padding, activation, transposed, with_bn) 237 | 238 | bias = False # because the parameter is already contained in group-norm!! 239 | self.conv2 = PaddedConv2d(out_channels, out_channels, kernel_size, stride, dilation, bias) 240 | num_groups = out_channels 241 | self.bn2 = nn.GroupNorm(num_groups, out_channels, affine=True) 242 | 243 | if activation.lower() == 'relu': 244 | self.act2 = nn.ReLU(inplace=True) 245 | elif activation.lower() == 'leakyrelu': 246 | self.act2 = nn.LeakyReLU(negative_slope=leaky_alpha, inplace=True) 247 | elif activation.lower() == 'elu': 248 | self.act2 = nn.ELU(inplace=True) 249 | else: 250 | raise NotImplementedError("Activation " + activation + " is currently not implemented!") 251 | 252 | def forward(self, x): 253 | residual = x 254 | x = self.convbnact1(x) 255 | 256 | x = self.conv2(x) 257 | x = self.bn2(x) 258 | 259 | x += residual 260 | x = self.act2(x) 261 | return x 262 | 263 | 264 | 265 | class ConvBatchNormAct(nn.Module): 266 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, 267 | activation='ReLU', transposed=False, with_bn=True): 268 | super(ConvBatchNormAct, self).__init__() 269 | 270 | if with_bn: 271 | bias = False 272 | else: 273 | bias = True 274 | 275 | if transposed: 276 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, 277 | stride=stride, padding=padding, dilation=dilation) 278 | else: 279 | self.conv = PaddedConv2d(in_channels, out_channels, kernel_size, stride=stride, 280 | dilation=dilation, bias=bias) 281 | 282 | self.bn = None 283 | if with_bn: 284 | #self.bn = nn.BatchNorm2d(out_channels, affine=True, track_running_stats=False) 285 | num_groups = out_channels 286 | self.bn = nn.GroupNorm(num_groups, out_channels, affine=True) 287 | 288 | 289 | if activation.lower() == 'relu': 290 | self.act = nn.ReLU(inplace=True) 291 | elif activation.lower() == 'leakyrelu': 292 | self.act = nn.LeakyReLU(inplace=True) 293 | elif activation.lower() == 'elu': 294 | self.act = nn.ELU(inplace=True) 295 | else: 296 | raise NotImplementedError("Activation " + activation + " is currently not implemented!") 297 | 298 | def forward(self, x): 299 | x = self.conv(x) 300 | if self.bn: 301 | x = self.bn(x) 302 | x = self.act(x) 303 | return x 304 | 305 | 306 | class BaseNet(nn.Module): 307 | def __init__(self): 308 | super(BaseNet, self).__init__() 309 | 310 | self._divisor = -1 311 | self._out_channels = -1 312 | 313 | @property 314 | def divisor(self): 315 | return self._divisor 316 | 317 | @property 318 | def out_channels(self): 319 | return self._out_channels 320 | 321 | @property 322 | def num_output_levels(self): 323 | return self._num_output_levels 324 | 325 | def forward(self, *input): 326 | raise NotImplementedError 327 | 328 | 329 | class StereoUnaryUnetDyn(BaseNet): 330 | def __init__(self, multi_level_output=False, activation='relu', 331 | with_bn=True, with_upconv=False, with_output_bn=True): 332 | super(StereoUnaryUnetDyn, self).__init__() 333 | 334 | self.multi_level_output = multi_level_output 335 | self.with_upconv = with_upconv 336 | self.with_output_bn = with_output_bn 337 | self._divisor = 8 338 | self._out_channels = [64] 339 | if multi_level_output: 340 | self._out_channels = [64, 64, 64] 341 | 342 | self._num_output_levels = 1 343 | if multi_level_output: 344 | self._num_output_levels = 3 345 | 346 | self._out_channels = [32] 347 | if multi_level_output: 348 | self._out_channels = [32, 32, 32] 349 | 350 | self.conv1 = ConvBatchNormAct(3, 16, 3, padding=1, activation=activation, with_bn=with_bn) 351 | self.conv2 = ConvBatchNormAct(16, 16, 3, padding=1, activation=activation, with_bn=with_bn) 352 | self.pool1 = nn.MaxPool2d(2) 353 | 354 | self.conv3 = ConvBatchNormAct(16, 32, 3, padding=1, activation=activation, with_bn=with_bn) 355 | self.conv4 = ConvBatchNormAct(32, 32, 3, padding=1, activation=activation, with_bn=with_bn) 356 | self.pool2 = nn.MaxPool2d(2) 357 | 358 | self.conv5 = ConvBatchNormAct(32, 64, 3, padding=1, activation=activation, with_bn=with_bn) 359 | self.conv6 = ConvBatchNormAct(64, 64, 3, padding=1, activation=activation, with_bn=with_bn) 360 | 361 | if self.with_upconv: 362 | self.upconv7 = ConvBatchNormAct(64, 32, 3, stride=2, padding=0, activation=activation, 363 | transposed=True, with_bn=with_bn) 364 | self.conv8 = ConvBatchNormAct(64, 32, 3, padding=1, activation=activation, 365 | with_bn=with_bn) 366 | else: 367 | self.conv8 = ConvBatchNormAct(96, 32, 3, padding=1, activation=activation, 368 | with_bn=with_bn) 369 | self.conv9 = ConvBatchNormAct(32, 32, 3, padding=1, activation=activation, with_bn=with_bn) 370 | 371 | if self.with_upconv: 372 | self.upconv10 = ConvBatchNormAct(32, 16, 3, stride=2, padding=0, activation=activation, 373 | transposed=True, with_bn=with_bn) 374 | self.conv11 = ConvBatchNormAct(32, 32, 3, padding=1, activation=activation, 375 | with_bn=with_bn) 376 | else: 377 | self.conv11 = ConvBatchNormAct(48, 32, 3, padding=1, activation=activation, 378 | with_bn=with_bn) 379 | 380 | self.conv12 = PaddedConv2d(32, 32, 3) 381 | self.bn12 = None 382 | if with_bn and with_output_bn: 383 | # self.bn12 = nn.BatchNorm2d(32, affine=True, track_running_stats=False) 384 | self.bn12 = nn.GroupNorm(32, 32, affine=True) 385 | 386 | if self.multi_level_output: 387 | self.conv_lvl1 = PaddedConv2d(32, 32, 3) 388 | self.conv_lvl2 = PaddedConv2d(64, 32, 3) 389 | 390 | def forward(self, x_in): 391 | x = x_in 392 | x = self.conv1(x) 393 | lvl0 = self.conv2(x) 394 | x = self.pool1(lvl0) 395 | 396 | x = self.conv3(x) 397 | lvl1 = self.conv4(x) 398 | x = self.pool2(lvl1) 399 | 400 | x = self.conv5(x) 401 | x = self.conv6(x) 402 | 403 | if self.multi_level_output: 404 | lvl2_out = self.conv_lvl2(x) 405 | 406 | if self.with_upconv: 407 | x = self.upconv7(x)[:, :, 1:, 1:] 408 | else: 409 | x = F.interpolate(x, lvl1.shape[2:], mode='bilinear') 410 | x = torch.cat([lvl1, x], dim=1) 411 | 412 | x = self.conv8(x) 413 | x = self.conv9(x) 414 | 415 | if self.multi_level_output: 416 | lvl1_out = self.conv_lvl1(x) 417 | 418 | if self.with_upconv: 419 | x = self.upconv10(x)[:, :, :-1, :-1] 420 | else: 421 | x = F.interpolate(x, lvl0.shape[2:], mode='bilinear') 422 | x = torch.cat([lvl0, x], dim=1) 423 | 424 | x = self.conv11(x) 425 | x = self.conv12(x) 426 | if self.bn12: 427 | x = self.bn12(x) 428 | 429 | if self.multi_level_output: 430 | return x, lvl1_out, lvl2_out 431 | 432 | return [x] -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from imageio import imread 3 | from skimage.color import rgb2lab 4 | from skimage.transform import rescale 5 | import numpy as np 6 | 7 | 8 | def get_lvl_name(lvl): 9 | return '_lvl' + str(lvl) 10 | 11 | def load_sample(left_path, right_path): 12 | # load left/right image 13 | i0 = imread(left_path) 14 | if i0.shape[2] == 4: 15 | i0 = i0[:, :, :3] # remove alpha channel 16 | 17 | i1 = imread(right_path) 18 | if i1.shape[2] == 4: 19 | i1 = i1[:, :, :3] # remove alpha channel 20 | 21 | # construct sample 22 | sample = {'i0': i0, 'i1': i1} 23 | 24 | # apply transforms 25 | sample['i0_lvl0'] = rgb2lab(sample['i0']).astype('float32') 26 | sample['i1_lvl0'] = rgb2lab(sample['i1']).astype('float32') 27 | for lvl in range(2): 28 | scale = 1.0 / 2**(lvl+1) 29 | sample['i0_lvl' + str(lvl + 1)] = rgb2lab(rescale(sample['i0'], scale, order=1, anti_aliasing=True, 30 | mode='reflect', multichannel=True)).astype('float32') 31 | sample['i1_lvl' + str(lvl + 1)] = rgb2lab(rescale(sample['i1'], scale, order=1, anti_aliasing=True, 32 | mode='reflect', multichannel=True)).astype('float32') 33 | 34 | for key in sample.keys(): 35 | sample[key] = torch.from_numpy(sample[key].transpose(2, 0, 1)).unsqueeze(0) 36 | 37 | # construct image pyramid 38 | I0_pyramid = [] 39 | I1_pyramid = [] 40 | for lvl in range(3): 41 | I0_pyramid.append(sample['i0' + get_lvl_name(lvl)]) 42 | I1_pyramid.append(sample['i1' + get_lvl_name(lvl)]) 43 | 44 | return I0_pyramid, I1_pyramid 45 | 46 | 47 | def readFlow(name): 48 | f = open(name, 'rb') 49 | 50 | header = f.read(4) 51 | if header.decode("utf-8") != 'PIEH': 52 | raise Exception('Flow file header does not contain PIEH') 53 | 54 | width = np.fromfile(f, np.int32, 1).squeeze() 55 | height = np.fromfile(f, np.int32, 1).squeeze() 56 | 57 | flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) 58 | 59 | return flow.astype(np.float32) -------------------------------------------------------------------------------- /data/000010_10_left.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/000010_10_left.png -------------------------------------------------------------------------------- /data/000010_10_right.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/000010_10_right.png -------------------------------------------------------------------------------- /data/frame_0019.flo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/frame_0019.flo -------------------------------------------------------------------------------- /data/frame_0019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/frame_0019.png -------------------------------------------------------------------------------- /data/frame_0020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/frame_0020.png -------------------------------------------------------------------------------- /data/frankfurt_val.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/frankfurt_val.png -------------------------------------------------------------------------------- /data/im0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/im0.png -------------------------------------------------------------------------------- /data/im1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/im1.png -------------------------------------------------------------------------------- /data/output/flow/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/output/flow/.keep -------------------------------------------------------------------------------- /data/output/semantic/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/output/semantic/.keep -------------------------------------------------------------------------------- /data/output/semantic/sem_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/output/semantic/sem_pred.png -------------------------------------------------------------------------------- /data/output/stereo/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/output/stereo/.keep -------------------------------------------------------------------------------- /data/params/flow/BP+MS (H)/affinity_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/affinity_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS (H)/crf0_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/crf0_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS (H)/crf0_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/crf0_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS (H)/crf0_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/crf0_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS (H)/matching_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/matching_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS (H)/matching_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/matching_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS (H)/matching_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/matching_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS (H)/unary_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS (H)/unary_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS+Ref (H)/affinity_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/affinity_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS+Ref (H)/crf0_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/crf0_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS+Ref (H)/crf0_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/crf0_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS+Ref (H)/crf0_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/crf0_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS+Ref (H)/matching_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/matching_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS+Ref (H)/matching_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/matching_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS+Ref (H)/matching_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/matching_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS+Ref (H)/refinement_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/refinement_best.cpt -------------------------------------------------------------------------------- /data/params/flow/BP+MS+Ref (H)/unary_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/flow/BP+MS+Ref (H)/unary_best.cpt -------------------------------------------------------------------------------- /data/params/semantic/global_model.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/semantic/global_model.cpt -------------------------------------------------------------------------------- /data/params/semantic/pixel_model.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/semantic/pixel_model.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (H)/affinity_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/affinity_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (H)/crf0_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/crf0_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (H)/crf0_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/crf0_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (H)/crf0_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/crf0_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (H)/matching_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/matching_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (H)/matching_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/matching_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (H)/matching_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/matching_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (H)/unary_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (H)/unary_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (NLL)/affinity_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/affinity_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (NLL)/crf0_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/crf0_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (NLL)/crf0_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/crf0_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (NLL)/crf0_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/crf0_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (NLL)/matching_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/matching_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (NLL)/matching_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/matching_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (NLL)/matching_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/matching_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS (NLL)/unary_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS (NLL)/unary_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS+Ref (H)/affinity_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/affinity_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS+Ref (H)/crf0_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/crf0_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS+Ref (H)/crf0_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/crf0_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS+Ref (H)/crf0_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/crf0_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS+Ref (H)/matching_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/matching_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS+Ref (H)/matching_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/matching_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS+Ref (H)/matching_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/matching_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS+Ref (H)/refinement_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/refinement_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/BP+MS+Ref (H)/unary_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/BP+MS+Ref (H)/unary_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/Kitti/affinity_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/affinity_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/Kitti/crf0_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/crf0_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/Kitti/crf0_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/crf0_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/Kitti/crf0_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/crf0_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/Kitti/matching_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/matching_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/Kitti/matching_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/matching_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/Kitti/matching_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/matching_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/Kitti/unary_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/Kitti/unary_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/MB/affinity_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/affinity_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/MB/crf0_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/crf0_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/MB/crf0_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/crf0_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/MB/crf0_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/crf0_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/MB/matching_lvl0_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/matching_lvl0_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/MB/matching_lvl1_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/matching_lvl1_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/MB/matching_lvl2_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/matching_lvl2_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/MB/unary_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/MB/unary_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/WTA (NLL)/matching_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/WTA (NLL)/matching_best.cpt -------------------------------------------------------------------------------- /data/params/stereo/WTA (NLL)/unary_best.cpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/params/stereo/WTA (NLL)/unary_best.cpt -------------------------------------------------------------------------------- /data/sf_0006_left.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/sf_0006_left.png -------------------------------------------------------------------------------- /data/sf_0006_right.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/data/sf_0006_right.png -------------------------------------------------------------------------------- /flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from networks import FeatureNet, AffinityNet, RefinementNet 6 | from flow_matching import FlowMatchingSad 7 | from ops.lbp_stereo.bp_op_cuda import BP 8 | 9 | from corenet import CvConfidence#, LrCheck, LrDistance, 10 | from corenet import Pad, Unpad 11 | 12 | class FlowMethod(nn.Module): 13 | def __init__(self, device, args): 14 | nn.Module.__init__(self) 15 | self.args = args 16 | 17 | self._feature_net = None 18 | self._matching = [] 19 | self._affinity_net = None 20 | self._refinement_net = None 21 | self._crf = [] # list of bp layers 22 | 23 | self.cv_conf = CvConfidence(device).to(device) 24 | 25 | self._sws = args.sws 26 | 27 | self.pad = None 28 | self.unpad = None 29 | 30 | self.device = device 31 | 32 | def forward(self, I0_pyramid, I1_pyramid, beliefs_in=None, sws=None): 33 | 34 | # necessary for evaluation 35 | if self.pad is None: 36 | self.pad = Pad(self.feature_net.net.divisor, self.args.pad) 37 | if self.unpad is None: 38 | self.unpad = Unpad() 39 | 40 | res_dict = {'flow0': None} 41 | 42 | I0_in = I0_pyramid[self.args.input_level_offset].to(self.device) 43 | I1_in = I1_pyramid[self.args.input_level_offset].to(self.device) 44 | 45 | # pad input for multi-scale (for evaluation) 46 | I0_in = self.pad.forward(I0_in) 47 | I1_in = self.pad.forward(I1_in) 48 | 49 | f0_pyramid = self.extract_features(I0_in) 50 | f1_pyramid = self.extract_features(I1_in) 51 | 52 | # if sws is not None: 53 | # self.matching.sws = sws 54 | 55 | # multi-scale-matching 56 | prob_vol_pyramid = self.match(f0_pyramid, f1_pyramid) 57 | 58 | uflow0_pyramid = [] 59 | for pv0 in prob_vol_pyramid: 60 | uflow0_pyramid.append(torch.argmax(pv0, dim=-1)) 61 | res_dict['flow0'] = uflow0_pyramid 62 | 63 | if not self._crf: 64 | return res_dict 65 | 66 | affinity_pyramid = None 67 | if self.affinity_net: 68 | affinity_pyramid = self.extract_affinities(I0_in) 69 | for lvl in range(len(affinity_pyramid)): 70 | _, _, h, w = affinity_pyramid[lvl].shape 71 | affinity_pyramid[lvl] = affinity_pyramid[lvl].view((-1, 2, 5, h, w)) 72 | affinity_pyramid[lvl] = affinity_pyramid[lvl].unsqueeze(0) 73 | 74 | output_flow_pyramid = [] 75 | beliefs_pyramid = None 76 | 77 | crf_flow_pyramid = [] 78 | beliefs_pyramid = [] 79 | beliefs_in = None 80 | for lvl in reversed(range(len(prob_vol_pyramid))): 81 | pv_lvl = prob_vol_pyramid[lvl] 82 | m = self.matching[lvl] 83 | 84 | affinity = None 85 | if affinity_pyramid is not None: 86 | affinity = affinity_pyramid[lvl] 87 | crf = self.crf[lvl] 88 | 89 | # add probably an if condition whether do add multi-scale to crf 90 | if beliefs_in is not None: 91 | N,_,H,W,K = beliefs_in.shape 92 | size = (2*H, 2*W, 2*K-1) 93 | beliefs_in_u = F.interpolate(beliefs_in[:, 0].unsqueeze(1), size=size, mode='trilinear')[:, 0] 94 | beliefs_in_v = F.interpolate(beliefs_in[:, 1].unsqueeze(1), size=size, mode='trilinear')[:, 0] 95 | beliefs_in = torch.cat((beliefs_in_u.unsqueeze(1), beliefs_in_v.unsqueeze(1)), dim=1).contiguous() 96 | pv_lvl = pv_lvl + beliefs_in / 2.0 97 | 98 | flow_lvl, beliefs_lvl, affinities_lvl, offsets_lvl = self.optimize_crf(crf, pv_lvl, None, affinity, None) 99 | 100 | if lvl == 0: # TODO FOR EVAL!!!! 101 | beliefs_lvl = self.unpad(beliefs_lvl, self.pad.l, self.pad.r, self.pad.t, 102 | self.pad.b, NCHW=False) 103 | flow_lvl = self.unpad(flow_lvl, self.pad.l, self.pad.r, self.pad.t, 104 | self.pad.b) 105 | 106 | beliefs_pyramid.append(beliefs_lvl) 107 | crf_flow_pyramid.append(flow_lvl - m.sws // 2) 108 | 109 | beliefs_in = beliefs_pyramid[-1] 110 | 111 | # beliefs are from low res to high res 112 | beliefs_pyramid.reverse() 113 | crf_flow_pyramid.reverse() 114 | output_flow_pyramid = crf_flow_pyramid 115 | res_dict['flow0'] = crf_flow_pyramid 116 | 117 | if self.refinement_net: 118 | # crf 119 | cv_conf_u = self.cv_conf.forward(beliefs_pyramid[0][:,0].permute(0, 3, 1, 2), 120 | crf_flow_pyramid[0][:,0:1] + m.sws // 2) 121 | cv_conf_v = self.cv_conf.forward(beliefs_pyramid[0][:,1].permute(0, 3, 1, 2), 122 | crf_flow_pyramid[0][:,1:2] + m.sws // 2) 123 | 124 | conf_all = torch.cat((cv_conf_u, cv_conf_v), dim=1) 125 | refined_flow_pyramid, _ = self.refine_disps(I0_pyramid, 126 | crf_flow_pyramid[0], 127 | confidence=conf_all, 128 | I1=I1_pyramid) 129 | refined_flow_pyramid.reverse() 130 | output_flow_pyramid = refined_flow_pyramid 131 | 132 | res_dict['flow0'] = output_flow_pyramid 133 | 134 | return res_dict 135 | 136 | def extract_features(self, ipt): 137 | if self.feature_net: 138 | return self.feature_net.forward(ipt) 139 | return None 140 | 141 | def compute_guidance(self, ipt): 142 | if self.guidance_net: 143 | return self.guidance_net.forward(ipt) 144 | return None 145 | 146 | def extract_edges(self, ipt): 147 | if self.edge_net: 148 | return self.edge_net.forward(ipt) 149 | return None 150 | 151 | def extract_affinities(self, ipt): 152 | if self.affinity_net: 153 | return self.affinity_net.forward(ipt) 154 | return None 155 | 156 | def extract_offsets(self, ipt): 157 | if self.offset_net: 158 | return self.offset_net.forward(ipt) 159 | return None 160 | 161 | def match(self, f0, f1): 162 | prob_vols = [] 163 | if self.matching: 164 | for matching, f0s, f1s in zip(self.matching, f0, f1): 165 | prob_vols.append(matching.forward(f0s, f1s)) 166 | return prob_vols 167 | return None 168 | 169 | def optimize_crf(self, crf_layer, prob_vol, weights, affinities, offsets): 170 | if crf_layer: 171 | # iterate over all bp "layers" 172 | for idx, crf in enumerate(crf_layer): 173 | #TODO take care of BP layer idx in adjust functions 174 | prob_vol = prob_vol.contiguous() 175 | weights_input = crf.adjust_input_weights(weights, idx) 176 | affinities_shift = crf.adjust_input_affinities(affinities[:,idx]) 177 | offsets_shift = crf.adjust_input_offsets(offsets) 178 | 179 | disps, prob_vol, messages = crf.forward(prob_vol, weights_input, affinities_shift, offsets_shift) 180 | 181 | return disps, prob_vol, affinities_shift, offsets_shift 182 | return None 183 | 184 | def refine_disps(self, I0, d0, confidence=None, I1=None): 185 | if self.refinement_net: 186 | refined, steps = self.refinement_net.forward(I0, d0, confidence, I1) 187 | return refined, steps 188 | return None 189 | 190 | def feature_net_params(self, requires_grad=None): 191 | if self.feature_net: 192 | return self.feature_net.parameter_list(requires_grad) 193 | return [] 194 | 195 | def matching_params(self, requires_grad=None): 196 | params = [] 197 | if self.matching: 198 | for m in self.matching: 199 | params += m.parameter_list(requires_grad) 200 | return params 201 | 202 | def affinity_net_params(self, requires_grad=None): 203 | if self.affinity_net: 204 | return self.affinity_net.parameter_list(requires_grad) 205 | return [] 206 | 207 | def crf_params(self, requires_grad=None): 208 | crf_params = [] 209 | if self.crf: 210 | for crf_layer in self.crf: 211 | for crf in crf_layer: 212 | crf_params += crf.parameter_list(requires_grad) 213 | return crf_params 214 | 215 | def refinement_net_params(self, requires_grad=None): 216 | if self.refinement_net: 217 | return self.refinement_net.parameter_list(requires_grad) 218 | return [] 219 | 220 | @property 221 | def feature_net(self): 222 | return self._feature_net 223 | 224 | @property 225 | def affinity_net(self): 226 | return self._affinity_net 227 | 228 | @property 229 | def offset_net(self): 230 | return self._offset_net 231 | 232 | @property 233 | def crf(self): 234 | if self._crf == []: 235 | return None 236 | return self._crf 237 | 238 | @property 239 | def refinement_net(self): 240 | return self._refinement_net 241 | 242 | @property 243 | def matching(self): 244 | return self._matching 245 | 246 | @property 247 | def min_disp(self): 248 | return self._min_disp 249 | 250 | @property 251 | def max_disp(self): 252 | return self._max_disp 253 | 254 | 255 | #################################################################################################### 256 | # Block Match 257 | #################################################################################################### 258 | class BlockMatchFlow(FlowMethod): 259 | def __init__(self, device, args): 260 | FlowMethod.__init__(self, device, args) 261 | self._feature_net = FeatureNet(device, args) 262 | 263 | self._matching = [] 264 | for matching_lvl in range(self._feature_net.net.num_output_levels): 265 | sws = ((self._sws) // 2**matching_lvl) 266 | self._matching.append(FlowMatchingSad(device, args, sws, lvl=matching_lvl)) 267 | if args.matching != 'sad': 268 | print('WARNING: Use SAD matching for flow, but', args.matching, 'was chosen.') 269 | 270 | 271 | #################################################################################################### 272 | # Min-Sum LBP 273 | #################################################################################################### 274 | class MinSumFlow(BlockMatchFlow): 275 | def __init__(self, device, args): 276 | BlockMatchFlow.__init__(self, device, args) 277 | 278 | self.max_iter = args.max_iter 279 | num_labels = self._sws + 1 280 | 281 | self._affinity_net = AffinityNet(device, args) 282 | 283 | for lvl in range(self._feature_net.net.num_output_levels): 284 | self._crf.append([BP(device, args, self.max_iter, num_labels, 3, 285 | mode_inference = args.bp_inference, 286 | mode_message_passing='min-sum', layer_idx=idx, level=lvl) 287 | for idx in range(args.num_bp_layers)]) 288 | 289 | class RefinedMinSumFlow(MinSumFlow): 290 | def __init__(self, device, args): 291 | super(RefinedMinSumFlow, self).__init__(device, args) 292 | 293 | self._refinement_net = RefinementNet(device, args, in_channels=7, out_channels=2, with_output_relu=False) 294 | -------------------------------------------------------------------------------- /flow_matching.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import torch 3 | 4 | from networks import SubNetwork 5 | 6 | from corenet import TemperatureSoftmax 7 | from ops.flow_mp_sad.flow_mp_sad import FlowMpSadFunction 8 | 9 | class FlowMatching(SubNetwork): 10 | def __init__(self, device, args, sws, in_channels=None, lvl=0): 11 | super(FlowMatching, self).__init__(args, device) 12 | self.device = device 13 | self.sws = sws 14 | self.in_channels = in_channels 15 | self.level = lvl 16 | 17 | self._softmax = TemperatureSoftmax(dim=3, init_temp=1.0) 18 | 19 | def compute_score_volume(self, f0, f1): 20 | raise NotImplementedError 21 | 22 | def forward(self, f0, f1): 23 | score_vol_u, score_vol_v = self.compute_score_volume(f0, f1) 24 | prob_vol_u = self._softmax.forward(score_vol_u) 25 | prob_vol_v = self._softmax.forward(score_vol_v) 26 | prob_vol_uv = torch.cat((prob_vol_u.unsqueeze(1), prob_vol_v.unsqueeze(1)), dim=1) 27 | return prob_vol_uv.contiguous() 28 | 29 | @staticmethod 30 | def argmin_to_disp(argmin, min_disp): 31 | res = argmin + min_disp 32 | return res 33 | 34 | def save_checkpoint(self, epoch, iteration): 35 | if 'u' in self.args.train_params: 36 | torch.save(self.state_dict(), 37 | osp.join(self.args.train_dir, 'matching_lvl' + str(self.level) + 38 | '_checkpoint_' + str(epoch) + '_' + str(iteration).zfill(6) + '.cpt')) 39 | 40 | 41 | class FlowMatchingSad(FlowMatching): 42 | def __init__(self, device, args, sws, lvl=0): 43 | super(FlowMatchingSad, self).__init__(device, args, sws, lvl=lvl) 44 | 45 | if args.checkpoint_matching: # not-empty check 46 | lvl = min(self.level, len(args.checkpoint_matching) - 1) 47 | self.load_parameters(args.checkpoint_matching[lvl], device) 48 | self.to(device) 49 | 50 | def compute_score_volume(self, f0, f1): 51 | cv_u, cv_v, amin_u, amin_v = FlowMpSadFunction.apply(f0, f1, self.sws) 52 | return -cv_u, -cv_v 53 | -------------------------------------------------------------------------------- /flow_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.transform import resize 3 | from skimage.color import rgb2gray 4 | from scipy.ndimage import map_coordinates 5 | from scipy.ndimage.filters import correlate 6 | 7 | def makeColorWheel(): 8 | RY = 15 9 | YG = 6 10 | GC = 4 11 | CB = 11 12 | BM = 13 13 | MR = 6 14 | 15 | size = RY + YG + GC + CB + BM + MR 16 | colorwheel = np.zeros((3, size)) 17 | 18 | col = 0 19 | # RY 20 | colorwheel[0, col:col+RY] = 255 21 | colorwheel[1, col:col+RY] = np.floor(255 * np.arange(RY)/RY) 22 | col += RY 23 | 24 | # YG 25 | colorwheel[0, col:col+YG] = 255 - np.floor(255 * np.arange(YG)/YG) 26 | colorwheel[1, col:col+YG] = 255 27 | col += YG 28 | 29 | # GC 30 | colorwheel[1, col:col+GC] = 255 31 | colorwheel[2, col:col+GC] = np.floor(255 * np.arange(GC)/GC) 32 | col += GC 33 | 34 | # CB 35 | colorwheel[1, col:col+CB] = 255 - np.floor(255 * np.arange(CB)/CB) 36 | colorwheel[2, col:col+CB] = 255 37 | col += CB 38 | 39 | # BM 40 | colorwheel[0, col:col+BM] = np.floor(255 * np.arange(BM)/BM) 41 | colorwheel[2, col:col+BM] = 255 42 | col += BM 43 | 44 | # MR 45 | colorwheel[0, col:col+MR] = 255 46 | colorwheel[2, col:col+MR] = 255 - np.floor(255 * np.arange(MR)/MR) 47 | 48 | return colorwheel.astype('uint8') 49 | 50 | def computeNormalizedFlow(u, v, u_ref=None, v_ref=None, verbose=False): 51 | # copy to not overwrite the inputs 52 | u = u.copy() 53 | v = v.copy() 54 | 55 | eps = 1e-15 56 | UNKNOWN_FLOW_THRES = 1e9 57 | # UNKNOWN_FLOW = 1e10 58 | 59 | maxu = -999 60 | maxv = -999 61 | minu = 999 62 | minv = 999 63 | maxrad = -1 64 | 65 | # fix unknown flow 66 | idxUnknown = np.logical_or(np.abs(u) > UNKNOWN_FLOW_THRES, np.abs(v) > UNKNOWN_FLOW_THRES) 67 | u[idxUnknown] = 0 68 | v[idxUnknown] = 0 69 | 70 | maxu = np.maximum(maxu, np.max(u)) 71 | minu = np.minimum(minu, np.min(u)) 72 | 73 | maxv = np.maximum(maxv, np.max(v)) 74 | minv = np.minimum(minv, np.min(v)) 75 | 76 | if u_ref is not None and v_ref is not None: 77 | rad = np.sqrt(u_ref**2 + v_ref**2) 78 | else: 79 | rad = np.sqrt(u**2 + v**2) 80 | maxrad = np.maximum(maxrad, np.max(rad)) 81 | 82 | if verbose: 83 | print("max flow: ", maxrad, " flow range: u = ", minu, "..", maxu, "v = ", minv, "..", maxv) 84 | 85 | u = u / (maxrad + eps) 86 | v = v / (maxrad + eps) 87 | 88 | return u, v 89 | 90 | def computeFlowImg(u, v, u_ref=None, v_ref=None): 91 | # do not overwrite input flow! 92 | u = u.copy() 93 | v = v.copy() 94 | 95 | u, v = computeNormalizedFlow(u, v, u_ref, v_ref) 96 | 97 | nanIdx = np.logical_or(np.isnan(u), np.isnan(v)) 98 | u[nanIdx] = 0 99 | v[nanIdx] = 0 100 | 101 | cw = makeColorWheel().T 102 | 103 | M, N = u.shape 104 | img = np.zeros((M, N, 3)).astype('uint8') 105 | 106 | mag = np.sqrt(u**2 + v**2) 107 | 108 | phi = np.arctan2(-v, -u) / np.pi # [-1, 1] 109 | phi_idx = (phi + 1.0) / 2.0 * (cw.shape[0] - 1) 110 | f_phi_idx = np.floor(phi_idx).astype('int') 111 | 112 | c_phi_idx = f_phi_idx + 1 113 | c_phi_idx[c_phi_idx == cw.shape[0]] = 0 114 | 115 | floor = phi_idx - f_phi_idx 116 | 117 | for i in range(cw.shape[1]): 118 | tmp = cw[:, i] 119 | 120 | # linear blend between colors 121 | col0 = tmp[f_phi_idx] / 255.0 # from colorwheel take specified values in phi_idx 122 | col1 = tmp[c_phi_idx] / 255.0 123 | col = (1.0 - floor)*col0 + floor * col1 124 | 125 | # increase saturation for small magnitude 126 | sat_idx = mag <= 1 127 | col[sat_idx] = 1 - mag[sat_idx] * (1 - col[sat_idx]) 128 | 129 | col[np.logical_not(sat_idx)] = col[np.logical_not(sat_idx)] * 0.75 130 | 131 | img[:, :, i] = (np.floor(255.0*col*(1-nanIdx))).astype('uint8') 132 | return img 133 | 134 | -------------------------------------------------------------------------------- /github_imgs/flow_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/flow_example.png -------------------------------------------------------------------------------- /github_imgs/kitti.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/kitti.png -------------------------------------------------------------------------------- /github_imgs/mb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/mb.png -------------------------------------------------------------------------------- /github_imgs/sem_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/sem_input.png -------------------------------------------------------------------------------- /github_imgs/sem_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/sem_pred.png -------------------------------------------------------------------------------- /github_imgs/sf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/sf.png -------------------------------------------------------------------------------- /github_imgs/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VLOGroup/bp-layers/226d67f73396dfd14f1cbb12445e542781256260/github_imgs/teaser.gif -------------------------------------------------------------------------------- /main_flow.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from flow import MinSumFlow, BlockMatchFlow, RefinedMinSumFlow 5 | import data 6 | from flow_tools import computeFlowImg 7 | 8 | import imageio 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--im0', action='store', default="data/sf_0006_left.png", required=False, type=str) 14 | parser.add_argument('--im1', action='store', default="data/sf_0006_right.png", required=False, type=str) 15 | parser.add_argument('--sws', action='store', default=108, type=int) 16 | parser.add_argument('--multi-level-output', action='store_true', default=False) 17 | parser.add_argument('--activation', action='store', choices=['relu', 'leakyrelu', 'elu'], default='leakyrelu') 18 | parser.add_argument('--with-bn', action='store_true', default=False) 19 | parser.add_argument('--with-upconv', action='store_true', default=False) 20 | parser.add_argument('--with-output-bn', action='store_true', default=False) 21 | parser.add_argument('--pad', action='store', default=(0, 0), nargs=2, type=int, 22 | help='extra padding of in height and in width on every side') 23 | 24 | parser.add_argument('--model', action='store', default='bp+ms+ref+h', 25 | choices=['wta', 'bp+ms', 'bp+ms+h', 'bp+ms+ref+h']) 26 | parser.add_argument('--checkpoint-unary', action='store', default=None, type=str) 27 | parser.add_argument('--checkpoint-matching', action='store', default=[], nargs='+', type=str) 28 | parser.add_argument('--checkpoint-affinity', action='store', default=None, type=str) 29 | parser.add_argument('--checkpoint-crf', action='append', default=[], type=str, nargs='+') 30 | parser.add_argument('--checkpoint-refinement', action='store', default=None, type=str) 31 | 32 | parser.add_argument('--lbp-min-disp', action='store_true', default=False) 33 | parser.add_argument('--max-iter', action='store', default=1, type=int) 34 | parser.add_argument('--num-bp-layers', action='store', default=1, type=int) 35 | parser.add_argument('--bp-inference', action='store', default='sub-exp', 36 | choices=['wta', 'expectation', 'sub-exp'], type=str) 37 | 38 | parser.add_argument('--matching', action='store', choices=['corr', 'sad', 'conv3d'], 39 | default='sad', type=str) 40 | 41 | parser.add_argument('--input-level-offset', action='store', default=1, type=int, 42 | help='1 means that level 1 is the input resolution') 43 | parser.add_argument('--output-level-offset', action='store', default=1, type=int, 44 | help="0 means that level 0 (=full res) is the output resolution") 45 | args = parser.parse_args() 46 | 47 | I0_pyramid, I1_pyramid = data.load_sample(args.im0, args.im1) 48 | 49 | args.multi_level_output = True 50 | 51 | device = 'cuda:0' 52 | with torch.no_grad(): 53 | if args.model == 'wta': 54 | model = BlockMatchFlow(device, args) 55 | elif args.model == 'bp+ms': 56 | model = MinSumFlow(device, args) 57 | elif args.model == 'bp+ms+h': 58 | model = MinSumFlow(device, args) 59 | elif args.model == 'bp+ms+ref+h': 60 | model = RefinedMinSumFlow(device, args) 61 | 62 | res_dict = model.to(device).forward(I0_pyramid, I1_pyramid, sws=args.sws) 63 | 64 | 65 | flow = res_dict['flow0'][0].squeeze().float().detach().cpu().numpy() 66 | 67 | gt_flow = data.readFlow("data/frame_0019.flo") # only used for normalizing the flow output! 68 | flow_img = computeFlowImg(flow[0], flow[1], gt_flow[0], gt_flow[1]) 69 | 70 | imageio.imwrite("data/output/flow/" + args.model + ".png", flow_img) -------------------------------------------------------------------------------- /main_semantic.py: -------------------------------------------------------------------------------- 1 | from semantic_segmentation import SemanticNet 2 | 3 | import argparse 4 | import os 5 | import yaml 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import matplotlib.pyplot as plt 10 | from imageio import imsave 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--img', action='store', required=True, type=str) 14 | parser.add_argument('--scale-imgs', action='store', default=0.5, type=float) 15 | parser.add_argument('--checkpoint-semantic', action='store', default=None, type=str) 16 | parser.add_argument('--checkpoint-esp-net', action='store', default=None, type=str) 17 | parser.add_argument('--num-labels', action='store', default=20, type=int) 18 | parser.add_argument('--activation', action='store', choices=['relu', 'leakyrelu', 'elu'], default='leakyrelu') 19 | parser.add_argument('--num-bp-layers', action='store', default=1, type=int) 20 | parser.add_argument('--with-bn', action='store_true', default=True) 21 | parser.add_argument('--with-upconv', action='store_true', default=False) 22 | parser.add_argument('--with-output-bn', action='store_true', default=False) 23 | parser.add_argument('--pad', action='store', default=(0, 0), nargs=2, type=int, help='extra padding of in height and in width on every side') 24 | parser.add_argument('--pairwise-type', action='store', choices=["global", "pixel"], default="global") 25 | parser.add_argument('--multi-level-features', action='store_true', default=False) 26 | parser.add_argument('--with-esp', action='store_true', default=False) 27 | parser.add_argument('--with-edges', action='store_true', default=False) 28 | parser.add_argument('--multi-level-output', action='store_true', default=False) 29 | 30 | 31 | args = parser.parse_args() 32 | 33 | cuda_device = 'cuda:0' 34 | 35 | semantic_model = SemanticNet(cuda_device, args) 36 | 37 | # read input image 38 | test_img_path = args.img 39 | 40 | test_img = cv2.imread(test_img_path).astype(np.float32) 41 | 42 | # normalize input and convert to torch 43 | mean = np.array([72.39231, 82.908936, 73.1584]) 44 | std = np.array([45.31922, 46.152893, 44.914833]) 45 | 46 | test_img -= mean[np.newaxis, np.newaxis, :] 47 | test_img /= std[np.newaxis, np.newaxis, :] 48 | 49 | height, width, _ = test_img.shape 50 | scaled_height = int(height * args.scale_imgs) 51 | scaled_width = int(width * args.scale_imgs) 52 | 53 | test_img = cv2.resize(test_img, (scaled_width, scaled_height)) 54 | 55 | test_img_torch = torch.from_numpy(test_img).to(device=cuda_device).permute(2, 0, 1).unsqueeze(0) / 255.0 56 | 57 | # run model 58 | sem_pred, _, _ = semantic_model.forward(test_img_torch) 59 | 60 | # visualize/save result 61 | 62 | ########## cityscapes visualization from ESP Net: https://github.com/sacmehta/ESPNet ########################### 63 | label_colors = [[128, 64, 128], 64 | [244, 35, 232], 65 | [70, 70, 70], 66 | [102, 102, 156], 67 | [190, 153, 153], 68 | [153, 153, 153], 69 | [250, 170, 30], 70 | [220, 220, 0], 71 | [107, 142, 35], 72 | [152, 251, 152], 73 | [70, 130, 180], 74 | [220, 20, 60], 75 | [255, 0, 0], 76 | [0, 0, 142], 77 | [0, 0, 70], 78 | [0, 60, 100], 79 | [0, 80, 100], 80 | [0, 0, 230], 81 | [119, 11, 32], 82 | [0, 0, 0]] 83 | 84 | sem_pred_np = sem_pred.squeeze().byte().cpu().data.numpy() 85 | sem_pred_np_color = np.zeros((sem_pred_np.shape[0], sem_pred_np.shape[1], 3), dtype=np.uint8) 86 | 87 | for label in range(len(label_colors)): 88 | sem_pred_np_color[sem_pred_np == label] = label_colors[label] 89 | 90 | imsave("data/output/semantic/sem_pred.png", sem_pred_np_color) 91 | 92 | # plt.figure() 93 | # plt.imshow(sem_pred_np_color) 94 | # plt.show() 95 | 96 | -------------------------------------------------------------------------------- /main_stereo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from stereo import MinSumStereo, BlockMatchStereo, RefinedMinSumStereo 5 | import data 6 | 7 | import imageio 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--im0', action='store', required=True, type=str) 13 | parser.add_argument('--im1', action='store', required=True, type=str) 14 | parser.add_argument('--min-disp', action='store', default=0, type=int) 15 | parser.add_argument('--max-disp', action='store', default=127, type=int) 16 | parser.add_argument('--stride-in', action='store', default=1, type=int) 17 | parser.add_argument('--stride-out', action='store', default=1, type=int) 18 | parser.add_argument('--multi-level-output', action='store_true', default=False) 19 | parser.add_argument('--activation', action='store', choices=['relu', 'leakyrelu', 'elu'], default='leakyrelu') 20 | parser.add_argument('--with-bn', action='store_true', default=False) 21 | parser.add_argument('--with-upconv', action='store_true', default=False) 22 | parser.add_argument('--with-output-bn', action='store_true', default=False) 23 | parser.add_argument('--pad', action='store', default=(0, 0), nargs=2, type=int, 24 | help='extra padding of in height and in width on every side') 25 | 26 | parser.add_argument('--model', action='store', default='bp+ms+h', 27 | choices=['wta', 'bp+ms', 'bp+ms+h', 'bp+ms+ref+h']) 28 | parser.add_argument('--checkpoint-unary', action='store', default=None, type=str) 29 | parser.add_argument('--checkpoint-matching', action='store', default=[], nargs='+', type=str) 30 | parser.add_argument('--checkpoint-affinity', action='store', default=None, type=str) 31 | parser.add_argument('--checkpoint-crf', action='append', default=[], type=str, nargs='+') 32 | parser.add_argument('--checkpoint-refinement', action='store', default=None, type=str) 33 | 34 | parser.add_argument('--lbp-min-disp', action='store_true', default=False) 35 | parser.add_argument('--max-iter', action='store', default=1, type=int) 36 | parser.add_argument('--num-bp-layers', action='store', default=1, type=int) 37 | parser.add_argument('--bp-inference', action='store', default='sub-exp', 38 | choices=['wta', 'expectation', 'sub-exp'], type=str) 39 | 40 | parser.add_argument('--matching', action='store', choices=['corr', 'sad', 'conv3d'], 41 | default='sad', type=str) 42 | 43 | parser.add_argument('--input-level-offset', action='store', default=1, type=int, 44 | help='1 means that level 1 is the input resolution') 45 | parser.add_argument('--output-level-offset', action='store', default=1, type=int, 46 | help="0 means that level 0 (=full res) is the output resolution") 47 | args = parser.parse_args() 48 | 49 | I0_pyramid, I1_pyramid = data.load_sample(args.im0, args.im1) 50 | 51 | device = 'cuda:0' 52 | with torch.no_grad(): 53 | if args.model == 'wta': 54 | model = BlockMatchStereo(device, args) 55 | elif args.model == 'bp+ms': 56 | model = MinSumStereo(device, args) 57 | elif args.model == 'bp+ms+h': 58 | model = MinSumStereo(device, args) 59 | elif args.model == 'bp+ms+ref+h': 60 | model = RefinedMinSumStereo(device, args) 61 | 62 | max_disp = None # use original max-disp 63 | res_dict = model.to(device).forward(I0_pyramid, I1_pyramid, max_disp=args.max_disp, step=1) 64 | 65 | 66 | imageio.imwrite("data/output/stereo/" + args.model + ".pfm", 67 | np.flipud(res_dict['disps0'][0].squeeze().float().detach().cpu().numpy())) -------------------------------------------------------------------------------- /matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from networks import SubNetwork 3 | from corenet import TemperatureSoftmax 4 | 5 | from ops.sad.stereo_sad import StereoMatchingSadFunction 6 | 7 | class StereoMatching(SubNetwork): 8 | def __init__(self, device, args, min_disp, max_disp, in_channels=None, lvl=0, step=1.0): 9 | super(StereoMatching, self).__init__(args, device) 10 | self.device = device 11 | self.min_disp = min_disp 12 | self.max_disp = max_disp 13 | self.step = step 14 | self.in_channels = in_channels 15 | self.level = lvl 16 | 17 | self._softmax = TemperatureSoftmax(dim=3, init_temp=1.0) 18 | 19 | def compute_score_volume(self, f0, f1): 20 | raise NotImplementedError 21 | 22 | def forward(self, f0, f1): 23 | score_vol = self.compute_score_volume(f0, f1) 24 | prob_vol = self._softmax.forward(score_vol) 25 | return prob_vol.contiguous() 26 | 27 | @staticmethod 28 | def argmin_to_disp(argmin, min_disp): 29 | res = argmin + min_disp 30 | return res 31 | 32 | def save_checkpoint(self, epoch, iteration): 33 | pass 34 | 35 | class StereoMatchingSad(StereoMatching): 36 | def __init__(self, device, args, min_disp, max_disp, lvl=0, step=1.0): 37 | super(StereoMatchingSad, self).__init__(device, args, min_disp, max_disp, lvl=lvl, step=step) 38 | 39 | self.load_parameters(args.checkpoint_matching[self.level], device) 40 | self.to(device) 41 | 42 | def compute_score_volume(self, f0, f1): 43 | cost_vol = StereoMatchingSadFunction.apply(f0, f1, self.min_disp, self.max_disp, self.step) 44 | return -cost_vol -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import os.path as osp 6 | 7 | from corenet import StereoUnaryUnetDyn, PadUnpad, PaddedConv2d, ResidualBlock 8 | 9 | class SubNetwork(nn.Module): 10 | def __init__(self, args, device): 11 | super(SubNetwork, self).__init__() 12 | self.net = None 13 | self.pad_unpad = None 14 | self.args = args 15 | self.device = device 16 | 17 | def forward(self, ipt): 18 | return self.pad_unpad.forward(ipt) 19 | 20 | def freeze_parameters(self): 21 | self.eval() 22 | for p in self.parameters(): 23 | p.requires_grad = False 24 | 25 | def parameter_list(self, requires_grad=None): 26 | def check(var): 27 | if requires_grad is None: 28 | return True 29 | elif requires_grad is True and var is True: 30 | return True 31 | elif requires_grad is False and var is False: 32 | return True 33 | else: return False 34 | 35 | params = [] 36 | for p in self.parameters(): 37 | if check(p.requires_grad): 38 | params.append(p) 39 | return params 40 | 41 | def load_parameters(self, path, device=None): 42 | if path is not None: 43 | if not osp.exists(path): 44 | raise FileNotFoundError("Specified unary checkpoint file does not exist!" + path) 45 | 46 | if device is not None: 47 | checkpoint = torch.load(path, map_location=device) 48 | else: 49 | checkpoint = torch.load(path) 50 | 51 | checkpoint = self.hook_adjust_checkpoint(checkpoint) 52 | 53 | self.load_state_dict(checkpoint) 54 | print('Successfully loaded checkpoint %s' % path) 55 | 56 | def hook_adjust_checkpoint(self, checkpoint): 57 | return checkpoint 58 | 59 | def save_checkpoint(self, epoch, iteration): 60 | raise NotImplementedError 61 | 62 | @property 63 | def divisor(self): 64 | return self.net.divisor 65 | 66 | @property 67 | def pad_input(self): 68 | return self._pad_input 69 | 70 | class FeatureNet(SubNetwork): 71 | def __init__(self, device, args): 72 | super(FeatureNet, self).__init__(args, device) 73 | 74 | self.net = StereoUnaryUnetDyn(args.multi_level_output, 75 | args.activation, args.with_bn, args.with_upconv, args.with_output_bn) 76 | 77 | # provided pad-unpad 78 | self.pad_unpad = PadUnpad(self.net, self.net.divisor, tuple(args.pad)) 79 | 80 | # load params 81 | self.load_parameters(args.checkpoint_unary, device) 82 | 83 | 84 | class AffinityNet(SubNetwork): 85 | def __init__(self, device, args): 86 | super(AffinityNet, self).__init__(args, device) 87 | 88 | self.net = StereoUnaryUnetDyn(args.multi_level_output, 89 | args.activation, args.with_bn, args.with_upconv, args.with_output_bn) 90 | 91 | 92 | out_channel_factor = args.num_bp_layers 93 | out_channels = out_channel_factor * 2 * 5 # 2 directions, 5 values # L2-, L2+, L1-, L1+, L3 94 | self.conv_out = [PaddedConv2d(ic, out_channels, 3, bias=True).to(device) for ic in \ 95 | self.net.out_channels] 96 | for lvl, conv in enumerate(self.conv_out): 97 | self.add_module('conv_out_lvl' + str(lvl), conv) 98 | 99 | # provide pad-unpad 100 | self.pad_unpad = PadUnpad(self.net, self.net.divisor, tuple(args.pad)) 101 | self.to(device) 102 | 103 | # load params 104 | self.load_parameters(args.checkpoint_affinity, device) 105 | 106 | def forward(self, ipt): 107 | features = super(AffinityNet, self).forward(ipt) 108 | affinities = [torch.abs(conv(fi)) for conv, fi in zip(self.conv_out, features)] # outshape = N x 2 * 5 x H x W 109 | 110 | return affinities 111 | 112 | class RefinementNet(SubNetwork): 113 | def __init__(self, device, args, in_channels=5, out_channels=1, with_output_relu=True): 114 | super(RefinementNet, self).__init__(args, device) 115 | 116 | self.net = [] 117 | self.with_output_relu = with_output_relu 118 | 119 | for lvl in range(self.args.input_level_offset, self.args.output_level_offset - 1, -1): 120 | net = nn.Sequential( 121 | PaddedConv2d(in_channels, 32, 3, bias=True), 122 | ResidualBlock(32, 32, 3, dilation=1), 123 | ResidualBlock(32, 32, 3, dilation=2), 124 | ResidualBlock(32, 32, 3, dilation=4), 125 | ResidualBlock(32, 32, 3, dilation=8), 126 | ResidualBlock(32, 32, 3, dilation=1), 127 | ResidualBlock(32, 32, 3, dilation=1), 128 | PaddedConv2d(32, out_channels, 3, bias=True), 129 | ) 130 | self.net.append(net) 131 | self.add_module('sn_' + str(lvl), net) 132 | 133 | self.to(device) 134 | self.relu = nn.ReLU(inplace=True) 135 | 136 | # load params 137 | self.load_parameters(args.checkpoint_refinement, device) 138 | 139 | 140 | def forward(self, I0_pyramid, d0, confidence, I1_pyramid): 141 | d0_lvl = d0.clone() 142 | refined_pyramid = [] 143 | residuum_pyramid = [] 144 | for ref_lvl in range(self.args.input_level_offset, self.args.output_level_offset -1, -1): 145 | I0_lvl = I0_pyramid[ref_lvl].to(self.device) 146 | I0_lvl = I0_lvl / I0_lvl.var(dim=(2,3), keepdim=True) 147 | 148 | # adapt input size 149 | scale_factor = I0_lvl.shape[2] / d0_lvl.shape[2] 150 | if abs(scale_factor - round(scale_factor)) > 0: 151 | print('WARNING: something weird is going on, got a fractional scale-factor in ref', scale_factor) 152 | d0_up = F.interpolate(d0_lvl.float(), size=I0_lvl.shape[2:], mode='nearest') * scale_factor 153 | conf_up = F.interpolate(confidence.float(), size=I0_lvl.shape[2:], mode='nearest') 154 | 155 | # compute input tensor 156 | ipt = torch.cat((I0_lvl, d0_up, conf_up), dim=1) 157 | residuum = self.net[ref_lvl - self.args.output_level_offset].forward(ipt) 158 | 159 | d0_lvl = d0_up.float() + residuum # flow 160 | if self.with_output_relu: 161 | d0_lvl = self.relu(d0_lvl) # stereo 162 | refined_pyramid.append(d0_lvl) 163 | residuum_pyramid.append(residuum) 164 | 165 | return refined_pyramid, [residuum_pyramid] 166 | 167 | 168 | class EdgeNet(SubNetwork): 169 | def __init__(self, device, args): 170 | super(EdgeNet, self).__init__(args, device) 171 | 172 | self.net = StereoUnaryUnetDyn(args.multi_level_output, args.activation, args.with_bn, args.with_upconv, args.with_output_bn) 173 | 174 | out_channels = args.num_bp_layers * 2 175 | self.conv_out = [PaddedConv2d(ic, out_channels, 3, bias=True).to(device) for ic in \ 176 | self.net.out_channels] 177 | for lvl, conv in enumerate(self.conv_out): 178 | self.add_module('conv_out_lvl' + str(lvl), conv) 179 | 180 | # provide pad-unpad 181 | self.pad_unpad = PadUnpad(self.net, self.net.divisor, tuple(args.pad)) 182 | self.to(device) 183 | 184 | def forward(self, ipt): 185 | features = super(EdgeNet, self).forward(ipt) 186 | edge_weights = [torch.abs(conv(fi)) for conv, fi in zip(self.conv_out, features)] 187 | return edge_weights 188 | 189 | 190 | class PWNet(SubNetwork): 191 | def __init__(self, device, args): 192 | super(PWNet, self).__init__(args, device) 193 | 194 | self.net = StereoUnaryUnetDyn(args.multi_level_output, 195 | args.activation, args.with_bn, args.with_upconv, args.with_output_bn) 196 | 197 | out_channels = args.num_bp_layers * 2 * args.num_labels * args.num_labels # 2 directions, 5 values # L2-, L2+, L1-, L1+, L3 198 | 199 | self.conv_out = [PaddedConv2d(ic, out_channels, 3, bias=True).to(device) for ic in \ 200 | self.net.out_channels] 201 | for lvl, conv in enumerate(self.conv_out): 202 | self.add_module('conv_out_lvl' + str(lvl), conv) 203 | 204 | # provide pad-unpad 205 | self.pad_unpad = PadUnpad(self.net, self.net.divisor, tuple(args.pad)) 206 | self.to(device) 207 | 208 | def forward(self, ipt): 209 | features = super(PWNet, self).forward(ipt) 210 | 211 | pairwise_costs = [torch.abs(conv(fi)) for conv, fi in zip(self.conv_out, features)] # outshape = N x 2 * 5 x H x W 212 | 213 | return pairwise_costs -------------------------------------------------------------------------------- /ops/flow_mp_sad/flow_mp_sad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import pytorch_cuda_flow_mp_sad_op 4 | 5 | class FlowMpSadFunction(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, f0, f1, sws): 8 | # shape = N x H x W x K => 2x 3D cost volume for u (=x dir) and v (=y dir) 9 | offset_u = offset_v = 0 10 | block_u = block_v = 0 11 | if sws <= 108: 12 | #if False: 13 | cv_u, cv_v, u_star, v_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, sws, offset_u, offset_v, block_u, block_v) 14 | elif sws == 2*108: 15 | #else: 16 | #print('Hello -> split-up flow SAD') 17 | s = sws // 2 # new sub-search-window-size 18 | cv00_u, cv00_v, u00_star, v00_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, s, -s//2, -s//2, 0, 0)#x0y0 19 | cv01_u, cv01_v, u01_star, v01_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, s, -s//2, s//2, 0, 1)#x0y1 20 | cv10_u, cv10_v, u10_star, v10_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, s, s//2, -s//2, 1, 0)#x1y0 21 | cv11_u, cv11_v, u11_star, v11_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, s, s//2, s//2, 1, 1)#x1y1 22 | 23 | # ref 24 | #cv_u_ref, cv_v_ref, u_star_ref, v_star_ref = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, sws, 0,0,0,0) 25 | 26 | # merge sub-volums to one volume 27 | u0 = torch.cat((cv00_u, cv10_u[:,:,:,1:]), dim=-1) 28 | u1 = torch.cat((cv01_u, cv11_u[:,:,:,1:]), dim=-1) 29 | au0 = torch.cat((u00_star, u10_star[:,:,:,1:]), dim=-1) 30 | au1 = torch.cat((u01_star, u11_star[:,:,:,1:]), dim=-1) 31 | 32 | idx = u1 < u0 33 | u0[idx] = u1[idx] # overwrite better values 34 | au0[idx] = au1[idx] 35 | 36 | 37 | v0 = torch.cat((cv00_v, cv01_v[:,:,:,1:]), dim=-1) 38 | v1 = torch.cat((cv10_v, cv11_v[:,:,:,1:]), dim=-1) 39 | av0 = torch.cat((v00_star, v01_star[:,:,:,1:]), dim=-1) 40 | av1 = torch.cat((v10_star, v11_star[:,:,:,1:]), dim=-1) 41 | 42 | idx = v1 < v0 43 | v0[idx] = v1[idx] # overwrite better valves 44 | av0[idx] = av1[idx] 45 | 46 | cv_u = u0 47 | u_star = au0 48 | cv_v = v0 49 | v_star = av0 50 | elif sws == 432: # sintel all!! 51 | # global offset 52 | go_u = 0 53 | go_v = 0 54 | 55 | # read cost-volume block u 56 | # index with row and col 57 | cv_bu = [[], []] 58 | cv_bv = [[], []] 59 | bu_star = [[], []] 60 | bv_star = [[], []] 61 | 62 | s = 108 # search-window size of block 63 | # iterate over 4 x 4 grid 64 | for idx_v, grid_v in enumerate([-2, -1, 1, -2]): # grid-position 65 | for idx_u, grid_u in enumerate([-2, -1, 1, -2]): 66 | ro_u = (grid_u * 2 - 1) * s // 2 # relative offset in x 67 | ro_v = (grid_v * 2 - 1) * s // 2 # relative offset in y 68 | 69 | cv_uu, cv_vv, uu_star, vv_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, s, 70 | ro_u + go_u, ro_v + go_v, idx_u, idx_v) 71 | cv_bu[idx_u].append(cv_uu) 72 | cv_bv[idx_u].append(cv_vv) 73 | bu_star[idx_u].append(uu_star) 74 | bv_star[idx_v].append(vv_star) 75 | 76 | # stitch complete cv_u 77 | # read: u0 for all v, u_star for all v 78 | u0_v = [cv_bu[0][idx_v] for idx_v in range(4)] 79 | u_star_v = [bu_star[0][idx_v] for idx_v in range(4)] 80 | for idx_u in range(1, len(cv_bu[0])): 81 | for idx_v in range(4): 82 | # increment u, keepp v 83 | u0_v[idx_v] = torch.cat((u0_v[idx_v], cv_bu[idx_u][idx_v][:,:,:,1:]), dim=-1) 84 | u_star_v[idx_v] = torch.cat((u_star_v[idx_v], bu_star[idx_u][idx_v][:,:,:,1:]), dim=-1) 85 | 86 | 87 | 88 | 89 | #elif sws == 372: # kitti 90 | # ATTENTION: FORWARD WOULD WORK NICELY LIKE THAT, BUT BACKWARD WOULD NEED THE ORIGINAL SEARCH WINDOW!! 91 | # PROBABLY NOT A GOOD IDEA TO CHANGE THIS NOW ... 92 | # # 2 x 6 grid with size 372 x 124 with block-size 62 and offset (+18, +36) 93 | 94 | # # global offset 95 | # go = np.array([18, 36]) 96 | 97 | # # read cost-volume block u 98 | # # index with row and col 99 | # cv_bu = [[], []] 100 | # cv_bv = [[], []] 101 | # bu_star = [[], []] 102 | # bv_star = [[], []] 103 | 104 | # # iterate over 2 x 6 grid 105 | # for grid_v in range(-1, 2): 106 | # for grid_u in range(-3, 4): 107 | # cv_uu, cv_vv, uu_star, vv_star = pytorch_cuda_flow_mp_sad_op.forward(f0, f1, s, -s//2, -s//2, 0, 0)#x0y0 108 | else: 109 | raise ValueError("Unsupported sws: ", sws, "only allowd: <= 108 or 216") 110 | 111 | ctx.save_for_backward(f0, f1, torch.tensor(sws), u_star, v_star) 112 | return cv_u, cv_v, u_star, v_star 113 | 114 | @staticmethod 115 | def backward(ctx, in_grad_u, in_grad_v, u_star_grad, v_star_grad): 116 | # u_star_grad and v_star_grad are just zeros. 117 | f0, f1, sws, u_star, v_star = ctx.saved_tensors 118 | df0, df1 = pytorch_cuda_flow_mp_sad_op.backward(f0, f1, int(sws), in_grad_u, in_grad_v, u_star, v_star) 119 | return df0, df1, None, None 120 | -------------------------------------------------------------------------------- /ops/flow_mp_sad/src/flow_mp_sad.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "flow_mp_sad_kernel.cuh" 3 | 4 | // C++ interface 5 | // AT_ASSERTM in pytorch 1.0 6 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 7 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 8 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); 9 | 10 | std::vector flow_mp_sad_forward(at::Tensor f0, at::Tensor f1, int sws, int offset_u, int offset_v, 11 | int blockIdx_u, int blockIdx_v) 12 | { 13 | CHECK_INPUT(f0) 14 | CHECK_INPUT(f1) 15 | return cuda::flow_mp_sad_forward(f0, f1, sws, offset_u, offset_v, blockIdx_u, blockIdx_v); 16 | } 17 | 18 | std::vector flow_mp_sad_backward(at::Tensor f0, at::Tensor f1, int sws, 19 | at::Tensor in_grad_u, at::Tensor in_grad_v, 20 | at::Tensor u_star, at::Tensor v_star) 21 | { 22 | CHECK_INPUT(f0) 23 | CHECK_INPUT(f1) 24 | CHECK_INPUT(in_grad_u) 25 | CHECK_INPUT(in_grad_v) 26 | CHECK_INPUT(u_star) 27 | CHECK_INPUT(v_star) 28 | 29 | return cuda::flow_mp_sad_backward(f0, f1, sws, in_grad_u, in_grad_v, u_star, v_star); 30 | } 31 | 32 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 33 | { 34 | m.def("forward", &flow_mp_sad_forward, "Flow correlation Matching forward (CUDA)"); 35 | m.def("backward", &flow_mp_sad_backward, "Flow correlation Matching backward (CUDA)"); 36 | } -------------------------------------------------------------------------------- /ops/flow_mp_sad/src/flow_mp_sad_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "flow_mp_sad_kernel.cuh" 4 | #include "tensor.h" 5 | #include "error_util.h" 6 | 7 | // ============================================================================ 8 | // CUDA KERNELS 9 | // ============================================================================ 10 | __global__ void flow_mp_sad_cuda_forward_kernel( 11 | KernelData f0, 12 | KernelData f1, 13 | int sws, 14 | KernelData cv_u, 15 | KernelData cv_v, 16 | KernelData u_star, 17 | KernelData v_star, 18 | int offset_u, 19 | int offset_v, 20 | int blockIdx_u, // necessary for argmin computation 21 | int blockIdx_v // same here 22 | ) 23 | { 24 | // parallelize over u, loop over v 25 | //const int x = blockIdx.x * blockDim.x + threadIdx.x; 26 | const int y = blockIdx.y * blockDim.y + threadIdx.y; 27 | // const int u_idx = blockIdx.z * blockDim.z + threadIdx.z; 28 | 29 | const int u_idx = blockIdx.x * blockDim.x + threadIdx.x; 30 | const int x = blockIdx.z * blockDim.z + threadIdx.z; 31 | 32 | // shared memory for search-window matching costs 33 | extern __shared__ float sdata[]; 34 | 35 | // global defines 36 | unsigned short K = cv_u.size3; 37 | short sws_half = sws / 2; 38 | short u = u_idx - sws_half; 39 | 40 | //short sm_offset = blockDim.x * K * K * threadIdx.y + K * K * threadIdx.x; 41 | short sm_offset = blockDim.z * K * K * threadIdx.y + K * K * threadIdx.z; 42 | 43 | // check inside image for reference pixel 44 | int n = 0; 45 | if(x >= f0.size3 || y >= f0.size2 || u_idx >= K) 46 | return; 47 | 48 | // initialize all sws with constant value (initialize all v displacements for given u_idx) 49 | for(short v_idx = 0; v_idx < K; ++v_idx) 50 | { 51 | sdata[sm_offset + K * v_idx + u_idx] = 40.0; 52 | } 53 | __syncthreads(); 54 | 55 | // skip outside pixels 56 | // if(x + u < 0 || x + u >= f0.size3) 57 | // return; 58 | 59 | // I cannot return outside pixels directly, because I need all the treads for the min-computation 60 | // later!! 61 | if(x + offset_u + u >= 0 && x + offset_u + u < f0.size3) // check match inside 62 | { 63 | for(short v = -sws_half; v <= sws_half; ++v) 64 | { 65 | short v_idx = v + sws_half; 66 | 67 | // skip outside pixels (match-pixel) 68 | if(y + offset_v + v < 0 || y + offset_v + v >= f0.size2) 69 | continue; 70 | 71 | float sad = 0.0f; 72 | for(int c = 0; c < f0.size1; ++c) 73 | { 74 | sad += fabs(f0(n, c, y, x) - f1(n, c, y + offset_v + v, x + offset_u + u)); 75 | } 76 | 77 | // save result to shared mem 78 | sdata[sm_offset + K * v_idx + u_idx] = sad; 79 | //cv_all(n, y, x, v_idx, u_idx) = sad; 80 | } 81 | } 82 | __syncthreads(); // all u-threads must be ready here! 83 | 84 | // compute min-projection in shared memory 85 | // Note: u_idx is parallelized within the kernel! 86 | float min_v = 9999999.0; 87 | short argmin_v = 0; 88 | for(unsigned short v_idx = 0; v_idx < K; ++v_idx) 89 | { 90 | if(sdata[sm_offset + K * v_idx + u_idx] < min_v) 91 | { 92 | min_v = sdata[sm_offset + K * v_idx + u_idx]; 93 | argmin_v = v_idx; 94 | } 95 | } 96 | 97 | // update min only if the current block has a better min 98 | // if(min_v < cv_u(n, y, x, u_idx)) // for inplace variant which I do not have yet 99 | //{ 100 | cv_u(n, y, x, u_idx) = min_v; 101 | u_star(n, y, x, u_idx) = argmin_v + blockIdx_v * sws; // sws = K - 1 => default overlap 102 | //} 103 | 104 | // compute min-projection in shared memory 105 | // here I swap rules and use the u_idx as v_idx for easier parallelization 106 | float min_u = 9999999.0; 107 | short v_idx = u_idx; 108 | short argmin_u = 0; 109 | for(unsigned short u_idx = 0; u_idx < K; ++u_idx) 110 | { 111 | if(sdata[sm_offset + K * v_idx + u_idx] < min_u) 112 | { 113 | min_u = sdata[sm_offset + K * v_idx + u_idx]; 114 | argmin_u = u_idx; 115 | } 116 | } 117 | 118 | // update min only if the current block has a better min 119 | //if(min_u < cv_v(n, y, x, v_idx)) // for inplace variant which I do not have yet 120 | //{ 121 | cv_v(n, y, x, v_idx) = min_u; 122 | v_star(n, y, x, v_idx) = argmin_u + blockIdx_u * sws; // sws = K - 1 => default overlap 123 | //} 124 | } 125 | 126 | __global__ void flow_mp_sad_cuda_backward_kernel( 127 | KernelData f0, 128 | KernelData f1, 129 | int sws, 130 | KernelData in_grad_u, 131 | KernelData in_grad_v, 132 | KernelData u_star, 133 | KernelData v_star, 134 | KernelData df0, 135 | KernelData df1 136 | ) 137 | { 138 | const unsigned int x = blockIdx.x * blockDim.x + threadIdx.x; 139 | const unsigned int y = blockIdx.y * blockDim.y + threadIdx.y; 140 | const unsigned int c = blockIdx.z * blockDim.z + threadIdx.z; 141 | 142 | float eps = 1e-15; 143 | 144 | // check inside image 145 | int n = 0; 146 | if(x >= f0.size3 || y >= f0.size2 || c >= f0.size1) 147 | return; 148 | 149 | int sws_half = sws / 2; 150 | 151 | float grad_f0 = 0.0f; 152 | float grad_f1 = 0.0f; 153 | for(short u = -sws_half; u <= sws_half; ++u) 154 | { 155 | short u_idx = u + sws_half; 156 | short v_idx = u_star(n, y, x, u_idx); 157 | short v = v_idx - sws_half; 158 | 159 | // skip outside pixels 160 | if(x + u >= 0 && x + u < f0.size3 && y + v >= 0 && y + v < f0.size2) 161 | { 162 | float diff = f0(n, c, y, x) - f1(n, c, y + v, x + u); 163 | if(fabsf(diff) > eps) // gradient is zero if diff is zero! 164 | { 165 | float update = diff / fabsf(diff) * in_grad_u(n, y, x, u_idx); 166 | // local update for df0 167 | grad_f0 += update; 168 | 169 | // global update for df1 (multiple vars can point to one address!) 170 | atomicAdd(&df1(n, c, y + v, x + u), -update); 171 | } 172 | } 173 | 174 | } 175 | 176 | for(short v = -sws_half; v <= sws_half; ++v) 177 | { 178 | short v_idx = v + sws_half; 179 | short u_idx = v_star(n, y, x, v_idx); 180 | short u = u_idx - sws_half; 181 | 182 | // copied from above, only change is that here in_grad_v is used 183 | if(x + u >= 0 && x + u < f0.size3 && y + v >= 0 && y + v < f0.size2) 184 | { 185 | float diff = f0(n, c, y, x) - f1(n, c, y + v, x + u); 186 | if(fabsf(diff) > eps) // gradient is zero if diff is zero! 187 | { 188 | float update = diff / fabsf(diff) * in_grad_v(n, y, x, v_idx); 189 | // local update for df0 190 | grad_f0 += update; 191 | 192 | // global update for df1 (multiple vars can point to one address!) 193 | atomicAdd(&df1(n, c, y + v, x + u), -update); 194 | } 195 | } 196 | } 197 | 198 | df0(n, c, y, x) = grad_f0; 199 | 200 | } 201 | 202 | 203 | // ============================================================================ 204 | // CPP KERNEL CALLS 205 | // ============================================================================ 206 | namespace cuda 207 | { 208 | std::vector flow_mp_sad_forward(at::Tensor f0, at::Tensor f1, int sws, int offset_u, int offset_v, 209 | int blockIdx_u, int blockIdx_v) 210 | { 211 | int N = f0.size(0); 212 | int C = f0.size(1); 213 | int H = f0.size(2); 214 | int W = f0.size(3); 215 | int K = sws + 1; 216 | 217 | auto cv_u = at::ones({N, H, W, K}, f0.options()) * 40; 218 | auto cv_v = at::ones({N, H, W, K}, f0.options()) * 40; 219 | 220 | auto u_star = at::zeros({N, H, W, K}, f0.options()); 221 | auto v_star = at::zeros({N, H, W, K}, f0.options()); 222 | 223 | //auto cv_all = at::ones({N, H, W, K, K}, f0.options()) * 40; 224 | 225 | if(K > 128) 226 | std::cout << "Error: Maximal search window size is " << K << " which is larger than max allowed 128!!" << std::endl; 227 | 228 | // parallelise over H x W x K 229 | // all K need to be in one block in order to have access to the same shared memory! 230 | // K needs to be the first, because last idx must be < 64. 231 | const dim3 blockSize(K, 1, 1); 232 | const dim3 numBlocks(std::ceil(K / static_cast(blockSize.x)), 233 | std::ceil(H / static_cast(blockSize.y)), 234 | std::ceil(W / static_cast(blockSize.z))); 235 | 236 | const int threadsPerBlock = blockSize.x * blockSize.y * blockSize.z; 237 | 238 | // std::cout << "N=" << N << " C=" << C << " H=" << H << " W=" << W << " K=" << K << std::endl; 239 | // std::cout << "threadsPerBlock=" << threadsPerBlock << std::endl; 240 | // std::cout << "numBlocks.x=" << numBlocks.x << " .y=" << numBlocks.y << " .z=" << numBlocks.z << std::endl; 241 | // std::cout << "mem-use=" << threadsPerBlock*K*sizeof(float) << "bytes" << std::endl; 242 | 243 | //CudaTimer cut; 244 | //cut.start(); 245 | flow_mp_sad_cuda_forward_kernel<<>>(f0, f1, sws, cv_u, cv_v, u_star, v_star, offset_u, offset_v, blockIdx_u, blockIdx_v); 246 | cudaSafeCall(cudaGetLastError()); 247 | // cudaDeviceSynchronize(); 248 | //std::cout << "SAD forward time " << cut.elapsed() << std::endl; 249 | std::vector res; 250 | //cost_vols.push_back(cv_all); 251 | res.push_back(cv_u); 252 | res.push_back(cv_v); 253 | res.push_back(u_star); 254 | res.push_back(v_star); 255 | return res; 256 | } 257 | 258 | std::vector flow_mp_sad_backward(at::Tensor f0, at::Tensor f1, 259 | int sws, at::Tensor in_grad_u, at::Tensor in_grad_v, 260 | at::Tensor u_star, at::Tensor v_star) 261 | { 262 | int N = f0.size(0); 263 | int C = f0.size(1); 264 | int H = f0.size(2); 265 | int W = f0.size(3); 266 | int K = sws + 1; 267 | 268 | auto df0 = at::zeros_like(f0); 269 | auto df1 = at::zeros_like(f1); 270 | 271 | // parallelise over H x W x D 272 | const dim3 blockSize(8, 8, 4); 273 | const dim3 numBlocks(std::ceil(W / static_cast(blockSize.x)), 274 | std::ceil(H / static_cast(blockSize.y)), 275 | std::ceil(C / static_cast(blockSize.z))); 276 | 277 | //CudaTimer cut; 278 | //cut.start(); 279 | flow_mp_sad_cuda_backward_kernel<<>>(f0, f1, sws, in_grad_u, in_grad_v, 280 | u_star, v_star, df0, df1); 281 | cudaSafeCall(cudaGetLastError()); 282 | // cudaDeviceSynchronize(); 283 | 284 | //std::cout << "SAD backward time " << cut.elapsed() << std::endl; 285 | 286 | std::vector gradients; 287 | gradients.push_back(df0); 288 | gradients.push_back(df1); 289 | 290 | return gradients; 291 | } 292 | } -------------------------------------------------------------------------------- /ops/flow_mp_sad/src/flow_mp_sad_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace cuda 5 | { 6 | std::vector flow_mp_sad_forward(at::Tensor f0, at::Tensor f1, int sws, int offset_u, int offset_v, 7 | int blockIdx_u, int blockIdx_v); 8 | std::vector flow_mp_sad_backward(at::Tensor f0, at::Tensor f1, int sws, 9 | at::Tensor in_grad_u, at::Tensor in_grad_v, 10 | at::Tensor u_star, at::Tensor v_star); 11 | } -------------------------------------------------------------------------------- /ops/include/error_util.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | //#include "cudnn.h" 5 | #include "stdio.h" 6 | #include "stdlib.h" 7 | 8 | #define CUDA_ERROR_CHECK 9 | 10 | #define cudaSafeCall( err ) __cnnCudaSafeCall( err, __FILE__, __LINE__ ) 11 | #define cudnnSafeCall( err ) __cnnCudnnSafeCall( err, __FILE__, __LINE__ ) 12 | 13 | inline void __cnnCudaSafeCall( cudaError_t err, const char *file, const int line ) 14 | { 15 | #ifdef CUDA_ERROR_CHECK 16 | if ( cudaSuccess != err ) 17 | { 18 | fprintf( stderr, "cudaSafeCall() failed at %s:%i : %s\n", file, line, cudaGetErrorString( err ) ); 19 | exit( -1 ); 20 | } 21 | #endif 22 | 23 | return; 24 | } 25 | 26 | // inline void __cnnCudnnSafeCall( cudnnStatus_t err, const char *file, const int line) 27 | // { 28 | // #ifdef CUDA_ERROR_CHECK 29 | // if(err != CUDNN_STATUS_SUCCESS) 30 | // { 31 | // fprintf( stderr, "cudaSafeCall() failed at %s:%i : %s\n", file, line, cudnnGetErrorString( err ) ); 32 | // exit( -1 ); 33 | // } 34 | // #endif 35 | 36 | // return; 37 | // } 38 | -------------------------------------------------------------------------------- /ops/include/tensor.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | /** \brief Struct pointer TensorKernelData that can be used in CUDA kernels. 7 | * 8 | * This struct provides the device data pointer as well as important class 9 | * properties. 10 | */ 11 | struct KernelData5 12 | { 13 | /** Pointer to device buffer. */ 14 | float* data_; 15 | 16 | unsigned int stride0; 17 | unsigned int stride1; 18 | unsigned int stride2; 19 | unsigned int stride3; 20 | unsigned int stride4; 21 | 22 | // size of dimensions 0 - 4 23 | unsigned short size0; 24 | unsigned short size1; 25 | unsigned short size2; 26 | unsigned short size3; 27 | unsigned short size4; 28 | 29 | 30 | /** Access the image via the () operator 31 | * @param x0 Position in the first dimension. 32 | * @param x1 Position in the second dimension. 33 | * @param x2 Position in the third dimension. 34 | * @param x3 Position in the forth dimension. 35 | * @param x4 Position in the fifth dimension. 36 | * @return value at position (x0, x1, x2, x3, x4). 37 | */ 38 | __device__ float& operator()(short x0, short x1, short x2, short x3, short x4) 39 | { 40 | return data_[x0 * stride0 + x1 * stride1 + x2 * stride2 + x3 *stride3 + x4]; 41 | } 42 | 43 | /** Get position / coordinates for a linear index. 44 | * @param[in] linearIdx Linear index. 45 | * @param[out] dim0 Position in the first dimension. 46 | * @param[out] dim1 Position in the second dimension. 47 | * @param[out] dim2 Position in the third dimension. 48 | * @param[out] dim3 Position in the forth dimension. 49 | */ 50 | __device__ void coords(unsigned int linearIdx, short *x0, short *x1, short *x2, short *x3, short *x4) 51 | { 52 | // modulo is slow 53 | // *dim0 = linearIdx / stride0; 54 | // *dim1 = (linearIdx % stride0) / stride1; 55 | // *dim2 = ((linearIdx % stride0) % stride1) / stride2; 56 | // *dim3 = ((linearIdx % stride0) % stride1) % stride2; 57 | *x0 = linearIdx / stride0; 58 | *x1 = (linearIdx - *x0 * stride0) / stride1; 59 | *x2 = (linearIdx - (*x0 * stride0 + *x1 * stride1)) / stride2; 60 | *x3 = linearIdx - (*x0 * stride0 + *x1 * stride1 + *x2 * stride2); 61 | *x4 = linearIdx - (*x0 * stride0 + *x1 * stride1 + *x2 * stride2 + *x3 * stride3); 62 | } 63 | 64 | /** Constructor */ 65 | __host__ KernelData5(const at::Tensor &tensor) : 66 | data_(tensor.data()), 67 | size0(tensor.size(0)), 68 | size1(tensor.size(1)), 69 | size2(tensor.size(2)), 70 | size3(tensor.size(3)), 71 | size4(tensor.size(4)), 72 | stride0(tensor.stride(0)), 73 | stride1(tensor.stride(1)), 74 | stride2(tensor.stride(2)), 75 | stride3(tensor.stride(3)), 76 | stride4(tensor.stride(4)) 77 | { 78 | } 79 | }; 80 | 81 | /** \brief Struct pointer TensorKernelData that can be used in CUDA kernels. 82 | * 83 | * This struct provides the device data pointer as well as important class 84 | * properties. 85 | */ 86 | struct KernelData 87 | { 88 | /** Pointer to device buffer. */ 89 | float* data_; 90 | 91 | unsigned int stride0; 92 | unsigned int stride1; 93 | unsigned int stride2; 94 | unsigned int stride3; 95 | 96 | // size of dimensions 0 - 3 97 | unsigned short size0; 98 | unsigned short size1; 99 | unsigned short size2; 100 | unsigned short size3; 101 | 102 | 103 | /** Access the image via the () operator 104 | * @param x0 Position in the first dimension. 105 | * @param x1 Position in the second dimension. 106 | * @param x2 Position in the third dimension. 107 | * @param x3 Position in the forth dimension. 108 | * @return value at position (x0, x1, x2, x3). 109 | */ 110 | __device__ float& operator()(short x0, short x1, short x2, short x3) 111 | { 112 | return data_[x0 * stride0 + x1 * stride1 + x2 * stride2 + x3]; 113 | } 114 | 115 | /** Get position / coordinates for a linear index. 116 | * @param[in] linearIdx Linear index. 117 | * @param[out] dim0 Position in the first dimension. 118 | * @param[out] dim1 Position in the second dimension. 119 | * @param[out] dim2 Position in the third dimension. 120 | * @param[out] dim3 Position in the forth dimension. 121 | */ 122 | __device__ void coords(unsigned int linearIdx, short *x0, short *x1, short *x2, short *x3) 123 | { 124 | // modulo is slow 125 | // *dim0 = linearIdx / stride0; 126 | // *dim1 = (linearIdx % stride0) / stride1; 127 | // *dim2 = ((linearIdx % stride0) % stride1) / stride2; 128 | // *dim3 = ((linearIdx % stride0) % stride1) % stride2; 129 | *x0 = linearIdx / stride0; 130 | *x1 = (linearIdx - *x0 * stride0) / stride1; 131 | *x2 = (linearIdx - (*x0 * stride0 + *x1 * stride1)) / stride2; 132 | *x3 = linearIdx - (*x0 * stride0 + *x1 * stride1 + *x2 * stride2); 133 | } 134 | 135 | /** Constructor */ 136 | __host__ KernelData(const at::Tensor &tensor) : 137 | data_(tensor.data()), 138 | size0(tensor.size(0)), 139 | size1(tensor.size(1)), 140 | size2(tensor.size(2)), 141 | size3(tensor.size(3)), 142 | stride0(tensor.stride(0)), 143 | stride1(tensor.stride(1)), 144 | stride2(tensor.stride(2)), 145 | stride3(tensor.stride(3)) 146 | { 147 | // std::cout << "size of size " << tensor.sizes().size() << std::endl; 148 | // std::cout << "s0 " << tensor.size(0) << std::endl; 149 | // std::cout << "s1 " << tensor.size(1) << std::endl; 150 | // std::cout << "s2 " << tensor.size(2) << std::endl; 151 | // std::cout << "s3 " << tensor.size(3) << std::endl; 152 | //std::cout << "s0 " << tensor.size(4) << std::endl; 153 | } 154 | }; -------------------------------------------------------------------------------- /ops/lbp_semantic_pw/bp_op_cuda.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from networks import SubNetwork 8 | 9 | from ops.lbp_semantic_pw.inference_op import Inference 10 | from ops.lbp_semantic_pw.message_passing_op_cuda import MessagePassing 11 | 12 | class BP(SubNetwork): 13 | @staticmethod 14 | def construct_pw_energy_weights(K, L1, L2): 15 | '''K = num labels''' 16 | fij = np.ones((K, K)) * L2 17 | for i in range(-1, 2): 18 | if i == 0: 19 | fij -= np.diag(L2 * np.ones(K - np.abs(i)), i) 20 | else: 21 | fij -= np.diag((L2 - L1) * np.ones(K - np.abs(i)), i) 22 | 23 | return fij 24 | 25 | @staticmethod 26 | def construct_pw_prob_weights(K, L1, L2): 27 | '''K = num labels''' 28 | fij = np.zeros((K, K)) 29 | 30 | for i in range(-1, 2): 31 | if i == 0: 32 | fij += np.diag(L1 * np.ones(K - np.abs(i)), i) 33 | else: 34 | fij += np.diag(L2 * np.ones(K - np.abs(i)), i) 35 | 36 | fij[fij == 0] = 1.0 - L1 - L2 #(1.0 - L1 - 2 * L2) / (K - 3.0) #1.0 - L1 - L2 37 | 38 | return fij 39 | 40 | # modes = wta / expectation 41 | def __init__(self, device, args, max_iter, num_labels, delta, mode_inference='expectation', mode_message_passing='min-sum', layer_idx=0, single_pw=False): 42 | super(BP, self).__init__(args, device) 43 | self.device = device 44 | self.max_iter = max_iter 45 | self.layer_idx = layer_idx 46 | self.delta = delta 47 | self.single_pw = single_pw 48 | 49 | if mode_inference != 'wta' and mode_inference != 'expectation' and mode_message_passing != 'min-sum' and mode_inference != 'norm' and mode_inference != 'raw': 50 | raise ValueError("Unknown inference/message passing mode " + mode_inference + " " + mode_message_passing) 51 | 52 | self.message_passing = MessagePassing(self.device, self.max_iter, num_labels, self.delta, mode_message_passing) 53 | self.inference = Inference(self.device, mode_inference, mode_passing=mode_message_passing) 54 | 55 | def forward(self, prob_vol, edge_weights, jump): 56 | 57 | N, H, W, K = prob_vol.shape 58 | messages = torch.zeros((N, 4, H, W, K), requires_grad=True, device=self.device, 59 | dtype=torch.float) 60 | 61 | # compute messages 62 | beliefs = self.message_passing.forward(prob_vol, edge_weights, messages, jump, self.single_pw) 63 | 64 | # + wta/expectation 65 | result = self.inference.forward(beliefs) 66 | 67 | return result.permute(0,3,1,2), beliefs 68 | 69 | def project_jumpcosts(self): 70 | self.message_passing.projectL1L2() 71 | 72 | def save_checkpoint(self, epoch, iteration): 73 | if 'c' in self.args.train_params: 74 | torch.save(self.state_dict(), 75 | osp.join(self.args.train_dir, 'crf' + str(self.layer_idx) + '_checkpoint_' + 76 | str(epoch) + '_' + str(iteration).zfill(6) + '.cpt')) 77 | 78 | def adjust_input_weights(self, weights, idx): 79 | if weights is not None: 80 | weights_idx = weights[:, idx * 2 : (idx + 1) * 2, :, :] 81 | 82 | # wx_L = np.zeros_like(wx) 83 | # wy_D = np.zeros_like(wy) 84 | # wx_L[:, 1:] = wx[:, :-1] 85 | # wy_D[1:, :] = wy[:-1, :] 86 | 87 | weights_input = torch.zeros((weights_idx.shape[0], 4, weights_idx.shape[2], weights_idx.shape[3])).cuda() 88 | 89 | weights_input[:, 0] = weights[:, 0] 90 | # wx RL 91 | weights_input[:, 1, :, 1:] = weights[:, 0, :, :-1] 92 | # wy UD 93 | weights_input[:, 2] = weights[:, 1] 94 | # wy DU 95 | weights_input[:, 3, 1:, :] = weights[:, 1, :-1, :] 96 | 97 | weights_input = weights_input.contiguous() 98 | 99 | else: 100 | weights_input = None 101 | 102 | return weights_input 103 | 104 | @property 105 | def L1(self): 106 | return self.message_passing.L1 107 | 108 | @L1.setter 109 | def L1(self, value): 110 | self.message_passing.L1.data = torch.tensor(value, device=self.device, dtype=torch.float) 111 | 112 | @property 113 | def L2(self): 114 | return self.message_passing.L2 115 | 116 | @L2.setter 117 | def L2(self, value): 118 | self.message_passing.L2.data = torch.tensor(value, device=self.device, dtype=torch.float) 119 | 120 | @property 121 | def rescaleT(self): 122 | return self.message_passing.rescaleT 123 | 124 | @rescaleT.setter 125 | def rescaleT(self, value): 126 | self.message_passing.rescaleT = value 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /ops/lbp_semantic_pw/inference_op.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | import pytorch_cuda_lbp_op as lbp 7 | 8 | class Inference(nn.Module): 9 | # modes = wta / expectation 10 | def __init__(self, device, mode='wta', mode_passing='min-sum'): 11 | super(Inference, self).__init__() 12 | self.device = device 13 | if mode != 'wta' and mode != 'expectation' and mode != 'norm' and mode != 'raw': 14 | raise ValueError("Unknown inference mode " + mode) 15 | self.mode = mode 16 | self.mode_passing = mode_passing 17 | 18 | def forward(self, beliefs): 19 | if self.mode == "wta" and self.mode_passing == "min-sum": 20 | res = torch.argmax(beliefs, dim=3, keepdim=True) 21 | if self.mode == "wta": 22 | res = torch.argmax(beliefs, dim=3, keepdim=True) 23 | elif self.mode == "expectation": 24 | beliefs_normal = beliefs / beliefs.sum(dim=3, keepdim=True) 25 | 26 | if torch.isnan(beliefs_normal).sum() > 0: 27 | print("Beliefs normalized contains " + str(torch.isnan(beliefs_normal).sum()) + " NaNs ;(") 28 | 29 | labels = np.arange(beliefs.shape[3])[np.newaxis, np.newaxis, np.newaxis, :] 30 | labels_tensor = torch.tensor(labels.astype('float32'), device=self.device) 31 | res = (beliefs_normal * labels_tensor).sum(dim=3, keepdim=True) 32 | elif self.mode == "norm": 33 | beliefs_normal = beliefs / beliefs.sum(dim=3, keepdim=True) 34 | res = beliefs_normal 35 | elif self.mode == "raw": 36 | #print("using raw inference...") 37 | res = beliefs 38 | 39 | return res 40 | -------------------------------------------------------------------------------- /ops/lbp_semantic_pw/message_passing_op_cuda.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import pytorch_cuda_lbp_op as lbp 7 | 8 | from corenet import TemperatureSoftmin 9 | 10 | def construct_pw_tensor(L1, L2, K): 11 | pw_tensor = torch.ones((K, K), dtype=torch.float) * L2 12 | 13 | for i in range(-1, 2): 14 | if i == 0: 15 | pw_tensor -= torch.diag(L2 * torch.ones(K - torch.abs(torch.tensor(i))), i) 16 | else: 17 | pw_tensor -= torch.diag((L2 - L1) * torch.ones(K - torch.abs(torch.tensor(i))), i) 18 | 19 | pw_tensor = pw_tensor.unsqueeze(0).unsqueeze(0) 20 | 21 | return pw_tensor 22 | 23 | class LBPMinSumFunction(torch.autograd.Function): 24 | @staticmethod 25 | def forward(ctx, cost, L1, L2, edge, messages, delta, jump, single_pw): 26 | 27 | if not single_pw.item(): 28 | #print("NOT SINGLE!") 29 | jump_in = torch.zeros((jump.shape[0], jump.shape[1] + 2, jump.shape[2], jump.shape[3])).cuda() 30 | jump_in[0, 0] = jump[0,0] 31 | jump_in[0, 2] = jump[0,1] 32 | jump_in[0, 1] = jump[0,0].permute((1,0)) 33 | jump_in[0, 3] = jump[0,1].permute((1,0)) 34 | else: 35 | #print("SINGLE!") 36 | jump_in = torch.zeros((jump.shape[0], 4, jump.shape[2], jump.shape[3])).cuda() 37 | jump_in[0, 0] = jump[0,0] 38 | jump_in[0, 2] = jump[0,0] 39 | jump_in[0, 1] = jump[0,0] 40 | jump_in[0, 3] = jump[0,0] 41 | 42 | messages, messages_argmin, message_scale = lbp.forward_minsum(cost, jump_in, edge, messages, delta) 43 | 44 | ctx.save_for_backward(cost, jump_in, edge, messages, messages_argmin, message_scale, single_pw) 45 | return messages 46 | 47 | @staticmethod 48 | # @profile 49 | def backward(ctx, in_grad): 50 | cost, jump, edge, messages, messages_argmin, message_scale, single_pw = ctx.saved_tensors 51 | 52 | grad_cost, grad_jump, grad_edge, grad_message = lbp.backward_minsum(cost, jump.contiguous(), edge, in_grad.contiguous(), messages, messages_argmin, message_scale) 53 | 54 | L1_grad = None 55 | L2_grad = None 56 | 57 | if not single_pw.item(): 58 | #print("NOT SINGLE!") 59 | grad_jump_out = torch.zeros((grad_jump.shape[0], grad_jump.shape[1] - 2, grad_jump.shape[2], grad_jump.shape[3])).cuda() 60 | grad_jump_out[0,0] += grad_jump[0,0] 61 | grad_jump_out[0,1] += grad_jump[0,2] 62 | grad_jump_out[0,0] += grad_jump[0,1].permute((1,0)) 63 | grad_jump_out[0,1] += grad_jump[0,3].permute((1,0)) 64 | else: 65 | #print("SINGLE!") 66 | grad_jump_out = torch.zeros((grad_jump.shape[0], 1, grad_jump.shape[2], grad_jump.shape[3])).cuda() 67 | grad_jump_out[0,0] += grad_jump[0,0] 68 | grad_jump_out[0,0] += grad_jump[0,1] 69 | grad_jump_out[0,0] += grad_jump[0,2] 70 | grad_jump_out[0,0] += grad_jump[0,3] 71 | 72 | return grad_cost, L1_grad, L2_grad, grad_edge, grad_message, None, grad_jump_out, None 73 | 74 | 75 | class MessagePassing(nn.Module): 76 | def __init__(self, device, max_iter, num_labels, delta, mode='min-sum'): 77 | super(MessagePassing, self).__init__() 78 | self.device = device 79 | self.max_iter = max_iter 80 | 81 | if mode != 'min-sum': 82 | raise ValueError("Unknown message parsing mode " + mode) 83 | self.mode = mode 84 | 85 | L1 = torch.tensor(0.1, device=device) 86 | L2 = torch.tensor(2.5, device=device) 87 | self.L1 = nn.Parameter(L1, requires_grad=True) 88 | self.L2 = nn.Parameter(L2, requires_grad=True) 89 | 90 | self.softmin = TemperatureSoftmin(dim=3, init_temp=1.0) 91 | 92 | self.delta = delta 93 | self.rescaleT = None 94 | 95 | def projectL1L2(self): 96 | self.L2.data = torch.max(self.L1.data, self.L2.data) 97 | 98 | def forward(self, prob_vol, edge_weights, messages, jump, single_pw=False): 99 | N, H, W, C = prob_vol.shape 100 | if edge_weights is None: 101 | edge_weights = torch.ones((N, 4, H, W)) 102 | 103 | if self.mode == 'min-sum': 104 | # convert to cost-input 105 | cost = -prob_vol 106 | 107 | # perform message-passing iterations 108 | for it in range(self.max_iter): 109 | messages = LBPMinSumFunction.apply(cost, self.L1, self.L2, edge_weights, messages, self.delta, jump, torch.tensor(single_pw)) 110 | 111 | # compute beliefs 112 | beliefs = messages.sum(dim=1) + cost 113 | 114 | # normalize output 115 | beliefs = self.softmin.forward(beliefs) 116 | 117 | else: 118 | raise NotImplementedError("message parsing mode " + self.mode + " is currently not implemented!") 119 | 120 | return beliefs 121 | -------------------------------------------------------------------------------- /ops/lbp_semantic_pw/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | setup( 6 | name='pytorch-lbp-op', 7 | version='0.2', 8 | author="Patrick Knöbelreiter", 9 | author_email="knoebelreiter@icg.tugraz.at", 10 | packages=["src"], 11 | include_dirs=['../include/'], 12 | ext_modules=[ 13 | CUDAExtension('pytorch_cuda_lbp_op', [ 14 | 'src/lbp.cpp', 15 | 'src/lbp_min_sum_kernel.cu', 16 | ]), 17 | ], 18 | cmdclass={ 19 | 'build_ext': BuildExtension 20 | }, 21 | install_requires=[ 22 | "numpy >= 1.15", 23 | "torch >= 0.4.1", 24 | "matplotlib >= 3.0.0", 25 | "scikit-image >= 0.14.1", 26 | "numba >= 0.42" 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /ops/lbp_semantic_pw/src/lbp.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "lbp_min_sum_kernel.cuh" 5 | 6 | // C++ interface 7 | // AT_ASSERTM in pytorch 1.0 8 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); 11 | 12 | 13 | // ================================================================================================ 14 | // MIN-SUM LBP 15 | // ================================================================================================ 16 | std::vector lbp_forward_min_sum(at::Tensor cost, 17 | at::Tensor jump, 18 | at::Tensor edge, 19 | at::Tensor messages, unsigned short delta) 20 | { 21 | CHECK_INPUT(cost) 22 | CHECK_INPUT(jump) 23 | CHECK_INPUT(edge) 24 | CHECK_INPUT(messages) 25 | 26 | return cuda::lbp_forward_min_sum(cost, jump, edge, messages, delta); 27 | } 28 | 29 | std::vector lbp_backward_min_sum(at::Tensor cost, 30 | at::Tensor jump, 31 | at::Tensor edge, 32 | at::Tensor in_grad, 33 | at::Tensor messages, 34 | at::Tensor messages_argmin, 35 | at::Tensor message_scale) 36 | { 37 | CHECK_INPUT(cost) 38 | CHECK_INPUT(jump) 39 | CHECK_INPUT(edge) 40 | CHECK_INPUT(in_grad) 41 | CHECK_INPUT(messages) 42 | CHECK_INPUT(messages_argmin) 43 | CHECK_INPUT(message_scale) 44 | 45 | return cuda::lbp_backward_min_sum(cost, jump, edge, in_grad, messages, messages_argmin, message_scale); 46 | } 47 | 48 | // ================================================================================================ 49 | // Pytorch Interfaces 50 | // ================================================================================================ 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 52 | { 53 | m.def("forward_minsum", &lbp_forward_min_sum, "LBP forward (CUDA)"); 54 | m.def("backward_minsum", &lbp_backward_min_sum, "LBP backward (CUDA)"); 55 | } 56 | 57 | -------------------------------------------------------------------------------- /ops/lbp_semantic_pw/src/lbp_min_sum_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace cuda 6 | { 7 | std::vector lbp_forward_min_sum(at::Tensor cost, 8 | at::Tensor jump, 9 | at::Tensor edge, 10 | at::Tensor messages, unsigned short delta); 11 | 12 | std::vector lbp_backward_min_sum(at::Tensor cost, 13 | at::Tensor jump, 14 | at::Tensor edge, 15 | at::Tensor in_grad, 16 | at::Tensor messages, 17 | at::Tensor messages_argmin, 18 | at::Tensor message_scale); 19 | 20 | } -------------------------------------------------------------------------------- /ops/lbp_semantic_pw/src/util.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #include "../../include/tensor.h" 6 | 7 | enum DIRECTION {LEFT, RIGHT, UP, DOWN}; 8 | 9 | 10 | extern __device__ __forceinline__ float getDerivativeValue(KernelData prev_values, int direction, float current_value, float max_float) 11 | { 12 | float ret_val = 0.0; 13 | if(prev_values(0, direction, 0, 0) != max_float) 14 | { 15 | ret_val = prev_values(0, direction, 0, 0); 16 | } 17 | else 18 | { 19 | ret_val = current_value; 20 | } 21 | 22 | return ret_val; 23 | } 24 | 25 | 26 | extern __device__ __forceinline__ float computeCrossGradient(KernelData prev_values, float max_float) 27 | { 28 | 29 | bool is_cross = false; 30 | if(prev_values(0,0,0,0) != max_float || 31 | prev_values(0,1,0,0) != max_float || 32 | prev_values(0,2,0,0) != max_float || 33 | prev_values(0,3,0,0) != max_float) 34 | { 35 | is_cross = true; 36 | } 37 | 38 | return is_cross; 39 | } 40 | 41 | extern __device__ __forceinline__ float getGradientAcc(KernelData gradient_accumulation, int direction, int n, int y, int x, int c, int grad_acc_idx) 42 | { 43 | 44 | float ret_val = 0.0; 45 | if(direction == UP || direction == DOWN) 46 | { 47 | ret_val = gradient_accumulation(n, x, grad_acc_idx, c); 48 | } 49 | if(direction == LEFT || direction == RIGHT) 50 | { 51 | ret_val = gradient_accumulation(n, y, grad_acc_idx, c); 52 | } 53 | 54 | return ret_val; 55 | 56 | } 57 | 58 | extern __device__ __forceinline__ void updateGradientAcc(KernelData gradient_accumulation, float value, int direction, int n, int y, int x, int c, int grad_acc_idx) 59 | { 60 | 61 | if(direction == UP || direction == DOWN) 62 | { 63 | //gradient_accumulation(n, x, grad_acc_idx, c) = value; 64 | atomicAdd(&gradient_accumulation(n, x, grad_acc_idx, c), value); 65 | } 66 | if(direction == LEFT || direction == RIGHT) 67 | { 68 | //gradient_accumulation(n, y, grad_acc_idx, c) = value; 69 | atomicAdd(&gradient_accumulation(n, y, grad_acc_idx, c), value); 70 | } 71 | 72 | } 73 | 74 | extern __device__ __forceinline__ void setGradientAcc(KernelData gradient_accumulation, float value, int direction, int n, int y, int x, int c, int grad_acc_idx) 75 | { 76 | 77 | if(direction == UP || direction == DOWN) 78 | { 79 | gradient_accumulation(n, x, grad_acc_idx, c) = value; 80 | } 81 | if(direction == LEFT || direction == RIGHT) 82 | { 83 | gradient_accumulation(n, y, grad_acc_idx, c) = value; 84 | } 85 | 86 | } 87 | 88 | extern __device__ __forceinline__ float getEdgeWeight(KernelData edges, int n, int y, int x, int direction) 89 | { 90 | 91 | float w = 1.0; 92 | if(direction == UP || direction == DOWN) 93 | { 94 | w = edges(n, 1, y, x); 95 | } 96 | else 97 | { 98 | w = edges(n, 0, y, x); 99 | } 100 | 101 | return w; 102 | } 103 | -------------------------------------------------------------------------------- /ops/lbp_semantic_pw_pixel/bp_op_cuda.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from networks import SubNetwork 8 | 9 | from ops.lbp_semantic_pw_pixel.inference_op import Inference 10 | from ops.lbp_semantic_pw_pixel.message_passing_op_pw_pixel import MessagePassing 11 | 12 | class BP(SubNetwork): 13 | @staticmethod 14 | def construct_pw_energy_weights(K, L1, L2): 15 | '''K = num labels''' 16 | fij = np.ones((K, K)) * L2 17 | for i in range(-1, 2): 18 | if i == 0: 19 | fij -= np.diag(L2 * np.ones(K - np.abs(i)), i) 20 | else: 21 | fij -= np.diag((L2 - L1) * np.ones(K - np.abs(i)), i) 22 | 23 | return fij 24 | 25 | @staticmethod 26 | def construct_pw_prob_weights(K, L1, L2): 27 | '''K = num labels''' 28 | fij = np.zeros((K, K)) 29 | 30 | for i in range(-1, 2): 31 | if i == 0: 32 | fij += np.diag(L1 * np.ones(K - np.abs(i)), i) 33 | else: 34 | fij += np.diag(L2 * np.ones(K - np.abs(i)), i) 35 | 36 | fij[fij == 0] = 1.0 - L1 - L2 #(1.0 - L1 - 2 * L2) / (K - 3.0) #1.0 - L1 - L2 37 | 38 | return fij 39 | 40 | # modes = wta / expectation 41 | def __init__(self, device, args, max_iter, num_labels, delta, mode_inference='expectation', mode_message_passing='min-sum', layer_idx=0): 42 | super(BP, self).__init__(args, device) 43 | self.device = device 44 | self.max_iter = max_iter 45 | self.layer_idx = layer_idx 46 | self.delta = delta 47 | 48 | print("init pixel wise bp...") 49 | 50 | if mode_inference != 'wta' and mode_inference != 'expectation' and mode_message_passing != 'min-sum' and mode_inference != 'norm' and mode_inference != 'raw': 51 | raise ValueError("Unknown inference/message passing mode " + mode_inference + " " + mode_message_passing) 52 | 53 | self.message_passing = MessagePassing(self.device, self.max_iter, num_labels, self.delta, mode_message_passing) 54 | self.inference = Inference(self.device, mode_inference, mode_passing=mode_message_passing) 55 | 56 | 57 | def forward(self, prob_vol, edge_weights, jump): 58 | 59 | N, H, W, K = prob_vol.shape 60 | messages = torch.zeros((N, 4, H, W, K), requires_grad=False, device=self.device, 61 | dtype=torch.float) 62 | 63 | # compute messages 64 | beliefs = self.message_passing.forward(prob_vol, edge_weights, messages, jump) 65 | 66 | # + wta/expectation 67 | result = self.inference.forward(beliefs) 68 | 69 | return result.permute(0,3,1,2), beliefs 70 | 71 | def project_jumpcosts(self): 72 | self.message_passing.projectL1L2() 73 | 74 | def save_checkpoint(self, epoch, iteration): 75 | if 'c' in self.args.train_params: 76 | torch.save(self.state_dict(), 77 | osp.join(self.args.train_dir, 'crf' + str(self.layer_idx) + '_checkpoint_' + 78 | str(epoch) + '_' + str(iteration).zfill(6) + '.cpt')) 79 | 80 | def adjust_input_weights(self, weights, idx): 81 | if weights is not None: 82 | weights_idx = weights[:, idx * 2 : (idx + 1) * 2, :, :] 83 | 84 | # wx_L = np.zeros_like(wx) 85 | # wy_D = np.zeros_like(wy) 86 | # wx_L[:, 1:] = wx[:, :-1] 87 | # wy_D[1:, :] = wy[:-1, :] 88 | 89 | weights_input = torch.zeros((weights_idx.shape[0], 4, weights_idx.shape[2], weights_idx.shape[3])).cuda() 90 | 91 | weights_input[:, 0] = weights[:, 0] 92 | # wx RL 93 | weights_input[:, 1, :, 1:] = weights[:, 0, :, :-1] 94 | # wy UD 95 | weights_input[:, 2] = weights[:, 1] 96 | # wy DU 97 | weights_input[:, 3, 1:, :] = weights[:, 1, :-1, :] 98 | 99 | weights_input = weights_input.contiguous() 100 | 101 | else: 102 | weights_input = None 103 | 104 | return weights_input 105 | 106 | @property 107 | def L1(self): 108 | return self.message_passing.L1 109 | 110 | @L1.setter 111 | def L1(self, value): 112 | self.message_passing.L1.data = torch.tensor(value, device=self.device, dtype=torch.float) 113 | 114 | @property 115 | def L2(self): 116 | return self.message_passing.L2 117 | 118 | @L2.setter 119 | def L2(self, value): 120 | self.message_passing.L2.data = torch.tensor(value, device=self.device, dtype=torch.float) 121 | 122 | @property 123 | def rescaleT(self): 124 | return self.message_passing.rescaleT 125 | 126 | @rescaleT.setter 127 | def rescaleT(self, value): 128 | self.message_passing.rescaleT = value 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /ops/lbp_semantic_pw_pixel/inference_op.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | import pytorch_cuda_lbp_op as lbp 7 | 8 | class Inference(nn.Module): 9 | # modes = wta / expectation 10 | def __init__(self, device, mode='wta', mode_passing='min-sum'): 11 | super(Inference, self).__init__() 12 | self.device = device 13 | if mode != 'wta' and mode != 'expectation' and mode != 'norm' and mode != 'raw': 14 | raise ValueError("Unknown inference mode " + mode) 15 | self.mode = mode 16 | self.mode_passing = mode_passing 17 | 18 | def forward(self, beliefs): 19 | if self.mode == "wta" and self.mode_passing == "min-sum": 20 | res = torch.argmax(beliefs, dim=3, keepdim=True) 21 | if self.mode == "wta": 22 | res = torch.argmax(beliefs, dim=3, keepdim=True) 23 | elif self.mode == "expectation": 24 | beliefs_normal = beliefs / beliefs.sum(dim=3, keepdim=True) 25 | 26 | if torch.isnan(beliefs_normal).sum() > 0: 27 | print("Beliefs normalized contains " + str(torch.isnan(beliefs_normal).sum()) + " NaNs ;(") 28 | 29 | labels = np.arange(beliefs.shape[3])[np.newaxis, np.newaxis, np.newaxis, :] 30 | labels_tensor = torch.tensor(labels.astype('float32'), device=self.device) 31 | res = (beliefs_normal * labels_tensor).sum(dim=3, keepdim=True) 32 | elif self.mode == "norm": 33 | beliefs_normal = beliefs / beliefs.sum(dim=3, keepdim=True) 34 | res = beliefs_normal 35 | elif self.mode == "raw": 36 | #print("using raw inference...") 37 | res = beliefs 38 | 39 | return res 40 | -------------------------------------------------------------------------------- /ops/lbp_semantic_pw_pixel/message_passing_op_pw_pixel.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import pytorch_cuda_lbp_op as lbp 7 | 8 | from corenet import TemperatureSoftmin 9 | 10 | def construct_pw_tensor(L1, L2, K): 11 | pw_tensor = torch.ones((K, K), dtype=torch.float) * L2 12 | 13 | for i in range(-1, 2): 14 | if i == 0: 15 | pw_tensor -= torch.diag(L2 * torch.ones(K - torch.abs(torch.tensor(i))), i) 16 | else: 17 | pw_tensor -= torch.diag((L2 - L1) * torch.ones(K - torch.abs(torch.tensor(i))), i) 18 | 19 | pw_tensor = pw_tensor.unsqueeze(0).unsqueeze(0) 20 | 21 | return pw_tensor 22 | 23 | class LBPMinSumFunction(torch.autograd.Function): 24 | @staticmethod 25 | def forward(ctx, cost, L1, L2, edge, messages, delta, jump): 26 | 27 | pw_net_jump_shift = torch.zeros((jump.shape[0] + 2, jump.shape[1], jump.shape[2], jump.shape[3], jump.shape[4])).cuda() 28 | 29 | # LR 30 | pw_net_jump_shift[0] = jump[0] 31 | # RL 32 | pw_net_jump_shift[1, :, 1:] = jump[0, :, :-1].permute((0, 1, 3, 2)) 33 | # UD 34 | pw_net_jump_shift[2] = jump[1] 35 | # DU 36 | pw_net_jump_shift[3, 1:] = jump[1, :-1].permute((0, 1, 3, 2)) 37 | 38 | pw_net_jump_shift = pw_net_jump_shift.contiguous() 39 | 40 | messages, messages_argmin, message_scale = lbp.forward_minsum(cost, pw_net_jump_shift, edge, messages, delta) 41 | 42 | ctx.save_for_backward(cost, edge, messages, messages_argmin, message_scale) 43 | return messages 44 | 45 | @staticmethod 46 | # @profile 47 | def backward(ctx, in_grad): 48 | cost, edge, messages, messages_argmin, message_scale = ctx.saved_tensors 49 | 50 | grad_cost, grad_jump, grad_edge, grad_message = lbp.backward_minsum(cost, edge, in_grad.contiguous(), messages, messages_argmin, message_scale) 51 | 52 | L1_grad = None 53 | L2_grad = None 54 | 55 | grad_jump[0, :, :-1] += grad_jump[1, :, 1:].permute((0, 1, 3, 2)) 56 | grad_jump[1] = 0 57 | grad_jump[1] += grad_jump[2] 58 | grad_jump[1, :-1] += grad_jump[3, 1:].permute((0, 1, 3, 2)) 59 | 60 | return grad_cost, L1_grad, L2_grad, grad_edge, grad_message, None, grad_jump[:2] 61 | 62 | class MessagePassing(nn.Module): 63 | def __init__(self, device, max_iter, num_labels, delta, mode='min-sum'): 64 | super(MessagePassing, self).__init__() 65 | 66 | self.device = device 67 | self.max_iter = max_iter 68 | 69 | if mode != 'min-sum': 70 | raise ValueError("Unknown message parsing mode " + mode) 71 | self.mode = mode 72 | 73 | L1 = torch.tensor(0.1, device=device) 74 | L2 = torch.tensor(2.5, device=device) 75 | self.L1 = nn.Parameter(L1, requires_grad=True) 76 | self.L2 = nn.Parameter(L2, requires_grad=True) 77 | 78 | self.softmin = TemperatureSoftmin(dim=3, init_temp=1.0) 79 | 80 | self.delta = delta 81 | self.rescaleT = None 82 | 83 | def projectL1L2(self): 84 | self.L2.data = torch.max(self.L1.data, self.L2.data) 85 | 86 | def forward(self, prob_vol, edge_weights, messages, jump): 87 | 88 | N, H, W, C = prob_vol.shape 89 | if edge_weights is None: 90 | edge_weights = torch.ones((N, 4, H, W)) 91 | 92 | if self.mode == 'min-sum': 93 | # convert to cost-input 94 | cost = -prob_vol 95 | 96 | # perform message-passing iterations 97 | for it in range(self.max_iter): 98 | messages = LBPMinSumFunction.apply(cost, self.L1, self.L2, edge_weights, messages, self.delta, jump) 99 | 100 | # compute beliefs 101 | beliefs = messages.sum(dim=1) + cost 102 | 103 | # normalize output 104 | beliefs = self.softmin.forward(beliefs) 105 | 106 | else: 107 | raise NotImplementedError("message parsing mode " + self.mode + " is currently not implemented!") 108 | 109 | return beliefs -------------------------------------------------------------------------------- /ops/lbp_semantic_pw_pixel/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | setup( 6 | name='pytorch-lbp-op', 7 | version='0.2', 8 | author="Patrick Knöbelreiter", 9 | author_email="knoebelreiter@icg.tugraz.at", 10 | packages=["src"], 11 | include_dirs=[], 12 | ext_modules=[ 13 | CUDAExtension('pytorch_cuda_lbp_op', [ 14 | 'src/lbp.cpp', 15 | 'src/lbp_min_sum_kernel.cu', 16 | ]), 17 | ], 18 | cmdclass={ 19 | 'build_ext': BuildExtension 20 | }, 21 | install_requires=[ 22 | "numpy >= 1.15", 23 | "torch >= 0.4.1", 24 | "matplotlib >= 3.0.0", 25 | "scikit-image >= 0.14.1", 26 | "numba >= 0.42" 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /ops/lbp_semantic_pw_pixel/src/lbp.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "lbp_min_sum_kernel.cuh" 5 | 6 | // C++ interface 7 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); 10 | 11 | 12 | // ================================================================================================ 13 | // MIN-SUM LBP 14 | // ================================================================================================ 15 | std::vector lbp_forward_min_sum(at::Tensor cost, 16 | at::Tensor jump, 17 | at::Tensor edge, 18 | at::Tensor messages, unsigned short delta) 19 | { 20 | CHECK_INPUT(cost) 21 | CHECK_INPUT(jump) 22 | CHECK_INPUT(edge) 23 | CHECK_INPUT(messages) 24 | 25 | return cuda::lbp_forward_min_sum(cost, jump, edge, messages, delta); 26 | } 27 | 28 | std::vector lbp_backward_min_sum(at::Tensor cost, 29 | at::Tensor edge, 30 | at::Tensor in_grad, 31 | at::Tensor messages, 32 | at::Tensor messages_argmin, 33 | at::Tensor message_scale) 34 | { 35 | CHECK_INPUT(cost) 36 | //CHECK_INPUT(jump) 37 | CHECK_INPUT(edge) 38 | CHECK_INPUT(in_grad) 39 | CHECK_INPUT(messages) 40 | CHECK_INPUT(messages_argmin) 41 | CHECK_INPUT(message_scale) 42 | 43 | return cuda::lbp_backward_min_sum(cost, edge, in_grad, messages, messages_argmin, message_scale); 44 | } 45 | 46 | // ================================================================================================ 47 | // Pytorch Interfaces 48 | // ================================================================================================ 49 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 50 | { 51 | m.def("forward_minsum", &lbp_forward_min_sum, "LBP forward (CUDA)"); 52 | m.def("backward_minsum", &lbp_backward_min_sum, "LBP backward (CUDA)"); 53 | } 54 | 55 | -------------------------------------------------------------------------------- /ops/lbp_semantic_pw_pixel/src/lbp_min_sum_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace cuda 6 | { 7 | std::vector lbp_forward_min_sum(at::Tensor cost, 8 | at::Tensor jump, 9 | at::Tensor edge, 10 | at::Tensor messages, unsigned short delta); 11 | 12 | std::vector lbp_backward_min_sum(at::Tensor cost, 13 | at::Tensor edge, 14 | at::Tensor in_grad, 15 | at::Tensor messages, 16 | at::Tensor messages_argmin, 17 | at::Tensor message_scale); 18 | 19 | } -------------------------------------------------------------------------------- /ops/lbp_semantic_pw_pixel/src/util.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #include "../../include/tensor.h" 6 | 7 | enum DIRECTION {LEFT, RIGHT, UP, DOWN}; 8 | 9 | 10 | extern __device__ __forceinline__ float getDerivativeValue(KernelData prev_values, int direction, float current_value, float max_float) 11 | { 12 | float ret_val = 0.0; 13 | if(prev_values(0, direction, 0, 0) != max_float) 14 | { 15 | ret_val = prev_values(0, direction, 0, 0); 16 | } 17 | else 18 | { 19 | ret_val = current_value; 20 | } 21 | 22 | return ret_val; 23 | } 24 | 25 | 26 | extern __device__ __forceinline__ float computeCrossGradient(KernelData prev_values, float max_float) 27 | { 28 | 29 | bool is_cross = false; 30 | if(prev_values(0,0,0,0) != max_float || 31 | prev_values(0,1,0,0) != max_float || 32 | prev_values(0,2,0,0) != max_float || 33 | prev_values(0,3,0,0) != max_float) 34 | { 35 | is_cross = true; 36 | } 37 | 38 | return is_cross; 39 | } 40 | 41 | extern __device__ __forceinline__ float getGradientAcc(KernelData gradient_accumulation, int direction, int n, int y, int x, int c, int grad_acc_idx) 42 | { 43 | 44 | float ret_val = 0.0; 45 | if(direction == UP || direction == DOWN) 46 | { 47 | ret_val = gradient_accumulation(n, x, grad_acc_idx, c); 48 | } 49 | if(direction == LEFT || direction == RIGHT) 50 | { 51 | ret_val = gradient_accumulation(n, y, grad_acc_idx, c); 52 | } 53 | 54 | return ret_val; 55 | 56 | } 57 | 58 | extern __device__ __forceinline__ void updateGradientAcc(KernelData gradient_accumulation, float value, int direction, int n, int y, int x, int c, int grad_acc_idx) 59 | { 60 | 61 | if(direction == UP || direction == DOWN) 62 | { 63 | //gradient_accumulation(n, x, grad_acc_idx, c) = value; 64 | atomicAdd(&gradient_accumulation(n, x, grad_acc_idx, c), value); 65 | } 66 | if(direction == LEFT || direction == RIGHT) 67 | { 68 | //gradient_accumulation(n, y, grad_acc_idx, c) = value; 69 | atomicAdd(&gradient_accumulation(n, y, grad_acc_idx, c), value); 70 | } 71 | 72 | } 73 | 74 | extern __device__ __forceinline__ void setGradientAcc(KernelData gradient_accumulation, float value, int direction, int n, int y, int x, int c, int grad_acc_idx) 75 | { 76 | 77 | if(direction == UP || direction == DOWN) 78 | { 79 | gradient_accumulation(n, x, grad_acc_idx, c) = value; 80 | } 81 | if(direction == LEFT || direction == RIGHT) 82 | { 83 | gradient_accumulation(n, y, grad_acc_idx, c) = value; 84 | } 85 | 86 | } 87 | 88 | extern __device__ __forceinline__ float getEdgeWeight(KernelData edges, int n, int y, int x, int direction) 89 | { 90 | 91 | float w = 1.0; 92 | if(direction == UP || direction == DOWN) 93 | { 94 | w = edges(n, 1, y, x); 95 | } 96 | else 97 | { 98 | w = edges(n, 0, y, x); 99 | } 100 | 101 | return w; 102 | } 103 | -------------------------------------------------------------------------------- /ops/lbp_stereo/bp_op_cuda.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from networks import SubNetwork 8 | 9 | import pytorch_cuda_lbp_op as lbp 10 | from ops.lbp_stereo.inference_op import Inference 11 | from ops.lbp_stereo.message_passing_op_cuda import MessagePassing 12 | 13 | import numba 14 | 15 | class BP(SubNetwork): 16 | 17 | @staticmethod 18 | def get_linear_idx(pos, L_vec_size): 19 | vec_idx = L_vec_size // 2 + pos 20 | vec_idx = np.maximum(1, vec_idx) 21 | vec_idx = np.minimum(L_vec_size - 1, vec_idx) 22 | 23 | return vec_idx 24 | 25 | @staticmethod 26 | def f_func(t, s, L_vec): 27 | 28 | o = L_vec[0] 29 | 30 | input_pos = t - s + o 31 | 32 | lower_pos = int(np.floor(input_pos)) 33 | upper_pos = int(np.ceil(input_pos)) 34 | 35 | if lower_pos == upper_pos: 36 | vec_idx = BP.get_linear_idx(lower_pos, L_vec.shape[0]) 37 | #print(L_vec[vec_idx]) 38 | return L_vec[vec_idx] 39 | 40 | lower_vec_idx = BP.get_linear_idx(lower_pos, L_vec.shape[0]) 41 | upper_vec_idx = BP.get_linear_idx(upper_pos, L_vec.shape[0]) 42 | 43 | lower_val = L_vec[lower_vec_idx] 44 | upper_val = L_vec[upper_vec_idx] 45 | 46 | weight_upper = input_pos - lower_pos 47 | weight_lower = upper_pos - input_pos 48 | 49 | interp_val = weight_lower * lower_val + weight_upper * upper_val 50 | 51 | return interp_val 52 | 53 | @staticmethod 54 | def construct_pw_energy_weights(K, L1, L2): 55 | '''K = num labels''' 56 | fij = np.ones((K, K)) * L2 57 | for i in range(-1, 2): 58 | if i == 0: 59 | fij -= np.diag(L2 * np.ones(K - np.abs(i)), i) 60 | else: 61 | fij -= np.diag((L2 - L1) * np.ones(K - np.abs(i)), i) 62 | 63 | return fij 64 | 65 | @staticmethod 66 | def construct_pw_prob_weights(K, L1, L2): 67 | '''K = num labels''' 68 | fij = np.zeros((K, K)) 69 | 70 | for i in range(-1, 2): 71 | if i == 0: 72 | fij += np.diag(L1 * np.ones(K - np.abs(i)), i) 73 | else: 74 | fij += np.diag(L2 * np.ones(K - np.abs(i)), i) 75 | 76 | fij[fij == 0] = 1.0 - L1 - L2 #(1.0 - L1 - 2 * L2) / (K - 3.0) #1.0 - L1 - L2 77 | 78 | return fij 79 | 80 | # modes = wta / expectation 81 | def __init__(self, device, args, max_iter, num_labels, delta, mode_inference='expectation', mode_message_passing='min-sum', layer_idx=0, level=0): 82 | super(BP, self).__init__(args, device) 83 | self.device = device 84 | self.max_iter = max_iter 85 | self.layer_idx = layer_idx 86 | self.level = level 87 | self.delta = delta 88 | 89 | if mode_inference != 'wta' and mode_inference != 'expectation' and mode_message_passing != 'min-sum' and mode_inference != 'norm' and mode_inference != 'raw': 90 | raise ValueError("Unknown inference/message passing mode " + mode_inference + " " + mode_message_passing) 91 | 92 | self.message_passing = MessagePassing(self.device, self.max_iter, num_labels, delta, mode_message_passing) 93 | self.inference = Inference(self.device, mode_inference, mode_passing=mode_message_passing) 94 | 95 | if args.checkpoint_crf and args and 'checkpoint_crf' in args.__dict__.keys() and args.checkpoint_crf[0] is not None: 96 | self.load_parameters(args.checkpoint_crf[self.level][self.layer_idx], device) 97 | 98 | def forward(self, prob_vol, edge_weights, affinities = None, offsets = None): 99 | 100 | if len(prob_vol.shape) == 5: 101 | N, _, H, W, K = prob_vol.shape 102 | 103 | # u-flow 104 | messages_u = torch.zeros((N, 4, H, W, K), requires_grad=True, device=self.device, dtype=torch.float) 105 | beliefs_u, messages_u = self.message_passing.forward(prob_vol[:,0].contiguous(), edge_weights, affinities, offsets, messages_u) 106 | result_u = self.inference.forward(beliefs_u) 107 | 108 | # v-flow 109 | messages_v = torch.zeros((N, 4, H, W, K), requires_grad=True, device=self.device, dtype=torch.float) 110 | beliefs_v, messages_v = self.message_passing.forward(prob_vol[:,1].contiguous(), edge_weights, affinities, offsets, messages_u) 111 | result_v = self.inference.forward(beliefs_v) 112 | 113 | flow = torch.cat((result_u, result_v), dim=-1).permute(0,3,1,2) 114 | beliefs = torch.cat((beliefs_u.unsqueeze(1), beliefs_v.unsqueeze(1)), dim=1) 115 | messages = torch.cat((messages_u.unsqueeze(1), messages_v.unsqueeze(1)), dim=1) 116 | 117 | return flow, beliefs, messages 118 | else: # 4 119 | N, H, W, K = prob_vol.shape 120 | 121 | # disps 122 | messages = torch.zeros((N, 4, H, W, K), requires_grad=True, device=self.device, dtype=torch.float) 123 | beliefs, messages = self.message_passing.forward(prob_vol.contiguous(), edge_weights, affinities, offsets, messages) 124 | result = self.inference.forward(beliefs) 125 | 126 | disps = result.permute(0,3,1,2) 127 | return disps, beliefs, messages 128 | 129 | 130 | def adjust_input_weights(self, weights, idx): 131 | if weights is not None: 132 | weights_idx = weights[:, idx * 2 : (idx + 1) * 2, :, :] 133 | 134 | # wx_L = np.zeros_like(wx) 135 | # wy_D = np.zeros_like(wy) 136 | # wx_L[:, 1:] = wx[:, :-1] 137 | # wy_D[1:, :] = wy[:-1, :] 138 | 139 | weights_input = torch.zeros((weights_idx.shape[0], 4, weights_idx.shape[2], weights_idx.shape[3])).cuda() 140 | 141 | weights_input[:, 0] = weights[:, 0] 142 | # wx RL 143 | weights_input[:, 1, :, 1:] = weights[:, 0, :, :-1] 144 | # wy UD 145 | weights_input[:, 2] = weights[:, 1] 146 | # wy DU 147 | weights_input[:, 3, 1:, :] = weights[:, 1, :-1, :] 148 | 149 | weights_input = weights_input.contiguous() 150 | 151 | else: 152 | weights_input = None 153 | 154 | return weights_input 155 | 156 | def adjust_input_affinities(self, affinities): 157 | # create affinities for 4 directions 158 | if affinities is not None: 159 | # outshape = N x 2, 5 x H x W 160 | # ensure ordering constraint # L2-, L2+, L1-, L1+, L3 161 | # L3 >= L2 <= L1 162 | affinities_new = affinities.clone() 163 | affinities_new[:, :, 0] = torch.max(affinities[:, :, 0], affinities[:, :, 2]) 164 | affinities_new[:, :, 1] = torch.max(affinities[:, :, 1], affinities[:, :, 3]) 165 | affinities_new[:, :, 4] = torch.max(affinities[:, :, 4], 166 | torch.max(affinities_new[:, :, 0].clone(), 167 | affinities_new[:, :, 1].clone())) 168 | 169 | affinities = affinities_new 170 | 171 | 172 | # shifted affinities 173 | affinities_shift = torch.zeros((affinities.shape[0], affinities.shape[1] + 2, affinities.shape[2], affinities.shape[3], affinities.shape[4])).cuda() 174 | 175 | # ax LR 176 | affinities_shift[:, 0] = affinities[:, 0] 177 | # ax RL 178 | affinities_shift[:, 1, :, :, 1:] = affinities[:, 0, :, :, :-1] 179 | # ay UD 180 | affinities_shift[:, 2] = affinities[:, 1] 181 | # ay DU 182 | affinities_shift[:, 3, :, 1:, :] = affinities[:, 1, :, :-1, :] 183 | 184 | affinities_shift = affinities_shift.contiguous() 185 | else: 186 | affinities_shift = None 187 | 188 | return affinities_shift 189 | 190 | def adjust_input_offsets(self, offsets): 191 | # create offsets for 4 directions 192 | if offsets is not None: 193 | # shifted offsets 194 | offsets_shift = torch.zeros((offsets.shape[0], offsets.shape[1] + 2, offsets.shape[2], offsets.shape[3])).cuda() 195 | 196 | # ox LR 197 | offsets_shift[:, 0] = offsets[:, 0] 198 | # ox RL 199 | offsets_shift[:, 1, :, 1:] = -offsets[:, 0, :, :-1] 200 | # oy UD 201 | offsets_shift[:, 2] = offsets[:, 1] 202 | # oy DU 203 | offsets_shift[:, 3, 1:, :] = -offsets[:, 1, :-1, :] 204 | 205 | offsets_shift = offsets_shift.contiguous() 206 | else: 207 | offsets_shift = None 208 | 209 | return offsets_shift 210 | 211 | 212 | def project_jumpcosts(self): 213 | self.message_passing.projectL1L2() 214 | 215 | def save_checkpoint(self, epoch, iteration): 216 | if 'c' in self.args.train_params: 217 | torch.save(self.state_dict(), 218 | osp.join(self.args.train_dir, 'crf' + str(self.layer_idx) + '_lvl' + str(self.level) + '_checkpoint_' + 219 | str(epoch) + '_' + str(iteration).zfill(6) + '.cpt')) 220 | 221 | def hook_adjust_checkpoint(self, checkpoint): 222 | # allow to continue a training where the temperature parameter after the BP did not exist 223 | if 'message_passing.softmin.T' not in checkpoint.keys(): 224 | print('Info: Adjust loaded BP-checkpoint -> Use temperature T=1 (=fully backward compatible)') 225 | checkpoint['message_passing.softmin.T'] = self.message_passing.softmin.T 226 | 227 | if 'message_passing.L1' in checkpoint.keys() and 'message_passing.L2' in checkpoint.keys(): 228 | print('Info: Adjust loaded BP-checkpoint -> Use setter for L1 L2') 229 | #L1 = nn.Parameter(checkpoint['message_passing.L1']) 230 | #L2 = nn.Parameter(checkpoint['message_passing.L2']) 231 | #self.setL1L2(L1, L2) 232 | checkpoint['message_passing._L1'] = checkpoint['message_passing.L1'] 233 | checkpoint['message_passing._L2'] = checkpoint['message_passing.L2'] 234 | checkpoint.pop('message_passing.L1') 235 | checkpoint.pop('message_passing.L2') 236 | 237 | return checkpoint 238 | 239 | @property 240 | def L1(self): 241 | return self.message_passing.L1 242 | 243 | @property 244 | def L2(self): 245 | return self.message_passing.L2 246 | 247 | def setL1L2(self, value_L1, value_L2): 248 | self.message_passing.setL1L2(value_L1, value_L2) 249 | 250 | @property 251 | def rescaleT(self): 252 | return self.message_passing.rescaleT 253 | 254 | @rescaleT.setter 255 | def rescaleT(self, value): 256 | self.message_passing.rescaleT = value 257 | 258 | 259 | 260 | -------------------------------------------------------------------------------- /ops/lbp_stereo/inference_op.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | import pytorch_cuda_lbp_op as lbp 7 | 8 | class Inference(nn.Module): 9 | # modes = wta / expectation 10 | def __init__(self, device, mode='wta', mode_passing='min-sum'): 11 | super(Inference, self).__init__() 12 | self.device = device 13 | if mode != 'wta' \ 14 | and mode != 'expectation' \ 15 | and mode != 'norm' \ 16 | and mode != 'raw' \ 17 | and mode != 'sub-exp': 18 | raise ValueError("Unknown inference mode " + mode) 19 | self.mode = mode 20 | self.mode_passing = mode_passing 21 | 22 | def forward(self, beliefs): 23 | if self.mode == "wta" and self.mode_passing == "min-sum": 24 | res = torch.argmax(beliefs, dim=3, keepdim=True).float() 25 | 26 | if self.mode == "wta": 27 | res = torch.argmax(beliefs, dim=3, keepdim=True) 28 | elif self.mode == "expectation": 29 | beliefs_normal = beliefs / beliefs.sum(dim=3, keepdim=True) 30 | 31 | if torch.isnan(beliefs_normal).sum() > 0: 32 | print("Beliefs normalized contains " + str(torch.isnan(beliefs_normal).sum()) + \ 33 | " NaNs ;(") 34 | 35 | labels = np.arange(beliefs.shape[3])[np.newaxis, np.newaxis, np.newaxis, :] 36 | labels_tensor = torch.tensor(labels.astype('float32'), device=self.device) 37 | res = (beliefs_normal * labels_tensor).sum(dim=3, keepdim=True) 38 | 39 | elif self.mode == "norm": 40 | beliefs_normal = beliefs / beliefs.sum(dim=3, keepdim=True) 41 | res = beliefs_normal 42 | 43 | elif self.mode == "raw": 44 | #print("using raw inference...") 45 | res = beliefs 46 | 47 | elif self.mode == 'sub-exp': 48 | res = self.compute_sub_expectation(beliefs) 49 | 50 | return res 51 | 52 | @staticmethod 53 | def compute_sub_expectation(beliefs, support=3): 54 | N, H, W, K = beliefs.shape 55 | device = beliefs.device 56 | 57 | disp = beliefs.argmax(dim=-1).unsqueeze(-1) 58 | 59 | # generate coordinates 60 | n_coords = torch.arange(N, device=device, dtype=torch.long) 61 | n_coords = n_coords.view(-1, 1, 1, 1) 62 | 63 | x_coords = torch.arange(W, device=device, dtype=torch.long).view(1, 1, -1, 1) 64 | y_coords = torch.arange(H, device=device, dtype=torch.long).view(1, -1, 1, 1) 65 | 66 | # nl = n_coords.expand((N, H, W, K)).long() 67 | # xl = x_coords.expand((N, H, W, K)).long() 68 | # yl = y_coords.expand((N, H, W, K)).long() 69 | 70 | #disp_multiple_hot = torch.zeros((N, H, W, K), device=device, dtype=torch.float) 71 | #torch.cuda.empty_cache() 72 | #for offset in range(-support, support + 1): 73 | # disp_offs = torch.min(torch.max(disp + offset, torch.tensor(0, device=device)), 74 | # torch.tensor(K - 1, device=device)) 75 | # print(offset) 76 | # disp_multiple_hot[nl, yl, xl, disp_offs] = 1 77 | # torch.cuda.empty_cache() 78 | 79 | # disps_range = torch.arange(K, device=device, dtype=torch.float).view(1, 1, 1, -1) 80 | 81 | # beliefs_max = beliefs * disp_multiple_hot 82 | # beliefs_max_normalized = beliefs_max / beliefs_max.sum(dim=-1, keepdim=True) 83 | # disp_subpix = torch.sum(beliefs_max_normalized * disps_range, dim=-1, keepdim=True) 84 | 85 | # reduces GPU memory requirement significantly 86 | ws = 2 * support + 1 87 | nl = n_coords.expand((N, H, W, ws)).long() 88 | xl = x_coords.expand((N, H, W, ws)).long() 89 | yl = y_coords.expand((N, H, W, ws)).long() 90 | 91 | disp_windows = torch.arange(-support, support + 1).cuda() + disp 92 | disp_windows = torch.min(torch.max(disp_windows, torch.tensor(0, device=device)), torch.tensor(K - 1, device=device)) 93 | beliefs_windows = beliefs[nl, yl, xl, disp_windows] 94 | beliefs_windows_normalized = beliefs_windows / beliefs_windows.sum(dim=-1, keepdim=True) 95 | disp_subpix = torch.sum(beliefs_windows_normalized * disp_windows, dim=-1, keepdim=True) 96 | 97 | return disp_subpix 98 | -------------------------------------------------------------------------------- /ops/lbp_stereo/message_passing_op_cuda.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from corenet import TemperatureSoftmin 6 | 7 | import pytorch_cuda_lbp_op as lbp 8 | 9 | def construct_pw_tensor(L1, L2, K): 10 | pw_tensor = torch.ones((K, K), dtype=torch.float) * L2 11 | 12 | for i in range(-1, 2): 13 | if i == 0: 14 | pw_tensor -= torch.diag(L2 * torch.ones(K - torch.abs(torch.tensor(i))), i) 15 | else: 16 | pw_tensor -= torch.diag((L2 - L1) * torch.ones(K - torch.abs(torch.tensor(i))), i) 17 | 18 | pw_tensor = pw_tensor.unsqueeze(0).unsqueeze(0) 19 | 20 | return pw_tensor 21 | 22 | class LBPMinSumFunction(torch.autograd.Function): 23 | @staticmethod 24 | def forward(ctx, cost, edge, messages, delta, affinities, offset): 25 | 26 | #print("Min sum forward OP......") 27 | 28 | #jump = construct_pw_tensor(L1, L2, cost.shape[3]).cuda() 29 | 30 | # affinities shape = N x 2*2 x 5 x H x W dim1 = 4 => already shifted 31 | # affinties dim1 = LR, RL, UD, DU 32 | # affinities dim2 = # L2-, L2+, L1-, L1+, L3 33 | # affiniteis_input shape = N x 2*2 x 5+3 x H x W 34 | # affinities_input dim2 = # offset, L3, L2-, L1-, L0, L1+, L2+, L3 35 | affinities_input = torch.zeros((affinities.shape[0], affinities.shape[1], affinities.shape[2] + 3, affinities.shape[3], affinities.shape[4])).cuda() 36 | 37 | affinities_input[:, :, 0, :, :] = offset 38 | 39 | # L3 is symmetric => assign all directions 40 | affinities_input[:, :, 1, :, :] = affinities[:, :, -1, :, :] 41 | affinities_input[:, :, -1, :, :] = affinities[:, :, -1, :, :] 42 | 43 | # L0 44 | affinities_input[:, :, int(affinities_input.shape[2] / 2), :, :] = 0.0 45 | 46 | # everything else 47 | in_count = 0 # affinities counter 48 | for i in range(2, int(affinities_input.shape[2] / 2)): # i = affinities_input counter 49 | # copy original ordering for LR and UD (first negative, then positive) 50 | affinities_input[:, 0::2, i, :, :] = affinities[:, 0::2, in_count, :, :] 51 | affinities_input[:, 0::2, -i, :, :] = affinities[:, 0::2, in_count + 1, :, :] 52 | 53 | # copy reverse ordering for RL and DU (first positive, then negative) 54 | affinities_input[:, 1::2, i, :, :] = affinities[:, 1::2, in_count + 1, :, :] 55 | affinities_input[:, 1::2, -i, :, :] = affinities[:, 1::2, in_count, :, :] 56 | in_count += 2 57 | 58 | #print(affinities_input[0, 0, :, 0, 0]) 59 | 60 | affinities_input = affinities_input.contiguous() 61 | torch.cuda.empty_cache() 62 | messages, messages_argmin, message_scale = lbp.forward_minsum(cost, affinities_input, edge, messages, delta) 63 | 64 | ctx.save_for_backward(cost, affinities_input, edge, messages, messages_argmin, message_scale) 65 | return messages 66 | 67 | @staticmethod 68 | # @profile 69 | def backward(ctx, in_grad): 70 | cost, affinities_input, edge, messages, messages_argmin, message_scale = ctx.saved_tensors 71 | 72 | grad_cost, grad_affinities_input, grad_edge, grad_message = lbp.backward_minsum(cost, affinities_input, edge, in_grad.contiguous(), messages, messages_argmin, message_scale) 73 | 74 | #re-compute affinities grad for all learned params 75 | grad_affinities_out = torch.zeros((affinities_input.shape[0], affinities_input.shape[1], affinities_input.shape[2] - 3, affinities_input.shape[3], affinities_input.shape[4])).cuda() 76 | 77 | # sum up grad L3 78 | grad_affinities_out[:, :, -1, :, :] += grad_affinities_input[:, :, 1, :, :] 79 | grad_affinities_out[:, :, -1, :, :] += grad_affinities_input[:, :, -1, :, :] 80 | 81 | in_count = 0 82 | for i in range(2, int(affinities_input.shape[2] / 2)): 83 | grad_affinities_out[:, 0::2, in_count, :, :] = grad_affinities_input[:, 0::2, i, :, :] 84 | grad_affinities_out[:, 0::2, in_count + 1, :, :] = grad_affinities_input[:, 0::2, -i, :, :] 85 | 86 | grad_affinities_out[:, 1::2, in_count + 1, :, :] = grad_affinities_input[:, 1::2, i, :, :] 87 | grad_affinities_out[:, 1::2, in_count, :, :] = grad_affinities_input[:, 1::2, -i, :, :] 88 | in_count += 2 89 | 90 | #offset grad 91 | grad_offset = grad_affinities_input[:, :, 0, :, :] 92 | 93 | return grad_cost, grad_edge, grad_message, None, grad_affinities_out, grad_offset 94 | 95 | class MessagePassing(nn.Module): 96 | 97 | @property 98 | def L1(self): 99 | return self._L1 100 | 101 | @property 102 | def L2(self): 103 | return self._L2 104 | 105 | def setL1L2(self, value_L1, value_L2): 106 | 107 | if value_L1 > 0 and value_L2 > 0 and value_L1 <= value_L2: 108 | self._L1 = value_L1 109 | self._L2 = value_L2 110 | elif value_L1 < 0 or value_L2 < 0: 111 | raise ValueError("L1 or L2 is < 0!") 112 | elif value_L1 > value_L2: 113 | raise ValueError("L1 must be smaller than or equal L2!") 114 | 115 | def __init__(self, device, max_iter, num_labels, delta, mode='min-sum'): 116 | super(MessagePassing, self).__init__() 117 | self.device = device 118 | self.max_iter = max_iter 119 | 120 | if mode != 'min-sum': 121 | raise ValueError("Unknown message parsing mode " + mode) 122 | self.mode = mode 123 | 124 | L1 = torch.tensor(0.1, device=device) 125 | L2 = torch.tensor(2.5, device=device) 126 | self._L1 = nn.Parameter(L1, requires_grad=True) 127 | self._L2 = nn.Parameter(L2, requires_grad=True) 128 | 129 | self.softmin = TemperatureSoftmin(dim=3, init_temp=1.0) 130 | 131 | self.delta = delta 132 | 133 | self.rescaleT = None 134 | 135 | def projectL1L2(self): 136 | self.L2.data = torch.max(self.L1.data, self.L2.data) 137 | 138 | def forward(self, prob_vol, edge_weights, affinities, offset, messages): 139 | N, H, W, C = prob_vol.shape 140 | 141 | NUM_DIR = 4 142 | 143 | if edge_weights is None: 144 | edge_weights = torch.ones((N, NUM_DIR, H, W)).cuda() 145 | 146 | if affinities is None: 147 | if (self._L1 is not None) and (self._L2 is not None): 148 | # parameters are expected as follows L1-left L1-right L2 149 | affinities = torch.zeros((N, NUM_DIR, 3, H, W), dtype=torch.float).cuda() 150 | affinities[:, :, :2, :, :] = self._L1 151 | affinities[:, :, 2, :, :] = self._L2 152 | 153 | else: 154 | if (self._L1 is not None and self._L2 is None) or (self._L1 is None and self._L2 is not None): 155 | raise ValueError("L1 or L2 is None and affinities are not set!") 156 | 157 | if offset is None: 158 | offset = torch.zeros((N, NUM_DIR, H, W)) 159 | 160 | if self.mode == 'min-sum': 161 | # convert to cost-input 162 | cost = -prob_vol 163 | 164 | # perform message-passing iterations 165 | for it in range(self.max_iter): 166 | torch.cuda.empty_cache() 167 | messages = LBPMinSumFunction.apply(cost, edge_weights, messages, self.delta, affinities, offset) 168 | 169 | # compute beliefs 170 | beliefs = messages.sum(dim=1) + cost 171 | 172 | # normalize output 173 | beliefs = self.softmin.forward(beliefs) 174 | 175 | else: 176 | raise NotImplementedError("message parsing mode " + self.mode + " is currently not implemented!") 177 | 178 | return beliefs, messages -------------------------------------------------------------------------------- /ops/lbp_stereo/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | setup( 6 | name='pytorch-lbp-op', 7 | version='0.2', 8 | author="Patrick Knöbelreiter", 9 | author_email="knoebelreiter@icg.tugraz.at", 10 | packages=["src"], 11 | include_dirs=[], 12 | ext_modules=[ 13 | CUDAExtension('pytorch_cuda_lbp_op', [ 14 | 'src/lbp.cpp', 15 | 'src/lbp_min_sum_kernel.cu', 16 | ]), 17 | ], 18 | cmdclass={ 19 | 'build_ext': BuildExtension 20 | }, 21 | install_requires=[ 22 | "numpy >= 1.15", 23 | "torch >= 0.4.1", 24 | "matplotlib >= 3.0.0", 25 | "scikit-image >= 0.14.1", 26 | "numba >= 0.42" 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /ops/lbp_stereo/src/lbp.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "lbp_min_sum_kernel.cuh" 5 | 6 | // C++ interface 7 | // AT_ASSERTM in pytorch 1.0 8 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); 11 | 12 | 13 | // ================================================================================================ 14 | // MIN-SUM LBP 15 | // ================================================================================================ 16 | std::vector lbp_forward_min_sum(at::Tensor cost, 17 | at::Tensor jump, 18 | at::Tensor edge, 19 | at::Tensor messages, unsigned short delta) 20 | { 21 | CHECK_INPUT(cost) 22 | CHECK_INPUT(jump) 23 | CHECK_INPUT(edge) 24 | CHECK_INPUT(messages) 25 | 26 | return cuda::lbp_forward_min_sum(cost, jump, edge, messages, delta); 27 | } 28 | 29 | std::vector lbp_backward_min_sum(at::Tensor cost, 30 | at::Tensor jump, 31 | at::Tensor edge, 32 | at::Tensor in_grad, 33 | at::Tensor messages, 34 | at::Tensor messages_argmin, 35 | at::Tensor message_scale) 36 | { 37 | CHECK_INPUT(cost) 38 | CHECK_INPUT(jump) 39 | CHECK_INPUT(edge) 40 | CHECK_INPUT(in_grad) 41 | CHECK_INPUT(messages) 42 | CHECK_INPUT(messages_argmin) 43 | CHECK_INPUT(message_scale) 44 | 45 | return cuda::lbp_backward_min_sum(cost, jump, edge, in_grad, messages, messages_argmin, message_scale); 46 | } 47 | 48 | // ================================================================================================ 49 | // Pytorch Interfaces 50 | // ================================================================================================ 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 52 | { 53 | m.def("forward_minsum", &lbp_forward_min_sum, "LBP forward (CUDA)"); 54 | m.def("backward_minsum", &lbp_backward_min_sum, "LBP backward (CUDA)"); 55 | } 56 | 57 | -------------------------------------------------------------------------------- /ops/lbp_stereo/src/lbp_min_sum_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace cuda 6 | { 7 | std::vector lbp_forward_min_sum(at::Tensor cost, 8 | at::Tensor jump, 9 | at::Tensor edge, 10 | at::Tensor messages, unsigned short delta); 11 | 12 | std::vector lbp_backward_min_sum(at::Tensor cost, 13 | at::Tensor jump, 14 | at::Tensor edge, 15 | at::Tensor in_grad, 16 | at::Tensor messages, 17 | at::Tensor messages_argmin, 18 | at::Tensor message_scale); 19 | 20 | } -------------------------------------------------------------------------------- /ops/lbp_stereo/src/util.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #include "../../include/tensor.h" 6 | 7 | enum DIRECTION {LEFT, RIGHT, UP, DOWN}; 8 | 9 | extern __device__ __forceinline__ float getDerivativeValue(KernelData prev_values, int direction, float current_value, float max_float) 10 | { 11 | float ret_val = 0.0; 12 | if(prev_values(0, direction, 0, 0) != max_float) 13 | { 14 | ret_val = prev_values(0, direction, 0, 0); 15 | } 16 | else 17 | { 18 | ret_val = current_value; 19 | } 20 | 21 | return ret_val; 22 | } 23 | 24 | 25 | extern __device__ __forceinline__ float computeCrossGradient(KernelData prev_values, float max_float) 26 | { 27 | 28 | bool is_cross = false; 29 | if(prev_values(0,0,0,0) != max_float || 30 | prev_values(0,1,0,0) != max_float || 31 | prev_values(0,2,0,0) != max_float || 32 | prev_values(0,3,0,0) != max_float) 33 | { 34 | is_cross = true; 35 | } 36 | 37 | return is_cross; 38 | } 39 | 40 | extern __device__ __forceinline__ float getGradientAcc(KernelData gradient_accumulation, int direction, int n, int y, int x, int c, int grad_acc_idx) 41 | { 42 | 43 | float ret_val = 0.0; 44 | if(direction == UP || direction == DOWN) 45 | { 46 | ret_val = gradient_accumulation(n, x, grad_acc_idx, c); 47 | } 48 | if(direction == LEFT || direction == RIGHT) 49 | { 50 | ret_val = gradient_accumulation(n, y, grad_acc_idx, c); 51 | } 52 | 53 | return ret_val; 54 | 55 | } 56 | 57 | extern __device__ __forceinline__ void updateGradientAcc(KernelData gradient_accumulation, float value, int direction, int n, int y, int x, int c, int grad_acc_idx) 58 | { 59 | 60 | if(direction == UP || direction == DOWN) 61 | { 62 | //gradient_accumulation(n, x, grad_acc_idx, c) = value; 63 | atomicAdd(&gradient_accumulation(n, x, grad_acc_idx, c), value); 64 | } 65 | if(direction == LEFT || direction == RIGHT) 66 | { 67 | //gradient_accumulation(n, y, grad_acc_idx, c) = value; 68 | atomicAdd(&gradient_accumulation(n, y, grad_acc_idx, c), value); 69 | } 70 | 71 | } 72 | 73 | extern __device__ __forceinline__ void setGradientAcc(KernelData gradient_accumulation, float value, int direction, int n, int y, int x, int c, int grad_acc_idx) 74 | { 75 | 76 | if(direction == UP || direction == DOWN) 77 | { 78 | gradient_accumulation(n, x, grad_acc_idx, c) = value; 79 | } 80 | if(direction == LEFT || direction == RIGHT) 81 | { 82 | gradient_accumulation(n, y, grad_acc_idx, c) = value; 83 | } 84 | 85 | } 86 | 87 | extern __device__ __forceinline__ float getEdgeWeight(KernelData edges, int n, int y, int x, int direction) 88 | { 89 | 90 | float w = 1.0; 91 | if(direction == UP || direction == DOWN) 92 | { 93 | w = edges(n, 1, y, x); 94 | } 95 | else 96 | { 97 | w = edges(n, 0, y, x); 98 | } 99 | 100 | return w; 101 | } 102 | 103 | extern __device__ __forceinline__ int getLinearIdx(int pos, int L_vec_size) 104 | { 105 | int vec_idx = L_vec_size / 2 + pos; 106 | vec_idx = max(1, vec_idx); 107 | vec_idx = min(L_vec_size - 1, vec_idx); 108 | return vec_idx; 109 | } 110 | 111 | extern __device__ __forceinline__ float getJumpCost(int t, int s, KernelData5 jump_cost, int n, int direction, int y, int x) 112 | { 113 | int num_L = jump_cost.size2; 114 | 115 | float input_pos = t - s; 116 | 117 | int vec_idx = getLinearIdx(input_pos, num_L); 118 | return jump_cost(n, direction, vec_idx, y, x); 119 | } 120 | 121 | extern __device__ __forceinline__ void addGradientJump(int t, int s, KernelData5 jump_cost, int n, int direction, int y, int x, KernelData5 gradient_pairwise, int grad_xy_idx, float additive_hor, float additive_up, float additive_down) 122 | { 123 | int num_L = jump_cost.size2; 124 | 125 | float input_pos = t - s; 126 | 127 | int vec_idx = getLinearIdx(input_pos, num_L); 128 | 129 | atomicAdd(&gradient_pairwise(n, grad_xy_idx, vec_idx, y, x), additive_hor); 130 | atomicAdd(&gradient_pairwise(n, grad_xy_idx, vec_idx, y, x), additive_up); 131 | atomicAdd(&gradient_pairwise(n, grad_xy_idx, vec_idx, y, x), additive_down); 132 | } 133 | -------------------------------------------------------------------------------- /ops/sad/src/stereo_sad.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "stereo_sad_kernel.cuh" 3 | 4 | // C++ interface 5 | // AT_ASSERTM in pytorch 1.0 6 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 7 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 8 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); 9 | 10 | at::Tensor stereo_sad_forward(at::Tensor f0, at::Tensor f1, int min_disp, int max_disp, float step) 11 | { 12 | CHECK_INPUT(f0) 13 | CHECK_INPUT(f1) 14 | return cuda::stereo_sad_forward(f0, f1, min_disp, max_disp, step); 15 | } 16 | 17 | std::vector stereo_sad_backward(at::Tensor f0, at::Tensor f1, int min_disp, int max_disp, 18 | at::Tensor in_grad) 19 | { 20 | CHECK_INPUT(f0) 21 | CHECK_INPUT(f1) 22 | CHECK_INPUT(in_grad) 23 | return cuda::stereo_sad_backward(f0, f1, min_disp, max_disp, in_grad); 24 | } 25 | 26 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 27 | { 28 | m.def("forward", &stereo_sad_forward, "SAD Matching forward (CUDA)"); 29 | m.def("backward", &stereo_sad_backward, "SAD Matching backward (CUDA)"); 30 | } -------------------------------------------------------------------------------- /ops/sad/src/stereo_sad_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "stereo_sad_kernel.cuh" 4 | #include "tensor.h" 5 | #include "error_util.h" 6 | 7 | // get y for position x 8 | __device__ float dLinearInterpolation1D(float x, float x0, float x1, float y0, float y1) 9 | { 10 | return y0 + (y1 - y0) * (x - x0) / (x1 - x0); 11 | } 12 | 13 | // ============================================================================ 14 | // CUDA KERNELS 15 | // ============================================================================ 16 | __global__ void stereo_sad_cuda_forward_kernel( 17 | KernelData f0, 18 | KernelData f1, 19 | int min_disp, 20 | int max_disp, 21 | float step, 22 | KernelData output 23 | ) 24 | { 25 | const unsigned int x = blockIdx.x * blockDim.x + threadIdx.x; 26 | const unsigned int y = blockIdx.y * blockDim.y + threadIdx.y; 27 | const unsigned int d_idx = blockIdx.z * blockDim.z + threadIdx.z; 28 | 29 | // check inside image 30 | int n = 0; 31 | if(x >= f0.size3 || y >= f0.size2 || d_idx >= output.size3) 32 | return; 33 | 34 | // 35 | int d = (d_idx * step) + min_disp; 36 | 37 | // skip outside pixels 38 | if(x - d < 0 || x - d >= f0.size3) 39 | return; 40 | 41 | float sad = 0.0f; 42 | for(int c = 0; c < f0.size1; ++c) 43 | { 44 | float f1_c = 0.0; 45 | if(step == 1.0) 46 | { 47 | f1_c = f1(n, c, y, x - d); 48 | } 49 | else 50 | { 51 | int floor_x = (int) floorf(x - d); 52 | int ceil_x = floor_x + 1; 53 | float x_pos = x - d; 54 | f1_c = dLinearInterpolation1D(x_pos, floor_x, ceil_x, f1(n, c, y, floor_x), f1(n, c, y, ceil_x)); 55 | } 56 | 57 | sad += fabs(f0(n, c, y, x) - f1_c); 58 | } 59 | 60 | // write result back to global memory 61 | output(n, y, x, d_idx) = sad; 62 | } 63 | 64 | __global__ void stereo_sad_cuda_backward_kernel( 65 | KernelData f0, 66 | KernelData f1, 67 | int min_disp, 68 | int max_disp, 69 | KernelData in_grad, 70 | KernelData df0, 71 | KernelData df1 72 | ) 73 | { 74 | const unsigned int x = blockIdx.x * blockDim.x + threadIdx.x; 75 | const unsigned int y = blockIdx.y * blockDim.y + threadIdx.y; 76 | const unsigned int c = blockIdx.z * blockDim.z + threadIdx.z; 77 | 78 | float eps = 1e-15; 79 | 80 | // check inside image 81 | int n = 0; 82 | if(x >= f0.size3 || y >= f0.size2 || c >= f0.size1) 83 | return; 84 | 85 | float grad_f0 = 0.0f; 86 | float grad_f1 = 0.0f; 87 | for(int d = min_disp; d <= max_disp; ++d) 88 | { 89 | int idx = d - min_disp; 90 | // skip outside pixels 91 | if(x - d >= 0 && x - d < f0.size3) 92 | { 93 | float diff = f0(n, c, y, x) - f1(n, c, y, x - d); 94 | if(fabsf(diff) > eps) // gradient is zero if diff is zero! 95 | grad_f0 += (diff / fabsf(diff)) * in_grad(n, y, x, idx); 96 | } 97 | 98 | if(x + d >= 0 && x + d < f0.size3) 99 | { 100 | float diff1 = f0(n, c, y, x + d) - f1(n, c, y, x); 101 | if(fabsf(diff1) > eps) 102 | grad_f1 -= (diff1 / fabsf(diff1)) * in_grad(n, y, x + d, idx); 103 | } 104 | } 105 | 106 | df0(n, c, y, x) = grad_f0; 107 | df1(n, c, y, x) = grad_f1; 108 | } 109 | 110 | 111 | // ============================================================================ 112 | // CPP KERNEL CALLS 113 | // ============================================================================ 114 | namespace cuda 115 | { 116 | at::Tensor stereo_sad_forward(at::Tensor f0, at::Tensor f1, int min_disp, int max_disp, float step) 117 | { 118 | int N = f0.size(0); 119 | int C = f0.size(1); 120 | int H = f0.size(2); 121 | int W = f0.size(3); 122 | int D = (max_disp - min_disp + 1) / step; 123 | 124 | 125 | auto cost_vol = at::ones({N, H, W, D}, f0.options()) * 40; 126 | 127 | // parallelise over H x W x D 128 | const dim3 blockSize(8, 8, 4); 129 | const dim3 numBlocks(std::ceil(W / static_cast(blockSize.x)), 130 | std::ceil(H / static_cast(blockSize.y)), 131 | std::ceil(D / static_cast(blockSize.z))); 132 | 133 | stereo_sad_cuda_forward_kernel<<>>(f0, f1, min_disp, max_disp, step, cost_vol); 134 | cudaSafeCall(cudaGetLastError()); 135 | return cost_vol; 136 | } 137 | 138 | std::vector stereo_sad_backward(at::Tensor f0, at::Tensor f1, 139 | int min_disp, int max_disp, 140 | at::Tensor in_grad) 141 | { 142 | int N = f0.size(0); 143 | int C = f0.size(1); 144 | int H = f0.size(2); 145 | int W = f0.size(3); 146 | int D = max_disp - min_disp + 1; 147 | 148 | auto df0 = at::zeros_like(f0); 149 | auto df1 = at::zeros_like(f1); 150 | 151 | // parallelise over H x W x D 152 | const dim3 blockSize(8, 8, 4); 153 | const dim3 numBlocks(std::ceil(W / static_cast(blockSize.x)), 154 | std::ceil(H / static_cast(blockSize.y)), 155 | std::ceil(D / static_cast(blockSize.z))); 156 | 157 | stereo_sad_cuda_backward_kernel<<>>(f0, f1, min_disp, max_disp, in_grad, 158 | df0, df1); 159 | cudaSafeCall(cudaGetLastError()); 160 | 161 | std::vector gradients; 162 | gradients.push_back(df0); 163 | gradients.push_back(df1); 164 | 165 | return gradients; 166 | } 167 | } -------------------------------------------------------------------------------- /ops/sad/src/stereo_sad_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace cuda 5 | { 6 | at::Tensor stereo_sad_forward(at::Tensor f0, at::Tensor f1, int min_disp, int max_disp, float step); 7 | std::vector stereo_sad_backward(at::Tensor f0, at::Tensor f1, int min_disp, 8 | int max_disp, at::Tensor in_grad); 9 | } -------------------------------------------------------------------------------- /ops/sad/stereo_sad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import pytorch_cuda_stereo_sad_op 4 | 5 | class StereoMatchingSadFunction(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, f0, f1, min_disp, max_disp, step=1.0): 8 | ctx.save_for_backward(f0, f1, torch.tensor(min_disp), torch.tensor(max_disp), torch.tensor(step)) 9 | res = pytorch_cuda_stereo_sad_op.forward(f0, f1, min_disp, max_disp, step) 10 | return res 11 | 12 | @staticmethod 13 | def backward(ctx, in_grad): 14 | f0, f1, min_disp, max_disp, step = ctx.saved_tensors 15 | if step != 1.0: 16 | raise ValueError("Error: Backward for step != 1 is not implemented!") 17 | df0, df1 = pytorch_cuda_stereo_sad_op.backward(f0, f1, int(min_disp), int(max_disp), 18 | in_grad) 19 | return df0, df1, None, None 20 | -------------------------------------------------------------------------------- /ops/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | setup( 6 | name='matching', 7 | version='1.0', 8 | author="Patrick Knöbelreiter", 9 | author_email="knoebelreiter@icg.tugraz.at", 10 | packages=["sad"], 11 | include_dirs=['include/' ], 12 | ext_modules=[ 13 | CUDAExtension('pytorch_cuda_stereo_sad_op', [ 14 | 'sad/src/stereo_sad.cpp', 15 | 'sad/src/stereo_sad_kernel.cu', 16 | ]), 17 | CUDAExtension('pytorch_cuda_flow_mp_sad_op', [ 18 | 'flow_mp_sad/src/flow_mp_sad.cpp', 19 | 'flow_mp_sad/src/flow_mp_sad_kernel.cu', 20 | ]) 21 | ], 22 | cmdclass={ 23 | 'build_ext': BuildExtension 24 | }, 25 | install_requires=[ 26 | "numpy >= 1.15", 27 | "torch >= 0.4.1", 28 | "scikit-image >= 0.14.1", 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /run_flow.sh: -------------------------------------------------------------------------------- 1 | #/bin/sh 2 | IM0="data/frame_0019.png" 3 | IM1="data/frame_0020.png" 4 | 5 | cd ops/lbp_stereo 6 | python setup.py install 7 | cd ../../ 8 | 9 | python main_flow.py --model bp+ms+h --checkpoint-unary "data/params/flow/BP+MS (H)/unary_best.cpt" --checkpoint-matching "data/params/flow/BP+MS (H)/matching_lvl0_best.cpt" "data/params/flow/BP+MS (H)/matching_lvl1_best.cpt" "data/params/flow/BP+MS (H)/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/flow/BP+MS (H)/affinity_best.cpt" --checkpoint-crf "data/params/flow/BP+MS (H)/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/flow/BP+MS (H)/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/flow/BP+MS (H)/crf0_lvl2_best.cpt" --multi-level-output --bp-inference sub-exp --with-bn --im0 $IM0 --im1 $IM1 10 | 11 | python main_flow.py --model bp+ms+ref+h --checkpoint-unary "data/params/flow/BP+MS+Ref (H)/unary_best.cpt" --checkpoint-matching "data/params/flow/BP+MS+Ref (H)/matching_lvl0_best.cpt" "data/params/flow/BP+MS+Ref (H)/matching_lvl1_best.cpt" "data/params/flow/BP+MS+Ref (H)/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/flow/BP+MS+Ref (H)/affinity_best.cpt" --checkpoint-crf "data/params/flow/BP+MS+Ref (H)/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/flow/BP+MS+Ref (H)/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/flow/BP+MS+Ref (H)/crf0_lvl2_best.cpt" --checkpoint-refinement "data/params/flow/BP+MS+Ref (H)/refinement_best.cpt" --with-bn --bp-inference sub-exp --output-level-offset 0 --multi-level-output --im0 $IM0 --im1 $IM1 -------------------------------------------------------------------------------- /run_semantic_global.sh: -------------------------------------------------------------------------------- 1 | #/bin/sh 2 | IMG="data/frankfurt_val.png" 3 | cd ops/lbp_semantic_pw && 4 | python setup.py install && 5 | cd ../../ && 6 | python main_semantic.py --img=$IMG --checkpoint-semantic data/params/semantic/global_model.cpt --checkpoint-esp-net dependencies/ESPNet/pretrained --pairwise-type global --with-edges -------------------------------------------------------------------------------- /run_semantic_pixel.sh: -------------------------------------------------------------------------------- 1 | #/bin/sh 2 | IMG="data/frankfurt_val.png" 3 | cd ops/lbp_semantic_pw_pixel && 4 | python setup.py install && 5 | cd ../../ && 6 | 7 | python main_semantic.py --img=$IMG --checkpoint-semantic data/params/semantic/pixel_model.cpt --checkpoint-esp-net dependencies/ESPNet/pretrained --pairwise-type pixel -------------------------------------------------------------------------------- /run_stereo_kitti.sh: -------------------------------------------------------------------------------- 1 | #/bin/sh 2 | cd ops/lbp_stereo 3 | python setup.py install 4 | cd ../../ 5 | 6 | python main_stereo.py --model bp+ms+h --checkpoint-unary "data/params/stereo/BP+MS (H)/unary_best.cpt" --checkpoint-matching "data/params/stereo/Kitti/matching_lvl0_best.cpt" "data/params/stereo/Kitti/matching_lvl1_best.cpt" "data/params/stereo/Kitti/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/stereo/Kitti/affinity_best.cpt" --checkpoint-crf "data/params/stereo/Kitti/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/stereo/Kitti/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/stereo/Kitti/crf0_lvl2_best.cpt" --with-bn --bp-inference sub-exp --input-level-offset 0 --output-level-offset 0 --multi-level-output --im0 "data/000010_10_left.png" --im1 "data/000010_10_right.png" -------------------------------------------------------------------------------- /run_stereo_mb.sh: -------------------------------------------------------------------------------- 1 | #/bin/sh 2 | cd ops/lbp_stereo 3 | python setup.py install 4 | cd ../../ 5 | 6 | python main_stereo.py --model bp+ms+h --checkpoint-unary "data/params/stereo/MB/unary_best.cpt" --checkpoint-matching "data/params/stereo/MB/matching_lvl0_best.cpt" "data/params/stereo/MB/matching_lvl1_best.cpt" "data/params/stereo/MB/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/stereo/MB/affinity_best.cpt" --checkpoint-crf "data/params/stereo/MB/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/stereo/MB/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/stereo/MB/crf0_lvl2_best.cpt" --multi-level-output --bp-inference sub-exp --with-bn --input-level-offset 1 --output-level-offset 1 --im0 "data/im0.png" --im1 "data/im1.png" -------------------------------------------------------------------------------- /run_stereo_sf.sh: -------------------------------------------------------------------------------- 1 | #/bin/sh 2 | IM0="data/sf_0006_left.png" 3 | IM1="data/sf_0006_right.png" 4 | 5 | cd ops/lbp_stereo 6 | python setup.py install 7 | cd ../../ 8 | 9 | python main_stereo.py --model wta --checkpoint-unary "data/params/stereo/WTA (NLL)/unary_best.cpt" --checkpoint-matching "data/params/stereo/WTA (NLL)/matching_best.cpt" --with-bn --with-output-bn --im0 $IM0 --im1 $IM1 10 | 11 | python main_stereo.py --model bp+ms --checkpoint-unary "data/params/stereo/BP+MS (NLL)/unary_best.cpt" --checkpoint-matching "data/params/stereo/BP+MS (NLL)/matching_lvl0_best.cpt" "data/params/stereo/BP+MS (NLL)/matching_lvl1_best.cpt" "data/params/stereo/BP+MS (NLL)/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/stereo/BP+MS (NLL)/affinity_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS (H)/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS (H)/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS (H)/crf0_lvl2_best.cpt" --with-bn --bp-inference wta --multi-level-output --im0 $IM0 --im1 $IM1 12 | 13 | python main_stereo.py --model bp+ms+h --checkpoint-unary "data/params/stereo/BP+MS (H)/unary_best.cpt" --checkpoint-matching "data/params/stereo/BP+MS (H)/matching_lvl0_best.cpt" "data/params/stereo/BP+MS (H)/matching_lvl1_best.cpt" "data/params/stereo/BP+MS (H)/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/stereo/BP+MS (H)/affinity_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS (H)/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS (H)/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS (H)/crf0_lvl2_best.cpt" --multi-level-output --bp-inference sub-exp --with-bn --im0 $IM0 --im1 $IM1 14 | 15 | python main_stereo.py --model bp+ms+ref+h --checkpoint-unary "data/params/stereo/BP+MS+Ref (H)/unary_best.cpt" --checkpoint-matching "data/params/stereo/BP+MS+Ref (H)/matching_lvl0_best.cpt" "data/params/stereo/BP+MS+Ref (H)/matching_lvl1_best.cpt" "data/params/stereo/BP+MS+Ref (H)/matching_lvl2_best.cpt" --checkpoint-affinity "data/params/stereo/BP+MS+Ref (H)/affinity_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS+Ref (H)/crf0_lvl0_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS+Ref (H)/crf0_lvl1_best.cpt" --checkpoint-crf "data/params/stereo/BP+MS+Ref (H)/crf0_lvl2_best.cpt" --checkpoint-refinement "data/params/stereo/BP+MS+Ref (H)/refinement_best.cpt" --with-bn --bp-inference sub-exp --output-level-offset 0 --multi-level-output --im0 $IM0 --im1 $IM1 -------------------------------------------------------------------------------- /semantic_segmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import os.path as osp 5 | from networks import SubNetwork 6 | from networks import PWNet, EdgeNet 7 | 8 | from ops.lbp_semantic_pw.bp_op_cuda import BP as BP_PW 9 | from ops.lbp_semantic_pw_pixel.bp_op_cuda import BP as BP_PW_PIXEL 10 | 11 | from dependencies.ESPNet.test.Model import ESPNet 12 | 13 | import os 14 | 15 | from corenet import TemperatureSoftmax 16 | 17 | from time import time 18 | 19 | class SemanticNet(SubNetwork): 20 | def __init__(self, device, args): 21 | super(SemanticNet, self).__init__(args, device) 22 | 23 | self.num_labels = args.num_labels 24 | 25 | self.forward_timings = [] 26 | 27 | self._esp_net = ESPNet(self.num_labels, 2, 8, None) 28 | 29 | self._softmax_input = TemperatureSoftmax(dim=3, init_temp=1.0) 30 | 31 | if (self.args.pairwise_type == "global") and args.with_edges: 32 | self._edge_net = EdgeNet(device, args) 33 | 34 | if self.args.checkpoint_esp_net: 35 | esp_weight_file = os.path.join(self.args.checkpoint_esp_net, "decoder/espnet_p_" + str(2) + "_q_" + str(8) + ".pth") 36 | self._esp_net.load_state_dict(torch.load(esp_weight_file)) 37 | 38 | max_iter = 1 39 | 40 | if args.pairwise_type == "global": 41 | self._crf = [BP_PW(device, args, max_iter, self.num_labels, self.num_labels, 42 | mode_inference = 'wta', 43 | mode_message_passing='min-sum', layer_idx=0)] 44 | 45 | self.pw_mat_init = torch.zeros((1, 2, self.num_labels, self.num_labels), dtype=torch.float).cuda() 46 | self._pw_mat = nn.Parameter(self.pw_mat_init, requires_grad=True) 47 | self._pw_net = None 48 | 49 | elif args.pairwise_type == "pixel": 50 | self._crf = [BP_PW_PIXEL(device, args, max_iter, self.num_labels, self.num_labels, 51 | mode_inference = 'wta', 52 | mode_message_passing='min-sum', layer_idx=0)] 53 | 54 | self._pw_net = PWNet(device, args) 55 | self._pw_mat = None 56 | 57 | if args.checkpoint_semantic is not None: 58 | print("Loading semantic checkpoint!") 59 | self.load_state_dict(torch.load(args.checkpoint_semantic)) 60 | 61 | if not self.args.with_esp: 62 | self._esp_net.eval() 63 | 64 | self._esp_net.to(device) 65 | 66 | def forward(self, ipt): 67 | 68 | t0_fwd = time() 69 | 70 | res_esp = self._esp_net.forward(ipt) 71 | 72 | res_esp = res_esp.permute((0, 2, 3, 1)) 73 | res_esp = res_esp.contiguous() 74 | 75 | res_esp = self._softmax_input(res_esp) 76 | 77 | esp_only_res = res_esp[0].max(2)[1].unsqueeze(0).unsqueeze(0) 78 | 79 | N, H, W, C = res_esp.shape 80 | 81 | if (self.args.pairwise_type == "global") and self.args.with_edges: 82 | weights = self.extract_edges(ipt)[0] 83 | else: 84 | weights = torch.ones((N, 2, H, W), dtype=torch.float).cuda() 85 | 86 | res, beliefs = self.optimize_crf(ipt, res_esp, weights, None, None) 87 | 88 | torch.cuda.synchronize() 89 | self.forward_timings.append(time() - t0_fwd) 90 | 91 | torch.cuda.empty_cache() 92 | 93 | return res, beliefs, esp_only_res 94 | 95 | def esp_net(self): 96 | return self._esp_net 97 | 98 | def esp_net_params(self, requires_grad=None): 99 | if self._esp_net: 100 | return self._esp_net.parameter_list(requires_grad) 101 | return [] 102 | 103 | def sem_params(self, requires_grad=None): 104 | return self.parameter_list(requires_grad) 105 | 106 | def pw_mat(self): 107 | return self._pw_mat 108 | 109 | def edge_net(self): 110 | return self._edge_net 111 | 112 | def edge_net_params(self, requires_grad=None): 113 | if self._edge_net: 114 | return self._edge_net.parameter_list(requires_grad) 115 | return [] 116 | 117 | def extract_edges(self, ipt): 118 | if self.edge_net: 119 | return self._edge_net.forward(ipt) 120 | return None 121 | 122 | def optimize_crf(self, ipt, prob_vol, weights, affinities, offsets): 123 | # iterate over all bp "layers" 124 | for idx, crf in enumerate(self._crf): 125 | 126 | prob_vol = prob_vol.contiguous() 127 | 128 | weights_input = crf.adjust_input_weights(weights, idx) 129 | 130 | if self._pw_net is not None: 131 | 132 | pw_net_jump = self._pw_net.forward(ipt)[0] 133 | 134 | _, _, H, W = pw_net_jump.shape 135 | 136 | pw_net_jump = pw_net_jump.view((2, self.num_labels, self.num_labels, H, W)) 137 | pw_net_jump = pw_net_jump.permute((0, 3, 4, 1, 2)).contiguous() 138 | 139 | disps, prob_vol = crf.forward(prob_vol, weights_input, pw_net_jump) 140 | elif self._pw_mat is not None: 141 | disps, prob_vol = crf.forward(prob_vol, weights_input, self._pw_mat) 142 | 143 | return disps, prob_vol 144 | -------------------------------------------------------------------------------- /stereo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from networks import FeatureNet, AffinityNet, RefinementNet 7 | from matching import StereoMatchingSad 8 | from ops.lbp_stereo.bp_op_cuda import BP 9 | from ops.lbp_stereo.inference_op import Inference 10 | 11 | from corenet import LrCheck, LrDistance, CvConfidence 12 | from corenet import PadUnpad, Pad, Unpad 13 | 14 | import numpy as np 15 | 16 | class StereoMethod(nn.Module): 17 | def __init__(self, device, args): 18 | nn.Module.__init__(self) 19 | self.args = args 20 | 21 | self._feature_net = None 22 | self._matching = [] 23 | self._affinity_net = None 24 | self._refinement_net = None 25 | self._crf = [] # list of bp layers 26 | 27 | max_dist = 3.0 28 | self.lr_check = LrCheck(device, max_dist).to(device) 29 | self.cv_conf = CvConfidence(device).to(device) 30 | self.lr_dist = LrDistance(device).to(device) 31 | 32 | self._min_disp = args.min_disp 33 | self._max_disp = args.max_disp 34 | 35 | self.pad = None 36 | self.unpad = None 37 | 38 | self.device = device 39 | 40 | self.logger = logging.getLogger("StereoMethod") 41 | 42 | 43 | def forward(self, I0_pyramid, I1_pyramid, offsets_orig=None, edges_orig=None, beliefs_in=None, 44 | min_disp=None, max_disp=None, step=None): 45 | 46 | # necessary for evaluation 47 | if self.pad is None: 48 | self.pad = Pad(self.feature_net.net.divisor, self.args.pad) 49 | if self.unpad is None: 50 | self.unpad = Unpad() 51 | 52 | res_dict = {'disps0': None} 53 | 54 | I0_in = I0_pyramid[self.args.input_level_offset].to(self.device) 55 | I1_in = I1_pyramid[self.args.input_level_offset].to(self.device) 56 | 57 | # pad input for multi-scale (for evaluation) 58 | I0_in = self.pad.forward(I0_in).cuda() 59 | I1_in = self.pad.forward(I1_in).cuda() 60 | 61 | f0_pyramid = self.extract_features(I0_in) 62 | f1_pyramid = self.extract_features(I1_in) 63 | 64 | if max_disp is not None: 65 | for matching_lvl, m in enumerate(self.matching): 66 | m.max_disp = ((max_disp + 1) // 2**matching_lvl) - 1 67 | if step is not None: 68 | for matching_lvl, m in enumerate(self.matching): 69 | m.step = step 70 | 71 | # multi-scale-matching 72 | prob_vol_pyramid = self.match(f0_pyramid, f1_pyramid) 73 | 74 | udisp0_pyramid = [] 75 | for pv0 in prob_vol_pyramid: 76 | udisp0_pyramid.append(torch.argmax(pv0, dim=-1, keepdim=True).permute(0, 3, 1, 2)) 77 | res_dict['disps0'] = udisp0_pyramid 78 | 79 | if self.args.model == 'wta': 80 | return res_dict 81 | 82 | affinity_pyramid = None 83 | if self.affinity_net: 84 | affinity_pyramid = self.extract_affinities(I0_in) 85 | for lvl in range(len(affinity_pyramid)): 86 | _, _, h, w = affinity_pyramid[lvl].shape 87 | affinity_pyramid[lvl] = affinity_pyramid[lvl].view((2, 5, h, w)) 88 | affinity_pyramid[lvl] = affinity_pyramid[lvl].unsqueeze(0) 89 | 90 | output_disps_pyramid = [] 91 | beliefs_pyramid = None 92 | crf_disps_pyramid = [] 93 | beliefs_pyramid = [] 94 | beliefs_in = None 95 | for lvl in reversed(range(len(prob_vol_pyramid))): 96 | pv_lvl = prob_vol_pyramid[lvl] 97 | m = self.matching[lvl] 98 | 99 | affinity = None 100 | if affinity_pyramid is not None: 101 | affinity = affinity_pyramid[lvl] 102 | crf = self.crf[lvl] 103 | 104 | # add probably an if condition whether do add multi-scale to crf 105 | if beliefs_in is not None: 106 | beliefs_in = F.interpolate(beliefs_in.unsqueeze(1), scale_factor=2.0, mode='trilinear')[:, 0] 107 | 108 | if beliefs_in.requires_grad: 109 | # print('requires grad') 110 | pv_lvl = pv_lvl / pv_lvl.sum(dim=-1, keepdim=True) 111 | else: 112 | # print('no grad-> inplace') 113 | pv_lvl += beliefs_in / 2.0 # in-place saves memory 114 | del beliefs_in 115 | 116 | torch.cuda.empty_cache() 117 | disps_lvl, beliefs_lvl, affinities_lvl, _ = self.optimize_crf(crf, pv_lvl, None, affinity) 118 | del affinities_lvl 119 | 120 | if lvl == 0: 121 | beliefs_lvl = self.unpad(beliefs_lvl, self.pad.l, self.pad.r, self.pad.t, 122 | self.pad.b, NCHW=False) 123 | disps_lvl = self.unpad(disps_lvl, self.pad.l, self.pad.r, self.pad.t, 124 | self.pad.b) 125 | 126 | beliefs_pyramid.append(beliefs_lvl) 127 | crf_disps_pyramid.append(disps_lvl + m.min_disp) 128 | beliefs_in = beliefs_pyramid[-1] 129 | 130 | # beliefs are from low res to high res 131 | beliefs_pyramid.reverse() 132 | crf_disps_pyramid.reverse() 133 | res_dict['disps0'] = crf_disps_pyramid 134 | 135 | if self.refinement_net: 136 | # crf 137 | cv_conf = self.cv_conf.forward(beliefs_pyramid[0].permute(0, 3, 1, 2), 138 | crf_disps_pyramid[0]) 139 | 140 | conf_all = cv_conf 141 | refined_disps_pyramid, refinement_steps = self.refine_disps(I0_pyramid, 142 | crf_disps_pyramid[0], 143 | confidence=conf_all, 144 | I1=I1_pyramid) 145 | if refinement_steps is not None: 146 | refinement_steps.reverse() 147 | refined_disps_pyramid.reverse() 148 | output_disps_pyramid = refined_disps_pyramid 149 | 150 | res_dict['disps0'] = output_disps_pyramid 151 | 152 | return res_dict 153 | 154 | def extract_features(self, ipt): 155 | if self.feature_net: 156 | return self.feature_net.forward(ipt) 157 | return None 158 | 159 | def extract_affinities(self, ipt): 160 | if self.affinity_net: 161 | return self.affinity_net.forward(ipt) 162 | return None 163 | 164 | def match(self, f0, f1, lr=False): 165 | prob_vols = [] 166 | if self.matching: 167 | for matching, f0s, f1s in zip(self.matching, f0, f1): 168 | if lr: 169 | f0s = torch.flip(f0s, dims=(3,)).contiguous() 170 | f1s = torch.flip(f1s, dims=(3,)).contiguous() 171 | prob_vol_s = matching.forward(f1s, f0s) 172 | prob_vol_s = torch.flip(prob_vol_s, dims=(2,)) 173 | else: 174 | prob_vol_s = matching.forward(f0s, f1s) 175 | 176 | prob_vols.append(prob_vol_s) 177 | return prob_vols 178 | return None 179 | 180 | def optimize_crf(self, crf_layer, prob_vol, weights, affinities): 181 | if crf_layer: 182 | offsets = None 183 | # iterate over all bp "layers" 184 | for idx, crf in enumerate(crf_layer): 185 | prob_vol = prob_vol.contiguous() 186 | weights_input = crf.adjust_input_weights(weights, idx) 187 | affinities_shift = crf.adjust_input_affinities(affinities) 188 | offsets_shift = crf.adjust_input_offsets(offsets) 189 | 190 | if not prob_vol.requires_grad: 191 | torch.cuda.empty_cache() 192 | 193 | disps, prob_vol, messages = crf.forward(prob_vol, weights_input, affinities_shift, offsets_shift) 194 | del messages # never used again 195 | 196 | return disps, prob_vol, affinities_shift, offsets_shift 197 | return None 198 | 199 | def refine_disps(self, I0, d0, confidence=None, I1=None): 200 | if self.refinement_net: 201 | refined, steps = self.refinement_net.forward(I0, d0, confidence, I1) 202 | return refined, steps 203 | return None 204 | 205 | def feature_net_params(self, requires_grad=None): 206 | if self.feature_net: 207 | return self.feature_net.parameter_list(requires_grad) 208 | return [] 209 | 210 | def matching_params(self, requires_grad=None): 211 | params = [] 212 | if self.matching: 213 | for m in self.matching: 214 | params += m.parameter_list(requires_grad) 215 | return params 216 | 217 | def affinity_net_params(self, requires_grad=None): 218 | if self.affinity_net: 219 | return self.affinity_net.parameter_list(requires_grad) 220 | return [] 221 | 222 | def crf_params(self, requires_grad=None): 223 | crf_params = [] 224 | if self.crf: 225 | for crf_layer in self.crf: 226 | for crf in crf_layer: 227 | crf_params += crf.parameter_list(requires_grad) 228 | return crf_params 229 | 230 | def refinement_net_params(self, requires_grad=None): 231 | if self.refinement_net: 232 | return self.refinement_net.parameter_list(requires_grad) 233 | return [] 234 | 235 | @property 236 | def feature_net(self): 237 | return self._feature_net 238 | 239 | @property 240 | def affinity_net(self): 241 | return self._affinity_net 242 | 243 | @property 244 | def crf(self): 245 | if self._crf == []: 246 | return None 247 | return self._crf 248 | 249 | @property 250 | def refinement_net(self): 251 | return self._refinement_net 252 | 253 | @property 254 | def matching(self): 255 | return self._matching 256 | 257 | @property 258 | def min_disp(self): 259 | return self._min_disp 260 | 261 | @property 262 | def max_disp(self): 263 | return self._max_disp 264 | 265 | @property 266 | def gc_net(self): 267 | return self._gc_net 268 | 269 | 270 | #################################################################################################### 271 | # Block Match 272 | #################################################################################################### 273 | class BlockMatchStereo(StereoMethod): 274 | def __init__(self, device, args): 275 | StereoMethod.__init__(self, device, args) 276 | self._feature_net = FeatureNet(device, args) 277 | 278 | self._matching = [] 279 | for matching_lvl in range(self._feature_net.net.num_output_levels): 280 | if self.args.lbp_min_disp: 281 | min_disp = self.min_disp # original 282 | else: 283 | min_disp = self.min_disp // 2**matching_lvl 284 | max_disp = ((self.max_disp + 1) // 2**matching_lvl) - 1 285 | self.logger.info("Construct Matching Level %d with min-disp=%d and max-disp=%d" %(matching_lvl, min_disp, max_disp)) 286 | 287 | self._matching.append(StereoMatchingSad(device, args, min_disp, max_disp, 288 | lvl=matching_lvl)) 289 | 290 | 291 | #################################################################################################### 292 | # Min-Sum LBP 293 | #################################################################################################### 294 | class MinSumStereo(BlockMatchStereo): 295 | def __init__(self, device, args): 296 | BlockMatchStereo.__init__(self, device, args) 297 | 298 | self.max_iter = args.max_iter 299 | num_labels = self.max_disp - self.min_disp + 1 300 | 301 | self._affinity_net = AffinityNet(device, args) 302 | 303 | for lvl in range(self._feature_net.net.num_output_levels): 304 | self._crf.append([BP(device, args, self.max_iter, num_labels, 3, 305 | mode_inference = args.bp_inference, 306 | mode_message_passing='min-sum', layer_idx=idx, level=lvl) 307 | for idx in range(args.num_bp_layers)]) 308 | 309 | 310 | class RefinedMinSumStereo(MinSumStereo): 311 | def __init__(self, device, args): 312 | super(RefinedMinSumStereo, self).__init__(device, args) 313 | 314 | self._refinement_net = RefinementNet(device, args) --------------------------------------------------------------------------------