├── LICENSE.txt
├── RAFT
├── .gitignore
├── LICENSE
├── RAFT.png
├── README.md
├── alt_cuda_corr
│ ├── correlation.cpp
│ ├── correlation_kernel.cu
│ └── setup.py
├── chairs_split.txt
├── core
│ ├── __init__.py
│ ├── corr.py
│ ├── datasets.py
│ ├── extractor.py
│ ├── raft.py
│ ├── update.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── augmentor.py
│ │ ├── flow_viz.py
│ │ ├── frame_utils.py
│ │ └── utils.py
├── demo-frames
│ ├── frame_0016.png
│ ├── frame_0017.png
│ ├── frame_0018.png
│ ├── frame_0019.png
│ ├── frame_0020.png
│ ├── frame_0021.png
│ ├── frame_0022.png
│ ├── frame_0023.png
│ ├── frame_0024.png
│ └── frame_0025.png
├── demo.py
├── download_models.sh
├── evaluate.py
├── train.py
├── train_mixed.sh
├── train_standard.sh
└── weights
│ └── raft-things.pth
├── README.md
├── adampiweight
└── adampi_64p.pth
├── bilateral_filter.py
├── core
├── __init__.py
├── corr.py
├── datasets.py
├── extractor.py
├── raft.py
├── update.py
└── utils
│ ├── __init__.py
│ ├── augmentor.py
│ ├── flow_viz.py
│ ├── frame_utils.py
│ └── utils.py
├── external
└── forward_warping
│ ├── compile.sh
│ ├── libwarping.so
│ └── warping.c
├── flow_colors.py
├── gen_3dphoto_dynamic_v2.py
├── geometry.py
├── misc
└── train_image_2_000000_00_1.png
├── model
├── AdaMPI.py
├── CPN
│ ├── decoder.py
│ ├── encoder.py
│ └── unet.py
└── PAN.py
├── moving_obj.py
├── scripts
├── gen_coco.sh
├── gen_test_kitti15.sh
├── gen_train_kitti15.sh
└── gen_train_kitti15_v2.sh
├── utils
├── arrow.py
├── flow_viz.py
├── mpi
│ ├── homography_sampler.py
│ ├── mpi_rendering.py
│ └── rendering_utils.py
├── transform.py
├── utils copy.py
├── utils.py
└── utils_coco.py
├── vis_flow.py
├── warpback
├── networks.py
├── stage1_dataset.py
├── stage2_dataset.py
└── utils.py
└── write_flow.py
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Dmitry Ryumin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/RAFT/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.egg-info
3 | dist
4 | datasets
5 | pytorch_env
6 | models
7 | build
8 | correlation.egg-info
9 |
--------------------------------------------------------------------------------
/RAFT/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, princeton-vl
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/RAFT/RAFT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/RAFT.png
--------------------------------------------------------------------------------
/RAFT/README.md:
--------------------------------------------------------------------------------
1 | # RAFT
2 | This repository contains the source code for our paper:
3 |
4 | [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
5 | ECCV 2020
6 | Zachary Teed and Jia Deng
7 |
8 |
9 |
10 | ## Requirements
11 | The code has been tested with PyTorch 1.6 and Cuda 10.1.
12 | ```Shell
13 | conda create --name raft
14 | conda activate raft
15 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
16 | ```
17 |
18 | ## Demos
19 | Pretrained models can be downloaded by running
20 | ```Shell
21 | ./download_models.sh
22 | ```
23 | or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing)
24 |
25 | You can demo a trained model on a sequence of frames
26 | ```Shell
27 | python demo.py --model=models/raft-things.pth --path=demo-frames
28 | ```
29 |
30 | ## Required Data
31 | To evaluate/train RAFT, you will need to download the required datasets.
32 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
33 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
34 | * [Sintel](http://sintel.is.tue.mpg.de/)
35 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
36 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional)
37 |
38 |
39 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder
40 |
41 | ```Shell
42 | ├── datasets
43 | ├── Sintel
44 | ├── test
45 | ├── training
46 | ├── KITTI
47 | ├── testing
48 | ├── training
49 | ├── devkit
50 | ├── FlyingChairs_release
51 | ├── data
52 | ├── FlyingThings3D
53 | ├── frames_cleanpass
54 | ├── frames_finalpass
55 | ├── optical_flow
56 | ```
57 |
58 | ## Evaluation
59 | You can evaluate a trained model using `evaluate.py`
60 | ```Shell
61 | python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision
62 | ```
63 |
64 | ## Training
65 | We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard
66 | ```Shell
67 | ./train_standard.sh
68 | ```
69 |
70 | If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
71 | ```Shell
72 | ./train_mixed.sh
73 | ```
74 |
75 | ## (Optional) Efficent Implementation
76 | You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
77 | ```Shell
78 | cd alt_cuda_corr && python setup.py install && cd ..
79 | ```
80 | and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
81 |
--------------------------------------------------------------------------------
/RAFT/alt_cuda_corr/correlation.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | // CUDA forward declarations
5 | std::vector corr_cuda_forward(
6 | torch::Tensor fmap1,
7 | torch::Tensor fmap2,
8 | torch::Tensor coords,
9 | int radius);
10 |
11 | std::vector corr_cuda_backward(
12 | torch::Tensor fmap1,
13 | torch::Tensor fmap2,
14 | torch::Tensor coords,
15 | torch::Tensor corr_grad,
16 | int radius);
17 |
18 | // C++ interface
19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
22 |
23 | std::vector corr_forward(
24 | torch::Tensor fmap1,
25 | torch::Tensor fmap2,
26 | torch::Tensor coords,
27 | int radius) {
28 | CHECK_INPUT(fmap1);
29 | CHECK_INPUT(fmap2);
30 | CHECK_INPUT(coords);
31 |
32 | return corr_cuda_forward(fmap1, fmap2, coords, radius);
33 | }
34 |
35 |
36 | std::vector corr_backward(
37 | torch::Tensor fmap1,
38 | torch::Tensor fmap2,
39 | torch::Tensor coords,
40 | torch::Tensor corr_grad,
41 | int radius) {
42 | CHECK_INPUT(fmap1);
43 | CHECK_INPUT(fmap2);
44 | CHECK_INPUT(coords);
45 | CHECK_INPUT(corr_grad);
46 |
47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
48 | }
49 |
50 |
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52 | m.def("forward", &corr_forward, "CORR forward");
53 | m.def("backward", &corr_backward, "CORR backward");
54 | }
--------------------------------------------------------------------------------
/RAFT/alt_cuda_corr/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 |
5 | setup(
6 | name='correlation',
7 | ext_modules=[
8 | CUDAExtension('alt_cuda_corr',
9 | sources=['correlation.cpp', 'correlation_kernel.cu'],
10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
11 | ],
12 | cmdclass={
13 | 'build_ext': BuildExtension
14 | })
15 |
16 |
--------------------------------------------------------------------------------
/RAFT/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/core/__init__.py
--------------------------------------------------------------------------------
/RAFT/core/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class AlternateCorrBlock:
64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65 | self.num_levels = num_levels
66 | self.radius = radius
67 |
68 | self.pyramid = [(fmap1, fmap2)]
69 | for i in range(self.num_levels):
70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72 | self.pyramid.append((fmap1, fmap2))
73 |
74 | def __call__(self, coords):
75 | coords = coords.permute(0, 2, 3, 1)
76 | B, H, W, _ = coords.shape
77 | dim = self.pyramid[0][0].shape[1]
78 |
79 | corr_list = []
80 | for i in range(self.num_levels):
81 | r = self.radius
82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84 |
85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87 | corr_list.append(corr.squeeze(1))
88 |
89 | corr = torch.stack(corr_list, dim=1)
90 | corr = corr.reshape(B, -1, H, W)
91 | return corr / torch.sqrt(torch.tensor(dim).float())
92 |
--------------------------------------------------------------------------------
/RAFT/core/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | if not stride == 1:
20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 |
22 | elif norm_fn == 'batch':
23 | self.norm1 = nn.BatchNorm2d(planes)
24 | self.norm2 = nn.BatchNorm2d(planes)
25 | if not stride == 1:
26 | self.norm3 = nn.BatchNorm2d(planes)
27 |
28 | elif norm_fn == 'instance':
29 | self.norm1 = nn.InstanceNorm2d(planes)
30 | self.norm2 = nn.InstanceNorm2d(planes)
31 | if not stride == 1:
32 | self.norm3 = nn.InstanceNorm2d(planes)
33 |
34 | elif norm_fn == 'none':
35 | self.norm1 = nn.Sequential()
36 | self.norm2 = nn.Sequential()
37 | if not stride == 1:
38 | self.norm3 = nn.Sequential()
39 |
40 | if stride == 1:
41 | self.downsample = None
42 |
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 |
48 | def forward(self, x):
49 | y = x
50 | y = self.relu(self.norm1(self.conv1(y)))
51 | y = self.relu(self.norm2(self.conv2(y)))
52 |
53 | if self.downsample is not None:
54 | x = self.downsample(x)
55 |
56 | return self.relu(x+y)
57 |
58 |
59 |
60 | class BottleneckBlock(nn.Module):
61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62 | super(BottleneckBlock, self).__init__()
63 |
64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | num_groups = planes // 8
70 |
71 | if norm_fn == 'group':
72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75 | if not stride == 1:
76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77 |
78 | elif norm_fn == 'batch':
79 | self.norm1 = nn.BatchNorm2d(planes//4)
80 | self.norm2 = nn.BatchNorm2d(planes//4)
81 | self.norm3 = nn.BatchNorm2d(planes)
82 | if not stride == 1:
83 | self.norm4 = nn.BatchNorm2d(planes)
84 |
85 | elif norm_fn == 'instance':
86 | self.norm1 = nn.InstanceNorm2d(planes//4)
87 | self.norm2 = nn.InstanceNorm2d(planes//4)
88 | self.norm3 = nn.InstanceNorm2d(planes)
89 | if not stride == 1:
90 | self.norm4 = nn.InstanceNorm2d(planes)
91 |
92 | elif norm_fn == 'none':
93 | self.norm1 = nn.Sequential()
94 | self.norm2 = nn.Sequential()
95 | self.norm3 = nn.Sequential()
96 | if not stride == 1:
97 | self.norm4 = nn.Sequential()
98 |
99 | if stride == 1:
100 | self.downsample = None
101 |
102 | else:
103 | self.downsample = nn.Sequential(
104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105 |
106 |
107 | def forward(self, x):
108 | y = x
109 | y = self.relu(self.norm1(self.conv1(y)))
110 | y = self.relu(self.norm2(self.conv2(y)))
111 | y = self.relu(self.norm3(self.conv3(y)))
112 |
113 | if self.downsample is not None:
114 | x = self.downsample(x)
115 |
116 | return self.relu(x+y)
117 |
118 | class BasicEncoder(nn.Module):
119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120 | super(BasicEncoder, self).__init__()
121 | self.norm_fn = norm_fn
122 |
123 | if self.norm_fn == 'group':
124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125 |
126 | elif self.norm_fn == 'batch':
127 | self.norm1 = nn.BatchNorm2d(64)
128 |
129 | elif self.norm_fn == 'instance':
130 | self.norm1 = nn.InstanceNorm2d(64)
131 |
132 | elif self.norm_fn == 'none':
133 | self.norm1 = nn.Sequential()
134 |
135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136 | self.relu1 = nn.ReLU(inplace=True)
137 |
138 | self.in_planes = 64
139 | self.layer1 = self._make_layer(64, stride=1)
140 | self.layer2 = self._make_layer(96, stride=2)
141 | self.layer3 = self._make_layer(128, stride=2)
142 |
143 | # output convolution
144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145 |
146 | self.dropout = None
147 | if dropout > 0:
148 | self.dropout = nn.Dropout2d(p=dropout)
149 |
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154 | if m.weight is not None:
155 | nn.init.constant_(m.weight, 1)
156 | if m.bias is not None:
157 | nn.init.constant_(m.bias, 0)
158 |
159 | def _make_layer(self, dim, stride=1):
160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162 | layers = (layer1, layer2)
163 |
164 | self.in_planes = dim
165 | return nn.Sequential(*layers)
166 |
167 |
168 | def forward(self, x):
169 |
170 | # if input is list, combine batch dimension
171 | is_list = isinstance(x, tuple) or isinstance(x, list)
172 | if is_list:
173 | batch_dim = x[0].shape[0]
174 | x = torch.cat(x, dim=0)
175 |
176 | x = self.conv1(x)
177 | x = self.norm1(x)
178 | x = self.relu1(x)
179 |
180 | x = self.layer1(x)
181 | x = self.layer2(x)
182 | x = self.layer3(x)
183 |
184 | x = self.conv2(x)
185 |
186 | if self.training and self.dropout is not None:
187 | x = self.dropout(x)
188 |
189 | if is_list:
190 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
191 |
192 | return x
193 |
194 |
195 | class SmallEncoder(nn.Module):
196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
197 | super(SmallEncoder, self).__init__()
198 | self.norm_fn = norm_fn
199 |
200 | if self.norm_fn == 'group':
201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
202 |
203 | elif self.norm_fn == 'batch':
204 | self.norm1 = nn.BatchNorm2d(32)
205 |
206 | elif self.norm_fn == 'instance':
207 | self.norm1 = nn.InstanceNorm2d(32)
208 |
209 | elif self.norm_fn == 'none':
210 | self.norm1 = nn.Sequential()
211 |
212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
213 | self.relu1 = nn.ReLU(inplace=True)
214 |
215 | self.in_planes = 32
216 | self.layer1 = self._make_layer(32, stride=1)
217 | self.layer2 = self._make_layer(64, stride=2)
218 | self.layer3 = self._make_layer(96, stride=2)
219 |
220 | self.dropout = None
221 | if dropout > 0:
222 | self.dropout = nn.Dropout2d(p=dropout)
223 |
224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
225 |
226 | for m in self.modules():
227 | if isinstance(m, nn.Conv2d):
228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230 | if m.weight is not None:
231 | nn.init.constant_(m.weight, 1)
232 | if m.bias is not None:
233 | nn.init.constant_(m.bias, 0)
234 |
235 | def _make_layer(self, dim, stride=1):
236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
238 | layers = (layer1, layer2)
239 |
240 | self.in_planes = dim
241 | return nn.Sequential(*layers)
242 |
243 |
244 | def forward(self, x):
245 |
246 | # if input is list, combine batch dimension
247 | is_list = isinstance(x, tuple) or isinstance(x, list)
248 | if is_list:
249 | batch_dim = x[0].shape[0]
250 | x = torch.cat(x, dim=0)
251 |
252 | x = self.conv1(x)
253 | x = self.norm1(x)
254 | x = self.relu1(x)
255 |
256 | x = self.layer1(x)
257 | x = self.layer2(x)
258 | x = self.layer3(x)
259 | x = self.conv2(x)
260 |
261 | if self.training and self.dropout is not None:
262 | x = self.dropout(x)
263 |
264 | if is_list:
265 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
266 |
267 | return x
268 |
--------------------------------------------------------------------------------
/RAFT/core/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from update import BasicUpdateBlock, SmallUpdateBlock
7 | from extractor import BasicEncoder, SmallEncoder
8 | from corr import CorrBlock, AlternateCorrBlock
9 | from utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except:
14 | # dummy autocast for PyTorch < 1.6
15 | class autocast:
16 | def __init__(self, enabled):
17 | pass
18 | def __enter__(self):
19 | pass
20 | def __exit__(self, *args):
21 | pass
22 |
23 |
24 | class RAFT(nn.Module):
25 | def __init__(self, args):
26 | super(RAFT, self).__init__()
27 | self.args = args
28 |
29 | if args.small:
30 | self.hidden_dim = hdim = 96
31 | self.context_dim = cdim = 64
32 | args.corr_levels = 4
33 | args.corr_radius = 3
34 |
35 | else:
36 | self.hidden_dim = hdim = 128
37 | self.context_dim = cdim = 128
38 | args.corr_levels = 4
39 | args.corr_radius = 4
40 |
41 | if 'dropout' not in self.args:
42 | self.args.dropout = 0
43 |
44 | if 'alternate_corr' not in self.args:
45 | self.args.alternate_corr = False
46 |
47 | # feature network, context network, and update block
48 | if args.small:
49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52 |
53 | else:
54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57 |
58 | def freeze_bn(self):
59 | for m in self.modules():
60 | if isinstance(m, nn.BatchNorm2d):
61 | m.eval()
62 |
63 | def initialize_flow(self, img):
64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
65 | N, C, H, W = img.shape
66 | coords0 = coords_grid(N, H//8, W//8, device=img.device)
67 | coords1 = coords_grid(N, H//8, W//8, device=img.device)
68 |
69 | # optical flow computed as difference: flow = coords1 - coords0
70 | return coords0, coords1
71 |
72 | def upsample_flow(self, flow, mask):
73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
74 | N, _, H, W = flow.shape
75 | mask = mask.view(N, 1, 9, 8, 8, H, W)
76 | mask = torch.softmax(mask, dim=2)
77 |
78 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
80 |
81 | up_flow = torch.sum(mask * up_flow, dim=2)
82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
83 | return up_flow.reshape(N, 2, 8*H, 8*W)
84 |
85 |
86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
87 | """ Estimate optical flow between pair of frames """
88 |
89 | image1 = 2 * (image1 / 255.0) - 1.0
90 | image2 = 2 * (image2 / 255.0) - 1.0
91 |
92 | image1 = image1.contiguous()
93 | image2 = image2.contiguous()
94 |
95 | hdim = self.hidden_dim
96 | cdim = self.context_dim
97 |
98 | # run the feature network
99 | with autocast(enabled=self.args.mixed_precision):
100 | fmap1, fmap2 = self.fnet([image1, image2])
101 |
102 | fmap1 = fmap1.float()
103 | fmap2 = fmap2.float()
104 | if self.args.alternate_corr:
105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
106 | else:
107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
108 |
109 | # run the context network
110 | with autocast(enabled=self.args.mixed_precision):
111 | cnet = self.cnet(image1)
112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
113 | net = torch.tanh(net)
114 | inp = torch.relu(inp)
115 |
116 | coords0, coords1 = self.initialize_flow(image1)
117 |
118 | if flow_init is not None:
119 | coords1 = coords1 + flow_init
120 |
121 | flow_predictions = []
122 | for itr in range(iters):
123 | coords1 = coords1.detach()
124 | corr = corr_fn(coords1) # index correlation volume
125 |
126 | flow = coords1 - coords0
127 | with autocast(enabled=self.args.mixed_precision):
128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
129 |
130 | # F(t+1) = F(t) + \Delta(t)
131 | coords1 = coords1 + delta_flow
132 |
133 | # upsample predictions
134 | if up_mask is None:
135 | flow_up = upflow8(coords1 - coords0)
136 | else:
137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
138 |
139 | flow_predictions.append(flow_up)
140 |
141 | if test_mode:
142 | return coords1 - coords0, flow_up
143 |
144 | return flow_predictions
145 |
--------------------------------------------------------------------------------
/RAFT/core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self, args):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self, args):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, args, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(args)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, args, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.args = args
118 | self.encoder = BasicMotionEncoder(args)
119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121 |
122 | self.mask = nn.Sequential(
123 | nn.Conv2d(128, 256, 3, padding=1),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(256, 64*9, 1, padding=0))
126 |
127 | def forward(self, net, inp, corr, flow, upsample=True):
128 | motion_features = self.encoder(flow, corr)
129 | inp = torch.cat([inp, motion_features], dim=1)
130 |
131 | net = self.gru(net, inp)
132 | delta_flow = self.flow_head(net)
133 |
134 | # scale mask to balence gradients
135 | mask = .25 * self.mask(net)
136 | return net, mask, delta_flow
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/RAFT/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/core/utils/__init__.py
--------------------------------------------------------------------------------
/RAFT/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/RAFT/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/RAFT/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd, device):
75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0016.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0016.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0017.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0017.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0018.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0018.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0019.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0019.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0020.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0020.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0021.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0021.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0022.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0022.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0023.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0023.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0024.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/demo-frames/frame_0025.png
--------------------------------------------------------------------------------
/RAFT/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | import argparse
5 | import os
6 | import cv2
7 | import glob
8 | import numpy as np
9 | import torch
10 | from PIL import Image
11 |
12 | from raft import RAFT
13 | from utils import flow_viz
14 | from utils.utils import InputPadder
15 |
16 |
17 |
18 | DEVICE = 'cuda'
19 |
20 | def load_image(imfile):
21 | img = np.array(Image.open(imfile)).astype(np.uint8)
22 | img = torch.from_numpy(img).permute(2, 0, 1).float()
23 | return img[None].to(DEVICE)
24 |
25 |
26 | def viz(img, flo):
27 | img = img[0].permute(1,2,0).cpu().numpy()
28 | flo = flo[0].permute(1,2,0).cpu().numpy()
29 |
30 | # map flow to rgb image
31 | flo = flow_viz.flow_to_image(flo)
32 | img_flo = np.concatenate([img, flo], axis=0)
33 |
34 | # import matplotlib.pyplot as plt
35 | # plt.imshow(img_flo / 255.0)
36 | # plt.show()
37 |
38 | cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
39 | cv2.waitKey()
40 |
41 |
42 | def demo(args):
43 | model = torch.nn.DataParallel(RAFT(args))
44 | model.load_state_dict(torch.load(args.model))
45 |
46 | model = model.module
47 | model.to(DEVICE)
48 | model.eval()
49 |
50 | with torch.no_grad():
51 | images = glob.glob(os.path.join(args.path, '*.png')) + \
52 | glob.glob(os.path.join(args.path, '*.jpg'))
53 |
54 | images = sorted(images)
55 | for imfile1, imfile2 in zip(images[:-1], images[1:]):
56 | image1 = load_image(imfile1)
57 | image2 = load_image(imfile2)
58 |
59 | padder = InputPadder(image1.shape)
60 | image1, image2 = padder.pad(image1, image2)
61 |
62 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
63 | viz(image1, flow_up)
64 |
65 |
66 | if __name__ == '__main__':
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument('--model', help="restore checkpoint")
69 | parser.add_argument('--path', help="dataset for evaluation")
70 | parser.add_argument('--small', action='store_true', help='use small model')
71 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
72 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
73 | args = parser.parse_args()
74 |
75 | demo(args)
76 |
--------------------------------------------------------------------------------
/RAFT/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget https://dl.dropboxusercontent.com/s/4j4z58wuv8o0mfz/models.zip
3 | unzip models.zip
4 |
--------------------------------------------------------------------------------
/RAFT/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | from PIL import Image
5 | import argparse
6 | import os
7 | import time
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | import matplotlib.pyplot as plt
12 |
13 | import datasets
14 | from utils import flow_viz
15 | from utils import frame_utils
16 |
17 | from raft import RAFT
18 | from utils.utils import InputPadder, forward_interpolate
19 |
20 |
21 | @torch.no_grad()
22 | def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
23 | """ Create submission for the Sintel leaderboard """
24 | model.eval()
25 | for dstype in ['clean', 'final']:
26 | test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
27 |
28 | flow_prev, sequence_prev = None, None
29 | for test_id in range(len(test_dataset)):
30 | image1, image2, (sequence, frame) = test_dataset[test_id]
31 | if sequence != sequence_prev:
32 | flow_prev = None
33 |
34 | padder = InputPadder(image1.shape)
35 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
36 |
37 | flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
38 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
39 |
40 | if warm_start:
41 | flow_prev = forward_interpolate(flow_low[0])[None].cuda()
42 |
43 | output_dir = os.path.join(output_path, dstype, sequence)
44 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
45 |
46 | if not os.path.exists(output_dir):
47 | os.makedirs(output_dir)
48 |
49 | frame_utils.writeFlow(output_file, flow)
50 | sequence_prev = sequence
51 |
52 |
53 | @torch.no_grad()
54 | def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
55 | """ Create submission for the Sintel leaderboard """
56 | model.eval()
57 | test_dataset = datasets.KITTI(split='testing', aug_params=None)
58 |
59 | if not os.path.exists(output_path):
60 | os.makedirs(output_path)
61 |
62 | for test_id in range(len(test_dataset)):
63 | image1, image2, (frame_id, ) = test_dataset[test_id]
64 | padder = InputPadder(image1.shape, mode='kitti')
65 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
66 |
67 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
68 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
69 |
70 | output_filename = os.path.join(output_path, frame_id)
71 | frame_utils.writeFlowKITTI(output_filename, flow)
72 |
73 |
74 | @torch.no_grad()
75 | def validate_chairs(model, iters=24):
76 | """ Perform evaluation on the FlyingChairs (test) split """
77 | model.eval()
78 | epe_list = []
79 |
80 | val_dataset = datasets.FlyingChairs(split='validation')
81 | for val_id in range(len(val_dataset)):
82 | image1, image2, flow_gt, _ = val_dataset[val_id]
83 | image1 = image1[None].cuda()
84 | image2 = image2[None].cuda()
85 |
86 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
87 | epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
88 | epe_list.append(epe.view(-1).numpy())
89 |
90 | epe = np.mean(np.concatenate(epe_list))
91 | print("Validation Chairs EPE: %f" % epe)
92 | return {'chairs': epe}
93 |
94 |
95 | @torch.no_grad()
96 | def validate_sintel(model, iters=32):
97 | """ Peform validation using the Sintel (train) split """
98 | model.eval()
99 | results = {}
100 | for dstype in ['clean', 'final']:
101 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
102 | epe_list = []
103 |
104 | for val_id in range(len(val_dataset)):
105 | image1, image2, flow_gt, _ = val_dataset[val_id]
106 | image1 = image1[None].cuda()
107 | image2 = image2[None].cuda()
108 |
109 | padder = InputPadder(image1.shape)
110 | image1, image2 = padder.pad(image1, image2)
111 |
112 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
113 | flow = padder.unpad(flow_pr[0]).cpu()
114 |
115 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
116 | epe_list.append(epe.view(-1).numpy())
117 |
118 | epe_all = np.concatenate(epe_list)
119 | epe = np.mean(epe_all)
120 | px1 = np.mean(epe_all<1)
121 | px3 = np.mean(epe_all<3)
122 | px5 = np.mean(epe_all<5)
123 |
124 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
125 | results[dstype] = np.mean(epe_list)
126 |
127 | return results
128 |
129 |
130 | @torch.no_grad()
131 | def validate_kitti(model, iters=24):
132 | """ Peform validation using the KITTI-2015 (train) split """
133 | model.eval()
134 | val_dataset = datasets.KITTI(split='training')
135 |
136 | out_list, epe_list = [], []
137 | for val_id in range(len(val_dataset)):
138 | image1, image2, flow_gt, valid_gt = val_dataset[val_id]
139 | image1 = image1[None].cuda()
140 | image2 = image2[None].cuda()
141 |
142 | padder = InputPadder(image1.shape, mode='kitti')
143 | image1, image2 = padder.pad(image1, image2)
144 |
145 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
146 | flow = padder.unpad(flow_pr[0]).cpu()
147 |
148 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
149 | mag = torch.sum(flow_gt**2, dim=0).sqrt()
150 |
151 | epe = epe.view(-1)
152 | mag = mag.view(-1)
153 | val = valid_gt.view(-1) >= 0.5
154 |
155 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
156 | epe_list.append(epe[val].mean().item())
157 | out_list.append(out[val].cpu().numpy())
158 |
159 | epe_list = np.array(epe_list)
160 | out_list = np.concatenate(out_list)
161 |
162 | epe = np.mean(epe_list)
163 | f1 = 100 * np.mean(out_list)
164 |
165 | print("Validation KITTI: %f, %f" % (epe, f1))
166 | return {'kitti-epe': epe, 'kitti-f1': f1}
167 |
168 |
169 | if __name__ == '__main__':
170 | parser = argparse.ArgumentParser()
171 | parser.add_argument('--model', help="restore checkpoint")
172 | parser.add_argument('--dataset', help="dataset for evaluation")
173 | parser.add_argument('--small', action='store_true', help='use small model')
174 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
175 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
176 | args = parser.parse_args()
177 |
178 | model = torch.nn.DataParallel(RAFT(args))
179 | model.load_state_dict(torch.load(args.model))
180 |
181 | model.cuda()
182 | model.eval()
183 |
184 | # create_sintel_submission(model.module, warm_start=True)
185 | # create_kitti_submission(model.module)
186 |
187 | with torch.no_grad():
188 | if args.dataset == 'chairs':
189 | validate_chairs(model.module)
190 |
191 | elif args.dataset == 'sintel':
192 | validate_sintel(model.module)
193 |
194 | elif args.dataset == 'kitti':
195 | validate_kitti(model.module)
196 |
197 |
198 |
--------------------------------------------------------------------------------
/RAFT/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import sys
3 | sys.path.append('core')
4 |
5 | import argparse
6 | import os
7 | import cv2
8 | import time
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.optim as optim
15 | import torch.nn.functional as F
16 |
17 | from torch.utils.data import DataLoader
18 | from raft import RAFT
19 | import evaluate
20 | import datasets
21 |
22 | from torch.utils.tensorboard import SummaryWriter
23 |
24 | try:
25 | from torch.cuda.amp import GradScaler
26 | except:
27 | # dummy GradScaler for PyTorch < 1.6
28 | class GradScaler:
29 | def __init__(self):
30 | pass
31 | def scale(self, loss):
32 | return loss
33 | def unscale_(self, optimizer):
34 | pass
35 | def step(self, optimizer):
36 | optimizer.step()
37 | def update(self):
38 | pass
39 |
40 |
41 | # exclude extremly large displacements
42 | MAX_FLOW = 400
43 | SUM_FREQ = 100
44 | VAL_FREQ = 5000
45 |
46 |
47 | def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
48 | """ Loss function defined over sequence of flow predictions """
49 |
50 | n_predictions = len(flow_preds)
51 | flow_loss = 0.0
52 |
53 | # exlude invalid pixels and extremely large diplacements
54 | mag = torch.sum(flow_gt**2, dim=1).sqrt()
55 | valid = (valid >= 0.5) & (mag < max_flow)
56 |
57 | for i in range(n_predictions):
58 | i_weight = gamma**(n_predictions - i - 1)
59 | i_loss = (flow_preds[i] - flow_gt).abs()
60 | flow_loss += i_weight * (valid[:, None] * i_loss).mean()
61 |
62 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
63 | epe = epe.view(-1)[valid.view(-1)]
64 |
65 | metrics = {
66 | 'epe': epe.mean().item(),
67 | '1px': (epe < 1).float().mean().item(),
68 | '3px': (epe < 3).float().mean().item(),
69 | '5px': (epe < 5).float().mean().item(),
70 | }
71 |
72 | return flow_loss, metrics
73 |
74 |
75 | def count_parameters(model):
76 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
77 |
78 |
79 | def fetch_optimizer(args, model):
80 | """ Create the optimizer and learning rate scheduler """
81 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
82 |
83 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
84 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
85 |
86 | return optimizer, scheduler
87 |
88 |
89 | class Logger:
90 | def __init__(self, model, scheduler):
91 | self.model = model
92 | self.scheduler = scheduler
93 | self.total_steps = 0
94 | self.running_loss = {}
95 | self.writer = None
96 |
97 | def _print_training_status(self):
98 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
99 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
100 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
101 |
102 | # print the training status
103 | print(training_str + metrics_str)
104 |
105 | if self.writer is None:
106 | self.writer = SummaryWriter()
107 |
108 | for k in self.running_loss:
109 | self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
110 | self.running_loss[k] = 0.0
111 |
112 | def push(self, metrics):
113 | self.total_steps += 1
114 |
115 | for key in metrics:
116 | if key not in self.running_loss:
117 | self.running_loss[key] = 0.0
118 |
119 | self.running_loss[key] += metrics[key]
120 |
121 | if self.total_steps % SUM_FREQ == SUM_FREQ-1:
122 | self._print_training_status()
123 | self.running_loss = {}
124 |
125 | def write_dict(self, results):
126 | if self.writer is None:
127 | self.writer = SummaryWriter()
128 |
129 | for key in results:
130 | self.writer.add_scalar(key, results[key], self.total_steps)
131 |
132 | def close(self):
133 | self.writer.close()
134 |
135 |
136 | def train(args):
137 |
138 | model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
139 | print("Parameter Count: %d" % count_parameters(model))
140 |
141 | if args.restore_ckpt is not None:
142 | model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
143 |
144 | model.cuda()
145 | model.train()
146 |
147 | if args.stage != 'chairs':
148 | model.module.freeze_bn()
149 |
150 | train_loader = datasets.fetch_dataloader(args)
151 | optimizer, scheduler = fetch_optimizer(args, model)
152 |
153 | total_steps = 0
154 | scaler = GradScaler(enabled=args.mixed_precision)
155 | logger = Logger(model, scheduler)
156 |
157 | VAL_FREQ = 5000
158 | add_noise = True
159 |
160 | should_keep_training = True
161 | while should_keep_training:
162 |
163 | for i_batch, data_blob in enumerate(train_loader):
164 | optimizer.zero_grad()
165 | image1, image2, flow, valid = [x.cuda() for x in data_blob]
166 |
167 | if args.add_noise:
168 | stdv = np.random.uniform(0.0, 5.0)
169 | image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
170 | image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
171 |
172 | flow_predictions = model(image1, image2, iters=args.iters)
173 |
174 | loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
175 | scaler.scale(loss).backward()
176 | scaler.unscale_(optimizer)
177 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
178 |
179 | scaler.step(optimizer)
180 | scheduler.step()
181 | scaler.update()
182 |
183 | logger.push(metrics)
184 |
185 | if total_steps % VAL_FREQ == VAL_FREQ - 1:
186 | PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
187 | torch.save(model.state_dict(), PATH)
188 |
189 | results = {}
190 | for val_dataset in args.validation:
191 | if val_dataset == 'chairs':
192 | results.update(evaluate.validate_chairs(model.module))
193 | elif val_dataset == 'sintel':
194 | results.update(evaluate.validate_sintel(model.module))
195 | elif val_dataset == 'kitti':
196 | results.update(evaluate.validate_kitti(model.module))
197 |
198 | logger.write_dict(results)
199 |
200 | model.train()
201 | if args.stage != 'chairs':
202 | model.module.freeze_bn()
203 |
204 | total_steps += 1
205 |
206 | if total_steps > args.num_steps:
207 | should_keep_training = False
208 | break
209 |
210 | logger.close()
211 | PATH = 'checkpoints/%s.pth' % args.name
212 | torch.save(model.state_dict(), PATH)
213 |
214 | return PATH
215 |
216 |
217 | if __name__ == '__main__':
218 | parser = argparse.ArgumentParser()
219 | parser.add_argument('--name', default='raft', help="name your experiment")
220 | parser.add_argument('--stage', help="determines which dataset to use for training")
221 | parser.add_argument('--restore_ckpt', help="restore checkpoint")
222 | parser.add_argument('--data_root', type=str, help="restore checkpoint")
223 | parser.add_argument('--small', action='store_true', help='use small model')
224 | parser.add_argument('--validation', type=str, nargs='+')
225 |
226 | parser.add_argument('--lr', type=float, default=0.00002)
227 | parser.add_argument('--num_steps', type=int, default=100000)
228 | parser.add_argument('--batch_size', type=int, default=6)
229 | parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
230 | parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
231 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
232 |
233 | parser.add_argument('--iters', type=int, default=12)
234 | parser.add_argument('--wdecay', type=float, default=.00005)
235 | parser.add_argument('--epsilon', type=float, default=1e-8)
236 | parser.add_argument('--clip', type=float, default=1.0)
237 | parser.add_argument('--dropout', type=float, default=0.0)
238 | parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
239 | parser.add_argument('--add_noise', action='store_true')
240 | args = parser.parse_args()
241 |
242 | torch.manual_seed(1234)
243 | np.random.seed(1234)
244 |
245 | if not os.path.isdir('checkpoints'):
246 | os.mkdir('checkpoints')
247 |
248 | train(args)
--------------------------------------------------------------------------------
/RAFT/train_mixed.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p checkpoints
3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision
6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision
7 |
--------------------------------------------------------------------------------
/RAFT/train_standard.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CUDA_VISIBLE_DEVICES=3,5 python -u train.py --name mpi-0.1-0.25-rf \
3 | --stage mpi-flow --validation kitti \
4 | --restore_ckpt weights/raft-things.pth \
5 | --gpus 0 1 --num_steps 50000 --batch_size 6 \
6 | --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 \
7 | --data_root /data1/liangyingping/MPI-Flow/dataset/debug
--------------------------------------------------------------------------------
/RAFT/weights/raft-things.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/RAFT/weights/raft-things.pth
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [ICCV 2023] MPI-Flow: Learning Realistic Optical Flow with Multiplane Images
2 |
3 | [Paper](https://arxiv.org/abs/2309.06714) | [Checkpoints](https://drive.google.com/drive/folders/1q0UxlswSwZjLgLkEjUNmBuVi0LJfY_b7?usp=sharing) | [Project Page](https://sites.google.com/view/mpi-flow) | [My Home Page](https://sharpiless.github.io/)
4 |
5 | ## Update
6 | - **2024.05.01** - Update large-scale dataset generation [scripts](scripts).
7 | - **2023.12.18** - Code for online training released at [Sharpiless/Train-RAFT-from-single-view-images](https://github.com/Sharpiless/Train-RAFT-from-single-view-images).
8 | - **2023.09.13** - Code released.
9 |
10 | # MPI-Flow
11 |
12 | This is a PyTorch implementation of our paper.
13 |
14 | **Abstract**: *The accuracy of learning-based optical flow estimation models heavily relies on the realism of the training datasets. Current approaches for generating such datasets either employ synthetic data or generate images with limited realism. However, the domain gap of these data with real-world scenes constrains the generalization of the trained model to real-world applications. To address this issue, we investigate generating realistic optical flow datasets from real-world images. Firstly, to generate highly realistic new images, we construct a layered depth representation, known as multiplane images (MPI), from single-view images. This allows us to generate novel view images that are highly realistic. To generate optical flow maps that correspond accurately to the new image, we calculate the optical flows of each plane using the camera matrix and plane depths. We then project these layered optical flows into the output optical flow map with volume rendering. Secondly, to ensure the realism of motion, we present an independent object motion module that can separate the camera and dynamic object motion in MPI. This module addresses the deficiency in MPI-based single-view methods, where optical flow is generated only by camera motion and does not account for any object movement. We additionally devise a depth-aware inpainting module to merge new images with dynamic objects and address unnatural motion occlusions. We show the superior performance of our method through extensive experiments on real-world datasets. Moreover, our approach achieves state-of-the-art performance in both unsupervised and supervised training of learning-based models.*
15 |
16 | # Document for *MPI-Flow*
17 | ## Environment
18 | ```
19 | conda create -n mpiflow python=3.8
20 |
21 | # here we use pytorch 1.11.0 and CUDA 11.3 for an example
22 |
23 | # install pytorch
24 | pip install https://download.pytorch.org/whl/cu113/torch-1.11.0%2Bcu113-cp38-cp38-linux_x86_64.whl
25 |
26 | # install torchvision
27 | pip install https://download.pytorch.org/whl/cu113/torchvision-0.12.0%2Bcu113-cp38-cp38-linux_x86_64.whl
28 |
29 | # install pytorch3d
30 | conda install https://anaconda.org/pytorch3d/pytorch3d/0.6.2/download/linux-64/pytorch3d-0.6.2-py38_cu113_pyt1100.tar.bz2
31 |
32 | # install other libs
33 | pip install \
34 | numpy==1.19 \
35 | scikit-image==0.19.1 \
36 | scipy==1.8.0 \
37 | pillow==9.0.1 \
38 | opencv-python==4.4.0.40 \
39 | tqdm==4.64.0 \
40 | moviepy==1.0.3 \
41 | pyyaml \
42 | matplotlib \
43 | scikit-learn \
44 | lpips \
45 | kornia \
46 | focal_frequency_loss \
47 | tensorboard \
48 | transformers
49 |
50 | cd external/forward_warping
51 | bash compile.sh
52 | cd ../..
53 | ```
54 |
55 | ## Usage
56 |
57 | The input to our MPI-Flow is a single in-the-wild image with its monocular depth estimation and main object mask.
58 | You can use the [MiDaS](https://github.com/isl-org/MiDaS) model to obtain the estimated depth map and use the [Mask2Former](https://github.com/facebookresearch/Mask2Former) to obtain the object mask.
59 |
60 | We provide some example inputs in `./images_kitti`, you can use the image, depth, and mask here to test our model.
61 | Here is an example to run the code:
62 |
63 | ```
64 | python gen_3dphoto_dynamic.py
65 | ```
66 |
67 | Then, you will see the result like that:
68 |
69 |
70 | ## Training online
71 |
72 | We have also released an online training version at [https://github.com/Sharpiless/Train-RAFT-from-single-view-images](https://github.com/Sharpiless/Train-RAFT-from-single-view-images).
73 |
74 | ## Performance (Online Training, single V100 GPU)
75 | 3.2w steps on COCO:
76 | | Dataset | EPE | F1 |
77 | | :-------: | :--------: | :-----: |
78 | | KITTI-15 (train) | 3.537468 | 11.694042 |
79 | | Sintel.C | 1.857986 | - |
80 | | Sintel.F | 3.250774 | - |
81 |
82 | 32.0w steps on COCO:
83 | | Dataset | EPE | F1 |
84 | | :-------: | :--------: | :-----: |
85 | | KITTI-15 (train) | 3.586417 | 9.887916 |
86 | | Sintel.C | - | - |
87 | | Sintel.F | - | - |
88 |
89 | ## Checkpoints
90 |
91 | | Image Source | Method | KITTI 12 | | KITTI 15 | |
92 | |--------------|--------------------|----------|-----|----------|-----|
93 | | | | EPE ↓ | F1 ↓| EPE ↓ | F1 ↓|
94 | | COCO | Depthstillation [1]| 1.74 | 6.81| 3.45 | 13.08|
95 | | | RealFlow [12] | N/A | N/A | N/A | N/A |
96 | | | MPI-Flow (ours) | 1.36 | 4.91| 3.44 | 10.66|
97 | | DAVIS | Depthstillation [1]| 1.81 | 6.89| 3.79 | 13.22|
98 | | | RealFlow [12] | 1.59 | 6.08| 3.55 | 12.52|
99 | | | MPI-Flow (ours) | 1.41 | 5.36| 3.32 | 10.47|
100 | | KITTI 15 Test| Depthstillation [1]| 1.77 | 5.97| 3.99 | 13.34|
101 | | | RealFlow [12] | 1.27 | 5.16| 2.43 | 8.86 |
102 | | | MPI-Flow (ours) | 1.24 | 4.51| 2.16 | 7.30 |
103 | | KITTI 15 Train| Depthstillation [1]| 1.67 | 5.71| {2.99} | {9.94}|
104 | | | RealFlow [12] | 1.25 | 5.02| {2.17} | {8.64}|
105 | | | MPI-Flow (ours) | 1.26 | 4.66| {1.88} | {7.16}|
106 |
107 |
108 | Checkpoints to reproduce our results in Table 1 can be downloaded in [Google Drive](https://drive.google.com/drive/folders/1q0UxlswSwZjLgLkEjUNmBuVi0LJfY_b7?usp=sharing).
109 |
110 | You can use the code in [RAFT](https://github.com/princeton-vl/RAFT) to evaluate/train the models.
111 |
112 | ## Contact
113 | If you have any questions, please contact Yingping Liang (liangyingping@bit.edu.cn).
114 |
115 | ## License and Citation
116 | This repository can only be used for personal/research/non-commercial purposes.
117 | Please cite the following paper if this model helps your research:
118 |
119 | @inproceedings{liang2023mpi,
120 | author = {Liang, Yingping and Liu, Jiaming and Zhang, Debing and Ying, Fu},
121 | title = {MPI-Flow: Learning Realistic Optical Flow with Multiplane Images},
122 | booktitle = {In the IEEE International Conference on Computer Vision (ICCV)},
123 | year={2023}
124 | }
125 |
126 | ## Acknowledgments
127 | * The code is heavily borrowed from [AdaMPI](https://github.com/yxuhan/AdaMPI), we thank the authors for their great effort.
128 |
--------------------------------------------------------------------------------
/adampiweight/adampi_64p.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/adampiweight/adampi_64p.pth
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/core/__init__.py
--------------------------------------------------------------------------------
/core/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class AlternateCorrBlock:
64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65 | self.num_levels = num_levels
66 | self.radius = radius
67 |
68 | self.pyramid = [(fmap1, fmap2)]
69 | for i in range(self.num_levels):
70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72 | self.pyramid.append((fmap1, fmap2))
73 |
74 | def __call__(self, coords):
75 | coords = coords.permute(0, 2, 3, 1)
76 | B, H, W, _ = coords.shape
77 | dim = self.pyramid[0][0].shape[1]
78 |
79 | corr_list = []
80 | for i in range(self.num_levels):
81 | r = self.radius
82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84 |
85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87 | corr_list.append(corr.squeeze(1))
88 |
89 | corr = torch.stack(corr_list, dim=1)
90 | corr = corr.reshape(B, -1, H, W)
91 | return corr / torch.sqrt(torch.tensor(dim).float())
92 |
--------------------------------------------------------------------------------
/core/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | if not stride == 1:
20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 |
22 | elif norm_fn == 'batch':
23 | self.norm1 = nn.BatchNorm2d(planes)
24 | self.norm2 = nn.BatchNorm2d(planes)
25 | if not stride == 1:
26 | self.norm3 = nn.BatchNorm2d(planes)
27 |
28 | elif norm_fn == 'instance':
29 | self.norm1 = nn.InstanceNorm2d(planes)
30 | self.norm2 = nn.InstanceNorm2d(planes)
31 | if not stride == 1:
32 | self.norm3 = nn.InstanceNorm2d(planes)
33 |
34 | elif norm_fn == 'none':
35 | self.norm1 = nn.Sequential()
36 | self.norm2 = nn.Sequential()
37 | if not stride == 1:
38 | self.norm3 = nn.Sequential()
39 |
40 | if stride == 1:
41 | self.downsample = None
42 |
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 |
48 | def forward(self, x):
49 | y = x
50 | y = self.relu(self.norm1(self.conv1(y)))
51 | y = self.relu(self.norm2(self.conv2(y)))
52 |
53 | if self.downsample is not None:
54 | x = self.downsample(x)
55 |
56 | return self.relu(x+y)
57 |
58 |
59 |
60 | class BottleneckBlock(nn.Module):
61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62 | super(BottleneckBlock, self).__init__()
63 |
64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | num_groups = planes // 8
70 |
71 | if norm_fn == 'group':
72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75 | if not stride == 1:
76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77 |
78 | elif norm_fn == 'batch':
79 | self.norm1 = nn.BatchNorm2d(planes//4)
80 | self.norm2 = nn.BatchNorm2d(planes//4)
81 | self.norm3 = nn.BatchNorm2d(planes)
82 | if not stride == 1:
83 | self.norm4 = nn.BatchNorm2d(planes)
84 |
85 | elif norm_fn == 'instance':
86 | self.norm1 = nn.InstanceNorm2d(planes//4)
87 | self.norm2 = nn.InstanceNorm2d(planes//4)
88 | self.norm3 = nn.InstanceNorm2d(planes)
89 | if not stride == 1:
90 | self.norm4 = nn.InstanceNorm2d(planes)
91 |
92 | elif norm_fn == 'none':
93 | self.norm1 = nn.Sequential()
94 | self.norm2 = nn.Sequential()
95 | self.norm3 = nn.Sequential()
96 | if not stride == 1:
97 | self.norm4 = nn.Sequential()
98 |
99 | if stride == 1:
100 | self.downsample = None
101 |
102 | else:
103 | self.downsample = nn.Sequential(
104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105 |
106 |
107 | def forward(self, x):
108 | y = x
109 | y = self.relu(self.norm1(self.conv1(y)))
110 | y = self.relu(self.norm2(self.conv2(y)))
111 | y = self.relu(self.norm3(self.conv3(y)))
112 |
113 | if self.downsample is not None:
114 | x = self.downsample(x)
115 |
116 | return self.relu(x+y)
117 |
118 | class BasicEncoder(nn.Module):
119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120 | super(BasicEncoder, self).__init__()
121 | self.norm_fn = norm_fn
122 |
123 | if self.norm_fn == 'group':
124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125 |
126 | elif self.norm_fn == 'batch':
127 | self.norm1 = nn.BatchNorm2d(64)
128 |
129 | elif self.norm_fn == 'instance':
130 | self.norm1 = nn.InstanceNorm2d(64)
131 |
132 | elif self.norm_fn == 'none':
133 | self.norm1 = nn.Sequential()
134 |
135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136 | self.relu1 = nn.ReLU(inplace=True)
137 |
138 | self.in_planes = 64
139 | self.layer1 = self._make_layer(64, stride=1)
140 | self.layer2 = self._make_layer(96, stride=2)
141 | self.layer3 = self._make_layer(128, stride=2)
142 |
143 | # output convolution
144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145 |
146 | self.dropout = None
147 | if dropout > 0:
148 | self.dropout = nn.Dropout2d(p=dropout)
149 |
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154 | if m.weight is not None:
155 | nn.init.constant_(m.weight, 1)
156 | if m.bias is not None:
157 | nn.init.constant_(m.bias, 0)
158 |
159 | def _make_layer(self, dim, stride=1):
160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162 | layers = (layer1, layer2)
163 |
164 | self.in_planes = dim
165 | return nn.Sequential(*layers)
166 |
167 |
168 | def forward(self, x):
169 |
170 | # if input is list, combine batch dimension
171 | is_list = isinstance(x, tuple) or isinstance(x, list)
172 | if is_list:
173 | batch_dim = x[0].shape[0]
174 | x = torch.cat(x, dim=0)
175 |
176 | x = self.conv1(x)
177 | x = self.norm1(x)
178 | x = self.relu1(x)
179 |
180 | x = self.layer1(x)
181 | x = self.layer2(x)
182 | x = self.layer3(x)
183 |
184 | x = self.conv2(x)
185 |
186 | if self.training and self.dropout is not None:
187 | x = self.dropout(x)
188 |
189 | if is_list:
190 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
191 |
192 | return x
193 |
194 |
195 | class SmallEncoder(nn.Module):
196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
197 | super(SmallEncoder, self).__init__()
198 | self.norm_fn = norm_fn
199 |
200 | if self.norm_fn == 'group':
201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
202 |
203 | elif self.norm_fn == 'batch':
204 | self.norm1 = nn.BatchNorm2d(32)
205 |
206 | elif self.norm_fn == 'instance':
207 | self.norm1 = nn.InstanceNorm2d(32)
208 |
209 | elif self.norm_fn == 'none':
210 | self.norm1 = nn.Sequential()
211 |
212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
213 | self.relu1 = nn.ReLU(inplace=True)
214 |
215 | self.in_planes = 32
216 | self.layer1 = self._make_layer(32, stride=1)
217 | self.layer2 = self._make_layer(64, stride=2)
218 | self.layer3 = self._make_layer(96, stride=2)
219 |
220 | self.dropout = None
221 | if dropout > 0:
222 | self.dropout = nn.Dropout2d(p=dropout)
223 |
224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
225 |
226 | for m in self.modules():
227 | if isinstance(m, nn.Conv2d):
228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230 | if m.weight is not None:
231 | nn.init.constant_(m.weight, 1)
232 | if m.bias is not None:
233 | nn.init.constant_(m.bias, 0)
234 |
235 | def _make_layer(self, dim, stride=1):
236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
238 | layers = (layer1, layer2)
239 |
240 | self.in_planes = dim
241 | return nn.Sequential(*layers)
242 |
243 |
244 | def forward(self, x):
245 |
246 | # if input is list, combine batch dimension
247 | is_list = isinstance(x, tuple) or isinstance(x, list)
248 | if is_list:
249 | batch_dim = x[0].shape[0]
250 | x = torch.cat(x, dim=0)
251 |
252 | x = self.conv1(x)
253 | x = self.norm1(x)
254 | x = self.relu1(x)
255 |
256 | x = self.layer1(x)
257 | x = self.layer2(x)
258 | x = self.layer3(x)
259 | x = self.conv2(x)
260 |
261 | if self.training and self.dropout is not None:
262 | x = self.dropout(x)
263 |
264 | if is_list:
265 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
266 |
267 | return x
268 |
--------------------------------------------------------------------------------
/core/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from update import BasicUpdateBlock, SmallUpdateBlock
7 | from extractor import BasicEncoder, SmallEncoder
8 | from corr import CorrBlock, AlternateCorrBlock
9 | from utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except:
14 | # dummy autocast for PyTorch < 1.6
15 | class autocast:
16 | def __init__(self, enabled):
17 | pass
18 | def __enter__(self):
19 | pass
20 | def __exit__(self, *args):
21 | pass
22 |
23 |
24 | class RAFT(nn.Module):
25 | def __init__(self, args):
26 | super(RAFT, self).__init__()
27 | self.args = args
28 |
29 | if args.small:
30 | self.hidden_dim = hdim = 96
31 | self.context_dim = cdim = 64
32 | args.corr_levels = 4
33 | args.corr_radius = 3
34 |
35 | else:
36 | self.hidden_dim = hdim = 128
37 | self.context_dim = cdim = 128
38 | args.corr_levels = 4
39 | args.corr_radius = 4
40 |
41 | if 'dropout' not in self.args:
42 | self.args.dropout = 0
43 |
44 | if 'alternate_corr' not in self.args:
45 | self.args.alternate_corr = False
46 |
47 | # feature network, context network, and update block
48 | if args.small:
49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52 |
53 | else:
54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57 |
58 | def freeze_bn(self):
59 | for m in self.modules():
60 | if isinstance(m, nn.BatchNorm2d):
61 | m.eval()
62 |
63 | def initialize_flow(self, img):
64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
65 | N, C, H, W = img.shape
66 | coords0 = coords_grid(N, H//8, W//8, device=img.device)
67 | coords1 = coords_grid(N, H//8, W//8, device=img.device)
68 |
69 | # optical flow computed as difference: flow = coords1 - coords0
70 | return coords0, coords1
71 |
72 | def upsample_flow(self, flow, mask):
73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
74 | N, _, H, W = flow.shape
75 | mask = mask.view(N, 1, 9, 8, 8, H, W)
76 | mask = torch.softmax(mask, dim=2)
77 |
78 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
80 |
81 | up_flow = torch.sum(mask * up_flow, dim=2)
82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
83 | return up_flow.reshape(N, 2, 8*H, 8*W)
84 |
85 |
86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
87 | """ Estimate optical flow between pair of frames """
88 |
89 | image1 = 2 * (image1 / 255.0) - 1.0
90 | image2 = 2 * (image2 / 255.0) - 1.0
91 |
92 | image1 = image1.contiguous()
93 | image2 = image2.contiguous()
94 |
95 | hdim = self.hidden_dim
96 | cdim = self.context_dim
97 |
98 | # run the feature network
99 | with autocast(enabled=self.args.mixed_precision):
100 | fmap1, fmap2 = self.fnet([image1, image2])
101 |
102 | fmap1 = fmap1.float()
103 | fmap2 = fmap2.float()
104 | if self.args.alternate_corr:
105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
106 | else:
107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
108 |
109 | # run the context network
110 | with autocast(enabled=self.args.mixed_precision):
111 | cnet = self.cnet(image1)
112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
113 | net = torch.tanh(net)
114 | inp = torch.relu(inp)
115 |
116 | coords0, coords1 = self.initialize_flow(image1)
117 |
118 | if flow_init is not None:
119 | coords1 = coords1 + flow_init
120 |
121 | flow_predictions = []
122 | for itr in range(iters):
123 | coords1 = coords1.detach()
124 | corr = corr_fn(coords1) # index correlation volume
125 |
126 | flow = coords1 - coords0
127 | with autocast(enabled=self.args.mixed_precision):
128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
129 |
130 | # F(t+1) = F(t) + \Delta(t)
131 | coords1 = coords1 + delta_flow
132 |
133 | # upsample predictions
134 | if up_mask is None:
135 | flow_up = upflow8(coords1 - coords0)
136 | else:
137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
138 |
139 | flow_predictions.append(flow_up)
140 |
141 | if test_mode:
142 | return coords1 - coords0, flow_up
143 |
144 | return flow_predictions
145 |
--------------------------------------------------------------------------------
/core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self, args):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self, args):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, args, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(args)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, args, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.args = args
118 | self.encoder = BasicMotionEncoder(args)
119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121 |
122 | self.mask = nn.Sequential(
123 | nn.Conv2d(128, 256, 3, padding=1),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(256, 64*9, 1, padding=0))
126 |
127 | def forward(self, net, inp, corr, flow, upsample=True):
128 | motion_features = self.encoder(flow, corr)
129 | inp = torch.cat([inp, motion_features], dim=1)
130 |
131 | net = self.gru(net, inp)
132 | delta_flow = self.flow_head(net)
133 |
134 | # scale mask to balence gradients
135 | mask = .25 * self.mask(net)
136 | return net, mask, delta_flow
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/core/utils/__init__.py
--------------------------------------------------------------------------------
/core/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import math
4 | from PIL import Image
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | import torch
11 | from torchvision.transforms import ColorJitter
12 | import torch.nn.functional as F
13 |
14 |
15 | class FlowAugmentor:
16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
17 |
18 | # spatial augmentation params
19 | self.crop_size = crop_size
20 | self.min_scale = min_scale
21 | self.max_scale = max_scale
22 | self.spatial_aug_prob = 0.8
23 | self.stretch_prob = 0.8
24 | self.max_stretch = 0.2
25 |
26 | # flip augmentation params
27 | self.do_flip = do_flip
28 | self.h_flip_prob = 0.5
29 | self.v_flip_prob = 0.1
30 |
31 | # photometric augmentation params
32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
33 | self.asymmetric_color_aug_prob = 0.2
34 | self.eraser_aug_prob = 0.5
35 |
36 | def color_transform(self, img1, img2):
37 | """ Photometric augmentation """
38 |
39 | # asymmetric
40 | if np.random.rand() < self.asymmetric_color_aug_prob:
41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
43 |
44 | # symmetric
45 | else:
46 | image_stack = np.concatenate([img1, img2], axis=0)
47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
48 | img1, img2 = np.split(image_stack, 2, axis=0)
49 |
50 | return img1, img2
51 |
52 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
53 | """ Occlusion augmentation """
54 |
55 | ht, wd = img1.shape[:2]
56 | if np.random.rand() < self.eraser_aug_prob:
57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
58 | for _ in range(np.random.randint(1, 3)):
59 | x0 = np.random.randint(0, wd)
60 | y0 = np.random.randint(0, ht)
61 | dx = np.random.randint(bounds[0], bounds[1])
62 | dy = np.random.randint(bounds[0], bounds[1])
63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
64 |
65 | return img1, img2
66 |
67 | def spatial_transform(self, img1, img2, flow):
68 | # randomly sample scale
69 | ht, wd = img1.shape[:2]
70 | min_scale = np.maximum(
71 | (self.crop_size[0] + 8) / float(ht),
72 | (self.crop_size[1] + 8) / float(wd))
73 |
74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
75 | scale_x = scale
76 | scale_y = scale
77 | if np.random.rand() < self.stretch_prob:
78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
80 |
81 | scale_x = np.clip(scale_x, min_scale, None)
82 | scale_y = np.clip(scale_y, min_scale, None)
83 |
84 | if np.random.rand() < self.spatial_aug_prob:
85 | # rescale the images
86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
89 | flow = flow * [scale_x, scale_y]
90 |
91 | if self.do_flip:
92 | if np.random.rand() < self.h_flip_prob: # h-flip
93 | img1 = img1[:, ::-1]
94 | img2 = img2[:, ::-1]
95 | flow = flow[:, ::-1] * [-1.0, 1.0]
96 |
97 | if np.random.rand() < self.v_flip_prob: # v-flip
98 | img1 = img1[::-1, :]
99 | img2 = img2[::-1, :]
100 | flow = flow[::-1, :] * [1.0, -1.0]
101 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
102 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
103 |
104 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
105 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
106 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
107 |
108 | return img1, img2, flow
109 |
110 | def __call__(self, img1, img2, flow):
111 | img1, img2 = self.color_transform(img1, img2)
112 | img1, img2 = self.eraser_transform(img1, img2)
113 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
114 |
115 | img1 = np.ascontiguousarray(img1)
116 | img2 = np.ascontiguousarray(img2)
117 | flow = np.ascontiguousarray(flow)
118 |
119 | return img1, img2, flow
120 |
121 | class SparseFlowAugmentor:
122 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
123 | # spatial augmentation params
124 | self.crop_size = crop_size
125 | self.min_scale = min_scale
126 | self.max_scale = max_scale
127 | self.spatial_aug_prob = 0.8
128 | self.stretch_prob = 0.8
129 | self.max_stretch = 0.2
130 |
131 | # flip augmentation params
132 | self.do_flip = do_flip
133 | self.h_flip_prob = 0.5
134 | self.v_flip_prob = 0.1
135 |
136 | # photometric augmentation params
137 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
138 | self.asymmetric_color_aug_prob = 0.2
139 | self.eraser_aug_prob = 0.5
140 |
141 | def color_transform(self, img1, img2):
142 | image_stack = np.concatenate([img1, img2], axis=0)
143 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
144 | img1, img2 = np.split(image_stack, 2, axis=0)
145 | return img1, img2
146 |
147 | def eraser_transform(self, img1, img2):
148 | ht, wd = img1.shape[:2]
149 | if np.random.rand() < self.eraser_aug_prob:
150 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
151 | for _ in range(np.random.randint(1, 3)):
152 | x0 = np.random.randint(0, wd)
153 | y0 = np.random.randint(0, ht)
154 | dx = np.random.randint(50, 100)
155 | dy = np.random.randint(50, 100)
156 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
157 |
158 | return img1, img2
159 |
160 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
161 | ht, wd = flow.shape[:2]
162 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
163 | coords = np.stack(coords, axis=-1)
164 |
165 | coords = coords.reshape(-1, 2).astype(np.float32)
166 | flow = flow.reshape(-1, 2).astype(np.float32)
167 | valid = valid.reshape(-1).astype(np.float32)
168 |
169 | coords0 = coords[valid>=1]
170 | flow0 = flow[valid>=1]
171 |
172 | ht1 = int(round(ht * fy))
173 | wd1 = int(round(wd * fx))
174 |
175 | coords1 = coords0 * [fx, fy]
176 | flow1 = flow0 * [fx, fy]
177 |
178 | xx = np.round(coords1[:,0]).astype(np.int32)
179 | yy = np.round(coords1[:,1]).astype(np.int32)
180 |
181 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
182 | xx = xx[v]
183 | yy = yy[v]
184 | flow1 = flow1[v]
185 |
186 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
187 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
188 |
189 | flow_img[yy, xx] = flow1
190 | valid_img[yy, xx] = 1
191 |
192 | return flow_img, valid_img
193 |
194 | def spatial_transform(self, img1, img2, flow, valid):
195 | # randomly sample scale
196 |
197 | ht, wd = img1.shape[:2]
198 | min_scale = np.maximum(
199 | (self.crop_size[0] + 1) / float(ht),
200 | (self.crop_size[1] + 1) / float(wd))
201 |
202 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
203 | scale_x = np.clip(scale, min_scale, None)
204 | scale_y = np.clip(scale, min_scale, None)
205 |
206 | if np.random.rand() < self.spatial_aug_prob:
207 | # rescale the images
208 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
209 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
210 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
211 |
212 | if self.do_flip:
213 | if np.random.rand() < 0.5: # h-flip
214 | img1 = img1[:, ::-1]
215 | img2 = img2[:, ::-1]
216 | flow = flow[:, ::-1] * [-1.0, 1.0]
217 | valid = valid[:, ::-1]
218 |
219 | margin_y = 20
220 | margin_x = 50
221 |
222 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
223 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
224 |
225 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
226 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
227 |
228 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
229 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
230 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
231 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
232 | return img1, img2, flow, valid
233 |
234 |
235 | def __call__(self, img1, img2, flow, valid):
236 | img1, img2 = self.color_transform(img1, img2)
237 | img1, img2 = self.eraser_transform(img1, img2)
238 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
239 |
240 | img1 = np.ascontiguousarray(img1)
241 | img2 = np.ascontiguousarray(img2)
242 | flow = np.ascontiguousarray(flow)
243 | valid = np.ascontiguousarray(valid)
244 |
245 | return img1, img2, flow, valid
246 |
--------------------------------------------------------------------------------
/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd, device):
75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/external/forward_warping/compile.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | gcc -fPIC -shared -o libwarping.so warping.c
--------------------------------------------------------------------------------
/external/forward_warping/libwarping.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/external/forward_warping/libwarping.so
--------------------------------------------------------------------------------
/external/forward_warping/warping.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #define valid(X, Y, W) (Y*W*5+X*5+3)
4 | #define collision(X, Y, W) (Y*W*5+X*5+4)
5 |
6 | void forward_warping(const void *src, const void *idx, const void *idy, const void *z, void *warped, int h, int w)
7 | {
8 | float *dlut = (float *)calloc(h * w, sizeof(float));
9 | for (int i = 0; i < h; i++)
10 | for (int j = 0; j < w; j++)
11 | dlut[i * w + j] = 1000;
12 |
13 | for (int i = 0; i < h; i++)
14 | for (int j = 0; j < w; j++)
15 | {
16 | int x = ((long *)idx)[i * w + j];
17 | int y = ((long *)idy)[i * w + j];
18 |
19 | if (((float *)z)[i * w + j] < dlut[y * w + x])
20 | for (int c = 0; c < 3; c++)
21 | ((unsigned char *)warped)[y * w * 5 + x * 5 + c] = ((unsigned char *)src)[i * w * 3 + j * 3 + c];
22 |
23 | ((unsigned char *)warped)[valid(x,y,w)] = 1;
24 | if (dlut[y * w + x] != 1000)
25 | ((unsigned char *)warped)[collision(x,y,w)] = 0;
26 | else
27 | ((unsigned char *)warped)[collision(x,y,w)] = 1;
28 | dlut[y * w + x] = ((float *)z)[i * w + j];
29 | }
30 |
31 | free(dlut);
32 | return;
33 | }
34 |
--------------------------------------------------------------------------------
/flow_colors.py:
--------------------------------------------------------------------------------
1 | #
2 | # Utility functions for coloring optical flow maps
3 | #
4 |
5 | import os
6 | import numpy as np
7 | import sys
8 | import re
9 | import cv2
10 | import torch
11 | import torch.nn.functional as F
12 |
13 | def make_colorwheel():
14 | """
15 | Generates a color wheel for optical flow visualization as presented in:
16 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
17 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
18 | Code follows the original C++ source code of Daniel Scharstein.
19 | Code follows the the Matlab source code of Deqing Sun.
20 | Returns:
21 | np.ndarray: Color wheel
22 | """
23 |
24 | RY = 15
25 | YG = 6
26 | GC = 4
27 | CB = 11
28 | BM = 13
29 | MR = 6
30 |
31 | ncols = RY + YG + GC + CB + BM + MR
32 | colorwheel = np.zeros((ncols, 3))
33 | col = 0
34 |
35 | # RY
36 | colorwheel[0:RY, 0] = 255
37 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
38 | col = col + RY
39 | # YG
40 | colorwheel[col: col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
41 | colorwheel[col: col + YG, 1] = 255
42 | col = col + YG
43 | # GC
44 | colorwheel[col: col + GC, 1] = 255
45 | colorwheel[col: col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
46 | col = col + GC
47 | # CB
48 | colorwheel[col: col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
49 | colorwheel[col: col + CB, 2] = 255
50 | col = col + CB
51 | # BM
52 | colorwheel[col: col + BM, 2] = 255
53 | colorwheel[col: col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
54 | col = col + BM
55 | # MR
56 | colorwheel[col: col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
57 | colorwheel[col: col + MR, 0] = 255
58 | return colorwheel
59 |
60 |
61 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
62 | """
63 | Applies the flow color wheel to (possibly clipped) flow components u and v.
64 | According to the C++ source code of Daniel Scharstein
65 | According to the Matlab source code of Deqing Sun
66 | Args:
67 | u (np.ndarray): Input horizontal flow of shape [H,W]
68 | v (np.ndarray): Input vertical flow of shape [H,W]
69 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
70 | Returns:
71 | np.ndarray: Flow visualization image of shape [H,W,3]
72 | """
73 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
74 | colorwheel = make_colorwheel() # shape [55x3]
75 | ncols = colorwheel.shape[0]
76 | rad = np.sqrt(np.square(u) + np.square(v))
77 | a = np.arctan2(-v, -u) / np.pi
78 | fk = (a + 1) / 2 * (ncols - 1)
79 | k0 = np.floor(fk).astype(np.int32)
80 | k1 = k0 + 1
81 | k1[k1 == ncols] = 0
82 | f = fk - k0
83 | for i in range(colorwheel.shape[1]):
84 | tmp = colorwheel[:, i]
85 | col0 = tmp[k0] / 255.0
86 | col1 = tmp[k1] / 255.0
87 | col = (1 - f) * col0 + f * col1
88 | idx = rad <= 1
89 | col[idx] = 1 - rad[idx] * (1 - col[idx])
90 | col[~idx] = col[~idx] * 0.75 # out of range
91 | # Note the 2-i => BGR instead of RGB
92 | ch_idx = 2 - i if convert_to_bgr else i
93 | flow_image[:, :, ch_idx] = np.floor(255 * col)
94 | return flow_image
95 |
96 |
97 | def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
98 | """
99 | Expects a two dimensional flow image of shape.
100 | Args:
101 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
102 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
103 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
104 | Returns:
105 | np.ndarray: Flow visualization image of shape [H,W,3]
106 | """
107 | assert flow_uv.ndim == 3, "input flow must have three dimensions"
108 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]"
109 | if clip_flow is not None:
110 | flow_uv = np.clip(flow_uv, 0, clip_flow)
111 | u = flow_uv[:, :, 0]
112 | v = flow_uv[:, :, 1]
113 | rad = np.sqrt(np.square(u) + np.square(v))
114 | rad_max = np.max(rad)
115 | epsilon = 1e-5
116 | u = u / (rad_max + epsilon)
117 | v = v / (rad_max + epsilon)
118 | return flow_uv_to_colors(u, v, convert_to_bgr)
119 |
--------------------------------------------------------------------------------
/gen_3dphoto_dynamic_v2.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn.functional as F
4 | import os
5 | import cv2
6 | from tqdm import tqdm
7 | from torchvision.utils import save_image
8 | from write_flow import writeFlow
9 |
10 | from utils.utils import (
11 | image_to_tensor,
12 | disparity_to_tensor,
13 | render_3dphoto_dynamic,
14 | )
15 | from model.AdaMPI import MPIPredictor
16 | from random import seed
17 | import numpy as np
18 | from PIL import Image
19 |
20 | parser = argparse.ArgumentParser(
21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22 | parser.add_argument('--width', type=int, default=1280)
23 | parser.add_argument('--height', type=int, default=384)
24 | parser.add_argument('--seed', type=int, default=114514)
25 | parser.add_argument('--ext_cz', type=float, default=0.15)
26 | parser.add_argument('--ckpt_path', type=str,
27 | default='adampiweight/adampi_64p.pth')
28 | parser.add_argument('--repeat', type=int, default=5)
29 | parser.add_argument('--base', type=str,
30 | default='', required=True)
31 | parser.add_argument('--out', type=str,
32 | default='', required=True)
33 |
34 | opt, _ = parser.parse_known_args()
35 |
36 | print(opt)
37 |
38 | seed(opt.seed)
39 | np.random.seed(opt.seed)
40 |
41 | # render 3D photo
42 | K = torch.tensor([
43 | [0.58, 0, 0.5],
44 | [0, 0.58, 0.5],
45 | [0, 0, 1]
46 | ]).cuda().half()
47 | K[0, :] *= opt.width
48 | K[1, :] *= opt.height
49 | K = K.unsqueeze(0)
50 |
51 | # load pretrained model
52 | ckpt = torch.load(opt.ckpt_path)
53 | model = MPIPredictor(
54 | width=opt.width,
55 | height=opt.height,
56 | num_planes=ckpt['num_planes'],
57 | )
58 | model.load_state_dict(ckpt['weight'])
59 | model = model.cuda().half()
60 | model.eval()
61 | # model = torch.jit.script(model)
62 |
63 | out = opt.out
64 | base = opt.base
65 |
66 | if not os.path.exists(out):
67 | os.mkdir(out)
68 | os.mkdir(f"{out}/src_images")
69 | os.mkdir(f"{out}/dst_images")
70 | os.mkdir(f"{out}/flows")
71 | os.mkdir(f"{out}/obj_mask")
72 |
73 |
74 | img_base = os.path.join(base, "images")
75 | disp_base = os.path.join(base, "disps")
76 | mask_base = os.path.join(base, "masks")
77 |
78 | for img in tqdm(sorted(os.listdir(img_base))):
79 |
80 | name = img.split(".")[0]
81 |
82 | image = image_to_tensor(os.path.join(img_base, img)).cuda().half() # [1,3,h,w]
83 | obj_mask_np = np.array(Image.open(os.path.join(mask_base, img)).convert("L"))
84 | disp = disparity_to_tensor(os.path.join(disp_base, img)).cuda().half() # [1,1,h,w]
85 |
86 | image = F.interpolate(image, size=(opt.height, opt.width),
87 | mode='bilinear', align_corners=True)
88 | disp = F.interpolate(disp, size=(opt.height, opt.width),
89 | mode='bilinear', align_corners=True)
90 |
91 | # disp.requires_grad = True
92 | with torch.no_grad():
93 | mpi_all_src, disparity_all_src = model(image, disp) # [b,s,4,h,w]
94 |
95 | # import IPython
96 | # IPython.embed()
97 | # exit()
98 |
99 | for r in range(opt.repeat):
100 | # predict MPI planes
101 | obj_index = np.random.randint(obj_mask_np.max()) + 1
102 | # print(obj_mask_np.max(), obj_index)
103 | obj_mask = torch.FloatTensor(obj_mask_np == obj_index).cuda().half().unsqueeze(0).unsqueeze(0) # [1,3,h,w]
104 | obj_mask = F.interpolate(obj_mask, size=(opt.height, opt.width),
105 | mode='bilinear', align_corners=True)
106 |
107 | flow_mix, src_np, inpainted, res = render_3dphoto_dynamic(
108 | opt,
109 | image,
110 | obj_mask,
111 | disp,
112 | mpi_all_src,
113 | disparity_all_src,
114 | K,
115 | K,
116 | data_path='outputs',
117 | name='demo'
118 | )
119 |
120 | writeFlow(os.path.join(out, "flows", f'{name}_{r}.flo'), flow_mix)
121 | cv2.imwrite(os.path.join(out, "dst_images", f'{name}_{r}.png'), inpainted)
122 | cv2.imwrite(os.path.join(out, "src_images", f'{name}_{r}.png'), src_np)
--------------------------------------------------------------------------------
/geometry.py:
--------------------------------------------------------------------------------
1 | #
2 | # Classes and functions in this script are taken from https://github.com/nianticlabs/monodepth2
3 | # Use conditions available in the LICENSE file at https://github.com/nianticlabs/monodepth2/blob/master/LICENSE
4 | # Copyright © Niantic, Inc. 2018. All rights reserved.
5 |
6 | from __future__ import absolute_import, division, print_function
7 |
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | __all__ = ["BackprojectDepth", "Project3D", "transformation_from_parameters"]
15 |
16 |
17 | class BackprojectDepth(nn.Module):
18 | """Layer to transform a depth image into a point cloud
19 | """
20 | def __init__(self, batch_size, height, width):
21 | super(BackprojectDepth, self).__init__()
22 |
23 | self.batch_size = batch_size
24 | self.height = height
25 | self.width = width
26 |
27 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
28 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
29 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
30 | requires_grad=False)
31 |
32 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
33 | requires_grad=False)
34 |
35 | self.pix_coords = torch.unsqueeze(torch.stack(
36 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
37 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
38 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
39 | requires_grad=False)
40 |
41 | def forward(self, depth, inv_K):
42 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
43 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points
44 | cam_points = torch.cat([cam_points, self.ones], 1)
45 | # import IPython
46 | # IPython.embed()
47 | # exit()
48 |
49 | return cam_points
50 |
51 |
52 | class Project3D(nn.Module):
53 | """Layer which projects 3D points into a camera with intrinsics K and at position T
54 | """
55 | def __init__(self, batch_size, height, width, eps=1e-7):
56 | super(Project3D, self).__init__()
57 |
58 | self.batch_size = batch_size
59 | self.height = height
60 | self.width = width
61 | self.eps = eps
62 |
63 | def forward(self, points, K, T, T2=None):
64 | if not T2 is None:
65 | T = torch.matmul(T, torch.inverse(T2))
66 | P = torch.matmul(K, T)[:, :3, :]
67 |
68 | cam_points = torch.matmul(P, points)
69 |
70 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
71 | pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
72 | pix_coords = pix_coords.permute(0, 2, 3, 1)
73 | pix_coords[..., 0] /= self.width - 1
74 | pix_coords[..., 1] /= self.height - 1
75 | pix_coords = (pix_coords - 0.5) * 2
76 | return pix_coords, cam_points[:, 2, :].unsqueeze(1)
77 |
78 |
79 | def transformation_from_parameters(axisangle, translation, invert=False):
80 | """Convert the network's (axisangle, translation) output into a 4x4 matrix
81 | """
82 | R = rot_from_axisangle(axisangle)
83 | t = translation.clone()
84 |
85 | if invert:
86 | R = R.transpose(1, 2)
87 | t *= -1
88 |
89 | T = get_translation_matrix(t)
90 |
91 | if invert:
92 | M = torch.matmul(R, T)
93 | else:
94 | M = torch.matmul(T, R)
95 | return M
96 |
97 |
98 | def get_translation_matrix(translation_vector):
99 | """Convert a translation vector into a 4x4 transformation matrix
100 | """
101 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
102 |
103 | t = translation_vector.contiguous().view(-1, 3, 1)
104 |
105 | T[:, 0, 0] = 1
106 | T[:, 1, 1] = 1
107 | T[:, 2, 2] = 1
108 | T[:, 3, 3] = 1
109 | T[:, :3, 3, None] = t
110 |
111 | return T
112 |
113 |
114 | def rot_from_axisangle(vec):
115 | """Convert an axisangle rotation into a 4x4 transformation matrix
116 | (adapted from https://github.com/Wallacoloo/printipi)
117 | Input 'vec' has to be Bx1x3
118 | """
119 | angle = torch.norm(vec, 2, 2, True)
120 | axis = vec / (angle + 1e-7)
121 |
122 | ca = torch.cos(angle)
123 | sa = torch.sin(angle)
124 | C = 1 - ca
125 |
126 | x = axis[..., 0].unsqueeze(1)
127 | y = axis[..., 1].unsqueeze(1)
128 | z = axis[..., 2].unsqueeze(1)
129 |
130 | xs = x * sa
131 | ys = y * sa
132 | zs = z * sa
133 | xC = x * C
134 | yC = y * C
135 | zC = z * C
136 | xyC = x * yC
137 | yzC = y * zC
138 | zxC = z * xC
139 |
140 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
141 |
142 | rot[:, 0, 0] = torch.squeeze(x * xC + ca)
143 | rot[:, 0, 1] = torch.squeeze(xyC - zs)
144 | rot[:, 0, 2] = torch.squeeze(zxC + ys)
145 | rot[:, 1, 0] = torch.squeeze(xyC + zs)
146 | rot[:, 1, 1] = torch.squeeze(y * yC + ca)
147 | rot[:, 1, 2] = torch.squeeze(yzC - xs)
148 | rot[:, 2, 0] = torch.squeeze(zxC - ys)
149 | rot[:, 2, 1] = torch.squeeze(yzC + xs)
150 | rot[:, 2, 2] = torch.squeeze(z * zC + ca)
151 | rot[:, 3, 3] = 1
152 |
153 | return rot
154 |
--------------------------------------------------------------------------------
/misc/train_image_2_000000_00_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/MPI-Flow/5ca4894cb36d9ad1e99af7db908735b539f80a4e/misc/train_image_2_000000_00_1.png
--------------------------------------------------------------------------------
/model/AdaMPI.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class MPIPredictor(nn.Module):
7 | def __init__(
8 | self,
9 | width=384,
10 | height=256,
11 | num_planes=64,
12 | ):
13 | super(MPIPredictor, self).__init__()
14 | self.num_planes = num_planes
15 | disp_range = [0.001, 1]
16 | self.far, self.near = disp_range
17 |
18 | H_tgt, W_tgt = height, width
19 | ctx_spatial_scale = 4
20 | self.low_res_size = (int(H_tgt / ctx_spatial_scale), int(W_tgt / ctx_spatial_scale))
21 |
22 | # -----------------------
23 | # CPN Encoder
24 | # -----------------------
25 | from model.CPN.encoder import ResnetEncoder
26 | self.encoder = ResnetEncoder(num_layers=18)
27 |
28 | # -----------------------
29 | # CPN Feature Mask UNet
30 | # -----------------------
31 | from model.CPN.unet import FeatMaskNetwork
32 | self.fmn = FeatMaskNetwork()
33 |
34 | # -----------------------
35 | # PAN
36 | # -----------------------
37 | from model.PAN import DepthPredictionNetwork
38 | self.dpn = DepthPredictionNetwork(
39 | disp_range=disp_range,
40 | n_planes=num_planes,
41 | )
42 |
43 | # -----------------------
44 | # CPN Decoder
45 | # -----------------------
46 | from model.CPN.decoder import DepthDecoder
47 | num_ch_enc = self.encoder.num_ch_enc
48 | self.decoder = DepthDecoder(
49 | num_ch_enc=num_ch_enc,
50 | use_alpha=False,
51 | scales=range(4),
52 | use_skips=True,
53 | )
54 |
55 | def forward(
56 | self,
57 | src_imgs,
58 | src_depths,
59 | ):
60 | rgb_low_res = F.interpolate(src_imgs, size=self.low_res_size, mode='bilinear', align_corners=True)
61 | disp_low_res = F.interpolate(src_depths, size=self.low_res_size, mode='bilinear', align_corners=True)
62 |
63 | bs = src_imgs.shape[0]
64 | dpn_input_disparity = torch.linspace(
65 | self.near,
66 | self.far,
67 | self.num_planes + 2
68 | )[1:-1].to(src_imgs.dtype).to(src_imgs.device).unsqueeze(0).repeat(bs, 1)
69 |
70 | # render_disp = self.dpn(dpn_input_disparity, rgb_low_res, disp_low_res)
71 | render_disp = dpn_input_disparity
72 | feature_mask = self.fmn(src_imgs, src_depths, render_disp)
73 | # Encoder forward
74 | conv1_out, block1_out, block2_out, block3_out, block4_out = self.encoder(src_imgs, src_depths)
75 | enc_features = [conv1_out, block1_out, block2_out, block3_out, block4_out]
76 | # Decoder forward
77 | outputs = self.decoder(enc_features, feature_mask)
78 | return outputs[0], render_disp
79 |
--------------------------------------------------------------------------------
/model/CPN/decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright Niantic 2019. Patent Pending. All rights reserved.
2 | #
3 | # This software is licensed under the terms of the Monodepth2 licence
4 | # which allows for non-commercial use only, the full terms of which are made
5 | # available in the LICENSE file.
6 |
7 |
8 | '''
9 | This code is borrowed heavily from MINE: https://github.com/vincentfung13/MINE
10 | '''
11 |
12 |
13 | import numpy as np
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 |
18 |
19 | def upsample(x):
20 | return F.interpolate(x, scale_factor=2, mode="nearest")
21 |
22 |
23 | class GatedConv(nn.Module):
24 | def __init__(self, in_channels, out_channels):
25 | super(GatedConv, self).__init__()
26 | self.pad = nn.ReflectionPad2d(1)
27 |
28 | self.conv2d = nn.Conv2d(in_channels, out_channels, 3)
29 | self.mask_conv2d = nn.Conv2d(in_channels, out_channels, 3)
30 |
31 | self.sigmoid = nn.Sigmoid()
32 |
33 | def forward(self, feat):
34 | feat = self.pad(feat)
35 | x = self.conv2d(feat)
36 | mask = self.mask_conv2d(feat)
37 | return x * self.sigmoid(mask)
38 |
39 |
40 | class GatedConvBlock(nn.Module):
41 | def __init__(self, in_channels, out_channels):
42 | super(GatedConvBlock, self).__init__()
43 | self.gated_conv = GatedConv(in_channels, out_channels)
44 | self.nonlin = nn.ELU(inplace=True)
45 | self.bn = nn.BatchNorm2d(out_channels)
46 |
47 | def forward(self, feat):
48 | x = self.gated_conv(feat)
49 | x = self.bn(x)
50 | x = self.nonlin(x)
51 | return x
52 |
53 |
54 | def conv(in_planes, out_planes, kernel_size, instancenorm=False):
55 | if instancenorm:
56 | m = nn.Sequential(
57 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
58 | stride=1, padding=(kernel_size - 1) // 2, bias=False),
59 | nn.InstanceNorm2d(out_planes),
60 | nn.LeakyReLU(0.1, inplace=True),
61 | )
62 | else:
63 | m = nn.Sequential(
64 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
65 | stride=1, padding=(kernel_size - 1) // 2, bias=False),
66 | nn.BatchNorm2d(out_planes),
67 | nn.LeakyReLU(0.1, inplace=True)
68 | )
69 | return m
70 |
71 |
72 | class DepthDecoder(nn.Module):
73 | def tuple_to_str(self, key_tuple):
74 | key_str = '-'.join(str(key_tuple))
75 | return key_str
76 |
77 | def __init__(self, num_ch_enc,
78 | use_alpha=False, scales=range(4), num_output_channels=4,
79 | use_skips=True, **kwargs):
80 | super(DepthDecoder, self).__init__()
81 |
82 | self.num_output_channels = num_output_channels
83 | self.use_skips = use_skips
84 | self.upsample_mode = 'nearest'
85 | self.scales = scales
86 | self.use_alpha = use_alpha
87 |
88 | final_enc_out_channels = num_ch_enc[-1]
89 | self.downsample = nn.MaxPool2d(3, stride=2, padding=1)
90 | self.upsample = nn.UpsamplingNearest2d(scale_factor=2)
91 | self.conv_down1 = conv(final_enc_out_channels, 512, 1, False)
92 | self.conv_down2 = conv(512, 256, 3, False)
93 | self.conv_up1 = conv(256, 256, 3, False)
94 | self.conv_up2 = conv(256, final_enc_out_channels, 1, False)
95 |
96 | self.num_ch_enc = num_ch_enc
97 | # print("num_ch_enc=", num_ch_enc)
98 | self.num_ch_enc = [x + 2 for x in self.num_ch_enc]
99 | self.num_ch_dec = np.array([12, 24, 48, 96, 192])
100 | # self.num_ch_enc = np.array([64, 64, 128, 256, 512])
101 |
102 | # decoder
103 | self.convs = nn.ModuleDict()
104 | for i in range(4, -1, -1):
105 | # upconv_0
106 | num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
107 | num_ch_out = self.num_ch_dec[i]
108 | self.convs[self.tuple_to_str(("upconv", i, 0))] = GatedConvBlock(num_ch_in, num_ch_out)
109 | # print("upconv_{}_{}".format(i, 0), num_ch_in, num_ch_out)
110 |
111 | # upconv_1
112 | num_ch_in = self.num_ch_dec[i]
113 | if self.use_skips and i > 0:
114 | num_ch_in += self.num_ch_enc[i - 1]
115 | num_ch_out = self.num_ch_dec[i]
116 | self.convs[self.tuple_to_str(("upconv", i, 1))] = GatedConvBlock(num_ch_in, num_ch_out)
117 | # print("upconv_{}_{}".format(i, 1), num_ch_in, num_ch_out)
118 |
119 | for s in self.scales:
120 | self.convs[self.tuple_to_str(("dispconv", s))] = GatedConv(self.num_ch_dec[s], self.num_output_channels)
121 |
122 | self.sigmoid = nn.Sigmoid()
123 |
124 | def forward(self, input_features, feature_mask):
125 | B, S, _, _ = feature_mask.size()
126 | # extension of encoder to increase receptive field
127 | encoder_out = input_features[-1]
128 | conv_down1 = self.conv_down1(self.downsample(encoder_out))
129 | conv_down2 = self.conv_down2(self.downsample(conv_down1))
130 | conv_up1 = self.conv_up1(self.upsample(conv_down2))
131 | conv_up2 = self.conv_up2(self.upsample(conv_up1))
132 |
133 | # repeat / reshape features
134 | _, C_feat, H_feat, W_feat = conv_up2.size()
135 | cum_mask = torch.cumsum(feature_mask, dim=1) # [B,S,H,W]
136 | inpaint_mask = torch.cat([torch.zeros_like(cum_mask[:, -1:, :, :]), cum_mask[:, :-1, :, :]], dim=1) # [B,S,H,W]
137 | context_mask = 1 - inpaint_mask # [B,S,H,W]
138 |
139 | cur_context_mask = F.adaptive_avg_pool2d(context_mask, (H_feat, W_feat)).unsqueeze(2)
140 | cur_feature_mask = F.adaptive_avg_pool2d(feature_mask, (H_feat, W_feat)).unsqueeze(2)
141 | conv_up2 = conv_up2.unsqueeze(1).repeat(1, S, 1, 1, 1)
142 | conv_up2 = torch.cat([conv_up2 * cur_context_mask, cur_context_mask, cur_feature_mask], dim=2) # [B,S,C+2,H,W]
143 | conv_up2 = conv_up2.reshape(-1, C_feat + 2, H_feat, W_feat) # [BxS,C+2,H,W]
144 |
145 | # repeat / reshape features
146 | for i, feat in enumerate(input_features):
147 | _, C_feat, H_feat, W_feat = feat.size()
148 | cur_context_mask = F.adaptive_avg_pool2d(context_mask, (H_feat, W_feat)).unsqueeze(2)
149 | cur_feature_mask = F.adaptive_avg_pool2d(feature_mask, (H_feat, W_feat)).unsqueeze(2)
150 | feat = feat.unsqueeze(1).repeat(1, S, 1, 1, 1)
151 | feat = torch.cat([feat * cur_context_mask, cur_context_mask, cur_feature_mask], dim=2) # [B,S,C+2,H,W]
152 | input_features[i] = feat.reshape(-1, C_feat + 2, H_feat, W_feat) # [BxS,C+2,H,W]
153 |
154 | outputs = []
155 | x = conv_up2
156 | for i in range(4, -1, -1):
157 | x = self.convs[self.tuple_to_str(("upconv", i, 0))](x)
158 | x = [upsample(x)]
159 | if self.use_skips and i > 0:
160 | x += [input_features[i - 1]]
161 | x = torch.cat(x, 1)
162 | x = self.convs[self.tuple_to_str(("upconv", i, 1))](x)
163 | if i in self.scales:
164 | output = self.convs[self.tuple_to_str(("dispconv", i))](x)
165 | H_mpi, W_mpi = output.size(2), output.size(3)
166 | cur_mask = F.adaptive_avg_pool2d(cum_mask, (H_mpi, W_mpi)).unsqueeze(2)
167 | mpi = output.view(B, S, 4, H_mpi, W_mpi)
168 | mpi_rgb = self.sigmoid(mpi[:, :, 0:3, :, :])
169 | if self.use_alpha:
170 | mpi_sigma = self.sigmoid(mpi[:, :, 3:, :, :]) * cur_mask
171 | else:
172 | mpi_sigma = torch.relu(mpi[:, :, 3:, :, :] * cur_mask) + 1e-4
173 | outputs.append(torch.cat((mpi_rgb, mpi_sigma), dim=2))
174 | return outputs[::-1]
175 |
--------------------------------------------------------------------------------
/model/CPN/encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright Niantic 2019. Patent Pending. All rights reserved.
2 | #
3 | # This software is licensed under the terms of the Monodepth2 licence
4 | # which allows for non-commercial use only, the full terms of which are made
5 | # available in the LICENSE file.
6 |
7 |
8 | '''
9 | This code is borrowed heavily from MINE: https://github.com/vincentfung13/MINE
10 | '''
11 |
12 |
13 | import numpy as np
14 | import torch
15 | import torch.nn as nn
16 | import torchvision.models as models
17 |
18 |
19 | class ResNetMultiImageInput(models.ResNet):
20 | """Constructs a resnet model with varying number of input images.
21 | Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
22 | """
23 | def __init__(self, block, layers, num_input_images=1):
24 | super(ResNetMultiImageInput, self).__init__(block, layers)
25 | self.inplanes = 64
26 | self.conv1 = nn.Conv2d(
27 | num_input_images * 4, 64, kernel_size=7, stride=2, padding=3, bias=False) # 输入为RGBD
28 | self.bn1 = nn.BatchNorm2d(64)
29 | self.relu = nn.ReLU(inplace=True)
30 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
31 | self.layer1 = self._make_layer(block, 64, layers[0])
32 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
33 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
34 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
35 |
36 | for m in self.modules():
37 | if isinstance(m, nn.Conv2d):
38 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
39 | elif isinstance(m, nn.BatchNorm2d):
40 | nn.init.constant_(m.weight, 1)
41 | nn.init.constant_(m.bias, 0)
42 |
43 |
44 | def resnet_multiimage_input(num_layers, num_input_images=1):
45 | """Constructs a ResNet model.
46 | Args:
47 | num_layers (int): Number of resnet layers. Must be 18 or 50
48 | pretrained (bool): If True, returns a model pre-trained on ImageNet
49 | num_input_images (int): Number of frames stacked as input
50 | """
51 | assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
52 | blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
53 | block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
54 | model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
55 |
56 | return model
57 |
58 |
59 | class ResnetEncoder(nn.Module):
60 | """Pytorch module for a resnet encoder
61 | """
62 | def __init__(self, num_layers, num_input_images=1, **kwargs):
63 | super(ResnetEncoder, self).__init__()
64 |
65 | self.num_ch_enc = np.array([64, 64, 128, 256, 512])
66 |
67 | resnets = {18: models.resnet18,
68 | 34: models.resnet34,
69 | 50: models.resnet50,
70 | 101: models.resnet101,
71 | 152: models.resnet152}
72 |
73 | if num_layers not in resnets:
74 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
75 |
76 | self.encoder = resnet_multiimage_input(num_layers, num_input_images)
77 |
78 | if num_layers > 34:
79 | self.num_ch_enc[1:] *= 4
80 |
81 | self.img_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)
82 | self.img_mean = self.img_mean.view(1, 3, 1, 1)
83 | self.img_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)
84 | self.img_std = self.img_std.view(1, 3, 1, 1)
85 |
86 | def forward(self, input_image, input_depth):
87 | # normalize before going into network
88 | ref_images_normalized = (input_image - self.img_mean.to(input_image)) / self.img_std.to(input_image)
89 |
90 | self.features = []
91 | # x = (input_image - 0.45) / 0.225
92 | x = torch.cat([ref_images_normalized, input_depth], dim=1)
93 | x = self.encoder.conv1(x)
94 | x = self.encoder.bn1(x)
95 | conv1_out = self.encoder.relu(x) # [bs,64,h//2,w//2]
96 | block1_out = self.encoder.layer1(self.encoder.maxpool(conv1_out)) # [bs,256,h//4,w//4]
97 | block2_out = self.encoder.layer2(block1_out) # [bs,512,h//8,w//8]
98 | block3_out = self.encoder.layer3(block2_out) # [bs,1024,h//16,w//16]
99 | block4_out = self.encoder.layer4(block3_out) # [bs,2048,h//32,w//32]
100 |
101 | return conv1_out, block1_out, block2_out, block3_out, block4_out
102 |
--------------------------------------------------------------------------------
/model/CPN/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ConvBNReLU(nn.Module):
6 | def __init__(self, ch_in, ch_out, kernel_size, stride, pad):
7 | super().__init__()
8 | self.layer = nn.Sequential(
9 | nn.Conv2d(ch_in, ch_out, kernel_size, stride, pad),
10 | nn.BatchNorm2d(ch_out),
11 | nn.ReLU()
12 | )
13 |
14 | def forward(self, x):
15 | return self.layer(x)
16 |
17 |
18 | class FeatMaskNetwork(nn.Module):
19 | def __init__(self, **kwargs):
20 | super().__init__()
21 | self.conv1 = ConvBNReLU(5, 16, 3, 1, 1)
22 | self.conv2 = ConvBNReLU(16, 32, 3, 2, 1)
23 | self.conv3 = ConvBNReLU(32, 64, 3, 2, 1)
24 | self.conv4 = ConvBNReLU(64, 128, 3, 2, 1)
25 | self.conv5 = ConvBNReLU(128, 128, 3, 1, 1)
26 | self.conv6 = ConvBNReLU(192, 64, 3, 1, 1)
27 | self.conv7 = ConvBNReLU(96, 32, 3, 1, 1)
28 | self.conv8 = ConvBNReLU(48, 16, 3, 1, 1)
29 | self.conv9 = ConvBNReLU(16, 1, 3, 1, 1)
30 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
31 |
32 | def forward(self, input_image, input_depth, input_mpi_disparity):
33 | '''
34 | input_image: [b,3,h,w]
35 | input_depth: [b,1,h,w]
36 | input_mpi_disparity: [b,s]
37 | '''
38 | _, _, h, w = input_image.size() # spatial dim
39 | b, s = input_mpi_disparity.size() # number of mpi planes
40 |
41 | # repeat input rgb
42 | expanded_image = input_image.unsqueeze(1).repeat(1, s, 1, 1, 1) # [b,s,3,h,w]
43 |
44 | # repeat input depth
45 | expanded_depth = input_depth.unsqueeze(1).repeat(1, s, 1, 1, 1) # [b,s,1,h,w]
46 |
47 | # repeat and reshape input mpi disparity
48 | expanded_mpi_disp = input_mpi_disparity[:, :, None, None, None].repeat(1, 1, 1, h, w) # [b,s,1,h,w]
49 |
50 | # concat together
51 | x = torch.cat([expanded_image, expanded_depth, expanded_mpi_disp], dim=2).reshape(b * s, 5, h, w) # [bs,5,h,w]
52 |
53 | # forward
54 | c1 = self.conv1(x)
55 | c2 = self.conv2(c1)
56 | c3 = self.conv3(c2)
57 | c4 = self.conv4(c3)
58 | c5 = self.conv5(c4)
59 | u5 = self.upsample(c5)
60 | c6 = self.conv6(torch.cat([u5, c3], dim=1))
61 | u6 = self.upsample(c6)
62 | c7 = self.conv7(torch.cat([u6, c2], dim=1))
63 | u7 = self.upsample(c7)
64 | c8 = self.conv8(torch.cat([u7, c1], dim=1))
65 | c9 = self.conv9(c8) # [bs,1,h,w]
66 | fm = c9.reshape(b, s, h, w)
67 | fm = torch.softmax(fm ,dim=1)
68 |
69 | return fm
70 |
--------------------------------------------------------------------------------
/model/PAN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def MLP(channels):
7 | """ Multi-layer perceptron """
8 | n = len(channels)
9 | layers = []
10 | for i in range(1, n):
11 | layers.append(
12 | nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
13 | if i < (n-1):
14 | layers.append(nn.ReLU())
15 | return nn.Sequential(*layers)
16 |
17 |
18 | class ResBlock(nn.Module):
19 | def __init__(self, in_channels, out_channels, hidden_channels):
20 | super().__init__()
21 | self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1)
22 | self.conv2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1)
23 | self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
24 | self.activation = nn.ReLU()
25 | self.bn = nn.BatchNorm2d(hidden_channels)
26 |
27 | def forward(self, x):
28 | return self.activation(self.conv3(x) + self.conv2(self.bn(self.activation(self.conv1(x)))))
29 |
30 |
31 | class DownsizeEncoder(nn.Module):
32 | def __init__(self, num_blocks, dim_in, dim_out):
33 | super().__init__()
34 | res_blocks = []
35 | for i_block in range(0, num_blocks):
36 | d_in = dim_in if i_block == 0 else max(dim_in, dim_out // (2 ** (num_blocks - i_block)))
37 | d_out = max(dim_in, dim_out // (2 ** (num_blocks - i_block - 1)))
38 | res_blocks.append(ResBlock(in_channels=d_in, out_channels=d_out, hidden_channels=d_out))
39 | self.res_blocks = nn.ModuleList(res_blocks)
40 |
41 | def forward(self, x):
42 | # [b, c, h, w]
43 | for res_block in self.res_blocks:
44 | x = res_block(x)
45 | x = F.avg_pool2d(x, kernel_size=2)
46 | return x # [b, c, h, w]
47 |
48 |
49 | class MultiheadSelfAttention(nn.Module):
50 | def __init__(self, num_heads, dim_in, dim_qk, dim_v):
51 | super().__init__()
52 | self.wQs = nn.ModuleList([nn.Linear(dim_in, dim_qk) for _ in range(num_heads)])
53 | self.wKs = nn.ModuleList([nn.Linear(dim_in, dim_qk) for _ in range(num_heads)])
54 | self.wVs = nn.ModuleList([nn.Linear(dim_in, dim_v // num_heads) for _ in range(num_heads)])
55 | self.fusion = nn.Linear(dim_v, dim_v)
56 | self.norm = dim_qk ** 0.5
57 |
58 | def forward(self, feat):
59 | feat_atted = []
60 | for wQ, wK, wV in zip(self.wQs, self.wKs, self.wVs):
61 | Q = wQ(feat) # [b,s,cq]
62 | K = wK(feat) # [b,s,cq]
63 | V = wV(feat) # [b,s,cv]
64 | att = torch.softmax(torch.einsum('bik,bjk->bij', Q, K) / self.norm, dim=2)
65 | feat_atted.append(torch.einsum('bij,bjc->bic', att, V))
66 | return self.fusion(torch.cat(feat_atted, dim=-1)) # [b,s,c]
67 |
68 |
69 | class LinearSigmoid(nn.Module):
70 | def __init__(self, in_ch, disp_range):
71 | super().__init__()
72 | self.start, self.end = disp_range
73 | self.linear = nn.Linear(in_ch, 1)
74 |
75 | def forward(self, feat, init_disp):
76 | feat = self.linear(feat).squeeze(-1) # [b,s]
77 | return init_disp + feat * 1. / init_disp.shape[1]
78 |
79 |
80 | class DepthPredictionNetwork(nn.Module):
81 | def __init__(self, disp_range, **kwargs):
82 | super().__init__()
83 | self.context_encoder = DownsizeEncoder(num_blocks=5, dim_in=5, dim_out=128)
84 | self.self_attention = MultiheadSelfAttention(num_heads=4, dim_in=128, dim_qk=32, dim_v=128)
85 | self.embed = nn.Sequential(
86 | nn.Linear(128, 32),
87 | nn.ReLU(),
88 | )
89 | self.to_disp = LinearSigmoid(32, disp_range)
90 |
91 | def forward(self, init_disp, rgb_low_res, disp_low_res):
92 | B, S = init_disp.shape
93 |
94 | # context encoder
95 | x = torch.cat([
96 | rgb_low_res[:, None, ...].repeat(1, S, 1, 1, 1),
97 | disp_low_res[:, None, ...].repeat(1, S, 1, 1, 1),
98 | init_disp[:, :, None, None, None].repeat(1, 1, 1, *rgb_low_res.shape[-2:])
99 | ], dim=-3) # [b, s, 5, h/4, w/4]
100 | x = x.view(-1, *x.shape[-3:]) # [b*s, 5, h/4, w/4]
101 | context = self.context_encoder(x) # [b*s, c, h/128, w/128]
102 | context = F.adaptive_avg_pool2d(context, (1, 1)).squeeze(-1).squeeze(-1) # [b*s, c]
103 | context = context.view(B, S, -1) # [b, s, c]
104 |
105 | # self attention
106 | feat_atted = self.self_attention(context) # [b, s, c ]
107 | feat = self.embed(feat_atted) # [b, s, c]
108 | disp_bs = self.to_disp(feat, init_disp) # [b, s]
109 | return disp_bs
110 |
--------------------------------------------------------------------------------
/moving_obj.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import torch
5 | from flow_colors import flow_to_color
6 | from geometry import *
7 | from ctypes import *
8 | import ctypes
9 | import random
10 | import math
11 |
12 | lib = cdll.LoadLibrary("external/forward_warping/libwarping.so")
13 | warp = lib.forward_warping
14 |
15 |
16 | def moveing_object_with_mask(depth_path, disp, rgb, K, inv_K, instance_mask, i):
17 |
18 | # Cast I0 and D0 to pytorch tensors
19 | h, w = rgb.shape[:2]
20 | rgb = torch.from_numpy(np.expand_dims(rgb, 0)).float().cuda()
21 | # depth = torch.from_numpy(np.expand_dims(depth, 0)).float().cuda()
22 |
23 | # debug
24 | # depth = cv2.imread(depth_path, -1) / (2**16-1)
25 | # if depth.shape[0] != h or depth.shape[1] != w:
26 | # depth = cv2.resize(depth, (w, h))
27 |
28 | # Get depth map and normalize
29 | depth = 1.0 / (disp[0] + 0.005)
30 | depth[depth > 100] = 100
31 | # depth = torch.from_numpy(np.expand_dims(depth, 0)).float().cuda()
32 |
33 | instance_mask = instance_mask[0]
34 | instance_mask = torch.stack([instance_mask, instance_mask], -1)
35 |
36 | # Create objects in charge of 3D projection
37 | backproject_depth = BackprojectDepth(1, h, w).cuda()
38 | project_3d = Project3D(1, h, w).cuda()
39 |
40 | # Prepare p0 coordinates
41 | meshgrid = np.meshgrid(range(w), range(h), indexing="xy")
42 | p0 = np.stack(meshgrid, axis=-1).astype(np.float32)
43 |
44 | # Initiate masks dictionary
45 | masks = {}
46 | axisangle = torch.from_numpy(np.array([[[0, 0, 0]]], dtype=np.float32)).cuda()
47 | translation = torch.from_numpy(np.array([[0, 0, 0]])).cuda()
48 |
49 | # Compute (R|t)
50 | T1 = transformation_from_parameters(axisangle, translation)
51 |
52 | temp = torch.zeros((1, 4, 4)).cuda()
53 | temp[0, -1, -1] = 1.
54 | temp[:, :3, :3] = K
55 | K = temp
56 |
57 | temp = torch.zeros((1, 4, 4)).cuda()
58 | temp[0, -1, -1] = 1.
59 | temp[:, :3, :3] = inv_K
60 | inv_K = temp
61 |
62 | # Back-projection
63 | cam_points = backproject_depth(depth, inv_K)
64 |
65 | # Apply transformation T_{0->1}
66 | p1, z1 = project_3d(cam_points, K, T1)
67 | z1 = z1.reshape(1, h, w)
68 |
69 | # Simulate objects moving independently
70 | if True:
71 |
72 | sign = -1
73 |
74 | # Random t (scalars and signs). Zeros and small motions are avoided as before
75 | # cix = (random.random()*0.05+0.05) * \
76 | # (sign*(-1)**random.randrange(2))
77 | # ciy = (random.random()*0.05+0.05) * \
78 | # (sign*(-1)**random.randrange(2))
79 | # ciz = (random.random()*0.05+0.05) * \
80 | # (sign*(-1)**random.randrange(2))
81 | cix = (random.random()*0.05+0.05)
82 | ciy = -1*(random.random()*0.05+0.05)
83 | ciz = (random.random()*0.05+0.05)
84 | camerai_mot = [cix, ciy, ciz]
85 |
86 | # Random Euler angles (scalars and signs). Zeros and small rotations are avoided as before
87 | aix = (random.random()*math.pi / 72.0 + math.pi /
88 | 72.0) * (sign*(-1)**random.randrange(2))
89 | aiy = (random.random()*math.pi / 72.0 + math.pi /
90 | 72.0) * (sign*(-1)**random.randrange(2))
91 | aiz = (random.random()*math.pi / 72.0 + math.pi /
92 | 72.0) * (sign*(-1)**random.randrange(2))
93 | camerai_ang = [aix, aiy, aiz]
94 | camerai_ang = [0, 0, 0]
95 |
96 | ai = torch.from_numpy(
97 | np.array([[camerai_ang]], dtype=np.float32)).cuda()
98 | tri = torch.from_numpy(np.array([[camerai_mot]])).cuda()
99 |
100 | # Compute (R|t)
101 | Ti = transformation_from_parameters(
102 | axisangle + ai, translation + tri)
103 |
104 | # Apply transformation T_{0->\pi_i}
105 | pi, zi = project_3d(cam_points, K, Ti)
106 |
107 | # If a pixel belongs to object label l, replace coordinates in I1...
108 | p1[instance_mask > 0] = pi[instance_mask > 0]
109 |
110 | # ... and its depth
111 | zi = zi.reshape(1, h, w)
112 | z1[instance_mask[:, :, :, 0] > 0] = zi[instance_mask[:, :, :, 0] > 0]
113 |
114 | # Bring p1 coordinates in [0,W-1]x[0,H-1] format
115 | p1 = (p1 + 1) / 2
116 | p1[:, :, :, 0] *= w - 1
117 | p1[:, :, :, 1] *= h - 1
118 |
119 | # Create auxiliary data for warping
120 | dlut = torch.ones(1, h, w).float().cuda() * 1000
121 | safe_y = np.maximum(np.minimum(p1[:, :, :, 1].cpu().long(), h - 1), 0)
122 | safe_x = np.maximum(np.minimum(p1[:, :, :, 0].cpu().long(), w - 1), 0)
123 | warped_arr = np.zeros(h*w*5).astype(np.uint8)
124 | img = rgb.reshape(-1).to(torch.uint8)
125 |
126 | # Call forward warping routine (C code)
127 | warp(c_void_p(img.cpu().numpy().ctypes.data), c_void_p(safe_x[0].cpu().numpy().ctypes.data),
128 | c_void_p(safe_y[0].cpu().numpy().ctypes.data), c_void_p(z1.reshape(-1).cpu().numpy().ctypes.data),
129 | c_void_p(warped_arr.ctypes.data), c_int(h), c_int(w))
130 | warped_arr = warped_arr.reshape(1, h, w, 5).astype(np.uint8)
131 |
132 | # Warped image
133 | im1_raw = warped_arr[0, :, :, 0:3]
134 |
135 | # Validity mask H
136 | masks["H"] = warped_arr[0, :, :, 3:4]
137 |
138 | # Collision mask M
139 | masks["M"] = warped_arr[0, :, :, 4:5]
140 | # Keep all pixels that are invalid (H) or collide (M)
141 | masks["M"] = 1-(masks["M"] == masks["H"]).astype(np.uint8)
142 |
143 | # Dilated collision mask M'
144 | kernel = np.ones((3, 3), np.uint8)
145 | masks["M'"] = cv2.dilate(masks["M"], kernel, iterations=1)
146 | masks["P"] = (np.expand_dims(masks["M'"], -1)
147 | == masks["M"]).astype(np.uint8)
148 |
149 | # Final mask P
150 | masks["H'"] = masks["H"]*masks["P"]
151 |
152 | # Compute flow as p1-p0
153 | flow_01 = p1.cpu().numpy() - p0
154 | im1 = rgb[0].cpu().numpy().copy()
155 | # mask_idx = np.logical_and(
156 | # flow_01[0, :, :, 0] > 1,
157 | # flow_01[0, :, :, 1] > 1
158 | # )
159 | mask_idx = np.where(instance_mask[0, :, :, 0].cpu().numpy())
160 | # mask_xp = mask_x, mask_y
161 |
162 | im1 = cv2.inpaint(im1_raw, 1 - masks["H"], 3, cv2.INPAINT_TELEA)
163 | flow_color = flow_to_color(flow_01[0], convert_to_bgr=True)
164 | mask = cv2.merge([masks["H"]*255, masks["H"]*255, masks["H"]*255])
165 | res = np.vstack(
166 | [rgb[0].cpu().numpy(), im1, im1_raw, mask, flow_color]
167 | )
168 | cv2.imwrite('temp/res-{:06d}.png'.format(i), res)
--------------------------------------------------------------------------------
/scripts/gen_coco.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=5 python gen_3dphoto_dynamic_coco.py \
2 | --base ../dataset/Flow/extra/coco/outputs/ \
3 | --out ../dataset/Flow/extra/coco/MPI-Flow-data \
4 | --repeat 2
--------------------------------------------------------------------------------
/scripts/gen_test_kitti15.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python gen_3dphoto_dynamic.py \
2 | --base ../dataset/Flow/testing/outputs/ \
3 | --out ../dataset/Flow/testing/MPI-Flow-data \
4 | --repeat 5
--------------------------------------------------------------------------------
/scripts/gen_train_kitti15.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=3 python gen_3dphoto_dynamic.py \
2 | --base ../dataset/Flow/training/outputs/ \
3 | --out ./dataset/debug \
4 | --repeat 2 --seed 0
--------------------------------------------------------------------------------
/scripts/gen_train_kitti15_v2.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python gen_3dphoto_dynamic_v2.py \
2 | --base ../dataset/Flow/training/outputs/ \
3 | --out /data/liangyingping/debug_0.35 \
4 | --repeat 2 --seed 0 --ext_cz 0.25
--------------------------------------------------------------------------------
/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 | from utils.arrow import arrowon
20 | import cv2
21 |
22 | def make_colorwheel():
23 | """
24 | Generates a color wheel for optical flow visualization as presented in:
25 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
26 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
27 |
28 | Code follows the original C++ source code of Daniel Scharstein.
29 | Code follows the the Matlab source code of Deqing Sun.
30 |
31 | Returns:
32 | np.ndarray: Color wheel
33 | """
34 |
35 | RY = 15
36 | YG = 6
37 | GC = 4
38 | CB = 11
39 | BM = 13
40 | MR = 6
41 |
42 | ncols = RY + YG + GC + CB + BM + MR
43 | colorwheel = np.zeros((ncols, 3))
44 | col = 0
45 |
46 | # RY
47 | colorwheel[0:RY, 0] = 255
48 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
49 | col = col+RY
50 | # YG
51 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
52 | colorwheel[col:col+YG, 1] = 255
53 | col = col+YG
54 | # GC
55 | colorwheel[col:col+GC, 1] = 255
56 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
57 | col = col+GC
58 | # CB
59 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
60 | colorwheel[col:col+CB, 2] = 255
61 | col = col+CB
62 | # BM
63 | colorwheel[col:col+BM, 2] = 255
64 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
65 | col = col+BM
66 | # MR
67 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
68 | colorwheel[col:col+MR, 0] = 255
69 | return colorwheel
70 |
71 |
72 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
73 | """
74 | Applies the flow color wheel to (possibly clipped) flow components u and v.
75 |
76 | According to the C++ source code of Daniel Scharstein
77 | According to the Matlab source code of Deqing Sun
78 |
79 | Args:
80 | u (np.ndarray): Input horizontal flow of shape [H,W]
81 | v (np.ndarray): Input vertical flow of shape [H,W]
82 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
83 |
84 | Returns:
85 | np.ndarray: Flow visualization image of shape [H,W,3]
86 | """
87 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
88 | colorwheel = make_colorwheel() # shape [55x3]
89 | ncols = colorwheel.shape[0]
90 | rad = np.sqrt(np.square(u) + np.square(v))
91 | a = np.arctan2(-v, -u)/np.pi
92 | fk = (a+1) / 2*(ncols-1)
93 | k0 = np.floor(fk).astype(np.int32)
94 | k1 = k0 + 1
95 | k1[k1 == ncols] = 0
96 | f = fk - k0
97 | for i in range(colorwheel.shape[1]):
98 | tmp = colorwheel[:,i]
99 | col0 = tmp[k0] / 255.0
100 | col1 = tmp[k1] / 255.0
101 | col = (1-f)*col0 + f*col1
102 | idx = (rad <= 1)
103 | col[idx] = 1 - rad[idx] * (1-col[idx])
104 | col[~idx] = col[~idx] * 0.75 # out of range
105 | # Note the 2-i => BGR instead of RGB
106 | ch_idx = 2-i if convert_to_bgr else i
107 | flow_image[:,:,ch_idx] = np.floor(255 * col)
108 | return flow_image
109 |
110 |
111 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
112 | """
113 | Expects a two dimensional flow image of shape.
114 |
115 | Args:
116 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
117 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
118 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
119 |
120 | Returns:
121 | np.ndarray: Flow visualization image of shape [H,W,3]
122 | """
123 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
124 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
125 | if clip_flow is not None:
126 | flow_uv = np.clip(flow_uv, 0, clip_flow)
127 | u = flow_uv[:,:,0]
128 | v = flow_uv[:,:,1]
129 | rad = np.sqrt(np.square(u) + np.square(v))
130 | rad_max = np.max(rad)
131 | epsilon = 1e-5
132 | u = u / (rad_max + epsilon)
133 | v = v / (rad_max + epsilon)
134 | return flow_uv_to_colors(u, v, convert_to_bgr)
135 |
136 | def viz_batch_mask_np(imgs, flos, fusions=None, masks=None, arrow_step=32, save_path='tmp.jpg'):
137 | '''
138 | input: imgs = [image1, image2]
139 | fusions = [1->2, 2->1]
140 | flos = [1->2, 2->1]
141 | masks = [1->2, 2->1]
142 | flow_past image2_arrow_past fusion1 mask_past
143 | flow_future image1_arrow_future fusion2 mask_future
144 | '''
145 | image1 = np.array(imgs[0])
146 | image2 = np.array(imgs[1])
147 |
148 | show_img_past = []
149 | #show_img_past.append(image1)
150 | img2_past = arrowon(image1, flos[0], arrow_step)
151 | show_img_past.append(img2_past)
152 | show_img_past.append(flow_to_image(flos[0]))
153 | if fusions is not None:
154 | show_img_past.append(fusions[0])
155 | if masks is not None:
156 | show_img_past.append(np.tile(masks[0]*255, (1, 1, 3)))
157 |
158 | show_img_future = []
159 | #show_img_future.append(image2)
160 | img2_past = arrowon(image2, flos[1], arrow_step)
161 | show_img_future.append(img2_past)
162 | show_img_future.append(flow_to_image(flos[1]))
163 | if fusions is not None:
164 | show_img_future.append(fusions[1])
165 | if masks is not None:
166 | show_img_future.append(np.tile(masks[1]*255, (1, 1, 3)))
167 |
168 | #img_flo = np.concatenate(show_img, axis=1)
169 | show_past = np.concatenate(show_img_past, axis=1)
170 | show_future = np.concatenate(show_img_future, axis=1)
171 | show_img = np.concatenate([show_past, show_future], axis=0)
172 |
173 | cv2.imwrite(save_path, show_img[:, :, [2,1,0]])
174 |
175 | def viz_batch_mask(imgs, fusions, flos, masks, save_path):
176 | '''
177 | input: imgs = [image1, image2, image3]
178 | fusions = [1->2, 3->2]
179 | flos = [flow_past, flow_future]
180 | masks = [mask_past, mask_future]
181 | image1 image2_arrow_past fusion1 flow_past mask_past
182 | image3 image2_arrow_future fusion2 flow_future mask_future
183 | '''
184 | image2 = imgs[1][0].permute(1,2,0).cpu().numpy()
185 |
186 | show_img_past = []
187 | show_img_past.append(imgs[0][0].permute(1,2,0).cpu().numpy())
188 | img2_past = arrowon(image2, flos[0][0], 32)
189 | show_img_past.append(img2_past)
190 | show_img_past.append(fusions[0][0].permute(1,2,0).cpu().numpy())
191 | show_img_past.append(flow_to_image(flos[0][0].permute(1,2,0).cpu().numpy()))
192 | show_img_past.append(np.tile(masks[0][0].permute(1,2,0).cpu().numpy()*255, (1, 1, 3)))
193 |
194 | show_img_future = []
195 | show_img_future.append(imgs[2][0].permute(1,2,0).cpu().numpy())
196 | img2_future = arrowon(image2, flos[1][0], 32)
197 | show_img_future.append(img2_future)
198 | show_img_future.append(fusions[1][0].permute(1,2,0).cpu().numpy())
199 | show_img_future.append(flow_to_image(flos[1][0].permute(1,2,0).cpu().numpy()))
200 | show_img_future.append(np.tile(masks[1][0].permute(1,2,0).cpu().numpy()*255, (1, 1, 3)))
201 |
202 | #img_flo = np.concatenate(show_img, axis=1)
203 | show_past = np.concatenate(show_img_past, axis=1)
204 | show_future = np.concatenate(show_img_future, axis=1)
205 | show_img = np.concatenate([show_past, show_future], axis=0)
206 |
207 | cv2.imwrite(save_path, show_img[:, :, [2,1,0]])
208 |
209 | def viz_batch(imgs, flo, save_path):
210 | show_img = []
211 | for img in imgs:
212 | img = img[0].permute(1,2,0).cpu().numpy()
213 | show_img.append(img)
214 |
215 | flo = flo[0].permute(1,2,0).cpu().numpy()
216 | show_img[0] = arrowon(show_img[0], flo, 32)
217 |
218 | # map flow to rgb image
219 | flo = flow_to_image(flo)
220 | show_img.append(flo)
221 | img_flo = np.concatenate(show_img, axis=1)
222 |
223 | cv2.imwrite(save_path, img_flo[:, :, [2,1,0]])
224 |
--------------------------------------------------------------------------------
/utils/mpi/rendering_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def transform_G_xyz(G, xyz, is_return_homo=False):
5 | """
6 |
7 | :param G: Bx4x4
8 | :param xyz: Bx3xN
9 | :return:
10 | """
11 | assert len(G.size()) == len(xyz.size())
12 | if len(G.size()) == 2:
13 | G_B44 = G.unsqueeze(0)
14 | xyz_B3N = xyz.unsqueeze(0)
15 | else:
16 | G_B44 = G
17 | xyz_B3N = xyz
18 | xyz_B4N = torch.cat((xyz_B3N, torch.ones_like(xyz_B3N[:, 0:1, :])), dim=1)
19 | G_xyz_B4N = torch.matmul(G_B44, xyz_B4N)
20 | if is_return_homo:
21 | return G_xyz_B4N
22 | else:
23 | return G_xyz_B4N[:, 0:3, :]
24 |
25 |
26 | def gather_pixel_by_pxpy(img, pxpy):
27 | """
28 |
29 | :param img: Bx3xHxW
30 | :param pxpy: Bx2xN
31 | :return:
32 | """
33 | with torch.no_grad():
34 | B, C, H, W = img.size()
35 | if pxpy.dtype == torch.float32:
36 | pxpy_int = torch.round(pxpy).to(torch.int64)
37 | pxpy_int = pxpy_int.to(torch.int64)
38 | pxpy_int[:, 0, :] = torch.clamp(pxpy_int[:, 0, :], min=0, max=W-1)
39 | pxpy_int[:, 1, :] = torch.clamp(pxpy_int[:, 1, :], min=0, max=H-1)
40 | pxpy_idx = pxpy_int[:, 0:1, :] + W * pxpy_int[:, 1:2, :] # Bx1xN_pt
41 | rgb = torch.gather(img.view(B, C, H * W), dim=2,
42 | index=pxpy_idx.repeat(1, C, 1)) # BxCxN_pt
43 | return rgb
44 |
45 |
46 | def uniformly_sample_disparity_from_bins(batch_size, disparity_np, device):
47 | """
48 | In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far)
49 | :param start:
50 | :param end:
51 | :param num_bins:
52 | :return:
53 | """
54 | assert disparity_np[0] > disparity_np[-1]
55 | S = disparity_np.shape[0] - 1
56 |
57 | B = batch_size
58 | bin_edges = torch.from_numpy(disparity_np).to(dtype=torch.float32, device=device) # S+1
59 | interval = bin_edges[1:] - bin_edges[0:-1] # S
60 | bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS
61 | # bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS
62 | interval = interval.unsqueeze(0).repeat(B, 1) # S -> BxS
63 |
64 | random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS
65 | disparity_array = bin_edges_start + interval * random_float
66 | return disparity_array # BxS
67 |
68 |
69 | def uniformly_sample_disparity_from_linspace_bins(batch_size, num_bins, start, end, device):
70 | """
71 | In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far)
72 | :param start:
73 | :param end:
74 | :param num_bins:
75 | :return:
76 | """
77 | assert start > end
78 |
79 | B, S = batch_size, num_bins
80 | bin_edges = torch.linspace(start, end, num_bins+1, dtype=torch.float32, device=device) # S+1
81 | interval = bin_edges[1] - bin_edges[0] # scalar
82 | bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS
83 | # bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS
84 |
85 | random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS
86 | disparity_array = bin_edges_start + interval * random_float
87 | return disparity_array # BxS
88 |
89 |
90 | def sample_pdf(values, weights, N_samples):
91 | """
92 | draw samples from distribution approximated by values and weights.
93 | the probability distribution can be denoted as weights = p(values)
94 | :param values: Bx1xNxS
95 | :param weights: Bx1xNxS
96 | :param N_samples: number of sample to draw
97 | :return:
98 | """
99 | B, N, S = weights.size(0), weights.size(2), weights.size(3)
100 | assert values.size() == (B, 1, N, S)
101 |
102 | # convert values to bin edges
103 | bin_edges = (values[:, :, :, 1:] + values[:, :, :, :-1]) * 0.5 # Bx1xNxS-1
104 | bin_edges = torch.cat((values[:, :, :, 0:1],
105 | bin_edges,
106 | values[:, :, :, -1:]), dim=3) # Bx1xNxS+1
107 |
108 | pdf = weights / (torch.sum(weights, dim=3, keepdim=True) + 1e-5) # Bx1xNxS
109 | cdf = torch.cumsum(pdf, dim=3) # Bx1xNxS
110 | cdf = torch.cat((torch.zeros((B, 1, N, 1), dtype=cdf.dtype, device=cdf.device),
111 | cdf), dim=3) # Bx1xNxS+1
112 |
113 | # uniform sample over the cdf values
114 | u = torch.rand((B, 1, N, N_samples), dtype=weights.dtype, device=weights.device) # Bx1xNxN_samples
115 |
116 | # get the index on the cdf array
117 | cdf_idx = torch.searchsorted(cdf, u, right=True) # Bx1xNxN_samples
118 | cdf_idx_lower = torch.clamp(cdf_idx-1, min=0) # Bx1xNxN_samples
119 | cdf_idx_upper = torch.clamp(cdf_idx, max=S) # Bx1xNxN_samples
120 |
121 | # linear approximation for each bin
122 | cdf_idx_lower_upper = torch.cat((cdf_idx_lower, cdf_idx_upper), dim=3) # Bx1xNx(N_samplesx2)
123 | cdf_bounds_N2 = torch.gather(cdf, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2)
124 | cdf_bounds = torch.stack((cdf_bounds_N2[..., 0:N_samples], cdf_bounds_N2[..., N_samples:]), dim=4)
125 | bin_bounds_N2 = torch.gather(bin_edges, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2)
126 | bin_bounds = torch.stack((bin_bounds_N2[..., 0:N_samples], bin_bounds_N2[..., N_samples:]), dim=4)
127 |
128 | # avoid zero cdf_intervals
129 | cdf_intervals = cdf_bounds[:, :, :, :, 1] - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples
130 | bin_intervals = bin_bounds[:, :, :, :, 1] - bin_bounds[:, :, :, :, 0] # Bx1xNxN_samples
131 | u_cdf_lower = u - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples
132 | # there is the case that cdf_interval = 0, caused by the cdf_idx_lower/upper clamp above, need special handling
133 | t = u_cdf_lower / torch.clamp(cdf_intervals, min=1e-5)
134 | t = torch.where(cdf_intervals <= 1e-4,
135 | torch.full_like(u_cdf_lower, 0.5),
136 | t)
137 |
138 | samples = bin_bounds[:, :, :, :, 0] + t*bin_intervals
139 | return samples
140 |
--------------------------------------------------------------------------------
/utils/transform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import glob
4 | from utils import flow_viz
5 | import torch
6 | from torch.nn import functional as F
7 |
8 | def gen_random_perspective():
9 | '''
10 | generate a random 3x3 perspective matrix
11 | '''
12 | init_M = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
13 | noise = np.random.normal(0, 0.0001, 9)
14 | noise = np.reshape(noise, [3, 3])
15 | noise[2, 2] = 0
16 | return init_M + noise
17 |
18 | def get_flow(img, M):
19 | '''
20 | use img shape and M to calculate flow
21 | return flow
22 | '''
23 | ## calculate flow
24 | x = np.linspace(0, img.shape[1]-1, img.shape[1])
25 | y = np.linspace(0, img.shape[0]-1, img.shape[0])
26 | xx, yy = np.meshgrid(x, y)
27 | coords = np.stack([xx, yy, np.ones_like(xx)], axis=0)
28 | #new_coords = np.einsum('ij,jkl->ikl', M, np.transpose(coords, (2, 0, 1)))
29 | new_coords = np.einsum('ij,jkl->ikl', M, coords)
30 | xx2 = new_coords[0, :, :] / new_coords[2, :, :]
31 | yy2 = new_coords[1, :, :] / new_coords[2, :, :]
32 | #xx2 = xx2 * img.shape[1]
33 | #yy2 = yy2 * img.shape[0]
34 | #import pdb; pdb.set_trace()
35 | xx2 = xx2.astype(np.float32)
36 | yy2 = yy2.astype(np.float32)
37 | flow_x = xx2-xx #* img.shape[1]
38 | flow_y = yy2-yy #* img.shape[0]
39 | flow_x = flow_x.astype(np.float32)
40 | flow_y = flow_y.astype(np.float32)
41 | return np.stack([flow_x, flow_y], axis=2)
42 |
43 | def transform(img, flow):
44 | '''
45 | remap image according to the M.
46 | return warped img and flow
47 | '''
48 |
49 | flow = flow.astype(np.float32)
50 | flow_x = flow[:, :, 0]
51 | flow_y = flow[:, :, 1]
52 |
53 | ## warp img by flow
54 | fh, fw = flow_x.shape
55 | add = np.mgrid[0:fh,0:fw].astype(np.float32);
56 |
57 | img_flow = cv2.remap(img, flow_y+add[1,:,:], flow_x+add[0,:,:], cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
58 | return img_flow
59 |
60 | def warp(x, flo):
61 | """
62 | warp an image/tensor (im2) back to im1, according to the optical flow
63 |
64 | x: [B, C, H, W] (im2)
65 | flo: [B, 2, H, W] flow
66 |
67 | """
68 | B, C, H, W = x.size()
69 | # mesh grid
70 | xx = torch.arange(0, W).view(1 ,-1).repeat(H ,1)
71 | yy = torch.arange(0, H).view(-1 ,1).repeat(1 ,W)
72 | xx = xx.view(1 ,1 ,H ,W).repeat(B ,1 ,1 ,1)
73 | yy = yy.view(1 ,1 ,H ,W).repeat(B ,1 ,1 ,1)
74 | grid = torch.cat((xx ,yy) ,1).float()
75 | if x.is_cuda:
76 | grid = grid.cuda()
77 | vgrid = torch.autograd.Variable(grid).detach() + flo
78 |
79 | # scale grid to [-1,1]
80 | vgrid[: ,0 ,: ,:] = 2.0 *vgrid[: ,0 ,: ,:].clone() / max( W -1 ,1 ) -1.0
81 | vgrid[: ,1 ,: ,:] = 2.0 *vgrid[: ,1 ,: ,:].clone() / max( H -1 ,1 ) -1.0
82 |
83 | vgrid = vgrid.permute(0 ,2 ,3 ,1)
84 | flo = flo.permute(0 ,2 ,3 ,1)
85 | output = F.grid_sample(x, vgrid)
86 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda()
87 | mask = F.grid_sample(mask, vgrid).detach()
88 |
89 | #mask[mask <0.9999] = 0
90 | #mask[mask >0] = 1
91 |
92 | return output, mask
93 |
94 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95 | backwarp_tenGrid = {}
96 |
97 | def warp_rife(tenInput, tenFlow):
98 | k = (str(tenFlow.device), str(tenFlow.size()))
99 | if k not in backwarp_tenGrid:
100 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
101 | 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
102 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
103 | 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
104 | backwarp_tenGrid[k] = torch.cat(
105 | [tenHorizontal, tenVertical], 1).to(device)
106 |
107 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
108 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
109 |
110 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
111 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
112 |
113 | if __name__ == "__main__":
114 |
115 | import sys
116 | #img_path = '/share/boyuan/Data/snow_good_imgs/01e1fbc1a563c467018370037ebf6cd3ae_258/000006.png'
117 | img_path = '/share/boyuan/Data/snow_good_imgs/skiing_0825/000006.png'
118 | seg_prefix = '/share/boyuan/Projects/mmdetection/output/skiing_0825/000006'
119 | img = cv2.imread(img_path)
120 |
121 | #all_mask = np.zeros((img.shape[0], img.shape[1]))
122 | res_flow = np.zeros((img.shape[0], img.shape[1], 2))
123 | for m_path in glob.glob(seg_prefix+"-*"):
124 | sub_mask = cv2.imread(m_path)[:, :, 0]
125 | sub_mask = sub_mask / 255
126 | sub_mask = sub_mask.astype(np.uint8)
127 | M = gen_random_perspective()
128 | sub_flow = get_flow(img, M)
129 | res_flow[np.where(sub_mask==1)] = sub_flow[np.where(sub_mask==1)]
130 |
131 | res_img = transform(img, res_flow)
132 |
133 | img_flo = flow_viz.flow_to_image(res_flow)
134 | out = np.concatenate([img, img_flo, res_img], axis=1)
135 | print(np.max(res_flow))
136 | cv2.imwrite('tmp/warp.jpg', out)
137 | cv2.imwrite('/share/boyuan/Projects/RAFT/tmp/skiing_per/000007.png', res_img)
138 |
139 |
140 |
141 |
--------------------------------------------------------------------------------
/vis_flow.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 |
5 | def readFlow(fn):
6 | """ Read .flo file in Middlebury format"""
7 | # Code adapted from:
8 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
9 |
10 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
11 | # print 'fn = %s'%(fn)
12 | with open(fn, 'rb') as f:
13 | magic = np.fromfile(f, np.float32, count=1)
14 | if 202021.25 != magic:
15 | print('Magic number incorrect. Invalid .flo file')
16 | return None
17 | else:
18 | w = np.fromfile(f, np.int32, count=1)
19 | h = np.fromfile(f, np.int32, count=1)
20 | # print 'Reading %d x %d flo file\n' % (w, h)
21 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
22 | # Reshape data into 3D array (columns, rows, bands)
23 | # The reshape here is for visualization, the original code is (w,h,2)
24 | return np.resize(data, (int(h), int(w), 2))
25 |
26 | base = "dataset/debug"
27 |
28 | if not os.path.exists(os.path.join(base, "vis")):
29 | os.mkdir(os.path.join(base, "vis"))
30 |
31 | for img in os.listdir(os.path.join(base, "src_images")):
32 | for r in range(4):
33 | image1 = cv2.imread(os.path.join(base, "src_images", img))
34 | image2 = cv2.imread(os.path.join(base, "dst_images", img.replace(".png", f"_{r}.png")))
35 | flow= readFlow(os.path.join(base, "flows", img.replace(".png", f"_{r}.flo")))
36 | print(flow.max(), flow.min(), flow.shape)
37 |
38 | H, W = image1.shape[:2]
39 | res = np.vstack([image1, image2])
40 |
41 | for _ in range(30):
42 |
43 | x1 = np.random.randint(W)
44 | y1 = np.random.randint(H)
45 | x2 = x1 + int(flow[y1, x1, 0])
46 | y2 = y1 + int(flow[y1, x1, 1]) + H
47 |
48 | cv2.line(res, (x1, y1), (x2, y2), (0,255,0), 2)
49 |
50 | cv2.imwrite(os.path.join(base, "vis", img.replace(".png", f"_{r}.png")), res)
--------------------------------------------------------------------------------
/warpback/networks.py:
--------------------------------------------------------------------------------
1 | '''
2 | this code is adapt from the EdgeConnect repo (https://github.com/knazeri/edge-connect)
3 | '''
4 |
5 |
6 | import torch
7 | import torch.nn as nn
8 | import os
9 |
10 |
11 | def get_edge_connect(weight_dir):
12 | inpaint_model = InpaintGenerator()
13 | inpaint_model_weight = torch.load(os.path.join(weight_dir, "InpaintingModel_gen.pth"))
14 | inpaint_model.load_state_dict(inpaint_model_weight["generator"])
15 | inpaint_model.eval()
16 |
17 | edge_model = EdgeGenerator()
18 | edge_model_weight = torch.load(os.path.join(weight_dir, "EdgeModel_gen.pth"))
19 | edge_model.load_state_dict(edge_model_weight["generator"])
20 | edge_model.eval()
21 |
22 | disp_model = InpaintGenerator(in_channels=2, out_channels=1)
23 | disp_model_weight = torch.load(os.path.join(weight_dir, "InpaintingModel_disp.pth"))
24 | disp_model.load_state_dict(disp_model_weight["generator"])
25 | disp_model.eval()
26 | return edge_model, inpaint_model, disp_model
27 |
28 |
29 | class BaseNetwork(nn.Module):
30 | def __init__(self):
31 | super(BaseNetwork, self).__init__()
32 |
33 | def init_weights(self, init_type='normal', gain=0.02):
34 | '''
35 | initialize network's weights
36 | init_type: normal | xavier | kaiming | orthogonal
37 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
38 | '''
39 |
40 | def init_func(m):
41 | classname = m.__class__.__name__
42 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
43 | if init_type == 'normal':
44 | nn.init.normal_(m.weight.data, 0.0, gain)
45 | elif init_type == 'xavier':
46 | nn.init.xavier_normal_(m.weight.data, gain=gain)
47 | elif init_type == 'kaiming':
48 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
49 | elif init_type == 'orthogonal':
50 | nn.init.orthogonal_(m.weight.data, gain=gain)
51 |
52 | if hasattr(m, 'bias') and m.bias is not None:
53 | nn.init.constant_(m.bias.data, 0.0)
54 |
55 | elif classname.find('BatchNorm2d') != -1:
56 | nn.init.normal_(m.weight.data, 1.0, gain)
57 | nn.init.constant_(m.bias.data, 0.0)
58 |
59 | self.apply(init_func)
60 |
61 |
62 | class InpaintGenerator(BaseNetwork):
63 | def __init__(self, residual_blocks=8, init_weights=True, in_channels=4, out_channels=3):
64 | super(InpaintGenerator, self).__init__()
65 |
66 | self.encoder = nn.Sequential(
67 | nn.ReflectionPad2d(3),
68 | nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0),
69 | nn.InstanceNorm2d(64, track_running_stats=False),
70 | nn.ReLU(True),
71 |
72 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
73 | nn.InstanceNorm2d(128, track_running_stats=False),
74 | nn.ReLU(True),
75 |
76 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
77 | nn.InstanceNorm2d(256, track_running_stats=False),
78 | nn.ReLU(True)
79 | )
80 |
81 | blocks = []
82 | for _ in range(residual_blocks):
83 | block = ResnetBlock(256, 2)
84 | blocks.append(block)
85 |
86 | self.middle = nn.Sequential(*blocks)
87 |
88 | self.decoder = nn.Sequential(
89 | nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
90 | nn.InstanceNorm2d(128, track_running_stats=False),
91 | nn.ReLU(True),
92 |
93 | nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
94 | nn.InstanceNorm2d(64, track_running_stats=False),
95 | nn.ReLU(True),
96 |
97 | nn.ReflectionPad2d(3),
98 | nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=7, padding=0),
99 | )
100 |
101 | if init_weights:
102 | self.init_weights()
103 |
104 | def forward(self, x):
105 | x = self.encoder(x)
106 | x = self.middle(x)
107 | x = self.decoder(x)
108 | x = (torch.tanh(x) + 1) / 2
109 |
110 | return x
111 |
112 |
113 | class EdgeGenerator(BaseNetwork):
114 | def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True):
115 | super(EdgeGenerator, self).__init__()
116 |
117 | self.encoder = nn.Sequential(
118 | nn.ReflectionPad2d(3),
119 | spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm),
120 | nn.InstanceNorm2d(64, track_running_stats=False),
121 | nn.ReLU(True),
122 |
123 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
124 | nn.InstanceNorm2d(128, track_running_stats=False),
125 | nn.ReLU(True),
126 |
127 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm),
128 | nn.InstanceNorm2d(256, track_running_stats=False),
129 | nn.ReLU(True)
130 | )
131 |
132 | blocks = []
133 | for _ in range(residual_blocks):
134 | block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm)
135 | blocks.append(block)
136 |
137 | self.middle = nn.Sequential(*blocks)
138 |
139 | self.decoder = nn.Sequential(
140 | spectral_norm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
141 | nn.InstanceNorm2d(128, track_running_stats=False),
142 | nn.ReLU(True),
143 |
144 | spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm),
145 | nn.InstanceNorm2d(64, track_running_stats=False),
146 | nn.ReLU(True),
147 |
148 | nn.ReflectionPad2d(3),
149 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0),
150 | )
151 |
152 | if init_weights:
153 | self.init_weights()
154 |
155 | def forward(self, x):
156 | x = self.encoder(x)
157 | x = self.middle(x)
158 | x = self.decoder(x)
159 | x = torch.sigmoid(x)
160 | return x
161 |
162 |
163 | class ResnetBlock(nn.Module):
164 | def __init__(self, dim, dilation=1, use_spectral_norm=False):
165 | super(ResnetBlock, self).__init__()
166 | self.conv_block = nn.Sequential(
167 | nn.ReflectionPad2d(dilation),
168 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm),
169 | nn.InstanceNorm2d(dim, track_running_stats=False),
170 | nn.ReLU(True),
171 |
172 | nn.ReflectionPad2d(1),
173 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm),
174 | nn.InstanceNorm2d(dim, track_running_stats=False),
175 | )
176 |
177 | def forward(self, x):
178 | out = x + self.conv_block(x)
179 | return out
180 |
181 |
182 | def spectral_norm(module, mode=True):
183 | if mode:
184 | return nn.utils.spectral_norm(module)
185 |
186 | return module
--------------------------------------------------------------------------------
/warpback/stage1_dataset.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(".")
3 | sys.path.append("..")
4 | import os
5 | import glob
6 | import math
7 | import torch
8 | import torch.nn.functional as F
9 | from torch.utils.data.dataset import Dataset
10 | from torch.utils.data.dataloader import DataLoader, default_collate
11 | from torchvision.utils import save_image
12 |
13 | from warpback.utils import (
14 | RGBDRenderer,
15 | image_to_tensor,
16 | disparity_to_tensor,
17 | transformation_from_parameters,
18 | )
19 |
20 |
21 | class WarpBackStage1Dataset(Dataset):
22 | def __init__(
23 | self,
24 | data_root,
25 | width=384,
26 | height=256,
27 | depth_dir_name="dpt_depth",
28 | device="cuda", # device of mesh renderer
29 | trans_range={"x":0.2, "y":-1, "z":-1, "a":-1, "b":-1, "c":-1}, # xyz for translation, abc for euler angle
30 | ):
31 | self.data_root = data_root
32 | self.depth_dir_name = depth_dir_name
33 | self.renderer = RGBDRenderer(device)
34 | self.width = width
35 | self.height = height
36 | self.device = device
37 | self.trans_range = trans_range
38 | self.image_path_list = glob.glob(os.path.join(self.data_root, "*.jpg"))
39 | self.image_path_list += glob.glob(os.path.join(self.data_root, "*.png"))
40 |
41 | # set intrinsics
42 | self.K = torch.tensor([
43 | [0.58, 0, 0.5],
44 | [0, 0.58, 0.5],
45 | [0, 0, 1]
46 | ]).to(device)
47 |
48 | def __len__(self):
49 | return len(self.image_path_list)
50 |
51 | def __getitem__(self, idx):
52 | image_path = self.image_path_list[idx]
53 | image_name = os.path.splitext(os.path.basename(image_path))[0]
54 | disp_path = os.path.join(self.data_root, self.depth_dir_name, "%s.png" % image_name)
55 |
56 | image = image_to_tensor(image_path, unsqueeze=False) # [3,h,w]
57 | disp = disparity_to_tensor(disp_path, unsqueeze=False) # [1,h,w]
58 |
59 | # do some data augmentation, ensure the rgbd spatial resolution is (self.height, self.width)
60 | image, disp = self.preprocess_rgbd(image, disp)
61 |
62 | return image, disp
63 |
64 | def preprocess_rgbd(self, image, disp):
65 | # NOTE
66 | # (1) here we directly resize the image to the target size (self.height, self.width)
67 | # a better way is to first crop a random patch from the image according to the height-width ratio
68 | # then resize this patch to the target size
69 | # (2) another suggestion is, add some code to filter the depth map to reduce artifacts around
70 | # depth discontinuities
71 | image = F.interpolate(image.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0)
72 | disp = F.interpolate(disp.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0)
73 | return image, disp
74 |
75 | def get_rand_ext(self, bs):
76 | x, y, z = self.trans_range['x'], self.trans_range['y'], self.trans_range['z']
77 | a, b, c = self.trans_range['a'], self.trans_range['b'], self.trans_range['c']
78 | cix = self.rand_tensor(x, bs)
79 | ciy = self.rand_tensor(y, bs)
80 | ciz = self.rand_tensor(z, bs)
81 | aix = self.rand_tensor(math.pi / a, bs)
82 | aiy = self.rand_tensor(math.pi / b, bs)
83 | aiz = self.rand_tensor(math.pi / c, bs)
84 |
85 | axisangle = torch.cat([aix, aiy, aiz], dim=-1) # [b,1,3]
86 | translation = torch.cat([cix, ciy, ciz], dim=-1)
87 |
88 | cam_ext = transformation_from_parameters(axisangle, translation) # [b,4,4]
89 | cam_ext_inv = torch.inverse(cam_ext) # [b,4,4]
90 | return cam_ext[:, :-1], cam_ext_inv[:, :-1]
91 |
92 | def rand_tensor(self, r, l):
93 | '''
94 | return a tensor of size [l], where each element is in range [-r,-r/2] or [r/2,r]
95 | '''
96 | if r < 0: # we can set a negtive value in self.trans_range to avoid random transformation
97 | return torch.zeros((l, 1, 1))
98 | rand = torch.rand((l, 1, 1))
99 | sign = 2 * (torch.randn_like(rand) > 0).float() - 1
100 | return sign * (r / 2 + r / 2 * rand)
101 |
102 | def collect_data(self, batch):
103 | batch = default_collate(batch)
104 | image, disp = batch
105 | image = image.to(self.device)
106 | disp = disp.to(self.device)
107 | rgbd = torch.cat([image, disp], dim=1) # [b,4,h,w]
108 | b = image.shape[0]
109 |
110 | cam_int = self.K.repeat(b, 1, 1) # [b,3,3]
111 | cam_ext, cam_ext_inv = self.get_rand_ext(b) # [b,3,4]
112 | cam_ext = cam_ext.to(self.device)
113 | cam_ext_inv = cam_ext_inv.to(self.device)
114 |
115 | # warp to a random novel view
116 | mesh = self.renderer.construct_mesh(rgbd, cam_int)
117 | warp_image, warp_disp, warp_mask = self.renderer.render_mesh(mesh, cam_int, cam_ext)
118 |
119 | # warp back to the original view
120 | warp_rgbd = torch.cat([warp_image, warp_disp], dim=1) # [b,4,h,w]
121 | warp_mesh = self.renderer.construct_mesh(warp_rgbd, cam_int)
122 | warp_back_image, warp_back_disp, mask = self.renderer.render_mesh(warp_mesh, cam_int, cam_ext_inv)
123 |
124 | # NOTE
125 | # (1) to train the inpainting network, you only need image, disp, and mask
126 | # (2) you can add some morphological operation to refine the mask
127 | return {
128 | "rgb": image,
129 | "disp": disp,
130 | "mask": mask,
131 | "warp_rgb": warp_image,
132 | "warp_disp": warp_disp,
133 | "warp_back_rgb": warp_back_image,
134 | "warp_back_disp": warp_back_disp,
135 | }
136 |
137 |
138 | if __name__ == "__main__":
139 | bs = 8
140 | data = WarpBackStage1Dataset(
141 | data_root="warpback/toydata",
142 | )
143 | loader = DataLoader(
144 | dataset=data,
145 | batch_size=bs,
146 | shuffle=True,
147 | collate_fn=data.collect_data,
148 | )
149 | for idx, batch in enumerate(loader):
150 | image, disp, mask = batch["rgb"], batch["disp"], batch["mask"]
151 | w_image, w_disp = batch["warp_rgb"], batch["warp_disp"]
152 | wb_image, wb_disp = batch["warp_back_rgb"], batch["warp_back_disp"]
153 | visual = torch.cat([
154 | image,
155 | disp.repeat(1, 3, 1, 1),
156 | mask.repeat(1, 3, 1, 1),
157 | wb_image,
158 | wb_disp.repeat(1, 3, 1, 1),
159 | w_image,
160 | w_disp.repeat(1, 3, 1, 1),
161 | ], dim=0)
162 | save_image(visual, "debug/stage1-%03d.jpg" % idx, nrow=bs)
163 |
--------------------------------------------------------------------------------
/warpback/stage2_dataset.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(".")
3 | sys.path.append("..")
4 | import os
5 | import glob
6 | import math
7 | import numpy as np
8 | from skimage.feature import canny
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.utils.data.dataset import Dataset
12 | from torch.utils.data.dataloader import DataLoader, default_collate
13 | from torchvision.utils import save_image
14 | from torchvision import transforms
15 |
16 | from warpback.utils import (
17 | RGBDRenderer,
18 | image_to_tensor,
19 | disparity_to_tensor,
20 | transformation_from_parameters,
21 | )
22 | from warpback.networks import get_edge_connect
23 |
24 |
25 | class WarpBackStage2Dataset(Dataset):
26 | def __init__(
27 | self,
28 | data_root,
29 | width=384,
30 | height=256,
31 | depth_dir_name="dpt_depth",
32 | device="cuda", # device of mesh renderer
33 | trans_range={"x":0.2, "y":-1, "z":-1, "a":-1, "b":-1, "c":-1}, # xyz for translation, abc for euler angle
34 | ec_weight_dir="warpback/ecweight",
35 | ):
36 | self.data_root = data_root
37 | self.depth_dir_name = depth_dir_name
38 | self.renderer = RGBDRenderer(device)
39 | self.width = width
40 | self.height = height
41 | self.device = device
42 | self.trans_range = trans_range
43 | self.image_path_list = glob.glob(os.path.join(self.data_root, "*.jpg"))
44 | self.image_path_list += glob.glob(os.path.join(self.data_root, "*.png"))
45 |
46 | # get Stage-1 pretrained inpainting network
47 | self.edge_model, self.inpaint_model, self.disp_model = get_edge_connect(ec_weight_dir)
48 | self.edge_model = self.edge_model.to(self.device)
49 | self.inpaint_model = self.inpaint_model.to(self.device)
50 | self.disp_model = self.disp_model.to(self.device)
51 |
52 | # set intrinsics
53 | self.K = torch.tensor([
54 | [0.58, 0, 0.5],
55 | [0, 0.58, 0.5],
56 | [0, 0, 1]
57 | ]).to(device)
58 |
59 | def __len__(self):
60 | return len(self.image_path_list)
61 |
62 | def __getitem__(self, idx):
63 | image_path = self.image_path_list[idx]
64 | image_name = os.path.splitext(os.path.basename(image_path))[0]
65 | disp_path = os.path.join(self.data_root, self.depth_dir_name, "%s.png" % image_name)
66 |
67 | image = image_to_tensor(image_path, unsqueeze=False) # [3,h,w]
68 | disp = disparity_to_tensor(disp_path, unsqueeze=False) # [1,h,w]
69 |
70 | # do some data augmentation, ensure the rgbd spatial resolution is (self.height, self.width)
71 | image, disp = self.preprocess_rgbd(image, disp)
72 |
73 | return image, disp
74 |
75 | def preprocess_rgbd(self, image, disp):
76 | # NOTE
77 | # (1) here we directly resize the image to the target size (self.height, self.width)
78 | # a better way is to first crop a random patch from the image according to the height-width ratio
79 | # then resize this patch to the target size
80 | # (2) another suggestion is, add some code to filter the depth map to reduce artifacts around
81 | # depth discontinuities
82 | image = F.interpolate(image.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0)
83 | disp = F.interpolate(disp.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0)
84 | return image, disp
85 |
86 | def get_rand_ext(self, bs):
87 | x, y, z = self.trans_range['x'], self.trans_range['y'], self.trans_range['z']
88 | a, b, c = self.trans_range['a'], self.trans_range['b'], self.trans_range['c']
89 | cix = self.rand_tensor(x, bs)
90 | ciy = self.rand_tensor(y, bs)
91 | ciz = self.rand_tensor(z, bs)
92 | aix = self.rand_tensor(math.pi / a, bs)
93 | aiy = self.rand_tensor(math.pi / b, bs)
94 | aiz = self.rand_tensor(math.pi / c, bs)
95 |
96 | axisangle = torch.cat([aix, aiy, aiz], dim=-1) # [b,1,3]
97 | translation = torch.cat([cix, ciy, ciz], dim=-1)
98 |
99 | cam_ext = transformation_from_parameters(axisangle, translation) # [b,4,4]
100 | cam_ext_inv = torch.inverse(cam_ext) # [b,4,4]
101 | return cam_ext[:, :-1], cam_ext_inv[:, :-1]
102 |
103 | def rand_tensor(self, r, l):
104 | '''
105 | return a tensor of size [l], where each element is in range [-r,-r/2] or [r/2,r]
106 | '''
107 | if r < 0: # we can set a negtive value in self.trans_range to avoid random transformation
108 | return torch.zeros((l, 1, 1))
109 | rand = torch.rand((l, 1, 1))
110 | sign = 2 * (torch.randn_like(rand) > 0).float() - 1
111 | return sign * (r / 2 + r / 2 * rand)
112 |
113 | def inpaint(self, image, disp, mask):
114 | image_gray = transforms.Grayscale()(image)
115 | edge = self.get_edge(image_gray, mask)
116 |
117 | mask_hole = 1 - mask
118 |
119 | # inpaint edge
120 | edge_model_input = torch.cat([image_gray, edge, mask_hole], dim=1) # [b,4,h,w]
121 | edge_inpaint = self.edge_model(edge_model_input) # [b,1,h,w]
122 |
123 | # inpaint RGB
124 | inpaint_model_input = torch.cat([image + mask_hole, edge_inpaint], dim=1)
125 | image_inpaint = self.inpaint_model(inpaint_model_input)
126 | image_merged = image * (1 - mask_hole) + image_inpaint * mask_hole
127 |
128 | # inpaint Disparity
129 | disp_model_input = torch.cat([disp + mask_hole, edge_inpaint], dim=1)
130 | disp_inpaint = self.disp_model(disp_model_input)
131 | disp_merged = disp * (1 - mask_hole) + disp_inpaint * mask_hole
132 |
133 | return image_merged, disp_merged
134 |
135 | def get_edge(self, image_gray, mask):
136 | image_gray_np = image_gray.squeeze(1).cpu().numpy() # [b,h,w]
137 | mask_bool_np = np.array(mask.squeeze(1).cpu(), dtype=np.bool_) # [b,h,w]
138 | edges = []
139 | for i in range(mask.shape[0]):
140 | cur_edge = canny(image_gray_np[i], sigma=2, mask=mask_bool_np[i])
141 | edges.append(torch.from_numpy(cur_edge).unsqueeze(0)) # [1,h,w]
142 | edge = torch.cat(edges, dim=0).unsqueeze(1).float() # [b,1,h,w]
143 | return edge.to(self.device)
144 |
145 | def collect_data(self, batch):
146 | batch = default_collate(batch)
147 | image, disp = batch
148 | image = image.to(self.device)
149 | disp = disp.to(self.device)
150 | rgbd = torch.cat([image, disp], dim=1) # [b,4,h,w]
151 | b = image.shape[0]
152 |
153 | cam_int = self.K.repeat(b, 1, 1) # [b,3,3]
154 | cam_ext, cam_ext_inv = self.get_rand_ext(b) # [b,3,4]
155 | cam_ext = cam_ext.to(self.device)
156 | cam_ext_inv = cam_ext_inv.to(self.device)
157 |
158 | # warp to a random novel view and inpaint the holes
159 | # as the source view (input view) to the single-view view synthesis method
160 | mesh = self.renderer.construct_mesh(rgbd, cam_int)
161 | warp_image, warp_disp, warp_mask = self.renderer.render_mesh(mesh, cam_int, cam_ext)
162 |
163 | with torch.no_grad():
164 | src_image, src_disp = self.inpaint(warp_image, warp_disp, warp_mask)
165 |
166 | return {
167 | "src_rgb": src_image,
168 | "src_disp": src_disp,
169 | "tgt_rgb": image,
170 | "tgt_disp": disp,
171 | "warp_rgb": warp_image,
172 | "warp_disp": warp_disp,
173 | "cam_int": cam_int, # src and tgt view share the same intrinsic
174 | "cam_ext": cam_ext_inv,
175 | }
176 |
177 |
178 | if __name__ == "__main__":
179 | bs = 8
180 | data = WarpBackStage2Dataset(
181 | data_root="warpback/toydata",
182 | )
183 | loader = DataLoader(
184 | dataset=data,
185 | batch_size=bs,
186 | shuffle=True,
187 | collate_fn=data.collect_data,
188 | )
189 | for idx, batch in enumerate(loader):
190 | src_rgb, src_disp = batch["src_rgb"], batch["src_disp"]
191 | tgt_rgb, tgt_disp = batch["tgt_rgb"], batch["tgt_disp"]
192 | warp_rgb, warp_disp = batch["warp_rgb"], batch["warp_disp"]
193 | visual = torch.cat([
194 | warp_rgb,
195 | warp_disp.repeat(1, 3, 1, 1),
196 | src_rgb,
197 | src_disp.repeat(1, 3, 1, 1),
198 | tgt_rgb,
199 | tgt_disp.repeat(1, 3, 1, 1),
200 | ], dim=0)
201 | save_image(visual, "debug/stage2-%03d.jpg" % idx, nrow=bs)
202 |
--------------------------------------------------------------------------------
/write_flow.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 |
8 | cv2.setNumThreads(0)
9 | cv2.ocl.setUseOpenCL(False)
10 |
11 | TAG_CHAR = np.array([202021.25], np.float32)
12 |
13 |
14 | def readFlow(fn):
15 | """ Read .flo file in Middlebury format"""
16 | # Code adapted from:
17 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
18 |
19 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
20 | # print 'fn = %s'%(fn)
21 | with open(fn, 'rb') as f:
22 | magic = np.fromfile(f, np.float32, count=1)
23 | if 202021.25 != magic:
24 | print('Magic number incorrect. Invalid .flo file')
25 | return None
26 | else:
27 | w = np.fromfile(f, np.int32, count=1)
28 | h = np.fromfile(f, np.int32, count=1)
29 | # print 'Reading %d x %d flo file\n' % (w, h)
30 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
31 | # Reshape data into 3D array (columns, rows, bands)
32 | # The reshape here is for visualization, the original code is (w,h,2)
33 | return np.resize(data, (int(h), int(w), 2))
34 |
35 |
36 | def readPFM(file):
37 | file = open(file, 'rb')
38 |
39 | color = None
40 | width = None
41 | height = None
42 | scale = None
43 | endian = None
44 |
45 | header = file.readline().rstrip()
46 | if header == b'PF':
47 | color = True
48 | elif header == b'Pf':
49 | color = False
50 | else:
51 | raise Exception('Not a PFM file.')
52 |
53 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
54 | if dim_match:
55 | width, height = map(int, dim_match.groups())
56 | else:
57 | raise Exception('Malformed PFM header.')
58 |
59 | scale = float(file.readline().rstrip())
60 | if scale < 0: # little-endian
61 | endian = '<'
62 | scale = -scale
63 | else:
64 | endian = '>' # big-endian
65 |
66 | data = np.fromfile(file, endian + 'f')
67 | shape = (height, width, 3) if color else (height, width)
68 |
69 | data = np.reshape(data, shape)
70 | data = np.flipud(data)
71 | return data
72 |
73 |
74 | def writeFlow(filename, uv, v=None):
75 | """ Write optical flow to file.
76 |
77 | If v is None, uv is assumed to contain both u and v channels,
78 | stacked in depth.
79 | Original code by Deqing Sun, adapted from Daniel Scharstein.
80 | """
81 | nBands = 2
82 |
83 | if v is None:
84 | assert (uv.ndim == 3)
85 | assert (uv.shape[2] == 2)
86 | u = uv[:, :, 0]
87 | v = uv[:, :, 1]
88 | else:
89 | u = uv
90 |
91 | assert (u.shape == v.shape)
92 | height, width = u.shape
93 | f = open(filename, 'wb')
94 | # write the header
95 | f.write(TAG_CHAR)
96 | np.array(width).astype(np.int32).tofile(f)
97 | np.array(height).astype(np.int32).tofile(f)
98 | # arrange into matrix form
99 | tmp = np.zeros((height, width * nBands))
100 | tmp[:, np.arange(width) * 2] = u
101 | tmp[:, np.arange(width) * 2 + 1] = v
102 | tmp.astype(np.float32).tofile(f)
103 | f.close()
104 |
105 |
106 | def readFlowKITTI(filename):
107 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
108 | flow = flow[:, :, ::-1].astype(np.float32)
109 | flow, valid = flow[:, :, :2], flow[:, :, 2]
110 | flow = (flow - 2 ** 15) / 64.0
111 | return flow, valid
112 |
113 |
114 | def readDispKITTI(filename):
115 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
116 | valid = disp > 0.0
117 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
118 | return flow, valid
119 |
120 |
121 | def writeFlowKITTI(filename, uv):
122 | uv = 64.0 * uv + 2 ** 15
123 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
124 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
125 | cv2.imwrite(filename, uv[..., ::-1])
126 |
127 |
128 | def read_gen(file_name, pil=False):
129 | ext = splitext(file_name)[-1]
130 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
131 | return Image.open(file_name)
132 | elif ext == '.bin' or ext == '.raw':
133 | return np.load(file_name)
134 | elif ext == '.flo':
135 | return readFlow(file_name).astype(np.float32)
136 | elif ext == '.pfm':
137 | flow = readPFM(file_name).astype(np.float32)
138 | if len(flow.shape) == 2:
139 | return flow
140 | else:
141 | return flow[:, :, :-1]
142 | else:
143 | raise ValueError('wrong file type: %s' % ext)
144 |
145 | TAG_FLOAT = 202021.25
146 | def depth_read(filename):
147 | """ Read depth data from file, return as numpy array. """
148 | f = open(filename,'rb')
149 | check = np.fromfile(f,dtype=np.float32,count=1)[0]
150 | assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
151 | width = np.fromfile(f,dtype=np.int32,count=1)[0]
152 | height = np.fromfile(f,dtype=np.int32,count=1)[0]
153 | size = width*height
154 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height)
155 | depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width))
156 | return depth
157 |
158 | if __name__ == '__main__':
159 | import os
160 | from tqdm import tqdm
161 | base = '/Extra/guowx/data/dCOCO-mpi/coco/flow'
162 | out = '/Extra/guowx/data/dCOCO-mpi/coco/flo'
163 | if not os.path.exists(out):
164 | os.mkdir(out)
165 | for flow in tqdm(os.listdir(base)):
166 | flo = np.load(os.path.join(base, flow))
167 | pt = os.path.join(out, flow[:-4]+'.flo')
168 | if not os.path.exists(pt):
169 | writeFlow(pt, flo)
170 |
--------------------------------------------------------------------------------