├── .gitignore ├── README.md ├── back2future.py ├── convert.lua ├── correlation_package ├── .make.sh.swp ├── __init__.py ├── build.py ├── functions │ ├── __init__.py │ └── correlation.py ├── make.sh ├── modules │ ├── __init__.py │ └── correlation.py └── src │ ├── correlation.c │ ├── correlation.h │ ├── correlation_cuda.c │ ├── correlation_cuda.h │ ├── correlation_cuda_kernel.cu │ └── correlation_cuda_kernel.h ├── demo.py ├── flow_io.py ├── pretrained ├── b2f_kitti.pth.tar └── b2f_sintel.pth.tar ├── samples ├── 000010_09.png ├── 000010_10.png └── 000010_11.png └── test_back2future.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.o 3 | correlation_package/_ext/* 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is a Pytorch implementation of 2 | 3 | Janai, J., Güney, F., Ranjan, A., Black, M. and Geiger, A., **Unsupervised Learning of Multi-Frame Optical Flow with Occlusions.** ECCV 2018. 4 | 5 | [[Link to Paper](http://www.cvlibs.net/publications/Janai2018ECCV.pdf)] [[Project Page](https://avg.is.tuebingen.mpg.de/research_projects/back2future)] [[Original Torch Code](https://github.com/jjanai/back2future)] 6 | 7 | ## Requirements 8 | - Runs and tested on [Pytorch 0.3.1](https://pytorch.org/get-started/previous-versions/), it should be compatible with higher versions with little/no modifications. 9 | - Correlation package is taken from [NVIDIA/flownet2-pytorch](https://github.com/NVIDIA/flownet2-pytorch/) and it can be installed using 10 | ```bash 11 | cd correlation_package 12 | bash make.sh 13 | ``` 14 | If you are using Pytorch>0.3.1, you can use correlation layer from [here](https://github.com/ClementPinard/Pytorch-Correlation-extension). 15 | ## Usage 16 | To use the model, go to your favorite python environment 17 | ```python 18 | from back2future import Model 19 | model = Model(pretrained='pretrained/path_to_your_favorite_model') 20 | ``` 21 | There are two pretrained models in `pretrained/`, that are fine tuned on Sintel and KITTI in an unsupervised way. 22 | 23 | Refer to `demo.py` for more. 24 | 25 | ## Testing 26 | To test performance on KITTI, use 27 | ```bash 28 | python3 test_back2future.py --pretrained-flow path/to/pretrained/model --kitti-dir path/to/kitti/2015/root 29 | ``` 30 | 31 | ## Training 32 | Please use the [[original torch code](https://github.com/jjanai/back2future)] for training new models. 33 | 34 | ## License 35 | This is a reimplementation. License for the original work can be found at [JJanai/back2future](https://github.com/JJanai/back2future/blob/master/LICENSE). 36 | 37 | ## While using this code, please cite 38 | ``` 39 | @inproceedings{Janai2018ECCV, 40 | title = {Unsupervised Learning of Multi-Frame Optical Flow with Occlusions }, 41 | author = {Janai, Joel and G{"u}ney, Fatma and Ranjan, Anurag and Black, Michael J. and Geiger, Andreas}, 42 | booktitle = {European Conference on Computer Vision (ECCV)}, 43 | volume = {Lecture Notes in Computer Science, vol 11220}, 44 | pages = {713--731}, 45 | publisher = {Springer, Cham}, 46 | month = sep, 47 | year = {2018}, 48 | month_numeric = {9} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /back2future.py: -------------------------------------------------------------------------------- 1 | # Author: Anurag Ranjan 2 | # Copyright (c) 2018, Max Planck Society 3 | 4 | import os 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.serialization import load_lua 11 | from torch.autograd import Variable 12 | from correlation_package.modules.correlation import Correlation 13 | 14 | def conv_feat_block(nIn, nOut): 15 | return nn.Sequential( 16 | nn.Conv2d(nIn, nOut, kernel_size=3, stride=2, padding=1), 17 | nn.LeakyReLU(0.2), 18 | nn.Conv2d(nOut, nOut, kernel_size=3, stride=1, padding=1), 19 | nn.LeakyReLU(0.2) 20 | ) 21 | 22 | def conv_dec_block(nIn): 23 | return nn.Sequential( 24 | nn.Conv2d(nIn, 128, kernel_size=3, stride=1, padding=1), 25 | nn.LeakyReLU(0.2), 26 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), 27 | nn.LeakyReLU(0.2), 28 | nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1), 29 | nn.LeakyReLU(0.2), 30 | nn.Conv2d(96, 64, kernel_size=3, stride=1, padding=1), 31 | nn.LeakyReLU(0.2), 32 | nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1), 33 | nn.LeakyReLU(0.2), 34 | nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1) 35 | ) 36 | 37 | 38 | class Model(nn.Module): 39 | def __init__(self, pretrained=None): 40 | super(Model, self).__init__() 41 | 42 | idx = [list(range(n, -1, -9)) for n in range(80,71,-1)] 43 | idx = list(np.array(idx).flatten()) 44 | self.idx_fwd = Variable(torch.LongTensor(np.array(idx)).cuda(), requires_grad=False) 45 | self.idx_bwd = Variable(torch.LongTensor(np.array(list(reversed(idx)))).cuda(), requires_grad=False) 46 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') 47 | self.softmax2d = nn.Softmax2d() 48 | 49 | self.conv1a = conv_feat_block(3,16) 50 | self.conv1b = conv_feat_block(3,16) 51 | self.conv1c = conv_feat_block(3,16) 52 | 53 | self.conv2a = conv_feat_block(16,32) 54 | self.conv2b = conv_feat_block(16,32) 55 | self.conv2c = conv_feat_block(16,32) 56 | 57 | self.conv3a = conv_feat_block(32,64) 58 | self.conv3b = conv_feat_block(32,64) 59 | self.conv3c = conv_feat_block(32,64) 60 | 61 | self.conv4a = conv_feat_block(64,96) 62 | self.conv4b = conv_feat_block(64,96) 63 | self.conv4c = conv_feat_block(64,96) 64 | 65 | self.conv5a = conv_feat_block(96,128) 66 | self.conv5b = conv_feat_block(96,128) 67 | self.conv5c = conv_feat_block(96,128) 68 | 69 | self.conv6a = conv_feat_block(128,192) 70 | self.conv6b = conv_feat_block(128,192) 71 | self.conv6c = conv_feat_block(128,192) 72 | 73 | self.corr = Correlation(pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1) 74 | 75 | self.decoder_fwd6 = conv_dec_block(162) 76 | self.decoder_bwd6 = conv_dec_block(162) 77 | self.decoder_fwd5 = conv_dec_block(292) 78 | self.decoder_bwd5 = conv_dec_block(292) 79 | self.decoder_fwd4 = conv_dec_block(260) 80 | self.decoder_bwd4 = conv_dec_block(260) 81 | self.decoder_fwd3 = conv_dec_block(228) 82 | self.decoder_bwd3 = conv_dec_block(228) 83 | self.decoder_fwd2 = conv_dec_block(196) 84 | self.decoder_bwd2 = conv_dec_block(196) 85 | 86 | self.decoder_occ6 = conv_dec_block(354) 87 | self.decoder_occ5 = conv_dec_block(292) 88 | self.decoder_occ4 = conv_dec_block(260) 89 | self.decoder_occ3 = conv_dec_block(228) 90 | self.decoder_occ2 = conv_dec_block(196) 91 | 92 | if pretrained is not None: 93 | self.load_state_dict(torch.load(pretrained)) 94 | print('Model loaded from ', pretrained) 95 | 96 | def load(self, key, load_id): 97 | module = getattr(self, key) 98 | loadpath = os.path.join(LOAD_DIR, load_id) 99 | for i,m in enumerate(module): 100 | if type(m)==nn.Conv2d: 101 | weight_path = loadpath + '_' + str(i+1) + 'weight.t7' 102 | bias_path = loadpath + '_' + str(i+1) + 'bias.t7' 103 | m.weight.data.copy_(load_lua(weight_path)) 104 | m.bias.data.copy_(load_lua(bias_path)) 105 | 106 | def initialize(self): 107 | #This is used for the first load after saving weight files from lua 108 | map = {'conv1a': '3', 'conv1b': '6', 'conv1c': '22', 109 | 'conv2a': '4', 'conv2b': '7', 'conv2c': '23', 110 | 'conv3a': '9', 'conv3b': '10', 'conv3c': '24', 111 | 'conv4a': '12', 'conv4b': '13', 'conv4c': '25', 112 | 'conv5a': '15', 'conv5b': '16', 'conv5c': '26', 113 | 'conv6a': '18', 'conv6b': '19', 'conv6c': '27', 114 | 'decoder_fwd6': '30', 'decoder_bwd6': '93', 'decoder_occ6': '183', 115 | 'decoder_fwd5': '45', 'decoder_bwd5': '96', 'decoder_occ5': '164', 116 | 'decoder_fwd4': '60', 'decoder_bwd4': '99', 'decoder_occ4': '145', 117 | 'decoder_fwd3': '75', 'decoder_bwd3': '102', 'decoder_occ3': '126', 118 | 'decoder_fwd2': '90', 'decoder_bwd2': '105', 'decoder_occ2': '109' 119 | } 120 | 121 | for key in map.keys(): 122 | print("Loading ", key) 123 | self.load(key, map[key]) 124 | 125 | def normalize(self, ims): 126 | imt = [] 127 | for im in ims: 128 | im[:,0,:,:] = im[:,0,:,:] - 0.485 # Red 129 | im[:,1,:,:] = im[:,1,:,:] - 0.456 # Green 130 | im[:,2,:,:] = im[:,2,:,:] - 0.406 # Blue 131 | 132 | im[:,0,:,:] = im[:,0,:,:] / 0.229 # Red 133 | im[:,1,:,:] = im[:,1,:,:] / 0.224 # Green 134 | im[:,2,:,:] = im[:,2,:,:] / 0.225 # Blue 135 | 136 | imt.append(im) 137 | return imt 138 | 139 | def forward(self, im_tar, im_refs): 140 | ''' 141 | Arguments: 142 | im_tar : Centre Frame 143 | im_refs : List constaining [Past_Frame, Future_Frame] 144 | ''' 145 | im_norm = self.normalize([im_tar] + im_refs) 146 | 147 | feat1a = self.conv1a(im_norm[0]) 148 | feat2a = self.conv2a(feat1a) 149 | feat3a = self.conv3a(feat2a) 150 | feat4a = self.conv4a(feat3a) 151 | feat5a = self.conv5a(feat4a) 152 | feat6a = self.conv6a(feat5a) 153 | 154 | feat1b = self.conv1b(im_norm[2]) 155 | feat2b = self.conv2b(feat1b) 156 | feat3b = self.conv3b(feat2b) 157 | feat4b = self.conv4b(feat3b) 158 | feat5b = self.conv5b(feat4b) 159 | feat6b = self.conv6b(feat5b) 160 | 161 | feat1c = self.conv1c(im_norm[1]) 162 | feat2c = self.conv2c(feat1c) 163 | feat3c = self.conv3c(feat2c) 164 | feat4c = self.conv4c(feat3c) 165 | feat5c = self.conv5c(feat4c) 166 | feat6c = self.conv6c(feat5c) 167 | 168 | corr6_fwd = self.corr(feat6a, feat6b) 169 | corr6_fwd = corr6_fwd.index_select(1,self.idx_fwd) 170 | corr6_bwd = self.corr(feat6a, feat6c) 171 | corr6_bwd = corr6_bwd.index_select(1,self.idx_bwd) 172 | corr6 = torch.cat((corr6_fwd, corr6_bwd), 1) 173 | 174 | flow6_fwd = self.decoder_fwd6(corr6) 175 | flow6_fwd_up = self.upsample(flow6_fwd) 176 | flow6_bwd = self.decoder_bwd6(corr6) 177 | flow6_bwd_up = self.upsample(flow6_bwd) 178 | feat5b_warped = self.warp(feat5b, 0.625*flow6_fwd_up) 179 | feat5c_warped = self.warp(feat5c, -0.625*flow6_bwd_up) 180 | 181 | occ6_feat = torch.cat((corr6, feat6a), 1) 182 | occ6 = self.softmax2d(self.decoder_occ6(occ6_feat)) 183 | 184 | corr5_fwd = self.corr(feat5a, feat5b_warped) 185 | corr5_fwd = corr5_fwd.index_select(1,self.idx_fwd) 186 | corr5_bwd = self.corr(feat5a, feat5c_warped) 187 | corr5_bwd = corr5_bwd.index_select(1,self.idx_bwd) 188 | corr5 = torch.cat((corr5_fwd, corr5_bwd), 1) 189 | 190 | upfeat5_fwd = torch.cat((corr5, feat5a, flow6_fwd_up), 1) 191 | flow5_fwd = self.decoder_fwd5(upfeat5_fwd) 192 | flow5_fwd_up = self.upsample(flow5_fwd) 193 | upfeat5_bwd = torch.cat((corr5, feat5a, flow6_bwd_up),1) 194 | flow5_bwd = self.decoder_bwd5(upfeat5_bwd) 195 | flow5_bwd_up = self.upsample(flow5_bwd) 196 | feat4b_warped = self.warp(feat4b, 1.25*flow5_fwd_up) 197 | feat4c_warped = self.warp(feat4c, -1.25*flow5_bwd_up) 198 | 199 | occ5 = self.softmax2d(self.decoder_occ5(upfeat5_fwd)) 200 | 201 | corr4_fwd = self.corr(feat4a, feat4b_warped) 202 | corr4_fwd = corr4_fwd.index_select(1,self.idx_fwd) 203 | corr4_bwd = self.corr(feat4a, feat4c_warped) 204 | corr4_bwd = corr4_bwd.index_select(1,self.idx_bwd) 205 | corr4 = torch.cat((corr4_fwd, corr4_bwd), 1) 206 | 207 | upfeat4_fwd = torch.cat((corr4, feat4a, flow5_fwd_up), 1) 208 | flow4_fwd = self.decoder_fwd4(upfeat4_fwd) 209 | flow4_fwd_up = self.upsample(flow4_fwd) 210 | upfeat4_bwd = torch.cat((corr4, feat4a, flow5_bwd_up),1) 211 | flow4_bwd = self.decoder_bwd4(upfeat4_bwd) 212 | flow4_bwd_up = self.upsample(flow4_bwd) 213 | feat3b_warped = self.warp(feat3b, 2.5*flow4_fwd_up) 214 | feat3c_warped = self.warp(feat3c, -2.5*flow4_bwd_up) 215 | 216 | occ4 = self.softmax2d(self.decoder_occ4(upfeat4_fwd)) 217 | 218 | corr3_fwd = self.corr(feat3a, feat3b_warped) 219 | corr3_fwd = corr3_fwd.index_select(1,self.idx_fwd) 220 | corr3_bwd = self.corr(feat3a, feat3c_warped) 221 | corr3_bwd = corr3_bwd.index_select(1,self.idx_bwd) 222 | corr3 = torch.cat((corr3_fwd, corr3_bwd), 1) 223 | 224 | upfeat3_fwd = torch.cat((corr3, feat3a, flow4_fwd_up), 1) 225 | flow3_fwd = self.decoder_fwd3(upfeat3_fwd) 226 | flow3_fwd_up = self.upsample(flow3_fwd) 227 | upfeat3_bwd = torch.cat((corr3, feat3a, flow4_bwd_up),1) 228 | flow3_bwd = self.decoder_bwd3(upfeat3_bwd) 229 | flow3_bwd_up = self.upsample(flow3_bwd) 230 | feat2b_warped = self.warp(feat2b, 5.0*flow3_fwd_up) 231 | feat2c_warped = self.warp(feat2c, -5.0*flow3_bwd_up) 232 | 233 | occ3 = self.softmax2d(self.decoder_occ3(upfeat3_fwd)) 234 | 235 | corr2_fwd = self.corr(feat2a, feat2b_warped) 236 | corr2_fwd = corr2_fwd.index_select(1,self.idx_fwd) 237 | corr2_bwd = self.corr(feat2a, feat2c_warped) 238 | corr2_bwd = corr2_bwd.index_select(1,self.idx_bwd) 239 | corr2 = torch.cat((corr2_fwd, corr2_bwd), 1) 240 | 241 | upfeat2_fwd = torch.cat((corr2, feat2a, flow3_fwd_up), 1) 242 | flow2_fwd = self.decoder_fwd2(upfeat2_fwd) 243 | flow2_fwd_up = self.upsample(flow2_fwd) 244 | upfeat2_bwd = torch.cat((corr2, feat2a, flow3_bwd_up),1) 245 | flow2_bwd = self.decoder_bwd2(upfeat2_bwd) 246 | flow2_bwd_up = self.upsample(flow2_bwd) 247 | 248 | occ2 = self.softmax2d(self.decoder_occ2(upfeat2_fwd)) 249 | 250 | flow2_fwd_fullres = 20*self.upsample(flow2_fwd_up) 251 | flow3_fwd_fullres = 10*self.upsample(flow3_fwd_up) 252 | flow4_fwd_fullres = 5*self.upsample(flow4_fwd_up) 253 | flow5_fwd_fullres = 2.5*self.upsample(flow5_fwd_up) 254 | flow6_fwd_fullres = 1.25*self.upsample(flow6_fwd_up) 255 | 256 | flow2_bwd_fullres = -20*self.upsample(flow2_bwd_up) 257 | flow3_bwd_fullres = -10*self.upsample(flow3_bwd_up) 258 | flow4_bwd_fullres = -5*self.upsample(flow4_bwd_up) 259 | flow5_bwd_fullres = -2.5*self.upsample(flow5_bwd_up) 260 | flow6_bwd_fullres = -1.25*self.upsample(flow6_bwd_up) 261 | 262 | occ2_fullres = F.upsample(occ2, scale_factor=4) 263 | occ3_fullres = F.upsample(occ3, scale_factor=4) 264 | occ4_fullres = F.upsample(occ4, scale_factor=4) 265 | occ5_fullres = F.upsample(occ5, scale_factor=4) 266 | occ6_fullres = F.upsample(occ6, scale_factor=4) 267 | 268 | flow_fwd = [flow2_fwd_fullres, flow3_fwd_fullres, flow4_fwd_fullres, flow5_fwd_fullres, flow6_fwd_fullres] 269 | flow_bwd = [flow2_bwd_fullres, flow3_bwd_fullres, flow4_bwd_fullres, flow5_bwd_fullres, flow6_bwd_fullres] 270 | occ = [occ2_fullres, occ3_fullres, occ4_fullres, occ5_fullres, occ6_fullres] 271 | 272 | return flow_fwd, flow_bwd, occ 273 | 274 | def warp(self, x, flo): 275 | """ 276 | warp an image/tensor (im2) back to im1, according to the optical flow 277 | x: [B, C, H, W] (im2) 278 | flo: [B, 2, H, W] flow 279 | """ 280 | B, C, H, W = x.size() 281 | # mesh grid 282 | xx = torch.arange(0, W).view(1,-1).repeat(H,1) 283 | yy = torch.arange(0, H).view(-1,1).repeat(1,W) 284 | xx = xx.view(1,1,H,W).repeat(B,1,1,1) 285 | yy = yy.view(1,1,H,W).repeat(B,1,1,1) 286 | grid = torch.cat((xx,yy),1).float() 287 | 288 | if x.is_cuda: 289 | grid = grid.cuda() 290 | vgrid = Variable(grid) + flo 291 | 292 | # scale grid to [-1,1] 293 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:]/max(W-1,1)-1.0 294 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:]/max(H-1,1)-1.0 295 | 296 | vgrid = vgrid.permute(0,2,3,1) 297 | output = nn.functional.grid_sample(x, vgrid, padding_mode='border') 298 | mask = torch.autograd.Variable(torch.ones(x.size()), requires_grad=False).cuda() 299 | mask = nn.functional.grid_sample(mask, vgrid) 300 | 301 | # if W==128: 302 | # np.save('mask.npy', mask.cpu().data.numpy()) 303 | # np.save('warp.npy', output.cpu().data.numpy()) 304 | 305 | mask[mask.data<0.9999] = 0 306 | mask[mask.data>0] = 1 307 | 308 | return output#*mask 309 | -------------------------------------------------------------------------------- /convert.lua: -------------------------------------------------------------------------------- 1 | -- Author: Anurag Ranjan 2 | -- Copyright (c) 2018, Max Planck Society 3 | 4 | require 'paths' 5 | require 'torch' 6 | require 'cutorch' 7 | require 'nn' 8 | require 'cunn' 9 | require 'cudnn' 10 | require 'stn' 11 | require 'spy' 12 | require 'nngraph' 13 | require 'models/CostVolMulti' 14 | 15 | SAVE_DIR = 'pretrained/pwc' 16 | model = torch.load('pretrained/Roaming_KITTI_model_300.t7' ) 17 | 18 | function save_sequential(name, model) 19 | for i = 1, #model do 20 | module = model:get(i) 21 | if tostring(torch.type(module)) == 'nn.SpatialConvolution' then 22 | torch.save(paths.concat(SAVE_DIR, name..'_'..tostring(i)..'weight.t7'), module.weight) 23 | torch.save(paths.concat(SAVE_DIR, name..'_'..tostring(i)..'bias.t7'), module.bias) 24 | end 25 | end 26 | end 27 | 28 | for i = 1, 200 do 29 | print('Traversing node' ..i) 30 | node = model:get(i) 31 | if tostring(torch.type(node)) == 'nn.Sequential' then 32 | nodenn = cudnn.convert(node, nn) 33 | nodenn_float = nodenn:float() 34 | save_sequential(tostring(i), nodenn_float) 35 | end 36 | end 37 | 38 | function warpingUnit() 39 | local I = nn.Identity()() 40 | local F = nn.Identity()() 41 | local input = I - nn.Transpose({2,3}, {3,4}) 42 | local flow = F - nn.Transpose({2,3}, {3,4}) 43 | local W = {input, flow} - nn.BilinearSamplerBHWD() - nn.Transpose({3,4}, {2,3}) 44 | local model = nn.gModule({I, F}, {W}) 45 | return model 46 | end 47 | 48 | for k, v in ipairs(model.forwardnodes) do 49 | print(k-1, v.id, v.data.module) 50 | v.data.annotations.name = tostring(k-1) 51 | end 52 | -------------------------------------------------------------------------------- /correlation_package/.make.sh.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/correlation_package/.make.sh.swp -------------------------------------------------------------------------------- /correlation_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/correlation_package/__init__.py -------------------------------------------------------------------------------- /correlation_package/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.ffi 4 | 5 | this_folder = os.path.dirname(os.path.abspath(__file__)) + '/' 6 | 7 | Headers = [] 8 | Sources = [] 9 | Defines = [] 10 | Objects = [] 11 | 12 | if torch.cuda.is_available() == True: 13 | Headers += ['src/correlation_cuda.h'] 14 | Sources += ['src/correlation_cuda.c'] 15 | Defines += [('WITH_CUDA', None)] 16 | Objects += ['src/correlation_cuda_kernel.o'] 17 | 18 | ffi = torch.utils.ffi.create_extension( 19 | name='_ext.correlation', 20 | headers=Headers, 21 | sources=Sources, 22 | verbose=False, 23 | with_cuda=True, 24 | package=False, 25 | relative_to=this_folder, 26 | define_macros=Defines, 27 | extra_objects=[os.path.join(this_folder, Object) for Object in Objects] 28 | ) 29 | 30 | if __name__ == '__main__': 31 | ffi.build() -------------------------------------------------------------------------------- /correlation_package/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/correlation_package/functions/__init__.py -------------------------------------------------------------------------------- /correlation_package/functions/correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from .._ext import correlation 4 | 5 | 6 | class CorrelationFunction(Function): 7 | 8 | def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1): 9 | super(CorrelationFunction, self).__init__() 10 | self.pad_size = pad_size 11 | self.kernel_size = kernel_size 12 | self.max_displacement = max_displacement 13 | self.stride1 = stride1 14 | self.stride2 = stride2 15 | self.corr_multiply = corr_multiply 16 | # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1) 17 | 18 | def forward(self, input1, input2): 19 | self.save_for_backward(input1, input2) 20 | 21 | assert(input1.is_contiguous() == True) 22 | assert(input2.is_contiguous() == True) 23 | 24 | with torch.cuda.device_of(input1): 25 | rbot1 = input1.new() 26 | rbot2 = input2.new() 27 | output = input1.new() 28 | 29 | correlation.Correlation_forward_cuda(input1, input2, rbot1, rbot2, output, 30 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) 31 | 32 | return output 33 | 34 | def backward(self, grad_output): 35 | input1, input2 = self.saved_tensors 36 | 37 | assert(grad_output.is_contiguous() == True) 38 | 39 | with torch.cuda.device_of(input1): 40 | rbot1 = input1.new() 41 | rbot2 = input2.new() 42 | 43 | grad_input1 = input1.new() 44 | grad_input2 = input2.new() 45 | 46 | correlation.Correlation_backward_cuda(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2, 47 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) 48 | 49 | return grad_input1, grad_input2 -------------------------------------------------------------------------------- /correlation_package/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))") 3 | 4 | cd src 5 | 6 | echo "Compiling correlation kernels by nvcc..." 7 | 8 | rm correlation_cuda_kernel.o 9 | rm -r ../_ext 10 | 11 | nvcc -c -o correlation_cuda_kernel.o correlation_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52 12 | 13 | cd ../ 14 | python build.py 15 | -------------------------------------------------------------------------------- /correlation_package/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/correlation_package/modules/__init__.py -------------------------------------------------------------------------------- /correlation_package/modules/correlation.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | 3 | from ..functions.correlation import CorrelationFunction 4 | 5 | class Correlation(Module): 6 | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1): 7 | super(Correlation, self).__init__() 8 | self.pad_size = pad_size 9 | self.kernel_size = kernel_size 10 | self.max_displacement = max_displacement 11 | self.stride1 = stride1 12 | self.stride2 = stride2 13 | self.corr_multiply = corr_multiply 14 | 15 | def forward(self, input1, input2): 16 | 17 | result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2) 18 | 19 | return result 20 | -------------------------------------------------------------------------------- /correlation_package/src/correlation.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | int Correlation_forward_cpu(THFloatTensor *input1, 4 | THFloatTensor *input2, 5 | THFloatTensor *rInput1, 6 | THFloatTensor *rInput2, 7 | THFloatTensor *output, 8 | int pad_size, 9 | int kernel_size, 10 | int max_displacement, 11 | int stride1, 12 | int stride2, 13 | int corr_type_multiply) 14 | { 15 | return 1; 16 | } 17 | 18 | int Correlation_backward_cpu(THFloatTensor *input1, 19 | THFloatTensor *input2, 20 | THFloatTensor *rInput1, 21 | THFloatTensor *rInput2, 22 | THFloatTensor *gradOutput, 23 | THFloatTensor *gradInput1, 24 | THFloatTensor *gradInput2, 25 | int pad_size, 26 | int kernel_size, 27 | int max_displacement, 28 | int stride1, 29 | int stride2, 30 | int corr_type_multiply) 31 | { 32 | return 1; 33 | } 34 | -------------------------------------------------------------------------------- /correlation_package/src/correlation.h: -------------------------------------------------------------------------------- 1 | int Correlation_forward_cpu(THFloatTensor *input1, 2 | THFloatTensor *input2, 3 | THFloatTensor *rInput1, 4 | THFloatTensor *rInput2, 5 | THFloatTensor *output, 6 | int pad_size, 7 | int kernel_size, 8 | int max_displacement, 9 | int stride1, 10 | int stride2, 11 | int corr_type_multiply); 12 | 13 | int Correlation_backward_cpu(THFloatTensor *input1, 14 | THFloatTensor *input2, 15 | THFloatTensor *rInput1, 16 | THFloatTensor *rInput2, 17 | THFloatTensor *gradOutput, 18 | THFloatTensor *gradInput1, 19 | THFloatTensor *gradInput2, 20 | int pad_size, 21 | int kernel_size, 22 | int max_displacement, 23 | int stride1, 24 | int stride2, 25 | int corr_type_multiply); 26 | -------------------------------------------------------------------------------- /correlation_package/src/correlation_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "correlation_cuda_kernel.h" 5 | 6 | #define real float 7 | 8 | // symbol to be automatically resolved by PyTorch libs 9 | extern THCState *state; 10 | 11 | int Correlation_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, THCudaTensor *output, 12 | int pad_size, 13 | int kernel_size, 14 | int max_displacement, 15 | int stride1, 16 | int stride2, 17 | int corr_type_multiply) 18 | { 19 | 20 | int batchSize = input1->size[0]; 21 | int nInputChannels = input1->size[1]; 22 | int inputHeight = input1->size[2]; 23 | int inputWidth = input1->size[3]; 24 | 25 | int kernel_radius = (kernel_size - 1) / 2; 26 | int border_radius = kernel_radius + max_displacement; 27 | 28 | int paddedInputHeight = inputHeight + 2 * pad_size; 29 | int paddedInputWidth = inputWidth + 2 * pad_size; 30 | 31 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); 32 | 33 | int outputHeight = ceil((float)(paddedInputHeight - 2 * border_radius) / (float)stride1); 34 | int outputwidth = ceil((float)(paddedInputWidth - 2 * border_radius) / (float)stride1); 35 | 36 | THCudaTensor_resize4d(state, rInput1, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels); 37 | THCudaTensor_resize4d(state, rInput2, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels); 38 | THCudaTensor_resize4d(state, output, batchSize, nOutputChannels, outputHeight, outputwidth); 39 | 40 | THCudaTensor_fill(state, rInput1, 0); 41 | THCudaTensor_fill(state, rInput2, 0); 42 | THCudaTensor_fill(state, output, 0); 43 | 44 | int success = 0; 45 | success = Correlation_forward_cuda_kernel( THCudaTensor_data(state, output), 46 | THCudaTensor_size(state, output, 0), 47 | THCudaTensor_size(state, output, 1), 48 | THCudaTensor_size(state, output, 2), 49 | THCudaTensor_size(state, output, 3), 50 | THCudaTensor_stride(state, output, 0), 51 | THCudaTensor_stride(state, output, 1), 52 | THCudaTensor_stride(state, output, 2), 53 | THCudaTensor_stride(state, output, 3), 54 | 55 | THCudaTensor_data(state, input1), 56 | THCudaTensor_size(state, input1, 1), 57 | THCudaTensor_size(state, input1, 2), 58 | THCudaTensor_size(state, input1, 3), 59 | THCudaTensor_stride(state, input1, 0), 60 | THCudaTensor_stride(state, input1, 1), 61 | THCudaTensor_stride(state, input1, 2), 62 | THCudaTensor_stride(state, input1, 3), 63 | 64 | THCudaTensor_data(state, input2), 65 | THCudaTensor_size(state, input2, 1), 66 | THCudaTensor_stride(state, input2, 0), 67 | THCudaTensor_stride(state, input2, 1), 68 | THCudaTensor_stride(state, input2, 2), 69 | THCudaTensor_stride(state, input2, 3), 70 | 71 | THCudaTensor_data(state, rInput1), 72 | THCudaTensor_data(state, rInput2), 73 | 74 | pad_size, 75 | kernel_size, 76 | max_displacement, 77 | stride1, 78 | stride2, 79 | corr_type_multiply, 80 | 81 | THCState_getCurrentStream(state)); 82 | 83 | THCudaTensor_free(state, rInput1); 84 | THCudaTensor_free(state, rInput2); 85 | 86 | //check for errors 87 | if (!success) { 88 | THError("aborting"); 89 | } 90 | 91 | return 1; 92 | 93 | } 94 | 95 | int Correlation_backward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, THCudaTensor *gradOutput, 96 | THCudaTensor *gradInput1, THCudaTensor *gradInput2, 97 | int pad_size, 98 | int kernel_size, 99 | int max_displacement, 100 | int stride1, 101 | int stride2, 102 | int corr_type_multiply) 103 | { 104 | 105 | int batchSize = input1->size[0]; 106 | int nInputChannels = input1->size[1]; 107 | int paddedInputHeight = input1->size[2]+ 2 * pad_size; 108 | int paddedInputWidth = input1->size[3]+ 2 * pad_size; 109 | 110 | int height = input1->size[2]; 111 | int width = input1->size[3]; 112 | 113 | THCudaTensor_resize4d(state, rInput1, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels); 114 | THCudaTensor_resize4d(state, rInput2, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels); 115 | THCudaTensor_resize4d(state, gradInput1, batchSize, nInputChannels, height, width); 116 | THCudaTensor_resize4d(state, gradInput2, batchSize, nInputChannels, height, width); 117 | 118 | THCudaTensor_fill(state, rInput1, 0); 119 | THCudaTensor_fill(state, rInput2, 0); 120 | THCudaTensor_fill(state, gradInput1, 0); 121 | THCudaTensor_fill(state, gradInput2, 0); 122 | 123 | int success = 0; 124 | success = Correlation_backward_cuda_kernel( 125 | THCudaTensor_data(state, gradOutput), 126 | THCudaTensor_size(state, gradOutput, 0), 127 | THCudaTensor_size(state, gradOutput, 1), 128 | THCudaTensor_size(state, gradOutput, 2), 129 | THCudaTensor_size(state, gradOutput, 3), 130 | THCudaTensor_stride(state, gradOutput, 0), 131 | THCudaTensor_stride(state, gradOutput, 1), 132 | THCudaTensor_stride(state, gradOutput, 2), 133 | THCudaTensor_stride(state, gradOutput, 3), 134 | 135 | THCudaTensor_data(state, input1), 136 | THCudaTensor_size(state, input1, 1), 137 | THCudaTensor_size(state, input1, 2), 138 | THCudaTensor_size(state, input1, 3), 139 | THCudaTensor_stride(state, input1, 0), 140 | THCudaTensor_stride(state, input1, 1), 141 | THCudaTensor_stride(state, input1, 2), 142 | THCudaTensor_stride(state, input1, 3), 143 | 144 | THCudaTensor_data(state, input2), 145 | THCudaTensor_stride(state, input2, 0), 146 | THCudaTensor_stride(state, input2, 1), 147 | THCudaTensor_stride(state, input2, 2), 148 | THCudaTensor_stride(state, input2, 3), 149 | 150 | THCudaTensor_data(state, gradInput1), 151 | THCudaTensor_stride(state, gradInput1, 0), 152 | THCudaTensor_stride(state, gradInput1, 1), 153 | THCudaTensor_stride(state, gradInput1, 2), 154 | THCudaTensor_stride(state, gradInput1, 3), 155 | 156 | THCudaTensor_data(state, gradInput2), 157 | THCudaTensor_size(state, gradInput2, 1), 158 | THCudaTensor_stride(state, gradInput2, 0), 159 | THCudaTensor_stride(state, gradInput2, 1), 160 | THCudaTensor_stride(state, gradInput2, 2), 161 | THCudaTensor_stride(state, gradInput2, 3), 162 | 163 | THCudaTensor_data(state, rInput1), 164 | THCudaTensor_data(state, rInput2), 165 | pad_size, 166 | kernel_size, 167 | max_displacement, 168 | stride1, 169 | stride2, 170 | corr_type_multiply, 171 | THCState_getCurrentStream(state)); 172 | 173 | THCudaTensor_free(state, rInput1); 174 | THCudaTensor_free(state, rInput2); 175 | 176 | if (!success) { 177 | THError("aborting"); 178 | } 179 | return 1; 180 | } 181 | -------------------------------------------------------------------------------- /correlation_package/src/correlation_cuda.h: -------------------------------------------------------------------------------- 1 | int Correlation_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, 2 | THCudaTensor *output, 3 | int pad_size, 4 | int kernel_size, 5 | int max_displacement, 6 | int stride1, 7 | int stride2, 8 | int corr_type_multiply); 9 | 10 | int Correlation_backward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, 11 | THCudaTensor *gradOutput, THCudaTensor *gradInput1, THCudaTensor *gradInput2, 12 | int pad_size, 13 | int kernel_size, 14 | int max_displacement, 15 | int stride1, 16 | int stride2, 17 | int corr_type_multiply); 18 | 19 | -------------------------------------------------------------------------------- /correlation_package/src/correlation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "correlation_cuda_kernel.h" 4 | 5 | #define real float 6 | 7 | #define CUDA_NUM_THREADS 1024 8 | #define THREADS_PER_BLOCK 32 9 | 10 | __global__ void channels_first(float* input, float* rinput, int channels, int height, int width, int pad_size) 11 | { 12 | // n (batch size), c (num of channels), y (height), x (width) 13 | int n = blockIdx.x; 14 | int y = blockIdx.y; 15 | int x = blockIdx.z; 16 | 17 | int ch_off = threadIdx.x; 18 | float value; 19 | 20 | int dimcyx = channels * height * width; 21 | int dimyx = height * width; 22 | 23 | int p_dimx = (width + 2 * pad_size); 24 | int p_dimy = (height + 2 * pad_size); 25 | int p_dimyxc = channels * p_dimy * p_dimx; 26 | int p_dimxc = p_dimx * channels; 27 | 28 | for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) { 29 | value = input[n * dimcyx + c * dimyx + y * width + x]; 30 | rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value; 31 | } 32 | } 33 | 34 | __global__ void Correlation_forward( float *output, int nOutputChannels, int outputHeight, int outputWidth, 35 | float *rInput1, int nInputChannels, int inputHeight, int inputWidth, 36 | float *rInput2, 37 | int pad_size, 38 | int kernel_size, 39 | int max_displacement, 40 | int stride1, 41 | int stride2) 42 | { 43 | // n (batch size), c (num of channels), y (height), x (width) 44 | 45 | int pInputWidth = inputWidth + 2 * pad_size; 46 | int pInputHeight = inputHeight + 2 * pad_size; 47 | 48 | int kernel_rad = (kernel_size - 1) / 2; 49 | int displacement_rad = max_displacement / stride2; 50 | int displacement_size = 2 * displacement_rad + 1; 51 | 52 | int n = blockIdx.x; 53 | int y1 = blockIdx.y * stride1 + max_displacement + kernel_rad; 54 | int x1 = blockIdx.z * stride1 + max_displacement + kernel_rad; 55 | int c = threadIdx.x; 56 | 57 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 58 | int pdimxc = pInputWidth * nInputChannels; 59 | int pdimc = nInputChannels; 60 | 61 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 62 | int tdimyx = outputHeight * outputWidth; 63 | int tdimx = outputWidth; 64 | 65 | float nelems = kernel_size * kernel_size * pdimc; 66 | 67 | __shared__ float prod_sum[THREADS_PER_BLOCK]; 68 | 69 | // no significant speed-up in using chip memory for input1 sub-data, 70 | // not enough chip memory size to accomodate memory per block for input2 sub-data 71 | // instead i've used device memory for both 72 | 73 | // element-wise product along channel axis 74 | for (int tj = -displacement_rad; tj <= displacement_rad; ++tj ) { 75 | for (int ti = -displacement_rad; ti <= displacement_rad; ++ti ) { 76 | prod_sum[c] = 0; 77 | int x2 = x1 + ti*stride2; 78 | int y2 = y1 + tj*stride2; 79 | 80 | for (int j = -kernel_rad; j <= kernel_rad; ++j) { 81 | for (int i = -kernel_rad; i <= kernel_rad; ++i) { 82 | for (int ch = c; ch < pdimc; ch += THREADS_PER_BLOCK) { 83 | int indx1 = n * pdimyxc + (y1+j) * pdimxc + (x1 + i) * pdimc + ch; 84 | int indx2 = n * pdimyxc + (y2+j) * pdimxc + (x2 + i) * pdimc + ch; 85 | 86 | prod_sum[c] += rInput1[indx1] * rInput2[indx2]; 87 | } 88 | } 89 | } 90 | 91 | // accumulate 92 | __syncthreads(); 93 | if (c == 0) { 94 | float reduce_sum = 0; 95 | for (int index = 0; index < THREADS_PER_BLOCK; ++index) { 96 | reduce_sum += prod_sum[index]; 97 | } 98 | int tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad); 99 | const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + blockIdx.z; 100 | output[tindx] = reduce_sum / nelems; 101 | } 102 | 103 | } 104 | } 105 | 106 | } 107 | 108 | __global__ void Correlation_backward_input1(int item, float *gradInput1, int nInputChannels, int inputHeight, int inputWidth, 109 | float *gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 110 | float *rInput2, 111 | int pad_size, 112 | int kernel_size, 113 | int max_displacement, 114 | int stride1, 115 | int stride2) 116 | { 117 | // n (batch size), c (num of channels), y (height), x (width) 118 | 119 | int n = item; 120 | int y = blockIdx.x * stride1 + pad_size; 121 | int x = blockIdx.y * stride1 + pad_size; 122 | int c = blockIdx.z; 123 | int tch_off = threadIdx.x; 124 | 125 | int kernel_rad = (kernel_size - 1) / 2; 126 | int displacement_rad = max_displacement / stride2; 127 | int displacement_size = 2 * displacement_rad + 1; 128 | 129 | int xmin = (x - kernel_rad - max_displacement) / stride1; 130 | int ymin = (y - kernel_rad - max_displacement) / stride1; 131 | 132 | int xmax = (x + kernel_rad - max_displacement) / stride1; 133 | int ymax = (y + kernel_rad - max_displacement) / stride1; 134 | 135 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 136 | // assumes gradInput1 is pre-allocated and zero filled 137 | return; 138 | } 139 | 140 | if (xmin > xmax || ymin > ymax) { 141 | // assumes gradInput1 is pre-allocated and zero filled 142 | return; 143 | } 144 | 145 | xmin = max(0,xmin); 146 | xmax = min(outputWidth-1,xmax); 147 | 148 | ymin = max(0,ymin); 149 | ymax = min(outputHeight-1,ymax); 150 | 151 | int pInputWidth = inputWidth + 2 * pad_size; 152 | int pInputHeight = inputHeight + 2 * pad_size; 153 | 154 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 155 | int pdimxc = pInputWidth * nInputChannels; 156 | int pdimc = nInputChannels; 157 | 158 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 159 | int tdimyx = outputHeight * outputWidth; 160 | int tdimx = outputWidth; 161 | 162 | int odimcyx = nInputChannels * inputHeight* inputWidth; 163 | int odimyx = inputHeight * inputWidth; 164 | int odimx = inputWidth; 165 | 166 | float nelems = kernel_size * kernel_size * nInputChannels; 167 | 168 | __shared__ float prod_sum[CUDA_NUM_THREADS]; 169 | prod_sum[tch_off] = 0; 170 | 171 | for (int tc = tch_off; tc < nOutputChannels; tc += CUDA_NUM_THREADS) { 172 | 173 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 174 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 175 | 176 | int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c; 177 | 178 | float val2 = rInput2[indx2]; 179 | 180 | for (int j = ymin; j <= ymax; ++j) { 181 | for (int i = xmin; i <= xmax; ++i) { 182 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 183 | prod_sum[tch_off] += gradOutput[tindx] * val2; 184 | } 185 | } 186 | } 187 | __syncthreads(); 188 | 189 | if(tch_off == 0) { 190 | float reduce_sum = 0; 191 | for(int idx = 0; idx < CUDA_NUM_THREADS; idx++) { 192 | reduce_sum += prod_sum[idx]; 193 | } 194 | const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 195 | gradInput1[indx1] = reduce_sum / nelems; 196 | } 197 | 198 | } 199 | 200 | __global__ void Correlation_backward_input2(int item, float *gradInput2, int nInputChannels, int inputHeight, int inputWidth, 201 | float *gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 202 | float *rInput1, 203 | int pad_size, 204 | int kernel_size, 205 | int max_displacement, 206 | int stride1, 207 | int stride2) 208 | { 209 | // n (batch size), c (num of channels), y (height), x (width) 210 | 211 | int n = item; 212 | int y = blockIdx.x * stride1 + pad_size; 213 | int x = blockIdx.y * stride1 + pad_size; 214 | int c = blockIdx.z; 215 | 216 | int tch_off = threadIdx.x; 217 | 218 | int kernel_rad = (kernel_size - 1) / 2; 219 | int displacement_rad = max_displacement / stride2; 220 | int displacement_size = 2 * displacement_rad + 1; 221 | 222 | int pInputWidth = inputWidth + 2 * pad_size; 223 | int pInputHeight = inputHeight + 2 * pad_size; 224 | 225 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 226 | int pdimxc = pInputWidth * nInputChannels; 227 | int pdimc = nInputChannels; 228 | 229 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 230 | int tdimyx = outputHeight * outputWidth; 231 | int tdimx = outputWidth; 232 | 233 | int odimcyx = nInputChannels * inputHeight* inputWidth; 234 | int odimyx = inputHeight * inputWidth; 235 | int odimx = inputWidth; 236 | 237 | float nelems = kernel_size * kernel_size * nInputChannels; 238 | 239 | __shared__ float prod_sum[CUDA_NUM_THREADS]; 240 | prod_sum[tch_off] = 0; 241 | 242 | for (int tc = tch_off; tc < nOutputChannels; tc += CUDA_NUM_THREADS) { 243 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 244 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 245 | 246 | int xmin = (x - kernel_rad - max_displacement - i2) / stride1; 247 | int ymin = (y - kernel_rad - max_displacement - j2) / stride1; 248 | 249 | int xmax = (x + kernel_rad - max_displacement - i2) / stride1; 250 | int ymax = (y + kernel_rad - max_displacement - j2) / stride1; 251 | 252 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 253 | // assumes gradInput2 is pre-allocated and zero filled 254 | continue; 255 | } 256 | 257 | if (xmin > xmax || ymin > ymax) { 258 | // assumes gradInput2 is pre-allocated and zero filled 259 | continue; 260 | } 261 | 262 | xmin = max(0,xmin); 263 | xmax = min(outputWidth-1,xmax); 264 | 265 | ymin = max(0,ymin); 266 | ymax = min(outputHeight-1,ymax); 267 | 268 | int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c; 269 | float val1 = rInput1[indx1]; 270 | 271 | for (int j = ymin; j <= ymax; ++j) { 272 | for (int i = xmin; i <= xmax; ++i) { 273 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 274 | prod_sum[tch_off] += gradOutput[tindx] * val1; 275 | } 276 | } 277 | } 278 | 279 | __syncthreads(); 280 | 281 | if(tch_off == 0) { 282 | float reduce_sum = 0; 283 | for(int idx = 0; idx < CUDA_NUM_THREADS; idx++) { 284 | reduce_sum += prod_sum[idx]; 285 | } 286 | const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 287 | gradInput2[indx2] = reduce_sum / nelems; 288 | } 289 | 290 | } 291 | 292 | #ifdef __cplusplus 293 | extern "C" { 294 | #endif 295 | 296 | int Correlation_forward_cuda_kernel(/*THCudaTensor_data(state, output)*/ float *output, 297 | /*THCudaTensor_size(state, output, 0)*/ int ob, 298 | /*THCudaTensor_size(state, output, 1)*/ int oc, 299 | /*THCudaTensor_size(state, output, 2)*/ int oh, 300 | /*THCudaTensor_size(state, output, 3)*/ int ow, 301 | /*THCudaTensor_stride(state, output, 0)*/ int osb, 302 | /*THCudaTensor_stride(state, output, 1)*/ int osc, 303 | /*THCudaTensor_stride(state, output, 2)*/ int osh, 304 | /*THCudaTensor_stride(state, output, 3)*/ int osw, 305 | 306 | /*THCudaTensor_data(state, input1)*/ float *input1, 307 | /*THCudaTensor_size(state, input1, 1)*/ int ic, 308 | /*THCudaTensor_size(state, input1, 2)*/ int ih, 309 | /*THCudaTensor_size(state, input1, 3)*/ int iw, 310 | /*THCudaTensor_stride(state, input1, 0)*/ int isb, 311 | /*THCudaTensor_stride(state, input1, 1)*/ int isc, 312 | /*THCudaTensor_stride(state, input1, 2)*/ int ish, 313 | /*THCudaTensor_stride(state, input1, 3)*/ int isw, 314 | 315 | /*THCudaTensor_data(state, input2)*/ float *input2, 316 | /*THCudaTensor_size(state, input2, 1)*/ int gc, 317 | /*THCudaTensor_stride(state, input2, 0)*/ int gsb, 318 | /*THCudaTensor_stride(state, input2, 1)*/ int gsc, 319 | /*THCudaTensor_stride(state, input2, 2)*/ int gsh, 320 | /*THCudaTensor_stride(state, input2, 3)*/ int gsw, 321 | 322 | /*THCudaTensor_data(state, rInput1)*/ float *rInput1, 323 | /*THCudaTensor_data(state, rInput2)*/ float *rInput2, 324 | int pad_size, 325 | int kernel_size, 326 | int max_displacement, 327 | int stride1, 328 | int stride2, 329 | int corr_type_multiply, 330 | /*THCState_getCurrentStream(state)*/ cudaStream_t stream) 331 | { 332 | int batchSize = ob; 333 | 334 | int nInputChannels = ic; 335 | int inputWidth = iw; 336 | int inputHeight = ih; 337 | 338 | int nOutputChannels = oc; 339 | int outputWidth = ow; 340 | int outputHeight = oh; 341 | 342 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 343 | dim3 threads_block(THREADS_PER_BLOCK); 344 | 345 | channels_first<<>> (input1,rInput1, nInputChannels, inputHeight, inputWidth,pad_size); 346 | channels_first<<>> (input2,rInput2, nInputChannels, inputHeight, inputWidth, pad_size); 347 | 348 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 349 | dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth); 350 | 351 | Correlation_forward <<< totalBlocksCorr, threadsPerBlock, 0, stream >>> 352 | (output, nOutputChannels, outputHeight, outputWidth, 353 | rInput1, nInputChannels, inputHeight, inputWidth, 354 | rInput2, 355 | pad_size, 356 | kernel_size, 357 | max_displacement, 358 | stride1, 359 | stride2); 360 | 361 | // check for errors 362 | cudaError_t err = cudaGetLastError(); 363 | if (err != cudaSuccess) { 364 | printf("error in Correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err)); 365 | return 0; 366 | } 367 | 368 | return 1; 369 | } 370 | 371 | int Correlation_backward_cuda_kernel( 372 | /*THCudaTensor_data(state, gradOutput)*/ float *gradOutput, 373 | /*THCudaTensor_size(state, gradOutput, 0)*/ int gob, 374 | /*THCudaTensor_size(state, gradOutput, 1)*/ int goc, 375 | /*THCudaTensor_size(state, gradOutput, 2)*/ int goh, 376 | /*THCudaTensor_size(state, gradOutput, 3)*/ int gow, 377 | /*THCudaTensor_stride(state, gradOutput, 0)*/ int gosb, 378 | /*THCudaTensor_stride(state, gradOutput, 1)*/ int gosc, 379 | /*THCudaTensor_stride(state, gradOutput, 2)*/ int gosh, 380 | /*THCudaTensor_stride(state, gradOutput, 3)*/ int gosw, 381 | 382 | /*THCudaTensor_data(state, input1)*/ float* input1, 383 | /*THCudaTensor_size(state, input1, 1)*/ int ic, 384 | /*THCudaTensor_size(state, input1, 2)*/ int ih, 385 | /*THCudaTensor_size(state, input1, 3)*/ int iw, 386 | /*THCudaTensor_stride(state, input1, 0)*/ int isb, 387 | /*THCudaTensor_stride(state, input1, 1)*/ int isc, 388 | /*THCudaTensor_stride(state, input1, 2)*/ int ish, 389 | /*THCudaTensor_stride(state, input1, 3)*/ int isw, 390 | 391 | /*THCudaTensor_data(state, input2)*/ float *input2, 392 | /*THCudaTensor_stride(state, input2, 0)*/ int gsb, 393 | /*THCudaTensor_stride(state, input2, 1)*/ int gsc, 394 | /*THCudaTensor_stride(state, input2, 2)*/ int gsh, 395 | /*THCudaTensor_stride(state, input2, 3)*/ int gsw, 396 | 397 | /*THCudaTensor_data(state, gradInput1)*/ float *gradInput1, 398 | /*THCudaTensor_stride(state, gradInput1, 0)*/ int gisb, 399 | /*THCudaTensor_stride(state, gradInput1, 1)*/ int gisc, 400 | /*THCudaTensor_stride(state, gradInput1, 2)*/ int gish, 401 | /*THCudaTensor_stride(state, gradInput1, 3)*/ int gisw, 402 | 403 | /*THCudaTensor_data(state, gradInput2)*/ float *gradInput2, 404 | /*THCudaTensor_size(state, gradInput2, 1)*/ int ggc, 405 | /*THCudaTensor_stride(state, gradInput2, 0)*/ int ggsb, 406 | /*THCudaTensor_stride(state, gradInput2, 1)*/ int ggsc, 407 | /*THCudaTensor_stride(state, gradInput2, 2)*/ int ggsh, 408 | /*THCudaTensor_stride(state, gradInput2, 3)*/ int ggsw, 409 | 410 | /*THCudaTensor_data(state, rInput1)*/ float *rInput1, 411 | /*THCudaTensor_data(state, rInput2)*/ float *rInput2, 412 | int pad_size, 413 | int kernel_size, 414 | int max_displacement, 415 | int stride1, 416 | int stride2, 417 | int corr_type_multiply, 418 | /*THCState_getCurrentStream(state)*/cudaStream_t stream) 419 | { 420 | 421 | int batchSize = gob; 422 | int num = batchSize; 423 | 424 | int nInputChannels = ic; 425 | int inputWidth = iw; 426 | int inputHeight = ih; 427 | 428 | int nOutputChannels = goc; 429 | int outputWidth = gow; 430 | int outputHeight = goh; 431 | 432 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 433 | dim3 threads_block(THREADS_PER_BLOCK); 434 | 435 | channels_first<<>> (input1, rInput1, nInputChannels,inputHeight, inputWidth, pad_size); 436 | channels_first<<>> (input2, rInput2, nInputChannels, inputHeight, inputWidth, pad_size); 437 | 438 | dim3 threadsPerBlock(CUDA_NUM_THREADS); 439 | dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels); 440 | 441 | for (int n = 0; n < num; ++n) { 442 | Correlation_backward_input1 << > > ( 443 | n, gradInput1, nInputChannels, inputHeight, inputWidth, 444 | gradOutput, nOutputChannels, outputHeight, outputWidth, 445 | rInput2, 446 | pad_size, 447 | kernel_size, 448 | max_displacement, 449 | stride1, 450 | stride2); 451 | } 452 | 453 | for(int n = 0; n < batchSize; n++) { 454 | Correlation_backward_input2<<>>( 455 | n, gradInput2, nInputChannels, inputHeight, inputWidth, 456 | gradOutput, nOutputChannels, outputHeight, outputWidth, 457 | rInput1, 458 | pad_size, 459 | kernel_size, 460 | max_displacement, 461 | stride1, 462 | stride2); 463 | } 464 | 465 | // check for errors 466 | cudaError_t err = cudaGetLastError(); 467 | if (err != cudaSuccess) { 468 | printf("error in Correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err)); 469 | return 0; 470 | } 471 | 472 | return 1; 473 | } 474 | 475 | #ifdef __cplusplus 476 | } 477 | #endif 478 | -------------------------------------------------------------------------------- /correlation_package/src/correlation_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | int Correlation_forward_cuda_kernel(/*THCudaTensor_data(state, output)*/ float *output, 6 | /*THCudaTensor_size(state, output, 0)*/ int ob, 7 | /*THCudaTensor_size(state, output, 1)*/ int oc, 8 | /*THCudaTensor_size(state, output, 2)*/ int oh, 9 | /*THCudaTensor_size(state, output, 3)*/ int ow, 10 | /*THCudaTensor_stride(state, output, 0)*/ int osb, 11 | /*THCudaTensor_stride(state, output, 1)*/ int osc, 12 | /*THCudaTensor_stride(state, output, 2)*/ int osh, 13 | /*THCudaTensor_stride(state, output, 3)*/ int osw, 14 | 15 | /*THCudaTensor_data(state, input1)*/ float *input1, 16 | /*THCudaTensor_size(state, input1, 1)*/ int ic, 17 | /*THCudaTensor_size(state, input1, 2)*/ int ih, 18 | /*THCudaTensor_size(state, input1, 3)*/ int iw, 19 | /*THCudaTensor_stride(state, input1, 0)*/ int isb, 20 | /*THCudaTensor_stride(state, input1, 1)*/ int isc, 21 | /*THCudaTensor_stride(state, input1, 2)*/ int ish, 22 | /*THCudaTensor_stride(state, input1, 3)*/ int isw, 23 | 24 | /*THCudaTensor_data(state, input2)*/ float *input2, 25 | /*THCudaTensor_size(state, input2, 1)*/ int gc, 26 | /*THCudaTensor_stride(state, input2, 0)*/ int gsb, 27 | /*THCudaTensor_stride(state, input2, 1)*/ int gsc, 28 | /*THCudaTensor_stride(state, input2, 2)*/ int gsh, 29 | /*THCudaTensor_stride(state, input2, 3)*/ int gsw, 30 | 31 | /*THCudaTensor_data(state, rInput1)*/ float *rInput1, 32 | /*THCudaTensor_data(state, rInput2)*/ float *rInput2, 33 | int pad_size, 34 | int kernel_size, 35 | int max_displacement, 36 | int stride1, 37 | int stride2, 38 | int corr_type_multiply, 39 | /*THCState_getCurrentStream(state)*/ cudaStream_t stream); 40 | 41 | int Correlation_backward_cuda_kernel( 42 | /*THCudaTensor_data(state, gradOutput)*/ float *gradOutput, 43 | /*THCudaTensor_size(state, gradOutput, 0)*/ int gob, 44 | /*THCudaTensor_size(state, gradOutput, 1)*/ int goc, 45 | /*THCudaTensor_size(state, gradOutput, 2)*/ int goh, 46 | /*THCudaTensor_size(state, gradOutput, 3)*/ int gow, 47 | /*THCudaTensor_stride(state, gradOutput, 0)*/ int gosb, 48 | /*THCudaTensor_stride(state, gradOutput, 1)*/ int gosc, 49 | /*THCudaTensor_stride(state, gradOutput, 2)*/ int gosh, 50 | /*THCudaTensor_stride(state, gradOutput, 3)*/ int gosw, 51 | 52 | /*THCudaTensor_data(state, input1)*/ float* input1, 53 | /*THCudaTensor_size(state, input1, 1)*/ int ic, 54 | /*THCudaTensor_size(state, input1, 2)*/ int ih, 55 | /*THCudaTensor_size(state, input1, 3)*/ int iw, 56 | /*THCudaTensor_stride(state, input1, 0)*/ int isb, 57 | /*THCudaTensor_stride(state, input1, 1)*/ int isc, 58 | /*THCudaTensor_stride(state, input1, 2)*/ int ish, 59 | /*THCudaTensor_stride(state, input1, 3)*/ int isw, 60 | 61 | /*THCudaTensor_data(state, input2)*/ float *input2, 62 | /*THCudaTensor_stride(state, input2, 0)*/ int gsb, 63 | /*THCudaTensor_stride(state, input2, 1)*/ int gsc, 64 | /*THCudaTensor_stride(state, input2, 2)*/ int gsh, 65 | /*THCudaTensor_stride(state, input2, 3)*/ int gsw, 66 | 67 | /*THCudaTensor_data(state, gradInput1)*/ float *gradInput1, 68 | /*THCudaTensor_stride(state, gradInput1, 0)*/ int gisb, 69 | /*THCudaTensor_stride(state, gradInput1, 1)*/ int gisc, 70 | /*THCudaTensor_stride(state, gradInput1, 2)*/ int gish, 71 | /*THCudaTensor_stride(state, gradInput1, 3)*/ int gisw, 72 | 73 | /*THCudaTensor_data(state, gradInput2)*/ float *gradInput2, 74 | /*THCudaTensor_size(state, gradInput2, 1)*/ int ggc, 75 | /*THCudaTensor_stride(state, gradInput2, 0)*/ int ggsb, 76 | /*THCudaTensor_stride(state, gradInput2, 1)*/ int ggsc, 77 | /*THCudaTensor_stride(state, gradInput2, 2)*/ int ggsh, 78 | /*THCudaTensor_stride(state, gradInput2, 3)*/ int ggsw, 79 | 80 | /*THCudaTensor_data(state, rInput1)*/ float *rInput1, 81 | /*THCudaTensor_data(state, rInput2)*/ float *rInput2, 82 | int pad_size, 83 | int kernel_size, 84 | int max_displacement, 85 | int stride1, 86 | int stride2, 87 | int corr_type_multiply, 88 | /*THCState_getCurrentStream(state)*/cudaStream_t stream); 89 | 90 | #ifdef __cplusplus 91 | } 92 | #endif 93 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from back2future import Model 3 | import numpy as np 4 | from scipy.misc import imread, imresize 5 | from torchvision.transforms import ToTensor 6 | from torch.autograd import Variable 7 | 8 | def main(): 9 | model = Model(pretrained='pretrained/back2future_kitti.pth.tar') 10 | model = model.cuda() 11 | im_tar, im_refs = fetch_image_tensors() 12 | im_tar = Variable(im_tar.unsqueeze(0)).cuda() 13 | im_refs = [Variable(im_r.unsqueeze(0)).cuda() for im_r in im_refs] 14 | flow_fwd, flow_bwd, occ = model(im_tar, im_refs) 15 | np.save('outputs.npy', {'flow_fwd':flow_fwd[0].cpu().data.numpy(), 16 | 'flow_bwd':flow_bwd[0].cpu().data.numpy(), 17 | 'occ':occ[0].cpu().data.numpy()}) 18 | print('Outputs saved in outputs.npy') 19 | print('Done!') 20 | 21 | def load_as_float(path): 22 | return imread(path).astype(np.float32) 23 | 24 | def fetch_image_tensors(): 25 | im0 = load_as_float('samples/000010_09.png') 26 | im1 = load_as_float('samples/000010_10.png') 27 | im2 = load_as_float('samples/000010_11.png') 28 | scale = Scale(h=256, w=832) 29 | im012 = scale([im0, im1, im2]) 30 | im_tar = im012[1] 31 | im_refs = [im012[0], im012[2]] 32 | return im_tar, im_refs 33 | 34 | class Scale(object): 35 | """Scales images to a particular size""" 36 | def __init__(self, h, w): 37 | self.h = h 38 | self.w = w 39 | 40 | def __call__(self, images): 41 | in_h, in_w, _ = images[0].shape 42 | scaled_h, scaled_w = self.h , self.w 43 | scaled_images = [ToTensor()(imresize(im, (scaled_h, scaled_w))) for im in images] 44 | return scaled_images 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /flow_io.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python2 2 | 3 | """ 4 | I/O script to save and load the data coming with the MPI-Sintel low-level 5 | computer vision benchmark. 6 | 7 | For more details about the benchmark, please visit www.mpi-sintel.de 8 | 9 | CHANGELOG: 10 | v1.0 (2015/02/03): First release 11 | 12 | Copyright (c) 2015 Jonas Wulff 13 | Max Planck Institute for Intelligent Systems, Tuebingen, Germany 14 | 15 | """ 16 | 17 | # Requirements: Numpy as PIL/Pillow 18 | import numpy as np 19 | try: 20 | import png 21 | has_png = True 22 | except: 23 | has_png = False 24 | png=None 25 | 26 | 27 | 28 | # Check for endianness, based on Daniel Scharstein's optical flow code. 29 | # Using little-endian architecture, these two should be equal. 30 | TAG_FLOAT = 202021.25 31 | TAG_CHAR = 'PIEH'.encode() 32 | 33 | def flow_read(filename, return_validity=False): 34 | """ Read optical flow from file, return (U,V) tuple. 35 | 36 | Original code by Deqing Sun, adapted from Daniel Scharstein. 37 | """ 38 | f = open(filename,'rb') 39 | check = np.fromfile(f,dtype=np.float32,count=1)[0] 40 | assert check == TAG_FLOAT, ' flow_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) 41 | width = np.fromfile(f,dtype=np.int32,count=1)[0] 42 | height = np.fromfile(f,dtype=np.int32,count=1)[0] 43 | size = width*height 44 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' flow_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) 45 | tmp = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width*2)) 46 | u = tmp[:,np.arange(width)*2] 47 | v = tmp[:,np.arange(width)*2 + 1] 48 | 49 | if return_validity: 50 | valid = u<1e19 51 | u[valid==0] = 0 52 | v[valid==0] = 0 53 | return u,v,valid 54 | else: 55 | return u,v 56 | 57 | def flow_write(filename,uv,v=None): 58 | """ Write optical flow to file. 59 | 60 | If v is None, uv is assumed to contain both u and v channels, 61 | stacked in depth. 62 | 63 | Original code by Deqing Sun, adapted from Daniel Scharstein. 64 | """ 65 | nBands = 2 66 | 67 | if v is None: 68 | uv_ = np.array(uv) 69 | assert(uv_.ndim==3) 70 | if uv_.shape[0] == 2: 71 | u = uv_[0,:,:] 72 | v = uv_[1,:,:] 73 | elif uv_.shape[2] == 2: 74 | u = uv_[:,:,0] 75 | v = uv_[:,:,1] 76 | else: 77 | raise UVError('Wrong format for flow input') 78 | else: 79 | u = uv 80 | 81 | assert(u.shape == v.shape) 82 | height,width = u.shape 83 | f = open(filename,'wb') 84 | # write the header 85 | f.write(TAG_CHAR) 86 | np.array(width).astype(np.int32).tofile(f) 87 | np.array(height).astype(np.int32).tofile(f) 88 | # arrange into matrix form 89 | tmp = np.zeros((height, width*nBands)) 90 | tmp[:,np.arange(width)*2] = u 91 | tmp[:,np.arange(width)*2 + 1] = v 92 | tmp.astype(np.float32).tofile(f) 93 | f.close() 94 | 95 | 96 | def flow_read_png(fpath): 97 | """ 98 | Read KITTI optical flow, returns u,v,valid mask 99 | 100 | """ 101 | if not has_png: 102 | print('Error. Please install the PyPNG library') 103 | return 104 | 105 | R = png.Reader(fpath) 106 | width,height,data,_ = R.asDirect() 107 | # This only worked with python2. 108 | #I = np.array(map(lambda x:x,data)).reshape((height,width,3)) 109 | I = np.array([x for x in data]).reshape((height,width,3)) 110 | u_ = I[:,:,0] 111 | v_ = I[:,:,1] 112 | valid = I[:,:,2] 113 | 114 | u = (u_.astype('float64')-2**15)/64.0 115 | v = (v_.astype('float64')-2**15)/64.0 116 | 117 | return u,v,valid 118 | 119 | 120 | def flow_write_png(fpath,u,v,valid=None): 121 | """ 122 | Write KITTI optical flow. 123 | 124 | """ 125 | if not has_png: 126 | print('Error. Please install the PyPNG library') 127 | return 128 | 129 | 130 | if valid==None: 131 | valid_ = np.ones(u.shape,dtype='uint16') 132 | else: 133 | valid_ = valid.astype('uint16') 134 | 135 | 136 | u = u.astype('float64') 137 | v = v.astype('float64') 138 | 139 | u_ = ((u*64.0)+2**15).astype('uint16') 140 | v_ = ((v*64.0)+2**15).astype('uint16') 141 | 142 | I = np.dstack((u_,v_,valid_)) 143 | 144 | W = png.Writer(width=u.shape[1], 145 | height=u.shape[0], 146 | bitdepth=16, 147 | planes=3) 148 | 149 | with open(fpath,'wb') as fil: 150 | W.write(fil,I.reshape((-1,3*u.shape[1]))) 151 | -------------------------------------------------------------------------------- /pretrained/b2f_kitti.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/pretrained/b2f_kitti.pth.tar -------------------------------------------------------------------------------- /pretrained/b2f_sintel.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/pretrained/b2f_sintel.pth.tar -------------------------------------------------------------------------------- /samples/000010_09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/samples/000010_09.png -------------------------------------------------------------------------------- /samples/000010_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/samples/000010_10.png -------------------------------------------------------------------------------- /samples/000010_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/samples/000010_11.png -------------------------------------------------------------------------------- /test_back2future.py: -------------------------------------------------------------------------------- 1 | # Author: Anurag Ranjan 2 | # Copyright (c) 2018, Max Planck Society 3 | 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | from path import Path 8 | from torch.autograd import Variable 9 | from torchvision.transforms import ToTensor 10 | 11 | from scipy.misc import imread, imresize 12 | from tqdm import tqdm 13 | import numpy as np 14 | 15 | from back2future import Model 16 | import flow_io 17 | 18 | parser = argparse.ArgumentParser(description='Code to test performace of Back2Future models on KITTI benchmarks', 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | parser.add_argument('--pretrained-flow', dest='pretrained_flow', default=None, metavar='PATH', 21 | help='path to pre-trained Flow net model') 22 | parser.add_argument('--kitti-dir', dest='kitti_dir', default=None, metavar='PATH', 23 | help='path to KITTI 2015 directory') 24 | 25 | 26 | def main(): 27 | global args 28 | args = parser.parse_args() 29 | flow_loader_h, flow_loader_w = 384, 1280 30 | valid_flow_transform = Scale(h=flow_loader_h, w=flow_loader_w) 31 | val_flow_set = KITTI2015(root=args.kitti_dir, 32 | transform=valid_flow_transform) 33 | 34 | val_flow_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1, shuffle=False, 35 | num_workers=2, pin_memory=True) 36 | 37 | flow_net = Model(pretrained=args.pretrained_flow).cuda() 38 | 39 | flow_net.eval() 40 | error_names = ['epe_total'] 41 | errors = AverageMeter(i=len(error_names)) 42 | 43 | for i, (tgt_img, ref_imgs, flow_gt) in enumerate(tqdm(val_flow_loader)): 44 | tgt_img_var = Variable(tgt_img.cuda(), volatile=True) 45 | ref_imgs_var = [Variable(img.cuda(), volatile=True) for img in ref_imgs] 46 | flow_gt_var = Variable(flow_gt.cuda(), volatile=True) 47 | 48 | # compute output 49 | flow_fwd, flow_bwd, occ = flow_net(tgt_img_var, ref_imgs_var) 50 | epe = compute_epe(gt=flow_gt_var, pred=flow_fwd[0]) 51 | errors.update(epe) 52 | 53 | print("Averge EPE",errors.avg ) 54 | 55 | 56 | class KITTI2015(torch.utils.data.Dataset): 57 | """ 58 | Kitti 2015 loader 59 | """ 60 | 61 | def __init__(self, root, transform=None, N=200, train=True, seed=0): 62 | self.root = Path(root) 63 | self.scenes = range(N) 64 | self.N = N 65 | self.transform = transform 66 | self.phase = 'training' if train else 'testing' 67 | self.seq_ids = [9, 11] 68 | 69 | def __getitem__(self, index): 70 | tgt_img_path = self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_10.png') 71 | ref_img_paths = [self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_'+str(k).zfill(2)+'.png') for k in self.seq_ids] 72 | gt_flow_path = self.root.joinpath('data_scene_flow', self.phase, 'flow_occ', str(index).zfill(6)+'_10.png') 73 | 74 | tgt_img = load_as_float(tgt_img_path) 75 | ref_imgs = [load_as_float(ref_img) for ref_img in ref_img_paths] 76 | 77 | u,v,valid = flow_io.flow_read_png(gt_flow_path) 78 | gtFlow = np.dstack((u,v,valid)) 79 | gtFlow = torch.FloatTensor(gtFlow.transpose(2,0,1)) 80 | 81 | if self.transform is not None: 82 | imgs = self.transform([tgt_img] + ref_imgs) 83 | tgt_img = imgs[0] 84 | ref_imgs = imgs[1:] 85 | 86 | return tgt_img, ref_imgs, gtFlow 87 | 88 | def __len__(self): 89 | return self.N 90 | 91 | 92 | class Scale(object): 93 | """Scales images to a particular size""" 94 | def __init__(self, h, w): 95 | self.h = h 96 | self.w = w 97 | 98 | def __call__(self, images): 99 | in_h, in_w, _ = images[0].shape 100 | scaled_h, scaled_w = self.h , self.w 101 | scaled_images = [ToTensor()(imresize(im, (scaled_h, scaled_w))) for im in images] 102 | return scaled_images 103 | 104 | def compute_epe(gt, pred): 105 | _, _, h_pred, w_pred = pred.size() 106 | bs, nc, h_gt, w_gt = gt.size() 107 | 108 | u_gt, v_gt = gt[:,0,:,:], gt[:,1,:,:] 109 | pred = nn.functional.upsample(pred, size=(h_gt, w_gt), mode='bilinear') 110 | u_pred = pred[:,0,:,:] * (w_gt/w_pred) 111 | v_pred = pred[:,1,:,:] * (h_gt/h_pred) 112 | 113 | epe = torch.sqrt(torch.pow((u_gt - u_pred), 2) + torch.pow((v_gt - v_pred), 2)) 114 | 115 | if nc == 3: 116 | valid = gt[:,2,:,:] 117 | epe = epe * valid 118 | avg_epe = epe.sum()/(valid.sum() + 1e-6) 119 | else: 120 | avg_epe = epe.sum()/(bs*h_gt*w_gt) 121 | 122 | if type(avg_epe) == Variable: avg_epe = avg_epe.data 123 | return avg_epe[0] 124 | 125 | def load_as_float(path): 126 | return imread(path).astype(np.float32) 127 | 128 | class AverageMeter(object): 129 | """Computes and stores the average and current value""" 130 | 131 | def __init__(self, i=1, precision=3): 132 | self.meters = i 133 | self.precision = precision 134 | self.reset(self.meters) 135 | 136 | def reset(self, i): 137 | self.val = [0]*i 138 | self.avg = [0]*i 139 | self.sum = [0]*i 140 | self.count = 0 141 | 142 | def update(self, val, n=1): 143 | if not isinstance(val, list): 144 | val = [val] 145 | assert(len(val) == self.meters) 146 | self.count += n 147 | for i,v in enumerate(val): 148 | self.val[i] = v 149 | self.sum[i] += v * n 150 | self.avg[i] = self.sum[i] / self.count 151 | 152 | def __repr__(self): 153 | val = ' '.join(['{:.{}f}'.format(v, self.precision) for v in self.val]) 154 | avg = ' '.join(['{:.{}f}'.format(a, self.precision) for a in self.avg]) 155 | return '{} ({})'.format(val, avg) 156 | 157 | if __name__ == '__main__': 158 | main() 159 | --------------------------------------------------------------------------------