├── .gitignore
├── README.md
├── back2future.py
├── convert.lua
├── correlation_package
├── .make.sh.swp
├── __init__.py
├── build.py
├── functions
│ ├── __init__.py
│ └── correlation.py
├── make.sh
├── modules
│ ├── __init__.py
│ └── correlation.py
└── src
│ ├── correlation.c
│ ├── correlation.h
│ ├── correlation_cuda.c
│ ├── correlation_cuda.h
│ ├── correlation_cuda_kernel.cu
│ └── correlation_cuda_kernel.h
├── demo.py
├── flow_io.py
├── pretrained
├── b2f_kitti.pth.tar
└── b2f_sintel.pth.tar
├── samples
├── 000010_09.png
├── 000010_10.png
└── 000010_11.png
└── test_back2future.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.o
3 | correlation_package/_ext/*
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This is a Pytorch implementation of
2 |
3 | Janai, J., Güney, F., Ranjan, A., Black, M. and Geiger, A., **Unsupervised Learning of Multi-Frame Optical Flow with Occlusions.** ECCV 2018.
4 |
5 | [[Link to Paper](http://www.cvlibs.net/publications/Janai2018ECCV.pdf)] [[Project Page](https://avg.is.tuebingen.mpg.de/research_projects/back2future)] [[Original Torch Code](https://github.com/jjanai/back2future)]
6 |
7 | ## Requirements
8 | - Runs and tested on [Pytorch 0.3.1](https://pytorch.org/get-started/previous-versions/), it should be compatible with higher versions with little/no modifications.
9 | - Correlation package is taken from [NVIDIA/flownet2-pytorch](https://github.com/NVIDIA/flownet2-pytorch/) and it can be installed using
10 | ```bash
11 | cd correlation_package
12 | bash make.sh
13 | ```
14 | If you are using Pytorch>0.3.1, you can use correlation layer from [here](https://github.com/ClementPinard/Pytorch-Correlation-extension).
15 | ## Usage
16 | To use the model, go to your favorite python environment
17 | ```python
18 | from back2future import Model
19 | model = Model(pretrained='pretrained/path_to_your_favorite_model')
20 | ```
21 | There are two pretrained models in `pretrained/`, that are fine tuned on Sintel and KITTI in an unsupervised way.
22 |
23 | Refer to `demo.py` for more.
24 |
25 | ## Testing
26 | To test performance on KITTI, use
27 | ```bash
28 | python3 test_back2future.py --pretrained-flow path/to/pretrained/model --kitti-dir path/to/kitti/2015/root
29 | ```
30 |
31 | ## Training
32 | Please use the [[original torch code](https://github.com/jjanai/back2future)] for training new models.
33 |
34 | ## License
35 | This is a reimplementation. License for the original work can be found at [JJanai/back2future](https://github.com/JJanai/back2future/blob/master/LICENSE).
36 |
37 | ## While using this code, please cite
38 | ```
39 | @inproceedings{Janai2018ECCV,
40 | title = {Unsupervised Learning of Multi-Frame Optical Flow with Occlusions },
41 | author = {Janai, Joel and G{"u}ney, Fatma and Ranjan, Anurag and Black, Michael J. and Geiger, Andreas},
42 | booktitle = {European Conference on Computer Vision (ECCV)},
43 | volume = {Lecture Notes in Computer Science, vol 11220},
44 | pages = {713--731},
45 | publisher = {Springer, Cham},
46 | month = sep,
47 | year = {2018},
48 | month_numeric = {9}
49 | }
50 | ```
51 |
--------------------------------------------------------------------------------
/back2future.py:
--------------------------------------------------------------------------------
1 | # Author: Anurag Ranjan
2 | # Copyright (c) 2018, Max Planck Society
3 |
4 | import os
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.utils.serialization import load_lua
11 | from torch.autograd import Variable
12 | from correlation_package.modules.correlation import Correlation
13 |
14 | def conv_feat_block(nIn, nOut):
15 | return nn.Sequential(
16 | nn.Conv2d(nIn, nOut, kernel_size=3, stride=2, padding=1),
17 | nn.LeakyReLU(0.2),
18 | nn.Conv2d(nOut, nOut, kernel_size=3, stride=1, padding=1),
19 | nn.LeakyReLU(0.2)
20 | )
21 |
22 | def conv_dec_block(nIn):
23 | return nn.Sequential(
24 | nn.Conv2d(nIn, 128, kernel_size=3, stride=1, padding=1),
25 | nn.LeakyReLU(0.2),
26 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
27 | nn.LeakyReLU(0.2),
28 | nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1),
29 | nn.LeakyReLU(0.2),
30 | nn.Conv2d(96, 64, kernel_size=3, stride=1, padding=1),
31 | nn.LeakyReLU(0.2),
32 | nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
33 | nn.LeakyReLU(0.2),
34 | nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1)
35 | )
36 |
37 |
38 | class Model(nn.Module):
39 | def __init__(self, pretrained=None):
40 | super(Model, self).__init__()
41 |
42 | idx = [list(range(n, -1, -9)) for n in range(80,71,-1)]
43 | idx = list(np.array(idx).flatten())
44 | self.idx_fwd = Variable(torch.LongTensor(np.array(idx)).cuda(), requires_grad=False)
45 | self.idx_bwd = Variable(torch.LongTensor(np.array(list(reversed(idx)))).cuda(), requires_grad=False)
46 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
47 | self.softmax2d = nn.Softmax2d()
48 |
49 | self.conv1a = conv_feat_block(3,16)
50 | self.conv1b = conv_feat_block(3,16)
51 | self.conv1c = conv_feat_block(3,16)
52 |
53 | self.conv2a = conv_feat_block(16,32)
54 | self.conv2b = conv_feat_block(16,32)
55 | self.conv2c = conv_feat_block(16,32)
56 |
57 | self.conv3a = conv_feat_block(32,64)
58 | self.conv3b = conv_feat_block(32,64)
59 | self.conv3c = conv_feat_block(32,64)
60 |
61 | self.conv4a = conv_feat_block(64,96)
62 | self.conv4b = conv_feat_block(64,96)
63 | self.conv4c = conv_feat_block(64,96)
64 |
65 | self.conv5a = conv_feat_block(96,128)
66 | self.conv5b = conv_feat_block(96,128)
67 | self.conv5c = conv_feat_block(96,128)
68 |
69 | self.conv6a = conv_feat_block(128,192)
70 | self.conv6b = conv_feat_block(128,192)
71 | self.conv6c = conv_feat_block(128,192)
72 |
73 | self.corr = Correlation(pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1)
74 |
75 | self.decoder_fwd6 = conv_dec_block(162)
76 | self.decoder_bwd6 = conv_dec_block(162)
77 | self.decoder_fwd5 = conv_dec_block(292)
78 | self.decoder_bwd5 = conv_dec_block(292)
79 | self.decoder_fwd4 = conv_dec_block(260)
80 | self.decoder_bwd4 = conv_dec_block(260)
81 | self.decoder_fwd3 = conv_dec_block(228)
82 | self.decoder_bwd3 = conv_dec_block(228)
83 | self.decoder_fwd2 = conv_dec_block(196)
84 | self.decoder_bwd2 = conv_dec_block(196)
85 |
86 | self.decoder_occ6 = conv_dec_block(354)
87 | self.decoder_occ5 = conv_dec_block(292)
88 | self.decoder_occ4 = conv_dec_block(260)
89 | self.decoder_occ3 = conv_dec_block(228)
90 | self.decoder_occ2 = conv_dec_block(196)
91 |
92 | if pretrained is not None:
93 | self.load_state_dict(torch.load(pretrained))
94 | print('Model loaded from ', pretrained)
95 |
96 | def load(self, key, load_id):
97 | module = getattr(self, key)
98 | loadpath = os.path.join(LOAD_DIR, load_id)
99 | for i,m in enumerate(module):
100 | if type(m)==nn.Conv2d:
101 | weight_path = loadpath + '_' + str(i+1) + 'weight.t7'
102 | bias_path = loadpath + '_' + str(i+1) + 'bias.t7'
103 | m.weight.data.copy_(load_lua(weight_path))
104 | m.bias.data.copy_(load_lua(bias_path))
105 |
106 | def initialize(self):
107 | #This is used for the first load after saving weight files from lua
108 | map = {'conv1a': '3', 'conv1b': '6', 'conv1c': '22',
109 | 'conv2a': '4', 'conv2b': '7', 'conv2c': '23',
110 | 'conv3a': '9', 'conv3b': '10', 'conv3c': '24',
111 | 'conv4a': '12', 'conv4b': '13', 'conv4c': '25',
112 | 'conv5a': '15', 'conv5b': '16', 'conv5c': '26',
113 | 'conv6a': '18', 'conv6b': '19', 'conv6c': '27',
114 | 'decoder_fwd6': '30', 'decoder_bwd6': '93', 'decoder_occ6': '183',
115 | 'decoder_fwd5': '45', 'decoder_bwd5': '96', 'decoder_occ5': '164',
116 | 'decoder_fwd4': '60', 'decoder_bwd4': '99', 'decoder_occ4': '145',
117 | 'decoder_fwd3': '75', 'decoder_bwd3': '102', 'decoder_occ3': '126',
118 | 'decoder_fwd2': '90', 'decoder_bwd2': '105', 'decoder_occ2': '109'
119 | }
120 |
121 | for key in map.keys():
122 | print("Loading ", key)
123 | self.load(key, map[key])
124 |
125 | def normalize(self, ims):
126 | imt = []
127 | for im in ims:
128 | im[:,0,:,:] = im[:,0,:,:] - 0.485 # Red
129 | im[:,1,:,:] = im[:,1,:,:] - 0.456 # Green
130 | im[:,2,:,:] = im[:,2,:,:] - 0.406 # Blue
131 |
132 | im[:,0,:,:] = im[:,0,:,:] / 0.229 # Red
133 | im[:,1,:,:] = im[:,1,:,:] / 0.224 # Green
134 | im[:,2,:,:] = im[:,2,:,:] / 0.225 # Blue
135 |
136 | imt.append(im)
137 | return imt
138 |
139 | def forward(self, im_tar, im_refs):
140 | '''
141 | Arguments:
142 | im_tar : Centre Frame
143 | im_refs : List constaining [Past_Frame, Future_Frame]
144 | '''
145 | im_norm = self.normalize([im_tar] + im_refs)
146 |
147 | feat1a = self.conv1a(im_norm[0])
148 | feat2a = self.conv2a(feat1a)
149 | feat3a = self.conv3a(feat2a)
150 | feat4a = self.conv4a(feat3a)
151 | feat5a = self.conv5a(feat4a)
152 | feat6a = self.conv6a(feat5a)
153 |
154 | feat1b = self.conv1b(im_norm[2])
155 | feat2b = self.conv2b(feat1b)
156 | feat3b = self.conv3b(feat2b)
157 | feat4b = self.conv4b(feat3b)
158 | feat5b = self.conv5b(feat4b)
159 | feat6b = self.conv6b(feat5b)
160 |
161 | feat1c = self.conv1c(im_norm[1])
162 | feat2c = self.conv2c(feat1c)
163 | feat3c = self.conv3c(feat2c)
164 | feat4c = self.conv4c(feat3c)
165 | feat5c = self.conv5c(feat4c)
166 | feat6c = self.conv6c(feat5c)
167 |
168 | corr6_fwd = self.corr(feat6a, feat6b)
169 | corr6_fwd = corr6_fwd.index_select(1,self.idx_fwd)
170 | corr6_bwd = self.corr(feat6a, feat6c)
171 | corr6_bwd = corr6_bwd.index_select(1,self.idx_bwd)
172 | corr6 = torch.cat((corr6_fwd, corr6_bwd), 1)
173 |
174 | flow6_fwd = self.decoder_fwd6(corr6)
175 | flow6_fwd_up = self.upsample(flow6_fwd)
176 | flow6_bwd = self.decoder_bwd6(corr6)
177 | flow6_bwd_up = self.upsample(flow6_bwd)
178 | feat5b_warped = self.warp(feat5b, 0.625*flow6_fwd_up)
179 | feat5c_warped = self.warp(feat5c, -0.625*flow6_bwd_up)
180 |
181 | occ6_feat = torch.cat((corr6, feat6a), 1)
182 | occ6 = self.softmax2d(self.decoder_occ6(occ6_feat))
183 |
184 | corr5_fwd = self.corr(feat5a, feat5b_warped)
185 | corr5_fwd = corr5_fwd.index_select(1,self.idx_fwd)
186 | corr5_bwd = self.corr(feat5a, feat5c_warped)
187 | corr5_bwd = corr5_bwd.index_select(1,self.idx_bwd)
188 | corr5 = torch.cat((corr5_fwd, corr5_bwd), 1)
189 |
190 | upfeat5_fwd = torch.cat((corr5, feat5a, flow6_fwd_up), 1)
191 | flow5_fwd = self.decoder_fwd5(upfeat5_fwd)
192 | flow5_fwd_up = self.upsample(flow5_fwd)
193 | upfeat5_bwd = torch.cat((corr5, feat5a, flow6_bwd_up),1)
194 | flow5_bwd = self.decoder_bwd5(upfeat5_bwd)
195 | flow5_bwd_up = self.upsample(flow5_bwd)
196 | feat4b_warped = self.warp(feat4b, 1.25*flow5_fwd_up)
197 | feat4c_warped = self.warp(feat4c, -1.25*flow5_bwd_up)
198 |
199 | occ5 = self.softmax2d(self.decoder_occ5(upfeat5_fwd))
200 |
201 | corr4_fwd = self.corr(feat4a, feat4b_warped)
202 | corr4_fwd = corr4_fwd.index_select(1,self.idx_fwd)
203 | corr4_bwd = self.corr(feat4a, feat4c_warped)
204 | corr4_bwd = corr4_bwd.index_select(1,self.idx_bwd)
205 | corr4 = torch.cat((corr4_fwd, corr4_bwd), 1)
206 |
207 | upfeat4_fwd = torch.cat((corr4, feat4a, flow5_fwd_up), 1)
208 | flow4_fwd = self.decoder_fwd4(upfeat4_fwd)
209 | flow4_fwd_up = self.upsample(flow4_fwd)
210 | upfeat4_bwd = torch.cat((corr4, feat4a, flow5_bwd_up),1)
211 | flow4_bwd = self.decoder_bwd4(upfeat4_bwd)
212 | flow4_bwd_up = self.upsample(flow4_bwd)
213 | feat3b_warped = self.warp(feat3b, 2.5*flow4_fwd_up)
214 | feat3c_warped = self.warp(feat3c, -2.5*flow4_bwd_up)
215 |
216 | occ4 = self.softmax2d(self.decoder_occ4(upfeat4_fwd))
217 |
218 | corr3_fwd = self.corr(feat3a, feat3b_warped)
219 | corr3_fwd = corr3_fwd.index_select(1,self.idx_fwd)
220 | corr3_bwd = self.corr(feat3a, feat3c_warped)
221 | corr3_bwd = corr3_bwd.index_select(1,self.idx_bwd)
222 | corr3 = torch.cat((corr3_fwd, corr3_bwd), 1)
223 |
224 | upfeat3_fwd = torch.cat((corr3, feat3a, flow4_fwd_up), 1)
225 | flow3_fwd = self.decoder_fwd3(upfeat3_fwd)
226 | flow3_fwd_up = self.upsample(flow3_fwd)
227 | upfeat3_bwd = torch.cat((corr3, feat3a, flow4_bwd_up),1)
228 | flow3_bwd = self.decoder_bwd3(upfeat3_bwd)
229 | flow3_bwd_up = self.upsample(flow3_bwd)
230 | feat2b_warped = self.warp(feat2b, 5.0*flow3_fwd_up)
231 | feat2c_warped = self.warp(feat2c, -5.0*flow3_bwd_up)
232 |
233 | occ3 = self.softmax2d(self.decoder_occ3(upfeat3_fwd))
234 |
235 | corr2_fwd = self.corr(feat2a, feat2b_warped)
236 | corr2_fwd = corr2_fwd.index_select(1,self.idx_fwd)
237 | corr2_bwd = self.corr(feat2a, feat2c_warped)
238 | corr2_bwd = corr2_bwd.index_select(1,self.idx_bwd)
239 | corr2 = torch.cat((corr2_fwd, corr2_bwd), 1)
240 |
241 | upfeat2_fwd = torch.cat((corr2, feat2a, flow3_fwd_up), 1)
242 | flow2_fwd = self.decoder_fwd2(upfeat2_fwd)
243 | flow2_fwd_up = self.upsample(flow2_fwd)
244 | upfeat2_bwd = torch.cat((corr2, feat2a, flow3_bwd_up),1)
245 | flow2_bwd = self.decoder_bwd2(upfeat2_bwd)
246 | flow2_bwd_up = self.upsample(flow2_bwd)
247 |
248 | occ2 = self.softmax2d(self.decoder_occ2(upfeat2_fwd))
249 |
250 | flow2_fwd_fullres = 20*self.upsample(flow2_fwd_up)
251 | flow3_fwd_fullres = 10*self.upsample(flow3_fwd_up)
252 | flow4_fwd_fullres = 5*self.upsample(flow4_fwd_up)
253 | flow5_fwd_fullres = 2.5*self.upsample(flow5_fwd_up)
254 | flow6_fwd_fullres = 1.25*self.upsample(flow6_fwd_up)
255 |
256 | flow2_bwd_fullres = -20*self.upsample(flow2_bwd_up)
257 | flow3_bwd_fullres = -10*self.upsample(flow3_bwd_up)
258 | flow4_bwd_fullres = -5*self.upsample(flow4_bwd_up)
259 | flow5_bwd_fullres = -2.5*self.upsample(flow5_bwd_up)
260 | flow6_bwd_fullres = -1.25*self.upsample(flow6_bwd_up)
261 |
262 | occ2_fullres = F.upsample(occ2, scale_factor=4)
263 | occ3_fullres = F.upsample(occ3, scale_factor=4)
264 | occ4_fullres = F.upsample(occ4, scale_factor=4)
265 | occ5_fullres = F.upsample(occ5, scale_factor=4)
266 | occ6_fullres = F.upsample(occ6, scale_factor=4)
267 |
268 | flow_fwd = [flow2_fwd_fullres, flow3_fwd_fullres, flow4_fwd_fullres, flow5_fwd_fullres, flow6_fwd_fullres]
269 | flow_bwd = [flow2_bwd_fullres, flow3_bwd_fullres, flow4_bwd_fullres, flow5_bwd_fullres, flow6_bwd_fullres]
270 | occ = [occ2_fullres, occ3_fullres, occ4_fullres, occ5_fullres, occ6_fullres]
271 |
272 | return flow_fwd, flow_bwd, occ
273 |
274 | def warp(self, x, flo):
275 | """
276 | warp an image/tensor (im2) back to im1, according to the optical flow
277 | x: [B, C, H, W] (im2)
278 | flo: [B, 2, H, W] flow
279 | """
280 | B, C, H, W = x.size()
281 | # mesh grid
282 | xx = torch.arange(0, W).view(1,-1).repeat(H,1)
283 | yy = torch.arange(0, H).view(-1,1).repeat(1,W)
284 | xx = xx.view(1,1,H,W).repeat(B,1,1,1)
285 | yy = yy.view(1,1,H,W).repeat(B,1,1,1)
286 | grid = torch.cat((xx,yy),1).float()
287 |
288 | if x.is_cuda:
289 | grid = grid.cuda()
290 | vgrid = Variable(grid) + flo
291 |
292 | # scale grid to [-1,1]
293 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:]/max(W-1,1)-1.0
294 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:]/max(H-1,1)-1.0
295 |
296 | vgrid = vgrid.permute(0,2,3,1)
297 | output = nn.functional.grid_sample(x, vgrid, padding_mode='border')
298 | mask = torch.autograd.Variable(torch.ones(x.size()), requires_grad=False).cuda()
299 | mask = nn.functional.grid_sample(mask, vgrid)
300 |
301 | # if W==128:
302 | # np.save('mask.npy', mask.cpu().data.numpy())
303 | # np.save('warp.npy', output.cpu().data.numpy())
304 |
305 | mask[mask.data<0.9999] = 0
306 | mask[mask.data>0] = 1
307 |
308 | return output#*mask
309 |
--------------------------------------------------------------------------------
/convert.lua:
--------------------------------------------------------------------------------
1 | -- Author: Anurag Ranjan
2 | -- Copyright (c) 2018, Max Planck Society
3 |
4 | require 'paths'
5 | require 'torch'
6 | require 'cutorch'
7 | require 'nn'
8 | require 'cunn'
9 | require 'cudnn'
10 | require 'stn'
11 | require 'spy'
12 | require 'nngraph'
13 | require 'models/CostVolMulti'
14 |
15 | SAVE_DIR = 'pretrained/pwc'
16 | model = torch.load('pretrained/Roaming_KITTI_model_300.t7' )
17 |
18 | function save_sequential(name, model)
19 | for i = 1, #model do
20 | module = model:get(i)
21 | if tostring(torch.type(module)) == 'nn.SpatialConvolution' then
22 | torch.save(paths.concat(SAVE_DIR, name..'_'..tostring(i)..'weight.t7'), module.weight)
23 | torch.save(paths.concat(SAVE_DIR, name..'_'..tostring(i)..'bias.t7'), module.bias)
24 | end
25 | end
26 | end
27 |
28 | for i = 1, 200 do
29 | print('Traversing node' ..i)
30 | node = model:get(i)
31 | if tostring(torch.type(node)) == 'nn.Sequential' then
32 | nodenn = cudnn.convert(node, nn)
33 | nodenn_float = nodenn:float()
34 | save_sequential(tostring(i), nodenn_float)
35 | end
36 | end
37 |
38 | function warpingUnit()
39 | local I = nn.Identity()()
40 | local F = nn.Identity()()
41 | local input = I - nn.Transpose({2,3}, {3,4})
42 | local flow = F - nn.Transpose({2,3}, {3,4})
43 | local W = {input, flow} - nn.BilinearSamplerBHWD() - nn.Transpose({3,4}, {2,3})
44 | local model = nn.gModule({I, F}, {W})
45 | return model
46 | end
47 |
48 | for k, v in ipairs(model.forwardnodes) do
49 | print(k-1, v.id, v.data.module)
50 | v.data.annotations.name = tostring(k-1)
51 | end
52 |
--------------------------------------------------------------------------------
/correlation_package/.make.sh.swp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/correlation_package/.make.sh.swp
--------------------------------------------------------------------------------
/correlation_package/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/correlation_package/__init__.py
--------------------------------------------------------------------------------
/correlation_package/build.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.ffi
4 |
5 | this_folder = os.path.dirname(os.path.abspath(__file__)) + '/'
6 |
7 | Headers = []
8 | Sources = []
9 | Defines = []
10 | Objects = []
11 |
12 | if torch.cuda.is_available() == True:
13 | Headers += ['src/correlation_cuda.h']
14 | Sources += ['src/correlation_cuda.c']
15 | Defines += [('WITH_CUDA', None)]
16 | Objects += ['src/correlation_cuda_kernel.o']
17 |
18 | ffi = torch.utils.ffi.create_extension(
19 | name='_ext.correlation',
20 | headers=Headers,
21 | sources=Sources,
22 | verbose=False,
23 | with_cuda=True,
24 | package=False,
25 | relative_to=this_folder,
26 | define_macros=Defines,
27 | extra_objects=[os.path.join(this_folder, Object) for Object in Objects]
28 | )
29 |
30 | if __name__ == '__main__':
31 | ffi.build()
--------------------------------------------------------------------------------
/correlation_package/functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/correlation_package/functions/__init__.py
--------------------------------------------------------------------------------
/correlation_package/functions/correlation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from .._ext import correlation
4 |
5 |
6 | class CorrelationFunction(Function):
7 |
8 | def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
9 | super(CorrelationFunction, self).__init__()
10 | self.pad_size = pad_size
11 | self.kernel_size = kernel_size
12 | self.max_displacement = max_displacement
13 | self.stride1 = stride1
14 | self.stride2 = stride2
15 | self.corr_multiply = corr_multiply
16 | # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)
17 |
18 | def forward(self, input1, input2):
19 | self.save_for_backward(input1, input2)
20 |
21 | assert(input1.is_contiguous() == True)
22 | assert(input2.is_contiguous() == True)
23 |
24 | with torch.cuda.device_of(input1):
25 | rbot1 = input1.new()
26 | rbot2 = input2.new()
27 | output = input1.new()
28 |
29 | correlation.Correlation_forward_cuda(input1, input2, rbot1, rbot2, output,
30 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
31 |
32 | return output
33 |
34 | def backward(self, grad_output):
35 | input1, input2 = self.saved_tensors
36 |
37 | assert(grad_output.is_contiguous() == True)
38 |
39 | with torch.cuda.device_of(input1):
40 | rbot1 = input1.new()
41 | rbot2 = input2.new()
42 |
43 | grad_input1 = input1.new()
44 | grad_input2 = input2.new()
45 |
46 | correlation.Correlation_backward_cuda(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
47 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
48 |
49 | return grad_input1, grad_input2
--------------------------------------------------------------------------------
/correlation_package/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))")
3 |
4 | cd src
5 |
6 | echo "Compiling correlation kernels by nvcc..."
7 |
8 | rm correlation_cuda_kernel.o
9 | rm -r ../_ext
10 |
11 | nvcc -c -o correlation_cuda_kernel.o correlation_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
12 |
13 | cd ../
14 | python build.py
15 |
--------------------------------------------------------------------------------
/correlation_package/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/correlation_package/modules/__init__.py
--------------------------------------------------------------------------------
/correlation_package/modules/correlation.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules.module import Module
2 |
3 | from ..functions.correlation import CorrelationFunction
4 |
5 | class Correlation(Module):
6 | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
7 | super(Correlation, self).__init__()
8 | self.pad_size = pad_size
9 | self.kernel_size = kernel_size
10 | self.max_displacement = max_displacement
11 | self.stride1 = stride1
12 | self.stride2 = stride2
13 | self.corr_multiply = corr_multiply
14 |
15 | def forward(self, input1, input2):
16 |
17 | result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2)
18 |
19 | return result
20 |
--------------------------------------------------------------------------------
/correlation_package/src/correlation.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | int Correlation_forward_cpu(THFloatTensor *input1,
4 | THFloatTensor *input2,
5 | THFloatTensor *rInput1,
6 | THFloatTensor *rInput2,
7 | THFloatTensor *output,
8 | int pad_size,
9 | int kernel_size,
10 | int max_displacement,
11 | int stride1,
12 | int stride2,
13 | int corr_type_multiply)
14 | {
15 | return 1;
16 | }
17 |
18 | int Correlation_backward_cpu(THFloatTensor *input1,
19 | THFloatTensor *input2,
20 | THFloatTensor *rInput1,
21 | THFloatTensor *rInput2,
22 | THFloatTensor *gradOutput,
23 | THFloatTensor *gradInput1,
24 | THFloatTensor *gradInput2,
25 | int pad_size,
26 | int kernel_size,
27 | int max_displacement,
28 | int stride1,
29 | int stride2,
30 | int corr_type_multiply)
31 | {
32 | return 1;
33 | }
34 |
--------------------------------------------------------------------------------
/correlation_package/src/correlation.h:
--------------------------------------------------------------------------------
1 | int Correlation_forward_cpu(THFloatTensor *input1,
2 | THFloatTensor *input2,
3 | THFloatTensor *rInput1,
4 | THFloatTensor *rInput2,
5 | THFloatTensor *output,
6 | int pad_size,
7 | int kernel_size,
8 | int max_displacement,
9 | int stride1,
10 | int stride2,
11 | int corr_type_multiply);
12 |
13 | int Correlation_backward_cpu(THFloatTensor *input1,
14 | THFloatTensor *input2,
15 | THFloatTensor *rInput1,
16 | THFloatTensor *rInput2,
17 | THFloatTensor *gradOutput,
18 | THFloatTensor *gradInput1,
19 | THFloatTensor *gradInput2,
20 | int pad_size,
21 | int kernel_size,
22 | int max_displacement,
23 | int stride1,
24 | int stride2,
25 | int corr_type_multiply);
26 |
--------------------------------------------------------------------------------
/correlation_package/src/correlation_cuda.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "correlation_cuda_kernel.h"
5 |
6 | #define real float
7 |
8 | // symbol to be automatically resolved by PyTorch libs
9 | extern THCState *state;
10 |
11 | int Correlation_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, THCudaTensor *output,
12 | int pad_size,
13 | int kernel_size,
14 | int max_displacement,
15 | int stride1,
16 | int stride2,
17 | int corr_type_multiply)
18 | {
19 |
20 | int batchSize = input1->size[0];
21 | int nInputChannels = input1->size[1];
22 | int inputHeight = input1->size[2];
23 | int inputWidth = input1->size[3];
24 |
25 | int kernel_radius = (kernel_size - 1) / 2;
26 | int border_radius = kernel_radius + max_displacement;
27 |
28 | int paddedInputHeight = inputHeight + 2 * pad_size;
29 | int paddedInputWidth = inputWidth + 2 * pad_size;
30 |
31 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
32 |
33 | int outputHeight = ceil((float)(paddedInputHeight - 2 * border_radius) / (float)stride1);
34 | int outputwidth = ceil((float)(paddedInputWidth - 2 * border_radius) / (float)stride1);
35 |
36 | THCudaTensor_resize4d(state, rInput1, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels);
37 | THCudaTensor_resize4d(state, rInput2, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels);
38 | THCudaTensor_resize4d(state, output, batchSize, nOutputChannels, outputHeight, outputwidth);
39 |
40 | THCudaTensor_fill(state, rInput1, 0);
41 | THCudaTensor_fill(state, rInput2, 0);
42 | THCudaTensor_fill(state, output, 0);
43 |
44 | int success = 0;
45 | success = Correlation_forward_cuda_kernel( THCudaTensor_data(state, output),
46 | THCudaTensor_size(state, output, 0),
47 | THCudaTensor_size(state, output, 1),
48 | THCudaTensor_size(state, output, 2),
49 | THCudaTensor_size(state, output, 3),
50 | THCudaTensor_stride(state, output, 0),
51 | THCudaTensor_stride(state, output, 1),
52 | THCudaTensor_stride(state, output, 2),
53 | THCudaTensor_stride(state, output, 3),
54 |
55 | THCudaTensor_data(state, input1),
56 | THCudaTensor_size(state, input1, 1),
57 | THCudaTensor_size(state, input1, 2),
58 | THCudaTensor_size(state, input1, 3),
59 | THCudaTensor_stride(state, input1, 0),
60 | THCudaTensor_stride(state, input1, 1),
61 | THCudaTensor_stride(state, input1, 2),
62 | THCudaTensor_stride(state, input1, 3),
63 |
64 | THCudaTensor_data(state, input2),
65 | THCudaTensor_size(state, input2, 1),
66 | THCudaTensor_stride(state, input2, 0),
67 | THCudaTensor_stride(state, input2, 1),
68 | THCudaTensor_stride(state, input2, 2),
69 | THCudaTensor_stride(state, input2, 3),
70 |
71 | THCudaTensor_data(state, rInput1),
72 | THCudaTensor_data(state, rInput2),
73 |
74 | pad_size,
75 | kernel_size,
76 | max_displacement,
77 | stride1,
78 | stride2,
79 | corr_type_multiply,
80 |
81 | THCState_getCurrentStream(state));
82 |
83 | THCudaTensor_free(state, rInput1);
84 | THCudaTensor_free(state, rInput2);
85 |
86 | //check for errors
87 | if (!success) {
88 | THError("aborting");
89 | }
90 |
91 | return 1;
92 |
93 | }
94 |
95 | int Correlation_backward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, THCudaTensor *gradOutput,
96 | THCudaTensor *gradInput1, THCudaTensor *gradInput2,
97 | int pad_size,
98 | int kernel_size,
99 | int max_displacement,
100 | int stride1,
101 | int stride2,
102 | int corr_type_multiply)
103 | {
104 |
105 | int batchSize = input1->size[0];
106 | int nInputChannels = input1->size[1];
107 | int paddedInputHeight = input1->size[2]+ 2 * pad_size;
108 | int paddedInputWidth = input1->size[3]+ 2 * pad_size;
109 |
110 | int height = input1->size[2];
111 | int width = input1->size[3];
112 |
113 | THCudaTensor_resize4d(state, rInput1, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels);
114 | THCudaTensor_resize4d(state, rInput2, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels);
115 | THCudaTensor_resize4d(state, gradInput1, batchSize, nInputChannels, height, width);
116 | THCudaTensor_resize4d(state, gradInput2, batchSize, nInputChannels, height, width);
117 |
118 | THCudaTensor_fill(state, rInput1, 0);
119 | THCudaTensor_fill(state, rInput2, 0);
120 | THCudaTensor_fill(state, gradInput1, 0);
121 | THCudaTensor_fill(state, gradInput2, 0);
122 |
123 | int success = 0;
124 | success = Correlation_backward_cuda_kernel(
125 | THCudaTensor_data(state, gradOutput),
126 | THCudaTensor_size(state, gradOutput, 0),
127 | THCudaTensor_size(state, gradOutput, 1),
128 | THCudaTensor_size(state, gradOutput, 2),
129 | THCudaTensor_size(state, gradOutput, 3),
130 | THCudaTensor_stride(state, gradOutput, 0),
131 | THCudaTensor_stride(state, gradOutput, 1),
132 | THCudaTensor_stride(state, gradOutput, 2),
133 | THCudaTensor_stride(state, gradOutput, 3),
134 |
135 | THCudaTensor_data(state, input1),
136 | THCudaTensor_size(state, input1, 1),
137 | THCudaTensor_size(state, input1, 2),
138 | THCudaTensor_size(state, input1, 3),
139 | THCudaTensor_stride(state, input1, 0),
140 | THCudaTensor_stride(state, input1, 1),
141 | THCudaTensor_stride(state, input1, 2),
142 | THCudaTensor_stride(state, input1, 3),
143 |
144 | THCudaTensor_data(state, input2),
145 | THCudaTensor_stride(state, input2, 0),
146 | THCudaTensor_stride(state, input2, 1),
147 | THCudaTensor_stride(state, input2, 2),
148 | THCudaTensor_stride(state, input2, 3),
149 |
150 | THCudaTensor_data(state, gradInput1),
151 | THCudaTensor_stride(state, gradInput1, 0),
152 | THCudaTensor_stride(state, gradInput1, 1),
153 | THCudaTensor_stride(state, gradInput1, 2),
154 | THCudaTensor_stride(state, gradInput1, 3),
155 |
156 | THCudaTensor_data(state, gradInput2),
157 | THCudaTensor_size(state, gradInput2, 1),
158 | THCudaTensor_stride(state, gradInput2, 0),
159 | THCudaTensor_stride(state, gradInput2, 1),
160 | THCudaTensor_stride(state, gradInput2, 2),
161 | THCudaTensor_stride(state, gradInput2, 3),
162 |
163 | THCudaTensor_data(state, rInput1),
164 | THCudaTensor_data(state, rInput2),
165 | pad_size,
166 | kernel_size,
167 | max_displacement,
168 | stride1,
169 | stride2,
170 | corr_type_multiply,
171 | THCState_getCurrentStream(state));
172 |
173 | THCudaTensor_free(state, rInput1);
174 | THCudaTensor_free(state, rInput2);
175 |
176 | if (!success) {
177 | THError("aborting");
178 | }
179 | return 1;
180 | }
181 |
--------------------------------------------------------------------------------
/correlation_package/src/correlation_cuda.h:
--------------------------------------------------------------------------------
1 | int Correlation_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2,
2 | THCudaTensor *output,
3 | int pad_size,
4 | int kernel_size,
5 | int max_displacement,
6 | int stride1,
7 | int stride2,
8 | int corr_type_multiply);
9 |
10 | int Correlation_backward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2,
11 | THCudaTensor *gradOutput, THCudaTensor *gradInput1, THCudaTensor *gradInput2,
12 | int pad_size,
13 | int kernel_size,
14 | int max_displacement,
15 | int stride1,
16 | int stride2,
17 | int corr_type_multiply);
18 |
19 |
--------------------------------------------------------------------------------
/correlation_package/src/correlation_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "correlation_cuda_kernel.h"
4 |
5 | #define real float
6 |
7 | #define CUDA_NUM_THREADS 1024
8 | #define THREADS_PER_BLOCK 32
9 |
10 | __global__ void channels_first(float* input, float* rinput, int channels, int height, int width, int pad_size)
11 | {
12 | // n (batch size), c (num of channels), y (height), x (width)
13 | int n = blockIdx.x;
14 | int y = blockIdx.y;
15 | int x = blockIdx.z;
16 |
17 | int ch_off = threadIdx.x;
18 | float value;
19 |
20 | int dimcyx = channels * height * width;
21 | int dimyx = height * width;
22 |
23 | int p_dimx = (width + 2 * pad_size);
24 | int p_dimy = (height + 2 * pad_size);
25 | int p_dimyxc = channels * p_dimy * p_dimx;
26 | int p_dimxc = p_dimx * channels;
27 |
28 | for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) {
29 | value = input[n * dimcyx + c * dimyx + y * width + x];
30 | rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value;
31 | }
32 | }
33 |
34 | __global__ void Correlation_forward( float *output, int nOutputChannels, int outputHeight, int outputWidth,
35 | float *rInput1, int nInputChannels, int inputHeight, int inputWidth,
36 | float *rInput2,
37 | int pad_size,
38 | int kernel_size,
39 | int max_displacement,
40 | int stride1,
41 | int stride2)
42 | {
43 | // n (batch size), c (num of channels), y (height), x (width)
44 |
45 | int pInputWidth = inputWidth + 2 * pad_size;
46 | int pInputHeight = inputHeight + 2 * pad_size;
47 |
48 | int kernel_rad = (kernel_size - 1) / 2;
49 | int displacement_rad = max_displacement / stride2;
50 | int displacement_size = 2 * displacement_rad + 1;
51 |
52 | int n = blockIdx.x;
53 | int y1 = blockIdx.y * stride1 + max_displacement + kernel_rad;
54 | int x1 = blockIdx.z * stride1 + max_displacement + kernel_rad;
55 | int c = threadIdx.x;
56 |
57 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
58 | int pdimxc = pInputWidth * nInputChannels;
59 | int pdimc = nInputChannels;
60 |
61 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
62 | int tdimyx = outputHeight * outputWidth;
63 | int tdimx = outputWidth;
64 |
65 | float nelems = kernel_size * kernel_size * pdimc;
66 |
67 | __shared__ float prod_sum[THREADS_PER_BLOCK];
68 |
69 | // no significant speed-up in using chip memory for input1 sub-data,
70 | // not enough chip memory size to accomodate memory per block for input2 sub-data
71 | // instead i've used device memory for both
72 |
73 | // element-wise product along channel axis
74 | for (int tj = -displacement_rad; tj <= displacement_rad; ++tj ) {
75 | for (int ti = -displacement_rad; ti <= displacement_rad; ++ti ) {
76 | prod_sum[c] = 0;
77 | int x2 = x1 + ti*stride2;
78 | int y2 = y1 + tj*stride2;
79 |
80 | for (int j = -kernel_rad; j <= kernel_rad; ++j) {
81 | for (int i = -kernel_rad; i <= kernel_rad; ++i) {
82 | for (int ch = c; ch < pdimc; ch += THREADS_PER_BLOCK) {
83 | int indx1 = n * pdimyxc + (y1+j) * pdimxc + (x1 + i) * pdimc + ch;
84 | int indx2 = n * pdimyxc + (y2+j) * pdimxc + (x2 + i) * pdimc + ch;
85 |
86 | prod_sum[c] += rInput1[indx1] * rInput2[indx2];
87 | }
88 | }
89 | }
90 |
91 | // accumulate
92 | __syncthreads();
93 | if (c == 0) {
94 | float reduce_sum = 0;
95 | for (int index = 0; index < THREADS_PER_BLOCK; ++index) {
96 | reduce_sum += prod_sum[index];
97 | }
98 | int tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad);
99 | const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + blockIdx.z;
100 | output[tindx] = reduce_sum / nelems;
101 | }
102 |
103 | }
104 | }
105 |
106 | }
107 |
108 | __global__ void Correlation_backward_input1(int item, float *gradInput1, int nInputChannels, int inputHeight, int inputWidth,
109 | float *gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
110 | float *rInput2,
111 | int pad_size,
112 | int kernel_size,
113 | int max_displacement,
114 | int stride1,
115 | int stride2)
116 | {
117 | // n (batch size), c (num of channels), y (height), x (width)
118 |
119 | int n = item;
120 | int y = blockIdx.x * stride1 + pad_size;
121 | int x = blockIdx.y * stride1 + pad_size;
122 | int c = blockIdx.z;
123 | int tch_off = threadIdx.x;
124 |
125 | int kernel_rad = (kernel_size - 1) / 2;
126 | int displacement_rad = max_displacement / stride2;
127 | int displacement_size = 2 * displacement_rad + 1;
128 |
129 | int xmin = (x - kernel_rad - max_displacement) / stride1;
130 | int ymin = (y - kernel_rad - max_displacement) / stride1;
131 |
132 | int xmax = (x + kernel_rad - max_displacement) / stride1;
133 | int ymax = (y + kernel_rad - max_displacement) / stride1;
134 |
135 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
136 | // assumes gradInput1 is pre-allocated and zero filled
137 | return;
138 | }
139 |
140 | if (xmin > xmax || ymin > ymax) {
141 | // assumes gradInput1 is pre-allocated and zero filled
142 | return;
143 | }
144 |
145 | xmin = max(0,xmin);
146 | xmax = min(outputWidth-1,xmax);
147 |
148 | ymin = max(0,ymin);
149 | ymax = min(outputHeight-1,ymax);
150 |
151 | int pInputWidth = inputWidth + 2 * pad_size;
152 | int pInputHeight = inputHeight + 2 * pad_size;
153 |
154 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
155 | int pdimxc = pInputWidth * nInputChannels;
156 | int pdimc = nInputChannels;
157 |
158 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
159 | int tdimyx = outputHeight * outputWidth;
160 | int tdimx = outputWidth;
161 |
162 | int odimcyx = nInputChannels * inputHeight* inputWidth;
163 | int odimyx = inputHeight * inputWidth;
164 | int odimx = inputWidth;
165 |
166 | float nelems = kernel_size * kernel_size * nInputChannels;
167 |
168 | __shared__ float prod_sum[CUDA_NUM_THREADS];
169 | prod_sum[tch_off] = 0;
170 |
171 | for (int tc = tch_off; tc < nOutputChannels; tc += CUDA_NUM_THREADS) {
172 |
173 | int i2 = (tc % displacement_size - displacement_rad) * stride2;
174 | int j2 = (tc / displacement_size - displacement_rad) * stride2;
175 |
176 | int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c;
177 |
178 | float val2 = rInput2[indx2];
179 |
180 | for (int j = ymin; j <= ymax; ++j) {
181 | for (int i = xmin; i <= xmax; ++i) {
182 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
183 | prod_sum[tch_off] += gradOutput[tindx] * val2;
184 | }
185 | }
186 | }
187 | __syncthreads();
188 |
189 | if(tch_off == 0) {
190 | float reduce_sum = 0;
191 | for(int idx = 0; idx < CUDA_NUM_THREADS; idx++) {
192 | reduce_sum += prod_sum[idx];
193 | }
194 | const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
195 | gradInput1[indx1] = reduce_sum / nelems;
196 | }
197 |
198 | }
199 |
200 | __global__ void Correlation_backward_input2(int item, float *gradInput2, int nInputChannels, int inputHeight, int inputWidth,
201 | float *gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
202 | float *rInput1,
203 | int pad_size,
204 | int kernel_size,
205 | int max_displacement,
206 | int stride1,
207 | int stride2)
208 | {
209 | // n (batch size), c (num of channels), y (height), x (width)
210 |
211 | int n = item;
212 | int y = blockIdx.x * stride1 + pad_size;
213 | int x = blockIdx.y * stride1 + pad_size;
214 | int c = blockIdx.z;
215 |
216 | int tch_off = threadIdx.x;
217 |
218 | int kernel_rad = (kernel_size - 1) / 2;
219 | int displacement_rad = max_displacement / stride2;
220 | int displacement_size = 2 * displacement_rad + 1;
221 |
222 | int pInputWidth = inputWidth + 2 * pad_size;
223 | int pInputHeight = inputHeight + 2 * pad_size;
224 |
225 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
226 | int pdimxc = pInputWidth * nInputChannels;
227 | int pdimc = nInputChannels;
228 |
229 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
230 | int tdimyx = outputHeight * outputWidth;
231 | int tdimx = outputWidth;
232 |
233 | int odimcyx = nInputChannels * inputHeight* inputWidth;
234 | int odimyx = inputHeight * inputWidth;
235 | int odimx = inputWidth;
236 |
237 | float nelems = kernel_size * kernel_size * nInputChannels;
238 |
239 | __shared__ float prod_sum[CUDA_NUM_THREADS];
240 | prod_sum[tch_off] = 0;
241 |
242 | for (int tc = tch_off; tc < nOutputChannels; tc += CUDA_NUM_THREADS) {
243 | int i2 = (tc % displacement_size - displacement_rad) * stride2;
244 | int j2 = (tc / displacement_size - displacement_rad) * stride2;
245 |
246 | int xmin = (x - kernel_rad - max_displacement - i2) / stride1;
247 | int ymin = (y - kernel_rad - max_displacement - j2) / stride1;
248 |
249 | int xmax = (x + kernel_rad - max_displacement - i2) / stride1;
250 | int ymax = (y + kernel_rad - max_displacement - j2) / stride1;
251 |
252 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
253 | // assumes gradInput2 is pre-allocated and zero filled
254 | continue;
255 | }
256 |
257 | if (xmin > xmax || ymin > ymax) {
258 | // assumes gradInput2 is pre-allocated and zero filled
259 | continue;
260 | }
261 |
262 | xmin = max(0,xmin);
263 | xmax = min(outputWidth-1,xmax);
264 |
265 | ymin = max(0,ymin);
266 | ymax = min(outputHeight-1,ymax);
267 |
268 | int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c;
269 | float val1 = rInput1[indx1];
270 |
271 | for (int j = ymin; j <= ymax; ++j) {
272 | for (int i = xmin; i <= xmax; ++i) {
273 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
274 | prod_sum[tch_off] += gradOutput[tindx] * val1;
275 | }
276 | }
277 | }
278 |
279 | __syncthreads();
280 |
281 | if(tch_off == 0) {
282 | float reduce_sum = 0;
283 | for(int idx = 0; idx < CUDA_NUM_THREADS; idx++) {
284 | reduce_sum += prod_sum[idx];
285 | }
286 | const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
287 | gradInput2[indx2] = reduce_sum / nelems;
288 | }
289 |
290 | }
291 |
292 | #ifdef __cplusplus
293 | extern "C" {
294 | #endif
295 |
296 | int Correlation_forward_cuda_kernel(/*THCudaTensor_data(state, output)*/ float *output,
297 | /*THCudaTensor_size(state, output, 0)*/ int ob,
298 | /*THCudaTensor_size(state, output, 1)*/ int oc,
299 | /*THCudaTensor_size(state, output, 2)*/ int oh,
300 | /*THCudaTensor_size(state, output, 3)*/ int ow,
301 | /*THCudaTensor_stride(state, output, 0)*/ int osb,
302 | /*THCudaTensor_stride(state, output, 1)*/ int osc,
303 | /*THCudaTensor_stride(state, output, 2)*/ int osh,
304 | /*THCudaTensor_stride(state, output, 3)*/ int osw,
305 |
306 | /*THCudaTensor_data(state, input1)*/ float *input1,
307 | /*THCudaTensor_size(state, input1, 1)*/ int ic,
308 | /*THCudaTensor_size(state, input1, 2)*/ int ih,
309 | /*THCudaTensor_size(state, input1, 3)*/ int iw,
310 | /*THCudaTensor_stride(state, input1, 0)*/ int isb,
311 | /*THCudaTensor_stride(state, input1, 1)*/ int isc,
312 | /*THCudaTensor_stride(state, input1, 2)*/ int ish,
313 | /*THCudaTensor_stride(state, input1, 3)*/ int isw,
314 |
315 | /*THCudaTensor_data(state, input2)*/ float *input2,
316 | /*THCudaTensor_size(state, input2, 1)*/ int gc,
317 | /*THCudaTensor_stride(state, input2, 0)*/ int gsb,
318 | /*THCudaTensor_stride(state, input2, 1)*/ int gsc,
319 | /*THCudaTensor_stride(state, input2, 2)*/ int gsh,
320 | /*THCudaTensor_stride(state, input2, 3)*/ int gsw,
321 |
322 | /*THCudaTensor_data(state, rInput1)*/ float *rInput1,
323 | /*THCudaTensor_data(state, rInput2)*/ float *rInput2,
324 | int pad_size,
325 | int kernel_size,
326 | int max_displacement,
327 | int stride1,
328 | int stride2,
329 | int corr_type_multiply,
330 | /*THCState_getCurrentStream(state)*/ cudaStream_t stream)
331 | {
332 | int batchSize = ob;
333 |
334 | int nInputChannels = ic;
335 | int inputWidth = iw;
336 | int inputHeight = ih;
337 |
338 | int nOutputChannels = oc;
339 | int outputWidth = ow;
340 | int outputHeight = oh;
341 |
342 | dim3 blocks_grid(batchSize, inputHeight, inputWidth);
343 | dim3 threads_block(THREADS_PER_BLOCK);
344 |
345 | channels_first<<>> (input1,rInput1, nInputChannels, inputHeight, inputWidth,pad_size);
346 | channels_first<<>> (input2,rInput2, nInputChannels, inputHeight, inputWidth, pad_size);
347 |
348 | dim3 threadsPerBlock(THREADS_PER_BLOCK);
349 | dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth);
350 |
351 | Correlation_forward <<< totalBlocksCorr, threadsPerBlock, 0, stream >>>
352 | (output, nOutputChannels, outputHeight, outputWidth,
353 | rInput1, nInputChannels, inputHeight, inputWidth,
354 | rInput2,
355 | pad_size,
356 | kernel_size,
357 | max_displacement,
358 | stride1,
359 | stride2);
360 |
361 | // check for errors
362 | cudaError_t err = cudaGetLastError();
363 | if (err != cudaSuccess) {
364 | printf("error in Correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err));
365 | return 0;
366 | }
367 |
368 | return 1;
369 | }
370 |
371 | int Correlation_backward_cuda_kernel(
372 | /*THCudaTensor_data(state, gradOutput)*/ float *gradOutput,
373 | /*THCudaTensor_size(state, gradOutput, 0)*/ int gob,
374 | /*THCudaTensor_size(state, gradOutput, 1)*/ int goc,
375 | /*THCudaTensor_size(state, gradOutput, 2)*/ int goh,
376 | /*THCudaTensor_size(state, gradOutput, 3)*/ int gow,
377 | /*THCudaTensor_stride(state, gradOutput, 0)*/ int gosb,
378 | /*THCudaTensor_stride(state, gradOutput, 1)*/ int gosc,
379 | /*THCudaTensor_stride(state, gradOutput, 2)*/ int gosh,
380 | /*THCudaTensor_stride(state, gradOutput, 3)*/ int gosw,
381 |
382 | /*THCudaTensor_data(state, input1)*/ float* input1,
383 | /*THCudaTensor_size(state, input1, 1)*/ int ic,
384 | /*THCudaTensor_size(state, input1, 2)*/ int ih,
385 | /*THCudaTensor_size(state, input1, 3)*/ int iw,
386 | /*THCudaTensor_stride(state, input1, 0)*/ int isb,
387 | /*THCudaTensor_stride(state, input1, 1)*/ int isc,
388 | /*THCudaTensor_stride(state, input1, 2)*/ int ish,
389 | /*THCudaTensor_stride(state, input1, 3)*/ int isw,
390 |
391 | /*THCudaTensor_data(state, input2)*/ float *input2,
392 | /*THCudaTensor_stride(state, input2, 0)*/ int gsb,
393 | /*THCudaTensor_stride(state, input2, 1)*/ int gsc,
394 | /*THCudaTensor_stride(state, input2, 2)*/ int gsh,
395 | /*THCudaTensor_stride(state, input2, 3)*/ int gsw,
396 |
397 | /*THCudaTensor_data(state, gradInput1)*/ float *gradInput1,
398 | /*THCudaTensor_stride(state, gradInput1, 0)*/ int gisb,
399 | /*THCudaTensor_stride(state, gradInput1, 1)*/ int gisc,
400 | /*THCudaTensor_stride(state, gradInput1, 2)*/ int gish,
401 | /*THCudaTensor_stride(state, gradInput1, 3)*/ int gisw,
402 |
403 | /*THCudaTensor_data(state, gradInput2)*/ float *gradInput2,
404 | /*THCudaTensor_size(state, gradInput2, 1)*/ int ggc,
405 | /*THCudaTensor_stride(state, gradInput2, 0)*/ int ggsb,
406 | /*THCudaTensor_stride(state, gradInput2, 1)*/ int ggsc,
407 | /*THCudaTensor_stride(state, gradInput2, 2)*/ int ggsh,
408 | /*THCudaTensor_stride(state, gradInput2, 3)*/ int ggsw,
409 |
410 | /*THCudaTensor_data(state, rInput1)*/ float *rInput1,
411 | /*THCudaTensor_data(state, rInput2)*/ float *rInput2,
412 | int pad_size,
413 | int kernel_size,
414 | int max_displacement,
415 | int stride1,
416 | int stride2,
417 | int corr_type_multiply,
418 | /*THCState_getCurrentStream(state)*/cudaStream_t stream)
419 | {
420 |
421 | int batchSize = gob;
422 | int num = batchSize;
423 |
424 | int nInputChannels = ic;
425 | int inputWidth = iw;
426 | int inputHeight = ih;
427 |
428 | int nOutputChannels = goc;
429 | int outputWidth = gow;
430 | int outputHeight = goh;
431 |
432 | dim3 blocks_grid(batchSize, inputHeight, inputWidth);
433 | dim3 threads_block(THREADS_PER_BLOCK);
434 |
435 | channels_first<<>> (input1, rInput1, nInputChannels,inputHeight, inputWidth, pad_size);
436 | channels_first<<>> (input2, rInput2, nInputChannels, inputHeight, inputWidth, pad_size);
437 |
438 | dim3 threadsPerBlock(CUDA_NUM_THREADS);
439 | dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels);
440 |
441 | for (int n = 0; n < num; ++n) {
442 | Correlation_backward_input1 << > > (
443 | n, gradInput1, nInputChannels, inputHeight, inputWidth,
444 | gradOutput, nOutputChannels, outputHeight, outputWidth,
445 | rInput2,
446 | pad_size,
447 | kernel_size,
448 | max_displacement,
449 | stride1,
450 | stride2);
451 | }
452 |
453 | for(int n = 0; n < batchSize; n++) {
454 | Correlation_backward_input2<<>>(
455 | n, gradInput2, nInputChannels, inputHeight, inputWidth,
456 | gradOutput, nOutputChannels, outputHeight, outputWidth,
457 | rInput1,
458 | pad_size,
459 | kernel_size,
460 | max_displacement,
461 | stride1,
462 | stride2);
463 | }
464 |
465 | // check for errors
466 | cudaError_t err = cudaGetLastError();
467 | if (err != cudaSuccess) {
468 | printf("error in Correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err));
469 | return 0;
470 | }
471 |
472 | return 1;
473 | }
474 |
475 | #ifdef __cplusplus
476 | }
477 | #endif
478 |
--------------------------------------------------------------------------------
/correlation_package/src/correlation_cuda_kernel.h:
--------------------------------------------------------------------------------
1 | #ifdef __cplusplus
2 | extern "C" {
3 | #endif
4 |
5 | int Correlation_forward_cuda_kernel(/*THCudaTensor_data(state, output)*/ float *output,
6 | /*THCudaTensor_size(state, output, 0)*/ int ob,
7 | /*THCudaTensor_size(state, output, 1)*/ int oc,
8 | /*THCudaTensor_size(state, output, 2)*/ int oh,
9 | /*THCudaTensor_size(state, output, 3)*/ int ow,
10 | /*THCudaTensor_stride(state, output, 0)*/ int osb,
11 | /*THCudaTensor_stride(state, output, 1)*/ int osc,
12 | /*THCudaTensor_stride(state, output, 2)*/ int osh,
13 | /*THCudaTensor_stride(state, output, 3)*/ int osw,
14 |
15 | /*THCudaTensor_data(state, input1)*/ float *input1,
16 | /*THCudaTensor_size(state, input1, 1)*/ int ic,
17 | /*THCudaTensor_size(state, input1, 2)*/ int ih,
18 | /*THCudaTensor_size(state, input1, 3)*/ int iw,
19 | /*THCudaTensor_stride(state, input1, 0)*/ int isb,
20 | /*THCudaTensor_stride(state, input1, 1)*/ int isc,
21 | /*THCudaTensor_stride(state, input1, 2)*/ int ish,
22 | /*THCudaTensor_stride(state, input1, 3)*/ int isw,
23 |
24 | /*THCudaTensor_data(state, input2)*/ float *input2,
25 | /*THCudaTensor_size(state, input2, 1)*/ int gc,
26 | /*THCudaTensor_stride(state, input2, 0)*/ int gsb,
27 | /*THCudaTensor_stride(state, input2, 1)*/ int gsc,
28 | /*THCudaTensor_stride(state, input2, 2)*/ int gsh,
29 | /*THCudaTensor_stride(state, input2, 3)*/ int gsw,
30 |
31 | /*THCudaTensor_data(state, rInput1)*/ float *rInput1,
32 | /*THCudaTensor_data(state, rInput2)*/ float *rInput2,
33 | int pad_size,
34 | int kernel_size,
35 | int max_displacement,
36 | int stride1,
37 | int stride2,
38 | int corr_type_multiply,
39 | /*THCState_getCurrentStream(state)*/ cudaStream_t stream);
40 |
41 | int Correlation_backward_cuda_kernel(
42 | /*THCudaTensor_data(state, gradOutput)*/ float *gradOutput,
43 | /*THCudaTensor_size(state, gradOutput, 0)*/ int gob,
44 | /*THCudaTensor_size(state, gradOutput, 1)*/ int goc,
45 | /*THCudaTensor_size(state, gradOutput, 2)*/ int goh,
46 | /*THCudaTensor_size(state, gradOutput, 3)*/ int gow,
47 | /*THCudaTensor_stride(state, gradOutput, 0)*/ int gosb,
48 | /*THCudaTensor_stride(state, gradOutput, 1)*/ int gosc,
49 | /*THCudaTensor_stride(state, gradOutput, 2)*/ int gosh,
50 | /*THCudaTensor_stride(state, gradOutput, 3)*/ int gosw,
51 |
52 | /*THCudaTensor_data(state, input1)*/ float* input1,
53 | /*THCudaTensor_size(state, input1, 1)*/ int ic,
54 | /*THCudaTensor_size(state, input1, 2)*/ int ih,
55 | /*THCudaTensor_size(state, input1, 3)*/ int iw,
56 | /*THCudaTensor_stride(state, input1, 0)*/ int isb,
57 | /*THCudaTensor_stride(state, input1, 1)*/ int isc,
58 | /*THCudaTensor_stride(state, input1, 2)*/ int ish,
59 | /*THCudaTensor_stride(state, input1, 3)*/ int isw,
60 |
61 | /*THCudaTensor_data(state, input2)*/ float *input2,
62 | /*THCudaTensor_stride(state, input2, 0)*/ int gsb,
63 | /*THCudaTensor_stride(state, input2, 1)*/ int gsc,
64 | /*THCudaTensor_stride(state, input2, 2)*/ int gsh,
65 | /*THCudaTensor_stride(state, input2, 3)*/ int gsw,
66 |
67 | /*THCudaTensor_data(state, gradInput1)*/ float *gradInput1,
68 | /*THCudaTensor_stride(state, gradInput1, 0)*/ int gisb,
69 | /*THCudaTensor_stride(state, gradInput1, 1)*/ int gisc,
70 | /*THCudaTensor_stride(state, gradInput1, 2)*/ int gish,
71 | /*THCudaTensor_stride(state, gradInput1, 3)*/ int gisw,
72 |
73 | /*THCudaTensor_data(state, gradInput2)*/ float *gradInput2,
74 | /*THCudaTensor_size(state, gradInput2, 1)*/ int ggc,
75 | /*THCudaTensor_stride(state, gradInput2, 0)*/ int ggsb,
76 | /*THCudaTensor_stride(state, gradInput2, 1)*/ int ggsc,
77 | /*THCudaTensor_stride(state, gradInput2, 2)*/ int ggsh,
78 | /*THCudaTensor_stride(state, gradInput2, 3)*/ int ggsw,
79 |
80 | /*THCudaTensor_data(state, rInput1)*/ float *rInput1,
81 | /*THCudaTensor_data(state, rInput2)*/ float *rInput2,
82 | int pad_size,
83 | int kernel_size,
84 | int max_displacement,
85 | int stride1,
86 | int stride2,
87 | int corr_type_multiply,
88 | /*THCState_getCurrentStream(state)*/cudaStream_t stream);
89 |
90 | #ifdef __cplusplus
91 | }
92 | #endif
93 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from back2future import Model
3 | import numpy as np
4 | from scipy.misc import imread, imresize
5 | from torchvision.transforms import ToTensor
6 | from torch.autograd import Variable
7 |
8 | def main():
9 | model = Model(pretrained='pretrained/back2future_kitti.pth.tar')
10 | model = model.cuda()
11 | im_tar, im_refs = fetch_image_tensors()
12 | im_tar = Variable(im_tar.unsqueeze(0)).cuda()
13 | im_refs = [Variable(im_r.unsqueeze(0)).cuda() for im_r in im_refs]
14 | flow_fwd, flow_bwd, occ = model(im_tar, im_refs)
15 | np.save('outputs.npy', {'flow_fwd':flow_fwd[0].cpu().data.numpy(),
16 | 'flow_bwd':flow_bwd[0].cpu().data.numpy(),
17 | 'occ':occ[0].cpu().data.numpy()})
18 | print('Outputs saved in outputs.npy')
19 | print('Done!')
20 |
21 | def load_as_float(path):
22 | return imread(path).astype(np.float32)
23 |
24 | def fetch_image_tensors():
25 | im0 = load_as_float('samples/000010_09.png')
26 | im1 = load_as_float('samples/000010_10.png')
27 | im2 = load_as_float('samples/000010_11.png')
28 | scale = Scale(h=256, w=832)
29 | im012 = scale([im0, im1, im2])
30 | im_tar = im012[1]
31 | im_refs = [im012[0], im012[2]]
32 | return im_tar, im_refs
33 |
34 | class Scale(object):
35 | """Scales images to a particular size"""
36 | def __init__(self, h, w):
37 | self.h = h
38 | self.w = w
39 |
40 | def __call__(self, images):
41 | in_h, in_w, _ = images[0].shape
42 | scaled_h, scaled_w = self.h , self.w
43 | scaled_images = [ToTensor()(imresize(im, (scaled_h, scaled_w))) for im in images]
44 | return scaled_images
45 |
46 | if __name__ == '__main__':
47 | main()
48 |
--------------------------------------------------------------------------------
/flow_io.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python2
2 |
3 | """
4 | I/O script to save and load the data coming with the MPI-Sintel low-level
5 | computer vision benchmark.
6 |
7 | For more details about the benchmark, please visit www.mpi-sintel.de
8 |
9 | CHANGELOG:
10 | v1.0 (2015/02/03): First release
11 |
12 | Copyright (c) 2015 Jonas Wulff
13 | Max Planck Institute for Intelligent Systems, Tuebingen, Germany
14 |
15 | """
16 |
17 | # Requirements: Numpy as PIL/Pillow
18 | import numpy as np
19 | try:
20 | import png
21 | has_png = True
22 | except:
23 | has_png = False
24 | png=None
25 |
26 |
27 |
28 | # Check for endianness, based on Daniel Scharstein's optical flow code.
29 | # Using little-endian architecture, these two should be equal.
30 | TAG_FLOAT = 202021.25
31 | TAG_CHAR = 'PIEH'.encode()
32 |
33 | def flow_read(filename, return_validity=False):
34 | """ Read optical flow from file, return (U,V) tuple.
35 |
36 | Original code by Deqing Sun, adapted from Daniel Scharstein.
37 | """
38 | f = open(filename,'rb')
39 | check = np.fromfile(f,dtype=np.float32,count=1)[0]
40 | assert check == TAG_FLOAT, ' flow_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
41 | width = np.fromfile(f,dtype=np.int32,count=1)[0]
42 | height = np.fromfile(f,dtype=np.int32,count=1)[0]
43 | size = width*height
44 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' flow_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height)
45 | tmp = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width*2))
46 | u = tmp[:,np.arange(width)*2]
47 | v = tmp[:,np.arange(width)*2 + 1]
48 |
49 | if return_validity:
50 | valid = u<1e19
51 | u[valid==0] = 0
52 | v[valid==0] = 0
53 | return u,v,valid
54 | else:
55 | return u,v
56 |
57 | def flow_write(filename,uv,v=None):
58 | """ Write optical flow to file.
59 |
60 | If v is None, uv is assumed to contain both u and v channels,
61 | stacked in depth.
62 |
63 | Original code by Deqing Sun, adapted from Daniel Scharstein.
64 | """
65 | nBands = 2
66 |
67 | if v is None:
68 | uv_ = np.array(uv)
69 | assert(uv_.ndim==3)
70 | if uv_.shape[0] == 2:
71 | u = uv_[0,:,:]
72 | v = uv_[1,:,:]
73 | elif uv_.shape[2] == 2:
74 | u = uv_[:,:,0]
75 | v = uv_[:,:,1]
76 | else:
77 | raise UVError('Wrong format for flow input')
78 | else:
79 | u = uv
80 |
81 | assert(u.shape == v.shape)
82 | height,width = u.shape
83 | f = open(filename,'wb')
84 | # write the header
85 | f.write(TAG_CHAR)
86 | np.array(width).astype(np.int32).tofile(f)
87 | np.array(height).astype(np.int32).tofile(f)
88 | # arrange into matrix form
89 | tmp = np.zeros((height, width*nBands))
90 | tmp[:,np.arange(width)*2] = u
91 | tmp[:,np.arange(width)*2 + 1] = v
92 | tmp.astype(np.float32).tofile(f)
93 | f.close()
94 |
95 |
96 | def flow_read_png(fpath):
97 | """
98 | Read KITTI optical flow, returns u,v,valid mask
99 |
100 | """
101 | if not has_png:
102 | print('Error. Please install the PyPNG library')
103 | return
104 |
105 | R = png.Reader(fpath)
106 | width,height,data,_ = R.asDirect()
107 | # This only worked with python2.
108 | #I = np.array(map(lambda x:x,data)).reshape((height,width,3))
109 | I = np.array([x for x in data]).reshape((height,width,3))
110 | u_ = I[:,:,0]
111 | v_ = I[:,:,1]
112 | valid = I[:,:,2]
113 |
114 | u = (u_.astype('float64')-2**15)/64.0
115 | v = (v_.astype('float64')-2**15)/64.0
116 |
117 | return u,v,valid
118 |
119 |
120 | def flow_write_png(fpath,u,v,valid=None):
121 | """
122 | Write KITTI optical flow.
123 |
124 | """
125 | if not has_png:
126 | print('Error. Please install the PyPNG library')
127 | return
128 |
129 |
130 | if valid==None:
131 | valid_ = np.ones(u.shape,dtype='uint16')
132 | else:
133 | valid_ = valid.astype('uint16')
134 |
135 |
136 | u = u.astype('float64')
137 | v = v.astype('float64')
138 |
139 | u_ = ((u*64.0)+2**15).astype('uint16')
140 | v_ = ((v*64.0)+2**15).astype('uint16')
141 |
142 | I = np.dstack((u_,v_,valid_))
143 |
144 | W = png.Writer(width=u.shape[1],
145 | height=u.shape[0],
146 | bitdepth=16,
147 | planes=3)
148 |
149 | with open(fpath,'wb') as fil:
150 | W.write(fil,I.reshape((-1,3*u.shape[1])))
151 |
--------------------------------------------------------------------------------
/pretrained/b2f_kitti.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/pretrained/b2f_kitti.pth.tar
--------------------------------------------------------------------------------
/pretrained/b2f_sintel.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/pretrained/b2f_sintel.pth.tar
--------------------------------------------------------------------------------
/samples/000010_09.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/samples/000010_09.png
--------------------------------------------------------------------------------
/samples/000010_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/samples/000010_10.png
--------------------------------------------------------------------------------
/samples/000010_11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anuragranj/back2future.pytorch/a3b619a9eb11c91866160565a8593dda690c2da9/samples/000010_11.png
--------------------------------------------------------------------------------
/test_back2future.py:
--------------------------------------------------------------------------------
1 | # Author: Anurag Ranjan
2 | # Copyright (c) 2018, Max Planck Society
3 |
4 | import argparse
5 | import torch
6 | import torch.nn as nn
7 | from path import Path
8 | from torch.autograd import Variable
9 | from torchvision.transforms import ToTensor
10 |
11 | from scipy.misc import imread, imresize
12 | from tqdm import tqdm
13 | import numpy as np
14 |
15 | from back2future import Model
16 | import flow_io
17 |
18 | parser = argparse.ArgumentParser(description='Code to test performace of Back2Future models on KITTI benchmarks',
19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20 | parser.add_argument('--pretrained-flow', dest='pretrained_flow', default=None, metavar='PATH',
21 | help='path to pre-trained Flow net model')
22 | parser.add_argument('--kitti-dir', dest='kitti_dir', default=None, metavar='PATH',
23 | help='path to KITTI 2015 directory')
24 |
25 |
26 | def main():
27 | global args
28 | args = parser.parse_args()
29 | flow_loader_h, flow_loader_w = 384, 1280
30 | valid_flow_transform = Scale(h=flow_loader_h, w=flow_loader_w)
31 | val_flow_set = KITTI2015(root=args.kitti_dir,
32 | transform=valid_flow_transform)
33 |
34 | val_flow_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1, shuffle=False,
35 | num_workers=2, pin_memory=True)
36 |
37 | flow_net = Model(pretrained=args.pretrained_flow).cuda()
38 |
39 | flow_net.eval()
40 | error_names = ['epe_total']
41 | errors = AverageMeter(i=len(error_names))
42 |
43 | for i, (tgt_img, ref_imgs, flow_gt) in enumerate(tqdm(val_flow_loader)):
44 | tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
45 | ref_imgs_var = [Variable(img.cuda(), volatile=True) for img in ref_imgs]
46 | flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
47 |
48 | # compute output
49 | flow_fwd, flow_bwd, occ = flow_net(tgt_img_var, ref_imgs_var)
50 | epe = compute_epe(gt=flow_gt_var, pred=flow_fwd[0])
51 | errors.update(epe)
52 |
53 | print("Averge EPE",errors.avg )
54 |
55 |
56 | class KITTI2015(torch.utils.data.Dataset):
57 | """
58 | Kitti 2015 loader
59 | """
60 |
61 | def __init__(self, root, transform=None, N=200, train=True, seed=0):
62 | self.root = Path(root)
63 | self.scenes = range(N)
64 | self.N = N
65 | self.transform = transform
66 | self.phase = 'training' if train else 'testing'
67 | self.seq_ids = [9, 11]
68 |
69 | def __getitem__(self, index):
70 | tgt_img_path = self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_10.png')
71 | ref_img_paths = [self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_'+str(k).zfill(2)+'.png') for k in self.seq_ids]
72 | gt_flow_path = self.root.joinpath('data_scene_flow', self.phase, 'flow_occ', str(index).zfill(6)+'_10.png')
73 |
74 | tgt_img = load_as_float(tgt_img_path)
75 | ref_imgs = [load_as_float(ref_img) for ref_img in ref_img_paths]
76 |
77 | u,v,valid = flow_io.flow_read_png(gt_flow_path)
78 | gtFlow = np.dstack((u,v,valid))
79 | gtFlow = torch.FloatTensor(gtFlow.transpose(2,0,1))
80 |
81 | if self.transform is not None:
82 | imgs = self.transform([tgt_img] + ref_imgs)
83 | tgt_img = imgs[0]
84 | ref_imgs = imgs[1:]
85 |
86 | return tgt_img, ref_imgs, gtFlow
87 |
88 | def __len__(self):
89 | return self.N
90 |
91 |
92 | class Scale(object):
93 | """Scales images to a particular size"""
94 | def __init__(self, h, w):
95 | self.h = h
96 | self.w = w
97 |
98 | def __call__(self, images):
99 | in_h, in_w, _ = images[0].shape
100 | scaled_h, scaled_w = self.h , self.w
101 | scaled_images = [ToTensor()(imresize(im, (scaled_h, scaled_w))) for im in images]
102 | return scaled_images
103 |
104 | def compute_epe(gt, pred):
105 | _, _, h_pred, w_pred = pred.size()
106 | bs, nc, h_gt, w_gt = gt.size()
107 |
108 | u_gt, v_gt = gt[:,0,:,:], gt[:,1,:,:]
109 | pred = nn.functional.upsample(pred, size=(h_gt, w_gt), mode='bilinear')
110 | u_pred = pred[:,0,:,:] * (w_gt/w_pred)
111 | v_pred = pred[:,1,:,:] * (h_gt/h_pred)
112 |
113 | epe = torch.sqrt(torch.pow((u_gt - u_pred), 2) + torch.pow((v_gt - v_pred), 2))
114 |
115 | if nc == 3:
116 | valid = gt[:,2,:,:]
117 | epe = epe * valid
118 | avg_epe = epe.sum()/(valid.sum() + 1e-6)
119 | else:
120 | avg_epe = epe.sum()/(bs*h_gt*w_gt)
121 |
122 | if type(avg_epe) == Variable: avg_epe = avg_epe.data
123 | return avg_epe[0]
124 |
125 | def load_as_float(path):
126 | return imread(path).astype(np.float32)
127 |
128 | class AverageMeter(object):
129 | """Computes and stores the average and current value"""
130 |
131 | def __init__(self, i=1, precision=3):
132 | self.meters = i
133 | self.precision = precision
134 | self.reset(self.meters)
135 |
136 | def reset(self, i):
137 | self.val = [0]*i
138 | self.avg = [0]*i
139 | self.sum = [0]*i
140 | self.count = 0
141 |
142 | def update(self, val, n=1):
143 | if not isinstance(val, list):
144 | val = [val]
145 | assert(len(val) == self.meters)
146 | self.count += n
147 | for i,v in enumerate(val):
148 | self.val[i] = v
149 | self.sum[i] += v * n
150 | self.avg[i] = self.sum[i] / self.count
151 |
152 | def __repr__(self):
153 | val = ' '.join(['{:.{}f}'.format(v, self.precision) for v in self.val])
154 | avg = ' '.join(['{:.{}f}'.format(a, self.precision) for a in self.avg])
155 | return '{} ({})'.format(val, avg)
156 |
157 | if __name__ == '__main__':
158 | main()
159 |
--------------------------------------------------------------------------------
|