├── README.md
├── alt_cuda_corr
├── correlation.cpp
├── correlation_kernel.cu
└── setup.py
├── chairs_split.txt
├── core
├── __init__.py
├── corr.py
├── datasets.py
├── extractor.py
├── fd_corr.py
├── fd_decoder.py
├── fd_encoder.py
├── flowdiffuser.py
├── module.py
├── raft.py
├── update.py
└── utils
│ ├── __init__.py
│ ├── augmentor.py
│ ├── flow_viz.py
│ ├── frame_utils.py
│ └── utils.py
├── eval.sh
├── evaluate.py
├── train.py
└── train.sh
/README.md:
--------------------------------------------------------------------------------
1 | # [CVPR 2024] FlowDiffuser: Advancing Optical Flow Estimation with Diffusion Models
2 |
3 |
Ao Luo1,2, Xin Li3, Fan Fang3, Jiangyu Liu2, Haoqiang Fan2, and Shuaicheng Liu4,2
4 | 1. Southwest Jiaotong University 2. Megvii Research 3.Group 42
5 | 4. University of Electronic Science and Technology of China
6 |
7 | This project provides the official implementation of '[**FlowDiffuser: Advancing Optical Flow Estimation with Diffusion Models**](https://openaccess.thecvf.com/content/CVPR2024/papers/Luo_FlowDiffuser_Advancing_Optical_Flow_Estimation_with_Diffusion_Models_CVPR_2024_paper.pdf)'.
8 |
9 | ## Abstract
10 | Optical flow estimation, a process of predicting pixel-wise displacement between consecutive frames, has commonly been approached as a regression task in the age of deep learning. Despite notable advancements, this de facto paradigm unfortunately falls short in generalization performance when trained on synthetic or constrained data. Pioneering a paradigm shift, we reformulate optical flow estimation as a conditional flow generation challenge, unveiling FlowDiffuser — a new family of optical flow models that could have stronger learning and generalization capabilities. FlowDiffuser estimates optical flow through a ‘noise-to-flow’ strategy, progressively eliminating noise from randomly generated flows conditioned on the provided pairs. To optimize accuracy and efficiency, our FlowDiffuser incorporates a novel Conditional Recurrent Denoising Decoder (Conditional-RDD), streamlining the flow estimation process. It incorporates a unique Hidden State Denoising (HSD) paradigm, effectively leveraging the information from previous time steps. Moreover, FlowDiffuser can be easily integrated into existing flow networks, leading to significant improvements in performance metrics compared to conventional implementations. Experiments on challenging benchmarks, including Sintel and KITTI, demonstrate the effectiveness of our FlowDiffuser with superior performance to existing state-of-the-art models.
11 |
12 |
13 | ## Overview
14 |
15 | 
16 |
17 | 
18 |
19 |
20 | ## Requirements
21 |
22 | Python 3.8 with following packages
23 | ```Shell
24 | pytorch 1.9.0
25 | torchvision 0.10.0
26 | numpy 1.19.5
27 | opencv-python 4.6.0.66
28 | timm 0.6.12
29 | scipy 1.5.4
30 | matplotlib 3.3.4
31 | ```
32 |
33 |
34 | ## Usage
35 |
36 | 1. Download [Sintel](http://sintel.is.tue.mpg.de/) and [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) dataset, and set the root path of each class in `./core/datasets.py`.
37 |
38 | 2. Put `*.pth` file ([GoogleDrive](https://drive.google.com/file/d/1msAB8-ibMCTUEQbT6yjV1y13D_-6LTSX/view?usp=sharing)) into folder `./weights`.
39 |
40 | 3. Evaluation on Sintel and KITTI
41 | ```Shell
42 | ./eval.sh
43 | ```
44 |
45 |
46 | ## Q & A
47 |
48 | Due to some changes in my job, I am busy with other matters. If you have any questions, please email me at aoluo@swjtu.edu.cn. I will respond to you at my earliest convenience.
49 |
50 |
51 | ## Citation
52 |
53 | If you think this work is helpful, please cite
54 | ```
55 | @inproceedings{luo2024flowdiffuser,
56 | title={FlowDiffuser: Advancing Optical Flow Estimation with Diffusion Models},
57 | author={Luo, Ao and Li, Xin and Yang, Fan and Liu, Jiangyu and Fan, Haoqiang and Liu, Shuaicheng},
58 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
59 | pages={19167--19176},
60 | year={2024}
61 | }
62 | ```
63 |
64 | ## Acknowledgement
65 |
66 | The code is built based on [RAFT](https://github.com/princeton-vl/RAFT), [SKFlow](https://github.com/littlespray/SKFlow), [DiffusionDet](https://github.com/ShoufaChen/DiffusionDet), and [EMD-Flow](https://github.com/gddcx/EMD-Flow). We thank the authors for their contributions.
67 |
--------------------------------------------------------------------------------
/alt_cuda_corr/correlation.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | // CUDA forward declarations
5 | std::vector corr_cuda_forward(
6 | torch::Tensor fmap1,
7 | torch::Tensor fmap2,
8 | torch::Tensor coords,
9 | int radius);
10 |
11 | std::vector corr_cuda_backward(
12 | torch::Tensor fmap1,
13 | torch::Tensor fmap2,
14 | torch::Tensor coords,
15 | torch::Tensor corr_grad,
16 | int radius);
17 |
18 | // C++ interface
19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
22 |
23 | std::vector corr_forward(
24 | torch::Tensor fmap1,
25 | torch::Tensor fmap2,
26 | torch::Tensor coords,
27 | int radius) {
28 | CHECK_INPUT(fmap1);
29 | CHECK_INPUT(fmap2);
30 | CHECK_INPUT(coords);
31 |
32 | return corr_cuda_forward(fmap1, fmap2, coords, radius);
33 | }
34 |
35 |
36 | std::vector corr_backward(
37 | torch::Tensor fmap1,
38 | torch::Tensor fmap2,
39 | torch::Tensor coords,
40 | torch::Tensor corr_grad,
41 | int radius) {
42 | CHECK_INPUT(fmap1);
43 | CHECK_INPUT(fmap2);
44 | CHECK_INPUT(coords);
45 | CHECK_INPUT(corr_grad);
46 |
47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
48 | }
49 |
50 |
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52 | m.def("forward", &corr_forward, "CORR forward");
53 | m.def("backward", &corr_backward, "CORR backward");
54 | }
--------------------------------------------------------------------------------
/alt_cuda_corr/correlation_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 |
7 | #define BLOCK_H 4
8 | #define BLOCK_W 8
9 | #define BLOCK_HW BLOCK_H * BLOCK_W
10 | #define CHANNEL_STRIDE 32
11 |
12 |
13 | __forceinline__ __device__
14 | bool within_bounds(int h, int w, int H, int W) {
15 | return h >= 0 && h < H && w >= 0 && w < W;
16 | }
17 |
18 | template
19 | __global__ void corr_forward_kernel(
20 | const torch::PackedTensorAccessor32 fmap1,
21 | const torch::PackedTensorAccessor32 fmap2,
22 | const torch::PackedTensorAccessor32 coords,
23 | torch::PackedTensorAccessor32 corr,
24 | int r)
25 | {
26 | const int b = blockIdx.x;
27 | const int h0 = blockIdx.y * blockDim.x;
28 | const int w0 = blockIdx.z * blockDim.y;
29 | const int tid = threadIdx.x * blockDim.y + threadIdx.y;
30 |
31 | const int H1 = fmap1.size(1);
32 | const int W1 = fmap1.size(2);
33 | const int H2 = fmap2.size(1);
34 | const int W2 = fmap2.size(2);
35 | const int N = coords.size(1);
36 | const int C = fmap1.size(3);
37 |
38 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
39 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
40 | __shared__ scalar_t x2s[BLOCK_HW];
41 | __shared__ scalar_t y2s[BLOCK_HW];
42 |
43 | for (int c=0; c(floor(y2s[k1]))-r+iy;
76 | int w2 = static_cast(floor(x2s[k1]))-r+ix;
77 | int c2 = tid % CHANNEL_STRIDE;
78 |
79 | auto fptr = fmap2[b][h2][w2];
80 | if (within_bounds(h2, w2, H2, W2))
81 | f2[c2][k1] = fptr[c+c2];
82 | else
83 | f2[c2][k1] = 0.0;
84 | }
85 |
86 | __syncthreads();
87 |
88 | scalar_t s = 0.0;
89 | for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
105 | *(corr_ptr + ix_nw) += nw;
106 |
107 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
108 | *(corr_ptr + ix_ne) += ne;
109 |
110 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
111 | *(corr_ptr + ix_sw) += sw;
112 |
113 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
114 | *(corr_ptr + ix_se) += se;
115 | }
116 | }
117 | }
118 | }
119 | }
120 |
121 |
122 | template
123 | __global__ void corr_backward_kernel(
124 | const torch::PackedTensorAccessor32 fmap1,
125 | const torch::PackedTensorAccessor32 fmap2,
126 | const torch::PackedTensorAccessor32 coords,
127 | const torch::PackedTensorAccessor32 corr_grad,
128 | torch::PackedTensorAccessor32 fmap1_grad,
129 | torch::PackedTensorAccessor32 fmap2_grad,
130 | torch::PackedTensorAccessor32 coords_grad,
131 | int r)
132 | {
133 |
134 | const int b = blockIdx.x;
135 | const int h0 = blockIdx.y * blockDim.x;
136 | const int w0 = blockIdx.z * blockDim.y;
137 | const int tid = threadIdx.x * blockDim.y + threadIdx.y;
138 |
139 | const int H1 = fmap1.size(1);
140 | const int W1 = fmap1.size(2);
141 | const int H2 = fmap2.size(1);
142 | const int W2 = fmap2.size(2);
143 | const int N = coords.size(1);
144 | const int C = fmap1.size(3);
145 |
146 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
147 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
148 |
149 | __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
150 | __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
151 |
152 | __shared__ scalar_t x2s[BLOCK_HW];
153 | __shared__ scalar_t y2s[BLOCK_HW];
154 |
155 | for (int c=0; c(floor(y2s[k1]))-r+iy;
190 | int w2 = static_cast(floor(x2s[k1]))-r+ix;
191 | int c2 = tid % CHANNEL_STRIDE;
192 |
193 | auto fptr = fmap2[b][h2][w2];
194 | if (within_bounds(h2, w2, H2, W2))
195 | f2[c2][k1] = fptr[c+c2];
196 | else
197 | f2[c2][k1] = 0.0;
198 |
199 | f2_grad[c2][k1] = 0.0;
200 | }
201 |
202 | __syncthreads();
203 |
204 | const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
205 | scalar_t g = 0.0;
206 |
207 | int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
208 | int ix_ne = H1*W1*((iy-1) + rd*ix);
209 | int ix_sw = H1*W1*(iy + rd*(ix-1));
210 | int ix_se = H1*W1*(iy + rd*ix);
211 |
212 | if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
213 | g += *(grad_ptr + ix_nw) * dy * dx;
214 |
215 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
216 | g += *(grad_ptr + ix_ne) * dy * (1-dx);
217 |
218 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
219 | g += *(grad_ptr + ix_sw) * (1-dy) * dx;
220 |
221 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
222 | g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
223 |
224 | for (int k=0; k(floor(y2s[k1]))-r+iy;
232 | int w2 = static_cast(floor(x2s[k1]))-r+ix;
233 | int c2 = tid % CHANNEL_STRIDE;
234 |
235 | scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
236 | if (within_bounds(h2, w2, H2, W2))
237 | atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
238 | }
239 | }
240 | }
241 | }
242 | __syncthreads();
243 |
244 |
245 | for (int k=0; k corr_cuda_forward(
261 | torch::Tensor fmap1,
262 | torch::Tensor fmap2,
263 | torch::Tensor coords,
264 | int radius)
265 | {
266 | const auto B = coords.size(0);
267 | const auto N = coords.size(1);
268 | const auto H = coords.size(2);
269 | const auto W = coords.size(3);
270 |
271 | const auto rd = 2 * radius + 1;
272 | auto opts = fmap1.options();
273 | auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
274 |
275 | const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
276 | const dim3 threads(BLOCK_H, BLOCK_W);
277 |
278 | corr_forward_kernel<<>>(
279 | fmap1.packed_accessor32(),
280 | fmap2.packed_accessor32(),
281 | coords.packed_accessor32(),
282 | corr.packed_accessor32(),
283 | radius);
284 |
285 | return {corr};
286 | }
287 |
288 | std::vector corr_cuda_backward(
289 | torch::Tensor fmap1,
290 | torch::Tensor fmap2,
291 | torch::Tensor coords,
292 | torch::Tensor corr_grad,
293 | int radius)
294 | {
295 | const auto B = coords.size(0);
296 | const auto N = coords.size(1);
297 |
298 | const auto H1 = fmap1.size(1);
299 | const auto W1 = fmap1.size(2);
300 | const auto H2 = fmap2.size(1);
301 | const auto W2 = fmap2.size(2);
302 | const auto C = fmap1.size(3);
303 |
304 | auto opts = fmap1.options();
305 | auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
306 | auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
307 | auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
308 |
309 | const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
310 | const dim3 threads(BLOCK_H, BLOCK_W);
311 |
312 |
313 | corr_backward_kernel<<>>(
314 | fmap1.packed_accessor32(),
315 | fmap2.packed_accessor32(),
316 | coords.packed_accessor32(),
317 | corr_grad.packed_accessor32(),
318 | fmap1_grad.packed_accessor32(),
319 | fmap2_grad.packed_accessor32(),
320 | coords_grad.packed_accessor32(),
321 | radius);
322 |
323 | return {fmap1_grad, fmap2_grad, coords_grad};
324 | }
--------------------------------------------------------------------------------
/alt_cuda_corr/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 |
5 | setup(
6 | name='correlation',
7 | ext_modules=[
8 | CUDAExtension('alt_cuda_corr',
9 | sources=['correlation.cpp', 'correlation_kernel.cu'],
10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
11 | ],
12 | cmdclass={
13 | 'build_ext': BuildExtension
14 | })
15 |
16 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LA30/FlowDiffuser/9aff9c6e8c68f809e40bb0ae4273621276686168/core/__init__.py
--------------------------------------------------------------------------------
/core/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class AlternateCorrBlock:
64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65 | self.num_levels = num_levels
66 | self.radius = radius
67 |
68 | self.pyramid = [(fmap1, fmap2)]
69 | for i in range(self.num_levels):
70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72 | self.pyramid.append((fmap1, fmap2))
73 |
74 | def __call__(self, coords):
75 | coords = coords.permute(0, 2, 3, 1)
76 | B, H, W, _ = coords.shape
77 | dim = self.pyramid[0][0].shape[1]
78 |
79 | corr_list = []
80 | for i in range(self.num_levels):
81 | r = self.radius
82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84 |
85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87 | corr_list.append(corr.squeeze(1))
88 |
89 | corr = torch.stack(corr_list, dim=1)
90 | corr = corr.reshape(B, -1, H, W)
91 | return corr / torch.sqrt(torch.tensor(dim).float())
92 |
--------------------------------------------------------------------------------
/core/datasets.py:
--------------------------------------------------------------------------------
1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 |
8 | import os
9 | import math
10 | import random
11 | from glob import glob
12 | import os.path as osp
13 |
14 | from utils import frame_utils
15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16 |
17 |
18 | class FlowDataset(data.Dataset):
19 | def __init__(self, aug_params=None, sparse=False):
20 | self.augmentor = None
21 | self.sparse = sparse
22 | if aug_params is not None:
23 | if sparse:
24 | self.augmentor = SparseFlowAugmentor(**aug_params)
25 | else:
26 | self.augmentor = FlowAugmentor(**aug_params)
27 |
28 | self.is_test = False
29 | self.init_seed = False
30 | self.flow_list = []
31 | self.image_list = []
32 | self.extra_info = []
33 |
34 | def __getitem__(self, index):
35 |
36 | if self.is_test:
37 | img1 = frame_utils.read_gen(self.image_list[index][0])
38 | img2 = frame_utils.read_gen(self.image_list[index][1])
39 | img1 = np.array(img1).astype(np.uint8)[..., :3]
40 | img2 = np.array(img2).astype(np.uint8)[..., :3]
41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43 | return img1, img2, self.extra_info[index]
44 |
45 | if not self.init_seed:
46 | worker_info = torch.utils.data.get_worker_info()
47 | if worker_info is not None:
48 | torch.manual_seed(worker_info.id)
49 | np.random.seed(worker_info.id)
50 | random.seed(worker_info.id)
51 | self.init_seed = True
52 |
53 | index = index % len(self.image_list)
54 | valid = None
55 | if self.sparse:
56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57 | else:
58 | flow = frame_utils.read_gen(self.flow_list[index])
59 |
60 | img1 = frame_utils.read_gen(self.image_list[index][0])
61 | img2 = frame_utils.read_gen(self.image_list[index][1])
62 |
63 | flow = np.array(flow).astype(np.float32)
64 | img1 = np.array(img1).astype(np.uint8)
65 | img2 = np.array(img2).astype(np.uint8)
66 |
67 | # grayscale images
68 | if len(img1.shape) == 2:
69 | img1 = np.tile(img1[...,None], (1, 1, 3))
70 | img2 = np.tile(img2[...,None], (1, 1, 3))
71 | else:
72 | img1 = img1[..., :3]
73 | img2 = img2[..., :3]
74 |
75 | if self.augmentor is not None:
76 | if self.sparse:
77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78 | else:
79 | img1, img2, flow = self.augmentor(img1, img2, flow)
80 |
81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84 |
85 | if valid is not None:
86 | valid = torch.from_numpy(valid)
87 | else:
88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89 |
90 | return img1, img2, flow, valid.float()
91 |
92 |
93 | def __rmul__(self, v):
94 | self.flow_list = v * self.flow_list
95 | self.image_list = v * self.image_list
96 | return self
97 |
98 | def __len__(self):
99 | return len(self.image_list)
100 |
101 |
102 | class MpiSintel(FlowDataset):
103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104 | super(MpiSintel, self).__init__(aug_params)
105 | flow_root = osp.join(root, split, 'flow')
106 | image_root = osp.join(root, split, dstype)
107 |
108 | if split == 'test':
109 | self.is_test = True
110 |
111 | for scene in os.listdir(image_root):
112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113 | for i in range(len(image_list)-1):
114 | self.image_list += [ [image_list[i], image_list[i+1]] ]
115 | self.extra_info += [ (scene, i) ] # scene and frame_id
116 |
117 | if split != 'test':
118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119 |
120 |
121 | class FlyingChairs(FlowDataset):
122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123 | super(FlyingChairs, self).__init__(aug_params)
124 |
125 | images = sorted(glob(osp.join(root, '*.ppm')))
126 | flows = sorted(glob(osp.join(root, '*.flo')))
127 | assert (len(images)//2 == len(flows))
128 |
129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130 | for i in range(len(flows)):
131 | xid = split_list[i]
132 | if (split=='training' and xid==1) or (split=='validation' and xid==2):
133 | self.flow_list += [ flows[i] ]
134 | self.image_list += [ [images[2*i], images[2*i+1]] ]
135 |
136 |
137 | class FlyingThings3D(FlowDataset):
138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
139 | super(FlyingThings3D, self).__init__(aug_params)
140 |
141 | for cam in ['left']:
142 | for direction in ['into_future', 'into_past']:
143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
145 |
146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
148 |
149 | for idir, fdir in zip(image_dirs, flow_dirs):
150 | images = sorted(glob(osp.join(idir, '*.png')) )
151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) )
152 | for i in range(len(flows)-1):
153 | if direction == 'into_future':
154 | self.image_list += [ [images[i], images[i+1]] ]
155 | self.flow_list += [ flows[i] ]
156 | elif direction == 'into_past':
157 | self.image_list += [ [images[i+1], images[i]] ]
158 | self.flow_list += [ flows[i+1] ]
159 |
160 |
161 | class KITTI(FlowDataset):
162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
163 | super(KITTI, self).__init__(aug_params, sparse=True)
164 | if split == 'testing':
165 | self.is_test = True
166 |
167 | root = osp.join(root, split)
168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
170 |
171 | for img1, img2 in zip(images1, images2):
172 | frame_id = img1.split('/')[-1]
173 | self.extra_info += [ [frame_id] ]
174 | self.image_list += [ [img1, img2] ]
175 |
176 | if split == 'training':
177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
178 |
179 |
180 | class HD1K(FlowDataset):
181 | def __init__(self, aug_params=None, root='datasets/HD1k'):
182 | super(HD1K, self).__init__(aug_params, sparse=True)
183 |
184 | seq_ix = 0
185 | while 1:
186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
188 |
189 | if len(flows) == 0:
190 | break
191 |
192 | for i in range(len(flows)-1):
193 | self.flow_list += [flows[i]]
194 | self.image_list += [ [images[i], images[i+1]] ]
195 |
196 | seq_ix += 1
197 |
198 |
199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
200 | """ Create the data loader for the corresponding trainign set """
201 |
202 | if args.stage == 'chairs':
203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
204 | train_dataset = FlyingChairs(aug_params, split='training')
205 |
206 | elif args.stage == 'things':
207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
210 | train_dataset = clean_dataset + final_dataset
211 |
212 | elif args.stage == 'sintel':
213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final')
217 |
218 | if TRAIN_DS == 'C+T+K+S+H':
219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
222 |
223 | elif TRAIN_DS == 'C+T+K/S':
224 | train_dataset = 100*sintel_clean + 100*sintel_final + things
225 |
226 | elif args.stage == 'kitti':
227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
228 | train_dataset = KITTI(aug_params, split='training')
229 |
230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
232 |
233 | print('Training with %d image pairs' % len(train_dataset))
234 | return train_loader
235 |
236 |
--------------------------------------------------------------------------------
/core/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | if not stride == 1:
20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 |
22 | elif norm_fn == 'batch':
23 | self.norm1 = nn.BatchNorm2d(planes)
24 | self.norm2 = nn.BatchNorm2d(planes)
25 | if not stride == 1:
26 | self.norm3 = nn.BatchNorm2d(planes)
27 |
28 | elif norm_fn == 'instance':
29 | self.norm1 = nn.InstanceNorm2d(planes)
30 | self.norm2 = nn.InstanceNorm2d(planes)
31 | if not stride == 1:
32 | self.norm3 = nn.InstanceNorm2d(planes)
33 |
34 | elif norm_fn == 'none':
35 | self.norm1 = nn.Sequential()
36 | self.norm2 = nn.Sequential()
37 | if not stride == 1:
38 | self.norm3 = nn.Sequential()
39 |
40 | if stride == 1:
41 | self.downsample = None
42 |
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 |
48 | def forward(self, x):
49 | y = x
50 | y = self.relu(self.norm1(self.conv1(y)))
51 | y = self.relu(self.norm2(self.conv2(y)))
52 |
53 | if self.downsample is not None:
54 | x = self.downsample(x)
55 |
56 | return self.relu(x+y)
57 |
58 |
59 |
60 | class BottleneckBlock(nn.Module):
61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62 | super(BottleneckBlock, self).__init__()
63 |
64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | num_groups = planes // 8
70 |
71 | if norm_fn == 'group':
72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75 | if not stride == 1:
76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77 |
78 | elif norm_fn == 'batch':
79 | self.norm1 = nn.BatchNorm2d(planes//4)
80 | self.norm2 = nn.BatchNorm2d(planes//4)
81 | self.norm3 = nn.BatchNorm2d(planes)
82 | if not stride == 1:
83 | self.norm4 = nn.BatchNorm2d(planes)
84 |
85 | elif norm_fn == 'instance':
86 | self.norm1 = nn.InstanceNorm2d(planes//4)
87 | self.norm2 = nn.InstanceNorm2d(planes//4)
88 | self.norm3 = nn.InstanceNorm2d(planes)
89 | if not stride == 1:
90 | self.norm4 = nn.InstanceNorm2d(planes)
91 |
92 | elif norm_fn == 'none':
93 | self.norm1 = nn.Sequential()
94 | self.norm2 = nn.Sequential()
95 | self.norm3 = nn.Sequential()
96 | if not stride == 1:
97 | self.norm4 = nn.Sequential()
98 |
99 | if stride == 1:
100 | self.downsample = None
101 |
102 | else:
103 | self.downsample = nn.Sequential(
104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105 |
106 |
107 | def forward(self, x):
108 | y = x
109 | y = self.relu(self.norm1(self.conv1(y)))
110 | y = self.relu(self.norm2(self.conv2(y)))
111 | y = self.relu(self.norm3(self.conv3(y)))
112 |
113 | if self.downsample is not None:
114 | x = self.downsample(x)
115 |
116 | return self.relu(x+y)
117 |
118 | class BasicEncoder(nn.Module):
119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120 | super(BasicEncoder, self).__init__()
121 | self.norm_fn = norm_fn
122 |
123 | if self.norm_fn == 'group':
124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125 |
126 | elif self.norm_fn == 'batch':
127 | self.norm1 = nn.BatchNorm2d(64)
128 |
129 | elif self.norm_fn == 'instance':
130 | self.norm1 = nn.InstanceNorm2d(64)
131 |
132 | elif self.norm_fn == 'none':
133 | self.norm1 = nn.Sequential()
134 |
135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136 | self.relu1 = nn.ReLU(inplace=True)
137 |
138 | self.in_planes = 64
139 | self.layer1 = self._make_layer(64, stride=1)
140 | self.layer2 = self._make_layer(96, stride=2)
141 | self.layer3 = self._make_layer(128, stride=2)
142 |
143 | # output convolution
144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145 |
146 | self.dropout = None
147 | if dropout > 0:
148 | self.dropout = nn.Dropout2d(p=dropout)
149 |
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154 | if m.weight is not None:
155 | nn.init.constant_(m.weight, 1)
156 | if m.bias is not None:
157 | nn.init.constant_(m.bias, 0)
158 |
159 | def _make_layer(self, dim, stride=1):
160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162 | layers = (layer1, layer2)
163 |
164 | self.in_planes = dim
165 | return nn.Sequential(*layers)
166 |
167 |
168 | def forward(self, x):
169 |
170 | # if input is list, combine batch dimension
171 | is_list = isinstance(x, tuple) or isinstance(x, list)
172 | if is_list:
173 | batch_dim = x[0].shape[0]
174 | x = torch.cat(x, dim=0)
175 |
176 | x = self.conv1(x)
177 | x = self.norm1(x)
178 | x = self.relu1(x)
179 |
180 | x = self.layer1(x)
181 | x = self.layer2(x)
182 | x = self.layer3(x)
183 |
184 | x = self.conv2(x)
185 |
186 | if self.training and self.dropout is not None:
187 | x = self.dropout(x)
188 |
189 | if is_list:
190 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
191 |
192 | return x
193 |
194 |
195 | class SmallEncoder(nn.Module):
196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
197 | super(SmallEncoder, self).__init__()
198 | self.norm_fn = norm_fn
199 |
200 | if self.norm_fn == 'group':
201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
202 |
203 | elif self.norm_fn == 'batch':
204 | self.norm1 = nn.BatchNorm2d(32)
205 |
206 | elif self.norm_fn == 'instance':
207 | self.norm1 = nn.InstanceNorm2d(32)
208 |
209 | elif self.norm_fn == 'none':
210 | self.norm1 = nn.Sequential()
211 |
212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
213 | self.relu1 = nn.ReLU(inplace=True)
214 |
215 | self.in_planes = 32
216 | self.layer1 = self._make_layer(32, stride=1)
217 | self.layer2 = self._make_layer(64, stride=2)
218 | self.layer3 = self._make_layer(96, stride=2)
219 |
220 | self.dropout = None
221 | if dropout > 0:
222 | self.dropout = nn.Dropout2d(p=dropout)
223 |
224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
225 |
226 | for m in self.modules():
227 | if isinstance(m, nn.Conv2d):
228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230 | if m.weight is not None:
231 | nn.init.constant_(m.weight, 1)
232 | if m.bias is not None:
233 | nn.init.constant_(m.bias, 0)
234 |
235 | def _make_layer(self, dim, stride=1):
236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
238 | layers = (layer1, layer2)
239 |
240 | self.in_planes = dim
241 | return nn.Sequential(*layers)
242 |
243 |
244 | def forward(self, x):
245 |
246 | # if input is list, combine batch dimension
247 | is_list = isinstance(x, tuple) or isinstance(x, list)
248 | if is_list:
249 | batch_dim = x[0].shape[0]
250 | x = torch.cat(x, dim=0)
251 |
252 | x = self.conv1(x)
253 | x = self.norm1(x)
254 | x = self.relu1(x)
255 |
256 | x = self.layer1(x)
257 | x = self.layer2(x)
258 | x = self.layer3(x)
259 | x = self.conv2(x)
260 |
261 | if self.training and self.dropout is not None:
262 | x = self.dropout(x)
263 |
264 | if is_list:
265 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
266 |
267 | return x
268 |
--------------------------------------------------------------------------------
/core/fd_corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from utils.utils import bilinear_sampler
4 |
5 |
6 | class CorrBlock_FD_Sp4:
7 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4, coords_init=None, rad=1):
8 | self.num_levels = num_levels
9 | self.radius = radius
10 | self.corr_pyramid = []
11 |
12 | corr = CorrBlock_FD_Sp4.corr(fmap1, fmap2, coords_init, r=rad)
13 |
14 | batch, h1, w1, dim, h2, w2 = corr.shape
15 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
16 |
17 | self.corr_pyramid.append(corr)
18 | for i in range(self.num_levels-1):
19 | corr = F.avg_pool2d(corr, 2, stride=2)
20 | self.corr_pyramid.append(corr)
21 |
22 | def __call__(self, coords):
23 | r = self.radius
24 | coords = coords.permute(0, 2, 3, 1)
25 | batch, h1, w1, _ = coords.shape
26 |
27 | out_pyramid = []
28 | for i in range(self.num_levels):
29 | corr = self.corr_pyramid[i]
30 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
31 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
32 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
33 |
34 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
35 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
36 | coords_lvl = centroid_lvl + delta_lvl
37 |
38 | corr = bilinear_sampler(corr, coords_lvl)
39 | corr = corr.view(batch, h1, w1, -1)
40 | out_pyramid.append(corr)
41 |
42 | out = torch.cat(out_pyramid, dim=-1)
43 | return out.permute(0, 3, 1, 2).contiguous().float()
44 |
45 | @staticmethod
46 | def corr(fmap1, fmap2, coords_init, r):
47 | batch, dim, ht, wd = fmap1.shape
48 | fmap1 = fmap1.view(batch, dim, ht*wd)
49 | fmap2 = fmap2.view(batch, dim, ht*wd)
50 |
51 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
52 | corr = corr.view(batch, ht, wd, 1, ht, wd)
53 | # return corr / torch.sqrt(torch.tensor(dim).float())
54 |
55 | coords = coords_init.permute(0, 2, 3, 1).contiguous()
56 | batch, h1, w1, _ = coords.shape
57 |
58 | corr = corr.view(batch*h1*w1, 1, h1, w1)
59 |
60 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
61 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
62 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
63 |
64 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2)
65 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
66 | coords_lvl = centroid_lvl + delta_lvl
67 |
68 | corr = bilinear_sampler(corr, coords_lvl)
69 |
70 | corr = corr.view(batch, h1, w1, 1, 2*r+1, 2*r+1)
71 | return corr.permute(0, 1, 2, 3, 5, 4).contiguous() / torch.sqrt(torch.tensor(dim).float())
72 |
--------------------------------------------------------------------------------
/core/fd_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | from corr import CorrBlock
7 |
8 |
9 | # --- transformer modules
10 | class TransformerModule(nn.Module):
11 | def __init__(self, args):
12 | super().__init__()
13 | self.args = args
14 | self.tb = TransBlocks(args)
15 |
16 | def forward(self, fmap1, fmap2, inp):
17 | batch, ch, ht, wd = fmap1.shape
18 | fmap1, fmap2 = self.tb(fmap1, fmap2)
19 | corr_fn = CorrBlock(fmap1, fmap2, num_levels=self.args.corr_levels, radius=self.args.corr_radius)
20 | return corr_fn
21 |
22 |
23 | class TransBlocks(nn.Module):
24 | def __init__(self, args):
25 | super().__init__()
26 | dim = args.m_dim
27 | mlp_scale = 4
28 | window_size = [8, 8]
29 | num_layers = [2, 2]
30 |
31 | self.num_layers = len(num_layers)
32 | self.blocks = nn.ModuleList()
33 | for n in range(self.num_layers):
34 | if n == self.num_layers - 1:
35 | self.blocks.append(
36 | BasicLayer(num_layer=num_layers[n], dim=dim, mlp_scale=mlp_scale, window_size=window_size, cross=False))
37 | else:
38 | self.blocks.append(
39 | BasicLayer(num_layer=num_layers[n], dim=dim, mlp_scale=mlp_scale, window_size=window_size, cross=True))
40 |
41 | def forward(self, fmap1, fmap2):
42 | _, _, ht, wd = fmap1.shape
43 | pad_h, pad_w = (8 - (ht % 8)) % 8, (8 - (wd % 8)) % 8
44 | _pad = [pad_w // 2, pad_w - pad_w // 2, pad_h, 0]
45 | fmap1 = F.pad(fmap1, pad=_pad, mode='constant', value=0)
46 | fmap2 = F.pad(fmap2, pad=_pad, mode='constant', value=0)
47 | mask = torch.zeros([1, ht, wd]).to(fmap1.device)
48 | mask = torch.nn.functional.pad(mask, pad=_pad, mode='constant', value=1)
49 | mask = mask.bool()
50 | fmap1 = fmap1.permute(0, 2, 3, 1).contiguous().float()
51 | fmap2 = fmap2.permute(0, 2, 3, 1).contiguous().float()
52 |
53 | for idx, blk in enumerate(self.blocks):
54 | fmap1, fmap2 = blk(fmap1, fmap2, mask=mask)
55 |
56 | _, ht, wd, _ = fmap1.shape
57 | fmap1 = fmap1[:, _pad[2]:ht - _pad[3], _pad[0]:wd - _pad[1], :]
58 | fmap2 = fmap2[:, _pad[2]:ht - _pad[3], _pad[0]:wd - _pad[1], :]
59 |
60 | fmap1 = fmap1.permute(0, 3, 1, 2).contiguous()
61 | fmap2 = fmap2.permute(0, 3, 1, 2).contiguous()
62 |
63 | return fmap1, fmap2
64 |
65 |
66 | def window_partition(fmap, window_size):
67 | """
68 | :param fmap: shape:B, H, W, C
69 | :param window_size: Wh, Ww
70 | :return: shape: B*nW, Wh*Ww, C
71 | """
72 | B, H, W, C = fmap.shape
73 | fmap = fmap.reshape(B, H//window_size[0], window_size[0], W//window_size[1], window_size[1], C)
74 | fmap = fmap.permute(0, 1, 3, 2, 4, 5).contiguous()
75 | fmap = fmap.reshape(B*(H//window_size[0])*(W//window_size[1]), window_size[0]*window_size[1], C)
76 | return fmap
77 |
78 |
79 | def window_reverse(fmap, window_size, H, W):
80 | """
81 | :param fmap: shape:B*nW, Wh*Ww, dim
82 | :param window_size: Wh, Ww
83 | :param H: original image height
84 | :param W: original image width
85 | :return: shape: B, H, W, C
86 | """
87 | Bnw, _, dim = fmap.shape
88 | nW = (H // window_size[0]) * (W // window_size[1])
89 | fmap = fmap.reshape(Bnw//nW, H // window_size[0], W // window_size[1], window_size[0], window_size[1], dim)
90 | fmap = fmap.permute(0, 1, 3, 2, 4, 5).contiguous()
91 | fmap = fmap.reshape(Bnw//nW, H, W, dim)
92 | return fmap
93 |
94 |
95 | class WindowAttention(nn.Module):
96 | def __init__(self, dim, window_size, scale=None):
97 | super().__init__()
98 | self.dim = dim
99 | self.scale = scale or dim ** (-0.5)
100 | self.q = nn.Linear(in_features=dim, out_features=dim)
101 | self.k = nn.Linear(in_features=dim, out_features=dim)
102 | self.v = nn.Linear(in_features=dim, out_features=dim)
103 | self.softmax = nn.Softmax(dim=-1)
104 | self.proj = nn.Linear(dim, dim)
105 |
106 | def forward(self, fmap, mask=None):
107 | """
108 | :param fmap1: B*nW, Wh*Ww, dim
109 | :param mask: nw, Wh*Ww, Ww*Wh
110 | :return: B*nW, Wh*Ww, dim
111 | """
112 | Bnw, WhWw, dim = fmap.shape
113 | q = self.q(fmap)
114 | k = self.k(fmap)
115 | v = self.v(fmap)
116 |
117 | q = q * self.scale
118 | attn = q @ k.transpose(1, 2)
119 |
120 | if mask is not None:
121 | nw = mask.shape[0]
122 | attn = attn.reshape(Bnw//nw, nw, WhWw, WhWw) + mask.unsqueeze(0)
123 | attn = attn.reshape(Bnw, WhWw, WhWw)
124 | attn = self.softmax(attn)
125 | else:
126 | attn = self.softmax(attn)
127 | x = attn @ v
128 | x = self.proj(x)
129 | return x
130 |
131 |
132 | class GlobalAttention(nn.Module):
133 | def __init__(self, dim, scale=None):
134 | super().__init__()
135 | self.dim = dim
136 | self.scale = scale or dim ** (-0.5)
137 | self.q = nn.Linear(in_features=dim, out_features=dim)
138 | self.k = nn.Linear(in_features=dim, out_features=dim)
139 | self.v = nn.Linear(in_features=dim, out_features=dim)
140 | self.softmax = nn.Softmax(dim=-1)
141 | self.proj = nn.Linear(in_features=dim, out_features=dim)
142 |
143 | def forward(self, fmap1, fmap2, mask=None):
144 | """
145 | :param fmap1: B, H, W, C
146 | :param fmap2: B, H, W, C
147 | :param pe: B, H, W, C
148 | :return:
149 | """
150 | B, H, W, C = fmap1.shape
151 | q = self.q(fmap1)
152 | k = self.k(fmap2)
153 | v = self.v(fmap2)
154 |
155 | q, k, v = map(lambda x: x.reshape(B, H*W, C), [q, k, v])
156 |
157 | q = q * self.scale
158 | attn = q @ k.transpose(1, 2)
159 | if mask is not None:
160 | mask = mask.reshape(1, H * W, 1) | mask.reshape(1, 1, H * W) # batch, hw, hw
161 | mask = mask.float() * -100.0
162 | attn = attn + mask
163 | attn = self.softmax(attn)
164 | x = attn @ v # B, HW, C
165 |
166 | x = self.proj(x)
167 |
168 | x = x.reshape(B, H, W, C)
169 | return x
170 |
171 |
172 | class SelfTransformerBlcok(nn.Module):
173 | def __init__(self, dim, mlp_scale, window_size, shift_size=None, norm=None):
174 | super().__init__()
175 | self.dim = dim
176 | self.window_size = window_size
177 | self.shift_size = shift_size
178 |
179 | if norm == 'layer':
180 | self.layer_norm1 = nn.LayerNorm(dim)
181 | self.layer_norm2 = nn.LayerNorm(dim)
182 | else:
183 | self.layer_norm1 = nn.Identity()
184 | self.layer_norm2 = nn.Identity()
185 |
186 | self.self_attn = WindowAttention(dim=dim, window_size=window_size)
187 | self.mlp = nn.Sequential(
188 | nn.Linear(dim, dim * mlp_scale),
189 | nn.GELU(),
190 | nn.Linear(dim * mlp_scale, dim)
191 | )
192 |
193 | def forward(self, fmap, mask=None):
194 | """
195 | :param fmap: shape: B, H, W, C
196 | :return: B, H, W, C
197 | """
198 | B, H, W, C = fmap.shape
199 |
200 | shortcut = fmap
201 | fmap = self.layer_norm1(fmap)
202 |
203 | if self.shift_size is not None:
204 | shifted_fmap = torch.roll(fmap, [-self.shift_size[0], -self.shift_size[1]], dims=(1, 2))
205 | if mask is not None:
206 | shifted_mask = torch.roll(mask, [-self.shift_size[0], -self.shift_size[1]], dims=(1, 2))
207 | else:
208 | shifted_fmap = fmap
209 | if mask is not None:
210 | shifted_mask = mask
211 |
212 | win_fmap = window_partition(shifted_fmap, window_size=self.window_size)
213 | if mask is not None:
214 | pad_mask = window_partition(shifted_mask.unsqueeze(-1), self.window_size)
215 | pad_mask = pad_mask.reshape(-1, self.window_size[0] * self.window_size[1], 1) \
216 | | pad_mask.reshape(-1, 1, self.window_size[0] * self.window_size[1])
217 |
218 | if self.shift_size is not None:
219 | h_slice = [slice(0, -self.window_size[0]), slice(-self.window_size[0], -self.shift_size[0]), slice(-self.shift_size[0], None)]
220 | w_slice = [slice(0, -self.window_size[1]), slice(-self.window_size[1], -self.shift_size[1]), slice(-self.shift_size[1], None)]
221 | img_mask = torch.zeros([1, H, W, 1]).to(win_fmap.device)
222 | count = 0
223 | for h in h_slice:
224 | for w in w_slice:
225 | img_mask[:, h, w, :] = count
226 | count += 1
227 | win_mask = window_partition(img_mask, self.window_size)
228 | win_mask = win_mask.reshape(-1, self.window_size[0] * self.window_size[1]) # nW, Wh*Ww
229 | attn_mask = win_mask.unsqueeze(2) - win_mask.unsqueeze(1) # nw, Wh*Ww, Wh*Ww
230 | if mask is not None:
231 | attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0).masked_fill((attn_mask != 0) | pad_mask, -100.0)
232 | else:
233 | attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0).masked_fill(attn_mask != 0, -100.0)
234 | attn_fmap = self.self_attn(win_fmap, attn_mask)
235 | else:
236 | if mask is not None:
237 | pad_mask = pad_mask.float()
238 | pad_mask = pad_mask.masked_fill(pad_mask != 0, -100.0).masked_fill(pad_mask == 0, 0.0)
239 | attn_fmap = self.self_attn(win_fmap, pad_mask)
240 | else:
241 | attn_fmap = self.self_attn(win_fmap, None)
242 | shifted_fmap = window_reverse(attn_fmap, self.window_size, H, W)
243 |
244 | if self.shift_size is not None:
245 | fmap = torch.roll(shifted_fmap, [self.shift_size[0], self.shift_size[1]], dims=(1, 2))
246 | else:
247 | fmap = shifted_fmap
248 |
249 | fmap = shortcut + fmap
250 | fmap = fmap + self.mlp(self.layer_norm2(fmap)) # B, H, W, C
251 | return fmap
252 |
253 |
254 | class CrossTransformerBlcok(nn.Module):
255 | def __init__(self, dim, mlp_scale, norm=None):
256 | super().__init__()
257 | self.dim = dim
258 |
259 | if norm == 'layer':
260 | self.layer_norm1 = nn.LayerNorm(dim)
261 | self.layer_norm2 = nn.LayerNorm(dim)
262 | self.layer_norm3 = nn.LayerNorm(dim)
263 | else:
264 | self.layer_norm1 = nn.Identity()
265 | self.layer_norm2 = nn.Identity()
266 | self.layer_norm3 = nn.Identity()
267 | self.cross_attn = GlobalAttention(dim=dim)
268 | self.mlp = nn.Sequential(
269 | nn.Linear(dim, dim * mlp_scale),
270 | nn.GELU(),
271 | nn.Linear(dim * mlp_scale, dim)
272 | )
273 |
274 | def forward(self, fmap1, fmap2, mask=None):
275 | """
276 | :param fmap1: shape: B, H, W, C
277 | :param fmap2: shape: B, H, W, C
278 | :return: B, H, W, C
279 | """
280 | shortcut = fmap1
281 |
282 | fmap1 = self.layer_norm1(fmap1)
283 | fmap2 = self.layer_norm2(fmap2)
284 |
285 | attn_fmap = self.cross_attn(fmap1, fmap2, mask)
286 | attn_fmap = shortcut + attn_fmap
287 | fmap = attn_fmap + self.mlp(self.layer_norm3(attn_fmap)) # B, H, W, C
288 | return fmap
289 |
290 |
291 | class BasicLayer(nn.Module):
292 | def __init__(self, num_layer, dim, mlp_scale, window_size, cross=False):
293 | super().__init__()
294 | assert num_layer % 2 == 0, "The number of Transformer Block must be even!"
295 | self.blocks = nn.ModuleList()
296 | for n in range(num_layer):
297 | shift_size = None if n % 2 == 0 else [window_size[0]//2, window_size[1]//2]
298 | self.blocks.append(
299 | SelfTransformerBlcok(
300 | dim=dim,
301 | mlp_scale=mlp_scale,
302 | window_size=window_size,
303 | shift_size=shift_size,
304 | norm='layer'))
305 |
306 | if cross:
307 | self.cross_transformer = CrossTransformerBlcok(dim=dim, mlp_scale=mlp_scale, norm='layer')
308 |
309 | self.cross = cross
310 |
311 | def forward(self, fmap1, fmap2, mask=None):
312 | """
313 | :param fmap1: B, H, W, C
314 | :param fmap2: B, H, W, C
315 | :return: B, H, W, C
316 | """
317 | B = fmap1.shape[0]
318 | fmap = torch.cat([fmap1, fmap2], dim=0)
319 | for blk in self.blocks:
320 | fmap = blk(fmap, mask)
321 | fmap1, fmap2 = torch.split(fmap, [B]*2, dim=0)
322 | if self.cross:
323 | fmap2 = self.cross_transformer(fmap2, fmap1, mask) + fmap2
324 | fmap1 = self.cross_transformer(fmap1, fmap2, mask) + fmap1
325 | return fmap1, fmap2
326 |
327 |
328 | # --- upsample modules
329 | class UpSampleMask8(nn.Module):
330 | def __init__(self, dim):
331 | super().__init__()
332 | self.up_sample_mask = nn.Sequential(
333 | nn.Conv2d(in_channels=dim, out_channels=256, kernel_size=3, padding=1, stride=1),
334 | nn.ReLU(inplace=True),
335 | nn.Conv2d(in_channels=256, out_channels=64 * 9, kernel_size=1, stride=1)
336 | )
337 |
338 | def forward(self, data):
339 | """
340 | :param data: B, C, H, W
341 | :return: batch, 8*8*9, H, W
342 | """
343 | mask = self.up_sample_mask(data) # B, 64*6, H, W
344 | return mask
345 |
346 |
347 | class UpSampleMask4(nn.Module):
348 | def __init__(self, dim):
349 | super().__init__()
350 | self.up_sample_mask = nn.Sequential(
351 | nn.Conv2d(in_channels=dim, out_channels=256, kernel_size=3, padding=1, stride=1),
352 | nn.ReLU(inplace=True),
353 | nn.Conv2d(in_channels=256, out_channels=16 * 9, kernel_size=1, stride=1)
354 | )
355 |
356 | def forward(self, data):
357 | """
358 | :param data: B, C, H, W
359 | :return: batch, 8*8*9, H, W
360 | """
361 | mask = self.up_sample_mask(data) # B, 64*6, H, W
362 | return mask
363 |
364 |
365 | # --- SK decoder modules
366 | class PCBlock4_Deep_nopool_res(nn.Module):
367 | def __init__(self, C_in, C_out, k_conv):
368 | super().__init__()
369 | self.conv_list = nn.ModuleList([
370 | nn.Conv2d(C_in, C_in, kernel, stride=1, padding=kernel//2, groups=C_in) for kernel in k_conv])
371 |
372 | self.ffn1 = nn.Sequential(
373 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0),
374 | nn.GELU(),
375 | nn.Conv2d(int(1.5*C_in), C_in, 1, padding=0),
376 | )
377 | self.pw = nn.Conv2d(C_in, C_in, 1, padding=0)
378 | self.ffn2 = nn.Sequential(
379 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0),
380 | nn.GELU(),
381 | nn.Conv2d(int(1.5*C_in), C_out, 1, padding=0),
382 | )
383 |
384 | def forward(self, x):
385 | x = F.gelu(x + self.ffn1(x))
386 | for conv in self.conv_list:
387 | x = F.gelu(x + conv(x))
388 | x = F.gelu(x + self.pw(x))
389 | x = self.ffn2(x)
390 | return x
391 |
392 |
393 | class SKMotionEncoder6_Deep_nopool_res(nn.Module):
394 | def __init__(self, args):
395 | super().__init__()
396 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
397 | self.convc1 = PCBlock4_Deep_nopool_res(cor_planes, 256, k_conv=args.k_conv)
398 | self.convc2 = PCBlock4_Deep_nopool_res(256, 192, k_conv=args.k_conv)
399 |
400 | self.convf1 = nn.Conv2d(2, 128, 1, 1, 0)
401 | self.convf2 = PCBlock4_Deep_nopool_res(128, 64, k_conv=args.k_conv)
402 |
403 | self.conv = PCBlock4_Deep_nopool_res(64+192, 128-2, k_conv=args.k_conv)
404 |
405 | def forward(self, flow, corr):
406 | cor = F.gelu(self.convc1(corr))
407 |
408 | cor = self.convc2(cor)
409 |
410 | flo = self.convf1(flow)
411 | flo = self.convf2(flo)
412 |
413 | cor_flo = torch.cat([cor, flo], dim=1)
414 | out = self.conv(cor_flo)
415 |
416 | return torch.cat([out, flow], dim=1)
417 |
418 |
419 | class SKUpdateBlock6_Deep_nopoolres_AllDecoder(nn.Module):
420 | def __init__(self, args, hidden_dim):
421 | super().__init__()
422 | self.args = args
423 | self.encoder = SKMotionEncoder6_Deep_nopool_res(args)
424 | self.gru = PCBlock4_Deep_nopool_res(128+hidden_dim+hidden_dim+128, 128, k_conv=args.PCUpdater_conv)
425 | self.flow_head = PCBlock4_Deep_nopool_res(128, 2, k_conv=args.k_conv)
426 |
427 | self.mask = nn.Sequential(
428 | nn.Conv2d(128, 256, 3, padding=1),
429 | nn.ReLU(inplace=True),
430 | nn.Conv2d(256, 64*9, 1, padding=0))
431 |
432 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=self.args.num_heads)
433 |
434 | def forward(self, net, inp, corr, flow, attention):
435 | motion_features = self.encoder(flow, corr)
436 | motion_features_global = self.aggregator(attention, motion_features)
437 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1)
438 |
439 | # Attentional update
440 | net = self.gru(torch.cat([net, inp_cat], dim=1))
441 |
442 | delta_flow = self.flow_head(net)
443 |
444 | # scale mask to balence gradients
445 | mask = .25 * self.mask(net)
446 | return net, mask, delta_flow
447 |
448 |
449 | class Aggregator(nn.Module):
450 | def __init__(self, args, chnn, heads=1):
451 | super().__init__()
452 | self.scale = chnn ** -0.5
453 | self.to_qk = nn.Conv2d(chnn, chnn * 2, 1, bias=False)
454 | self.to_v = nn.Conv2d(128, 128, 1, bias=False)
455 | self.gamma = nn.Parameter(torch.zeros(1))
456 |
457 | def forward(self, *inputs):
458 | feat_ctx, feat_mo, itr = inputs
459 |
460 | feat_shape = feat_mo.shape
461 | b, c, h, w = feat_shape
462 | c_c = feat_ctx.shape[1]
463 |
464 | if itr == 0:
465 | feat_q, feat_k = self.to_qk(feat_ctx).chunk(2, dim=1)
466 | feat_q = self.scale * feat_q.view(b, c_c, h*w)
467 | feat_k = feat_k.view(b, c_c, h*w)
468 |
469 | attn = torch.einsum('b c n, b c m -> b m n', feat_q, feat_k)
470 | attn = attn.view(b, 1, h*w, h*w)
471 | self.attn = attn.softmax(2).view(b, h*w, h*w).permute(0, 2, 1).contiguous()
472 |
473 | feat_v = self.to_v(feat_mo).view(b, c, h*w)
474 | feat_o = torch.einsum('b n m, b c m -> b c n', self.attn, feat_v).contiguous().view(b, c, h, w)
475 | feat_o = feat_mo + feat_o * self.gamma
476 | return feat_o
477 |
478 |
479 | class SKUpdate(nn.Module):
480 | def __init__(self, args, hidden_dim):
481 | super().__init__()
482 | self.args = args
483 | d_dim = args.c_dim
484 |
485 | self.encoder = SKMotionEncoder6_Deep_nopool_res(args)
486 | self.gru = PCBlock4_Deep_nopool_res(128+hidden_dim+hidden_dim+128, d_dim, k_conv=args.PCUpdater_conv)
487 | self.flow_head = PCBlock4_Deep_nopool_res(d_dim, 2, k_conv=args.k_conv)
488 | self.aggregator = Aggregator(self.args, d_dim)
489 |
490 | def forward(self, net, inp, corr, flow, itr=None, sp4=False):
491 | motion_features = self.encoder(flow, corr)
492 |
493 | if not sp4:
494 | motion_features_global = self.aggregator(inp, motion_features, itr)
495 | else:
496 | motion_features_global = motion_features
497 |
498 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1)
499 | net = self.gru(torch.cat([net, inp_cat], dim=1))
500 |
501 | delta_flow = self.flow_head(net)
502 | return net, delta_flow
503 |
504 |
505 | class SinusoidalPositionEmbeddings(nn.Module):
506 | def __init__(self, dim):
507 | super().__init__()
508 | self.dim = dim
509 |
510 | def forward(self, time):
511 | device = time.device
512 | half_dim = self.dim // 2
513 | embeddings = math.log(10000) / (half_dim - 1)
514 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
515 | embeddings = time[:, None] * embeddings[None, :]
516 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
517 | return embeddings
518 |
519 |
520 | class ConvEE(nn.Module):
521 | def __init__(self, C_in, C_out):
522 | super().__init__()
523 | groups = 4
524 | self.conv1 = nn.Sequential(
525 | nn.GroupNorm(groups, C_in),
526 | nn.GELU(),
527 | nn.Conv2d(C_in, C_in, 3, padding=1),
528 | nn.GroupNorm(groups, C_in))
529 | self.conv2 = nn.Sequential(
530 | nn.GELU(),
531 | nn.Conv2d(C_in, C_in, 3, padding=1))
532 | self.gamma = nn.Parameter(torch.zeros(1))
533 |
534 | def forward(self, x, t_emb):
535 | scale, shift = t_emb
536 | x_res = x
537 | x = self.conv1(x)
538 |
539 | x = x * (scale + 1) + shift
540 |
541 | x = self.conv2(x)
542 | x_o = x * self.gamma
543 |
544 | return x_o
545 |
546 |
547 | class SKUpdateDFM(nn.Module):
548 | def __init__(self, args, hidden_dim):
549 | super().__init__()
550 | self.args = args
551 | chnn = hidden_dim
552 | self.conv_ee = ConvEE(chnn, chnn)
553 |
554 | d_model = 256
555 | self.d_model = d_model
556 | time_dim = d_model * 2
557 | self.time_mlp = nn.Sequential(
558 | SinusoidalPositionEmbeddings(d_model),
559 | nn.Linear(d_model, time_dim),
560 | nn.GELU(),
561 | nn.Linear(time_dim, time_dim))
562 | self.chnn_o = 256
563 | self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, self.chnn_o))
564 |
565 | def forward(self, net, inp, corr, flow, itr, first_step=False, dfm_params=[]):
566 | t, funcs, i_ddim, dfm_itrs = dfm_params
567 | b = t.shape[0]
568 | time_emb = self.time_mlp(t)
569 |
570 | scale_shift = self.block_time_mlp(time_emb)
571 | scale_shift = scale_shift.view(b, 256, 1, 1)
572 | scale, shift = scale_shift.chunk(2, dim=1)
573 |
574 | motion_features = funcs.encoder(flow, corr)
575 |
576 | if first_step:
577 | self.shape = net.shape
578 |
579 | if self.shape == net.shape:
580 | feat_mo = funcs.aggregator(inp, motion_features, itr)
581 | else:
582 | feat_mo = motion_features
583 |
584 | feat_mo = self.conv_ee(feat_mo, [scale, shift])
585 |
586 | inp = torch.cat([inp, motion_features, feat_mo], dim=1)
587 | net = funcs.gru(torch.cat([net, inp], dim=1))
588 |
589 | net = net * (scale + 1) + shift
590 |
591 | delta_flow = funcs.flow_head(net)
592 |
593 | return net, delta_flow
594 |
595 |
596 |
597 |
--------------------------------------------------------------------------------
/core/fd_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import timm
6 | import numpy as np
7 |
8 |
9 | class twins_svt_large(nn.Module):
10 | def __init__(self, pretrained=True):
11 | super().__init__()
12 | self.svt = timm.create_model('twins_svt_large', pretrained=pretrained)
13 |
14 | del self.svt.head
15 | del self.svt.patch_embeds[2]
16 | del self.svt.patch_embeds[2]
17 | del self.svt.blocks[2]
18 | del self.svt.blocks[2]
19 | del self.svt.pos_block[2]
20 | del self.svt.pos_block[2]
21 |
22 | def forward(self, x, data=None, layer=2):
23 | B = x.shape[0]
24 | x_4 = None
25 | for i, (embed, drop, blocks, pos_blk) in enumerate(
26 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
27 |
28 | patch_size = embed.patch_size
29 | if i == layer - 1:
30 | embed.patch_size = (1, 1)
31 | embed.proj.stride = embed.patch_size
32 | x_4 = torch.nn.functional.pad(x, [1, 0, 1, 0], mode='constant', value=0)
33 | x_4, size_4 = embed(x_4)
34 | size_4 = (size_4[0] - 1, size_4[1] - 1)
35 | x_4 = drop(x_4)
36 | for j, blk in enumerate(blocks):
37 | x_4 = blk(x_4, size_4)
38 | if j == 0:
39 | x_4 = pos_blk(x_4, size_4)
40 |
41 | if i < len(self.svt.depths) - 1:
42 | x_4 = x_4.reshape(B, *size_4, -1).permute(0, 3, 1, 2).contiguous()
43 |
44 | embed.patch_size = patch_size
45 | embed.proj.stride = patch_size
46 | x, size = embed(x)
47 | x = drop(x)
48 | for j, blk in enumerate(blocks):
49 | x = blk(x, size)
50 | if j==0:
51 | x = pos_blk(x, size)
52 | if i < len(self.svt.depths) - 1:
53 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
54 |
55 | if i == layer-1:
56 | break
57 |
58 | return x, x_4
59 |
60 | def compute_params(self, layer=2):
61 | num = 0
62 | for i, (embed, drop, blocks, pos_blk) in enumerate(
63 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
64 |
65 | for param in embed.parameters():
66 | num += np.prod(param.size())
67 |
68 | for param in drop.parameters():
69 | num += np.prod(param.size())
70 |
71 | for param in blocks.parameters():
72 | num += np.prod(param.size())
73 |
74 | for param in pos_blk.parameters():
75 | num += np.prod(param.size())
76 |
77 | if i == layer-1:
78 | break
79 |
80 | for param in self.svt.head.parameters():
81 | num += np.prod(param.size())
82 |
83 | return num
84 |
85 |
86 | class twins_svt_small_context(nn.Module):
87 | def __init__(self, pretrained=True):
88 | super().__init__()
89 | self.svt = timm.create_model('twins_svt_small', pretrained=pretrained)
90 |
91 | del self.svt.head
92 | del self.svt.patch_embeds[2]
93 | del self.svt.patch_embeds[2]
94 | del self.svt.blocks[2]
95 | del self.svt.blocks[2]
96 | del self.svt.pos_block[2]
97 | del self.svt.pos_block[2]
98 |
99 | def forward(self, x, data=None, layer=2):
100 | B = x.shape[0]
101 | x_4 = None
102 | for i, (embed, drop, blocks, pos_blk) in enumerate(
103 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
104 |
105 | patch_size = embed.patch_size
106 | if i == layer - 1:
107 | embed.patch_size = (1, 1)
108 | embed.proj.stride = embed.patch_size
109 | x_4 = torch.nn.functional.pad(x, [1, 0, 1, 0], mode='constant', value=0)
110 | x_4, size_4 = embed(x_4)
111 | size_4 = (size_4[0] - 1, size_4[1] - 1)
112 | x_4 = drop(x_4)
113 | for j, blk in enumerate(blocks):
114 | x_4 = blk(x_4, size_4)
115 | if j == 0:
116 | x_4 = pos_blk(x_4, size_4)
117 |
118 | if i < len(self.svt.depths) - 1:
119 | x_4 = x_4.reshape(B, *size_4, -1).permute(0, 3, 1, 2).contiguous()
120 |
121 | embed.patch_size = patch_size
122 | embed.proj.stride = patch_size
123 | x, size = embed(x)
124 | x = drop(x)
125 | for j, blk in enumerate(blocks):
126 | x = blk(x, size)
127 | if j == 0:
128 | x = pos_blk(x, size)
129 | if i < len(self.svt.depths) - 1:
130 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
131 |
132 | if i == layer - 1:
133 | break
134 |
135 | return x, x_4
136 |
137 | def compute_params(self, layer=2):
138 | num = 0
139 | for i, (embed, drop, blocks, pos_blk) in enumerate(
140 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
141 |
142 | for param in embed.parameters():
143 | num += np.prod(param.size())
144 |
145 | for param in drop.parameters():
146 | num += np.prod(param.size())
147 |
148 | for param in blocks.parameters():
149 | num += np.prod(param.size())
150 |
151 | for param in pos_blk.parameters():
152 | num += np.prod(param.size())
153 |
154 | if i == layer - 1:
155 | break
156 |
157 | for param in self.svt.head.parameters():
158 | num += np.prod(param.size())
159 |
160 | return num
161 |
162 |
--------------------------------------------------------------------------------
/core/flowdiffuser.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | from corr import CorrBlock
7 | from utils.utils import coords_grid
8 |
9 | from fd_encoder import twins_svt_large, twins_svt_small_context
10 | from fd_decoder import UpSampleMask8, UpSampleMask4, TransformerModule, SKUpdate, SinusoidalPositionEmbeddings, SKUpdateDFM
11 | from fd_corr import CorrBlock_FD_Sp4
12 |
13 | autocast = torch.cuda.amp.autocast
14 |
15 |
16 | def exists(x):
17 | return x is not None
18 |
19 |
20 | def default(val, d):
21 | if exists(val):
22 | return val
23 | return d() if callable(d) else d
24 |
25 |
26 | def extract(a, t, x_shape):
27 | """extract the appropriate t index for a batch of indices"""
28 | batch_size = t.shape[0]
29 | out = a.gather(-1, t)
30 | return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
31 |
32 |
33 | def cosine_beta_schedule(timesteps, s=0.008):
34 | """
35 | cosine schedule
36 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
37 | """
38 | steps = timesteps + 1
39 | x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
40 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
41 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
42 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
43 | return torch.clip(betas, 0, 0.999)
44 |
45 |
46 | def ste_round(x):
47 | return torch.round(x) - x.detach() + x
48 |
49 |
50 | class FlowDiffuser(nn.Module):
51 | def __init__(self, args):
52 | super().__init__()
53 | print('\n ---------- model: FlowDiffuser ---------- \n')
54 |
55 | args.corr_levels = 4
56 | args.corr_radius = 4
57 | args.m_dim = 256
58 | args.c_dim = c_dim = 128
59 | args.iters_const6 = 6
60 |
61 | self.args = args
62 | self.args.UpdateBlock = 'SKUpdateBlock6_Deep_nopoolres_AllDecoder'
63 | self.args.k_conv = [1, 15]
64 | self.args.PCUpdater_conv = [1, 7]
65 | self.sp4 = True
66 | self.rad = 8
67 |
68 | self.fnet = twins_svt_large(pretrained=True)
69 | self.cnet = twins_svt_small_context(pretrained=True)
70 | self.trans = TransformerModule(args)
71 | self.C_inp = nn.Conv2d(in_channels=c_dim, out_channels=c_dim, kernel_size=1)
72 | self.C_net = nn.Conv2d(in_channels=c_dim, out_channels=c_dim, kernel_size=1)
73 | self.update = SKUpdate(self.args, hidden_dim=c_dim)
74 | self.um8 = UpSampleMask8(c_dim)
75 | self.um4 = UpSampleMask4(c_dim)
76 | self.zero = nn.Parameter(torch.zeros(12), requires_grad=False)
77 |
78 | self.diffusion = True
79 | if self.diffusion:
80 | self.update_dfm = SKUpdateDFM(self.args, hidden_dim=c_dim)
81 |
82 | timesteps = 1000
83 | sampling_timesteps = 4
84 | recurr_itrs = 6
85 | print(' -- denoise steps: %d \n' % sampling_timesteps)
86 | print(' -- recurrent iterations: %d \n' % recurr_itrs)
87 |
88 | self.ddim_n = sampling_timesteps
89 | self.recurr_itrs = recurr_itrs
90 | self.n_sc = 0.1
91 | self.scale = nn.Parameter(torch.ones(1) * 0.5, requires_grad=False)
92 | self.n_lambda = 0.2
93 |
94 | self.objective = 'pred_x0'
95 | betas = cosine_beta_schedule(timesteps)
96 | alphas = 1. - betas
97 | alphas_cumprod = torch.cumprod(alphas, dim=0)
98 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
99 | timesteps, = betas.shape
100 | self.num_timesteps = int(timesteps)
101 |
102 | self.sampling_timesteps = default(sampling_timesteps, timesteps)
103 | assert self.sampling_timesteps <= timesteps
104 | self.is_ddim_sampling = self.sampling_timesteps < timesteps
105 | self.ddim_sampling_eta = 1.
106 | self.self_condition = False
107 |
108 | self.register_buffer('betas', betas)
109 | self.register_buffer('alphas_cumprod', alphas_cumprod)
110 | self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
111 | self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
112 | self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
113 | self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
114 | self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
115 | self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
116 |
117 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
118 | self.register_buffer('posterior_variance', posterior_variance)
119 | self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))
120 | self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
121 | self.register_buffer('posterior_mean_coef2',
122 | (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
123 |
124 | def freeze_bn(self):
125 | for m in self.modules():
126 | if isinstance(m, nn.BatchNorm2d):
127 | m.eval()
128 |
129 | def up_sample_flow8(self, flow, mask):
130 | B, _, H, W = flow.shape
131 | flow = torch.nn.functional.unfold(8 * flow, kernel_size=[3, 3], stride=[1, 1], padding=[1, 1])
132 | flow = flow.reshape(B, 2, 9, 1, 1, H, W)
133 | mask = mask.reshape(B, 1, 9, 8, 8, H, W)
134 | mask = torch.softmax(mask, dim=2)
135 | up_flow = torch.sum(flow * mask, dim=2)
136 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3).contiguous()
137 | up_flow = up_flow.reshape(B, 2, H * 8, W * 8)
138 | return up_flow
139 |
140 | def up_sample_flow4(self, flow, mask):
141 | B, _, H, W = flow.shape
142 | flow = torch.nn.functional.unfold(4 * flow, kernel_size=[3, 3], stride=[1, 1], padding=[1, 1])
143 | flow = flow.reshape(B, 2, 9, 1, 1, H, W)
144 | mask = mask.reshape(B, 1, 9, 4, 4, H, W)
145 | mask = torch.softmax(mask, dim=2)
146 | up_flow = torch.sum(flow * mask, dim=2)
147 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3).contiguous()
148 | up_flow = up_flow.reshape(B, 2, H * 4, W * 4)
149 | return up_flow
150 |
151 | def initialize_flow8(self, img):
152 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
153 | N, C, H, W = img.shape
154 | coords0 = coords_grid(N, H//8, W//8, device=img.device).permute(0, 2, 3, 1).contiguous()
155 | coords1 = coords_grid(N, H//8, W//8, device=img.device).permute(0, 2, 3, 1).contiguous()
156 | return coords0, coords1
157 |
158 | def initialize_flow4(self, img):
159 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
160 | N, C, H, W = img.shape
161 | coords0 = coords_grid(N, H//4, W//4, device=img.device).permute(0, 2, 3, 1).contiguous()
162 | coords1 = coords_grid(N, H//4, W//4, device=img.device).permute(0, 2, 3, 1).contiguous()
163 | return coords0, coords1
164 |
165 | def _train_dfm(self, feat_shape, flow_gt, net, inp8, coords0, coords1):
166 | b, c, h, w = feat_shape
167 | if len(flow_gt.shape) == 3:
168 | flow_gt = flow_gt.unsqueeze(0)
169 | flow_gt_sp8 = F.interpolate(flow_gt, (h, w), mode='bilinear', align_corners=True) / 8.
170 |
171 | x_t, noises, t = self._prepare_targets(flow_gt_sp8)
172 | x_t = x_t * self.norm_const
173 | coords1 = coords1 + x_t.float()
174 |
175 | flow_up_s = []
176 | for ii in range(self.recurr_itrs):
177 | t_ii = (t - t / self.recurr_itrs * ii).int()
178 |
179 | itr = ii
180 | first_step = False if itr != 0 else True
181 |
182 | coords1 = coords1.detach()
183 | corr = self.corr_fn(coords1)
184 | flow = coords1 - coords0
185 | with autocast(enabled=self.args.mixed_precision):
186 | dfm_params = [t_ii, self.update, ii, 0]
187 | net, delta_flow = self.update_dfm(net, inp8, corr, flow, itr, first_step=first_step, dfm_params=dfm_params)
188 | up_mask = self.um8(net)
189 |
190 | coords1 = coords1 + delta_flow
191 | flow = coords1 - coords0
192 |
193 | flow_up = self.up_sample_flow8(flow, up_mask)
194 | flow_up_s.append(flow_up)
195 |
196 | return flow_up_s, coords1, net
197 |
198 | def _prepare_targets(self, flow_gt):
199 | noise = torch.randn(flow_gt.shape, device=self.device)
200 | t = torch.randint(0, self.num_timesteps, (1,), device=self.device).long()
201 |
202 | x_start = flow_gt / self.norm_const
203 | x_start = x_start * self.scale
204 | x_t = self._q_sample(x_start=x_start, t=t, noise=noise)
205 | x_t = torch.clamp(x_t, min=-1, max=1)
206 | x_t = x_t * self.n_sc
207 | return x_t, noise, t
208 |
209 | def _q_sample(self, x_start, t, noise=None):
210 | if noise is None:
211 | noise = torch.randn_like(x_start)
212 |
213 | sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
214 | sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
215 |
216 | return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
217 |
218 | @torch.no_grad()
219 | def _ddim_sample(self, feat_shape, net, inp, coords0, coords1_init, clip_denoised=True):
220 | batch, c, h, w = feat_shape
221 | shape = (batch, 2, h, w)
222 | total_timesteps, sampling_timesteps, eta, objective = self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
223 | times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1)
224 | times = list(reversed(times.int().tolist()))
225 | time_pairs = list(zip(times[:-1], times[1:]))
226 | x_in = torch.randn(shape, device=self.device)
227 |
228 | flow_s = []
229 | x_start = None
230 | pred_s = None
231 | for i_ddim, time_s in enumerate(time_pairs):
232 | time, time_next = time_s
233 | time_cond = torch.full((batch,), time, device=self.device, dtype=torch.long)
234 | t_next = torch.full((batch,), time_next, device=self.device, dtype=torch.long)
235 |
236 | x_pred, inner_flow_s, pred_s = self._model_predictions(x_in, time_cond, net, inp, coords0, coords1_init, i_ddim, pred_s, t_next)
237 | flow_s = flow_s + inner_flow_s
238 |
239 | alpha = self.alphas_cumprod[time]
240 | alpha_next = self.alphas_cumprod[time_next]
241 |
242 | x_t = x_in
243 | x_pred = x_pred * self.scale
244 | x_pred = torch.clamp(x_pred, min=-1 * self.scale, max=self.scale)
245 | eps = (1 / (1 - alpha).sqrt()) * (x_t - alpha.sqrt() * x_pred)
246 | x_next = alpha_next.sqrt() * x_pred + (1 - alpha_next).sqrt() * eps
247 | x_in = x_next
248 |
249 | net, up_mask, coords1 = pred_s
250 |
251 | return coords1, net, flow_s
252 |
253 | def _model_predictions(self, x, t, net, inp8, coords0, coords1, i_ddim, pred_last=None, t_next=None):
254 | x_flow = torch.clamp(x, min=-1, max=1)
255 | x_flow = x_flow * self.n_sc
256 | x_flow = x_flow * self.norm_const
257 |
258 | if pred_last:
259 | net, _, coords1 = pred_last
260 | x_flow = x_flow * self.n_lambda
261 |
262 | coords1 = coords1 + x_flow.float()
263 |
264 | flow_s = []
265 | for ii in range(self.recurr_itrs):
266 | t_ii = (t - (t - 0) / self.recurr_itrs * ii).int()
267 |
268 | corr = self.corr_fn(coords1)
269 | flow = coords1 - coords0
270 |
271 | with autocast(enabled=self.args.mixed_precision):
272 | itr = ii
273 | first_step = False if itr != 0 else True
274 | dfm_params = [t_ii, self.update, ii, 0]
275 | net, delta_flow = self.update_dfm(net, inp8, corr, flow, itr, first_step=first_step, dfm_params=dfm_params)
276 | up_mask = self.um8(net)
277 |
278 | coords1 = coords1 + delta_flow
279 |
280 | flow = coords1 - coords0
281 | flow_up = self.up_sample_flow8(flow, up_mask)
282 |
283 | flow_s.append(flow_up)
284 |
285 | flow = coords1 - coords0
286 | x_pred = flow / self.norm_const
287 |
288 | return x_pred, flow_s, [net, up_mask, coords1]
289 |
290 | def _predict_noise_from_start(self, x_t, t, x0):
291 | return (
292 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) /
293 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
294 | )
295 |
296 | def forward(self, image1, image2, test_mode=False, iters=None, flow_gt=None, flow_init=None):
297 |
298 | image1 = 2 * (image1 / 255.0) - 1.0
299 | image2 = 2 * (image2 / 255.0) - 1.0
300 |
301 | with autocast(enabled=self.args.mixed_precision):
302 | fmap = self.fnet(torch.cat([image1, image2], dim=0))
303 | inp = self.cnet(image1)
304 |
305 | fmap, fmap4 = fmap
306 | inp, inp4 = inp
307 | fmap = fmap.float()
308 | fmap4 = fmap4.float()
309 | inp = inp.float()
310 | inp4 = inp4.float()
311 |
312 | fmap1_4, fmap2_4 = torch.chunk(fmap4, chunks=2, dim=0)
313 | fmap1_8, fmap2_8 = torch.chunk(fmap, chunks=2, dim=0)
314 | inp8 = self.C_inp(inp)
315 | net = self.C_net(inp)
316 |
317 | corr_fn = self.trans(fmap1_8, fmap2_8, inp8)
318 |
319 | coords0, coords1 = self.initialize_flow8(image1)
320 | coords0 = coords0.permute(0, 3, 1, 2).contiguous()
321 | coords1 = coords1.permute(0, 3, 1, 2).contiguous()
322 |
323 | flow_list = []
324 | if flow_init is not None:
325 | if flow_init.shape[-2:] != coords1.shape[-2:]:
326 | flow_init = F.interpolate(flow_init, coords1.shape[-2:], mode='bilinear', align_corners=True) * 0.5
327 | coords1 = coords1 + flow_init
328 |
329 | if self.diffusion:
330 | self.corr_fn = corr_fn
331 | self.device = fmap1_8.device
332 | h, w = fmap1_8.shape[-2:]
333 | self.norm_const = torch.as_tensor([w, h], dtype=torch.float, device=self.device).view(1, 2, 1, 1)
334 |
335 | if self.training:
336 | coords1 = coords1.detach()
337 | flow_up_s, coords1, net = self._train_dfm(fmap1_8.shape, flow_gt, net, inp8, coords0, coords1)
338 | else:
339 | coords1, net, flow_up_s = self._ddim_sample(fmap1_8.shape, net, inp8, coords0, coords1)
340 |
341 | if self.sp4:
342 | flow4 = torch.nn.functional.interpolate(2 * (coords1 - coords0), scale_factor=2, mode='bilinear', align_corners=True)
343 | coords0, coords1 = self.initialize_flow4(image1)
344 | coords0 = coords0.permute(0, 3, 1, 2).contiguous()
345 | coords1 = coords1.permute(0, 3, 1, 2).contiguous()
346 | coords1 = coords1 + flow4
347 |
348 | net = torch.nn.functional.interpolate(net, scale_factor=2, mode='bilinear', align_corners=True)
349 | coords1_rd = ste_round(coords1)
350 |
351 | corr_fn4 = CorrBlock_FD_Sp4(fmap1_4, fmap2_4, num_levels=self.args.corr_levels, radius=self.args.corr_radius, coords_init=coords1_rd, rad=self.rad)
352 |
353 | for itr in range(self.args.iters_const6):
354 | coords1 = coords1.detach()
355 | corr = corr_fn4(coords1 - coords1_rd + self.rad)
356 |
357 | flow = coords1 - coords0
358 | with autocast(enabled=self.args.mixed_precision):
359 | net, delta_flow = self.update(net, inp4, corr, flow, itr, sp4=True)
360 | up_mask = self.um4(net)
361 |
362 | coords1 = coords1 + delta_flow
363 | flow_up = self.up_sample_flow4(coords1 - coords0, up_mask)
364 | flow_up_s.append(flow_up)
365 |
366 | flow_list = flow_list + flow_up_s
367 |
368 | if test_mode:
369 | flow = coords1 - coords0
370 | return flow, flow_list[-1]
371 |
372 | return flow_list
373 |
374 |
--------------------------------------------------------------------------------
/core/module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import time
6 | import os
7 | from torch.nn.parameter import Parameter
8 | import math
9 | import numpy as np
10 | import cv2
11 |
12 |
13 | class KPAFlowDec(nn.Module):
14 | def __init__(self, args, chnn=128):
15 | super().__init__()
16 | self.args = args
17 | cor_planes = 4 * (2 * args.corr_radius + 1) ** 2
18 | self.C_cor = nn.Sequential(
19 | nn.Conv2d(cor_planes, 256, 1),
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d(256, 192, 3, padding=1),
22 | nn.ReLU(inplace=True))
23 | self.C_flo = nn.Sequential(
24 | nn.Conv2d(2, 128, 7, padding=3),
25 | nn.ReLU(inplace=True),
26 | nn.Conv2d(128, 64, 3, padding=1),
27 | nn.ReLU(inplace=True))
28 | self.C_mo = nn.Sequential(
29 | nn.Conv2d(192+64, 128-2, 3, padding=1),
30 | nn.ReLU(inplace=True))
31 |
32 | self.kpa = KPA(args, chnn)
33 | self.gru = SepConvGRU(hidden_dim=chnn, input_dim=chnn+chnn+chnn)
34 | self.C_flow = nn.Sequential(
35 | nn.Conv2d(chnn, chnn*2, 3, padding=1),
36 | nn.ReLU(inplace=True),
37 | nn.Conv2d(chnn*2, 2, 3, padding=1))
38 | self.C_mask = nn.Sequential(
39 | nn.Conv2d(chnn, chnn*2, 3, padding=1),
40 | nn.ReLU(inplace=True),
41 | nn.Conv2d(chnn*2, 64*9, 1, padding=0))
42 |
43 | def _mo_enc(self, flow, corr, itr):
44 | feat_cor = self.C_cor(corr)
45 | feat_flo = self.C_flo(flow)
46 | feat_cat = torch.cat([feat_cor, feat_flo], dim=1)
47 | feat_mo = self.C_mo(feat_cat)
48 | feat_mo = torch.cat([feat_mo, flow], dim=1)
49 | return feat_mo
50 |
51 | def forward(self, net, inp, corr, flow, itr, upsample=True):
52 | feat_mo = self._mo_enc(flow, corr, itr)
53 | feat_moa = self.kpa(inp, feat_mo, itr)
54 | inp = torch.cat([inp, feat_mo, feat_moa], dim=1)
55 | net = self.gru(net, inp)
56 | delta_flow = self.C_flow(net)
57 |
58 | # scale mask to balence gradients
59 | mask = .25 * self.C_mask(net)
60 | return net, mask, delta_flow
61 |
62 |
63 | class KPA(nn.Module):
64 | def __init__(self, args, chnn):
65 | super().__init__()
66 | self.unfold_type = 'x311'
67 | if 'kitti' in args.dataset:
68 | self.sc = 15
69 | else:
70 | self.sc = 19
71 |
72 | self.unfold = nn.Unfold(kernel_size=3*self.sc, dilation=1, padding=self.sc, stride=self.sc)
73 | self.scale = chnn ** -0.5
74 | self.to_qk = nn.Conv2d(chnn, chnn * 2, 1, bias=False)
75 | self.to_v = nn.Conv2d(chnn, chnn, 1, bias=False)
76 | self.gamma = nn.Parameter(torch.zeros(1))
77 |
78 | h_k = (3 * self.sc - 1) / 2
79 | self.w_prelu = nn.Parameter(torch.zeros(1) + 1/h_k)
80 | self.scp = 0.02
81 | self.b = 1.
82 |
83 | def _FS(self, attn, shape):
84 | b, c, h, w, h_sc, w_sc = shape
85 | device = attn.device
86 | k = int(math.sqrt(attn.shape[1]))
87 | crd_k = torch.linspace(0, k-1, k).to(device)
88 | x = crd_k.view(1, 1, k, 1, 1).expand(b, 1, k, h, w)
89 | y = crd_k.view(1, k, 1, 1, 1).expand(b, k, 1, h, w)
90 |
91 | sc = torch.tensor(self.sc).to(device)
92 | idx_x = sc.view(1, 1, 1, 1, 1).expand(b, 1, 1, h, w)
93 | idx_y = sc.view(1, 1, 1, 1, 1).expand(b, 1, 1, h, w)
94 | crd_w = torch.linspace(0, w-1, w).to(device)
95 | crd_h = torch.linspace(0, h-1, h).to(device)
96 | idx_x = idx_x + crd_w.view(1, 1, 1, 1, w).expand(b, 1, 1, h, w) % self.sc
97 | idx_y = idx_y + crd_h.view(1, 1, 1, h, 1).expand(b, 1, 1, h, w) % self.sc
98 |
99 | half_ker = torch.tensor(self.sc * 2).to(device)
100 | o_x = -1 * F.prelu(abs(x - idx_x) - half_ker, self.w_prelu * self.scp) + self.b
101 | o_x[o_x < 0] = 0
102 | o_y = -1 * F.prelu(abs(y - idx_y) - half_ker, self.w_prelu * self.scp) + self.b
103 | o_y[o_y < 0] = 0
104 | ker_S = o_x * o_y
105 | ker_S = ker_S.view(b, k**2, h_sc, self.sc, w_sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(b, k**2, h_sc*w_sc, self.sc**2)
106 | return ker_S
107 |
108 | def forward(self, *inputs):
109 | feat_ci, feat_mi, itr = inputs
110 | b, c, h_in, w_in = feat_mi.shape
111 |
112 | x_pad = self.sc - w_in % self.sc
113 | y_pad = self.sc - h_in % self.sc
114 | feat_c = F.pad(feat_ci, (0, x_pad, 0, y_pad))
115 | feat_m = F.pad(feat_mi, (0, x_pad, 0, y_pad))
116 | b, c, h, w = feat_c.shape
117 | h_sc = h // self.sc
118 | w_sc = w // self.sc
119 |
120 | fm = torch.ones(1, 1, h_in, w_in).to(feat_m.device)
121 | fm = F.pad(fm, (0, x_pad, 0, y_pad))
122 | fm_k = self.unfold(fm).view(1, 1, -1, h_sc*w_sc)
123 | fm_q = fm.view(1, 1, h_sc, self.sc, w_sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(1, 1, h_sc*w_sc, self.sc**2)
124 | am = torch.einsum('b c k n, b c n s -> b k n s', fm_k, fm_q)
125 | am = (am - 1) * 99.
126 | am = am.repeat(b, 1, 1, 1)
127 |
128 | if itr == 0:
129 | feat_q, feat_k = self.to_qk(feat_c).chunk(2, dim=1)
130 | feat_k = self.unfold(feat_k).view(b, c, -1, h_sc*w_sc)
131 | feat_k = self.scale * feat_k
132 | feat_q = feat_q.view(b, c, h_sc, self.sc, w_sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(b, c, h_sc*w_sc, self.sc**2)
133 | attn = torch.einsum('b c k n, b c n s -> b k n s', feat_k, feat_q)
134 | attn = attn + am
135 |
136 | ker_S = self._FS(attn, [b, c, h, w, h_sc, w_sc])
137 | attn_kpa = ker_S.view(attn.shape) * attn
138 | self.attn = F.softmax(attn_kpa, dim=1)
139 |
140 | feat_v = self.to_v(feat_m)
141 | feat_v = self.unfold(feat_v).view(b, c, -1, h_sc*w_sc)
142 | feat_r = torch.einsum('b k n s, b c k n -> b c n s', self.attn, feat_v)
143 | feat_r = feat_r.view(b, c, h_sc, w_sc, self.sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(b, c, h, w)
144 | feat_r = feat_r[:,:,:h_in,:w_in]
145 |
146 | feat_o = feat_mi + feat_r * self.gamma
147 | return feat_o
148 |
149 |
150 | class KPAEnc(nn.Module):
151 | def __init__(self, args, chnn, sc):
152 | super().__init__()
153 | self.sc = sc
154 | self.unfold = nn.Unfold(kernel_size=3*self.sc, dilation=1, padding=self.sc, stride=self.sc)
155 | self.scale = chnn ** -0.5
156 | self.to_qk = nn.Conv2d(chnn, chnn * 2, 1, bias=False)
157 | self.to_v = nn.Conv2d(chnn, chnn, 1, bias=False)
158 | self.gamma = nn.Parameter(torch.zeros(1))
159 | self.mask_k = True
160 |
161 | def forward(self, inputs):
162 | feat_i = inputs
163 | b, c, h_in, w_in = feat_i.shape
164 | x_pad = self.sc - w_in % self.sc
165 | y_pad = self.sc - h_in % self.sc
166 | feat = F.pad(feat_i, (0, x_pad, 0, y_pad))
167 | b, c, h, w = feat.shape
168 | h_sc = h // self.sc
169 | w_sc = w // self.sc
170 |
171 | fm = torch.ones(1, 1, h_in, w_in).to(feat.device)
172 | fm = F.pad(fm, (0, x_pad, 0, y_pad))
173 | fm_k = self.unfold(fm).view(1, 1, -1, h_sc*w_sc)
174 | fm_q = fm.view(1, 1, h_sc, self.sc, w_sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(1, 1, h_sc*w_sc, self.sc**2)
175 | am = torch.einsum('b c k n, b c n s -> b k n s', fm_k, fm_q)
176 | am = (am - 1) * 99.
177 | am = am.repeat(b, 1, 1, 1)
178 |
179 | feat_q, feat_k = self.to_qk(feat).chunk(2, dim=1)
180 | feat_k = self.unfold(feat_k).view(b, c, -1, h_sc*w_sc)
181 | feat_k = self.scale * feat_k
182 | feat_q = feat_q.view(b, c, h_sc, self.sc, w_sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(b, c, h_sc*w_sc, self.sc**2)
183 | attn = torch.einsum('b c k n, b c n s -> b k n s', feat_k, feat_q)
184 |
185 | attn = attn + am
186 | self.attn = F.softmax(attn, dim=1)
187 |
188 | feat_v = self.to_v(feat)
189 | feat_v = self.unfold(feat_v).view(b, c, -1, h_sc*w_sc)
190 | feat_r = torch.einsum('b k n s, b c k n -> b c n s', self.attn, feat_v)
191 | feat_r = feat_r.view(b, c, h_sc, w_sc, self.sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(b, c, h, w)
192 | feat_r = feat_r[:,:,:h_in,:w_in]
193 | feat_o = feat_i + feat_r * self.gamma
194 | return feat_o
195 |
196 |
197 | class SepConvGRU(nn.Module):
198 | def __init__(self, hidden_dim=128, input_dim=192+128):
199 | super(SepConvGRU, self).__init__()
200 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
201 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
202 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
203 |
204 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
205 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
206 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
207 |
208 | def forward(self, h, x):
209 | # horizontal
210 | hx = torch.cat([h, x], dim=1)
211 | z = torch.sigmoid(self.convz1(hx))
212 | r = torch.sigmoid(self.convr1(hx))
213 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
214 | h = (1-z) * h + z * q
215 |
216 | # vertical
217 | hx = torch.cat([h, x], dim=1)
218 | z = torch.sigmoid(self.convz2(hx))
219 | r = torch.sigmoid(self.convr2(hx))
220 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
221 | h = (1-z) * h + z * q
222 |
223 | return h
224 |
--------------------------------------------------------------------------------
/core/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from update import BasicUpdateBlock, SmallUpdateBlock
7 | from extractor import BasicEncoder, SmallEncoder
8 | from corr import CorrBlock, AlternateCorrBlock
9 | from utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except:
14 | # dummy autocast for PyTorch < 1.6
15 | class autocast:
16 | def __init__(self, enabled):
17 | pass
18 | def __enter__(self):
19 | pass
20 | def __exit__(self, *args):
21 | pass
22 |
23 |
24 | class RAFT(nn.Module):
25 | def __init__(self, args):
26 | super(RAFT, self).__init__()
27 | self.args = args
28 |
29 | if args.small:
30 | self.hidden_dim = hdim = 96
31 | self.context_dim = cdim = 64
32 | args.corr_levels = 4
33 | args.corr_radius = 3
34 |
35 | else:
36 | self.hidden_dim = hdim = 128
37 | self.context_dim = cdim = 128
38 | args.corr_levels = 4
39 | args.corr_radius = 4
40 |
41 | if 'dropout' not in self.args:
42 | self.args.dropout = 0
43 |
44 | if 'alternate_corr' not in self.args:
45 | self.args.alternate_corr = False
46 |
47 | # feature network, context network, and update block
48 | if args.small:
49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52 |
53 | else:
54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57 |
58 | def freeze_bn(self):
59 | for m in self.modules():
60 | if isinstance(m, nn.BatchNorm2d):
61 | m.eval()
62 |
63 | def initialize_flow(self, img):
64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
65 | N, C, H, W = img.shape
66 | coords0 = coords_grid(N, H//8, W//8, device=img.device)
67 | coords1 = coords_grid(N, H//8, W//8, device=img.device)
68 |
69 | # optical flow computed as difference: flow = coords1 - coords0
70 | return coords0, coords1
71 |
72 | def upsample_flow(self, flow, mask):
73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
74 | N, _, H, W = flow.shape
75 | mask = mask.view(N, 1, 9, 8, 8, H, W)
76 | mask = torch.softmax(mask, dim=2)
77 |
78 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
80 |
81 | up_flow = torch.sum(mask * up_flow, dim=2)
82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
83 | return up_flow.reshape(N, 2, 8*H, 8*W)
84 |
85 |
86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
87 | """ Estimate optical flow between pair of frames """
88 |
89 | image1 = 2 * (image1 / 255.0) - 1.0
90 | image2 = 2 * (image2 / 255.0) - 1.0
91 |
92 | image1 = image1.contiguous()
93 | image2 = image2.contiguous()
94 |
95 | hdim = self.hidden_dim
96 | cdim = self.context_dim
97 |
98 | # run the feature network
99 | with autocast(enabled=self.args.mixed_precision):
100 | fmap1, fmap2 = self.fnet([image1, image2])
101 |
102 | fmap1 = fmap1.float()
103 | fmap2 = fmap2.float()
104 | if self.args.alternate_corr:
105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
106 | else:
107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
108 |
109 | # run the context network
110 | with autocast(enabled=self.args.mixed_precision):
111 | cnet = self.cnet(image1)
112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
113 | net = torch.tanh(net)
114 | inp = torch.relu(inp)
115 |
116 | coords0, coords1 = self.initialize_flow(image1)
117 |
118 | if flow_init is not None:
119 | coords1 = coords1 + flow_init
120 |
121 | flow_predictions = []
122 | for itr in range(iters):
123 | coords1 = coords1.detach()
124 | corr = corr_fn(coords1) # index correlation volume
125 |
126 | flow = coords1 - coords0
127 | with autocast(enabled=self.args.mixed_precision):
128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
129 |
130 | # F(t+1) = F(t) + \Delta(t)
131 | coords1 = coords1 + delta_flow
132 |
133 | # upsample predictions
134 | if up_mask is None:
135 | flow_up = upflow8(coords1 - coords0)
136 | else:
137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
138 |
139 | flow_predictions.append(flow_up)
140 |
141 | if test_mode:
142 | return coords1 - coords0, flow_up
143 |
144 | return flow_predictions
145 |
--------------------------------------------------------------------------------
/core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self, args):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self, args):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, args, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(args)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, args, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.args = args
118 | self.encoder = BasicMotionEncoder(args)
119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121 |
122 | self.mask = nn.Sequential(
123 | nn.Conv2d(128, 256, 3, padding=1),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(256, 64*9, 1, padding=0))
126 |
127 | def forward(self, net, inp, corr, flow, upsample=True):
128 | motion_features = self.encoder(flow, corr)
129 | inp = torch.cat([inp, motion_features], dim=1)
130 |
131 | net = self.gru(net, inp)
132 | delta_flow = self.flow_head(net)
133 |
134 | # scale mask to balence gradients
135 | mask = .25 * self.mask(net)
136 | return net, mask, delta_flow
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LA30/FlowDiffuser/9aff9c6e8c68f809e40bb0ae4273621276686168/core/utils/__init__.py
--------------------------------------------------------------------------------
/core/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import math
4 | from PIL import Image
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | import torch
11 | from torchvision.transforms import ColorJitter
12 | import torch.nn.functional as F
13 |
14 |
15 | class FlowAugmentor:
16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
17 |
18 | # spatial augmentation params
19 | self.crop_size = crop_size
20 | self.min_scale = min_scale
21 | self.max_scale = max_scale
22 | self.spatial_aug_prob = 0.8
23 | self.stretch_prob = 0.8
24 | self.max_stretch = 0.2
25 |
26 | # flip augmentation params
27 | self.do_flip = do_flip
28 | self.h_flip_prob = 0.5
29 | self.v_flip_prob = 0.1
30 |
31 | # photometric augmentation params
32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
33 | self.asymmetric_color_aug_prob = 0.2
34 | self.eraser_aug_prob = 0.5
35 |
36 | def color_transform(self, img1, img2):
37 | """ Photometric augmentation """
38 |
39 | # asymmetric
40 | if np.random.rand() < self.asymmetric_color_aug_prob:
41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
43 |
44 | # symmetric
45 | else:
46 | image_stack = np.concatenate([img1, img2], axis=0)
47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
48 | img1, img2 = np.split(image_stack, 2, axis=0)
49 |
50 | return img1, img2
51 |
52 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
53 | """ Occlusion augmentation """
54 |
55 | ht, wd = img1.shape[:2]
56 | if np.random.rand() < self.eraser_aug_prob:
57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
58 | for _ in range(np.random.randint(1, 3)):
59 | x0 = np.random.randint(0, wd)
60 | y0 = np.random.randint(0, ht)
61 | dx = np.random.randint(bounds[0], bounds[1])
62 | dy = np.random.randint(bounds[0], bounds[1])
63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
64 |
65 | return img1, img2
66 |
67 | def spatial_transform(self, img1, img2, flow):
68 | # randomly sample scale
69 | ht, wd = img1.shape[:2]
70 | min_scale = np.maximum(
71 | (self.crop_size[0] + 8) / float(ht),
72 | (self.crop_size[1] + 8) / float(wd))
73 |
74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
75 | scale_x = scale
76 | scale_y = scale
77 | if np.random.rand() < self.stretch_prob:
78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
80 |
81 | scale_x = np.clip(scale_x, min_scale, None)
82 | scale_y = np.clip(scale_y, min_scale, None)
83 |
84 | if np.random.rand() < self.spatial_aug_prob:
85 | # rescale the images
86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
89 | flow = flow * [scale_x, scale_y]
90 |
91 | if self.do_flip:
92 | if np.random.rand() < self.h_flip_prob: # h-flip
93 | img1 = img1[:, ::-1]
94 | img2 = img2[:, ::-1]
95 | flow = flow[:, ::-1] * [-1.0, 1.0]
96 |
97 | if np.random.rand() < self.v_flip_prob: # v-flip
98 | img1 = img1[::-1, :]
99 | img2 = img2[::-1, :]
100 | flow = flow[::-1, :] * [1.0, -1.0]
101 |
102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
104 |
105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
108 |
109 | return img1, img2, flow
110 |
111 | def __call__(self, img1, img2, flow):
112 | img1, img2 = self.color_transform(img1, img2)
113 | img1, img2 = self.eraser_transform(img1, img2)
114 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
115 |
116 | img1 = np.ascontiguousarray(img1)
117 | img2 = np.ascontiguousarray(img2)
118 | flow = np.ascontiguousarray(flow)
119 |
120 | return img1, img2, flow
121 |
122 | class SparseFlowAugmentor:
123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
124 | # spatial augmentation params
125 | self.crop_size = crop_size
126 | self.min_scale = min_scale
127 | self.max_scale = max_scale
128 | self.spatial_aug_prob = 0.8
129 | self.stretch_prob = 0.8
130 | self.max_stretch = 0.2
131 |
132 | # flip augmentation params
133 | self.do_flip = do_flip
134 | self.h_flip_prob = 0.5
135 | self.v_flip_prob = 0.1
136 |
137 | # photometric augmentation params
138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
139 | self.asymmetric_color_aug_prob = 0.2
140 | self.eraser_aug_prob = 0.5
141 |
142 | def color_transform(self, img1, img2):
143 | image_stack = np.concatenate([img1, img2], axis=0)
144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
145 | img1, img2 = np.split(image_stack, 2, axis=0)
146 | return img1, img2
147 |
148 | def eraser_transform(self, img1, img2):
149 | ht, wd = img1.shape[:2]
150 | if np.random.rand() < self.eraser_aug_prob:
151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
152 | for _ in range(np.random.randint(1, 3)):
153 | x0 = np.random.randint(0, wd)
154 | y0 = np.random.randint(0, ht)
155 | dx = np.random.randint(50, 100)
156 | dy = np.random.randint(50, 100)
157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
158 |
159 | return img1, img2
160 |
161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
162 | ht, wd = flow.shape[:2]
163 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
164 | coords = np.stack(coords, axis=-1)
165 |
166 | coords = coords.reshape(-1, 2).astype(np.float32)
167 | flow = flow.reshape(-1, 2).astype(np.float32)
168 | valid = valid.reshape(-1).astype(np.float32)
169 |
170 | coords0 = coords[valid>=1]
171 | flow0 = flow[valid>=1]
172 |
173 | ht1 = int(round(ht * fy))
174 | wd1 = int(round(wd * fx))
175 |
176 | coords1 = coords0 * [fx, fy]
177 | flow1 = flow0 * [fx, fy]
178 |
179 | xx = np.round(coords1[:,0]).astype(np.int32)
180 | yy = np.round(coords1[:,1]).astype(np.int32)
181 |
182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
183 | xx = xx[v]
184 | yy = yy[v]
185 | flow1 = flow1[v]
186 |
187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
189 |
190 | flow_img[yy, xx] = flow1
191 | valid_img[yy, xx] = 1
192 |
193 | return flow_img, valid_img
194 |
195 | def spatial_transform(self, img1, img2, flow, valid):
196 | # randomly sample scale
197 |
198 | ht, wd = img1.shape[:2]
199 | min_scale = np.maximum(
200 | (self.crop_size[0] + 1) / float(ht),
201 | (self.crop_size[1] + 1) / float(wd))
202 |
203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
204 | scale_x = np.clip(scale, min_scale, None)
205 | scale_y = np.clip(scale, min_scale, None)
206 |
207 | if np.random.rand() < self.spatial_aug_prob:
208 | # rescale the images
209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
212 |
213 | if self.do_flip:
214 | if np.random.rand() < 0.5: # h-flip
215 | img1 = img1[:, ::-1]
216 | img2 = img2[:, ::-1]
217 | flow = flow[:, ::-1] * [-1.0, 1.0]
218 | valid = valid[:, ::-1]
219 |
220 | margin_y = 20
221 | margin_x = 50
222 |
223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
225 |
226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
228 |
229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
233 | return img1, img2, flow, valid
234 |
235 |
236 | def __call__(self, img1, img2, flow, valid):
237 | img1, img2 = self.color_transform(img1, img2)
238 | img1, img2 = self.eraser_transform(img1, img2)
239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
240 |
241 | img1 = np.ascontiguousarray(img1)
242 | img2 = np.ascontiguousarray(img2)
243 | flow = np.ascontiguousarray(flow)
244 | valid = np.ascontiguousarray(valid)
245 |
246 | return img1, img2, flow, valid
247 |
--------------------------------------------------------------------------------
/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd, device):
75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python evaluate.py --model=weights/FlowDiffuser-things.pth --dataset=sintel
3 | python evaluate.py --model=weights/FlowDiffuser-things.pth --dataset=kitti
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | from PIL import Image
5 | import argparse
6 | import os
7 | import time
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | import matplotlib.pyplot as plt
12 |
13 | import datasets
14 | from utils import flow_viz
15 | from utils import frame_utils
16 |
17 | from flowdiffuser import FlowDiffuser
18 |
19 | from utils.utils import InputPadder, forward_interpolate
20 |
21 |
22 | @torch.no_grad()
23 | def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
24 | """ Create submission for the Sintel leaderboard """
25 | model.eval()
26 | for dstype in ['clean', 'final']:
27 | test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
28 |
29 | flow_prev, sequence_prev = None, None
30 | for test_id in range(len(test_dataset)):
31 | image1, image2, (sequence, frame) = test_dataset[test_id]
32 | # if sequence != sequence_prev:
33 | # flow_prev = None
34 |
35 | if (sequence != sequence_prev) or (dstype == 'final' and sequence in ['market_4', ]) or dstype == 'clean':
36 | flow_prev = None
37 |
38 | padder = InputPadder(image1.shape)
39 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
40 |
41 | flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
42 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
43 |
44 | if warm_start:
45 | flow_prev = forward_interpolate(flow_low[0])[None].cuda()
46 |
47 | output_dir = os.path.join(output_path, dstype, sequence)
48 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
49 |
50 | if not os.path.exists(output_dir):
51 | os.makedirs(output_dir)
52 |
53 | frame_utils.writeFlow(output_file, flow)
54 | sequence_prev = sequence
55 |
56 |
57 | @torch.no_grad()
58 | def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
59 | """ Create submission for the Sintel leaderboard """
60 | model.eval()
61 | test_dataset = datasets.KITTI(split='testing', aug_params=None)
62 |
63 | if not os.path.exists(output_path):
64 | os.makedirs(output_path)
65 |
66 | for test_id in range(len(test_dataset)):
67 | image1, image2, (frame_id, ) = test_dataset[test_id]
68 | padder = InputPadder(image1.shape, mode='kitti')
69 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
70 |
71 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
72 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
73 |
74 | output_filename = os.path.join(output_path, frame_id)
75 | frame_utils.writeFlowKITTI(output_filename, flow)
76 |
77 |
78 | @torch.no_grad()
79 | def validate_chairs(model, iters=24):
80 | """ Perform evaluation on the FlyingChairs (test) split """
81 | model.eval()
82 | epe_list = []
83 |
84 | val_dataset = datasets.FlyingChairs(split='validation')
85 | for val_id in range(len(val_dataset)):
86 | image1, image2, flow_gt, _ = val_dataset[val_id]
87 | image1 = image1[None].cuda()
88 | image2 = image2[None].cuda()
89 |
90 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
91 | epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
92 | epe_list.append(epe.view(-1).numpy())
93 |
94 | epe = np.mean(np.concatenate(epe_list))
95 | print("Validation Chairs EPE: %f" % epe)
96 | return {'chairs': epe}
97 |
98 |
99 | @torch.no_grad()
100 | def validate_sintel(model, iters=32):
101 | """ Peform validation using the Sintel (train) split """
102 | model.eval()
103 | results = {}
104 | for dstype in ['clean', 'final']:
105 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
106 | epe_list = []
107 |
108 | for val_id in range(len(val_dataset)):
109 | image1, image2, flow_gt, _ = val_dataset[val_id]
110 | image1 = image1[None].cuda()
111 | image2 = image2[None].cuda()
112 |
113 | padder = InputPadder(image1.shape)
114 | image1, image2 = padder.pad(image1, image2)
115 |
116 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
117 | flow = padder.unpad(flow_pr[0]).cpu()
118 |
119 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
120 | epe_list.append(epe.view(-1).numpy())
121 |
122 | epe_all = np.concatenate(epe_list)
123 | epe = np.mean(epe_all)
124 | px1 = np.mean(epe_all<1)
125 | px3 = np.mean(epe_all<3)
126 | px5 = np.mean(epe_all<5)
127 |
128 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
129 | results[dstype] = np.mean(epe_list)
130 |
131 | return results
132 |
133 |
134 | @torch.no_grad()
135 | def validate_kitti(model, iters=24):
136 | """ Peform validation using the KITTI-2015 (train) split """
137 | model.eval()
138 | val_dataset = datasets.KITTI(split='training')
139 |
140 | out_list, epe_list = [], []
141 | for val_id in range(len(val_dataset)):
142 | image1, image2, flow_gt, valid_gt = val_dataset[val_id]
143 | image1 = image1[None].cuda()
144 | image2 = image2[None].cuda()
145 |
146 | padder = InputPadder(image1.shape, mode='kitti')
147 | image1, image2 = padder.pad(image1, image2)
148 |
149 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
150 | flow = padder.unpad(flow_pr[0]).cpu()
151 |
152 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
153 | mag = torch.sum(flow_gt**2, dim=0).sqrt()
154 |
155 | epe = epe.view(-1)
156 | mag = mag.view(-1)
157 | val = valid_gt.view(-1) >= 0.5
158 |
159 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
160 | epe_list.append(epe[val].mean().item())
161 | out_list.append(out[val].cpu().numpy())
162 |
163 | epe_list = np.array(epe_list)
164 | out_list = np.concatenate(out_list)
165 |
166 | epe = np.mean(epe_list)
167 | f1 = 100 * np.mean(out_list)
168 |
169 | print("Validation KITTI: %f, %f" % (epe, f1))
170 | return {'kitti-epe': epe, 'kitti-f1': f1}
171 |
172 |
173 | if __name__ == '__main__':
174 | parser = argparse.ArgumentParser()
175 | parser.add_argument('--model', help="restore checkpoint")
176 | parser.add_argument('--dataset', help="dataset for evaluation")
177 | parser.add_argument('--small', action='store_true', help='use small model')
178 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
179 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
180 | args = parser.parse_args()
181 |
182 | model = torch.nn.DataParallel(FlowDiffuser(args))
183 | model.load_state_dict(torch.load(args.model))
184 |
185 | model.cuda()
186 | model.eval()
187 |
188 | # create_sintel_submission(model.module, warm_start=True)
189 | # create_kitti_submission(model.module)
190 |
191 | with torch.no_grad():
192 | if args.dataset == 'chairs':
193 | validate_chairs(model.module)
194 |
195 | elif args.dataset == 'sintel':
196 | validate_sintel(model.module)
197 |
198 | elif args.dataset == 'kitti':
199 | validate_kitti(model.module)
200 |
201 |
202 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import sys
3 | sys.path.append('core')
4 |
5 | import argparse
6 | import os
7 | import cv2
8 | import time
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.optim as optim
15 | import torch.nn.functional as F
16 |
17 | from torch.utils.data import DataLoader
18 | import evaluate
19 | import datasets
20 | from flowdiffuser import FlowDiffuser
21 |
22 | from torch.utils.tensorboard import SummaryWriter
23 |
24 | from torch.cuda.amp import GradScaler
25 |
26 |
27 | # exclude extremly large displacements
28 | MAX_FLOW = 400
29 | SUM_FREQ = 100
30 | VAL_FREQ = 5000
31 |
32 |
33 | def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
34 | """ Loss function defined over sequence of flow predictions """
35 |
36 | n_predictions = len(flow_preds)
37 | flow_loss = 0.0
38 |
39 | # exlude invalid pixels and extremely large diplacements
40 | mag = torch.sum(flow_gt**2, dim=1).sqrt()
41 | valid = (valid >= 0.5) & (mag < max_flow)
42 |
43 | for i in range(n_predictions):
44 | i_weight = gamma**(n_predictions - i - 1)
45 | i_loss = (flow_preds[i] - flow_gt).abs()
46 | flow_loss += i_weight * (valid[:, None] * i_loss).mean()
47 |
48 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
49 | epe = epe.view(-1)[valid.view(-1)]
50 |
51 | metrics = {
52 | 'epe': epe.mean().item(),
53 | '1px': (epe < 1).float().mean().item(),
54 | '3px': (epe < 3).float().mean().item(),
55 | '5px': (epe < 5).float().mean().item(),
56 | }
57 |
58 | return flow_loss, metrics
59 |
60 |
61 | def count_parameters(model):
62 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
63 |
64 |
65 | def fetch_optimizer(args, model):
66 | """ Create the optimizer and learning rate scheduler """
67 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
68 |
69 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
70 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
71 |
72 | return optimizer, scheduler
73 |
74 |
75 | class Logger:
76 | def __init__(self, model, scheduler):
77 | self.model = model
78 | self.scheduler = scheduler
79 | self.total_steps = 0
80 | self.running_loss = {}
81 | self.writer = None
82 |
83 | def _print_training_status(self):
84 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
85 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
86 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
87 |
88 | # print the training status
89 | print(training_str + metrics_str)
90 |
91 | if self.writer is None:
92 | self.writer = SummaryWriter()
93 |
94 | for k in self.running_loss:
95 | self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
96 | self.running_loss[k] = 0.0
97 |
98 | def push(self, metrics):
99 | self.total_steps += 1
100 |
101 | for key in metrics:
102 | if key not in self.running_loss:
103 | self.running_loss[key] = 0.0
104 |
105 | self.running_loss[key] += metrics[key]
106 |
107 | if self.total_steps % SUM_FREQ == SUM_FREQ-1:
108 | self._print_training_status()
109 | self.running_loss = {}
110 |
111 | def write_dict(self, results):
112 | if self.writer is None:
113 | self.writer = SummaryWriter()
114 |
115 | for key in results:
116 | self.writer.add_scalar(key, results[key], self.total_steps)
117 |
118 | def close(self):
119 | self.writer.close()
120 |
121 |
122 | def train(args):
123 |
124 | model = nn.DataParallel(FlowDiffuser(args), device_ids=args.gpus)
125 | print("Parameter Count: %d" % count_parameters(model))
126 |
127 | if args.restore_ckpt is not None:
128 | model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
129 |
130 | model.cuda()
131 | model.train()
132 |
133 | if args.stage != 'chairs':
134 | model.module.freeze_bn()
135 |
136 | train_loader = datasets.fetch_dataloader(args)
137 | optimizer, scheduler = fetch_optimizer(args, model)
138 |
139 | total_steps = 0
140 | scaler = GradScaler(enabled=args.mixed_precision)
141 | logger = Logger(model, scheduler)
142 |
143 | VAL_FREQ = 5000
144 | add_noise = True
145 |
146 | should_keep_training = True
147 | while should_keep_training:
148 |
149 | for i_batch, data_blob in enumerate(train_loader):
150 | optimizer.zero_grad()
151 | image1, image2, flow, valid = [x.cuda() for x in data_blob]
152 |
153 | if args.add_noise:
154 | stdv = np.random.uniform(0.0, 5.0)
155 | image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
156 | image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
157 |
158 | flow_predictions = model(image1, image2, iters=args.iters, flow_gt=flow)
159 |
160 | loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
161 | scaler.scale(loss).backward()
162 | scaler.unscale_(optimizer)
163 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
164 |
165 | scaler.step(optimizer)
166 | scheduler.step()
167 | scaler.update()
168 |
169 | logger.push(metrics)
170 |
171 | if total_steps % VAL_FREQ == VAL_FREQ - 1:
172 | PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
173 | torch.save(model.state_dict(), PATH)
174 |
175 | results = {}
176 | for val_dataset in args.validation:
177 | if val_dataset == 'chairs':
178 | results.update(evaluate.validate_chairs(model.module))
179 | elif val_dataset == 'sintel':
180 | results.update(evaluate.validate_sintel(model.module))
181 | elif val_dataset == 'kitti':
182 | results.update(evaluate.validate_kitti(model.module))
183 |
184 | logger.write_dict(results)
185 |
186 | model.train()
187 | if args.stage != 'chairs':
188 | model.module.freeze_bn()
189 |
190 | total_steps += 1
191 |
192 | if total_steps > args.num_steps:
193 | should_keep_training = False
194 | break
195 |
196 | logger.close()
197 | PATH = 'checkpoints/%s.pth' % args.name
198 | torch.save(model.state_dict(), PATH)
199 |
200 | return PATH
201 |
202 |
203 | if __name__ == '__main__':
204 | parser = argparse.ArgumentParser()
205 | parser.add_argument('--name', default='flowdiffuser', help="name your experiment")
206 | parser.add_argument('--stage', help="determines which dataset to use for training")
207 | parser.add_argument('--restore_ckpt', help="restore checkpoint")
208 | parser.add_argument('--small', action='store_true', help='use small model')
209 | parser.add_argument('--validation', type=str, nargs='+')
210 |
211 | parser.add_argument('--lr', type=float, default=0.00002)
212 | parser.add_argument('--num_steps', type=int, default=100000)
213 | parser.add_argument('--batch_size', type=int, default=6)
214 | parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
215 | parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
216 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
217 |
218 | parser.add_argument('--iters', type=int, default=12)
219 | parser.add_argument('--wdecay', type=float, default=.00005)
220 | parser.add_argument('--epsilon', type=float, default=1e-8)
221 | parser.add_argument('--clip', type=float, default=1.0)
222 | parser.add_argument('--dropout', type=float, default=0.0)
223 | parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
224 | parser.add_argument('--add_noise', action='store_true')
225 | args = parser.parse_args()
226 |
227 | torch.manual_seed(1234)
228 | np.random.seed(1234)
229 |
230 | if not os.path.isdir('checkpoints'):
231 | os.mkdir('checkpoints')
232 |
233 | train(args)
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p checkpoints
3 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -u train.py --name fd-chairs --stage chairs --validation chairs --gpus 0 1 2 3 4 5 --num_steps 100000 --batch_size 12 --lr 0.00045 --image_size 384 512 --wdecay 0.0001
4 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -u train.py --name fd-things --stage things --validation sintel --restore_ckpt checkpoints/fd-chairs.pth --gpus 0 1 2 3 4 5 --num_steps 200000 --batch_size 6 --lr 0.000175 --image_size 432 960 --wdecay 0.0001
5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -u train.py --name fd-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/fd-things.pth --gpus 0 1 2 3 4 5 --num_steps 180000 --batch_size 6 --lr 0.000175 --image_size 432 960 --wdecay 0.00001 --gamma=0.85
6 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -u train.py --name fd-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/fd-sintel.pth --gpus 0 1 2 3 4 5 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85
--------------------------------------------------------------------------------