├── LICENSE.txt ├── RAFT ├── .gitignore ├── LICENSE ├── RAFT.png ├── README.md ├── alt_cuda_corr │ ├── correlation.cpp │ ├── correlation_kernel.cu │ └── setup.py ├── chairs_split.txt ├── core │ ├── __init__.py │ ├── corr.py │ ├── datasets.py │ ├── extractor.py │ ├── raft.py │ ├── update.py │ └── utils │ │ ├── __init__.py │ │ ├── augmentor.py │ │ ├── flow_viz.py │ │ ├── frame_utils.py │ │ └── utils.py ├── demo-frames │ ├── frame_0016.png │ ├── frame_0017.png │ ├── frame_0018.png │ ├── frame_0019.png │ ├── frame_0020.png │ ├── frame_0021.png │ ├── frame_0022.png │ ├── frame_0023.png │ ├── frame_0024.png │ └── frame_0025.png ├── demo.py ├── download_models.sh ├── evaluate.py ├── train.py ├── train_mixed.sh ├── train_standard.sh └── weights │ └── raft-things.pth ├── README.md ├── adampiweight └── adampi_64p.pth ├── bilateral_filter.py ├── core ├── __init__.py ├── corr.py ├── datasets.py ├── extractor.py ├── raft.py ├── update.py └── utils │ ├── __init__.py │ ├── augmentor.py │ ├── flow_viz.py │ ├── frame_utils.py │ └── utils.py ├── external └── forward_warping │ ├── compile.sh │ ├── libwarping.so │ └── warping.c ├── flow_colors.py ├── gen_3dphoto_dynamic_v2.py ├── geometry.py ├── misc └── train_image_2_000000_00_1.png ├── model ├── AdaMPI.py ├── CPN │ ├── decoder.py │ ├── encoder.py │ └── unet.py └── PAN.py ├── moving_obj.py ├── scripts ├── gen_coco.sh ├── gen_test_kitti15.sh ├── gen_train_kitti15.sh └── gen_train_kitti15_v2.sh ├── utils ├── arrow.py ├── flow_viz.py ├── mpi │ ├── homography_sampler.py │ ├── mpi_rendering.py │ └── rendering_utils.py ├── transform.py ├── utils copy.py ├── utils.py └── utils_coco.py ├── vis_flow.py ├── warpback ├── networks.py ├── stage1_dataset.py ├── stage2_dataset.py └── utils.py └── write_flow.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Dmitry Ryumin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /RAFT/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | dist 4 | datasets 5 | pytorch_env 6 | models 7 | build 8 | correlation.egg-info 9 | -------------------------------------------------------------------------------- /RAFT/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, princeton-vl 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /RAFT/RAFT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/RAFT.png -------------------------------------------------------------------------------- /RAFT/README.md: -------------------------------------------------------------------------------- 1 | # RAFT 2 | This repository contains the source code for our paper: 3 | 4 | [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
5 | ECCV 2020
6 | Zachary Teed and Jia Deng
7 | 8 | 9 | 10 | ## Requirements 11 | The code has been tested with PyTorch 1.6 and Cuda 10.1. 12 | ```Shell 13 | conda create --name raft 14 | conda activate raft 15 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch 16 | ``` 17 | 18 | ## Demos 19 | Pretrained models can be downloaded by running 20 | ```Shell 21 | ./download_models.sh 22 | ``` 23 | or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing) 24 | 25 | You can demo a trained model on a sequence of frames 26 | ```Shell 27 | python demo.py --model=models/raft-things.pth --path=demo-frames 28 | ``` 29 | 30 | ## Required Data 31 | To evaluate/train RAFT, you will need to download the required datasets. 32 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 33 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 34 | * [Sintel](http://sintel.is.tue.mpg.de/) 35 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 36 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional) 37 | 38 | 39 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder 40 | 41 | ```Shell 42 | ├── datasets 43 | ├── Sintel 44 | ├── test 45 | ├── training 46 | ├── KITTI 47 | ├── testing 48 | ├── training 49 | ├── devkit 50 | ├── FlyingChairs_release 51 | ├── data 52 | ├── FlyingThings3D 53 | ├── frames_cleanpass 54 | ├── frames_finalpass 55 | ├── optical_flow 56 | ``` 57 | 58 | ## Evaluation 59 | You can evaluate a trained model using `evaluate.py` 60 | ```Shell 61 | python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision 62 | ``` 63 | 64 | ## Training 65 | We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard 66 | ```Shell 67 | ./train_standard.sh 68 | ``` 69 | 70 | If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU) 71 | ```Shell 72 | ./train_mixed.sh 73 | ``` 74 | 75 | ## (Optional) Efficent Implementation 76 | You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension 77 | ```Shell 78 | cd alt_cuda_corr && python setup.py install && cd .. 79 | ``` 80 | and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass. 81 | -------------------------------------------------------------------------------- /RAFT/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /RAFT/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /RAFT/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/core/__init__.py -------------------------------------------------------------------------------- /RAFT/core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class AlternateCorrBlock: 64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 65 | self.num_levels = num_levels 66 | self.radius = radius 67 | 68 | self.pyramid = [(fmap1, fmap2)] 69 | for i in range(self.num_levels): 70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 72 | self.pyramid.append((fmap1, fmap2)) 73 | 74 | def __call__(self, coords): 75 | coords = coords.permute(0, 2, 3, 1) 76 | B, H, W, _ = coords.shape 77 | dim = self.pyramid[0][0].shape[1] 78 | 79 | corr_list = [] 80 | for i in range(self.num_levels): 81 | r = self.radius 82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 84 | 85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 87 | corr_list.append(corr.squeeze(1)) 88 | 89 | corr = torch.stack(corr_list, dim=1) 90 | corr = corr.reshape(B, -1, H, W) 91 | return corr / torch.sqrt(torch.tensor(dim).float()) 92 | -------------------------------------------------------------------------------- /RAFT/core/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | num_groups = planes // 8 70 | 71 | if norm_fn == 'group': 72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | if not stride == 1: 76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | 78 | elif norm_fn == 'batch': 79 | self.norm1 = nn.BatchNorm2d(planes//4) 80 | self.norm2 = nn.BatchNorm2d(planes//4) 81 | self.norm3 = nn.BatchNorm2d(planes) 82 | if not stride == 1: 83 | self.norm4 = nn.BatchNorm2d(planes) 84 | 85 | elif norm_fn == 'instance': 86 | self.norm1 = nn.InstanceNorm2d(planes//4) 87 | self.norm2 = nn.InstanceNorm2d(planes//4) 88 | self.norm3 = nn.InstanceNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.InstanceNorm2d(planes) 91 | 92 | elif norm_fn == 'none': 93 | self.norm1 = nn.Sequential() 94 | self.norm2 = nn.Sequential() 95 | self.norm3 = nn.Sequential() 96 | if not stride == 1: 97 | self.norm4 = nn.Sequential() 98 | 99 | if stride == 1: 100 | self.downsample = None 101 | 102 | else: 103 | self.downsample = nn.Sequential( 104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 105 | 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.relu(self.norm1(self.conv1(y))) 110 | y = self.relu(self.norm2(self.conv2(y))) 111 | y = self.relu(self.norm3(self.conv3(y))) 112 | 113 | if self.downsample is not None: 114 | x = self.downsample(x) 115 | 116 | return self.relu(x+y) 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | 123 | if self.norm_fn == 'group': 124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 125 | 126 | elif self.norm_fn == 'batch': 127 | self.norm1 = nn.BatchNorm2d(64) 128 | 129 | elif self.norm_fn == 'instance': 130 | self.norm1 = nn.InstanceNorm2d(64) 131 | 132 | elif self.norm_fn == 'none': 133 | self.norm1 = nn.Sequential() 134 | 135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 136 | self.relu1 = nn.ReLU(inplace=True) 137 | 138 | self.in_planes = 64 139 | self.layer1 = self._make_layer(64, stride=1) 140 | self.layer2 = self._make_layer(96, stride=2) 141 | self.layer3 = self._make_layer(128, stride=2) 142 | 143 | # output convolution 144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 145 | 146 | self.dropout = None 147 | if dropout > 0: 148 | self.dropout = nn.Dropout2d(p=dropout) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 154 | if m.weight is not None: 155 | nn.init.constant_(m.weight, 1) 156 | if m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | 159 | def _make_layer(self, dim, stride=1): 160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 162 | layers = (layer1, layer2) 163 | 164 | self.in_planes = dim 165 | return nn.Sequential(*layers) 166 | 167 | 168 | def forward(self, x): 169 | 170 | # if input is list, combine batch dimension 171 | is_list = isinstance(x, tuple) or isinstance(x, list) 172 | if is_list: 173 | batch_dim = x[0].shape[0] 174 | x = torch.cat(x, dim=0) 175 | 176 | x = self.conv1(x) 177 | x = self.norm1(x) 178 | x = self.relu1(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | 184 | x = self.conv2(x) 185 | 186 | if self.training and self.dropout is not None: 187 | x = self.dropout(x) 188 | 189 | if is_list: 190 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 191 | 192 | return x 193 | 194 | 195 | class SmallEncoder(nn.Module): 196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 197 | super(SmallEncoder, self).__init__() 198 | self.norm_fn = norm_fn 199 | 200 | if self.norm_fn == 'group': 201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 202 | 203 | elif self.norm_fn == 'batch': 204 | self.norm1 = nn.BatchNorm2d(32) 205 | 206 | elif self.norm_fn == 'instance': 207 | self.norm1 = nn.InstanceNorm2d(32) 208 | 209 | elif self.norm_fn == 'none': 210 | self.norm1 = nn.Sequential() 211 | 212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 213 | self.relu1 = nn.ReLU(inplace=True) 214 | 215 | self.in_planes = 32 216 | self.layer1 = self._make_layer(32, stride=1) 217 | self.layer2 = self._make_layer(64, stride=2) 218 | self.layer3 = self._make_layer(96, stride=2) 219 | 220 | self.dropout = None 221 | if dropout > 0: 222 | self.dropout = nn.Dropout2d(p=dropout) 223 | 224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 230 | if m.weight is not None: 231 | nn.init.constant_(m.weight, 1) 232 | if m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | 235 | def _make_layer(self, dim, stride=1): 236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 238 | layers = (layer1, layer2) 239 | 240 | self.in_planes = dim 241 | return nn.Sequential(*layers) 242 | 243 | 244 | def forward(self, x): 245 | 246 | # if input is list, combine batch dimension 247 | is_list = isinstance(x, tuple) or isinstance(x, list) 248 | if is_list: 249 | batch_dim = x[0].shape[0] 250 | x = torch.cat(x, dim=0) 251 | 252 | x = self.conv1(x) 253 | x = self.norm1(x) 254 | x = self.relu1(x) 255 | 256 | x = self.layer1(x) 257 | x = self.layer2(x) 258 | x = self.layer3(x) 259 | x = self.conv2(x) 260 | 261 | if self.training and self.dropout is not None: 262 | x = self.dropout(x) 263 | 264 | if is_list: 265 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 266 | 267 | return x 268 | -------------------------------------------------------------------------------- /RAFT/core/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from update import BasicUpdateBlock, SmallUpdateBlock 7 | from extractor import BasicEncoder, SmallEncoder 8 | from corr import CorrBlock, AlternateCorrBlock 9 | from utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | def __enter__(self): 19 | pass 20 | def __exit__(self, *args): 21 | pass 22 | 23 | 24 | class RAFT(nn.Module): 25 | def __init__(self, args): 26 | super(RAFT, self).__init__() 27 | self.args = args 28 | 29 | if args.small: 30 | self.hidden_dim = hdim = 96 31 | self.context_dim = cdim = 64 32 | args.corr_levels = 4 33 | args.corr_radius = 3 34 | 35 | else: 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | args.corr_levels = 4 39 | args.corr_radius = 4 40 | 41 | if 'dropout' not in self.args: 42 | self.args.dropout = 0 43 | 44 | if 'alternate_corr' not in self.args: 45 | self.args.alternate_corr = False 46 | 47 | # feature network, context network, and update block 48 | if args.small: 49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) 51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 52 | 53 | else: 54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 57 | 58 | def freeze_bn(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.BatchNorm2d): 61 | m.eval() 62 | 63 | def initialize_flow(self, img): 64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 65 | N, C, H, W = img.shape 66 | coords0 = coords_grid(N, H//8, W//8, device=img.device) 67 | coords1 = coords_grid(N, H//8, W//8, device=img.device) 68 | 69 | # optical flow computed as difference: flow = coords1 - coords0 70 | return coords0, coords1 71 | 72 | def upsample_flow(self, flow, mask): 73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 74 | N, _, H, W = flow.shape 75 | mask = mask.view(N, 1, 9, 8, 8, H, W) 76 | mask = torch.softmax(mask, dim=2) 77 | 78 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 80 | 81 | up_flow = torch.sum(mask * up_flow, dim=2) 82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 83 | return up_flow.reshape(N, 2, 8*H, 8*W) 84 | 85 | 86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 87 | """ Estimate optical flow between pair of frames """ 88 | 89 | image1 = 2 * (image1 / 255.0) - 1.0 90 | image2 = 2 * (image2 / 255.0) - 1.0 91 | 92 | image1 = image1.contiguous() 93 | image2 = image2.contiguous() 94 | 95 | hdim = self.hidden_dim 96 | cdim = self.context_dim 97 | 98 | # run the feature network 99 | with autocast(enabled=self.args.mixed_precision): 100 | fmap1, fmap2 = self.fnet([image1, image2]) 101 | 102 | fmap1 = fmap1.float() 103 | fmap2 = fmap2.float() 104 | if self.args.alternate_corr: 105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 106 | else: 107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | 109 | # run the context network 110 | with autocast(enabled=self.args.mixed_precision): 111 | cnet = self.cnet(image1) 112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 113 | net = torch.tanh(net) 114 | inp = torch.relu(inp) 115 | 116 | coords0, coords1 = self.initialize_flow(image1) 117 | 118 | if flow_init is not None: 119 | coords1 = coords1 + flow_init 120 | 121 | flow_predictions = [] 122 | for itr in range(iters): 123 | coords1 = coords1.detach() 124 | corr = corr_fn(coords1) # index correlation volume 125 | 126 | flow = coords1 - coords0 127 | with autocast(enabled=self.args.mixed_precision): 128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 129 | 130 | # F(t+1) = F(t) + \Delta(t) 131 | coords1 = coords1 + delta_flow 132 | 133 | # upsample predictions 134 | if up_mask is None: 135 | flow_up = upflow8(coords1 - coords0) 136 | else: 137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 138 | 139 | flow_predictions.append(flow_up) 140 | 141 | if test_mode: 142 | return coords1 - coords0, flow_up 143 | 144 | return flow_predictions 145 | -------------------------------------------------------------------------------- /RAFT/core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /RAFT/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/core/utils/__init__.py -------------------------------------------------------------------------------- /RAFT/core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /RAFT/core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /RAFT/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd, device): 75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0016.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0017.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0018.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0019.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0020.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0021.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0022.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0022.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0023.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0023.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0024.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0025.png -------------------------------------------------------------------------------- /RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | import argparse 5 | import os 6 | import cv2 7 | import glob 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | 12 | from raft import RAFT 13 | from utils import flow_viz 14 | from utils.utils import InputPadder 15 | 16 | 17 | 18 | DEVICE = 'cuda' 19 | 20 | def load_image(imfile): 21 | img = np.array(Image.open(imfile)).astype(np.uint8) 22 | img = torch.from_numpy(img).permute(2, 0, 1).float() 23 | return img[None].to(DEVICE) 24 | 25 | 26 | def viz(img, flo): 27 | img = img[0].permute(1,2,0).cpu().numpy() 28 | flo = flo[0].permute(1,2,0).cpu().numpy() 29 | 30 | # map flow to rgb image 31 | flo = flow_viz.flow_to_image(flo) 32 | img_flo = np.concatenate([img, flo], axis=0) 33 | 34 | # import matplotlib.pyplot as plt 35 | # plt.imshow(img_flo / 255.0) 36 | # plt.show() 37 | 38 | cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 39 | cv2.waitKey() 40 | 41 | 42 | def demo(args): 43 | model = torch.nn.DataParallel(RAFT(args)) 44 | model.load_state_dict(torch.load(args.model)) 45 | 46 | model = model.module 47 | model.to(DEVICE) 48 | model.eval() 49 | 50 | with torch.no_grad(): 51 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 52 | glob.glob(os.path.join(args.path, '*.jpg')) 53 | 54 | images = sorted(images) 55 | for imfile1, imfile2 in zip(images[:-1], images[1:]): 56 | image1 = load_image(imfile1) 57 | image2 = load_image(imfile2) 58 | 59 | padder = InputPadder(image1.shape) 60 | image1, image2 = padder.pad(image1, image2) 61 | 62 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 63 | viz(image1, flow_up) 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--model', help="restore checkpoint") 69 | parser.add_argument('--path', help="dataset for evaluation") 70 | parser.add_argument('--small', action='store_true', help='use small model') 71 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 72 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 73 | args = parser.parse_args() 74 | 75 | demo(args) 76 | -------------------------------------------------------------------------------- /RAFT/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://dl.dropboxusercontent.com/s/4j4z58wuv8o0mfz/models.zip 3 | unzip models.zip 4 | -------------------------------------------------------------------------------- /RAFT/evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | from PIL import Image 5 | import argparse 6 | import os 7 | import time 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import matplotlib.pyplot as plt 12 | 13 | import datasets 14 | from utils import flow_viz 15 | from utils import frame_utils 16 | 17 | from raft import RAFT 18 | from utils.utils import InputPadder, forward_interpolate 19 | 20 | 21 | @torch.no_grad() 22 | def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'): 23 | """ Create submission for the Sintel leaderboard """ 24 | model.eval() 25 | for dstype in ['clean', 'final']: 26 | test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype) 27 | 28 | flow_prev, sequence_prev = None, None 29 | for test_id in range(len(test_dataset)): 30 | image1, image2, (sequence, frame) = test_dataset[test_id] 31 | if sequence != sequence_prev: 32 | flow_prev = None 33 | 34 | padder = InputPadder(image1.shape) 35 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 36 | 37 | flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True) 38 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 39 | 40 | if warm_start: 41 | flow_prev = forward_interpolate(flow_low[0])[None].cuda() 42 | 43 | output_dir = os.path.join(output_path, dstype, sequence) 44 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1)) 45 | 46 | if not os.path.exists(output_dir): 47 | os.makedirs(output_dir) 48 | 49 | frame_utils.writeFlow(output_file, flow) 50 | sequence_prev = sequence 51 | 52 | 53 | @torch.no_grad() 54 | def create_kitti_submission(model, iters=24, output_path='kitti_submission'): 55 | """ Create submission for the Sintel leaderboard """ 56 | model.eval() 57 | test_dataset = datasets.KITTI(split='testing', aug_params=None) 58 | 59 | if not os.path.exists(output_path): 60 | os.makedirs(output_path) 61 | 62 | for test_id in range(len(test_dataset)): 63 | image1, image2, (frame_id, ) = test_dataset[test_id] 64 | padder = InputPadder(image1.shape, mode='kitti') 65 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 66 | 67 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True) 68 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 69 | 70 | output_filename = os.path.join(output_path, frame_id) 71 | frame_utils.writeFlowKITTI(output_filename, flow) 72 | 73 | 74 | @torch.no_grad() 75 | def validate_chairs(model, iters=24): 76 | """ Perform evaluation on the FlyingChairs (test) split """ 77 | model.eval() 78 | epe_list = [] 79 | 80 | val_dataset = datasets.FlyingChairs(split='validation') 81 | for val_id in range(len(val_dataset)): 82 | image1, image2, flow_gt, _ = val_dataset[val_id] 83 | image1 = image1[None].cuda() 84 | image2 = image2[None].cuda() 85 | 86 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True) 87 | epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt() 88 | epe_list.append(epe.view(-1).numpy()) 89 | 90 | epe = np.mean(np.concatenate(epe_list)) 91 | print("Validation Chairs EPE: %f" % epe) 92 | return {'chairs': epe} 93 | 94 | 95 | @torch.no_grad() 96 | def validate_sintel(model, iters=32): 97 | """ Peform validation using the Sintel (train) split """ 98 | model.eval() 99 | results = {} 100 | for dstype in ['clean', 'final']: 101 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype) 102 | epe_list = [] 103 | 104 | for val_id in range(len(val_dataset)): 105 | image1, image2, flow_gt, _ = val_dataset[val_id] 106 | image1 = image1[None].cuda() 107 | image2 = image2[None].cuda() 108 | 109 | padder = InputPadder(image1.shape) 110 | image1, image2 = padder.pad(image1, image2) 111 | 112 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) 113 | flow = padder.unpad(flow_pr[0]).cpu() 114 | 115 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 116 | epe_list.append(epe.view(-1).numpy()) 117 | 118 | epe_all = np.concatenate(epe_list) 119 | epe = np.mean(epe_all) 120 | px1 = np.mean(epe_all<1) 121 | px3 = np.mean(epe_all<3) 122 | px5 = np.mean(epe_all<5) 123 | 124 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) 125 | results[dstype] = np.mean(epe_list) 126 | 127 | return results 128 | 129 | 130 | @torch.no_grad() 131 | def validate_kitti(model, iters=24): 132 | """ Peform validation using the KITTI-2015 (train) split """ 133 | model.eval() 134 | val_dataset = datasets.KITTI(split='training') 135 | 136 | out_list, epe_list = [], [] 137 | for val_id in range(len(val_dataset)): 138 | image1, image2, flow_gt, valid_gt = val_dataset[val_id] 139 | image1 = image1[None].cuda() 140 | image2 = image2[None].cuda() 141 | 142 | padder = InputPadder(image1.shape, mode='kitti') 143 | image1, image2 = padder.pad(image1, image2) 144 | 145 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) 146 | flow = padder.unpad(flow_pr[0]).cpu() 147 | 148 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 149 | mag = torch.sum(flow_gt**2, dim=0).sqrt() 150 | 151 | epe = epe.view(-1) 152 | mag = mag.view(-1) 153 | val = valid_gt.view(-1) >= 0.5 154 | 155 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() 156 | epe_list.append(epe[val].mean().item()) 157 | out_list.append(out[val].cpu().numpy()) 158 | 159 | epe_list = np.array(epe_list) 160 | out_list = np.concatenate(out_list) 161 | 162 | epe = np.mean(epe_list) 163 | f1 = 100 * np.mean(out_list) 164 | 165 | print("Validation KITTI: %f, %f" % (epe, f1)) 166 | return {'kitti-epe': epe, 'kitti-f1': f1} 167 | 168 | 169 | if __name__ == '__main__': 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument('--model', help="restore checkpoint") 172 | parser.add_argument('--dataset', help="dataset for evaluation") 173 | parser.add_argument('--small', action='store_true', help='use small model') 174 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 175 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 176 | args = parser.parse_args() 177 | 178 | model = torch.nn.DataParallel(RAFT(args)) 179 | model.load_state_dict(torch.load(args.model)) 180 | 181 | model.cuda() 182 | model.eval() 183 | 184 | # create_sintel_submission(model.module, warm_start=True) 185 | # create_kitti_submission(model.module) 186 | 187 | with torch.no_grad(): 188 | if args.dataset == 'chairs': 189 | validate_chairs(model.module) 190 | 191 | elif args.dataset == 'sintel': 192 | validate_sintel(model.module) 193 | 194 | elif args.dataset == 'kitti': 195 | validate_kitti(model.module) 196 | 197 | 198 | -------------------------------------------------------------------------------- /RAFT/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import sys 3 | sys.path.append('core') 4 | 5 | import argparse 6 | import os 7 | import cv2 8 | import time 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | 17 | from torch.utils.data import DataLoader 18 | from raft import RAFT 19 | import evaluate 20 | import datasets 21 | 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | try: 25 | from torch.cuda.amp import GradScaler 26 | except: 27 | # dummy GradScaler for PyTorch < 1.6 28 | class GradScaler: 29 | def __init__(self): 30 | pass 31 | def scale(self, loss): 32 | return loss 33 | def unscale_(self, optimizer): 34 | pass 35 | def step(self, optimizer): 36 | optimizer.step() 37 | def update(self): 38 | pass 39 | 40 | 41 | # exclude extremly large displacements 42 | MAX_FLOW = 400 43 | SUM_FREQ = 100 44 | VAL_FREQ = 5000 45 | 46 | 47 | def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW): 48 | """ Loss function defined over sequence of flow predictions """ 49 | 50 | n_predictions = len(flow_preds) 51 | flow_loss = 0.0 52 | 53 | # exlude invalid pixels and extremely large diplacements 54 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 55 | valid = (valid >= 0.5) & (mag < max_flow) 56 | 57 | for i in range(n_predictions): 58 | i_weight = gamma**(n_predictions - i - 1) 59 | i_loss = (flow_preds[i] - flow_gt).abs() 60 | flow_loss += i_weight * (valid[:, None] * i_loss).mean() 61 | 62 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() 63 | epe = epe.view(-1)[valid.view(-1)] 64 | 65 | metrics = { 66 | 'epe': epe.mean().item(), 67 | '1px': (epe < 1).float().mean().item(), 68 | '3px': (epe < 3).float().mean().item(), 69 | '5px': (epe < 5).float().mean().item(), 70 | } 71 | 72 | return flow_loss, metrics 73 | 74 | 75 | def count_parameters(model): 76 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 77 | 78 | 79 | def fetch_optimizer(args, model): 80 | """ Create the optimizer and learning rate scheduler """ 81 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) 82 | 83 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, 84 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear') 85 | 86 | return optimizer, scheduler 87 | 88 | 89 | class Logger: 90 | def __init__(self, model, scheduler): 91 | self.model = model 92 | self.scheduler = scheduler 93 | self.total_steps = 0 94 | self.running_loss = {} 95 | self.writer = None 96 | 97 | def _print_training_status(self): 98 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())] 99 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0]) 100 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data) 101 | 102 | # print the training status 103 | print(training_str + metrics_str) 104 | 105 | if self.writer is None: 106 | self.writer = SummaryWriter() 107 | 108 | for k in self.running_loss: 109 | self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps) 110 | self.running_loss[k] = 0.0 111 | 112 | def push(self, metrics): 113 | self.total_steps += 1 114 | 115 | for key in metrics: 116 | if key not in self.running_loss: 117 | self.running_loss[key] = 0.0 118 | 119 | self.running_loss[key] += metrics[key] 120 | 121 | if self.total_steps % SUM_FREQ == SUM_FREQ-1: 122 | self._print_training_status() 123 | self.running_loss = {} 124 | 125 | def write_dict(self, results): 126 | if self.writer is None: 127 | self.writer = SummaryWriter() 128 | 129 | for key in results: 130 | self.writer.add_scalar(key, results[key], self.total_steps) 131 | 132 | def close(self): 133 | self.writer.close() 134 | 135 | 136 | def train(args): 137 | 138 | model = nn.DataParallel(RAFT(args), device_ids=args.gpus) 139 | print("Parameter Count: %d" % count_parameters(model)) 140 | 141 | if args.restore_ckpt is not None: 142 | model.load_state_dict(torch.load(args.restore_ckpt), strict=False) 143 | 144 | model.cuda() 145 | model.train() 146 | 147 | if args.stage != 'chairs': 148 | model.module.freeze_bn() 149 | 150 | train_loader = datasets.fetch_dataloader(args) 151 | optimizer, scheduler = fetch_optimizer(args, model) 152 | 153 | total_steps = 0 154 | scaler = GradScaler(enabled=args.mixed_precision) 155 | logger = Logger(model, scheduler) 156 | 157 | VAL_FREQ = 5000 158 | add_noise = True 159 | 160 | should_keep_training = True 161 | while should_keep_training: 162 | 163 | for i_batch, data_blob in enumerate(train_loader): 164 | optimizer.zero_grad() 165 | image1, image2, flow, valid = [x.cuda() for x in data_blob] 166 | 167 | if args.add_noise: 168 | stdv = np.random.uniform(0.0, 5.0) 169 | image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0) 170 | image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0) 171 | 172 | flow_predictions = model(image1, image2, iters=args.iters) 173 | 174 | loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma) 175 | scaler.scale(loss).backward() 176 | scaler.unscale_(optimizer) 177 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 178 | 179 | scaler.step(optimizer) 180 | scheduler.step() 181 | scaler.update() 182 | 183 | logger.push(metrics) 184 | 185 | if total_steps % VAL_FREQ == VAL_FREQ - 1: 186 | PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name) 187 | torch.save(model.state_dict(), PATH) 188 | 189 | results = {} 190 | for val_dataset in args.validation: 191 | if val_dataset == 'chairs': 192 | results.update(evaluate.validate_chairs(model.module)) 193 | elif val_dataset == 'sintel': 194 | results.update(evaluate.validate_sintel(model.module)) 195 | elif val_dataset == 'kitti': 196 | results.update(evaluate.validate_kitti(model.module)) 197 | 198 | logger.write_dict(results) 199 | 200 | model.train() 201 | if args.stage != 'chairs': 202 | model.module.freeze_bn() 203 | 204 | total_steps += 1 205 | 206 | if total_steps > args.num_steps: 207 | should_keep_training = False 208 | break 209 | 210 | logger.close() 211 | PATH = 'checkpoints/%s.pth' % args.name 212 | torch.save(model.state_dict(), PATH) 213 | 214 | return PATH 215 | 216 | 217 | if __name__ == '__main__': 218 | parser = argparse.ArgumentParser() 219 | parser.add_argument('--name', default='raft', help="name your experiment") 220 | parser.add_argument('--stage', help="determines which dataset to use for training") 221 | parser.add_argument('--restore_ckpt', help="restore checkpoint") 222 | parser.add_argument('--data_root', type=str, help="restore checkpoint") 223 | parser.add_argument('--small', action='store_true', help='use small model') 224 | parser.add_argument('--validation', type=str, nargs='+') 225 | 226 | parser.add_argument('--lr', type=float, default=0.00002) 227 | parser.add_argument('--num_steps', type=int, default=100000) 228 | parser.add_argument('--batch_size', type=int, default=6) 229 | parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512]) 230 | parser.add_argument('--gpus', type=int, nargs='+', default=[0,1]) 231 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 232 | 233 | parser.add_argument('--iters', type=int, default=12) 234 | parser.add_argument('--wdecay', type=float, default=.00005) 235 | parser.add_argument('--epsilon', type=float, default=1e-8) 236 | parser.add_argument('--clip', type=float, default=1.0) 237 | parser.add_argument('--dropout', type=float, default=0.0) 238 | parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting') 239 | parser.add_argument('--add_noise', action='store_true') 240 | args = parser.parse_args() 241 | 242 | torch.manual_seed(1234) 243 | np.random.seed(1234) 244 | 245 | if not os.path.isdir('checkpoints'): 246 | os.mkdir('checkpoints') 247 | 248 | train(args) -------------------------------------------------------------------------------- /RAFT/train_mixed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p checkpoints 3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision 4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision 5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision 6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision 7 | -------------------------------------------------------------------------------- /RAFT/train_standard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=3,5 python -u train.py --name mpi-0.1-0.25-rf \ 3 | --stage mpi-flow --validation kitti \ 4 | --restore_ckpt weights/raft-things.pth \ 5 | --gpus 0 1 --num_steps 50000 --batch_size 6 \ 6 | --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 \ 7 | --data_root /data1/liangyingping/MPI-Flow/dataset/debug -------------------------------------------------------------------------------- /RAFT/weights/raft-things.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/weights/raft-things.pth -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICCV 2023] MPI-Flow: Learning Realistic Optical Flow with Multiplane Images 2 | 3 | [Paper](https://arxiv.org/abs/2309.06714) | [Checkpoints](https://drive.google.com/drive/folders/1q0UxlswSwZjLgLkEjUNmBuVi0LJfY_b7?usp=sharing) | [Project Page](https://sites.google.com/view/mpi-flow) | [My Home Page](https://sharpiless.github.io/) 4 | 5 | ## Update 6 | - **2024.05.01** - Update large-scale dataset generation [scripts](scripts). 7 | - **2023.12.18** - Code for online training released at [Sharpiless/Train-RAFT-from-single-view-images](https://github.com/Sharpiless/Train-RAFT-from-single-view-images). 8 | - **2023.09.13** - Code released. 9 | 10 | # MPI-Flow 11 | 12 | This is a PyTorch implementation of our paper. 13 | 14 | **Abstract**: *The accuracy of learning-based optical flow estimation models heavily relies on the realism of the training datasets. Current approaches for generating such datasets either employ synthetic data or generate images with limited realism. However, the domain gap of these data with real-world scenes constrains the generalization of the trained model to real-world applications. To address this issue, we investigate generating realistic optical flow datasets from real-world images. Firstly, to generate highly realistic new images, we construct a layered depth representation, known as multiplane images (MPI), from single-view images. This allows us to generate novel view images that are highly realistic. To generate optical flow maps that correspond accurately to the new image, we calculate the optical flows of each plane using the camera matrix and plane depths. We then project these layered optical flows into the output optical flow map with volume rendering. Secondly, to ensure the realism of motion, we present an independent object motion module that can separate the camera and dynamic object motion in MPI. This module addresses the deficiency in MPI-based single-view methods, where optical flow is generated only by camera motion and does not account for any object movement. We additionally devise a depth-aware inpainting module to merge new images with dynamic objects and address unnatural motion occlusions. We show the superior performance of our method through extensive experiments on real-world datasets. Moreover, our approach achieves state-of-the-art performance in both unsupervised and supervised training of learning-based models.* 15 | 16 | # Document for *MPI-Flow* 17 | ## Environment 18 | ``` 19 | conda create -n mpiflow python=3.8 20 | 21 | # here we use pytorch 1.11.0 and CUDA 11.3 for an example 22 | 23 | # install pytorch 24 | pip install https://download.pytorch.org/whl/cu113/torch-1.11.0%2Bcu113-cp38-cp38-linux_x86_64.whl 25 | 26 | # install torchvision 27 | pip install https://download.pytorch.org/whl/cu113/torchvision-0.12.0%2Bcu113-cp38-cp38-linux_x86_64.whl 28 | 29 | # install pytorch3d 30 | conda install https://anaconda.org/pytorch3d/pytorch3d/0.6.2/download/linux-64/pytorch3d-0.6.2-py38_cu113_pyt1100.tar.bz2 31 | 32 | # install other libs 33 | pip install \ 34 | numpy==1.19 \ 35 | scikit-image==0.19.1 \ 36 | scipy==1.8.0 \ 37 | pillow==9.0.1 \ 38 | opencv-python==4.4.0.40 \ 39 | tqdm==4.64.0 \ 40 | moviepy==1.0.3 \ 41 | pyyaml \ 42 | matplotlib \ 43 | scikit-learn \ 44 | lpips \ 45 | kornia \ 46 | focal_frequency_loss \ 47 | tensorboard \ 48 | transformers 49 | 50 | cd external/forward_warping 51 | bash compile.sh 52 | cd ../.. 53 | ``` 54 | 55 | ## Usage 56 | 57 | The input to our MPI-Flow is a single in-the-wild image with its monocular depth estimation and main object mask. 58 | You can use the [MiDaS](https://github.com/isl-org/MiDaS) model to obtain the estimated depth map and use the [Mask2Former](https://github.com/facebookresearch/Mask2Former) to obtain the object mask. 59 | 60 | We provide some example inputs in `./images_kitti`, you can use the image, depth, and mask here to test our model. 61 | Here is an example to run the code: 62 | 63 | ``` 64 | python gen_3dphoto_dynamic.py 65 | ``` 66 | 67 | Then, you will see the result like that: 68 | 69 | 70 | ## Training online 71 | 72 | We have also released an online training version at [https://github.com/Sharpiless/Train-RAFT-from-single-view-images](https://github.com/Sharpiless/Train-RAFT-from-single-view-images). 73 | 74 | ## Performance (Online Training, single V100 GPU) 75 | 3.2w steps on COCO: 76 | | Dataset | EPE | F1 | 77 | | :-------: | :--------: | :-----: | 78 | | KITTI-15 (train) | 3.537468 | 11.694042 | 79 | | Sintel.C | 1.857986 | - | 80 | | Sintel.F | 3.250774 | - | 81 | 82 | 32.0w steps on COCO: 83 | | Dataset | EPE | F1 | 84 | | :-------: | :--------: | :-----: | 85 | | KITTI-15 (train) | 3.586417 | 9.887916 | 86 | | Sintel.C | - | - | 87 | | Sintel.F | - | - | 88 | 89 | ## Checkpoints 90 | 91 | | Image Source | Method | KITTI 12 | | KITTI 15 | | 92 | |--------------|--------------------|----------|-----|----------|-----| 93 | | | | EPE ↓ | F1 ↓| EPE ↓ | F1 ↓| 94 | | COCO | Depthstillation [1]| 1.74 | 6.81| 3.45 | 13.08| 95 | | | RealFlow [12] | N/A | N/A | N/A | N/A | 96 | | | MPI-Flow (ours) | 1.36 | 4.91| 3.44 | 10.66| 97 | | DAVIS | Depthstillation [1]| 1.81 | 6.89| 3.79 | 13.22| 98 | | | RealFlow [12] | 1.59 | 6.08| 3.55 | 12.52| 99 | | | MPI-Flow (ours) | 1.41 | 5.36| 3.32 | 10.47| 100 | | KITTI 15 Test| Depthstillation [1]| 1.77 | 5.97| 3.99 | 13.34| 101 | | | RealFlow [12] | 1.27 | 5.16| 2.43 | 8.86 | 102 | | | MPI-Flow (ours) | 1.24 | 4.51| 2.16 | 7.30 | 103 | | KITTI 15 Train| Depthstillation [1]| 1.67 | 5.71| {2.99} | {9.94}| 104 | | | RealFlow [12] | 1.25 | 5.02| {2.17} | {8.64}| 105 | | | MPI-Flow (ours) | 1.26 | 4.66| {1.88} | {7.16}| 106 | 107 | 108 | Checkpoints to reproduce our results in Table 1 can be downloaded in [Google Drive](https://drive.google.com/drive/folders/1q0UxlswSwZjLgLkEjUNmBuVi0LJfY_b7?usp=sharing). 109 | 110 | You can use the code in [RAFT](https://github.com/princeton-vl/RAFT) to evaluate/train the models. 111 | 112 | ## Contact 113 | If you have any questions, please contact Yingping Liang (liangyingping@bit.edu.cn). 114 | 115 | ## License and Citation 116 | This repository can only be used for personal/research/non-commercial purposes. 117 | Please cite the following paper if this model helps your research: 118 | 119 | @inproceedings{liang2023mpi, 120 | author = {Liang, Yingping and Liu, Jiaming and Zhang, Debing and Ying, Fu}, 121 | title = {MPI-Flow: Learning Realistic Optical Flow with Multiplane Images}, 122 | booktitle = {In the IEEE International Conference on Computer Vision (ICCV)}, 123 | year={2023} 124 | } 125 | 126 | ## Acknowledgments 127 | * The code is heavily borrowed from [AdaMPI](https://github.com/yxuhan/AdaMPI), we thank the authors for their great effort. 128 | -------------------------------------------------------------------------------- /adampiweight/adampi_64p.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/adampiweight/adampi_64p.pth -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/core/__init__.py -------------------------------------------------------------------------------- /core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class AlternateCorrBlock: 64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 65 | self.num_levels = num_levels 66 | self.radius = radius 67 | 68 | self.pyramid = [(fmap1, fmap2)] 69 | for i in range(self.num_levels): 70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 72 | self.pyramid.append((fmap1, fmap2)) 73 | 74 | def __call__(self, coords): 75 | coords = coords.permute(0, 2, 3, 1) 76 | B, H, W, _ = coords.shape 77 | dim = self.pyramid[0][0].shape[1] 78 | 79 | corr_list = [] 80 | for i in range(self.num_levels): 81 | r = self.radius 82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 84 | 85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 87 | corr_list.append(corr.squeeze(1)) 88 | 89 | corr = torch.stack(corr_list, dim=1) 90 | corr = corr.reshape(B, -1, H, W) 91 | return corr / torch.sqrt(torch.tensor(dim).float()) 92 | -------------------------------------------------------------------------------- /core/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | num_groups = planes // 8 70 | 71 | if norm_fn == 'group': 72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | if not stride == 1: 76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | 78 | elif norm_fn == 'batch': 79 | self.norm1 = nn.BatchNorm2d(planes//4) 80 | self.norm2 = nn.BatchNorm2d(planes//4) 81 | self.norm3 = nn.BatchNorm2d(planes) 82 | if not stride == 1: 83 | self.norm4 = nn.BatchNorm2d(planes) 84 | 85 | elif norm_fn == 'instance': 86 | self.norm1 = nn.InstanceNorm2d(planes//4) 87 | self.norm2 = nn.InstanceNorm2d(planes//4) 88 | self.norm3 = nn.InstanceNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.InstanceNorm2d(planes) 91 | 92 | elif norm_fn == 'none': 93 | self.norm1 = nn.Sequential() 94 | self.norm2 = nn.Sequential() 95 | self.norm3 = nn.Sequential() 96 | if not stride == 1: 97 | self.norm4 = nn.Sequential() 98 | 99 | if stride == 1: 100 | self.downsample = None 101 | 102 | else: 103 | self.downsample = nn.Sequential( 104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 105 | 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.relu(self.norm1(self.conv1(y))) 110 | y = self.relu(self.norm2(self.conv2(y))) 111 | y = self.relu(self.norm3(self.conv3(y))) 112 | 113 | if self.downsample is not None: 114 | x = self.downsample(x) 115 | 116 | return self.relu(x+y) 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | 123 | if self.norm_fn == 'group': 124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 125 | 126 | elif self.norm_fn == 'batch': 127 | self.norm1 = nn.BatchNorm2d(64) 128 | 129 | elif self.norm_fn == 'instance': 130 | self.norm1 = nn.InstanceNorm2d(64) 131 | 132 | elif self.norm_fn == 'none': 133 | self.norm1 = nn.Sequential() 134 | 135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 136 | self.relu1 = nn.ReLU(inplace=True) 137 | 138 | self.in_planes = 64 139 | self.layer1 = self._make_layer(64, stride=1) 140 | self.layer2 = self._make_layer(96, stride=2) 141 | self.layer3 = self._make_layer(128, stride=2) 142 | 143 | # output convolution 144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 145 | 146 | self.dropout = None 147 | if dropout > 0: 148 | self.dropout = nn.Dropout2d(p=dropout) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 154 | if m.weight is not None: 155 | nn.init.constant_(m.weight, 1) 156 | if m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | 159 | def _make_layer(self, dim, stride=1): 160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 162 | layers = (layer1, layer2) 163 | 164 | self.in_planes = dim 165 | return nn.Sequential(*layers) 166 | 167 | 168 | def forward(self, x): 169 | 170 | # if input is list, combine batch dimension 171 | is_list = isinstance(x, tuple) or isinstance(x, list) 172 | if is_list: 173 | batch_dim = x[0].shape[0] 174 | x = torch.cat(x, dim=0) 175 | 176 | x = self.conv1(x) 177 | x = self.norm1(x) 178 | x = self.relu1(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | 184 | x = self.conv2(x) 185 | 186 | if self.training and self.dropout is not None: 187 | x = self.dropout(x) 188 | 189 | if is_list: 190 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 191 | 192 | return x 193 | 194 | 195 | class SmallEncoder(nn.Module): 196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 197 | super(SmallEncoder, self).__init__() 198 | self.norm_fn = norm_fn 199 | 200 | if self.norm_fn == 'group': 201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 202 | 203 | elif self.norm_fn == 'batch': 204 | self.norm1 = nn.BatchNorm2d(32) 205 | 206 | elif self.norm_fn == 'instance': 207 | self.norm1 = nn.InstanceNorm2d(32) 208 | 209 | elif self.norm_fn == 'none': 210 | self.norm1 = nn.Sequential() 211 | 212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 213 | self.relu1 = nn.ReLU(inplace=True) 214 | 215 | self.in_planes = 32 216 | self.layer1 = self._make_layer(32, stride=1) 217 | self.layer2 = self._make_layer(64, stride=2) 218 | self.layer3 = self._make_layer(96, stride=2) 219 | 220 | self.dropout = None 221 | if dropout > 0: 222 | self.dropout = nn.Dropout2d(p=dropout) 223 | 224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 230 | if m.weight is not None: 231 | nn.init.constant_(m.weight, 1) 232 | if m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | 235 | def _make_layer(self, dim, stride=1): 236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 238 | layers = (layer1, layer2) 239 | 240 | self.in_planes = dim 241 | return nn.Sequential(*layers) 242 | 243 | 244 | def forward(self, x): 245 | 246 | # if input is list, combine batch dimension 247 | is_list = isinstance(x, tuple) or isinstance(x, list) 248 | if is_list: 249 | batch_dim = x[0].shape[0] 250 | x = torch.cat(x, dim=0) 251 | 252 | x = self.conv1(x) 253 | x = self.norm1(x) 254 | x = self.relu1(x) 255 | 256 | x = self.layer1(x) 257 | x = self.layer2(x) 258 | x = self.layer3(x) 259 | x = self.conv2(x) 260 | 261 | if self.training and self.dropout is not None: 262 | x = self.dropout(x) 263 | 264 | if is_list: 265 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 266 | 267 | return x 268 | -------------------------------------------------------------------------------- /core/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from update import BasicUpdateBlock, SmallUpdateBlock 7 | from extractor import BasicEncoder, SmallEncoder 8 | from corr import CorrBlock, AlternateCorrBlock 9 | from utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | def __enter__(self): 19 | pass 20 | def __exit__(self, *args): 21 | pass 22 | 23 | 24 | class RAFT(nn.Module): 25 | def __init__(self, args): 26 | super(RAFT, self).__init__() 27 | self.args = args 28 | 29 | if args.small: 30 | self.hidden_dim = hdim = 96 31 | self.context_dim = cdim = 64 32 | args.corr_levels = 4 33 | args.corr_radius = 3 34 | 35 | else: 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | args.corr_levels = 4 39 | args.corr_radius = 4 40 | 41 | if 'dropout' not in self.args: 42 | self.args.dropout = 0 43 | 44 | if 'alternate_corr' not in self.args: 45 | self.args.alternate_corr = False 46 | 47 | # feature network, context network, and update block 48 | if args.small: 49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) 51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 52 | 53 | else: 54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 57 | 58 | def freeze_bn(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.BatchNorm2d): 61 | m.eval() 62 | 63 | def initialize_flow(self, img): 64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 65 | N, C, H, W = img.shape 66 | coords0 = coords_grid(N, H//8, W//8, device=img.device) 67 | coords1 = coords_grid(N, H//8, W//8, device=img.device) 68 | 69 | # optical flow computed as difference: flow = coords1 - coords0 70 | return coords0, coords1 71 | 72 | def upsample_flow(self, flow, mask): 73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 74 | N, _, H, W = flow.shape 75 | mask = mask.view(N, 1, 9, 8, 8, H, W) 76 | mask = torch.softmax(mask, dim=2) 77 | 78 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 80 | 81 | up_flow = torch.sum(mask * up_flow, dim=2) 82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 83 | return up_flow.reshape(N, 2, 8*H, 8*W) 84 | 85 | 86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 87 | """ Estimate optical flow between pair of frames """ 88 | 89 | image1 = 2 * (image1 / 255.0) - 1.0 90 | image2 = 2 * (image2 / 255.0) - 1.0 91 | 92 | image1 = image1.contiguous() 93 | image2 = image2.contiguous() 94 | 95 | hdim = self.hidden_dim 96 | cdim = self.context_dim 97 | 98 | # run the feature network 99 | with autocast(enabled=self.args.mixed_precision): 100 | fmap1, fmap2 = self.fnet([image1, image2]) 101 | 102 | fmap1 = fmap1.float() 103 | fmap2 = fmap2.float() 104 | if self.args.alternate_corr: 105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 106 | else: 107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | 109 | # run the context network 110 | with autocast(enabled=self.args.mixed_precision): 111 | cnet = self.cnet(image1) 112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 113 | net = torch.tanh(net) 114 | inp = torch.relu(inp) 115 | 116 | coords0, coords1 = self.initialize_flow(image1) 117 | 118 | if flow_init is not None: 119 | coords1 = coords1 + flow_init 120 | 121 | flow_predictions = [] 122 | for itr in range(iters): 123 | coords1 = coords1.detach() 124 | corr = corr_fn(coords1) # index correlation volume 125 | 126 | flow = coords1 - coords0 127 | with autocast(enabled=self.args.mixed_precision): 128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 129 | 130 | # F(t+1) = F(t) + \Delta(t) 131 | coords1 = coords1 + delta_flow 132 | 133 | # upsample predictions 134 | if up_mask is None: 135 | flow_up = upflow8(coords1 - coords0) 136 | else: 137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 138 | 139 | flow_predictions.append(flow_up) 140 | 141 | if test_mode: 142 | return coords1 - coords0, flow_up 143 | 144 | return flow_predictions 145 | -------------------------------------------------------------------------------- /core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/core/utils/__init__.py -------------------------------------------------------------------------------- /core/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, img1, img2): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if np.random.rand() < self.asymmetric_color_aug_prob: 41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 43 | 44 | # symmetric 45 | else: 46 | image_stack = np.concatenate([img1, img2], axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | img1, img2 = np.split(image_stack, 2, axis=0) 49 | 50 | return img1, img2 51 | 52 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 53 | """ Occlusion augmentation """ 54 | 55 | ht, wd = img1.shape[:2] 56 | if np.random.rand() < self.eraser_aug_prob: 57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | x0 = np.random.randint(0, wd) 60 | y0 = np.random.randint(0, ht) 61 | dx = np.random.randint(bounds[0], bounds[1]) 62 | dy = np.random.randint(bounds[0], bounds[1]) 63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 64 | 65 | return img1, img2 66 | 67 | def spatial_transform(self, img1, img2, flow): 68 | # randomly sample scale 69 | ht, wd = img1.shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | flow = flow * [scale_x, scale_y] 90 | 91 | if self.do_flip: 92 | if np.random.rand() < self.h_flip_prob: # h-flip 93 | img1 = img1[:, ::-1] 94 | img2 = img2[:, ::-1] 95 | flow = flow[:, ::-1] * [-1.0, 1.0] 96 | 97 | if np.random.rand() < self.v_flip_prob: # v-flip 98 | img1 = img1[::-1, :] 99 | img2 = img2[::-1, :] 100 | flow = flow[::-1, :] * [1.0, -1.0] 101 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 102 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 103 | 104 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 105 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 106 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 107 | 108 | return img1, img2, flow 109 | 110 | def __call__(self, img1, img2, flow): 111 | img1, img2 = self.color_transform(img1, img2) 112 | img1, img2 = self.eraser_transform(img1, img2) 113 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 114 | 115 | img1 = np.ascontiguousarray(img1) 116 | img2 = np.ascontiguousarray(img2) 117 | flow = np.ascontiguousarray(flow) 118 | 119 | return img1, img2, flow 120 | 121 | class SparseFlowAugmentor: 122 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 123 | # spatial augmentation params 124 | self.crop_size = crop_size 125 | self.min_scale = min_scale 126 | self.max_scale = max_scale 127 | self.spatial_aug_prob = 0.8 128 | self.stretch_prob = 0.8 129 | self.max_stretch = 0.2 130 | 131 | # flip augmentation params 132 | self.do_flip = do_flip 133 | self.h_flip_prob = 0.5 134 | self.v_flip_prob = 0.1 135 | 136 | # photometric augmentation params 137 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 138 | self.asymmetric_color_aug_prob = 0.2 139 | self.eraser_aug_prob = 0.5 140 | 141 | def color_transform(self, img1, img2): 142 | image_stack = np.concatenate([img1, img2], axis=0) 143 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 144 | img1, img2 = np.split(image_stack, 2, axis=0) 145 | return img1, img2 146 | 147 | def eraser_transform(self, img1, img2): 148 | ht, wd = img1.shape[:2] 149 | if np.random.rand() < self.eraser_aug_prob: 150 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 151 | for _ in range(np.random.randint(1, 3)): 152 | x0 = np.random.randint(0, wd) 153 | y0 = np.random.randint(0, ht) 154 | dx = np.random.randint(50, 100) 155 | dy = np.random.randint(50, 100) 156 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 157 | 158 | return img1, img2 159 | 160 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 161 | ht, wd = flow.shape[:2] 162 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 163 | coords = np.stack(coords, axis=-1) 164 | 165 | coords = coords.reshape(-1, 2).astype(np.float32) 166 | flow = flow.reshape(-1, 2).astype(np.float32) 167 | valid = valid.reshape(-1).astype(np.float32) 168 | 169 | coords0 = coords[valid>=1] 170 | flow0 = flow[valid>=1] 171 | 172 | ht1 = int(round(ht * fy)) 173 | wd1 = int(round(wd * fx)) 174 | 175 | coords1 = coords0 * [fx, fy] 176 | flow1 = flow0 * [fx, fy] 177 | 178 | xx = np.round(coords1[:,0]).astype(np.int32) 179 | yy = np.round(coords1[:,1]).astype(np.int32) 180 | 181 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 182 | xx = xx[v] 183 | yy = yy[v] 184 | flow1 = flow1[v] 185 | 186 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 187 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 188 | 189 | flow_img[yy, xx] = flow1 190 | valid_img[yy, xx] = 1 191 | 192 | return flow_img, valid_img 193 | 194 | def spatial_transform(self, img1, img2, flow, valid): 195 | # randomly sample scale 196 | 197 | ht, wd = img1.shape[:2] 198 | min_scale = np.maximum( 199 | (self.crop_size[0] + 1) / float(ht), 200 | (self.crop_size[1] + 1) / float(wd)) 201 | 202 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 203 | scale_x = np.clip(scale, min_scale, None) 204 | scale_y = np.clip(scale, min_scale, None) 205 | 206 | if np.random.rand() < self.spatial_aug_prob: 207 | # rescale the images 208 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 209 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 210 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 211 | 212 | if self.do_flip: 213 | if np.random.rand() < 0.5: # h-flip 214 | img1 = img1[:, ::-1] 215 | img2 = img2[:, ::-1] 216 | flow = flow[:, ::-1] * [-1.0, 1.0] 217 | valid = valid[:, ::-1] 218 | 219 | margin_y = 20 220 | margin_x = 50 221 | 222 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 223 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 224 | 225 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 226 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 227 | 228 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 229 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 230 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 231 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 232 | return img1, img2, flow, valid 233 | 234 | 235 | def __call__(self, img1, img2, flow, valid): 236 | img1, img2 = self.color_transform(img1, img2) 237 | img1, img2 = self.eraser_transform(img1, img2) 238 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 239 | 240 | img1 = np.ascontiguousarray(img1) 241 | img2 = np.ascontiguousarray(img2) 242 | flow = np.ascontiguousarray(flow) 243 | valid = np.ascontiguousarray(valid) 244 | 245 | return img1, img2, flow, valid 246 | -------------------------------------------------------------------------------- /core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd, device): 75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /external/forward_warping/compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | gcc -fPIC -shared -o libwarping.so warping.c -------------------------------------------------------------------------------- /external/forward_warping/libwarping.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/external/forward_warping/libwarping.so -------------------------------------------------------------------------------- /external/forward_warping/warping.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #define valid(X, Y, W) (Y*W*5+X*5+3) 4 | #define collision(X, Y, W) (Y*W*5+X*5+4) 5 | 6 | void forward_warping(const void *src, const void *idx, const void *idy, const void *z, void *warped, int h, int w) 7 | { 8 | float *dlut = (float *)calloc(h * w, sizeof(float)); 9 | for (int i = 0; i < h; i++) 10 | for (int j = 0; j < w; j++) 11 | dlut[i * w + j] = 1000; 12 | 13 | for (int i = 0; i < h; i++) 14 | for (int j = 0; j < w; j++) 15 | { 16 | int x = ((long *)idx)[i * w + j]; 17 | int y = ((long *)idy)[i * w + j]; 18 | 19 | if (((float *)z)[i * w + j] < dlut[y * w + x]) 20 | for (int c = 0; c < 3; c++) 21 | ((unsigned char *)warped)[y * w * 5 + x * 5 + c] = ((unsigned char *)src)[i * w * 3 + j * 3 + c]; 22 | 23 | ((unsigned char *)warped)[valid(x,y,w)] = 1; 24 | if (dlut[y * w + x] != 1000) 25 | ((unsigned char *)warped)[collision(x,y,w)] = 0; 26 | else 27 | ((unsigned char *)warped)[collision(x,y,w)] = 1; 28 | dlut[y * w + x] = ((float *)z)[i * w + j]; 29 | } 30 | 31 | free(dlut); 32 | return; 33 | } 34 | -------------------------------------------------------------------------------- /flow_colors.py: -------------------------------------------------------------------------------- 1 | # 2 | # Utility functions for coloring optical flow maps 3 | # 4 | 5 | import os 6 | import numpy as np 7 | import sys 8 | import re 9 | import cv2 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | def make_colorwheel(): 14 | """ 15 | Generates a color wheel for optical flow visualization as presented in: 16 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 17 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 18 | Code follows the original C++ source code of Daniel Scharstein. 19 | Code follows the the Matlab source code of Deqing Sun. 20 | Returns: 21 | np.ndarray: Color wheel 22 | """ 23 | 24 | RY = 15 25 | YG = 6 26 | GC = 4 27 | CB = 11 28 | BM = 13 29 | MR = 6 30 | 31 | ncols = RY + YG + GC + CB + BM + MR 32 | colorwheel = np.zeros((ncols, 3)) 33 | col = 0 34 | 35 | # RY 36 | colorwheel[0:RY, 0] = 255 37 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 38 | col = col + RY 39 | # YG 40 | colorwheel[col: col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 41 | colorwheel[col: col + YG, 1] = 255 42 | col = col + YG 43 | # GC 44 | colorwheel[col: col + GC, 1] = 255 45 | colorwheel[col: col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 46 | col = col + GC 47 | # CB 48 | colorwheel[col: col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 49 | colorwheel[col: col + CB, 2] = 255 50 | col = col + CB 51 | # BM 52 | colorwheel[col: col + BM, 2] = 255 53 | colorwheel[col: col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 54 | col = col + BM 55 | # MR 56 | colorwheel[col: col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 57 | colorwheel[col: col + MR, 0] = 255 58 | return colorwheel 59 | 60 | 61 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 62 | """ 63 | Applies the flow color wheel to (possibly clipped) flow components u and v. 64 | According to the C++ source code of Daniel Scharstein 65 | According to the Matlab source code of Deqing Sun 66 | Args: 67 | u (np.ndarray): Input horizontal flow of shape [H,W] 68 | v (np.ndarray): Input vertical flow of shape [H,W] 69 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 70 | Returns: 71 | np.ndarray: Flow visualization image of shape [H,W,3] 72 | """ 73 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 74 | colorwheel = make_colorwheel() # shape [55x3] 75 | ncols = colorwheel.shape[0] 76 | rad = np.sqrt(np.square(u) + np.square(v)) 77 | a = np.arctan2(-v, -u) / np.pi 78 | fk = (a + 1) / 2 * (ncols - 1) 79 | k0 = np.floor(fk).astype(np.int32) 80 | k1 = k0 + 1 81 | k1[k1 == ncols] = 0 82 | f = fk - k0 83 | for i in range(colorwheel.shape[1]): 84 | tmp = colorwheel[:, i] 85 | col0 = tmp[k0] / 255.0 86 | col1 = tmp[k1] / 255.0 87 | col = (1 - f) * col0 + f * col1 88 | idx = rad <= 1 89 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 90 | col[~idx] = col[~idx] * 0.75 # out of range 91 | # Note the 2-i => BGR instead of RGB 92 | ch_idx = 2 - i if convert_to_bgr else i 93 | flow_image[:, :, ch_idx] = np.floor(255 * col) 94 | return flow_image 95 | 96 | 97 | def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): 98 | """ 99 | Expects a two dimensional flow image of shape. 100 | Args: 101 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 102 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 103 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 104 | Returns: 105 | np.ndarray: Flow visualization image of shape [H,W,3] 106 | """ 107 | assert flow_uv.ndim == 3, "input flow must have three dimensions" 108 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" 109 | if clip_flow is not None: 110 | flow_uv = np.clip(flow_uv, 0, clip_flow) 111 | u = flow_uv[:, :, 0] 112 | v = flow_uv[:, :, 1] 113 | rad = np.sqrt(np.square(u) + np.square(v)) 114 | rad_max = np.max(rad) 115 | epsilon = 1e-5 116 | u = u / (rad_max + epsilon) 117 | v = v / (rad_max + epsilon) 118 | return flow_uv_to_colors(u, v, convert_to_bgr) 119 | -------------------------------------------------------------------------------- /gen_3dphoto_dynamic_v2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn.functional as F 4 | import os 5 | import cv2 6 | from tqdm import tqdm 7 | from torchvision.utils import save_image 8 | from write_flow import writeFlow 9 | 10 | from utils.utils import ( 11 | image_to_tensor, 12 | disparity_to_tensor, 13 | render_3dphoto_dynamic, 14 | ) 15 | from model.AdaMPI import MPIPredictor 16 | from random import seed 17 | import numpy as np 18 | from PIL import Image 19 | 20 | parser = argparse.ArgumentParser( 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--width', type=int, default=1280) 23 | parser.add_argument('--height', type=int, default=384) 24 | parser.add_argument('--seed', type=int, default=114514) 25 | parser.add_argument('--ext_cz', type=float, default=0.15) 26 | parser.add_argument('--ckpt_path', type=str, 27 | default='adampiweight/adampi_64p.pth') 28 | parser.add_argument('--repeat', type=int, default=5) 29 | parser.add_argument('--base', type=str, 30 | default='', required=True) 31 | parser.add_argument('--out', type=str, 32 | default='', required=True) 33 | 34 | opt, _ = parser.parse_known_args() 35 | 36 | print(opt) 37 | 38 | seed(opt.seed) 39 | np.random.seed(opt.seed) 40 | 41 | # render 3D photo 42 | K = torch.tensor([ 43 | [0.58, 0, 0.5], 44 | [0, 0.58, 0.5], 45 | [0, 0, 1] 46 | ]).cuda().half() 47 | K[0, :] *= opt.width 48 | K[1, :] *= opt.height 49 | K = K.unsqueeze(0) 50 | 51 | # load pretrained model 52 | ckpt = torch.load(opt.ckpt_path) 53 | model = MPIPredictor( 54 | width=opt.width, 55 | height=opt.height, 56 | num_planes=ckpt['num_planes'], 57 | ) 58 | model.load_state_dict(ckpt['weight']) 59 | model = model.cuda().half() 60 | model.eval() 61 | # model = torch.jit.script(model) 62 | 63 | out = opt.out 64 | base = opt.base 65 | 66 | if not os.path.exists(out): 67 | os.mkdir(out) 68 | os.mkdir(f"{out}/src_images") 69 | os.mkdir(f"{out}/dst_images") 70 | os.mkdir(f"{out}/flows") 71 | os.mkdir(f"{out}/obj_mask") 72 | 73 | 74 | img_base = os.path.join(base, "images") 75 | disp_base = os.path.join(base, "disps") 76 | mask_base = os.path.join(base, "masks") 77 | 78 | for img in tqdm(sorted(os.listdir(img_base))): 79 | 80 | name = img.split(".")[0] 81 | 82 | image = image_to_tensor(os.path.join(img_base, img)).cuda().half() # [1,3,h,w] 83 | obj_mask_np = np.array(Image.open(os.path.join(mask_base, img)).convert("L")) 84 | disp = disparity_to_tensor(os.path.join(disp_base, img)).cuda().half() # [1,1,h,w] 85 | 86 | image = F.interpolate(image, size=(opt.height, opt.width), 87 | mode='bilinear', align_corners=True) 88 | disp = F.interpolate(disp, size=(opt.height, opt.width), 89 | mode='bilinear', align_corners=True) 90 | 91 | # disp.requires_grad = True 92 | with torch.no_grad(): 93 | mpi_all_src, disparity_all_src = model(image, disp) # [b,s,4,h,w] 94 | 95 | # import IPython 96 | # IPython.embed() 97 | # exit() 98 | 99 | for r in range(opt.repeat): 100 | # predict MPI planes 101 | obj_index = np.random.randint(obj_mask_np.max()) + 1 102 | # print(obj_mask_np.max(), obj_index) 103 | obj_mask = torch.FloatTensor(obj_mask_np == obj_index).cuda().half().unsqueeze(0).unsqueeze(0) # [1,3,h,w] 104 | obj_mask = F.interpolate(obj_mask, size=(opt.height, opt.width), 105 | mode='bilinear', align_corners=True) 106 | 107 | flow_mix, src_np, inpainted, res = render_3dphoto_dynamic( 108 | opt, 109 | image, 110 | obj_mask, 111 | disp, 112 | mpi_all_src, 113 | disparity_all_src, 114 | K, 115 | K, 116 | data_path='outputs', 117 | name='demo' 118 | ) 119 | 120 | writeFlow(os.path.join(out, "flows", f'{name}_{r}.flo'), flow_mix) 121 | cv2.imwrite(os.path.join(out, "dst_images", f'{name}_{r}.png'), inpainted) 122 | cv2.imwrite(os.path.join(out, "src_images", f'{name}_{r}.png'), src_np) -------------------------------------------------------------------------------- /geometry.py: -------------------------------------------------------------------------------- 1 | # 2 | # Classes and functions in this script are taken from https://github.com/nianticlabs/monodepth2 3 | # Use conditions available in the LICENSE file at https://github.com/nianticlabs/monodepth2/blob/master/LICENSE 4 | # Copyright © Niantic, Inc. 2018. All rights reserved. 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | __all__ = ["BackprojectDepth", "Project3D", "transformation_from_parameters"] 15 | 16 | 17 | class BackprojectDepth(nn.Module): 18 | """Layer to transform a depth image into a point cloud 19 | """ 20 | def __init__(self, batch_size, height, width): 21 | super(BackprojectDepth, self).__init__() 22 | 23 | self.batch_size = batch_size 24 | self.height = height 25 | self.width = width 26 | 27 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') 28 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) 29 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), 30 | requires_grad=False) 31 | 32 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), 33 | requires_grad=False) 34 | 35 | self.pix_coords = torch.unsqueeze(torch.stack( 36 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) 37 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) 38 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), 39 | requires_grad=False) 40 | 41 | def forward(self, depth, inv_K): 42 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) 43 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points 44 | cam_points = torch.cat([cam_points, self.ones], 1) 45 | # import IPython 46 | # IPython.embed() 47 | # exit() 48 | 49 | return cam_points 50 | 51 | 52 | class Project3D(nn.Module): 53 | """Layer which projects 3D points into a camera with intrinsics K and at position T 54 | """ 55 | def __init__(self, batch_size, height, width, eps=1e-7): 56 | super(Project3D, self).__init__() 57 | 58 | self.batch_size = batch_size 59 | self.height = height 60 | self.width = width 61 | self.eps = eps 62 | 63 | def forward(self, points, K, T, T2=None): 64 | if not T2 is None: 65 | T = torch.matmul(T, torch.inverse(T2)) 66 | P = torch.matmul(K, T)[:, :3, :] 67 | 68 | cam_points = torch.matmul(P, points) 69 | 70 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) 71 | pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) 72 | pix_coords = pix_coords.permute(0, 2, 3, 1) 73 | pix_coords[..., 0] /= self.width - 1 74 | pix_coords[..., 1] /= self.height - 1 75 | pix_coords = (pix_coords - 0.5) * 2 76 | return pix_coords, cam_points[:, 2, :].unsqueeze(1) 77 | 78 | 79 | def transformation_from_parameters(axisangle, translation, invert=False): 80 | """Convert the network's (axisangle, translation) output into a 4x4 matrix 81 | """ 82 | R = rot_from_axisangle(axisangle) 83 | t = translation.clone() 84 | 85 | if invert: 86 | R = R.transpose(1, 2) 87 | t *= -1 88 | 89 | T = get_translation_matrix(t) 90 | 91 | if invert: 92 | M = torch.matmul(R, T) 93 | else: 94 | M = torch.matmul(T, R) 95 | return M 96 | 97 | 98 | def get_translation_matrix(translation_vector): 99 | """Convert a translation vector into a 4x4 transformation matrix 100 | """ 101 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) 102 | 103 | t = translation_vector.contiguous().view(-1, 3, 1) 104 | 105 | T[:, 0, 0] = 1 106 | T[:, 1, 1] = 1 107 | T[:, 2, 2] = 1 108 | T[:, 3, 3] = 1 109 | T[:, :3, 3, None] = t 110 | 111 | return T 112 | 113 | 114 | def rot_from_axisangle(vec): 115 | """Convert an axisangle rotation into a 4x4 transformation matrix 116 | (adapted from https://github.com/Wallacoloo/printipi) 117 | Input 'vec' has to be Bx1x3 118 | """ 119 | angle = torch.norm(vec, 2, 2, True) 120 | axis = vec / (angle + 1e-7) 121 | 122 | ca = torch.cos(angle) 123 | sa = torch.sin(angle) 124 | C = 1 - ca 125 | 126 | x = axis[..., 0].unsqueeze(1) 127 | y = axis[..., 1].unsqueeze(1) 128 | z = axis[..., 2].unsqueeze(1) 129 | 130 | xs = x * sa 131 | ys = y * sa 132 | zs = z * sa 133 | xC = x * C 134 | yC = y * C 135 | zC = z * C 136 | xyC = x * yC 137 | yzC = y * zC 138 | zxC = z * xC 139 | 140 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) 141 | 142 | rot[:, 0, 0] = torch.squeeze(x * xC + ca) 143 | rot[:, 0, 1] = torch.squeeze(xyC - zs) 144 | rot[:, 0, 2] = torch.squeeze(zxC + ys) 145 | rot[:, 1, 0] = torch.squeeze(xyC + zs) 146 | rot[:, 1, 1] = torch.squeeze(y * yC + ca) 147 | rot[:, 1, 2] = torch.squeeze(yzC - xs) 148 | rot[:, 2, 0] = torch.squeeze(zxC - ys) 149 | rot[:, 2, 1] = torch.squeeze(yzC + xs) 150 | rot[:, 2, 2] = torch.squeeze(z * zC + ca) 151 | rot[:, 3, 3] = 1 152 | 153 | return rot 154 | -------------------------------------------------------------------------------- /misc/train_image_2_000000_00_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/misc/train_image_2_000000_00_1.png -------------------------------------------------------------------------------- /model/AdaMPI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MPIPredictor(nn.Module): 7 | def __init__( 8 | self, 9 | width=384, 10 | height=256, 11 | num_planes=64, 12 | ): 13 | super(MPIPredictor, self).__init__() 14 | self.num_planes = num_planes 15 | disp_range = [0.001, 1] 16 | self.far, self.near = disp_range 17 | 18 | H_tgt, W_tgt = height, width 19 | ctx_spatial_scale = 4 20 | self.low_res_size = (int(H_tgt / ctx_spatial_scale), int(W_tgt / ctx_spatial_scale)) 21 | 22 | # ----------------------- 23 | # CPN Encoder 24 | # ----------------------- 25 | from model.CPN.encoder import ResnetEncoder 26 | self.encoder = ResnetEncoder(num_layers=18) 27 | 28 | # ----------------------- 29 | # CPN Feature Mask UNet 30 | # ----------------------- 31 | from model.CPN.unet import FeatMaskNetwork 32 | self.fmn = FeatMaskNetwork() 33 | 34 | # ----------------------- 35 | # PAN 36 | # ----------------------- 37 | from model.PAN import DepthPredictionNetwork 38 | self.dpn = DepthPredictionNetwork( 39 | disp_range=disp_range, 40 | n_planes=num_planes, 41 | ) 42 | 43 | # ----------------------- 44 | # CPN Decoder 45 | # ----------------------- 46 | from model.CPN.decoder import DepthDecoder 47 | num_ch_enc = self.encoder.num_ch_enc 48 | self.decoder = DepthDecoder( 49 | num_ch_enc=num_ch_enc, 50 | use_alpha=False, 51 | scales=range(4), 52 | use_skips=True, 53 | ) 54 | 55 | def forward( 56 | self, 57 | src_imgs, 58 | src_depths, 59 | ): 60 | rgb_low_res = F.interpolate(src_imgs, size=self.low_res_size, mode='bilinear', align_corners=True) 61 | disp_low_res = F.interpolate(src_depths, size=self.low_res_size, mode='bilinear', align_corners=True) 62 | 63 | bs = src_imgs.shape[0] 64 | dpn_input_disparity = torch.linspace( 65 | self.near, 66 | self.far, 67 | self.num_planes + 2 68 | )[1:-1].to(src_imgs.dtype).to(src_imgs.device).unsqueeze(0).repeat(bs, 1) 69 | 70 | # render_disp = self.dpn(dpn_input_disparity, rgb_low_res, disp_low_res) 71 | render_disp = dpn_input_disparity 72 | feature_mask = self.fmn(src_imgs, src_depths, render_disp) 73 | # Encoder forward 74 | conv1_out, block1_out, block2_out, block3_out, block4_out = self.encoder(src_imgs, src_depths) 75 | enc_features = [conv1_out, block1_out, block2_out, block3_out, block4_out] 76 | # Decoder forward 77 | outputs = self.decoder(enc_features, feature_mask) 78 | return outputs[0], render_disp 79 | -------------------------------------------------------------------------------- /model/CPN/decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | 8 | ''' 9 | This code is borrowed heavily from MINE: https://github.com/vincentfung13/MINE 10 | ''' 11 | 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | def upsample(x): 20 | return F.interpolate(x, scale_factor=2, mode="nearest") 21 | 22 | 23 | class GatedConv(nn.Module): 24 | def __init__(self, in_channels, out_channels): 25 | super(GatedConv, self).__init__() 26 | self.pad = nn.ReflectionPad2d(1) 27 | 28 | self.conv2d = nn.Conv2d(in_channels, out_channels, 3) 29 | self.mask_conv2d = nn.Conv2d(in_channels, out_channels, 3) 30 | 31 | self.sigmoid = nn.Sigmoid() 32 | 33 | def forward(self, feat): 34 | feat = self.pad(feat) 35 | x = self.conv2d(feat) 36 | mask = self.mask_conv2d(feat) 37 | return x * self.sigmoid(mask) 38 | 39 | 40 | class GatedConvBlock(nn.Module): 41 | def __init__(self, in_channels, out_channels): 42 | super(GatedConvBlock, self).__init__() 43 | self.gated_conv = GatedConv(in_channels, out_channels) 44 | self.nonlin = nn.ELU(inplace=True) 45 | self.bn = nn.BatchNorm2d(out_channels) 46 | 47 | def forward(self, feat): 48 | x = self.gated_conv(feat) 49 | x = self.bn(x) 50 | x = self.nonlin(x) 51 | return x 52 | 53 | 54 | def conv(in_planes, out_planes, kernel_size, instancenorm=False): 55 | if instancenorm: 56 | m = nn.Sequential( 57 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 58 | stride=1, padding=(kernel_size - 1) // 2, bias=False), 59 | nn.InstanceNorm2d(out_planes), 60 | nn.LeakyReLU(0.1, inplace=True), 61 | ) 62 | else: 63 | m = nn.Sequential( 64 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 65 | stride=1, padding=(kernel_size - 1) // 2, bias=False), 66 | nn.BatchNorm2d(out_planes), 67 | nn.LeakyReLU(0.1, inplace=True) 68 | ) 69 | return m 70 | 71 | 72 | class DepthDecoder(nn.Module): 73 | def tuple_to_str(self, key_tuple): 74 | key_str = '-'.join(str(key_tuple)) 75 | return key_str 76 | 77 | def __init__(self, num_ch_enc, 78 | use_alpha=False, scales=range(4), num_output_channels=4, 79 | use_skips=True, **kwargs): 80 | super(DepthDecoder, self).__init__() 81 | 82 | self.num_output_channels = num_output_channels 83 | self.use_skips = use_skips 84 | self.upsample_mode = 'nearest' 85 | self.scales = scales 86 | self.use_alpha = use_alpha 87 | 88 | final_enc_out_channels = num_ch_enc[-1] 89 | self.downsample = nn.MaxPool2d(3, stride=2, padding=1) 90 | self.upsample = nn.UpsamplingNearest2d(scale_factor=2) 91 | self.conv_down1 = conv(final_enc_out_channels, 512, 1, False) 92 | self.conv_down2 = conv(512, 256, 3, False) 93 | self.conv_up1 = conv(256, 256, 3, False) 94 | self.conv_up2 = conv(256, final_enc_out_channels, 1, False) 95 | 96 | self.num_ch_enc = num_ch_enc 97 | # print("num_ch_enc=", num_ch_enc) 98 | self.num_ch_enc = [x + 2 for x in self.num_ch_enc] 99 | self.num_ch_dec = np.array([12, 24, 48, 96, 192]) 100 | # self.num_ch_enc = np.array([64, 64, 128, 256, 512]) 101 | 102 | # decoder 103 | self.convs = nn.ModuleDict() 104 | for i in range(4, -1, -1): 105 | # upconv_0 106 | num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] 107 | num_ch_out = self.num_ch_dec[i] 108 | self.convs[self.tuple_to_str(("upconv", i, 0))] = GatedConvBlock(num_ch_in, num_ch_out) 109 | # print("upconv_{}_{}".format(i, 0), num_ch_in, num_ch_out) 110 | 111 | # upconv_1 112 | num_ch_in = self.num_ch_dec[i] 113 | if self.use_skips and i > 0: 114 | num_ch_in += self.num_ch_enc[i - 1] 115 | num_ch_out = self.num_ch_dec[i] 116 | self.convs[self.tuple_to_str(("upconv", i, 1))] = GatedConvBlock(num_ch_in, num_ch_out) 117 | # print("upconv_{}_{}".format(i, 1), num_ch_in, num_ch_out) 118 | 119 | for s in self.scales: 120 | self.convs[self.tuple_to_str(("dispconv", s))] = GatedConv(self.num_ch_dec[s], self.num_output_channels) 121 | 122 | self.sigmoid = nn.Sigmoid() 123 | 124 | def forward(self, input_features, feature_mask): 125 | B, S, _, _ = feature_mask.size() 126 | # extension of encoder to increase receptive field 127 | encoder_out = input_features[-1] 128 | conv_down1 = self.conv_down1(self.downsample(encoder_out)) 129 | conv_down2 = self.conv_down2(self.downsample(conv_down1)) 130 | conv_up1 = self.conv_up1(self.upsample(conv_down2)) 131 | conv_up2 = self.conv_up2(self.upsample(conv_up1)) 132 | 133 | # repeat / reshape features 134 | _, C_feat, H_feat, W_feat = conv_up2.size() 135 | cum_mask = torch.cumsum(feature_mask, dim=1) # [B,S,H,W] 136 | inpaint_mask = torch.cat([torch.zeros_like(cum_mask[:, -1:, :, :]), cum_mask[:, :-1, :, :]], dim=1) # [B,S,H,W] 137 | context_mask = 1 - inpaint_mask # [B,S,H,W] 138 | 139 | cur_context_mask = F.adaptive_avg_pool2d(context_mask, (H_feat, W_feat)).unsqueeze(2) 140 | cur_feature_mask = F.adaptive_avg_pool2d(feature_mask, (H_feat, W_feat)).unsqueeze(2) 141 | conv_up2 = conv_up2.unsqueeze(1).repeat(1, S, 1, 1, 1) 142 | conv_up2 = torch.cat([conv_up2 * cur_context_mask, cur_context_mask, cur_feature_mask], dim=2) # [B,S,C+2,H,W] 143 | conv_up2 = conv_up2.reshape(-1, C_feat + 2, H_feat, W_feat) # [BxS,C+2,H,W] 144 | 145 | # repeat / reshape features 146 | for i, feat in enumerate(input_features): 147 | _, C_feat, H_feat, W_feat = feat.size() 148 | cur_context_mask = F.adaptive_avg_pool2d(context_mask, (H_feat, W_feat)).unsqueeze(2) 149 | cur_feature_mask = F.adaptive_avg_pool2d(feature_mask, (H_feat, W_feat)).unsqueeze(2) 150 | feat = feat.unsqueeze(1).repeat(1, S, 1, 1, 1) 151 | feat = torch.cat([feat * cur_context_mask, cur_context_mask, cur_feature_mask], dim=2) # [B,S,C+2,H,W] 152 | input_features[i] = feat.reshape(-1, C_feat + 2, H_feat, W_feat) # [BxS,C+2,H,W] 153 | 154 | outputs = [] 155 | x = conv_up2 156 | for i in range(4, -1, -1): 157 | x = self.convs[self.tuple_to_str(("upconv", i, 0))](x) 158 | x = [upsample(x)] 159 | if self.use_skips and i > 0: 160 | x += [input_features[i - 1]] 161 | x = torch.cat(x, 1) 162 | x = self.convs[self.tuple_to_str(("upconv", i, 1))](x) 163 | if i in self.scales: 164 | output = self.convs[self.tuple_to_str(("dispconv", i))](x) 165 | H_mpi, W_mpi = output.size(2), output.size(3) 166 | cur_mask = F.adaptive_avg_pool2d(cum_mask, (H_mpi, W_mpi)).unsqueeze(2) 167 | mpi = output.view(B, S, 4, H_mpi, W_mpi) 168 | mpi_rgb = self.sigmoid(mpi[:, :, 0:3, :, :]) 169 | if self.use_alpha: 170 | mpi_sigma = self.sigmoid(mpi[:, :, 3:, :, :]) * cur_mask 171 | else: 172 | mpi_sigma = torch.relu(mpi[:, :, 3:, :, :] * cur_mask) + 1e-4 173 | outputs.append(torch.cat((mpi_rgb, mpi_sigma), dim=2)) 174 | return outputs[::-1] 175 | -------------------------------------------------------------------------------- /model/CPN/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | 8 | ''' 9 | This code is borrowed heavily from MINE: https://github.com/vincentfung13/MINE 10 | ''' 11 | 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torchvision.models as models 17 | 18 | 19 | class ResNetMultiImageInput(models.ResNet): 20 | """Constructs a resnet model with varying number of input images. 21 | Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 22 | """ 23 | def __init__(self, block, layers, num_input_images=1): 24 | super(ResNetMultiImageInput, self).__init__(block, layers) 25 | self.inplanes = 64 26 | self.conv1 = nn.Conv2d( 27 | num_input_images * 4, 64, kernel_size=7, stride=2, padding=3, bias=False) # 输入为RGBD 28 | self.bn1 = nn.BatchNorm2d(64) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 31 | self.layer1 = self._make_layer(block, 64, layers[0]) 32 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 33 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 34 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 35 | 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 39 | elif isinstance(m, nn.BatchNorm2d): 40 | nn.init.constant_(m.weight, 1) 41 | nn.init.constant_(m.bias, 0) 42 | 43 | 44 | def resnet_multiimage_input(num_layers, num_input_images=1): 45 | """Constructs a ResNet model. 46 | Args: 47 | num_layers (int): Number of resnet layers. Must be 18 or 50 48 | pretrained (bool): If True, returns a model pre-trained on ImageNet 49 | num_input_images (int): Number of frames stacked as input 50 | """ 51 | assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" 52 | blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] 53 | block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] 54 | model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) 55 | 56 | return model 57 | 58 | 59 | class ResnetEncoder(nn.Module): 60 | """Pytorch module for a resnet encoder 61 | """ 62 | def __init__(self, num_layers, num_input_images=1, **kwargs): 63 | super(ResnetEncoder, self).__init__() 64 | 65 | self.num_ch_enc = np.array([64, 64, 128, 256, 512]) 66 | 67 | resnets = {18: models.resnet18, 68 | 34: models.resnet34, 69 | 50: models.resnet50, 70 | 101: models.resnet101, 71 | 152: models.resnet152} 72 | 73 | if num_layers not in resnets: 74 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) 75 | 76 | self.encoder = resnet_multiimage_input(num_layers, num_input_images) 77 | 78 | if num_layers > 34: 79 | self.num_ch_enc[1:] *= 4 80 | 81 | self.img_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32) 82 | self.img_mean = self.img_mean.view(1, 3, 1, 1) 83 | self.img_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32) 84 | self.img_std = self.img_std.view(1, 3, 1, 1) 85 | 86 | def forward(self, input_image, input_depth): 87 | # normalize before going into network 88 | ref_images_normalized = (input_image - self.img_mean.to(input_image)) / self.img_std.to(input_image) 89 | 90 | self.features = [] 91 | # x = (input_image - 0.45) / 0.225 92 | x = torch.cat([ref_images_normalized, input_depth], dim=1) 93 | x = self.encoder.conv1(x) 94 | x = self.encoder.bn1(x) 95 | conv1_out = self.encoder.relu(x) # [bs,64,h//2,w//2] 96 | block1_out = self.encoder.layer1(self.encoder.maxpool(conv1_out)) # [bs,256,h//4,w//4] 97 | block2_out = self.encoder.layer2(block1_out) # [bs,512,h//8,w//8] 98 | block3_out = self.encoder.layer3(block2_out) # [bs,1024,h//16,w//16] 99 | block4_out = self.encoder.layer4(block3_out) # [bs,2048,h//32,w//32] 100 | 101 | return conv1_out, block1_out, block2_out, block3_out, block4_out 102 | -------------------------------------------------------------------------------- /model/CPN/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConvBNReLU(nn.Module): 6 | def __init__(self, ch_in, ch_out, kernel_size, stride, pad): 7 | super().__init__() 8 | self.layer = nn.Sequential( 9 | nn.Conv2d(ch_in, ch_out, kernel_size, stride, pad), 10 | nn.BatchNorm2d(ch_out), 11 | nn.ReLU() 12 | ) 13 | 14 | def forward(self, x): 15 | return self.layer(x) 16 | 17 | 18 | class FeatMaskNetwork(nn.Module): 19 | def __init__(self, **kwargs): 20 | super().__init__() 21 | self.conv1 = ConvBNReLU(5, 16, 3, 1, 1) 22 | self.conv2 = ConvBNReLU(16, 32, 3, 2, 1) 23 | self.conv3 = ConvBNReLU(32, 64, 3, 2, 1) 24 | self.conv4 = ConvBNReLU(64, 128, 3, 2, 1) 25 | self.conv5 = ConvBNReLU(128, 128, 3, 1, 1) 26 | self.conv6 = ConvBNReLU(192, 64, 3, 1, 1) 27 | self.conv7 = ConvBNReLU(96, 32, 3, 1, 1) 28 | self.conv8 = ConvBNReLU(48, 16, 3, 1, 1) 29 | self.conv9 = ConvBNReLU(16, 1, 3, 1, 1) 30 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 31 | 32 | def forward(self, input_image, input_depth, input_mpi_disparity): 33 | ''' 34 | input_image: [b,3,h,w] 35 | input_depth: [b,1,h,w] 36 | input_mpi_disparity: [b,s] 37 | ''' 38 | _, _, h, w = input_image.size() # spatial dim 39 | b, s = input_mpi_disparity.size() # number of mpi planes 40 | 41 | # repeat input rgb 42 | expanded_image = input_image.unsqueeze(1).repeat(1, s, 1, 1, 1) # [b,s,3,h,w] 43 | 44 | # repeat input depth 45 | expanded_depth = input_depth.unsqueeze(1).repeat(1, s, 1, 1, 1) # [b,s,1,h,w] 46 | 47 | # repeat and reshape input mpi disparity 48 | expanded_mpi_disp = input_mpi_disparity[:, :, None, None, None].repeat(1, 1, 1, h, w) # [b,s,1,h,w] 49 | 50 | # concat together 51 | x = torch.cat([expanded_image, expanded_depth, expanded_mpi_disp], dim=2).reshape(b * s, 5, h, w) # [bs,5,h,w] 52 | 53 | # forward 54 | c1 = self.conv1(x) 55 | c2 = self.conv2(c1) 56 | c3 = self.conv3(c2) 57 | c4 = self.conv4(c3) 58 | c5 = self.conv5(c4) 59 | u5 = self.upsample(c5) 60 | c6 = self.conv6(torch.cat([u5, c3], dim=1)) 61 | u6 = self.upsample(c6) 62 | c7 = self.conv7(torch.cat([u6, c2], dim=1)) 63 | u7 = self.upsample(c7) 64 | c8 = self.conv8(torch.cat([u7, c1], dim=1)) 65 | c9 = self.conv9(c8) # [bs,1,h,w] 66 | fm = c9.reshape(b, s, h, w) 67 | fm = torch.softmax(fm ,dim=1) 68 | 69 | return fm 70 | -------------------------------------------------------------------------------- /model/PAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def MLP(channels): 7 | """ Multi-layer perceptron """ 8 | n = len(channels) 9 | layers = [] 10 | for i in range(1, n): 11 | layers.append( 12 | nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) 13 | if i < (n-1): 14 | layers.append(nn.ReLU()) 15 | return nn.Sequential(*layers) 16 | 17 | 18 | class ResBlock(nn.Module): 19 | def __init__(self, in_channels, out_channels, hidden_channels): 20 | super().__init__() 21 | self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1) 22 | self.conv2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1) 23 | self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1) 24 | self.activation = nn.ReLU() 25 | self.bn = nn.BatchNorm2d(hidden_channels) 26 | 27 | def forward(self, x): 28 | return self.activation(self.conv3(x) + self.conv2(self.bn(self.activation(self.conv1(x))))) 29 | 30 | 31 | class DownsizeEncoder(nn.Module): 32 | def __init__(self, num_blocks, dim_in, dim_out): 33 | super().__init__() 34 | res_blocks = [] 35 | for i_block in range(0, num_blocks): 36 | d_in = dim_in if i_block == 0 else max(dim_in, dim_out // (2 ** (num_blocks - i_block))) 37 | d_out = max(dim_in, dim_out // (2 ** (num_blocks - i_block - 1))) 38 | res_blocks.append(ResBlock(in_channels=d_in, out_channels=d_out, hidden_channels=d_out)) 39 | self.res_blocks = nn.ModuleList(res_blocks) 40 | 41 | def forward(self, x): 42 | # [b, c, h, w] 43 | for res_block in self.res_blocks: 44 | x = res_block(x) 45 | x = F.avg_pool2d(x, kernel_size=2) 46 | return x # [b, c, h, w] 47 | 48 | 49 | class MultiheadSelfAttention(nn.Module): 50 | def __init__(self, num_heads, dim_in, dim_qk, dim_v): 51 | super().__init__() 52 | self.wQs = nn.ModuleList([nn.Linear(dim_in, dim_qk) for _ in range(num_heads)]) 53 | self.wKs = nn.ModuleList([nn.Linear(dim_in, dim_qk) for _ in range(num_heads)]) 54 | self.wVs = nn.ModuleList([nn.Linear(dim_in, dim_v // num_heads) for _ in range(num_heads)]) 55 | self.fusion = nn.Linear(dim_v, dim_v) 56 | self.norm = dim_qk ** 0.5 57 | 58 | def forward(self, feat): 59 | feat_atted = [] 60 | for wQ, wK, wV in zip(self.wQs, self.wKs, self.wVs): 61 | Q = wQ(feat) # [b,s,cq] 62 | K = wK(feat) # [b,s,cq] 63 | V = wV(feat) # [b,s,cv] 64 | att = torch.softmax(torch.einsum('bik,bjk->bij', Q, K) / self.norm, dim=2) 65 | feat_atted.append(torch.einsum('bij,bjc->bic', att, V)) 66 | return self.fusion(torch.cat(feat_atted, dim=-1)) # [b,s,c] 67 | 68 | 69 | class LinearSigmoid(nn.Module): 70 | def __init__(self, in_ch, disp_range): 71 | super().__init__() 72 | self.start, self.end = disp_range 73 | self.linear = nn.Linear(in_ch, 1) 74 | 75 | def forward(self, feat, init_disp): 76 | feat = self.linear(feat).squeeze(-1) # [b,s] 77 | return init_disp + feat * 1. / init_disp.shape[1] 78 | 79 | 80 | class DepthPredictionNetwork(nn.Module): 81 | def __init__(self, disp_range, **kwargs): 82 | super().__init__() 83 | self.context_encoder = DownsizeEncoder(num_blocks=5, dim_in=5, dim_out=128) 84 | self.self_attention = MultiheadSelfAttention(num_heads=4, dim_in=128, dim_qk=32, dim_v=128) 85 | self.embed = nn.Sequential( 86 | nn.Linear(128, 32), 87 | nn.ReLU(), 88 | ) 89 | self.to_disp = LinearSigmoid(32, disp_range) 90 | 91 | def forward(self, init_disp, rgb_low_res, disp_low_res): 92 | B, S = init_disp.shape 93 | 94 | # context encoder 95 | x = torch.cat([ 96 | rgb_low_res[:, None, ...].repeat(1, S, 1, 1, 1), 97 | disp_low_res[:, None, ...].repeat(1, S, 1, 1, 1), 98 | init_disp[:, :, None, None, None].repeat(1, 1, 1, *rgb_low_res.shape[-2:]) 99 | ], dim=-3) # [b, s, 5, h/4, w/4] 100 | x = x.view(-1, *x.shape[-3:]) # [b*s, 5, h/4, w/4] 101 | context = self.context_encoder(x) # [b*s, c, h/128, w/128] 102 | context = F.adaptive_avg_pool2d(context, (1, 1)).squeeze(-1).squeeze(-1) # [b*s, c] 103 | context = context.view(B, S, -1) # [b, s, c] 104 | 105 | # self attention 106 | feat_atted = self.self_attention(context) # [b, s, c ] 107 | feat = self.embed(feat_atted) # [b, s, c] 108 | disp_bs = self.to_disp(feat, init_disp) # [b, s] 109 | return disp_bs 110 | -------------------------------------------------------------------------------- /moving_obj.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch 5 | from flow_colors import flow_to_color 6 | from geometry import * 7 | from ctypes import * 8 | import ctypes 9 | import random 10 | import math 11 | 12 | lib = cdll.LoadLibrary("external/forward_warping/libwarping.so") 13 | warp = lib.forward_warping 14 | 15 | 16 | def moveing_object_with_mask(depth_path, disp, rgb, K, inv_K, instance_mask, i): 17 | 18 | # Cast I0 and D0 to pytorch tensors 19 | h, w = rgb.shape[:2] 20 | rgb = torch.from_numpy(np.expand_dims(rgb, 0)).float().cuda() 21 | # depth = torch.from_numpy(np.expand_dims(depth, 0)).float().cuda() 22 | 23 | # debug 24 | # depth = cv2.imread(depth_path, -1) / (2**16-1) 25 | # if depth.shape[0] != h or depth.shape[1] != w: 26 | # depth = cv2.resize(depth, (w, h)) 27 | 28 | # Get depth map and normalize 29 | depth = 1.0 / (disp[0] + 0.005) 30 | depth[depth > 100] = 100 31 | # depth = torch.from_numpy(np.expand_dims(depth, 0)).float().cuda() 32 | 33 | instance_mask = instance_mask[0] 34 | instance_mask = torch.stack([instance_mask, instance_mask], -1) 35 | 36 | # Create objects in charge of 3D projection 37 | backproject_depth = BackprojectDepth(1, h, w).cuda() 38 | project_3d = Project3D(1, h, w).cuda() 39 | 40 | # Prepare p0 coordinates 41 | meshgrid = np.meshgrid(range(w), range(h), indexing="xy") 42 | p0 = np.stack(meshgrid, axis=-1).astype(np.float32) 43 | 44 | # Initiate masks dictionary 45 | masks = {} 46 | axisangle = torch.from_numpy(np.array([[[0, 0, 0]]], dtype=np.float32)).cuda() 47 | translation = torch.from_numpy(np.array([[0, 0, 0]])).cuda() 48 | 49 | # Compute (R|t) 50 | T1 = transformation_from_parameters(axisangle, translation) 51 | 52 | temp = torch.zeros((1, 4, 4)).cuda() 53 | temp[0, -1, -1] = 1. 54 | temp[:, :3, :3] = K 55 | K = temp 56 | 57 | temp = torch.zeros((1, 4, 4)).cuda() 58 | temp[0, -1, -1] = 1. 59 | temp[:, :3, :3] = inv_K 60 | inv_K = temp 61 | 62 | # Back-projection 63 | cam_points = backproject_depth(depth, inv_K) 64 | 65 | # Apply transformation T_{0->1} 66 | p1, z1 = project_3d(cam_points, K, T1) 67 | z1 = z1.reshape(1, h, w) 68 | 69 | # Simulate objects moving independently 70 | if True: 71 | 72 | sign = -1 73 | 74 | # Random t (scalars and signs). Zeros and small motions are avoided as before 75 | # cix = (random.random()*0.05+0.05) * \ 76 | # (sign*(-1)**random.randrange(2)) 77 | # ciy = (random.random()*0.05+0.05) * \ 78 | # (sign*(-1)**random.randrange(2)) 79 | # ciz = (random.random()*0.05+0.05) * \ 80 | # (sign*(-1)**random.randrange(2)) 81 | cix = (random.random()*0.05+0.05) 82 | ciy = -1*(random.random()*0.05+0.05) 83 | ciz = (random.random()*0.05+0.05) 84 | camerai_mot = [cix, ciy, ciz] 85 | 86 | # Random Euler angles (scalars and signs). Zeros and small rotations are avoided as before 87 | aix = (random.random()*math.pi / 72.0 + math.pi / 88 | 72.0) * (sign*(-1)**random.randrange(2)) 89 | aiy = (random.random()*math.pi / 72.0 + math.pi / 90 | 72.0) * (sign*(-1)**random.randrange(2)) 91 | aiz = (random.random()*math.pi / 72.0 + math.pi / 92 | 72.0) * (sign*(-1)**random.randrange(2)) 93 | camerai_ang = [aix, aiy, aiz] 94 | camerai_ang = [0, 0, 0] 95 | 96 | ai = torch.from_numpy( 97 | np.array([[camerai_ang]], dtype=np.float32)).cuda() 98 | tri = torch.from_numpy(np.array([[camerai_mot]])).cuda() 99 | 100 | # Compute (R|t) 101 | Ti = transformation_from_parameters( 102 | axisangle + ai, translation + tri) 103 | 104 | # Apply transformation T_{0->\pi_i} 105 | pi, zi = project_3d(cam_points, K, Ti) 106 | 107 | # If a pixel belongs to object label l, replace coordinates in I1... 108 | p1[instance_mask > 0] = pi[instance_mask > 0] 109 | 110 | # ... and its depth 111 | zi = zi.reshape(1, h, w) 112 | z1[instance_mask[:, :, :, 0] > 0] = zi[instance_mask[:, :, :, 0] > 0] 113 | 114 | # Bring p1 coordinates in [0,W-1]x[0,H-1] format 115 | p1 = (p1 + 1) / 2 116 | p1[:, :, :, 0] *= w - 1 117 | p1[:, :, :, 1] *= h - 1 118 | 119 | # Create auxiliary data for warping 120 | dlut = torch.ones(1, h, w).float().cuda() * 1000 121 | safe_y = np.maximum(np.minimum(p1[:, :, :, 1].cpu().long(), h - 1), 0) 122 | safe_x = np.maximum(np.minimum(p1[:, :, :, 0].cpu().long(), w - 1), 0) 123 | warped_arr = np.zeros(h*w*5).astype(np.uint8) 124 | img = rgb.reshape(-1).to(torch.uint8) 125 | 126 | # Call forward warping routine (C code) 127 | warp(c_void_p(img.cpu().numpy().ctypes.data), c_void_p(safe_x[0].cpu().numpy().ctypes.data), 128 | c_void_p(safe_y[0].cpu().numpy().ctypes.data), c_void_p(z1.reshape(-1).cpu().numpy().ctypes.data), 129 | c_void_p(warped_arr.ctypes.data), c_int(h), c_int(w)) 130 | warped_arr = warped_arr.reshape(1, h, w, 5).astype(np.uint8) 131 | 132 | # Warped image 133 | im1_raw = warped_arr[0, :, :, 0:3] 134 | 135 | # Validity mask H 136 | masks["H"] = warped_arr[0, :, :, 3:4] 137 | 138 | # Collision mask M 139 | masks["M"] = warped_arr[0, :, :, 4:5] 140 | # Keep all pixels that are invalid (H) or collide (M) 141 | masks["M"] = 1-(masks["M"] == masks["H"]).astype(np.uint8) 142 | 143 | # Dilated collision mask M' 144 | kernel = np.ones((3, 3), np.uint8) 145 | masks["M'"] = cv2.dilate(masks["M"], kernel, iterations=1) 146 | masks["P"] = (np.expand_dims(masks["M'"], -1) 147 | == masks["M"]).astype(np.uint8) 148 | 149 | # Final mask P 150 | masks["H'"] = masks["H"]*masks["P"] 151 | 152 | # Compute flow as p1-p0 153 | flow_01 = p1.cpu().numpy() - p0 154 | im1 = rgb[0].cpu().numpy().copy() 155 | # mask_idx = np.logical_and( 156 | # flow_01[0, :, :, 0] > 1, 157 | # flow_01[0, :, :, 1] > 1 158 | # ) 159 | mask_idx = np.where(instance_mask[0, :, :, 0].cpu().numpy()) 160 | # mask_xp = mask_x, mask_y 161 | 162 | im1 = cv2.inpaint(im1_raw, 1 - masks["H"], 3, cv2.INPAINT_TELEA) 163 | flow_color = flow_to_color(flow_01[0], convert_to_bgr=True) 164 | mask = cv2.merge([masks["H"]*255, masks["H"]*255, masks["H"]*255]) 165 | res = np.vstack( 166 | [rgb[0].cpu().numpy(), im1, im1_raw, mask, flow_color] 167 | ) 168 | cv2.imwrite('temp/res-{:06d}.png'.format(i), res) -------------------------------------------------------------------------------- /scripts/gen_coco.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=5 python gen_3dphoto_dynamic_coco.py \ 2 | --base ../dataset/Flow/extra/coco/outputs/ \ 3 | --out ../dataset/Flow/extra/coco/MPI-Flow-data \ 4 | --repeat 2 -------------------------------------------------------------------------------- /scripts/gen_test_kitti15.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python gen_3dphoto_dynamic.py \ 2 | --base ../dataset/Flow/testing/outputs/ \ 3 | --out ../dataset/Flow/testing/MPI-Flow-data \ 4 | --repeat 5 -------------------------------------------------------------------------------- /scripts/gen_train_kitti15.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=3 python gen_3dphoto_dynamic.py \ 2 | --base ../dataset/Flow/training/outputs/ \ 3 | --out ./dataset/debug \ 4 | --repeat 2 --seed 0 -------------------------------------------------------------------------------- /scripts/gen_train_kitti15_v2.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python gen_3dphoto_dynamic_v2.py \ 2 | --base ../dataset/Flow/training/outputs/ \ 3 | --out /data/liangyingping/debug_0.35 \ 4 | --repeat 2 --seed 0 --ext_cz 0.25 -------------------------------------------------------------------------------- /utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | from utils.arrow import arrowon 20 | import cv2 21 | 22 | def make_colorwheel(): 23 | """ 24 | Generates a color wheel for optical flow visualization as presented in: 25 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 26 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 27 | 28 | Code follows the original C++ source code of Daniel Scharstein. 29 | Code follows the the Matlab source code of Deqing Sun. 30 | 31 | Returns: 32 | np.ndarray: Color wheel 33 | """ 34 | 35 | RY = 15 36 | YG = 6 37 | GC = 4 38 | CB = 11 39 | BM = 13 40 | MR = 6 41 | 42 | ncols = RY + YG + GC + CB + BM + MR 43 | colorwheel = np.zeros((ncols, 3)) 44 | col = 0 45 | 46 | # RY 47 | colorwheel[0:RY, 0] = 255 48 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 49 | col = col+RY 50 | # YG 51 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 52 | colorwheel[col:col+YG, 1] = 255 53 | col = col+YG 54 | # GC 55 | colorwheel[col:col+GC, 1] = 255 56 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 57 | col = col+GC 58 | # CB 59 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 60 | colorwheel[col:col+CB, 2] = 255 61 | col = col+CB 62 | # BM 63 | colorwheel[col:col+BM, 2] = 255 64 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 65 | col = col+BM 66 | # MR 67 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 68 | colorwheel[col:col+MR, 0] = 255 69 | return colorwheel 70 | 71 | 72 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 73 | """ 74 | Applies the flow color wheel to (possibly clipped) flow components u and v. 75 | 76 | According to the C++ source code of Daniel Scharstein 77 | According to the Matlab source code of Deqing Sun 78 | 79 | Args: 80 | u (np.ndarray): Input horizontal flow of shape [H,W] 81 | v (np.ndarray): Input vertical flow of shape [H,W] 82 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 83 | 84 | Returns: 85 | np.ndarray: Flow visualization image of shape [H,W,3] 86 | """ 87 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 88 | colorwheel = make_colorwheel() # shape [55x3] 89 | ncols = colorwheel.shape[0] 90 | rad = np.sqrt(np.square(u) + np.square(v)) 91 | a = np.arctan2(-v, -u)/np.pi 92 | fk = (a+1) / 2*(ncols-1) 93 | k0 = np.floor(fk).astype(np.int32) 94 | k1 = k0 + 1 95 | k1[k1 == ncols] = 0 96 | f = fk - k0 97 | for i in range(colorwheel.shape[1]): 98 | tmp = colorwheel[:,i] 99 | col0 = tmp[k0] / 255.0 100 | col1 = tmp[k1] / 255.0 101 | col = (1-f)*col0 + f*col1 102 | idx = (rad <= 1) 103 | col[idx] = 1 - rad[idx] * (1-col[idx]) 104 | col[~idx] = col[~idx] * 0.75 # out of range 105 | # Note the 2-i => BGR instead of RGB 106 | ch_idx = 2-i if convert_to_bgr else i 107 | flow_image[:,:,ch_idx] = np.floor(255 * col) 108 | return flow_image 109 | 110 | 111 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 112 | """ 113 | Expects a two dimensional flow image of shape. 114 | 115 | Args: 116 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 117 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 118 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 119 | 120 | Returns: 121 | np.ndarray: Flow visualization image of shape [H,W,3] 122 | """ 123 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 124 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 125 | if clip_flow is not None: 126 | flow_uv = np.clip(flow_uv, 0, clip_flow) 127 | u = flow_uv[:,:,0] 128 | v = flow_uv[:,:,1] 129 | rad = np.sqrt(np.square(u) + np.square(v)) 130 | rad_max = np.max(rad) 131 | epsilon = 1e-5 132 | u = u / (rad_max + epsilon) 133 | v = v / (rad_max + epsilon) 134 | return flow_uv_to_colors(u, v, convert_to_bgr) 135 | 136 | def viz_batch_mask_np(imgs, flos, fusions=None, masks=None, arrow_step=32, save_path='tmp.jpg'): 137 | ''' 138 | input: imgs = [image1, image2] 139 | fusions = [1->2, 2->1] 140 | flos = [1->2, 2->1] 141 | masks = [1->2, 2->1] 142 | flow_past image2_arrow_past fusion1 mask_past 143 | flow_future image1_arrow_future fusion2 mask_future 144 | ''' 145 | image1 = np.array(imgs[0]) 146 | image2 = np.array(imgs[1]) 147 | 148 | show_img_past = [] 149 | #show_img_past.append(image1) 150 | img2_past = arrowon(image1, flos[0], arrow_step) 151 | show_img_past.append(img2_past) 152 | show_img_past.append(flow_to_image(flos[0])) 153 | if fusions is not None: 154 | show_img_past.append(fusions[0]) 155 | if masks is not None: 156 | show_img_past.append(np.tile(masks[0]*255, (1, 1, 3))) 157 | 158 | show_img_future = [] 159 | #show_img_future.append(image2) 160 | img2_past = arrowon(image2, flos[1], arrow_step) 161 | show_img_future.append(img2_past) 162 | show_img_future.append(flow_to_image(flos[1])) 163 | if fusions is not None: 164 | show_img_future.append(fusions[1]) 165 | if masks is not None: 166 | show_img_future.append(np.tile(masks[1]*255, (1, 1, 3))) 167 | 168 | #img_flo = np.concatenate(show_img, axis=1) 169 | show_past = np.concatenate(show_img_past, axis=1) 170 | show_future = np.concatenate(show_img_future, axis=1) 171 | show_img = np.concatenate([show_past, show_future], axis=0) 172 | 173 | cv2.imwrite(save_path, show_img[:, :, [2,1,0]]) 174 | 175 | def viz_batch_mask(imgs, fusions, flos, masks, save_path): 176 | ''' 177 | input: imgs = [image1, image2, image3] 178 | fusions = [1->2, 3->2] 179 | flos = [flow_past, flow_future] 180 | masks = [mask_past, mask_future] 181 | image1 image2_arrow_past fusion1 flow_past mask_past 182 | image3 image2_arrow_future fusion2 flow_future mask_future 183 | ''' 184 | image2 = imgs[1][0].permute(1,2,0).cpu().numpy() 185 | 186 | show_img_past = [] 187 | show_img_past.append(imgs[0][0].permute(1,2,0).cpu().numpy()) 188 | img2_past = arrowon(image2, flos[0][0], 32) 189 | show_img_past.append(img2_past) 190 | show_img_past.append(fusions[0][0].permute(1,2,0).cpu().numpy()) 191 | show_img_past.append(flow_to_image(flos[0][0].permute(1,2,0).cpu().numpy())) 192 | show_img_past.append(np.tile(masks[0][0].permute(1,2,0).cpu().numpy()*255, (1, 1, 3))) 193 | 194 | show_img_future = [] 195 | show_img_future.append(imgs[2][0].permute(1,2,0).cpu().numpy()) 196 | img2_future = arrowon(image2, flos[1][0], 32) 197 | show_img_future.append(img2_future) 198 | show_img_future.append(fusions[1][0].permute(1,2,0).cpu().numpy()) 199 | show_img_future.append(flow_to_image(flos[1][0].permute(1,2,0).cpu().numpy())) 200 | show_img_future.append(np.tile(masks[1][0].permute(1,2,0).cpu().numpy()*255, (1, 1, 3))) 201 | 202 | #img_flo = np.concatenate(show_img, axis=1) 203 | show_past = np.concatenate(show_img_past, axis=1) 204 | show_future = np.concatenate(show_img_future, axis=1) 205 | show_img = np.concatenate([show_past, show_future], axis=0) 206 | 207 | cv2.imwrite(save_path, show_img[:, :, [2,1,0]]) 208 | 209 | def viz_batch(imgs, flo, save_path): 210 | show_img = [] 211 | for img in imgs: 212 | img = img[0].permute(1,2,0).cpu().numpy() 213 | show_img.append(img) 214 | 215 | flo = flo[0].permute(1,2,0).cpu().numpy() 216 | show_img[0] = arrowon(show_img[0], flo, 32) 217 | 218 | # map flow to rgb image 219 | flo = flow_to_image(flo) 220 | show_img.append(flo) 221 | img_flo = np.concatenate(show_img, axis=1) 222 | 223 | cv2.imwrite(save_path, img_flo[:, :, [2,1,0]]) 224 | -------------------------------------------------------------------------------- /utils/mpi/rendering_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def transform_G_xyz(G, xyz, is_return_homo=False): 5 | """ 6 | 7 | :param G: Bx4x4 8 | :param xyz: Bx3xN 9 | :return: 10 | """ 11 | assert len(G.size()) == len(xyz.size()) 12 | if len(G.size()) == 2: 13 | G_B44 = G.unsqueeze(0) 14 | xyz_B3N = xyz.unsqueeze(0) 15 | else: 16 | G_B44 = G 17 | xyz_B3N = xyz 18 | xyz_B4N = torch.cat((xyz_B3N, torch.ones_like(xyz_B3N[:, 0:1, :])), dim=1) 19 | G_xyz_B4N = torch.matmul(G_B44, xyz_B4N) 20 | if is_return_homo: 21 | return G_xyz_B4N 22 | else: 23 | return G_xyz_B4N[:, 0:3, :] 24 | 25 | 26 | def gather_pixel_by_pxpy(img, pxpy): 27 | """ 28 | 29 | :param img: Bx3xHxW 30 | :param pxpy: Bx2xN 31 | :return: 32 | """ 33 | with torch.no_grad(): 34 | B, C, H, W = img.size() 35 | if pxpy.dtype == torch.float32: 36 | pxpy_int = torch.round(pxpy).to(torch.int64) 37 | pxpy_int = pxpy_int.to(torch.int64) 38 | pxpy_int[:, 0, :] = torch.clamp(pxpy_int[:, 0, :], min=0, max=W-1) 39 | pxpy_int[:, 1, :] = torch.clamp(pxpy_int[:, 1, :], min=0, max=H-1) 40 | pxpy_idx = pxpy_int[:, 0:1, :] + W * pxpy_int[:, 1:2, :] # Bx1xN_pt 41 | rgb = torch.gather(img.view(B, C, H * W), dim=2, 42 | index=pxpy_idx.repeat(1, C, 1)) # BxCxN_pt 43 | return rgb 44 | 45 | 46 | def uniformly_sample_disparity_from_bins(batch_size, disparity_np, device): 47 | """ 48 | In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far) 49 | :param start: 50 | :param end: 51 | :param num_bins: 52 | :return: 53 | """ 54 | assert disparity_np[0] > disparity_np[-1] 55 | S = disparity_np.shape[0] - 1 56 | 57 | B = batch_size 58 | bin_edges = torch.from_numpy(disparity_np).to(dtype=torch.float32, device=device) # S+1 59 | interval = bin_edges[1:] - bin_edges[0:-1] # S 60 | bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS 61 | # bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS 62 | interval = interval.unsqueeze(0).repeat(B, 1) # S -> BxS 63 | 64 | random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS 65 | disparity_array = bin_edges_start + interval * random_float 66 | return disparity_array # BxS 67 | 68 | 69 | def uniformly_sample_disparity_from_linspace_bins(batch_size, num_bins, start, end, device): 70 | """ 71 | In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far) 72 | :param start: 73 | :param end: 74 | :param num_bins: 75 | :return: 76 | """ 77 | assert start > end 78 | 79 | B, S = batch_size, num_bins 80 | bin_edges = torch.linspace(start, end, num_bins+1, dtype=torch.float32, device=device) # S+1 81 | interval = bin_edges[1] - bin_edges[0] # scalar 82 | bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS 83 | # bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS 84 | 85 | random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS 86 | disparity_array = bin_edges_start + interval * random_float 87 | return disparity_array # BxS 88 | 89 | 90 | def sample_pdf(values, weights, N_samples): 91 | """ 92 | draw samples from distribution approximated by values and weights. 93 | the probability distribution can be denoted as weights = p(values) 94 | :param values: Bx1xNxS 95 | :param weights: Bx1xNxS 96 | :param N_samples: number of sample to draw 97 | :return: 98 | """ 99 | B, N, S = weights.size(0), weights.size(2), weights.size(3) 100 | assert values.size() == (B, 1, N, S) 101 | 102 | # convert values to bin edges 103 | bin_edges = (values[:, :, :, 1:] + values[:, :, :, :-1]) * 0.5 # Bx1xNxS-1 104 | bin_edges = torch.cat((values[:, :, :, 0:1], 105 | bin_edges, 106 | values[:, :, :, -1:]), dim=3) # Bx1xNxS+1 107 | 108 | pdf = weights / (torch.sum(weights, dim=3, keepdim=True) + 1e-5) # Bx1xNxS 109 | cdf = torch.cumsum(pdf, dim=3) # Bx1xNxS 110 | cdf = torch.cat((torch.zeros((B, 1, N, 1), dtype=cdf.dtype, device=cdf.device), 111 | cdf), dim=3) # Bx1xNxS+1 112 | 113 | # uniform sample over the cdf values 114 | u = torch.rand((B, 1, N, N_samples), dtype=weights.dtype, device=weights.device) # Bx1xNxN_samples 115 | 116 | # get the index on the cdf array 117 | cdf_idx = torch.searchsorted(cdf, u, right=True) # Bx1xNxN_samples 118 | cdf_idx_lower = torch.clamp(cdf_idx-1, min=0) # Bx1xNxN_samples 119 | cdf_idx_upper = torch.clamp(cdf_idx, max=S) # Bx1xNxN_samples 120 | 121 | # linear approximation for each bin 122 | cdf_idx_lower_upper = torch.cat((cdf_idx_lower, cdf_idx_upper), dim=3) # Bx1xNx(N_samplesx2) 123 | cdf_bounds_N2 = torch.gather(cdf, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2) 124 | cdf_bounds = torch.stack((cdf_bounds_N2[..., 0:N_samples], cdf_bounds_N2[..., N_samples:]), dim=4) 125 | bin_bounds_N2 = torch.gather(bin_edges, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2) 126 | bin_bounds = torch.stack((bin_bounds_N2[..., 0:N_samples], bin_bounds_N2[..., N_samples:]), dim=4) 127 | 128 | # avoid zero cdf_intervals 129 | cdf_intervals = cdf_bounds[:, :, :, :, 1] - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples 130 | bin_intervals = bin_bounds[:, :, :, :, 1] - bin_bounds[:, :, :, :, 0] # Bx1xNxN_samples 131 | u_cdf_lower = u - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples 132 | # there is the case that cdf_interval = 0, caused by the cdf_idx_lower/upper clamp above, need special handling 133 | t = u_cdf_lower / torch.clamp(cdf_intervals, min=1e-5) 134 | t = torch.where(cdf_intervals <= 1e-4, 135 | torch.full_like(u_cdf_lower, 0.5), 136 | t) 137 | 138 | samples = bin_bounds[:, :, :, :, 0] + t*bin_intervals 139 | return samples 140 | -------------------------------------------------------------------------------- /utils/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import glob 4 | from utils import flow_viz 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | def gen_random_perspective(): 9 | ''' 10 | generate a random 3x3 perspective matrix 11 | ''' 12 | init_M = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 13 | noise = np.random.normal(0, 0.0001, 9) 14 | noise = np.reshape(noise, [3, 3]) 15 | noise[2, 2] = 0 16 | return init_M + noise 17 | 18 | def get_flow(img, M): 19 | ''' 20 | use img shape and M to calculate flow 21 | return flow 22 | ''' 23 | ## calculate flow 24 | x = np.linspace(0, img.shape[1]-1, img.shape[1]) 25 | y = np.linspace(0, img.shape[0]-1, img.shape[0]) 26 | xx, yy = np.meshgrid(x, y) 27 | coords = np.stack([xx, yy, np.ones_like(xx)], axis=0) 28 | #new_coords = np.einsum('ij,jkl->ikl', M, np.transpose(coords, (2, 0, 1))) 29 | new_coords = np.einsum('ij,jkl->ikl', M, coords) 30 | xx2 = new_coords[0, :, :] / new_coords[2, :, :] 31 | yy2 = new_coords[1, :, :] / new_coords[2, :, :] 32 | #xx2 = xx2 * img.shape[1] 33 | #yy2 = yy2 * img.shape[0] 34 | #import pdb; pdb.set_trace() 35 | xx2 = xx2.astype(np.float32) 36 | yy2 = yy2.astype(np.float32) 37 | flow_x = xx2-xx #* img.shape[1] 38 | flow_y = yy2-yy #* img.shape[0] 39 | flow_x = flow_x.astype(np.float32) 40 | flow_y = flow_y.astype(np.float32) 41 | return np.stack([flow_x, flow_y], axis=2) 42 | 43 | def transform(img, flow): 44 | ''' 45 | remap image according to the M. 46 | return warped img and flow 47 | ''' 48 | 49 | flow = flow.astype(np.float32) 50 | flow_x = flow[:, :, 0] 51 | flow_y = flow[:, :, 1] 52 | 53 | ## warp img by flow 54 | fh, fw = flow_x.shape 55 | add = np.mgrid[0:fh,0:fw].astype(np.float32); 56 | 57 | img_flow = cv2.remap(img, flow_y+add[1,:,:], flow_x+add[0,:,:], cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101) 58 | return img_flow 59 | 60 | def warp(x, flo): 61 | """ 62 | warp an image/tensor (im2) back to im1, according to the optical flow 63 | 64 | x: [B, C, H, W] (im2) 65 | flo: [B, 2, H, W] flow 66 | 67 | """ 68 | B, C, H, W = x.size() 69 | # mesh grid 70 | xx = torch.arange(0, W).view(1 ,-1).repeat(H ,1) 71 | yy = torch.arange(0, H).view(-1 ,1).repeat(1 ,W) 72 | xx = xx.view(1 ,1 ,H ,W).repeat(B ,1 ,1 ,1) 73 | yy = yy.view(1 ,1 ,H ,W).repeat(B ,1 ,1 ,1) 74 | grid = torch.cat((xx ,yy) ,1).float() 75 | if x.is_cuda: 76 | grid = grid.cuda() 77 | vgrid = torch.autograd.Variable(grid).detach() + flo 78 | 79 | # scale grid to [-1,1] 80 | vgrid[: ,0 ,: ,:] = 2.0 *vgrid[: ,0 ,: ,:].clone() / max( W -1 ,1 ) -1.0 81 | vgrid[: ,1 ,: ,:] = 2.0 *vgrid[: ,1 ,: ,:].clone() / max( H -1 ,1 ) -1.0 82 | 83 | vgrid = vgrid.permute(0 ,2 ,3 ,1) 84 | flo = flo.permute(0 ,2 ,3 ,1) 85 | output = F.grid_sample(x, vgrid) 86 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda() 87 | mask = F.grid_sample(mask, vgrid).detach() 88 | 89 | #mask[mask <0.9999] = 0 90 | #mask[mask >0] = 1 91 | 92 | return output, mask 93 | 94 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 95 | backwarp_tenGrid = {} 96 | 97 | def warp_rife(tenInput, tenFlow): 98 | k = (str(tenFlow.device), str(tenFlow.size())) 99 | if k not in backwarp_tenGrid: 100 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( 101 | 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 102 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( 103 | 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 104 | backwarp_tenGrid[k] = torch.cat( 105 | [tenHorizontal, tenVertical], 1).to(device) 106 | 107 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), 108 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) 109 | 110 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 111 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) 112 | 113 | if __name__ == "__main__": 114 | 115 | import sys 116 | #img_path = '/share/boyuan/Data/snow_good_imgs/01e1fbc1a563c467018370037ebf6cd3ae_258/000006.png' 117 | img_path = '/share/boyuan/Data/snow_good_imgs/skiing_0825/000006.png' 118 | seg_prefix = '/share/boyuan/Projects/mmdetection/output/skiing_0825/000006' 119 | img = cv2.imread(img_path) 120 | 121 | #all_mask = np.zeros((img.shape[0], img.shape[1])) 122 | res_flow = np.zeros((img.shape[0], img.shape[1], 2)) 123 | for m_path in glob.glob(seg_prefix+"-*"): 124 | sub_mask = cv2.imread(m_path)[:, :, 0] 125 | sub_mask = sub_mask / 255 126 | sub_mask = sub_mask.astype(np.uint8) 127 | M = gen_random_perspective() 128 | sub_flow = get_flow(img, M) 129 | res_flow[np.where(sub_mask==1)] = sub_flow[np.where(sub_mask==1)] 130 | 131 | res_img = transform(img, res_flow) 132 | 133 | img_flo = flow_viz.flow_to_image(res_flow) 134 | out = np.concatenate([img, img_flo, res_img], axis=1) 135 | print(np.max(res_flow)) 136 | cv2.imwrite('tmp/warp.jpg', out) 137 | cv2.imwrite('/share/boyuan/Projects/RAFT/tmp/skiing_per/000007.png', res_img) 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /vis_flow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | 5 | def readFlow(fn): 6 | """ Read .flo file in Middlebury format""" 7 | # Code adapted from: 8 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 9 | 10 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 11 | # print 'fn = %s'%(fn) 12 | with open(fn, 'rb') as f: 13 | magic = np.fromfile(f, np.float32, count=1) 14 | if 202021.25 != magic: 15 | print('Magic number incorrect. Invalid .flo file') 16 | return None 17 | else: 18 | w = np.fromfile(f, np.int32, count=1) 19 | h = np.fromfile(f, np.int32, count=1) 20 | # print 'Reading %d x %d flo file\n' % (w, h) 21 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 22 | # Reshape data into 3D array (columns, rows, bands) 23 | # The reshape here is for visualization, the original code is (w,h,2) 24 | return np.resize(data, (int(h), int(w), 2)) 25 | 26 | base = "dataset/debug" 27 | 28 | if not os.path.exists(os.path.join(base, "vis")): 29 | os.mkdir(os.path.join(base, "vis")) 30 | 31 | for img in os.listdir(os.path.join(base, "src_images")): 32 | for r in range(4): 33 | image1 = cv2.imread(os.path.join(base, "src_images", img)) 34 | image2 = cv2.imread(os.path.join(base, "dst_images", img.replace(".png", f"_{r}.png"))) 35 | flow= readFlow(os.path.join(base, "flows", img.replace(".png", f"_{r}.flo"))) 36 | print(flow.max(), flow.min(), flow.shape) 37 | 38 | H, W = image1.shape[:2] 39 | res = np.vstack([image1, image2]) 40 | 41 | for _ in range(30): 42 | 43 | x1 = np.random.randint(W) 44 | y1 = np.random.randint(H) 45 | x2 = x1 + int(flow[y1, x1, 0]) 46 | y2 = y1 + int(flow[y1, x1, 1]) + H 47 | 48 | cv2.line(res, (x1, y1), (x2, y2), (0,255,0), 2) 49 | 50 | cv2.imwrite(os.path.join(base, "vis", img.replace(".png", f"_{r}.png")), res) -------------------------------------------------------------------------------- /warpback/networks.py: -------------------------------------------------------------------------------- 1 | ''' 2 | this code is adapt from the EdgeConnect repo (https://github.com/knazeri/edge-connect) 3 | ''' 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import os 9 | 10 | 11 | def get_edge_connect(weight_dir): 12 | inpaint_model = InpaintGenerator() 13 | inpaint_model_weight = torch.load(os.path.join(weight_dir, "InpaintingModel_gen.pth")) 14 | inpaint_model.load_state_dict(inpaint_model_weight["generator"]) 15 | inpaint_model.eval() 16 | 17 | edge_model = EdgeGenerator() 18 | edge_model_weight = torch.load(os.path.join(weight_dir, "EdgeModel_gen.pth")) 19 | edge_model.load_state_dict(edge_model_weight["generator"]) 20 | edge_model.eval() 21 | 22 | disp_model = InpaintGenerator(in_channels=2, out_channels=1) 23 | disp_model_weight = torch.load(os.path.join(weight_dir, "InpaintingModel_disp.pth")) 24 | disp_model.load_state_dict(disp_model_weight["generator"]) 25 | disp_model.eval() 26 | return edge_model, inpaint_model, disp_model 27 | 28 | 29 | class BaseNetwork(nn.Module): 30 | def __init__(self): 31 | super(BaseNetwork, self).__init__() 32 | 33 | def init_weights(self, init_type='normal', gain=0.02): 34 | ''' 35 | initialize network's weights 36 | init_type: normal | xavier | kaiming | orthogonal 37 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 38 | ''' 39 | 40 | def init_func(m): 41 | classname = m.__class__.__name__ 42 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 43 | if init_type == 'normal': 44 | nn.init.normal_(m.weight.data, 0.0, gain) 45 | elif init_type == 'xavier': 46 | nn.init.xavier_normal_(m.weight.data, gain=gain) 47 | elif init_type == 'kaiming': 48 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 49 | elif init_type == 'orthogonal': 50 | nn.init.orthogonal_(m.weight.data, gain=gain) 51 | 52 | if hasattr(m, 'bias') and m.bias is not None: 53 | nn.init.constant_(m.bias.data, 0.0) 54 | 55 | elif classname.find('BatchNorm2d') != -1: 56 | nn.init.normal_(m.weight.data, 1.0, gain) 57 | nn.init.constant_(m.bias.data, 0.0) 58 | 59 | self.apply(init_func) 60 | 61 | 62 | class InpaintGenerator(BaseNetwork): 63 | def __init__(self, residual_blocks=8, init_weights=True, in_channels=4, out_channels=3): 64 | super(InpaintGenerator, self).__init__() 65 | 66 | self.encoder = nn.Sequential( 67 | nn.ReflectionPad2d(3), 68 | nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0), 69 | nn.InstanceNorm2d(64, track_running_stats=False), 70 | nn.ReLU(True), 71 | 72 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), 73 | nn.InstanceNorm2d(128, track_running_stats=False), 74 | nn.ReLU(True), 75 | 76 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), 77 | nn.InstanceNorm2d(256, track_running_stats=False), 78 | nn.ReLU(True) 79 | ) 80 | 81 | blocks = [] 82 | for _ in range(residual_blocks): 83 | block = ResnetBlock(256, 2) 84 | blocks.append(block) 85 | 86 | self.middle = nn.Sequential(*blocks) 87 | 88 | self.decoder = nn.Sequential( 89 | nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), 90 | nn.InstanceNorm2d(128, track_running_stats=False), 91 | nn.ReLU(True), 92 | 93 | nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), 94 | nn.InstanceNorm2d(64, track_running_stats=False), 95 | nn.ReLU(True), 96 | 97 | nn.ReflectionPad2d(3), 98 | nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=7, padding=0), 99 | ) 100 | 101 | if init_weights: 102 | self.init_weights() 103 | 104 | def forward(self, x): 105 | x = self.encoder(x) 106 | x = self.middle(x) 107 | x = self.decoder(x) 108 | x = (torch.tanh(x) + 1) / 2 109 | 110 | return x 111 | 112 | 113 | class EdgeGenerator(BaseNetwork): 114 | def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True): 115 | super(EdgeGenerator, self).__init__() 116 | 117 | self.encoder = nn.Sequential( 118 | nn.ReflectionPad2d(3), 119 | spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm), 120 | nn.InstanceNorm2d(64, track_running_stats=False), 121 | nn.ReLU(True), 122 | 123 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm), 124 | nn.InstanceNorm2d(128, track_running_stats=False), 125 | nn.ReLU(True), 126 | 127 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm), 128 | nn.InstanceNorm2d(256, track_running_stats=False), 129 | nn.ReLU(True) 130 | ) 131 | 132 | blocks = [] 133 | for _ in range(residual_blocks): 134 | block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm) 135 | blocks.append(block) 136 | 137 | self.middle = nn.Sequential(*blocks) 138 | 139 | self.decoder = nn.Sequential( 140 | spectral_norm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm), 141 | nn.InstanceNorm2d(128, track_running_stats=False), 142 | nn.ReLU(True), 143 | 144 | spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm), 145 | nn.InstanceNorm2d(64, track_running_stats=False), 146 | nn.ReLU(True), 147 | 148 | nn.ReflectionPad2d(3), 149 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0), 150 | ) 151 | 152 | if init_weights: 153 | self.init_weights() 154 | 155 | def forward(self, x): 156 | x = self.encoder(x) 157 | x = self.middle(x) 158 | x = self.decoder(x) 159 | x = torch.sigmoid(x) 160 | return x 161 | 162 | 163 | class ResnetBlock(nn.Module): 164 | def __init__(self, dim, dilation=1, use_spectral_norm=False): 165 | super(ResnetBlock, self).__init__() 166 | self.conv_block = nn.Sequential( 167 | nn.ReflectionPad2d(dilation), 168 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm), 169 | nn.InstanceNorm2d(dim, track_running_stats=False), 170 | nn.ReLU(True), 171 | 172 | nn.ReflectionPad2d(1), 173 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm), 174 | nn.InstanceNorm2d(dim, track_running_stats=False), 175 | ) 176 | 177 | def forward(self, x): 178 | out = x + self.conv_block(x) 179 | return out 180 | 181 | 182 | def spectral_norm(module, mode=True): 183 | if mode: 184 | return nn.utils.spectral_norm(module) 185 | 186 | return module -------------------------------------------------------------------------------- /warpback/stage1_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | sys.path.append("..") 4 | import os 5 | import glob 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.data.dataset import Dataset 10 | from torch.utils.data.dataloader import DataLoader, default_collate 11 | from torchvision.utils import save_image 12 | 13 | from warpback.utils import ( 14 | RGBDRenderer, 15 | image_to_tensor, 16 | disparity_to_tensor, 17 | transformation_from_parameters, 18 | ) 19 | 20 | 21 | class WarpBackStage1Dataset(Dataset): 22 | def __init__( 23 | self, 24 | data_root, 25 | width=384, 26 | height=256, 27 | depth_dir_name="dpt_depth", 28 | device="cuda", # device of mesh renderer 29 | trans_range={"x":0.2, "y":-1, "z":-1, "a":-1, "b":-1, "c":-1}, # xyz for translation, abc for euler angle 30 | ): 31 | self.data_root = data_root 32 | self.depth_dir_name = depth_dir_name 33 | self.renderer = RGBDRenderer(device) 34 | self.width = width 35 | self.height = height 36 | self.device = device 37 | self.trans_range = trans_range 38 | self.image_path_list = glob.glob(os.path.join(self.data_root, "*.jpg")) 39 | self.image_path_list += glob.glob(os.path.join(self.data_root, "*.png")) 40 | 41 | # set intrinsics 42 | self.K = torch.tensor([ 43 | [0.58, 0, 0.5], 44 | [0, 0.58, 0.5], 45 | [0, 0, 1] 46 | ]).to(device) 47 | 48 | def __len__(self): 49 | return len(self.image_path_list) 50 | 51 | def __getitem__(self, idx): 52 | image_path = self.image_path_list[idx] 53 | image_name = os.path.splitext(os.path.basename(image_path))[0] 54 | disp_path = os.path.join(self.data_root, self.depth_dir_name, "%s.png" % image_name) 55 | 56 | image = image_to_tensor(image_path, unsqueeze=False) # [3,h,w] 57 | disp = disparity_to_tensor(disp_path, unsqueeze=False) # [1,h,w] 58 | 59 | # do some data augmentation, ensure the rgbd spatial resolution is (self.height, self.width) 60 | image, disp = self.preprocess_rgbd(image, disp) 61 | 62 | return image, disp 63 | 64 | def preprocess_rgbd(self, image, disp): 65 | # NOTE 66 | # (1) here we directly resize the image to the target size (self.height, self.width) 67 | # a better way is to first crop a random patch from the image according to the height-width ratio 68 | # then resize this patch to the target size 69 | # (2) another suggestion is, add some code to filter the depth map to reduce artifacts around 70 | # depth discontinuities 71 | image = F.interpolate(image.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0) 72 | disp = F.interpolate(disp.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0) 73 | return image, disp 74 | 75 | def get_rand_ext(self, bs): 76 | x, y, z = self.trans_range['x'], self.trans_range['y'], self.trans_range['z'] 77 | a, b, c = self.trans_range['a'], self.trans_range['b'], self.trans_range['c'] 78 | cix = self.rand_tensor(x, bs) 79 | ciy = self.rand_tensor(y, bs) 80 | ciz = self.rand_tensor(z, bs) 81 | aix = self.rand_tensor(math.pi / a, bs) 82 | aiy = self.rand_tensor(math.pi / b, bs) 83 | aiz = self.rand_tensor(math.pi / c, bs) 84 | 85 | axisangle = torch.cat([aix, aiy, aiz], dim=-1) # [b,1,3] 86 | translation = torch.cat([cix, ciy, ciz], dim=-1) 87 | 88 | cam_ext = transformation_from_parameters(axisangle, translation) # [b,4,4] 89 | cam_ext_inv = torch.inverse(cam_ext) # [b,4,4] 90 | return cam_ext[:, :-1], cam_ext_inv[:, :-1] 91 | 92 | def rand_tensor(self, r, l): 93 | ''' 94 | return a tensor of size [l], where each element is in range [-r,-r/2] or [r/2,r] 95 | ''' 96 | if r < 0: # we can set a negtive value in self.trans_range to avoid random transformation 97 | return torch.zeros((l, 1, 1)) 98 | rand = torch.rand((l, 1, 1)) 99 | sign = 2 * (torch.randn_like(rand) > 0).float() - 1 100 | return sign * (r / 2 + r / 2 * rand) 101 | 102 | def collect_data(self, batch): 103 | batch = default_collate(batch) 104 | image, disp = batch 105 | image = image.to(self.device) 106 | disp = disp.to(self.device) 107 | rgbd = torch.cat([image, disp], dim=1) # [b,4,h,w] 108 | b = image.shape[0] 109 | 110 | cam_int = self.K.repeat(b, 1, 1) # [b,3,3] 111 | cam_ext, cam_ext_inv = self.get_rand_ext(b) # [b,3,4] 112 | cam_ext = cam_ext.to(self.device) 113 | cam_ext_inv = cam_ext_inv.to(self.device) 114 | 115 | # warp to a random novel view 116 | mesh = self.renderer.construct_mesh(rgbd, cam_int) 117 | warp_image, warp_disp, warp_mask = self.renderer.render_mesh(mesh, cam_int, cam_ext) 118 | 119 | # warp back to the original view 120 | warp_rgbd = torch.cat([warp_image, warp_disp], dim=1) # [b,4,h,w] 121 | warp_mesh = self.renderer.construct_mesh(warp_rgbd, cam_int) 122 | warp_back_image, warp_back_disp, mask = self.renderer.render_mesh(warp_mesh, cam_int, cam_ext_inv) 123 | 124 | # NOTE 125 | # (1) to train the inpainting network, you only need image, disp, and mask 126 | # (2) you can add some morphological operation to refine the mask 127 | return { 128 | "rgb": image, 129 | "disp": disp, 130 | "mask": mask, 131 | "warp_rgb": warp_image, 132 | "warp_disp": warp_disp, 133 | "warp_back_rgb": warp_back_image, 134 | "warp_back_disp": warp_back_disp, 135 | } 136 | 137 | 138 | if __name__ == "__main__": 139 | bs = 8 140 | data = WarpBackStage1Dataset( 141 | data_root="warpback/toydata", 142 | ) 143 | loader = DataLoader( 144 | dataset=data, 145 | batch_size=bs, 146 | shuffle=True, 147 | collate_fn=data.collect_data, 148 | ) 149 | for idx, batch in enumerate(loader): 150 | image, disp, mask = batch["rgb"], batch["disp"], batch["mask"] 151 | w_image, w_disp = batch["warp_rgb"], batch["warp_disp"] 152 | wb_image, wb_disp = batch["warp_back_rgb"], batch["warp_back_disp"] 153 | visual = torch.cat([ 154 | image, 155 | disp.repeat(1, 3, 1, 1), 156 | mask.repeat(1, 3, 1, 1), 157 | wb_image, 158 | wb_disp.repeat(1, 3, 1, 1), 159 | w_image, 160 | w_disp.repeat(1, 3, 1, 1), 161 | ], dim=0) 162 | save_image(visual, "debug/stage1-%03d.jpg" % idx, nrow=bs) 163 | -------------------------------------------------------------------------------- /warpback/stage2_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | sys.path.append("..") 4 | import os 5 | import glob 6 | import math 7 | import numpy as np 8 | from skimage.feature import canny 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data.dataset import Dataset 12 | from torch.utils.data.dataloader import DataLoader, default_collate 13 | from torchvision.utils import save_image 14 | from torchvision import transforms 15 | 16 | from warpback.utils import ( 17 | RGBDRenderer, 18 | image_to_tensor, 19 | disparity_to_tensor, 20 | transformation_from_parameters, 21 | ) 22 | from warpback.networks import get_edge_connect 23 | 24 | 25 | class WarpBackStage2Dataset(Dataset): 26 | def __init__( 27 | self, 28 | data_root, 29 | width=384, 30 | height=256, 31 | depth_dir_name="dpt_depth", 32 | device="cuda", # device of mesh renderer 33 | trans_range={"x":0.2, "y":-1, "z":-1, "a":-1, "b":-1, "c":-1}, # xyz for translation, abc for euler angle 34 | ec_weight_dir="warpback/ecweight", 35 | ): 36 | self.data_root = data_root 37 | self.depth_dir_name = depth_dir_name 38 | self.renderer = RGBDRenderer(device) 39 | self.width = width 40 | self.height = height 41 | self.device = device 42 | self.trans_range = trans_range 43 | self.image_path_list = glob.glob(os.path.join(self.data_root, "*.jpg")) 44 | self.image_path_list += glob.glob(os.path.join(self.data_root, "*.png")) 45 | 46 | # get Stage-1 pretrained inpainting network 47 | self.edge_model, self.inpaint_model, self.disp_model = get_edge_connect(ec_weight_dir) 48 | self.edge_model = self.edge_model.to(self.device) 49 | self.inpaint_model = self.inpaint_model.to(self.device) 50 | self.disp_model = self.disp_model.to(self.device) 51 | 52 | # set intrinsics 53 | self.K = torch.tensor([ 54 | [0.58, 0, 0.5], 55 | [0, 0.58, 0.5], 56 | [0, 0, 1] 57 | ]).to(device) 58 | 59 | def __len__(self): 60 | return len(self.image_path_list) 61 | 62 | def __getitem__(self, idx): 63 | image_path = self.image_path_list[idx] 64 | image_name = os.path.splitext(os.path.basename(image_path))[0] 65 | disp_path = os.path.join(self.data_root, self.depth_dir_name, "%s.png" % image_name) 66 | 67 | image = image_to_tensor(image_path, unsqueeze=False) # [3,h,w] 68 | disp = disparity_to_tensor(disp_path, unsqueeze=False) # [1,h,w] 69 | 70 | # do some data augmentation, ensure the rgbd spatial resolution is (self.height, self.width) 71 | image, disp = self.preprocess_rgbd(image, disp) 72 | 73 | return image, disp 74 | 75 | def preprocess_rgbd(self, image, disp): 76 | # NOTE 77 | # (1) here we directly resize the image to the target size (self.height, self.width) 78 | # a better way is to first crop a random patch from the image according to the height-width ratio 79 | # then resize this patch to the target size 80 | # (2) another suggestion is, add some code to filter the depth map to reduce artifacts around 81 | # depth discontinuities 82 | image = F.interpolate(image.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0) 83 | disp = F.interpolate(disp.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0) 84 | return image, disp 85 | 86 | def get_rand_ext(self, bs): 87 | x, y, z = self.trans_range['x'], self.trans_range['y'], self.trans_range['z'] 88 | a, b, c = self.trans_range['a'], self.trans_range['b'], self.trans_range['c'] 89 | cix = self.rand_tensor(x, bs) 90 | ciy = self.rand_tensor(y, bs) 91 | ciz = self.rand_tensor(z, bs) 92 | aix = self.rand_tensor(math.pi / a, bs) 93 | aiy = self.rand_tensor(math.pi / b, bs) 94 | aiz = self.rand_tensor(math.pi / c, bs) 95 | 96 | axisangle = torch.cat([aix, aiy, aiz], dim=-1) # [b,1,3] 97 | translation = torch.cat([cix, ciy, ciz], dim=-1) 98 | 99 | cam_ext = transformation_from_parameters(axisangle, translation) # [b,4,4] 100 | cam_ext_inv = torch.inverse(cam_ext) # [b,4,4] 101 | return cam_ext[:, :-1], cam_ext_inv[:, :-1] 102 | 103 | def rand_tensor(self, r, l): 104 | ''' 105 | return a tensor of size [l], where each element is in range [-r,-r/2] or [r/2,r] 106 | ''' 107 | if r < 0: # we can set a negtive value in self.trans_range to avoid random transformation 108 | return torch.zeros((l, 1, 1)) 109 | rand = torch.rand((l, 1, 1)) 110 | sign = 2 * (torch.randn_like(rand) > 0).float() - 1 111 | return sign * (r / 2 + r / 2 * rand) 112 | 113 | def inpaint(self, image, disp, mask): 114 | image_gray = transforms.Grayscale()(image) 115 | edge = self.get_edge(image_gray, mask) 116 | 117 | mask_hole = 1 - mask 118 | 119 | # inpaint edge 120 | edge_model_input = torch.cat([image_gray, edge, mask_hole], dim=1) # [b,4,h,w] 121 | edge_inpaint = self.edge_model(edge_model_input) # [b,1,h,w] 122 | 123 | # inpaint RGB 124 | inpaint_model_input = torch.cat([image + mask_hole, edge_inpaint], dim=1) 125 | image_inpaint = self.inpaint_model(inpaint_model_input) 126 | image_merged = image * (1 - mask_hole) + image_inpaint * mask_hole 127 | 128 | # inpaint Disparity 129 | disp_model_input = torch.cat([disp + mask_hole, edge_inpaint], dim=1) 130 | disp_inpaint = self.disp_model(disp_model_input) 131 | disp_merged = disp * (1 - mask_hole) + disp_inpaint * mask_hole 132 | 133 | return image_merged, disp_merged 134 | 135 | def get_edge(self, image_gray, mask): 136 | image_gray_np = image_gray.squeeze(1).cpu().numpy() # [b,h,w] 137 | mask_bool_np = np.array(mask.squeeze(1).cpu(), dtype=np.bool_) # [b,h,w] 138 | edges = [] 139 | for i in range(mask.shape[0]): 140 | cur_edge = canny(image_gray_np[i], sigma=2, mask=mask_bool_np[i]) 141 | edges.append(torch.from_numpy(cur_edge).unsqueeze(0)) # [1,h,w] 142 | edge = torch.cat(edges, dim=0).unsqueeze(1).float() # [b,1,h,w] 143 | return edge.to(self.device) 144 | 145 | def collect_data(self, batch): 146 | batch = default_collate(batch) 147 | image, disp = batch 148 | image = image.to(self.device) 149 | disp = disp.to(self.device) 150 | rgbd = torch.cat([image, disp], dim=1) # [b,4,h,w] 151 | b = image.shape[0] 152 | 153 | cam_int = self.K.repeat(b, 1, 1) # [b,3,3] 154 | cam_ext, cam_ext_inv = self.get_rand_ext(b) # [b,3,4] 155 | cam_ext = cam_ext.to(self.device) 156 | cam_ext_inv = cam_ext_inv.to(self.device) 157 | 158 | # warp to a random novel view and inpaint the holes 159 | # as the source view (input view) to the single-view view synthesis method 160 | mesh = self.renderer.construct_mesh(rgbd, cam_int) 161 | warp_image, warp_disp, warp_mask = self.renderer.render_mesh(mesh, cam_int, cam_ext) 162 | 163 | with torch.no_grad(): 164 | src_image, src_disp = self.inpaint(warp_image, warp_disp, warp_mask) 165 | 166 | return { 167 | "src_rgb": src_image, 168 | "src_disp": src_disp, 169 | "tgt_rgb": image, 170 | "tgt_disp": disp, 171 | "warp_rgb": warp_image, 172 | "warp_disp": warp_disp, 173 | "cam_int": cam_int, # src and tgt view share the same intrinsic 174 | "cam_ext": cam_ext_inv, 175 | } 176 | 177 | 178 | if __name__ == "__main__": 179 | bs = 8 180 | data = WarpBackStage2Dataset( 181 | data_root="warpback/toydata", 182 | ) 183 | loader = DataLoader( 184 | dataset=data, 185 | batch_size=bs, 186 | shuffle=True, 187 | collate_fn=data.collect_data, 188 | ) 189 | for idx, batch in enumerate(loader): 190 | src_rgb, src_disp = batch["src_rgb"], batch["src_disp"] 191 | tgt_rgb, tgt_disp = batch["tgt_rgb"], batch["tgt_disp"] 192 | warp_rgb, warp_disp = batch["warp_rgb"], batch["warp_disp"] 193 | visual = torch.cat([ 194 | warp_rgb, 195 | warp_disp.repeat(1, 3, 1, 1), 196 | src_rgb, 197 | src_disp.repeat(1, 3, 1, 1), 198 | tgt_rgb, 199 | tgt_disp.repeat(1, 3, 1, 1), 200 | ], dim=0) 201 | save_image(visual, "debug/stage2-%03d.jpg" % idx, nrow=bs) 202 | -------------------------------------------------------------------------------- /write_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | 8 | cv2.setNumThreads(0) 9 | cv2.ocl.setUseOpenCL(False) 10 | 11 | TAG_CHAR = np.array([202021.25], np.float32) 12 | 13 | 14 | def readFlow(fn): 15 | """ Read .flo file in Middlebury format""" 16 | # Code adapted from: 17 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 18 | 19 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 20 | # print 'fn = %s'%(fn) 21 | with open(fn, 'rb') as f: 22 | magic = np.fromfile(f, np.float32, count=1) 23 | if 202021.25 != magic: 24 | print('Magic number incorrect. Invalid .flo file') 25 | return None 26 | else: 27 | w = np.fromfile(f, np.int32, count=1) 28 | h = np.fromfile(f, np.int32, count=1) 29 | # print 'Reading %d x %d flo file\n' % (w, h) 30 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 31 | # Reshape data into 3D array (columns, rows, bands) 32 | # The reshape here is for visualization, the original code is (w,h,2) 33 | return np.resize(data, (int(h), int(w), 2)) 34 | 35 | 36 | def readPFM(file): 37 | file = open(file, 'rb') 38 | 39 | color = None 40 | width = None 41 | height = None 42 | scale = None 43 | endian = None 44 | 45 | header = file.readline().rstrip() 46 | if header == b'PF': 47 | color = True 48 | elif header == b'Pf': 49 | color = False 50 | else: 51 | raise Exception('Not a PFM file.') 52 | 53 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 54 | if dim_match: 55 | width, height = map(int, dim_match.groups()) 56 | else: 57 | raise Exception('Malformed PFM header.') 58 | 59 | scale = float(file.readline().rstrip()) 60 | if scale < 0: # little-endian 61 | endian = '<' 62 | scale = -scale 63 | else: 64 | endian = '>' # big-endian 65 | 66 | data = np.fromfile(file, endian + 'f') 67 | shape = (height, width, 3) if color else (height, width) 68 | 69 | data = np.reshape(data, shape) 70 | data = np.flipud(data) 71 | return data 72 | 73 | 74 | def writeFlow(filename, uv, v=None): 75 | """ Write optical flow to file. 76 | 77 | If v is None, uv is assumed to contain both u and v channels, 78 | stacked in depth. 79 | Original code by Deqing Sun, adapted from Daniel Scharstein. 80 | """ 81 | nBands = 2 82 | 83 | if v is None: 84 | assert (uv.ndim == 3) 85 | assert (uv.shape[2] == 2) 86 | u = uv[:, :, 0] 87 | v = uv[:, :, 1] 88 | else: 89 | u = uv 90 | 91 | assert (u.shape == v.shape) 92 | height, width = u.shape 93 | f = open(filename, 'wb') 94 | # write the header 95 | f.write(TAG_CHAR) 96 | np.array(width).astype(np.int32).tofile(f) 97 | np.array(height).astype(np.int32).tofile(f) 98 | # arrange into matrix form 99 | tmp = np.zeros((height, width * nBands)) 100 | tmp[:, np.arange(width) * 2] = u 101 | tmp[:, np.arange(width) * 2 + 1] = v 102 | tmp.astype(np.float32).tofile(f) 103 | f.close() 104 | 105 | 106 | def readFlowKITTI(filename): 107 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) 108 | flow = flow[:, :, ::-1].astype(np.float32) 109 | flow, valid = flow[:, :, :2], flow[:, :, 2] 110 | flow = (flow - 2 ** 15) / 64.0 111 | return flow, valid 112 | 113 | 114 | def readDispKITTI(filename): 115 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 116 | valid = disp > 0.0 117 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 118 | return flow, valid 119 | 120 | 121 | def writeFlowKITTI(filename, uv): 122 | uv = 64.0 * uv + 2 ** 15 123 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 124 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 125 | cv2.imwrite(filename, uv[..., ::-1]) 126 | 127 | 128 | def read_gen(file_name, pil=False): 129 | ext = splitext(file_name)[-1] 130 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 131 | return Image.open(file_name) 132 | elif ext == '.bin' or ext == '.raw': 133 | return np.load(file_name) 134 | elif ext == '.flo': 135 | return readFlow(file_name).astype(np.float32) 136 | elif ext == '.pfm': 137 | flow = readPFM(file_name).astype(np.float32) 138 | if len(flow.shape) == 2: 139 | return flow 140 | else: 141 | return flow[:, :, :-1] 142 | else: 143 | raise ValueError('wrong file type: %s' % ext) 144 | 145 | TAG_FLOAT = 202021.25 146 | def depth_read(filename): 147 | """ Read depth data from file, return as numpy array. """ 148 | f = open(filename,'rb') 149 | check = np.fromfile(f,dtype=np.float32,count=1)[0] 150 | assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) 151 | width = np.fromfile(f,dtype=np.int32,count=1)[0] 152 | height = np.fromfile(f,dtype=np.int32,count=1)[0] 153 | size = width*height 154 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) 155 | depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width)) 156 | return depth 157 | 158 | if __name__ == '__main__': 159 | import os 160 | from tqdm import tqdm 161 | base = '/Extra/guowx/data/dCOCO-mpi/coco/flow' 162 | out = '/Extra/guowx/data/dCOCO-mpi/coco/flo' 163 | if not os.path.exists(out): 164 | os.mkdir(out) 165 | for flow in tqdm(os.listdir(base)): 166 | flo = np.load(os.path.join(base, flow)) 167 | pt = os.path.join(out, flow[:-4]+'.flo') 168 | if not os.path.exists(pt): 169 | writeFlow(pt, flo) 170 | --------------------------------------------------------------------------------