├── .gitignore
├── LICENSE
├── README.md
├── alt_cuda_corr
├── correlation.cpp
├── correlation_kernel.cu
└── setup.py
├── chairs_split.txt
├── core
├── __init__.py
├── corr.py
├── datasets.py
├── extractor.py
├── gma.py
├── gmflownet_model.py
├── loss.py
├── onecyclelr.py
├── onecyclelr.py.save
├── raft_gma_model.py
├── raft_model.py
├── swin_transformer.py
├── update.py
└── utils
│ ├── __init__.py
│ ├── augmentor.py
│ ├── drop.py
│ ├── flow_viz.py
│ ├── frame_utils.py
│ ├── helpers.py
│ ├── utils.py
│ └── weight_init.py
├── demo-frames
├── frame_0016.png
├── frame_0017.png
├── frame_0018.png
├── frame_0019.png
├── frame_0020.png
├── frame_0021.png
├── frame_0022.png
├── frame_0023.png
├── frame_0024.png
└── frame_0025.png
├── demo.py
├── evaluate.py
├── train.py
├── train_gmflownet.sh
└── train_gmflownet_mix.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.egg-info
3 | dist
4 | datasets
5 | datasets/
6 | pretrained_models/
7 | build/
8 | correlation.egg-info
9 | checkpoints/
10 | runs/
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Shiyu Zhao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GMFlowNet
2 |
3 | This repository contains the official implementation for the paper:
4 |
5 | [Global Matching with Overlapping Attention for Optical Flow Estimation](https://arxiv.org/abs/2203.11335)
6 | CVPR 2022
7 | Shiyu Zhao, Long Zhao, Zhixing Zhang, Enyu Zhou, Dimitris Metaxas
8 |
9 | ## Requirements
10 | The code has been tested with PyTorch 1.7 and Cuda 11.0. Later PyTorch may also work.
11 | ```Shell
12 | conda create --name gmflownet
13 | conda activate gmflownet
14 | conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=11.0 -c pytorch
15 | conda install matplotlib tensorboard scipy opencv
16 | ```
17 |
18 | ## Demos
19 | Download .zip file with pretrained models at [Google Drive](https://drive.google.com/file/d/1rVfu0j9O1M2hNsew-dVRlx9jL7c6ICru/view?usp=sharing). Unzip `pretrained_models.zip` in the root.
20 | ```Shell
21 | unzip pretrained_models.zip
22 | ```
23 |
24 | You can demo a trained model on a sequence of frames
25 | ```Shell
26 | python demo.py --model gmflownet --ckpt=pretrained_models/gmflownet-things.pth --path=demo-frames
27 | ```
28 |
29 | ## Required Data
30 | To evaluate/train RAFT, you need to download the following datasets.
31 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
32 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
33 | * [Sintel](http://sintel.is.tue.mpg.de/)
34 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
35 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional)
36 |
37 | Place all datasets in your preferred directory and symbolic link it to `./datasets` with `ln -s ./datasets` so that your `./datasets` folder looks like
38 | ```Shell
39 | ├── datasets
40 | ├── Sintel
41 | ├── test
42 | ├── training
43 | ├── KITTI
44 | ├── testing
45 | ├── training
46 | ├── devkit
47 | ├── FlyingChairs_release
48 | ├── data
49 | ├── FlyingThings3D
50 | ├── frames_cleanpass
51 | ├── frames_finalpass
52 | ├── optical_flow
53 | ...
54 | ```
55 |
56 | ## Evaluation
57 | Download the pretraind model described in [Demo](https://github.com/xiaofeng94/GMFlowNet/blob/master/README.md#demos).
58 | You may evaluate a pretrained model using `evaluate.py`. To get the best result,
59 |
60 | On Sintel, evaluate the `gmflownet_mix` model as,
61 | ```Shell
62 | python evaluate.py --model gmflownet --use_mix_attn --ckpt=pretrained_models/gmflownet_mix-things.pth --dataset=sintel
63 | ```
64 | On KITTI, evaluate the `gmflownet` model as,
65 | ```Shell
66 | python evaluate.py --model gmflownet --ckpt=pretrained_models/gmflownet-things.pth --dataset=kitti
67 | ```
68 | Note: `gmflownet_mix` replaces half of heads (4 out of 8 heads) in each POLA attention of `gmflownet` with heads of [axial attentions](https://arxiv.org/abs/2003.07853) and achieves better results on Sintel.
69 |
70 |
71 | ## Training
72 | We used the following training schedules in our paper (2 GPUs).
73 |
74 | - Train `gmflownet` as,
75 | ```Shell
76 | ./train_gmflownet.sh
77 | ```
78 | - Train `gmflownet_mix` as,
79 | ```Shell
80 | ./train_gmflownet_mix.sh
81 | ```
82 |
83 | Training logs will be written to the `./runs` which can be visualized using tensorboard as,
84 | ```Shell
85 | tensorboard --bind_all --port 8080 --logdir ./runs
86 | ```
87 |
88 |
89 | ## Acknowledgement
90 | The code is based on [RAFT](https://github.com/princeton-vl/RAFT) and [SwinTransformer](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
91 | We sincerely thank the authors for their great work.
92 |
93 |
94 |
95 |
96 |
--------------------------------------------------------------------------------
/alt_cuda_corr/correlation.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | // CUDA forward declarations
5 | std::vector corr_cuda_forward(
6 | torch::Tensor fmap1,
7 | torch::Tensor fmap2,
8 | torch::Tensor coords,
9 | int radius);
10 |
11 | std::vector corr_cuda_backward(
12 | torch::Tensor fmap1,
13 | torch::Tensor fmap2,
14 | torch::Tensor coords,
15 | torch::Tensor corr_grad,
16 | int radius);
17 |
18 | // C++ interface
19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
22 |
23 | std::vector corr_forward(
24 | torch::Tensor fmap1,
25 | torch::Tensor fmap2,
26 | torch::Tensor coords,
27 | int radius) {
28 | CHECK_INPUT(fmap1);
29 | CHECK_INPUT(fmap2);
30 | CHECK_INPUT(coords);
31 |
32 | return corr_cuda_forward(fmap1, fmap2, coords, radius);
33 | }
34 |
35 |
36 | std::vector corr_backward(
37 | torch::Tensor fmap1,
38 | torch::Tensor fmap2,
39 | torch::Tensor coords,
40 | torch::Tensor corr_grad,
41 | int radius) {
42 | CHECK_INPUT(fmap1);
43 | CHECK_INPUT(fmap2);
44 | CHECK_INPUT(coords);
45 | CHECK_INPUT(corr_grad);
46 |
47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
48 | }
49 |
50 |
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52 | m.def("forward", &corr_forward, "CORR forward");
53 | m.def("backward", &corr_backward, "CORR backward");
54 | }
--------------------------------------------------------------------------------
/alt_cuda_corr/correlation_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 |
7 | #define BLOCK_H 4
8 | #define BLOCK_W 8
9 | #define BLOCK_HW BLOCK_H * BLOCK_W
10 | #define CHANNEL_STRIDE 32
11 |
12 |
13 | __forceinline__ __device__
14 | bool within_bounds(int h, int w, int H, int W) {
15 | return h >= 0 && h < H && w >= 0 && w < W;
16 | }
17 |
18 | template
19 | __global__ void corr_forward_kernel(
20 | const torch::PackedTensorAccessor32 fmap1,
21 | const torch::PackedTensorAccessor32 fmap2,
22 | const torch::PackedTensorAccessor32 coords,
23 | torch::PackedTensorAccessor32 corr,
24 | int r)
25 | {
26 | const int b = blockIdx.x;
27 | const int h0 = blockIdx.y * blockDim.x;
28 | const int w0 = blockIdx.z * blockDim.y;
29 | const int tid = threadIdx.x * blockDim.y + threadIdx.y;
30 |
31 | const int H1 = fmap1.size(1);
32 | const int W1 = fmap1.size(2);
33 | const int H2 = fmap2.size(1);
34 | const int W2 = fmap2.size(2);
35 | const int N = coords.size(1);
36 | const int C = fmap1.size(3);
37 |
38 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
39 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
40 | __shared__ scalar_t x2s[BLOCK_HW];
41 | __shared__ scalar_t y2s[BLOCK_HW];
42 |
43 | for (int c=0; c(floor(y2s[k1]))-r+iy;
76 | int w2 = static_cast(floor(x2s[k1]))-r+ix;
77 | int c2 = tid % CHANNEL_STRIDE;
78 |
79 | auto fptr = fmap2[b][h2][w2];
80 | if (within_bounds(h2, w2, H2, W2))
81 | f2[c2][k1] = fptr[c+c2];
82 | else
83 | f2[c2][k1] = 0.0;
84 | }
85 |
86 | __syncthreads();
87 |
88 | scalar_t s = 0.0;
89 | for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
105 | *(corr_ptr + ix_nw) += nw;
106 |
107 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
108 | *(corr_ptr + ix_ne) += ne;
109 |
110 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
111 | *(corr_ptr + ix_sw) += sw;
112 |
113 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
114 | *(corr_ptr + ix_se) += se;
115 | }
116 | }
117 | }
118 | }
119 | }
120 |
121 |
122 | template
123 | __global__ void corr_backward_kernel(
124 | const torch::PackedTensorAccessor32 fmap1,
125 | const torch::PackedTensorAccessor32 fmap2,
126 | const torch::PackedTensorAccessor32 coords,
127 | const torch::PackedTensorAccessor32 corr_grad,
128 | torch::PackedTensorAccessor32 fmap1_grad,
129 | torch::PackedTensorAccessor32 fmap2_grad,
130 | torch::PackedTensorAccessor32 coords_grad,
131 | int r)
132 | {
133 |
134 | const int b = blockIdx.x;
135 | const int h0 = blockIdx.y * blockDim.x;
136 | const int w0 = blockIdx.z * blockDim.y;
137 | const int tid = threadIdx.x * blockDim.y + threadIdx.y;
138 |
139 | const int H1 = fmap1.size(1);
140 | const int W1 = fmap1.size(2);
141 | const int H2 = fmap2.size(1);
142 | const int W2 = fmap2.size(2);
143 | const int N = coords.size(1);
144 | const int C = fmap1.size(3);
145 |
146 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
147 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
148 |
149 | __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
150 | __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
151 |
152 | __shared__ scalar_t x2s[BLOCK_HW];
153 | __shared__ scalar_t y2s[BLOCK_HW];
154 |
155 | for (int c=0; c(floor(y2s[k1]))-r+iy;
190 | int w2 = static_cast(floor(x2s[k1]))-r+ix;
191 | int c2 = tid % CHANNEL_STRIDE;
192 |
193 | auto fptr = fmap2[b][h2][w2];
194 | if (within_bounds(h2, w2, H2, W2))
195 | f2[c2][k1] = fptr[c+c2];
196 | else
197 | f2[c2][k1] = 0.0;
198 |
199 | f2_grad[c2][k1] = 0.0;
200 | }
201 |
202 | __syncthreads();
203 |
204 | const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
205 | scalar_t g = 0.0;
206 |
207 | int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
208 | int ix_ne = H1*W1*((iy-1) + rd*ix);
209 | int ix_sw = H1*W1*(iy + rd*(ix-1));
210 | int ix_se = H1*W1*(iy + rd*ix);
211 |
212 | if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
213 | g += *(grad_ptr + ix_nw) * dy * dx;
214 |
215 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
216 | g += *(grad_ptr + ix_ne) * dy * (1-dx);
217 |
218 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
219 | g += *(grad_ptr + ix_sw) * (1-dy) * dx;
220 |
221 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
222 | g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
223 |
224 | for (int k=0; k(floor(y2s[k1]))-r+iy;
232 | int w2 = static_cast(floor(x2s[k1]))-r+ix;
233 | int c2 = tid % CHANNEL_STRIDE;
234 |
235 | scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
236 | if (within_bounds(h2, w2, H2, W2))
237 | atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
238 | }
239 | }
240 | }
241 | }
242 | __syncthreads();
243 |
244 |
245 | for (int k=0; k corr_cuda_forward(
261 | torch::Tensor fmap1,
262 | torch::Tensor fmap2,
263 | torch::Tensor coords,
264 | int radius)
265 | {
266 | const auto B = coords.size(0);
267 | const auto N = coords.size(1);
268 | const auto H = coords.size(2);
269 | const auto W = coords.size(3);
270 |
271 | const auto rd = 2 * radius + 1;
272 | auto opts = fmap1.options();
273 | auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
274 |
275 | const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
276 | const dim3 threads(BLOCK_H, BLOCK_W);
277 |
278 | corr_forward_kernel<<>>(
279 | fmap1.packed_accessor32(),
280 | fmap2.packed_accessor32(),
281 | coords.packed_accessor32(),
282 | corr.packed_accessor32(),
283 | radius);
284 |
285 | return {corr};
286 | }
287 |
288 | std::vector corr_cuda_backward(
289 | torch::Tensor fmap1,
290 | torch::Tensor fmap2,
291 | torch::Tensor coords,
292 | torch::Tensor corr_grad,
293 | int radius)
294 | {
295 | const auto B = coords.size(0);
296 | const auto N = coords.size(1);
297 |
298 | const auto H1 = fmap1.size(1);
299 | const auto W1 = fmap1.size(2);
300 | const auto H2 = fmap2.size(1);
301 | const auto W2 = fmap2.size(2);
302 | const auto C = fmap1.size(3);
303 |
304 | auto opts = fmap1.options();
305 | auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
306 | auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
307 | auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
308 |
309 | const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
310 | const dim3 threads(BLOCK_H, BLOCK_W);
311 |
312 |
313 | corr_backward_kernel<<>>(
314 | fmap1.packed_accessor32(),
315 | fmap2.packed_accessor32(),
316 | coords.packed_accessor32(),
317 | corr_grad.packed_accessor32(),
318 | fmap1_grad.packed_accessor32(),
319 | fmap2_grad.packed_accessor32(),
320 | coords_grad.packed_accessor32(),
321 | radius);
322 |
323 | return {fmap1_grad, fmap2_grad, coords_grad};
324 | }
--------------------------------------------------------------------------------
/alt_cuda_corr/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 |
5 | setup(
6 | name='correlation',
7 | ext_modules=[
8 | CUDAExtension('alt_cuda_corr',
9 | sources=['correlation.cpp', 'correlation_kernel.cu'],
10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
11 | ],
12 | cmdclass={
13 | 'build_ext': BuildExtension
14 | })
15 |
16 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch.nn as nn
3 |
4 |
5 | def find_model_using_name(model_name):
6 | """Import the module "models/[model_name]_model.py".
7 |
8 | In the file, the class called [model_name]Model() will
9 | be instantiated. It has to be a subclass of nn.Module,
10 | and it is case-insensitive.
11 | """
12 | model_filename = model_name + "_model"
13 | modellib = importlib.import_module(model_filename)
14 | model = None
15 | target_model_name = model_name.replace('_', '') + 'model'
16 | for name, cls in modellib.__dict__.items():
17 | if name.lower() == target_model_name.lower() \
18 | and issubclass(cls, nn.Module):
19 | model = cls
20 |
21 | if model is None:
22 | print("In %s.py, there should be a subclass of nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name))
23 | exit(0)
24 |
25 | return model
26 |
27 |
28 | def create_model(opt):
29 | """Create a model given the option.
30 |
31 | This function warps the class CustomDatasetDataLoader.
32 | This is the main interface between this package and 'train.py'/'test.py'
33 |
34 | Example:
35 | >>> from core import create_model
36 | >>> model = create_model(opt)
37 | """
38 | model = find_model_using_name(opt.model)
39 | instance = model(opt)
40 | print("model [%s] was created" % type(instance).__name__)
41 | return instance
--------------------------------------------------------------------------------
/core/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | self.corrMap = corr.view(batch, h1*w1, h2*w2)
23 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
24 |
25 |
26 | self.corr_pyramid.append(corr)
27 | for i in range(self.num_levels-1):
28 | corr = F.avg_pool2d(corr, 2, stride=2)
29 | self.corr_pyramid.append(corr)
30 |
31 | def __call__(self, coords):
32 | r = self.radius
33 | coords = coords.permute(0, 2, 3, 1)
34 | batch, h1, w1, _ = coords.shape
35 |
36 | out_pyramid = []
37 | for i in range(self.num_levels):
38 | corr = self.corr_pyramid[i]
39 | dx = torch.linspace(-r, r, 2*r+1)
40 | dy = torch.linspace(-r, r, 2*r+1)
41 | delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device)
42 |
43 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
44 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
45 | coords_lvl = centroid_lvl + delta_lvl
46 |
47 | corr = bilinear_sampler(corr, coords_lvl)
48 | corr = corr.view(batch, h1, w1, -1)
49 | out_pyramid.append(corr)
50 |
51 | out = torch.cat(out_pyramid, dim=-1)
52 | return out.permute(0, 3, 1, 2).contiguous().float()
53 |
54 | @staticmethod
55 | def corr(fmap1, fmap2):
56 | batch, dim, ht, wd = fmap1.shape
57 | fmap1 = fmap1.view(batch, dim, ht*wd)
58 | fmap2 = fmap2.view(batch, dim, ht*wd)
59 |
60 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
61 | corr = corr.view(batch, ht, wd, 1, ht, wd)
62 | return corr / torch.sqrt(torch.tensor(dim).float())
63 |
64 |
65 | class AlternateCorrBlock:
66 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
67 | self.num_levels = num_levels
68 | self.radius = radius
69 |
70 | self.pyramid = [(fmap1, fmap2)]
71 | for i in range(self.num_levels):
72 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
73 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
74 | self.pyramid.append((fmap1, fmap2))
75 |
76 | def __call__(self, coords):
77 | coords = coords.permute(0, 2, 3, 1)
78 | B, H, W, _ = coords.shape
79 | dim = self.pyramid[0][0].shape[1]
80 |
81 | corr_list = []
82 | for i in range(self.num_levels):
83 | r = self.radius
84 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
85 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
86 |
87 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
88 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
89 | corr_list.append(corr.squeeze(1))
90 |
91 | corr = torch.stack(corr_list, dim=1)
92 | corr = corr.reshape(B, -1, H, W)
93 | return corr / torch.sqrt(torch.tensor(dim).float())
94 |
--------------------------------------------------------------------------------
/core/datasets.py:
--------------------------------------------------------------------------------
1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 |
8 | import os
9 | import math
10 | import random
11 | from glob import glob
12 | import os.path as osp
13 |
14 | from core.utils import frame_utils
15 | from core.utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16 |
17 | import sys
18 |
19 | class FlowDataset(data.Dataset):
20 | def __init__(self, aug_params=None, sparse=False):
21 | self.augmentor = None
22 | self.sparse = sparse
23 | if aug_params is not None:
24 | if sparse:
25 | self.augmentor = SparseFlowAugmentor(**aug_params)
26 | else:
27 | self.augmentor = FlowAugmentor(**aug_params)
28 |
29 | self.is_test = False
30 | self.is_validate = False
31 | self.init_seed = False
32 | self.flow_list = []
33 | self.image_list = []
34 | self.extra_info = []
35 |
36 | def __getitem__(self, index):
37 | # print('Index is {}'.format(index))
38 | # sys.stdout.flush()
39 | if self.is_test:
40 | img1 = frame_utils.read_gen(self.image_list[index][0])
41 | img2 = frame_utils.read_gen(self.image_list[index][1])
42 | img1 = np.array(img1).astype(np.uint8)[..., :3]
43 | img2 = np.array(img2).astype(np.uint8)[..., :3]
44 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
45 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
46 | return img1, img2, self.extra_info[index]
47 |
48 | # if not self.init_seed:
49 | # worker_info = torch.utils.data.get_worker_info()
50 | # if worker_info is not None:
51 | # torch.manual_seed(worker_info.id)
52 | # np.random.seed(worker_info.id)
53 | # random.seed(worker_info.id)
54 | # self.init_seed = True
55 |
56 | index = index % len(self.image_list)
57 | valid = None
58 | if self.sparse:
59 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
60 | else:
61 | flow = frame_utils.read_gen(self.flow_list[index])
62 |
63 | img1 = frame_utils.read_gen(self.image_list[index][0])
64 | img2 = frame_utils.read_gen(self.image_list[index][1])
65 |
66 | flow = np.array(flow).astype(np.float32)
67 | img1 = np.array(img1).astype(np.uint8)
68 | img2 = np.array(img2).astype(np.uint8)
69 |
70 | # grayscale images
71 | if len(img1.shape) == 2:
72 | img1 = np.tile(img1[...,None], (1, 1, 3))
73 | img2 = np.tile(img2[...,None], (1, 1, 3))
74 | else:
75 | img1 = img1[..., :3]
76 | img2 = img2[..., :3]
77 |
78 | if self.augmentor is not None:
79 | if self.sparse:
80 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
81 | else:
82 | img1, img2, flow = self.augmentor(img1, img2, flow)
83 |
84 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
85 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
86 | flow = torch.from_numpy(flow).permute(2, 0, 1).float()
87 |
88 | if valid is not None:
89 | valid = torch.from_numpy(valid)
90 | else:
91 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
92 |
93 | if self.is_validate:
94 | return img1, img2, flow, valid.float(), self.extra_info[index]
95 | else:
96 | return img1, img2, flow, valid.float()
97 |
98 | def getDataWithPath(self, index):
99 | img1, img2, flow, valid = self.__getitem__(index)
100 |
101 | imgPath_1 = self.image_list[index][0]
102 | imgPath_2 = self.image_list[index][1]
103 |
104 | return img1, img2, flow, valid, imgPath_1, imgPath_2
105 |
106 | def __rmul__(self, v):
107 | self.flow_list = v * self.flow_list
108 | self.image_list = v * self.image_list
109 | return self
110 |
111 | def __len__(self):
112 | return len(self.image_list)
113 |
114 |
115 | class MpiSintel(FlowDataset):
116 | def __init__(self, aug_params=None, split='training', root='./datasets/Sintel', dstype='clean', is_validate=False):
117 | super(MpiSintel, self).__init__(aug_params)
118 | flow_root = osp.join(root, split, 'flow')
119 | image_root = osp.join(root, split, dstype)
120 |
121 | self.is_validate = is_validate
122 | if split == 'test':
123 | self.is_test = True
124 |
125 | for scene in os.listdir(image_root):
126 | image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
127 | for i in range(len(image_list)-1):
128 | self.image_list += [ [image_list[i], image_list[i+1]] ]
129 | self.extra_info += [ (scene, i) ] # scene and frame_id
130 |
131 | if split != 'test':
132 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
133 |
134 |
135 | class FlyingChairs(FlowDataset):
136 | def __init__(self, aug_params=None, split='train', root='./datasets/FlyingChairs_release/data'):
137 | super(FlyingChairs, self).__init__(aug_params)
138 |
139 | images = sorted(glob(osp.join(root, '*.ppm')))
140 | flows = sorted(glob(osp.join(root, '*.flo')))
141 | assert (len(images)//2 == len(flows))
142 |
143 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
144 | for i in range(len(flows)):
145 | xid = split_list[i]
146 | if (split=='training' and xid==1) or (split=='validation' and xid==2):
147 | self.flow_list += [ flows[i] ]
148 | self.image_list += [ [images[2*i], images[2*i+1]] ]
149 |
150 |
151 | class FlyingThings3D(FlowDataset):
152 | def __init__(self, aug_params=None, root='./datasets/FlyingThings3D', dstype='frames_cleanpass'):
153 | super(FlyingThings3D, self).__init__(aug_params)
154 |
155 | for cam in ['left']:
156 | for direction in ['into_future', 'into_past']:
157 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
158 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
159 |
160 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
161 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
162 |
163 | for idir, fdir in zip(image_dirs, flow_dirs):
164 | images = sorted(glob(osp.join(idir, '*.png')) )
165 | flows = sorted(glob(osp.join(fdir, '*.pfm')) )
166 | for i in range(len(flows)-1):
167 | if direction == 'into_future':
168 | self.image_list += [ [images[i], images[i+1]] ]
169 | self.flow_list += [ flows[i] ]
170 | elif direction == 'into_past':
171 | self.image_list += [ [images[i+1], images[i]] ]
172 | self.flow_list += [ flows[i+1] ]
173 |
174 |
175 | class KITTI(FlowDataset):
176 | def __init__(self, aug_params=None, split='training', root='./datasets/kitti'):
177 | super(KITTI, self).__init__(aug_params, sparse=True)
178 | if split == 'testing':
179 | self.is_test = True
180 |
181 | root = osp.join(root, split)
182 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
183 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
184 |
185 | for img1, img2 in zip(images1, images2):
186 | frame_id = img1.split('/')[-1]
187 | self.extra_info += [ [frame_id] ]
188 | self.image_list += [ [img1, img2] ]
189 |
190 | if split == 'training':
191 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
192 |
193 |
194 | class HD1K(FlowDataset):
195 | def __init__(self, aug_params=None, root='./datasets/hd1k_full_package'):
196 | super(HD1K, self).__init__(aug_params, sparse=True)
197 |
198 | seq_ix = 0
199 | while 1:
200 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
201 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
202 |
203 | if len(flows) == 0:
204 | break
205 |
206 | for i in range(len(flows)-1):
207 | self.flow_list += [flows[i]]
208 | self.image_list += [ [images[i], images[i+1]] ]
209 |
210 | seq_ix += 1
211 |
212 |
213 | def fetch_dataloader(args, TRAIN_DS='C+T+K/S'):
214 | """ Create the data loader for the corresponding trainign set """
215 |
216 | if args.stage == 'chairs':
217 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
218 | train_dataset = FlyingChairs(aug_params, split='training')
219 |
220 | elif args.stage == 'things':
221 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
222 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
223 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
224 | train_dataset = clean_dataset + final_dataset
225 |
226 | elif args.stage == 'sintel':
227 | print('Training Sintel Stage...')
228 | sys.stdout.flush()
229 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
230 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
231 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
232 | sintel_final = MpiSintel(aug_params, split='training', dstype='final')
233 |
234 | if TRAIN_DS == 'C+T+K+S+H':
235 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
236 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
237 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
238 |
239 | elif TRAIN_DS == 'C+T+K+H':
240 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
241 | train_dataset = 100*sintel_clean + 100*sintel_final + 5*hd1k + things
242 |
243 | elif TRAIN_DS == 'C+T+S+K':
244 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
245 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + things
246 |
247 | elif TRAIN_DS == 'C+T+K/S':
248 | train_dataset = 100*sintel_clean + 100*sintel_final + things
249 |
250 | elif args.stage == 'kitti':
251 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
252 | train_dataset = KITTI(aug_params, split='training')
253 |
254 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
255 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
256 |
257 | print('Training with %d image pairs' % len(train_dataset))
258 | return train_loader
259 |
--------------------------------------------------------------------------------
/core/extractor.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class ResidualBlock(nn.Module):
8 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
9 | super(ResidualBlock, self).__init__()
10 |
11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
13 | self.relu = nn.ReLU(inplace=True)
14 |
15 | num_groups = planes // 8
16 |
17 | if norm_fn == 'group':
18 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
20 | if not stride == 1:
21 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
22 |
23 | elif norm_fn == 'batch':
24 | self.norm1 = nn.BatchNorm2d(planes)
25 | self.norm2 = nn.BatchNorm2d(planes)
26 | if not stride == 1:
27 | self.norm3 = nn.BatchNorm2d(planes)
28 |
29 | elif norm_fn == 'instance':
30 | self.norm1 = nn.InstanceNorm2d(planes)
31 | self.norm2 = nn.InstanceNorm2d(planes)
32 | if not stride == 1:
33 | self.norm3 = nn.InstanceNorm2d(planes)
34 |
35 | elif norm_fn == 'none':
36 | self.norm1 = nn.Sequential()
37 | self.norm2 = nn.Sequential()
38 | if not stride == 1:
39 | self.norm3 = nn.Sequential()
40 |
41 | if stride == 1:
42 | self.downsample = None
43 |
44 | else:
45 | self.downsample = nn.Sequential(
46 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
47 |
48 |
49 | def forward(self, x):
50 | y = x
51 | y = self.relu(self.norm1(self.conv1(y)))
52 | y = self.relu(self.norm2(self.conv2(y)))
53 |
54 | if self.downsample is not None:
55 | x = self.downsample(x)
56 |
57 | return self.relu(x+y)
58 |
59 |
60 |
61 | class BottleneckBlock(nn.Module):
62 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
63 | super(BottleneckBlock, self).__init__()
64 |
65 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
66 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
67 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
68 | self.relu = nn.ReLU(inplace=True)
69 |
70 | num_groups = planes // 8
71 |
72 | if norm_fn == 'group':
73 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
75 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
76 | if not stride == 1:
77 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
78 |
79 | elif norm_fn == 'batch':
80 | self.norm1 = nn.BatchNorm2d(planes//4)
81 | self.norm2 = nn.BatchNorm2d(planes//4)
82 | self.norm3 = nn.BatchNorm2d(planes)
83 | if not stride == 1:
84 | self.norm4 = nn.BatchNorm2d(planes)
85 |
86 | elif norm_fn == 'instance':
87 | self.norm1 = nn.InstanceNorm2d(planes//4)
88 | self.norm2 = nn.InstanceNorm2d(planes//4)
89 | self.norm3 = nn.InstanceNorm2d(planes)
90 | if not stride == 1:
91 | self.norm4 = nn.InstanceNorm2d(planes)
92 |
93 | elif norm_fn == 'none':
94 | self.norm1 = nn.Sequential()
95 | self.norm2 = nn.Sequential()
96 | self.norm3 = nn.Sequential()
97 | if not stride == 1:
98 | self.norm4 = nn.Sequential()
99 |
100 | if stride == 1:
101 | self.downsample = None
102 |
103 | else:
104 | self.downsample = nn.Sequential(
105 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
106 |
107 |
108 | def forward(self, x):
109 | y = x
110 | y = self.relu(self.norm1(self.conv1(y)))
111 | y = self.relu(self.norm2(self.conv2(y)))
112 | y = self.relu(self.norm3(self.conv3(y)))
113 |
114 | if self.downsample is not None:
115 | x = self.downsample(x)
116 |
117 | return self.relu(x+y)
118 |
119 | class BasicEncoder(nn.Module):
120 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
121 | super(BasicEncoder, self).__init__()
122 | self.norm_fn = norm_fn
123 |
124 | if self.norm_fn == 'group':
125 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
126 |
127 | elif self.norm_fn == 'batch':
128 | self.norm1 = nn.BatchNorm2d(64)
129 |
130 | elif self.norm_fn == 'instance':
131 | self.norm1 = nn.InstanceNorm2d(64)
132 |
133 | elif self.norm_fn == 'none':
134 | self.norm1 = nn.Sequential()
135 |
136 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
137 | self.relu1 = nn.ReLU(inplace=True)
138 |
139 | self.in_planes = 64
140 | self.layer1 = self._make_layer(64, stride=1)
141 | self.layer2 = self._make_layer(96, stride=2)
142 | self.layer3 = self._make_layer(128, stride=2)
143 |
144 | # output convolution
145 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
146 |
147 | self.dropout = None
148 | if dropout > 0:
149 | self.dropout = nn.Dropout2d(p=dropout)
150 |
151 | for m in self.modules():
152 | if isinstance(m, nn.Conv2d):
153 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
154 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
155 | if m.weight is not None:
156 | nn.init.constant_(m.weight, 1)
157 | if m.bias is not None:
158 | nn.init.constant_(m.bias, 0)
159 |
160 | def _make_layer(self, dim, stride=1):
161 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
162 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
163 | layers = (layer1, layer2)
164 |
165 | self.in_planes = dim
166 | return nn.Sequential(*layers)
167 |
168 |
169 | def forward(self, x):
170 |
171 | # if input is list, combine batch dimension
172 | is_list = isinstance(x, tuple) or isinstance(x, list)
173 | if is_list:
174 | batch_dim = x[0].shape[0]
175 | x = torch.cat(x, dim=0)
176 |
177 | x = self.conv1(x)
178 | x = self.norm1(x)
179 | x = self.relu1(x)
180 |
181 | x = self.layer1(x)
182 | x = self.layer2(x)
183 | x = self.layer3(x)
184 |
185 | x = self.conv2(x)
186 |
187 | if self.training and self.dropout is not None:
188 | x = self.dropout(x)
189 |
190 | if is_list:
191 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
192 |
193 | return x
194 |
195 |
196 | class IPTHeadEncoder(nn.Module):
197 | """docstring for IPTHead"""
198 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
199 | super(IPTHeadEncoder, self).__init__()
200 | self.norm_fn = norm_fn
201 |
202 | if self.norm_fn == 'group':
203 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
204 |
205 | elif self.norm_fn == 'batch':
206 | self.norm1 = nn.BatchNorm2d(64)
207 |
208 | elif self.norm_fn == 'instance':
209 | self.norm1 = nn.InstanceNorm2d(64)
210 |
211 | elif self.norm_fn == 'none':
212 | self.norm1 = nn.Sequential()
213 |
214 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
215 | self.relu1 = nn.ReLU(inplace=True)
216 |
217 | half_out_dim = max(output_dim // 2, 64)
218 | self.layer1 = ResidualBlock(64, half_out_dim, self.norm_fn, stride=2)
219 | self.layer2 = ResidualBlock(half_out_dim, output_dim, self.norm_fn, stride=2)
220 |
221 | # # output convolution; this can solve mixed memory warning, not know why
222 | # self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
223 |
224 | self.dropout = None
225 | if dropout > 0:
226 | self.dropout = nn.Dropout2d(p=dropout)
227 |
228 | def forward(self, x):
229 |
230 | # if input is list, combine batch dimension
231 | is_list = isinstance(x, tuple) or isinstance(x, list)
232 | if is_list:
233 | batch_dim = x[0].shape[0]
234 | x = torch.cat(x, dim=0)
235 |
236 | x = self.relu1(self.norm1(self.conv1(x)))
237 | x = self.layer1(x)
238 | x = self.layer2(x)
239 |
240 | if self.training and self.dropout is not None:
241 | x = self.dropout(x)
242 |
243 | if is_list:
244 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
245 |
246 | return x
247 |
248 |
249 | class BasicConvEncoder(nn.Module):
250 | """docstring for BasicConvEncoder"""
251 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
252 | super(BasicConvEncoder, self).__init__()
253 | self.norm_fn = norm_fn
254 |
255 | half_out_dim = max(output_dim // 2, 64)
256 |
257 | if self.norm_fn == 'group':
258 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
259 | self.norm2 = nn.GroupNorm(num_groups=8, num_channels=64)
260 | self.norm3 = nn.GroupNorm(num_groups=8, num_channels=64)
261 |
262 | elif self.norm_fn == 'batch':
263 | self.norm1 = nn.BatchNorm2d(64)
264 | self.norm2 = nn.BatchNorm2d(half_out_dim)
265 | self.norm3 = nn.BatchNorm2d(output_dim)
266 |
267 | elif self.norm_fn == 'instance':
268 | self.norm1 = nn.InstanceNorm2d(64)
269 | self.norm2 = nn.InstanceNorm2d(half_out_dim)
270 | self.norm3 = nn.InstanceNorm2d(output_dim)
271 |
272 | elif self.norm_fn == 'none':
273 | self.norm1 = nn.Sequential()
274 | self.norm2 = nn.Sequential()
275 | self.norm3 = nn.Sequential()
276 |
277 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
278 | self.conv2 = nn.Conv2d(64, half_out_dim, kernel_size=3, stride=2, padding=1)
279 | self.conv3 = nn.Conv2d(half_out_dim, output_dim, kernel_size=3, stride=2, padding=1)
280 |
281 | # # output convolution; this can solve mixed memory warning, not know why
282 | # self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
283 |
284 | self.dropout = None
285 | if dropout > 0:
286 | self.dropout = nn.Dropout2d(p=dropout)
287 |
288 | def forward(self, x):
289 |
290 | # if input is list, combine batch dimension
291 | is_list = isinstance(x, tuple) or isinstance(x, list)
292 | if is_list:
293 | batch_dim = x[0].shape[0]
294 | x = torch.cat(x, dim=0)
295 |
296 | x = F.relu(self.norm1(self.conv1(x)), inplace=True)
297 | x = F.relu(self.norm2(self.conv2(x)), inplace=True)
298 | x = F.relu(self.norm3(self.conv3(x)), inplace=True)
299 |
300 | if self.training and self.dropout is not None:
301 | x = self.dropout(x)
302 |
303 | if is_list:
304 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
305 |
306 | return x
307 |
308 |
309 | class SmallEncoder(nn.Module):
310 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
311 | super(SmallEncoder, self).__init__()
312 | self.norm_fn = norm_fn
313 |
314 | if self.norm_fn == 'group':
315 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
316 |
317 | elif self.norm_fn == 'batch':
318 | self.norm1 = nn.BatchNorm2d(32)
319 |
320 | elif self.norm_fn == 'instance':
321 | self.norm1 = nn.InstanceNorm2d(32)
322 |
323 | elif self.norm_fn == 'none':
324 | self.norm1 = nn.Sequential()
325 |
326 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
327 | self.relu1 = nn.ReLU(inplace=True)
328 |
329 | self.in_planes = 32
330 | self.layer1 = self._make_layer(32, stride=1)
331 | self.layer2 = self._make_layer(64, stride=2)
332 | self.layer3 = self._make_layer(96, stride=2)
333 |
334 | self.dropout = None
335 | if dropout > 0:
336 | self.dropout = nn.Dropout2d(p=dropout)
337 |
338 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
339 |
340 | for m in self.modules():
341 | if isinstance(m, nn.Conv2d):
342 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
343 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
344 | if m.weight is not None:
345 | nn.init.constant_(m.weight, 1)
346 | if m.bias is not None:
347 | nn.init.constant_(m.bias, 0)
348 |
349 | def _make_layer(self, dim, stride=1):
350 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
351 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
352 | layers = (layer1, layer2)
353 |
354 | self.in_planes = dim
355 | return nn.Sequential(*layers)
356 |
357 |
358 | def forward(self, x):
359 |
360 | # if input is list, combine batch dimension
361 | is_list = isinstance(x, tuple) or isinstance(x, list)
362 | if is_list:
363 | batch_dim = x[0].shape[0]
364 | x = torch.cat(x, dim=0)
365 |
366 | x = self.conv1(x)
367 | x = self.norm1(x)
368 | x = self.relu1(x)
369 |
370 | x = self.layer1(x)
371 | x = self.layer2(x)
372 | x = self.layer3(x)
373 | x = self.conv2(x)
374 |
375 | if self.training and self.dropout is not None:
376 | x = self.dropout(x)
377 |
378 | if is_list:
379 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
380 |
381 | return x
382 |
383 |
384 |
385 |
386 | class MultiHeadAttention(nn.Module):
387 | ''' Multi-Head Attention module '''
388 |
389 | def __init__(self, d_model, n_head, dropout=0.1):
390 | super().__init__()
391 |
392 | assert d_model % n_head == 0
393 |
394 | self.n_head = n_head
395 | self.d_model = d_model
396 | self.d_head = self.d_model // self.n_head
397 |
398 | self.w_qs = nn.Linear(self.d_model, self.d_model, bias=False) # TODO: enable bias
399 | self.w_ks = nn.Linear(self.d_model, self.d_model, bias=False)
400 | self.w_vs = nn.Linear(self.d_model, self.d_model, bias=False)
401 | self.fc = nn.Linear(self.d_model, self.d_model)
402 |
403 | # self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) # TODO
404 |
405 | self.dropout = nn.Dropout(dropout)
406 | # self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
407 |
408 |
409 | def forward(self, q, k, v):
410 | '''
411 | q: shape of N*len*C
412 | '''
413 | d_head, n_head = self.d_head, self.n_head
414 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
415 |
416 | # residual = q
417 |
418 | # Pass through the pre-attention projection: b x lq x (n*dv)
419 | # Separate different heads: b x lq x n x dv
420 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_head)
421 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_head)
422 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_head)
423 |
424 | # Transpose for attention dot product: b x n x lq x dv
425 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
426 |
427 | attn = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(self.d_head)
428 | attn = self.dropout(F.softmax(attn, dim=-1))
429 | q_updated = torch.matmul(attn, v)
430 |
431 | # Transpose to move the head dimension back: b x lq x n x dv
432 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
433 | q_updated = q_updated.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
434 | q_updated = self.dropout(self.fc(q_updated))
435 | # q_updated += residual
436 |
437 | # q_updated = self.layer_norm(q_updated)
438 |
439 | return q_updated, attn
440 |
441 |
442 |
443 | class AnchorEncoderBlock(nn.Module):
444 |
445 | def __init__(self, anchor_dist, d_model, num_heads, d_ff, dropout=0.):
446 | super().__init__()
447 |
448 | self.anchor_dist = anchor_dist
449 | self.half_anchor_dist = anchor_dist // 2
450 |
451 | self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout)
452 | self.dropout = nn.Dropout(dropout)
453 | self.layer_norm_1 = nn.LayerNorm(d_model)
454 |
455 | self.FFN = nn.Sequential(
456 | nn.Linear(d_model, d_ff),
457 | nn.ReLU(),
458 | nn.Linear(d_ff, d_model),
459 | )
460 |
461 | self.layer_norm_2 = nn.LayerNorm(d_model)
462 |
463 |
464 | def forward(self, inputs):
465 | '''
466 | inputs: batches with N*C*H*W
467 | '''
468 | N, C, H, W = inputs.shape
469 |
470 | x = inputs
471 | anchors = inputs[:,:, self.half_anchor_dist::self.anchor_dist,
472 | self.half_anchor_dist::self.anchor_dist].clone()
473 |
474 | # flatten feature maps
475 | x = x.reshape(N, C, H*W).transpose(-1,-2)
476 | anchors = anchors.reshape(N, C, anchors.shape[2]* anchors.shape[3]).transpose(-1,-2)
477 |
478 | # two-stage multi-head self-attention
479 | anchors_new = self.dropout(self.selfAttn(anchors, x, x)[0])
480 | residual = self.dropout(self.selfAttn(x, anchors_new, anchors_new)[0])
481 |
482 | norm_1 = self.layer_norm_1(x + residual)
483 | x_linear = self.dropout(self.FFN(norm_1))
484 | x_new = self.layer_norm_2(norm_1 + x_linear)
485 |
486 | outputs = x_new.transpose(-1,-2).reshape(N, C, H, W)
487 | return outputs
488 |
489 |
490 |
491 | class EncoderBlock(nn.Module):
492 |
493 | def __init__(self, d_model, num_heads, d_ff, dropout=0.):
494 | super().__init__()
495 |
496 | self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout)
497 | self.dropout = nn.Dropout(dropout)
498 | self.layer_norm_1 = nn.LayerNorm(d_model)
499 |
500 | self.FFN = nn.Sequential(
501 | nn.Linear(d_model, d_ff),
502 | nn.ReLU(),
503 | nn.Linear(d_ff, d_model),
504 | )
505 |
506 | self.layer_norm_2 = nn.LayerNorm(d_model)
507 |
508 |
509 | def forward(self, x):
510 | '''
511 | x: input batches with N*C*H*W
512 | '''
513 | N, C, H, W = x.shape
514 |
515 | # update x
516 | x = x.reshape(N, C, H*W).transpose(-1,-2)
517 |
518 | residual = self.dropout(self.selfAttn(x, x, x)[0])
519 | norm_1 = self.layer_norm_1(x + residual)
520 | x_linear = self.dropout(self.FFN(norm_1))
521 | x_new = self.layer_norm_2(norm_1 + x_linear)
522 |
523 | outputs = x_new.transpose(-1,-2).reshape(N, C, H, W)
524 | return outputs
525 |
526 |
527 |
528 | class ReduceEncoderBlock(nn.Module):
529 |
530 | def __init__(self, d_model, num_heads, d_ff, dropout=0.):
531 | super().__init__()
532 |
533 | self.reduce = nn.Sequential(
534 | nn.Conv2d(d_model, d_model, 2, 2),
535 | nn.Conv2d(d_model, d_model, 2, 2)
536 | )
537 | # self.reduce = nn.Sequential(
538 | # nn.AvgPool2d(16, 16)
539 | # )
540 |
541 | self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout)
542 | self.dropout = nn.Dropout(dropout)
543 | self.layer_norm_1 = nn.LayerNorm(d_model)
544 |
545 | self.FFN = nn.Sequential(
546 | nn.Linear(d_model, d_ff),
547 | nn.ReLU(),
548 | nn.Linear(d_ff, d_model)
549 | )
550 |
551 | self.layer_norm_2 = nn.LayerNorm(d_model)
552 |
553 |
554 | def forward(self, x):
555 | '''
556 | x: input batches with N*C*H*W
557 | '''
558 | N, C, H, W = x.shape
559 | x_reduced = self.reduce(x)
560 |
561 | # update x
562 | x = x.reshape(N, C, H*W).transpose(-1,-2)
563 | x_reduced = x_reduced.reshape(N, C, -1).transpose(-1,-2)
564 |
565 | # print('x ', x.shape)
566 | # print('x_reduced ', x_reduced.shape)
567 | # exit()
568 |
569 | residual = self.dropout(self.selfAttn(x, x_reduced, x_reduced)[0])
570 |
571 | norm_1 = self.layer_norm_1(x + residual)
572 | x_linear = self.dropout(self.FFN(norm_1))
573 | x_new = self.layer_norm_2(norm_1 + x_linear)
574 |
575 | outputs = x_new.transpose(-1,-2).reshape(N, C, H, W)
576 | return outputs
577 |
578 |
579 |
580 | def window_partition(x, window_size):
581 | """
582 | Args:
583 | x: (B, H, W, C)
584 | window_size (int): window size
585 |
586 | Returns:
587 | windows: (num_windows*B, window_size, window_size, C)
588 | """
589 | B, H, W, C = x.shape
590 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
591 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
592 | return windows
593 |
594 |
595 | def window_reverse(windows, window_size, H, W):
596 | """
597 | Args:
598 | windows: (num_windows*B, window_size, window_size, C)
599 | window_size (int): Window size
600 | H (int): Height of image
601 | W (int): Width of image
602 |
603 | Returns:
604 | x: (B, H, W, C)
605 | """
606 | B = int(windows.shape[0] / (H * W / window_size / window_size))
607 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
608 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
609 | return x
610 |
611 |
612 | class LayerEncoderBlock(nn.Module):
613 |
614 | def __init__(self, win_size, d_model, num_heads, d_ff, dropout=0.):
615 | super().__init__()
616 |
617 | self.win_size = win_size
618 | self.down_factor = 4
619 | self.unfold_stride = int(self.win_size//self.down_factor)
620 |
621 | self.stride_list = [math.floor(win_size/self.down_factor**idx) for idx in range(8) if win_size/self.down_factor**idx >= 1]
622 | # [16, 4, 1]
623 |
624 | self.reduce = nn.Sequential(
625 | nn.AvgPool2d(self.down_factor, self.down_factor)
626 | )
627 |
628 | self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout)
629 | self.crossAttn = MultiHeadAttention(d_model, num_heads, dropout)
630 |
631 | self.dropout = nn.Dropout(dropout)
632 | self.layerNormSelf = nn.LayerNorm(d_model)
633 | self.layerNormCross = nn.LayerNorm(d_model)
634 |
635 | self.FFN = nn.Sequential(
636 | nn.Linear(d_model, d_ff),
637 | nn.ReLU(),
638 | nn.Linear(d_ff, d_model)
639 | )
640 |
641 | self.layer_norm_out = nn.LayerNorm(d_model)
642 |
643 |
644 | def Circular_pad2D(self, x, pad_right, pad_bottom):
645 | '''
646 | x: (N, H, W, C)
647 | x_pad: (N, H_pad, W_pad, C)
648 | '''
649 | N, H, W, C = x.shape
650 |
651 | H_pad = H + pad_bottom
652 | W_pad = W + pad_right
653 |
654 | H_repeat = math.ceil(H_pad/H)
655 | W_repeat = math.ceil(W_pad/W)
656 | x_repeat = x.repeat(1, H_repeat, W_repeat, 1)
657 |
658 | x_pad = x_repeat[:, :H_pad, :W_pad, :]
659 | return x_pad
660 |
661 |
662 | def pad_fit_win(self, x, win_size):
663 | N, H, W, C = x.shape
664 |
665 | W_ = math.ceil(W/win_size)*win_size
666 | H_ = math.ceil(H/win_size)*win_size
667 | padRight = W_ - W
668 | padBottom = H_ - H
669 |
670 | x_pad = self.Circular_pad2D(x, padRight, padBottom) # N*H_*W_*C
671 | return x_pad
672 |
673 |
674 | def self_attention(self, x):
675 | '''
676 | x: (N, H, W, C)
677 | out: (N, H, W, C)
678 | '''
679 | N, H, W, C = x.shape
680 | x_pad = self.pad_fit_win(x, self.win_size) # N*H_*W_*C
681 | _, H_, W_, _ = x_pad.shape
682 |
683 | # x_pad = F.pad(x.permute(xxx), (0, padRight, 0, padBottom), mode='reflect') # N*C*H_*W_
684 |
685 | x_window = window_partition(x_pad, self.win_size) # (num_win*B, win_size, win_size, C)
686 | x_window = x_window.view(-1, self.win_size*self.win_size, C) # (num_win*B, win_size*win_size, C)
687 |
688 | # self-attention
689 | residual = self.dropout(self.selfAttn(x_window, x_window, x_window)[0])
690 | residual = residual.view(-1, self.win_size, self.win_size, C)
691 | residual = window_reverse(residual, self.win_size, H_, W_) # (N, H_, W_, C)
692 |
693 | out = x_pad + residual
694 | out = out[:, :H, :W, :]
695 | return out
696 |
697 |
698 | def cross_attention(self, query, keyVal):
699 | '''
700 | query: (N, qH, qW, C)
701 | keyVal: (N, kH, kW, C)
702 | out: (N, qH, qW, C)
703 | '''
704 | _, qH, qW, C = query.shape
705 | _, kH, kW, C = keyVal.shape
706 |
707 | # print('in query ', query.shape)
708 | # print('in keyVal ', keyVal.shape)
709 | # print('-')
710 |
711 | query = self.pad_fit_win(query, self.win_size) # N*H_*W_*C
712 | _, qH_, qW_, C = query.shape
713 |
714 | query_win = window_partition(query, self.win_size)
715 | query_win = query_win.view(-1, self.win_size*self.win_size, C) # (num_win*B, win_size*win_size, C)
716 |
717 | # pad and unfold keyVal
718 | kW_ = (math.ceil(kW/self.unfold_stride) - 1)*self.unfold_stride + self.win_size
719 | kH_ = (math.ceil(kH/self.unfold_stride) - 1)*self.unfold_stride + self.win_size
720 | padRight = kW_ - kW
721 | padBottom = kH_ - kH
722 |
723 | keyVal_pad = self.Circular_pad2D(keyVal, padRight, padBottom)
724 | keyVal = F.unfold(keyVal_pad.permute(0, 3, 1, 2), self.win_size, stride=self.unfold_stride) # (N, C*win_size*win_size, num_win)
725 | keyVal = keyVal.permute(0,2,1).reshape(-1, C, self.win_size*self.win_size).permute(0,2,1) # (num_win*B, win_size*win_size, C)
726 |
727 | # print('win query ', query_win.shape)
728 | # print('win keyVal ', keyVal.shape)
729 | # print('-')
730 |
731 | residual = self.dropout(self.crossAttn(query_win, keyVal, keyVal)[0])
732 | residual = residual.view(-1, self.win_size, self.win_size, C)
733 | residual = window_reverse(residual, self.win_size, qH_, qW_) # (N, H, W, C)
734 |
735 | out = query + residual
736 | out = out[:, :qH, :qW, :]
737 | return out
738 |
739 |
740 | def forward(self, x):
741 | '''
742 | x: input batches with N*C*H*W
743 | '''
744 | N, C, H, W = x.shape
745 | x = x.permute(0, 2, 3, 1) # N*H*W*C
746 | x = self.pad_fit_win(x, self.win_size) # pad
747 |
748 | # layered self-attention
749 | layerAttnList = []
750 | strideListLen = len(self.stride_list)
751 | for idx in range(strideListLen):
752 | x_attn = self.self_attention(x) # built-in shortcut
753 | x_attn = self.layerNormSelf(x_attn)
754 | layerAttnList.append(x_attn)
755 |
756 | if idx < strideListLen - 1:
757 | x = self.reduce(x_attn.permute(0, 3, 1 ,2)) # N*C*H*W
758 | x = x.permute(0, 2, 3, 1) # N*H*W*C
759 |
760 | # layered cross-attention
761 | KeyVal = layerAttnList[-1]
762 | for idx in range(strideListLen-1, 0, -1):
763 | Query = layerAttnList[idx-1]
764 | Query = self.cross_attention(Query, KeyVal) # built-in shortcut
765 | Query = self.layerNormCross(Query)
766 |
767 | KeyVal = Query
768 |
769 | Query = Query[:, :H, :W, :] # unpad
770 |
771 | q_residual = self.dropout(self.FFN(Query))
772 | x_new = self.layer_norm_out(Query + q_residual)
773 |
774 | outputs = x_new.permute(0, 3, 1, 2)
775 | return outputs
776 |
777 |
778 |
779 | class BasicLayerEncoderBlock(nn.Module):
780 |
781 | def __init__(self, win_size, d_model, num_heads, d_ff, dropout=0.):
782 | super().__init__()
783 |
784 | self.win_size = win_size
785 | self.down_factor = 2
786 | self.unfold_stride = int(self.win_size//self.down_factor)
787 |
788 | self.stride_list = [math.floor(win_size/self.down_factor**idx) for idx in range(8) if win_size/self.down_factor**idx >= 1]
789 | # [16, 8, 4, 2, 1]
790 |
791 | self.reduce = nn.Sequential(
792 | nn.AvgPool2d(self.down_factor, self.down_factor)
793 | )
794 |
795 | self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout)
796 | self.crossAttn = MultiHeadAttention(d_model, num_heads, dropout)
797 |
798 | self.dropout = nn.Dropout(dropout)
799 | self.layerNormSelf = nn.LayerNorm(d_model)
800 | self.layerNormCross = nn.LayerNorm(d_model)
801 |
802 | self.FFN = nn.Sequential(
803 | nn.Linear(d_model, d_ff),
804 | nn.ReLU(),
805 | nn.Linear(d_ff, d_model)
806 | )
807 |
808 | self.layer_norm_out = nn.LayerNorm(d_model)
809 |
810 |
811 | def Circular_pad2D(self, x, pad_right, pad_bottom):
812 | '''
813 | x: (N, H, W, C)
814 | x_pad: (N, H_pad, W_pad, C)
815 | '''
816 | N, H, W, C = x.shape
817 |
818 | H_pad = H + pad_bottom
819 | W_pad = W + pad_right
820 |
821 | H_repeat = math.ceil(H_pad/H)
822 | W_repeat = math.ceil(W_pad/W)
823 | x_repeat = x.repeat(1, H_repeat, W_repeat, 1)
824 |
825 | x_pad = x_repeat[:, :H_pad, :W_pad, :]
826 | return x_pad
827 |
828 |
829 | def pad_fit_win(self, x, win_size):
830 | N, H, W, C = x.shape
831 |
832 | W_ = math.ceil(W/win_size)*win_size
833 | H_ = math.ceil(H/win_size)*win_size
834 | padRight = W_ - W
835 | padBottom = H_ - H
836 |
837 | x_pad = self.Circular_pad2D(x, padRight, padBottom) # N*H_*W_*C
838 | return x_pad
839 |
840 |
841 | def self_attention(self, x):
842 | '''
843 | x: (N, H, W, C)
844 | out: (N, H, W, C)
845 | '''
846 | N, H, W, C = x.shape
847 | x_pad = self.pad_fit_win(x, self.win_size) # N*H_*W_*C
848 | _, H_, W_, _ = x_pad.shape
849 |
850 | # x_pad = F.pad(x.permute(xxx), (0, padRight, 0, padBottom), mode='reflect') # N*C*H_*W_
851 |
852 | x_window = window_partition(x_pad, self.win_size) # (num_win*B, win_size, win_size, C)
853 | x_window = x_window.view(-1, self.win_size*self.win_size, C) # (num_win*B, win_size*win_size, C)
854 |
855 | # self-attention
856 | residual = self.dropout(self.selfAttn(x_window, x_window, x_window)[0])
857 | residual = residual.view(-1, self.win_size, self.win_size, C)
858 | residual = window_reverse(residual, self.win_size, H_, W_) # (N, H_, W_, C)
859 |
860 | out = x_pad + residual
861 | out = out[:, :H, :W, :]
862 | return out
863 |
864 |
865 | def cross_attention(self, query, keyVal, query_win_size):
866 | '''
867 | query: (N, qH, qW, C)
868 | keyVal: (N, kH, kW, C)
869 | out: (N, qH, qW, C)
870 | '''
871 | _, qH, qW, C = query.shape
872 |
873 | query_win = window_partition(query, query_win_size)
874 | query_win = query_win.view(-1, query_win_size*query_win_size, C) # (num_win*B, win_size*win_size, C)
875 |
876 | keyWinSize = query_win_size // 2
877 | keyVal_win = window_partition(keyVal, keyWinSize)
878 | keyVal_win = keyVal_win.view(-1, keyWinSize*keyWinSize, C) # (num_win*B, win_size*win_size, C)
879 |
880 | residual = self.dropout(self.crossAttn(query_win, keyVal_win, keyVal_win)[0])
881 | residual = residual.view(-1, query_win_size, query_win_size, C)
882 | residual = window_reverse(residual, query_win_size, qH, qW) # (N, H, W, C)
883 |
884 | out = query + residual
885 | return out
886 |
887 |
888 | def forward(self, x):
889 | '''
890 | x: input batches with N*C*H*W
891 | '''
892 | N, C, H, W = x.shape
893 | x = x.permute(0, 2, 3, 1) # N*H*W*C
894 | x = self.pad_fit_win(x, self.win_size) # pad
895 |
896 | # layered self-attention
897 | layerAttnList = []
898 | strideListLen = len(self.stride_list)
899 | for idx in range(strideListLen):
900 | x_attn = self.self_attention(x) # built-in shortcut
901 | x_attn = self.layerNormSelf(x_attn)
902 | layerAttnList.append(x_attn)
903 |
904 | if idx < strideListLen - 1:
905 | x = self.reduce(x_attn.permute(0, 3, 1 ,2)) # N*C*H*W
906 | x = x.permute(0, 2, 3, 1) # N*H*W*C
907 |
908 | # layered cross-attention
909 | KeyVal = layerAttnList[-1]
910 | for idx in range(strideListLen-1, 0, -1):
911 | Query = layerAttnList[idx-1]
912 | QueryWinSize = self.stride_list[idx-1]
913 |
914 | Query = self.cross_attention(Query, KeyVal, QueryWinSize) # built-in shortcut
915 | Query = self.layerNormCross(Query)
916 |
917 | KeyVal = Query
918 |
919 | Query = Query[:, :H, :W, :] # unpad
920 |
921 | q_residual = self.dropout(self.FFN(Query))
922 | x_new = self.layer_norm_out(Query + q_residual)
923 |
924 | outputs = x_new.permute(0, 3, 1, 2)
925 | return outputs
926 |
927 |
928 |
929 |
930 | class PositionalEncoding(nn.Module):
931 |
932 | def __init__(self, d_model, dropout=0.):
933 | super().__init__()
934 |
935 | self.max_len = 256
936 | self.d_model = d_model
937 |
938 | self._update_PE_table(self.max_len, self.d_model//2)
939 |
940 |
941 | def _update_PE_table(self, max_len, d_model):
942 | self.PE_table = torch.zeros(max_len, d_model)
943 |
944 | pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
945 | denominator = torch.pow(10000, torch.arange(0, d_model, 2).float()/d_model)
946 |
947 | self.PE_table[:, 0::2] = torch.sin(pos/denominator)
948 | self.PE_table[:, 1::2] = torch.cos(pos/denominator)
949 |
950 |
951 | def forward(self, x):
952 | ''' x: image batches with N*C*H*W '''
953 |
954 | N, C, H, W = x.shape
955 | max_hw = max(H, W)
956 |
957 | if max_hw > self.max_len or self.d_model != C:
958 | self.max_len = max_hw
959 | self.d_model = C
960 |
961 | self._update_PE_table(self.max_len, self.d_model//2)
962 |
963 | if self.PE_table.device != x.device:
964 | self.PE_table = self.PE_table.to(x.device)
965 |
966 | h_pos_emb = self.PE_table[:H, :].unsqueeze(1).repeat(1, W, 1) # H*W*C/2
967 | w_pos_emb = self.PE_table[:W, :].unsqueeze(0).repeat(H, 1, 1) # H*W*C/2
968 | pos_emb = torch.cat([h_pos_emb, w_pos_emb], dim=-1
969 | ).permute([2,0,1]).unsqueeze(0).repeat(N,1,1,1) # N*C*H*W
970 |
971 | output = x + pos_emb
972 | return output
973 |
974 |
975 |
976 | class TransformerEncoder(nn.Module):
977 |
978 | def __init__(self, anchor_dist, num_blocks, d_model, num_heads, d_ff, dropout=0.):
979 | super().__init__()
980 |
981 | self.anchor_dist = anchor_dist
982 |
983 | blocks_list = []
984 | for idx in range(num_blocks):
985 | # blocks_list.append( AnchorEncoderBlock(anchor_dist, d_model, num_heads, d_ff, dropout) )
986 | # blocks_list.append( EncoderBlock(d_model, num_heads, d_ff, dropout) )
987 | blocks_list.append( ReduceEncoderBlock(d_model, num_heads, d_ff, dropout) )
988 | # blocks_list.append( BasicLayerEncoderBlock(anchor_dist, d_model, num_heads, d_ff, dropout) )
989 |
990 | self.blocks = nn.Sequential(*blocks_list)
991 |
992 | self.posEmbedding = PositionalEncoding(d_model, dropout)
993 |
994 |
995 | def forward(self, x):
996 | x_w_pos = self.posEmbedding(x)
997 | x_updated = self.blocks(x_w_pos)
998 |
999 | return x_updated
1000 |
1001 |
1002 |
1003 | class RawInputTransEncoder(nn.Module):
1004 |
1005 | def __init__(self, anchor_dist, num_blocks, d_model, num_heads, d_ff, dropout=0.):
1006 | super().__init__()
1007 |
1008 | self.anchor_dist = anchor_dist
1009 |
1010 | self.linear = nn.Conv2d(3, d_model, 8, 8)
1011 |
1012 | blocks_list = []
1013 | for idx in range(num_blocks):
1014 | # blocks_list.append( AnchorEncoderBlock(anchor_dist, d_model, num_heads, d_ff, dropout) )
1015 | # blocks_list.append( EncoderBlock(d_model, num_heads, d_ff, dropout) )
1016 | # blocks_list.append( ReduceEncoderBlock(d_model, num_heads, d_ff, dropout) )
1017 | blocks_list.append( LayerEncoderBlock(anchor_dist, d_model, num_heads, d_ff, dropout) )
1018 |
1019 | self.blocks = nn.Sequential(*blocks_list)
1020 |
1021 | self.posEmbedding = PositionalEncoding(d_model, dropout)
1022 |
1023 | # initialization
1024 | for m in self.modules():
1025 | if isinstance(m, nn.Conv2d):
1026 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
1027 | elif isinstance(m, nn.Linear):
1028 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
1029 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm, nn.LayerNorm)):
1030 | if m.weight is not None:
1031 | nn.init.constant_(m.weight, 1)
1032 | if m.bias is not None:
1033 | nn.init.constant_(m.bias, 0)
1034 |
1035 |
1036 | def forward(self, x):
1037 | # if input is list, combine batch dimension
1038 | is_list = isinstance(x, tuple) or isinstance(x, list)
1039 | if is_list:
1040 | batch_dim = x[0].shape[0]
1041 | x = torch.cat(x, dim=0)
1042 |
1043 | x = self.linear(x)
1044 | x_w_pos = self.posEmbedding(x)
1045 | x_updated = self.blocks(x_w_pos)
1046 |
1047 | if is_list:
1048 | x_updated = torch.split(x_updated, [batch_dim, batch_dim], dim=0)
1049 |
1050 | return x_updated
1051 |
1052 |
1053 |
1054 | class GlobalLocalBlock(nn.Module):
1055 |
1056 | def __init__(self, anchor_dist, d_model, num_heads, out_dim, dropout=0., stride=1):
1057 | super().__init__()
1058 |
1059 | self.anchor_dist = anchor_dist
1060 | self.half_anchor_dist = anchor_dist // 2
1061 | self.d_model = d_model
1062 | self.out_dim = out_dim
1063 |
1064 | self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout)
1065 | self.dropout = nn.Dropout(dropout)
1066 | self.layer_norm_1 = nn.LayerNorm(d_model)
1067 |
1068 | self.resBlock_1 = ResidualBlock(d_model, d_model, norm_fn='instance', stride=stride)
1069 | self.change_channel = nn.Linear(d_model, out_dim)
1070 | self.resBlock_2 = ResidualBlock(out_dim, out_dim, norm_fn='instance', stride=1)
1071 |
1072 | self.posEmbedding = PositionalEncoding(d_model, dropout)
1073 |
1074 |
1075 | def forward(self, inputs):
1076 | '''
1077 | inputs: batches with N*H*W*C
1078 | '''
1079 |
1080 | # local update 1
1081 | x = self.resBlock_1(inputs)
1082 | x = self.posEmbedding(x)
1083 | anchors = x[:,:, self.half_anchor_dist::self.anchor_dist,
1084 | self.half_anchor_dist::self.anchor_dist].clone()
1085 |
1086 | # flatten feature maps
1087 | N, C, H, W = x.shape
1088 | x = x.reshape(N, C, H*W).transpose(-1,-2)
1089 | anchors = anchors.reshape(N, C, anchors.shape[2]* anchors.shape[3]).transpose(-1,-2)
1090 |
1091 | # gloabl update with two-stage multi-head self-attention
1092 | anchors_new = self.dropout(self.selfAttn(anchors, x, x)[0])
1093 | residual = self.dropout(self.selfAttn(x, anchors_new, anchors_new)[0])
1094 | norm_1 = self.layer_norm_1(x + residual)
1095 |
1096 | # local update 2
1097 | norm_1 = self.change_channel(norm_1)
1098 | norm_1 = norm_1.transpose(-1,-2).reshape(N, self.out_dim, H, W)
1099 | outputs = self.resBlock_2(norm_1)
1100 |
1101 | return outputs
1102 |
1103 |
1104 |
1105 | class GlobalLocalEncoder(nn.Module):
1106 |
1107 | def __init__(self, anchor_dist, output_dim, dropout=0.):
1108 | super().__init__()
1109 |
1110 | self.anchor_dist = anchor_dist
1111 | self.output_dim = output_dim
1112 |
1113 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
1114 | self.norm1 = nn.InstanceNorm2d(64)
1115 | self.relu1 = nn.ReLU(inplace=True)
1116 |
1117 | self.layer1 = GlobalLocalBlock(self.anchor_dist, 64, 2, 96, dropout, stride=2)
1118 | self.layer2 = GlobalLocalBlock(self.anchor_dist, 96, 3, 96, dropout, stride=1)
1119 | self.layer3 = GlobalLocalBlock(self.anchor_dist//2, 96, 4, 128, dropout, stride=2)
1120 |
1121 | # output convolution
1122 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
1123 |
1124 | self.dropout = None
1125 | if dropout > 0:
1126 | self.dropout = nn.Dropout2d(p=dropout)
1127 |
1128 | # initialization
1129 | for m in self.modules():
1130 | if isinstance(m, nn.Conv2d):
1131 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
1132 | elif isinstance(m, nn.Linear):
1133 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
1134 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm, nn.LayerNorm)):
1135 | if m.weight is not None:
1136 | nn.init.constant_(m.weight, 1)
1137 | if m.bias is not None:
1138 | nn.init.constant_(m.bias, 0)
1139 |
1140 |
1141 | # def _make_layer(self, in_dim, out_dim, dropout=0., stride=1):
1142 | # layer1 = GlobalLocalBlock(self.anchor_dist, in_dim, in_dim//32, out_dim, dropout=0., stride=stride)
1143 | # layer2 = GlobalLocalBlock(self.anchor_dist, out_dim, out_dim//32, out_dim, dropout=0., stride=1)
1144 | # layers = (layer1, layer2)
1145 |
1146 | # return nn.Sequential(*layers)
1147 |
1148 |
1149 | def forward(self, x):
1150 | # if input is list, combine batch dimension
1151 | is_list = isinstance(x, tuple) or isinstance(x, list)
1152 | if is_list:
1153 | batch_dim = x[0].shape[0]
1154 | x = torch.cat(x, dim=0)
1155 |
1156 | x = self.relu1(self.norm1(self.conv1(x)))
1157 |
1158 | x = self.layer1(x)
1159 | x = self.layer2(x)
1160 | x = self.layer3(x)
1161 |
1162 | x = self.conv2(x)
1163 |
1164 | if self.training and self.dropout is not None:
1165 | x = self.dropout(x)
1166 |
1167 | if is_list:
1168 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
1169 |
1170 | return x
--------------------------------------------------------------------------------
/core/gma.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange
4 |
5 |
6 | class RelPosEmb(nn.Module):
7 | def __init__(
8 | self,
9 | max_pos_size,
10 | dim_head
11 | ):
12 | super().__init__()
13 | self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head)
14 | self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head)
15 |
16 | deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1)
17 | rel_ind = deltas + max_pos_size - 1
18 | self.register_buffer('rel_ind', rel_ind)
19 |
20 | def forward(self, q):
21 | batch, heads, h, w, c = q.shape
22 | height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1))
23 | width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1))
24 |
25 | height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h)
26 | width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w)
27 |
28 | height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb)
29 | width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb)
30 |
31 | return height_score + width_score
32 |
33 |
34 | class Attention(nn.Module):
35 | def __init__(
36 | self,
37 | *,
38 | args,
39 | dim,
40 | max_pos_size = 100,
41 | heads = 4,
42 | dim_head = 128,
43 | ):
44 | super().__init__()
45 | self.args = args
46 | self.heads = heads
47 | self.scale = dim_head ** -0.5
48 | inner_dim = heads * dim_head
49 |
50 | self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
51 |
52 | self.pos_emb = RelPosEmb(max_pos_size, dim_head)
53 |
54 | def forward(self, fmap):
55 | heads, b, c, h, w = self.heads, *fmap.shape
56 |
57 | q, k = self.to_qk(fmap).chunk(2, dim=1)
58 |
59 | q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k))
60 | q = self.scale * q
61 |
62 | if self.args.position_only:
63 | sim = self.pos_emb(q)
64 |
65 | elif self.args.position_and_content:
66 | sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k)
67 | sim_pos = self.pos_emb(q)
68 | sim = sim_content + sim_pos
69 |
70 | else:
71 | sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k)
72 |
73 | sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)')
74 | attn = sim.softmax(dim=-1)
75 |
76 | return attn
77 |
78 |
79 | class Aggregate(nn.Module):
80 | def __init__(
81 | self,
82 | args,
83 | dim,
84 | heads = 4,
85 | dim_head = 128,
86 | ):
87 | super().__init__()
88 | self.args = args
89 | self.heads = heads
90 | self.scale = dim_head ** -0.5
91 | inner_dim = heads * dim_head
92 |
93 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False)
94 |
95 | self.gamma = nn.Parameter(torch.zeros(1))
96 |
97 | if dim != inner_dim:
98 | self.project = nn.Conv2d(inner_dim, dim, 1, bias=False)
99 | else:
100 | self.project = None
101 |
102 | def forward(self, attn, fmap):
103 | heads, b, c, h, w = self.heads, *fmap.shape
104 |
105 | v = self.to_v(fmap)
106 | v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads)
107 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
108 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
109 |
110 | if self.project is not None:
111 | out = self.project(out)
112 |
113 | out = fmap + self.gamma * out
114 |
115 | return out
116 |
117 |
118 | if __name__ == "__main__":
119 | att = Attention(dim=128, heads=1)
120 | fmap = torch.randn(2, 128, 40, 90)
121 | out = att(fmap)
122 |
123 | print(out.shape)
124 |
--------------------------------------------------------------------------------
/core/gmflownet_model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from update import BasicUpdateBlock
8 | from extractor import BasicEncoder, BasicConvEncoder
9 | from corr import CorrBlock, AlternateCorrBlock
10 | from utils.utils import bilinear_sampler, coords_grid, upflow8
11 | from swin_transformer import POLAUpdate, MixAxialPOLAUpdate
12 |
13 | try:
14 | autocast = torch.cuda.amp.autocast
15 | except:
16 | # dummy autocast for PyTorch < 1.6
17 | class autocast:
18 | def __init__(self, enabled):
19 | pass
20 | def __enter__(self):
21 | pass
22 | def __exit__(self, *args):
23 | pass
24 |
25 |
26 | class GMFlowNetModel(nn.Module):
27 | def __init__(self, args):
28 | super().__init__()
29 | self.args = args
30 |
31 | self.hidden_dim = hdim = 128
32 | self.context_dim = cdim = 128
33 | args.corr_levels = 4
34 | args.corr_radius = 4
35 |
36 | if not hasattr(self.args, 'dropout'):
37 | self.args.dropout = 0
38 |
39 | if not hasattr(self.args, 'alternate_corr'):
40 | self.args.alternate_corr = False
41 |
42 | # feature network, context network, and update block
43 | if self.args.use_mix_attn:
44 | self.fnet = nn.Sequential(
45 | BasicConvEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout),
46 | MixAxialPOLAUpdate(embed_dim=256, depth=6, num_head=8, window_size=7)
47 | )
48 | else:
49 | self.fnet = nn.Sequential(
50 | BasicConvEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout),
51 | POLAUpdate(embed_dim=256, depth=6, num_head=8, window_size=7, neig_win_num=1)
52 | )
53 |
54 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
55 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim, input_dim=cdim)
56 |
57 |
58 | def freeze_bn(self):
59 | for m in self.modules():
60 | if isinstance(m, nn.BatchNorm2d):
61 | m.eval()
62 |
63 | def initialize_flow(self, img):
64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
65 | N, C, H, W = img.shape
66 | coords0 = coords_grid(N, H//8, W//8).to(img.device)
67 | coords1 = coords_grid(N, H//8, W//8).to(img.device)
68 |
69 | # optical flow computed as difference: flow = coords1 - coords0
70 | return coords0, coords1
71 |
72 | def upsample_flow(self, flow, mask):
73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
74 | N, _, H, W = flow.shape
75 | mask = mask.view(N, 1, 9, 8, 8, H, W)
76 | mask = torch.softmax(mask, dim=2)
77 |
78 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
80 |
81 | up_flow = torch.sum(mask * up_flow, dim=2)
82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
83 | return up_flow.reshape(N, 2, 8*H, 8*W)
84 |
85 |
86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
87 | """ Estimate optical flow between pair of frames """
88 |
89 | image1 = 2 * (image1 / 255.0) - 1.0
90 | image2 = 2 * (image2 / 255.0) - 1.0
91 |
92 | image1 = image1.contiguous()
93 | image2 = image2.contiguous()
94 |
95 | hdim = self.hidden_dim
96 | cdim = self.context_dim
97 |
98 | # run the feature network
99 | with autocast(enabled=self.args.mixed_precision):
100 | fmap1, fmap2 = self.fnet([image1, image2])
101 |
102 | fmap1 = fmap1.float()
103 | fmap2 = fmap2.float()
104 |
105 | # # Self-attention update
106 | # fmap1 = self.transEncoder(fmap1)
107 | # fmap2 = self.transEncoder(fmap2)
108 |
109 | if self.args.alternate_corr:
110 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
111 | else:
112 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
113 |
114 | # run the context network
115 | with autocast(enabled=self.args.mixed_precision):
116 | cnet = self.cnet(image1)
117 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
118 | net = torch.tanh(net)
119 | inp = torch.relu(inp)
120 |
121 | coords0, coords1 = self.initialize_flow(image1)
122 |
123 | # Correlation as initialization
124 | N, fC, fH, fW = fmap1.shape
125 | corrMap = corr_fn.corrMap
126 |
127 | #_, coords_index = torch.max(corrMap, dim=-1) # no gradient here
128 | softCorrMap = F.softmax(corrMap, dim=2) * F.softmax(corrMap, dim=1) # (N, fH*fW, fH*fW)
129 |
130 | if flow_init is not None:
131 | coords1 = coords1 + flow_init
132 | else:
133 | # print('matching as init')
134 | # mutual match selection
135 | match12, match_idx12 = softCorrMap.max(dim=2) # (N, fH*fW)
136 | match21, match_idx21 = softCorrMap.max(dim=1)
137 |
138 | for b_idx in range(N):
139 | match21_b = match21[b_idx,:]
140 | match_idx12_b = match_idx12[b_idx,:]
141 | match21[b_idx,:] = match21_b[match_idx12_b]
142 |
143 | matched = (match12 - match21) == 0 # (N, fH*fW)
144 | coords_index = torch.arange(fH*fW).unsqueeze(0).repeat(N,1).to(softCorrMap.device)
145 | coords_index[matched] = match_idx12[matched]
146 |
147 | # matched coords
148 | coords_index = coords_index.reshape(N, fH, fW)
149 | coords_x = coords_index % fW
150 | coords_y = coords_index // fW
151 |
152 | coords_xy = torch.stack([coords_x, coords_y], dim=1).float()
153 | coords1 = coords_xy
154 |
155 | # Iterative update
156 | flow_predictions = []
157 | for itr in range(iters):
158 | coords1 = coords1.detach()
159 | corr = corr_fn(coords1) # index correlation volume
160 |
161 | flow = coords1 - coords0
162 | with autocast(enabled=self.args.mixed_precision):
163 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
164 |
165 | # F(t+1) = F(t) + \Delta(t)
166 | coords1 = coords1 + delta_flow
167 |
168 | # upsample predictions
169 | if up_mask is None:
170 | flow_up = upflow8(coords1 - coords0)
171 | else:
172 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
173 |
174 | flow_predictions.append(flow_up)
175 |
176 | if test_mode:
177 | return coords1 - coords0, flow_up
178 |
179 | return flow_predictions, softCorrMap
180 |
--------------------------------------------------------------------------------
/core/loss.py:
--------------------------------------------------------------------------------
1 | # from operator import index
2 | import torch
3 | import sys
4 | import numpy as np
5 | # from .utils.iamge_utils import save_flow, save_image
6 | import torch.nn.functional as F
7 | # from .utils.kornia import create_meshgrid
8 |
9 |
10 | from typing import Optional
11 | def create_meshgrid(
12 | height: int,
13 | width: int,
14 | normalized_coordinates: bool = True,
15 | device: Optional[torch.device] = torch.device('cpu'),
16 | dtype: torch.dtype = torch.float32,
17 | ) -> torch.Tensor:
18 | """Generate a coordinate grid for an image.
19 |
20 | When the flag ``normalized_coordinates`` is set to True, the grid is
21 | normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch
22 | function :py:func:`torch.nn.functional.grid_sample`.
23 |
24 | Args:
25 | height: the image height (rows).
26 | width: the image width (cols).
27 | normalized_coordinates: whether to normalize
28 | coordinates in the range :math:`[-1,1]` in order to be consistent with the
29 | PyTorch function :py:func:`torch.nn.functional.grid_sample`.
30 | device: the device on which the grid will be generated.
31 | dtype: the data type of the generated grid.
32 |
33 | Return:
34 | grid tensor with shape :math:`(1, H, W, 2)`.
35 |
36 | Example:
37 | >>> create_meshgrid(2, 2)
38 | tensor([[[[-1., -1.],
39 | [ 1., -1.]],
40 |
41 | [[-1., 1.],
42 | [ 1., 1.]]]])
43 |
44 | >>> create_meshgrid(2, 2, normalized_coordinates=False)
45 | tensor([[[[0., 0.],
46 | [1., 0.]],
47 |
48 | [[0., 1.],
49 | [1., 1.]]]])
50 | """
51 | xs: torch.Tensor = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
52 | ys: torch.Tensor = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
53 | # Fix TracerWarning
54 | # Note: normalize_pixel_coordinates still gots TracerWarning since new width and height
55 | # tensors will be generated.
56 | # Below is the code using normalize_pixel_coordinates:
57 | # base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys]), dim=2)
58 | # if normalized_coordinates:
59 | # base_grid = K.geometry.normalize_pixel_coordinates(base_grid, height, width)
60 | # return torch.unsqueeze(base_grid.transpose(0, 1), dim=0)
61 | if normalized_coordinates:
62 | xs = (xs / (width - 1) - 0.5) * 2
63 | ys = (ys / (height - 1) - 0.5) * 2
64 | # generate grid by stacking coordinates
65 | base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys])).transpose(1, 2) # 2xHxW
66 | return torch.unsqueeze(base_grid, dim=0).permute(0, 2, 3, 1) # 1xHxWx2
67 |
68 |
69 | def backwarp(img, flow):
70 | _, _, H, W = img.size()
71 |
72 | u = flow[:, 0, :, :]
73 | v = flow[:, 1, :, :]
74 |
75 | gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
76 | gridX = torch.tensor(gridX, requires_grad=False,).to(flow.device)
77 | gridY = torch.tensor(gridY, requires_grad=False,).to(flow.device)
78 | x = gridX.unsqueeze(0).expand_as(u).float() + u
79 | y = gridY.unsqueeze(0).expand_as(v).float() + v
80 | # range -1 to 1
81 | x = 2*(x/W - 0.5)
82 | y = 2*(y/H - 0.5)
83 | # stacking X and Y
84 | grid = torch.stack((x,y), dim=3)
85 | # Sample pixels using bilinear interpolation.
86 | imgOut = torch.nn.functional.grid_sample(img, grid, align_corners=False)
87 |
88 | return imgOut
89 |
90 |
91 | @torch.no_grad()
92 | def compute_supervision_coarse(flow, occlusions, scale: int):
93 | N, _, H, W = flow.shape
94 | Hc, Wc = int(np.ceil(H / scale)), int(np.ceil(W / scale))
95 |
96 | occlusions_c = occlusions[:, :, ::scale, ::scale]
97 | flow_c = flow[:, :, ::scale, ::scale] / scale
98 | occlusions_c = occlusions_c.reshape(N, Hc * Wc)
99 |
100 | grid_c = create_meshgrid(Hc, Wc, False, device=flow.device).reshape(1, Hc * Wc, 2).repeat(N, 1, 1)
101 | warp_c = grid_c + flow_c.permute(0, 2, 3, 1).reshape(N, Hc * Wc, 2)
102 | warp_c = warp_c.round().long()
103 |
104 | def out_bound_mask(pt, w, h):
105 | return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
106 |
107 | occlusions_c[out_bound_mask(warp_c, Wc, Hc)] = 1
108 | warp_c = warp_c[..., 0] + warp_c[..., 1] * Wc
109 |
110 | b_ids, i_ids = torch.split(torch.nonzero(occlusions_c == 0), 1, dim=1)
111 | conf_matrix_gt = torch.zeros(N, Hc * Wc, Hc * Wc, device=flow.device)
112 | j_ids = warp_c[b_ids, i_ids]
113 | conf_matrix_gt[b_ids, i_ids, j_ids] = 1
114 |
115 | return conf_matrix_gt
116 |
117 |
118 | def compute_coarse_loss(conf, conf_gt, cfg):
119 | c_pos_w, c_neg_w = cfg.POS_WEIGHT, cfg.NEG_WEIGHT
120 | pos_mask, neg_mask = conf_gt == 1, conf_gt == 0
121 |
122 | if cfg.COARSE_TYPE == 'cross_entropy':
123 | conf = torch.clamp(conf, 1e-6, 1 - 1e-6)
124 | loss_pos = -torch.log(conf[pos_mask])
125 | loss_neg = -torch.log(1 - conf[neg_mask])
126 |
127 | return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
128 | elif cfg.COARSE_TYPE == 'focal':
129 | conf = torch.clamp(conf, 1e-6, 1 - 1e-6)
130 | alpha = cfg.FOCAL_ALPHA
131 | gamma = cfg.FOCAL_GAMMA
132 | loss_pos = -alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log()
133 | loss_neg = -alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log()
134 | return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
135 | else:
136 | raise ValueError('Unknown coarse loss: {type}'.format(type=cfg.COARSE_TYPE))
137 |
138 |
139 | def compute_fine_loss(kflow, kflow_gt, cfg):
140 | fine_correct_thr = cfg.WINDOW_SIZE // 2 * 2
141 | error = (kflow - kflow_gt).abs()
142 | correct = torch.max(error, dim=1)[0] < fine_correct_thr
143 | rate = torch.sum(correct).float() / correct.shape[0]
144 | num = correct.shape[0]
145 | return error[correct].mean(), rate.item(), num
146 |
147 |
148 | def compute_flow_loss(flow, flow_gt):
149 | loss = (flow - flow_gt).abs().mean()
150 | epe = torch.sum((flow - flow_gt)**2, dim=1).sqrt()
151 |
152 | metrics = {
153 | 'epe': epe.mean().item(),
154 | '1px': (epe < 1).float().mean().item(),
155 | '3px': (epe < 3).float().mean().item(),
156 | '5px': (epe < 5).float().mean().item(),
157 | }
158 |
159 | return loss, metrics
160 |
--------------------------------------------------------------------------------
/core/onecyclelr.py:
--------------------------------------------------------------------------------
1 | import types
2 | import math
3 | from torch._six import inf
4 | from functools import wraps
5 | import warnings
6 | import weakref
7 | from collections import Counter
8 | from bisect import bisect_right
9 |
10 | EPOCH_DEPRECATION_WARNING = (
11 | "The epoch parameter in `scheduler.step()` was not necessary and is being "
12 | "deprecated where possible. Please use `scheduler.step()` to step the "
13 | "scheduler. During the deprecation, if epoch is different from None, the "
14 | "closed form is used instead of the new chainable form, where available. "
15 | "Please open an issue if you are unable to replicate your use case: "
16 | "https://github.com/pytorch/pytorch/issues/new/choose."
17 | )
18 |
19 | class _LRScheduler(object):
20 |
21 | def __init__(self, optimizer, last_epoch=-1):
22 |
23 | # Attach optimizer
24 | # if not isinstance(optimizer, Optimizer):
25 | # raise TypeError('{} is not an Optimizer'.format(
26 | # type(optimizer).__name__))
27 | self.optimizer = optimizer
28 |
29 | # Initialize epoch and base learning rates
30 | if last_epoch == -1:
31 | for group in optimizer.param_groups:
32 | group.setdefault('initial_lr', group['lr'])
33 | else:
34 | for i, group in enumerate(optimizer.param_groups):
35 | if 'initial_lr' not in group:
36 | raise KeyError("param 'initial_lr' is not specified "
37 | "in param_groups[{}] when resuming an optimizer".format(i))
38 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
39 | self.last_epoch = last_epoch
40 |
41 | # Following https://github.com/pytorch/pytorch/issues/20124
42 | # We would like to ensure that `lr_scheduler.step()` is called after
43 | # `optimizer.step()`
44 | def with_counter(method):
45 | if getattr(method, '_with_counter', False):
46 | # `optimizer.step()` has already been replaced, return.
47 | return method
48 |
49 | # Keep a weak reference to the optimizer instance to prevent
50 | # cyclic references.
51 | instance_ref = weakref.ref(method.__self__)
52 | # Get the unbound method for the same purpose.
53 | func = method.__func__
54 | cls = instance_ref().__class__
55 | del method
56 |
57 | @wraps(func)
58 | def wrapper(*args, **kwargs):
59 | instance = instance_ref()
60 | instance._step_count += 1
61 | wrapped = func.__get__(instance, cls)
62 | return wrapped(*args, **kwargs)
63 |
64 | # Note that the returned function here is no longer a bound method,
65 | # so attributes like `__func__` and `__self__` no longer exist.
66 | wrapper._with_counter = True
67 | return wrapper
68 |
69 | self.optimizer.step = with_counter(self.optimizer.step)
70 | self.optimizer._step_count = 0
71 | self._step_count = 0
72 |
73 | self.step()
74 |
75 | def state_dict(self):
76 | """Returns the state of the scheduler as a :class:`dict`.
77 |
78 | It contains an entry for every variable in self.__dict__ which
79 | is not the optimizer.
80 | """
81 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
82 |
83 | def load_state_dict(self, state_dict):
84 | """Loads the schedulers state.
85 |
86 | Arguments:
87 | state_dict (dict): scheduler state. Should be an object returned
88 | from a call to :meth:`state_dict`.
89 | """
90 | self.__dict__.update(state_dict)
91 |
92 | def get_last_lr(self):
93 | """ Return last computed learning rate by current scheduler.
94 | """
95 | return self._last_lr
96 |
97 | def get_lr(self):
98 | # Compute learning rate using chainable form of the scheduler
99 | raise NotImplementedError
100 |
101 | def step(self, epoch=None):
102 | # Raise a warning if old pattern is detected
103 | # https://github.com/pytorch/pytorch/issues/20124
104 | if self._step_count == 1:
105 | if not hasattr(self.optimizer.step, "_with_counter"):
106 | warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
107 | "initialization. Please, make sure to call `optimizer.step()` before "
108 | "`lr_scheduler.step()`. See more details at "
109 | "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
110 |
111 | # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
112 | elif self.optimizer._step_count < 1:
113 | warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
114 | "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
115 | "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
116 | "will result in PyTorch skipping the first value of the learning rate schedule. "
117 | "See more details at "
118 | "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
119 | self._step_count += 1
120 |
121 | class _enable_get_lr_call:
122 |
123 | def __init__(self, o):
124 | self.o = o
125 |
126 | def __enter__(self):
127 | self.o._get_lr_called_within_step = True
128 | return self
129 |
130 | def __exit__(self, type, value, traceback):
131 | self.o._get_lr_called_within_step = False
132 |
133 | with _enable_get_lr_call(self):
134 | if epoch is None:
135 | self.last_epoch += 1
136 | values = self.get_lr()
137 | else:
138 | warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
139 | self.last_epoch = epoch
140 | if hasattr(self, "_get_closed_form_lr"):
141 | values = self._get_closed_form_lr()
142 | else:
143 | values = self.get_lr()
144 |
145 | for param_group, lr in zip(self.optimizer.param_groups, values):
146 | param_group['lr'] = lr
147 |
148 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
149 |
150 | class OneCycleLR(_LRScheduler):
151 | r"""Sets the learning rate of each parameter group according to the
152 | 1cycle learning rate policy. The 1cycle policy anneals the learning
153 | rate from an initial learning rate to some maximum learning rate and then
154 | from that maximum learning rate to some minimum learning rate much lower
155 | than the initial learning rate.
156 | This policy was initially described in the paper `Super-Convergence:
157 | Very Fast Training of Neural Networks Using Large Learning Rates`_.
158 |
159 | The 1cycle learning rate policy changes the learning rate after every batch.
160 | `step` should be called after a batch has been used for training.
161 |
162 | This scheduler is not chainable.
163 |
164 | Note also that the total number of steps in the cycle can be determined in one
165 | of two ways (listed in order of precedence):
166 |
167 | #. A value for total_steps is explicitly provided.
168 | #. A number of epochs (epochs) and a number of steps per epoch
169 | (steps_per_epoch) are provided.
170 | In this case, the number of total steps is inferred by
171 | total_steps = epochs * steps_per_epoch
172 |
173 | You must either provide a value for total_steps or provide a value for both
174 | epochs and steps_per_epoch.
175 |
176 | Args:
177 | optimizer (Optimizer): Wrapped optimizer.
178 | max_lr (float or list): Upper learning rate boundaries in the cycle
179 | for each parameter group.
180 | total_steps (int): The total number of steps in the cycle. Note that
181 | if a value is not provided here, then it must be inferred by providing
182 | a value for epochs and steps_per_epoch.
183 | Default: None
184 | epochs (int): The number of epochs to train for. This is used along
185 | with steps_per_epoch in order to infer the total number of steps in the cycle
186 | if a value for total_steps is not provided.
187 | Default: None
188 | steps_per_epoch (int): The number of steps per epoch to train for. This is
189 | used along with epochs in order to infer the total number of steps in the
190 | cycle if a value for total_steps is not provided.
191 | Default: None
192 | pct_start (float): The percentage of the cycle (in number of steps) spent
193 | increasing the learning rate.
194 | Default: 0.3
195 | anneal_strategy (str): {'cos', 'linear'}
196 | Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
197 | linear annealing.
198 | Default: 'cos'
199 | cycle_momentum (bool): If ``True``, momentum is cycled inversely
200 | to learning rate between 'base_momentum' and 'max_momentum'.
201 | Default: True
202 | base_momentum (float or list): Lower momentum boundaries in the cycle
203 | for each parameter group. Note that momentum is cycled inversely
204 | to learning rate; at the peak of a cycle, momentum is
205 | 'base_momentum' and learning rate is 'max_lr'.
206 | Default: 0.85
207 | max_momentum (float or list): Upper momentum boundaries in the cycle
208 | for each parameter group. Functionally,
209 | it defines the cycle amplitude (max_momentum - base_momentum).
210 | Note that momentum is cycled inversely
211 | to learning rate; at the start of a cycle, momentum is 'max_momentum'
212 | and learning rate is 'base_lr'
213 | Default: 0.95
214 | div_factor (float): Determines the initial learning rate via
215 | initial_lr = max_lr/div_factor
216 | Default: 25
217 | final_div_factor (float): Determines the minimum learning rate via
218 | min_lr = initial_lr/final_div_factor
219 | Default: 1e4
220 | last_epoch (int): The index of the last batch. This parameter is used when
221 | resuming a training job. Since `step()` should be invoked after each
222 | batch instead of after each epoch, this number represents the total
223 | number of *batches* computed, not the total number of epochs computed.
224 | When last_epoch=-1, the schedule is started from the beginning.
225 | Default: -1
226 |
227 | Example:
228 | >>> data_loader = torch.utils.data.DataLoader(...)
229 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
230 | >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
231 | >>> for epoch in range(10):
232 | >>> for batch in data_loader:
233 | >>> train_batch(...)
234 | >>> scheduler.step()
235 |
236 |
237 | .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
238 | https://arxiv.org/abs/1708.07120
239 | """
240 | def __init__(self,
241 | optimizer,
242 | max_lr,
243 | total_steps=None,
244 | epochs=None,
245 | steps_per_epoch=None,
246 | pct_start=0.3,
247 | anneal_strategy='cos',
248 | cycle_momentum=True,
249 | base_momentum=0.85,
250 | max_momentum=0.95,
251 | div_factor=25.,
252 | final_div_factor=1e4,
253 | last_epoch=-1):
254 |
255 | # Validate optimizer
256 | # if not isinstance(optimizer, Optimizer):
257 | # raise TypeError('{} is not an Optimizer'.format(
258 | # type(optimizer).__name__))
259 | self.optimizer = optimizer
260 |
261 | # Validate total_steps
262 | if total_steps is None and epochs is None and steps_per_epoch is None:
263 | raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
264 | elif total_steps is not None:
265 | if total_steps <= 0 or not isinstance(total_steps, int):
266 | raise ValueError("Expected non-negative integer total_steps, but got {}".format(total_steps))
267 | self.total_steps = total_steps
268 | else:
269 | if epochs <= 0 or not isinstance(epochs, int):
270 | raise ValueError("Expected non-negative integer epochs, but got {}".format(epochs))
271 | if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
272 | raise ValueError("Expected non-negative integer steps_per_epoch, but got {}".format(steps_per_epoch))
273 | self.total_steps = epochs * steps_per_epoch
274 | self.step_size_up = float(pct_start * self.total_steps) - 1
275 | self.step_size_down = float(self.total_steps - self.step_size_up) - 1
276 |
277 | # Validate pct_start
278 | if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
279 | raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
280 |
281 | # Validate anneal_strategy
282 | if anneal_strategy not in ['cos', 'linear']:
283 | raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
284 | elif anneal_strategy == 'cos':
285 | self.anneal_func = self._annealing_cos
286 | elif anneal_strategy == 'linear':
287 | self.anneal_func = self._annealing_linear
288 |
289 | # Initialize learning rate variables
290 | max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
291 | for idx, group in enumerate(self.optimizer.param_groups):
292 | group['initial_lr'] = max_lrs[idx] / div_factor
293 | group['max_lr'] = max_lrs[idx]
294 | group['min_lr'] = group['initial_lr'] / final_div_factor
295 |
296 | # Initialize momentum variables
297 | self.cycle_momentum = cycle_momentum
298 | if self.cycle_momentum:
299 | if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
300 | raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
301 | self.use_beta1 = 'betas' in self.optimizer.defaults
302 | max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
303 | base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
304 | if last_epoch == -1:
305 | for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
306 | if self.use_beta1:
307 | _, beta2 = group['betas']
308 | group['betas'] = (m_momentum, beta2)
309 | else:
310 | group['momentum'] = m_momentum
311 | group['max_momentum'] = m_momentum
312 | group['base_momentum'] = b_momentum
313 |
314 | super(OneCycleLR, self).__init__(optimizer, last_epoch)
315 |
316 | def _format_param(self, name, optimizer, param):
317 | """Return correctly formatted lr/momentum for each param group."""
318 | if isinstance(param, (list, tuple)):
319 | if len(param) != len(optimizer.param_groups):
320 | raise ValueError("expected {} values for {}, got {}".format(
321 | len(optimizer.param_groups), name, len(param)))
322 | return param
323 | else:
324 | return [param] * len(optimizer.param_groups)
325 |
326 | def _annealing_cos(self, start, end, pct):
327 | "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
328 | cos_out = math.cos(math.pi * pct) + 1
329 | return end + (start - end) / 2.0 * cos_out
330 |
331 | def _annealing_linear(self, start, end, pct):
332 | "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
333 | return (end - start) * pct + start
334 |
335 | def get_lr(self):
336 | if not self._get_lr_called_within_step:
337 | warnings.warn("To get the last learning rate computed by the scheduler, "
338 | "please use `get_last_lr()`.", UserWarning)
339 |
340 | lrs = []
341 | step_num = self.last_epoch
342 |
343 | if step_num > self.total_steps:
344 | raise ValueError("Tried to step {} times. The specified number of total steps is {}"
345 | .format(step_num + 1, self.total_steps))
346 |
347 | for group in self.optimizer.param_groups:
348 | if step_num <= self.step_size_up:
349 | computed_lr = self.anneal_func(group['initial_lr'], group['max_lr'], step_num / self.step_size_up)
350 | if self.cycle_momentum:
351 | computed_momentum = self.anneal_func(group['max_momentum'], group['base_momentum'],
352 | step_num / self.step_size_up)
353 | else:
354 | down_step_num = step_num - self.step_size_up
355 | computed_lr = self.anneal_func(group['max_lr'], group['min_lr'], down_step_num / self.step_size_down)
356 | if self.cycle_momentum:
357 | computed_momentum = self.anneal_func(group['base_momentum'], group['max_momentum'],
358 | down_step_num / self.step_size_down)
359 |
360 | lrs.append(computed_lr)
361 | if self.cycle_momentum:
362 | if self.use_beta1:
363 | _, beta2 = group['betas']
364 | group['betas'] = (computed_momentum, beta2)
365 | else:
366 | group['momentum'] = computed_momentum
367 |
368 | return lrs
369 |
--------------------------------------------------------------------------------
/core/onecyclelr.py.save:
--------------------------------------------------------------------------------
1 | import types
2 | import math
3 | from torch._six import inf
4 | from functools import wraps
5 | import warnings
6 | import weakref
7 | from collections import Counter
8 | from bisect import bisect_right
9 |
10 | EPOCH_DEPRECATION_WARNING = (
11 | "The epoch parameter in `scheduler.step()` was not necessary and is being "
12 | "deprecated where possible. Please use `scheduler.step()` to step the "
13 | "scheduler. During the deprecation, if epoch is different from None, the "
14 | "closed form is used instead of the new chainable form, where available. "
15 | "Please open an issue if you are unable to replicate your use case: "
16 | "https://github.com/pytorch/pytorch/issues/new/choose."
17 | )
18 |
19 | class _LRScheduler(object):
20 |
21 | def __init__(self, optimizer, last_epoch=-1):
22 |
23 | # Attach optimizer
24 | # if not isinstance(optimizer, Optimizer):
25 | # raise TypeError('{} is not an Optimizer'.format(
26 | # type(optimizer).__name__))
27 | self.optimizer = optimizer
28 |
29 | # Initialize epoch and base learning rates
30 | if last_epoch == -1:
31 | for group in optimizer.param_groups:
32 | group.setdefault('initial_lr', group['lr'])
33 | else:
34 | for i, group in enumerate(optimizer.param_groups):
35 | if 'initial_lr' not in group:
36 | raise KeyError("param 'initial_lr' is not specified "
37 | "in param_groups[{}] when resuming an optimizer".format(i))
38 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
39 | self.last_epoch = last_epoch
40 |
41 | # Following https://github.com/pytorch/pytorch/issues/20124
42 | # We would like to ensure that `lr_scheduler.step()` is called after
43 | # `optimizer.step()`
44 | def with_counter(method):
45 | if getattr(method, '_with_counter', False):
46 | # `optimizer.step()` has already been replaced, return.
47 | return method
48 |
49 | # Keep a weak reference to the optimizer instance to prevent
50 | # cyclic references.
51 | instance_ref = weakref.ref(method.__self__)
52 | # Get the unbound method for the same purpose.
53 | func = method.__func__
54 | cls = instance_ref().__class__
55 | del method
56 |
57 | @wraps(func)
58 | def wrapper(*args, **kwargs):
59 | instance = instance_ref()
60 | instance._step_count += 1
61 | wrapped = func.__get__(instance, cls)
62 | return wrapped(*args, **kwargs)
63 |
64 | # Note that the returned function here is no longer a bound method,
65 | # so attributes like `__func__` and `__self__` no longer exist.
66 | wrapper._with_counter = True
67 | return wrapper
68 |
69 | self.optimizer.step = with_counter(self.optimizer.step)
70 | self.optimizer._step_count = 0
71 | self._step_count = 0
72 |
73 | self.step()
74 |
75 | def state_dict(self):
76 | """Returns the state of the scheduler as a :class:`dict`.
77 |
78 | It contains an entry for every variable in self.__dict__ which
79 | is not the optimizer.
80 | """
81 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
82 |
83 | def load_state_dict(self, state_dict):
84 | """Loads the schedulers state.
85 |
86 | Arguments:
87 | state_dict (dict): scheduler state. Should be an object returned
88 | from a call to :meth:`state_dict`.
89 | """
90 | self.__dict__.update(state_dict)
91 |
92 | def get_last_lr(self):
93 | """ Return last computed learning rate by current scheduler.
94 | """
95 | return self._last_lr
96 |
97 | def get_lr(self):
98 | # Compute learning rate using chainable form of the scheduler
99 | raise NotImplementedError
100 |
101 | def step(self, epoch=None):
102 | # Raise a warning if old pattern is detected
103 | # https://github.com/pytorch/pytorch/issues/20124
104 | if self._step_count == 1:
105 | if not hasattr(self.optimizer.step, "_with_counter"):
106 | warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
107 | "initialization. Please, make sure to call `optimizer.step()` before "
108 | "`lr_scheduler.step()`. See more details at "
109 | "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
110 |
111 | # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
112 | elif self.optimizer._step_count < 1:
113 | warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
114 | "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
115 | "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
116 | "will result in PyTorch skipping the first value of the learning rate schedule. "
117 | "See more details at "
118 | "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
119 | self._step_count += 1
120 |
121 | class _enable_get_lr_call:
122 |
123 | def __init__(self, o):
124 | self.o = o
125 |
126 | def __enter__(self):
127 | self.o._get_lr_called_within_step = True
128 | return self
129 |
130 | def __exit__(self, type, value, traceback):
131 | self.o._get_lr_called_within_step = False
132 |
133 | with _enable_get_lr_call(self):
134 | if epoch is None:
135 | self.last_epoch += 1
136 | values = self.get_lr()
137 | else:
138 | warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
139 | self.last_epoch = epoch
140 | if hasattr(self, "_get_closed_form_lr"):
141 | values = self._get_closed_form_lr()
142 | else:
143 | values = self.get_lr()
144 |
145 | for param_group, lr in zip(self.optimizer.param_groups, values):
146 | param_group['lr'] = lr
147 |
148 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
149 |
150 | class OneCycleLR(_LRScheduler):
151 | r"""Sets the learning rate of each parameter group according to the
152 | 1cycle learning rate policy. The 1cycle policy anneals the learning
153 | rate from an initial learning rate to some maximum learning rate and then
154 | from that maximum learning rate to some minimum learning rate much lower
155 | than the initial learning rate.
156 | This policy was initially described in the paper `Super-Convergence:
157 | Very Fast Training of Neural Networks Using Large Learning Rates`_.
158 |
159 | The 1cycle learning rate policy changes the learning rate after every batch.
160 | `step` should be called after a batch has been used for training.
161 |
162 | This scheduler is not chainable.
163 |
164 | Note also that the total number of steps in the cycle can be determined in one
165 | of two ways (listed in order of precedence):
166 |
167 | #. A value for total_steps is explicitly provided.
168 | #. A number of epochs (epochs) and a number of steps per epoch
169 | (steps_per_epoch) are provided.
170 | In this case, the number of total steps is inferred by
171 | total_steps = epochs * steps_per_epoch
172 |
173 | You must either provide a value for total_steps or provide a value for both
174 | epochs and steps_per_epoch.
175 |
176 | Args:
177 | optimizer (Optimizer): Wrapped optimizer.
178 | max_lr (float or list): Upper learning rate boundaries in the cycle
179 | for each parameter group.
180 | total_steps (int): The total number of steps in the cycle. Note that
181 | if a value is not provided here, then it must be inferred by providing
182 | a value for epochs and steps_per_epoch.
183 | Default: None
184 | epochs (int): The number of epochs to train for. This is used along
185 | with steps_per_epoch in order to infer the total number of steps in the cycle
186 | if a value for total_steps is not provided.
187 | Default: None
188 | steps_per_epoch (int): The number of steps per epoch to train for. This is
189 | used along with epochs in order to infer the total number of steps in the
190 | cycle if a value for total_steps is not provided.
191 | Default: None
192 | pct_start (float): The percentage of the cycle (in number of steps) spent
193 | increasing the learning rate.
194 | Default: 0.3
195 | anneal_strategy (str): {'cos', 'linear'}
196 | Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
197 | linear annealing.
198 | Default: 'cos'
199 | cycle_momentum (bool): If ``True``, momentum is cycled inversely
200 | to learning rate between 'base_momentum' and 'max_momentum'.
201 | Default: True
202 | base_momentum (float or list): Lower momentum boundaries in the cycle
203 | for each parameter group. Note that momentum is cycled inversely
204 | to learning rate; at the peak of a cycle, momentum is
205 | 'base_momentum' and learning rate is 'max_lr'.
206 | Default: 0.85
207 | max_momentum (float or list): Upper momentum boundaries in the cycle
208 | for each parameter group. Functionally,
209 | it defines the cycle amplitude (max_momentum - base_momentum).
210 | Note that momentum is cycled inversely
211 | to learning rate; at the start of a cycle, momentum is 'max_momentum'
212 | and learning rate is 'base_lr'
213 | Default: 0.95
214 | div_factor (float): Determines the initial learning rate via
215 | initial_lr = max_lr/div_factor
216 | Default: 25
217 | final_div_factor (float): Determines the minimum learning rate via
218 | min_lr = initial_lr/final_div_factor
219 | Default: 1e4
220 | last_epoch (int): The index of the last batch. This parameter is used when
221 | resuming a training job. Since `step()` should be invoked after each
222 | batch instead of after each epoch, this number represents the total
223 | number of *batches* computed, not the total number of epochs computed.
224 | When last_epoch=-1, the schedule is started from the beginning.
225 | Default: -1
226 |
227 | Example:
228 | >>> data_loader = torch.utils.data.DataLoader(...)
229 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
230 | >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
231 | >>> for epoch in range(10):
232 | >>> for batch in data_loader:
233 | >>> train_batch(...)
234 | >>> scheduler.step()
235 |
236 |
237 | .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
238 | https://arxiv.org/abs/1708.07120
239 | """
240 | def __init__(self,
241 | optimizer,
242 | max_lr,
243 | total_steps=None,
244 | epochs=None,
245 | steps_per_epoch=None,
246 | pct_start=0.3,
247 | anneal_strategy='cos',
248 | cycle_momentum=True,
249 | base_momentum=0.85,
250 | max_momentum=0.95,
251 | div_factor=25.,
252 | final_div_factor=1e4,
253 | last_epoch=-1):
254 |
255 | # Validate optimizer
256 | # if not isinstance(optimizer, Optimizer):
257 | # raise TypeError('{} is not an Optimizer'.format(
258 | # type(optimizer).__name__))
259 | self.optimizer = optimizer
260 |
261 | # Validate total_steps
262 | if total_steps is None and epochs is None and steps_per_epoch is None:
263 | raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
264 | elif total_steps is not None:
265 | if total_steps <= 0 or not isinstance(total_steps, int):
266 | raise ValueError("Expected non-negative integer total_steps, but got {}".format(total_steps))
267 | self.total_steps = total_steps
268 | else:
269 | if epochs <= 0 or not isinstance(epochs, int):
270 | raise ValueError("Expected non-negative integer epochs, but got {}".format(epochs))
271 | if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
272 | raise ValueError("Expected non-negative integer steps_per_epoch, but got {}".format(steps_per_epoch))
273 | self.total_steps = epochs * steps_per_epoch
274 | self.step_size_up = float(pct_start * self.total_steps) - 1
275 | self.step_size_down = float(self.total_steps - self.step_size_up) - 1
276 |
277 | # Validate pct_start
278 | if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
279 | raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
280 |
281 | # Validate anneal_strategy
282 | if anneal_strategy not in ['cos', 'linear']:
283 | raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
284 | elif anneal_strategy == 'cos':
285 | self.anneal_func = self._annealing_cos
286 | elif anneal_strategy == 'linear':
287 | self.anneal_func = self._annealing_linear
288 |
289 | # Initialize learning rate variables
290 | print('groups before: ', self.optimizer.param_groups[0].keys())
291 |
292 | max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
293 | if last_epoch == -1:
294 | for idx, group in enumerate(self.optimizer.param_groups):
295 | self.optimizer.param_groups[idx]['initial_lr'] = max_lrs[idx] / div_factor
296 | self.optimizer.param_groups[idx]['max_lr'] = max_lrs[idx]
297 | self.optimizer.param_groups[idx]['min_lr'] = group['initial_lr'] / final_div_factor
298 |
299 | print('groups: ', self.optimizer.param_groups[0].keys())
300 | exit()
301 |
302 | # Initialize momentum variables
303 | self.cycle_momentum = cycle_momentum
304 | if self.cycle_momentum:
305 | if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
306 | raise Valuerror('optimizer must support momentum with `cycle_momentum` option enabled')
307 | self.use_beta1 = 'betas' in self.optimizer.defaults
308 | max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
309 | base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
310 | if last_epoch == -1:
311 | for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
312 | if self.use_beta1:
313 | _, beta2 = group['betas']
314 | group['betas'] = (m_momentum, beta2)
315 | else:
316 | group['momentum'] = m_momentum
317 | group['max_momentum'] = m_momentum
318 | group['base_momentum'] = b_momentum
319 |
320 | super(OneCycleLR, self).__init__(optimizer, last_epoch)
321 |
322 | def _format_param(self, name, optimizer, param):
323 | """Return correctly formatted lr/momentum for each param group."""
324 | if isinstance(param, (list, tuple)):
325 | if len(param) != len(optimizer.param_groups):
326 | raise ValueError("expected {} values for {}, got {}".format(
327 | len(optimizer.param_groups), name, len(param)))
328 | return param
329 | else:
330 | return [param] * len(optimizer.param_groups)
331 |
332 | def _annealing_cos(self, start, end, pct):
333 | "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
334 | cos_out = math.cos(math.pi * pct) + 1
335 | return end + (start - end) / 2.0 * cos_out
336 |
337 | def _annealing_linear(self, start, end, pct):
338 | "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
339 | return (end - start) * pct + start
340 |
341 | def get_lr(self):
342 | if not self._get_lr_called_within_step:
343 | warnings.warn("To get the last learning rate computed by the scheduler, "
344 | "please use `get_last_lr()`.", UserWarning)
345 |
346 | lrs = []
347 | step_num = self.last_epoch
348 |
349 | if step_num > self.total_steps:
350 | raise ValueError("Tried to step {} times. The specified number of total steps is {}"
351 | .format(step_num + 1, self.total_steps))
352 |
353 | for group in self.optimizer.param_groups:
354 | if step_num <= self.step_size_up:
355 | computed_lr = self.anneal_func(group['initial_lr'], group['max_lr'], step_num / self.step_size_up)
356 | if self.cycle_momentum:
357 | computed_momentum = self.anneal_func(group['max_momentum'], group['base_momentum'],
358 | step_num / self.step_size_up)
359 | else:
360 | down_step_num = step_num - self.step_size_up
361 | computed_lr = self.anneal_func(group['max_lr'], group['min_lr'], down_step_num / self.step_size_down)
362 | if self.cycle_momentum:
363 | computed_momentum = self.anneal_func(group['base_momentum'], group['max_momentum'],
364 | down_step_num / self.step_size_down)
365 |
366 | lrs.append(computed_lr)
367 | if self.cycle_momentum:
368 | if self.use_beta1:
369 | _, beta2 = group['betas']
370 | group['betas'] = (computed_momentum, beta2)
371 | else:
372 | group['momentum'] = computed_momentum
373 |
374 | return lrs
375 |
--------------------------------------------------------------------------------
/core/raft_gma_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from update import GMAUpdateBlock
6 | from extractor import BasicEncoder
7 | from corr import CorrBlock
8 | from utils.utils import bilinear_sampler, coords_grid, upflow8
9 | from gma import Attention
10 |
11 | from swin_basic import BasicSwinUpdate
12 |
13 | try:
14 | autocast = torch.cuda.amp.autocast
15 | except:
16 | # dummy autocast for PyTorch < 1.6
17 | class autocast:
18 | def __init__(self, enabled):
19 | pass
20 |
21 | def __enter__(self):
22 | pass
23 |
24 | def __exit__(self, *args):
25 | pass
26 |
27 |
28 | class RAFTGMAModel(nn.Module):
29 | def __init__(self, args):
30 | super().__init__()
31 | self.args = args
32 | self.args.num_heads = 1
33 | self.args.position_only = False
34 | self.args.position_and_content = False
35 |
36 | self.hidden_dim = hdim = 128
37 | self.context_dim = cdim = 128
38 | args.corr_levels = 4
39 | args.corr_radius = 4
40 |
41 | if 'dropout' not in self.args:
42 | self.args.dropout = 0
43 |
44 | # feature network, context network, and update block
45 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
46 | self.cnet = BasicEncoder(output_dim=hdim + cdim, norm_fn='batch', dropout=args.dropout)
47 | self.update_block = GMAUpdateBlock(self.args, hidden_dim=hdim)
48 | self.att = Attention(args=self.args, dim=cdim, heads=self.args.num_heads, max_pos_size=160, dim_head=cdim)
49 |
50 | def freeze_bn(self):
51 | for m in self.modules():
52 | if isinstance(m, nn.BatchNorm2d):
53 | m.eval()
54 |
55 | def initialize_flow(self, img):
56 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
57 | N, C, H, W = img.shape
58 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
59 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
60 |
61 | # optical flow computed as difference: flow = coords1 - coords0
62 | return coords0, coords1
63 |
64 | def upsample_flow(self, flow, mask):
65 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
66 | N, _, H, W = flow.shape
67 | mask = mask.view(N, 1, 9, 8, 8, H, W)
68 | mask = torch.softmax(mask, dim=2)
69 |
70 | up_flow = F.unfold(8 * flow, [3, 3], padding=1)
71 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
72 |
73 | up_flow = torch.sum(mask * up_flow, dim=2)
74 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
75 | return up_flow.reshape(N, 2, 8 * H, 8 * W)
76 |
77 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
78 | """ Estimate optical flow between pair of frames """
79 |
80 | image1 = 2 * (image1 / 255.0) - 1.0
81 | image2 = 2 * (image2 / 255.0) - 1.0
82 |
83 | image1 = image1.contiguous()
84 | image2 = image2.contiguous()
85 |
86 | hdim = self.hidden_dim
87 | cdim = self.context_dim
88 |
89 | # run the feature network
90 | with autocast(enabled=self.args.mixed_precision):
91 | fmap1, fmap2 = self.fnet([image1, image2])
92 |
93 | fmap1 = fmap1.float()
94 | fmap2 = fmap2.float()
95 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
96 |
97 | # run the context network
98 | with autocast(enabled=self.args.mixed_precision):
99 | cnet = self.cnet(image1)
100 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
101 | net = torch.tanh(net)
102 | inp = torch.relu(inp)
103 | # attention, att_c, att_p = self.att(inp)
104 | attention = self.att(inp)
105 |
106 | coords0, coords1 = self.initialize_flow(image1)
107 |
108 | if flow_init is not None:
109 | coords1 = coords1 + flow_init
110 |
111 | flow_predictions = []
112 | for itr in range(iters):
113 | coords1 = coords1.detach()
114 | corr = corr_fn(coords1) # index correlation volume
115 |
116 | flow = coords1 - coords0
117 | with autocast(enabled=self.args.mixed_precision):
118 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow, attention)
119 |
120 | # F(t+1) = F(t) + \Delta(t)
121 | coords1 = coords1 + delta_flow
122 |
123 | # upsample predictions
124 | if up_mask is None:
125 | flow_up = upflow8(coords1 - coords0)
126 | else:
127 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
128 |
129 | flow_predictions.append(flow_up)
130 |
131 | if test_mode:
132 | return coords1 - coords0, flow_up
133 |
134 | return flow_predictions, None
135 |
136 |
137 |
138 | if __name__ == "__main__":
139 | att = Attention(dim=128, heads=1)
140 | fmap = torch.randn(2, 128, 40, 90)
141 | out = att(fmap)
142 |
143 | print(out.shape)
144 |
--------------------------------------------------------------------------------
/core/raft_model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from update import BasicUpdateBlock, SmallUpdateBlock
8 | from extractor import BasicEncoder, SmallEncoder
9 | from corr import CorrBlock, AlternateCorrBlock
10 | from utils.utils import bilinear_sampler, coords_grid, upflow8
11 |
12 | try:
13 | autocast = torch.cuda.amp.autocast
14 | except:
15 | # dummy autocast for PyTorch < 1.6
16 | class autocast:
17 | def __init__(self, enabled):
18 | pass
19 | def __enter__(self):
20 | pass
21 | def __exit__(self, *args):
22 | pass
23 |
24 |
25 | class RAFTModel(nn.Module):
26 | def __init__(self, args):
27 | super(RAFTModel, self).__init__()
28 | self.args = args
29 |
30 | if args.small:
31 | self.hidden_dim = hdim = 96
32 | self.context_dim = cdim = 64
33 | args.corr_levels = 4
34 | args.corr_radius = 3
35 |
36 | else:
37 | self.hidden_dim = hdim = 128
38 | self.context_dim = cdim = 128
39 | args.corr_levels = 4
40 | args.corr_radius = 4
41 |
42 | if not hasattr(self.args, 'dropout'):
43 | self.args.dropout = 0
44 |
45 | if not hasattr(self.args, 'alternate_corr'):
46 | self.args.alternate_corr = False
47 |
48 | # feature network, context network, and update block
49 | if args.small:
50 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
51 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
52 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
53 |
54 | else:
55 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
56 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
57 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
58 |
59 | def freeze_bn(self):
60 | for m in self.modules():
61 | if isinstance(m, nn.BatchNorm2d):
62 | m.eval()
63 |
64 | def initialize_flow(self, img):
65 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
66 | N, C, H, W = img.shape
67 | coords0 = coords_grid(N, H//8, W//8).to(img.device)
68 | coords1 = coords_grid(N, H//8, W//8).to(img.device)
69 |
70 | # optical flow computed as difference: flow = coords1 - coords0
71 | return coords0, coords1
72 |
73 | def upsample_flow(self, flow, mask):
74 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
75 | N, _, H, W = flow.shape
76 | mask = mask.view(N, 1, 9, 8, 8, H, W)
77 | mask = torch.softmax(mask, dim=2)
78 |
79 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
80 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
81 |
82 | up_flow = torch.sum(mask * up_flow, dim=2)
83 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
84 | return up_flow.reshape(N, 2, 8*H, 8*W)
85 |
86 |
87 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
88 | """ Estimate optical flow between pair of frames """
89 |
90 | image1 = 2 * (image1 / 255.0) - 1.0
91 | image2 = 2 * (image2 / 255.0) - 1.0
92 |
93 | image1 = image1.contiguous()
94 | image2 = image2.contiguous()
95 |
96 | hdim = self.hidden_dim
97 | cdim = self.context_dim
98 |
99 | # run the feature network
100 | with autocast(enabled=self.args.mixed_precision):
101 | fmap1, fmap2 = self.fnet([image1, image2])
102 |
103 | fmap1 = fmap1.float()
104 | fmap2 = fmap2.float()
105 | if self.args.alternate_corr:
106 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
107 | else:
108 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
109 |
110 | # run the context network
111 | with autocast(enabled=self.args.mixed_precision):
112 | cnet = self.cnet(image1)
113 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
114 | net = torch.tanh(net)
115 | inp = torch.relu(inp)
116 |
117 | coords0, coords1 = self.initialize_flow(image1)
118 |
119 | if flow_init is not None:
120 | coords1 = coords1 + flow_init
121 |
122 | flow_predictions = []
123 | for itr in range(iters):
124 | coords1 = coords1.detach()
125 | corr = corr_fn(coords1) # index correlation volume
126 |
127 | flow = coords1 - coords0
128 | with autocast(enabled=self.args.mixed_precision):
129 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
130 |
131 | # F(t+1) = F(t) + \Delta(t)
132 | coords1 = coords1 + delta_flow
133 |
134 | # upsample predictions
135 | if up_mask is None:
136 | flow_up = upflow8(coords1 - coords0)
137 | else:
138 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
139 |
140 | flow_predictions.append(flow_up)
141 |
142 | if test_mode:
143 | return coords1 - coords0, flow_up
144 |
145 | return flow_predictions, fmap1, fmap2
146 |
147 |
--------------------------------------------------------------------------------
/core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from gma import Aggregate
6 |
7 |
8 | class FlowHead(nn.Module):
9 | def __init__(self, input_dim=128, hidden_dim=256):
10 | super(FlowHead, self).__init__()
11 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
12 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
13 | self.relu = nn.ReLU(inplace=True)
14 |
15 | def forward(self, x):
16 | return self.conv2(self.relu(self.conv1(x)))
17 |
18 |
19 | class ConvGRU(nn.Module):
20 | def __init__(self, hidden_dim=128, input_dim=192+128):
21 | super(ConvGRU, self).__init__()
22 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
23 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
24 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
25 |
26 | def forward(self, h, x):
27 | hx = torch.cat([h, x], dim=1)
28 |
29 | z = torch.sigmoid(self.convz(hx))
30 | r = torch.sigmoid(self.convr(hx))
31 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
32 |
33 | h = (1-z) * h + z * q
34 | return h
35 |
36 |
37 | class SepConvGRU(nn.Module):
38 | def __init__(self, hidden_dim=128, input_dim=192+128):
39 | super(SepConvGRU, self).__init__()
40 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
41 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
42 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
43 |
44 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
45 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
46 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
47 |
48 |
49 | def forward(self, h, x):
50 | # horizontal
51 | hx = torch.cat([h, x], dim=1)
52 | z = torch.sigmoid(self.convz1(hx))
53 | r = torch.sigmoid(self.convr1(hx))
54 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
55 | h = (1-z) * h + z * q
56 |
57 | # vertical
58 | hx = torch.cat([h, x], dim=1)
59 | z = torch.sigmoid(self.convz2(hx))
60 | r = torch.sigmoid(self.convr2(hx))
61 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
62 | h = (1-z) * h + z * q
63 |
64 | return h
65 |
66 |
67 | class SmallMotionEncoder(nn.Module):
68 | def __init__(self, args):
69 | super(SmallMotionEncoder, self).__init__()
70 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
71 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
72 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
73 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
74 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
75 |
76 | def forward(self, flow, corr):
77 | cor = F.relu(self.convc1(corr))
78 | flo = F.relu(self.convf1(flow))
79 | flo = F.relu(self.convf2(flo))
80 | cor_flo = torch.cat([cor, flo], dim=1)
81 | out = F.relu(self.conv(cor_flo))
82 | return torch.cat([out, flow], dim=1)
83 |
84 |
85 | class BasicMotionEncoder(nn.Module):
86 | def __init__(self, args):
87 | super(BasicMotionEncoder, self).__init__()
88 | if hasattr(args, 'motion_feat_indim'):
89 | cor_planes = args.motion_feat_indim
90 | else:
91 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
92 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
93 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
94 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
95 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
96 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
97 |
98 | def forward(self, flow, corr):
99 | cor = F.relu(self.convc1(corr))
100 | cor = F.relu(self.convc2(cor))
101 | flo = F.relu(self.convf1(flow))
102 | flo = F.relu(self.convf2(flo))
103 |
104 | cor_flo = torch.cat([cor, flo], dim=1)
105 | out = F.relu(self.conv(cor_flo))
106 | return torch.cat([out, flow], dim=1)
107 |
108 |
109 | class SmallUpdateBlock(nn.Module):
110 | def __init__(self, args, hidden_dim=96):
111 | super(SmallUpdateBlock, self).__init__()
112 | self.encoder = SmallMotionEncoder(args)
113 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
114 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
115 |
116 | def forward(self, net, inp, corr, flow):
117 | motion_features = self.encoder(flow, corr)
118 | inp = torch.cat([inp, motion_features], dim=1)
119 | net = self.gru(net, inp)
120 | delta_flow = self.flow_head(net)
121 |
122 | return net, None, delta_flow
123 |
124 |
125 | class BasicUpdateBlock(nn.Module):
126 | def __init__(self, args, hidden_dim=128, input_dim=128):
127 | super(BasicUpdateBlock, self).__init__()
128 | self.args = args
129 | self.encoder = BasicMotionEncoder(args)
130 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+input_dim)
131 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
132 |
133 | hidden_dim2 = (hidden_dim // 256)*128 + 256
134 |
135 | self.mask = nn.Sequential(
136 | nn.Conv2d(hidden_dim, hidden_dim2, 3, padding=1),
137 | nn.ReLU(inplace=True),
138 | nn.Conv2d(hidden_dim2, 64*9, 1, padding=0))
139 |
140 | def forward(self, net, inp, corr, flow, upsample=True):
141 | motion_features = self.encoder(flow, corr)
142 | inp = torch.cat([inp, motion_features], dim=1)
143 |
144 | net = self.gru(net, inp)
145 | delta_flow = self.flow_head(net)
146 |
147 | # scale mask to balence gradients
148 | mask = .25 * self.mask(net)
149 | return net, mask, delta_flow
150 |
151 |
152 | class GMAUpdateBlock(nn.Module):
153 | def __init__(self, args, hidden_dim=128):
154 | super().__init__()
155 | self.args = args
156 | self.encoder = BasicMotionEncoder(args)
157 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim+hidden_dim)
158 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
159 |
160 | self.mask = nn.Sequential(
161 | nn.Conv2d(128, 256, 3, padding=1),
162 | nn.ReLU(inplace=True),
163 | nn.Conv2d(256, 64*9, 1, padding=0))
164 |
165 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=self.args.num_heads)
166 |
167 | def forward(self, net, inp, corr, flow, attention):
168 | motion_features = self.encoder(flow, corr)
169 | motion_features_global = self.aggregator(attention, motion_features)
170 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1)
171 |
172 | # Attentional update
173 | net = self.gru(net, inp_cat)
174 |
175 | delta_flow = self.flow_head(net)
176 |
177 | # scale mask to balence gradients
178 | mask = .25 * self.mask(net)
179 | return net, mask, delta_flow
180 |
--------------------------------------------------------------------------------
/core/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # functions from timm
2 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
3 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
4 | from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_
--------------------------------------------------------------------------------
/core/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import math
4 | from PIL import Image
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | import torch
11 | from torchvision.transforms import ColorJitter
12 | import torch.nn.functional as F
13 |
14 |
15 | class FlowAugmentor:
16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
17 |
18 | # spatial augmentation params
19 | self.crop_size = crop_size
20 | self.min_scale = min_scale
21 | self.max_scale = max_scale
22 | self.spatial_aug_prob = 0.8
23 | self.stretch_prob = 0.8
24 | self.max_stretch = 0.2
25 |
26 | # flip augmentation params
27 | self.do_flip = do_flip
28 | self.h_flip_prob = 0.5
29 | self.v_flip_prob = 0.1
30 |
31 | # photometric augmentation params
32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
33 | self.asymmetric_color_aug_prob = 0.2
34 | self.eraser_aug_prob = 0.5
35 |
36 | def color_transform(self, img1, img2):
37 | """ Photometric augmentation """
38 |
39 | # asymmetric
40 | if np.random.rand() < self.asymmetric_color_aug_prob:
41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
43 |
44 | # symmetric
45 | else:
46 | image_stack = np.concatenate([img1, img2], axis=0)
47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
48 | img1, img2 = np.split(image_stack, 2, axis=0)
49 |
50 | return img1, img2
51 |
52 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
53 | """ Occlusion augmentation """
54 |
55 | ht, wd = img1.shape[:2]
56 | if np.random.rand() < self.eraser_aug_prob:
57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
58 | for _ in range(np.random.randint(1, 3)):
59 | x0 = np.random.randint(0, wd)
60 | y0 = np.random.randint(0, ht)
61 | dx = np.random.randint(bounds[0], bounds[1])
62 | dy = np.random.randint(bounds[0], bounds[1])
63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
64 |
65 | return img1, img2
66 |
67 | def spatial_transform(self, img1, img2, flow):
68 | # randomly sample scale
69 | ht, wd = img1.shape[:2]
70 | min_scale = np.maximum(
71 | (self.crop_size[0] + 8) / float(ht),
72 | (self.crop_size[1] + 8) / float(wd))
73 |
74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
75 | scale_x = scale
76 | scale_y = scale
77 | if np.random.rand() < self.stretch_prob:
78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
80 |
81 | scale_x = np.clip(scale_x, min_scale, None)
82 | scale_y = np.clip(scale_y, min_scale, None)
83 |
84 | if np.random.rand() < self.spatial_aug_prob:
85 | # rescale the images
86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
89 | flow = flow * [scale_x, scale_y]
90 |
91 | if self.do_flip:
92 | if np.random.rand() < self.h_flip_prob: # h-flip
93 | img1 = img1[:, ::-1]
94 | img2 = img2[:, ::-1]
95 | flow = flow[:, ::-1] * [-1.0, 1.0]
96 |
97 | if np.random.rand() < self.v_flip_prob: # v-flip
98 | img1 = img1[::-1, :]
99 | img2 = img2[::-1, :]
100 | flow = flow[::-1, :] * [1.0, -1.0]
101 |
102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
104 |
105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
108 |
109 | return img1, img2, flow
110 |
111 | def __call__(self, img1, img2, flow):
112 | img1, img2 = self.color_transform(img1, img2)
113 | img1, img2 = self.eraser_transform(img1, img2)
114 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
115 |
116 | img1 = np.ascontiguousarray(img1)
117 | img2 = np.ascontiguousarray(img2)
118 | flow = np.ascontiguousarray(flow)
119 |
120 | return img1, img2, flow
121 |
122 | class SparseFlowAugmentor:
123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
124 | # spatial augmentation params
125 | self.crop_size = crop_size
126 | self.min_scale = min_scale
127 | self.max_scale = max_scale
128 | self.spatial_aug_prob = 0.8
129 | self.stretch_prob = 0.8
130 | self.max_stretch = 0.2
131 |
132 | # flip augmentation params
133 | self.do_flip = do_flip
134 | self.h_flip_prob = 0.5
135 | self.v_flip_prob = 0.1
136 |
137 | # photometric augmentation params
138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
139 | self.asymmetric_color_aug_prob = 0.2
140 | self.eraser_aug_prob = 0.5
141 |
142 | def color_transform(self, img1, img2):
143 | image_stack = np.concatenate([img1, img2], axis=0)
144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
145 | img1, img2 = np.split(image_stack, 2, axis=0)
146 | return img1, img2
147 |
148 | def eraser_transform(self, img1, img2):
149 | ht, wd = img1.shape[:2]
150 | if np.random.rand() < self.eraser_aug_prob:
151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
152 | for _ in range(np.random.randint(1, 3)):
153 | x0 = np.random.randint(0, wd)
154 | y0 = np.random.randint(0, ht)
155 | dx = np.random.randint(50, 100)
156 | dy = np.random.randint(50, 100)
157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
158 |
159 | return img1, img2
160 |
161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
162 | ht, wd = flow.shape[:2]
163 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
164 | coords = np.stack(coords, axis=-1)
165 |
166 | coords = coords.reshape(-1, 2).astype(np.float32)
167 | flow = flow.reshape(-1, 2).astype(np.float32)
168 | valid = valid.reshape(-1).astype(np.float32)
169 |
170 | coords0 = coords[valid>=1]
171 | flow0 = flow[valid>=1]
172 |
173 | ht1 = int(round(ht * fy))
174 | wd1 = int(round(wd * fx))
175 |
176 | coords1 = coords0 * [fx, fy]
177 | flow1 = flow0 * [fx, fy]
178 |
179 | xx = np.round(coords1[:,0]).astype(np.int32)
180 | yy = np.round(coords1[:,1]).astype(np.int32)
181 |
182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
183 | xx = xx[v]
184 | yy = yy[v]
185 | flow1 = flow1[v]
186 |
187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
189 |
190 | flow_img[yy, xx] = flow1
191 | valid_img[yy, xx] = 1
192 |
193 | return flow_img, valid_img
194 |
195 | def spatial_transform(self, img1, img2, flow, valid):
196 | # randomly sample scale
197 |
198 | ht, wd = img1.shape[:2]
199 | min_scale = np.maximum(
200 | (self.crop_size[0] + 1) / float(ht),
201 | (self.crop_size[1] + 1) / float(wd))
202 |
203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
204 | scale_x = np.clip(scale, min_scale, None)
205 | scale_y = np.clip(scale, min_scale, None)
206 |
207 | if np.random.rand() < self.spatial_aug_prob:
208 | # rescale the images
209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
212 |
213 | if self.do_flip:
214 | if np.random.rand() < 0.5: # h-flip
215 | img1 = img1[:, ::-1]
216 | img2 = img2[:, ::-1]
217 | flow = flow[:, ::-1] * [-1.0, 1.0]
218 | valid = valid[:, ::-1]
219 |
220 | margin_y = 20
221 | margin_x = 50
222 |
223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
225 |
226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
228 |
229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
233 | return img1, img2, flow, valid
234 |
235 |
236 | def __call__(self, img1, img2, flow, valid):
237 | img1, img2 = self.color_transform(img1, img2)
238 | img1, img2 = self.eraser_transform(img1, img2)
239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
240 |
241 | img1 = np.ascontiguousarray(img1)
242 | img2 = np.ascontiguousarray(img2)
243 | flow = np.ascontiguousarray(flow)
244 | valid = np.ascontiguousarray(valid)
245 |
246 | return img1, img2, flow, valid
247 |
--------------------------------------------------------------------------------
/core/utils/drop.py:
--------------------------------------------------------------------------------
1 | """ DropBlock, DropPath
2 |
3 | PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
4 |
5 | Papers:
6 | DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
7 |
8 | Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
9 |
10 | Code:
11 | DropBlock impl inspired by two Tensorflow impl that I liked:
12 | - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
13 | - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
14 |
15 | Hacked together by / Copyright 2020 Ross Wightman
16 | """
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 |
22 | def drop_block_2d(
23 | x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
24 | with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
25 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
26 |
27 | DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
28 | runs with success, but needs further validation and possibly optimization for lower runtime impact.
29 | """
30 | B, C, H, W = x.shape
31 | total_size = W * H
32 | clipped_block_size = min(block_size, min(W, H))
33 | # seed_drop_rate, the gamma parameter
34 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
35 | (W - block_size + 1) * (H - block_size + 1))
36 |
37 | # Forces the block to be inside the feature map.
38 | w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
39 | valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
40 | ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
41 | valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
42 |
43 | if batchwise:
44 | # one mask for whole batch, quite a bit faster
45 | uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
46 | else:
47 | uniform_noise = torch.rand_like(x)
48 | block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
49 | block_mask = -F.max_pool2d(
50 | -block_mask,
51 | kernel_size=clipped_block_size, # block_size,
52 | stride=1,
53 | padding=clipped_block_size // 2)
54 |
55 | if with_noise:
56 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
57 | if inplace:
58 | x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
59 | else:
60 | x = x * block_mask + normal_noise * (1 - block_mask)
61 | else:
62 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
63 | if inplace:
64 | x.mul_(block_mask * normalize_scale)
65 | else:
66 | x = x * block_mask * normalize_scale
67 | return x
68 |
69 |
70 | def drop_block_fast_2d(
71 | x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
72 | gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False):
73 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
74 |
75 | DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
76 | block mask at edges.
77 | """
78 | B, C, H, W = x.shape
79 | total_size = W * H
80 | clipped_block_size = min(block_size, min(W, H))
81 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
82 | (W - block_size + 1) * (H - block_size + 1))
83 |
84 | block_mask = torch.empty_like(x).bernoulli_(gamma)
85 | block_mask = F.max_pool2d(
86 | block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
87 |
88 | if with_noise:
89 | normal_noise = torch.empty_like(x).normal_()
90 | if inplace:
91 | x.mul_(1. - block_mask).add_(normal_noise * block_mask)
92 | else:
93 | x = x * (1. - block_mask) + normal_noise * block_mask
94 | else:
95 | block_mask = 1 - block_mask
96 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
97 | if inplace:
98 | x.mul_(block_mask * normalize_scale)
99 | else:
100 | x = x * block_mask * normalize_scale
101 | return x
102 |
103 |
104 | class DropBlock2d(nn.Module):
105 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
106 | """
107 |
108 | def __init__(
109 | self,
110 | drop_prob: float = 0.1,
111 | block_size: int = 7,
112 | gamma_scale: float = 1.0,
113 | with_noise: bool = False,
114 | inplace: bool = False,
115 | batchwise: bool = False,
116 | fast: bool = True):
117 | super(DropBlock2d, self).__init__()
118 | self.drop_prob = drop_prob
119 | self.gamma_scale = gamma_scale
120 | self.block_size = block_size
121 | self.with_noise = with_noise
122 | self.inplace = inplace
123 | self.batchwise = batchwise
124 | self.fast = fast # FIXME finish comparisons of fast vs not
125 |
126 | def forward(self, x):
127 | if not self.training or not self.drop_prob:
128 | return x
129 | if self.fast:
130 | return drop_block_fast_2d(
131 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace)
132 | else:
133 | return drop_block_2d(
134 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
135 |
136 |
137 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
138 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
139 |
140 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
141 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
142 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
143 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
144 | 'survival rate' as the argument.
145 |
146 | """
147 | if drop_prob == 0. or not training:
148 | return x
149 | keep_prob = 1 - drop_prob
150 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
151 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
152 | if keep_prob > 0.0 and scale_by_keep:
153 | random_tensor.div_(keep_prob)
154 | return x * random_tensor
155 |
156 |
157 | class DropPath(nn.Module):
158 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
159 | """
160 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
161 | super(DropPath, self).__init__()
162 | self.drop_prob = drop_prob
163 | self.scale_by_keep = scale_by_keep
164 |
165 | def forward(self, x):
166 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
--------------------------------------------------------------------------------
/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/core/utils/helpers.py:
--------------------------------------------------------------------------------
1 | """ Layer/Module Helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from itertools import repeat
6 | import collections.abc
7 |
8 |
9 | # From PyTorch internals
10 | def _ntuple(n):
11 | def parse(x):
12 | if isinstance(x, collections.abc.Iterable):
13 | return x
14 | return tuple(repeat(x, n))
15 | return parse
16 |
17 |
18 | to_1tuple = _ntuple(1)
19 | to_2tuple = _ntuple(2)
20 | to_3tuple = _ntuple(3)
21 | to_4tuple = _ntuple(4)
22 | to_ntuple = _ntuple
23 |
24 |
25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
26 | min_value = min_value or divisor
27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
28 | # Make sure that round down does not go down by more than 10%.
29 | if new_v < round_limit * v:
30 | new_v += divisor
31 | return new_v
--------------------------------------------------------------------------------
/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd):
75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/core/utils/weight_init.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import warnings
4 |
5 | from torch.nn.init import _calculate_fan_in_and_fan_out
6 |
7 |
8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11 | def norm_cdf(x):
12 | # Computes standard normal cumulative distribution function
13 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
14 |
15 | if (mean < a - 2 * std) or (mean > b + 2 * std):
16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
17 | "The distribution of values may be incorrect.",
18 | stacklevel=2)
19 |
20 | with torch.no_grad():
21 | # Values are generated by using a truncated uniform distribution and
22 | # then using the inverse CDF for the normal distribution.
23 | # Get upper and lower cdf values
24 | l = norm_cdf((a - mean) / std)
25 | u = norm_cdf((b - mean) / std)
26 |
27 | # Uniformly fill tensor with values from [l, u], then translate to
28 | # [2l-1, 2u-1].
29 | tensor.uniform_(2 * l - 1, 2 * u - 1)
30 |
31 | # Use inverse cdf transform for normal distribution to get truncated
32 | # standard normal
33 | tensor.erfinv_()
34 |
35 | # Transform to proper mean, std
36 | tensor.mul_(std * math.sqrt(2.))
37 | tensor.add_(mean)
38 |
39 | # Clamp to ensure it's in the proper range
40 | tensor.clamp_(min=a, max=b)
41 | return tensor
42 |
43 |
44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
45 | # type: (Tensor, float, float, float, float) -> Tensor
46 | r"""Fills the input Tensor with values drawn from a truncated
47 | normal distribution. The values are effectively drawn from the
48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
49 | with values outside :math:`[a, b]` redrawn until they are within
50 | the bounds. The method used for generating the random values works
51 | best when :math:`a \leq \text{mean} \leq b`.
52 | Args:
53 | tensor: an n-dimensional `torch.Tensor`
54 | mean: the mean of the normal distribution
55 | std: the standard deviation of the normal distribution
56 | a: the minimum cutoff value
57 | b: the maximum cutoff value
58 | Examples:
59 | >>> w = torch.empty(3, 5)
60 | >>> nn.init.trunc_normal_(w)
61 | """
62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
63 |
64 |
65 | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
66 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
67 | if mode == 'fan_in':
68 | denom = fan_in
69 | elif mode == 'fan_out':
70 | denom = fan_out
71 | elif mode == 'fan_avg':
72 | denom = (fan_in + fan_out) / 2
73 |
74 | variance = scale / denom
75 |
76 | if distribution == "truncated_normal":
77 | # constant is stddev of standard normal truncated to (-2, 2)
78 | trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
79 | elif distribution == "normal":
80 | tensor.normal_(std=math.sqrt(variance))
81 | elif distribution == "uniform":
82 | bound = math.sqrt(3 * variance)
83 | tensor.uniform_(-bound, bound)
84 | else:
85 | raise ValueError(f"invalid distribution {distribution}")
86 |
87 |
88 | def lecun_normal_(tensor):
89 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
--------------------------------------------------------------------------------
/demo-frames/frame_0016.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0016.png
--------------------------------------------------------------------------------
/demo-frames/frame_0017.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0017.png
--------------------------------------------------------------------------------
/demo-frames/frame_0018.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0018.png
--------------------------------------------------------------------------------
/demo-frames/frame_0019.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0019.png
--------------------------------------------------------------------------------
/demo-frames/frame_0020.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0020.png
--------------------------------------------------------------------------------
/demo-frames/frame_0021.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0021.png
--------------------------------------------------------------------------------
/demo-frames/frame_0022.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0022.png
--------------------------------------------------------------------------------
/demo-frames/frame_0023.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0023.png
--------------------------------------------------------------------------------
/demo-frames/frame_0024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0024.png
--------------------------------------------------------------------------------
/demo-frames/frame_0025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0025.png
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | import argparse
5 | import os
6 | import cv2
7 | import glob
8 | import numpy as np
9 | import torch
10 | from PIL import Image
11 |
12 | from utils import flow_viz
13 | from utils.utils import InputPadder
14 | from core import create_model
15 |
16 |
17 | DEVICE = 'cuda'
18 |
19 | def load_image(imfile):
20 | img = np.array(Image.open(imfile)).astype(np.uint8)
21 | img = torch.from_numpy(img).permute(2, 0, 1).float()
22 | return img[None].to(DEVICE)
23 |
24 |
25 | def viz(img, flo):
26 | img = img[0].permute(1,2,0).cpu().numpy()
27 | flo = flo[0].permute(1,2,0).cpu().numpy()
28 |
29 | # map flow to rgb image
30 | flo = flow_viz.flow_to_image(flo)
31 | img_flo = np.concatenate([img, flo], axis=0)
32 |
33 | # import matplotlib.pyplot as plt
34 | # plt.imshow(img_flo / 255.0)
35 | # plt.show()
36 |
37 | cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
38 | cv2.waitKey()
39 |
40 |
41 | def demo(args):
42 | model = torch.nn.DataParallel(create_model(args))
43 | model.load_state_dict(torch.load(args.ckpt))
44 |
45 | model = model.module
46 | model.to(DEVICE)
47 | model.eval()
48 |
49 | with torch.no_grad():
50 | images = glob.glob(os.path.join(args.path, '*.png')) + \
51 | glob.glob(os.path.join(args.path, '*.jpg'))
52 |
53 | images = sorted(images)
54 | for imfile1, imfile2 in zip(images[:-1], images[1:]):
55 | image1 = load_image(imfile1)
56 | image2 = load_image(imfile2)
57 |
58 | padder = InputPadder(image1.shape)
59 | image1, image2 = padder.pad(image1, image2)
60 |
61 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
62 | viz(image1, flow_up)
63 |
64 |
65 | if __name__ == '__main__':
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument('--model', default='gmflownet', help="mdoel class. ``_model.py should be in ./core and `Model` should be defined in this file")
68 | parser.add_argument('--ckpt', help="restored checkpoint")
69 |
70 | parser.add_argument('--path', help="dataset for evaluation")
71 | parser.add_argument('--use_mix_attn', action='store_true', help='use mixture of POLA and axial attentions')
72 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
73 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
74 | args = parser.parse_args()
75 |
76 | demo(args)
77 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | from PIL import Image
5 | import argparse
6 | import os
7 | import time
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | import matplotlib.pyplot as plt
12 |
13 | import datasets
14 | from utils import flow_viz
15 | from utils import frame_utils
16 |
17 | # from raft import RAFT, RAFT_Transformer
18 | from core import create_model
19 | from utils.utils import InputPadder, forward_interpolate
20 |
21 |
22 | @torch.no_grad()
23 | def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
24 | """ Create submission for the Sintel leaderboard """
25 | model.eval()
26 | for dstype in ['clean', 'final']:
27 | test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
28 |
29 | flow_prev, sequence_prev = None, None
30 | for test_id in range(len(test_dataset)):
31 | image1, image2, (sequence, frame) = test_dataset[test_id]
32 | if sequence != sequence_prev:
33 | flow_prev = None
34 |
35 | padder = InputPadder(image1.shape)
36 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
37 |
38 | flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
39 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
40 |
41 | if warm_start:
42 | flow_prev = forward_interpolate(flow_low[0])[None].cuda()
43 |
44 | output_dir = os.path.join(output_path, dstype, sequence)
45 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
46 |
47 | if not os.path.exists(output_dir):
48 | os.makedirs(output_dir)
49 |
50 | frame_utils.writeFlow(output_file, flow)
51 | sequence_prev = sequence
52 |
53 |
54 | @torch.no_grad()
55 | def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
56 | """ Create submission for the Sintel leaderboard """
57 | model.eval()
58 | test_dataset = datasets.KITTI(split='testing', aug_params=None)
59 |
60 | if not os.path.exists(output_path):
61 | os.makedirs(output_path)
62 |
63 | for test_id in range(len(test_dataset)):
64 | image1, image2, (frame_id, ) = test_dataset[test_id]
65 | padder = InputPadder(image1.shape, mode='kitti')
66 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
67 |
68 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
69 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
70 |
71 | output_filename = os.path.join(output_path, frame_id)
72 | frame_utils.writeFlowKITTI(output_filename, flow)
73 |
74 |
75 | @torch.no_grad()
76 | def validate_chairs(model, iters=24):
77 | """ Perform evaluation on the FlyingChairs (test) split """
78 | model.eval()
79 | epe_list = []
80 |
81 | val_dataset = datasets.FlyingChairs(split='validation')
82 | for val_id in range(len(val_dataset)):
83 | image1, image2, flow_gt, _ = val_dataset[val_id]
84 | image1 = image1[None].cuda()
85 | image2 = image2[None].cuda()
86 |
87 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
88 | epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
89 | epe_list.append(epe.view(-1).numpy())
90 |
91 | epe = np.mean(np.concatenate(epe_list))
92 | print("Validation Chairs EPE: %f" % epe)
93 | return {'chairs': epe}
94 |
95 |
96 | @torch.no_grad()
97 | def validate_sintel(model, iters=32, warm_start=False):
98 | """ Peform validation using the Sintel (train) split """
99 | model.eval()
100 | results = {}
101 | for dstype in ['clean', 'final']:
102 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype, is_validate=True)
103 | epe_list = []
104 |
105 | flow_prev, sequence_prev = None, None
106 | for val_id in range(len(val_dataset)):
107 | image1, image2, flow_gt, _, (sequence, frame) = val_dataset[val_id]
108 | image1 = image1[None].cuda()
109 | image2 = image2[None].cuda()
110 |
111 | if sequence != sequence_prev:
112 | flow_prev = None
113 |
114 | padder = InputPadder(image1.shape)
115 | image1, image2 = padder.pad(image1, image2)
116 |
117 | flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
118 | flow = padder.unpad(flow_pr[0]).cpu()
119 |
120 | if warm_start:
121 | flow_prev = forward_interpolate(flow_low[0])[None].cuda()
122 |
123 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
124 | epe_list.append(epe.view(-1).numpy())
125 |
126 | sequence_prev = sequence
127 |
128 | epe_all = np.concatenate(epe_list)
129 | epe = np.mean(epe_all)
130 | px1 = np.mean(epe_all<1)
131 | px3 = np.mean(epe_all<3)
132 | px5 = np.mean(epe_all<5)
133 |
134 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
135 | results[dstype] = np.mean(epe_list)
136 |
137 | return results
138 |
139 |
140 | @torch.no_grad()
141 | def validate_kitti(model, iters=24):
142 | """ Peform validation using the KITTI-2015 (train) split """
143 | model.eval()
144 | val_dataset = datasets.KITTI(split='training')
145 |
146 | out_list, epe_list = [], []
147 | for val_id in range(len(val_dataset)):
148 | image1, image2, flow_gt, valid_gt = val_dataset[val_id]
149 | image1 = image1[None].cuda()
150 | image2 = image2[None].cuda()
151 |
152 | padder = InputPadder(image1.shape, mode='kitti')
153 | image1, image2 = padder.pad(image1, image2)
154 |
155 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
156 | flow = padder.unpad(flow_pr[0]).cpu()
157 |
158 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
159 | mag = torch.sum(flow_gt**2, dim=0).sqrt()
160 |
161 | epe = epe.view(-1)
162 | mag = mag.view(-1)
163 | val = valid_gt.view(-1) >= 0.5
164 |
165 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
166 | epe_list.append(epe[val].mean().item())
167 | out_list.append(out[val].cpu().numpy())
168 |
169 | epe_list = np.array(epe_list)
170 | out_list = np.concatenate(out_list)
171 |
172 | epe = np.mean(epe_list)
173 | f1 = 100 * np.mean(out_list)
174 |
175 | print("Validation KITTI: %f, %f" % (epe, f1))
176 | return {'kitti-epe': epe, 'kitti-f1': f1}
177 |
178 |
179 | if __name__ == '__main__':
180 | parser = argparse.ArgumentParser()
181 | parser.add_argument('--model', default='gmflownet', help="mdoel class. ``_model.py should be in ./core and `Model` should be defined in this file")
182 | parser.add_argument('--ckpt', help="restored checkpoint")
183 |
184 | parser.add_argument('--dataset', help="dataset for evaluation")
185 | parser.add_argument('--use_mix_attn', action='store_true', help='use mixture of POLA and axial attentions')
186 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
187 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
188 | args = parser.parse_args()
189 |
190 | model = torch.nn.DataParallel(create_model(args))
191 | model.load_state_dict(torch.load(args.ckpt), strict=True)
192 |
193 | model.cuda()
194 | model.eval()
195 |
196 | # create_sintel_submission(model.module, warm_start=True)
197 | # create_kitti_submission(model.module)
198 |
199 | with torch.no_grad():
200 | if args.dataset == 'chairs':
201 | validate_chairs(model.module)
202 |
203 | elif args.dataset == 'sintel':
204 | validate_sintel(model.module)
205 |
206 | elif args.dataset == 'sintel_test':
207 | create_sintel_submission(model.module)
208 |
209 | elif args.dataset == 'kitti':
210 | validate_kitti(model.module)
211 |
212 | elif args.dataset == 'kitti_test':
213 | create_kitti_submission(model.module)
214 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import sys
3 | sys.path.append('core')
4 |
5 | import argparse, configparser
6 | import os
7 | import cv2
8 | import time
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.optim as optim
15 | import torch.nn.functional as F
16 |
17 | from torch.utils.data import DataLoader
18 | # from torch.optim import Adam as AdamW
19 | from torch.optim.adamw import AdamW
20 | from core.onecyclelr import OneCycleLR
21 | from core import create_model
22 |
23 | from loss import compute_supervision_coarse, compute_coarse_loss, backwarp
24 |
25 | import evaluate
26 | import datasets
27 |
28 | from tensorboardX import SummaryWriter
29 |
30 | try:
31 | from torch.cuda.amp import GradScaler
32 | except:
33 | # dummy GradScaler for PyTorch < 1.6
34 | class GradScaler:
35 | def __init__(self, enabled=False):
36 | pass
37 | def scale(self, loss):
38 | return loss
39 | def unscale_(self, optimizer):
40 | pass
41 | def step(self, optimizer):
42 | optimizer.step()
43 | def update(self):
44 | pass
45 |
46 |
47 | # exclude extremely large displacements
48 | MAX_FLOW = 400
49 | SUM_FREQ = 100
50 | VAL_FREQ = 5000
51 |
52 |
53 | def sequence_loss(train_outputs, image1, image2, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW, use_matching_loss=False):
54 | """ Loss function defined over sequence of flow predictions """
55 | flow_preds, softCorrMap = train_outputs
56 |
57 | # original RAFT loss
58 | n_predictions = len(flow_preds)
59 | flow_loss = 0.0
60 |
61 | # exclude invalid pixels and extremely large displacements
62 | mag = torch.sum(flow_gt**2, dim=1).sqrt()
63 | valid = (valid >= 0.5) & (mag < max_flow)
64 |
65 | for i in range(n_predictions):
66 | i_weight = gamma**(n_predictions - i - 1)
67 | i_loss = (flow_preds[i] - flow_gt).abs()
68 | flow_loss += i_weight * (valid[:, None].float() * i_loss).mean()
69 |
70 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
71 | epe = epe.view(-1)[valid.view(-1)]
72 |
73 | metrics = {
74 | 'epe': epe.mean().item(),
75 | '1px': (epe < 1).float().mean().item(),
76 | '3px': (epe < 3).float().mean().item(),
77 | '5px': (epe < 5).float().mean().item(),
78 | }
79 |
80 | if use_matching_loss:
81 | # enable global matching loss. Try to use it in late stages of the trianing
82 | img_2back1 = backwarp(image2, flow_gt)
83 | occlusionMap = (image1 - img_2back1).mean(1, keepdims=True) #(N, H, W)
84 | occlusionMap = torch.abs(occlusionMap) > 20
85 | occlusionMap = occlusionMap.float()
86 |
87 | conf_matrix_gt = compute_supervision_coarse(flow_gt, occlusionMap, 8) # 8 from RAFT downsample
88 |
89 | matchLossCfg = configparser.ConfigParser()
90 | matchLossCfg.POS_WEIGHT = 1
91 | matchLossCfg.NEG_WEIGHT = 1
92 | matchLossCfg.FOCAL_ALPHA = 0.25
93 | matchLossCfg.FOCAL_GAMMA = 2.0
94 | matchLossCfg.COARSE_TYPE = 'cross_entropy'
95 | match_loss = compute_coarse_loss(softCorrMap, conf_matrix_gt, matchLossCfg)
96 |
97 | flow_loss = flow_loss + 0.01*match_loss
98 |
99 | return flow_loss, metrics
100 |
101 |
102 | def count_parameters(model):
103 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
104 |
105 |
106 | def fetch_optimizer(args, model, last_iters=-1):
107 | """ Create the optimizer and learning rate scheduler """
108 | optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
109 |
110 | scheduler = OneCycleLR(optimizer, args.lr, args.num_steps+100,
111 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear', last_epoch=last_iters)
112 |
113 | return optimizer, scheduler
114 |
115 |
116 | class Logger:
117 | def __init__(self, model, scheduler, total_steps=0, log_dir=None):
118 | self.model = model
119 | self.scheduler = scheduler
120 | self.total_steps = total_steps
121 | self.running_loss = {}
122 | self.writer = None
123 | self.log_dir = log_dir
124 |
125 | def _print_training_status(self):
126 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
127 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
128 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
129 |
130 | # print the training status
131 | print(training_str + metrics_str)
132 |
133 | if self.writer is None:
134 | self.writer = SummaryWriter(logdir=self.log_dir)
135 |
136 | for k in self.running_loss:
137 | self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
138 | self.running_loss[k] = 0.0
139 |
140 | def push(self, metrics):
141 | self.total_steps += 1
142 |
143 | for key in metrics:
144 | if key not in self.running_loss:
145 | self.running_loss[key] = 0.0
146 |
147 | self.running_loss[key] += metrics[key]
148 |
149 | if self.total_steps % SUM_FREQ == SUM_FREQ-1:
150 | self._print_training_status()
151 | self.running_loss = {}
152 |
153 | def write_dict(self, results):
154 | if self.writer is None:
155 | self.writer = SummaryWriter(logdir=self.log_dir)
156 |
157 | for key in results:
158 | self.writer.add_scalar(key, results[key], self.total_steps)
159 |
160 | def close(self):
161 | self.writer.close()
162 |
163 |
164 | def train(args):
165 |
166 | model = nn.DataParallel(create_model(args), device_ids=args.gpus)
167 | print("Parameter Count: %d" % count_parameters(model))
168 |
169 | if args.restore_ckpt is not None:
170 | model.load_state_dict(torch.load(args.restore_ckpt), strict=True)
171 |
172 | model.cuda()
173 | model.train()
174 |
175 | if args.stage != 'chairs':
176 | model.module.freeze_bn()
177 |
178 | if args.restore_ckpt is not None:
179 | strStep = os.path.split(args.restore_ckpt)[-1].split('_')[0]
180 | total_steps = int(strStep) if strStep.isdigit() else 0
181 | else:
182 | total_steps = 0
183 |
184 | train_loader = datasets.fetch_dataloader(args, TRAIN_DS='C+T+K/S')
185 | optimizer, scheduler = fetch_optimizer(args, model, total_steps)
186 |
187 | scaler = GradScaler(enabled=args.mixed_precision)
188 | logger = Logger(model, scheduler, total_steps, os.path.join('runs', args.name))
189 |
190 | add_noise = True
191 |
192 | should_keep_training = True
193 | while should_keep_training:
194 |
195 | for i_batch, data_blob in enumerate(train_loader):
196 | optimizer.zero_grad()
197 | image1, image2, flow, valid = [x.cuda() for x in data_blob]
198 |
199 | if args.add_noise:
200 | stdv = np.random.uniform(0.0, 5.0)
201 | image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
202 | image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
203 |
204 | flow_predictions = model(image1, image2, iters=args.iters)
205 |
206 | loss, metrics = sequence_loss(flow_predictions, image1, image2, flow, valid, gamma=args.gamma, use_matching_loss=args.use_mix_attn)
207 | scaler.scale(loss).backward()
208 | scaler.unscale_(optimizer)
209 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
210 |
211 | scaler.step(optimizer)
212 | scheduler.step()
213 | scaler.update()
214 |
215 | logger.push(metrics)
216 |
217 | if total_steps % VAL_FREQ == VAL_FREQ - 1:
218 | PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
219 | torch.save(model.state_dict(), PATH)
220 |
221 | results = {}
222 | for val_dataset in args.validation:
223 | if val_dataset == 'chairs':
224 | results.update(evaluate.validate_chairs(model.module))
225 | elif val_dataset == 'sintel':
226 | results.update(evaluate.validate_sintel(model.module))
227 | elif val_dataset == 'kitti':
228 | results.update(evaluate.validate_kitti(model.module))
229 |
230 | logger.write_dict(results)
231 |
232 | model.train()
233 | if args.stage != 'chairs':
234 | model.module.freeze_bn()
235 |
236 | total_steps += 1
237 |
238 | if total_steps > args.num_steps:
239 | should_keep_training = False
240 | break
241 |
242 | logger.close()
243 | PATH = 'checkpoints/%s.pth' % args.name
244 | torch.save(model.state_dict(), PATH)
245 |
246 | return PATH
247 |
248 |
249 | if __name__ == '__main__':
250 | parser = argparse.ArgumentParser()
251 | parser.add_argument('--name', default='gmflownet', help="name of your experiment. The saved checkpoint will be named after this in `./checkpoints/.`")
252 | parser.add_argument('--model', default='gmflownet', help="mdoel class. ``_model.py should be in ./core and `Model` should be defined in this file")
253 | parser.add_argument('--stage', help="determines which dataset to use for training")
254 | parser.add_argument('--restore_ckpt', help="restore checkpoint")
255 | parser.add_argument('--use_mix_attn', action='store_true', help='use mixture of POLA and axial attentions')
256 | parser.add_argument('--validation', type=str, nargs='+')
257 |
258 | parser.add_argument('--lr', type=float, default=0.00002)
259 | parser.add_argument('--num_steps', type=int, default=100000)
260 | parser.add_argument('--batch_size', type=int, default=6)
261 | parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
262 | parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
263 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
264 |
265 | parser.add_argument('--iters', type=int, default=12)
266 | parser.add_argument('--wdecay', type=float, default=.00005)
267 | parser.add_argument('--epsilon', type=float, default=1e-8)
268 | parser.add_argument('--clip', type=float, default=1.0)
269 | parser.add_argument('--dropout', type=float, default=0.0)
270 | parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
271 | parser.add_argument('--add_noise', action='store_true')
272 | args = parser.parse_args()
273 |
274 | torch.set_num_threads(16)
275 |
276 | torch.manual_seed(1234)
277 | np.random.seed(1234)
278 |
279 | if not os.path.exists('checkpoints'):
280 | os.mkdir('checkpoints')
281 | if not os.path.exists('runs'):
282 | os.mkdir('runs')
283 |
284 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.gpus))
285 | args.gpus = [i for i in range(len(args.gpus))]
286 | train(args)
287 |
--------------------------------------------------------------------------------
/train_gmflownet.sh:
--------------------------------------------------------------------------------
1 | python -u train.py --model gmflownet --name gmflownet-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 120000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
2 | python -u train.py --model gmflownet --name gmflownet-things --stage things --validation sintel kitti --restore_ckpt checkpoints/gmflownet-chairs.pth --gpus 0 1 --num_steps 160000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
3 | python -u train.py --model gmflownet --name gmflownet-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/gmflownet-things.pth --gpus 0 1 --num_steps 160000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma 0.85
4 | python -u train.py --model gmflownet --name gmflownet-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/gmflownet-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85
5 |
6 |
7 |
--------------------------------------------------------------------------------
/train_gmflownet_mix.sh:
--------------------------------------------------------------------------------
1 | python -u train.py --model gmflownet --name gmflownet_mix-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 120000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 --use_mix_attn
2 | python -u train.py --model gmflownet --name gmflownet_mix-things --stage things --validation sintel kitti --restore_ckpt checkpoints/gmflownet_mix-chairs.pth --gpus 0 1 --num_steps 160000 --batch_size 8 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 --use_mix_attn
3 | python -u train.py --model gmflownet --name gmflownet_mix-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/gmflownet_mix-things.pth --gpus 0 1 --num_steps 160000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma 0.85 --use_mix_attn
4 | python -u train.py --model gmflownet --name gmflownet_mix-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/gmflownet_mix-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --use_mix_attn
5 |
--------------------------------------------------------------------------------