├── .gitignore ├── LICENSE ├── README.md ├── alt_cuda_corr ├── correlation.cpp ├── correlation_kernel.cu └── setup.py ├── chairs_split.txt ├── core ├── __init__.py ├── corr.py ├── datasets.py ├── extractor.py ├── gma.py ├── gmflownet_model.py ├── loss.py ├── onecyclelr.py ├── onecyclelr.py.save ├── raft_gma_model.py ├── raft_model.py ├── swin_transformer.py ├── update.py └── utils │ ├── __init__.py │ ├── augmentor.py │ ├── drop.py │ ├── flow_viz.py │ ├── frame_utils.py │ ├── helpers.py │ ├── utils.py │ └── weight_init.py ├── demo-frames ├── frame_0016.png ├── frame_0017.png ├── frame_0018.png ├── frame_0019.png ├── frame_0020.png ├── frame_0021.png ├── frame_0022.png ├── frame_0023.png ├── frame_0024.png └── frame_0025.png ├── demo.py ├── evaluate.py ├── train.py ├── train_gmflownet.sh └── train_gmflownet_mix.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | dist 4 | datasets 5 | datasets/ 6 | pretrained_models/ 7 | build/ 8 | correlation.egg-info 9 | checkpoints/ 10 | runs/ 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shiyu Zhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GMFlowNet 2 | 3 | This repository contains the official implementation for the paper: 4 | 5 | [Global Matching with Overlapping Attention for Optical Flow Estimation](https://arxiv.org/abs/2203.11335)
6 | CVPR 2022
7 | Shiyu Zhao, Long Zhao, Zhixing Zhang, Enyu Zhou, Dimitris Metaxas
8 | 9 | ## Requirements 10 | The code has been tested with PyTorch 1.7 and Cuda 11.0. Later PyTorch may also work. 11 | ```Shell 12 | conda create --name gmflownet 13 | conda activate gmflownet 14 | conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=11.0 -c pytorch 15 | conda install matplotlib tensorboard scipy opencv 16 | ``` 17 | 18 | ## Demos 19 | Download .zip file with pretrained models at [Google Drive](https://drive.google.com/file/d/1rVfu0j9O1M2hNsew-dVRlx9jL7c6ICru/view?usp=sharing). Unzip `pretrained_models.zip` in the root. 20 | ```Shell 21 | unzip pretrained_models.zip 22 | ``` 23 | 24 | You can demo a trained model on a sequence of frames 25 | ```Shell 26 | python demo.py --model gmflownet --ckpt=pretrained_models/gmflownet-things.pth --path=demo-frames 27 | ``` 28 | 29 | ## Required Data 30 | To evaluate/train RAFT, you need to download the following datasets. 31 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 32 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 33 | * [Sintel](http://sintel.is.tue.mpg.de/) 34 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 35 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional) 36 | 37 | Place all datasets in your preferred directory and symbolic link it to `./datasets` with `ln -s ./datasets` so that your `./datasets` folder looks like 38 | ```Shell 39 | ├── datasets 40 | ├── Sintel 41 | ├── test 42 | ├── training 43 | ├── KITTI 44 | ├── testing 45 | ├── training 46 | ├── devkit 47 | ├── FlyingChairs_release 48 | ├── data 49 | ├── FlyingThings3D 50 | ├── frames_cleanpass 51 | ├── frames_finalpass 52 | ├── optical_flow 53 | ... 54 | ``` 55 | 56 | ## Evaluation 57 | Download the pretraind model described in [Demo](https://github.com/xiaofeng94/GMFlowNet/blob/master/README.md#demos). 58 | You may evaluate a pretrained model using `evaluate.py`. To get the best result, 59 | 60 | On Sintel, evaluate the `gmflownet_mix` model as, 61 | ```Shell 62 | python evaluate.py --model gmflownet --use_mix_attn --ckpt=pretrained_models/gmflownet_mix-things.pth --dataset=sintel 63 | ``` 64 | On KITTI, evaluate the `gmflownet` model as, 65 | ```Shell 66 | python evaluate.py --model gmflownet --ckpt=pretrained_models/gmflownet-things.pth --dataset=kitti 67 | ``` 68 | Note: `gmflownet_mix` replaces half of heads (4 out of 8 heads) in each POLA attention of `gmflownet` with heads of [axial attentions](https://arxiv.org/abs/2003.07853) and achieves better results on Sintel. 69 | 70 | 71 | ## Training 72 | We used the following training schedules in our paper (2 GPUs). 73 | 74 | - Train `gmflownet` as, 75 | ```Shell 76 | ./train_gmflownet.sh 77 | ``` 78 | - Train `gmflownet_mix` as, 79 | ```Shell 80 | ./train_gmflownet_mix.sh 81 | ``` 82 | 83 | Training logs will be written to the `./runs` which can be visualized using tensorboard as, 84 | ```Shell 85 | tensorboard --bind_all --port 8080 --logdir ./runs 86 | ``` 87 | 88 | 89 | ## Acknowledgement 90 | The code is based on [RAFT](https://github.com/princeton-vl/RAFT) and [SwinTransformer](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection). 91 | We sincerely thank the authors for their great work. 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /alt_cuda_corr/correlation_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | 7 | #define BLOCK_H 4 8 | #define BLOCK_W 8 9 | #define BLOCK_HW BLOCK_H * BLOCK_W 10 | #define CHANNEL_STRIDE 32 11 | 12 | 13 | __forceinline__ __device__ 14 | bool within_bounds(int h, int w, int H, int W) { 15 | return h >= 0 && h < H && w >= 0 && w < W; 16 | } 17 | 18 | template 19 | __global__ void corr_forward_kernel( 20 | const torch::PackedTensorAccessor32 fmap1, 21 | const torch::PackedTensorAccessor32 fmap2, 22 | const torch::PackedTensorAccessor32 coords, 23 | torch::PackedTensorAccessor32 corr, 24 | int r) 25 | { 26 | const int b = blockIdx.x; 27 | const int h0 = blockIdx.y * blockDim.x; 28 | const int w0 = blockIdx.z * blockDim.y; 29 | const int tid = threadIdx.x * blockDim.y + threadIdx.y; 30 | 31 | const int H1 = fmap1.size(1); 32 | const int W1 = fmap1.size(2); 33 | const int H2 = fmap2.size(1); 34 | const int W2 = fmap2.size(2); 35 | const int N = coords.size(1); 36 | const int C = fmap1.size(3); 37 | 38 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; 39 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; 40 | __shared__ scalar_t x2s[BLOCK_HW]; 41 | __shared__ scalar_t y2s[BLOCK_HW]; 42 | 43 | for (int c=0; c(floor(y2s[k1]))-r+iy; 76 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 77 | int c2 = tid % CHANNEL_STRIDE; 78 | 79 | auto fptr = fmap2[b][h2][w2]; 80 | if (within_bounds(h2, w2, H2, W2)) 81 | f2[c2][k1] = fptr[c+c2]; 82 | else 83 | f2[c2][k1] = 0.0; 84 | } 85 | 86 | __syncthreads(); 87 | 88 | scalar_t s = 0.0; 89 | for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) 105 | *(corr_ptr + ix_nw) += nw; 106 | 107 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) 108 | *(corr_ptr + ix_ne) += ne; 109 | 110 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) 111 | *(corr_ptr + ix_sw) += sw; 112 | 113 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) 114 | *(corr_ptr + ix_se) += se; 115 | } 116 | } 117 | } 118 | } 119 | } 120 | 121 | 122 | template 123 | __global__ void corr_backward_kernel( 124 | const torch::PackedTensorAccessor32 fmap1, 125 | const torch::PackedTensorAccessor32 fmap2, 126 | const torch::PackedTensorAccessor32 coords, 127 | const torch::PackedTensorAccessor32 corr_grad, 128 | torch::PackedTensorAccessor32 fmap1_grad, 129 | torch::PackedTensorAccessor32 fmap2_grad, 130 | torch::PackedTensorAccessor32 coords_grad, 131 | int r) 132 | { 133 | 134 | const int b = blockIdx.x; 135 | const int h0 = blockIdx.y * blockDim.x; 136 | const int w0 = blockIdx.z * blockDim.y; 137 | const int tid = threadIdx.x * blockDim.y + threadIdx.y; 138 | 139 | const int H1 = fmap1.size(1); 140 | const int W1 = fmap1.size(2); 141 | const int H2 = fmap2.size(1); 142 | const int W2 = fmap2.size(2); 143 | const int N = coords.size(1); 144 | const int C = fmap1.size(3); 145 | 146 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; 147 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; 148 | 149 | __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1]; 150 | __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1]; 151 | 152 | __shared__ scalar_t x2s[BLOCK_HW]; 153 | __shared__ scalar_t y2s[BLOCK_HW]; 154 | 155 | for (int c=0; c(floor(y2s[k1]))-r+iy; 190 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 191 | int c2 = tid % CHANNEL_STRIDE; 192 | 193 | auto fptr = fmap2[b][h2][w2]; 194 | if (within_bounds(h2, w2, H2, W2)) 195 | f2[c2][k1] = fptr[c+c2]; 196 | else 197 | f2[c2][k1] = 0.0; 198 | 199 | f2_grad[c2][k1] = 0.0; 200 | } 201 | 202 | __syncthreads(); 203 | 204 | const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1]; 205 | scalar_t g = 0.0; 206 | 207 | int ix_nw = H1*W1*((iy-1) + rd*(ix-1)); 208 | int ix_ne = H1*W1*((iy-1) + rd*ix); 209 | int ix_sw = H1*W1*(iy + rd*(ix-1)); 210 | int ix_se = H1*W1*(iy + rd*ix); 211 | 212 | if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) 213 | g += *(grad_ptr + ix_nw) * dy * dx; 214 | 215 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) 216 | g += *(grad_ptr + ix_ne) * dy * (1-dx); 217 | 218 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) 219 | g += *(grad_ptr + ix_sw) * (1-dy) * dx; 220 | 221 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) 222 | g += *(grad_ptr + ix_se) * (1-dy) * (1-dx); 223 | 224 | for (int k=0; k(floor(y2s[k1]))-r+iy; 232 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 233 | int c2 = tid % CHANNEL_STRIDE; 234 | 235 | scalar_t* fptr = &fmap2_grad[b][h2][w2][0]; 236 | if (within_bounds(h2, w2, H2, W2)) 237 | atomicAdd(fptr+c+c2, f2_grad[c2][k1]); 238 | } 239 | } 240 | } 241 | } 242 | __syncthreads(); 243 | 244 | 245 | for (int k=0; k corr_cuda_forward( 261 | torch::Tensor fmap1, 262 | torch::Tensor fmap2, 263 | torch::Tensor coords, 264 | int radius) 265 | { 266 | const auto B = coords.size(0); 267 | const auto N = coords.size(1); 268 | const auto H = coords.size(2); 269 | const auto W = coords.size(3); 270 | 271 | const auto rd = 2 * radius + 1; 272 | auto opts = fmap1.options(); 273 | auto corr = torch::zeros({B, N, rd*rd, H, W}, opts); 274 | 275 | const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W); 276 | const dim3 threads(BLOCK_H, BLOCK_W); 277 | 278 | corr_forward_kernel<<>>( 279 | fmap1.packed_accessor32(), 280 | fmap2.packed_accessor32(), 281 | coords.packed_accessor32(), 282 | corr.packed_accessor32(), 283 | radius); 284 | 285 | return {corr}; 286 | } 287 | 288 | std::vector corr_cuda_backward( 289 | torch::Tensor fmap1, 290 | torch::Tensor fmap2, 291 | torch::Tensor coords, 292 | torch::Tensor corr_grad, 293 | int radius) 294 | { 295 | const auto B = coords.size(0); 296 | const auto N = coords.size(1); 297 | 298 | const auto H1 = fmap1.size(1); 299 | const auto W1 = fmap1.size(2); 300 | const auto H2 = fmap2.size(1); 301 | const auto W2 = fmap2.size(2); 302 | const auto C = fmap1.size(3); 303 | 304 | auto opts = fmap1.options(); 305 | auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts); 306 | auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts); 307 | auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts); 308 | 309 | const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W); 310 | const dim3 threads(BLOCK_H, BLOCK_W); 311 | 312 | 313 | corr_backward_kernel<<>>( 314 | fmap1.packed_accessor32(), 315 | fmap2.packed_accessor32(), 316 | coords.packed_accessor32(), 317 | corr_grad.packed_accessor32(), 318 | fmap1_grad.packed_accessor32(), 319 | fmap2_grad.packed_accessor32(), 320 | coords_grad.packed_accessor32(), 321 | radius); 322 | 323 | return {fmap1_grad, fmap2_grad, coords_grad}; 324 | } -------------------------------------------------------------------------------- /alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.nn as nn 3 | 4 | 5 | def find_model_using_name(model_name): 6 | """Import the module "models/[model_name]_model.py". 7 | 8 | In the file, the class called [model_name]Model() will 9 | be instantiated. It has to be a subclass of nn.Module, 10 | and it is case-insensitive. 11 | """ 12 | model_filename = model_name + "_model" 13 | modellib = importlib.import_module(model_filename) 14 | model = None 15 | target_model_name = model_name.replace('_', '') + 'model' 16 | for name, cls in modellib.__dict__.items(): 17 | if name.lower() == target_model_name.lower() \ 18 | and issubclass(cls, nn.Module): 19 | model = cls 20 | 21 | if model is None: 22 | print("In %s.py, there should be a subclass of nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name)) 23 | exit(0) 24 | 25 | return model 26 | 27 | 28 | def create_model(opt): 29 | """Create a model given the option. 30 | 31 | This function warps the class CustomDatasetDataLoader. 32 | This is the main interface between this package and 'train.py'/'test.py' 33 | 34 | Example: 35 | >>> from core import create_model 36 | >>> model = create_model(opt) 37 | """ 38 | model = find_model_using_name(opt.model) 39 | instance = model(opt) 40 | print("model [%s] was created" % type(instance).__name__) 41 | return instance -------------------------------------------------------------------------------- /core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | self.corrMap = corr.view(batch, h1*w1, h2*w2) 23 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 24 | 25 | 26 | self.corr_pyramid.append(corr) 27 | for i in range(self.num_levels-1): 28 | corr = F.avg_pool2d(corr, 2, stride=2) 29 | self.corr_pyramid.append(corr) 30 | 31 | def __call__(self, coords): 32 | r = self.radius 33 | coords = coords.permute(0, 2, 3, 1) 34 | batch, h1, w1, _ = coords.shape 35 | 36 | out_pyramid = [] 37 | for i in range(self.num_levels): 38 | corr = self.corr_pyramid[i] 39 | dx = torch.linspace(-r, r, 2*r+1) 40 | dy = torch.linspace(-r, r, 2*r+1) 41 | delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device) 42 | 43 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 44 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 45 | coords_lvl = centroid_lvl + delta_lvl 46 | 47 | corr = bilinear_sampler(corr, coords_lvl) 48 | corr = corr.view(batch, h1, w1, -1) 49 | out_pyramid.append(corr) 50 | 51 | out = torch.cat(out_pyramid, dim=-1) 52 | return out.permute(0, 3, 1, 2).contiguous().float() 53 | 54 | @staticmethod 55 | def corr(fmap1, fmap2): 56 | batch, dim, ht, wd = fmap1.shape 57 | fmap1 = fmap1.view(batch, dim, ht*wd) 58 | fmap2 = fmap2.view(batch, dim, ht*wd) 59 | 60 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 61 | corr = corr.view(batch, ht, wd, 1, ht, wd) 62 | return corr / torch.sqrt(torch.tensor(dim).float()) 63 | 64 | 65 | class AlternateCorrBlock: 66 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 67 | self.num_levels = num_levels 68 | self.radius = radius 69 | 70 | self.pyramid = [(fmap1, fmap2)] 71 | for i in range(self.num_levels): 72 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 73 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 74 | self.pyramid.append((fmap1, fmap2)) 75 | 76 | def __call__(self, coords): 77 | coords = coords.permute(0, 2, 3, 1) 78 | B, H, W, _ = coords.shape 79 | dim = self.pyramid[0][0].shape[1] 80 | 81 | corr_list = [] 82 | for i in range(self.num_levels): 83 | r = self.radius 84 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 85 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 86 | 87 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 88 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 89 | corr_list.append(corr.squeeze(1)) 90 | 91 | corr = torch.stack(corr_list, dim=1) 92 | corr = corr.reshape(B, -1, H, W) 93 | return corr / torch.sqrt(torch.tensor(dim).float()) 94 | -------------------------------------------------------------------------------- /core/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | 8 | import os 9 | import math 10 | import random 11 | from glob import glob 12 | import os.path as osp 13 | 14 | from core.utils import frame_utils 15 | from core.utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | import sys 18 | 19 | class FlowDataset(data.Dataset): 20 | def __init__(self, aug_params=None, sparse=False): 21 | self.augmentor = None 22 | self.sparse = sparse 23 | if aug_params is not None: 24 | if sparse: 25 | self.augmentor = SparseFlowAugmentor(**aug_params) 26 | else: 27 | self.augmentor = FlowAugmentor(**aug_params) 28 | 29 | self.is_test = False 30 | self.is_validate = False 31 | self.init_seed = False 32 | self.flow_list = [] 33 | self.image_list = [] 34 | self.extra_info = [] 35 | 36 | def __getitem__(self, index): 37 | # print('Index is {}'.format(index)) 38 | # sys.stdout.flush() 39 | if self.is_test: 40 | img1 = frame_utils.read_gen(self.image_list[index][0]) 41 | img2 = frame_utils.read_gen(self.image_list[index][1]) 42 | img1 = np.array(img1).astype(np.uint8)[..., :3] 43 | img2 = np.array(img2).astype(np.uint8)[..., :3] 44 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 45 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 46 | return img1, img2, self.extra_info[index] 47 | 48 | # if not self.init_seed: 49 | # worker_info = torch.utils.data.get_worker_info() 50 | # if worker_info is not None: 51 | # torch.manual_seed(worker_info.id) 52 | # np.random.seed(worker_info.id) 53 | # random.seed(worker_info.id) 54 | # self.init_seed = True 55 | 56 | index = index % len(self.image_list) 57 | valid = None 58 | if self.sparse: 59 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 60 | else: 61 | flow = frame_utils.read_gen(self.flow_list[index]) 62 | 63 | img1 = frame_utils.read_gen(self.image_list[index][0]) 64 | img2 = frame_utils.read_gen(self.image_list[index][1]) 65 | 66 | flow = np.array(flow).astype(np.float32) 67 | img1 = np.array(img1).astype(np.uint8) 68 | img2 = np.array(img2).astype(np.uint8) 69 | 70 | # grayscale images 71 | if len(img1.shape) == 2: 72 | img1 = np.tile(img1[...,None], (1, 1, 3)) 73 | img2 = np.tile(img2[...,None], (1, 1, 3)) 74 | else: 75 | img1 = img1[..., :3] 76 | img2 = img2[..., :3] 77 | 78 | if self.augmentor is not None: 79 | if self.sparse: 80 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 81 | else: 82 | img1, img2, flow = self.augmentor(img1, img2, flow) 83 | 84 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 85 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 86 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 87 | 88 | if valid is not None: 89 | valid = torch.from_numpy(valid) 90 | else: 91 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 92 | 93 | if self.is_validate: 94 | return img1, img2, flow, valid.float(), self.extra_info[index] 95 | else: 96 | return img1, img2, flow, valid.float() 97 | 98 | def getDataWithPath(self, index): 99 | img1, img2, flow, valid = self.__getitem__(index) 100 | 101 | imgPath_1 = self.image_list[index][0] 102 | imgPath_2 = self.image_list[index][1] 103 | 104 | return img1, img2, flow, valid, imgPath_1, imgPath_2 105 | 106 | def __rmul__(self, v): 107 | self.flow_list = v * self.flow_list 108 | self.image_list = v * self.image_list 109 | return self 110 | 111 | def __len__(self): 112 | return len(self.image_list) 113 | 114 | 115 | class MpiSintel(FlowDataset): 116 | def __init__(self, aug_params=None, split='training', root='./datasets/Sintel', dstype='clean', is_validate=False): 117 | super(MpiSintel, self).__init__(aug_params) 118 | flow_root = osp.join(root, split, 'flow') 119 | image_root = osp.join(root, split, dstype) 120 | 121 | self.is_validate = is_validate 122 | if split == 'test': 123 | self.is_test = True 124 | 125 | for scene in os.listdir(image_root): 126 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 127 | for i in range(len(image_list)-1): 128 | self.image_list += [ [image_list[i], image_list[i+1]] ] 129 | self.extra_info += [ (scene, i) ] # scene and frame_id 130 | 131 | if split != 'test': 132 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 133 | 134 | 135 | class FlyingChairs(FlowDataset): 136 | def __init__(self, aug_params=None, split='train', root='./datasets/FlyingChairs_release/data'): 137 | super(FlyingChairs, self).__init__(aug_params) 138 | 139 | images = sorted(glob(osp.join(root, '*.ppm'))) 140 | flows = sorted(glob(osp.join(root, '*.flo'))) 141 | assert (len(images)//2 == len(flows)) 142 | 143 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 144 | for i in range(len(flows)): 145 | xid = split_list[i] 146 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 147 | self.flow_list += [ flows[i] ] 148 | self.image_list += [ [images[2*i], images[2*i+1]] ] 149 | 150 | 151 | class FlyingThings3D(FlowDataset): 152 | def __init__(self, aug_params=None, root='./datasets/FlyingThings3D', dstype='frames_cleanpass'): 153 | super(FlyingThings3D, self).__init__(aug_params) 154 | 155 | for cam in ['left']: 156 | for direction in ['into_future', 'into_past']: 157 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 158 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 159 | 160 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 161 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 162 | 163 | for idir, fdir in zip(image_dirs, flow_dirs): 164 | images = sorted(glob(osp.join(idir, '*.png')) ) 165 | flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 166 | for i in range(len(flows)-1): 167 | if direction == 'into_future': 168 | self.image_list += [ [images[i], images[i+1]] ] 169 | self.flow_list += [ flows[i] ] 170 | elif direction == 'into_past': 171 | self.image_list += [ [images[i+1], images[i]] ] 172 | self.flow_list += [ flows[i+1] ] 173 | 174 | 175 | class KITTI(FlowDataset): 176 | def __init__(self, aug_params=None, split='training', root='./datasets/kitti'): 177 | super(KITTI, self).__init__(aug_params, sparse=True) 178 | if split == 'testing': 179 | self.is_test = True 180 | 181 | root = osp.join(root, split) 182 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 183 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 184 | 185 | for img1, img2 in zip(images1, images2): 186 | frame_id = img1.split('/')[-1] 187 | self.extra_info += [ [frame_id] ] 188 | self.image_list += [ [img1, img2] ] 189 | 190 | if split == 'training': 191 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 192 | 193 | 194 | class HD1K(FlowDataset): 195 | def __init__(self, aug_params=None, root='./datasets/hd1k_full_package'): 196 | super(HD1K, self).__init__(aug_params, sparse=True) 197 | 198 | seq_ix = 0 199 | while 1: 200 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 201 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 202 | 203 | if len(flows) == 0: 204 | break 205 | 206 | for i in range(len(flows)-1): 207 | self.flow_list += [flows[i]] 208 | self.image_list += [ [images[i], images[i+1]] ] 209 | 210 | seq_ix += 1 211 | 212 | 213 | def fetch_dataloader(args, TRAIN_DS='C+T+K/S'): 214 | """ Create the data loader for the corresponding trainign set """ 215 | 216 | if args.stage == 'chairs': 217 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 218 | train_dataset = FlyingChairs(aug_params, split='training') 219 | 220 | elif args.stage == 'things': 221 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 222 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 223 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 224 | train_dataset = clean_dataset + final_dataset 225 | 226 | elif args.stage == 'sintel': 227 | print('Training Sintel Stage...') 228 | sys.stdout.flush() 229 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 230 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 231 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 232 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 233 | 234 | if TRAIN_DS == 'C+T+K+S+H': 235 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 236 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 237 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things 238 | 239 | elif TRAIN_DS == 'C+T+K+H': 240 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 241 | train_dataset = 100*sintel_clean + 100*sintel_final + 5*hd1k + things 242 | 243 | elif TRAIN_DS == 'C+T+S+K': 244 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 245 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + things 246 | 247 | elif TRAIN_DS == 'C+T+K/S': 248 | train_dataset = 100*sintel_clean + 100*sintel_final + things 249 | 250 | elif args.stage == 'kitti': 251 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 252 | train_dataset = KITTI(aug_params, split='training') 253 | 254 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 255 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True) 256 | 257 | print('Training with %d image pairs' % len(train_dataset)) 258 | return train_loader 259 | -------------------------------------------------------------------------------- /core/extractor.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class ResidualBlock(nn.Module): 8 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 9 | super(ResidualBlock, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | num_groups = planes // 8 16 | 17 | if norm_fn == 'group': 18 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 20 | if not stride == 1: 21 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 22 | 23 | elif norm_fn == 'batch': 24 | self.norm1 = nn.BatchNorm2d(planes) 25 | self.norm2 = nn.BatchNorm2d(planes) 26 | if not stride == 1: 27 | self.norm3 = nn.BatchNorm2d(planes) 28 | 29 | elif norm_fn == 'instance': 30 | self.norm1 = nn.InstanceNorm2d(planes) 31 | self.norm2 = nn.InstanceNorm2d(planes) 32 | if not stride == 1: 33 | self.norm3 = nn.InstanceNorm2d(planes) 34 | 35 | elif norm_fn == 'none': 36 | self.norm1 = nn.Sequential() 37 | self.norm2 = nn.Sequential() 38 | if not stride == 1: 39 | self.norm3 = nn.Sequential() 40 | 41 | if stride == 1: 42 | self.downsample = None 43 | 44 | else: 45 | self.downsample = nn.Sequential( 46 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 47 | 48 | 49 | def forward(self, x): 50 | y = x 51 | y = self.relu(self.norm1(self.conv1(y))) 52 | y = self.relu(self.norm2(self.conv2(y))) 53 | 54 | if self.downsample is not None: 55 | x = self.downsample(x) 56 | 57 | return self.relu(x+y) 58 | 59 | 60 | 61 | class BottleneckBlock(nn.Module): 62 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 63 | super(BottleneckBlock, self).__init__() 64 | 65 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 66 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 67 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 68 | self.relu = nn.ReLU(inplace=True) 69 | 70 | num_groups = planes // 8 71 | 72 | if norm_fn == 'group': 73 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 75 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 76 | if not stride == 1: 77 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 78 | 79 | elif norm_fn == 'batch': 80 | self.norm1 = nn.BatchNorm2d(planes//4) 81 | self.norm2 = nn.BatchNorm2d(planes//4) 82 | self.norm3 = nn.BatchNorm2d(planes) 83 | if not stride == 1: 84 | self.norm4 = nn.BatchNorm2d(planes) 85 | 86 | elif norm_fn == 'instance': 87 | self.norm1 = nn.InstanceNorm2d(planes//4) 88 | self.norm2 = nn.InstanceNorm2d(planes//4) 89 | self.norm3 = nn.InstanceNorm2d(planes) 90 | if not stride == 1: 91 | self.norm4 = nn.InstanceNorm2d(planes) 92 | 93 | elif norm_fn == 'none': 94 | self.norm1 = nn.Sequential() 95 | self.norm2 = nn.Sequential() 96 | self.norm3 = nn.Sequential() 97 | if not stride == 1: 98 | self.norm4 = nn.Sequential() 99 | 100 | if stride == 1: 101 | self.downsample = None 102 | 103 | else: 104 | self.downsample = nn.Sequential( 105 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 106 | 107 | 108 | def forward(self, x): 109 | y = x 110 | y = self.relu(self.norm1(self.conv1(y))) 111 | y = self.relu(self.norm2(self.conv2(y))) 112 | y = self.relu(self.norm3(self.conv3(y))) 113 | 114 | if self.downsample is not None: 115 | x = self.downsample(x) 116 | 117 | return self.relu(x+y) 118 | 119 | class BasicEncoder(nn.Module): 120 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 121 | super(BasicEncoder, self).__init__() 122 | self.norm_fn = norm_fn 123 | 124 | if self.norm_fn == 'group': 125 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 126 | 127 | elif self.norm_fn == 'batch': 128 | self.norm1 = nn.BatchNorm2d(64) 129 | 130 | elif self.norm_fn == 'instance': 131 | self.norm1 = nn.InstanceNorm2d(64) 132 | 133 | elif self.norm_fn == 'none': 134 | self.norm1 = nn.Sequential() 135 | 136 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 137 | self.relu1 = nn.ReLU(inplace=True) 138 | 139 | self.in_planes = 64 140 | self.layer1 = self._make_layer(64, stride=1) 141 | self.layer2 = self._make_layer(96, stride=2) 142 | self.layer3 = self._make_layer(128, stride=2) 143 | 144 | # output convolution 145 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 146 | 147 | self.dropout = None 148 | if dropout > 0: 149 | self.dropout = nn.Dropout2d(p=dropout) 150 | 151 | for m in self.modules(): 152 | if isinstance(m, nn.Conv2d): 153 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 154 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 155 | if m.weight is not None: 156 | nn.init.constant_(m.weight, 1) 157 | if m.bias is not None: 158 | nn.init.constant_(m.bias, 0) 159 | 160 | def _make_layer(self, dim, stride=1): 161 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 162 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 163 | layers = (layer1, layer2) 164 | 165 | self.in_planes = dim 166 | return nn.Sequential(*layers) 167 | 168 | 169 | def forward(self, x): 170 | 171 | # if input is list, combine batch dimension 172 | is_list = isinstance(x, tuple) or isinstance(x, list) 173 | if is_list: 174 | batch_dim = x[0].shape[0] 175 | x = torch.cat(x, dim=0) 176 | 177 | x = self.conv1(x) 178 | x = self.norm1(x) 179 | x = self.relu1(x) 180 | 181 | x = self.layer1(x) 182 | x = self.layer2(x) 183 | x = self.layer3(x) 184 | 185 | x = self.conv2(x) 186 | 187 | if self.training and self.dropout is not None: 188 | x = self.dropout(x) 189 | 190 | if is_list: 191 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 192 | 193 | return x 194 | 195 | 196 | class IPTHeadEncoder(nn.Module): 197 | """docstring for IPTHead""" 198 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 199 | super(IPTHeadEncoder, self).__init__() 200 | self.norm_fn = norm_fn 201 | 202 | if self.norm_fn == 'group': 203 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 204 | 205 | elif self.norm_fn == 'batch': 206 | self.norm1 = nn.BatchNorm2d(64) 207 | 208 | elif self.norm_fn == 'instance': 209 | self.norm1 = nn.InstanceNorm2d(64) 210 | 211 | elif self.norm_fn == 'none': 212 | self.norm1 = nn.Sequential() 213 | 214 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 215 | self.relu1 = nn.ReLU(inplace=True) 216 | 217 | half_out_dim = max(output_dim // 2, 64) 218 | self.layer1 = ResidualBlock(64, half_out_dim, self.norm_fn, stride=2) 219 | self.layer2 = ResidualBlock(half_out_dim, output_dim, self.norm_fn, stride=2) 220 | 221 | # # output convolution; this can solve mixed memory warning, not know why 222 | # self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 223 | 224 | self.dropout = None 225 | if dropout > 0: 226 | self.dropout = nn.Dropout2d(p=dropout) 227 | 228 | def forward(self, x): 229 | 230 | # if input is list, combine batch dimension 231 | is_list = isinstance(x, tuple) or isinstance(x, list) 232 | if is_list: 233 | batch_dim = x[0].shape[0] 234 | x = torch.cat(x, dim=0) 235 | 236 | x = self.relu1(self.norm1(self.conv1(x))) 237 | x = self.layer1(x) 238 | x = self.layer2(x) 239 | 240 | if self.training and self.dropout is not None: 241 | x = self.dropout(x) 242 | 243 | if is_list: 244 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 245 | 246 | return x 247 | 248 | 249 | class BasicConvEncoder(nn.Module): 250 | """docstring for BasicConvEncoder""" 251 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 252 | super(BasicConvEncoder, self).__init__() 253 | self.norm_fn = norm_fn 254 | 255 | half_out_dim = max(output_dim // 2, 64) 256 | 257 | if self.norm_fn == 'group': 258 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 259 | self.norm2 = nn.GroupNorm(num_groups=8, num_channels=64) 260 | self.norm3 = nn.GroupNorm(num_groups=8, num_channels=64) 261 | 262 | elif self.norm_fn == 'batch': 263 | self.norm1 = nn.BatchNorm2d(64) 264 | self.norm2 = nn.BatchNorm2d(half_out_dim) 265 | self.norm3 = nn.BatchNorm2d(output_dim) 266 | 267 | elif self.norm_fn == 'instance': 268 | self.norm1 = nn.InstanceNorm2d(64) 269 | self.norm2 = nn.InstanceNorm2d(half_out_dim) 270 | self.norm3 = nn.InstanceNorm2d(output_dim) 271 | 272 | elif self.norm_fn == 'none': 273 | self.norm1 = nn.Sequential() 274 | self.norm2 = nn.Sequential() 275 | self.norm3 = nn.Sequential() 276 | 277 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 278 | self.conv2 = nn.Conv2d(64, half_out_dim, kernel_size=3, stride=2, padding=1) 279 | self.conv3 = nn.Conv2d(half_out_dim, output_dim, kernel_size=3, stride=2, padding=1) 280 | 281 | # # output convolution; this can solve mixed memory warning, not know why 282 | # self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 283 | 284 | self.dropout = None 285 | if dropout > 0: 286 | self.dropout = nn.Dropout2d(p=dropout) 287 | 288 | def forward(self, x): 289 | 290 | # if input is list, combine batch dimension 291 | is_list = isinstance(x, tuple) or isinstance(x, list) 292 | if is_list: 293 | batch_dim = x[0].shape[0] 294 | x = torch.cat(x, dim=0) 295 | 296 | x = F.relu(self.norm1(self.conv1(x)), inplace=True) 297 | x = F.relu(self.norm2(self.conv2(x)), inplace=True) 298 | x = F.relu(self.norm3(self.conv3(x)), inplace=True) 299 | 300 | if self.training and self.dropout is not None: 301 | x = self.dropout(x) 302 | 303 | if is_list: 304 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 305 | 306 | return x 307 | 308 | 309 | class SmallEncoder(nn.Module): 310 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 311 | super(SmallEncoder, self).__init__() 312 | self.norm_fn = norm_fn 313 | 314 | if self.norm_fn == 'group': 315 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 316 | 317 | elif self.norm_fn == 'batch': 318 | self.norm1 = nn.BatchNorm2d(32) 319 | 320 | elif self.norm_fn == 'instance': 321 | self.norm1 = nn.InstanceNorm2d(32) 322 | 323 | elif self.norm_fn == 'none': 324 | self.norm1 = nn.Sequential() 325 | 326 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 327 | self.relu1 = nn.ReLU(inplace=True) 328 | 329 | self.in_planes = 32 330 | self.layer1 = self._make_layer(32, stride=1) 331 | self.layer2 = self._make_layer(64, stride=2) 332 | self.layer3 = self._make_layer(96, stride=2) 333 | 334 | self.dropout = None 335 | if dropout > 0: 336 | self.dropout = nn.Dropout2d(p=dropout) 337 | 338 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 339 | 340 | for m in self.modules(): 341 | if isinstance(m, nn.Conv2d): 342 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 343 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 344 | if m.weight is not None: 345 | nn.init.constant_(m.weight, 1) 346 | if m.bias is not None: 347 | nn.init.constant_(m.bias, 0) 348 | 349 | def _make_layer(self, dim, stride=1): 350 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 351 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 352 | layers = (layer1, layer2) 353 | 354 | self.in_planes = dim 355 | return nn.Sequential(*layers) 356 | 357 | 358 | def forward(self, x): 359 | 360 | # if input is list, combine batch dimension 361 | is_list = isinstance(x, tuple) or isinstance(x, list) 362 | if is_list: 363 | batch_dim = x[0].shape[0] 364 | x = torch.cat(x, dim=0) 365 | 366 | x = self.conv1(x) 367 | x = self.norm1(x) 368 | x = self.relu1(x) 369 | 370 | x = self.layer1(x) 371 | x = self.layer2(x) 372 | x = self.layer3(x) 373 | x = self.conv2(x) 374 | 375 | if self.training and self.dropout is not None: 376 | x = self.dropout(x) 377 | 378 | if is_list: 379 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 380 | 381 | return x 382 | 383 | 384 | 385 | 386 | class MultiHeadAttention(nn.Module): 387 | ''' Multi-Head Attention module ''' 388 | 389 | def __init__(self, d_model, n_head, dropout=0.1): 390 | super().__init__() 391 | 392 | assert d_model % n_head == 0 393 | 394 | self.n_head = n_head 395 | self.d_model = d_model 396 | self.d_head = self.d_model // self.n_head 397 | 398 | self.w_qs = nn.Linear(self.d_model, self.d_model, bias=False) # TODO: enable bias 399 | self.w_ks = nn.Linear(self.d_model, self.d_model, bias=False) 400 | self.w_vs = nn.Linear(self.d_model, self.d_model, bias=False) 401 | self.fc = nn.Linear(self.d_model, self.d_model) 402 | 403 | # self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) # TODO 404 | 405 | self.dropout = nn.Dropout(dropout) 406 | # self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 407 | 408 | 409 | def forward(self, q, k, v): 410 | ''' 411 | q: shape of N*len*C 412 | ''' 413 | d_head, n_head = self.d_head, self.n_head 414 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 415 | 416 | # residual = q 417 | 418 | # Pass through the pre-attention projection: b x lq x (n*dv) 419 | # Separate different heads: b x lq x n x dv 420 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_head) 421 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_head) 422 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_head) 423 | 424 | # Transpose for attention dot product: b x n x lq x dv 425 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 426 | 427 | attn = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(self.d_head) 428 | attn = self.dropout(F.softmax(attn, dim=-1)) 429 | q_updated = torch.matmul(attn, v) 430 | 431 | # Transpose to move the head dimension back: b x lq x n x dv 432 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 433 | q_updated = q_updated.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 434 | q_updated = self.dropout(self.fc(q_updated)) 435 | # q_updated += residual 436 | 437 | # q_updated = self.layer_norm(q_updated) 438 | 439 | return q_updated, attn 440 | 441 | 442 | 443 | class AnchorEncoderBlock(nn.Module): 444 | 445 | def __init__(self, anchor_dist, d_model, num_heads, d_ff, dropout=0.): 446 | super().__init__() 447 | 448 | self.anchor_dist = anchor_dist 449 | self.half_anchor_dist = anchor_dist // 2 450 | 451 | self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout) 452 | self.dropout = nn.Dropout(dropout) 453 | self.layer_norm_1 = nn.LayerNorm(d_model) 454 | 455 | self.FFN = nn.Sequential( 456 | nn.Linear(d_model, d_ff), 457 | nn.ReLU(), 458 | nn.Linear(d_ff, d_model), 459 | ) 460 | 461 | self.layer_norm_2 = nn.LayerNorm(d_model) 462 | 463 | 464 | def forward(self, inputs): 465 | ''' 466 | inputs: batches with N*C*H*W 467 | ''' 468 | N, C, H, W = inputs.shape 469 | 470 | x = inputs 471 | anchors = inputs[:,:, self.half_anchor_dist::self.anchor_dist, 472 | self.half_anchor_dist::self.anchor_dist].clone() 473 | 474 | # flatten feature maps 475 | x = x.reshape(N, C, H*W).transpose(-1,-2) 476 | anchors = anchors.reshape(N, C, anchors.shape[2]* anchors.shape[3]).transpose(-1,-2) 477 | 478 | # two-stage multi-head self-attention 479 | anchors_new = self.dropout(self.selfAttn(anchors, x, x)[0]) 480 | residual = self.dropout(self.selfAttn(x, anchors_new, anchors_new)[0]) 481 | 482 | norm_1 = self.layer_norm_1(x + residual) 483 | x_linear = self.dropout(self.FFN(norm_1)) 484 | x_new = self.layer_norm_2(norm_1 + x_linear) 485 | 486 | outputs = x_new.transpose(-1,-2).reshape(N, C, H, W) 487 | return outputs 488 | 489 | 490 | 491 | class EncoderBlock(nn.Module): 492 | 493 | def __init__(self, d_model, num_heads, d_ff, dropout=0.): 494 | super().__init__() 495 | 496 | self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout) 497 | self.dropout = nn.Dropout(dropout) 498 | self.layer_norm_1 = nn.LayerNorm(d_model) 499 | 500 | self.FFN = nn.Sequential( 501 | nn.Linear(d_model, d_ff), 502 | nn.ReLU(), 503 | nn.Linear(d_ff, d_model), 504 | ) 505 | 506 | self.layer_norm_2 = nn.LayerNorm(d_model) 507 | 508 | 509 | def forward(self, x): 510 | ''' 511 | x: input batches with N*C*H*W 512 | ''' 513 | N, C, H, W = x.shape 514 | 515 | # update x 516 | x = x.reshape(N, C, H*W).transpose(-1,-2) 517 | 518 | residual = self.dropout(self.selfAttn(x, x, x)[0]) 519 | norm_1 = self.layer_norm_1(x + residual) 520 | x_linear = self.dropout(self.FFN(norm_1)) 521 | x_new = self.layer_norm_2(norm_1 + x_linear) 522 | 523 | outputs = x_new.transpose(-1,-2).reshape(N, C, H, W) 524 | return outputs 525 | 526 | 527 | 528 | class ReduceEncoderBlock(nn.Module): 529 | 530 | def __init__(self, d_model, num_heads, d_ff, dropout=0.): 531 | super().__init__() 532 | 533 | self.reduce = nn.Sequential( 534 | nn.Conv2d(d_model, d_model, 2, 2), 535 | nn.Conv2d(d_model, d_model, 2, 2) 536 | ) 537 | # self.reduce = nn.Sequential( 538 | # nn.AvgPool2d(16, 16) 539 | # ) 540 | 541 | self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout) 542 | self.dropout = nn.Dropout(dropout) 543 | self.layer_norm_1 = nn.LayerNorm(d_model) 544 | 545 | self.FFN = nn.Sequential( 546 | nn.Linear(d_model, d_ff), 547 | nn.ReLU(), 548 | nn.Linear(d_ff, d_model) 549 | ) 550 | 551 | self.layer_norm_2 = nn.LayerNorm(d_model) 552 | 553 | 554 | def forward(self, x): 555 | ''' 556 | x: input batches with N*C*H*W 557 | ''' 558 | N, C, H, W = x.shape 559 | x_reduced = self.reduce(x) 560 | 561 | # update x 562 | x = x.reshape(N, C, H*W).transpose(-1,-2) 563 | x_reduced = x_reduced.reshape(N, C, -1).transpose(-1,-2) 564 | 565 | # print('x ', x.shape) 566 | # print('x_reduced ', x_reduced.shape) 567 | # exit() 568 | 569 | residual = self.dropout(self.selfAttn(x, x_reduced, x_reduced)[0]) 570 | 571 | norm_1 = self.layer_norm_1(x + residual) 572 | x_linear = self.dropout(self.FFN(norm_1)) 573 | x_new = self.layer_norm_2(norm_1 + x_linear) 574 | 575 | outputs = x_new.transpose(-1,-2).reshape(N, C, H, W) 576 | return outputs 577 | 578 | 579 | 580 | def window_partition(x, window_size): 581 | """ 582 | Args: 583 | x: (B, H, W, C) 584 | window_size (int): window size 585 | 586 | Returns: 587 | windows: (num_windows*B, window_size, window_size, C) 588 | """ 589 | B, H, W, C = x.shape 590 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 591 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 592 | return windows 593 | 594 | 595 | def window_reverse(windows, window_size, H, W): 596 | """ 597 | Args: 598 | windows: (num_windows*B, window_size, window_size, C) 599 | window_size (int): Window size 600 | H (int): Height of image 601 | W (int): Width of image 602 | 603 | Returns: 604 | x: (B, H, W, C) 605 | """ 606 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 607 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 608 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 609 | return x 610 | 611 | 612 | class LayerEncoderBlock(nn.Module): 613 | 614 | def __init__(self, win_size, d_model, num_heads, d_ff, dropout=0.): 615 | super().__init__() 616 | 617 | self.win_size = win_size 618 | self.down_factor = 4 619 | self.unfold_stride = int(self.win_size//self.down_factor) 620 | 621 | self.stride_list = [math.floor(win_size/self.down_factor**idx) for idx in range(8) if win_size/self.down_factor**idx >= 1] 622 | # [16, 4, 1] 623 | 624 | self.reduce = nn.Sequential( 625 | nn.AvgPool2d(self.down_factor, self.down_factor) 626 | ) 627 | 628 | self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout) 629 | self.crossAttn = MultiHeadAttention(d_model, num_heads, dropout) 630 | 631 | self.dropout = nn.Dropout(dropout) 632 | self.layerNormSelf = nn.LayerNorm(d_model) 633 | self.layerNormCross = nn.LayerNorm(d_model) 634 | 635 | self.FFN = nn.Sequential( 636 | nn.Linear(d_model, d_ff), 637 | nn.ReLU(), 638 | nn.Linear(d_ff, d_model) 639 | ) 640 | 641 | self.layer_norm_out = nn.LayerNorm(d_model) 642 | 643 | 644 | def Circular_pad2D(self, x, pad_right, pad_bottom): 645 | ''' 646 | x: (N, H, W, C) 647 | x_pad: (N, H_pad, W_pad, C) 648 | ''' 649 | N, H, W, C = x.shape 650 | 651 | H_pad = H + pad_bottom 652 | W_pad = W + pad_right 653 | 654 | H_repeat = math.ceil(H_pad/H) 655 | W_repeat = math.ceil(W_pad/W) 656 | x_repeat = x.repeat(1, H_repeat, W_repeat, 1) 657 | 658 | x_pad = x_repeat[:, :H_pad, :W_pad, :] 659 | return x_pad 660 | 661 | 662 | def pad_fit_win(self, x, win_size): 663 | N, H, W, C = x.shape 664 | 665 | W_ = math.ceil(W/win_size)*win_size 666 | H_ = math.ceil(H/win_size)*win_size 667 | padRight = W_ - W 668 | padBottom = H_ - H 669 | 670 | x_pad = self.Circular_pad2D(x, padRight, padBottom) # N*H_*W_*C 671 | return x_pad 672 | 673 | 674 | def self_attention(self, x): 675 | ''' 676 | x: (N, H, W, C) 677 | out: (N, H, W, C) 678 | ''' 679 | N, H, W, C = x.shape 680 | x_pad = self.pad_fit_win(x, self.win_size) # N*H_*W_*C 681 | _, H_, W_, _ = x_pad.shape 682 | 683 | # x_pad = F.pad(x.permute(xxx), (0, padRight, 0, padBottom), mode='reflect') # N*C*H_*W_ 684 | 685 | x_window = window_partition(x_pad, self.win_size) # (num_win*B, win_size, win_size, C) 686 | x_window = x_window.view(-1, self.win_size*self.win_size, C) # (num_win*B, win_size*win_size, C) 687 | 688 | # self-attention 689 | residual = self.dropout(self.selfAttn(x_window, x_window, x_window)[0]) 690 | residual = residual.view(-1, self.win_size, self.win_size, C) 691 | residual = window_reverse(residual, self.win_size, H_, W_) # (N, H_, W_, C) 692 | 693 | out = x_pad + residual 694 | out = out[:, :H, :W, :] 695 | return out 696 | 697 | 698 | def cross_attention(self, query, keyVal): 699 | ''' 700 | query: (N, qH, qW, C) 701 | keyVal: (N, kH, kW, C) 702 | out: (N, qH, qW, C) 703 | ''' 704 | _, qH, qW, C = query.shape 705 | _, kH, kW, C = keyVal.shape 706 | 707 | # print('in query ', query.shape) 708 | # print('in keyVal ', keyVal.shape) 709 | # print('-') 710 | 711 | query = self.pad_fit_win(query, self.win_size) # N*H_*W_*C 712 | _, qH_, qW_, C = query.shape 713 | 714 | query_win = window_partition(query, self.win_size) 715 | query_win = query_win.view(-1, self.win_size*self.win_size, C) # (num_win*B, win_size*win_size, C) 716 | 717 | # pad and unfold keyVal 718 | kW_ = (math.ceil(kW/self.unfold_stride) - 1)*self.unfold_stride + self.win_size 719 | kH_ = (math.ceil(kH/self.unfold_stride) - 1)*self.unfold_stride + self.win_size 720 | padRight = kW_ - kW 721 | padBottom = kH_ - kH 722 | 723 | keyVal_pad = self.Circular_pad2D(keyVal, padRight, padBottom) 724 | keyVal = F.unfold(keyVal_pad.permute(0, 3, 1, 2), self.win_size, stride=self.unfold_stride) # (N, C*win_size*win_size, num_win) 725 | keyVal = keyVal.permute(0,2,1).reshape(-1, C, self.win_size*self.win_size).permute(0,2,1) # (num_win*B, win_size*win_size, C) 726 | 727 | # print('win query ', query_win.shape) 728 | # print('win keyVal ', keyVal.shape) 729 | # print('-') 730 | 731 | residual = self.dropout(self.crossAttn(query_win, keyVal, keyVal)[0]) 732 | residual = residual.view(-1, self.win_size, self.win_size, C) 733 | residual = window_reverse(residual, self.win_size, qH_, qW_) # (N, H, W, C) 734 | 735 | out = query + residual 736 | out = out[:, :qH, :qW, :] 737 | return out 738 | 739 | 740 | def forward(self, x): 741 | ''' 742 | x: input batches with N*C*H*W 743 | ''' 744 | N, C, H, W = x.shape 745 | x = x.permute(0, 2, 3, 1) # N*H*W*C 746 | x = self.pad_fit_win(x, self.win_size) # pad 747 | 748 | # layered self-attention 749 | layerAttnList = [] 750 | strideListLen = len(self.stride_list) 751 | for idx in range(strideListLen): 752 | x_attn = self.self_attention(x) # built-in shortcut 753 | x_attn = self.layerNormSelf(x_attn) 754 | layerAttnList.append(x_attn) 755 | 756 | if idx < strideListLen - 1: 757 | x = self.reduce(x_attn.permute(0, 3, 1 ,2)) # N*C*H*W 758 | x = x.permute(0, 2, 3, 1) # N*H*W*C 759 | 760 | # layered cross-attention 761 | KeyVal = layerAttnList[-1] 762 | for idx in range(strideListLen-1, 0, -1): 763 | Query = layerAttnList[idx-1] 764 | Query = self.cross_attention(Query, KeyVal) # built-in shortcut 765 | Query = self.layerNormCross(Query) 766 | 767 | KeyVal = Query 768 | 769 | Query = Query[:, :H, :W, :] # unpad 770 | 771 | q_residual = self.dropout(self.FFN(Query)) 772 | x_new = self.layer_norm_out(Query + q_residual) 773 | 774 | outputs = x_new.permute(0, 3, 1, 2) 775 | return outputs 776 | 777 | 778 | 779 | class BasicLayerEncoderBlock(nn.Module): 780 | 781 | def __init__(self, win_size, d_model, num_heads, d_ff, dropout=0.): 782 | super().__init__() 783 | 784 | self.win_size = win_size 785 | self.down_factor = 2 786 | self.unfold_stride = int(self.win_size//self.down_factor) 787 | 788 | self.stride_list = [math.floor(win_size/self.down_factor**idx) for idx in range(8) if win_size/self.down_factor**idx >= 1] 789 | # [16, 8, 4, 2, 1] 790 | 791 | self.reduce = nn.Sequential( 792 | nn.AvgPool2d(self.down_factor, self.down_factor) 793 | ) 794 | 795 | self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout) 796 | self.crossAttn = MultiHeadAttention(d_model, num_heads, dropout) 797 | 798 | self.dropout = nn.Dropout(dropout) 799 | self.layerNormSelf = nn.LayerNorm(d_model) 800 | self.layerNormCross = nn.LayerNorm(d_model) 801 | 802 | self.FFN = nn.Sequential( 803 | nn.Linear(d_model, d_ff), 804 | nn.ReLU(), 805 | nn.Linear(d_ff, d_model) 806 | ) 807 | 808 | self.layer_norm_out = nn.LayerNorm(d_model) 809 | 810 | 811 | def Circular_pad2D(self, x, pad_right, pad_bottom): 812 | ''' 813 | x: (N, H, W, C) 814 | x_pad: (N, H_pad, W_pad, C) 815 | ''' 816 | N, H, W, C = x.shape 817 | 818 | H_pad = H + pad_bottom 819 | W_pad = W + pad_right 820 | 821 | H_repeat = math.ceil(H_pad/H) 822 | W_repeat = math.ceil(W_pad/W) 823 | x_repeat = x.repeat(1, H_repeat, W_repeat, 1) 824 | 825 | x_pad = x_repeat[:, :H_pad, :W_pad, :] 826 | return x_pad 827 | 828 | 829 | def pad_fit_win(self, x, win_size): 830 | N, H, W, C = x.shape 831 | 832 | W_ = math.ceil(W/win_size)*win_size 833 | H_ = math.ceil(H/win_size)*win_size 834 | padRight = W_ - W 835 | padBottom = H_ - H 836 | 837 | x_pad = self.Circular_pad2D(x, padRight, padBottom) # N*H_*W_*C 838 | return x_pad 839 | 840 | 841 | def self_attention(self, x): 842 | ''' 843 | x: (N, H, W, C) 844 | out: (N, H, W, C) 845 | ''' 846 | N, H, W, C = x.shape 847 | x_pad = self.pad_fit_win(x, self.win_size) # N*H_*W_*C 848 | _, H_, W_, _ = x_pad.shape 849 | 850 | # x_pad = F.pad(x.permute(xxx), (0, padRight, 0, padBottom), mode='reflect') # N*C*H_*W_ 851 | 852 | x_window = window_partition(x_pad, self.win_size) # (num_win*B, win_size, win_size, C) 853 | x_window = x_window.view(-1, self.win_size*self.win_size, C) # (num_win*B, win_size*win_size, C) 854 | 855 | # self-attention 856 | residual = self.dropout(self.selfAttn(x_window, x_window, x_window)[0]) 857 | residual = residual.view(-1, self.win_size, self.win_size, C) 858 | residual = window_reverse(residual, self.win_size, H_, W_) # (N, H_, W_, C) 859 | 860 | out = x_pad + residual 861 | out = out[:, :H, :W, :] 862 | return out 863 | 864 | 865 | def cross_attention(self, query, keyVal, query_win_size): 866 | ''' 867 | query: (N, qH, qW, C) 868 | keyVal: (N, kH, kW, C) 869 | out: (N, qH, qW, C) 870 | ''' 871 | _, qH, qW, C = query.shape 872 | 873 | query_win = window_partition(query, query_win_size) 874 | query_win = query_win.view(-1, query_win_size*query_win_size, C) # (num_win*B, win_size*win_size, C) 875 | 876 | keyWinSize = query_win_size // 2 877 | keyVal_win = window_partition(keyVal, keyWinSize) 878 | keyVal_win = keyVal_win.view(-1, keyWinSize*keyWinSize, C) # (num_win*B, win_size*win_size, C) 879 | 880 | residual = self.dropout(self.crossAttn(query_win, keyVal_win, keyVal_win)[0]) 881 | residual = residual.view(-1, query_win_size, query_win_size, C) 882 | residual = window_reverse(residual, query_win_size, qH, qW) # (N, H, W, C) 883 | 884 | out = query + residual 885 | return out 886 | 887 | 888 | def forward(self, x): 889 | ''' 890 | x: input batches with N*C*H*W 891 | ''' 892 | N, C, H, W = x.shape 893 | x = x.permute(0, 2, 3, 1) # N*H*W*C 894 | x = self.pad_fit_win(x, self.win_size) # pad 895 | 896 | # layered self-attention 897 | layerAttnList = [] 898 | strideListLen = len(self.stride_list) 899 | for idx in range(strideListLen): 900 | x_attn = self.self_attention(x) # built-in shortcut 901 | x_attn = self.layerNormSelf(x_attn) 902 | layerAttnList.append(x_attn) 903 | 904 | if idx < strideListLen - 1: 905 | x = self.reduce(x_attn.permute(0, 3, 1 ,2)) # N*C*H*W 906 | x = x.permute(0, 2, 3, 1) # N*H*W*C 907 | 908 | # layered cross-attention 909 | KeyVal = layerAttnList[-1] 910 | for idx in range(strideListLen-1, 0, -1): 911 | Query = layerAttnList[idx-1] 912 | QueryWinSize = self.stride_list[idx-1] 913 | 914 | Query = self.cross_attention(Query, KeyVal, QueryWinSize) # built-in shortcut 915 | Query = self.layerNormCross(Query) 916 | 917 | KeyVal = Query 918 | 919 | Query = Query[:, :H, :W, :] # unpad 920 | 921 | q_residual = self.dropout(self.FFN(Query)) 922 | x_new = self.layer_norm_out(Query + q_residual) 923 | 924 | outputs = x_new.permute(0, 3, 1, 2) 925 | return outputs 926 | 927 | 928 | 929 | 930 | class PositionalEncoding(nn.Module): 931 | 932 | def __init__(self, d_model, dropout=0.): 933 | super().__init__() 934 | 935 | self.max_len = 256 936 | self.d_model = d_model 937 | 938 | self._update_PE_table(self.max_len, self.d_model//2) 939 | 940 | 941 | def _update_PE_table(self, max_len, d_model): 942 | self.PE_table = torch.zeros(max_len, d_model) 943 | 944 | pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 945 | denominator = torch.pow(10000, torch.arange(0, d_model, 2).float()/d_model) 946 | 947 | self.PE_table[:, 0::2] = torch.sin(pos/denominator) 948 | self.PE_table[:, 1::2] = torch.cos(pos/denominator) 949 | 950 | 951 | def forward(self, x): 952 | ''' x: image batches with N*C*H*W ''' 953 | 954 | N, C, H, W = x.shape 955 | max_hw = max(H, W) 956 | 957 | if max_hw > self.max_len or self.d_model != C: 958 | self.max_len = max_hw 959 | self.d_model = C 960 | 961 | self._update_PE_table(self.max_len, self.d_model//2) 962 | 963 | if self.PE_table.device != x.device: 964 | self.PE_table = self.PE_table.to(x.device) 965 | 966 | h_pos_emb = self.PE_table[:H, :].unsqueeze(1).repeat(1, W, 1) # H*W*C/2 967 | w_pos_emb = self.PE_table[:W, :].unsqueeze(0).repeat(H, 1, 1) # H*W*C/2 968 | pos_emb = torch.cat([h_pos_emb, w_pos_emb], dim=-1 969 | ).permute([2,0,1]).unsqueeze(0).repeat(N,1,1,1) # N*C*H*W 970 | 971 | output = x + pos_emb 972 | return output 973 | 974 | 975 | 976 | class TransformerEncoder(nn.Module): 977 | 978 | def __init__(self, anchor_dist, num_blocks, d_model, num_heads, d_ff, dropout=0.): 979 | super().__init__() 980 | 981 | self.anchor_dist = anchor_dist 982 | 983 | blocks_list = [] 984 | for idx in range(num_blocks): 985 | # blocks_list.append( AnchorEncoderBlock(anchor_dist, d_model, num_heads, d_ff, dropout) ) 986 | # blocks_list.append( EncoderBlock(d_model, num_heads, d_ff, dropout) ) 987 | blocks_list.append( ReduceEncoderBlock(d_model, num_heads, d_ff, dropout) ) 988 | # blocks_list.append( BasicLayerEncoderBlock(anchor_dist, d_model, num_heads, d_ff, dropout) ) 989 | 990 | self.blocks = nn.Sequential(*blocks_list) 991 | 992 | self.posEmbedding = PositionalEncoding(d_model, dropout) 993 | 994 | 995 | def forward(self, x): 996 | x_w_pos = self.posEmbedding(x) 997 | x_updated = self.blocks(x_w_pos) 998 | 999 | return x_updated 1000 | 1001 | 1002 | 1003 | class RawInputTransEncoder(nn.Module): 1004 | 1005 | def __init__(self, anchor_dist, num_blocks, d_model, num_heads, d_ff, dropout=0.): 1006 | super().__init__() 1007 | 1008 | self.anchor_dist = anchor_dist 1009 | 1010 | self.linear = nn.Conv2d(3, d_model, 8, 8) 1011 | 1012 | blocks_list = [] 1013 | for idx in range(num_blocks): 1014 | # blocks_list.append( AnchorEncoderBlock(anchor_dist, d_model, num_heads, d_ff, dropout) ) 1015 | # blocks_list.append( EncoderBlock(d_model, num_heads, d_ff, dropout) ) 1016 | # blocks_list.append( ReduceEncoderBlock(d_model, num_heads, d_ff, dropout) ) 1017 | blocks_list.append( LayerEncoderBlock(anchor_dist, d_model, num_heads, d_ff, dropout) ) 1018 | 1019 | self.blocks = nn.Sequential(*blocks_list) 1020 | 1021 | self.posEmbedding = PositionalEncoding(d_model, dropout) 1022 | 1023 | # initialization 1024 | for m in self.modules(): 1025 | if isinstance(m, nn.Conv2d): 1026 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 1027 | elif isinstance(m, nn.Linear): 1028 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 1029 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm, nn.LayerNorm)): 1030 | if m.weight is not None: 1031 | nn.init.constant_(m.weight, 1) 1032 | if m.bias is not None: 1033 | nn.init.constant_(m.bias, 0) 1034 | 1035 | 1036 | def forward(self, x): 1037 | # if input is list, combine batch dimension 1038 | is_list = isinstance(x, tuple) or isinstance(x, list) 1039 | if is_list: 1040 | batch_dim = x[0].shape[0] 1041 | x = torch.cat(x, dim=0) 1042 | 1043 | x = self.linear(x) 1044 | x_w_pos = self.posEmbedding(x) 1045 | x_updated = self.blocks(x_w_pos) 1046 | 1047 | if is_list: 1048 | x_updated = torch.split(x_updated, [batch_dim, batch_dim], dim=0) 1049 | 1050 | return x_updated 1051 | 1052 | 1053 | 1054 | class GlobalLocalBlock(nn.Module): 1055 | 1056 | def __init__(self, anchor_dist, d_model, num_heads, out_dim, dropout=0., stride=1): 1057 | super().__init__() 1058 | 1059 | self.anchor_dist = anchor_dist 1060 | self.half_anchor_dist = anchor_dist // 2 1061 | self.d_model = d_model 1062 | self.out_dim = out_dim 1063 | 1064 | self.selfAttn = MultiHeadAttention(d_model, num_heads, dropout) 1065 | self.dropout = nn.Dropout(dropout) 1066 | self.layer_norm_1 = nn.LayerNorm(d_model) 1067 | 1068 | self.resBlock_1 = ResidualBlock(d_model, d_model, norm_fn='instance', stride=stride) 1069 | self.change_channel = nn.Linear(d_model, out_dim) 1070 | self.resBlock_2 = ResidualBlock(out_dim, out_dim, norm_fn='instance', stride=1) 1071 | 1072 | self.posEmbedding = PositionalEncoding(d_model, dropout) 1073 | 1074 | 1075 | def forward(self, inputs): 1076 | ''' 1077 | inputs: batches with N*H*W*C 1078 | ''' 1079 | 1080 | # local update 1 1081 | x = self.resBlock_1(inputs) 1082 | x = self.posEmbedding(x) 1083 | anchors = x[:,:, self.half_anchor_dist::self.anchor_dist, 1084 | self.half_anchor_dist::self.anchor_dist].clone() 1085 | 1086 | # flatten feature maps 1087 | N, C, H, W = x.shape 1088 | x = x.reshape(N, C, H*W).transpose(-1,-2) 1089 | anchors = anchors.reshape(N, C, anchors.shape[2]* anchors.shape[3]).transpose(-1,-2) 1090 | 1091 | # gloabl update with two-stage multi-head self-attention 1092 | anchors_new = self.dropout(self.selfAttn(anchors, x, x)[0]) 1093 | residual = self.dropout(self.selfAttn(x, anchors_new, anchors_new)[0]) 1094 | norm_1 = self.layer_norm_1(x + residual) 1095 | 1096 | # local update 2 1097 | norm_1 = self.change_channel(norm_1) 1098 | norm_1 = norm_1.transpose(-1,-2).reshape(N, self.out_dim, H, W) 1099 | outputs = self.resBlock_2(norm_1) 1100 | 1101 | return outputs 1102 | 1103 | 1104 | 1105 | class GlobalLocalEncoder(nn.Module): 1106 | 1107 | def __init__(self, anchor_dist, output_dim, dropout=0.): 1108 | super().__init__() 1109 | 1110 | self.anchor_dist = anchor_dist 1111 | self.output_dim = output_dim 1112 | 1113 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 1114 | self.norm1 = nn.InstanceNorm2d(64) 1115 | self.relu1 = nn.ReLU(inplace=True) 1116 | 1117 | self.layer1 = GlobalLocalBlock(self.anchor_dist, 64, 2, 96, dropout, stride=2) 1118 | self.layer2 = GlobalLocalBlock(self.anchor_dist, 96, 3, 96, dropout, stride=1) 1119 | self.layer3 = GlobalLocalBlock(self.anchor_dist//2, 96, 4, 128, dropout, stride=2) 1120 | 1121 | # output convolution 1122 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 1123 | 1124 | self.dropout = None 1125 | if dropout > 0: 1126 | self.dropout = nn.Dropout2d(p=dropout) 1127 | 1128 | # initialization 1129 | for m in self.modules(): 1130 | if isinstance(m, nn.Conv2d): 1131 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 1132 | elif isinstance(m, nn.Linear): 1133 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 1134 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm, nn.LayerNorm)): 1135 | if m.weight is not None: 1136 | nn.init.constant_(m.weight, 1) 1137 | if m.bias is not None: 1138 | nn.init.constant_(m.bias, 0) 1139 | 1140 | 1141 | # def _make_layer(self, in_dim, out_dim, dropout=0., stride=1): 1142 | # layer1 = GlobalLocalBlock(self.anchor_dist, in_dim, in_dim//32, out_dim, dropout=0., stride=stride) 1143 | # layer2 = GlobalLocalBlock(self.anchor_dist, out_dim, out_dim//32, out_dim, dropout=0., stride=1) 1144 | # layers = (layer1, layer2) 1145 | 1146 | # return nn.Sequential(*layers) 1147 | 1148 | 1149 | def forward(self, x): 1150 | # if input is list, combine batch dimension 1151 | is_list = isinstance(x, tuple) or isinstance(x, list) 1152 | if is_list: 1153 | batch_dim = x[0].shape[0] 1154 | x = torch.cat(x, dim=0) 1155 | 1156 | x = self.relu1(self.norm1(self.conv1(x))) 1157 | 1158 | x = self.layer1(x) 1159 | x = self.layer2(x) 1160 | x = self.layer3(x) 1161 | 1162 | x = self.conv2(x) 1163 | 1164 | if self.training and self.dropout is not None: 1165 | x = self.dropout(x) 1166 | 1167 | if is_list: 1168 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 1169 | 1170 | return x -------------------------------------------------------------------------------- /core/gma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | 5 | 6 | class RelPosEmb(nn.Module): 7 | def __init__( 8 | self, 9 | max_pos_size, 10 | dim_head 11 | ): 12 | super().__init__() 13 | self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) 14 | self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) 15 | 16 | deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1) 17 | rel_ind = deltas + max_pos_size - 1 18 | self.register_buffer('rel_ind', rel_ind) 19 | 20 | def forward(self, q): 21 | batch, heads, h, w, c = q.shape 22 | height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) 23 | width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) 24 | 25 | height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h) 26 | width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w) 27 | 28 | height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) 29 | width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) 30 | 31 | return height_score + width_score 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__( 36 | self, 37 | *, 38 | args, 39 | dim, 40 | max_pos_size = 100, 41 | heads = 4, 42 | dim_head = 128, 43 | ): 44 | super().__init__() 45 | self.args = args 46 | self.heads = heads 47 | self.scale = dim_head ** -0.5 48 | inner_dim = heads * dim_head 49 | 50 | self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) 51 | 52 | self.pos_emb = RelPosEmb(max_pos_size, dim_head) 53 | 54 | def forward(self, fmap): 55 | heads, b, c, h, w = self.heads, *fmap.shape 56 | 57 | q, k = self.to_qk(fmap).chunk(2, dim=1) 58 | 59 | q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) 60 | q = self.scale * q 61 | 62 | if self.args.position_only: 63 | sim = self.pos_emb(q) 64 | 65 | elif self.args.position_and_content: 66 | sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) 67 | sim_pos = self.pos_emb(q) 68 | sim = sim_content + sim_pos 69 | 70 | else: 71 | sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) 72 | 73 | sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') 74 | attn = sim.softmax(dim=-1) 75 | 76 | return attn 77 | 78 | 79 | class Aggregate(nn.Module): 80 | def __init__( 81 | self, 82 | args, 83 | dim, 84 | heads = 4, 85 | dim_head = 128, 86 | ): 87 | super().__init__() 88 | self.args = args 89 | self.heads = heads 90 | self.scale = dim_head ** -0.5 91 | inner_dim = heads * dim_head 92 | 93 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) 94 | 95 | self.gamma = nn.Parameter(torch.zeros(1)) 96 | 97 | if dim != inner_dim: 98 | self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) 99 | else: 100 | self.project = None 101 | 102 | def forward(self, attn, fmap): 103 | heads, b, c, h, w = self.heads, *fmap.shape 104 | 105 | v = self.to_v(fmap) 106 | v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) 107 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 108 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) 109 | 110 | if self.project is not None: 111 | out = self.project(out) 112 | 113 | out = fmap + self.gamma * out 114 | 115 | return out 116 | 117 | 118 | if __name__ == "__main__": 119 | att = Attention(dim=128, heads=1) 120 | fmap = torch.randn(2, 128, 40, 90) 121 | out = att(fmap) 122 | 123 | print(out.shape) 124 | -------------------------------------------------------------------------------- /core/gmflownet_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from update import BasicUpdateBlock 8 | from extractor import BasicEncoder, BasicConvEncoder 9 | from corr import CorrBlock, AlternateCorrBlock 10 | from utils.utils import bilinear_sampler, coords_grid, upflow8 11 | from swin_transformer import POLAUpdate, MixAxialPOLAUpdate 12 | 13 | try: 14 | autocast = torch.cuda.amp.autocast 15 | except: 16 | # dummy autocast for PyTorch < 1.6 17 | class autocast: 18 | def __init__(self, enabled): 19 | pass 20 | def __enter__(self): 21 | pass 22 | def __exit__(self, *args): 23 | pass 24 | 25 | 26 | class GMFlowNetModel(nn.Module): 27 | def __init__(self, args): 28 | super().__init__() 29 | self.args = args 30 | 31 | self.hidden_dim = hdim = 128 32 | self.context_dim = cdim = 128 33 | args.corr_levels = 4 34 | args.corr_radius = 4 35 | 36 | if not hasattr(self.args, 'dropout'): 37 | self.args.dropout = 0 38 | 39 | if not hasattr(self.args, 'alternate_corr'): 40 | self.args.alternate_corr = False 41 | 42 | # feature network, context network, and update block 43 | if self.args.use_mix_attn: 44 | self.fnet = nn.Sequential( 45 | BasicConvEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout), 46 | MixAxialPOLAUpdate(embed_dim=256, depth=6, num_head=8, window_size=7) 47 | ) 48 | else: 49 | self.fnet = nn.Sequential( 50 | BasicConvEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout), 51 | POLAUpdate(embed_dim=256, depth=6, num_head=8, window_size=7, neig_win_num=1) 52 | ) 53 | 54 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 55 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim, input_dim=cdim) 56 | 57 | 58 | def freeze_bn(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.BatchNorm2d): 61 | m.eval() 62 | 63 | def initialize_flow(self, img): 64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 65 | N, C, H, W = img.shape 66 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 67 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 68 | 69 | # optical flow computed as difference: flow = coords1 - coords0 70 | return coords0, coords1 71 | 72 | def upsample_flow(self, flow, mask): 73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 74 | N, _, H, W = flow.shape 75 | mask = mask.view(N, 1, 9, 8, 8, H, W) 76 | mask = torch.softmax(mask, dim=2) 77 | 78 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 80 | 81 | up_flow = torch.sum(mask * up_flow, dim=2) 82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 83 | return up_flow.reshape(N, 2, 8*H, 8*W) 84 | 85 | 86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 87 | """ Estimate optical flow between pair of frames """ 88 | 89 | image1 = 2 * (image1 / 255.0) - 1.0 90 | image2 = 2 * (image2 / 255.0) - 1.0 91 | 92 | image1 = image1.contiguous() 93 | image2 = image2.contiguous() 94 | 95 | hdim = self.hidden_dim 96 | cdim = self.context_dim 97 | 98 | # run the feature network 99 | with autocast(enabled=self.args.mixed_precision): 100 | fmap1, fmap2 = self.fnet([image1, image2]) 101 | 102 | fmap1 = fmap1.float() 103 | fmap2 = fmap2.float() 104 | 105 | # # Self-attention update 106 | # fmap1 = self.transEncoder(fmap1) 107 | # fmap2 = self.transEncoder(fmap2) 108 | 109 | if self.args.alternate_corr: 110 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 111 | else: 112 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 113 | 114 | # run the context network 115 | with autocast(enabled=self.args.mixed_precision): 116 | cnet = self.cnet(image1) 117 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 118 | net = torch.tanh(net) 119 | inp = torch.relu(inp) 120 | 121 | coords0, coords1 = self.initialize_flow(image1) 122 | 123 | # Correlation as initialization 124 | N, fC, fH, fW = fmap1.shape 125 | corrMap = corr_fn.corrMap 126 | 127 | #_, coords_index = torch.max(corrMap, dim=-1) # no gradient here 128 | softCorrMap = F.softmax(corrMap, dim=2) * F.softmax(corrMap, dim=1) # (N, fH*fW, fH*fW) 129 | 130 | if flow_init is not None: 131 | coords1 = coords1 + flow_init 132 | else: 133 | # print('matching as init') 134 | # mutual match selection 135 | match12, match_idx12 = softCorrMap.max(dim=2) # (N, fH*fW) 136 | match21, match_idx21 = softCorrMap.max(dim=1) 137 | 138 | for b_idx in range(N): 139 | match21_b = match21[b_idx,:] 140 | match_idx12_b = match_idx12[b_idx,:] 141 | match21[b_idx,:] = match21_b[match_idx12_b] 142 | 143 | matched = (match12 - match21) == 0 # (N, fH*fW) 144 | coords_index = torch.arange(fH*fW).unsqueeze(0).repeat(N,1).to(softCorrMap.device) 145 | coords_index[matched] = match_idx12[matched] 146 | 147 | # matched coords 148 | coords_index = coords_index.reshape(N, fH, fW) 149 | coords_x = coords_index % fW 150 | coords_y = coords_index // fW 151 | 152 | coords_xy = torch.stack([coords_x, coords_y], dim=1).float() 153 | coords1 = coords_xy 154 | 155 | # Iterative update 156 | flow_predictions = [] 157 | for itr in range(iters): 158 | coords1 = coords1.detach() 159 | corr = corr_fn(coords1) # index correlation volume 160 | 161 | flow = coords1 - coords0 162 | with autocast(enabled=self.args.mixed_precision): 163 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 164 | 165 | # F(t+1) = F(t) + \Delta(t) 166 | coords1 = coords1 + delta_flow 167 | 168 | # upsample predictions 169 | if up_mask is None: 170 | flow_up = upflow8(coords1 - coords0) 171 | else: 172 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 173 | 174 | flow_predictions.append(flow_up) 175 | 176 | if test_mode: 177 | return coords1 - coords0, flow_up 178 | 179 | return flow_predictions, softCorrMap 180 | -------------------------------------------------------------------------------- /core/loss.py: -------------------------------------------------------------------------------- 1 | # from operator import index 2 | import torch 3 | import sys 4 | import numpy as np 5 | # from .utils.iamge_utils import save_flow, save_image 6 | import torch.nn.functional as F 7 | # from .utils.kornia import create_meshgrid 8 | 9 | 10 | from typing import Optional 11 | def create_meshgrid( 12 | height: int, 13 | width: int, 14 | normalized_coordinates: bool = True, 15 | device: Optional[torch.device] = torch.device('cpu'), 16 | dtype: torch.dtype = torch.float32, 17 | ) -> torch.Tensor: 18 | """Generate a coordinate grid for an image. 19 | 20 | When the flag ``normalized_coordinates`` is set to True, the grid is 21 | normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch 22 | function :py:func:`torch.nn.functional.grid_sample`. 23 | 24 | Args: 25 | height: the image height (rows). 26 | width: the image width (cols). 27 | normalized_coordinates: whether to normalize 28 | coordinates in the range :math:`[-1,1]` in order to be consistent with the 29 | PyTorch function :py:func:`torch.nn.functional.grid_sample`. 30 | device: the device on which the grid will be generated. 31 | dtype: the data type of the generated grid. 32 | 33 | Return: 34 | grid tensor with shape :math:`(1, H, W, 2)`. 35 | 36 | Example: 37 | >>> create_meshgrid(2, 2) 38 | tensor([[[[-1., -1.], 39 | [ 1., -1.]], 40 | 41 | [[-1., 1.], 42 | [ 1., 1.]]]]) 43 | 44 | >>> create_meshgrid(2, 2, normalized_coordinates=False) 45 | tensor([[[[0., 0.], 46 | [1., 0.]], 47 | 48 | [[0., 1.], 49 | [1., 1.]]]]) 50 | """ 51 | xs: torch.Tensor = torch.linspace(0, width - 1, width, device=device, dtype=dtype) 52 | ys: torch.Tensor = torch.linspace(0, height - 1, height, device=device, dtype=dtype) 53 | # Fix TracerWarning 54 | # Note: normalize_pixel_coordinates still gots TracerWarning since new width and height 55 | # tensors will be generated. 56 | # Below is the code using normalize_pixel_coordinates: 57 | # base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys]), dim=2) 58 | # if normalized_coordinates: 59 | # base_grid = K.geometry.normalize_pixel_coordinates(base_grid, height, width) 60 | # return torch.unsqueeze(base_grid.transpose(0, 1), dim=0) 61 | if normalized_coordinates: 62 | xs = (xs / (width - 1) - 0.5) * 2 63 | ys = (ys / (height - 1) - 0.5) * 2 64 | # generate grid by stacking coordinates 65 | base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys])).transpose(1, 2) # 2xHxW 66 | return torch.unsqueeze(base_grid, dim=0).permute(0, 2, 3, 1) # 1xHxWx2 67 | 68 | 69 | def backwarp(img, flow): 70 | _, _, H, W = img.size() 71 | 72 | u = flow[:, 0, :, :] 73 | v = flow[:, 1, :, :] 74 | 75 | gridX, gridY = np.meshgrid(np.arange(W), np.arange(H)) 76 | gridX = torch.tensor(gridX, requires_grad=False,).to(flow.device) 77 | gridY = torch.tensor(gridY, requires_grad=False,).to(flow.device) 78 | x = gridX.unsqueeze(0).expand_as(u).float() + u 79 | y = gridY.unsqueeze(0).expand_as(v).float() + v 80 | # range -1 to 1 81 | x = 2*(x/W - 0.5) 82 | y = 2*(y/H - 0.5) 83 | # stacking X and Y 84 | grid = torch.stack((x,y), dim=3) 85 | # Sample pixels using bilinear interpolation. 86 | imgOut = torch.nn.functional.grid_sample(img, grid, align_corners=False) 87 | 88 | return imgOut 89 | 90 | 91 | @torch.no_grad() 92 | def compute_supervision_coarse(flow, occlusions, scale: int): 93 | N, _, H, W = flow.shape 94 | Hc, Wc = int(np.ceil(H / scale)), int(np.ceil(W / scale)) 95 | 96 | occlusions_c = occlusions[:, :, ::scale, ::scale] 97 | flow_c = flow[:, :, ::scale, ::scale] / scale 98 | occlusions_c = occlusions_c.reshape(N, Hc * Wc) 99 | 100 | grid_c = create_meshgrid(Hc, Wc, False, device=flow.device).reshape(1, Hc * Wc, 2).repeat(N, 1, 1) 101 | warp_c = grid_c + flow_c.permute(0, 2, 3, 1).reshape(N, Hc * Wc, 2) 102 | warp_c = warp_c.round().long() 103 | 104 | def out_bound_mask(pt, w, h): 105 | return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) 106 | 107 | occlusions_c[out_bound_mask(warp_c, Wc, Hc)] = 1 108 | warp_c = warp_c[..., 0] + warp_c[..., 1] * Wc 109 | 110 | b_ids, i_ids = torch.split(torch.nonzero(occlusions_c == 0), 1, dim=1) 111 | conf_matrix_gt = torch.zeros(N, Hc * Wc, Hc * Wc, device=flow.device) 112 | j_ids = warp_c[b_ids, i_ids] 113 | conf_matrix_gt[b_ids, i_ids, j_ids] = 1 114 | 115 | return conf_matrix_gt 116 | 117 | 118 | def compute_coarse_loss(conf, conf_gt, cfg): 119 | c_pos_w, c_neg_w = cfg.POS_WEIGHT, cfg.NEG_WEIGHT 120 | pos_mask, neg_mask = conf_gt == 1, conf_gt == 0 121 | 122 | if cfg.COARSE_TYPE == 'cross_entropy': 123 | conf = torch.clamp(conf, 1e-6, 1 - 1e-6) 124 | loss_pos = -torch.log(conf[pos_mask]) 125 | loss_neg = -torch.log(1 - conf[neg_mask]) 126 | 127 | return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() 128 | elif cfg.COARSE_TYPE == 'focal': 129 | conf = torch.clamp(conf, 1e-6, 1 - 1e-6) 130 | alpha = cfg.FOCAL_ALPHA 131 | gamma = cfg.FOCAL_GAMMA 132 | loss_pos = -alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log() 133 | loss_neg = -alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log() 134 | return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() 135 | else: 136 | raise ValueError('Unknown coarse loss: {type}'.format(type=cfg.COARSE_TYPE)) 137 | 138 | 139 | def compute_fine_loss(kflow, kflow_gt, cfg): 140 | fine_correct_thr = cfg.WINDOW_SIZE // 2 * 2 141 | error = (kflow - kflow_gt).abs() 142 | correct = torch.max(error, dim=1)[0] < fine_correct_thr 143 | rate = torch.sum(correct).float() / correct.shape[0] 144 | num = correct.shape[0] 145 | return error[correct].mean(), rate.item(), num 146 | 147 | 148 | def compute_flow_loss(flow, flow_gt): 149 | loss = (flow - flow_gt).abs().mean() 150 | epe = torch.sum((flow - flow_gt)**2, dim=1).sqrt() 151 | 152 | metrics = { 153 | 'epe': epe.mean().item(), 154 | '1px': (epe < 1).float().mean().item(), 155 | '3px': (epe < 3).float().mean().item(), 156 | '5px': (epe < 5).float().mean().item(), 157 | } 158 | 159 | return loss, metrics 160 | -------------------------------------------------------------------------------- /core/onecyclelr.py: -------------------------------------------------------------------------------- 1 | import types 2 | import math 3 | from torch._six import inf 4 | from functools import wraps 5 | import warnings 6 | import weakref 7 | from collections import Counter 8 | from bisect import bisect_right 9 | 10 | EPOCH_DEPRECATION_WARNING = ( 11 | "The epoch parameter in `scheduler.step()` was not necessary and is being " 12 | "deprecated where possible. Please use `scheduler.step()` to step the " 13 | "scheduler. During the deprecation, if epoch is different from None, the " 14 | "closed form is used instead of the new chainable form, where available. " 15 | "Please open an issue if you are unable to replicate your use case: " 16 | "https://github.com/pytorch/pytorch/issues/new/choose." 17 | ) 18 | 19 | class _LRScheduler(object): 20 | 21 | def __init__(self, optimizer, last_epoch=-1): 22 | 23 | # Attach optimizer 24 | # if not isinstance(optimizer, Optimizer): 25 | # raise TypeError('{} is not an Optimizer'.format( 26 | # type(optimizer).__name__)) 27 | self.optimizer = optimizer 28 | 29 | # Initialize epoch and base learning rates 30 | if last_epoch == -1: 31 | for group in optimizer.param_groups: 32 | group.setdefault('initial_lr', group['lr']) 33 | else: 34 | for i, group in enumerate(optimizer.param_groups): 35 | if 'initial_lr' not in group: 36 | raise KeyError("param 'initial_lr' is not specified " 37 | "in param_groups[{}] when resuming an optimizer".format(i)) 38 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 39 | self.last_epoch = last_epoch 40 | 41 | # Following https://github.com/pytorch/pytorch/issues/20124 42 | # We would like to ensure that `lr_scheduler.step()` is called after 43 | # `optimizer.step()` 44 | def with_counter(method): 45 | if getattr(method, '_with_counter', False): 46 | # `optimizer.step()` has already been replaced, return. 47 | return method 48 | 49 | # Keep a weak reference to the optimizer instance to prevent 50 | # cyclic references. 51 | instance_ref = weakref.ref(method.__self__) 52 | # Get the unbound method for the same purpose. 53 | func = method.__func__ 54 | cls = instance_ref().__class__ 55 | del method 56 | 57 | @wraps(func) 58 | def wrapper(*args, **kwargs): 59 | instance = instance_ref() 60 | instance._step_count += 1 61 | wrapped = func.__get__(instance, cls) 62 | return wrapped(*args, **kwargs) 63 | 64 | # Note that the returned function here is no longer a bound method, 65 | # so attributes like `__func__` and `__self__` no longer exist. 66 | wrapper._with_counter = True 67 | return wrapper 68 | 69 | self.optimizer.step = with_counter(self.optimizer.step) 70 | self.optimizer._step_count = 0 71 | self._step_count = 0 72 | 73 | self.step() 74 | 75 | def state_dict(self): 76 | """Returns the state of the scheduler as a :class:`dict`. 77 | 78 | It contains an entry for every variable in self.__dict__ which 79 | is not the optimizer. 80 | """ 81 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 82 | 83 | def load_state_dict(self, state_dict): 84 | """Loads the schedulers state. 85 | 86 | Arguments: 87 | state_dict (dict): scheduler state. Should be an object returned 88 | from a call to :meth:`state_dict`. 89 | """ 90 | self.__dict__.update(state_dict) 91 | 92 | def get_last_lr(self): 93 | """ Return last computed learning rate by current scheduler. 94 | """ 95 | return self._last_lr 96 | 97 | def get_lr(self): 98 | # Compute learning rate using chainable form of the scheduler 99 | raise NotImplementedError 100 | 101 | def step(self, epoch=None): 102 | # Raise a warning if old pattern is detected 103 | # https://github.com/pytorch/pytorch/issues/20124 104 | if self._step_count == 1: 105 | if not hasattr(self.optimizer.step, "_with_counter"): 106 | warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " 107 | "initialization. Please, make sure to call `optimizer.step()` before " 108 | "`lr_scheduler.step()`. See more details at " 109 | "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) 110 | 111 | # Just check if there were two first lr_scheduler.step() calls before optimizer.step() 112 | elif self.optimizer._step_count < 1: 113 | warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " 114 | "In PyTorch 1.1.0 and later, you should call them in the opposite order: " 115 | "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " 116 | "will result in PyTorch skipping the first value of the learning rate schedule. " 117 | "See more details at " 118 | "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) 119 | self._step_count += 1 120 | 121 | class _enable_get_lr_call: 122 | 123 | def __init__(self, o): 124 | self.o = o 125 | 126 | def __enter__(self): 127 | self.o._get_lr_called_within_step = True 128 | return self 129 | 130 | def __exit__(self, type, value, traceback): 131 | self.o._get_lr_called_within_step = False 132 | 133 | with _enable_get_lr_call(self): 134 | if epoch is None: 135 | self.last_epoch += 1 136 | values = self.get_lr() 137 | else: 138 | warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) 139 | self.last_epoch = epoch 140 | if hasattr(self, "_get_closed_form_lr"): 141 | values = self._get_closed_form_lr() 142 | else: 143 | values = self.get_lr() 144 | 145 | for param_group, lr in zip(self.optimizer.param_groups, values): 146 | param_group['lr'] = lr 147 | 148 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 149 | 150 | class OneCycleLR(_LRScheduler): 151 | r"""Sets the learning rate of each parameter group according to the 152 | 1cycle learning rate policy. The 1cycle policy anneals the learning 153 | rate from an initial learning rate to some maximum learning rate and then 154 | from that maximum learning rate to some minimum learning rate much lower 155 | than the initial learning rate. 156 | This policy was initially described in the paper `Super-Convergence: 157 | Very Fast Training of Neural Networks Using Large Learning Rates`_. 158 | 159 | The 1cycle learning rate policy changes the learning rate after every batch. 160 | `step` should be called after a batch has been used for training. 161 | 162 | This scheduler is not chainable. 163 | 164 | Note also that the total number of steps in the cycle can be determined in one 165 | of two ways (listed in order of precedence): 166 | 167 | #. A value for total_steps is explicitly provided. 168 | #. A number of epochs (epochs) and a number of steps per epoch 169 | (steps_per_epoch) are provided. 170 | In this case, the number of total steps is inferred by 171 | total_steps = epochs * steps_per_epoch 172 | 173 | You must either provide a value for total_steps or provide a value for both 174 | epochs and steps_per_epoch. 175 | 176 | Args: 177 | optimizer (Optimizer): Wrapped optimizer. 178 | max_lr (float or list): Upper learning rate boundaries in the cycle 179 | for each parameter group. 180 | total_steps (int): The total number of steps in the cycle. Note that 181 | if a value is not provided here, then it must be inferred by providing 182 | a value for epochs and steps_per_epoch. 183 | Default: None 184 | epochs (int): The number of epochs to train for. This is used along 185 | with steps_per_epoch in order to infer the total number of steps in the cycle 186 | if a value for total_steps is not provided. 187 | Default: None 188 | steps_per_epoch (int): The number of steps per epoch to train for. This is 189 | used along with epochs in order to infer the total number of steps in the 190 | cycle if a value for total_steps is not provided. 191 | Default: None 192 | pct_start (float): The percentage of the cycle (in number of steps) spent 193 | increasing the learning rate. 194 | Default: 0.3 195 | anneal_strategy (str): {'cos', 'linear'} 196 | Specifies the annealing strategy: "cos" for cosine annealing, "linear" for 197 | linear annealing. 198 | Default: 'cos' 199 | cycle_momentum (bool): If ``True``, momentum is cycled inversely 200 | to learning rate between 'base_momentum' and 'max_momentum'. 201 | Default: True 202 | base_momentum (float or list): Lower momentum boundaries in the cycle 203 | for each parameter group. Note that momentum is cycled inversely 204 | to learning rate; at the peak of a cycle, momentum is 205 | 'base_momentum' and learning rate is 'max_lr'. 206 | Default: 0.85 207 | max_momentum (float or list): Upper momentum boundaries in the cycle 208 | for each parameter group. Functionally, 209 | it defines the cycle amplitude (max_momentum - base_momentum). 210 | Note that momentum is cycled inversely 211 | to learning rate; at the start of a cycle, momentum is 'max_momentum' 212 | and learning rate is 'base_lr' 213 | Default: 0.95 214 | div_factor (float): Determines the initial learning rate via 215 | initial_lr = max_lr/div_factor 216 | Default: 25 217 | final_div_factor (float): Determines the minimum learning rate via 218 | min_lr = initial_lr/final_div_factor 219 | Default: 1e4 220 | last_epoch (int): The index of the last batch. This parameter is used when 221 | resuming a training job. Since `step()` should be invoked after each 222 | batch instead of after each epoch, this number represents the total 223 | number of *batches* computed, not the total number of epochs computed. 224 | When last_epoch=-1, the schedule is started from the beginning. 225 | Default: -1 226 | 227 | Example: 228 | >>> data_loader = torch.utils.data.DataLoader(...) 229 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 230 | >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) 231 | >>> for epoch in range(10): 232 | >>> for batch in data_loader: 233 | >>> train_batch(...) 234 | >>> scheduler.step() 235 | 236 | 237 | .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: 238 | https://arxiv.org/abs/1708.07120 239 | """ 240 | def __init__(self, 241 | optimizer, 242 | max_lr, 243 | total_steps=None, 244 | epochs=None, 245 | steps_per_epoch=None, 246 | pct_start=0.3, 247 | anneal_strategy='cos', 248 | cycle_momentum=True, 249 | base_momentum=0.85, 250 | max_momentum=0.95, 251 | div_factor=25., 252 | final_div_factor=1e4, 253 | last_epoch=-1): 254 | 255 | # Validate optimizer 256 | # if not isinstance(optimizer, Optimizer): 257 | # raise TypeError('{} is not an Optimizer'.format( 258 | # type(optimizer).__name__)) 259 | self.optimizer = optimizer 260 | 261 | # Validate total_steps 262 | if total_steps is None and epochs is None and steps_per_epoch is None: 263 | raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)") 264 | elif total_steps is not None: 265 | if total_steps <= 0 or not isinstance(total_steps, int): 266 | raise ValueError("Expected non-negative integer total_steps, but got {}".format(total_steps)) 267 | self.total_steps = total_steps 268 | else: 269 | if epochs <= 0 or not isinstance(epochs, int): 270 | raise ValueError("Expected non-negative integer epochs, but got {}".format(epochs)) 271 | if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int): 272 | raise ValueError("Expected non-negative integer steps_per_epoch, but got {}".format(steps_per_epoch)) 273 | self.total_steps = epochs * steps_per_epoch 274 | self.step_size_up = float(pct_start * self.total_steps) - 1 275 | self.step_size_down = float(self.total_steps - self.step_size_up) - 1 276 | 277 | # Validate pct_start 278 | if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): 279 | raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start)) 280 | 281 | # Validate anneal_strategy 282 | if anneal_strategy not in ['cos', 'linear']: 283 | raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy)) 284 | elif anneal_strategy == 'cos': 285 | self.anneal_func = self._annealing_cos 286 | elif anneal_strategy == 'linear': 287 | self.anneal_func = self._annealing_linear 288 | 289 | # Initialize learning rate variables 290 | max_lrs = self._format_param('max_lr', self.optimizer, max_lr) 291 | for idx, group in enumerate(self.optimizer.param_groups): 292 | group['initial_lr'] = max_lrs[idx] / div_factor 293 | group['max_lr'] = max_lrs[idx] 294 | group['min_lr'] = group['initial_lr'] / final_div_factor 295 | 296 | # Initialize momentum variables 297 | self.cycle_momentum = cycle_momentum 298 | if self.cycle_momentum: 299 | if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults: 300 | raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') 301 | self.use_beta1 = 'betas' in self.optimizer.defaults 302 | max_momentums = self._format_param('max_momentum', optimizer, max_momentum) 303 | base_momentums = self._format_param('base_momentum', optimizer, base_momentum) 304 | if last_epoch == -1: 305 | for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups): 306 | if self.use_beta1: 307 | _, beta2 = group['betas'] 308 | group['betas'] = (m_momentum, beta2) 309 | else: 310 | group['momentum'] = m_momentum 311 | group['max_momentum'] = m_momentum 312 | group['base_momentum'] = b_momentum 313 | 314 | super(OneCycleLR, self).__init__(optimizer, last_epoch) 315 | 316 | def _format_param(self, name, optimizer, param): 317 | """Return correctly formatted lr/momentum for each param group.""" 318 | if isinstance(param, (list, tuple)): 319 | if len(param) != len(optimizer.param_groups): 320 | raise ValueError("expected {} values for {}, got {}".format( 321 | len(optimizer.param_groups), name, len(param))) 322 | return param 323 | else: 324 | return [param] * len(optimizer.param_groups) 325 | 326 | def _annealing_cos(self, start, end, pct): 327 | "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." 328 | cos_out = math.cos(math.pi * pct) + 1 329 | return end + (start - end) / 2.0 * cos_out 330 | 331 | def _annealing_linear(self, start, end, pct): 332 | "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." 333 | return (end - start) * pct + start 334 | 335 | def get_lr(self): 336 | if not self._get_lr_called_within_step: 337 | warnings.warn("To get the last learning rate computed by the scheduler, " 338 | "please use `get_last_lr()`.", UserWarning) 339 | 340 | lrs = [] 341 | step_num = self.last_epoch 342 | 343 | if step_num > self.total_steps: 344 | raise ValueError("Tried to step {} times. The specified number of total steps is {}" 345 | .format(step_num + 1, self.total_steps)) 346 | 347 | for group in self.optimizer.param_groups: 348 | if step_num <= self.step_size_up: 349 | computed_lr = self.anneal_func(group['initial_lr'], group['max_lr'], step_num / self.step_size_up) 350 | if self.cycle_momentum: 351 | computed_momentum = self.anneal_func(group['max_momentum'], group['base_momentum'], 352 | step_num / self.step_size_up) 353 | else: 354 | down_step_num = step_num - self.step_size_up 355 | computed_lr = self.anneal_func(group['max_lr'], group['min_lr'], down_step_num / self.step_size_down) 356 | if self.cycle_momentum: 357 | computed_momentum = self.anneal_func(group['base_momentum'], group['max_momentum'], 358 | down_step_num / self.step_size_down) 359 | 360 | lrs.append(computed_lr) 361 | if self.cycle_momentum: 362 | if self.use_beta1: 363 | _, beta2 = group['betas'] 364 | group['betas'] = (computed_momentum, beta2) 365 | else: 366 | group['momentum'] = computed_momentum 367 | 368 | return lrs 369 | -------------------------------------------------------------------------------- /core/onecyclelr.py.save: -------------------------------------------------------------------------------- 1 | import types 2 | import math 3 | from torch._six import inf 4 | from functools import wraps 5 | import warnings 6 | import weakref 7 | from collections import Counter 8 | from bisect import bisect_right 9 | 10 | EPOCH_DEPRECATION_WARNING = ( 11 | "The epoch parameter in `scheduler.step()` was not necessary and is being " 12 | "deprecated where possible. Please use `scheduler.step()` to step the " 13 | "scheduler. During the deprecation, if epoch is different from None, the " 14 | "closed form is used instead of the new chainable form, where available. " 15 | "Please open an issue if you are unable to replicate your use case: " 16 | "https://github.com/pytorch/pytorch/issues/new/choose." 17 | ) 18 | 19 | class _LRScheduler(object): 20 | 21 | def __init__(self, optimizer, last_epoch=-1): 22 | 23 | # Attach optimizer 24 | # if not isinstance(optimizer, Optimizer): 25 | # raise TypeError('{} is not an Optimizer'.format( 26 | # type(optimizer).__name__)) 27 | self.optimizer = optimizer 28 | 29 | # Initialize epoch and base learning rates 30 | if last_epoch == -1: 31 | for group in optimizer.param_groups: 32 | group.setdefault('initial_lr', group['lr']) 33 | else: 34 | for i, group in enumerate(optimizer.param_groups): 35 | if 'initial_lr' not in group: 36 | raise KeyError("param 'initial_lr' is not specified " 37 | "in param_groups[{}] when resuming an optimizer".format(i)) 38 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 39 | self.last_epoch = last_epoch 40 | 41 | # Following https://github.com/pytorch/pytorch/issues/20124 42 | # We would like to ensure that `lr_scheduler.step()` is called after 43 | # `optimizer.step()` 44 | def with_counter(method): 45 | if getattr(method, '_with_counter', False): 46 | # `optimizer.step()` has already been replaced, return. 47 | return method 48 | 49 | # Keep a weak reference to the optimizer instance to prevent 50 | # cyclic references. 51 | instance_ref = weakref.ref(method.__self__) 52 | # Get the unbound method for the same purpose. 53 | func = method.__func__ 54 | cls = instance_ref().__class__ 55 | del method 56 | 57 | @wraps(func) 58 | def wrapper(*args, **kwargs): 59 | instance = instance_ref() 60 | instance._step_count += 1 61 | wrapped = func.__get__(instance, cls) 62 | return wrapped(*args, **kwargs) 63 | 64 | # Note that the returned function here is no longer a bound method, 65 | # so attributes like `__func__` and `__self__` no longer exist. 66 | wrapper._with_counter = True 67 | return wrapper 68 | 69 | self.optimizer.step = with_counter(self.optimizer.step) 70 | self.optimizer._step_count = 0 71 | self._step_count = 0 72 | 73 | self.step() 74 | 75 | def state_dict(self): 76 | """Returns the state of the scheduler as a :class:`dict`. 77 | 78 | It contains an entry for every variable in self.__dict__ which 79 | is not the optimizer. 80 | """ 81 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 82 | 83 | def load_state_dict(self, state_dict): 84 | """Loads the schedulers state. 85 | 86 | Arguments: 87 | state_dict (dict): scheduler state. Should be an object returned 88 | from a call to :meth:`state_dict`. 89 | """ 90 | self.__dict__.update(state_dict) 91 | 92 | def get_last_lr(self): 93 | """ Return last computed learning rate by current scheduler. 94 | """ 95 | return self._last_lr 96 | 97 | def get_lr(self): 98 | # Compute learning rate using chainable form of the scheduler 99 | raise NotImplementedError 100 | 101 | def step(self, epoch=None): 102 | # Raise a warning if old pattern is detected 103 | # https://github.com/pytorch/pytorch/issues/20124 104 | if self._step_count == 1: 105 | if not hasattr(self.optimizer.step, "_with_counter"): 106 | warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " 107 | "initialization. Please, make sure to call `optimizer.step()` before " 108 | "`lr_scheduler.step()`. See more details at " 109 | "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) 110 | 111 | # Just check if there were two first lr_scheduler.step() calls before optimizer.step() 112 | elif self.optimizer._step_count < 1: 113 | warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " 114 | "In PyTorch 1.1.0 and later, you should call them in the opposite order: " 115 | "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " 116 | "will result in PyTorch skipping the first value of the learning rate schedule. " 117 | "See more details at " 118 | "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) 119 | self._step_count += 1 120 | 121 | class _enable_get_lr_call: 122 | 123 | def __init__(self, o): 124 | self.o = o 125 | 126 | def __enter__(self): 127 | self.o._get_lr_called_within_step = True 128 | return self 129 | 130 | def __exit__(self, type, value, traceback): 131 | self.o._get_lr_called_within_step = False 132 | 133 | with _enable_get_lr_call(self): 134 | if epoch is None: 135 | self.last_epoch += 1 136 | values = self.get_lr() 137 | else: 138 | warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) 139 | self.last_epoch = epoch 140 | if hasattr(self, "_get_closed_form_lr"): 141 | values = self._get_closed_form_lr() 142 | else: 143 | values = self.get_lr() 144 | 145 | for param_group, lr in zip(self.optimizer.param_groups, values): 146 | param_group['lr'] = lr 147 | 148 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 149 | 150 | class OneCycleLR(_LRScheduler): 151 | r"""Sets the learning rate of each parameter group according to the 152 | 1cycle learning rate policy. The 1cycle policy anneals the learning 153 | rate from an initial learning rate to some maximum learning rate and then 154 | from that maximum learning rate to some minimum learning rate much lower 155 | than the initial learning rate. 156 | This policy was initially described in the paper `Super-Convergence: 157 | Very Fast Training of Neural Networks Using Large Learning Rates`_. 158 | 159 | The 1cycle learning rate policy changes the learning rate after every batch. 160 | `step` should be called after a batch has been used for training. 161 | 162 | This scheduler is not chainable. 163 | 164 | Note also that the total number of steps in the cycle can be determined in one 165 | of two ways (listed in order of precedence): 166 | 167 | #. A value for total_steps is explicitly provided. 168 | #. A number of epochs (epochs) and a number of steps per epoch 169 | (steps_per_epoch) are provided. 170 | In this case, the number of total steps is inferred by 171 | total_steps = epochs * steps_per_epoch 172 | 173 | You must either provide a value for total_steps or provide a value for both 174 | epochs and steps_per_epoch. 175 | 176 | Args: 177 | optimizer (Optimizer): Wrapped optimizer. 178 | max_lr (float or list): Upper learning rate boundaries in the cycle 179 | for each parameter group. 180 | total_steps (int): The total number of steps in the cycle. Note that 181 | if a value is not provided here, then it must be inferred by providing 182 | a value for epochs and steps_per_epoch. 183 | Default: None 184 | epochs (int): The number of epochs to train for. This is used along 185 | with steps_per_epoch in order to infer the total number of steps in the cycle 186 | if a value for total_steps is not provided. 187 | Default: None 188 | steps_per_epoch (int): The number of steps per epoch to train for. This is 189 | used along with epochs in order to infer the total number of steps in the 190 | cycle if a value for total_steps is not provided. 191 | Default: None 192 | pct_start (float): The percentage of the cycle (in number of steps) spent 193 | increasing the learning rate. 194 | Default: 0.3 195 | anneal_strategy (str): {'cos', 'linear'} 196 | Specifies the annealing strategy: "cos" for cosine annealing, "linear" for 197 | linear annealing. 198 | Default: 'cos' 199 | cycle_momentum (bool): If ``True``, momentum is cycled inversely 200 | to learning rate between 'base_momentum' and 'max_momentum'. 201 | Default: True 202 | base_momentum (float or list): Lower momentum boundaries in the cycle 203 | for each parameter group. Note that momentum is cycled inversely 204 | to learning rate; at the peak of a cycle, momentum is 205 | 'base_momentum' and learning rate is 'max_lr'. 206 | Default: 0.85 207 | max_momentum (float or list): Upper momentum boundaries in the cycle 208 | for each parameter group. Functionally, 209 | it defines the cycle amplitude (max_momentum - base_momentum). 210 | Note that momentum is cycled inversely 211 | to learning rate; at the start of a cycle, momentum is 'max_momentum' 212 | and learning rate is 'base_lr' 213 | Default: 0.95 214 | div_factor (float): Determines the initial learning rate via 215 | initial_lr = max_lr/div_factor 216 | Default: 25 217 | final_div_factor (float): Determines the minimum learning rate via 218 | min_lr = initial_lr/final_div_factor 219 | Default: 1e4 220 | last_epoch (int): The index of the last batch. This parameter is used when 221 | resuming a training job. Since `step()` should be invoked after each 222 | batch instead of after each epoch, this number represents the total 223 | number of *batches* computed, not the total number of epochs computed. 224 | When last_epoch=-1, the schedule is started from the beginning. 225 | Default: -1 226 | 227 | Example: 228 | >>> data_loader = torch.utils.data.DataLoader(...) 229 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 230 | >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) 231 | >>> for epoch in range(10): 232 | >>> for batch in data_loader: 233 | >>> train_batch(...) 234 | >>> scheduler.step() 235 | 236 | 237 | .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: 238 | https://arxiv.org/abs/1708.07120 239 | """ 240 | def __init__(self, 241 | optimizer, 242 | max_lr, 243 | total_steps=None, 244 | epochs=None, 245 | steps_per_epoch=None, 246 | pct_start=0.3, 247 | anneal_strategy='cos', 248 | cycle_momentum=True, 249 | base_momentum=0.85, 250 | max_momentum=0.95, 251 | div_factor=25., 252 | final_div_factor=1e4, 253 | last_epoch=-1): 254 | 255 | # Validate optimizer 256 | # if not isinstance(optimizer, Optimizer): 257 | # raise TypeError('{} is not an Optimizer'.format( 258 | # type(optimizer).__name__)) 259 | self.optimizer = optimizer 260 | 261 | # Validate total_steps 262 | if total_steps is None and epochs is None and steps_per_epoch is None: 263 | raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)") 264 | elif total_steps is not None: 265 | if total_steps <= 0 or not isinstance(total_steps, int): 266 | raise ValueError("Expected non-negative integer total_steps, but got {}".format(total_steps)) 267 | self.total_steps = total_steps 268 | else: 269 | if epochs <= 0 or not isinstance(epochs, int): 270 | raise ValueError("Expected non-negative integer epochs, but got {}".format(epochs)) 271 | if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int): 272 | raise ValueError("Expected non-negative integer steps_per_epoch, but got {}".format(steps_per_epoch)) 273 | self.total_steps = epochs * steps_per_epoch 274 | self.step_size_up = float(pct_start * self.total_steps) - 1 275 | self.step_size_down = float(self.total_steps - self.step_size_up) - 1 276 | 277 | # Validate pct_start 278 | if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): 279 | raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start)) 280 | 281 | # Validate anneal_strategy 282 | if anneal_strategy not in ['cos', 'linear']: 283 | raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy)) 284 | elif anneal_strategy == 'cos': 285 | self.anneal_func = self._annealing_cos 286 | elif anneal_strategy == 'linear': 287 | self.anneal_func = self._annealing_linear 288 | 289 | # Initialize learning rate variables 290 | print('groups before: ', self.optimizer.param_groups[0].keys()) 291 | 292 | max_lrs = self._format_param('max_lr', self.optimizer, max_lr) 293 | if last_epoch == -1: 294 | for idx, group in enumerate(self.optimizer.param_groups): 295 | self.optimizer.param_groups[idx]['initial_lr'] = max_lrs[idx] / div_factor 296 | self.optimizer.param_groups[idx]['max_lr'] = max_lrs[idx] 297 | self.optimizer.param_groups[idx]['min_lr'] = group['initial_lr'] / final_div_factor 298 | 299 | print('groups: ', self.optimizer.param_groups[0].keys()) 300 | exit() 301 | 302 | # Initialize momentum variables 303 | self.cycle_momentum = cycle_momentum 304 | if self.cycle_momentum: 305 | if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults: 306 | raise Valuerror('optimizer must support momentum with `cycle_momentum` option enabled') 307 | self.use_beta1 = 'betas' in self.optimizer.defaults 308 | max_momentums = self._format_param('max_momentum', optimizer, max_momentum) 309 | base_momentums = self._format_param('base_momentum', optimizer, base_momentum) 310 | if last_epoch == -1: 311 | for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups): 312 | if self.use_beta1: 313 | _, beta2 = group['betas'] 314 | group['betas'] = (m_momentum, beta2) 315 | else: 316 | group['momentum'] = m_momentum 317 | group['max_momentum'] = m_momentum 318 | group['base_momentum'] = b_momentum 319 | 320 | super(OneCycleLR, self).__init__(optimizer, last_epoch) 321 | 322 | def _format_param(self, name, optimizer, param): 323 | """Return correctly formatted lr/momentum for each param group.""" 324 | if isinstance(param, (list, tuple)): 325 | if len(param) != len(optimizer.param_groups): 326 | raise ValueError("expected {} values for {}, got {}".format( 327 | len(optimizer.param_groups), name, len(param))) 328 | return param 329 | else: 330 | return [param] * len(optimizer.param_groups) 331 | 332 | def _annealing_cos(self, start, end, pct): 333 | "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." 334 | cos_out = math.cos(math.pi * pct) + 1 335 | return end + (start - end) / 2.0 * cos_out 336 | 337 | def _annealing_linear(self, start, end, pct): 338 | "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." 339 | return (end - start) * pct + start 340 | 341 | def get_lr(self): 342 | if not self._get_lr_called_within_step: 343 | warnings.warn("To get the last learning rate computed by the scheduler, " 344 | "please use `get_last_lr()`.", UserWarning) 345 | 346 | lrs = [] 347 | step_num = self.last_epoch 348 | 349 | if step_num > self.total_steps: 350 | raise ValueError("Tried to step {} times. The specified number of total steps is {}" 351 | .format(step_num + 1, self.total_steps)) 352 | 353 | for group in self.optimizer.param_groups: 354 | if step_num <= self.step_size_up: 355 | computed_lr = self.anneal_func(group['initial_lr'], group['max_lr'], step_num / self.step_size_up) 356 | if self.cycle_momentum: 357 | computed_momentum = self.anneal_func(group['max_momentum'], group['base_momentum'], 358 | step_num / self.step_size_up) 359 | else: 360 | down_step_num = step_num - self.step_size_up 361 | computed_lr = self.anneal_func(group['max_lr'], group['min_lr'], down_step_num / self.step_size_down) 362 | if self.cycle_momentum: 363 | computed_momentum = self.anneal_func(group['base_momentum'], group['max_momentum'], 364 | down_step_num / self.step_size_down) 365 | 366 | lrs.append(computed_lr) 367 | if self.cycle_momentum: 368 | if self.use_beta1: 369 | _, beta2 = group['betas'] 370 | group['betas'] = (computed_momentum, beta2) 371 | else: 372 | group['momentum'] = computed_momentum 373 | 374 | return lrs 375 | -------------------------------------------------------------------------------- /core/raft_gma_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from update import GMAUpdateBlock 6 | from extractor import BasicEncoder 7 | from corr import CorrBlock 8 | from utils.utils import bilinear_sampler, coords_grid, upflow8 9 | from gma import Attention 10 | 11 | from swin_basic import BasicSwinUpdate 12 | 13 | try: 14 | autocast = torch.cuda.amp.autocast 15 | except: 16 | # dummy autocast for PyTorch < 1.6 17 | class autocast: 18 | def __init__(self, enabled): 19 | pass 20 | 21 | def __enter__(self): 22 | pass 23 | 24 | def __exit__(self, *args): 25 | pass 26 | 27 | 28 | class RAFTGMAModel(nn.Module): 29 | def __init__(self, args): 30 | super().__init__() 31 | self.args = args 32 | self.args.num_heads = 1 33 | self.args.position_only = False 34 | self.args.position_and_content = False 35 | 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | args.corr_levels = 4 39 | args.corr_radius = 4 40 | 41 | if 'dropout' not in self.args: 42 | self.args.dropout = 0 43 | 44 | # feature network, context network, and update block 45 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 46 | self.cnet = BasicEncoder(output_dim=hdim + cdim, norm_fn='batch', dropout=args.dropout) 47 | self.update_block = GMAUpdateBlock(self.args, hidden_dim=hdim) 48 | self.att = Attention(args=self.args, dim=cdim, heads=self.args.num_heads, max_pos_size=160, dim_head=cdim) 49 | 50 | def freeze_bn(self): 51 | for m in self.modules(): 52 | if isinstance(m, nn.BatchNorm2d): 53 | m.eval() 54 | 55 | def initialize_flow(self, img): 56 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 57 | N, C, H, W = img.shape 58 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device) 59 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device) 60 | 61 | # optical flow computed as difference: flow = coords1 - coords0 62 | return coords0, coords1 63 | 64 | def upsample_flow(self, flow, mask): 65 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 66 | N, _, H, W = flow.shape 67 | mask = mask.view(N, 1, 9, 8, 8, H, W) 68 | mask = torch.softmax(mask, dim=2) 69 | 70 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 71 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 72 | 73 | up_flow = torch.sum(mask * up_flow, dim=2) 74 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 75 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 76 | 77 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 78 | """ Estimate optical flow between pair of frames """ 79 | 80 | image1 = 2 * (image1 / 255.0) - 1.0 81 | image2 = 2 * (image2 / 255.0) - 1.0 82 | 83 | image1 = image1.contiguous() 84 | image2 = image2.contiguous() 85 | 86 | hdim = self.hidden_dim 87 | cdim = self.context_dim 88 | 89 | # run the feature network 90 | with autocast(enabled=self.args.mixed_precision): 91 | fmap1, fmap2 = self.fnet([image1, image2]) 92 | 93 | fmap1 = fmap1.float() 94 | fmap2 = fmap2.float() 95 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 96 | 97 | # run the context network 98 | with autocast(enabled=self.args.mixed_precision): 99 | cnet = self.cnet(image1) 100 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 101 | net = torch.tanh(net) 102 | inp = torch.relu(inp) 103 | # attention, att_c, att_p = self.att(inp) 104 | attention = self.att(inp) 105 | 106 | coords0, coords1 = self.initialize_flow(image1) 107 | 108 | if flow_init is not None: 109 | coords1 = coords1 + flow_init 110 | 111 | flow_predictions = [] 112 | for itr in range(iters): 113 | coords1 = coords1.detach() 114 | corr = corr_fn(coords1) # index correlation volume 115 | 116 | flow = coords1 - coords0 117 | with autocast(enabled=self.args.mixed_precision): 118 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow, attention) 119 | 120 | # F(t+1) = F(t) + \Delta(t) 121 | coords1 = coords1 + delta_flow 122 | 123 | # upsample predictions 124 | if up_mask is None: 125 | flow_up = upflow8(coords1 - coords0) 126 | else: 127 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 128 | 129 | flow_predictions.append(flow_up) 130 | 131 | if test_mode: 132 | return coords1 - coords0, flow_up 133 | 134 | return flow_predictions, None 135 | 136 | 137 | 138 | if __name__ == "__main__": 139 | att = Attention(dim=128, heads=1) 140 | fmap = torch.randn(2, 128, 40, 90) 141 | out = att(fmap) 142 | 143 | print(out.shape) 144 | -------------------------------------------------------------------------------- /core/raft_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from update import BasicUpdateBlock, SmallUpdateBlock 8 | from extractor import BasicEncoder, SmallEncoder 9 | from corr import CorrBlock, AlternateCorrBlock 10 | from utils.utils import bilinear_sampler, coords_grid, upflow8 11 | 12 | try: 13 | autocast = torch.cuda.amp.autocast 14 | except: 15 | # dummy autocast for PyTorch < 1.6 16 | class autocast: 17 | def __init__(self, enabled): 18 | pass 19 | def __enter__(self): 20 | pass 21 | def __exit__(self, *args): 22 | pass 23 | 24 | 25 | class RAFTModel(nn.Module): 26 | def __init__(self, args): 27 | super(RAFTModel, self).__init__() 28 | self.args = args 29 | 30 | if args.small: 31 | self.hidden_dim = hdim = 96 32 | self.context_dim = cdim = 64 33 | args.corr_levels = 4 34 | args.corr_radius = 3 35 | 36 | else: 37 | self.hidden_dim = hdim = 128 38 | self.context_dim = cdim = 128 39 | args.corr_levels = 4 40 | args.corr_radius = 4 41 | 42 | if not hasattr(self.args, 'dropout'): 43 | self.args.dropout = 0 44 | 45 | if not hasattr(self.args, 'alternate_corr'): 46 | self.args.alternate_corr = False 47 | 48 | # feature network, context network, and update block 49 | if args.small: 50 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 51 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) 52 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 53 | 54 | else: 55 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 56 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 57 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 58 | 59 | def freeze_bn(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.BatchNorm2d): 62 | m.eval() 63 | 64 | def initialize_flow(self, img): 65 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 66 | N, C, H, W = img.shape 67 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 68 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 69 | 70 | # optical flow computed as difference: flow = coords1 - coords0 71 | return coords0, coords1 72 | 73 | def upsample_flow(self, flow, mask): 74 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 75 | N, _, H, W = flow.shape 76 | mask = mask.view(N, 1, 9, 8, 8, H, W) 77 | mask = torch.softmax(mask, dim=2) 78 | 79 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 80 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 81 | 82 | up_flow = torch.sum(mask * up_flow, dim=2) 83 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 84 | return up_flow.reshape(N, 2, 8*H, 8*W) 85 | 86 | 87 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 88 | """ Estimate optical flow between pair of frames """ 89 | 90 | image1 = 2 * (image1 / 255.0) - 1.0 91 | image2 = 2 * (image2 / 255.0) - 1.0 92 | 93 | image1 = image1.contiguous() 94 | image2 = image2.contiguous() 95 | 96 | hdim = self.hidden_dim 97 | cdim = self.context_dim 98 | 99 | # run the feature network 100 | with autocast(enabled=self.args.mixed_precision): 101 | fmap1, fmap2 = self.fnet([image1, image2]) 102 | 103 | fmap1 = fmap1.float() 104 | fmap2 = fmap2.float() 105 | if self.args.alternate_corr: 106 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 107 | else: 108 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 109 | 110 | # run the context network 111 | with autocast(enabled=self.args.mixed_precision): 112 | cnet = self.cnet(image1) 113 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 114 | net = torch.tanh(net) 115 | inp = torch.relu(inp) 116 | 117 | coords0, coords1 = self.initialize_flow(image1) 118 | 119 | if flow_init is not None: 120 | coords1 = coords1 + flow_init 121 | 122 | flow_predictions = [] 123 | for itr in range(iters): 124 | coords1 = coords1.detach() 125 | corr = corr_fn(coords1) # index correlation volume 126 | 127 | flow = coords1 - coords0 128 | with autocast(enabled=self.args.mixed_precision): 129 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 130 | 131 | # F(t+1) = F(t) + \Delta(t) 132 | coords1 = coords1 + delta_flow 133 | 134 | # upsample predictions 135 | if up_mask is None: 136 | flow_up = upflow8(coords1 - coords0) 137 | else: 138 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 139 | 140 | flow_predictions.append(flow_up) 141 | 142 | if test_mode: 143 | return coords1 - coords0, flow_up 144 | 145 | return flow_predictions, fmap1, fmap2 146 | 147 | -------------------------------------------------------------------------------- /core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from gma import Aggregate 6 | 7 | 8 | class FlowHead(nn.Module): 9 | def __init__(self, input_dim=128, hidden_dim=256): 10 | super(FlowHead, self).__init__() 11 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 12 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | def forward(self, x): 16 | return self.conv2(self.relu(self.conv1(x))) 17 | 18 | 19 | class ConvGRU(nn.Module): 20 | def __init__(self, hidden_dim=128, input_dim=192+128): 21 | super(ConvGRU, self).__init__() 22 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 23 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 24 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 25 | 26 | def forward(self, h, x): 27 | hx = torch.cat([h, x], dim=1) 28 | 29 | z = torch.sigmoid(self.convz(hx)) 30 | r = torch.sigmoid(self.convr(hx)) 31 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 32 | 33 | h = (1-z) * h + z * q 34 | return h 35 | 36 | 37 | class SepConvGRU(nn.Module): 38 | def __init__(self, hidden_dim=128, input_dim=192+128): 39 | super(SepConvGRU, self).__init__() 40 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 41 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 42 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 43 | 44 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 45 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 46 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 47 | 48 | 49 | def forward(self, h, x): 50 | # horizontal 51 | hx = torch.cat([h, x], dim=1) 52 | z = torch.sigmoid(self.convz1(hx)) 53 | r = torch.sigmoid(self.convr1(hx)) 54 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 55 | h = (1-z) * h + z * q 56 | 57 | # vertical 58 | hx = torch.cat([h, x], dim=1) 59 | z = torch.sigmoid(self.convz2(hx)) 60 | r = torch.sigmoid(self.convr2(hx)) 61 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 62 | h = (1-z) * h + z * q 63 | 64 | return h 65 | 66 | 67 | class SmallMotionEncoder(nn.Module): 68 | def __init__(self, args): 69 | super(SmallMotionEncoder, self).__init__() 70 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 71 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 72 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 73 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 74 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 75 | 76 | def forward(self, flow, corr): 77 | cor = F.relu(self.convc1(corr)) 78 | flo = F.relu(self.convf1(flow)) 79 | flo = F.relu(self.convf2(flo)) 80 | cor_flo = torch.cat([cor, flo], dim=1) 81 | out = F.relu(self.conv(cor_flo)) 82 | return torch.cat([out, flow], dim=1) 83 | 84 | 85 | class BasicMotionEncoder(nn.Module): 86 | def __init__(self, args): 87 | super(BasicMotionEncoder, self).__init__() 88 | if hasattr(args, 'motion_feat_indim'): 89 | cor_planes = args.motion_feat_indim 90 | else: 91 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 92 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 93 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 94 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 95 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 96 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 97 | 98 | def forward(self, flow, corr): 99 | cor = F.relu(self.convc1(corr)) 100 | cor = F.relu(self.convc2(cor)) 101 | flo = F.relu(self.convf1(flow)) 102 | flo = F.relu(self.convf2(flo)) 103 | 104 | cor_flo = torch.cat([cor, flo], dim=1) 105 | out = F.relu(self.conv(cor_flo)) 106 | return torch.cat([out, flow], dim=1) 107 | 108 | 109 | class SmallUpdateBlock(nn.Module): 110 | def __init__(self, args, hidden_dim=96): 111 | super(SmallUpdateBlock, self).__init__() 112 | self.encoder = SmallMotionEncoder(args) 113 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 114 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 115 | 116 | def forward(self, net, inp, corr, flow): 117 | motion_features = self.encoder(flow, corr) 118 | inp = torch.cat([inp, motion_features], dim=1) 119 | net = self.gru(net, inp) 120 | delta_flow = self.flow_head(net) 121 | 122 | return net, None, delta_flow 123 | 124 | 125 | class BasicUpdateBlock(nn.Module): 126 | def __init__(self, args, hidden_dim=128, input_dim=128): 127 | super(BasicUpdateBlock, self).__init__() 128 | self.args = args 129 | self.encoder = BasicMotionEncoder(args) 130 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+input_dim) 131 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 132 | 133 | hidden_dim2 = (hidden_dim // 256)*128 + 256 134 | 135 | self.mask = nn.Sequential( 136 | nn.Conv2d(hidden_dim, hidden_dim2, 3, padding=1), 137 | nn.ReLU(inplace=True), 138 | nn.Conv2d(hidden_dim2, 64*9, 1, padding=0)) 139 | 140 | def forward(self, net, inp, corr, flow, upsample=True): 141 | motion_features = self.encoder(flow, corr) 142 | inp = torch.cat([inp, motion_features], dim=1) 143 | 144 | net = self.gru(net, inp) 145 | delta_flow = self.flow_head(net) 146 | 147 | # scale mask to balence gradients 148 | mask = .25 * self.mask(net) 149 | return net, mask, delta_flow 150 | 151 | 152 | class GMAUpdateBlock(nn.Module): 153 | def __init__(self, args, hidden_dim=128): 154 | super().__init__() 155 | self.args = args 156 | self.encoder = BasicMotionEncoder(args) 157 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim+hidden_dim) 158 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 159 | 160 | self.mask = nn.Sequential( 161 | nn.Conv2d(128, 256, 3, padding=1), 162 | nn.ReLU(inplace=True), 163 | nn.Conv2d(256, 64*9, 1, padding=0)) 164 | 165 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=self.args.num_heads) 166 | 167 | def forward(self, net, inp, corr, flow, attention): 168 | motion_features = self.encoder(flow, corr) 169 | motion_features_global = self.aggregator(attention, motion_features) 170 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) 171 | 172 | # Attentional update 173 | net = self.gru(net, inp_cat) 174 | 175 | delta_flow = self.flow_head(net) 176 | 177 | # scale mask to balence gradients 178 | mask = .25 * self.mask(net) 179 | return net, mask, delta_flow 180 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # functions from timm 2 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 3 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible 4 | from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ -------------------------------------------------------------------------------- /core/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, img1, img2): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if np.random.rand() < self.asymmetric_color_aug_prob: 41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 43 | 44 | # symmetric 45 | else: 46 | image_stack = np.concatenate([img1, img2], axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | img1, img2 = np.split(image_stack, 2, axis=0) 49 | 50 | return img1, img2 51 | 52 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 53 | """ Occlusion augmentation """ 54 | 55 | ht, wd = img1.shape[:2] 56 | if np.random.rand() < self.eraser_aug_prob: 57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | x0 = np.random.randint(0, wd) 60 | y0 = np.random.randint(0, ht) 61 | dx = np.random.randint(bounds[0], bounds[1]) 62 | dy = np.random.randint(bounds[0], bounds[1]) 63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 64 | 65 | return img1, img2 66 | 67 | def spatial_transform(self, img1, img2, flow): 68 | # randomly sample scale 69 | ht, wd = img1.shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | flow = flow * [scale_x, scale_y] 90 | 91 | if self.do_flip: 92 | if np.random.rand() < self.h_flip_prob: # h-flip 93 | img1 = img1[:, ::-1] 94 | img2 = img2[:, ::-1] 95 | flow = flow[:, ::-1] * [-1.0, 1.0] 96 | 97 | if np.random.rand() < self.v_flip_prob: # v-flip 98 | img1 = img1[::-1, :] 99 | img2 = img2[::-1, :] 100 | flow = flow[::-1, :] * [1.0, -1.0] 101 | 102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 104 | 105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 108 | 109 | return img1, img2, flow 110 | 111 | def __call__(self, img1, img2, flow): 112 | img1, img2 = self.color_transform(img1, img2) 113 | img1, img2 = self.eraser_transform(img1, img2) 114 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 115 | 116 | img1 = np.ascontiguousarray(img1) 117 | img2 = np.ascontiguousarray(img2) 118 | flow = np.ascontiguousarray(flow) 119 | 120 | return img1, img2, flow 121 | 122 | class SparseFlowAugmentor: 123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 124 | # spatial augmentation params 125 | self.crop_size = crop_size 126 | self.min_scale = min_scale 127 | self.max_scale = max_scale 128 | self.spatial_aug_prob = 0.8 129 | self.stretch_prob = 0.8 130 | self.max_stretch = 0.2 131 | 132 | # flip augmentation params 133 | self.do_flip = do_flip 134 | self.h_flip_prob = 0.5 135 | self.v_flip_prob = 0.1 136 | 137 | # photometric augmentation params 138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 139 | self.asymmetric_color_aug_prob = 0.2 140 | self.eraser_aug_prob = 0.5 141 | 142 | def color_transform(self, img1, img2): 143 | image_stack = np.concatenate([img1, img2], axis=0) 144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 145 | img1, img2 = np.split(image_stack, 2, axis=0) 146 | return img1, img2 147 | 148 | def eraser_transform(self, img1, img2): 149 | ht, wd = img1.shape[:2] 150 | if np.random.rand() < self.eraser_aug_prob: 151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 152 | for _ in range(np.random.randint(1, 3)): 153 | x0 = np.random.randint(0, wd) 154 | y0 = np.random.randint(0, ht) 155 | dx = np.random.randint(50, 100) 156 | dy = np.random.randint(50, 100) 157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 158 | 159 | return img1, img2 160 | 161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 162 | ht, wd = flow.shape[:2] 163 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 164 | coords = np.stack(coords, axis=-1) 165 | 166 | coords = coords.reshape(-1, 2).astype(np.float32) 167 | flow = flow.reshape(-1, 2).astype(np.float32) 168 | valid = valid.reshape(-1).astype(np.float32) 169 | 170 | coords0 = coords[valid>=1] 171 | flow0 = flow[valid>=1] 172 | 173 | ht1 = int(round(ht * fy)) 174 | wd1 = int(round(wd * fx)) 175 | 176 | coords1 = coords0 * [fx, fy] 177 | flow1 = flow0 * [fx, fy] 178 | 179 | xx = np.round(coords1[:,0]).astype(np.int32) 180 | yy = np.round(coords1[:,1]).astype(np.int32) 181 | 182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 183 | xx = xx[v] 184 | yy = yy[v] 185 | flow1 = flow1[v] 186 | 187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 189 | 190 | flow_img[yy, xx] = flow1 191 | valid_img[yy, xx] = 1 192 | 193 | return flow_img, valid_img 194 | 195 | def spatial_transform(self, img1, img2, flow, valid): 196 | # randomly sample scale 197 | 198 | ht, wd = img1.shape[:2] 199 | min_scale = np.maximum( 200 | (self.crop_size[0] + 1) / float(ht), 201 | (self.crop_size[1] + 1) / float(wd)) 202 | 203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 204 | scale_x = np.clip(scale, min_scale, None) 205 | scale_y = np.clip(scale, min_scale, None) 206 | 207 | if np.random.rand() < self.spatial_aug_prob: 208 | # rescale the images 209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 212 | 213 | if self.do_flip: 214 | if np.random.rand() < 0.5: # h-flip 215 | img1 = img1[:, ::-1] 216 | img2 = img2[:, ::-1] 217 | flow = flow[:, ::-1] * [-1.0, 1.0] 218 | valid = valid[:, ::-1] 219 | 220 | margin_y = 20 221 | margin_x = 50 222 | 223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 225 | 226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 228 | 229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 233 | return img1, img2, flow, valid 234 | 235 | 236 | def __call__(self, img1, img2, flow, valid): 237 | img1, img2 = self.color_transform(img1, img2) 238 | img1, img2 = self.eraser_transform(img1, img2) 239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 240 | 241 | img1 = np.ascontiguousarray(img1) 242 | img2 = np.ascontiguousarray(img2) 243 | flow = np.ascontiguousarray(flow) 244 | valid = np.ascontiguousarray(valid) 245 | 246 | return img1, img2, flow, valid 247 | -------------------------------------------------------------------------------- /core/utils/drop.py: -------------------------------------------------------------------------------- 1 | """ DropBlock, DropPath 2 | 3 | PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. 4 | 5 | Papers: 6 | DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) 7 | 8 | Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) 9 | 10 | Code: 11 | DropBlock impl inspired by two Tensorflow impl that I liked: 12 | - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 13 | - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman 16 | """ 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | def drop_block_2d( 23 | x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, 24 | with_noise: bool = False, inplace: bool = False, batchwise: bool = False): 25 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 26 | 27 | DropBlock with an experimental gaussian noise option. This layer has been tested on a few training 28 | runs with success, but needs further validation and possibly optimization for lower runtime impact. 29 | """ 30 | B, C, H, W = x.shape 31 | total_size = W * H 32 | clipped_block_size = min(block_size, min(W, H)) 33 | # seed_drop_rate, the gamma parameter 34 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( 35 | (W - block_size + 1) * (H - block_size + 1)) 36 | 37 | # Forces the block to be inside the feature map. 38 | w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) 39 | valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ 40 | ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) 41 | valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) 42 | 43 | if batchwise: 44 | # one mask for whole batch, quite a bit faster 45 | uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) 46 | else: 47 | uniform_noise = torch.rand_like(x) 48 | block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) 49 | block_mask = -F.max_pool2d( 50 | -block_mask, 51 | kernel_size=clipped_block_size, # block_size, 52 | stride=1, 53 | padding=clipped_block_size // 2) 54 | 55 | if with_noise: 56 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) 57 | if inplace: 58 | x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) 59 | else: 60 | x = x * block_mask + normal_noise * (1 - block_mask) 61 | else: 62 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) 63 | if inplace: 64 | x.mul_(block_mask * normalize_scale) 65 | else: 66 | x = x * block_mask * normalize_scale 67 | return x 68 | 69 | 70 | def drop_block_fast_2d( 71 | x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, 72 | gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False): 73 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 74 | 75 | DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid 76 | block mask at edges. 77 | """ 78 | B, C, H, W = x.shape 79 | total_size = W * H 80 | clipped_block_size = min(block_size, min(W, H)) 81 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( 82 | (W - block_size + 1) * (H - block_size + 1)) 83 | 84 | block_mask = torch.empty_like(x).bernoulli_(gamma) 85 | block_mask = F.max_pool2d( 86 | block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) 87 | 88 | if with_noise: 89 | normal_noise = torch.empty_like(x).normal_() 90 | if inplace: 91 | x.mul_(1. - block_mask).add_(normal_noise * block_mask) 92 | else: 93 | x = x * (1. - block_mask) + normal_noise * block_mask 94 | else: 95 | block_mask = 1 - block_mask 96 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype) 97 | if inplace: 98 | x.mul_(block_mask * normalize_scale) 99 | else: 100 | x = x * block_mask * normalize_scale 101 | return x 102 | 103 | 104 | class DropBlock2d(nn.Module): 105 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 106 | """ 107 | 108 | def __init__( 109 | self, 110 | drop_prob: float = 0.1, 111 | block_size: int = 7, 112 | gamma_scale: float = 1.0, 113 | with_noise: bool = False, 114 | inplace: bool = False, 115 | batchwise: bool = False, 116 | fast: bool = True): 117 | super(DropBlock2d, self).__init__() 118 | self.drop_prob = drop_prob 119 | self.gamma_scale = gamma_scale 120 | self.block_size = block_size 121 | self.with_noise = with_noise 122 | self.inplace = inplace 123 | self.batchwise = batchwise 124 | self.fast = fast # FIXME finish comparisons of fast vs not 125 | 126 | def forward(self, x): 127 | if not self.training or not self.drop_prob: 128 | return x 129 | if self.fast: 130 | return drop_block_fast_2d( 131 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace) 132 | else: 133 | return drop_block_2d( 134 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) 135 | 136 | 137 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 138 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 139 | 140 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 141 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 142 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 143 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 144 | 'survival rate' as the argument. 145 | 146 | """ 147 | if drop_prob == 0. or not training: 148 | return x 149 | keep_prob = 1 - drop_prob 150 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 151 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 152 | if keep_prob > 0.0 and scale_by_keep: 153 | random_tensor.div_(keep_prob) 154 | return x * random_tensor 155 | 156 | 157 | class DropPath(nn.Module): 158 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 159 | """ 160 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 161 | super(DropPath, self).__init__() 162 | self.drop_prob = drop_prob 163 | self.scale_by_keep = scale_by_keep 164 | 165 | def forward(self, x): 166 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) -------------------------------------------------------------------------------- /core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /core/utils/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < round_limit * v: 30 | new_v += divisor 31 | return new_v -------------------------------------------------------------------------------- /core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd): 75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /core/utils/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | from torch.nn.init import _calculate_fan_in_and_fan_out 6 | 7 | 8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | with torch.no_grad(): 21 | # Values are generated by using a truncated uniform distribution and 22 | # then using the inverse CDF for the normal distribution. 23 | # Get upper and lower cdf values 24 | l = norm_cdf((a - mean) / std) 25 | u = norm_cdf((b - mean) / std) 26 | 27 | # Uniformly fill tensor with values from [l, u], then translate to 28 | # [2l-1, 2u-1]. 29 | tensor.uniform_(2 * l - 1, 2 * u - 1) 30 | 31 | # Use inverse cdf transform for normal distribution to get truncated 32 | # standard normal 33 | tensor.erfinv_() 34 | 35 | # Transform to proper mean, std 36 | tensor.mul_(std * math.sqrt(2.)) 37 | tensor.add_(mean) 38 | 39 | # Clamp to ensure it's in the proper range 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 45 | # type: (Tensor, float, float, float, float) -> Tensor 46 | r"""Fills the input Tensor with values drawn from a truncated 47 | normal distribution. The values are effectively drawn from the 48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 49 | with values outside :math:`[a, b]` redrawn until they are within 50 | the bounds. The method used for generating the random values works 51 | best when :math:`a \leq \text{mean} \leq b`. 52 | Args: 53 | tensor: an n-dimensional `torch.Tensor` 54 | mean: the mean of the normal distribution 55 | std: the standard deviation of the normal distribution 56 | a: the minimum cutoff value 57 | b: the maximum cutoff value 58 | Examples: 59 | >>> w = torch.empty(3, 5) 60 | >>> nn.init.trunc_normal_(w) 61 | """ 62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 63 | 64 | 65 | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): 66 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 67 | if mode == 'fan_in': 68 | denom = fan_in 69 | elif mode == 'fan_out': 70 | denom = fan_out 71 | elif mode == 'fan_avg': 72 | denom = (fan_in + fan_out) / 2 73 | 74 | variance = scale / denom 75 | 76 | if distribution == "truncated_normal": 77 | # constant is stddev of standard normal truncated to (-2, 2) 78 | trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) 79 | elif distribution == "normal": 80 | tensor.normal_(std=math.sqrt(variance)) 81 | elif distribution == "uniform": 82 | bound = math.sqrt(3 * variance) 83 | tensor.uniform_(-bound, bound) 84 | else: 85 | raise ValueError(f"invalid distribution {distribution}") 86 | 87 | 88 | def lecun_normal_(tensor): 89 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') -------------------------------------------------------------------------------- /demo-frames/frame_0016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0016.png -------------------------------------------------------------------------------- /demo-frames/frame_0017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0017.png -------------------------------------------------------------------------------- /demo-frames/frame_0018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0018.png -------------------------------------------------------------------------------- /demo-frames/frame_0019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0019.png -------------------------------------------------------------------------------- /demo-frames/frame_0020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0020.png -------------------------------------------------------------------------------- /demo-frames/frame_0021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0021.png -------------------------------------------------------------------------------- /demo-frames/frame_0022.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0022.png -------------------------------------------------------------------------------- /demo-frames/frame_0023.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0023.png -------------------------------------------------------------------------------- /demo-frames/frame_0024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0024.png -------------------------------------------------------------------------------- /demo-frames/frame_0025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng94/GMFlowNet/4d45870450b8418f8ed7d97252bd2f014ecf1ce9/demo-frames/frame_0025.png -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | import argparse 5 | import os 6 | import cv2 7 | import glob 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | 12 | from utils import flow_viz 13 | from utils.utils import InputPadder 14 | from core import create_model 15 | 16 | 17 | DEVICE = 'cuda' 18 | 19 | def load_image(imfile): 20 | img = np.array(Image.open(imfile)).astype(np.uint8) 21 | img = torch.from_numpy(img).permute(2, 0, 1).float() 22 | return img[None].to(DEVICE) 23 | 24 | 25 | def viz(img, flo): 26 | img = img[0].permute(1,2,0).cpu().numpy() 27 | flo = flo[0].permute(1,2,0).cpu().numpy() 28 | 29 | # map flow to rgb image 30 | flo = flow_viz.flow_to_image(flo) 31 | img_flo = np.concatenate([img, flo], axis=0) 32 | 33 | # import matplotlib.pyplot as plt 34 | # plt.imshow(img_flo / 255.0) 35 | # plt.show() 36 | 37 | cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 38 | cv2.waitKey() 39 | 40 | 41 | def demo(args): 42 | model = torch.nn.DataParallel(create_model(args)) 43 | model.load_state_dict(torch.load(args.ckpt)) 44 | 45 | model = model.module 46 | model.to(DEVICE) 47 | model.eval() 48 | 49 | with torch.no_grad(): 50 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 51 | glob.glob(os.path.join(args.path, '*.jpg')) 52 | 53 | images = sorted(images) 54 | for imfile1, imfile2 in zip(images[:-1], images[1:]): 55 | image1 = load_image(imfile1) 56 | image2 = load_image(imfile2) 57 | 58 | padder = InputPadder(image1.shape) 59 | image1, image2 = padder.pad(image1, image2) 60 | 61 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 62 | viz(image1, flow_up) 63 | 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('--model', default='gmflownet', help="mdoel class. ``_model.py should be in ./core and `Model` should be defined in this file") 68 | parser.add_argument('--ckpt', help="restored checkpoint") 69 | 70 | parser.add_argument('--path', help="dataset for evaluation") 71 | parser.add_argument('--use_mix_attn', action='store_true', help='use mixture of POLA and axial attentions') 72 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 73 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 74 | args = parser.parse_args() 75 | 76 | demo(args) 77 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | from PIL import Image 5 | import argparse 6 | import os 7 | import time 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import matplotlib.pyplot as plt 12 | 13 | import datasets 14 | from utils import flow_viz 15 | from utils import frame_utils 16 | 17 | # from raft import RAFT, RAFT_Transformer 18 | from core import create_model 19 | from utils.utils import InputPadder, forward_interpolate 20 | 21 | 22 | @torch.no_grad() 23 | def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'): 24 | """ Create submission for the Sintel leaderboard """ 25 | model.eval() 26 | for dstype in ['clean', 'final']: 27 | test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype) 28 | 29 | flow_prev, sequence_prev = None, None 30 | for test_id in range(len(test_dataset)): 31 | image1, image2, (sequence, frame) = test_dataset[test_id] 32 | if sequence != sequence_prev: 33 | flow_prev = None 34 | 35 | padder = InputPadder(image1.shape) 36 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 37 | 38 | flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True) 39 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 40 | 41 | if warm_start: 42 | flow_prev = forward_interpolate(flow_low[0])[None].cuda() 43 | 44 | output_dir = os.path.join(output_path, dstype, sequence) 45 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1)) 46 | 47 | if not os.path.exists(output_dir): 48 | os.makedirs(output_dir) 49 | 50 | frame_utils.writeFlow(output_file, flow) 51 | sequence_prev = sequence 52 | 53 | 54 | @torch.no_grad() 55 | def create_kitti_submission(model, iters=24, output_path='kitti_submission'): 56 | """ Create submission for the Sintel leaderboard """ 57 | model.eval() 58 | test_dataset = datasets.KITTI(split='testing', aug_params=None) 59 | 60 | if not os.path.exists(output_path): 61 | os.makedirs(output_path) 62 | 63 | for test_id in range(len(test_dataset)): 64 | image1, image2, (frame_id, ) = test_dataset[test_id] 65 | padder = InputPadder(image1.shape, mode='kitti') 66 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 67 | 68 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True) 69 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 70 | 71 | output_filename = os.path.join(output_path, frame_id) 72 | frame_utils.writeFlowKITTI(output_filename, flow) 73 | 74 | 75 | @torch.no_grad() 76 | def validate_chairs(model, iters=24): 77 | """ Perform evaluation on the FlyingChairs (test) split """ 78 | model.eval() 79 | epe_list = [] 80 | 81 | val_dataset = datasets.FlyingChairs(split='validation') 82 | for val_id in range(len(val_dataset)): 83 | image1, image2, flow_gt, _ = val_dataset[val_id] 84 | image1 = image1[None].cuda() 85 | image2 = image2[None].cuda() 86 | 87 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True) 88 | epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt() 89 | epe_list.append(epe.view(-1).numpy()) 90 | 91 | epe = np.mean(np.concatenate(epe_list)) 92 | print("Validation Chairs EPE: %f" % epe) 93 | return {'chairs': epe} 94 | 95 | 96 | @torch.no_grad() 97 | def validate_sintel(model, iters=32, warm_start=False): 98 | """ Peform validation using the Sintel (train) split """ 99 | model.eval() 100 | results = {} 101 | for dstype in ['clean', 'final']: 102 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype, is_validate=True) 103 | epe_list = [] 104 | 105 | flow_prev, sequence_prev = None, None 106 | for val_id in range(len(val_dataset)): 107 | image1, image2, flow_gt, _, (sequence, frame) = val_dataset[val_id] 108 | image1 = image1[None].cuda() 109 | image2 = image2[None].cuda() 110 | 111 | if sequence != sequence_prev: 112 | flow_prev = None 113 | 114 | padder = InputPadder(image1.shape) 115 | image1, image2 = padder.pad(image1, image2) 116 | 117 | flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True) 118 | flow = padder.unpad(flow_pr[0]).cpu() 119 | 120 | if warm_start: 121 | flow_prev = forward_interpolate(flow_low[0])[None].cuda() 122 | 123 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 124 | epe_list.append(epe.view(-1).numpy()) 125 | 126 | sequence_prev = sequence 127 | 128 | epe_all = np.concatenate(epe_list) 129 | epe = np.mean(epe_all) 130 | px1 = np.mean(epe_all<1) 131 | px3 = np.mean(epe_all<3) 132 | px5 = np.mean(epe_all<5) 133 | 134 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) 135 | results[dstype] = np.mean(epe_list) 136 | 137 | return results 138 | 139 | 140 | @torch.no_grad() 141 | def validate_kitti(model, iters=24): 142 | """ Peform validation using the KITTI-2015 (train) split """ 143 | model.eval() 144 | val_dataset = datasets.KITTI(split='training') 145 | 146 | out_list, epe_list = [], [] 147 | for val_id in range(len(val_dataset)): 148 | image1, image2, flow_gt, valid_gt = val_dataset[val_id] 149 | image1 = image1[None].cuda() 150 | image2 = image2[None].cuda() 151 | 152 | padder = InputPadder(image1.shape, mode='kitti') 153 | image1, image2 = padder.pad(image1, image2) 154 | 155 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) 156 | flow = padder.unpad(flow_pr[0]).cpu() 157 | 158 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 159 | mag = torch.sum(flow_gt**2, dim=0).sqrt() 160 | 161 | epe = epe.view(-1) 162 | mag = mag.view(-1) 163 | val = valid_gt.view(-1) >= 0.5 164 | 165 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() 166 | epe_list.append(epe[val].mean().item()) 167 | out_list.append(out[val].cpu().numpy()) 168 | 169 | epe_list = np.array(epe_list) 170 | out_list = np.concatenate(out_list) 171 | 172 | epe = np.mean(epe_list) 173 | f1 = 100 * np.mean(out_list) 174 | 175 | print("Validation KITTI: %f, %f" % (epe, f1)) 176 | return {'kitti-epe': epe, 'kitti-f1': f1} 177 | 178 | 179 | if __name__ == '__main__': 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument('--model', default='gmflownet', help="mdoel class. ``_model.py should be in ./core and `Model` should be defined in this file") 182 | parser.add_argument('--ckpt', help="restored checkpoint") 183 | 184 | parser.add_argument('--dataset', help="dataset for evaluation") 185 | parser.add_argument('--use_mix_attn', action='store_true', help='use mixture of POLA and axial attentions') 186 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 187 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 188 | args = parser.parse_args() 189 | 190 | model = torch.nn.DataParallel(create_model(args)) 191 | model.load_state_dict(torch.load(args.ckpt), strict=True) 192 | 193 | model.cuda() 194 | model.eval() 195 | 196 | # create_sintel_submission(model.module, warm_start=True) 197 | # create_kitti_submission(model.module) 198 | 199 | with torch.no_grad(): 200 | if args.dataset == 'chairs': 201 | validate_chairs(model.module) 202 | 203 | elif args.dataset == 'sintel': 204 | validate_sintel(model.module) 205 | 206 | elif args.dataset == 'sintel_test': 207 | create_sintel_submission(model.module) 208 | 209 | elif args.dataset == 'kitti': 210 | validate_kitti(model.module) 211 | 212 | elif args.dataset == 'kitti_test': 213 | create_kitti_submission(model.module) 214 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import sys 3 | sys.path.append('core') 4 | 5 | import argparse, configparser 6 | import os 7 | import cv2 8 | import time 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | 17 | from torch.utils.data import DataLoader 18 | # from torch.optim import Adam as AdamW 19 | from torch.optim.adamw import AdamW 20 | from core.onecyclelr import OneCycleLR 21 | from core import create_model 22 | 23 | from loss import compute_supervision_coarse, compute_coarse_loss, backwarp 24 | 25 | import evaluate 26 | import datasets 27 | 28 | from tensorboardX import SummaryWriter 29 | 30 | try: 31 | from torch.cuda.amp import GradScaler 32 | except: 33 | # dummy GradScaler for PyTorch < 1.6 34 | class GradScaler: 35 | def __init__(self, enabled=False): 36 | pass 37 | def scale(self, loss): 38 | return loss 39 | def unscale_(self, optimizer): 40 | pass 41 | def step(self, optimizer): 42 | optimizer.step() 43 | def update(self): 44 | pass 45 | 46 | 47 | # exclude extremely large displacements 48 | MAX_FLOW = 400 49 | SUM_FREQ = 100 50 | VAL_FREQ = 5000 51 | 52 | 53 | def sequence_loss(train_outputs, image1, image2, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW, use_matching_loss=False): 54 | """ Loss function defined over sequence of flow predictions """ 55 | flow_preds, softCorrMap = train_outputs 56 | 57 | # original RAFT loss 58 | n_predictions = len(flow_preds) 59 | flow_loss = 0.0 60 | 61 | # exclude invalid pixels and extremely large displacements 62 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 63 | valid = (valid >= 0.5) & (mag < max_flow) 64 | 65 | for i in range(n_predictions): 66 | i_weight = gamma**(n_predictions - i - 1) 67 | i_loss = (flow_preds[i] - flow_gt).abs() 68 | flow_loss += i_weight * (valid[:, None].float() * i_loss).mean() 69 | 70 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() 71 | epe = epe.view(-1)[valid.view(-1)] 72 | 73 | metrics = { 74 | 'epe': epe.mean().item(), 75 | '1px': (epe < 1).float().mean().item(), 76 | '3px': (epe < 3).float().mean().item(), 77 | '5px': (epe < 5).float().mean().item(), 78 | } 79 | 80 | if use_matching_loss: 81 | # enable global matching loss. Try to use it in late stages of the trianing 82 | img_2back1 = backwarp(image2, flow_gt) 83 | occlusionMap = (image1 - img_2back1).mean(1, keepdims=True) #(N, H, W) 84 | occlusionMap = torch.abs(occlusionMap) > 20 85 | occlusionMap = occlusionMap.float() 86 | 87 | conf_matrix_gt = compute_supervision_coarse(flow_gt, occlusionMap, 8) # 8 from RAFT downsample 88 | 89 | matchLossCfg = configparser.ConfigParser() 90 | matchLossCfg.POS_WEIGHT = 1 91 | matchLossCfg.NEG_WEIGHT = 1 92 | matchLossCfg.FOCAL_ALPHA = 0.25 93 | matchLossCfg.FOCAL_GAMMA = 2.0 94 | matchLossCfg.COARSE_TYPE = 'cross_entropy' 95 | match_loss = compute_coarse_loss(softCorrMap, conf_matrix_gt, matchLossCfg) 96 | 97 | flow_loss = flow_loss + 0.01*match_loss 98 | 99 | return flow_loss, metrics 100 | 101 | 102 | def count_parameters(model): 103 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 104 | 105 | 106 | def fetch_optimizer(args, model, last_iters=-1): 107 | """ Create the optimizer and learning rate scheduler """ 108 | optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) 109 | 110 | scheduler = OneCycleLR(optimizer, args.lr, args.num_steps+100, 111 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear', last_epoch=last_iters) 112 | 113 | return optimizer, scheduler 114 | 115 | 116 | class Logger: 117 | def __init__(self, model, scheduler, total_steps=0, log_dir=None): 118 | self.model = model 119 | self.scheduler = scheduler 120 | self.total_steps = total_steps 121 | self.running_loss = {} 122 | self.writer = None 123 | self.log_dir = log_dir 124 | 125 | def _print_training_status(self): 126 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())] 127 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0]) 128 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data) 129 | 130 | # print the training status 131 | print(training_str + metrics_str) 132 | 133 | if self.writer is None: 134 | self.writer = SummaryWriter(logdir=self.log_dir) 135 | 136 | for k in self.running_loss: 137 | self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps) 138 | self.running_loss[k] = 0.0 139 | 140 | def push(self, metrics): 141 | self.total_steps += 1 142 | 143 | for key in metrics: 144 | if key not in self.running_loss: 145 | self.running_loss[key] = 0.0 146 | 147 | self.running_loss[key] += metrics[key] 148 | 149 | if self.total_steps % SUM_FREQ == SUM_FREQ-1: 150 | self._print_training_status() 151 | self.running_loss = {} 152 | 153 | def write_dict(self, results): 154 | if self.writer is None: 155 | self.writer = SummaryWriter(logdir=self.log_dir) 156 | 157 | for key in results: 158 | self.writer.add_scalar(key, results[key], self.total_steps) 159 | 160 | def close(self): 161 | self.writer.close() 162 | 163 | 164 | def train(args): 165 | 166 | model = nn.DataParallel(create_model(args), device_ids=args.gpus) 167 | print("Parameter Count: %d" % count_parameters(model)) 168 | 169 | if args.restore_ckpt is not None: 170 | model.load_state_dict(torch.load(args.restore_ckpt), strict=True) 171 | 172 | model.cuda() 173 | model.train() 174 | 175 | if args.stage != 'chairs': 176 | model.module.freeze_bn() 177 | 178 | if args.restore_ckpt is not None: 179 | strStep = os.path.split(args.restore_ckpt)[-1].split('_')[0] 180 | total_steps = int(strStep) if strStep.isdigit() else 0 181 | else: 182 | total_steps = 0 183 | 184 | train_loader = datasets.fetch_dataloader(args, TRAIN_DS='C+T+K/S') 185 | optimizer, scheduler = fetch_optimizer(args, model, total_steps) 186 | 187 | scaler = GradScaler(enabled=args.mixed_precision) 188 | logger = Logger(model, scheduler, total_steps, os.path.join('runs', args.name)) 189 | 190 | add_noise = True 191 | 192 | should_keep_training = True 193 | while should_keep_training: 194 | 195 | for i_batch, data_blob in enumerate(train_loader): 196 | optimizer.zero_grad() 197 | image1, image2, flow, valid = [x.cuda() for x in data_blob] 198 | 199 | if args.add_noise: 200 | stdv = np.random.uniform(0.0, 5.0) 201 | image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0) 202 | image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0) 203 | 204 | flow_predictions = model(image1, image2, iters=args.iters) 205 | 206 | loss, metrics = sequence_loss(flow_predictions, image1, image2, flow, valid, gamma=args.gamma, use_matching_loss=args.use_mix_attn) 207 | scaler.scale(loss).backward() 208 | scaler.unscale_(optimizer) 209 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 210 | 211 | scaler.step(optimizer) 212 | scheduler.step() 213 | scaler.update() 214 | 215 | logger.push(metrics) 216 | 217 | if total_steps % VAL_FREQ == VAL_FREQ - 1: 218 | PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name) 219 | torch.save(model.state_dict(), PATH) 220 | 221 | results = {} 222 | for val_dataset in args.validation: 223 | if val_dataset == 'chairs': 224 | results.update(evaluate.validate_chairs(model.module)) 225 | elif val_dataset == 'sintel': 226 | results.update(evaluate.validate_sintel(model.module)) 227 | elif val_dataset == 'kitti': 228 | results.update(evaluate.validate_kitti(model.module)) 229 | 230 | logger.write_dict(results) 231 | 232 | model.train() 233 | if args.stage != 'chairs': 234 | model.module.freeze_bn() 235 | 236 | total_steps += 1 237 | 238 | if total_steps > args.num_steps: 239 | should_keep_training = False 240 | break 241 | 242 | logger.close() 243 | PATH = 'checkpoints/%s.pth' % args.name 244 | torch.save(model.state_dict(), PATH) 245 | 246 | return PATH 247 | 248 | 249 | if __name__ == '__main__': 250 | parser = argparse.ArgumentParser() 251 | parser.add_argument('--name', default='gmflownet', help="name of your experiment. The saved checkpoint will be named after this in `./checkpoints/.`") 252 | parser.add_argument('--model', default='gmflownet', help="mdoel class. ``_model.py should be in ./core and `Model` should be defined in this file") 253 | parser.add_argument('--stage', help="determines which dataset to use for training") 254 | parser.add_argument('--restore_ckpt', help="restore checkpoint") 255 | parser.add_argument('--use_mix_attn', action='store_true', help='use mixture of POLA and axial attentions') 256 | parser.add_argument('--validation', type=str, nargs='+') 257 | 258 | parser.add_argument('--lr', type=float, default=0.00002) 259 | parser.add_argument('--num_steps', type=int, default=100000) 260 | parser.add_argument('--batch_size', type=int, default=6) 261 | parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512]) 262 | parser.add_argument('--gpus', type=int, nargs='+', default=[0,1]) 263 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 264 | 265 | parser.add_argument('--iters', type=int, default=12) 266 | parser.add_argument('--wdecay', type=float, default=.00005) 267 | parser.add_argument('--epsilon', type=float, default=1e-8) 268 | parser.add_argument('--clip', type=float, default=1.0) 269 | parser.add_argument('--dropout', type=float, default=0.0) 270 | parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting') 271 | parser.add_argument('--add_noise', action='store_true') 272 | args = parser.parse_args() 273 | 274 | torch.set_num_threads(16) 275 | 276 | torch.manual_seed(1234) 277 | np.random.seed(1234) 278 | 279 | if not os.path.exists('checkpoints'): 280 | os.mkdir('checkpoints') 281 | if not os.path.exists('runs'): 282 | os.mkdir('runs') 283 | 284 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.gpus)) 285 | args.gpus = [i for i in range(len(args.gpus))] 286 | train(args) 287 | -------------------------------------------------------------------------------- /train_gmflownet.sh: -------------------------------------------------------------------------------- 1 | python -u train.py --model gmflownet --name gmflownet-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 120000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 2 | python -u train.py --model gmflownet --name gmflownet-things --stage things --validation sintel kitti --restore_ckpt checkpoints/gmflownet-chairs.pth --gpus 0 1 --num_steps 160000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 3 | python -u train.py --model gmflownet --name gmflownet-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/gmflownet-things.pth --gpus 0 1 --num_steps 160000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma 0.85 4 | python -u train.py --model gmflownet --name gmflownet-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/gmflownet-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 5 | 6 | 7 | -------------------------------------------------------------------------------- /train_gmflownet_mix.sh: -------------------------------------------------------------------------------- 1 | python -u train.py --model gmflownet --name gmflownet_mix-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 120000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 --use_mix_attn 2 | python -u train.py --model gmflownet --name gmflownet_mix-things --stage things --validation sintel kitti --restore_ckpt checkpoints/gmflownet_mix-chairs.pth --gpus 0 1 --num_steps 160000 --batch_size 8 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 --use_mix_attn 3 | python -u train.py --model gmflownet --name gmflownet_mix-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/gmflownet_mix-things.pth --gpus 0 1 --num_steps 160000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma 0.85 --use_mix_attn 4 | python -u train.py --model gmflownet --name gmflownet_mix-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/gmflownet_mix-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --use_mix_attn 5 | --------------------------------------------------------------------------------