├── README.md ├── alt_cuda_corr ├── correlation.cpp ├── correlation_kernel.cu └── setup.py ├── chairs_split.txt ├── core ├── __init__.py ├── corr.py ├── datasets.py ├── extractor.py ├── fd_corr.py ├── fd_decoder.py ├── fd_encoder.py ├── flowdiffuser.py ├── module.py ├── raft.py ├── update.py └── utils │ ├── __init__.py │ ├── augmentor.py │ ├── flow_viz.py │ ├── frame_utils.py │ └── utils.py ├── eval.sh ├── evaluate.py ├── train.py └── train.sh /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR 2024] FlowDiffuser: Advancing Optical Flow Estimation with Diffusion Models 2 | 3 |

Ao Luo1,2, Xin Li3, Fan Fang3, Jiangyu Liu2, Haoqiang Fan2, and Shuaicheng Liu4,2

4 |

1. Southwest Jiaotong University   2. Megvii Research   3.Group 42

5 |

4. University of Electronic Science and Technology of China

6 | 7 | This project provides the official implementation of '[**FlowDiffuser: Advancing Optical Flow Estimation with Diffusion Models**](https://openaccess.thecvf.com/content/CVPR2024/papers/Luo_FlowDiffuser_Advancing_Optical_Flow_Estimation_with_Diffusion_Models_CVPR_2024_paper.pdf)'. 8 | 9 | ## Abstract 10 | Optical flow estimation, a process of predicting pixel-wise displacement between consecutive frames, has commonly been approached as a regression task in the age of deep learning. Despite notable advancements, this de facto paradigm unfortunately falls short in generalization performance when trained on synthetic or constrained data. Pioneering a paradigm shift, we reformulate optical flow estimation as a conditional flow generation challenge, unveiling FlowDiffuser — a new family of optical flow models that could have stronger learning and generalization capabilities. FlowDiffuser estimates optical flow through a ‘noise-to-flow’ strategy, progressively eliminating noise from randomly generated flows conditioned on the provided pairs. To optimize accuracy and efficiency, our FlowDiffuser incorporates a novel Conditional Recurrent Denoising Decoder (Conditional-RDD), streamlining the flow estimation process. It incorporates a unique Hidden State Denoising (HSD) paradigm, effectively leveraging the information from previous time steps. Moreover, FlowDiffuser can be easily integrated into existing flow networks, leading to significant improvements in performance metrics compared to conventional implementations. Experiments on challenging benchmarks, including Sintel and KITTI, demonstrate the effectiveness of our FlowDiffuser with superior performance to existing state-of-the-art models. 11 | 12 | 13 | ## Overview 14 | 15 | ![FlowDiffuser](https://github.com/LA30/FlowDiffuser/assets/47421121/3d90bb6b-a3a0-411d-b8f6-66119f08b2b2) 16 | 17 | ![comparison](https://github.com/LA30/FlowDiffuser/assets/47421121/5463a1a7-c9bd-4596-afe5-d5b0c0ed7b40) 18 | 19 | 20 | ## Requirements 21 | 22 | Python 3.8 with following packages 23 | ```Shell 24 | pytorch 1.9.0 25 | torchvision 0.10.0 26 | numpy 1.19.5 27 | opencv-python 4.6.0.66 28 | timm 0.6.12 29 | scipy 1.5.4 30 | matplotlib 3.3.4 31 | ``` 32 | 33 | 34 | ## Usage 35 | 36 | 1. Download [Sintel](http://sintel.is.tue.mpg.de/) and [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) dataset, and set the root path of each class in `./core/datasets.py`. 37 | 38 | 2. Put `*.pth` file ([GoogleDrive](https://drive.google.com/file/d/1msAB8-ibMCTUEQbT6yjV1y13D_-6LTSX/view?usp=sharing)) into folder `./weights`. 39 | 40 | 3. Evaluation on Sintel and KITTI 41 | ```Shell 42 | ./eval.sh 43 | ``` 44 | 45 | 46 | ## Q & A 47 | 48 | Due to some changes in my job, I am busy with other matters. If you have any questions, please email me at aoluo@swjtu.edu.cn. I will respond to you at my earliest convenience. 49 | 50 | 51 | ## Citation 52 | 53 | If you think this work is helpful, please cite 54 | ``` 55 | @inproceedings{luo2024flowdiffuser, 56 | title={FlowDiffuser: Advancing Optical Flow Estimation with Diffusion Models}, 57 | author={Luo, Ao and Li, Xin and Yang, Fan and Liu, Jiangyu and Fan, Haoqiang and Liu, Shuaicheng}, 58 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 59 | pages={19167--19176}, 60 | year={2024} 61 | } 62 | ``` 63 | 64 | ## Acknowledgement 65 | 66 | The code is built based on [RAFT](https://github.com/princeton-vl/RAFT), [SKFlow](https://github.com/littlespray/SKFlow), [DiffusionDet](https://github.com/ShoufaChen/DiffusionDet), and [EMD-Flow](https://github.com/gddcx/EMD-Flow). We thank the authors for their contributions. 67 | -------------------------------------------------------------------------------- /alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /alt_cuda_corr/correlation_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | 7 | #define BLOCK_H 4 8 | #define BLOCK_W 8 9 | #define BLOCK_HW BLOCK_H * BLOCK_W 10 | #define CHANNEL_STRIDE 32 11 | 12 | 13 | __forceinline__ __device__ 14 | bool within_bounds(int h, int w, int H, int W) { 15 | return h >= 0 && h < H && w >= 0 && w < W; 16 | } 17 | 18 | template 19 | __global__ void corr_forward_kernel( 20 | const torch::PackedTensorAccessor32 fmap1, 21 | const torch::PackedTensorAccessor32 fmap2, 22 | const torch::PackedTensorAccessor32 coords, 23 | torch::PackedTensorAccessor32 corr, 24 | int r) 25 | { 26 | const int b = blockIdx.x; 27 | const int h0 = blockIdx.y * blockDim.x; 28 | const int w0 = blockIdx.z * blockDim.y; 29 | const int tid = threadIdx.x * blockDim.y + threadIdx.y; 30 | 31 | const int H1 = fmap1.size(1); 32 | const int W1 = fmap1.size(2); 33 | const int H2 = fmap2.size(1); 34 | const int W2 = fmap2.size(2); 35 | const int N = coords.size(1); 36 | const int C = fmap1.size(3); 37 | 38 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; 39 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; 40 | __shared__ scalar_t x2s[BLOCK_HW]; 41 | __shared__ scalar_t y2s[BLOCK_HW]; 42 | 43 | for (int c=0; c(floor(y2s[k1]))-r+iy; 76 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 77 | int c2 = tid % CHANNEL_STRIDE; 78 | 79 | auto fptr = fmap2[b][h2][w2]; 80 | if (within_bounds(h2, w2, H2, W2)) 81 | f2[c2][k1] = fptr[c+c2]; 82 | else 83 | f2[c2][k1] = 0.0; 84 | } 85 | 86 | __syncthreads(); 87 | 88 | scalar_t s = 0.0; 89 | for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) 105 | *(corr_ptr + ix_nw) += nw; 106 | 107 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) 108 | *(corr_ptr + ix_ne) += ne; 109 | 110 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) 111 | *(corr_ptr + ix_sw) += sw; 112 | 113 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) 114 | *(corr_ptr + ix_se) += se; 115 | } 116 | } 117 | } 118 | } 119 | } 120 | 121 | 122 | template 123 | __global__ void corr_backward_kernel( 124 | const torch::PackedTensorAccessor32 fmap1, 125 | const torch::PackedTensorAccessor32 fmap2, 126 | const torch::PackedTensorAccessor32 coords, 127 | const torch::PackedTensorAccessor32 corr_grad, 128 | torch::PackedTensorAccessor32 fmap1_grad, 129 | torch::PackedTensorAccessor32 fmap2_grad, 130 | torch::PackedTensorAccessor32 coords_grad, 131 | int r) 132 | { 133 | 134 | const int b = blockIdx.x; 135 | const int h0 = blockIdx.y * blockDim.x; 136 | const int w0 = blockIdx.z * blockDim.y; 137 | const int tid = threadIdx.x * blockDim.y + threadIdx.y; 138 | 139 | const int H1 = fmap1.size(1); 140 | const int W1 = fmap1.size(2); 141 | const int H2 = fmap2.size(1); 142 | const int W2 = fmap2.size(2); 143 | const int N = coords.size(1); 144 | const int C = fmap1.size(3); 145 | 146 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; 147 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; 148 | 149 | __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1]; 150 | __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1]; 151 | 152 | __shared__ scalar_t x2s[BLOCK_HW]; 153 | __shared__ scalar_t y2s[BLOCK_HW]; 154 | 155 | for (int c=0; c(floor(y2s[k1]))-r+iy; 190 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 191 | int c2 = tid % CHANNEL_STRIDE; 192 | 193 | auto fptr = fmap2[b][h2][w2]; 194 | if (within_bounds(h2, w2, H2, W2)) 195 | f2[c2][k1] = fptr[c+c2]; 196 | else 197 | f2[c2][k1] = 0.0; 198 | 199 | f2_grad[c2][k1] = 0.0; 200 | } 201 | 202 | __syncthreads(); 203 | 204 | const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1]; 205 | scalar_t g = 0.0; 206 | 207 | int ix_nw = H1*W1*((iy-1) + rd*(ix-1)); 208 | int ix_ne = H1*W1*((iy-1) + rd*ix); 209 | int ix_sw = H1*W1*(iy + rd*(ix-1)); 210 | int ix_se = H1*W1*(iy + rd*ix); 211 | 212 | if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) 213 | g += *(grad_ptr + ix_nw) * dy * dx; 214 | 215 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) 216 | g += *(grad_ptr + ix_ne) * dy * (1-dx); 217 | 218 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) 219 | g += *(grad_ptr + ix_sw) * (1-dy) * dx; 220 | 221 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) 222 | g += *(grad_ptr + ix_se) * (1-dy) * (1-dx); 223 | 224 | for (int k=0; k(floor(y2s[k1]))-r+iy; 232 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 233 | int c2 = tid % CHANNEL_STRIDE; 234 | 235 | scalar_t* fptr = &fmap2_grad[b][h2][w2][0]; 236 | if (within_bounds(h2, w2, H2, W2)) 237 | atomicAdd(fptr+c+c2, f2_grad[c2][k1]); 238 | } 239 | } 240 | } 241 | } 242 | __syncthreads(); 243 | 244 | 245 | for (int k=0; k corr_cuda_forward( 261 | torch::Tensor fmap1, 262 | torch::Tensor fmap2, 263 | torch::Tensor coords, 264 | int radius) 265 | { 266 | const auto B = coords.size(0); 267 | const auto N = coords.size(1); 268 | const auto H = coords.size(2); 269 | const auto W = coords.size(3); 270 | 271 | const auto rd = 2 * radius + 1; 272 | auto opts = fmap1.options(); 273 | auto corr = torch::zeros({B, N, rd*rd, H, W}, opts); 274 | 275 | const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W); 276 | const dim3 threads(BLOCK_H, BLOCK_W); 277 | 278 | corr_forward_kernel<<>>( 279 | fmap1.packed_accessor32(), 280 | fmap2.packed_accessor32(), 281 | coords.packed_accessor32(), 282 | corr.packed_accessor32(), 283 | radius); 284 | 285 | return {corr}; 286 | } 287 | 288 | std::vector corr_cuda_backward( 289 | torch::Tensor fmap1, 290 | torch::Tensor fmap2, 291 | torch::Tensor coords, 292 | torch::Tensor corr_grad, 293 | int radius) 294 | { 295 | const auto B = coords.size(0); 296 | const auto N = coords.size(1); 297 | 298 | const auto H1 = fmap1.size(1); 299 | const auto W1 = fmap1.size(2); 300 | const auto H2 = fmap2.size(1); 301 | const auto W2 = fmap2.size(2); 302 | const auto C = fmap1.size(3); 303 | 304 | auto opts = fmap1.options(); 305 | auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts); 306 | auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts); 307 | auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts); 308 | 309 | const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W); 310 | const dim3 threads(BLOCK_H, BLOCK_W); 311 | 312 | 313 | corr_backward_kernel<<>>( 314 | fmap1.packed_accessor32(), 315 | fmap2.packed_accessor32(), 316 | coords.packed_accessor32(), 317 | corr_grad.packed_accessor32(), 318 | fmap1_grad.packed_accessor32(), 319 | fmap2_grad.packed_accessor32(), 320 | coords_grad.packed_accessor32(), 321 | radius); 322 | 323 | return {fmap1_grad, fmap2_grad, coords_grad}; 324 | } -------------------------------------------------------------------------------- /alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LA30/FlowDiffuser/9aff9c6e8c68f809e40bb0ae4273621276686168/core/__init__.py -------------------------------------------------------------------------------- /core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class AlternateCorrBlock: 64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 65 | self.num_levels = num_levels 66 | self.radius = radius 67 | 68 | self.pyramid = [(fmap1, fmap2)] 69 | for i in range(self.num_levels): 70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 72 | self.pyramid.append((fmap1, fmap2)) 73 | 74 | def __call__(self, coords): 75 | coords = coords.permute(0, 2, 3, 1) 76 | B, H, W, _ = coords.shape 77 | dim = self.pyramid[0][0].shape[1] 78 | 79 | corr_list = [] 80 | for i in range(self.num_levels): 81 | r = self.radius 82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 84 | 85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 87 | corr_list.append(corr.squeeze(1)) 88 | 89 | corr = torch.stack(corr_list, dim=1) 90 | corr = corr.reshape(B, -1, H, W) 91 | return corr / torch.sqrt(torch.tensor(dim).float()) 92 | -------------------------------------------------------------------------------- /core/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | 8 | import os 9 | import math 10 | import random 11 | from glob import glob 12 | import os.path as osp 13 | 14 | from utils import frame_utils 15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | 18 | class FlowDataset(data.Dataset): 19 | def __init__(self, aug_params=None, sparse=False): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | if aug_params is not None: 23 | if sparse: 24 | self.augmentor = SparseFlowAugmentor(**aug_params) 25 | else: 26 | self.augmentor = FlowAugmentor(**aug_params) 27 | 28 | self.is_test = False 29 | self.init_seed = False 30 | self.flow_list = [] 31 | self.image_list = [] 32 | self.extra_info = [] 33 | 34 | def __getitem__(self, index): 35 | 36 | if self.is_test: 37 | img1 = frame_utils.read_gen(self.image_list[index][0]) 38 | img2 = frame_utils.read_gen(self.image_list[index][1]) 39 | img1 = np.array(img1).astype(np.uint8)[..., :3] 40 | img2 = np.array(img2).astype(np.uint8)[..., :3] 41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 43 | return img1, img2, self.extra_info[index] 44 | 45 | if not self.init_seed: 46 | worker_info = torch.utils.data.get_worker_info() 47 | if worker_info is not None: 48 | torch.manual_seed(worker_info.id) 49 | np.random.seed(worker_info.id) 50 | random.seed(worker_info.id) 51 | self.init_seed = True 52 | 53 | index = index % len(self.image_list) 54 | valid = None 55 | if self.sparse: 56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 57 | else: 58 | flow = frame_utils.read_gen(self.flow_list[index]) 59 | 60 | img1 = frame_utils.read_gen(self.image_list[index][0]) 61 | img2 = frame_utils.read_gen(self.image_list[index][1]) 62 | 63 | flow = np.array(flow).astype(np.float32) 64 | img1 = np.array(img1).astype(np.uint8) 65 | img2 = np.array(img2).astype(np.uint8) 66 | 67 | # grayscale images 68 | if len(img1.shape) == 2: 69 | img1 = np.tile(img1[...,None], (1, 1, 3)) 70 | img2 = np.tile(img2[...,None], (1, 1, 3)) 71 | else: 72 | img1 = img1[..., :3] 73 | img2 = img2[..., :3] 74 | 75 | if self.augmentor is not None: 76 | if self.sparse: 77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 78 | else: 79 | img1, img2, flow = self.augmentor(img1, img2, flow) 80 | 81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 84 | 85 | if valid is not None: 86 | valid = torch.from_numpy(valid) 87 | else: 88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 89 | 90 | return img1, img2, flow, valid.float() 91 | 92 | 93 | def __rmul__(self, v): 94 | self.flow_list = v * self.flow_list 95 | self.image_list = v * self.image_list 96 | return self 97 | 98 | def __len__(self): 99 | return len(self.image_list) 100 | 101 | 102 | class MpiSintel(FlowDataset): 103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): 104 | super(MpiSintel, self).__init__(aug_params) 105 | flow_root = osp.join(root, split, 'flow') 106 | image_root = osp.join(root, split, dstype) 107 | 108 | if split == 'test': 109 | self.is_test = True 110 | 111 | for scene in os.listdir(image_root): 112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 113 | for i in range(len(image_list)-1): 114 | self.image_list += [ [image_list[i], image_list[i+1]] ] 115 | self.extra_info += [ (scene, i) ] # scene and frame_id 116 | 117 | if split != 'test': 118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 119 | 120 | 121 | class FlyingChairs(FlowDataset): 122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): 123 | super(FlyingChairs, self).__init__(aug_params) 124 | 125 | images = sorted(glob(osp.join(root, '*.ppm'))) 126 | flows = sorted(glob(osp.join(root, '*.flo'))) 127 | assert (len(images)//2 == len(flows)) 128 | 129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 130 | for i in range(len(flows)): 131 | xid = split_list[i] 132 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 133 | self.flow_list += [ flows[i] ] 134 | self.image_list += [ [images[2*i], images[2*i+1]] ] 135 | 136 | 137 | class FlyingThings3D(FlowDataset): 138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): 139 | super(FlyingThings3D, self).__init__(aug_params) 140 | 141 | for cam in ['left']: 142 | for direction in ['into_future', 'into_past']: 143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 145 | 146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 148 | 149 | for idir, fdir in zip(image_dirs, flow_dirs): 150 | images = sorted(glob(osp.join(idir, '*.png')) ) 151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 152 | for i in range(len(flows)-1): 153 | if direction == 'into_future': 154 | self.image_list += [ [images[i], images[i+1]] ] 155 | self.flow_list += [ flows[i] ] 156 | elif direction == 'into_past': 157 | self.image_list += [ [images[i+1], images[i]] ] 158 | self.flow_list += [ flows[i+1] ] 159 | 160 | 161 | class KITTI(FlowDataset): 162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): 163 | super(KITTI, self).__init__(aug_params, sparse=True) 164 | if split == 'testing': 165 | self.is_test = True 166 | 167 | root = osp.join(root, split) 168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 170 | 171 | for img1, img2 in zip(images1, images2): 172 | frame_id = img1.split('/')[-1] 173 | self.extra_info += [ [frame_id] ] 174 | self.image_list += [ [img1, img2] ] 175 | 176 | if split == 'training': 177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 178 | 179 | 180 | class HD1K(FlowDataset): 181 | def __init__(self, aug_params=None, root='datasets/HD1k'): 182 | super(HD1K, self).__init__(aug_params, sparse=True) 183 | 184 | seq_ix = 0 185 | while 1: 186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 188 | 189 | if len(flows) == 0: 190 | break 191 | 192 | for i in range(len(flows)-1): 193 | self.flow_list += [flows[i]] 194 | self.image_list += [ [images[i], images[i+1]] ] 195 | 196 | seq_ix += 1 197 | 198 | 199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): 200 | """ Create the data loader for the corresponding trainign set """ 201 | 202 | if args.stage == 'chairs': 203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 204 | train_dataset = FlyingChairs(aug_params, split='training') 205 | 206 | elif args.stage == 'things': 207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 210 | train_dataset = clean_dataset + final_dataset 211 | 212 | elif args.stage == 'sintel': 213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 217 | 218 | if TRAIN_DS == 'C+T+K+S+H': 219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things 222 | 223 | elif TRAIN_DS == 'C+T+K/S': 224 | train_dataset = 100*sintel_clean + 100*sintel_final + things 225 | 226 | elif args.stage == 'kitti': 227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 228 | train_dataset = KITTI(aug_params, split='training') 229 | 230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True) 232 | 233 | print('Training with %d image pairs' % len(train_dataset)) 234 | return train_loader 235 | 236 | -------------------------------------------------------------------------------- /core/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | num_groups = planes // 8 70 | 71 | if norm_fn == 'group': 72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | if not stride == 1: 76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | 78 | elif norm_fn == 'batch': 79 | self.norm1 = nn.BatchNorm2d(planes//4) 80 | self.norm2 = nn.BatchNorm2d(planes//4) 81 | self.norm3 = nn.BatchNorm2d(planes) 82 | if not stride == 1: 83 | self.norm4 = nn.BatchNorm2d(planes) 84 | 85 | elif norm_fn == 'instance': 86 | self.norm1 = nn.InstanceNorm2d(planes//4) 87 | self.norm2 = nn.InstanceNorm2d(planes//4) 88 | self.norm3 = nn.InstanceNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.InstanceNorm2d(planes) 91 | 92 | elif norm_fn == 'none': 93 | self.norm1 = nn.Sequential() 94 | self.norm2 = nn.Sequential() 95 | self.norm3 = nn.Sequential() 96 | if not stride == 1: 97 | self.norm4 = nn.Sequential() 98 | 99 | if stride == 1: 100 | self.downsample = None 101 | 102 | else: 103 | self.downsample = nn.Sequential( 104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 105 | 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.relu(self.norm1(self.conv1(y))) 110 | y = self.relu(self.norm2(self.conv2(y))) 111 | y = self.relu(self.norm3(self.conv3(y))) 112 | 113 | if self.downsample is not None: 114 | x = self.downsample(x) 115 | 116 | return self.relu(x+y) 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | 123 | if self.norm_fn == 'group': 124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 125 | 126 | elif self.norm_fn == 'batch': 127 | self.norm1 = nn.BatchNorm2d(64) 128 | 129 | elif self.norm_fn == 'instance': 130 | self.norm1 = nn.InstanceNorm2d(64) 131 | 132 | elif self.norm_fn == 'none': 133 | self.norm1 = nn.Sequential() 134 | 135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 136 | self.relu1 = nn.ReLU(inplace=True) 137 | 138 | self.in_planes = 64 139 | self.layer1 = self._make_layer(64, stride=1) 140 | self.layer2 = self._make_layer(96, stride=2) 141 | self.layer3 = self._make_layer(128, stride=2) 142 | 143 | # output convolution 144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 145 | 146 | self.dropout = None 147 | if dropout > 0: 148 | self.dropout = nn.Dropout2d(p=dropout) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 154 | if m.weight is not None: 155 | nn.init.constant_(m.weight, 1) 156 | if m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | 159 | def _make_layer(self, dim, stride=1): 160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 162 | layers = (layer1, layer2) 163 | 164 | self.in_planes = dim 165 | return nn.Sequential(*layers) 166 | 167 | 168 | def forward(self, x): 169 | 170 | # if input is list, combine batch dimension 171 | is_list = isinstance(x, tuple) or isinstance(x, list) 172 | if is_list: 173 | batch_dim = x[0].shape[0] 174 | x = torch.cat(x, dim=0) 175 | 176 | x = self.conv1(x) 177 | x = self.norm1(x) 178 | x = self.relu1(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | 184 | x = self.conv2(x) 185 | 186 | if self.training and self.dropout is not None: 187 | x = self.dropout(x) 188 | 189 | if is_list: 190 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 191 | 192 | return x 193 | 194 | 195 | class SmallEncoder(nn.Module): 196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 197 | super(SmallEncoder, self).__init__() 198 | self.norm_fn = norm_fn 199 | 200 | if self.norm_fn == 'group': 201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 202 | 203 | elif self.norm_fn == 'batch': 204 | self.norm1 = nn.BatchNorm2d(32) 205 | 206 | elif self.norm_fn == 'instance': 207 | self.norm1 = nn.InstanceNorm2d(32) 208 | 209 | elif self.norm_fn == 'none': 210 | self.norm1 = nn.Sequential() 211 | 212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 213 | self.relu1 = nn.ReLU(inplace=True) 214 | 215 | self.in_planes = 32 216 | self.layer1 = self._make_layer(32, stride=1) 217 | self.layer2 = self._make_layer(64, stride=2) 218 | self.layer3 = self._make_layer(96, stride=2) 219 | 220 | self.dropout = None 221 | if dropout > 0: 222 | self.dropout = nn.Dropout2d(p=dropout) 223 | 224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 230 | if m.weight is not None: 231 | nn.init.constant_(m.weight, 1) 232 | if m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | 235 | def _make_layer(self, dim, stride=1): 236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 238 | layers = (layer1, layer2) 239 | 240 | self.in_planes = dim 241 | return nn.Sequential(*layers) 242 | 243 | 244 | def forward(self, x): 245 | 246 | # if input is list, combine batch dimension 247 | is_list = isinstance(x, tuple) or isinstance(x, list) 248 | if is_list: 249 | batch_dim = x[0].shape[0] 250 | x = torch.cat(x, dim=0) 251 | 252 | x = self.conv1(x) 253 | x = self.norm1(x) 254 | x = self.relu1(x) 255 | 256 | x = self.layer1(x) 257 | x = self.layer2(x) 258 | x = self.layer3(x) 259 | x = self.conv2(x) 260 | 261 | if self.training and self.dropout is not None: 262 | x = self.dropout(x) 263 | 264 | if is_list: 265 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 266 | 267 | return x 268 | -------------------------------------------------------------------------------- /core/fd_corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.utils import bilinear_sampler 4 | 5 | 6 | class CorrBlock_FD_Sp4: 7 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4, coords_init=None, rad=1): 8 | self.num_levels = num_levels 9 | self.radius = radius 10 | self.corr_pyramid = [] 11 | 12 | corr = CorrBlock_FD_Sp4.corr(fmap1, fmap2, coords_init, r=rad) 13 | 14 | batch, h1, w1, dim, h2, w2 = corr.shape 15 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 16 | 17 | self.corr_pyramid.append(corr) 18 | for i in range(self.num_levels-1): 19 | corr = F.avg_pool2d(corr, 2, stride=2) 20 | self.corr_pyramid.append(corr) 21 | 22 | def __call__(self, coords): 23 | r = self.radius 24 | coords = coords.permute(0, 2, 3, 1) 25 | batch, h1, w1, _ = coords.shape 26 | 27 | out_pyramid = [] 28 | for i in range(self.num_levels): 29 | corr = self.corr_pyramid[i] 30 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 31 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 32 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 33 | 34 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 35 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 36 | coords_lvl = centroid_lvl + delta_lvl 37 | 38 | corr = bilinear_sampler(corr, coords_lvl) 39 | corr = corr.view(batch, h1, w1, -1) 40 | out_pyramid.append(corr) 41 | 42 | out = torch.cat(out_pyramid, dim=-1) 43 | return out.permute(0, 3, 1, 2).contiguous().float() 44 | 45 | @staticmethod 46 | def corr(fmap1, fmap2, coords_init, r): 47 | batch, dim, ht, wd = fmap1.shape 48 | fmap1 = fmap1.view(batch, dim, ht*wd) 49 | fmap2 = fmap2.view(batch, dim, ht*wd) 50 | 51 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 52 | corr = corr.view(batch, ht, wd, 1, ht, wd) 53 | # return corr / torch.sqrt(torch.tensor(dim).float()) 54 | 55 | coords = coords_init.permute(0, 2, 3, 1).contiguous() 56 | batch, h1, w1, _ = coords.shape 57 | 58 | corr = corr.view(batch*h1*w1, 1, h1, w1) 59 | 60 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 61 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 62 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 63 | 64 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) 65 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 66 | coords_lvl = centroid_lvl + delta_lvl 67 | 68 | corr = bilinear_sampler(corr, coords_lvl) 69 | 70 | corr = corr.view(batch, h1, w1, 1, 2*r+1, 2*r+1) 71 | return corr.permute(0, 1, 2, 3, 5, 4).contiguous() / torch.sqrt(torch.tensor(dim).float()) 72 | -------------------------------------------------------------------------------- /core/fd_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from corr import CorrBlock 7 | 8 | 9 | # --- transformer modules 10 | class TransformerModule(nn.Module): 11 | def __init__(self, args): 12 | super().__init__() 13 | self.args = args 14 | self.tb = TransBlocks(args) 15 | 16 | def forward(self, fmap1, fmap2, inp): 17 | batch, ch, ht, wd = fmap1.shape 18 | fmap1, fmap2 = self.tb(fmap1, fmap2) 19 | corr_fn = CorrBlock(fmap1, fmap2, num_levels=self.args.corr_levels, radius=self.args.corr_radius) 20 | return corr_fn 21 | 22 | 23 | class TransBlocks(nn.Module): 24 | def __init__(self, args): 25 | super().__init__() 26 | dim = args.m_dim 27 | mlp_scale = 4 28 | window_size = [8, 8] 29 | num_layers = [2, 2] 30 | 31 | self.num_layers = len(num_layers) 32 | self.blocks = nn.ModuleList() 33 | for n in range(self.num_layers): 34 | if n == self.num_layers - 1: 35 | self.blocks.append( 36 | BasicLayer(num_layer=num_layers[n], dim=dim, mlp_scale=mlp_scale, window_size=window_size, cross=False)) 37 | else: 38 | self.blocks.append( 39 | BasicLayer(num_layer=num_layers[n], dim=dim, mlp_scale=mlp_scale, window_size=window_size, cross=True)) 40 | 41 | def forward(self, fmap1, fmap2): 42 | _, _, ht, wd = fmap1.shape 43 | pad_h, pad_w = (8 - (ht % 8)) % 8, (8 - (wd % 8)) % 8 44 | _pad = [pad_w // 2, pad_w - pad_w // 2, pad_h, 0] 45 | fmap1 = F.pad(fmap1, pad=_pad, mode='constant', value=0) 46 | fmap2 = F.pad(fmap2, pad=_pad, mode='constant', value=0) 47 | mask = torch.zeros([1, ht, wd]).to(fmap1.device) 48 | mask = torch.nn.functional.pad(mask, pad=_pad, mode='constant', value=1) 49 | mask = mask.bool() 50 | fmap1 = fmap1.permute(0, 2, 3, 1).contiguous().float() 51 | fmap2 = fmap2.permute(0, 2, 3, 1).contiguous().float() 52 | 53 | for idx, blk in enumerate(self.blocks): 54 | fmap1, fmap2 = blk(fmap1, fmap2, mask=mask) 55 | 56 | _, ht, wd, _ = fmap1.shape 57 | fmap1 = fmap1[:, _pad[2]:ht - _pad[3], _pad[0]:wd - _pad[1], :] 58 | fmap2 = fmap2[:, _pad[2]:ht - _pad[3], _pad[0]:wd - _pad[1], :] 59 | 60 | fmap1 = fmap1.permute(0, 3, 1, 2).contiguous() 61 | fmap2 = fmap2.permute(0, 3, 1, 2).contiguous() 62 | 63 | return fmap1, fmap2 64 | 65 | 66 | def window_partition(fmap, window_size): 67 | """ 68 | :param fmap: shape:B, H, W, C 69 | :param window_size: Wh, Ww 70 | :return: shape: B*nW, Wh*Ww, C 71 | """ 72 | B, H, W, C = fmap.shape 73 | fmap = fmap.reshape(B, H//window_size[0], window_size[0], W//window_size[1], window_size[1], C) 74 | fmap = fmap.permute(0, 1, 3, 2, 4, 5).contiguous() 75 | fmap = fmap.reshape(B*(H//window_size[0])*(W//window_size[1]), window_size[0]*window_size[1], C) 76 | return fmap 77 | 78 | 79 | def window_reverse(fmap, window_size, H, W): 80 | """ 81 | :param fmap: shape:B*nW, Wh*Ww, dim 82 | :param window_size: Wh, Ww 83 | :param H: original image height 84 | :param W: original image width 85 | :return: shape: B, H, W, C 86 | """ 87 | Bnw, _, dim = fmap.shape 88 | nW = (H // window_size[0]) * (W // window_size[1]) 89 | fmap = fmap.reshape(Bnw//nW, H // window_size[0], W // window_size[1], window_size[0], window_size[1], dim) 90 | fmap = fmap.permute(0, 1, 3, 2, 4, 5).contiguous() 91 | fmap = fmap.reshape(Bnw//nW, H, W, dim) 92 | return fmap 93 | 94 | 95 | class WindowAttention(nn.Module): 96 | def __init__(self, dim, window_size, scale=None): 97 | super().__init__() 98 | self.dim = dim 99 | self.scale = scale or dim ** (-0.5) 100 | self.q = nn.Linear(in_features=dim, out_features=dim) 101 | self.k = nn.Linear(in_features=dim, out_features=dim) 102 | self.v = nn.Linear(in_features=dim, out_features=dim) 103 | self.softmax = nn.Softmax(dim=-1) 104 | self.proj = nn.Linear(dim, dim) 105 | 106 | def forward(self, fmap, mask=None): 107 | """ 108 | :param fmap1: B*nW, Wh*Ww, dim 109 | :param mask: nw, Wh*Ww, Ww*Wh 110 | :return: B*nW, Wh*Ww, dim 111 | """ 112 | Bnw, WhWw, dim = fmap.shape 113 | q = self.q(fmap) 114 | k = self.k(fmap) 115 | v = self.v(fmap) 116 | 117 | q = q * self.scale 118 | attn = q @ k.transpose(1, 2) 119 | 120 | if mask is not None: 121 | nw = mask.shape[0] 122 | attn = attn.reshape(Bnw//nw, nw, WhWw, WhWw) + mask.unsqueeze(0) 123 | attn = attn.reshape(Bnw, WhWw, WhWw) 124 | attn = self.softmax(attn) 125 | else: 126 | attn = self.softmax(attn) 127 | x = attn @ v 128 | x = self.proj(x) 129 | return x 130 | 131 | 132 | class GlobalAttention(nn.Module): 133 | def __init__(self, dim, scale=None): 134 | super().__init__() 135 | self.dim = dim 136 | self.scale = scale or dim ** (-0.5) 137 | self.q = nn.Linear(in_features=dim, out_features=dim) 138 | self.k = nn.Linear(in_features=dim, out_features=dim) 139 | self.v = nn.Linear(in_features=dim, out_features=dim) 140 | self.softmax = nn.Softmax(dim=-1) 141 | self.proj = nn.Linear(in_features=dim, out_features=dim) 142 | 143 | def forward(self, fmap1, fmap2, mask=None): 144 | """ 145 | :param fmap1: B, H, W, C 146 | :param fmap2: B, H, W, C 147 | :param pe: B, H, W, C 148 | :return: 149 | """ 150 | B, H, W, C = fmap1.shape 151 | q = self.q(fmap1) 152 | k = self.k(fmap2) 153 | v = self.v(fmap2) 154 | 155 | q, k, v = map(lambda x: x.reshape(B, H*W, C), [q, k, v]) 156 | 157 | q = q * self.scale 158 | attn = q @ k.transpose(1, 2) 159 | if mask is not None: 160 | mask = mask.reshape(1, H * W, 1) | mask.reshape(1, 1, H * W) # batch, hw, hw 161 | mask = mask.float() * -100.0 162 | attn = attn + mask 163 | attn = self.softmax(attn) 164 | x = attn @ v # B, HW, C 165 | 166 | x = self.proj(x) 167 | 168 | x = x.reshape(B, H, W, C) 169 | return x 170 | 171 | 172 | class SelfTransformerBlcok(nn.Module): 173 | def __init__(self, dim, mlp_scale, window_size, shift_size=None, norm=None): 174 | super().__init__() 175 | self.dim = dim 176 | self.window_size = window_size 177 | self.shift_size = shift_size 178 | 179 | if norm == 'layer': 180 | self.layer_norm1 = nn.LayerNorm(dim) 181 | self.layer_norm2 = nn.LayerNorm(dim) 182 | else: 183 | self.layer_norm1 = nn.Identity() 184 | self.layer_norm2 = nn.Identity() 185 | 186 | self.self_attn = WindowAttention(dim=dim, window_size=window_size) 187 | self.mlp = nn.Sequential( 188 | nn.Linear(dim, dim * mlp_scale), 189 | nn.GELU(), 190 | nn.Linear(dim * mlp_scale, dim) 191 | ) 192 | 193 | def forward(self, fmap, mask=None): 194 | """ 195 | :param fmap: shape: B, H, W, C 196 | :return: B, H, W, C 197 | """ 198 | B, H, W, C = fmap.shape 199 | 200 | shortcut = fmap 201 | fmap = self.layer_norm1(fmap) 202 | 203 | if self.shift_size is not None: 204 | shifted_fmap = torch.roll(fmap, [-self.shift_size[0], -self.shift_size[1]], dims=(1, 2)) 205 | if mask is not None: 206 | shifted_mask = torch.roll(mask, [-self.shift_size[0], -self.shift_size[1]], dims=(1, 2)) 207 | else: 208 | shifted_fmap = fmap 209 | if mask is not None: 210 | shifted_mask = mask 211 | 212 | win_fmap = window_partition(shifted_fmap, window_size=self.window_size) 213 | if mask is not None: 214 | pad_mask = window_partition(shifted_mask.unsqueeze(-1), self.window_size) 215 | pad_mask = pad_mask.reshape(-1, self.window_size[0] * self.window_size[1], 1) \ 216 | | pad_mask.reshape(-1, 1, self.window_size[0] * self.window_size[1]) 217 | 218 | if self.shift_size is not None: 219 | h_slice = [slice(0, -self.window_size[0]), slice(-self.window_size[0], -self.shift_size[0]), slice(-self.shift_size[0], None)] 220 | w_slice = [slice(0, -self.window_size[1]), slice(-self.window_size[1], -self.shift_size[1]), slice(-self.shift_size[1], None)] 221 | img_mask = torch.zeros([1, H, W, 1]).to(win_fmap.device) 222 | count = 0 223 | for h in h_slice: 224 | for w in w_slice: 225 | img_mask[:, h, w, :] = count 226 | count += 1 227 | win_mask = window_partition(img_mask, self.window_size) 228 | win_mask = win_mask.reshape(-1, self.window_size[0] * self.window_size[1]) # nW, Wh*Ww 229 | attn_mask = win_mask.unsqueeze(2) - win_mask.unsqueeze(1) # nw, Wh*Ww, Wh*Ww 230 | if mask is not None: 231 | attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0).masked_fill((attn_mask != 0) | pad_mask, -100.0) 232 | else: 233 | attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0).masked_fill(attn_mask != 0, -100.0) 234 | attn_fmap = self.self_attn(win_fmap, attn_mask) 235 | else: 236 | if mask is not None: 237 | pad_mask = pad_mask.float() 238 | pad_mask = pad_mask.masked_fill(pad_mask != 0, -100.0).masked_fill(pad_mask == 0, 0.0) 239 | attn_fmap = self.self_attn(win_fmap, pad_mask) 240 | else: 241 | attn_fmap = self.self_attn(win_fmap, None) 242 | shifted_fmap = window_reverse(attn_fmap, self.window_size, H, W) 243 | 244 | if self.shift_size is not None: 245 | fmap = torch.roll(shifted_fmap, [self.shift_size[0], self.shift_size[1]], dims=(1, 2)) 246 | else: 247 | fmap = shifted_fmap 248 | 249 | fmap = shortcut + fmap 250 | fmap = fmap + self.mlp(self.layer_norm2(fmap)) # B, H, W, C 251 | return fmap 252 | 253 | 254 | class CrossTransformerBlcok(nn.Module): 255 | def __init__(self, dim, mlp_scale, norm=None): 256 | super().__init__() 257 | self.dim = dim 258 | 259 | if norm == 'layer': 260 | self.layer_norm1 = nn.LayerNorm(dim) 261 | self.layer_norm2 = nn.LayerNorm(dim) 262 | self.layer_norm3 = nn.LayerNorm(dim) 263 | else: 264 | self.layer_norm1 = nn.Identity() 265 | self.layer_norm2 = nn.Identity() 266 | self.layer_norm3 = nn.Identity() 267 | self.cross_attn = GlobalAttention(dim=dim) 268 | self.mlp = nn.Sequential( 269 | nn.Linear(dim, dim * mlp_scale), 270 | nn.GELU(), 271 | nn.Linear(dim * mlp_scale, dim) 272 | ) 273 | 274 | def forward(self, fmap1, fmap2, mask=None): 275 | """ 276 | :param fmap1: shape: B, H, W, C 277 | :param fmap2: shape: B, H, W, C 278 | :return: B, H, W, C 279 | """ 280 | shortcut = fmap1 281 | 282 | fmap1 = self.layer_norm1(fmap1) 283 | fmap2 = self.layer_norm2(fmap2) 284 | 285 | attn_fmap = self.cross_attn(fmap1, fmap2, mask) 286 | attn_fmap = shortcut + attn_fmap 287 | fmap = attn_fmap + self.mlp(self.layer_norm3(attn_fmap)) # B, H, W, C 288 | return fmap 289 | 290 | 291 | class BasicLayer(nn.Module): 292 | def __init__(self, num_layer, dim, mlp_scale, window_size, cross=False): 293 | super().__init__() 294 | assert num_layer % 2 == 0, "The number of Transformer Block must be even!" 295 | self.blocks = nn.ModuleList() 296 | for n in range(num_layer): 297 | shift_size = None if n % 2 == 0 else [window_size[0]//2, window_size[1]//2] 298 | self.blocks.append( 299 | SelfTransformerBlcok( 300 | dim=dim, 301 | mlp_scale=mlp_scale, 302 | window_size=window_size, 303 | shift_size=shift_size, 304 | norm='layer')) 305 | 306 | if cross: 307 | self.cross_transformer = CrossTransformerBlcok(dim=dim, mlp_scale=mlp_scale, norm='layer') 308 | 309 | self.cross = cross 310 | 311 | def forward(self, fmap1, fmap2, mask=None): 312 | """ 313 | :param fmap1: B, H, W, C 314 | :param fmap2: B, H, W, C 315 | :return: B, H, W, C 316 | """ 317 | B = fmap1.shape[0] 318 | fmap = torch.cat([fmap1, fmap2], dim=0) 319 | for blk in self.blocks: 320 | fmap = blk(fmap, mask) 321 | fmap1, fmap2 = torch.split(fmap, [B]*2, dim=0) 322 | if self.cross: 323 | fmap2 = self.cross_transformer(fmap2, fmap1, mask) + fmap2 324 | fmap1 = self.cross_transformer(fmap1, fmap2, mask) + fmap1 325 | return fmap1, fmap2 326 | 327 | 328 | # --- upsample modules 329 | class UpSampleMask8(nn.Module): 330 | def __init__(self, dim): 331 | super().__init__() 332 | self.up_sample_mask = nn.Sequential( 333 | nn.Conv2d(in_channels=dim, out_channels=256, kernel_size=3, padding=1, stride=1), 334 | nn.ReLU(inplace=True), 335 | nn.Conv2d(in_channels=256, out_channels=64 * 9, kernel_size=1, stride=1) 336 | ) 337 | 338 | def forward(self, data): 339 | """ 340 | :param data: B, C, H, W 341 | :return: batch, 8*8*9, H, W 342 | """ 343 | mask = self.up_sample_mask(data) # B, 64*6, H, W 344 | return mask 345 | 346 | 347 | class UpSampleMask4(nn.Module): 348 | def __init__(self, dim): 349 | super().__init__() 350 | self.up_sample_mask = nn.Sequential( 351 | nn.Conv2d(in_channels=dim, out_channels=256, kernel_size=3, padding=1, stride=1), 352 | nn.ReLU(inplace=True), 353 | nn.Conv2d(in_channels=256, out_channels=16 * 9, kernel_size=1, stride=1) 354 | ) 355 | 356 | def forward(self, data): 357 | """ 358 | :param data: B, C, H, W 359 | :return: batch, 8*8*9, H, W 360 | """ 361 | mask = self.up_sample_mask(data) # B, 64*6, H, W 362 | return mask 363 | 364 | 365 | # --- SK decoder modules 366 | class PCBlock4_Deep_nopool_res(nn.Module): 367 | def __init__(self, C_in, C_out, k_conv): 368 | super().__init__() 369 | self.conv_list = nn.ModuleList([ 370 | nn.Conv2d(C_in, C_in, kernel, stride=1, padding=kernel//2, groups=C_in) for kernel in k_conv]) 371 | 372 | self.ffn1 = nn.Sequential( 373 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 374 | nn.GELU(), 375 | nn.Conv2d(int(1.5*C_in), C_in, 1, padding=0), 376 | ) 377 | self.pw = nn.Conv2d(C_in, C_in, 1, padding=0) 378 | self.ffn2 = nn.Sequential( 379 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 380 | nn.GELU(), 381 | nn.Conv2d(int(1.5*C_in), C_out, 1, padding=0), 382 | ) 383 | 384 | def forward(self, x): 385 | x = F.gelu(x + self.ffn1(x)) 386 | for conv in self.conv_list: 387 | x = F.gelu(x + conv(x)) 388 | x = F.gelu(x + self.pw(x)) 389 | x = self.ffn2(x) 390 | return x 391 | 392 | 393 | class SKMotionEncoder6_Deep_nopool_res(nn.Module): 394 | def __init__(self, args): 395 | super().__init__() 396 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 397 | self.convc1 = PCBlock4_Deep_nopool_res(cor_planes, 256, k_conv=args.k_conv) 398 | self.convc2 = PCBlock4_Deep_nopool_res(256, 192, k_conv=args.k_conv) 399 | 400 | self.convf1 = nn.Conv2d(2, 128, 1, 1, 0) 401 | self.convf2 = PCBlock4_Deep_nopool_res(128, 64, k_conv=args.k_conv) 402 | 403 | self.conv = PCBlock4_Deep_nopool_res(64+192, 128-2, k_conv=args.k_conv) 404 | 405 | def forward(self, flow, corr): 406 | cor = F.gelu(self.convc1(corr)) 407 | 408 | cor = self.convc2(cor) 409 | 410 | flo = self.convf1(flow) 411 | flo = self.convf2(flo) 412 | 413 | cor_flo = torch.cat([cor, flo], dim=1) 414 | out = self.conv(cor_flo) 415 | 416 | return torch.cat([out, flow], dim=1) 417 | 418 | 419 | class SKUpdateBlock6_Deep_nopoolres_AllDecoder(nn.Module): 420 | def __init__(self, args, hidden_dim): 421 | super().__init__() 422 | self.args = args 423 | self.encoder = SKMotionEncoder6_Deep_nopool_res(args) 424 | self.gru = PCBlock4_Deep_nopool_res(128+hidden_dim+hidden_dim+128, 128, k_conv=args.PCUpdater_conv) 425 | self.flow_head = PCBlock4_Deep_nopool_res(128, 2, k_conv=args.k_conv) 426 | 427 | self.mask = nn.Sequential( 428 | nn.Conv2d(128, 256, 3, padding=1), 429 | nn.ReLU(inplace=True), 430 | nn.Conv2d(256, 64*9, 1, padding=0)) 431 | 432 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=self.args.num_heads) 433 | 434 | def forward(self, net, inp, corr, flow, attention): 435 | motion_features = self.encoder(flow, corr) 436 | motion_features_global = self.aggregator(attention, motion_features) 437 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) 438 | 439 | # Attentional update 440 | net = self.gru(torch.cat([net, inp_cat], dim=1)) 441 | 442 | delta_flow = self.flow_head(net) 443 | 444 | # scale mask to balence gradients 445 | mask = .25 * self.mask(net) 446 | return net, mask, delta_flow 447 | 448 | 449 | class Aggregator(nn.Module): 450 | def __init__(self, args, chnn, heads=1): 451 | super().__init__() 452 | self.scale = chnn ** -0.5 453 | self.to_qk = nn.Conv2d(chnn, chnn * 2, 1, bias=False) 454 | self.to_v = nn.Conv2d(128, 128, 1, bias=False) 455 | self.gamma = nn.Parameter(torch.zeros(1)) 456 | 457 | def forward(self, *inputs): 458 | feat_ctx, feat_mo, itr = inputs 459 | 460 | feat_shape = feat_mo.shape 461 | b, c, h, w = feat_shape 462 | c_c = feat_ctx.shape[1] 463 | 464 | if itr == 0: 465 | feat_q, feat_k = self.to_qk(feat_ctx).chunk(2, dim=1) 466 | feat_q = self.scale * feat_q.view(b, c_c, h*w) 467 | feat_k = feat_k.view(b, c_c, h*w) 468 | 469 | attn = torch.einsum('b c n, b c m -> b m n', feat_q, feat_k) 470 | attn = attn.view(b, 1, h*w, h*w) 471 | self.attn = attn.softmax(2).view(b, h*w, h*w).permute(0, 2, 1).contiguous() 472 | 473 | feat_v = self.to_v(feat_mo).view(b, c, h*w) 474 | feat_o = torch.einsum('b n m, b c m -> b c n', self.attn, feat_v).contiguous().view(b, c, h, w) 475 | feat_o = feat_mo + feat_o * self.gamma 476 | return feat_o 477 | 478 | 479 | class SKUpdate(nn.Module): 480 | def __init__(self, args, hidden_dim): 481 | super().__init__() 482 | self.args = args 483 | d_dim = args.c_dim 484 | 485 | self.encoder = SKMotionEncoder6_Deep_nopool_res(args) 486 | self.gru = PCBlock4_Deep_nopool_res(128+hidden_dim+hidden_dim+128, d_dim, k_conv=args.PCUpdater_conv) 487 | self.flow_head = PCBlock4_Deep_nopool_res(d_dim, 2, k_conv=args.k_conv) 488 | self.aggregator = Aggregator(self.args, d_dim) 489 | 490 | def forward(self, net, inp, corr, flow, itr=None, sp4=False): 491 | motion_features = self.encoder(flow, corr) 492 | 493 | if not sp4: 494 | motion_features_global = self.aggregator(inp, motion_features, itr) 495 | else: 496 | motion_features_global = motion_features 497 | 498 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) 499 | net = self.gru(torch.cat([net, inp_cat], dim=1)) 500 | 501 | delta_flow = self.flow_head(net) 502 | return net, delta_flow 503 | 504 | 505 | class SinusoidalPositionEmbeddings(nn.Module): 506 | def __init__(self, dim): 507 | super().__init__() 508 | self.dim = dim 509 | 510 | def forward(self, time): 511 | device = time.device 512 | half_dim = self.dim // 2 513 | embeddings = math.log(10000) / (half_dim - 1) 514 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 515 | embeddings = time[:, None] * embeddings[None, :] 516 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 517 | return embeddings 518 | 519 | 520 | class ConvEE(nn.Module): 521 | def __init__(self, C_in, C_out): 522 | super().__init__() 523 | groups = 4 524 | self.conv1 = nn.Sequential( 525 | nn.GroupNorm(groups, C_in), 526 | nn.GELU(), 527 | nn.Conv2d(C_in, C_in, 3, padding=1), 528 | nn.GroupNorm(groups, C_in)) 529 | self.conv2 = nn.Sequential( 530 | nn.GELU(), 531 | nn.Conv2d(C_in, C_in, 3, padding=1)) 532 | self.gamma = nn.Parameter(torch.zeros(1)) 533 | 534 | def forward(self, x, t_emb): 535 | scale, shift = t_emb 536 | x_res = x 537 | x = self.conv1(x) 538 | 539 | x = x * (scale + 1) + shift 540 | 541 | x = self.conv2(x) 542 | x_o = x * self.gamma 543 | 544 | return x_o 545 | 546 | 547 | class SKUpdateDFM(nn.Module): 548 | def __init__(self, args, hidden_dim): 549 | super().__init__() 550 | self.args = args 551 | chnn = hidden_dim 552 | self.conv_ee = ConvEE(chnn, chnn) 553 | 554 | d_model = 256 555 | self.d_model = d_model 556 | time_dim = d_model * 2 557 | self.time_mlp = nn.Sequential( 558 | SinusoidalPositionEmbeddings(d_model), 559 | nn.Linear(d_model, time_dim), 560 | nn.GELU(), 561 | nn.Linear(time_dim, time_dim)) 562 | self.chnn_o = 256 563 | self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, self.chnn_o)) 564 | 565 | def forward(self, net, inp, corr, flow, itr, first_step=False, dfm_params=[]): 566 | t, funcs, i_ddim, dfm_itrs = dfm_params 567 | b = t.shape[0] 568 | time_emb = self.time_mlp(t) 569 | 570 | scale_shift = self.block_time_mlp(time_emb) 571 | scale_shift = scale_shift.view(b, 256, 1, 1) 572 | scale, shift = scale_shift.chunk(2, dim=1) 573 | 574 | motion_features = funcs.encoder(flow, corr) 575 | 576 | if first_step: 577 | self.shape = net.shape 578 | 579 | if self.shape == net.shape: 580 | feat_mo = funcs.aggregator(inp, motion_features, itr) 581 | else: 582 | feat_mo = motion_features 583 | 584 | feat_mo = self.conv_ee(feat_mo, [scale, shift]) 585 | 586 | inp = torch.cat([inp, motion_features, feat_mo], dim=1) 587 | net = funcs.gru(torch.cat([net, inp], dim=1)) 588 | 589 | net = net * (scale + 1) + shift 590 | 591 | delta_flow = funcs.flow_head(net) 592 | 593 | return net, delta_flow 594 | 595 | 596 | 597 | -------------------------------------------------------------------------------- /core/fd_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import timm 6 | import numpy as np 7 | 8 | 9 | class twins_svt_large(nn.Module): 10 | def __init__(self, pretrained=True): 11 | super().__init__() 12 | self.svt = timm.create_model('twins_svt_large', pretrained=pretrained) 13 | 14 | del self.svt.head 15 | del self.svt.patch_embeds[2] 16 | del self.svt.patch_embeds[2] 17 | del self.svt.blocks[2] 18 | del self.svt.blocks[2] 19 | del self.svt.pos_block[2] 20 | del self.svt.pos_block[2] 21 | 22 | def forward(self, x, data=None, layer=2): 23 | B = x.shape[0] 24 | x_4 = None 25 | for i, (embed, drop, blocks, pos_blk) in enumerate( 26 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 27 | 28 | patch_size = embed.patch_size 29 | if i == layer - 1: 30 | embed.patch_size = (1, 1) 31 | embed.proj.stride = embed.patch_size 32 | x_4 = torch.nn.functional.pad(x, [1, 0, 1, 0], mode='constant', value=0) 33 | x_4, size_4 = embed(x_4) 34 | size_4 = (size_4[0] - 1, size_4[1] - 1) 35 | x_4 = drop(x_4) 36 | for j, blk in enumerate(blocks): 37 | x_4 = blk(x_4, size_4) 38 | if j == 0: 39 | x_4 = pos_blk(x_4, size_4) 40 | 41 | if i < len(self.svt.depths) - 1: 42 | x_4 = x_4.reshape(B, *size_4, -1).permute(0, 3, 1, 2).contiguous() 43 | 44 | embed.patch_size = patch_size 45 | embed.proj.stride = patch_size 46 | x, size = embed(x) 47 | x = drop(x) 48 | for j, blk in enumerate(blocks): 49 | x = blk(x, size) 50 | if j==0: 51 | x = pos_blk(x, size) 52 | if i < len(self.svt.depths) - 1: 53 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() 54 | 55 | if i == layer-1: 56 | break 57 | 58 | return x, x_4 59 | 60 | def compute_params(self, layer=2): 61 | num = 0 62 | for i, (embed, drop, blocks, pos_blk) in enumerate( 63 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 64 | 65 | for param in embed.parameters(): 66 | num += np.prod(param.size()) 67 | 68 | for param in drop.parameters(): 69 | num += np.prod(param.size()) 70 | 71 | for param in blocks.parameters(): 72 | num += np.prod(param.size()) 73 | 74 | for param in pos_blk.parameters(): 75 | num += np.prod(param.size()) 76 | 77 | if i == layer-1: 78 | break 79 | 80 | for param in self.svt.head.parameters(): 81 | num += np.prod(param.size()) 82 | 83 | return num 84 | 85 | 86 | class twins_svt_small_context(nn.Module): 87 | def __init__(self, pretrained=True): 88 | super().__init__() 89 | self.svt = timm.create_model('twins_svt_small', pretrained=pretrained) 90 | 91 | del self.svt.head 92 | del self.svt.patch_embeds[2] 93 | del self.svt.patch_embeds[2] 94 | del self.svt.blocks[2] 95 | del self.svt.blocks[2] 96 | del self.svt.pos_block[2] 97 | del self.svt.pos_block[2] 98 | 99 | def forward(self, x, data=None, layer=2): 100 | B = x.shape[0] 101 | x_4 = None 102 | for i, (embed, drop, blocks, pos_blk) in enumerate( 103 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 104 | 105 | patch_size = embed.patch_size 106 | if i == layer - 1: 107 | embed.patch_size = (1, 1) 108 | embed.proj.stride = embed.patch_size 109 | x_4 = torch.nn.functional.pad(x, [1, 0, 1, 0], mode='constant', value=0) 110 | x_4, size_4 = embed(x_4) 111 | size_4 = (size_4[0] - 1, size_4[1] - 1) 112 | x_4 = drop(x_4) 113 | for j, blk in enumerate(blocks): 114 | x_4 = blk(x_4, size_4) 115 | if j == 0: 116 | x_4 = pos_blk(x_4, size_4) 117 | 118 | if i < len(self.svt.depths) - 1: 119 | x_4 = x_4.reshape(B, *size_4, -1).permute(0, 3, 1, 2).contiguous() 120 | 121 | embed.patch_size = patch_size 122 | embed.proj.stride = patch_size 123 | x, size = embed(x) 124 | x = drop(x) 125 | for j, blk in enumerate(blocks): 126 | x = blk(x, size) 127 | if j == 0: 128 | x = pos_blk(x, size) 129 | if i < len(self.svt.depths) - 1: 130 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() 131 | 132 | if i == layer - 1: 133 | break 134 | 135 | return x, x_4 136 | 137 | def compute_params(self, layer=2): 138 | num = 0 139 | for i, (embed, drop, blocks, pos_blk) in enumerate( 140 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 141 | 142 | for param in embed.parameters(): 143 | num += np.prod(param.size()) 144 | 145 | for param in drop.parameters(): 146 | num += np.prod(param.size()) 147 | 148 | for param in blocks.parameters(): 149 | num += np.prod(param.size()) 150 | 151 | for param in pos_blk.parameters(): 152 | num += np.prod(param.size()) 153 | 154 | if i == layer - 1: 155 | break 156 | 157 | for param in self.svt.head.parameters(): 158 | num += np.prod(param.size()) 159 | 160 | return num 161 | 162 | -------------------------------------------------------------------------------- /core/flowdiffuser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from corr import CorrBlock 7 | from utils.utils import coords_grid 8 | 9 | from fd_encoder import twins_svt_large, twins_svt_small_context 10 | from fd_decoder import UpSampleMask8, UpSampleMask4, TransformerModule, SKUpdate, SinusoidalPositionEmbeddings, SKUpdateDFM 11 | from fd_corr import CorrBlock_FD_Sp4 12 | 13 | autocast = torch.cuda.amp.autocast 14 | 15 | 16 | def exists(x): 17 | return x is not None 18 | 19 | 20 | def default(val, d): 21 | if exists(val): 22 | return val 23 | return d() if callable(d) else d 24 | 25 | 26 | def extract(a, t, x_shape): 27 | """extract the appropriate t index for a batch of indices""" 28 | batch_size = t.shape[0] 29 | out = a.gather(-1, t) 30 | return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))) 31 | 32 | 33 | def cosine_beta_schedule(timesteps, s=0.008): 34 | """ 35 | cosine schedule 36 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 37 | """ 38 | steps = timesteps + 1 39 | x = torch.linspace(0, timesteps, steps, dtype=torch.float64) 40 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 41 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 42 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 43 | return torch.clip(betas, 0, 0.999) 44 | 45 | 46 | def ste_round(x): 47 | return torch.round(x) - x.detach() + x 48 | 49 | 50 | class FlowDiffuser(nn.Module): 51 | def __init__(self, args): 52 | super().__init__() 53 | print('\n ---------- model: FlowDiffuser ---------- \n') 54 | 55 | args.corr_levels = 4 56 | args.corr_radius = 4 57 | args.m_dim = 256 58 | args.c_dim = c_dim = 128 59 | args.iters_const6 = 6 60 | 61 | self.args = args 62 | self.args.UpdateBlock = 'SKUpdateBlock6_Deep_nopoolres_AllDecoder' 63 | self.args.k_conv = [1, 15] 64 | self.args.PCUpdater_conv = [1, 7] 65 | self.sp4 = True 66 | self.rad = 8 67 | 68 | self.fnet = twins_svt_large(pretrained=True) 69 | self.cnet = twins_svt_small_context(pretrained=True) 70 | self.trans = TransformerModule(args) 71 | self.C_inp = nn.Conv2d(in_channels=c_dim, out_channels=c_dim, kernel_size=1) 72 | self.C_net = nn.Conv2d(in_channels=c_dim, out_channels=c_dim, kernel_size=1) 73 | self.update = SKUpdate(self.args, hidden_dim=c_dim) 74 | self.um8 = UpSampleMask8(c_dim) 75 | self.um4 = UpSampleMask4(c_dim) 76 | self.zero = nn.Parameter(torch.zeros(12), requires_grad=False) 77 | 78 | self.diffusion = True 79 | if self.diffusion: 80 | self.update_dfm = SKUpdateDFM(self.args, hidden_dim=c_dim) 81 | 82 | timesteps = 1000 83 | sampling_timesteps = 4 84 | recurr_itrs = 6 85 | print(' -- denoise steps: %d \n' % sampling_timesteps) 86 | print(' -- recurrent iterations: %d \n' % recurr_itrs) 87 | 88 | self.ddim_n = sampling_timesteps 89 | self.recurr_itrs = recurr_itrs 90 | self.n_sc = 0.1 91 | self.scale = nn.Parameter(torch.ones(1) * 0.5, requires_grad=False) 92 | self.n_lambda = 0.2 93 | 94 | self.objective = 'pred_x0' 95 | betas = cosine_beta_schedule(timesteps) 96 | alphas = 1. - betas 97 | alphas_cumprod = torch.cumprod(alphas, dim=0) 98 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) 99 | timesteps, = betas.shape 100 | self.num_timesteps = int(timesteps) 101 | 102 | self.sampling_timesteps = default(sampling_timesteps, timesteps) 103 | assert self.sampling_timesteps <= timesteps 104 | self.is_ddim_sampling = self.sampling_timesteps < timesteps 105 | self.ddim_sampling_eta = 1. 106 | self.self_condition = False 107 | 108 | self.register_buffer('betas', betas) 109 | self.register_buffer('alphas_cumprod', alphas_cumprod) 110 | self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 111 | self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 112 | self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 113 | self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 114 | self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 115 | self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 116 | 117 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 118 | self.register_buffer('posterior_variance', posterior_variance) 119 | self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20))) 120 | self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 121 | self.register_buffer('posterior_mean_coef2', 122 | (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) 123 | 124 | def freeze_bn(self): 125 | for m in self.modules(): 126 | if isinstance(m, nn.BatchNorm2d): 127 | m.eval() 128 | 129 | def up_sample_flow8(self, flow, mask): 130 | B, _, H, W = flow.shape 131 | flow = torch.nn.functional.unfold(8 * flow, kernel_size=[3, 3], stride=[1, 1], padding=[1, 1]) 132 | flow = flow.reshape(B, 2, 9, 1, 1, H, W) 133 | mask = mask.reshape(B, 1, 9, 8, 8, H, W) 134 | mask = torch.softmax(mask, dim=2) 135 | up_flow = torch.sum(flow * mask, dim=2) 136 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3).contiguous() 137 | up_flow = up_flow.reshape(B, 2, H * 8, W * 8) 138 | return up_flow 139 | 140 | def up_sample_flow4(self, flow, mask): 141 | B, _, H, W = flow.shape 142 | flow = torch.nn.functional.unfold(4 * flow, kernel_size=[3, 3], stride=[1, 1], padding=[1, 1]) 143 | flow = flow.reshape(B, 2, 9, 1, 1, H, W) 144 | mask = mask.reshape(B, 1, 9, 4, 4, H, W) 145 | mask = torch.softmax(mask, dim=2) 146 | up_flow = torch.sum(flow * mask, dim=2) 147 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3).contiguous() 148 | up_flow = up_flow.reshape(B, 2, H * 4, W * 4) 149 | return up_flow 150 | 151 | def initialize_flow8(self, img): 152 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 153 | N, C, H, W = img.shape 154 | coords0 = coords_grid(N, H//8, W//8, device=img.device).permute(0, 2, 3, 1).contiguous() 155 | coords1 = coords_grid(N, H//8, W//8, device=img.device).permute(0, 2, 3, 1).contiguous() 156 | return coords0, coords1 157 | 158 | def initialize_flow4(self, img): 159 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 160 | N, C, H, W = img.shape 161 | coords0 = coords_grid(N, H//4, W//4, device=img.device).permute(0, 2, 3, 1).contiguous() 162 | coords1 = coords_grid(N, H//4, W//4, device=img.device).permute(0, 2, 3, 1).contiguous() 163 | return coords0, coords1 164 | 165 | def _train_dfm(self, feat_shape, flow_gt, net, inp8, coords0, coords1): 166 | b, c, h, w = feat_shape 167 | if len(flow_gt.shape) == 3: 168 | flow_gt = flow_gt.unsqueeze(0) 169 | flow_gt_sp8 = F.interpolate(flow_gt, (h, w), mode='bilinear', align_corners=True) / 8. 170 | 171 | x_t, noises, t = self._prepare_targets(flow_gt_sp8) 172 | x_t = x_t * self.norm_const 173 | coords1 = coords1 + x_t.float() 174 | 175 | flow_up_s = [] 176 | for ii in range(self.recurr_itrs): 177 | t_ii = (t - t / self.recurr_itrs * ii).int() 178 | 179 | itr = ii 180 | first_step = False if itr != 0 else True 181 | 182 | coords1 = coords1.detach() 183 | corr = self.corr_fn(coords1) 184 | flow = coords1 - coords0 185 | with autocast(enabled=self.args.mixed_precision): 186 | dfm_params = [t_ii, self.update, ii, 0] 187 | net, delta_flow = self.update_dfm(net, inp8, corr, flow, itr, first_step=first_step, dfm_params=dfm_params) 188 | up_mask = self.um8(net) 189 | 190 | coords1 = coords1 + delta_flow 191 | flow = coords1 - coords0 192 | 193 | flow_up = self.up_sample_flow8(flow, up_mask) 194 | flow_up_s.append(flow_up) 195 | 196 | return flow_up_s, coords1, net 197 | 198 | def _prepare_targets(self, flow_gt): 199 | noise = torch.randn(flow_gt.shape, device=self.device) 200 | t = torch.randint(0, self.num_timesteps, (1,), device=self.device).long() 201 | 202 | x_start = flow_gt / self.norm_const 203 | x_start = x_start * self.scale 204 | x_t = self._q_sample(x_start=x_start, t=t, noise=noise) 205 | x_t = torch.clamp(x_t, min=-1, max=1) 206 | x_t = x_t * self.n_sc 207 | return x_t, noise, t 208 | 209 | def _q_sample(self, x_start, t, noise=None): 210 | if noise is None: 211 | noise = torch.randn_like(x_start) 212 | 213 | sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape) 214 | sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 215 | 216 | return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise 217 | 218 | @torch.no_grad() 219 | def _ddim_sample(self, feat_shape, net, inp, coords0, coords1_init, clip_denoised=True): 220 | batch, c, h, w = feat_shape 221 | shape = (batch, 2, h, w) 222 | total_timesteps, sampling_timesteps, eta, objective = self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective 223 | times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) 224 | times = list(reversed(times.int().tolist())) 225 | time_pairs = list(zip(times[:-1], times[1:])) 226 | x_in = torch.randn(shape, device=self.device) 227 | 228 | flow_s = [] 229 | x_start = None 230 | pred_s = None 231 | for i_ddim, time_s in enumerate(time_pairs): 232 | time, time_next = time_s 233 | time_cond = torch.full((batch,), time, device=self.device, dtype=torch.long) 234 | t_next = torch.full((batch,), time_next, device=self.device, dtype=torch.long) 235 | 236 | x_pred, inner_flow_s, pred_s = self._model_predictions(x_in, time_cond, net, inp, coords0, coords1_init, i_ddim, pred_s, t_next) 237 | flow_s = flow_s + inner_flow_s 238 | 239 | alpha = self.alphas_cumprod[time] 240 | alpha_next = self.alphas_cumprod[time_next] 241 | 242 | x_t = x_in 243 | x_pred = x_pred * self.scale 244 | x_pred = torch.clamp(x_pred, min=-1 * self.scale, max=self.scale) 245 | eps = (1 / (1 - alpha).sqrt()) * (x_t - alpha.sqrt() * x_pred) 246 | x_next = alpha_next.sqrt() * x_pred + (1 - alpha_next).sqrt() * eps 247 | x_in = x_next 248 | 249 | net, up_mask, coords1 = pred_s 250 | 251 | return coords1, net, flow_s 252 | 253 | def _model_predictions(self, x, t, net, inp8, coords0, coords1, i_ddim, pred_last=None, t_next=None): 254 | x_flow = torch.clamp(x, min=-1, max=1) 255 | x_flow = x_flow * self.n_sc 256 | x_flow = x_flow * self.norm_const 257 | 258 | if pred_last: 259 | net, _, coords1 = pred_last 260 | x_flow = x_flow * self.n_lambda 261 | 262 | coords1 = coords1 + x_flow.float() 263 | 264 | flow_s = [] 265 | for ii in range(self.recurr_itrs): 266 | t_ii = (t - (t - 0) / self.recurr_itrs * ii).int() 267 | 268 | corr = self.corr_fn(coords1) 269 | flow = coords1 - coords0 270 | 271 | with autocast(enabled=self.args.mixed_precision): 272 | itr = ii 273 | first_step = False if itr != 0 else True 274 | dfm_params = [t_ii, self.update, ii, 0] 275 | net, delta_flow = self.update_dfm(net, inp8, corr, flow, itr, first_step=first_step, dfm_params=dfm_params) 276 | up_mask = self.um8(net) 277 | 278 | coords1 = coords1 + delta_flow 279 | 280 | flow = coords1 - coords0 281 | flow_up = self.up_sample_flow8(flow, up_mask) 282 | 283 | flow_s.append(flow_up) 284 | 285 | flow = coords1 - coords0 286 | x_pred = flow / self.norm_const 287 | 288 | return x_pred, flow_s, [net, up_mask, coords1] 289 | 290 | def _predict_noise_from_start(self, x_t, t, x0): 291 | return ( 292 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / 293 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 294 | ) 295 | 296 | def forward(self, image1, image2, test_mode=False, iters=None, flow_gt=None, flow_init=None): 297 | 298 | image1 = 2 * (image1 / 255.0) - 1.0 299 | image2 = 2 * (image2 / 255.0) - 1.0 300 | 301 | with autocast(enabled=self.args.mixed_precision): 302 | fmap = self.fnet(torch.cat([image1, image2], dim=0)) 303 | inp = self.cnet(image1) 304 | 305 | fmap, fmap4 = fmap 306 | inp, inp4 = inp 307 | fmap = fmap.float() 308 | fmap4 = fmap4.float() 309 | inp = inp.float() 310 | inp4 = inp4.float() 311 | 312 | fmap1_4, fmap2_4 = torch.chunk(fmap4, chunks=2, dim=0) 313 | fmap1_8, fmap2_8 = torch.chunk(fmap, chunks=2, dim=0) 314 | inp8 = self.C_inp(inp) 315 | net = self.C_net(inp) 316 | 317 | corr_fn = self.trans(fmap1_8, fmap2_8, inp8) 318 | 319 | coords0, coords1 = self.initialize_flow8(image1) 320 | coords0 = coords0.permute(0, 3, 1, 2).contiguous() 321 | coords1 = coords1.permute(0, 3, 1, 2).contiguous() 322 | 323 | flow_list = [] 324 | if flow_init is not None: 325 | if flow_init.shape[-2:] != coords1.shape[-2:]: 326 | flow_init = F.interpolate(flow_init, coords1.shape[-2:], mode='bilinear', align_corners=True) * 0.5 327 | coords1 = coords1 + flow_init 328 | 329 | if self.diffusion: 330 | self.corr_fn = corr_fn 331 | self.device = fmap1_8.device 332 | h, w = fmap1_8.shape[-2:] 333 | self.norm_const = torch.as_tensor([w, h], dtype=torch.float, device=self.device).view(1, 2, 1, 1) 334 | 335 | if self.training: 336 | coords1 = coords1.detach() 337 | flow_up_s, coords1, net = self._train_dfm(fmap1_8.shape, flow_gt, net, inp8, coords0, coords1) 338 | else: 339 | coords1, net, flow_up_s = self._ddim_sample(fmap1_8.shape, net, inp8, coords0, coords1) 340 | 341 | if self.sp4: 342 | flow4 = torch.nn.functional.interpolate(2 * (coords1 - coords0), scale_factor=2, mode='bilinear', align_corners=True) 343 | coords0, coords1 = self.initialize_flow4(image1) 344 | coords0 = coords0.permute(0, 3, 1, 2).contiguous() 345 | coords1 = coords1.permute(0, 3, 1, 2).contiguous() 346 | coords1 = coords1 + flow4 347 | 348 | net = torch.nn.functional.interpolate(net, scale_factor=2, mode='bilinear', align_corners=True) 349 | coords1_rd = ste_round(coords1) 350 | 351 | corr_fn4 = CorrBlock_FD_Sp4(fmap1_4, fmap2_4, num_levels=self.args.corr_levels, radius=self.args.corr_radius, coords_init=coords1_rd, rad=self.rad) 352 | 353 | for itr in range(self.args.iters_const6): 354 | coords1 = coords1.detach() 355 | corr = corr_fn4(coords1 - coords1_rd + self.rad) 356 | 357 | flow = coords1 - coords0 358 | with autocast(enabled=self.args.mixed_precision): 359 | net, delta_flow = self.update(net, inp4, corr, flow, itr, sp4=True) 360 | up_mask = self.um4(net) 361 | 362 | coords1 = coords1 + delta_flow 363 | flow_up = self.up_sample_flow4(coords1 - coords0, up_mask) 364 | flow_up_s.append(flow_up) 365 | 366 | flow_list = flow_list + flow_up_s 367 | 368 | if test_mode: 369 | flow = coords1 - coords0 370 | return flow, flow_list[-1] 371 | 372 | return flow_list 373 | 374 | -------------------------------------------------------------------------------- /core/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import time 6 | import os 7 | from torch.nn.parameter import Parameter 8 | import math 9 | import numpy as np 10 | import cv2 11 | 12 | 13 | class KPAFlowDec(nn.Module): 14 | def __init__(self, args, chnn=128): 15 | super().__init__() 16 | self.args = args 17 | cor_planes = 4 * (2 * args.corr_radius + 1) ** 2 18 | self.C_cor = nn.Sequential( 19 | nn.Conv2d(cor_planes, 256, 1), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(256, 192, 3, padding=1), 22 | nn.ReLU(inplace=True)) 23 | self.C_flo = nn.Sequential( 24 | nn.Conv2d(2, 128, 7, padding=3), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(128, 64, 3, padding=1), 27 | nn.ReLU(inplace=True)) 28 | self.C_mo = nn.Sequential( 29 | nn.Conv2d(192+64, 128-2, 3, padding=1), 30 | nn.ReLU(inplace=True)) 31 | 32 | self.kpa = KPA(args, chnn) 33 | self.gru = SepConvGRU(hidden_dim=chnn, input_dim=chnn+chnn+chnn) 34 | self.C_flow = nn.Sequential( 35 | nn.Conv2d(chnn, chnn*2, 3, padding=1), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(chnn*2, 2, 3, padding=1)) 38 | self.C_mask = nn.Sequential( 39 | nn.Conv2d(chnn, chnn*2, 3, padding=1), 40 | nn.ReLU(inplace=True), 41 | nn.Conv2d(chnn*2, 64*9, 1, padding=0)) 42 | 43 | def _mo_enc(self, flow, corr, itr): 44 | feat_cor = self.C_cor(corr) 45 | feat_flo = self.C_flo(flow) 46 | feat_cat = torch.cat([feat_cor, feat_flo], dim=1) 47 | feat_mo = self.C_mo(feat_cat) 48 | feat_mo = torch.cat([feat_mo, flow], dim=1) 49 | return feat_mo 50 | 51 | def forward(self, net, inp, corr, flow, itr, upsample=True): 52 | feat_mo = self._mo_enc(flow, corr, itr) 53 | feat_moa = self.kpa(inp, feat_mo, itr) 54 | inp = torch.cat([inp, feat_mo, feat_moa], dim=1) 55 | net = self.gru(net, inp) 56 | delta_flow = self.C_flow(net) 57 | 58 | # scale mask to balence gradients 59 | mask = .25 * self.C_mask(net) 60 | return net, mask, delta_flow 61 | 62 | 63 | class KPA(nn.Module): 64 | def __init__(self, args, chnn): 65 | super().__init__() 66 | self.unfold_type = 'x311' 67 | if 'kitti' in args.dataset: 68 | self.sc = 15 69 | else: 70 | self.sc = 19 71 | 72 | self.unfold = nn.Unfold(kernel_size=3*self.sc, dilation=1, padding=self.sc, stride=self.sc) 73 | self.scale = chnn ** -0.5 74 | self.to_qk = nn.Conv2d(chnn, chnn * 2, 1, bias=False) 75 | self.to_v = nn.Conv2d(chnn, chnn, 1, bias=False) 76 | self.gamma = nn.Parameter(torch.zeros(1)) 77 | 78 | h_k = (3 * self.sc - 1) / 2 79 | self.w_prelu = nn.Parameter(torch.zeros(1) + 1/h_k) 80 | self.scp = 0.02 81 | self.b = 1. 82 | 83 | def _FS(self, attn, shape): 84 | b, c, h, w, h_sc, w_sc = shape 85 | device = attn.device 86 | k = int(math.sqrt(attn.shape[1])) 87 | crd_k = torch.linspace(0, k-1, k).to(device) 88 | x = crd_k.view(1, 1, k, 1, 1).expand(b, 1, k, h, w) 89 | y = crd_k.view(1, k, 1, 1, 1).expand(b, k, 1, h, w) 90 | 91 | sc = torch.tensor(self.sc).to(device) 92 | idx_x = sc.view(1, 1, 1, 1, 1).expand(b, 1, 1, h, w) 93 | idx_y = sc.view(1, 1, 1, 1, 1).expand(b, 1, 1, h, w) 94 | crd_w = torch.linspace(0, w-1, w).to(device) 95 | crd_h = torch.linspace(0, h-1, h).to(device) 96 | idx_x = idx_x + crd_w.view(1, 1, 1, 1, w).expand(b, 1, 1, h, w) % self.sc 97 | idx_y = idx_y + crd_h.view(1, 1, 1, h, 1).expand(b, 1, 1, h, w) % self.sc 98 | 99 | half_ker = torch.tensor(self.sc * 2).to(device) 100 | o_x = -1 * F.prelu(abs(x - idx_x) - half_ker, self.w_prelu * self.scp) + self.b 101 | o_x[o_x < 0] = 0 102 | o_y = -1 * F.prelu(abs(y - idx_y) - half_ker, self.w_prelu * self.scp) + self.b 103 | o_y[o_y < 0] = 0 104 | ker_S = o_x * o_y 105 | ker_S = ker_S.view(b, k**2, h_sc, self.sc, w_sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(b, k**2, h_sc*w_sc, self.sc**2) 106 | return ker_S 107 | 108 | def forward(self, *inputs): 109 | feat_ci, feat_mi, itr = inputs 110 | b, c, h_in, w_in = feat_mi.shape 111 | 112 | x_pad = self.sc - w_in % self.sc 113 | y_pad = self.sc - h_in % self.sc 114 | feat_c = F.pad(feat_ci, (0, x_pad, 0, y_pad)) 115 | feat_m = F.pad(feat_mi, (0, x_pad, 0, y_pad)) 116 | b, c, h, w = feat_c.shape 117 | h_sc = h // self.sc 118 | w_sc = w // self.sc 119 | 120 | fm = torch.ones(1, 1, h_in, w_in).to(feat_m.device) 121 | fm = F.pad(fm, (0, x_pad, 0, y_pad)) 122 | fm_k = self.unfold(fm).view(1, 1, -1, h_sc*w_sc) 123 | fm_q = fm.view(1, 1, h_sc, self.sc, w_sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(1, 1, h_sc*w_sc, self.sc**2) 124 | am = torch.einsum('b c k n, b c n s -> b k n s', fm_k, fm_q) 125 | am = (am - 1) * 99. 126 | am = am.repeat(b, 1, 1, 1) 127 | 128 | if itr == 0: 129 | feat_q, feat_k = self.to_qk(feat_c).chunk(2, dim=1) 130 | feat_k = self.unfold(feat_k).view(b, c, -1, h_sc*w_sc) 131 | feat_k = self.scale * feat_k 132 | feat_q = feat_q.view(b, c, h_sc, self.sc, w_sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(b, c, h_sc*w_sc, self.sc**2) 133 | attn = torch.einsum('b c k n, b c n s -> b k n s', feat_k, feat_q) 134 | attn = attn + am 135 | 136 | ker_S = self._FS(attn, [b, c, h, w, h_sc, w_sc]) 137 | attn_kpa = ker_S.view(attn.shape) * attn 138 | self.attn = F.softmax(attn_kpa, dim=1) 139 | 140 | feat_v = self.to_v(feat_m) 141 | feat_v = self.unfold(feat_v).view(b, c, -1, h_sc*w_sc) 142 | feat_r = torch.einsum('b k n s, b c k n -> b c n s', self.attn, feat_v) 143 | feat_r = feat_r.view(b, c, h_sc, w_sc, self.sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(b, c, h, w) 144 | feat_r = feat_r[:,:,:h_in,:w_in] 145 | 146 | feat_o = feat_mi + feat_r * self.gamma 147 | return feat_o 148 | 149 | 150 | class KPAEnc(nn.Module): 151 | def __init__(self, args, chnn, sc): 152 | super().__init__() 153 | self.sc = sc 154 | self.unfold = nn.Unfold(kernel_size=3*self.sc, dilation=1, padding=self.sc, stride=self.sc) 155 | self.scale = chnn ** -0.5 156 | self.to_qk = nn.Conv2d(chnn, chnn * 2, 1, bias=False) 157 | self.to_v = nn.Conv2d(chnn, chnn, 1, bias=False) 158 | self.gamma = nn.Parameter(torch.zeros(1)) 159 | self.mask_k = True 160 | 161 | def forward(self, inputs): 162 | feat_i = inputs 163 | b, c, h_in, w_in = feat_i.shape 164 | x_pad = self.sc - w_in % self.sc 165 | y_pad = self.sc - h_in % self.sc 166 | feat = F.pad(feat_i, (0, x_pad, 0, y_pad)) 167 | b, c, h, w = feat.shape 168 | h_sc = h // self.sc 169 | w_sc = w // self.sc 170 | 171 | fm = torch.ones(1, 1, h_in, w_in).to(feat.device) 172 | fm = F.pad(fm, (0, x_pad, 0, y_pad)) 173 | fm_k = self.unfold(fm).view(1, 1, -1, h_sc*w_sc) 174 | fm_q = fm.view(1, 1, h_sc, self.sc, w_sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(1, 1, h_sc*w_sc, self.sc**2) 175 | am = torch.einsum('b c k n, b c n s -> b k n s', fm_k, fm_q) 176 | am = (am - 1) * 99. 177 | am = am.repeat(b, 1, 1, 1) 178 | 179 | feat_q, feat_k = self.to_qk(feat).chunk(2, dim=1) 180 | feat_k = self.unfold(feat_k).view(b, c, -1, h_sc*w_sc) 181 | feat_k = self.scale * feat_k 182 | feat_q = feat_q.view(b, c, h_sc, self.sc, w_sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(b, c, h_sc*w_sc, self.sc**2) 183 | attn = torch.einsum('b c k n, b c n s -> b k n s', feat_k, feat_q) 184 | 185 | attn = attn + am 186 | self.attn = F.softmax(attn, dim=1) 187 | 188 | feat_v = self.to_v(feat) 189 | feat_v = self.unfold(feat_v).view(b, c, -1, h_sc*w_sc) 190 | feat_r = torch.einsum('b k n s, b c k n -> b c n s', self.attn, feat_v) 191 | feat_r = feat_r.view(b, c, h_sc, w_sc, self.sc, self.sc).permute(0, 1, 2, 4, 3, 5).contiguous().view(b, c, h, w) 192 | feat_r = feat_r[:,:,:h_in,:w_in] 193 | feat_o = feat_i + feat_r * self.gamma 194 | return feat_o 195 | 196 | 197 | class SepConvGRU(nn.Module): 198 | def __init__(self, hidden_dim=128, input_dim=192+128): 199 | super(SepConvGRU, self).__init__() 200 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 201 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 202 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 203 | 204 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 205 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 206 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 207 | 208 | def forward(self, h, x): 209 | # horizontal 210 | hx = torch.cat([h, x], dim=1) 211 | z = torch.sigmoid(self.convz1(hx)) 212 | r = torch.sigmoid(self.convr1(hx)) 213 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 214 | h = (1-z) * h + z * q 215 | 216 | # vertical 217 | hx = torch.cat([h, x], dim=1) 218 | z = torch.sigmoid(self.convz2(hx)) 219 | r = torch.sigmoid(self.convr2(hx)) 220 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 221 | h = (1-z) * h + z * q 222 | 223 | return h 224 | -------------------------------------------------------------------------------- /core/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from update import BasicUpdateBlock, SmallUpdateBlock 7 | from extractor import BasicEncoder, SmallEncoder 8 | from corr import CorrBlock, AlternateCorrBlock 9 | from utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | def __enter__(self): 19 | pass 20 | def __exit__(self, *args): 21 | pass 22 | 23 | 24 | class RAFT(nn.Module): 25 | def __init__(self, args): 26 | super(RAFT, self).__init__() 27 | self.args = args 28 | 29 | if args.small: 30 | self.hidden_dim = hdim = 96 31 | self.context_dim = cdim = 64 32 | args.corr_levels = 4 33 | args.corr_radius = 3 34 | 35 | else: 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | args.corr_levels = 4 39 | args.corr_radius = 4 40 | 41 | if 'dropout' not in self.args: 42 | self.args.dropout = 0 43 | 44 | if 'alternate_corr' not in self.args: 45 | self.args.alternate_corr = False 46 | 47 | # feature network, context network, and update block 48 | if args.small: 49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) 51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 52 | 53 | else: 54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 57 | 58 | def freeze_bn(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.BatchNorm2d): 61 | m.eval() 62 | 63 | def initialize_flow(self, img): 64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 65 | N, C, H, W = img.shape 66 | coords0 = coords_grid(N, H//8, W//8, device=img.device) 67 | coords1 = coords_grid(N, H//8, W//8, device=img.device) 68 | 69 | # optical flow computed as difference: flow = coords1 - coords0 70 | return coords0, coords1 71 | 72 | def upsample_flow(self, flow, mask): 73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 74 | N, _, H, W = flow.shape 75 | mask = mask.view(N, 1, 9, 8, 8, H, W) 76 | mask = torch.softmax(mask, dim=2) 77 | 78 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 80 | 81 | up_flow = torch.sum(mask * up_flow, dim=2) 82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 83 | return up_flow.reshape(N, 2, 8*H, 8*W) 84 | 85 | 86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 87 | """ Estimate optical flow between pair of frames """ 88 | 89 | image1 = 2 * (image1 / 255.0) - 1.0 90 | image2 = 2 * (image2 / 255.0) - 1.0 91 | 92 | image1 = image1.contiguous() 93 | image2 = image2.contiguous() 94 | 95 | hdim = self.hidden_dim 96 | cdim = self.context_dim 97 | 98 | # run the feature network 99 | with autocast(enabled=self.args.mixed_precision): 100 | fmap1, fmap2 = self.fnet([image1, image2]) 101 | 102 | fmap1 = fmap1.float() 103 | fmap2 = fmap2.float() 104 | if self.args.alternate_corr: 105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 106 | else: 107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | 109 | # run the context network 110 | with autocast(enabled=self.args.mixed_precision): 111 | cnet = self.cnet(image1) 112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 113 | net = torch.tanh(net) 114 | inp = torch.relu(inp) 115 | 116 | coords0, coords1 = self.initialize_flow(image1) 117 | 118 | if flow_init is not None: 119 | coords1 = coords1 + flow_init 120 | 121 | flow_predictions = [] 122 | for itr in range(iters): 123 | coords1 = coords1.detach() 124 | corr = corr_fn(coords1) # index correlation volume 125 | 126 | flow = coords1 - coords0 127 | with autocast(enabled=self.args.mixed_precision): 128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 129 | 130 | # F(t+1) = F(t) + \Delta(t) 131 | coords1 = coords1 + delta_flow 132 | 133 | # upsample predictions 134 | if up_mask is None: 135 | flow_up = upflow8(coords1 - coords0) 136 | else: 137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 138 | 139 | flow_predictions.append(flow_up) 140 | 141 | if test_mode: 142 | return coords1 - coords0, flow_up 143 | 144 | return flow_predictions 145 | -------------------------------------------------------------------------------- /core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LA30/FlowDiffuser/9aff9c6e8c68f809e40bb0ae4273621276686168/core/utils/__init__.py -------------------------------------------------------------------------------- /core/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, img1, img2): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if np.random.rand() < self.asymmetric_color_aug_prob: 41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 43 | 44 | # symmetric 45 | else: 46 | image_stack = np.concatenate([img1, img2], axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | img1, img2 = np.split(image_stack, 2, axis=0) 49 | 50 | return img1, img2 51 | 52 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 53 | """ Occlusion augmentation """ 54 | 55 | ht, wd = img1.shape[:2] 56 | if np.random.rand() < self.eraser_aug_prob: 57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | x0 = np.random.randint(0, wd) 60 | y0 = np.random.randint(0, ht) 61 | dx = np.random.randint(bounds[0], bounds[1]) 62 | dy = np.random.randint(bounds[0], bounds[1]) 63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 64 | 65 | return img1, img2 66 | 67 | def spatial_transform(self, img1, img2, flow): 68 | # randomly sample scale 69 | ht, wd = img1.shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | flow = flow * [scale_x, scale_y] 90 | 91 | if self.do_flip: 92 | if np.random.rand() < self.h_flip_prob: # h-flip 93 | img1 = img1[:, ::-1] 94 | img2 = img2[:, ::-1] 95 | flow = flow[:, ::-1] * [-1.0, 1.0] 96 | 97 | if np.random.rand() < self.v_flip_prob: # v-flip 98 | img1 = img1[::-1, :] 99 | img2 = img2[::-1, :] 100 | flow = flow[::-1, :] * [1.0, -1.0] 101 | 102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 104 | 105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 108 | 109 | return img1, img2, flow 110 | 111 | def __call__(self, img1, img2, flow): 112 | img1, img2 = self.color_transform(img1, img2) 113 | img1, img2 = self.eraser_transform(img1, img2) 114 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 115 | 116 | img1 = np.ascontiguousarray(img1) 117 | img2 = np.ascontiguousarray(img2) 118 | flow = np.ascontiguousarray(flow) 119 | 120 | return img1, img2, flow 121 | 122 | class SparseFlowAugmentor: 123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 124 | # spatial augmentation params 125 | self.crop_size = crop_size 126 | self.min_scale = min_scale 127 | self.max_scale = max_scale 128 | self.spatial_aug_prob = 0.8 129 | self.stretch_prob = 0.8 130 | self.max_stretch = 0.2 131 | 132 | # flip augmentation params 133 | self.do_flip = do_flip 134 | self.h_flip_prob = 0.5 135 | self.v_flip_prob = 0.1 136 | 137 | # photometric augmentation params 138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 139 | self.asymmetric_color_aug_prob = 0.2 140 | self.eraser_aug_prob = 0.5 141 | 142 | def color_transform(self, img1, img2): 143 | image_stack = np.concatenate([img1, img2], axis=0) 144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 145 | img1, img2 = np.split(image_stack, 2, axis=0) 146 | return img1, img2 147 | 148 | def eraser_transform(self, img1, img2): 149 | ht, wd = img1.shape[:2] 150 | if np.random.rand() < self.eraser_aug_prob: 151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 152 | for _ in range(np.random.randint(1, 3)): 153 | x0 = np.random.randint(0, wd) 154 | y0 = np.random.randint(0, ht) 155 | dx = np.random.randint(50, 100) 156 | dy = np.random.randint(50, 100) 157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 158 | 159 | return img1, img2 160 | 161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 162 | ht, wd = flow.shape[:2] 163 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 164 | coords = np.stack(coords, axis=-1) 165 | 166 | coords = coords.reshape(-1, 2).astype(np.float32) 167 | flow = flow.reshape(-1, 2).astype(np.float32) 168 | valid = valid.reshape(-1).astype(np.float32) 169 | 170 | coords0 = coords[valid>=1] 171 | flow0 = flow[valid>=1] 172 | 173 | ht1 = int(round(ht * fy)) 174 | wd1 = int(round(wd * fx)) 175 | 176 | coords1 = coords0 * [fx, fy] 177 | flow1 = flow0 * [fx, fy] 178 | 179 | xx = np.round(coords1[:,0]).astype(np.int32) 180 | yy = np.round(coords1[:,1]).astype(np.int32) 181 | 182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 183 | xx = xx[v] 184 | yy = yy[v] 185 | flow1 = flow1[v] 186 | 187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 189 | 190 | flow_img[yy, xx] = flow1 191 | valid_img[yy, xx] = 1 192 | 193 | return flow_img, valid_img 194 | 195 | def spatial_transform(self, img1, img2, flow, valid): 196 | # randomly sample scale 197 | 198 | ht, wd = img1.shape[:2] 199 | min_scale = np.maximum( 200 | (self.crop_size[0] + 1) / float(ht), 201 | (self.crop_size[1] + 1) / float(wd)) 202 | 203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 204 | scale_x = np.clip(scale, min_scale, None) 205 | scale_y = np.clip(scale, min_scale, None) 206 | 207 | if np.random.rand() < self.spatial_aug_prob: 208 | # rescale the images 209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 212 | 213 | if self.do_flip: 214 | if np.random.rand() < 0.5: # h-flip 215 | img1 = img1[:, ::-1] 216 | img2 = img2[:, ::-1] 217 | flow = flow[:, ::-1] * [-1.0, 1.0] 218 | valid = valid[:, ::-1] 219 | 220 | margin_y = 20 221 | margin_x = 50 222 | 223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 225 | 226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 228 | 229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 233 | return img1, img2, flow, valid 234 | 235 | 236 | def __call__(self, img1, img2, flow, valid): 237 | img1, img2 = self.color_transform(img1, img2) 238 | img1, img2 = self.eraser_transform(img1, img2) 239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 240 | 241 | img1 = np.ascontiguousarray(img1) 242 | img2 = np.ascontiguousarray(img2) 243 | flow = np.ascontiguousarray(flow) 244 | valid = np.ascontiguousarray(valid) 245 | 246 | return img1, img2, flow, valid 247 | -------------------------------------------------------------------------------- /core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd, device): 75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python evaluate.py --model=weights/FlowDiffuser-things.pth --dataset=sintel 3 | python evaluate.py --model=weights/FlowDiffuser-things.pth --dataset=kitti -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | from PIL import Image 5 | import argparse 6 | import os 7 | import time 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import matplotlib.pyplot as plt 12 | 13 | import datasets 14 | from utils import flow_viz 15 | from utils import frame_utils 16 | 17 | from flowdiffuser import FlowDiffuser 18 | 19 | from utils.utils import InputPadder, forward_interpolate 20 | 21 | 22 | @torch.no_grad() 23 | def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'): 24 | """ Create submission for the Sintel leaderboard """ 25 | model.eval() 26 | for dstype in ['clean', 'final']: 27 | test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype) 28 | 29 | flow_prev, sequence_prev = None, None 30 | for test_id in range(len(test_dataset)): 31 | image1, image2, (sequence, frame) = test_dataset[test_id] 32 | # if sequence != sequence_prev: 33 | # flow_prev = None 34 | 35 | if (sequence != sequence_prev) or (dstype == 'final' and sequence in ['market_4', ]) or dstype == 'clean': 36 | flow_prev = None 37 | 38 | padder = InputPadder(image1.shape) 39 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 40 | 41 | flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True) 42 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 43 | 44 | if warm_start: 45 | flow_prev = forward_interpolate(flow_low[0])[None].cuda() 46 | 47 | output_dir = os.path.join(output_path, dstype, sequence) 48 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1)) 49 | 50 | if not os.path.exists(output_dir): 51 | os.makedirs(output_dir) 52 | 53 | frame_utils.writeFlow(output_file, flow) 54 | sequence_prev = sequence 55 | 56 | 57 | @torch.no_grad() 58 | def create_kitti_submission(model, iters=24, output_path='kitti_submission'): 59 | """ Create submission for the Sintel leaderboard """ 60 | model.eval() 61 | test_dataset = datasets.KITTI(split='testing', aug_params=None) 62 | 63 | if not os.path.exists(output_path): 64 | os.makedirs(output_path) 65 | 66 | for test_id in range(len(test_dataset)): 67 | image1, image2, (frame_id, ) = test_dataset[test_id] 68 | padder = InputPadder(image1.shape, mode='kitti') 69 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 70 | 71 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True) 72 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 73 | 74 | output_filename = os.path.join(output_path, frame_id) 75 | frame_utils.writeFlowKITTI(output_filename, flow) 76 | 77 | 78 | @torch.no_grad() 79 | def validate_chairs(model, iters=24): 80 | """ Perform evaluation on the FlyingChairs (test) split """ 81 | model.eval() 82 | epe_list = [] 83 | 84 | val_dataset = datasets.FlyingChairs(split='validation') 85 | for val_id in range(len(val_dataset)): 86 | image1, image2, flow_gt, _ = val_dataset[val_id] 87 | image1 = image1[None].cuda() 88 | image2 = image2[None].cuda() 89 | 90 | _, flow_pr = model(image1, image2, iters=iters, test_mode=True) 91 | epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt() 92 | epe_list.append(epe.view(-1).numpy()) 93 | 94 | epe = np.mean(np.concatenate(epe_list)) 95 | print("Validation Chairs EPE: %f" % epe) 96 | return {'chairs': epe} 97 | 98 | 99 | @torch.no_grad() 100 | def validate_sintel(model, iters=32): 101 | """ Peform validation using the Sintel (train) split """ 102 | model.eval() 103 | results = {} 104 | for dstype in ['clean', 'final']: 105 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype) 106 | epe_list = [] 107 | 108 | for val_id in range(len(val_dataset)): 109 | image1, image2, flow_gt, _ = val_dataset[val_id] 110 | image1 = image1[None].cuda() 111 | image2 = image2[None].cuda() 112 | 113 | padder = InputPadder(image1.shape) 114 | image1, image2 = padder.pad(image1, image2) 115 | 116 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) 117 | flow = padder.unpad(flow_pr[0]).cpu() 118 | 119 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 120 | epe_list.append(epe.view(-1).numpy()) 121 | 122 | epe_all = np.concatenate(epe_list) 123 | epe = np.mean(epe_all) 124 | px1 = np.mean(epe_all<1) 125 | px3 = np.mean(epe_all<3) 126 | px5 = np.mean(epe_all<5) 127 | 128 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) 129 | results[dstype] = np.mean(epe_list) 130 | 131 | return results 132 | 133 | 134 | @torch.no_grad() 135 | def validate_kitti(model, iters=24): 136 | """ Peform validation using the KITTI-2015 (train) split """ 137 | model.eval() 138 | val_dataset = datasets.KITTI(split='training') 139 | 140 | out_list, epe_list = [], [] 141 | for val_id in range(len(val_dataset)): 142 | image1, image2, flow_gt, valid_gt = val_dataset[val_id] 143 | image1 = image1[None].cuda() 144 | image2 = image2[None].cuda() 145 | 146 | padder = InputPadder(image1.shape, mode='kitti') 147 | image1, image2 = padder.pad(image1, image2) 148 | 149 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) 150 | flow = padder.unpad(flow_pr[0]).cpu() 151 | 152 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 153 | mag = torch.sum(flow_gt**2, dim=0).sqrt() 154 | 155 | epe = epe.view(-1) 156 | mag = mag.view(-1) 157 | val = valid_gt.view(-1) >= 0.5 158 | 159 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() 160 | epe_list.append(epe[val].mean().item()) 161 | out_list.append(out[val].cpu().numpy()) 162 | 163 | epe_list = np.array(epe_list) 164 | out_list = np.concatenate(out_list) 165 | 166 | epe = np.mean(epe_list) 167 | f1 = 100 * np.mean(out_list) 168 | 169 | print("Validation KITTI: %f, %f" % (epe, f1)) 170 | return {'kitti-epe': epe, 'kitti-f1': f1} 171 | 172 | 173 | if __name__ == '__main__': 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument('--model', help="restore checkpoint") 176 | parser.add_argument('--dataset', help="dataset for evaluation") 177 | parser.add_argument('--small', action='store_true', help='use small model') 178 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 179 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 180 | args = parser.parse_args() 181 | 182 | model = torch.nn.DataParallel(FlowDiffuser(args)) 183 | model.load_state_dict(torch.load(args.model)) 184 | 185 | model.cuda() 186 | model.eval() 187 | 188 | # create_sintel_submission(model.module, warm_start=True) 189 | # create_kitti_submission(model.module) 190 | 191 | with torch.no_grad(): 192 | if args.dataset == 'chairs': 193 | validate_chairs(model.module) 194 | 195 | elif args.dataset == 'sintel': 196 | validate_sintel(model.module) 197 | 198 | elif args.dataset == 'kitti': 199 | validate_kitti(model.module) 200 | 201 | 202 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import sys 3 | sys.path.append('core') 4 | 5 | import argparse 6 | import os 7 | import cv2 8 | import time 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | 17 | from torch.utils.data import DataLoader 18 | import evaluate 19 | import datasets 20 | from flowdiffuser import FlowDiffuser 21 | 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | from torch.cuda.amp import GradScaler 25 | 26 | 27 | # exclude extremly large displacements 28 | MAX_FLOW = 400 29 | SUM_FREQ = 100 30 | VAL_FREQ = 5000 31 | 32 | 33 | def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW): 34 | """ Loss function defined over sequence of flow predictions """ 35 | 36 | n_predictions = len(flow_preds) 37 | flow_loss = 0.0 38 | 39 | # exlude invalid pixels and extremely large diplacements 40 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 41 | valid = (valid >= 0.5) & (mag < max_flow) 42 | 43 | for i in range(n_predictions): 44 | i_weight = gamma**(n_predictions - i - 1) 45 | i_loss = (flow_preds[i] - flow_gt).abs() 46 | flow_loss += i_weight * (valid[:, None] * i_loss).mean() 47 | 48 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() 49 | epe = epe.view(-1)[valid.view(-1)] 50 | 51 | metrics = { 52 | 'epe': epe.mean().item(), 53 | '1px': (epe < 1).float().mean().item(), 54 | '3px': (epe < 3).float().mean().item(), 55 | '5px': (epe < 5).float().mean().item(), 56 | } 57 | 58 | return flow_loss, metrics 59 | 60 | 61 | def count_parameters(model): 62 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 63 | 64 | 65 | def fetch_optimizer(args, model): 66 | """ Create the optimizer and learning rate scheduler """ 67 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) 68 | 69 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, 70 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear') 71 | 72 | return optimizer, scheduler 73 | 74 | 75 | class Logger: 76 | def __init__(self, model, scheduler): 77 | self.model = model 78 | self.scheduler = scheduler 79 | self.total_steps = 0 80 | self.running_loss = {} 81 | self.writer = None 82 | 83 | def _print_training_status(self): 84 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())] 85 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0]) 86 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data) 87 | 88 | # print the training status 89 | print(training_str + metrics_str) 90 | 91 | if self.writer is None: 92 | self.writer = SummaryWriter() 93 | 94 | for k in self.running_loss: 95 | self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps) 96 | self.running_loss[k] = 0.0 97 | 98 | def push(self, metrics): 99 | self.total_steps += 1 100 | 101 | for key in metrics: 102 | if key not in self.running_loss: 103 | self.running_loss[key] = 0.0 104 | 105 | self.running_loss[key] += metrics[key] 106 | 107 | if self.total_steps % SUM_FREQ == SUM_FREQ-1: 108 | self._print_training_status() 109 | self.running_loss = {} 110 | 111 | def write_dict(self, results): 112 | if self.writer is None: 113 | self.writer = SummaryWriter() 114 | 115 | for key in results: 116 | self.writer.add_scalar(key, results[key], self.total_steps) 117 | 118 | def close(self): 119 | self.writer.close() 120 | 121 | 122 | def train(args): 123 | 124 | model = nn.DataParallel(FlowDiffuser(args), device_ids=args.gpus) 125 | print("Parameter Count: %d" % count_parameters(model)) 126 | 127 | if args.restore_ckpt is not None: 128 | model.load_state_dict(torch.load(args.restore_ckpt), strict=False) 129 | 130 | model.cuda() 131 | model.train() 132 | 133 | if args.stage != 'chairs': 134 | model.module.freeze_bn() 135 | 136 | train_loader = datasets.fetch_dataloader(args) 137 | optimizer, scheduler = fetch_optimizer(args, model) 138 | 139 | total_steps = 0 140 | scaler = GradScaler(enabled=args.mixed_precision) 141 | logger = Logger(model, scheduler) 142 | 143 | VAL_FREQ = 5000 144 | add_noise = True 145 | 146 | should_keep_training = True 147 | while should_keep_training: 148 | 149 | for i_batch, data_blob in enumerate(train_loader): 150 | optimizer.zero_grad() 151 | image1, image2, flow, valid = [x.cuda() for x in data_blob] 152 | 153 | if args.add_noise: 154 | stdv = np.random.uniform(0.0, 5.0) 155 | image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0) 156 | image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0) 157 | 158 | flow_predictions = model(image1, image2, iters=args.iters, flow_gt=flow) 159 | 160 | loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma) 161 | scaler.scale(loss).backward() 162 | scaler.unscale_(optimizer) 163 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 164 | 165 | scaler.step(optimizer) 166 | scheduler.step() 167 | scaler.update() 168 | 169 | logger.push(metrics) 170 | 171 | if total_steps % VAL_FREQ == VAL_FREQ - 1: 172 | PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name) 173 | torch.save(model.state_dict(), PATH) 174 | 175 | results = {} 176 | for val_dataset in args.validation: 177 | if val_dataset == 'chairs': 178 | results.update(evaluate.validate_chairs(model.module)) 179 | elif val_dataset == 'sintel': 180 | results.update(evaluate.validate_sintel(model.module)) 181 | elif val_dataset == 'kitti': 182 | results.update(evaluate.validate_kitti(model.module)) 183 | 184 | logger.write_dict(results) 185 | 186 | model.train() 187 | if args.stage != 'chairs': 188 | model.module.freeze_bn() 189 | 190 | total_steps += 1 191 | 192 | if total_steps > args.num_steps: 193 | should_keep_training = False 194 | break 195 | 196 | logger.close() 197 | PATH = 'checkpoints/%s.pth' % args.name 198 | torch.save(model.state_dict(), PATH) 199 | 200 | return PATH 201 | 202 | 203 | if __name__ == '__main__': 204 | parser = argparse.ArgumentParser() 205 | parser.add_argument('--name', default='flowdiffuser', help="name your experiment") 206 | parser.add_argument('--stage', help="determines which dataset to use for training") 207 | parser.add_argument('--restore_ckpt', help="restore checkpoint") 208 | parser.add_argument('--small', action='store_true', help='use small model') 209 | parser.add_argument('--validation', type=str, nargs='+') 210 | 211 | parser.add_argument('--lr', type=float, default=0.00002) 212 | parser.add_argument('--num_steps', type=int, default=100000) 213 | parser.add_argument('--batch_size', type=int, default=6) 214 | parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512]) 215 | parser.add_argument('--gpus', type=int, nargs='+', default=[0,1]) 216 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 217 | 218 | parser.add_argument('--iters', type=int, default=12) 219 | parser.add_argument('--wdecay', type=float, default=.00005) 220 | parser.add_argument('--epsilon', type=float, default=1e-8) 221 | parser.add_argument('--clip', type=float, default=1.0) 222 | parser.add_argument('--dropout', type=float, default=0.0) 223 | parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting') 224 | parser.add_argument('--add_noise', action='store_true') 225 | args = parser.parse_args() 226 | 227 | torch.manual_seed(1234) 228 | np.random.seed(1234) 229 | 230 | if not os.path.isdir('checkpoints'): 231 | os.mkdir('checkpoints') 232 | 233 | train(args) -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p checkpoints 3 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -u train.py --name fd-chairs --stage chairs --validation chairs --gpus 0 1 2 3 4 5 --num_steps 100000 --batch_size 12 --lr 0.00045 --image_size 384 512 --wdecay 0.0001 4 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -u train.py --name fd-things --stage things --validation sintel --restore_ckpt checkpoints/fd-chairs.pth --gpus 0 1 2 3 4 5 --num_steps 200000 --batch_size 6 --lr 0.000175 --image_size 432 960 --wdecay 0.0001 5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -u train.py --name fd-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/fd-things.pth --gpus 0 1 2 3 4 5 --num_steps 180000 --batch_size 6 --lr 0.000175 --image_size 432 960 --wdecay 0.00001 --gamma=0.85 6 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -u train.py --name fd-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/fd-sintel.pth --gpus 0 1 2 3 4 5 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --------------------------------------------------------------------------------