├── .gitignore
├── LICENSE
├── RAFT.png
├── README.md
├── alt_cuda_corr
├── correlation.cpp
├── correlation_kernel.cu
└── setup.py
├── chairs_split.txt
├── core
├── __init__.py
├── corr.py
├── datasets.py
├── extractor.py
├── raft.py
├── update.py
└── utils
│ ├── __init__.py
│ ├── augmentor.py
│ ├── flow_viz.py
│ ├── frame_utils.py
│ └── utils.py
├── demo-frames
├── frame_0016.png
├── frame_0017.png
├── frame_0018.png
├── frame_0019.png
├── frame_0020.png
├── frame_0021.png
├── frame_0022.png
├── frame_0023.png
├── frame_0024.png
└── frame_0025.png
├── demo.py
├── download_models.sh
├── evaluate.py
├── train.py
├── train_mixed.sh
└── train_standard.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.egg-info
3 | dist
4 | datasets
5 | pytorch_env
6 | models
7 | build
8 | correlation.egg-info
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, princeton-vl
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/RAFT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/RAFT/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02/RAFT.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RAFT
2 | This repository contains the source code for our paper:
3 |
4 | [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
5 | ECCV 2020
6 | Zachary Teed and Jia Deng
7 |
8 |
9 |
10 | ## Requirements
11 | The code has been tested with PyTorch 1.6 and Cuda 10.1.
12 | ```Shell
13 | conda create --name raft
14 | conda activate raft
15 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
16 | ```
17 |
18 | ## Demos
19 | Pretrained models can be downloaded by running
20 | ```Shell
21 | ./download_models.sh
22 | ```
23 | or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing)
24 |
25 | You can demo a trained model on a sequence of frames
26 | ```Shell
27 | python demo.py --model=models/raft-things.pth --path=demo-frames
28 | ```
29 |
30 | ## Required Data
31 | To evaluate/train RAFT, you will need to download the required datasets.
32 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
33 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
34 | * [Sintel](http://sintel.is.tue.mpg.de/)
35 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
36 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional)
37 |
38 |
39 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder
40 |
41 | ```Shell
42 | ├── datasets
43 | ├── Sintel
44 | ├── test
45 | ├── training
46 | ├── KITTI
47 | ├── testing
48 | ├── training
49 | ├── devkit
50 | ├── FlyingChairs_release
51 | ├── data
52 | ├── FlyingThings3D
53 | ├── frames_cleanpass
54 | ├── frames_finalpass
55 | ├── optical_flow
56 | ```
57 |
58 | ## Evaluation
59 | You can evaluate a trained model using `evaluate.py`
60 | ```Shell
61 | python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision
62 | ```
63 |
64 | ## Training
65 | We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard
66 | ```Shell
67 | ./train_standard.sh
68 | ```
69 |
70 | If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
71 | ```Shell
72 | ./train_mixed.sh
73 | ```
74 |
75 | ## (Optional) Efficent Implementation
76 | You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
77 | ```Shell
78 | cd alt_cuda_corr && python setup.py install && cd ..
79 | ```
80 | and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
81 |
--------------------------------------------------------------------------------
/alt_cuda_corr/correlation.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | // CUDA forward declarations
5 | std::vector corr_cuda_forward(
6 | torch::Tensor fmap1,
7 | torch::Tensor fmap2,
8 | torch::Tensor coords,
9 | int radius);
10 |
11 | std::vector corr_cuda_backward(
12 | torch::Tensor fmap1,
13 | torch::Tensor fmap2,
14 | torch::Tensor coords,
15 | torch::Tensor corr_grad,
16 | int radius);
17 |
18 | // C++ interface
19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
22 |
23 | std::vector corr_forward(
24 | torch::Tensor fmap1,
25 | torch::Tensor fmap2,
26 | torch::Tensor coords,
27 | int radius) {
28 | CHECK_INPUT(fmap1);
29 | CHECK_INPUT(fmap2);
30 | CHECK_INPUT(coords);
31 |
32 | return corr_cuda_forward(fmap1, fmap2, coords, radius);
33 | }
34 |
35 |
36 | std::vector corr_backward(
37 | torch::Tensor fmap1,
38 | torch::Tensor fmap2,
39 | torch::Tensor coords,
40 | torch::Tensor corr_grad,
41 | int radius) {
42 | CHECK_INPUT(fmap1);
43 | CHECK_INPUT(fmap2);
44 | CHECK_INPUT(coords);
45 | CHECK_INPUT(corr_grad);
46 |
47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
48 | }
49 |
50 |
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52 | m.def("forward", &corr_forward, "CORR forward");
53 | m.def("backward", &corr_backward, "CORR backward");
54 | }
--------------------------------------------------------------------------------
/alt_cuda_corr/correlation_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 |
7 | #define BLOCK_H 4
8 | #define BLOCK_W 8
9 | #define BLOCK_HW BLOCK_H * BLOCK_W
10 | #define CHANNEL_STRIDE 32
11 |
12 |
13 | __forceinline__ __device__
14 | bool within_bounds(int h, int w, int H, int W) {
15 | return h >= 0 && h < H && w >= 0 && w < W;
16 | }
17 |
18 | template
19 | __global__ void corr_forward_kernel(
20 | const torch::PackedTensorAccessor32 fmap1,
21 | const torch::PackedTensorAccessor32 fmap2,
22 | const torch::PackedTensorAccessor32 coords,
23 | torch::PackedTensorAccessor32 corr,
24 | int r)
25 | {
26 | const int b = blockIdx.x;
27 | const int h0 = blockIdx.y * blockDim.x;
28 | const int w0 = blockIdx.z * blockDim.y;
29 | const int tid = threadIdx.x * blockDim.y + threadIdx.y;
30 |
31 | const int H1 = fmap1.size(1);
32 | const int W1 = fmap1.size(2);
33 | const int H2 = fmap2.size(1);
34 | const int W2 = fmap2.size(2);
35 | const int N = coords.size(1);
36 | const int C = fmap1.size(3);
37 |
38 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
39 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
40 | __shared__ scalar_t x2s[BLOCK_HW];
41 | __shared__ scalar_t y2s[BLOCK_HW];
42 |
43 | for (int c=0; c(floor(y2s[k1]))-r+iy;
76 | int w2 = static_cast(floor(x2s[k1]))-r+ix;
77 | int c2 = tid % CHANNEL_STRIDE;
78 |
79 | auto fptr = fmap2[b][h2][w2];
80 | if (within_bounds(h2, w2, H2, W2))
81 | f2[c2][k1] = fptr[c+c2];
82 | else
83 | f2[c2][k1] = 0.0;
84 | }
85 |
86 | __syncthreads();
87 |
88 | scalar_t s = 0.0;
89 | for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
105 | *(corr_ptr + ix_nw) += nw;
106 |
107 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
108 | *(corr_ptr + ix_ne) += ne;
109 |
110 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
111 | *(corr_ptr + ix_sw) += sw;
112 |
113 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
114 | *(corr_ptr + ix_se) += se;
115 | }
116 | }
117 | }
118 | }
119 | }
120 |
121 |
122 | template
123 | __global__ void corr_backward_kernel(
124 | const torch::PackedTensorAccessor32 fmap1,
125 | const torch::PackedTensorAccessor32 fmap2,
126 | const torch::PackedTensorAccessor32 coords,
127 | const torch::PackedTensorAccessor32 corr_grad,
128 | torch::PackedTensorAccessor32 fmap1_grad,
129 | torch::PackedTensorAccessor32 fmap2_grad,
130 | torch::PackedTensorAccessor32 coords_grad,
131 | int r)
132 | {
133 |
134 | const int b = blockIdx.x;
135 | const int h0 = blockIdx.y * blockDim.x;
136 | const int w0 = blockIdx.z * blockDim.y;
137 | const int tid = threadIdx.x * blockDim.y + threadIdx.y;
138 |
139 | const int H1 = fmap1.size(1);
140 | const int W1 = fmap1.size(2);
141 | const int H2 = fmap2.size(1);
142 | const int W2 = fmap2.size(2);
143 | const int N = coords.size(1);
144 | const int C = fmap1.size(3);
145 |
146 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
147 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
148 |
149 | __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
150 | __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
151 |
152 | __shared__ scalar_t x2s[BLOCK_HW];
153 | __shared__ scalar_t y2s[BLOCK_HW];
154 |
155 | for (int c=0; c(floor(y2s[k1]))-r+iy;
190 | int w2 = static_cast(floor(x2s[k1]))-r+ix;
191 | int c2 = tid % CHANNEL_STRIDE;
192 |
193 | auto fptr = fmap2[b][h2][w2];
194 | if (within_bounds(h2, w2, H2, W2))
195 | f2[c2][k1] = fptr[c+c2];
196 | else
197 | f2[c2][k1] = 0.0;
198 |
199 | f2_grad[c2][k1] = 0.0;
200 | }
201 |
202 | __syncthreads();
203 |
204 | const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
205 | scalar_t g = 0.0;
206 |
207 | int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
208 | int ix_ne = H1*W1*((iy-1) + rd*ix);
209 | int ix_sw = H1*W1*(iy + rd*(ix-1));
210 | int ix_se = H1*W1*(iy + rd*ix);
211 |
212 | if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
213 | g += *(grad_ptr + ix_nw) * dy * dx;
214 |
215 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
216 | g += *(grad_ptr + ix_ne) * dy * (1-dx);
217 |
218 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
219 | g += *(grad_ptr + ix_sw) * (1-dy) * dx;
220 |
221 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
222 | g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
223 |
224 | for (int k=0; k(floor(y2s[k1]))-r+iy;
232 | int w2 = static_cast(floor(x2s[k1]))-r+ix;
233 | int c2 = tid % CHANNEL_STRIDE;
234 |
235 | scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
236 | if (within_bounds(h2, w2, H2, W2))
237 | atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
238 | }
239 | }
240 | }
241 | }
242 | __syncthreads();
243 |
244 |
245 | for (int k=0; k corr_cuda_forward(
261 | torch::Tensor fmap1,
262 | torch::Tensor fmap2,
263 | torch::Tensor coords,
264 | int radius)
265 | {
266 | const auto B = coords.size(0);
267 | const auto N = coords.size(1);
268 | const auto H = coords.size(2);
269 | const auto W = coords.size(3);
270 |
271 | const auto rd = 2 * radius + 1;
272 | auto opts = fmap1.options();
273 | auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
274 |
275 | const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
276 | const dim3 threads(BLOCK_H, BLOCK_W);
277 |
278 | corr_forward_kernel<<>>(
279 | fmap1.packed_accessor32(),
280 | fmap2.packed_accessor32(),
281 | coords.packed_accessor32(),
282 | corr.packed_accessor32(),
283 | radius);
284 |
285 | return {corr};
286 | }
287 |
288 | std::vector corr_cuda_backward(
289 | torch::Tensor fmap1,
290 | torch::Tensor fmap2,
291 | torch::Tensor coords,
292 | torch::Tensor corr_grad,
293 | int radius)
294 | {
295 | const auto B = coords.size(0);
296 | const auto N = coords.size(1);
297 |
298 | const auto H1 = fmap1.size(1);
299 | const auto W1 = fmap1.size(2);
300 | const auto H2 = fmap2.size(1);
301 | const auto W2 = fmap2.size(2);
302 | const auto C = fmap1.size(3);
303 |
304 | auto opts = fmap1.options();
305 | auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
306 | auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
307 | auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
308 |
309 | const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
310 | const dim3 threads(BLOCK_H, BLOCK_W);
311 |
312 |
313 | corr_backward_kernel<<>>(
314 | fmap1.packed_accessor32(),
315 | fmap2.packed_accessor32(),
316 | coords.packed_accessor32(),
317 | corr_grad.packed_accessor32(),
318 | fmap1_grad.packed_accessor32(),
319 | fmap2_grad.packed_accessor32(),
320 | coords_grad.packed_accessor32(),
321 | radius);
322 |
323 | return {fmap1_grad, fmap2_grad, coords_grad};
324 | }
--------------------------------------------------------------------------------
/alt_cuda_corr/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 |
5 | setup(
6 | name='correlation',
7 | ext_modules=[
8 | CUDAExtension('alt_cuda_corr',
9 | sources=['correlation.cpp', 'correlation_kernel.cu'],
10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
11 | ],
12 | cmdclass={
13 | 'build_ext': BuildExtension
14 | })
15 |
16 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/RAFT/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02/core/__init__.py
--------------------------------------------------------------------------------
/core/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class AlternateCorrBlock:
64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65 | self.num_levels = num_levels
66 | self.radius = radius
67 |
68 | self.pyramid = [(fmap1, fmap2)]
69 | for i in range(self.num_levels):
70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72 | self.pyramid.append((fmap1, fmap2))
73 |
74 | def __call__(self, coords):
75 | coords = coords.permute(0, 2, 3, 1)
76 | B, H, W, _ = coords.shape
77 | dim = self.pyramid[0][0].shape[1]
78 |
79 | corr_list = []
80 | for i in range(self.num_levels):
81 | r = self.radius
82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84 |
85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87 | corr_list.append(corr.squeeze(1))
88 |
89 | corr = torch.stack(corr_list, dim=1)
90 | corr = corr.reshape(B, -1, H, W)
91 | return corr / torch.sqrt(torch.tensor(dim).float())
92 |
--------------------------------------------------------------------------------
/core/datasets.py:
--------------------------------------------------------------------------------
1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 |
8 | import os
9 | import math
10 | import random
11 | from glob import glob
12 | import os.path as osp
13 |
14 | from utils import frame_utils
15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16 |
17 |
18 | class FlowDataset(data.Dataset):
19 | def __init__(self, aug_params=None, sparse=False):
20 | self.augmentor = None
21 | self.sparse = sparse
22 | if aug_params is not None:
23 | if sparse:
24 | self.augmentor = SparseFlowAugmentor(**aug_params)
25 | else:
26 | self.augmentor = FlowAugmentor(**aug_params)
27 |
28 | self.is_test = False
29 | self.init_seed = False
30 | self.flow_list = []
31 | self.image_list = []
32 | self.extra_info = []
33 |
34 | def __getitem__(self, index):
35 |
36 | if self.is_test:
37 | img1 = frame_utils.read_gen(self.image_list[index][0])
38 | img2 = frame_utils.read_gen(self.image_list[index][1])
39 | img1 = np.array(img1).astype(np.uint8)[..., :3]
40 | img2 = np.array(img2).astype(np.uint8)[..., :3]
41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43 | return img1, img2, self.extra_info[index]
44 |
45 | if not self.init_seed:
46 | worker_info = torch.utils.data.get_worker_info()
47 | if worker_info is not None:
48 | torch.manual_seed(worker_info.id)
49 | np.random.seed(worker_info.id)
50 | random.seed(worker_info.id)
51 | self.init_seed = True
52 |
53 | index = index % len(self.image_list)
54 | valid = None
55 | if self.sparse:
56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57 | else:
58 | flow = frame_utils.read_gen(self.flow_list[index])
59 |
60 | img1 = frame_utils.read_gen(self.image_list[index][0])
61 | img2 = frame_utils.read_gen(self.image_list[index][1])
62 |
63 | flow = np.array(flow).astype(np.float32)
64 | img1 = np.array(img1).astype(np.uint8)
65 | img2 = np.array(img2).astype(np.uint8)
66 |
67 | # grayscale images
68 | if len(img1.shape) == 2:
69 | img1 = np.tile(img1[...,None], (1, 1, 3))
70 | img2 = np.tile(img2[...,None], (1, 1, 3))
71 | else:
72 | img1 = img1[..., :3]
73 | img2 = img2[..., :3]
74 |
75 | if self.augmentor is not None:
76 | if self.sparse:
77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78 | else:
79 | img1, img2, flow = self.augmentor(img1, img2, flow)
80 |
81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84 |
85 | if valid is not None:
86 | valid = torch.from_numpy(valid)
87 | else:
88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89 |
90 | return img1, img2, flow, valid.float()
91 |
92 |
93 | def __rmul__(self, v):
94 | self.flow_list = v * self.flow_list
95 | self.image_list = v * self.image_list
96 | return self
97 |
98 | def __len__(self):
99 | return len(self.image_list)
100 |
101 |
102 | class MpiSintel(FlowDataset):
103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104 | super(MpiSintel, self).__init__(aug_params)
105 | flow_root = osp.join(root, split, 'flow')
106 | image_root = osp.join(root, split, dstype)
107 |
108 | if split == 'test':
109 | self.is_test = True
110 |
111 | for scene in os.listdir(image_root):
112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113 | for i in range(len(image_list)-1):
114 | self.image_list += [ [image_list[i], image_list[i+1]] ]
115 | self.extra_info += [ (scene, i) ] # scene and frame_id
116 |
117 | if split != 'test':
118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119 |
120 |
121 | class FlyingChairs(FlowDataset):
122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123 | super(FlyingChairs, self).__init__(aug_params)
124 |
125 | images = sorted(glob(osp.join(root, '*.ppm')))
126 | flows = sorted(glob(osp.join(root, '*.flo')))
127 | assert (len(images)//2 == len(flows))
128 |
129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130 | for i in range(len(flows)):
131 | xid = split_list[i]
132 | if (split=='training' and xid==1) or (split=='validation' and xid==2):
133 | self.flow_list += [ flows[i] ]
134 | self.image_list += [ [images[2*i], images[2*i+1]] ]
135 |
136 |
137 | class FlyingThings3D(FlowDataset):
138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
139 | super(FlyingThings3D, self).__init__(aug_params)
140 |
141 | for cam in ['left']:
142 | for direction in ['into_future', 'into_past']:
143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
145 |
146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
148 |
149 | for idir, fdir in zip(image_dirs, flow_dirs):
150 | images = sorted(glob(osp.join(idir, '*.png')) )
151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) )
152 | for i in range(len(flows)-1):
153 | if direction == 'into_future':
154 | self.image_list += [ [images[i], images[i+1]] ]
155 | self.flow_list += [ flows[i] ]
156 | elif direction == 'into_past':
157 | self.image_list += [ [images[i+1], images[i]] ]
158 | self.flow_list += [ flows[i+1] ]
159 |
160 |
161 | class KITTI(FlowDataset):
162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
163 | super(KITTI, self).__init__(aug_params, sparse=True)
164 | if split == 'testing':
165 | self.is_test = True
166 |
167 | root = osp.join(root, split)
168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
170 |
171 | for img1, img2 in zip(images1, images2):
172 | frame_id = img1.split('/')[-1]
173 | self.extra_info += [ [frame_id] ]
174 | self.image_list += [ [img1, img2] ]
175 |
176 | if split == 'training':
177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
178 |
179 |
180 | class HD1K(FlowDataset):
181 | def __init__(self, aug_params=None, root='datasets/HD1k'):
182 | super(HD1K, self).__init__(aug_params, sparse=True)
183 |
184 | seq_ix = 0
185 | while 1:
186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
188 |
189 | if len(flows) == 0:
190 | break
191 |
192 | for i in range(len(flows)-1):
193 | self.flow_list += [flows[i]]
194 | self.image_list += [ [images[i], images[i+1]] ]
195 |
196 | seq_ix += 1
197 |
198 |
199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
200 | """ Create the data loader for the corresponding trainign set """
201 |
202 | if args.stage == 'chairs':
203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
204 | train_dataset = FlyingChairs(aug_params, split='training')
205 |
206 | elif args.stage == 'things':
207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
210 | train_dataset = clean_dataset + final_dataset
211 |
212 | elif args.stage == 'sintel':
213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final')
217 |
218 | if TRAIN_DS == 'C+T+K+S+H':
219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
222 |
223 | elif TRAIN_DS == 'C+T+K/S':
224 | train_dataset = 100*sintel_clean + 100*sintel_final + things
225 |
226 | elif args.stage == 'kitti':
227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
228 | train_dataset = KITTI(aug_params, split='training')
229 |
230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
232 |
233 | print('Training with %d image pairs' % len(train_dataset))
234 | return train_loader
235 |
236 |
--------------------------------------------------------------------------------
/core/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | if not stride == 1:
20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 |
22 | elif norm_fn == 'batch':
23 | self.norm1 = nn.BatchNorm2d(planes)
24 | self.norm2 = nn.BatchNorm2d(planes)
25 | if not stride == 1:
26 | self.norm3 = nn.BatchNorm2d(planes)
27 |
28 | elif norm_fn == 'instance':
29 | self.norm1 = nn.InstanceNorm2d(planes)
30 | self.norm2 = nn.InstanceNorm2d(planes)
31 | if not stride == 1:
32 | self.norm3 = nn.InstanceNorm2d(planes)
33 |
34 | elif norm_fn == 'none':
35 | self.norm1 = nn.Sequential()
36 | self.norm2 = nn.Sequential()
37 | if not stride == 1:
38 | self.norm3 = nn.Sequential()
39 |
40 | if stride == 1:
41 | self.downsample = None
42 |
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 |
48 | def forward(self, x):
49 | y = x
50 | y = self.relu(self.norm1(self.conv1(y)))
51 | y = self.relu(self.norm2(self.conv2(y)))
52 |
53 | if self.downsample is not None:
54 | x = self.downsample(x)
55 |
56 | return self.relu(x+y)
57 |
58 |
59 |
60 | class BottleneckBlock(nn.Module):
61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62 | super(BottleneckBlock, self).__init__()
63 |
64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | num_groups = planes // 8
70 |
71 | if norm_fn == 'group':
72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75 | if not stride == 1:
76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77 |
78 | elif norm_fn == 'batch':
79 | self.norm1 = nn.BatchNorm2d(planes//4)
80 | self.norm2 = nn.BatchNorm2d(planes//4)
81 | self.norm3 = nn.BatchNorm2d(planes)
82 | if not stride == 1:
83 | self.norm4 = nn.BatchNorm2d(planes)
84 |
85 | elif norm_fn == 'instance':
86 | self.norm1 = nn.InstanceNorm2d(planes//4)
87 | self.norm2 = nn.InstanceNorm2d(planes//4)
88 | self.norm3 = nn.InstanceNorm2d(planes)
89 | if not stride == 1:
90 | self.norm4 = nn.InstanceNorm2d(planes)
91 |
92 | elif norm_fn == 'none':
93 | self.norm1 = nn.Sequential()
94 | self.norm2 = nn.Sequential()
95 | self.norm3 = nn.Sequential()
96 | if not stride == 1:
97 | self.norm4 = nn.Sequential()
98 |
99 | if stride == 1:
100 | self.downsample = None
101 |
102 | else:
103 | self.downsample = nn.Sequential(
104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105 |
106 |
107 | def forward(self, x):
108 | y = x
109 | y = self.relu(self.norm1(self.conv1(y)))
110 | y = self.relu(self.norm2(self.conv2(y)))
111 | y = self.relu(self.norm3(self.conv3(y)))
112 |
113 | if self.downsample is not None:
114 | x = self.downsample(x)
115 |
116 | return self.relu(x+y)
117 |
118 | class BasicEncoder(nn.Module):
119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120 | super(BasicEncoder, self).__init__()
121 | self.norm_fn = norm_fn
122 |
123 | if self.norm_fn == 'group':
124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125 |
126 | elif self.norm_fn == 'batch':
127 | self.norm1 = nn.BatchNorm2d(64)
128 |
129 | elif self.norm_fn == 'instance':
130 | self.norm1 = nn.InstanceNorm2d(64)
131 |
132 | elif self.norm_fn == 'none':
133 | self.norm1 = nn.Sequential()
134 |
135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136 | self.relu1 = nn.ReLU(inplace=True)
137 |
138 | self.in_planes = 64
139 | self.layer1 = self._make_layer(64, stride=1)
140 | self.layer2 = self._make_layer(96, stride=2)
141 | self.layer3 = self._make_layer(128, stride=2)
142 |
143 | # output convolution
144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145 |
146 | self.dropout = None
147 | if dropout > 0:
148 | self.dropout = nn.Dropout2d(p=dropout)
149 |
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154 | if m.weight is not None:
155 | nn.init.constant_(m.weight, 1)
156 | if m.bias is not None:
157 | nn.init.constant_(m.bias, 0)
158 |
159 | def _make_layer(self, dim, stride=1):
160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162 | layers = (layer1, layer2)
163 |
164 | self.in_planes = dim
165 | return nn.Sequential(*layers)
166 |
167 |
168 | def forward(self, x):
169 |
170 | # if input is list, combine batch dimension
171 | is_list = isinstance(x, tuple) or isinstance(x, list)
172 | if is_list:
173 | batch_dim = x[0].shape[0]
174 | x = torch.cat(x, dim=0)
175 |
176 | x = self.conv1(x)
177 | x = self.norm1(x)
178 | x = self.relu1(x)
179 |
180 | x = self.layer1(x)
181 | x = self.layer2(x)
182 | x = self.layer3(x)
183 |
184 | x = self.conv2(x)
185 |
186 | if self.training and self.dropout is not None:
187 | x = self.dropout(x)
188 |
189 | if is_list:
190 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
191 |
192 | return x
193 |
194 |
195 | class SmallEncoder(nn.Module):
196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
197 | super(SmallEncoder, self).__init__()
198 | self.norm_fn = norm_fn
199 |
200 | if self.norm_fn == 'group':
201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
202 |
203 | elif self.norm_fn == 'batch':
204 | self.norm1 = nn.BatchNorm2d(32)
205 |
206 | elif self.norm_fn == 'instance':
207 | self.norm1 = nn.InstanceNorm2d(32)
208 |
209 | elif self.norm_fn == 'none':
210 | self.norm1 = nn.Sequential()
211 |
212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
213 | self.relu1 = nn.ReLU(inplace=True)
214 |
215 | self.in_planes = 32
216 | self.layer1 = self._make_layer(32, stride=1)
217 | self.layer2 = self._make_layer(64, stride=2)
218 | self.layer3 = self._make_layer(96, stride=2)
219 |
220 | self.dropout = None
221 | if dropout > 0:
222 | self.dropout = nn.Dropout2d(p=dropout)
223 |
224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
225 |
226 | for m in self.modules():
227 | if isinstance(m, nn.Conv2d):
228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230 | if m.weight is not None:
231 | nn.init.constant_(m.weight, 1)
232 | if m.bias is not None:
233 | nn.init.constant_(m.bias, 0)
234 |
235 | def _make_layer(self, dim, stride=1):
236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
238 | layers = (layer1, layer2)
239 |
240 | self.in_planes = dim
241 | return nn.Sequential(*layers)
242 |
243 |
244 | def forward(self, x):
245 |
246 | # if input is list, combine batch dimension
247 | is_list = isinstance(x, tuple) or isinstance(x, list)
248 | if is_list:
249 | batch_dim = x[0].shape[0]
250 | x = torch.cat(x, dim=0)
251 |
252 | x = self.conv1(x)
253 | x = self.norm1(x)
254 | x = self.relu1(x)
255 |
256 | x = self.layer1(x)
257 | x = self.layer2(x)
258 | x = self.layer3(x)
259 | x = self.conv2(x)
260 |
261 | if self.training and self.dropout is not None:
262 | x = self.dropout(x)
263 |
264 | if is_list:
265 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
266 |
267 | return x
268 |
--------------------------------------------------------------------------------
/core/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from update import BasicUpdateBlock, SmallUpdateBlock
7 | from extractor import BasicEncoder, SmallEncoder
8 | from corr import CorrBlock, AlternateCorrBlock
9 | from utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except:
14 | # dummy autocast for PyTorch < 1.6
15 | class autocast:
16 | def __init__(self, enabled):
17 | pass
18 | def __enter__(self):
19 | pass
20 | def __exit__(self, *args):
21 | pass
22 |
23 |
24 | class RAFT(nn.Module):
25 | def __init__(self, args):
26 | super(RAFT, self).__init__()
27 | self.args = args
28 |
29 | if args.small:
30 | self.hidden_dim = hdim = 96
31 | self.context_dim = cdim = 64
32 | args.corr_levels = 4
33 | args.corr_radius = 3
34 |
35 | else:
36 | self.hidden_dim = hdim = 128
37 | self.context_dim = cdim = 128
38 | args.corr_levels = 4
39 | args.corr_radius = 4
40 |
41 | if 'dropout' not in self.args:
42 | self.args.dropout = 0
43 |
44 | if 'alternate_corr' not in self.args:
45 | self.args.alternate_corr = False
46 |
47 | # feature network, context network, and update block
48 | if args.small:
49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52 |
53 | else:
54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57 |
58 | def freeze_bn(self):
59 | for m in self.modules():
60 | if isinstance(m, nn.BatchNorm2d):
61 | m.eval()
62 |
63 | def initialize_flow(self, img):
64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
65 | N, C, H, W = img.shape
66 | coords0 = coords_grid(N, H//8, W//8, device=img.device)
67 | coords1 = coords_grid(N, H//8, W//8, device=img.device)
68 |
69 | # optical flow computed as difference: flow = coords1 - coords0
70 | return coords0, coords1
71 |
72 | def upsample_flow(self, flow, mask):
73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
74 | N, _, H, W = flow.shape
75 | mask = mask.view(N, 1, 9, 8, 8, H, W)
76 | mask = torch.softmax(mask, dim=2)
77 |
78 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
80 |
81 | up_flow = torch.sum(mask * up_flow, dim=2)
82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
83 | return up_flow.reshape(N, 2, 8*H, 8*W)
84 |
85 |
86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
87 | """ Estimate optical flow between pair of frames """
88 |
89 | image1 = 2 * (image1 / 255.0) - 1.0
90 | image2 = 2 * (image2 / 255.0) - 1.0
91 |
92 | image1 = image1.contiguous()
93 | image2 = image2.contiguous()
94 |
95 | hdim = self.hidden_dim
96 | cdim = self.context_dim
97 |
98 | # run the feature network
99 | with autocast(enabled=self.args.mixed_precision):
100 | fmap1, fmap2 = self.fnet([image1, image2])
101 |
102 | fmap1 = fmap1.float()
103 | fmap2 = fmap2.float()
104 | if self.args.alternate_corr:
105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
106 | else:
107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
108 |
109 | # run the context network
110 | with autocast(enabled=self.args.mixed_precision):
111 | cnet = self.cnet(image1)
112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
113 | net = torch.tanh(net)
114 | inp = torch.relu(inp)
115 |
116 | coords0, coords1 = self.initialize_flow(image1)
117 |
118 | if flow_init is not None:
119 | coords1 = coords1 + flow_init
120 |
121 | flow_predictions = []
122 | for itr in range(iters):
123 | coords1 = coords1.detach()
124 | corr = corr_fn(coords1) # index correlation volume
125 |
126 | flow = coords1 - coords0
127 | with autocast(enabled=self.args.mixed_precision):
128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
129 |
130 | # F(t+1) = F(t) + \Delta(t)
131 | coords1 = coords1 + delta_flow
132 |
133 | # upsample predictions
134 | if up_mask is None:
135 | flow_up = upflow8(coords1 - coords0)
136 | else:
137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
138 |
139 | flow_predictions.append(flow_up)
140 |
141 | if test_mode:
142 | return coords1 - coords0, flow_up
143 |
144 | return flow_predictions
145 |
--------------------------------------------------------------------------------
/core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self, args):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self, args):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, args, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(args)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, args, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.args = args
118 | self.encoder = BasicMotionEncoder(args)
119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121 |
122 | self.mask = nn.Sequential(
123 | nn.Conv2d(128, 256, 3, padding=1),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(256, 64*9, 1, padding=0))
126 |
127 | def forward(self, net, inp, corr, flow, upsample=True):
128 | motion_features = self.encoder(flow, corr)
129 | inp = torch.cat([inp, motion_features], dim=1)
130 |
131 | net = self.gru(net, inp)
132 | delta_flow = self.flow_head(net)
133 |
134 | # scale mask to balence gradients
135 | mask = .25 * self.mask(net)
136 | return net, mask, delta_flow
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/RAFT/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02/core/utils/__init__.py
--------------------------------------------------------------------------------
/core/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import math
4 | from PIL import Image
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | import torch
11 | from torchvision.transforms import ColorJitter
12 | import torch.nn.functional as F
13 |
14 |
15 | class FlowAugmentor:
16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
17 |
18 | # spatial augmentation params
19 | self.crop_size = crop_size
20 | self.min_scale = min_scale
21 | self.max_scale = max_scale
22 | self.spatial_aug_prob = 0.8
23 | self.stretch_prob = 0.8
24 | self.max_stretch = 0.2
25 |
26 | # flip augmentation params
27 | self.do_flip = do_flip
28 | self.h_flip_prob = 0.5
29 | self.v_flip_prob = 0.1
30 |
31 | # photometric augmentation params
32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
33 | self.asymmetric_color_aug_prob = 0.2
34 | self.eraser_aug_prob = 0.5
35 |
36 | def color_transform(self, img1, img2):
37 | """ Photometric augmentation """
38 |
39 | # asymmetric
40 | if np.random.rand() < self.asymmetric_color_aug_prob:
41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
43 |
44 | # symmetric
45 | else:
46 | image_stack = np.concatenate([img1, img2], axis=0)
47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
48 | img1, img2 = np.split(image_stack, 2, axis=0)
49 |
50 | return img1, img2
51 |
52 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
53 | """ Occlusion augmentation """
54 |
55 | ht, wd = img1.shape[:2]
56 | if np.random.rand() < self.eraser_aug_prob:
57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
58 | for _ in range(np.random.randint(1, 3)):
59 | x0 = np.random.randint(0, wd)
60 | y0 = np.random.randint(0, ht)
61 | dx = np.random.randint(bounds[0], bounds[1])
62 | dy = np.random.randint(bounds[0], bounds[1])
63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
64 |
65 | return img1, img2
66 |
67 | def spatial_transform(self, img1, img2, flow):
68 | # randomly sample scale
69 | ht, wd = img1.shape[:2]
70 | min_scale = np.maximum(
71 | (self.crop_size[0] + 8) / float(ht),
72 | (self.crop_size[1] + 8) / float(wd))
73 |
74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
75 | scale_x = scale
76 | scale_y = scale
77 | if np.random.rand() < self.stretch_prob:
78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
80 |
81 | scale_x = np.clip(scale_x, min_scale, None)
82 | scale_y = np.clip(scale_y, min_scale, None)
83 |
84 | if np.random.rand() < self.spatial_aug_prob:
85 | # rescale the images
86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
89 | flow = flow * [scale_x, scale_y]
90 |
91 | if self.do_flip:
92 | if np.random.rand() < self.h_flip_prob: # h-flip
93 | img1 = img1[:, ::-1]
94 | img2 = img2[:, ::-1]
95 | flow = flow[:, ::-1] * [-1.0, 1.0]
96 |
97 | if np.random.rand() < self.v_flip_prob: # v-flip
98 | img1 = img1[::-1, :]
99 | img2 = img2[::-1, :]
100 | flow = flow[::-1, :] * [1.0, -1.0]
101 |
102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
104 |
105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
108 |
109 | return img1, img2, flow
110 |
111 | def __call__(self, img1, img2, flow):
112 | img1, img2 = self.color_transform(img1, img2)
113 | img1, img2 = self.eraser_transform(img1, img2)
114 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
115 |
116 | img1 = np.ascontiguousarray(img1)
117 | img2 = np.ascontiguousarray(img2)
118 | flow = np.ascontiguousarray(flow)
119 |
120 | return img1, img2, flow
121 |
122 | class SparseFlowAugmentor:
123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
124 | # spatial augmentation params
125 | self.crop_size = crop_size
126 | self.min_scale = min_scale
127 | self.max_scale = max_scale
128 | self.spatial_aug_prob = 0.8
129 | self.stretch_prob = 0.8
130 | self.max_stretch = 0.2
131 |
132 | # flip augmentation params
133 | self.do_flip = do_flip
134 | self.h_flip_prob = 0.5
135 | self.v_flip_prob = 0.1
136 |
137 | # photometric augmentation params
138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
139 | self.asymmetric_color_aug_prob = 0.2
140 | self.eraser_aug_prob = 0.5
141 |
142 | def color_transform(self, img1, img2):
143 | image_stack = np.concatenate([img1, img2], axis=0)
144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
145 | img1, img2 = np.split(image_stack, 2, axis=0)
146 | return img1, img2
147 |
148 | def eraser_transform(self, img1, img2):
149 | ht, wd = img1.shape[:2]
150 | if np.random.rand() < self.eraser_aug_prob:
151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
152 | for _ in range(np.random.randint(1, 3)):
153 | x0 = np.random.randint(0, wd)
154 | y0 = np.random.randint(0, ht)
155 | dx = np.random.randint(50, 100)
156 | dy = np.random.randint(50, 100)
157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
158 |
159 | return img1, img2
160 |
161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
162 | ht, wd = flow.shape[:2]
163 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
164 | coords = np.stack(coords, axis=-1)
165 |
166 | coords = coords.reshape(-1, 2).astype(np.float32)
167 | flow = flow.reshape(-1, 2).astype(np.float32)
168 | valid = valid.reshape(-1).astype(np.float32)
169 |
170 | coords0 = coords[valid>=1]
171 | flow0 = flow[valid>=1]
172 |
173 | ht1 = int(round(ht * fy))
174 | wd1 = int(round(wd * fx))
175 |
176 | coords1 = coords0 * [fx, fy]
177 | flow1 = flow0 * [fx, fy]
178 |
179 | xx = np.round(coords1[:,0]).astype(np.int32)
180 | yy = np.round(coords1[:,1]).astype(np.int32)
181 |
182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
183 | xx = xx[v]
184 | yy = yy[v]
185 | flow1 = flow1[v]
186 |
187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
189 |
190 | flow_img[yy, xx] = flow1
191 | valid_img[yy, xx] = 1
192 |
193 | return flow_img, valid_img
194 |
195 | def spatial_transform(self, img1, img2, flow, valid):
196 | # randomly sample scale
197 |
198 | ht, wd = img1.shape[:2]
199 | min_scale = np.maximum(
200 | (self.crop_size[0] + 1) / float(ht),
201 | (self.crop_size[1] + 1) / float(wd))
202 |
203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
204 | scale_x = np.clip(scale, min_scale, None)
205 | scale_y = np.clip(scale, min_scale, None)
206 |
207 | if np.random.rand() < self.spatial_aug_prob:
208 | # rescale the images
209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
212 |
213 | if self.do_flip:
214 | if np.random.rand() < 0.5: # h-flip
215 | img1 = img1[:, ::-1]
216 | img2 = img2[:, ::-1]
217 | flow = flow[:, ::-1] * [-1.0, 1.0]
218 | valid = valid[:, ::-1]
219 |
220 | margin_y = 20
221 | margin_x = 50
222 |
223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
225 |
226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
228 |
229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
233 | return img1, img2, flow, valid
234 |
235 |
236 | def __call__(self, img1, img2, flow, valid):
237 | img1, img2 = self.color_transform(img1, img2)
238 | img1, img2 = self.eraser_transform(img1, img2)
239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
240 |
241 | img1 = np.ascontiguousarray(img1)
242 | img2 = np.ascontiguousarray(img2)
243 | flow = np.ascontiguousarray(flow)
244 | valid = np.ascontiguousarray(valid)
245 |
246 | return img1, img2, flow, valid
247 |
--------------------------------------------------------------------------------
/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd, device):
75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/demo-frames/frame_0016.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/RAFT/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02/demo-frames/frame_0016.png
--------------------------------------------------------------------------------
/demo-frames/frame_0017.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/RAFT/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02/demo-frames/frame_0017.png
--------------------------------------------------------------------------------
/demo-frames/frame_0018.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/RAFT/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02/demo-frames/frame_0018.png
--------------------------------------------------------------------------------
/demo-frames/frame_0019.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/RAFT/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02/demo-frames/frame_0019.png
--------------------------------------------------------------------------------
/demo-frames/frame_0020.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/RAFT/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02/demo-frames/frame_0020.png
--------------------------------------------------------------------------------
/demo-frames/frame_0021.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/RAFT/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02/demo-frames/frame_0021.png
--------------------------------------------------------------------------------
/demo-frames/frame_0022.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/RAFT/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02/demo-frames/frame_0022.png
--------------------------------------------------------------------------------
/demo-frames/frame_0023.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/RAFT/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02/demo-frames/frame_0023.png
--------------------------------------------------------------------------------
/demo-frames/frame_0024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/RAFT/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02/demo-frames/frame_0024.png
--------------------------------------------------------------------------------
/demo-frames/frame_0025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/RAFT/3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02/demo-frames/frame_0025.png
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | import argparse
5 | import os
6 | import cv2
7 | import glob
8 | import numpy as np
9 | import torch
10 | from PIL import Image
11 |
12 | from raft import RAFT
13 | from utils import flow_viz
14 | from utils.utils import InputPadder
15 |
16 |
17 |
18 | DEVICE = 'cuda'
19 |
20 | def load_image(imfile):
21 | img = np.array(Image.open(imfile)).astype(np.uint8)
22 | img = torch.from_numpy(img).permute(2, 0, 1).float()
23 | return img[None].to(DEVICE)
24 |
25 |
26 | def viz(img, flo):
27 | img = img[0].permute(1,2,0).cpu().numpy()
28 | flo = flo[0].permute(1,2,0).cpu().numpy()
29 |
30 | # map flow to rgb image
31 | flo = flow_viz.flow_to_image(flo)
32 | img_flo = np.concatenate([img, flo], axis=0)
33 |
34 | # import matplotlib.pyplot as plt
35 | # plt.imshow(img_flo / 255.0)
36 | # plt.show()
37 |
38 | cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
39 | cv2.waitKey()
40 |
41 |
42 | def demo(args):
43 | model = torch.nn.DataParallel(RAFT(args))
44 | model.load_state_dict(torch.load(args.model))
45 |
46 | model = model.module
47 | model.to(DEVICE)
48 | model.eval()
49 |
50 | with torch.no_grad():
51 | images = glob.glob(os.path.join(args.path, '*.png')) + \
52 | glob.glob(os.path.join(args.path, '*.jpg'))
53 |
54 | images = sorted(images)
55 | for imfile1, imfile2 in zip(images[:-1], images[1:]):
56 | image1 = load_image(imfile1)
57 | image2 = load_image(imfile2)
58 |
59 | padder = InputPadder(image1.shape)
60 | image1, image2 = padder.pad(image1, image2)
61 |
62 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
63 | viz(image1, flow_up)
64 |
65 |
66 | if __name__ == '__main__':
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument('--model', help="restore checkpoint")
69 | parser.add_argument('--path', help="dataset for evaluation")
70 | parser.add_argument('--small', action='store_true', help='use small model')
71 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
72 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
73 | args = parser.parse_args()
74 |
75 | demo(args)
76 |
--------------------------------------------------------------------------------
/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget https://dl.dropboxusercontent.com/s/4j4z58wuv8o0mfz/models.zip
3 | unzip models.zip
4 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | from PIL import Image
5 | import argparse
6 | import os
7 | import time
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | import matplotlib.pyplot as plt
12 |
13 | import datasets
14 | from utils import flow_viz
15 | from utils import frame_utils
16 |
17 | from raft import RAFT
18 | from utils.utils import InputPadder, forward_interpolate
19 |
20 |
21 | @torch.no_grad()
22 | def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
23 | """ Create submission for the Sintel leaderboard """
24 | model.eval()
25 | for dstype in ['clean', 'final']:
26 | test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
27 |
28 | flow_prev, sequence_prev = None, None
29 | for test_id in range(len(test_dataset)):
30 | image1, image2, (sequence, frame) = test_dataset[test_id]
31 | if sequence != sequence_prev:
32 | flow_prev = None
33 |
34 | padder = InputPadder(image1.shape)
35 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
36 |
37 | flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
38 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
39 |
40 | if warm_start:
41 | flow_prev = forward_interpolate(flow_low[0])[None].cuda()
42 |
43 | output_dir = os.path.join(output_path, dstype, sequence)
44 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
45 |
46 | if not os.path.exists(output_dir):
47 | os.makedirs(output_dir)
48 |
49 | frame_utils.writeFlow(output_file, flow)
50 | sequence_prev = sequence
51 |
52 |
53 | @torch.no_grad()
54 | def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
55 | """ Create submission for the Sintel leaderboard """
56 | model.eval()
57 | test_dataset = datasets.KITTI(split='testing', aug_params=None)
58 |
59 | if not os.path.exists(output_path):
60 | os.makedirs(output_path)
61 |
62 | for test_id in range(len(test_dataset)):
63 | image1, image2, (frame_id, ) = test_dataset[test_id]
64 | padder = InputPadder(image1.shape, mode='kitti')
65 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
66 |
67 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
68 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
69 |
70 | output_filename = os.path.join(output_path, frame_id)
71 | frame_utils.writeFlowKITTI(output_filename, flow)
72 |
73 |
74 | @torch.no_grad()
75 | def validate_chairs(model, iters=24):
76 | """ Perform evaluation on the FlyingChairs (test) split """
77 | model.eval()
78 | epe_list = []
79 |
80 | val_dataset = datasets.FlyingChairs(split='validation')
81 | for val_id in range(len(val_dataset)):
82 | image1, image2, flow_gt, _ = val_dataset[val_id]
83 | image1 = image1[None].cuda()
84 | image2 = image2[None].cuda()
85 |
86 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
87 | epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
88 | epe_list.append(epe.view(-1).numpy())
89 |
90 | epe = np.mean(np.concatenate(epe_list))
91 | print("Validation Chairs EPE: %f" % epe)
92 | return {'chairs': epe}
93 |
94 |
95 | @torch.no_grad()
96 | def validate_sintel(model, iters=32):
97 | """ Peform validation using the Sintel (train) split """
98 | model.eval()
99 | results = {}
100 | for dstype in ['clean', 'final']:
101 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
102 | epe_list = []
103 |
104 | for val_id in range(len(val_dataset)):
105 | image1, image2, flow_gt, _ = val_dataset[val_id]
106 | image1 = image1[None].cuda()
107 | image2 = image2[None].cuda()
108 |
109 | padder = InputPadder(image1.shape)
110 | image1, image2 = padder.pad(image1, image2)
111 |
112 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
113 | flow = padder.unpad(flow_pr[0]).cpu()
114 |
115 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
116 | epe_list.append(epe.view(-1).numpy())
117 |
118 | epe_all = np.concatenate(epe_list)
119 | epe = np.mean(epe_all)
120 | px1 = np.mean(epe_all<1)
121 | px3 = np.mean(epe_all<3)
122 | px5 = np.mean(epe_all<5)
123 |
124 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
125 | results[dstype] = np.mean(epe_list)
126 |
127 | return results
128 |
129 |
130 | @torch.no_grad()
131 | def validate_kitti(model, iters=24):
132 | """ Peform validation using the KITTI-2015 (train) split """
133 | model.eval()
134 | val_dataset = datasets.KITTI(split='training')
135 |
136 | out_list, epe_list = [], []
137 | for val_id in range(len(val_dataset)):
138 | image1, image2, flow_gt, valid_gt = val_dataset[val_id]
139 | image1 = image1[None].cuda()
140 | image2 = image2[None].cuda()
141 |
142 | padder = InputPadder(image1.shape, mode='kitti')
143 | image1, image2 = padder.pad(image1, image2)
144 |
145 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
146 | flow = padder.unpad(flow_pr[0]).cpu()
147 |
148 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
149 | mag = torch.sum(flow_gt**2, dim=0).sqrt()
150 |
151 | epe = epe.view(-1)
152 | mag = mag.view(-1)
153 | val = valid_gt.view(-1) >= 0.5
154 |
155 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
156 | epe_list.append(epe[val].mean().item())
157 | out_list.append(out[val].cpu().numpy())
158 |
159 | epe_list = np.array(epe_list)
160 | out_list = np.concatenate(out_list)
161 |
162 | epe = np.mean(epe_list)
163 | f1 = 100 * np.mean(out_list)
164 |
165 | print("Validation KITTI: %f, %f" % (epe, f1))
166 | return {'kitti-epe': epe, 'kitti-f1': f1}
167 |
168 |
169 | if __name__ == '__main__':
170 | parser = argparse.ArgumentParser()
171 | parser.add_argument('--model', help="restore checkpoint")
172 | parser.add_argument('--dataset', help="dataset for evaluation")
173 | parser.add_argument('--small', action='store_true', help='use small model')
174 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
175 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
176 | args = parser.parse_args()
177 |
178 | model = torch.nn.DataParallel(RAFT(args))
179 | model.load_state_dict(torch.load(args.model))
180 |
181 | model.cuda()
182 | model.eval()
183 |
184 | # create_sintel_submission(model.module, warm_start=True)
185 | # create_kitti_submission(model.module)
186 |
187 | with torch.no_grad():
188 | if args.dataset == 'chairs':
189 | validate_chairs(model.module)
190 |
191 | elif args.dataset == 'sintel':
192 | validate_sintel(model.module)
193 |
194 | elif args.dataset == 'kitti':
195 | validate_kitti(model.module)
196 |
197 |
198 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import sys
3 | sys.path.append('core')
4 |
5 | import argparse
6 | import os
7 | import cv2
8 | import time
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.optim as optim
15 | import torch.nn.functional as F
16 |
17 | from torch.utils.data import DataLoader
18 | from raft import RAFT
19 | import evaluate
20 | import datasets
21 |
22 | from torch.utils.tensorboard import SummaryWriter
23 |
24 | try:
25 | from torch.cuda.amp import GradScaler
26 | except:
27 | # dummy GradScaler for PyTorch < 1.6
28 | class GradScaler:
29 | def __init__(self):
30 | pass
31 | def scale(self, loss):
32 | return loss
33 | def unscale_(self, optimizer):
34 | pass
35 | def step(self, optimizer):
36 | optimizer.step()
37 | def update(self):
38 | pass
39 |
40 |
41 | # exclude extremly large displacements
42 | MAX_FLOW = 400
43 | SUM_FREQ = 100
44 | VAL_FREQ = 5000
45 |
46 |
47 | def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
48 | """ Loss function defined over sequence of flow predictions """
49 |
50 | n_predictions = len(flow_preds)
51 | flow_loss = 0.0
52 |
53 | # exlude invalid pixels and extremely large diplacements
54 | mag = torch.sum(flow_gt**2, dim=1).sqrt()
55 | valid = (valid >= 0.5) & (mag < max_flow)
56 |
57 | for i in range(n_predictions):
58 | i_weight = gamma**(n_predictions - i - 1)
59 | i_loss = (flow_preds[i] - flow_gt).abs()
60 | flow_loss += i_weight * (valid[:, None] * i_loss).mean()
61 |
62 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
63 | epe = epe.view(-1)[valid.view(-1)]
64 |
65 | metrics = {
66 | 'epe': epe.mean().item(),
67 | '1px': (epe < 1).float().mean().item(),
68 | '3px': (epe < 3).float().mean().item(),
69 | '5px': (epe < 5).float().mean().item(),
70 | }
71 |
72 | return flow_loss, metrics
73 |
74 |
75 | def count_parameters(model):
76 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
77 |
78 |
79 | def fetch_optimizer(args, model):
80 | """ Create the optimizer and learning rate scheduler """
81 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
82 |
83 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
84 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
85 |
86 | return optimizer, scheduler
87 |
88 |
89 | class Logger:
90 | def __init__(self, model, scheduler):
91 | self.model = model
92 | self.scheduler = scheduler
93 | self.total_steps = 0
94 | self.running_loss = {}
95 | self.writer = None
96 |
97 | def _print_training_status(self):
98 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
99 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
100 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
101 |
102 | # print the training status
103 | print(training_str + metrics_str)
104 |
105 | if self.writer is None:
106 | self.writer = SummaryWriter()
107 |
108 | for k in self.running_loss:
109 | self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
110 | self.running_loss[k] = 0.0
111 |
112 | def push(self, metrics):
113 | self.total_steps += 1
114 |
115 | for key in metrics:
116 | if key not in self.running_loss:
117 | self.running_loss[key] = 0.0
118 |
119 | self.running_loss[key] += metrics[key]
120 |
121 | if self.total_steps % SUM_FREQ == SUM_FREQ-1:
122 | self._print_training_status()
123 | self.running_loss = {}
124 |
125 | def write_dict(self, results):
126 | if self.writer is None:
127 | self.writer = SummaryWriter()
128 |
129 | for key in results:
130 | self.writer.add_scalar(key, results[key], self.total_steps)
131 |
132 | def close(self):
133 | self.writer.close()
134 |
135 |
136 | def train(args):
137 |
138 | model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
139 | print("Parameter Count: %d" % count_parameters(model))
140 |
141 | if args.restore_ckpt is not None:
142 | model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
143 |
144 | model.cuda()
145 | model.train()
146 |
147 | if args.stage != 'chairs':
148 | model.module.freeze_bn()
149 |
150 | train_loader = datasets.fetch_dataloader(args)
151 | optimizer, scheduler = fetch_optimizer(args, model)
152 |
153 | total_steps = 0
154 | scaler = GradScaler(enabled=args.mixed_precision)
155 | logger = Logger(model, scheduler)
156 |
157 | VAL_FREQ = 5000
158 | add_noise = True
159 |
160 | should_keep_training = True
161 | while should_keep_training:
162 |
163 | for i_batch, data_blob in enumerate(train_loader):
164 | optimizer.zero_grad()
165 | image1, image2, flow, valid = [x.cuda() for x in data_blob]
166 |
167 | if args.add_noise:
168 | stdv = np.random.uniform(0.0, 5.0)
169 | image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
170 | image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
171 |
172 | flow_predictions = model(image1, image2, iters=args.iters)
173 |
174 | loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
175 | scaler.scale(loss).backward()
176 | scaler.unscale_(optimizer)
177 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
178 |
179 | scaler.step(optimizer)
180 | scheduler.step()
181 | scaler.update()
182 |
183 | logger.push(metrics)
184 |
185 | if total_steps % VAL_FREQ == VAL_FREQ - 1:
186 | PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
187 | torch.save(model.state_dict(), PATH)
188 |
189 | results = {}
190 | for val_dataset in args.validation:
191 | if val_dataset == 'chairs':
192 | results.update(evaluate.validate_chairs(model.module))
193 | elif val_dataset == 'sintel':
194 | results.update(evaluate.validate_sintel(model.module))
195 | elif val_dataset == 'kitti':
196 | results.update(evaluate.validate_kitti(model.module))
197 |
198 | logger.write_dict(results)
199 |
200 | model.train()
201 | if args.stage != 'chairs':
202 | model.module.freeze_bn()
203 |
204 | total_steps += 1
205 |
206 | if total_steps > args.num_steps:
207 | should_keep_training = False
208 | break
209 |
210 | logger.close()
211 | PATH = 'checkpoints/%s.pth' % args.name
212 | torch.save(model.state_dict(), PATH)
213 |
214 | return PATH
215 |
216 |
217 | if __name__ == '__main__':
218 | parser = argparse.ArgumentParser()
219 | parser.add_argument('--name', default='raft', help="name your experiment")
220 | parser.add_argument('--stage', help="determines which dataset to use for training")
221 | parser.add_argument('--restore_ckpt', help="restore checkpoint")
222 | parser.add_argument('--small', action='store_true', help='use small model')
223 | parser.add_argument('--validation', type=str, nargs='+')
224 |
225 | parser.add_argument('--lr', type=float, default=0.00002)
226 | parser.add_argument('--num_steps', type=int, default=100000)
227 | parser.add_argument('--batch_size', type=int, default=6)
228 | parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
229 | parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
230 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
231 |
232 | parser.add_argument('--iters', type=int, default=12)
233 | parser.add_argument('--wdecay', type=float, default=.00005)
234 | parser.add_argument('--epsilon', type=float, default=1e-8)
235 | parser.add_argument('--clip', type=float, default=1.0)
236 | parser.add_argument('--dropout', type=float, default=0.0)
237 | parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
238 | parser.add_argument('--add_noise', action='store_true')
239 | args = parser.parse_args()
240 |
241 | torch.manual_seed(1234)
242 | np.random.seed(1234)
243 |
244 | if not os.path.isdir('checkpoints'):
245 | os.mkdir('checkpoints')
246 |
247 | train(args)
--------------------------------------------------------------------------------
/train_mixed.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p checkpoints
3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision
6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision
7 |
--------------------------------------------------------------------------------
/train_standard.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p checkpoints
3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85
6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85
7 |
--------------------------------------------------------------------------------