├── LICENSE ├── README.md ├── compile.sh ├── core ├── corr.py ├── cost_agg.py ├── datasets.py ├── extractor.py ├── sepflow.py ├── update.py └── utils │ ├── augmentor.py │ ├── flow_viz.py │ ├── frame_utils.py │ └── utils.py ├── demo.py ├── evaluate.py ├── evaluate.sh ├── libs └── GANet │ ├── functions │ ├── GANet.py │ └── __init__.py │ ├── modules │ ├── GANet.py │ └── __init__.py │ ├── setup.py │ └── src │ ├── GANet_cuda.cpp │ ├── GANet_cuda.h │ ├── GANet_kernel.cu │ ├── GANet_kernel.h │ ├── GANet_kernel_share.cu │ ├── NLF_kernel.cu │ └── costvolume.cu ├── train.py └── train.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Feihu Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SeparableFlow 2 | [Separable Flow: Learning Motion Cost Volumes for Optical Flow Estimation](https://openaccess.thecvf.com/content/ICCV2021/papers/Zhang_Separable_Flow_Learning_Motion_Cost_Volumes_for_Optical_Flow_Estimation_ICCV_2021_paper.pdf) 3 | 4 | 5 | 6 | 7 | ## Building Requirements: 8 | 9 | gcc: >=5.3 10 | GPU mem: >=5G (for testing); >=11G (for training) 11 | pytorch: >=1.6 12 | cuda: >=9.2 (9.0 doesn’t support well for the new pytorch version and may have “pybind11 errors”.) 13 | tested platform/settings: 14 | 1) ubuntu 18.04 + cuda 11.0 + python 3.6, 3.7 15 | 2) centos + cuda 11 + python 3.7 16 | 17 | 18 | ## Environment: 19 | 20 | conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia 21 | conda install matplotlib tensorboard scipy opencv 22 | pip install einops opencv-python pypng 23 | 24 | 25 | ## How to Use? 26 | 27 | Step 1: compile the libs by "sh compile.sh" 28 | - Change the environmental variable ($PATH, $LD_LIBRARY_PATH etc.), if it's not set correctly in your system environment (e.g. .bashrc). Examples are included in "compile.sh". 29 | 30 | Step 2: download and prepare the training dataset or your own test set. 31 | 32 | 33 | Step 3: revise parameter settings and run "train.sh" and "evaluate.sh" for training, finetuning and prediction/testing. Note that the “crop_width” and “crop_height” must be multiple of 64 during training. 34 | 35 | Demo example: (use "sintel" or "universal" for other unseen datasets): 36 | $ python demo.py --model checkpoints/sepflow_universal.pth --path ./your-own-image-folder 37 | 38 | 39 | ## Pretrained models: 40 | 41 | | things | sintel | kitti| universal | 42 | |---|---|---|---| 43 | |[Google Drive](https://drive.google.com/file/d/1baepLE9wxmt4QJEGMC5QeaQCQfZETEAu/view?usp=sharing)|[Google Drive](https://drive.google.com/file/d/1bpm0HmwcBrbyAsikTJR3qST6mAavQ60k/view?usp=sharing)|[Google Drive](https://drive.google.com/file/d/1qqpuaPpFBcg5TjBrg49MZvdJoL7bEy8A/view?usp=sharing)|[Google Drive](https://drive.google.com/file/d/1FTYSdHzW12Iejal6n4xEbdKPyrSK-W6P/view?usp=sharing)| 44 | |[Baidu Yun (password: 9qcd)](https://pan.baidu.com/s/1lK2q0QtMwC0ROVCd6tyejA?pwd=9qcd)|[Baidu Yun (password: m1xs)](https://pan.baidu.com/s/1rtUrsGiTjU0GqMys1xRm6Q?pwd=m1xs)|[Baidu Yun (password: sg46)](https://pan.baidu.com/s/1ALo1lFmQkkziagoRPxzSsQ?pwd=sg46)|[Baidu Yun (password: 2has)](https://pan.baidu.com/s/1AP7ytz3HPy-oZZdNXzduWw?pwd=2has)| 45 | 46 | These pre-trained models perform a little better than those reported in our original paper. 47 | "universal" is trained on a mixture of synthetic and real datasets for cross-domain generalization. 48 | 49 | | Leadboards | Sintel clean | Sintel final | KITTI | 50 | |---|---|---|---| 51 | | RAFT baseline | 1.94 | 3.18 | 5.10 | 52 | | Orginal paper | 1.50 | 2.67 | 4.64 | 53 | | This new implementation | 1.49 | 2.64 | 4.53 | 54 | 55 | *Standard two-frame evaluations without previous video frames for "warm start".* 56 | 57 | 58 | 59 | 60 | ## Reference: 61 | 62 | If you find the code useful, please cite our paper: 63 | 64 | @inproceedings{Zhang2021SepFlow, 65 | title={Separable Flow: Learning Motion Cost Volumes for Optical Flow Estimation}, 66 | author={Zhang, Feihu and Woodford, Oliver J. and Prisacariu, Victor Adrian and Torr, Philip H.S.}, 67 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 68 | year={2021} 69 | pages={10807-10817} 70 | } 71 | 72 | The code is implemented based on 73 | https://github.com/feihuzhang/DSMNet and https://github.com/princeton-vl/RAFT. 74 | Please also consider citing: 75 | 76 | @inproceedings{zhang2019domaininvariant, 77 | title={Domain-invariant Stereo Matching Networks}, 78 | author={Feihu Zhang and Xiaojuan Qi and Ruigang Yang and Victor Prisacariu and Benjamin Wah and Philip Torr}, 79 | booktitle={Europe Conference on Computer Vision (ECCV)}, 80 | year={2020} 81 | } 82 | @inproceedings{teed2020raft, 83 | title={RAFT: Recurrent All Pairs Field Transforms for Optical Flow}, 84 | author={Zachary Teed and Jia Deng}, 85 | booktitle={Europe Conference on Computer Vision (ECCV)}, 86 | year={2020} 87 | } 88 | 89 | -------------------------------------------------------------------------------- /compile.sh: -------------------------------------------------------------------------------- 1 | #export LD_LIBRARY_PATH="/home/feihu/anaconda3/lib:$LD_LIBRARY_PATH" 2 | #export LD_INCLUDE_PATH="/home/feihu/anaconda3/include:$LD_INCLUDE_PATH" 3 | #export CUDA_HOME="/usr/local/cuda-10.0" 4 | #export PATH="/home/feihu/anaconda3/bin:/usr/local/cuda-10.0/bin:$PATH" 5 | #export CPATH="/usr/local/cuda-10.0/include" 6 | #export CUDNN_INCLUDE_DIR="/usr/local/cuda-10.0/include" 7 | #export CUDNN_LIB_DIR="/usr/local/cuda-10.0/lib64" 8 | 9 | TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))") 10 | #echo $TORCH 11 | cd libs/GANet 12 | python setup.py clean 13 | rm -rf build 14 | python setup.py build 15 | cp -r build/lib* build/lib 16 | -------------------------------------------------------------------------------- /core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from utils.utils import bilinear_sampler, coords_grid 5 | from libs.GANet.modules.GANet import NLFMax, NLFIter 6 | 7 | try: 8 | import alt_cuda_corr 9 | except: 10 | # alt_cuda_corr is not compiled 11 | pass 12 | class NLF(nn.Module): 13 | 14 | def __init__(self, in_channel=32): 15 | super(NLF, self).__init__() 16 | self.nlf = NLFIter() 17 | def forward(self, x, g): 18 | N, D1, D2, H, W = x.shape 19 | x = x.reshape(N, D1*D2, H, W).contiguous() 20 | rem = x 21 | k1, k2, k3, k4 = torch.split(g, (5, 5, 5, 5), 1) 22 | # k1, k2, k3, k4 = self.getweights(x) 23 | k1 = F.normalize(k1, p=1, dim=1) 24 | k2 = F.normalize(k2, p=1, dim=1) 25 | k3 = F.normalize(k3, p=1, dim=1) 26 | k4 = F.normalize(k4, p=1, dim=1) 27 | 28 | x = self.nlf(x, k1, k2, k3, k4) 29 | # x = x + rem 30 | x = x.reshape(N, D1, D2, H, W) 31 | return x 32 | 33 | 34 | 35 | class CorrBlock: 36 | def __init__(self, fmap1, fmap2, guid, num_levels=4, radius=4): 37 | self.num_levels = num_levels 38 | self.radius = radius 39 | self.corr_pyramid = [] 40 | 41 | # all pairs correlation 42 | self.nlf = NLF() 43 | corr = self.corr_compute(fmap1, fmap2, guid, reverse=True) 44 | #corr = self.nlf(corr, g) 45 | #corr = corr.permute(0, 3,4, 1,2) 46 | 47 | batch, h1, w1, h2, w2 = corr.shape 48 | self.shape = corr.shape 49 | corr = corr.reshape(batch*h1*w1, 1, h2, w2) 50 | 51 | self.corr_pyramid.append(corr) 52 | for i in range(self.num_levels-1): 53 | corr = F.avg_pool2d(corr, 2, stride=2) 54 | self.corr_pyramid.append(corr) 55 | def separate(self): 56 | sep_u = [] 57 | sep_v = [] 58 | for i in range(self.num_levels): 59 | corr = self.corr_pyramid[i] 60 | m1, _ = corr.max(dim=2, keepdim=True) 61 | m2 = corr.mean(dim=2, keepdim=True) 62 | sep = torch.cat((m1, m2), dim=2) 63 | sep = sep.reshape(self.shape[0], self.shape[1], self.shape[2], sep.shape[2], sep.shape[3]).permute(0, 3, 4, 1, 2) 64 | sep = F.interpolate(sep, [self.shape[4], self.shape[1], self.shape[2]], mode='trilinear', align_corners=True) 65 | sep_u.append(sep) 66 | m1, _ = corr.max(dim=3, keepdim=True) 67 | m2 = corr.mean(dim=3, keepdim=True) 68 | sep = torch.cat((m1, m2), dim=3) 69 | sep = sep.reshape(self.shape[0], self.shape[1], self.shape[2], sep.shape[2], sep.shape[3]).permute(0, 4, 3, 1, 2) 70 | sep = F.interpolate(sep, [self.shape[3], self.shape[1], self.shape[2]], mode='trilinear', align_corners=True) 71 | sep_v.append(sep) 72 | sep_u = torch.cat(sep_u, dim=1) 73 | sep_v = torch.cat(sep_v, dim=1) 74 | return sep_u, sep_v 75 | 76 | 77 | def __call__(self, coords, sep=False): 78 | if sep: 79 | return self.separate() 80 | r = self.radius 81 | coords = coords.permute(0, 2, 3, 1) 82 | batch, h1, w1, _ = coords.shape 83 | 84 | out_pyramid = [] 85 | for i in range(self.num_levels): 86 | corr = self.corr_pyramid[i] 87 | dx = torch.linspace(-r, r, 2*r+1) 88 | dy = torch.linspace(-r, r, 2*r+1) 89 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 90 | 91 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 92 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 93 | coords_lvl = centroid_lvl + delta_lvl 94 | 95 | corr = bilinear_sampler(corr, coords_lvl) 96 | corr = corr.view(batch, h1, w1, -1) 97 | out_pyramid.append(corr) 98 | 99 | out = torch.cat(out_pyramid, dim=-1) 100 | return out.permute(0, 3, 1, 2).contiguous().float() 101 | 102 | #@staticmethod 103 | def corr_compute(self, fmap1, fmap2, guid, reverse=True): 104 | batch, dim, ht, wd = fmap1.shape 105 | fmap1 = fmap1.view(batch, dim, ht*wd) 106 | fmap2 = fmap2.view(batch, dim, ht*wd) 107 | 108 | if reverse: 109 | corr = torch.matmul(fmap2.transpose(1,2), fmap1) / torch.sqrt(torch.tensor(dim).float()) 110 | corr = corr.view(batch, ht, wd, ht, wd) 111 | corr = self.nlf(corr, guid) 112 | corr = corr.permute(0, 3, 4, 1, 2) 113 | else: 114 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) / torch.sqrt(torch.tensor(dim).float()) 115 | corr = corr.view(batch, ht, wd, ht, wd) 116 | corr = self.nlf(corr, guid) 117 | 118 | return corr 119 | 120 | 121 | class AlternateCorrBlock: 122 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 123 | self.num_levels = num_levels 124 | self.radius = radius 125 | 126 | self.pyramid = [(fmap1, fmap2)] 127 | for i in range(self.num_levels): 128 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 129 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 130 | self.pyramid.append((fmap1, fmap2)) 131 | 132 | def __call__(self, coords): 133 | coords = coords.permute(0, 2, 3, 1) 134 | B, H, W, _ = coords.shape 135 | dim = self.pyramid[0][0].shape[1] 136 | 137 | corr_list = [] 138 | for i in range(self.num_levels): 139 | r = self.radius 140 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 141 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 142 | 143 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 144 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 145 | corr_list.append(corr.squeeze(1)) 146 | 147 | corr = torch.stack(corr_list, dim=1) 148 | corr = corr.reshape(B, -1, H, W) 149 | return corr / torch.sqrt(torch.tensor(dim).float()) 150 | 151 | class CorrBlock1D: 152 | def __init__(self, corr1, corr2, num_levels=4, radius=4): 153 | self.num_levels = num_levels 154 | self.radius = radius 155 | self.corr_pyramid1 = [] 156 | self.corr_pyramid2 = [] 157 | 158 | corr1 = corr1.permute(0,3,4,1,2) 159 | corr2 = corr2.permute(0,3,4,1,2) 160 | batch, h1, w1, dim, w2 = corr1.shape 161 | batch, h1, w1, dim, h2 = corr2.shape 162 | assert(corr1.shape[:-1] == corr2.shape[:-1]) 163 | assert(h1 == h2 and w1 == w2) 164 | 165 | #self.coords = coords_grid(batch, h2, w2).to(corr1.device) 166 | 167 | corr1 = corr1.reshape(batch*h1*w1, dim, 1, w2) 168 | corr2 = corr2.reshape(batch*h1*w1, dim, 1, h2) 169 | 170 | self.corr_pyramid1.append(corr1) 171 | self.corr_pyramid2.append(corr2) 172 | for i in range(self.num_levels): 173 | corr1 = F.avg_pool2d(corr1, [1,2], stride=[1,2]) 174 | self.corr_pyramid1.append(corr1) 175 | corr2 = F.avg_pool2d(corr2, [1,2], stride=[1,2]) 176 | self.corr_pyramid2.append(corr2) 177 | #print(corr1.shape, corr1.mean().item(), corr2.shape, corr2.mean().item()) 178 | def bilinear_sampler(self, img, coords, mode='bilinear', mask=False): 179 | """ Wrapper for grid_sample, uses pixel coordinates """ 180 | H, W = img.shape[-2:] 181 | xgrid, ygrid = coords.split([1,1], dim=-1) 182 | xgrid = 2*xgrid/(W-1) - 1 183 | assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem 184 | 185 | grid = torch.cat([xgrid, ygrid], dim=-1) 186 | img = F.grid_sample(img, grid, align_corners=True) 187 | 188 | if mask: 189 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 190 | return img, mask.float() 191 | 192 | return img 193 | 194 | def __call__(self, coords): 195 | coords_org = coords.clone() 196 | coords = coords_org[:, :1, :, :] 197 | coords = coords.permute(0, 2, 3, 1) 198 | r = self.radius 199 | batch, h1, w1, _ = coords.shape 200 | 201 | out_pyramid = [] 202 | for i in range(self.num_levels): 203 | corr = self.corr_pyramid1[i] 204 | dx = torch.linspace(-r, r, 2*r+1) 205 | dx = dx.view(1, 1, 2*r+1, 1).to(coords.device) 206 | x0 = dx + coords.reshape(batch*h1*w1, 1, 1, 1) / 2**i 207 | y0 = torch.zeros_like(x0) 208 | 209 | coords_lvl = torch.cat([x0,y0], dim=-1) 210 | coords_lvl = torch.clamp(coords_lvl, -1, 1) 211 | #print("corri:", corr.shape, corr.mean().item(), coords_lvl.shape, coords_lvl.mean().item()) 212 | corr = self.bilinear_sampler(corr, coords_lvl) 213 | #print("corri:", corr.shape, corr.mean().item()) 214 | corr = corr.view(batch, h1, w1, -1) 215 | #print("corri:", corr.shape, corr.mean().item()) 216 | out_pyramid.append(corr) 217 | 218 | out = torch.cat(out_pyramid, dim=-1) 219 | out1 = out.permute(0, 3, 1, 2).contiguous().float() 220 | 221 | coords = coords_org[:, 1:, :, :] 222 | coords = coords.permute(0, 2, 3, 1) 223 | r = self.radius 224 | batch, h1, w1, _ = coords.shape 225 | 226 | out_pyramid = [] 227 | for i in range(self.num_levels): 228 | corr = self.corr_pyramid2[i] 229 | dx = torch.linspace(-r, r, 2*r+1) 230 | dx = dx.view(1, 1, 2*r+1, 1).to(coords.device) 231 | x0 = dx + coords.reshape(batch*h1*w1, 1, 1, 1) / 2**i 232 | y0 = torch.zeros_like(x0) 233 | 234 | coords_lvl = torch.cat([x0, y0], dim=-1) 235 | corr = self.bilinear_sampler(corr, coords_lvl) 236 | corr = corr.view(batch, h1, w1, -1) 237 | out_pyramid.append(corr) 238 | 239 | out = torch.cat(out_pyramid, dim=-1) 240 | out2 = out.permute(0, 3, 1, 2).contiguous().float() 241 | return out1, out2 242 | -------------------------------------------------------------------------------- /core/cost_agg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | from libs.GANet.modules.GANet import DisparityRegression 5 | from libs.GANet.modules.GANet import MyNormalize 6 | from libs.GANet.modules.GANet import GetWeights, GetFilters 7 | from libs.GANet.modules.GANet import SGA 8 | from libs.GANet.modules.GANet import NLFIter 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | import numpy as np 12 | 13 | 14 | class DomainNorm2(nn.Module): 15 | def __init__(self, channel, l2=True): 16 | super(DomainNorm, self).__init__() 17 | self.normalize = nn.InstanceNorm2d(num_features=channel, affine=False) 18 | self.l2 = l2 19 | self.weight = nn.Parameter(torch.ones(1,channel,1,1)) 20 | self.bias = nn.Parameter(torch.zeros(1,channel,1,1)) 21 | self.weight.requires_grad = True 22 | self.bias.requires_grad = True 23 | def forward(self, x): 24 | x = self.normalize(x) 25 | if self.l2: 26 | x = F.normalize(x, p=2, dim=1) 27 | return x * self.weight + self.bias 28 | 29 | class DomainNorm(nn.Module): 30 | def __init__(self, channel, l2=True): 31 | super(DomainNorm, self).__init__() 32 | self.normalize = nn.InstanceNorm2d(num_features=channel, affine=True) 33 | self.l2 = l2 34 | def forward(self, x): 35 | if self.l2: 36 | x = F.normalize(x, p=2, dim=1) 37 | x = self.normalize(x) 38 | return x 39 | 40 | 41 | class BasicConv(nn.Module): 42 | 43 | def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, l2=True, relu=True, **kwargs): 44 | super(BasicConv, self).__init__() 45 | # print(in_channels, out_channels, deconv, is_3d, bn, relu, kwargs) 46 | self.relu = relu 47 | self.use_bn = bn 48 | self.l2 = l2 49 | if is_3d: 50 | if deconv: 51 | self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs) 52 | else: 53 | self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) 54 | self.bn = nn.BatchNorm3d(out_channels) 55 | else: 56 | if deconv: 57 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs) 58 | else: 59 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 60 | self.bn = DomainNorm(channel=out_channels, l2=self.l2) 61 | # self.bn = nn.InstanceNorm2d(out_channels) 62 | 63 | def forward(self, x): 64 | x = self.conv(x) 65 | if self.use_bn: 66 | x = self.bn(x) 67 | if self.relu: 68 | x = F.relu(x, inplace=True) 69 | return x 70 | 71 | 72 | class Conv2x(nn.Module): 73 | 74 | def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, bn=True, relu=True, kernel=None): 75 | super(Conv2x, self).__init__() 76 | self.concat = concat 77 | if kernel is not None: 78 | self.kernel = kernel 79 | #elif deconv and is_3d: 80 | # kernel = (4, 4, 4) 81 | elif deconv: 82 | kernel = 4 83 | else: 84 | kernel = 3 85 | self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=2, padding=1) 86 | 87 | if self.concat: 88 | self.conv2 = BasicConv(out_channels*2, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1) 89 | else: 90 | self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1) 91 | 92 | def forward(self, x, rem): 93 | x = self.conv1(x) 94 | #print(x.shape, rem.shape) 95 | assert(x.size() == rem.size()),[x.size(), rem.size()] 96 | if self.concat: 97 | x = torch.cat((x, rem), 1) 98 | else: 99 | x = x + rem 100 | x = self.conv2(x) 101 | return x 102 | class SGABlock(nn.Module): 103 | def __init__(self, channels=32, refine=False): 104 | super(SGABlock, self).__init__() 105 | self.refine = refine 106 | if self.refine: 107 | self.bn_relu = nn.Sequential(nn.BatchNorm3d(channels), 108 | nn.ReLU(inplace=True)) 109 | self.conv_refine = BasicConv(channels, channels, is_3d=True, kernel_size=3, padding=1, relu=False) 110 | # self.conv_refine1 = BasicConv(8, 8, is_3d=True, kernel_size=1, padding=1) 111 | else: 112 | self.bn = nn.BatchNorm3d(channels) 113 | self.SGA=SGA() 114 | self.relu = nn.ReLU(inplace=True) 115 | def forward(self, x, g): 116 | rem = x 117 | k1, k2, k3, k4 = torch.split(g, (5, 5, 5, 5), 1) 118 | k1 = F.normalize(k1, p=1, dim=1) 119 | k2 = F.normalize(k2, p=1, dim=1) 120 | k3 = F.normalize(k3, p=1, dim=1) 121 | k4 = F.normalize(k4, p=1, dim=1) 122 | x = self.SGA(x, k1, k2, k3, k4) 123 | if self.refine: 124 | x = self.bn_relu(x) 125 | x = self.conv_refine(x) 126 | else: 127 | x = self.bn(x) 128 | assert(x.size() == rem.size()) 129 | x += rem 130 | return self.relu(x) 131 | 132 | def freeze_layers(model): 133 | for child in model.children(): 134 | for param in child.parameters(): 135 | param.requires_grad = False 136 | for param in model.parameters(): 137 | param.requires_grad = False 138 | 139 | class ShiftRegression(nn.Module): 140 | def __init__(self, max_shift=192): 141 | super(ShiftRegression, self).__init__() 142 | self.max_shift = max_shift 143 | # self.disp = Variable(torch.Tensor(np.reshape(np.array(range(self.maxdisp)),[1,self.maxdisp,1,1])).cuda(), requires_grad=False) 144 | 145 | def forward(self, x, max_shift=None): 146 | if max_shift is not None: 147 | self.max_shift = max_shift 148 | assert(x.is_contiguous() == True) 149 | with torch.cuda.device_of(x): 150 | shift = Variable(torch.Tensor(np.reshape(np.array(range(-self.max_shift, self.max_shift+1)),[1,self.max_shift*2+1,1,1])).cuda(), requires_grad=False) 151 | shift = shift.repeat(x.size()[0],1,x.size()[2],x.size()[3]) 152 | out = torch.sum(x*shift,dim=1,keepdim=True) 153 | return out 154 | class ShiftEstimate(nn.Module): 155 | 156 | def __init__(self, max_shift=192, InChannel=24): 157 | super(ShiftEstimate, self).__init__() 158 | self.max_shift = int(max_shift/2) 159 | self.softmax = nn.Softmax(dim=1) 160 | self.regression = ShiftRegression(max_shift=self.max_shift+1) 161 | self.conv3d_2d = nn.Conv3d(InChannel, 1, (3, 3, 3), (1, 1, 1), (1, 1, 1), bias=True) 162 | #self.upsample_cost = FilterUpsample() 163 | def upsample_flow(self, flow, mask): 164 | """ Upsample flow field [H/3, W/3, 2] -> [H, W, 2] using convex combination """ 165 | N, _, H, W = flow.shape 166 | mask = mask.view(N, 1, 9, 3, 3, H, W) 167 | mask = torch.softmax(mask, dim=2) 168 | 169 | up_flow = F.unfold(3 * flow, [3,3], padding=1) 170 | up_flow = up_flow.view(N, 1, 9, 1, 1, H, W) 171 | 172 | up_flow = torch.sum(mask * up_flow, dim=2) 173 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 174 | return up_flow.reshape(N, 1, 3*H, 3*W) 175 | 176 | def forward(self, x): 177 | #N, _, _, H, W = g.size() 178 | #assert (x.size(3)==H and x.size(4)==W) 179 | #x = self.upsample_cost(x, g) 180 | #print(x.size(), g.size()) 181 | x = F.interpolate(self.conv3d_2d(x), [self.max_shift*2+1, x.size()[3]*4, x.size()[4]*4], mode='trilinear', align_corners=True) 182 | x = torch.squeeze(x, 1) 183 | x = self.softmax(x) 184 | x = self.regression(x) 185 | x = F.interpolate(x, [x.size()[2]*2, x.size()[3]*2], mode='bilinear', align_corners=True) 186 | return x * 2 187 | 188 | class ShiftEstimate2(nn.Module): 189 | 190 | def __init__(self, max_shift=100, InChannel=24): 191 | super(ShiftEstimate2, self).__init__() 192 | self.max_shift = int(max_shift//4) 193 | self.softmax = nn.Softmax(dim=1) 194 | self.regression = ShiftRegression() 195 | self.conv3d_2d = nn.Conv3d(InChannel, 1, kernel_size=3, stride=1, padding=1, bias=True) 196 | #self.upsample_cost = FilterUpsample() 197 | 198 | def forward(self, x, max_shift=None): 199 | if max_shift is not None: 200 | assert ((max_shift//8 * 2 + 1) == x.shape[2]),[x.shape, max_shift, max_shift//8*2+1] 201 | #assert(x.size() == rem.size()),[x.size(), rem.size()] 202 | self.max_shift = max_shift // 4 203 | x = F.interpolate(self.conv3d_2d(x), [self.max_shift*2+1, x.size()[3]*2, x.size()[4]*2], mode='trilinear', align_corners=True) 204 | # x = self.conv3d_2d(x) 205 | x = torch.squeeze(x, 1) 206 | 207 | x = self.softmax(x) 208 | x = self.regression(x, self.max_shift) 209 | x = F.interpolate(x, [x.size()[2]*4, x.size()[3]*4], mode='bilinear', align_corners=True) 210 | 211 | return x * 4 212 | 213 | 214 | class CostAggregation(nn.Module): 215 | def __init__(self, max_shift=400, in_channel=8): 216 | super(CostAggregation, self).__init__() 217 | self.max_shift = max_shift 218 | self.in_channel = in_channel #t(self.max_shift / 6) * 2 + 1 219 | self.inner_channel = 8 220 | self.conv0 = BasicConv(self.in_channel, self.inner_channel, is_3d=True, kernel_size=3, padding=1, relu=True) 221 | 222 | self.conv1a = BasicConv(self.inner_channel, self.inner_channel*2, is_3d=True, kernel_size=3, stride=2, padding=1) 223 | self.conv2a = BasicConv(self.inner_channel*2, self.inner_channel*4, is_3d=True, kernel_size=3, stride=2, padding=1) 224 | self.conv3a = BasicConv(self.inner_channel*4, self.inner_channel*6, is_3d=True, kernel_size=3, stride=2, padding=1) 225 | 226 | self.deconv1a = Conv2x(self.inner_channel*2, self.inner_channel, deconv=True, is_3d=True, relu=True) 227 | self.deconv2a = Conv2x(self.inner_channel*4, self.inner_channel*2, deconv=True, is_3d=True) 228 | self.deconv3a = Conv2x(self.inner_channel*6, self.inner_channel*4, deconv=True, is_3d=True) 229 | 230 | self.conv1b = BasicConv(self.inner_channel, self.inner_channel*2, is_3d=True, kernel_size=3, stride=2, padding=1) 231 | self.conv2b = BasicConv(self.inner_channel*2, self.inner_channel*4, is_3d=True, kernel_size=3, stride=2, padding=1) 232 | self.conv3b = BasicConv(self.inner_channel*4, self.inner_channel*6, is_3d=True, kernel_size=3, stride=2, padding=1) 233 | 234 | self.deconv1b = Conv2x(self.inner_channel*2, self.inner_channel, deconv=True, is_3d=True, relu=True, kernel=(3,4,4)) 235 | self.deconv2b = Conv2x(self.inner_channel*4, self.inner_channel*2, deconv=True, is_3d=True, kernel=(3,4,4)) 236 | self.deconv3b = Conv2x(self.inner_channel*6, self.inner_channel*4, deconv=True, is_3d=True, kernel=(3,4,4)) 237 | self.shift0= ShiftEstimate2(self.max_shift, self.inner_channel) 238 | self.shift1= ShiftEstimate2(self.max_shift, self.inner_channel) 239 | self.shift2= ShiftEstimate2(self.max_shift, self.inner_channel) 240 | self.sga1 = SGABlock(channels=self.inner_channel, refine=True) 241 | self.sga2 = SGABlock(channels=self.inner_channel, refine=True) 242 | self.sga3 = SGABlock(channels=self.inner_channel, refine=True) 243 | self.sga11 = SGABlock(channels=self.inner_channel*2, refine=True) 244 | self.sga12 = SGABlock(channels=self.inner_channel*2, refine=True) 245 | self.corr_output = BasicConv(self.inner_channel, 1, is_3d=True, kernel_size=3, padding=1, relu=False) 246 | self.corr2cost = Corr2Cost() 247 | def forward(self, x, g, max_shift=400, is_ux=True): 248 | x = self.conv0(x) 249 | x = self.sga1(x, g['sg1']) 250 | rem0 = x 251 | 252 | if self.training: 253 | cost = self.corr2cost(x, max_shift//8, is_ux) 254 | shift0 = self.shift0(cost, max_shift) 255 | 256 | x = self.conv1a(x) 257 | x = self.sga11(x, g['sg11']) 258 | rem1 = x 259 | x = self.conv2a(x) 260 | rem2 = x 261 | x = self.conv3a(x) 262 | rem3 = x 263 | 264 | x = self.deconv3a(x, rem2) 265 | rem2 = x 266 | x = self.deconv2a(x, rem1) 267 | x = self.sga12(x, g['sg12']) 268 | rem1 = x 269 | x = self.deconv1a(x, rem0) 270 | x = self.sga2(x, g['sg2']) 271 | rem0 = x 272 | cost = self.corr2cost(x, max_shift//8, is_ux) 273 | if self.training: 274 | shift1 = self.shift1(cost, max_shift) 275 | corr = self.corr_output(x) 276 | rem0 = cost 277 | x = self.conv1b(cost) 278 | rem1 = x 279 | x = self.conv2b(x) 280 | rem2 = x 281 | x = self.conv3b(x) 282 | x = self.deconv3b(x, rem2) 283 | x = self.deconv2b(x, rem1) 284 | x = self.deconv1b(x, rem0) 285 | x = self.sga3(x, g['sg3']) 286 | shift2 = self.shift2(x, max_shift) 287 | if self.training: 288 | return shift0, shift1, shift2, corr 289 | else: 290 | return shift2, corr 291 | 292 | class Corr2Cost(nn.Module): 293 | def __init__(self): 294 | super(Corr2Cost, self).__init__() 295 | def coords_grid(self, batch, ht, wd, device): 296 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 297 | coords = torch.stack(coords[::-1], dim=0).float() 298 | return coords[None].repeat(batch, 1, 1, 1) 299 | def bilinear_sampler(self, img, coords, mode='bilinear', mask=False): 300 | """ Wrapper for grid_sample, uses pixel coordinates """ 301 | H, W = img.shape[-2:] 302 | xgrid, ygrid = coords.split([1,1], dim=-1) 303 | xgrid = 2*xgrid/(W-1) - 1 304 | assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem 305 | 306 | grid = torch.cat([xgrid, ygrid], dim=-1) 307 | img = F.grid_sample(img, grid, align_corners=True) 308 | 309 | if mask: 310 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 311 | return img, mask.float() 312 | 313 | return img 314 | 315 | def forward(self, corr, maxdisp=50, is_ux=True): 316 | batch, dim, d, h, w = corr.shape 317 | corr = corr.permute(0, 3, 4, 1, 2).reshape(batch*h*w, dim, 1, d) 318 | with torch.no_grad(): 319 | coords = self.coords_grid(batch, h, w, corr.device) 320 | if is_ux: 321 | coords = coords[:, :1, :, :] 322 | else: 323 | coords = coords[:, 1:, :, :] 324 | dx = torch.linspace(-maxdisp, maxdisp, maxdisp*2+1) 325 | dx = dx.view(1, 1, 2*maxdisp+1, 1).to(corr.device) 326 | x0 = dx + coords.reshape(batch*h*w, 1, 1, 1) 327 | y0 = torch.zeros_like(x0) 328 | # if is_ux: 329 | coords_lvl = torch.cat([x0,y0], dim=-1) 330 | # else: 331 | # coords_lvl = torch.cat([y0, x0], dim=-1) 332 | corr = self.bilinear_sampler(corr, coords_lvl) 333 | #print(corr.shape) 334 | corr = corr.view(batch, h, w, dim, maxdisp*2+1) 335 | corr = corr.permute(0, 3, 4, 1, 2).contiguous().float() 336 | return corr 337 | 338 | -------------------------------------------------------------------------------- /core/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | 8 | import os 9 | import math 10 | import random 11 | from glob import glob 12 | import os.path as osp 13 | 14 | from utils import frame_utils 15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | 18 | class FlowDataset(data.Dataset): 19 | def __init__(self, aug_params=None, sparse=False, vkitti=False): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | self.vkitti = vkitti 23 | self.normalized = False 24 | if aug_params is not None: 25 | if sparse: 26 | self.augmentor = SparseFlowAugmentor(**aug_params, normalized=self.normalized) 27 | else: 28 | self.augmentor = FlowAugmentor(**aug_params, normalized=self.normalized) 29 | 30 | self.is_test = False 31 | self.init_seed = False 32 | self.flow_list = [] 33 | self.image_list = [] 34 | self.extra_info = [] 35 | def normalize(self, img): 36 | img = np.float32(img) 37 | r = img[:, :, 0] 38 | g = img[:, :, 1] 39 | b = img[:, :, 2] 40 | 41 | img[:, :, 0] = (r - np.mean(r[:])) / (np.std(r[:]) + 1e-6) 42 | img[:, :, 1] = (g - np.mean(g[:])) / (np.std(g[:]) + 1e-6) 43 | img[:, :, 2] = (b - np.mean(b[:])) / (np.std(b[:]) + 1e-6) 44 | return img 45 | 46 | def __getitem__(self, index): 47 | 48 | if self.is_test: 49 | img1 = frame_utils.read_gen(self.image_list[index][0]) 50 | img2 = frame_utils.read_gen(self.image_list[index][1]) 51 | img1 = np.array(img1).astype(np.uint8)[..., :3] 52 | img2 = np.array(img2).astype(np.uint8)[..., :3] 53 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 54 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 55 | return img1, img2, self.extra_info[index] 56 | 57 | if not self.init_seed: 58 | worker_info = torch.utils.data.get_worker_info() 59 | if worker_info is not None: 60 | torch.manual_seed(worker_info.id) 61 | np.random.seed(worker_info.id) 62 | random.seed(worker_info.id) 63 | self.init_seed = True 64 | 65 | index = index % len(self.image_list) 66 | valid = None 67 | if self.sparse: 68 | if self.vkitti: 69 | flow, valid = frame_utils.readFlowVKITTI(self.flow_list[index]) 70 | #flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 71 | else: 72 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 73 | else: 74 | flow = frame_utils.read_gen(self.flow_list[index]) 75 | 76 | img1 = frame_utils.read_gen(self.image_list[index][0]) 77 | img2 = frame_utils.read_gen(self.image_list[index][1]) 78 | 79 | flow = np.array(flow).astype(np.float32) 80 | img1 = np.array(img1).astype(np.uint8) 81 | img2 = np.array(img2).astype(np.uint8) 82 | 83 | # grayscale images 84 | if len(img1.shape) == 2: 85 | img1 = np.tile(img1[...,None], (1, 1, 3)) 86 | img2 = np.tile(img2[...,None], (1, 1, 3)) 87 | else: 88 | img1 = img1[..., :3] 89 | img2 = img2[..., :3] 90 | 91 | if self.augmentor is not None: 92 | if self.sparse: 93 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 94 | else: 95 | img1, img2, flow = self.augmentor(img1, img2, flow) 96 | ##### normalize 97 | elif self.normalized: 98 | img1 = self.normalize(img1) 99 | img2 = self.normalize(img2) 100 | 101 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 102 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 103 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 104 | 105 | if valid is not None: 106 | valid = torch.from_numpy(valid) 107 | else: 108 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 109 | 110 | return img1, img2, flow, valid.float() 111 | 112 | 113 | def __rmul__(self, v): 114 | self.flow_list = v * self.flow_list 115 | self.image_list = v * self.image_list 116 | return self 117 | 118 | def __len__(self): 119 | return len(self.image_list) 120 | 121 | 122 | class MpiSintel(FlowDataset): 123 | def __init__(self, aug_params=None, split='training', root='/export/work/feihu/flow/Sintel', dstype='clean'): 124 | super(MpiSintel, self).__init__(aug_params) 125 | flow_root = osp.join(root, split, 'flow') 126 | image_root = osp.join(root, split, dstype) 127 | 128 | if split == 'test': 129 | self.is_test = True 130 | 131 | for scene in os.listdir(image_root): 132 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 133 | for i in range(len(image_list)-1): 134 | self.image_list += [ [image_list[i], image_list[i+1]] ] 135 | self.extra_info += [ (scene, i) ] # scene and frame_id 136 | 137 | if split != 'test': 138 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 139 | 140 | 141 | class FlyingChairs(FlowDataset): 142 | def __init__(self, aug_params=None, split='train', root='/export/work/feihu/flow/FlyingChairs_release/data'): 143 | super(FlyingChairs, self).__init__(aug_params) 144 | 145 | images = sorted(glob(osp.join(root, '*.ppm'))) 146 | flows = sorted(glob(osp.join(root, '*.flo'))) 147 | assert (len(images)//2 == len(flows)) 148 | 149 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 150 | for i in range(len(flows)): 151 | xid = split_list[i] 152 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 153 | self.flow_list += [ flows[i] ] 154 | self.image_list += [ [images[2*i], images[2*i+1]] ] 155 | 156 | 157 | class FlyingThings3D(FlowDataset): 158 | def __init__(self, aug_params=None, root='/export/work/feihu/flow/SceneFlow', dstype='frames_cleanpass'): 159 | super(FlyingThings3D, self).__init__(aug_params) 160 | 161 | for cam in ['left']: 162 | for direction in ['into_future', 'into_past']: 163 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 164 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 165 | 166 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 167 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 168 | 169 | for idir, fdir in zip(image_dirs, flow_dirs): 170 | images = sorted(glob(osp.join(idir, '*.png')) ) 171 | flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 172 | for i in range(len(flows)-1): 173 | if direction == 'into_future': 174 | self.image_list += [ [images[i], images[i+1]] ] 175 | self.flow_list += [ flows[i] ] 176 | elif direction == 'into_past': 177 | self.image_list += [ [images[i+1], images[i]] ] 178 | self.flow_list += [ flows[i+1] ] 179 | 180 | 181 | class KITTI(FlowDataset): 182 | def __init__(self, aug_params=None, split='training', root='/export/work/feihu/KITTI'): 183 | super(KITTI, self).__init__(aug_params, sparse=True) 184 | if split == 'testing': 185 | self.is_test = True 186 | 187 | root = osp.join(root, split) 188 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 189 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 190 | 191 | for img1, img2 in zip(images1, images2): 192 | frame_id = img1.split('/')[-1] 193 | self.extra_info += [ [frame_id] ] 194 | self.image_list += [ [img1, img2] ] 195 | 196 | if split == 'training': 197 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 198 | class KITTI2012(FlowDataset): 199 | def __init__(self, aug_params=None, split='training', root='/export/work/feihu/kitti2012'): 200 | super(KITTI2012, self).__init__(aug_params, sparse=True) 201 | if split == 'testing': 202 | self.is_test = True 203 | 204 | root = osp.join(root, split) 205 | images1 = sorted(glob(osp.join(root, 'colored_0/*_10.png'))) 206 | images2 = sorted(glob(osp.join(root, 'colored_0/*_11.png'))) 207 | 208 | for img1, img2 in zip(images1, images2): 209 | frame_id = img1.split('/')[-1] 210 | self.extra_info += [ [frame_id] ] 211 | self.image_list += [ [img1, img2] ] 212 | 213 | if split == 'training': 214 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 215 | 216 | 217 | class VKITTI(FlowDataset): 218 | def __init__(self, aug_params=None, root='/export/work/feihu/vkitti', direction='forward', dstype='clone'): 219 | super(VKITTI, self).__init__(aug_params, sparse=True, vkitti=True) 220 | 221 | for cam in ['Camera_0', 'Camera_1']: 222 | #for direction in ['forwardFlow', 'backwardFlow']: 223 | for direction in ['forwardFlow']: 224 | #for dstype in ['15-deg-left','15-deg-right','30-deg-left','30-deg-right','clone','fog','morning','overcast','rain', 'sunset']: 225 | for dstype in ['15-deg-left','15-deg-right','30-deg-left','30-deg-right','clone','morning','overcast', 'sunset']: 226 | image_dirs = sorted(glob(osp.join(root, 'Scene*/'+ dstype +'/frames/rgb'))) 227 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 228 | 229 | flow_dirs = sorted(glob(osp.join(root, 'correct/Scene*/'+ dstype +'/frames'))) 230 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 231 | 232 | for idir, fdir in zip(image_dirs, flow_dirs): 233 | images = sorted(glob(osp.join(idir, '*.jpg')) ) 234 | flows = sorted(glob(osp.join(fdir, '*.png')) ) 235 | for i in range(len(flows)): 236 | if direction == 'forwardFlow': 237 | self.image_list += [ [images[i], images[i+1]] ] 238 | self.flow_list += [ flows[i] ] 239 | elif direction == 'backwardFlow': 240 | self.image_list += [ [images[i+1], images[i]] ] 241 | self.flow_list += [ flows[i] ] 242 | 243 | class HD1K(FlowDataset): 244 | def __init__(self, aug_params=None, root='/export/work/feihu/flow/HD1K'): 245 | super(HD1K, self).__init__(aug_params, sparse=True) 246 | 247 | seq_ix = 0 248 | while 1: 249 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 250 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 251 | 252 | if len(flows) == 0: 253 | break 254 | 255 | for i in range(len(flows)-1): 256 | self.flow_list += [flows[i]] 257 | self.image_list += [ [images[i], images[i+1]] ] 258 | 259 | seq_ix += 1 260 | class Cityscapes(FlowDataset): 261 | def __init__(self, aug_params=None, root='/export/work/feihu/cityscapes'): 262 | super(Cityscapes, self).__init__(aug_params, sparse=True) 263 | 264 | seq_ix = 0 265 | self.is_test=True 266 | images = sorted(glob(os.path.join(root, '*.png'))) 267 | for i in range(len(images)-1): 268 | frame_id = images[i].split('/')[-1] 269 | self.extra_info += [ [frame_id] ] 270 | self.flow_list += [images[i]] 271 | self.image_list += [ [images[i], images[i+1]] ] 272 | 273 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): 274 | """ Create the data loader for the corresponding trainign set """ 275 | 276 | if args.stage == 'chairs': 277 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 278 | train_dataset = FlyingChairs(aug_params, split='training') 279 | 280 | elif args.stage == 'things': 281 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 282 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 283 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 284 | train_dataset = clean_dataset + final_dataset 285 | 286 | elif args.stage == 'sintel': 287 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 288 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 289 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 290 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 291 | 292 | if TRAIN_DS == 'C+T+K+S+H': 293 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}, root="/export/work/feihu/kitti2015") 294 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 295 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things 296 | 297 | elif TRAIN_DS == 'C+T+K/S': 298 | train_dataset = 100*sintel_clean + 100*sintel_final + things 299 | else: 300 | train_dataset = sintel_clean + sintel_final 301 | 302 | elif args.stage == 'vkitti': 303 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 304 | vkitti = VKITTI(aug_params) 305 | train_dataset = vkitti 306 | 307 | elif args.stage == 'kitti': 308 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 309 | train_dataset = 200 * KITTI(aug_params, split='training', root="/export/work/feihu/kitti2015") 310 | elif args.stage == 'kitti2012': 311 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 312 | train_dataset = 200 * KITTI2012(aug_params, split='training', root="/export/work/feihu/kitti2012") 313 | print('Training with %d image pairs' % len(train_dataset)) 314 | return train_dataset 315 | 316 | 317 | -------------------------------------------------------------------------------- /core/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.utils import DomainNorm 5 | #from apex.parallel import SyncBatchNorm as BatchNorm 6 | #import torch.nn.BatchNorm2d as BatchNorm 7 | 8 | 9 | 10 | class ResidualBlock(nn.Module): 11 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 12 | super(ResidualBlock, self).__init__() 13 | 14 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 16 | self.relu = nn.ReLU(inplace=True) 17 | 18 | num_groups = planes // 8 19 | 20 | if norm_fn == 'group': 21 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 22 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 23 | if not stride == 1: 24 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 25 | 26 | elif norm_fn == 'batch': 27 | self.norm1 = nn.BatchNorm2d(planes) 28 | self.norm2 = nn.BatchNorm2d(planes) 29 | if not stride == 1: 30 | self.norm3 = nn.BatchNorm2d(planes) 31 | 32 | elif norm_fn == 'instance': 33 | self.norm1 = nn.InstanceNorm2d(planes) 34 | self.norm2 = nn.InstanceNorm2d(planes) 35 | if not stride == 1: 36 | self.norm3 = nn.InstanceNorm2d(planes) 37 | elif norm_fn == 'domain': 38 | self.norm1 = DomainNorm(planes) 39 | self.norm2 = DomainNorm(planes) 40 | if not stride == 1: 41 | self.norm3 = DomainNorm(planes) 42 | elif norm_fn == 'sync': 43 | self.norm1 = nn.BatchNorm2d(planes) 44 | self.norm2 = nn.BatchNorm2d(planes) 45 | if not stride == 1: 46 | self.norm3 = nn.BatchNorm2d(planes) 47 | 48 | elif norm_fn == 'none': 49 | self.norm1 = nn.Sequential() 50 | self.norm2 = nn.Sequential() 51 | if not stride == 1: 52 | self.norm3 = nn.Sequential() 53 | 54 | if stride == 1: 55 | self.downsample = None 56 | 57 | else: 58 | self.downsample = nn.Sequential( 59 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 60 | 61 | 62 | def forward(self, x): 63 | y = x 64 | y = self.relu(self.norm1(self.conv1(y))) 65 | y = self.relu(self.norm2(self.conv2(y))) 66 | 67 | if self.downsample is not None: 68 | x = self.downsample(x) 69 | 70 | return self.relu(x+y) 71 | 72 | 73 | 74 | class BottleneckBlock(nn.Module): 75 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 76 | super(BottleneckBlock, self).__init__() 77 | 78 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 79 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 80 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 81 | self.relu = nn.ReLU(inplace=True) 82 | 83 | num_groups = planes // 8 84 | 85 | if norm_fn == 'group': 86 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 87 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 88 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 89 | if not stride == 1: 90 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 91 | 92 | elif norm_fn == 'batch': 93 | self.norm1 = nn.BatchNorm2d(planes//4) 94 | self.norm2 = nn.BatchNorm2d(planes//4) 95 | self.norm3 = nn.BatchNorm2d(planes) 96 | if not stride == 1: 97 | self.norm4 = nn.BatchNorm2d(planes) 98 | elif norm_fn == 'sync': 99 | self.norm1 = nn.BatchNorm2d(planes//4) 100 | self.norm2 = nn.BatchNorm2d(planes//4) 101 | self.norm3 = nn.BatchNorm2d(planes) 102 | if not stride == 1: 103 | self.norm4 = nn.BatchNorm2d(planes) 104 | 105 | elif norm_fn == 'instance': 106 | self.norm1 = nn.InstanceNorm2d(planes//4) 107 | self.norm2 = nn.InstanceNorm2d(planes//4) 108 | self.norm3 = nn.InstanceNorm2d(planes) 109 | if not stride == 1: 110 | self.norm4 = nn.InstanceNorm2d(planes) 111 | elif norm_fn == 'domain': 112 | self.norm1 = DomainNorm(planes//4) 113 | self.norm2 = DomainNorm(planes//4) 114 | self.norm3 = DomainNorm(planes) 115 | if not stride == 1: 116 | self.norm4 = DomainNorm(planes) 117 | 118 | elif norm_fn == 'none': 119 | self.norm1 = nn.Sequential() 120 | self.norm2 = nn.Sequential() 121 | self.norm3 = nn.Sequential() 122 | if not stride == 1: 123 | self.norm4 = nn.Sequential() 124 | 125 | if stride == 1: 126 | self.downsample = None 127 | 128 | else: 129 | self.downsample = nn.Sequential( 130 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 131 | 132 | 133 | def forward(self, x): 134 | y = x 135 | y = self.relu(self.norm1(self.conv1(y))) 136 | y = self.relu(self.norm2(self.conv2(y))) 137 | y = self.relu(self.norm3(self.conv3(y))) 138 | 139 | if self.downsample is not None: 140 | x = self.downsample(x) 141 | 142 | return self.relu(x+y) 143 | 144 | class BasicEncoder(nn.Module): 145 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 146 | super(BasicEncoder, self).__init__() 147 | self.norm_fn = norm_fn 148 | 149 | if self.norm_fn == 'group': 150 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 151 | 152 | elif self.norm_fn == 'batch': 153 | self.norm1 = nn.BatchNorm2d(64) 154 | elif self.norm_fn == 'sync': 155 | self.norm1 = nn.BatchNorm2d(64) 156 | 157 | elif self.norm_fn == 'instance': 158 | self.norm1 = nn.InstanceNorm2d(64) 159 | 160 | elif norm_fn == 'domain': 161 | self.norm1 = DomainNorm(64) 162 | 163 | elif self.norm_fn == 'none': 164 | self.norm1 = nn.Sequential() 165 | 166 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 167 | self.relu1 = nn.ReLU(inplace=True) 168 | 169 | self.in_planes = 64 170 | self.layer1 = self._make_layer(64, stride=1) 171 | self.layer2 = self._make_layer(96, stride=2) 172 | self.layer3 = self._make_layer(128, stride=2) 173 | 174 | # output convolution 175 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 176 | 177 | self.dropout = None 178 | if dropout > 0: 179 | self.dropout = nn.Dropout2d(p=dropout) 180 | 181 | for m in self.modules(): 182 | if isinstance(m, nn.Conv2d): 183 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 184 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 185 | if m.weight is not None: 186 | nn.init.constant_(m.weight, 1) 187 | if m.bias is not None: 188 | nn.init.constant_(m.bias, 0) 189 | 190 | def _make_layer(self, dim, stride=1): 191 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 192 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 193 | layers = (layer1, layer2) 194 | 195 | self.in_planes = dim 196 | return nn.Sequential(*layers) 197 | 198 | 199 | def forward(self, x): 200 | 201 | # if input is list, combine batch dimension 202 | is_list = isinstance(x, tuple) or isinstance(x, list) 203 | if is_list: 204 | batch_dim = x[0].shape[0] 205 | x = torch.cat(x, dim=0) 206 | 207 | x = self.conv1(x) 208 | x = self.norm1(x) 209 | x = self.relu1(x) 210 | 211 | #print(x.shape, x.mean().item()) 212 | x = self.layer1(x) 213 | #print(x.shape, x.mean().item()) 214 | x = self.layer2(x) 215 | #print(x.shape, x.mean().item()) 216 | x = self.layer3(x) 217 | #print(x.shape, x.mean().item()) 218 | 219 | x = self.conv2(x) 220 | #print(x.shape, x.mean().item()) 221 | 222 | if self.training and self.dropout is not None: 223 | x = self.dropout(x) 224 | 225 | if is_list: 226 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 227 | 228 | return x 229 | 230 | 231 | class SmallEncoder(nn.Module): 232 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 233 | super(SmallEncoder, self).__init__() 234 | self.norm_fn = norm_fn 235 | 236 | if self.norm_fn == 'group': 237 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 238 | 239 | elif self.norm_fn == 'batch': 240 | self.norm1 = nn.BatchNorm2d(32) 241 | 242 | elif self.norm_fn == 'instance': 243 | self.norm1 = nn.InstanceNorm2d(32) 244 | 245 | elif self.norm_fn == 'none': 246 | self.norm1 = nn.Sequential() 247 | 248 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 249 | self.relu1 = nn.ReLU(inplace=True) 250 | 251 | self.in_planes = 32 252 | self.layer1 = self._make_layer(32, stride=1) 253 | self.layer2 = self._make_layer(64, stride=2) 254 | self.layer3 = self._make_layer(96, stride=2) 255 | 256 | self.dropout = None 257 | if dropout > 0: 258 | self.dropout = nn.Dropout2d(p=dropout) 259 | 260 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 261 | 262 | for m in self.modules(): 263 | if isinstance(m, nn.Conv2d): 264 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 265 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 266 | if m.weight is not None: 267 | nn.init.constant_(m.weight, 1) 268 | if m.bias is not None: 269 | nn.init.constant_(m.bias, 0) 270 | 271 | def _make_layer(self, dim, stride=1): 272 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 273 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 274 | layers = (layer1, layer2) 275 | 276 | self.in_planes = dim 277 | return nn.Sequential(*layers) 278 | 279 | 280 | def forward(self, x): 281 | 282 | # if input is list, combine batch dimension 283 | is_list = isinstance(x, tuple) or isinstance(x, list) 284 | if is_list: 285 | batch_dim = x[0].shape[0] 286 | x = torch.cat(x, dim=0) 287 | 288 | x = self.conv1(x) 289 | x = self.norm1(x) 290 | x = self.relu1(x) 291 | print(x.shape, x.mean().item()) 292 | 293 | x = self.layer1(x) 294 | print(x.shape, x.mean().item()) 295 | x = self.layer2(x) 296 | print(x.shape, x.mean().item()) 297 | x = self.layer3(x) 298 | print(x.shape, x.mean().item()) 299 | x = self.conv2(x) 300 | print(x.shape, x.mean().item()) 301 | 302 | if self.training and self.dropout is not None: 303 | x = self.dropout(x) 304 | 305 | if is_list: 306 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 307 | 308 | return x 309 | -------------------------------------------------------------------------------- /core/sepflow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from update import BasicUpdateBlock, SmallUpdateBlock 7 | from extractor import BasicEncoder, SmallEncoder 8 | from corr import CorrBlock, CorrBlock1D 9 | from cost_agg import CostAggregation 10 | from utils.utils import bilinear_sampler, coords_grid, upflow8 11 | 12 | try: 13 | autocast = torch.cuda.amp.autocast 14 | except: 15 | # dummy autocast for PyTorch < 1.6 16 | class autocast: 17 | def __init__(self, enabled): 18 | pass 19 | def __enter__(self): 20 | pass 21 | def __exit__(self, *args): 22 | pass 23 | class Guidance(nn.Module): 24 | def __init__(self, channels=32, refine=False): 25 | super(Guidance, self).__init__() 26 | self.bn_relu = nn.Sequential(nn.InstanceNorm2d(channels), 27 | nn.ReLU(inplace=True)) 28 | self.conv0 = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1), 29 | nn.InstanceNorm2d(16), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(16, int(channels/4), kernel_size=3, stride=2, padding=1), 32 | nn.InstanceNorm2d(int(channels/4)), 33 | nn.ReLU(inplace=True), 34 | nn.Conv2d(int(channels/4), int(channels/2), kernel_size=3, stride=2, padding=1), 35 | nn.InstanceNorm2d(int(channels/2)), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(int(channels/2), channels, kernel_size=3, stride=2, padding=1), 38 | nn.InstanceNorm2d(channels), 39 | nn.ReLU(inplace=True)) 40 | inner_channels = channels // 4 41 | self.wsize = 20 42 | self.conv1 = nn.Sequential(nn.Conv2d(channels*2, inner_channels, kernel_size=3, padding=1), 43 | nn.InstanceNorm2d(inner_channels), 44 | nn.ReLU(inplace=True)) 45 | self.conv2 = nn.Sequential(nn.Conv2d(inner_channels, inner_channels, kernel_size=3, padding=1), 46 | nn.InstanceNorm2d(inner_channels), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1), 49 | nn.InstanceNorm2d(inner_channels), 50 | nn.ReLU(inplace=True)) 51 | self.conv3 = nn.Sequential(nn.Conv2d(inner_channels, inner_channels, kernel_size=3, padding=1), 52 | nn.InstanceNorm2d(inner_channels), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1), 55 | nn.InstanceNorm2d(inner_channels), 56 | nn.ReLU(inplace=True)) 57 | self.conv11 = nn.Sequential(nn.Conv2d(inner_channels, inner_channels*2, kernel_size=3, stride=2, padding=1), 58 | nn.InstanceNorm2d(inner_channels*2), 59 | nn.ReLU(inplace=True)) 60 | self.conv12 = nn.Sequential(nn.Conv2d(inner_channels*2, inner_channels*2, kernel_size=3, stride=1, padding=1), 61 | nn.InstanceNorm2d(inner_channels*2), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(inner_channels*2, inner_channels*2, kernel_size=3, stride=1, padding=1), 64 | nn.InstanceNorm2d(inner_channels*2), 65 | nn.ReLU(inplace=True)) 66 | self.weights = nn.Sequential(nn.Conv2d(inner_channels, inner_channels, kernel_size=3, padding=1), 67 | nn.InstanceNorm2d(inner_channels), 68 | nn.ReLU(inplace=True), 69 | nn.Conv2d(inner_channels, self.wsize, kernel_size=3, stride=1, padding=1)) 70 | self.weight_sg1 = nn.Sequential(nn.Conv2d(inner_channels, inner_channels, kernel_size=3, padding=1), 71 | nn.InstanceNorm2d(inner_channels), 72 | nn.ReLU(inplace=True), 73 | nn.Conv2d(inner_channels, self.wsize*2, kernel_size=3, stride=1, padding=1)) 74 | self.weight_sg2 = nn.Sequential(nn.Conv2d(inner_channels, inner_channels, kernel_size=3, padding=1), 75 | nn.InstanceNorm2d(inner_channels), 76 | nn.ReLU(inplace=True), 77 | nn.Conv2d(inner_channels, self.wsize*2, kernel_size=3, stride=1, padding=1)) 78 | self.weight_sg11 = nn.Sequential(nn.Conv2d(inner_channels*2, inner_channels*2, kernel_size=3, padding=1), 79 | nn.InstanceNorm2d(inner_channels*2), 80 | nn.ReLU(inplace=True), 81 | nn.Conv2d(inner_channels*2, self.wsize*2, kernel_size=3, stride=1, padding=1)) 82 | self.weight_sg12 = nn.Sequential(nn.Conv2d(inner_channels*2, inner_channels*2, kernel_size=3, padding=1), 83 | nn.InstanceNorm2d(inner_channels*2), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(inner_channels*2, self.wsize*2, kernel_size=3, stride=1, padding=1)) 86 | self.weight_sg3 = nn.Sequential(nn.Conv2d(inner_channels, inner_channels, kernel_size=3, padding=1), 87 | nn.InstanceNorm2d(inner_channels), 88 | nn.ReLU(inplace=True), 89 | nn.Conv2d(inner_channels, self.wsize*2, kernel_size=3, stride=1, padding=1)) 90 | #self.getweights = nn.Sequential(GetFilters(radius=1), 91 | # nn.Conv2d(9, 20, kernel_size=1, stride=1, padding=0, bias=False)) 92 | 93 | 94 | 95 | def forward(self, fea, img): 96 | x = self.conv0(img) 97 | x = torch.cat((self.bn_relu(fea), x), 1) 98 | x = self.conv1(x) 99 | rem = x 100 | x = self.conv2(x) + rem 101 | rem = x 102 | guid = self.weights(x) 103 | x = self.conv3(x) + rem 104 | sg1 = self.weight_sg1(x) 105 | sg1_u, sg1_v = torch.split(sg1, (self.wsize, self.wsize), dim=1) 106 | sg2 = self.weight_sg2(x) 107 | sg2_u, sg2_v = torch.split(sg2, (self.wsize, self.wsize), dim=1) 108 | sg3 = self.weight_sg3(x) 109 | sg3_u, sg3_v = torch.split(sg3, (self.wsize, self.wsize), dim=1) 110 | x = self.conv11(x) 111 | rem = x 112 | x = self.conv12(x) + rem 113 | sg11 = self.weight_sg11(x) 114 | sg11_u, sg11_v = torch.split(sg11, (self.wsize, self.wsize), dim=1) 115 | sg12 = self.weight_sg12(x) 116 | sg12_u, sg12_v = torch.split(sg12, (self.wsize, self.wsize), dim=1) 117 | guid_u = dict([('sg1', sg1_u), 118 | ('sg2', sg2_u), 119 | ('sg3', sg3_u), 120 | ('sg11', sg11_u), 121 | ('sg12', sg12_u)]) 122 | guid_v = dict([('sg1', sg1_v), 123 | ('sg2', sg2_v), 124 | ('sg3', sg3_v), 125 | ('sg11', sg11_v), 126 | ('sg12', sg12_v)]) 127 | return guid, guid_u, guid_v 128 | 129 | 130 | class SepFlow(nn.Module): 131 | def __init__(self, args): 132 | super(SepFlow, self).__init__() 133 | self.args = args 134 | self.hidden_dim = hdim = 128 135 | self.context_dim = cdim = 128 136 | args.corr_levels = 4 137 | args.corr_radius = 4 138 | 139 | if 'dropout' not in self.args: 140 | self.args.dropout = 0 141 | 142 | if 'alternate_corr' not in self.args: 143 | self.args.alternate_corr = False 144 | 145 | # feature network, context network, and update block 146 | 147 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 148 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 149 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 150 | self.guidance = Guidance(channels=256) 151 | self.cost_agg1 = CostAggregation(in_channel=8) 152 | self.cost_agg2 = CostAggregation(in_channel=8) 153 | 154 | def freeze_bn(self): 155 | count1, count2, count3 = 0, 0, 0 156 | for m in self.modules(): 157 | if isinstance(m, nn.SyncBatchNorm): 158 | count1 += 1 159 | m.eval() 160 | if isinstance(m, nn.BatchNorm2d): 161 | count2 += 1 162 | m.eval() 163 | if isinstance(m, nn.BatchNorm3d): 164 | count3 += 1 165 | #print(m) 166 | m.eval() 167 | #print(count1, count2, count3) 168 | #print(m) 169 | 170 | def initialize_flow(self, img): 171 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 172 | N, C, H, W = img.shape 173 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 174 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 175 | 176 | # optical flow computed as difference: flow = coords1 - coords0 177 | return coords0, coords1 178 | 179 | def upsample_flow(self, flow, mask): 180 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 181 | N, _, H, W = flow.shape 182 | mask = mask.view(N, 1, 9, 8, 8, H, W) 183 | mask = torch.softmax(mask, dim=2) 184 | 185 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 186 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 187 | 188 | up_flow = torch.sum(mask * up_flow, dim=2) 189 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 190 | return up_flow.reshape(N, 2, 8*H, 8*W) 191 | 192 | 193 | def forward(self, image1, image2, iters=12, upsample=True): 194 | """ Estimate optical flow between pair of frames """ 195 | 196 | image1 = 2 * (image1 / 255.0) - 1.0 197 | image2 = 2 * (image2 / 255.0) - 1.0 198 | 199 | image1 = image1.contiguous() 200 | image2 = image2.contiguous() 201 | 202 | hdim = self.hidden_dim 203 | cdim = self.context_dim 204 | fmap1 = self.fnet(image1) 205 | fmap2 = self.fnet(image2) 206 | 207 | fmap1 = fmap1.float() 208 | fmap2 = fmap2.float() 209 | guid, guid_u, guid_v = self.guidance(fmap1.detach(), image1) 210 | corr_fn = CorrBlock(fmap1, fmap2, guid, radius=self.args.corr_radius) 211 | 212 | cnet = self.cnet(image1) 213 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 214 | net = torch.tanh(net) 215 | inp = torch.relu(inp) 216 | corr1, corr2 = corr_fn(None, sep=True) 217 | coords0, coords1 = self.initialize_flow(image1) 218 | if self.training: 219 | u0, u1, flow_u, corr1 = self.cost_agg1(corr1, guid_u, max_shift=384, is_ux=True) 220 | v0, v1, flow_v, corr2 = self.cost_agg2(corr2, guid_v, max_shift=384, is_ux=False) 221 | flow_init = torch.cat((flow_u, flow_v), dim=1) 222 | 223 | flow_predictions = [] 224 | flow_predictions.append(torch.cat((u0, v0), dim=1)) 225 | flow_predictions.append(torch.cat((u1, v1), dim=1)) 226 | flow_predictions.append(flow_init) 227 | 228 | else: 229 | flow_u, corr1 = self.cost_agg1(corr1, guid_u, max_shift=384, is_ux=True) 230 | flow_v, corr2 = self.cost_agg2(corr2, guid_v, max_shift=384, is_ux=False) 231 | flow_init = torch.cat((flow_u, flow_v), dim=1) 232 | flow_init = F.interpolate(flow_init.detach()/8.0, [cnet.shape[2], cnet.shape[3]], mode='bilinear', align_corners=True) 233 | corr1d_fn = CorrBlock1D(corr1, corr2, radius=self.args.corr_radius) 234 | coords1 = coords1 + flow_init 235 | for itr in range(iters): 236 | coords1 = coords1.detach() 237 | 238 | corr = corr_fn(coords1) # index correlation volume 239 | corr1, corr2 = corr1d_fn(coords1) # index correlation volume 240 | flow = coords1 - coords0 241 | with autocast(enabled=self.args.mixed_precision): 242 | net, up_mask, delta_flow = self.update_block(net, inp, corr, corr1, corr2, flow) 243 | 244 | # F(t+1) = F(t) + \Delta(t) 245 | coords1 = coords1 + delta_flow 246 | 247 | # upsample predictions 248 | if up_mask is None: 249 | flow_up = upflow8(coords1 - coords0) 250 | else: 251 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 252 | if self.training: 253 | flow_predictions.append(flow_up) 254 | 255 | if self.training: 256 | return flow_predictions 257 | else: 258 | return coords1 - coords0, flow_up 259 | 260 | -------------------------------------------------------------------------------- /core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | cor1_planes = args.corr_levels * (2*args.corr_radius + 1) 84 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 85 | self.convc11 = nn.Conv2d(cor1_planes, 64, 1, padding=0) 86 | self.convc12 = nn.Conv2d(cor1_planes, 64, 1, padding=0) 87 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 88 | self.convc21 = nn.Conv2d(64, 64, 3, padding=1) 89 | self.convc22 = nn.Conv2d(64, 64, 3, padding=1) 90 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 91 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 92 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 93 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 94 | self.conv = nn.Conv2d(64+192+64*2, 128-2, 3, padding=1) 95 | 96 | def forward(self, flow, corr, corr1, corr2): 97 | cor = F.relu(self.convc1(corr)) 98 | cor = F.relu(self.convc2(cor)) 99 | cor1 = F.relu(self.convc11(corr1)) 100 | cor1 = F.relu(self.convc21(cor1)) 101 | cor2 = F.relu(self.convc12(corr2)) 102 | cor2 = F.relu(self.convc22(cor2)) 103 | flo = F.relu(self.convf1(flow)) 104 | flo = F.relu(self.convf2(flo)) 105 | 106 | cor_flo = torch.cat([cor, cor1, cor2, flo], dim=1) 107 | out = F.relu(self.conv(cor_flo)) 108 | return torch.cat([out, flow], dim=1) 109 | 110 | class SmallUpdateBlock(nn.Module): 111 | def __init__(self, args, hidden_dim=96): 112 | super(SmallUpdateBlock, self).__init__() 113 | self.encoder = SmallMotionEncoder(args) 114 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 115 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 116 | 117 | def forward(self, net, inp, corr, flow): 118 | motion_features = self.encoder(flow, corr) 119 | inp = torch.cat([inp, motion_features], dim=1) 120 | net = self.gru(net, inp) 121 | delta_flow = self.flow_head(net) 122 | 123 | return net, None, delta_flow 124 | 125 | class BasicUpdateBlock(nn.Module): 126 | def __init__(self, args, hidden_dim=128, input_dim=128): 127 | super(BasicUpdateBlock, self).__init__() 128 | self.args = args 129 | self.encoder = BasicMotionEncoder(args) 130 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 131 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 132 | 133 | self.mask = nn.Sequential( 134 | nn.Conv2d(128, 256, 3, padding=1), 135 | nn.ReLU(inplace=True), 136 | nn.Conv2d(256, 64*9, 1, padding=0)) 137 | 138 | def forward(self, net, inp, corr, corr1, corr2, flow, upsample=True): 139 | motion_features = self.encoder(flow, corr, corr1, corr2) 140 | inp = torch.cat([inp, motion_features], dim=1) 141 | 142 | net = self.gru(net, inp) 143 | delta_flow = self.flow_head(net) 144 | 145 | # scale mask to balence gradients 146 | mask = .25 * self.mask(net) 147 | return net, mask, delta_flow 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /core/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True, normalized=False): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | self.normalized = normalized 26 | 27 | # flip augmentation params 28 | self.do_flip = do_flip 29 | self.h_flip_prob = 0.5 30 | self.v_flip_prob = 0.1 31 | 32 | # photometric augmentation params 33 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 34 | self.asymmetric_color_aug_prob = 0.2 35 | self.eraser_aug_prob = 0.5 36 | 37 | def color_transform(self, img1, img2): 38 | """ Photometric augmentation """ 39 | 40 | # asymmetric 41 | if np.random.rand() < self.asymmetric_color_aug_prob: 42 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 43 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 44 | 45 | # symmetric 46 | else: 47 | image_stack = np.concatenate([img1, img2], axis=0) 48 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 49 | img1, img2 = np.split(image_stack, 2, axis=0) 50 | 51 | return img1, img2 52 | 53 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 54 | """ Occlusion augmentation """ 55 | 56 | ht, wd = img1.shape[:2] 57 | if np.random.rand() < self.eraser_aug_prob: 58 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 59 | for _ in range(np.random.randint(1, 3)): 60 | x0 = np.random.randint(0, wd) 61 | y0 = np.random.randint(0, ht) 62 | dx = np.random.randint(bounds[0], bounds[1]) 63 | dy = np.random.randint(bounds[0], bounds[1]) 64 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 65 | 66 | return img1, img2 67 | 68 | def padding(self, img1, img2, flow, valid=None): 69 | h, w, channel = img1.shape[:3] 70 | crop_height, crop_width = self.crop_size[0] + 4, self.crop_size[1] + 4 71 | border_h = max(crop_height - h, 0) 72 | border_w = max(crop_width - w, 0) 73 | pad0 = border_h // 2 # top 74 | pad0 = border_h 75 | pad1 = border_h - pad0 76 | pad2 = border_w // 2 # left 77 | pad3 = border_w - pad2 78 | img1 = cv2.copyMakeBorder(img1, pad0, pad1, pad2, pad3, cv2.BORDER_REPLICATE) 79 | img2 = cv2.copyMakeBorder(img2, pad0, pad1, pad2, pad3, cv2.BORDER_REPLICATE) 80 | flow = cv2.copyMakeBorder(flow, pad0, pad1, pad2, pad3, cv2.BORDER_CONSTANT, value=(1e4, 1e4)) 81 | if valid is not None: 82 | valid = cv2.copyMakeBorder(valid, pad0, pad1, pad2, pad3, cv2.BORDER_CONSTANT, value=(0)) 83 | return img1, img2, flow, valid 84 | return img1, img2, flow 85 | def spatial_transform(self, img1, img2, flow): 86 | # randomly sample scale 87 | ht, wd = img1.shape[:2] 88 | min_scale = np.maximum( 89 | (self.crop_size[0] + 8) / float(ht), 90 | (self.crop_size[1] + 8) / float(wd)) 91 | 92 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 93 | scale_x = scale 94 | scale_y = scale 95 | if np.random.rand() < self.stretch_prob: 96 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 97 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 98 | 99 | scale_x = np.clip(scale_x, min_scale, None) 100 | scale_y = np.clip(scale_y, min_scale, None) 101 | 102 | if np.random.rand() < self.spatial_aug_prob: 103 | # rescale the images 104 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 105 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 106 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 107 | flow = flow * [scale_x, scale_y] 108 | ##normalize 109 | if self.normalized: 110 | img1 = self.normalize(img1) 111 | img2 = self.normalize(img2) 112 | #img1_, img2_, flow_ = self.padding(img1, img2, flow, None) 113 | #print(np.fabs(img1_-img1).sum()) 114 | #print(np.fabs(img2_-img2).sum()) 115 | #print(np.fabs(flow_-flow).sum()) 116 | if img1.shape[0] <= self.crop_size[0] or img1.shape[1] <= self.crop_size[1]: 117 | img1, img2, flow = self.padding(img1, img2, flow, None) 118 | 119 | if self.do_flip: 120 | if np.random.rand() < self.h_flip_prob: # h-flip 121 | img1 = img1[:, ::-1] 122 | img2 = img2[:, ::-1] 123 | flow = flow[:, ::-1] * [-1.0, 1.0] 124 | 125 | if np.random.rand() < self.v_flip_prob: # v-flip 126 | img1 = img1[::-1, :] 127 | img2 = img2[::-1, :] 128 | flow = flow[::-1, :] * [1.0, -1.0] 129 | 130 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 131 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 132 | 133 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 134 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 135 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 136 | 137 | return img1, img2, flow 138 | 139 | def normalize(self, img): 140 | img = np.float32(img) 141 | r = img[:, :, 0] 142 | g = img[:, :, 1] 143 | b = img[:, :, 2] 144 | #img[:,:,0] = r 145 | #img[:,:,1] = g 146 | #img[:,:,2] = b 147 | 148 | img[:, :, 0] = (r - np.mean(r[:])) / (np.std(r[:]) + 1e-6) 149 | img[:, :, 1] = (g - np.mean(g[:])) / (np.std(g[:]) + 1e-6) 150 | img[:, :, 2] = (b - np.mean(b[:])) / (np.std(b[:]) + 1e-6) 151 | return img 152 | 153 | def __call__(self, img1, img2, flow): 154 | img1, img2 = self.color_transform(img1, img2) 155 | img1, img2 = self.eraser_transform(img1, img2) 156 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 157 | 158 | img1 = np.ascontiguousarray(img1) 159 | img2 = np.ascontiguousarray(img2) 160 | flow = np.ascontiguousarray(flow) 161 | 162 | return img1, img2, flow 163 | 164 | class SparseFlowAugmentor: 165 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False, normalized=False): 166 | # spatial augmentation params 167 | self.crop_size = crop_size 168 | self.min_scale = min_scale 169 | self.max_scale = max_scale 170 | self.spatial_aug_prob = 0.8 171 | self.stretch_prob = 0.8 172 | self.max_stretch = 0.2 173 | self.normalized = normalized 174 | 175 | # flip augmentation params 176 | self.do_flip = do_flip 177 | self.h_flip_prob = 0.5 178 | self.v_flip_prob = 0.1 179 | 180 | # photometric augmentation params 181 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 182 | self.asymmetric_color_aug_prob = 0.2 183 | self.eraser_aug_prob = 0.5 184 | 185 | def color_transform(self, img1, img2): 186 | image_stack = np.concatenate([img1, img2], axis=0) 187 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 188 | img1, img2 = np.split(image_stack, 2, axis=0) 189 | return img1, img2 190 | 191 | def eraser_transform(self, img1, img2): 192 | ht, wd = img1.shape[:2] 193 | if np.random.rand() < self.eraser_aug_prob: 194 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 195 | for _ in range(np.random.randint(1, 3)): 196 | x0 = np.random.randint(0, wd) 197 | y0 = np.random.randint(0, ht) 198 | dx = np.random.randint(50, 100) 199 | dy = np.random.randint(50, 100) 200 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 201 | 202 | return img1, img2 203 | 204 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 205 | ht, wd = flow.shape[:2] 206 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 207 | coords = np.stack(coords, axis=-1) 208 | 209 | coords = coords.reshape(-1, 2).astype(np.float32) 210 | flow = flow.reshape(-1, 2).astype(np.float32) 211 | valid = valid.reshape(-1).astype(np.float32) 212 | 213 | coords0 = coords[valid>=1] 214 | flow0 = flow[valid>=1] 215 | 216 | ht1 = int(round(ht * fy)) 217 | wd1 = int(round(wd * fx)) 218 | 219 | coords1 = coords0 * [fx, fy] 220 | flow1 = flow0 * [fx, fy] 221 | 222 | xx = np.round(coords1[:,0]).astype(np.int32) 223 | yy = np.round(coords1[:,1]).astype(np.int32) 224 | 225 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 226 | xx = xx[v] 227 | yy = yy[v] 228 | flow1 = flow1[v] 229 | 230 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 231 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 232 | 233 | flow_img[yy, xx] = flow1 234 | valid_img[yy, xx] = 1 235 | 236 | return flow_img, valid_img 237 | 238 | def spatial_transform(self, img1, img2, flow, valid): 239 | # randomly sample scale 240 | 241 | ht, wd = img1.shape[:2] 242 | min_scale = np.maximum( 243 | (self.crop_size[0] + 1) / float(ht), 244 | (self.crop_size[1] + 1) / float(wd)) 245 | 246 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 247 | scale_x = np.clip(scale, min_scale, None) 248 | scale_y = np.clip(scale, min_scale, None) 249 | 250 | if np.random.rand() < self.spatial_aug_prob: 251 | # rescale the images 252 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 253 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 254 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 255 | ## normalize 256 | if self.normalized: 257 | img1 = self.normalize(img1) 258 | img2 = self.normalize(img2) 259 | if img1.shape[0] <= self.crop_size[0] or img1.shape[1] <= self.crop_size[1]: 260 | img1, img2, flow, valid = self.padding(img1, img2, flow, valid) 261 | #if np.random.rand() < 0.2 or img1.shape[0] <= self.crop_size[0] or img1.shape[1] <= self.crop_size[1]: 262 | # img1, img2, flow, valid = self.padding(img1, img2, flow, valid) 263 | 264 | if self.do_flip: 265 | if np.random.rand() < 0.5: # h-flip 266 | img1 = img1[:, ::-1] 267 | img2 = img2[:, ::-1] 268 | flow = flow[:, ::-1] * [-1.0, 1.0] 269 | valid = valid[:, ::-1] 270 | 271 | margin_y = 20 272 | margin_x = 50 273 | 274 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 275 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 276 | 277 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 278 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 279 | 280 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 281 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 282 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 283 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 284 | return img1, img2, flow, valid 285 | 286 | def padding(self, img1, img2, flow, valid=None): 287 | h, w, channel = img1.shape[:3] 288 | crop_height, crop_width = self.crop_size[0] + 8, self.crop_size[1] + 8 289 | #crop_height, crop_width = (h+63)//64*64+8, (w+63)//64*64 + 8 290 | border_h = max(crop_height - h, 0) 291 | border_w = max(crop_width - w, 0) 292 | pad0 = border_h // 2 # top 293 | pad0 = border_h 294 | pad1 = border_h - pad0 295 | pad2 = border_w // 2 # left 296 | pad3 = border_w - pad2 297 | img1 = cv2.copyMakeBorder(img1, pad0, pad1, pad2, pad3, cv2.BORDER_REPLICATE) 298 | img2 = cv2.copyMakeBorder(img2, pad0, pad1, pad2, pad3, cv2.BORDER_REPLICATE) 299 | flow = cv2.copyMakeBorder(flow, pad0, pad1, pad2, pad3, cv2.BORDER_CONSTANT, value=(1e4, 1e4)) 300 | if valid is not None: 301 | valid = cv2.copyMakeBorder(valid, pad0, pad1, pad2, pad3, cv2.BORDER_CONSTANT, value=(0)) 302 | return img1, img2, flow, valid 303 | return img1, img2, flow 304 | 305 | 306 | def normalize(self, img): 307 | img = np.float32(img) 308 | r = img[:, :, 0] 309 | g = img[:, :, 1] 310 | b = img[:, :, 2] 311 | #img[:,:,0] = r 312 | #img[:,:,1] = g 313 | #img[:,:,2] = b 314 | 315 | img[:, :, 0] = (r - np.mean(r[:])) / (np.std(r[:]) + 1e-6) 316 | img[:, :, 1] = (g - np.mean(g[:])) / (np.std(g[:]) + 1e-6) 317 | img[:, :, 2] = (b - np.mean(b[:])) / (np.std(b[:]) + 1e-6) 318 | return img 319 | 320 | 321 | def __call__(self, img1, img2, flow, valid): 322 | img1, img2 = self.color_transform(img1, img2) 323 | img1, img2 = self.eraser_transform(img1, img2) 324 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 325 | 326 | img1 = np.ascontiguousarray(img1) 327 | img2 = np.ascontiguousarray(img2) 328 | flow = np.ascontiguousarray(flow) 329 | valid = np.ascontiguousarray(valid) 330 | 331 | return img1, img2, flow, valid 332 | -------------------------------------------------------------------------------- /core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | def readFlowVKITTI(filename): 101 | bgr = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) 102 | h, w, _c = bgr.shape 103 | assert bgr.dtype == np.uint16 and _c == 3 104 | # b == invalid flow flag == 0 for sky or other invalid flow 105 | invalid = bgr[:,:, 0] == 0 106 | valid = bgr[:,:,0].astype(np.float32) 107 | # g,r == flow_y,x normalized by height,width and scaled to [0;2**16 – 1] 108 | flow = 2.0 / (2**16 - 1.0) * bgr[:,:, 2:0:-1].astype(np.float32) - 1 109 | temp = flow[:,:,0] 110 | temp[invalid] = 0 111 | flow[:,:, 0] = temp * (w - 1) 112 | temp = flow[:,:,1] 113 | temp[invalid] = 0 114 | flow[:,:, 1] = temp * (h - 1) 115 | return flow, valid 116 | 117 | 118 | 119 | def readFlowKITTI(filename): 120 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 121 | flow = flow[:,:,::-1].astype(np.float32) 122 | flow, valid = flow[:, :, :2], flow[:, :, 2] 123 | flow = (flow - 2**15) / 64.0 124 | return flow, valid 125 | 126 | def readDispKITTI(filename): 127 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 128 | valid = disp > 0.0 129 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 130 | return flow, valid 131 | 132 | 133 | def writeFlowKITTI(filename, uv): 134 | uv = 64.0 * uv + 2**15 135 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 136 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 137 | cv2.imwrite(filename, uv[..., ::-1]) 138 | 139 | 140 | def read_gen(file_name, pil=False): 141 | ext = splitext(file_name)[-1] 142 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 143 | return Image.open(file_name) 144 | elif ext == '.bin' or ext == '.raw': 145 | return np.load(file_name) 146 | elif ext == '.flo': 147 | return readFlow(file_name).astype(np.float32) 148 | elif ext == '.pfm': 149 | flow = readPFM(file_name).astype(np.float32) 150 | if len(flow.shape) == 2: 151 | return flow 152 | else: 153 | return flow[:, :, :-1] 154 | return [] 155 | -------------------------------------------------------------------------------- /core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import torch.nn as nn 5 | from scipy import interpolate 6 | 7 | class DomainNorm(torch.nn.Module): 8 | def __init__(self, channel, l2=True): 9 | super(DomainNorm, self).__init__() 10 | self.normalize = nn.InstanceNorm2d(num_features=channel, affine=False) 11 | self.l2 = l2 12 | self.weight = nn.Parameter(torch.ones(1,channel,1,1)) 13 | self.bias = nn.Parameter(torch.zeros(1,channel,1,1)) 14 | self.weight.requires_grad = True 15 | self.bias.requires_grad = True 16 | def forward(self, x): 17 | x = self.normalize(x) 18 | if self.l2: 19 | x = F.normalize(x, p=2, dim=1) 20 | return x * self.weight + self.bias 21 | class InputPadder: 22 | """ Pads images such that dimensions are divisible by 8 """ 23 | def __init__(self, dims, mode='sintel'): 24 | self.ht, self.wd = dims[-2:] 25 | pad_ht = (((self.ht + 63) // 64) * 64 - self.ht) % 64 26 | pad_wd = (((self.wd + 63) // 64) * 64 - self.wd) % 64 27 | self.mode = mode 28 | if self.mode == 'sintel': 29 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 30 | else: 31 | #self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 32 | self.ht_new = (self.ht + 63) // 64 * 64 33 | self.wd_new = (self.wd + 63) // 64 * 64 34 | #self._pad = [0, 0, 0, pad_ht] 35 | 36 | def pad(self, *inputs): 37 | if self.mode == 'sintel': 38 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 39 | else: 40 | return [F.interpolate(x, [self.ht_new, self.wd_new], mode='bilinear', align_corners=True) for x in inputs] 41 | 42 | 43 | def unpad(self,x): 44 | if self.mode == 'sintel': 45 | ht, wd = x.shape[-2:] 46 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 47 | return x[..., c[0]:c[1], c[2]:c[3]] 48 | else: 49 | x_shape = len(x.shape) 50 | if x_shape == 3: 51 | x = x.unsqueeze(0) 52 | flow = F.interpolate(x, [self.ht, self.wd], mode='bilinear', align_corners=True) 53 | flow[:,0,:,:] *= (self.wd*1.0/self.wd_new) 54 | flow[:,1,:,:] *= (self.ht*1.0/self.ht_new) 55 | if x_shape == 3: 56 | return flow[0] 57 | return flow 58 | 59 | 60 | class InputPadder2: 61 | """ Pads images such that dimensions are divisible by 8 """ 62 | def __init__(self, dims, mode='sintel'): 63 | self.ht, self.wd = dims[-2:] 64 | pad_ht = (((self.ht + 63) // 64) * 64 - self.ht) % 64 65 | pad_wd = (((self.wd + 63) // 64) * 64 - self.wd) % 64 66 | if mode == 'sintel': 67 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 68 | else: 69 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 70 | 71 | def pad(self, *inputs): 72 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 73 | 74 | def unpad(self,x): 75 | ht, wd = x.shape[-2:] 76 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 77 | return x[..., c[0]:c[1], c[2]:c[3]] 78 | 79 | def forward_interpolate(flow): 80 | flow = flow.detach().cpu().numpy() 81 | dx, dy = flow[0], flow[1] 82 | 83 | ht, wd = dx.shape 84 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 85 | 86 | x1 = x0 + dx 87 | y1 = y0 + dy 88 | 89 | x1 = x1.reshape(-1) 90 | y1 = y1.reshape(-1) 91 | dx = dx.reshape(-1) 92 | dy = dy.reshape(-1) 93 | 94 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 95 | x1 = x1[valid] 96 | y1 = y1[valid] 97 | dx = dx[valid] 98 | dy = dy[valid] 99 | 100 | flow_x = interpolate.griddata( 101 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 102 | 103 | flow_y = interpolate.griddata( 104 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 105 | 106 | flow = np.stack([flow_x, flow_y], axis=0) 107 | return torch.from_numpy(flow).float() 108 | 109 | 110 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 111 | """ Wrapper for grid_sample, uses pixel coordinates """ 112 | H, W = img.shape[-2:] 113 | xgrid, ygrid = coords.split([1,1], dim=-1) 114 | xgrid = 2*xgrid/(W-1) - 1 115 | ygrid = 2*ygrid/(H-1) - 1 116 | 117 | grid = torch.cat([xgrid, ygrid], dim=-1) 118 | img = F.grid_sample(img, grid, align_corners=True) 119 | 120 | if mask: 121 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 122 | return img, mask.float() 123 | 124 | return img 125 | 126 | 127 | def coords_grid(batch, ht, wd): 128 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 129 | coords = torch.stack(coords[::-1], dim=0).float() 130 | return coords[None].repeat(batch, 1, 1, 1) 131 | 132 | 133 | def upflow8(flow, mode='bilinear'): 134 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 135 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 136 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | import argparse 5 | import os 6 | import cv2 7 | import glob 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | 12 | from sepflow import SepFlow 13 | from utils import flow_viz 14 | from utils.utils import InputPadder 15 | 16 | 17 | 18 | DEVICE = 'cuda' 19 | 20 | def load_image(imfile): 21 | img = np.array(Image.open(imfile)).astype(np.uint8) 22 | img = torch.from_numpy(img).permute(2, 0, 1).float() 23 | return img[None].to(DEVICE) 24 | 25 | 26 | def viz(img, flo): 27 | img = img[0].permute(1,2,0).cpu().numpy() 28 | flo = flo.permute(1,2,0).cpu().numpy() 29 | 30 | # map flow to rgb image 31 | flo = flow_viz.flow_to_image(flo) 32 | img_flo = np.concatenate([img, flo], axis=0) 33 | 34 | # import matplotlib.pyplot as plt 35 | # plt.imshow(img_flo / 255.0) 36 | # plt.show() 37 | 38 | cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 39 | cv2.waitKey() 40 | 41 | 42 | def demo(args): 43 | 44 | model = torch.nn.DataParallel(SepFlow(args)) 45 | checkpoint = torch.load(args.model) 46 | msg = model.load_state_dict(checkpoint['state_dict'], strict=False) 47 | print(msg) 48 | 49 | model = model.module 50 | model.cuda() 51 | model.eval() 52 | 53 | 54 | with torch.no_grad(): 55 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 56 | glob.glob(os.path.join(args.path, '*.jpg')) 57 | 58 | images = sorted(images) 59 | for imfile1, imfile2 in zip(images[:-1], images[1:]): 60 | image1 = load_image(imfile1) 61 | image2 = load_image(imfile2) 62 | 63 | padder = InputPadder(image1.shape, mode='others') 64 | image1, image2 = padder.pad(image1, image2) 65 | 66 | flow_low, flow_up = model(image1, image2, iters=20) 67 | flow_up = padder.unpad(flow_up[0]) 68 | viz(image1, flow_up) 69 | 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--model', help="restore checkpoint") 74 | parser.add_argument('--path', help="dataset for evaluation") 75 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 76 | args = parser.parse_args() 77 | 78 | demo(args) 79 | 80 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | from PIL import Image 5 | import argparse 6 | import os 7 | import time 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import matplotlib.pyplot as plt 12 | 13 | import datasets 14 | from utils import flow_viz 15 | from utils import frame_utils 16 | 17 | from sepflow import SepFlow 18 | from utils.utils import InputPadder, forward_interpolate 19 | 20 | @torch.no_grad() 21 | def create_sintel_submission(model, iters=32, output_path='sintel_submission'): 22 | """ Create submission for the Sintel leaderboard """ 23 | model.eval() 24 | for dstype in ['clean', 'final']: 25 | test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype) 26 | 27 | for test_id in range(len(test_dataset)): 28 | image1, image2, (sequence, frame) = test_dataset[test_id] 29 | 30 | padder = InputPadder(image1.shape) 31 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 32 | 33 | flow_low, flow_pr = model(image1, image2, iters=iters) 34 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 35 | 36 | 37 | output_dir = os.path.join(output_path, dstype, sequence) 38 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1)) 39 | 40 | if not os.path.exists(output_dir): 41 | os.makedirs(output_dir) 42 | 43 | frame_utils.writeFlow(output_file, flow) 44 | 45 | 46 | @torch.no_grad() 47 | def create_kitti_submission(model, iters=24, output_path='kitti_submission'): 48 | """ Create submission for the Sintel leaderboard """ 49 | model.eval() 50 | file_path = "/export/work/feihu/kitti2015" 51 | test_dataset = datasets.KITTI(split='testing', root=file_path, aug_params=None) 52 | 53 | if not os.path.exists(output_path): 54 | os.makedirs(output_path) 55 | 56 | for test_id in range(len(test_dataset)): 57 | image1, image2, (frame_id, ) = test_dataset[test_id] 58 | padder = InputPadder(image1.shape, mode='kitti') 59 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 60 | 61 | _, flow_pr = model(image1, image2, iters=iters) 62 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 63 | 64 | output_filename = os.path.join(output_path, frame_id) 65 | frame_utils.writeFlowKITTI(output_filename, flow) 66 | 67 | @torch.no_grad() 68 | def create_kitti2012_submission(model, iters=24, output_path='kitti2012_submission'): 69 | """ Create submission for the Sintel leaderboard """ 70 | model.eval() 71 | file_path = "/export/work/feihu/kitti2012" 72 | test_dataset = datasets.KITTI2012(split='testing', root=file_path, aug_params=None) 73 | 74 | if not os.path.exists(output_path): 75 | os.makedirs(output_path) 76 | 77 | for test_id in range(len(test_dataset)): 78 | image1, image2, (frame_id, ) = test_dataset[test_id] 79 | padder = InputPadder(image1.shape, mode='kitti') 80 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 81 | 82 | _, flow_pr = model(image1, image2, iters=iters) 83 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 84 | 85 | output_filename = os.path.join(output_path, frame_id) 86 | frame_utils.writeFlowKITTI(output_filename, flow) 87 | 88 | @torch.no_grad() 89 | def validate_chairs(model, iters=24): 90 | """ Perform evaluation on the FlyingChairs (test) split """ 91 | model.eval() 92 | epe_list = [] 93 | 94 | val_dataset = datasets.FlyingChairs(split='validation') 95 | for val_id in range(len(val_dataset)): 96 | image1, image2, flow_gt, _ = val_dataset[val_id] 97 | image1 = image1[None].cuda() 98 | image2 = image2[None].cuda() 99 | 100 | _, flow_pr = model(image1, image2, iters=iters) 101 | epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt() 102 | epe_list.append(epe.view(-1).numpy()) 103 | 104 | epe = np.mean(np.concatenate(epe_list)) 105 | print("Validation Chairs EPE: %f" % epe) 106 | return {'chairs': epe} 107 | 108 | 109 | @torch.no_grad() 110 | def validate_sintel(model, iters=32): 111 | """ Peform validation using the Sintel (train) split """ 112 | model.eval() 113 | results = {} 114 | for dstype in ['clean', 'final']: 115 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype) 116 | epe_list = [] 117 | 118 | for val_id in range(len(val_dataset)): 119 | image1, image2, flow_gt, _ = val_dataset[val_id] 120 | image1 = image1[None].cuda() 121 | image2 = image2[None].cuda() 122 | 123 | padder = InputPadder(image1.shape) 124 | image1, image2 = padder.pad(image1, image2) 125 | 126 | flow_low, flow_pr = model(image1, image2, iters=iters) 127 | flow = padder.unpad(flow_pr[0]).cpu() 128 | 129 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 130 | epe_list.append(epe.view(-1).numpy()) 131 | 132 | epe_all = np.concatenate(epe_list) 133 | epe = np.mean(epe_all) 134 | px1 = np.mean(epe_all<1) 135 | px3 = np.mean(epe_all<3) 136 | px5 = np.mean(epe_all<5) 137 | 138 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) 139 | results[dstype] = np.mean(epe_list) 140 | 141 | return results 142 | 143 | 144 | @torch.no_grad() 145 | def validate_kitti(model, iters=24): 146 | """ Peform validation using the KITTI-2015 (train) split """ 147 | model.eval() 148 | file_path = "/export/work/feihu/kitti2015" 149 | val_dataset = datasets.KITTI(split='training', root=file_path) 150 | 151 | out_list, epe_list = [], [] 152 | for val_id in range(len(val_dataset)): 153 | image1, image2, flow_gt, valid_gt = val_dataset[val_id] 154 | image1 = image1[None].cuda() 155 | image2 = image2[None].cuda() 156 | 157 | padder = InputPadder(image1.shape, mode='kitti') 158 | image1, image2 = padder.pad(image1, image2) 159 | 160 | flow_low, flow_pr = model(image1, image2, iters=iters) 161 | flow = padder.unpad(flow_pr[0]).cpu() 162 | 163 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 164 | mag = torch.sum(flow_gt**2, dim=0).sqrt() 165 | 166 | epe = epe.view(-1) 167 | mag = mag.view(-1) 168 | val = valid_gt.view(-1) >= 0.5 169 | 170 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() 171 | epe_list.append(epe[val].mean().item()) 172 | out_list.append(out[val].cpu().numpy()) 173 | 174 | epe_list = np.array(epe_list) 175 | out_list = np.concatenate(out_list) 176 | 177 | epe = np.mean(epe_list) 178 | f1 = 100 * np.mean(out_list) 179 | 180 | print("Validation KITTI: %f, %f" % (epe, f1)) 181 | return {'kitti-epe': epe, 'kitti-f1': f1} 182 | 183 | @torch.no_grad() 184 | def validate_kitti2012(model, iters=24): 185 | """ Peform validation using the KITTI-2015 (train) split """ 186 | model.eval() 187 | file_path = "/export/work/feihu/kitti2012" 188 | val_dataset = datasets.KITTI2012(split='training', root=file_path) 189 | 190 | out_list, epe_list = [], [] 191 | for val_id in range(len(val_dataset)): 192 | image1, image2, flow_gt, valid_gt = val_dataset[val_id] 193 | image1 = image1[None].cuda() 194 | image2 = image2[None].cuda() 195 | 196 | padder = InputPadder(image1.shape, mode='kitti') 197 | image1, image2 = padder.pad(image1, image2) 198 | 199 | flow_low, flow_pr = model(image1, image2, iters=iters) 200 | flow = padder.unpad(flow_pr[0]).cpu() 201 | 202 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 203 | mag = torch.sum(flow_gt**2, dim=0).sqrt() 204 | 205 | epe = epe.view(-1) 206 | mag = mag.view(-1) 207 | val = valid_gt.view(-1) >= 0.5 208 | 209 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() 210 | epe_list.append(epe[val].mean().item()) 211 | out_list.append(out[val].cpu().numpy()) 212 | 213 | epe_list = np.array(epe_list) 214 | out_list = np.concatenate(out_list) 215 | 216 | epe = np.mean(epe_list) 217 | f1 = 100 * np.mean(out_list) 218 | 219 | print("Validation KITTI2012: %f, %f" % (epe, f1)) 220 | return {'kitti-epe': epe, 'kitti-f1': f1} 221 | 222 | if __name__ == '__main__': 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument('--model', help="restore checkpoint") 225 | parser.add_argument('--dataset', help="dataset for evaluation") 226 | parser.add_argument('--small', action='store_true', help='use small model') 227 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 228 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 229 | args = parser.parse_args() 230 | 231 | model = torch.nn.DataParallel(SepFlow(args)) 232 | checkpoint = torch.load(args.model) 233 | msg = model.load_state_dict(checkpoint['state_dict'], strict=False) 234 | print(msg) 235 | 236 | model.cuda() 237 | model.eval() 238 | 239 | with torch.no_grad(): 240 | if args.dataset == 'sintel': 241 | create_sintel_submission(model.module) 242 | elif args.dataset == 'kitti': 243 | create_kitti_submission(model.module,output_path='kitti_submission') 244 | quit() 245 | 246 | with torch.no_grad(): 247 | if args.dataset == 'chairs': 248 | validate_chairs(model.module) 249 | 250 | elif args.dataset == 'sintel': 251 | validate_sintel(model.module) 252 | 253 | elif args.dataset == 'kitti': 254 | validate_kitti(model.module) 255 | elif args.dataset == 'kitti2012': 256 | validate_kitti2012(model.module) 257 | 258 | 259 | -------------------------------------------------------------------------------- /evaluate.sh: -------------------------------------------------------------------------------- 1 | python evaluate.py --model './checkpoints/sepflow_sintel.pth' --dataset 'sintel' 2 | python evaluate.py --model './checkpoints/sepflow_kitti.pth' --dataset 'kitti' 3 | -------------------------------------------------------------------------------- /libs/GANet/functions/GANet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from ..build.lib import GANet 4 | from torch.autograd import Variable 5 | #import GANet 6 | class NlfDownFunction(Function): 7 | @staticmethod 8 | def forward(ctx, input, g0): 9 | assert(input.is_contiguous() == True and g0.is_contiguous() == True) 10 | with torch.cuda.device_of(input): 11 | output_down = input.new().resize_(input.size()).zero_() 12 | GANet.nlf_down_cuda_forward(input, g0, output_down) 13 | output_down = output_down.contiguous() 14 | ctx.save_for_backward(input, g0, output_down) 15 | return output_down 16 | @staticmethod 17 | def backward(ctx, gradOutput): 18 | input, g0, output_down = ctx.saved_tensors 19 | assert(gradOutput.is_contiguous() == True) 20 | with torch.cuda.device_of(gradOutput): 21 | 22 | gradInput = gradOutput.new().resize_(input.size()).zero_() 23 | grad0 = gradOutput.new().resize_(g0.size()).zero_() 24 | GANet.nlf_down_cuda_backward(input, g0, output_down, gradOutput, gradInput, grad0) 25 | gradInput = gradInput.contiguous() 26 | grad0 = grad0.contiguous() 27 | return gradInput, grad0 28 | class NlfUpFunction(Function): 29 | @staticmethod 30 | def forward(ctx, input, g1): 31 | assert(input.is_contiguous() == True and g1.is_contiguous() == True) 32 | with torch.cuda.device_of(input): 33 | output_up = input.new().resize_(input.size()).zero_() 34 | GANet.nlf_up_cuda_forward(input, g1, output_up) 35 | output_up = output_up.contiguous() 36 | ctx.save_for_backward(input, g1, output_up) 37 | return output_up 38 | @staticmethod 39 | def backward(ctx, gradOutput): 40 | input, g1, output_up = ctx.saved_tensors 41 | assert(gradOutput.is_contiguous() == True) 42 | with torch.cuda.device_of(gradOutput): 43 | gradInput = gradOutput.new().resize_(input.size()).zero_() 44 | grad1 = gradOutput.new().resize_(g1.size()).zero_() 45 | GANet.nlf_up_cuda_backward(input, g1, output_up, gradOutput, gradInput, grad1) 46 | gradInput = gradInput.contiguous() 47 | grad1 = grad1.contiguous() 48 | return gradInput, grad1 49 | class NlfRightFunction(Function): 50 | @staticmethod 51 | def forward(ctx, input, g2): 52 | assert(input.is_contiguous() == True and g2.is_contiguous() == True) 53 | with torch.cuda.device_of(input): 54 | # num, channels, height, width = input.size() 55 | # output_right = input.new().resize_(num, channels, height, width).zero_() 56 | output_right = input.new().resize_(input.size()).zero_() 57 | GANet.nlf_right_cuda_forward(input, g2, output_right) 58 | output_right = output_right.contiguous() 59 | ctx.save_for_backward(input, g2, output_right) 60 | return output_right 61 | @staticmethod 62 | def backward(ctx, gradOutput): 63 | input, g2, output_right = ctx.saved_tensors 64 | assert(gradOutput.is_contiguous() == True) 65 | with torch.cuda.device_of(gradOutput): 66 | # num, channels, height, width = input.size() 67 | # _, fsize, _, _ = g2.size() 68 | # gradInput = gradOutput.new().resize_(num, channels, height, width).zero_() 69 | # grad2 = gradOutput.new().resize_(num, fsize, height, width).zero_() 70 | gradInput = gradOutput.new().resize_(input.size()).zero_() 71 | grad2 = gradOutput.new().resize_(g2.size()).zero_() 72 | GANet.nlf_right_cuda_backward(input, g2, output_right, gradOutput, gradInput, grad2) 73 | gradInput = gradInput.contiguous() 74 | grad2 = grad2.contiguous() 75 | return gradInput, grad2 76 | class NlfLeftFunction(Function): 77 | @staticmethod 78 | def forward(ctx, input, g3): 79 | 80 | assert(input.is_contiguous() == True and g3.is_contiguous() == True) 81 | with torch.cuda.device_of(input): 82 | # num, channels, height, width = input.size() 83 | # output_left = input.new().resize_(num, channels, height, width).zero_() 84 | output_left = input.new().resize_(input.size()).zero_() 85 | GANet.nlf_left_cuda_forward(input, g3, output_left) 86 | output_left = output_left.contiguous() 87 | ctx.save_for_backward(input, g3, output_left) 88 | return output_left 89 | @staticmethod 90 | def backward(ctx, gradOutput): 91 | input, g3, output_left = ctx.saved_tensors 92 | gradOutput = gradOutput.contiguous() 93 | assert(gradOutput.is_contiguous() == True) 94 | with torch.cuda.device_of(gradOutput): 95 | # num, channels, height, width = input.size() 96 | # _, fsize, _, _ = g3.size() 97 | # gradInput = gradOutput.new().resize_(num, channels, height, width).zero_() 98 | # grad3 = gradOutput.new().resize_(num, fsize, height, width).zero_() 99 | gradInput = gradOutput.new().resize_(input.size()).zero_() 100 | grad3 = gradOutput.new().resize_(g3.size()).zero_() 101 | GANet.nlf_left_cuda_backward(input, g3, output_left, gradOutput, gradInput, grad3) 102 | gradInput = gradInput.contiguous() 103 | grad3 = grad3.contiguous() 104 | return gradInput, grad3 105 | 106 | class NlfFunction(Function): 107 | @staticmethod 108 | def forward(ctx, input, g0, g1, g2, g3): 109 | assert(input.is_contiguous() == True and g0.is_contiguous() == True and g1.is_contiguous() == True and g2.is_contiguous() == True and g3.is_contiguous() == True) 110 | with torch.cuda.device_of(input): 111 | num, channels, height, width = input.size() 112 | output_down = input.new().resize_(num, channels, height, width).zero_() 113 | output_up = input.new().resize_(num, channels, height, width).zero_() 114 | output_right = input.new().resize_(num, channels, height, width).zero_() 115 | output_left = input.new().resize_(num, channels, height, width).zero_() 116 | GANet.nlf_cuda_forward(input, g0, g1, g2, g3, output_down, output_up, output_right, output_left) 117 | # GANet.sga_cuda_forward(input, filters, output, radius) 118 | 119 | output_down = output_down.contiguous() 120 | output_up = output_up.contiguous() 121 | output_right = output_right.contiguous() 122 | output_left = output_left.contiguous() 123 | ctx.save_for_backward(input, g0, g1, g2, g3, output_down, output_up, output_right, output_left) 124 | # print(output_left.size()) 125 | return output_left 126 | @staticmethod 127 | def backward(ctx, gradOutput): 128 | input, g0, g1, g2, g3, output_down, output_up, output_right, output_left = ctx.saved_tensors 129 | # print temp_out.size() 130 | # print mask.size() 131 | assert(gradOutput.is_contiguous() == True) 132 | with torch.cuda.device_of(gradOutput): 133 | num, channels, height, width = input.size() 134 | _, _, fsize, _, _ = g0.size() 135 | # print fsize 136 | gradInput = gradOutput.new().resize_(num, channels, height, width).zero_() 137 | grad0 = gradOutput.new().resize_(num, channels, fsize, height, width).zero_() 138 | grad1 = gradOutput.new().resize_(num, channels, fsize, height, width).zero_() 139 | grad2 = gradOutput.new().resize_(num, channels, fsize, height, width).zero_() 140 | grad3 = gradOutput.new().resize_(num, channels, fsize, height, width).zero_() 141 | 142 | GANet.nlf_cuda_backward(input, g0, g1, g2, g3, output_down, output_up, output_right, output_left, gradOutput, gradInput, grad0, grad1, grad2, grad3) 143 | # GANet.lga_cuda_backward(input, filters, gradOutput, gradInput, gradFilters, radius) 144 | gradInput = gradInput.contiguous() 145 | grad0 = grad0.contiguous() 146 | grad1 = grad1.contiguous() 147 | grad2 = grad2.contiguous() 148 | grad3 = grad3.contiguous() 149 | return gradInput, grad0, grad1, grad2, grad3 150 | 151 | class SgaFunction(Function): 152 | @staticmethod 153 | def forward(ctx, input, g0, g1, g2, g3): 154 | assert(input.is_contiguous() == True and g0.is_contiguous() == True and g1.is_contiguous() == True and g2.is_contiguous() == True and g3.is_contiguous() == True) 155 | with torch.cuda.device_of(input): 156 | num, channels, depth, height, width = input.size() 157 | output = input.new().resize_(num, channels, depth, height, width).zero_() 158 | temp_out = input.new().resize_(num, channels, depth, height, width).zero_() 159 | mask = input.new().resize_(num, channels, depth, height, width).zero_() 160 | GANet.sga_cuda_forward(input, g0, g1, g2, g3, temp_out, output, mask) 161 | # GANet.sga_cuda_forward(input, filters, output, radius) 162 | 163 | output = output.contiguous() 164 | ctx.save_for_backward(input, g0, g1, g2, g3, temp_out, mask) 165 | return output 166 | @staticmethod 167 | def backward(ctx, gradOutput): 168 | input, g0, g1, g2, g3, temp_out, mask = ctx.saved_tensors 169 | # print temp_out.size() 170 | # print mask.size() 171 | assert(gradOutput.is_contiguous() == True) 172 | with torch.cuda.device_of(gradOutput): 173 | num, channels, depth, height, width = input.size() 174 | # _, _, fsize, _, _ = g0.size() 175 | # print fsize 176 | gradInput = gradOutput.new().resize_(num, channels, depth, height, width).zero_() 177 | grad0 = gradOutput.new().resize_(g0.size()).zero_() 178 | grad1 = gradOutput.new().resize_(g1.size()).zero_() 179 | grad2 = gradOutput.new().resize_(g2.size()).zero_() 180 | grad3 = gradOutput.new().resize_(g3.size()).zero_() 181 | temp_grad = gradOutput.new().resize_(num, channels, depth, height, width).zero_() 182 | max_idx = gradOutput.new().resize_(num, channels, height, width).zero_() 183 | 184 | GANet.sga_cuda_backward(input, g0, g1, g2, g3, temp_out, mask, max_idx, gradOutput, temp_grad, gradInput, grad0, grad1, grad2, grad3) 185 | # GANet.lga_cuda_backward(input, filters, gradOutput, gradInput, gradFilters, radius) 186 | gradInput = gradInput.contiguous() 187 | grad0 = grad0.contiguous() 188 | grad1 = grad1.contiguous() 189 | grad2 = grad2.contiguous() 190 | grad3 = grad3.contiguous() 191 | return gradInput, grad0, grad1, grad2, grad3 192 | 193 | 194 | class Lga3d3Function(Function): 195 | @staticmethod 196 | def forward(ctx, input, filters, radius=1): 197 | ctx.radius = radius 198 | assert(input.is_contiguous() == True and filters.is_contiguous() == True) 199 | with torch.cuda.device_of(input): 200 | num, channels, depth, height, width = input.size() 201 | temp_out1 = input.new().resize_(num, channels, depth, height, width).zero_() 202 | temp_out2 = input.new().resize_(num, channels, depth, height, width).zero_() 203 | output = input.new().resize_(num, channels, depth, height, width).zero_() 204 | GANet.lga3d_cuda_forward(input, filters, temp_out1, radius) 205 | GANet.lga3d_cuda_forward(temp_out1, filters, temp_out2, radius) 206 | GANet.lga3d_cuda_forward(temp_out2, filters, output, radius) 207 | output = output.contiguous() 208 | ctx.save_for_backward(input, filters, temp_out1, temp_out2) 209 | return output 210 | @staticmethod 211 | def backward(ctx, gradOutput): 212 | input, filters, temp_out1, temp_out2 = ctx.saved_tensors 213 | assert(gradOutput.is_contiguous() == True) 214 | with torch.cuda.device_of(gradOutput): 215 | num, channels, depth, height, width = input.size() 216 | _, _, fsize, _, _ = filters.size() 217 | # gradInput = gradOutput.new().resize_(num, channels, height, width).zero_() 218 | gradFilters = gradOutput.new().resize_(num, channels, fsize, height, width).zero_() 219 | GANet.lga3d_cuda_backward(temp_out2, filters, gradOutput, temp_out2, gradFilters, ctx.radius) 220 | GANet.lga3d_cuda_backward(temp_out1, filters, temp_out2, temp_out1, gradFilters, ctx.radius) 221 | # temp_out[...] = 0 222 | GANet.lga3d_cuda_backward(input, filters, temp_out1, temp_out2, gradFilters, ctx.radius) 223 | # temp_out[...] = gradOutput[...] 224 | temp_out2 = temp_out2.contiguous() 225 | gradFilters = gradFilters.contiguous() 226 | return temp_out2, gradFilters, None 227 | class Lga3d2Function(Function): 228 | @staticmethod 229 | def forward(ctx, input, filters, radius=1): 230 | ctx.radius = radius 231 | assert(input.is_contiguous() == True and filters.is_contiguous() == True) 232 | with torch.cuda.device_of(input): 233 | num, channels, depth, height, width = input.size() 234 | temp_out = input.new().resize_(num, channels, depth, height, width).zero_() 235 | output = input.new().resize_(num, channels, depth, height, width).zero_() 236 | GANet.lga3d_cuda_forward(input, filters, temp_out, radius) 237 | GANet.lga3d_cuda_forward(temp_out, filters, output, radius) 238 | output = output.contiguous() 239 | ctx.save_for_backward(input, filters, temp_out) 240 | return output 241 | @staticmethod 242 | def backward(ctx, gradOutput): 243 | input, filters, temp_out = ctx.saved_tensors 244 | assert(gradOutput.is_contiguous() == True) 245 | with torch.cuda.device_of(gradOutput): 246 | num, channels, depth, height, width = input.size() 247 | _, _, fsize, _, _ = filters.size() 248 | # gradInput = gradOutput.new().resize_(num, channels, height, width).zero_() 249 | gradFilters = gradOutput.new().resize_(num, channels, fsize, height, width).zero_() 250 | GANet.lga3d_cuda_backward(temp_out, filters, gradOutput, temp_out, gradFilters, ctx.radius) 251 | # temp_out[...] = 0 252 | GANet.lga3d_cuda_backward(input, filters, temp_out, gradOutput, gradFilters, ctx.radius) 253 | temp_out[...] = gradOutput[...] 254 | temp_out = temp_out.contiguous() 255 | gradFilters = gradFilters.contiguous() 256 | return temp_out, gradFilters, None 257 | 258 | class Lga3dFunction(Function): 259 | @staticmethod 260 | def forward(ctx, input, filters, radius=1): 261 | ctx.radius = radius 262 | ctx.save_for_backward(input, filters) 263 | assert(input.is_contiguous() == True and filters.is_contiguous() == True) 264 | with torch.cuda.device_of(input): 265 | num, channels, depth, height, width = input.size() 266 | output = input.new().resize_(num, channels, depth, height, width).zero_() 267 | GANet.lga3d_cuda_forward(input, filters, output, radius) 268 | output = output.contiguous() 269 | return output 270 | @staticmethod 271 | def backward(ctx, gradOutput): 272 | input, filters = ctx.saved_tensors 273 | assert(gradOutput.is_contiguous() == True) 274 | with torch.cuda.device_of(gradOutput): 275 | num, channels, depth, height, width = input.size() 276 | _, _, fsize, _, _ = filters.size() 277 | gradInput = gradOutput.new().resize_(num, channels, depth, height, width).zero_() 278 | gradFilters = gradOutput.new().resize_(num, channels, fsize, height, width).zero_() 279 | GANet.lga3d_cuda_backward(input, filters, gradOutput, gradInput, gradFilters, ctx.radius) 280 | gradInput = gradInput.contiguous() 281 | gradFilters = gradFilters.contiguous() 282 | return gradInput, gradFilters, None 283 | 284 | class Lga3Function(Function): 285 | @staticmethod 286 | def forward(ctx, input, filters, radius=1): 287 | ctx.radius = radius 288 | assert(input.is_contiguous() == True and filters.is_contiguous() == True) 289 | with torch.cuda.device_of(input): 290 | num, channels, height, width = input.size() 291 | temp_out1 = input.new().resize_(num, channels, height, width).zero_() 292 | temp_out2 = input.new().resize_(num, channels, height, width).zero_() 293 | output = input.new().resize_(num, channels, height, width).zero_() 294 | GANet.lga_cuda_forward(input, filters, temp_out1, radius) 295 | GANet.lga_cuda_forward(temp_out1, filters, temp_out2, radius) 296 | GANet.lga_cuda_forward(temp_out2, filters, output, radius) 297 | output = output.contiguous() 298 | ctx.save_for_backward(input, fitlers, temp_out1, temp_out2) 299 | return output 300 | @staticmethod 301 | def backward(ctx, gradOutput): 302 | input, filters, temp_out1, temp_out2 = ctx.saved_tensors 303 | assert(gradOutput.is_contiguous() == True) 304 | with torch.cuda.device_of(gradOutput): 305 | num, channels, height, width = input.size() 306 | _, fsize, _, _ = filters.size() 307 | # gradInput = gradOutput.new().resize_(num, channels, height, width).zero_() 308 | gradFilters = gradOutput.new().resize_(num, fsize, height, width).zero_() 309 | GANet.lga_cuda_backward(temp_out2, filters, gradOutput, temp_out2, gradFilters, ctx.radius) 310 | GANet.lga_cuda_backward(temp_out1, filters, temp_out2, temp_out1, gradFilters, ctx.radius) 311 | # temp_out[...] = 0 312 | GANet.lga_cuda_backward(input, filters, temp_out1, temp_out2, gradFilters, ctx.radius) 313 | # temp_out[...] = gradOutput[...] 314 | temp_out2 = temp_out2.contiguous() 315 | gradFilters = gradFilters.contiguous() 316 | return temp_out2, gradFilters, None 317 | class Lga2Function(Function): 318 | @staticmethod 319 | def forward(ctx, input, filters, radius=1): 320 | ctx.radius = radius 321 | assert(input.is_contiguous() == True and filters.is_contiguous() == True) 322 | with torch.cuda.device_of(input): 323 | num, channels, height, width = input.size() 324 | temp_out = input.new().resize_(num, channels, height, width).zero_() 325 | output = input.new().resize_(num, channels, height, width).zero_() 326 | GANet.lga_cuda_forward(input, filters, temp_out, radius) 327 | GANet.lga_cuda_forward(temp_out, filters, output, radius) 328 | output = output.contiguous() 329 | ctx.save_for_backward(input, filters, temp_out) 330 | return output 331 | @staticmethod 332 | def backward(ctx, gradOutput): 333 | input, filters, temp_out = ctx.saved_tensors 334 | assert(gradOutput.is_contiguous() == True) 335 | with torch.cuda.device_of(gradOutput): 336 | num, channels, height, width = input.size() 337 | _, fsize, _, _ = filters.size() 338 | # gradInput = gradOutput.new().resize_(num, channels, height, width).zero_() 339 | gradFilters = gradOutput.new().resize_(num, fsize, height, width).zero_() 340 | GANet.lga_cuda_backward(temp_out, filters, gradOutput, temp_out, gradFilters, ctx.radius) 341 | # temp_out[...] = 0 342 | GANet.lga_cuda_backward(input, filters, temp_out, gradOutput, gradFilters, ctx.radius) 343 | temp_out[...] = gradOutput[...] 344 | temp_out = temp_out.contiguous() 345 | gradFilters = gradFilters.contiguous() 346 | return temp_out, gradFilters, None 347 | 348 | class Lgf2Function(Function): 349 | @staticmethod 350 | def forward(ctx, input, filters, radius=2): 351 | ctx.radius = radius 352 | assert(input.is_contiguous() == True and filters.is_contiguous() == True) 353 | with torch.cuda.device_of(input): 354 | # num, channels, depth, height, width = input.size() 355 | # temp_out = input.new().resize_(num, channels, depth, height, width).zero_() 356 | # output = input.new().resize_(num, channels, depth, height, width).zero_() 357 | temp_out = input.new().resize_(input.size()).zero_() 358 | output = input.new().resize_(input.size()).zero_() 359 | GANet.lgf_cuda_forward(input, filters, temp_out, radius) 360 | GANet.lgf_cuda_forward(temp_out, filters, output, radius) 361 | output = output.contiguous() 362 | ctx.save_for_backward(input, filters, temp_out) 363 | return output 364 | @staticmethod 365 | def backward(ctx, gradOutput): 366 | input, filters, temp_out = ctx.saved_tensors 367 | assert(gradOutput.is_contiguous() == True) 368 | with torch.cuda.device_of(gradOutput): 369 | # num, channels, depth, height, width = input.size() 370 | # _, fsize, _, _ = filters.size() 371 | # gradInput = gradOutput.new().resize_(num, channels, height, width).zero_() 372 | # gradFilters = gradOutput.new().resize_(num, fsize, height, width).zero_() 373 | gradFilters = gradOutput.new().resize_(filters.size()).zero_() 374 | GANet.lgf_cuda_backward(temp_out, filters, gradOutput, temp_out, gradFilters, ctx.radius) 375 | # temp_out[...] = 0 376 | GANet.lgf_cuda_backward(input, filters, temp_out, gradOutput, gradFilters, ctx.radius) 377 | temp_out[...] = gradOutput[...] 378 | temp_out = temp_out.contiguous() 379 | gradFilters = gradFilters.contiguous() 380 | return temp_out, gradFilters, None 381 | 382 | class LgaFunction(Function): 383 | @staticmethod 384 | def forward(ctx, input, filters): 385 | ctx.radius = radius 386 | assert(input.is_contiguous() == True and filters.is_contiguous() == True) 387 | with torch.cuda.device_of(input): 388 | num, channels, height, width = input.size() 389 | output = input.new().resize_(num, channels, height, width).zero_() 390 | GANet.lga_cuda_forward(input, filters, output, radius) 391 | output = output.contiguous() 392 | ctx.save_for_backward(input, filters) 393 | return output 394 | @staticmethod 395 | def backward(ctx, gradOutput): 396 | input, filters = ctx.saved_tensors 397 | assert(gradOutput.is_contiguous() == True) 398 | with torch.cuda.device_of(gradOutput): 399 | num, channels, height, width = input.size() 400 | _, fsize, _, _ = filters.size() 401 | gradInput = gradOutput.new().resize_(num, channels, height, width).zero_() 402 | gradFilters = gradOutput.new().resize_(num, fsize, height, width).zero_() 403 | GANet.lga_cuda_backward(input, filters, gradOutput, gradInput, gradFilters, ctx.radius) 404 | gradInput = gradInput.contiguous() 405 | gradFilters = gradFilters.contiguous() 406 | return gradInput, gradFilters, None 407 | class MyLoss2Function(Function): 408 | @staticmethod 409 | def forward(ctx, input1, input2, thresh=1, alpha=2): 410 | ctx.thresh = thresh 411 | ctx.alpha = alpha 412 | diff = input1 - input2 413 | temp=torch.abs(diff) 414 | temp[temp < thresh] = temp[temp < thresh] ** 2 / thresh 415 | tag = (temp <= thresh + alpha) & (temp >= thresh) 416 | temp[tag]=temp[tag] * 2 - (temp[tag] - thresh) ** 2 /(2.0 * alpha) - thresh 417 | temp[temp > thresh + alpha] += (alpha / 2.0) 418 | ctx.save_for_backward(diff) 419 | return torch.mean(temp) 420 | @staticmethod 421 | def backward(ctx, gradOutput): 422 | diff, = ctx.saved_tensors 423 | scale = torch.abs(diff) 424 | scale[scale > ctx.thresh + ctx.alpha] = 1 425 | tag = (scale <= ctx.thresh + ctx.alpha) & (scale >= ctx.thresh) 426 | scale[tag] = 2 - (scale[tag] - ctx.thresh) / ctx.alpha 427 | tag = scale < ctx.thresh 428 | scale[tag] = 2*scale[tag] / ctx.thresh 429 | diff[diff > 0] = 1.0 430 | diff[diff < 0] = -1.0 431 | diff = diff * scale * gradOutput / scale.numel() 432 | return diff, Variable(torch.Tensor([0])), None, None 433 | 434 | class MyLossFunction(Function): 435 | 436 | @staticmethod 437 | def forward(ctx, input1, input2, upper_thresh=5, lower_thresh=1): 438 | ctx.upper_thresh = upper_thresh 439 | ctx.lower_thresh = lower_thresh 440 | diff = input1 - input2 441 | ctx.save_for_backward(diff) 442 | return torch.mean(torch.abs(diff)) 443 | @staticmethod 444 | def backward(ctx, gradOutput): 445 | diff, = ctx.saved_tensors 446 | scale = torch.abs(diff) 447 | scale[scale > ctx.upper_thresh] = 1 448 | tag = (scale <= ctx.upper_thresh) & (scale >= ctx.lower_thresh) 449 | scale[tag] = 2 - torch.abs(scale[tag]-(ctx.upper_thresh + ctx.lower_thresh)/2.)/2. 450 | diff[diff > 0] = 1 451 | diff[diff < 0] = -1 452 | diff = diff * scale * gradOutput 453 | return diff, Variable(torch.Tensor([0])), None, None 454 | 455 | 456 | -------------------------------------------------------------------------------- /libs/GANet/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .GANet import * 2 | -------------------------------------------------------------------------------- /libs/GANet/modules/GANet.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | import torch 3 | import numpy as np 4 | from torch.autograd import Variable 5 | from ..functions import * 6 | import torch.nn.functional as F 7 | 8 | from ..functions.GANet import MyLossFunction 9 | from ..functions.GANet import SgaFunction 10 | from ..functions.GANet import NlfFunction 11 | from ..functions.GANet import LgaFunction 12 | from ..functions.GANet import Lga2Function 13 | from ..functions.GANet import Lga3Function 14 | from ..functions.GANet import Lga3dFunction 15 | from ..functions.GANet import Lga3d2Function 16 | from ..functions.GANet import Lga3d3Function 17 | from ..functions.GANet import MyLoss2Function 18 | from ..functions.GANet import NlfDownFunction 19 | from ..functions.GANet import NlfUpFunction 20 | from ..functions.GANet import NlfRightFunction 21 | from ..functions.GANet import NlfLeftFunction 22 | class ChannelNorm(Module): 23 | def __init__(self, eps=1e-5): 24 | super(ChannelNorm, self).__init__() 25 | # self.weight = nn.Parameter(torch.ones(1,num_features,1,1)) 26 | # self.bias = nn.Parameter(torch.zeros(1,num_features,1,1)) 27 | # self.num_groups = num_groups 28 | self.eps = eps 29 | 30 | def forward(self, x): 31 | # N,C,H,W = x.size() 32 | # G = self.num_groups 33 | 34 | # x = x.view(N,G,-1) 35 | mean = x.mean(1, keepdim=True) 36 | var = x.var(1, keepdim=True) 37 | 38 | x = (x-mean) / (var+self.eps).sqrt() 39 | return x 40 | # x = x.view(N,C,H,W) 41 | # return x * self.weight + self.bias 42 | 43 | class GetWeights(Module): 44 | def __init__(self, wsize=5): 45 | super(GetWeights, self).__init__() 46 | self.wsize = wsize 47 | 48 | # self.disp = Variable(torch.Tensor(np.reshape(np.array(range(self.maxdisp)),[1,self.maxdisp,1,1])).cuda(), requires_grad=False) 49 | def forward(self, x): 50 | assert(x.is_contiguous() == True) 51 | with torch.cuda.device_of(x): 52 | # x = F.normalize(x, p=2, dim=1) 53 | num, channels, height, width = x.size() 54 | 55 | weight_down = x.new().resize_(num, 5, height, width).zero_() 56 | weight_down[:, 0, :, :] = torch.sum(x * x, 1) 57 | weight_down[:, 1, 1:, :] = torch.sum(x[:, :, 1:, :] * x[:, :, :-1, :], 1) 58 | weight_down[:, 2, 1:, 1:] = torch.sum(x[:, :, 1:, 1:] * x[:, :, :-1, :-1], 1) 59 | weight_down[:, 3, 1:, :-1] = torch.sum(x[:, :, 1:, :-1] * x[:, :, :-1, 1:], 1) 60 | weight_down[:, 4, :, 1:] = torch.sum(x[:, :, :, 1:] * x[:, :, :, :-1], 1) 61 | 62 | weight_up = x.new().resize_(num, 5, height, width).zero_() 63 | weight_up[:, 0, :, :] = torch.sum(x * x, 1) 64 | weight_up[:, 1, :-1, :] = torch.sum(x[:, :, :-1, :] * x[:, :, 1:, :], 1) 65 | weight_up[:, 2, :-1, 1:] = torch.sum(x[:, :, :-1, 1:] * x[:, :, 1:, :-1], 1) 66 | weight_up[:, 3, :-1, :-1] = torch.sum(x[:, :, :-1, :-1] * x[:, :, 1:, 1:], 1) 67 | weight_up[:, 4, :, :-1] = torch.sum(x[:, :, :, :-1] * x[:, :, :, 1:], 1) 68 | 69 | weight_right = x.new().resize_(num, 5, height, width).zero_() 70 | weight_right[:, 0, :, :] = torch.sum(x * x, 1) 71 | weight_right[:, 1, :, 1:] = torch.sum(x[:, :, :, 1:] * x[:, :, :, :-1], 1) 72 | weight_right[:, 2, 1:, 1:] = torch.sum(x[:, :, 1:, 1:] * x[:, :, :-1, :-1], 1) 73 | weight_right[:, 3, :-1, 1:] = torch.sum(x[:, :, :-1, 1:] * x[:, :, 1:, :-1], 1) 74 | weight_right[:, 4, 1:, :] = torch.sum(x[:, :, 1:, :] * x[:, :, :-1, :], 1) 75 | 76 | weight_left = x.new().resize_(num, 5, height, width).zero_() 77 | weight_left[:, 0, :, :] = torch.sum(x * x, 1) 78 | weight_left[:, 1, :, :-1] = torch.sum(x[:, :, :, :-1] * x[:, :, :, 1:], 1) 79 | weight_left[:, 2, 1:, :-1] = torch.sum(x[:, :, 1:, :-1] * x[:, :, :-1, 1:], 1) 80 | weight_left[:, 3, :-1, :-1] = torch.sum(x[:, :, :-1, :-1] * x[:, :, 1:, 1:], 1) 81 | weight_left[:, 4, :-1, :] = torch.sum(x[:, :, :-1, :] * x[:, :, 1:, :], 1) 82 | 83 | # weight_down = F.normalize(weight_down, p=1, dim=1) 84 | # weight_up = F.normalize(weight_up, p=1, dim=1) 85 | # weight_right = F.normalize(weight_right, p=1, dim=1) 86 | # weight_left = F.normalize(weight_left, p=1, dim=1) 87 | 88 | weight_down = F.softmax(weight_down, dim=1) 89 | weight_up = F.softmax(weight_up, dim=1) 90 | weight_right = F.softmax(weight_right, dim=1) 91 | weight_left = F.softmax(weight_left, dim=1) 92 | 93 | weight_down = weight_down.contiguous() 94 | weight_up = weight_up.contiguous() 95 | weight_right = weight_right.contiguous() 96 | weight_left = weight_left.contiguous() 97 | 98 | return weight_down, weight_up, weight_right, weight_left 99 | 100 | class GetFilters(Module): 101 | def __init__(self, radius=2): 102 | super(GetFilters, self).__init__() 103 | self.radius = radius 104 | self.wsize = (radius*2 + 1) * (radius*2 + 1) 105 | # self.disp = Variable(torch.Tensor(np.reshape(np.array(range(self.maxdisp)),[1,self.maxdisp,1,1])).cuda(), requires_grad=False) 106 | def forward(self, x): 107 | #assert(x.is_contiguous() == True) 108 | 109 | with torch.cuda.device_of(x): 110 | x = F.normalize(x, p=2, dim=1) 111 | # num, channels, height, width = x.size() 112 | # rem = torch.unsqueeze(x, 2).repeat(1, 1, self.wsize, 1, 1) 113 | # 114 | # temp = x.new().resize_(num, channels, self.wsize, height, width).zero_() 115 | # idx = 0 116 | # for r in range(-self.radius, self.radius+1): 117 | # for c in range(-self.radius, self.radius+1): 118 | # temp[:, :, idx, max(-r, 0):min(height - r, height), max(-c,0):min(width-c, width)] = x[:, :, max(r, 0):min(height + r, height), max(c, 0):min(width + c, width)] 119 | # idx += 1 120 | # filters = torch.squeeze(torch.sum(rem*temp, 1), 1) 121 | # filters = F.normalize(filters, p=1, dim=1) 122 | # filters = filters.contiguous() 123 | # return filters 124 | 125 | 126 | num, channels, height, width = x.size() 127 | filters = x.new().resize_(num, self.wsize, height, width).zero_() 128 | idx = 0 129 | for r in range(-self.radius, self.radius+1): 130 | for c in range(-self.radius, self.radius+1): 131 | filters[:, idx, max(-r, 0):min(height - r, height), max(-c,0):min(width-c, width)] = torch.squeeze(torch.sum(x[:, :, max(r, 0):min(height + r, height), max(c, 0):min(width + c, width)] * x[:,:,max(-r, 0):min(height-r, height), max(-c,0):min(width-c,width)], 1),1) 132 | idx += 1 133 | 134 | filters = F.normalize(filters, p=1, dim=1) 135 | filters = filters.contiguous() 136 | return filters 137 | 138 | 139 | 140 | 141 | 142 | class MyNormalize(Module): 143 | def __init__(self, dim): 144 | self.dim = dim 145 | super(MyNormalize, self).__init__() 146 | def forward(self, x): 147 | # assert(x.is_contiguous() == True) 148 | with torch.cuda.device_of(x): 149 | norm = torch.sum(torch.abs(x),self.dim) 150 | norm[norm <= 0] = norm[norm <= 0] - 1e-6 151 | norm[norm >= 0] = norm[norm >= 0] + 1e-6 152 | norm = torch.unsqueeze(norm, self.dim) 153 | size = np.ones(x.dim(), dtype='int') 154 | size[self.dim] = x.size()[self.dim] 155 | norm = norm.repeat(*size) 156 | x = torch.div(x, norm) 157 | return x 158 | class MyLoss2(Module): 159 | def __init__(self, thresh=1, alpha=2): 160 | super(MyLoss2, self).__init__() 161 | self.thresh = thresh 162 | self.alpha = alpha 163 | def forward(self, input1, input2): 164 | result = MyLoss2Function.apply(input1, input2, self.thresh, self.alpha) 165 | return result 166 | class MyLoss(Module): 167 | def __init__(self, upper_thresh=5, lower_thresh=1): 168 | super(MyLoss, self).__init__() 169 | self.upper_thresh = 5 170 | self.lower_thresh = 1 171 | def forward(self, input1, input2): 172 | result = MyLossFunction.apply(input1, input2, self.upper_thresh, self.lower_thresh) 173 | return result 174 | class NLFMax(Module): 175 | def __init__(self): 176 | super(NLFMax, self).__init__() 177 | def forward(self, input, g0, g1, g2, g3): 178 | result0 = NlfDownFunction.apply(input, g0) 179 | result1 = NlfUpFunction.apply(input, g1) 180 | result2 = NlfRightFunction.apply(input, g2) 181 | result3 = NlfLeftFunction.apply(input, g3) 182 | return torch.max(torch.max(torch.max(result0, result1), result2), result3) 183 | 184 | class NLFMean(Module): 185 | def __init__(self): 186 | super(NLFMean, self).__init__() 187 | def forward(self, input, g0, g1, g2, g3): 188 | result0 = NlfDownFunction.apply(input, g0) 189 | result1 = NlfUpFunction.apply(input, g1) 190 | result2 = NlfRightFunction.apply(input, g2) 191 | result3 = NlfLeftFunction.apply(input, g3) 192 | # result1 = NlfUpFunction()(input, g1) 193 | # result2 = NlfRightFunction()(input, g2) 194 | # result3 = NlfLeftFunction()(input, g3) 195 | # return torch.add(torch.add(torch.add(result0, result1), result2), result3) 196 | return (result0 + result1 + result2 + result3) * 0.25 197 | class NLFIter(Module): 198 | def __init__(self): 199 | super(NLFIter, self).__init__() 200 | def forward(self, input, g0, g1, g2, g3): 201 | result = NlfDownFunction.apply(input, g0) 202 | result = NlfUpFunction.apply(result, g1) 203 | result = NlfRightFunction.apply(result, g2) 204 | result = NlfLeftFunction.apply(result, g3) 205 | return result 206 | class NLF(Module): 207 | def __init__(self): 208 | super(NLF, self).__init__() 209 | 210 | def forward(self, input, g0, g1, g2, g3): 211 | result = NlfFunction.apply(input, g0, g1, g2, g3) 212 | return result 213 | class SGA(Module): 214 | def __init__(self): 215 | super(SGA, self).__init__() 216 | 217 | def forward(self, input, g0, g1, g2, g3): 218 | result = SgaFunction.apply(input, g0, g1, g2, g3) 219 | return result 220 | 221 | 222 | 223 | class LGA3D3(Module): 224 | def __init__(self, radius=2): 225 | super(LGA3D3, self).__init__() 226 | self.radius = radius 227 | 228 | def forward(self, input1, input2): 229 | result = Lga3d3Function.apply(input1, input2, self.radius) 230 | return result 231 | class LGA3D2(Module): 232 | def __init__(self, radius=2): 233 | super(LGA3D2, self).__init__() 234 | self.radius = radius 235 | 236 | def forward(self, input1, input2): 237 | result = Lga3d2Function.apply(input1, input2, self.radius) 238 | return result 239 | class LGA3D(Module): 240 | def __init__(self, radius=2): 241 | super(LGA3D, self).__init__() 242 | self.radius = radius 243 | 244 | def forward(self, input1, input2): 245 | result = Lga3dFunction.apply(input1, input2, self.radius) 246 | return result 247 | 248 | class LGA3(Module): 249 | def __init__(self, radius=2): 250 | super(LGA3, self).__init__() 251 | self.radius = radius 252 | 253 | def forward(self, input1, input2): 254 | result = Lga3Function.apply(input1, input2, self.radius) 255 | return result 256 | class LGA2(Module): 257 | def __init__(self, radius=2): 258 | super(LGA2, self).__init__() 259 | self.radius = radius 260 | 261 | def forward(self, input1, input2): 262 | result = Lga2Function.apply(input1, input2, self.radius) 263 | return result 264 | class LGA(Module): 265 | def __init__(self, radius=2): 266 | super(LGA, self).__init__() 267 | self.radius = radius 268 | 269 | def forward(self, input1, input2): 270 | result = LgaFunction.apply(input1, input2, self.radius) 271 | return result 272 | 273 | 274 | 275 | class GetCostVolume(Module): 276 | def __init__(self, maxdisp): 277 | super(GetCostVolume, self).__init__() 278 | self.maxdisp=maxdisp+1 279 | 280 | def forward(self, x,y): 281 | assert(x.is_contiguous() == True) 282 | with torch.cuda.device_of(x): 283 | num, channels, height, width = x.size() 284 | cost = x.new().resize_(num, channels*2, self.maxdisp, height, width).zero_() 285 | # cost = Variable(torch.FloatTensor(x.size()[0], x.size()[1]*2, self.maxdisp, x.size()[2], x.size()[3]).zero_(), volatile= not self.training).cuda() 286 | for i in range(self.maxdisp): 287 | if i > 0 : 288 | cost[:, :x.size()[1], i, :,i:] = x[:,:,:,i:] 289 | cost[:, x.size()[1]:, i, :,i:] = y[:,:,:,:-i] 290 | else: 291 | cost[:, :x.size()[1], i, :,:] = x 292 | cost[:, x.size()[1]:, i, :,:] = y 293 | 294 | cost = cost.contiguous() 295 | return cost 296 | 297 | class DisparityRegression(Module): 298 | def __init__(self, maxdisp): 299 | super(DisparityRegression, self).__init__() 300 | self.maxdisp = maxdisp+1 301 | # self.disp = Variable(torch.Tensor(np.reshape(np.array(range(self.maxdisp)),[1,self.maxdisp,1,1])).cuda(), requires_grad=False) 302 | 303 | def forward(self, x): 304 | assert(x.is_contiguous() == True) 305 | with torch.cuda.device_of(x): 306 | disp = Variable(torch.Tensor(np.reshape(np.array(range(self.maxdisp)),[1,self.maxdisp,1,1])).cuda(), requires_grad=False) 307 | disp = disp.repeat(x.size()[0],1,x.size()[2],x.size()[3]) 308 | out = torch.sum(x*disp,1) 309 | return out 310 | 311 | -------------------------------------------------------------------------------- /libs/GANet/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .GANet import * 2 | -------------------------------------------------------------------------------- /libs/GANet/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CppExtension, BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='GANet', 7 | ext_modules=[ 8 | CUDAExtension('GANet', [ 9 | 'src/GANet_cuda.cpp', 10 | 'src/GANet_kernel_share.cu', 11 | 'src/NLF_kernel.cu', 12 | ]) 13 | ], 14 | cmdclass={ 15 | 'build_ext': BuildExtension 16 | }) 17 | -------------------------------------------------------------------------------- /libs/GANet/src/GANet_cuda.cpp: -------------------------------------------------------------------------------- 1 | //#include 2 | #include 3 | #include "GANet_kernel.h" 4 | 5 | extern "C" int 6 | lga_cuda_backward (at::Tensor input, at::Tensor filters, 7 | at::Tensor gradOutput, at::Tensor gradInput, 8 | at::Tensor gradFilters, const int radius) 9 | { 10 | lga_backward (input, filters, gradOutput, gradInput, gradFilters, radius); 11 | return 1; 12 | } 13 | 14 | extern "C" int 15 | lga_cuda_forward (at::Tensor input, at::Tensor filters, at::Tensor output, 16 | const int radius) 17 | { 18 | lga_forward (input, filters, output, radius); 19 | return 1; 20 | } 21 | 22 | extern "C" int 23 | lga3d_cuda_backward (at::Tensor input, at::Tensor filters, 24 | at::Tensor gradOutput, at::Tensor gradInput, 25 | at::Tensor gradFilters, const int radius) 26 | { 27 | lga3d_backward (input, filters, gradOutput, gradInput, gradFilters, radius); 28 | return 1; 29 | } 30 | 31 | extern "C" int 32 | lga3d_cuda_forward (at::Tensor input, at::Tensor filters, at::Tensor output, 33 | const int radius) 34 | { 35 | lga3d_forward (input, filters, output, radius); 36 | return 1; 37 | } 38 | 39 | extern "C" int 40 | sga_cuda_forward (at::Tensor input, at::Tensor guidance_down, 41 | at::Tensor guidance_up, at::Tensor guidance_right, 42 | at::Tensor guidance_left, at::Tensor temp_out, 43 | at::Tensor output, at::Tensor mask) 44 | { 45 | sga_kernel_forward (input, guidance_down, guidance_up, guidance_right, 46 | guidance_left, temp_out, output, mask); 47 | return 1; 48 | } 49 | 50 | extern "C" int 51 | sga_cuda_backward (at::Tensor input, at::Tensor guidance_down, 52 | at::Tensor guidance_up, at::Tensor guidance_right, 53 | at::Tensor guidance_left, at::Tensor temp_out, 54 | at::Tensor mask, at::Tensor max_idx, at::Tensor gradOutput, 55 | at::Tensor temp_grad, at::Tensor gradInput, 56 | at::Tensor grad_down, at::Tensor grad_up, 57 | at::Tensor grad_right, at::Tensor grad_left) 58 | { 59 | sga_kernel_backward (input, guidance_down, guidance_up, guidance_right, 60 | guidance_left, temp_out, mask, max_idx, gradOutput, 61 | temp_grad, gradInput, grad_down, grad_up, grad_right, 62 | grad_left); 63 | return 1; 64 | } 65 | 66 | extern "C" int 67 | nlf_cuda_backward (at::Tensor input, at::Tensor guidance_down, 68 | at::Tensor guidance_up, at::Tensor guidance_right, 69 | at::Tensor guidance_left, at::Tensor output_down, 70 | at::Tensor output_up, at::Tensor output_right, 71 | at::Tensor output_left, at::Tensor gradOutput, 72 | at::Tensor gradInput, at::Tensor grad_down, 73 | at::Tensor grad_up, at::Tensor grad_right, 74 | at::Tensor grad_left) 75 | { 76 | nlf_kernel_backward(input, guidance_down, guidance_up, guidance_right, guidance_left, output_down, output_up, output_right, output_left, gradOutput, gradInput, grad_down, grad_up, grad_right, grad_left); 77 | return 1; 78 | } 79 | extern "C" int 80 | nlf_cuda_forward (at::Tensor input, at::Tensor guidance_down, 81 | at::Tensor guidance_up, at::Tensor guidance_right, 82 | at::Tensor guidance_left, at::Tensor output_down, 83 | at::Tensor output_up, at::Tensor output_right, 84 | at::Tensor output_left) 85 | { 86 | nlf_kernel_forward(input, guidance_down, guidance_up, guidance_right, guidance_left, output_down, output_up, output_right, output_left); 87 | return 1; 88 | } 89 | 90 | extern "C" int 91 | nlf_down_cuda_forward (at::Tensor input, at::Tensor guidance_down, at::Tensor output_down){ 92 | nlf_down_kernel_forward (input, guidance_down, output_down); 93 | return 1; 94 | 95 | } 96 | extern "C" int 97 | nlf_down_cuda_backward (at::Tensor input, at::Tensor guidance_down, at::Tensor output_down, at::Tensor gradOutput, 98 | at::Tensor gradInput, at::Tensor grad_down){ 99 | nlf_down_kernel_backward (input, guidance_down, output_down, gradOutput, gradInput, grad_down); 100 | return 1; 101 | } 102 | extern "C" int 103 | nlf_up_cuda_forward (at::Tensor input, at::Tensor guidance_up, at::Tensor output_up){ 104 | nlf_up_kernel_forward (input, guidance_up, output_up); 105 | return 1; 106 | } 107 | extern "C" int 108 | nlf_up_cuda_backward (at::Tensor input, at::Tensor guidance_up, at::Tensor output_up, at::Tensor gradOutput, 109 | at::Tensor gradInput, at::Tensor grad_up){ 110 | nlf_up_kernel_backward (input, guidance_up, output_up, gradOutput, gradInput, grad_up); 111 | return 1; 112 | } 113 | extern "C" int 114 | nlf_left_cuda_forward (at::Tensor input, at::Tensor guidance_left, at::Tensor output_left){ 115 | nlf_left_kernel_forward (input, guidance_left, output_left); 116 | return 1; 117 | } 118 | extern "C" int 119 | nlf_left_cuda_backward (at::Tensor input, at::Tensor guidance_left, at::Tensor output_left, at::Tensor gradOutput, 120 | at::Tensor gradInput, at::Tensor grad_left){ 121 | nlf_left_kernel_backward (input, guidance_left, output_left, gradOutput, gradInput, grad_left); 122 | return 1; 123 | } 124 | extern "C" int 125 | nlf_right_cuda_forward (at::Tensor input, at::Tensor guidance_right, at::Tensor output_right){ 126 | nlf_right_kernel_forward (input, guidance_right, output_right); 127 | return 1; 128 | } 129 | extern "C" int 130 | nlf_right_cuda_backward (at::Tensor input, at::Tensor guidance_right, at::Tensor output_right, at::Tensor gradOutput, 131 | at::Tensor gradInput, at::Tensor grad_right){ 132 | nlf_right_kernel_backward (input, guidance_right, output_right, gradOutput, gradInput, grad_right); 133 | return 1; 134 | } 135 | /* 136 | extern "C" int 137 | lgf_cuda_forward (at::Tensor input, at::Tensor filters, at::Tensor output, 138 | const int radius){ 139 | lgf_kernel_forward (input, filters, output, radius); 140 | return 1; 141 | } 142 | extern "C" int 143 | lgf_cuda_backward (at::Tensor input, at::Tensor filters, at::Tensor gradOutput, 144 | at::Tensor gradInput, at::Tensor gradFilters, const int radius){ 145 | lgf_kernel_backward (input, filters, gradOutput, gradInput, gradFilters, radius); 146 | return 1; 147 | }*/ 148 | /* 149 | extern "C" int 150 | sgf3d_down_cuda_forward (at::Tensor input, at::Tensor guidance_down, at::Tensor output_down){ 151 | sgf3d_down_kernel_forward (input, guidance_down, output_down); 152 | return 1; 153 | 154 | } 155 | extern "C" int 156 | sgf3d_down_cuda_backward (at::Tensor input, at::Tensor guidance_down, at::Tensor output_down, at::Tensor gradOutput, 157 | at::Tensor gradInput, at::Tensor grad_down){ 158 | sgf3d_down_kernel_backward (input, guidance_down, output_down, gradOutput, gradInput, grad_down); 159 | return 1; 160 | } 161 | extern "C" int 162 | sgf3d_up_cuda_forward (at::Tensor input, at::Tensor guidance_up, at::Tensor output_up){ 163 | sgf3d_up_kernel_forward (input, guidance_up, output_up); 164 | return 1; 165 | } 166 | extern "C" int 167 | sgf3d_up_cuda_backward (at::Tensor input, at::Tensor guidance_up, at::Tensor output_up, at::Tensor gradOutput, 168 | at::Tensor gradInput, at::Tensor grad_up){ 169 | sgf3d_up_kernel_backward (input, guidance_up, output_up, gradOutput, gradInput, grad_up); 170 | return 1; 171 | } 172 | extern "C" int 173 | sgf3d_left_cuda_forward (at::Tensor input, at::Tensor guidance_left, at::Tensor output_left){ 174 | sgf3d_left_kernel_forward (input, guidance_left, output_left); 175 | return 1; 176 | } 177 | extern "C" int 178 | sgf3d_left_cuda_backward (at::Tensor input, at::Tensor guidance_left, at::Tensor output_left, at::Tensor gradOutput, 179 | at::Tensor gradInput, at::Tensor grad_left){ 180 | sgf3d_left_kernel_backward (input, guidance_left, output_left, gradOutput, gradInput, grad_left); 181 | return 1; 182 | } 183 | extern "C" int 184 | sgf3d_right_cuda_forward (at::Tensor input, at::Tensor guidance_right, at::Tensor output_right){ 185 | sgf3d_right_kernel_forward (input, guidance_right, output_right); 186 | return 1; 187 | } 188 | extern "C" int 189 | sgf3d_right_cuda_backward (at::Tensor input, at::Tensor guidance_right, at::Tensor output_right, at::Tensor gradOutput, 190 | at::Tensor gradInput, at::Tensor grad_right){ 191 | sgf3d_right_kernel_backward (input, guidance_right, output_right, gradOutput, gradInput, grad_right); 192 | return 1; 193 | } 194 | */ 195 | PYBIND11_MODULE (TORCH_EXTENSION_NAME, GANet) 196 | { 197 | GANet.def ("lga_cuda_forward", &lga_cuda_forward, "lga forward (CUDA)"); 198 | GANet.def ("lga_cuda_backward", &lga_cuda_backward, "lga backward (CUDA)"); 199 | GANet.def ("lga3d_cuda_forward", &lga3d_cuda_forward, "lga3d forward (CUDA)"); 200 | GANet.def ("lga3d_cuda_backward", &lga3d_cuda_backward, "lga3d backward (CUDA)"); 201 | GANet.def ("sga_cuda_backward", &sga_cuda_backward, "sga backward (CUDA)"); 202 | GANet.def ("sga_cuda_forward", &sga_cuda_forward, "sga forward (CUDA)"); 203 | GANet.def ("nlf_cuda_backward", &nlf_cuda_backward, "sgf backward (CUDA)"); 204 | GANet.def ("nlf_cuda_forward", &nlf_cuda_forward, "sgf forward (CUDA)"); 205 | GANet.def ("nlf_down_cuda_forward", &nlf_down_cuda_forward, "sgf down forward (CUDA)"); 206 | GANet.def ("nlf_down_cuda_backward", &nlf_down_cuda_backward, "sgf down backward (CUDA)"); 207 | GANet.def ("nlf_up_cuda_forward", &nlf_up_cuda_forward, "sgf up forward (CUDA)"); 208 | GANet.def ("nlf_up_cuda_backward", &nlf_up_cuda_backward, "sgf up backward (CUDA)"); 209 | GANet.def ("nlf_right_cuda_forward", &nlf_right_cuda_forward, "sgf right forward (CUDA)"); 210 | GANet.def ("nlf_right_cuda_backward", &nlf_right_cuda_backward, "sgf right backward (CUDA)"); 211 | GANet.def ("nlf_left_cuda_forward", &nlf_left_cuda_forward, "sgf left forward (CUDA)"); 212 | GANet.def ("nlf_left_cuda_backward", &nlf_left_cuda_backward, "sgf left backward (CUDA)"); 213 | // GANet.def ("lgf_cuda_forward", &lgf_cuda_forward, "lgf forward (CUDA)"); 214 | // GANet.def ("lgf_cuda_backward", &lgf_cuda_backward, "lgf backward (CUDA)"); 215 | /* GANet.def ("sgf3d_down_cuda_forward", &sgf3d_down_cuda_forward, "sgf3d down forward (CUDA)"); 216 | GANet.def ("sgf3d_down_cuda_backward", &sgf3d_down_cuda_backward, "sgf3d down backward (CUDA)"); 217 | GANet.def ("sgf3d_up_cuda_forward", &sgf3d_up_cuda_forward, "sgf3d up forward (CUDA)"); 218 | GANet.def ("sgf3d_up_cuda_backward", &sgf3d_up_cuda_backward, "sgf3d up backward (CUDA)"); 219 | GANet.def ("sgf3d_right_cuda_forward", &sgf3d_right_cuda_forward, "sgf3d right forward (CUDA)"); 220 | GANet.def ("sgf3d_right_cuda_backward", &sgf3d_right_cuda_backward, "sgf3d right backward (CUDA)"); 221 | GANet.def ("sgf3d_left_cuda_forward", &sgf3d_left_cuda_forward, "sgf3d left forward (CUDA)"); 222 | GANet.def ("sgf3d_left_cuda_backward", &sgf3d_left_cuda_backward, "sgf3d left backward (CUDA)"); 223 | */ 224 | } 225 | 226 | -------------------------------------------------------------------------------- /libs/GANet/src/GANet_cuda.h: -------------------------------------------------------------------------------- 1 | int lga_cuda_backward (at::Tensor input, at::Tensor filters, 2 | at::Tensor gradOutput, at::Tensor gradInput, 3 | at::Tensor gradFilters, const int radius); 4 | int lga_cuda_forward (at::Tensor input, at::Tensor filters, at::Tensor output, 5 | const int radius); 6 | int lga3d_cuda_backward (at::Tensor input, at::Tensor filters, 7 | at::Tensor gradOutput, at::Tensor gradInput, 8 | at::Tensor gradFilters, const int radius); 9 | int lga3d_cuda_forward (at::Tensor input, at::Tensor filters, 10 | at::Tensor output, const int radius); 11 | int sga_cuda_forward (at::Tensor input, at::Tensor guidance_down, 12 | at::Tensor guidance_up, at::Tensor guidance_right, 13 | at::Tensor guidance_left, at::Tensor temp_out, 14 | at::Tensor output, at::Tensor mask); 15 | int sga_cuda_backward (at::Tensor input, at::Tensor guidance_down, 16 | at::Tensor guidance_up, at::Tensor guidance_right, 17 | at::Tensor guidance_left, at::Tensor temp_out, 18 | at::Tensor mask, at::Tensor max_idx, 19 | at::Tensor gradOutput, at::Tensor temp_grad, 20 | at::Tensor gradInput, at::Tensor grad_down, 21 | at::Tensor grad_up, at::Tensor grad_right, 22 | at::Tensor grad_left); 23 | 24 | int nlf_cuda_backward (at::Tensor input, at::Tensor guidance_down, 25 | at::Tensor guidance_up, at::Tensor guidance_right, 26 | at::Tensor guidance_left, at::Tensor output_down, 27 | at::Tensor output_up, at::Tensor output_right, 28 | at::Tensor output_left, at::Tensor gradOutput, 29 | at::Tensor gradInput, at::Tensor grad_down, 30 | at::Tensor grad_up, at::Tensor grad_right, 31 | at::Tensor grad_left); 32 | int nlf_cuda_forward (at::Tensor input, at::Tensor guidance_down, 33 | at::Tensor guidance_up, at::Tensor guidance_right, 34 | at::Tensor guidance_left, at::Tensor output_down, 35 | at::Tensor output_up, at::Tensor output_right, 36 | at::Tensor output_left); 37 | int nlf_left_cuda_forward (at::Tensor input, at::Tensor guidance_left, at::Tensor output_left); 38 | int nlf_left_cuda_backward (at::Tensor input, at::Tensor guidance_left, at::Tensor output_left, at::Tensor gradOutput, 39 | at::Tensor gradInput, at::Tensor grad_left); 40 | int nlf_right_cuda_forward (at::Tensor input, at::Tensor guidance_right, at::Tensor output_right); 41 | int nlf_right_cuda_backward (at::Tensor input, at::Tensor guidance_right, at::Tensor output_right, at::Tensor gradOutput, 42 | at::Tensor gradInput, at::Tensor grad_right); 43 | int nlf_down_cuda_forward (at::Tensor input, at::Tensor guidance_down, at::Tensor output_down); 44 | int nlf_down_cuda_backward (at::Tensor input, at::Tensor guidance_down, at::Tensor output_down, at::Tensor gradOutput, 45 | at::Tensor gradInput, at::Tensor grad_down); 46 | int nlf_up_cuda_forward (at::Tensor input, at::Tensor guidance_up, at::Tensor output_up); 47 | int nlf_up_cuda_backward (at::Tensor input, at::Tensor guidance_up, at::Tensor output_up, at::Tensor gradOutput, 48 | at::Tensor gradInput, at::Tensor grad_up); 49 | 50 | 51 | -------------------------------------------------------------------------------- /libs/GANet/src/GANet_kernel.h: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | void nlf_kernel_backward (at::Tensor input, at::Tensor guidance_down, 8 | at::Tensor guidance_up, at::Tensor guidance_right, 9 | at::Tensor guidance_left, at::Tensor output_down, 10 | at::Tensor output_up, at::Tensor output_right, 11 | at::Tensor output_left, at::Tensor gradOutput, 12 | at::Tensor gradInput, at::Tensor grad_down, 13 | at::Tensor grad_up, at::Tensor grad_right, 14 | at::Tensor grad_left); 15 | void nlf_kernel_forward (at::Tensor input, at::Tensor guidance_down, 16 | at::Tensor guidance_up, at::Tensor guidance_right, 17 | at::Tensor guidance_left, at::Tensor output_down, 18 | at::Tensor output_up, at::Tensor output_right, 19 | at::Tensor output_left); 20 | 21 | 22 | void sga_kernel_forward (at::Tensor input, at::Tensor guidance_down, 23 | at::Tensor guidance_up, at::Tensor guidance_right, 24 | at::Tensor guidance_left, at::Tensor temp_out, 25 | at::Tensor output, at::Tensor mask); 26 | void sga_kernel_backward (at::Tensor input, at::Tensor guidance_down, 27 | at::Tensor guidance_up, at::Tensor guidance_right, 28 | at::Tensor guidance_left, at::Tensor temp_out, 29 | at::Tensor mask, at::Tensor max_idx, 30 | at::Tensor gradOutput, at::Tensor temp_grad, 31 | at::Tensor gradInput, at::Tensor grad_down, 32 | at::Tensor grad_up, at::Tensor grad_right, 33 | at::Tensor grad_left); 34 | 35 | void lga_backward (at::Tensor input, at::Tensor filters, 36 | at::Tensor gradOutput, at::Tensor gradInput, 37 | at::Tensor gradFilters, const int radius); 38 | void lga_forward (at::Tensor input, at::Tensor filters, at::Tensor output, 39 | const int radius); 40 | 41 | void lga3d_backward (at::Tensor input, at::Tensor filters, 42 | at::Tensor gradOutput, at::Tensor gradInput, 43 | at::Tensor gradFilters, const int radius); 44 | void lga3d_forward (at::Tensor input, at::Tensor filters, at::Tensor output, 45 | const int radius); 46 | void nlf_down_kernel_forward (at::Tensor input, at::Tensor guidance_down, at::Tensor output_down); 47 | void nlf_down_kernel_backward (at::Tensor input, at::Tensor guidance_down, at::Tensor output_down, at::Tensor gradOutput, 48 | at::Tensor gradInput, at::Tensor grad_down); 49 | void nlf_up_kernel_forward (at::Tensor input, at::Tensor guidance_up, at::Tensor output_up); 50 | void nlf_up_kernel_backward (at::Tensor input, at::Tensor guidance_up, at::Tensor output_up, at::Tensor gradOutput, 51 | at::Tensor gradInput, at::Tensor grad_up); 52 | 53 | void nlf_right_kernel_forward (at::Tensor input, at::Tensor guidance_right, at::Tensor output_right); 54 | void nlf_right_kernel_backward (at::Tensor input, at::Tensor guidance_right, at::Tensor output_right, at::Tensor gradOutput, 55 | at::Tensor gradInput, at::Tensor grad_right); 56 | 57 | void nlf_left_kernel_forward (at::Tensor input, at::Tensor guidance_left, at::Tensor output_left); 58 | void nlf_left_kernel_backward (at::Tensor input, at::Tensor guidance_left, at::Tensor output_left, at::Tensor gradOutput, 59 | at::Tensor gradInput, at::Tensor grad_left); 60 | 61 | /* 62 | void sgf3d_down_kernel_forward (at::Tensor input, at::Tensor guidance_down, at::Tensor output_down); 63 | void sgf3d_down_kernel_backward (at::Tensor input, at::Tensor guidance_down, at::Tensor output_down, at::Tensor gradOutput, 64 | at::Tensor gradInput, at::Tensor grad_down); 65 | void sgf3d_up_kernel_forward (at::Tensor input, at::Tensor guidance_up, at::Tensor output_up); 66 | void sgf3d_up_kernel_backward (at::Tensor input, at::Tensor guidance_up, at::Tensor output_up, at::Tensor gradOutput, 67 | at::Tensor gradInput, at::Tensor grad_up); 68 | 69 | void sgf3d_right_kernel_forward (at::Tensor input, at::Tensor guidance_right, at::Tensor output_right); 70 | void sgf3d_right_kernel_backward (at::Tensor input, at::Tensor guidance_right, at::Tensor output_right, at::Tensor gradOutput, 71 | at::Tensor gradInput, at::Tensor grad_right); 72 | 73 | void sgf3d_left_kernel_forward (at::Tensor input, at::Tensor guidance_left, at::Tensor output_left); 74 | void sgf3d_left_kernel_backward (at::Tensor input, at::Tensor guidance_left, at::Tensor output_left, at::Tensor gradOutput, 75 | at::Tensor gradInput, at::Tensor grad_left); 76 | */ 77 | #ifdef __cplusplus 78 | } 79 | #endif 80 | -------------------------------------------------------------------------------- /libs/GANet/src/costvolume.cu: -------------------------------------------------------------------------------- 1 | #include 2 | //#include 3 | //#include 4 | //#include 5 | //#include 6 | //#include 7 | 8 | #define CUDA_NUM_THREADS 256 9 | #define THREADS_PER_BLOCK 64 10 | 11 | #define DIM0(TENSOR) ((TENSOR).x) 12 | #define DIM1(TENSOR) ((TENSOR).y) 13 | #define DIM2(TENSOR) ((TENSOR).z) 14 | #define DIM3(TENSOR) ((TENSOR).w) 15 | 16 | #define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) 17 | 18 | #ifdef __cplusplus 19 | extern "C" { 20 | #endif 21 | 22 | 23 | 24 | 25 | __global__ void cost_kernel_forward(const int n, const float* x, const float* y, const int shift1, const int shift2, const int stride1, const int stride2, const int heightx, const int widthx, const int heighty, const int widthy, const int channel, float* top_data) { 26 | 27 | int index = blockIdx.x * blockDim.x + threadIdx.x; 28 | 29 | if (index >= n) { 30 | return; 31 | } 32 | int stepx = heightx * widthx; 33 | int stepy = heighty * widthy; 34 | int size1 = (shift1*2+1); 35 | int size2 = (shift2*2+1); 36 | int loc0 = index/(size1*size2*stepx); 37 | int loc34 = index%stepx; 38 | int loc1 = index/(size2*stepx)%size1-shift1; 39 | int loc2 = index/stepx%size2-shift2; 40 | 41 | int cur_x = loc0*stepx*channel + loc34; 42 | int row = loc34/widthx/stride1; 43 | int col = (loc34%widthx)/stride2; 44 | if(row+loc1<0||row+loc1>=heighty||col+loc2<0 ||col+loc2>=widthy){ 45 | top_data[index]=0; 46 | return; 47 | } 48 | int cur_y = loc0*stepy*channel + (row+loc1)*widthy + col+loc2; 49 | float temp = 0; 50 | for(int i=0;i= n) { 64 | return; 65 | } 66 | int stepx = heightx * widthx; 67 | int stepy = heighty * widthy; 68 | int size1 = (shift1*2+1); 69 | int size2 = (shift2*2+1); 70 | int loc0 = index/(channel*stepx); 71 | int loc34 = index%stepx; 72 | int rowx = loc34/widthx; 73 | int colx = loc34%widthx; 74 | int rowy = rowx/stride1; 75 | int coly= colx/stride2; 76 | int basey = index/stepx*stepy; 77 | int base = loc0*size1*size2*stepx + loc34; 78 | float temp = 0; 79 | for(int i=0;i=heighty) 82 | continue; 83 | int basei = base + i*size2*stepx; 84 | int baseyi = basey + ry * widthy; 85 | for(int j=0;j=widthy) 88 | continue; 89 | temp += top_diff[basei+j*stepx] * y[baseyi+cy]; 90 | } 91 | } 92 | bottom_diff[index]=temp; 93 | } 94 | __global__ void cost_backward_y(const int n, const float* x, const float* top_diff, const int shift1, const int shift2, const int stride1, const int stride2, const int heightx, const int widthx, const int heighty, const int widthy, const int channel, float* bottom_diff){ 95 | 96 | int index = blockIdx.x * blockDim.x + threadIdx.x; 97 | 98 | if (index >= n) { 99 | return; 100 | } 101 | int stepx = heightx * widthx; 102 | int stepy = heighty * widthy; 103 | int size1 = (shift1*2+1); 104 | int size2 = (shift2*2+1); 105 | int loc0 = index/(channel*stepy); 106 | int loc34 = index%stepy; 107 | int rowy = loc34/widthy; 108 | int coly = loc34%widthy; 109 | int rowx = rowy*stride1; 110 | int colx= coly*stride2; 111 | int basex = index/stepy*stepx; 112 | int base = loc0*size1*size2*stepx; 113 | float temp = 0; 114 | for(int i=-shift1*stride1;i<(shift1+1)*stride1;i++){ 115 | int rx = rowx + i; 116 | if(rx<0||rx>=heightx) 117 | continue; 118 | int loc1 = -i/stride1 + shift1; 119 | int basei = base + loc1*size2*stepx + rx*widthx; 120 | int basexi = basex + rx * widthx; 121 | for(int j=-shift2*stride2;j<(shift2+1)*stride2;j++){ 122 | int cx = colx + j; 123 | if(cx<0||cx>=widthx) 124 | continue; 125 | int loc2 = -j/stride2 + shift2; 126 | int curx = basexi + cx; 127 | int cur = basei + loc2*stepx +cx; 128 | temp += top_diff[cur] * x[curx]; 129 | } 130 | } 131 | bottom_diff[index]=temp; 132 | } 133 | 134 | 135 | 136 | 137 | void cost_volume_forward (at::Tensor input1, at::Tensor input2, 138 | at::Tensor output, const int shift1=48, 139 | const int shift2 = 48, const int stride1=1, 140 | const int stride2=1){ 141 | 142 | int num = input1.size(0); 143 | int channel = input1.size(1); 144 | int heightx = input1.size(2); 145 | int widthx = input1.size(3); 146 | int heighty = input2.size(2); 147 | int widthy = input2.size(3); 148 | 149 | float *cost = output.data(); 150 | 151 | const float *x = input1.data(); 152 | const float *y = input2.data(); 153 | 154 | int n = output.numel(); 155 | int threads = (n + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 156 | // printf("%d %d %d %d %d %d %d %d\n", num, channel, height, width, wsize, n, threads, N); 157 | cost_kernel_forward <<>>(n, x, y, shift1, shift2, stride1, stride2, heightx, widthx, heighty, widthy, channel, cost); 158 | // printf("sgf down done...\n"); 159 | } 160 | void cost_volume_backward (at::Tensor input1, at::Tensor input2, 161 | at::Tensor grad_output, at::Tensor grad_input1, at::Tensor grad_input2, 162 | const int shift1=48, const int shift2 = 48, const int stride1=1, 163 | const int stride2=1){ 164 | 165 | int num = input1.size(0); 166 | int channel = input1.size(1); 167 | int heightx = input1.size(2); 168 | int widthx = input1.size(3); 169 | int heighty = input2.size(2); 170 | int widthy = input2.size(3); 171 | 172 | const float *grad_out = grad_output.data(); 173 | const float *x = input1.data(); 174 | const float *y = input2.data(); 175 | 176 | float *gradx = grad_input1.data(); 177 | float *grady = grad_input2.data(); 178 | 179 | int n = input1.numel(); 180 | int threads = (n + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 181 | // printf("%d %d %d %d %d %d %d %d\n", num, channel, height, width, wsize, n, threads, N); 182 | cost_backward_x <<>>(n, y, grad_out, shift1, shift2, stride1, stride2, heightx, widthx, heighty, widthy, channel, gradx); 183 | n = input2.numel(); 184 | threads = (n + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 185 | cost_backward_y <<>>(n, x, grad_out, shift1, shift2, stride1, stride2, heightx, widthx, heighty, widthy, channel, grady); 186 | // printf("sgf down done...\n"); 187 | } 188 | 189 | 190 | 191 | 192 | #ifdef __cplusplus 193 | } 194 | #endif 195 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | from math import log10 4 | import sys 5 | sys.path.append('core') 6 | import shutil 7 | import os 8 | import torch 9 | import torch.distributed as dist 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim as optim 14 | from torch.autograd import Variable 15 | from torch.utils.data import DataLoader 16 | import torch.nn.functional as F 17 | import torch.multiprocessing as mp 18 | import numpy as np 19 | import cv2 20 | import time 21 | import matplotlib.pyplot as plt 22 | from sepflow import SepFlow 23 | import evaluate 24 | import datasets 25 | from torch.utils.tensorboard import SummaryWriter 26 | from utils.utils import InputPadder, forward_interpolate 27 | 28 | try: 29 | from torch.cuda.amp import GradScaler 30 | except: 31 | # dummy GradScaler for PyTorch < 1.6 32 | class GradScaler: 33 | def __init__(self): 34 | pass 35 | def scale(self, loss): 36 | return loss 37 | def unscale_(self, optimizer): 38 | pass 39 | def step(self, optimizer): 40 | optimizer.step() 41 | def update(self): 42 | pass 43 | 44 | # Training settings 45 | parser = argparse.ArgumentParser(description='PyTorch SepFlow Example') 46 | parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512]) 47 | parser.add_argument('--resume', type=str, default='', help="resume from saved model") 48 | parser.add_argument('--weights', type=str, default='', help="weights from saved model") 49 | parser.add_argument('--batchSize', type=int, default=1, help='training batch size') 50 | parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size') 51 | parser.add_argument('--nEpochs', type=int, default=2048, help='number of epochs to train for') 52 | parser.add_argument('--lr', type=float, default=0.001, help='Learning Rate. Default=0.001') 53 | parser.add_argument('--cuda', type=int, default=1, help='use cuda? Default=True') 54 | parser.add_argument('--threads', type=int, default=1, help='number of threads for data loader to use') 55 | parser.add_argument('--manual_seed', type=int, default=1234, help='random seed to use. Default=123') 56 | parser.add_argument('--shift', type=int, default=0, help='random shift of left image. Default=0') 57 | parser.add_argument('--data_path', type=str, default='/export/work/feihu/flow/SceneFlow/', help="data root") 58 | parser.add_argument('--save_path', type=str, default='./checkpoints/', help="location to save models") 59 | parser.add_argument('--gpu', default='0,1,2,3,4,5,6,7', type=str, help="gpu idxs") 60 | parser.add_argument('--workers', type=int, default=16, help="workers") 61 | parser.add_argument('--world_size', type=int, default=1, help="world_size") 62 | parser.add_argument('--rank', type=int, default=0, help="rank") 63 | parser.add_argument('--dist_backend', type=str, default="nccl", help="dist_backend") 64 | parser.add_argument('--dist_url', type=str, default="tcp://127.0.0.1:6789", help="dist_url") 65 | parser.add_argument('--distributed', type=int, default=0, help="distribute") 66 | parser.add_argument('--sync_bn', type=int, default=0, help="sync bn") 67 | parser.add_argument('--multiprocessing_distributed', type=int, default=0, help="multiprocess") 68 | parser.add_argument('--freeze_bn', type=int, default=0, help="freeze bn") 69 | parser.add_argument('--start_epoch', type=int, default=0, help="start epoch") 70 | parser.add_argument('--stage', type=str, default='chairs', help="training stage: 1) things 2) chairs 3) kitti 4) mixed.") 71 | parser.add_argument('--validation', type=str, nargs='+') 72 | parser.add_argument('--num_steps', type=int, default=100000) 73 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 74 | parser.add_argument('--iters', type=int, default=12) 75 | parser.add_argument('--wdecay', type=float, default=.00005) 76 | parser.add_argument('--epsilon', type=float, default=1e-8) 77 | parser.add_argument('--clip', type=float, default=1.0) 78 | parser.add_argument('--dropout', type=float, default=0.0) 79 | parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting') 80 | parser.add_argument('--add_noise', action='store_true') 81 | parser.add_argument('--small', action='store_true', help='use small model') 82 | #parser.add_argument('--smoothl1', action='store_true', help='use smooth l1 loss') 83 | 84 | MAX_FLOW = 400 85 | SUM_FREQ = 100 86 | VAL_FREQ = 2500 87 | 88 | def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW): 89 | """ Loss function defined over sequence of flow predictions """ 90 | 91 | n_predictions = len(flow_preds) 92 | flow_loss = 0.0 93 | 94 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 95 | valid = (valid >= 0.5) & (mag < max_flow) 96 | 97 | weights = [0.1, 0.3, 0.5] 98 | base = weights[2] - gamma ** (n_predictions - 3) 99 | for i in range(n_predictions - 3): 100 | weights.append( base + gamma**(n_predictions - i - 4) ) 101 | 102 | for i in range(n_predictions): 103 | i_loss = (flow_preds[i] - flow_gt).abs() 104 | flow_loss += weights[i] * (valid[:, None] * i_loss).mean() 105 | 106 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() 107 | epe = epe.view(-1)[valid.view(-1)] 108 | loss_value = flow_loss.detach() 109 | rate0 = (epe > 1).float().mean() 110 | rate1 = (epe > 3).float().mean() 111 | error3 = epe.mean() 112 | epe = torch.sum((flow_preds[1] - flow_gt)**2, dim=1).sqrt() 113 | epe = epe.view(-1)[valid.view(-1)] 114 | error1 = epe.mean() 115 | epe = torch.sum((flow_preds[0] - flow_gt)**2, dim=1).sqrt() 116 | epe = epe.view(-1)[valid.view(-1)] 117 | error0 = epe.mean() 118 | epe = torch.sum((flow_preds[2] - flow_gt)**2, dim=1).sqrt() 119 | epe = epe.view(-1)[valid.view(-1)] 120 | error2 = epe.mean() 121 | 122 | if args.multiprocessing_distributed: 123 | count = flow_gt.new_tensor([1], dtype=torch.long) 124 | dist.all_reduce(loss_value), dist.all_reduce(error3), dist.all_reduce(error0), dist.all_reduce(error1), dist.all_reduce(error2), dist.all_reduce(count) 125 | dist.all_reduce(rate0), dist.all_reduce(rate1) 126 | n = count.item() 127 | loss_value, error0, error1, error2, error3 = loss_value / n, error0 / n, error1 / n, error2 / n, error3 / n 128 | rate1, rate0 = rate1 / n, rate0 / n 129 | 130 | metrics = { 131 | 'epe0': error0.item(), 132 | 'epe1': error1.item(), 133 | 'epe2': error2.item(), 134 | 'epe3': error3.item(), 135 | '1px': rate0.item(), 136 | '3px': rate1.item(), 137 | 'loss': loss_value.item() 138 | } 139 | return flow_loss, metrics 140 | 141 | class Logger: 142 | def __init__(self, model, scheduler): 143 | self.model = model 144 | self.scheduler = scheduler 145 | self.total_steps = 0 146 | self.running_loss = {} 147 | self.writer = None 148 | 149 | def _print_training_status(self): 150 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())] 151 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0]) 152 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data) 153 | 154 | # print the training status 155 | print(training_str + metrics_str) 156 | 157 | if self.writer is None: 158 | self.writer = SummaryWriter() 159 | 160 | for k in self.running_loss: 161 | self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps) 162 | self.running_loss[k] = 0.0 163 | 164 | def push(self, metrics): 165 | self.total_steps += 1 166 | 167 | for key in metrics: 168 | if key not in self.running_loss: 169 | self.running_loss[key] = 0.0 170 | 171 | self.running_loss[key] += metrics[key] 172 | 173 | if self.total_steps % SUM_FREQ == SUM_FREQ-1: 174 | self._print_training_status() 175 | self.running_loss = {} 176 | 177 | def write_dict(self, results): 178 | if self.writer is None: 179 | self.writer = SummaryWriter() 180 | 181 | for key in results: 182 | self.writer.add_scalar(key, results[key], self.total_steps) 183 | 184 | def close(self): 185 | self.writer.close() 186 | 187 | def find_free_port(): 188 | import socket 189 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 190 | # Binding to port 0 will cause the OS to find an available port for us 191 | sock.bind(("", 0)) 192 | port = sock.getsockname()[1] 193 | sock.close() 194 | # NOTE: there is still a chance the port could be taken by other processes. 195 | return port 196 | def main(): 197 | args = parser.parse_args() 198 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 199 | args.gpu = (args.gpu).split(',') 200 | torch.backends.cudnn.benchmark = True 201 | # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu.split(',')) 202 | #args.distributed = args.world_size > 1 or args.multiprocessing_distributed 203 | if args.manual_seed is not None: 204 | np.random.seed(args.manual_seed) 205 | torch.manual_seed(args.manual_seed) 206 | torch.cuda.manual_seed(args.manual_seed) 207 | torch.cuda.manual_seed_all(args.manual_seed) 208 | cudnn.benchmark = True 209 | cudnn.deterministic = True 210 | args.ngpus_per_node = len(args.gpu) 211 | if len(args.gpu) == 1: 212 | args.sync_bn = False 213 | args.distributed = False 214 | args.multiprocessing_distributed = False 215 | main_worker(args.gpu, args.ngpus_per_node, args) 216 | else: 217 | args.sync_bn = True 218 | args.distributed = True 219 | args.multiprocessing_distributed = True 220 | port = find_free_port() 221 | args.dist_url = f"tcp://127.0.0.1:{port}" 222 | #print(args) 223 | #quit() 224 | args.world_size = args.ngpus_per_node * args.world_size 225 | mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args.ngpus_per_node, args)) 226 | 227 | def fetch_optimizer(args, model): 228 | """ Create the optimizer and learning rate scheduler """ 229 | modules_ori = [model.cnet, model.fnet, model.update_block, model.guidance] 230 | modules_new = [model.cost_agg1, model.cost_agg2] 231 | params_list = [] 232 | for module in modules_ori: 233 | params_list.append(dict(params=module.parameters(), lr=args.lr)) 234 | for module in modules_new: 235 | params_list.append(dict(params=module.parameters(), lr=args.lr * 2.5)) 236 | #optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) 237 | optimizer = optim.AdamW(params_list, lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) 238 | 239 | 240 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, 241 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear') 242 | 243 | return optimizer, scheduler 244 | 245 | def main_process(): 246 | return not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % args.ngpus_per_node == 0) 247 | def main_worker(gpu, ngpus_per_node, argss): 248 | global args 249 | args = argss 250 | if args.distributed: 251 | if args.dist_url == "env://" and args.rank == -1: 252 | args.rank = int(os.environ["RANK"]) 253 | if args.multiprocessing_distributed: 254 | args.rank = args.rank * ngpus_per_node + gpu 255 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) 256 | 257 | model = SepFlow(args) 258 | optimizer, scheduler = fetch_optimizer(args, model) 259 | 260 | if args.sync_bn: 261 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 262 | if args.distributed: 263 | torch.cuda.set_device(gpu) 264 | args.batchSize = int(args.batchSize / args.ngpus_per_node) 265 | args.testBatchSize = int(args.testBatchSize / args.ngpus_per_node) 266 | args.workers = int((args.workers + args.ngpus_per_node - 1) / args.ngpus_per_node) 267 | model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[gpu]) 268 | else: 269 | model = torch.nn.DataParallel(model).cuda() 270 | 271 | #scheduler = None 272 | logger = Logger(model, scheduler) 273 | 274 | 275 | if args.weights: 276 | if os.path.isfile(args.weights): 277 | checkpoint = torch.load(args.weights, map_location=lambda storage, loc: storage.cuda()) 278 | msg=model.load_state_dict(checkpoint['state_dict'], strict=False) 279 | if main_process(): 280 | print("=> loaded checkpoint '{}'".format(args.weights)) 281 | print(msg) 282 | sys.stdout.flush() 283 | else: 284 | if main_process(): 285 | print("=> no checkpoint found at '{}'".format(args.weights)) 286 | if args.resume: 287 | if os.path.isfile(args.resume): 288 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda()) 289 | msg=model.load_state_dict(checkpoint['state_dict'], strict=False) 290 | optimizer.load_state_dict(checkpoint['optimizer']) 291 | scheduler.load_state_dict(checkpoint['scheduler']) 292 | 293 | args.start_epoch = checkpoint['epoch'] + 1 294 | if main_process(): 295 | print("=> resume checkpoint '{}'".format(args.resume)) 296 | print(msg) 297 | sys.stdout.flush() 298 | else: 299 | if main_process(): 300 | print("=> no checkpoint found at '{}'".format(args.resume)) 301 | 302 | train_set = datasets.fetch_dataloader(args) 303 | val_set = datasets.KITTI(split='training') 304 | val_set3 = datasets.FlyingChairs(split='validation') 305 | val_set2_2 = datasets.MpiSintel(split='training', dstype='final') 306 | val_set2_1 = datasets.MpiSintel(split='training', dstype='clean') 307 | sys.stdout.flush() 308 | if args.distributed: 309 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) 310 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_set) 311 | val_sampler2_2 = torch.utils.data.distributed.DistributedSampler(val_set2_2) 312 | val_sampler2_1 = torch.utils.data.distributed.DistributedSampler(val_set2_1) 313 | val_sampler3 = torch.utils.data.distributed.DistributedSampler(val_set3) 314 | else: 315 | train_sampler = None 316 | val_sampler = None 317 | val_sampler2_1 = None 318 | val_sampler2_2 = None 319 | val_sampler3 = None 320 | training_data_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batchSize, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 321 | val_data_loader = torch.utils.data.DataLoader(val_set, batch_size=args.testBatchSize, shuffle=False, num_workers=args.workers//2, pin_memory=True, sampler=val_sampler) 322 | val_data_loader2_2 = torch.utils.data.DataLoader(val_set2_2, batch_size=args.testBatchSize, shuffle=False, num_workers=args.workers//2, pin_memory=True, sampler=val_sampler2_2) 323 | val_data_loader2_1 = torch.utils.data.DataLoader(val_set2_1, batch_size=args.testBatchSize, shuffle=False, num_workers=args.workers//2, pin_memory=True, sampler=val_sampler2_1) 324 | val_data_loader3 = torch.utils.data.DataLoader(val_set3, batch_size=args.testBatchSize, shuffle=False, num_workers=args.workers//2, pin_memory=True, sampler=val_sampler3) 325 | 326 | error = 100 327 | args.nEpochs = args.num_steps // len(training_data_loader) + 1 328 | 329 | for epoch in range(args.start_epoch, args.nEpochs): 330 | if args.distributed: 331 | train_sampler.set_epoch(epoch) 332 | 333 | train(training_data_loader, model, optimizer, scheduler, logger, epoch) 334 | if main_process() and epoch > args.nEpochs - 3: 335 | save_checkpoint(args.save_path, epoch,{ 336 | 'epoch': epoch, 337 | 'state_dict': model.state_dict(), 338 | 'optimizer' : optimizer.state_dict(), 339 | 'scheduler' : scheduler.state_dict(), 340 | }, False) 341 | 342 | if args.stage == 'chairs': 343 | loss = val(val_data_loader3, model, split='chairs') 344 | elif args.stage == 'sintel' or args.stage == 'things': 345 | loss_tmp = val(val_data_loader2_1, model, split='sintel', iters=32) 346 | loss_tmp = val(val_data_loader2_2, model, split='sintel', iters=32) 347 | loss_tmp = val(val_data_loader, model, split='kitti') 348 | elif args.stage == 'kitti': 349 | loss_tmp = val(val_data_loader, model, split='kitti') 350 | 351 | if main_process(): 352 | save_checkpoint(args.save_path, args.nEpochs,{ 353 | 'state_dict': model.state_dict() 354 | }, True) 355 | 356 | 357 | 358 | 359 | def train(training_data_loader, model, optimizer, scheduler, logger, epoch): 360 | valid_iteration = 0 361 | model.train() 362 | if args.freeze_bn: 363 | model.module.freeze_bn() 364 | if main_process(): 365 | print("Epoch " + str(epoch) + ": freezing bn...") 366 | sys.stdout.flush() 367 | for iteration, batch in enumerate(training_data_loader): 368 | input1, input2, target, valid = Variable(batch[0], requires_grad=True), Variable(batch[1], requires_grad=True), Variable(batch[2], requires_grad=False), Variable(batch[3], requires_grad=False) 369 | input1 = input1.cuda(non_blocking=True) 370 | input2 = input2.cuda(non_blocking=True) 371 | target = target.cuda(non_blocking=True) 372 | valid = valid.cuda(non_blocking=True) 373 | if len(valid.shape) > 3: 374 | valid = valid.squeeze(1) 375 | if valid.sum() > 0: 376 | optimizer.zero_grad() 377 | if args.add_noise: 378 | stdv = np.random.uniform(0.0, 5.0) 379 | input1 = (input1 + stdv * torch.randn(*input1.shape).cuda()).clamp(0.0, 255.0) 380 | input2 = (input2 + stdv * torch.randn(*input2.shape).cuda()).clamp(0.0, 255.0) 381 | 382 | flow_predictions = model(input1, input2, iters=args.iters) 383 | loss, metrics = sequence_loss(flow_predictions, target, valid) 384 | 385 | loss.backward() 386 | optimizer.step() 387 | scheduler.step() 388 | adjust_learning_rate(optimizer, scheduler) 389 | if scheduler.get_last_lr()[0] < 0.0000002: 390 | return 391 | 392 | 393 | valid_iteration += 1 394 | 395 | if main_process(): 396 | logger.push(metrics) 397 | # print(metrics) 398 | if valid_iteration % 10000 == 0: 399 | save_checkpoint(args.save_path, epoch,{ 400 | 'epoch': epoch, 401 | 'state_dict': model.state_dict(), 402 | 'optimizer' : optimizer.state_dict(), 403 | 'scheduler' : scheduler.state_dict(), 404 | }, False) 405 | 406 | sys.stdout.flush() 407 | 408 | def val(testing_data_loader, model, split='sintel', iters=24): 409 | epoch_error = 0 410 | epoch_error_rate0 = 0 411 | epoch_error_rate1 = 0 412 | valid_iteration = 0 413 | model.eval() 414 | for iteration, batch in enumerate(testing_data_loader): 415 | input1, input2, target, valid = Variable(batch[0],requires_grad=False), Variable(batch[1], requires_grad=False), Variable(batch[2], requires_grad=False), Variable(batch[3], requires_grad=False) 416 | input1 = input1.cuda(non_blocking=True) 417 | input2 = input2.cuda(non_blocking=True) 418 | padder = InputPadder(input1.shape, mode=split) 419 | input1, input2 = padder.pad(input1, input2) 420 | target = target.cuda(non_blocking=True) 421 | valid = valid.cuda(non_blocking=True) 422 | mag = torch.sum(target**2, dim=1, keepdim=False).sqrt() 423 | if len(valid.shape) > 3: 424 | valid = valid.squeeze(1) 425 | valid = (valid >= 0.001) #& (mag < MAX_FLOW) 426 | if valid.sum()>0: 427 | with torch.no_grad(): 428 | _, flow = model(input1,input2, iters=iters) 429 | flow = padder.unpad(flow) 430 | epe = torch.sum((flow - target)**2, dim=1).sqrt() 431 | epe = epe.view(-1)[valid.view(-1)] 432 | rate0 = (epe > 1).float().mean() 433 | if split == 'kitti': 434 | rate1 = ((epe > 3.0) & ((epe/mag.view(-1)[valid.view(-1)]) > 0.05)).float().mean() 435 | else: 436 | rate1 = (epe > 3.0).float().mean() 437 | error = epe.mean() 438 | valid_iteration += 1 439 | if args.multiprocessing_distributed: 440 | count = target.new_tensor([1], dtype=torch.long) 441 | dist.all_reduce(error) 442 | dist.all_reduce(rate0) 443 | dist.all_reduce(rate1) 444 | dist.all_reduce(count) 445 | n = count.item() 446 | error /= n 447 | rate0 /= n 448 | rate1 /= n 449 | epoch_error += error.item() 450 | epoch_error_rate0 += rate0.item() 451 | epoch_error_rate1 += rate1.item() 452 | 453 | if main_process() and (valid_iteration % 1000 == 0): 454 | print("===> Test({}/{}): Error: ({:.4f} {:.4f} {:.4f})".format(iteration, len(testing_data_loader), error.item(), rate0.item(), rate1.item())) 455 | sys.stdout.flush() 456 | 457 | if main_process(): 458 | print("===> Test: Avg. Error: ({:.4f} {:.4f} {:.4f})".format(epoch_error/valid_iteration, epoch_error_rate0/valid_iteration, epoch_error_rate1/valid_iteration)) 459 | 460 | return epoch_error/valid_iteration 461 | 462 | def save_checkpoint(save_path, epoch,state, is_best): 463 | filename = save_path + "_epoch_{}.pth".format(epoch) 464 | if is_best: 465 | filename = save_path + ".pth" 466 | torch.save(state, filename) 467 | print("Checkpoint saved to {}".format(filename)) 468 | 469 | def adjust_learning_rate(optimizer, scheduler): 470 | lr = scheduler.get_last_lr()[0] 471 | nums = len(optimizer.param_groups) 472 | for index in range(0, nums-2): 473 | optimizer.param_groups[index]['lr'] = lr 474 | for index in range(nums-2, nums): 475 | optimizer.param_groups[index]['lr'] = lr * 2.5 476 | 477 | if __name__ == '__main__': 478 | main() 479 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | mkdir -p checkpoints 2 | mkdir -p logs 3 | python train.py --batchSize=12 --testBatchSize=4 --thread=16 --stage='chairs' --freeze_bn=0 --gpu='0,1,2,3' --lr=0.0004 --save_path='checkpoints/chairs' --start_epoch=0 --image_size 320 448 --wdecay 0.0001 --gamma=0.8 --num_steps 50000 #2>&1 | tee logs/log_chairs.txt 4 | python train.py --stage='things' --weights 'checkpoints/chairs.pth' --gpu='0,1,2,3' --num_steps 100000 --batchSize 8 --testBatchSize=4 --lr 0.000125 --image_size 448 768 --wdecay 0.0001 --freeze_bn=1 --save_path='checkpoints/things' --gamma=0.8 #2>&1 | tee logs/log_things.txt 5 | python -u train.py --stage='sintel' --weights 'checkpoints/things.pth' --gpu='0,1,2,3' --num_steps 100000 --batchSize 8 --testBatchSize=4 --lr 0.000125 --image_size 384 832 --wdecay 0.00001 --freeze_bn=1 --save_path='checkpoints/sintel' --gamma=0.85 #2>&1 |tee logs/log_sintel.txt 6 | python -u train.py --stage='kitti' --weights 'checkpoints/sintel.pth' --gpu='0,1,2,3' --num_steps 50000 --batchSize 8 --testBatchSize=4 --lr 0.0001 --image_size 320 1024 --wdecay 0.00001 --freeze_bn=1 --save_path='checkpoints/kitti' --gamma=0.85 #2>&1 | tee logs/log_kitti.txt 7 | 8 | ## If you have more GPU resource, consider using larger batchsize 9 | 10 | #python train.py --batchSize=16 --testBatchSize=4 --thread=16 --resume='' --stage='chairs' --freeze_bn=0 --gpu='0,1,2,3' --lr=0.0004 --save_path='checkpoints/chairs' --start_epoch=0 --image_size 320 448 --wdecay 0.0001 --gamma=0.8 --num_steps 50000 2>&1 | tee logs/log_chairs.txt 11 | 12 | #python train.py --stage='things' --weights 'checkpoints/chairs.pth' --gpu='0,1,2,3' --num_steps 100000 --batchSize 12 --testBatchSize=4 --lr 0.000125 --image_size 448 768 --wdecay 0.0001 --freeze_bn=1 --save_path='checkpoints/things' --gamma=0.8 2>&1 | tee logs/log_things.txt 13 | 14 | #python -u train.py --stage='sintel' --weights 'checkpoints/things.pth' --gpu='0,1,2,3' --num_steps 100000 --batchSize 12 --testBatchSize=4 --lr 0.000125 --image_size 384 832 --wdecay 0.00001 --freeze_bn=1 --save_path='checkpoints/sintel' --gamma=0.85 2>&1 |tee logs/log_sintel.txt 15 | 16 | #python -u train.py --stage='kitti' --weights 'checkpoints/sintel.pth' --gpu='0,1,2,3' --num_steps 50000 --batchSize 12 --testBatchSize=4 --lr 0.0001 --image_size 320 1024 --wdecay 0.00001 --freeze_bn=1 --save_path='checkpoints/kitti' --gamma=0.85 2>&1 | tee logs/log_kitti.txt 17 | 18 | --------------------------------------------------------------------------------