├── LICENSE ├── README.md ├── assets ├── flow1d.png └── teaser.png ├── data ├── __init__.py ├── chairs_split.txt ├── datasets.py └── transforms.py ├── demo └── dogs-jump │ ├── 00033.jpg │ ├── 00034.jpg │ ├── 00035.jpg │ └── 00036.jpg ├── environment.yml ├── evaluate.py ├── flow1d ├── __init__.py ├── attention.py ├── correlation.py ├── extractor.py ├── flow1d.py ├── position.py └── update.py ├── loss.py ├── main.py ├── scripts ├── demo.sh ├── evaluate.sh └── train.sh └── utils ├── flow_viz.py ├── frame_utils.py ├── logger.py ├── misc.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Haofei Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flow1D 2 | 3 | Official PyTorch implementation of paper: 4 | 5 | [**High-Resolution Optical Flow from 1D Attention and Correlation**](https://arxiv.org/abs/2104.13918), **ICCV 2021, Oral** 6 | 7 | Authors: [Haofei Xu](https://haofeixu.github.io/), [Jiaolong Yang](https://jlyang.org/), [Jianfei Cai](https://jianfei-cai.github.io/), [Juyong Zhang](http://staff.ustc.edu.cn/~juyong/), [Xin Tong](https://scholar.google.com/citations?user=P91a-UQAAAAJ&hl=en&oi=ao) 8 | 9 | **11/15/2022 Update: Check out our new work: [Unifying Flow, Stereo and Depth Estimation](https://haofeixu.github.io/unimatch/) and code: [unimatch](https://github.com/autonomousvision/unimatch) for estimating optical flow with our new GMFlow model. [9 pretrained GMFlow models](https://github.com/autonomousvision/unimatch/blob/master/MODEL_ZOO.md) with different speed-accuray trade-offs are also released. Check out our [Colab](https://colab.research.google.com/drive/1r5m-xVy3Kw60U-m5VB-aQ98oqqg_6cab?usp=sharing) and [HuggingFace](https://huggingface.co/spaces/haofeixu/unimatch) demo to play with GMFlow in your browser!** 10 | 11 | We enabled **4K resolution** optical flow estimation by factorizing 2D optical flow with 1D attention and 1D correlation. 12 | 13 | 14 | 15 |

16 | 17 | 18 | 19 | 20 | The full framework: 21 | 22 | 23 | 24 | 25 |

26 | 27 | 28 | 29 | 30 | 31 | 32 | ## Installation 33 | 34 | Our code is based on pytorch 1.7.1, CUDA 10.2 and python 3.7. Higher version pytorch should also work well. 35 | 36 | We recommend using [conda](https://www.anaconda.com/distribution/) for installation: 37 | 38 | ``` 39 | conda env create -f environment.yml 40 | conda activate flow1d 41 | ``` 42 | 43 | ## Demos 44 | 45 | All pretrained models can be downloaded from [google drive](https://drive.google.com/file/d/1IzcmvxpY90DuXYkGkwitxslO1Psq52OI/view?usp=sharing). 46 | 47 | 48 | 49 | You can run a trained model on a sequence of images and visualize the results (as shown in [scripts/demo.sh](scripts/demo.sh)): 50 | 51 | ``` 52 | CUDA_VISIBLE_DEVICES=0 python main.py \ 53 | --resume pretrained/flow1d_highres-e0b98d7e.pth \ 54 | --val_iters 24 \ 55 | --inference_dir demo/dogs-jump \ 56 | --output_path output/flow1d-dogs-jump 57 | ``` 58 | 59 | 60 | 61 | ## Datasets 62 | 63 | The datasets used to train and evaluate Flow1D are as follows: 64 | 65 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 66 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 67 | * [Sintel](http://sintel.is.tue.mpg.de/) 68 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 69 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) 70 | 71 | By default the dataloader [datasets.py](data/datasets.py) assumes the datasets are located in folder `datasets` and are organized as follows: 72 | 73 | ``` 74 | datasets 75 | ├── FlyingChairs_release 76 | │ └── data 77 | ├── FlyingThings3D 78 | │ ├── frames_cleanpass 79 | │ ├── frames_finalpass 80 | │ └── optical_flow 81 | ├── HD1K 82 | │ ├── hd1k_challenge 83 | │ ├── hd1k_flow_gt 84 | │ ├── hd1k_flow_uncertainty 85 | │ └── hd1k_input 86 | ├── KITTI 87 | │ ├── testing 88 | │ └── training 89 | ├── Sintel 90 | │ ├── test 91 | │ └── training 92 | ``` 93 | 94 | It is recommended to symlink your dataset root to `datasets`: 95 | 96 | ```shell 97 | ln -s $YOUR_DATASET_ROOT datasets 98 | ``` 99 | 100 | Otherwise, you may need to change the corresponding paths in [datasets.py](data/datasets.py). 101 | 102 | 103 | 104 | ## Evaluation 105 | 106 | You can evaluate a trained Flow1D model by running: 107 | 108 | ``` 109 | CUDA_VISIBLE_DEVICES=0 python main.py --eval --val_dataset kitti --resume pretrained/flow1d_things-fd4bee1f.pth --val_iters 24 110 | ``` 111 | 112 | More evaluation scripts can be found in [scripts/evaluate.sh](scripts/evaluate.sh). 113 | 114 | 115 | 116 | ## Training 117 | 118 | All training scripts on FlyingChairs, FlyingThings3D, Sintel and KITTI datasets can be found in [scripts/train.sh](scripts/train.sh). 119 | 120 | Note that our Flow1D model can be trained on a single 32GB V100 GPU. You may need to tune the number of GPUs used for training according to your hardware. 121 | 122 | 123 | 124 | We support using tensorboard to monitor and visualize the training process. You can first start a tensorboard session with 125 | 126 | ```shell 127 | tensorboard --logdir checkpoints 128 | ``` 129 | 130 | and then access [http://localhost:6006](http://localhost:6006) in your browser. 131 | 132 | 133 | 134 | ## Citation 135 | 136 | If you find our work useful in your research, please consider citing our paper: 137 | 138 | ``` 139 | @inproceedings{xu2021high, 140 | title={High-Resolution Optical Flow from 1D Attention and Correlation}, 141 | author={Xu, Haofei and Yang, Jiaolong and Cai, Jianfei and Zhang, Juyong and Tong, Xin}, 142 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 143 | pages={10498--10507}, 144 | year={2021} 145 | } 146 | ``` 147 | 148 | 149 | 150 | ## Acknowledgements 151 | 152 | This project is heavily based on [RAFT](https://github.com/princeton-vl/RAFT). We thank the original authors for their excellent work. 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /assets/flow1d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/assets/flow1d.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/assets/teaser.png -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import build_dataset 2 | from .datasets import (FlyingChairs, 3 | FlyingThings3D, 4 | MpiSintel, 5 | KITTI, 6 | HD1K, 7 | ) 8 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | 7 | import os 8 | import random 9 | from glob import glob 10 | import os.path as osp 11 | 12 | from utils import frame_utils 13 | from data.transforms import FlowAugmentor, SparseFlowAugmentor 14 | 15 | 16 | class FlowDataset(data.Dataset): 17 | def __init__(self, aug_params=None, sparse=False, 18 | ): 19 | self.augmentor = None 20 | self.sparse = sparse 21 | 22 | if aug_params is not None: 23 | if sparse: 24 | self.augmentor = SparseFlowAugmentor(**aug_params) 25 | else: 26 | self.augmentor = FlowAugmentor(**aug_params) 27 | 28 | self.is_test = False 29 | self.init_seed = False 30 | self.flow_list = [] 31 | self.image_list = [] 32 | self.extra_info = [] 33 | 34 | def __getitem__(self, index): 35 | 36 | if self.is_test: 37 | img1 = frame_utils.read_gen(self.image_list[index][0]) 38 | img2 = frame_utils.read_gen(self.image_list[index][1]) 39 | img1 = np.array(img1).astype(np.uint8)[..., :3] 40 | img2 = np.array(img2).astype(np.uint8)[..., :3] 41 | 42 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 43 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 44 | 45 | return img1, img2, self.extra_info[index] 46 | 47 | if not self.init_seed: 48 | worker_info = torch.utils.data.get_worker_info() 49 | if worker_info is not None: 50 | torch.manual_seed(worker_info.id) 51 | np.random.seed(worker_info.id) 52 | random.seed(worker_info.id) 53 | self.init_seed = True 54 | 55 | index = index % len(self.image_list) 56 | valid = None 57 | if self.sparse: 58 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 59 | else: 60 | flow = frame_utils.read_gen(self.flow_list[index]) 61 | 62 | img1 = frame_utils.read_gen(self.image_list[index][0]) 63 | img2 = frame_utils.read_gen(self.image_list[index][1]) 64 | 65 | flow = np.array(flow).astype(np.float32) 66 | img1 = np.array(img1).astype(np.uint8) 67 | img2 = np.array(img2).astype(np.uint8) 68 | 69 | # grayscale images 70 | if len(img1.shape) == 2: 71 | img1 = np.tile(img1[..., None], (1, 1, 3)) 72 | img2 = np.tile(img2[..., None], (1, 1, 3)) 73 | else: 74 | img1 = img1[..., :3] 75 | img2 = img2[..., :3] 76 | 77 | if self.augmentor is not None: 78 | if self.sparse: 79 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 80 | else: 81 | img1, img2, flow = self.augmentor(img1, img2, flow) 82 | 83 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 84 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 85 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 86 | 87 | if valid is not None: 88 | valid = torch.from_numpy(valid) 89 | else: 90 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 91 | 92 | return img1, img2, flow, valid.float() 93 | 94 | def __rmul__(self, v): 95 | self.flow_list = v * self.flow_list 96 | self.image_list = v * self.image_list 97 | 98 | return self 99 | 100 | def __len__(self): 101 | return len(self.image_list) 102 | 103 | 104 | class MpiSintel(FlowDataset): 105 | def __init__(self, aug_params=None, split='training', 106 | root='datasets/Sintel', 107 | dstype='clean'): 108 | super(MpiSintel, self).__init__(aug_params) 109 | 110 | flow_root = osp.join(root, split, 'flow') 111 | image_root = osp.join(root, split, dstype) 112 | 113 | if split == 'test': 114 | self.is_test = True 115 | 116 | for scene in os.listdir(image_root): 117 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 118 | for i in range(len(image_list) - 1): 119 | self.image_list += [[image_list[i], image_list[i + 1]]] 120 | self.extra_info += [(scene, i)] # scene and frame_id 121 | 122 | if split != 'test': 123 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 124 | 125 | 126 | class FlyingChairs(FlowDataset): 127 | def __init__(self, aug_params=None, split='train', 128 | root='datasets/FlyingChairs_release/data', 129 | ): 130 | super(FlyingChairs, self).__init__(aug_params) 131 | 132 | images = sorted(glob(osp.join(root, '*.ppm'))) 133 | flows = sorted(glob(osp.join(root, '*.flo'))) 134 | assert (len(images) // 2 == len(flows)) 135 | 136 | split_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'chairs_split.txt') 137 | split_list = np.loadtxt(split_file, dtype=np.int32) 138 | for i in range(len(flows)): 139 | xid = split_list[i] 140 | if (split == 'training' and xid == 1) or (split == 'validation' and xid == 2): 141 | self.flow_list += [flows[i]] 142 | self.image_list += [[images[2 * i], images[2 * i + 1]]] 143 | 144 | 145 | class FlyingThings3D(FlowDataset): 146 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', 147 | dstype='frames_cleanpass'): 148 | super(FlyingThings3D, self).__init__(aug_params) 149 | 150 | img_dir = root 151 | flow_dir = root 152 | 153 | for cam in ['left']: 154 | for direction in ['into_future', 'into_past']: 155 | image_dirs = sorted(glob(osp.join(img_dir, dstype, 'TRAIN/*/*'))) 156 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 157 | 158 | flow_dirs = sorted(glob(osp.join(flow_dir, 'optical_flow/TRAIN/*/*'))) 159 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 160 | 161 | for idir, fdir in zip(image_dirs, flow_dirs): 162 | images = sorted(glob(osp.join(idir, '*.png'))) 163 | flows = sorted(glob(osp.join(fdir, '*.pfm'))) 164 | for i in range(len(flows) - 1): 165 | if direction == 'into_future': 166 | self.image_list += [[images[i], images[i + 1]]] 167 | self.flow_list += [flows[i]] 168 | elif direction == 'into_past': 169 | self.image_list += [[images[i + 1], images[i]]] 170 | self.flow_list += [flows[i + 1]] 171 | 172 | 173 | class KITTI(FlowDataset): 174 | def __init__(self, aug_params=None, split='training', 175 | root='datasets/KITTI', 176 | ): 177 | super(KITTI, self).__init__(aug_params, sparse=True) 178 | if split == 'testing': 179 | self.is_test = True 180 | 181 | root = osp.join(root, split) 182 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 183 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 184 | 185 | for img1, img2 in zip(images1, images2): 186 | frame_id = img1.split('/')[-1] 187 | self.extra_info += [[frame_id]] 188 | self.image_list += [[img1, img2]] 189 | 190 | if split == 'training': 191 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 192 | 193 | 194 | class HD1K(FlowDataset): 195 | def __init__(self, aug_params=None, root='datasets/HD1K'): 196 | super(HD1K, self).__init__(aug_params, sparse=True) 197 | 198 | seq_ix = 0 199 | while 1: 200 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 201 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 202 | 203 | if len(flows) == 0: 204 | break 205 | 206 | for i in range(len(flows) - 1): 207 | self.flow_list += [flows[i]] 208 | self.image_list += [[images[i], images[i + 1]]] 209 | 210 | seq_ix += 1 211 | 212 | 213 | def build_dataset(args): 214 | """ Create the data loader for the corresponding training set """ 215 | if args.stage == 'chairs': 216 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 217 | 218 | train_dataset = FlyingChairs(aug_params, split='training') 219 | 220 | elif args.stage == 'things': 221 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 222 | 223 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 224 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 225 | 226 | train_dataset = clean_dataset + final_dataset 227 | 228 | elif args.stage == 'sintel': 229 | # 1041 pairs for clean and final each 230 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 231 | 232 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 233 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 234 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 235 | 236 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 237 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 238 | 239 | train_dataset = 100 * sintel_clean + 100 * sintel_final + 200 * kitti + 5 * hd1k + things 240 | 241 | elif args.stage == 'kitti': 242 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 243 | train_dataset = KITTI(aug_params, split='training') 244 | 245 | else: 246 | raise ValueError(f'stage {args.stage} is not supported') 247 | 248 | return train_dataset 249 | -------------------------------------------------------------------------------- /data/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | import cv2 5 | 6 | from torchvision.transforms import ColorJitter 7 | 8 | 9 | class FlowAugmentor: 10 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True, 11 | resize_when_needed=False, 12 | no_eraser_aug=False, 13 | ): 14 | # TODO: support resize to higher resolution, and then do croping 15 | # for instance, resize all slow_flow data to 1024x1280 16 | 17 | # spatial augmentation params 18 | self.crop_size = crop_size 19 | self.min_scale = min_scale 20 | self.max_scale = max_scale 21 | self.spatial_aug_prob = 0.8 22 | self.stretch_prob = 0.8 23 | self.max_stretch = 0.2 24 | 25 | self.resize_when_needed = resize_when_needed 26 | 27 | # flip augmentation params 28 | self.do_flip = do_flip 29 | self.h_flip_prob = 0.5 30 | self.v_flip_prob = 0.1 31 | 32 | # photometric augmentation params 33 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14) 34 | self.asymmetric_color_aug_prob = 0.2 35 | 36 | if no_eraser_aug: 37 | self.eraser_aug_prob = -1 38 | else: 39 | self.eraser_aug_prob = 0.5 40 | 41 | def color_transform(self, img1, img2): 42 | """ Photometric augmentation """ 43 | 44 | # asymmetric 45 | if np.random.rand() < self.asymmetric_color_aug_prob: 46 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 47 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 48 | 49 | # symmetric 50 | else: 51 | image_stack = np.concatenate([img1, img2], axis=0) 52 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 53 | img1, img2 = np.split(image_stack, 2, axis=0) 54 | 55 | return img1, img2 56 | 57 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 58 | """ Occlusion augmentation """ 59 | 60 | ht, wd = img1.shape[:2] 61 | if np.random.rand() < self.eraser_aug_prob: 62 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 63 | for _ in range(np.random.randint(1, 3)): 64 | x0 = np.random.randint(0, wd) 65 | y0 = np.random.randint(0, ht) 66 | dx = np.random.randint(bounds[0], bounds[1]) 67 | dy = np.random.randint(bounds[0], bounds[1]) 68 | img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color 69 | 70 | return img1, img2 71 | 72 | def spatial_transform(self, img1, img2, flow, backward_flow=None, occlusion=None, backward_occlusion=None): 73 | # randomly sample scale 74 | ht, wd = img1.shape[:2] 75 | min_scale = np.maximum( 76 | (self.crop_size[0] + 8) / float(ht), 77 | (self.crop_size[1] + 8) / float(wd)) 78 | 79 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 80 | scale_x = scale 81 | scale_y = scale 82 | if np.random.rand() < self.stretch_prob: 83 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 84 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 85 | 86 | scale_x = np.clip(scale_x, min_scale, None) 87 | scale_y = np.clip(scale_y, min_scale, None) 88 | 89 | if np.random.rand() < self.spatial_aug_prob: 90 | # rescale the images 91 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 92 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 93 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 94 | flow = flow * [scale_x, scale_y] 95 | 96 | if backward_flow is not None: 97 | backward_flow = cv2.resize(backward_flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 98 | backward_flow = backward_flow * [scale_x, scale_y] 99 | 100 | if occlusion is not None: 101 | occlusion = cv2.resize(occlusion, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 102 | if backward_occlusion is not None: 103 | backward_occlusion = cv2.resize(backward_occlusion, None, fx=scale_x, fy=scale_y, 104 | interpolation=cv2.INTER_LINEAR) 105 | 106 | if self.do_flip: 107 | if np.random.rand() < self.h_flip_prob: # h-flip 108 | img1 = img1[:, ::-1] 109 | img2 = img2[:, ::-1] 110 | flow = flow[:, ::-1] * [-1.0, 1.0] 111 | 112 | if backward_flow is not None: 113 | backward_flow = backward_flow[:, ::-1] * [-1.0, 1.0] 114 | 115 | if occlusion is not None: 116 | occlusion = occlusion[:, ::-1] 117 | if backward_occlusion is not None: 118 | backward_occlusion = backward_occlusion[:, ::-1] 119 | 120 | if np.random.rand() < self.v_flip_prob: # v-flip 121 | img1 = img1[::-1, :] 122 | img2 = img2[::-1, :] 123 | flow = flow[::-1, :] * [1.0, -1.0] 124 | 125 | if backward_flow is not None: 126 | backward_flow = backward_flow[::-1, :] * [1.0, -1.0] 127 | 128 | if occlusion is not None: 129 | occlusion = occlusion[::-1, :] 130 | if backward_occlusion is not None: 131 | backward_occlusion = backward_occlusion[::-1, :] 132 | 133 | # In case no cropping 134 | if img1.shape[0] - self.crop_size[0] > 0: 135 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 136 | else: 137 | y0 = 0 138 | if img1.shape[1] - self.crop_size[1] > 0: 139 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 140 | else: 141 | x0 = 0 142 | 143 | img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 144 | img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 145 | flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 146 | 147 | if backward_flow is not None: 148 | backward_flow = backward_flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 149 | 150 | if occlusion is not None: 151 | occlusion = occlusion[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 152 | 153 | if backward_occlusion is not None: 154 | backward_occlusion = backward_occlusion[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 155 | 156 | return img1, img2, flow, backward_flow, occlusion, backward_occlusion 157 | 158 | return img1, img2, flow, backward_flow, occlusion 159 | 160 | return img1, img2, flow, backward_flow 161 | 162 | return img1, img2, flow 163 | 164 | def resize(self, img1, img2, flow): 165 | ori_h, ori_w = img1.shape[:2] 166 | 167 | if ori_h < self.crop_size[0] and ori_w < self.crop_size[1]: 168 | # resize both h and w 169 | scale_y = self.crop_size[0] / ori_h 170 | scale_x = self.crop_size[1] / ori_w 171 | elif ori_h < self.crop_size[0]: # only resize h 172 | scale_y = self.crop_size[0] / ori_h 173 | scale_x = 1. 174 | elif ori_w < self.crop_size[1]: # only resize w 175 | scale_x = self.crop_size[1] / ori_w 176 | scale_y = 1. 177 | else: 178 | raise ValueError('Original size %dx%d is not smaller than crop size %dx%d' % ( 179 | ori_h, ori_w, self.crop_size[0], self.crop_size[1] 180 | )) 181 | 182 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 183 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 184 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 185 | flow = flow * [scale_x, scale_y] 186 | 187 | return img1, img2, flow 188 | 189 | def __call__(self, img1, img2, flow, backward_flow=None, occlusion=None, backward_occlusion=None): 190 | img1, img2 = self.color_transform(img1, img2) 191 | img1, img2 = self.eraser_transform(img1, img2) 192 | 193 | if self.resize_when_needed: 194 | assert backward_flow is None 195 | # Resize only when original size is smaller than the crop size 196 | if img1.shape[0] < self.crop_size[0] or img1.shape[1] < self.crop_size[1]: 197 | img1, img2, flow = self.resize(img1, img2, flow) 198 | 199 | if backward_flow is not None: 200 | if occlusion is not None: 201 | if backward_occlusion is not None: 202 | img1, img2, flow, backward_flow, occlusion, backward_occlusion = self.spatial_transform( 203 | img1, img2, flow, backward_flow, occlusion, backward_occlusion) 204 | else: 205 | img1, img2, flow, backward_flow, occlusion = self.spatial_transform( 206 | img1, img2, flow, backward_flow, occlusion) 207 | else: 208 | img1, img2, flow, backward_flow = self.spatial_transform(img1, img2, flow, backward_flow) 209 | else: 210 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 211 | 212 | img1 = np.ascontiguousarray(img1) 213 | img2 = np.ascontiguousarray(img2) 214 | flow = np.ascontiguousarray(flow) 215 | 216 | if backward_flow is not None: 217 | backward_flow = np.ascontiguousarray(backward_flow) 218 | 219 | if occlusion is not None: 220 | occlusion = np.ascontiguousarray(occlusion) 221 | if backward_occlusion is not None: 222 | backward_occlusion = np.ascontiguousarray(backward_occlusion) 223 | return img1, img2, flow, backward_flow, occlusion, backward_occlusion 224 | 225 | return img1, img2, flow, backward_flow, occlusion 226 | 227 | return img1, img2, flow, backward_flow 228 | 229 | return img1, img2, flow 230 | 231 | 232 | class SparseFlowAugmentor: 233 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False, 234 | resize_when_needed=False, # used for slow flow dataset 235 | is_kitti=True, # for KITTI dataset, use sparse resize flow, other bilinear resize 236 | no_eraser_aug=False, 237 | ): 238 | # spatial augmentation params 239 | self.crop_size = crop_size 240 | self.min_scale = min_scale 241 | self.max_scale = max_scale 242 | self.spatial_aug_prob = 0.8 243 | self.stretch_prob = 0.8 244 | self.max_stretch = 0.2 245 | 246 | # flip augmentation params 247 | self.do_flip = do_flip 248 | self.h_flip_prob = 0.5 249 | self.v_flip_prob = 0.1 250 | 251 | # photometric augmentation params 252 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14) 253 | self.asymmetric_color_aug_prob = 0.2 254 | 255 | if no_eraser_aug: 256 | self.eraser_aug_prob = -1 257 | else: 258 | self.eraser_aug_prob = 0.5 259 | 260 | self.resize_when_needed = resize_when_needed 261 | self.is_kitti = is_kitti 262 | 263 | def color_transform(self, img1, img2): 264 | image_stack = np.concatenate([img1, img2], axis=0) 265 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 266 | img1, img2 = np.split(image_stack, 2, axis=0) 267 | return img1, img2 268 | 269 | def eraser_transform(self, img1, img2): 270 | ht, wd = img1.shape[:2] 271 | if np.random.rand() < self.eraser_aug_prob: 272 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 273 | for _ in range(np.random.randint(1, 3)): 274 | x0 = np.random.randint(0, wd) 275 | y0 = np.random.randint(0, ht) 276 | dx = np.random.randint(50, 100) 277 | dy = np.random.randint(50, 100) 278 | img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color 279 | 280 | return img1, img2 281 | 282 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 283 | ht, wd = flow.shape[:2] 284 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 285 | coords = np.stack(coords, axis=-1) 286 | 287 | coords = coords.reshape(-1, 2).astype(np.float32) 288 | flow = flow.reshape(-1, 2).astype(np.float32) 289 | valid = valid.reshape(-1).astype(np.float32) 290 | 291 | coords0 = coords[valid >= 1] 292 | flow0 = flow[valid >= 1] 293 | 294 | ht1 = int(round(ht * fy)) 295 | wd1 = int(round(wd * fx)) 296 | 297 | coords1 = coords0 * [fx, fy] 298 | flow1 = flow0 * [fx, fy] 299 | 300 | xx = np.round(coords1[:, 0]).astype(np.int32) 301 | yy = np.round(coords1[:, 1]).astype(np.int32) 302 | 303 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 304 | xx = xx[v] 305 | yy = yy[v] 306 | flow1 = flow1[v] 307 | 308 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 309 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 310 | 311 | flow_img[yy, xx] = flow1 312 | valid_img[yy, xx] = 1 313 | 314 | return flow_img, valid_img 315 | 316 | def resize(self, img1, img2, flow, valid): 317 | ori_h, ori_w = img1.shape[:2] 318 | 319 | if ori_h < self.crop_size[0] and ori_w < self.crop_size[1]: 320 | # resize both h and w 321 | scale_y = self.crop_size[0] / ori_h 322 | scale_x = self.crop_size[1] / ori_w 323 | elif ori_h < self.crop_size[0]: # only resize h 324 | scale_y = self.crop_size[0] / ori_h 325 | scale_x = 1. 326 | elif ori_w < self.crop_size[1]: # only resize w 327 | scale_x = self.crop_size[1] / ori_w 328 | scale_y = 1. 329 | else: 330 | raise ValueError('Original size %dx%d is not smaller than crop size %dx%d' % ( 331 | ori_h, ori_w, self.crop_size[0], self.crop_size[1] 332 | )) 333 | 334 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 335 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 336 | 337 | if self.is_kitti: 338 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 339 | else: # for viper and slow flow datasets, only a few pixels are invalid 340 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 341 | # NOTE: don't forget scale flow also 342 | flow = flow * [scale_x, scale_y] 343 | 344 | valid = cv2.resize(valid, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_NEAREST) 345 | 346 | return img1, img2, flow, valid 347 | 348 | def spatial_transform(self, img1, img2, flow, valid): 349 | # randomly sample scale 350 | 351 | ht, wd = img1.shape[:2] 352 | min_scale = np.maximum( 353 | (self.crop_size[0] + 1) / float(ht), 354 | (self.crop_size[1] + 1) / float(wd)) 355 | 356 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 357 | scale_x = np.clip(scale, min_scale, None) 358 | scale_y = np.clip(scale, min_scale, None) 359 | 360 | if np.random.rand() < self.spatial_aug_prob: 361 | # rescale the images 362 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 363 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 364 | 365 | if self.is_kitti: 366 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 367 | else: # for viper and slow flow datasets, only a few pixels are invalid 368 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 369 | flow = flow * [scale_x, scale_y] 370 | 371 | valid = cv2.resize(valid, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_NEAREST) 372 | 373 | if self.do_flip: 374 | if np.random.rand() < 0.5: # h-flip 375 | img1 = img1[:, ::-1] 376 | img2 = img2[:, ::-1] 377 | flow = flow[:, ::-1] * [-1.0, 1.0] 378 | valid = valid[:, ::-1] 379 | 380 | margin_y = 20 381 | margin_x = 50 382 | 383 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 384 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 385 | 386 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 387 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 388 | 389 | img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 390 | img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 391 | flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 392 | valid = valid[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 393 | return img1, img2, flow, valid 394 | 395 | def __call__(self, img1, img2, flow, valid): 396 | img1, img2 = self.color_transform(img1, img2) 397 | img1, img2 = self.eraser_transform(img1, img2) 398 | 399 | if self.resize_when_needed: 400 | # Resize only when original size is smaller than the crop size 401 | if img1.shape[0] < self.crop_size[0] or img1.shape[1] < self.crop_size[1]: 402 | img1, img2, flow, valid = self.resize(img1, img2, flow, valid) 403 | 404 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 405 | 406 | img1 = np.ascontiguousarray(img1) 407 | img2 = np.ascontiguousarray(img2) 408 | flow = np.ascontiguousarray(flow) 409 | valid = np.ascontiguousarray(valid) 410 | 411 | return img1, img2, flow, valid 412 | -------------------------------------------------------------------------------- /demo/dogs-jump/00033.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/demo/dogs-jump/00033.jpg -------------------------------------------------------------------------------- /demo/dogs-jump/00034.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/demo/dogs-jump/00034.jpg -------------------------------------------------------------------------------- /demo/dogs-jump/00035.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/demo/dogs-jump/00035.jpg -------------------------------------------------------------------------------- /demo/dogs-jump/00036.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/demo/dogs-jump/00036.jpg -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: flow1d 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - blas=1.0=mkl 8 | - ca-certificates=2020.6.24=0 9 | - certifi=2020.6.20=py37_0 10 | - cloudpickle=1.5.0=py_0 11 | - cudatoolkit=10.2.89=hfd86e86_1 12 | - cycler=0.10.0=py37_0 13 | - cytoolz=0.10.1=py37h7b6447c_0 14 | - dask-core=2.20.0=py_0 15 | - dbus=1.13.16=hb2f20db_0 16 | - decorator=4.4.2=py_0 17 | - expat=2.2.9=he6710b0_2 18 | - fontconfig=2.13.0=h9420a91_0 19 | - freetype=2.10.2=h5ab3b9f_0 20 | - glib=2.65.0=h3eb4bd4_0 21 | - gst-plugins-base=1.14.0=hbbd80ab_1 22 | - gstreamer=1.14.0=hb31296c_0 23 | - icu=58.2=he6710b0_3 24 | - imageio=2.9.0=py_0 25 | - intel-openmp=2020.1=217 26 | - jpeg=9b=h024ee3a_2 27 | - kiwisolver=1.2.0=py37hfd86e86_0 28 | - lcms2=2.11=h396b838_0 29 | - ld_impl_linux-64=2.33.1=h53a641e_7 30 | - libedit=3.1.20191231=h14c3975_1 31 | - libffi=3.3=he6710b0_2 32 | - libgcc-ng=9.1.0=hdf63c60_0 33 | - libgfortran-ng=7.3.0=hdf63c60_0 34 | - libpng=1.6.37=hbc83047_0 35 | - libstdcxx-ng=9.1.0=hdf63c60_0 36 | - libtiff=4.1.0=h2733197_1 37 | - libuuid=1.0.3=h1bed415_2 38 | - libxcb=1.14=h7b6447c_0 39 | - libxml2=2.9.10=he19cac6_1 40 | - lz4-c=1.9.2=he6710b0_0 41 | - matplotlib=3.2.2=0 42 | - matplotlib-base=3.2.2=py37hef1b27d_0 43 | - mkl=2020.1=217 44 | - mkl-service=2.3.0=py37he904b0f_0 45 | - mkl_fft=1.1.0=py37h23d657b_0 46 | - mkl_random=1.1.1=py37h0573a6f_0 47 | - ncurses=6.2=he6710b0_1 48 | - networkx=2.4=py_1 49 | - ninja=1.9.0=py37hfd86e86_0 50 | - numpy=1.18.5=py37ha1c710e_0 51 | - numpy-base=1.18.5=py37hde5b4d6_0 52 | - olefile=0.46=py37_0 53 | - openssl=1.1.1g=h7b6447c_0 54 | - pcre=8.44=he6710b0_0 55 | - pillow=7.2.0=py37hb39fc2d_0 56 | - pip=20.1.1=py37_1 57 | - pyparsing=2.4.7=py_0 58 | - pyqt=5.9.2=py37h05f1152_2 59 | - python=3.7.7=hcff3b4d_5 60 | - python-dateutil=2.8.1=py_0 61 | - pytorch=1.7.1=py3.7_cuda10.2.89_cudnn7.6.5_0 62 | - pywavelets=1.1.1=py37h7b6447c_0 63 | - pyyaml=5.3.1=py37h7b6447c_1 64 | - qt=5.9.7=h5867ecd_1 65 | - readline=8.0=h7b6447c_0 66 | - scikit-image=0.16.2=py37h0573a6f_0 67 | - scipy=1.5.0=py37h0b6359f_0 68 | - setuptools=49.2.0=py37_0 69 | - sip=4.19.8=py37hf484d3e_0 70 | - six=1.15.0=py_0 71 | - sqlite=3.32.3=h62c20be_0 72 | - tk=8.6.10=hbc83047_0 73 | - toolz=0.10.0=py_0 74 | - torchvision=0.6.0=py37_cu102 75 | - tornado=6.0.4=py37h7b6447c_1 76 | - wheel=0.34.2=py37_0 77 | - xz=5.2.5=h7b6447c_0 78 | - yaml=0.2.5=h7b6447c_0 79 | - zlib=1.2.11=h7b6447c_3 80 | - zstd=1.4.5=h0b5b093_0 81 | - pip: 82 | - absl-py==0.9.0 83 | - cachetools==4.1.1 84 | - chardet==3.0.4 85 | - google-auth==1.19.2 86 | - google-auth-oauthlib==0.4.1 87 | - grpcio==1.30.0 88 | - idna==2.10 89 | - importlib-metadata==1.7.0 90 | - markdown==3.2.2 91 | - oauthlib==3.1.0 92 | - opencv-python==4.3.0.36 93 | - protobuf==3.12.2 94 | - pyasn1==0.4.8 95 | - pyasn1-modules==0.2.8 96 | - requests==2.24.0 97 | - requests-oauthlib==1.3.0 98 | - rsa==4.6 99 | - tensorboard==2.2.2 100 | - tensorboard-plugin-wit==1.7.0 101 | - urllib3==1.25.9 102 | - werkzeug==1.0.1 103 | - zipp==3.1.0 104 | 105 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | 6 | import data 7 | from utils import frame_utils 8 | from utils.flow_viz import save_vis_flow_tofile 9 | 10 | from utils.utils import InputPadder, forward_interpolate 11 | from glob import glob 12 | 13 | 14 | @torch.no_grad() 15 | def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission', 16 | padding_factor=8, 17 | save_vis_flow=False, 18 | no_save_flo=False, 19 | **kwargs, 20 | ): 21 | """ Create submission for the Sintel leaderboard """ 22 | model.eval() 23 | for dstype in ['clean', 'final']: 24 | test_dataset = data.MpiSintel(split='test', aug_params=None, dstype=dstype) 25 | 26 | flow_prev, sequence_prev = None, None 27 | for test_id in range(len(test_dataset)): 28 | image1, image2, (sequence, frame) = test_dataset[test_id] 29 | if sequence != sequence_prev: 30 | flow_prev = None 31 | 32 | padder = InputPadder(image1.shape, padding_factor=padding_factor) 33 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 34 | 35 | flow_low, flow_pr = model(image1, image2, iters=iters, 36 | flow_init=flow_prev, 37 | test_mode=True) 38 | 39 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 40 | 41 | if warm_start: 42 | flow_prev = forward_interpolate(flow_low[0])[None].cuda() 43 | 44 | output_dir = os.path.join(output_path, dstype, sequence) 45 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame + 1)) 46 | 47 | if not os.path.exists(output_dir): 48 | os.makedirs(output_dir) 49 | 50 | if not no_save_flo: 51 | frame_utils.writeFlow(output_file, flow) 52 | sequence_prev = sequence 53 | 54 | # Save vis flow 55 | if save_vis_flow: 56 | vis_flow_file = output_file.replace('.flo', '.png') 57 | save_vis_flow_tofile(flow, vis_flow_file) 58 | 59 | 60 | @torch.no_grad() 61 | def create_kitti_submission(model, iters=24, output_path='kitti_submission', 62 | padding_factor=8, 63 | save_vis_flow=False, 64 | **kwargs, 65 | ): 66 | """ Create submission for the Sintel leaderboard """ 67 | model.eval() 68 | test_dataset = data.KITTI(split='testing', aug_params=None) 69 | 70 | if not os.path.exists(output_path): 71 | os.makedirs(output_path) 72 | 73 | for test_id in range(len(test_dataset)): 74 | image1, image2, (frame_id,) = test_dataset[test_id] 75 | padder = InputPadder(image1.shape, mode='kitti', padding_factor=padding_factor) 76 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 77 | 78 | flow_pr = model(image1, image2, iters=iters, 79 | flow_init=None, 80 | test_mode=True)[-1] 81 | 82 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 83 | 84 | output_filename = os.path.join(output_path, frame_id) 85 | 86 | # Save vis flow 87 | if save_vis_flow: 88 | vis_flow_file = output_filename 89 | save_vis_flow_tofile(flow, vis_flow_file) 90 | else: 91 | frame_utils.writeFlowKITTI(output_filename, flow) 92 | 93 | 94 | @torch.no_grad() 95 | def validate_chairs(model, 96 | iters=24, 97 | **kwargs, 98 | ): 99 | """ Perform evaluation on the FlyingChairs (test) split """ 100 | model.eval() 101 | epe_list = [] 102 | results = {} 103 | 104 | val_dataset = data.FlyingChairs(split='validation') 105 | 106 | print('Number of validation image pairs: %d' % len(val_dataset)) 107 | 108 | for val_id in range(len(val_dataset)): 109 | image1, image2, flow_gt, _ = val_dataset[val_id] 110 | 111 | image1 = image1[None].cuda() 112 | image2 = image2[None].cuda() 113 | 114 | flow_pr = model(image1, image2, iters=iters, test_mode=True)[-1] # RAFT 115 | 116 | epe = torch.sum((flow_pr[0].cpu() - flow_gt) ** 2, dim=0).sqrt() 117 | epe_list.append(epe.view(-1).numpy()) 118 | 119 | epe_all = np.concatenate(epe_list) 120 | epe = np.mean(epe_all) 121 | px1 = np.mean(epe_all > 1) 122 | px3 = np.mean(epe_all > 3) 123 | px5 = np.mean(epe_all > 5) 124 | 125 | print("Validation Chairs EPE: %.3f, 1px: %.3f, 3px: %.3f, 5px: %.3f" % (epe, px1, px3, px5)) 126 | 127 | results['chairs_epe'] = epe 128 | results['chairs_1px'] = px1 129 | results['chairs_3px'] = px3 130 | results['chairs_5px'] = px5 131 | 132 | return results 133 | 134 | 135 | @torch.no_grad() 136 | def validate_sintel(model, 137 | count_time=False, 138 | padding_factor=8, 139 | iters=32, 140 | **kwargs, 141 | ): 142 | """ Peform validation using the Sintel (train) split """ 143 | model.eval() 144 | results = {} 145 | 146 | if count_time: 147 | total_time = 0 148 | num_runs = 100 149 | 150 | for dstype in ['clean', 'final']: 151 | val_dataset = data.MpiSintel(split='training', dstype=dstype) 152 | 153 | print('Number of validation image pairs: %d' % len(val_dataset)) 154 | epe_list = [] 155 | 156 | for val_id in range(len(val_dataset)): 157 | image1, image2, flow_gt, _ = val_dataset[val_id] 158 | image1 = image1[None].cuda() 159 | image2 = image2[None].cuda() 160 | 161 | padder = InputPadder(image1.shape, padding_factor=padding_factor) 162 | image1, image2 = padder.pad(image1, image2) 163 | 164 | if count_time and val_id >= 5: # 5 warmup 165 | torch.cuda.synchronize() 166 | time_start = time.perf_counter() 167 | 168 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) 169 | 170 | if count_time and val_id >= 5: 171 | torch.cuda.synchronize() 172 | total_time += time.perf_counter() - time_start 173 | 174 | if val_id >= num_runs + 4: 175 | break 176 | 177 | flow = padder.unpad(flow_pr[0]).cpu() 178 | 179 | epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt() 180 | epe_list.append(epe.view(-1).numpy()) 181 | 182 | epe_all = np.concatenate(epe_list) 183 | epe = np.mean(epe_all) 184 | px1 = np.mean(epe_all > 1) 185 | px3 = np.mean(epe_all > 3) 186 | px5 = np.mean(epe_all > 5) 187 | 188 | print("Validation Sintel (%s) EPE: %.3f, 1px: %.3f, 3px: %.3f, 5px: %.3f" % (dstype, epe, px1, px3, px5)) 189 | 190 | dstype = 'sintel_' + dstype 191 | 192 | results[dstype + '_epe'] = np.mean(epe_list) 193 | results[dstype + '_1px'] = px1 194 | results[dstype + '_3px'] = px3 195 | results[dstype + '_5px'] = px5 196 | 197 | if count_time: 198 | print('Time: %.3fs' % (total_time / num_runs)) 199 | break # only the clean pass when counting time 200 | 201 | return results 202 | 203 | 204 | @torch.no_grad() 205 | def validate_kitti(model, 206 | padding_factor=8, 207 | iters=24, 208 | **kwargs, 209 | ): 210 | """ Peform validation using the KITTI-2015 (train) split """ 211 | model.eval() 212 | 213 | val_dataset = data.KITTI(split='training') 214 | print('Number of validation image pairs: %d' % len(val_dataset)) 215 | 216 | out_list, epe_list = [], [] 217 | results = {} 218 | 219 | for val_id in range(len(val_dataset)): 220 | image1, image2, flow_gt, valid_gt = val_dataset[val_id] 221 | image1 = image1[None].cuda() 222 | image2 = image2[None].cuda() 223 | 224 | padder = InputPadder(image1.shape, mode='kitti', padding_factor=padding_factor) 225 | image1, image2 = padder.pad(image1, image2) 226 | 227 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) 228 | 229 | flow = padder.unpad(flow_pr[0]).cpu() 230 | 231 | epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt() 232 | mag = torch.sum(flow_gt ** 2, dim=0).sqrt() 233 | 234 | epe = epe.view(-1) 235 | mag = mag.view(-1) 236 | val = valid_gt.view(-1) >= 0.5 237 | 238 | out = ((epe > 3.0) & ((epe / mag) > 0.05)).float() 239 | 240 | epe_list.append(epe[val].mean().item()) 241 | out_list.append(out[val].cpu().numpy()) 242 | 243 | epe_list = np.array(epe_list) 244 | out_list = np.concatenate(out_list) 245 | 246 | epe = np.mean(epe_list) 247 | f1 = 100 * np.mean(out_list) 248 | 249 | print("Validation KITTI EPE: %.3f, F1-all: %.3f" % (epe, f1)) 250 | results['kitti_epe'] = epe 251 | results['kitti_f1'] = f1 252 | 253 | return results 254 | 255 | 256 | @torch.no_grad() 257 | def inference_on_dir(model, inference_dir, 258 | iters=32, warm_start=False, output_path='output', 259 | padding_factor=8, 260 | paired_data=False, # dir of paired data instead of a sequence 261 | save_flo_flow=False, # save as .flo for quantative evaluation 262 | **kwargs, 263 | ): 264 | """ Inference on a directory """ 265 | model.eval() 266 | 267 | if not os.path.exists(output_path): 268 | os.makedirs(output_path) 269 | 270 | filenames = sorted(glob(inference_dir + '/*')) 271 | print('%d images found' % len(filenames)) 272 | 273 | flow_prev, sequence_prev = None, None 274 | 275 | stride = 2 if paired_data else 1 276 | 277 | if paired_data: 278 | assert len(filenames) % 2 == 0 279 | 280 | for test_id in range(0, len(filenames) - 1, stride): 281 | image1 = frame_utils.read_gen(filenames[test_id]) 282 | image2 = frame_utils.read_gen(filenames[test_id + 1]) 283 | 284 | image1 = np.array(image1).astype(np.uint8) 285 | image2 = np.array(image2).astype(np.uint8) 286 | 287 | if len(image1.shape) == 2: # gray image, for example, HD1K 288 | image1 = np.tile(image1[..., None], (1, 1, 3)) 289 | image2 = np.tile(image2[..., None], (1, 1, 3)) 290 | else: 291 | image1 = image1[..., :3] 292 | image2 = image2[..., :3] 293 | 294 | image1 = torch.from_numpy(image1).permute(2, 0, 1).float() 295 | image2 = torch.from_numpy(image2).permute(2, 0, 1).float() 296 | 297 | if test_id == 0: 298 | flow_prev = None 299 | 300 | padder = InputPadder(image1.shape, padding_factor=padding_factor) 301 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 302 | 303 | flow_init = None 304 | flow_low, flow_pr = model(image1, image2, iters=iters, 305 | flow_init=flow_prev if flow_init is None else flow_init, 306 | test_mode=True) 307 | 308 | if warm_start: 309 | flow_prev = forward_interpolate(flow_low[0])[None].cuda() 310 | 311 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() # [H, W, 2] 312 | 313 | output_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_flow.png') 314 | 315 | # Save vis flow 316 | save_vis_flow_tofile(flow, output_file) 317 | 318 | if save_flo_flow: 319 | output_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_pred.flo') 320 | frame_utils.writeFlow(output_file, flow) 321 | -------------------------------------------------------------------------------- /flow1d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/flow1d/__init__.py -------------------------------------------------------------------------------- /flow1d/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | 5 | 6 | class Attention1D(nn.Module): 7 | """Cross-Attention on x or y direction, 8 | without multi-head and dropout support for faster speed 9 | """ 10 | 11 | def __init__(self, in_channels, 12 | y_attention=False, 13 | double_cross_attn=False, # cross attn feature1 before computing cross attn feature2 14 | **kwargs, 15 | ): 16 | super(Attention1D, self).__init__() 17 | 18 | self.y_attention = y_attention 19 | self.double_cross_attn = double_cross_attn 20 | 21 | # self attn feature1 before cross attn 22 | if double_cross_attn: 23 | self.self_attn = copy.deepcopy(Attention1D(in_channels=in_channels, 24 | y_attention=not y_attention, 25 | ) 26 | ) 27 | 28 | self.query_conv = nn.Conv2d(in_channels, in_channels, 1) 29 | self.key_conv = nn.Conv2d(in_channels, in_channels, 1) 30 | 31 | # Initialize: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py#L138 32 | for p in self.parameters(): 33 | if p.dim() > 1: 34 | nn.init.xavier_uniform_(p) # original Transformer initialization 35 | 36 | def forward(self, feature1, feature2, position=None, value=None): 37 | b, c, h, w = feature1.size() 38 | 39 | # self attn before cross attn 40 | if self.double_cross_attn: 41 | feature1 = self.self_attn(feature1, feature1, position)[0] # self attn feature1 42 | 43 | query = feature1 + position if position is not None else feature1 44 | query = self.query_conv(query) # [B, C, H, W] 45 | 46 | key = feature2 + position if position is not None else feature2 47 | 48 | key = self.key_conv(key) # [B, C, H, W] 49 | value = feature2 if value is None else value # [B, C, H, W] 50 | scale_factor = c ** 0.5 51 | 52 | if self.y_attention: 53 | query = query.permute(0, 3, 2, 1) # [B, W, H, C] 54 | key = key.permute(0, 3, 1, 2) # [B, W, C, H] 55 | value = value.permute(0, 3, 2, 1) # [B, W, H, C] 56 | else: # x attention 57 | query = query.permute(0, 2, 3, 1) # [B, H, W, C] 58 | key = key.permute(0, 2, 1, 3) # [B, H, C, W] 59 | value = value.permute(0, 2, 3, 1) # [B, H, W, C] 60 | 61 | scores = torch.matmul(query, key) / scale_factor # [B, W, H, H] or [B, H, W, W] 62 | 63 | attention = torch.softmax(scores, dim=-1) # [B, W, H, H] or [B, H, W, W] 64 | 65 | out = torch.matmul(attention, value) # [B, W, H, C] or [B, H, W, C] 66 | 67 | if self.y_attention: 68 | out = out.permute(0, 3, 2, 1).contiguous() # [B, C, H, W] 69 | else: 70 | out = out.permute(0, 3, 1, 2).contiguous() # [B, C, H, W] 71 | 72 | return out, attention 73 | -------------------------------------------------------------------------------- /flow1d/correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class Correlation1D: 6 | def __init__(self, feature1, feature2, 7 | radius=32, 8 | x_correlation=False, 9 | ): 10 | self.radius = radius 11 | self.x_correlation = x_correlation 12 | 13 | if self.x_correlation: 14 | self.corr = self.corr_x(feature1, feature2) # [B*H*W, 1, 1, W] 15 | else: 16 | self.corr = self.corr_y(feature1, feature2) # [B*H*W, 1, H, 1] 17 | 18 | def __call__(self, coords): 19 | r = self.radius 20 | coords = coords.permute(0, 2, 3, 1) # [B, H, W, 2] 21 | b, h, w = coords.shape[:3] 22 | 23 | if self.x_correlation: 24 | dx = torch.linspace(-r, r, 2 * r + 1) 25 | dy = torch.zeros_like(dx) 26 | delta_x = torch.stack((dx, dy), dim=-1).to(coords.device) # [2r+1, 2] 27 | 28 | coords_x = coords[:, :, :, 0] # [B, H, W] 29 | coords_x = torch.stack((coords_x, torch.zeros_like(coords_x)), dim=-1) # [B, H, W, 2] 30 | 31 | centroid_x = coords_x.view(b * h * w, 1, 1, 2) # [B*H*W, 1, 1, 2] 32 | coords_x = centroid_x + delta_x # [B*H*W, 1, 2r+1, 2] 33 | 34 | coords_x = 2 * coords_x / (w - 1) - 1 # [-1, 1], y is always 0 35 | 36 | corr_x = F.grid_sample(self.corr, coords_x, mode='bilinear', 37 | align_corners=True) # [B*H*W, G, 1, 2r+1] 38 | 39 | corr_x = corr_x.view(b, h, w, -1) # [B, H, W, (2r+1)*G] 40 | corr_x = corr_x.permute(0, 3, 1, 2).contiguous() # [B, (2r+1)*G, H, W] 41 | return corr_x 42 | else: # y correlation 43 | dy = torch.linspace(-r, r, 2 * r + 1) 44 | dx = torch.zeros_like(dy) 45 | delta_y = torch.stack((dx, dy), dim=-1).to(coords.device) # [2r+1, 2] 46 | delta_y = delta_y.unsqueeze(1).unsqueeze(0) # [1, 2r+1, 1, 2] 47 | 48 | coords_y = coords[:, :, :, 1] # [B, H, W] 49 | coords_y = torch.stack((torch.zeros_like(coords_y), coords_y), dim=-1) # [B, H, W, 2] 50 | 51 | centroid_y = coords_y.view(b * h * w, 1, 1, 2) # [B*H*W, 1, 1, 2] 52 | coords_y = centroid_y + delta_y # [B*H*W, 2r+1, 1, 2] 53 | 54 | coords_y = 2 * coords_y / (h - 1) - 1 # [-1, 1], x is always 0 55 | 56 | corr_y = F.grid_sample(self.corr, coords_y, mode='bilinear', 57 | align_corners=True) # [B*H*W, G, 2r+1, 1] 58 | 59 | corr_y = corr_y.view(b, h, w, -1) # [B, H, W, (2r+1)*G] 60 | corr_y = corr_y.permute(0, 3, 1, 2).contiguous() # [B, (2r+1)*G, H, W] 61 | 62 | return corr_y 63 | 64 | def corr_x(self, feature1, feature2): 65 | b, c, h, w = feature1.shape # [B, C, H, W] 66 | scale_factor = c ** 0.5 67 | 68 | # x direction 69 | feature1 = feature1.permute(0, 2, 3, 1) # [B, H, W, C] 70 | feature2 = feature2.permute(0, 2, 1, 3) # [B, H, C, W] 71 | corr = torch.matmul(feature1, feature2) # [B, H, W, W] 72 | 73 | corr = corr.unsqueeze(3).unsqueeze(3) # [B, H, W, 1, 1, W] 74 | corr = corr / scale_factor 75 | corr = corr.flatten(0, 2) # [B*H*W, 1, 1, W] 76 | 77 | return corr 78 | 79 | def corr_y(self, feature1, feature2): 80 | b, c, h, w = feature1.shape # [B, C, H, W] 81 | scale_factor = c ** 0.5 82 | 83 | # y direction 84 | feature1 = feature1.permute(0, 3, 2, 1) # [B, W, H, C] 85 | feature2 = feature2.permute(0, 3, 1, 2) # [B, W, C, H] 86 | corr = torch.matmul(feature1, feature2) # [B, W, H, H] 87 | 88 | corr = corr.permute(0, 2, 1, 3).contiguous().view(b, h, w, 1, h, 1) # [B, H, W, 1, H, 1] 89 | corr = corr / scale_factor 90 | corr = corr.flatten(0, 2) # [B*H*W, 1, H, 1] 91 | 92 | return corr 93 | -------------------------------------------------------------------------------- /flow1d/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ResidualBlock(nn.Module): 6 | def __init__(self, in_planes, planes, norm_fn='group', stride=1, dilation=1): 7 | super(ResidualBlock, self).__init__() 8 | 9 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, 10 | dilation=dilation, padding=dilation, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 12 | dilation=dilation, padding=dilation) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | num_groups = planes // 8 16 | 17 | if norm_fn == 'group': 18 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 20 | if not stride == 1: 21 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 22 | 23 | elif norm_fn == 'batch': 24 | self.norm1 = nn.BatchNorm2d(planes) 25 | self.norm2 = nn.BatchNorm2d(planes) 26 | if not stride == 1 or in_planes != planes: 27 | self.norm3 = nn.BatchNorm2d(planes) 28 | 29 | elif norm_fn == 'instance': 30 | self.norm1 = nn.InstanceNorm2d(planes) 31 | self.norm2 = nn.InstanceNorm2d(planes) 32 | if not stride == 1 or in_planes != planes: 33 | self.norm3 = nn.InstanceNorm2d(planes) 34 | 35 | elif norm_fn == 'none': 36 | self.norm1 = nn.Sequential() 37 | self.norm2 = nn.Sequential() 38 | if not stride == 1: 39 | self.norm3 = nn.Sequential() 40 | 41 | if stride == 1 and in_planes == planes: 42 | self.downsample = None 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | def forward(self, x): 48 | y = x 49 | y = self.relu(self.norm1(self.conv1(y))) 50 | y = self.relu(self.norm2(self.conv2(y))) 51 | 52 | if self.downsample is not None: 53 | x = self.downsample(x) 54 | 55 | return self.relu(x + y) 56 | 57 | 58 | class BasicEncoder(nn.Module): 59 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, 60 | **kwargs, 61 | ): 62 | super(BasicEncoder, self).__init__() 63 | self.norm_fn = norm_fn 64 | 65 | feature_dims = [64, 96, 128, 160] 66 | 67 | if self.norm_fn == 'group': 68 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=feature_dims[0]) 69 | 70 | elif self.norm_fn == 'batch': 71 | self.norm1 = nn.BatchNorm2d(feature_dims[0]) 72 | 73 | elif self.norm_fn == 'instance': 74 | self.norm1 = nn.InstanceNorm2d(feature_dims[0]) 75 | 76 | elif self.norm_fn == 'none': 77 | self.norm1 = nn.Sequential() 78 | 79 | self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3) 80 | self.relu1 = nn.ReLU(inplace=True) 81 | 82 | self.in_planes = feature_dims[0] 83 | self.layer1 = self._make_layer(feature_dims[0], stride=1) 84 | self.layer2 = self._make_layer(feature_dims[1], stride=2) # 1/4 85 | 86 | self.layer3 = self._make_layer(feature_dims[2], stride=2, dilation=1) 87 | 88 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, kernel_size=1) 89 | 90 | self.dropout = None 91 | if dropout > 0: 92 | self.dropout = nn.Dropout2d(p=dropout) 93 | 94 | for m in self.modules(): 95 | if isinstance(m, nn.Conv2d): 96 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 97 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 98 | if m.weight is not None: 99 | nn.init.constant_(m.weight, 1) 100 | if m.bias is not None: 101 | nn.init.constant_(m.bias, 0) 102 | 103 | def _make_layer(self, dim, stride=1, dilation=1): 104 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride, dilation=dilation) 105 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1, dilation=dilation) 106 | layers = (layer1, layer2) 107 | 108 | self.in_planes = dim 109 | return nn.Sequential(*layers) 110 | 111 | def forward(self, x): 112 | 113 | # if input is list, combine batch dimension 114 | is_list = isinstance(x, tuple) or isinstance(x, list) 115 | if is_list: 116 | batch_dim = x[0].shape[0] 117 | x = torch.cat(x, dim=0) 118 | 119 | x = self.conv1(x) 120 | x = self.norm1(x) 121 | x = self.relu1(x) 122 | 123 | x = self.layer1(x) # 1/2 124 | layer2 = self.layer2(x) # 1/4 125 | 126 | x = self.layer3(layer2) # 1/8 127 | 128 | x = self.conv2(x) 129 | 130 | if self.training and self.dropout is not None: 131 | x = self.dropout(x) 132 | 133 | if is_list: 134 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 135 | 136 | return x 137 | -------------------------------------------------------------------------------- /flow1d/flow1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .extractor import BasicEncoder 6 | from .attention import Attention1D 7 | from .position import PositionEmbeddingSine 8 | from .correlation import Correlation1D 9 | from .update import BasicUpdateBlock 10 | from utils.utils import coords_grid 11 | 12 | 13 | class Model(nn.Module): 14 | def __init__(self, 15 | downsample_factor=8, 16 | feature_channels=256, 17 | hidden_dim=128, 18 | context_dim=128, 19 | corr_radius=32, 20 | mixed_precision=False, 21 | **kwargs, 22 | ): 23 | super(Model, self).__init__() 24 | 25 | self.downsample_factor = downsample_factor 26 | 27 | self.feature_channels = feature_channels 28 | 29 | self.hidden_dim = hidden_dim 30 | self.context_dim = context_dim 31 | self.corr_radius = corr_radius 32 | 33 | self.mixed_precision = mixed_precision 34 | 35 | # feature network, context network, and update block 36 | self.fnet = BasicEncoder(output_dim=feature_channels, norm_fn='instance', 37 | ) 38 | 39 | self.cnet = BasicEncoder(output_dim=hidden_dim + context_dim, norm_fn='batch', 40 | ) 41 | 42 | # 1D attention 43 | corr_channels = (2 * corr_radius + 1) * 2 44 | 45 | self.attn_x = Attention1D(feature_channels, 46 | y_attention=False, 47 | double_cross_attn=True, 48 | ) 49 | self.attn_y = Attention1D(feature_channels, 50 | y_attention=True, 51 | double_cross_attn=True, 52 | ) 53 | 54 | # Update block 55 | self.update_block = BasicUpdateBlock(corr_channels=corr_channels, 56 | hidden_dim=hidden_dim, 57 | context_dim=context_dim, 58 | downsample_factor=downsample_factor, 59 | ) 60 | 61 | def freeze_bn(self): 62 | for m in self.modules(): 63 | if isinstance(m, nn.BatchNorm2d): 64 | m.eval() 65 | 66 | def initialize_flow(self, img, downsample=None): 67 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 68 | n, c, h, w = img.shape 69 | downsample_factor = self.downsample_factor if downsample is None else downsample 70 | coords0 = coords_grid(n, h // downsample_factor, w // downsample_factor).to(img.device) 71 | coords1 = coords_grid(n, h // downsample_factor, w // downsample_factor).to(img.device) 72 | 73 | # optical flow computed as difference: flow = coords1 - coords0 74 | return coords0, coords1 75 | 76 | def learned_upflow(self, flow, mask): 77 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 78 | n, _, h, w = flow.shape 79 | mask = mask.view(n, 1, 9, self.downsample_factor, self.downsample_factor, h, w) 80 | mask = torch.softmax(mask, dim=2) 81 | 82 | up_flow = F.unfold(self.downsample_factor * flow, [3, 3], padding=1) 83 | up_flow = up_flow.view(n, 2, 9, 1, 1, h, w) 84 | 85 | up_flow = torch.sum(mask * up_flow, dim=2) 86 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 87 | return up_flow.reshape(n, 2, self.downsample_factor * h, self.downsample_factor * w) 88 | 89 | def forward(self, image1, image2, iters=12, flow_init=None, test_mode=False, 90 | ): 91 | """ Estimate optical flow between pair of frames """ 92 | image1 = 2 * (image1 / 255.0) - 1.0 93 | image2 = 2 * (image2 / 255.0) - 1.0 94 | 95 | # run the feature network 96 | feature1, feature2 = self.fnet([image1, image2]) 97 | 98 | # Used for attention loss computation, store the attention matrix 99 | attn_x_list = [] 100 | attn_y_list = [] 101 | 102 | hdim = self.hidden_dim 103 | cdim = self.context_dim 104 | 105 | # position encoding 106 | pos_channels = self.feature_channels // 2 107 | pos_enc = PositionEmbeddingSine(pos_channels) 108 | 109 | position = pos_enc(feature1) # [B, C, H, W] 110 | 111 | # 1D correlation 112 | feature2_x, attn_x = self.attn_x(feature1, feature2, position) 113 | corr_fn_y = Correlation1D(feature1, feature2_x, 114 | radius=self.corr_radius, 115 | x_correlation=False, 116 | ) 117 | 118 | feature2_y, attn_y = self.attn_y(feature1, feature2, position) 119 | corr_fn_x = Correlation1D(feature1, feature2_y, 120 | radius=self.corr_radius, 121 | x_correlation=True, 122 | ) 123 | 124 | # run the context network 125 | cnet = self.cnet(image1) # list of feature pyramid, low scale to high scale 126 | 127 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 128 | net = torch.tanh(net) 129 | inp = torch.relu(inp) 130 | 131 | coords0, coords1 = self.initialize_flow(image1) # 1/8 resolution or 1/4 132 | 133 | if flow_init is not None: # flow_init is 1/8 resolution or 1/4 134 | coords1 = coords1 + flow_init 135 | 136 | flow_predictions = [] 137 | for itr in range(iters): 138 | coords1 = coords1.detach() # stop gradient 139 | 140 | corr_x = corr_fn_x(coords1) 141 | corr_y = corr_fn_y(coords1) 142 | corr = torch.cat((corr_x, corr_y), dim=1) # [B, 2(2R+1), H, W] 143 | 144 | flow = coords1 - coords0 145 | 146 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow, 147 | upsample=not test_mode or itr == iters - 1, 148 | ) 149 | 150 | coords1 = coords1 + delta_flow 151 | 152 | if test_mode: 153 | # only upsample the last iteration 154 | if itr == iters - 1: 155 | flow_up = self.learned_upflow(coords1 - coords0, up_mask) 156 | 157 | return coords1 - coords0, flow_up 158 | else: 159 | # upsample predictions 160 | flow_up = self.learned_upflow(coords1 - coords0, up_mask) 161 | flow_predictions.append(flow_up) 162 | 163 | return flow_predictions, attn_x_list, attn_y_list, coords1 - coords0 164 | 165 | 166 | def build_model(args): 167 | return Model(downsample_factor=args.downsample_factor, 168 | feature_channels=args.feature_channels, 169 | corr_radius=args.corr_radius, 170 | hidden_dim=args.hidden_dim, 171 | context_dim=args.context_dim, 172 | mixed_precision=args.mixed_precision, 173 | ) 174 | -------------------------------------------------------------------------------- /flow1d/position.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | 6 | class PositionEmbeddingSine(nn.Module): 7 | """ 8 | https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py 9 | This is a more standard version of the position embedding, very similar to the one 10 | used by the Attention is all you need paper, generalized to work on images. 11 | """ 12 | 13 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): 14 | super().__init__() 15 | self.num_pos_feats = num_pos_feats 16 | self.temperature = temperature 17 | self.normalize = normalize 18 | if scale is not None and normalize is False: 19 | raise ValueError("normalize should be True if scale is passed") 20 | if scale is None: 21 | scale = 2 * math.pi 22 | self.scale = scale 23 | 24 | def forward(self, x): 25 | # x = tensor_list.tensors # [B, C, H, W] 26 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 27 | b, c, h, w = x.size() 28 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W] 29 | y_embed = mask.cumsum(1, dtype=torch.float32) 30 | x_embed = mask.cumsum(2, dtype=torch.float32) 31 | if self.normalize: 32 | eps = 1e-6 33 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 34 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 35 | 36 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 37 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 38 | 39 | pos_x = x_embed[:, :, :, None] / dim_t 40 | pos_y = y_embed[:, :, :, None] / dim_t 41 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 42 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 43 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 44 | return pos 45 | -------------------------------------------------------------------------------- /flow1d/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256, 8 | ): 9 | super(FlowHead, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 12 | 13 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 14 | self.relu = nn.ReLU(inplace=True) 15 | 16 | def forward(self, x): 17 | out = self.conv2(self.relu(self.conv1(x))) 18 | 19 | return out 20 | 21 | 22 | class SepConvGRU(nn.Module): 23 | def __init__(self, hidden_dim=128, input_dim=192 + 128, 24 | kernel_size=5, 25 | ): 26 | padding = (kernel_size - 1) // 2 27 | 28 | super(SepConvGRU, self).__init__() 29 | self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 30 | self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 31 | self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 32 | 33 | self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 34 | self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 35 | self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 36 | 37 | def forward(self, h, x): 38 | # horizontal 39 | hx = torch.cat([h, x], dim=1) 40 | z = torch.sigmoid(self.convz1(hx)) 41 | r = torch.sigmoid(self.convr1(hx)) 42 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 43 | h = (1 - z) * h + z * q 44 | 45 | # vertical 46 | hx = torch.cat([h, x], dim=1) 47 | z = torch.sigmoid(self.convz2(hx)) 48 | r = torch.sigmoid(self.convr2(hx)) 49 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 50 | h = (1 - z) * h + z * q 51 | 52 | return h 53 | 54 | 55 | class BasicMotionEncoder(nn.Module): 56 | def __init__(self, corr_channels=324, 57 | ): 58 | super(BasicMotionEncoder, self).__init__() 59 | 60 | self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0) 61 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 62 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 63 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 64 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) 65 | 66 | def forward(self, flow, corr): 67 | cor = F.relu(self.convc1(corr)) 68 | cor = F.relu(self.convc2(cor)) 69 | flo = F.relu(self.convf1(flow)) 70 | flo = F.relu(self.convf2(flo)) 71 | 72 | cor_flo = torch.cat([cor, flo], dim=1) 73 | out = F.relu(self.conv(cor_flo)) 74 | return torch.cat([out, flow], dim=1) 75 | 76 | 77 | class BasicUpdateBlock(nn.Module): 78 | def __init__(self, corr_channels=324, 79 | hidden_dim=128, 80 | context_dim=128, 81 | downsample_factor=8, 82 | learn_upsample=True, 83 | **kwargs, 84 | ): 85 | super(BasicUpdateBlock, self).__init__() 86 | 87 | self.encoder = BasicMotionEncoder(corr_channels=corr_channels) 88 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim) 89 | 90 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256, 91 | ) 92 | 93 | self.learn_upsample = learn_upsample 94 | 95 | if learn_upsample: 96 | self.mask = nn.Sequential( 97 | nn.Conv2d(hidden_dim, 256, 3, padding=1), 98 | nn.ReLU(inplace=True), 99 | nn.Conv2d(256, downsample_factor ** 2 * 9, 1, padding=0)) 100 | 101 | def forward(self, net, inp, corr, flow, upsample=True, 102 | **kwargs, 103 | ): 104 | motion_features = self.encoder(flow, corr) 105 | 106 | inp = torch.cat([inp, motion_features], dim=1) 107 | 108 | net = self.gru(net, inp) 109 | delta_flow = self.flow_head(net) 110 | 111 | if self.learn_upsample and upsample: 112 | # scale mask to balence gradients following RAFT 113 | mask = .25 * self.mask(net) 114 | else: 115 | mask = None 116 | return net, mask, delta_flow 117 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def criterion(flow_preds, flow_gt, valid, gamma=0.8, max_flow=400, 5 | ): 6 | """ Loss function defined over sequence of flow predictions 7 | """ 8 | 9 | n_predictions = len(flow_preds) 10 | flow_loss = 0.0 11 | 12 | # exlude invalid pixels and extremely large diplacements 13 | mag = torch.sum(flow_gt ** 2, dim=1).sqrt() 14 | valid = (valid >= 0.5) & (mag < max_flow) 15 | 16 | for i in range(n_predictions): 17 | i_weight = gamma ** (n_predictions - i - 1) 18 | i_loss = (flow_preds[i] - flow_gt).abs() 19 | 20 | flow_loss += i_weight * (valid[:, None] * i_loss).mean() 21 | 22 | epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt() 23 | epe = epe.view(-1)[valid.view(-1)] 24 | 25 | metrics = { 26 | 'epe': epe.mean().item(), 27 | '1px': (epe > 1).float().mean().item(), 28 | '3px': (epe > 3).float().mean().item(), 29 | '5px': (epe > 5).float().mean().item(), 30 | } 31 | 32 | return flow_loss, metrics 33 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.utils.tensorboard import SummaryWriter 4 | 5 | import argparse 6 | import numpy as np 7 | import os 8 | 9 | from data import build_dataset 10 | 11 | from flow1d.flow1d import build_model 12 | from loss import criterion 13 | from evaluate import (validate_chairs, validate_sintel, validate_kitti, 14 | create_kitti_submission, create_sintel_submission, 15 | inference_on_dir, 16 | ) 17 | 18 | from utils.logger import Logger 19 | from utils import misc 20 | 21 | 22 | def get_args_parser(): 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoints/tmp') 26 | parser.add_argument('--eval', action='store_true') 27 | 28 | # Dataset 29 | parser.add_argument('--image_size', default=[368, 496], type=int, nargs='+') 30 | parser.add_argument('--stage', default='chairs', type=str) 31 | parser.add_argument('--max_flow', default=400, type=int) 32 | parser.add_argument('--padding_factor', default=8, type=int) 33 | parser.add_argument('--val_dataset', default='chairs', type=str, nargs='+') 34 | 35 | # Create Sintel and KITTI submission 36 | parser.add_argument('--submission', action='store_true', 37 | help='Create submission') 38 | parser.add_argument('--warm_start', action='store_true') 39 | parser.add_argument('--output_path', default='output', type=str) 40 | parser.add_argument('--save_vis_flow', action='store_true') 41 | parser.add_argument('--no_save_flo', action='store_true') 42 | 43 | # Inference on a directory 44 | parser.add_argument('--inference_dir', default=None, type=str) 45 | parser.add_argument('--dir_paired_data', action='store_true', 46 | help='Paired data in a dir instead of a sequence') 47 | parser.add_argument('--save_flo_flow', action='store_true') 48 | 49 | # Training 50 | parser.add_argument('--lr', default=4e-4, type=float) 51 | parser.add_argument('--lr_warmup', default=0.05, type=float, 52 | help='Percentage of lr warmup') 53 | parser.add_argument('--batch_size', default=12, type=int) 54 | parser.add_argument('--num_workers', default=4, type=int) 55 | parser.add_argument('--weight_decay', default=1e-4, type=float) 56 | parser.add_argument('--grad_clip', default=1.0, type=float) 57 | parser.add_argument('--num_steps', default=100000, type=int) 58 | parser.add_argument('--seed', default=326, type=int) 59 | parser.add_argument('--summary_freq', default=100, type=int) 60 | parser.add_argument('--val_freq', default=5000, type=int) 61 | parser.add_argument('--save_ckpt_freq', default=50000, type=int) 62 | parser.add_argument('--resume', default=None, type=str) 63 | parser.add_argument('--no_resume_optimizer', action='store_true') 64 | parser.add_argument('--no_latest_ckpt', action='store_true') 65 | parser.add_argument('--save_latest_ckpt_freq', default=1000, type=int) 66 | parser.add_argument('--freeze_bn', action='store_true') 67 | 68 | parser.add_argument('--train_iters', default=12, type=int) 69 | parser.add_argument('--val_iters', default=12, type=int) 70 | 71 | # Flow1D 72 | parser.add_argument('--downsample_factor', default=8, type=int) 73 | parser.add_argument('--feature_channels', default=256, type=int) 74 | parser.add_argument('--corr_radius', default=32, type=int) 75 | parser.add_argument('--hidden_dim', default=128, type=int) 76 | parser.add_argument('--context_dim', default=128, type=int) 77 | parser.add_argument('--gamma', default=0.8, type=float, 78 | help='Exponential weighting') 79 | 80 | parser.add_argument('--mixed_precision', action='store_true') 81 | 82 | # Distributed training 83 | parser.add_argument('--local_rank', default=0, type=int) 84 | 85 | # Misc 86 | parser.add_argument('--count_time', action='store_true') 87 | 88 | return parser 89 | 90 | 91 | def main(args): 92 | if not args.eval and not args.submission and args.inference_dir is None: 93 | print('PyTorch version:', torch.__version__) 94 | print(args) 95 | misc.save_args(args) 96 | misc.check_path(args.checkpoint_dir) 97 | misc.save_command(args.checkpoint_dir) 98 | 99 | misc.check_path(args.output_path) 100 | 101 | seed = args.seed 102 | torch.manual_seed(seed) 103 | np.random.seed(seed) 104 | 105 | torch.backends.cudnn.benchmark = True 106 | 107 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 108 | 109 | # model 110 | model = build_model(args).to(device) 111 | 112 | if not args.eval: 113 | print('Model definition:') 114 | print(model) 115 | 116 | if torch.cuda.device_count() > 1: 117 | print('Use %d GPUs' % torch.cuda.device_count()) 118 | model = torch.nn.DataParallel(model) 119 | 120 | model_without_ddp = model.module 121 | else: 122 | model_without_ddp = model 123 | 124 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 125 | print('Number of params:', num_params) 126 | if not args.eval and not args.submission and args.inference_dir is None: 127 | save_name = '%d_parameters' % num_params 128 | open(os.path.join(args.checkpoint_dir, save_name), 'a').close() 129 | 130 | # optimizer 131 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, 132 | weight_decay=args.weight_decay) 133 | 134 | start_epoch = 0 135 | start_step = 0 136 | 137 | # resume checkpoints 138 | if args.resume: 139 | print('Load checkpoint: %s' % args.resume) 140 | checkpoint = torch.load(args.resume) 141 | weights = checkpoint['model'] if 'model' in checkpoint else checkpoint 142 | model_without_ddp.load_state_dict(weights, strict=False) 143 | 144 | if 'optimizer' in checkpoint and 'step' in checkpoint and 'epoch' in checkpoint and not \ 145 | args.no_resume_optimizer: 146 | print('Load optimizer') 147 | optimizer.load_state_dict(checkpoint['optimizer']) 148 | start_epoch = checkpoint['epoch'] 149 | start_step = checkpoint['step'] 150 | 151 | print('start_epoch: %d, start_step: %d' % (start_epoch, start_step)) 152 | 153 | # evaluate 154 | if args.eval: 155 | if 'chairs' in args.val_dataset: 156 | validate_chairs(model_without_ddp, 157 | iters=args.val_iters, 158 | ) 159 | elif 'sintel' in args.val_dataset: 160 | validate_sintel(model_without_ddp, 161 | iters=args.val_iters, 162 | padding_factor=args.padding_factor, 163 | count_time=args.count_time, 164 | ) 165 | elif 'kitti' in args.val_dataset: 166 | validate_kitti(model_without_ddp, 167 | iters=args.val_iters, 168 | padding_factor=args.padding_factor, 169 | ) 170 | else: 171 | raise ValueError(f'Dataset type {args.val_dataset} is not supported') 172 | 173 | return 174 | 175 | # create sintel and kitti submission 176 | if args.submission: 177 | if args.val_dataset[0] == 'sintel': 178 | create_sintel_submission(model_without_ddp, 179 | iters=args.val_iters, 180 | warm_start=args.warm_start, 181 | output_path=args.output_path, 182 | padding_factor=args.padding_factor, 183 | save_vis_flow=args.save_vis_flow, 184 | no_save_flo=args.no_save_flo, 185 | ) 186 | elif args.val_dataset[0] == 'kitti': 187 | create_kitti_submission(model_without_ddp, 188 | iters=args.val_iters, 189 | output_path=args.output_path, 190 | padding_factor=args.padding_factor, 191 | save_vis_flow=args.save_vis_flow, 192 | ) 193 | else: 194 | raise ValueError(f'Not supported dataset for submission') 195 | 196 | return 197 | 198 | # inferece on a dir 199 | if args.inference_dir is not None: 200 | inference_on_dir(model_without_ddp, 201 | inference_dir=args.inference_dir, 202 | iters=args.val_iters, 203 | warm_start=args.warm_start, 204 | output_path=args.output_path, 205 | padding_factor=args.padding_factor, 206 | paired_data=args.dir_paired_data, 207 | save_flo_flow=args.save_flo_flow, 208 | ) 209 | 210 | return 211 | 212 | # train datset 213 | train_dataset = build_dataset(args) 214 | print('Number of training images:', len(train_dataset)) 215 | 216 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, 217 | shuffle=True, num_workers=args.num_workers, 218 | pin_memory=True, drop_last=True) 219 | 220 | last_epoch = start_step if args.resume and not args.no_resume_optimizer else -1 221 | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps + 10, 222 | pct_start=args.lr_warmup, cycle_momentum=False, 223 | anneal_strategy='linear', 224 | last_epoch=last_epoch, 225 | ) 226 | 227 | if args.local_rank == 0: 228 | summary_writer = SummaryWriter(args.checkpoint_dir) 229 | logger = Logger(lr_scheduler, summary_writer, args.summary_freq, 230 | start_step=start_step) 231 | 232 | total_steps = start_step 233 | epoch = start_epoch 234 | print('Start training') 235 | while total_steps < args.num_steps: 236 | model.train() 237 | 238 | # freeze BN after pretraining on chairs 239 | if args.freeze_bn: 240 | model_without_ddp.freeze_bn() 241 | 242 | print('Start epoch %d' % (epoch + 1)) 243 | for i, sample in enumerate(train_loader): 244 | img1, img2, flow_gt, valid = [x.to(device) for x in sample] 245 | 246 | flow_preds = model(img1, img2, iters=args.train_iters)[0] 247 | 248 | loss, metrics = criterion(flow_preds, flow_gt, valid, 249 | gamma=args.gamma, 250 | max_flow=args.max_flow) 251 | 252 | optimizer.zero_grad() 253 | loss.backward() 254 | 255 | # gradient clipping 256 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 257 | 258 | optimizer.step() 259 | lr_scheduler.step() 260 | 261 | if args.local_rank == 0: 262 | logger.push(metrics) 263 | 264 | logger.add_image_summary(img1, img2, flow_preds, flow_gt) 265 | 266 | total_steps += 1 267 | 268 | if total_steps % args.save_ckpt_freq == 0 or total_steps == args.num_steps: 269 | if args.local_rank == 0: 270 | print('Save checkpoint at step: %d' % total_steps) 271 | checkpoint_path = os.path.join(args.checkpoint_dir, 'step_%06d.pth' % total_steps) 272 | torch.save({ 273 | 'model': model_without_ddp.state_dict() 274 | }, checkpoint_path) 275 | 276 | if total_steps % args.save_latest_ckpt_freq == 0: 277 | # Save lastest checkpoint after each epoch 278 | checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint_latest.pth') 279 | 280 | if args.local_rank == 0: 281 | print('Save latest checkpoint') 282 | torch.save({ 283 | 'model': model_without_ddp.state_dict(), 284 | 'optimizer': optimizer.state_dict(), 285 | 'step': total_steps, 286 | 'epoch': epoch, 287 | }, checkpoint_path) 288 | 289 | if total_steps % args.val_freq == 0: 290 | if args.local_rank == 0: 291 | print('Start validation') 292 | 293 | val_results = {} 294 | # Support validation on multiple datasets 295 | if 'chairs' in args.val_dataset: 296 | results_dict = validate_chairs(model_without_ddp, 297 | iters=args.val_iters, 298 | ) 299 | val_results.update(results_dict) 300 | if 'sintel' in args.val_dataset: 301 | results_dict = validate_sintel(model_without_ddp, 302 | iters=args.val_iters, 303 | padding_factor=args.padding_factor, 304 | ) 305 | val_results.update(results_dict) 306 | 307 | if 'kitti' in args.val_dataset: 308 | results_dict = validate_kitti(model_without_ddp, 309 | iters=args.val_iters, 310 | padding_factor=args.padding_factor, 311 | ) 312 | val_results.update(results_dict) 313 | 314 | logger.write_dict(val_results) 315 | 316 | # Save validation results 317 | val_file = os.path.join(args.checkpoint_dir, 'val_results.txt') 318 | with open(val_file, 'a') as f: 319 | f.write('step: %06d\t' % total_steps) 320 | # order of metrics 321 | metrics = ['chairs_epe', 'chairs_1px', 'clean_epe', 'clean_1px', 'final_epe', 'final_1px', 322 | 'kitti_epe', 'kitti_f1'] 323 | for metric in metrics: 324 | if metric in val_results.keys(): 325 | f.write('%s: %.3f\t' % (metric, val_results[metric])) 326 | f.write('\n') 327 | 328 | model.train() 329 | 330 | # freeze BN after pretraining on chairs 331 | if args.freeze_bn: 332 | model_without_ddp.freeze_bn() 333 | 334 | if total_steps >= args.num_steps: 335 | print('Training done') 336 | 337 | return 338 | 339 | epoch += 1 340 | 341 | 342 | if __name__ == '__main__': 343 | parser = get_args_parser() 344 | args = parser.parse_args() 345 | 346 | main(args) 347 | -------------------------------------------------------------------------------- /scripts/demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py \ 4 | --resume pretrained/flow1d_highres-e0b98d7e.pth \ 5 | --val_iters 24 \ 6 | --inference_dir demo/dogs-jump \ 7 | --output_path output/flow1d-dogs-jump 8 | -------------------------------------------------------------------------------- /scripts/evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # evaluate chairs & things trained model on kitti (24 iters) 5 | CUDA_VISIBLE_DEVICES=0 python main.py \ 6 | --eval \ 7 | --val_dataset kitti \ 8 | --resume pretrained/flow1d_things-fd4bee1f.pth \ 9 | --val_iters 24 10 | 11 | 12 | # evaluate chairs & things trained model on sintel (32 iters) 13 | CUDA_VISIBLE_DEVICES=0 python main.py \ 14 | --eval \ 15 | --val_dataset sintel \ 16 | --resume pretrained/flow1d_things-fd4bee1f.pth \ 17 | --val_iters 32 18 | 19 | 20 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # can be trained on a single 32G V100 GPU 4 | 5 | # chairs 6 | CHECKPOINT_DIR=checkpoints/chairs-flow1d && \ 7 | mkdir -p ${CHECKPOINT_DIR} && \ 8 | CUDA_VISIBLE_DEVICES=0 python main.py \ 9 | --checkpoint_dir ${CHECKPOINT_DIR} \ 10 | --batch_size 12 \ 11 | --val_dataset chairs sintel kitti \ 12 | --val_iters 12 \ 13 | --lr 4e-4 \ 14 | --image_size 368 496 \ 15 | --summary_freq 100 \ 16 | --val_freq 10000 \ 17 | --save_ckpt_freq 5000 \ 18 | --save_latest_ckpt_freq 1000 \ 19 | --num_steps 100000 \ 20 | 2>&1 | tee ${CHECKPOINT_DIR}/train.log 21 | 22 | # things 23 | CHECKPOINT_DIR=checkpoints/things-flow1d && \ 24 | mkdir -p ${CHECKPOINT_DIR} && \ 25 | CUDA_VISIBLE_DEVICES=0 python main.py \ 26 | --stage things \ 27 | --resume checkpoints/chairs-flow1d/step_100000.pth \ 28 | --no_resume_optimizer \ 29 | --checkpoint_dir ${CHECKPOINT_DIR} \ 30 | --batch_size 6 \ 31 | --val_dataset sintel kitti \ 32 | --val_iters 12 \ 33 | --lr 1.25e-4 \ 34 | --image_size 400 720 \ 35 | --freeze_bn \ 36 | --summary_freq 100 \ 37 | --val_freq 10000 \ 38 | --save_ckpt_freq 5000 \ 39 | --save_latest_ckpt_freq 1000 \ 40 | --num_steps 100000 \ 41 | 2>&1 | tee ${CHECKPOINT_DIR}/train.log 42 | 43 | # sintel 44 | CHECKPOINT_DIR=checkpoints/sintel-flow1d && \ 45 | mkdir -p ${CHECKPOINT_DIR} && \ 46 | CUDA_VISIBLE_DEVICES=0 python main.py \ 47 | --stage sintel \ 48 | --resume checkpoints/things-flow1d/step_100000.pth \ 49 | --no_resume_optimizer \ 50 | --checkpoint_dir ${CHECKPOINT_DIR} \ 51 | --batch_size 6 \ 52 | --val_dataset sintel kitti \ 53 | --val_iters 12 \ 54 | --lr 1.25e-4 \ 55 | --weight_decay 1e-5 \ 56 | --gamma 0.85 \ 57 | --image_size 368 960 \ 58 | --freeze_bn \ 59 | --summary_freq 100 \ 60 | --val_freq 10000 \ 61 | --save_ckpt_freq 5000 \ 62 | --save_latest_ckpt_freq 1000 \ 63 | --num_steps 100000 \ 64 | 2>&1 | tee ${CHECKPOINT_DIR}/train.log 65 | 66 | # kitti 67 | CHECKPOINT_DIR=checkpoints/kitti-flow1d && \ 68 | mkdir -p ${CHECKPOINT_DIR} && \ 69 | CUDA_VISIBLE_DEVICES=0 python main.py \ 70 | --stage kitti \ 71 | --resume checkpoints/sintel-flow1d/step_100000.pth \ 72 | --no_resume_optimizer \ 73 | --checkpoint_dir ${CHECKPOINT_DIR} \ 74 | --batch_size 6 \ 75 | --val_dataset kitti \ 76 | --val_iters 12 \ 77 | --lr 1e-4 \ 78 | --weight_decay 1e-5 \ 79 | --gamma 0.85 \ 80 | --image_size 320 1024 \ 81 | --freeze_bn \ 82 | --summary_freq 100 \ 83 | --val_freq 10000 \ 84 | --save_ckpt_freq 5000 \ 85 | --save_latest_ckpt_freq 1000 \ 86 | --num_steps 50000 \ 87 | 2>&1 | tee ${CHECKPOINT_DIR}/train.log 88 | 89 | -------------------------------------------------------------------------------- /utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-08-03 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | 21 | 22 | def make_colorwheel(): 23 | ''' 24 | Generates a color wheel for optical flow visualization as presented in: 25 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 26 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 27 | According to the C++ source code of Daniel Scharstein 28 | According to the Matlab source code of Deqing Sun 29 | ''' 30 | 31 | RY = 15 32 | YG = 6 33 | GC = 4 34 | CB = 11 35 | BM = 13 36 | MR = 6 37 | 38 | ncols = RY + YG + GC + CB + BM + MR 39 | colorwheel = np.zeros((ncols, 3)) 40 | col = 0 41 | 42 | # RY 43 | colorwheel[0:RY, 0] = 255 44 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 45 | col = col + RY 46 | # YG 47 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 48 | colorwheel[col:col + YG, 1] = 255 49 | col = col + YG 50 | # GC 51 | colorwheel[col:col + GC, 1] = 255 52 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 53 | col = col + GC 54 | # CB 55 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 56 | colorwheel[col:col + CB, 2] = 255 57 | col = col + CB 58 | # BM 59 | colorwheel[col:col + BM, 2] = 255 60 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 61 | col = col + BM 62 | # MR 63 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 64 | colorwheel[col:col + MR, 0] = 255 65 | return colorwheel 66 | 67 | 68 | def flow_compute_color(u, v, convert_to_bgr=False): 69 | ''' 70 | Applies the flow color wheel to (possibly clipped) flow components u and v. 71 | According to the C++ source code of Daniel Scharstein 72 | According to the Matlab source code of Deqing Sun 73 | :param u: np.ndarray, input horizontal flow 74 | :param v: np.ndarray, input vertical flow 75 | :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB 76 | :return: 77 | ''' 78 | 79 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 80 | 81 | colorwheel = make_colorwheel() # shape [55x3] 82 | ncols = colorwheel.shape[0] 83 | 84 | rad = np.sqrt(np.square(u) + np.square(v)) 85 | a = np.arctan2(-v, -u) / np.pi 86 | 87 | fk = (a + 1) / 2 * (ncols - 1) + 1 88 | k0 = np.floor(fk).astype(np.int32) 89 | k1 = k0 + 1 90 | k1[k1 == ncols] = 1 91 | f = fk - k0 92 | 93 | for i in range(colorwheel.shape[1]): 94 | tmp = colorwheel[:, i] 95 | col0 = tmp[k0] / 255.0 96 | col1 = tmp[k1] / 255.0 97 | col = (1 - f) * col0 + f * col1 98 | 99 | idx = (rad <= 1) 100 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 101 | col[~idx] = col[~idx] * 0.75 # out of range? 102 | 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2 - i if convert_to_bgr else i 105 | flow_image[:, :, ch_idx] = np.floor(255 * col) 106 | 107 | return flow_image 108 | 109 | 110 | def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): 111 | ''' 112 | Expects a two dimensional flow image of shape [H,W,2] 113 | According to the C++ source code of Daniel Scharstein 114 | According to the Matlab source code of Deqing Sun 115 | :param flow_uv: np.ndarray of shape [H,W,2] 116 | :param clip_flow: float, maximum clipping value for flow 117 | :return: 118 | ''' 119 | 120 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 121 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 122 | 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | 126 | u = flow_uv[:, :, 0] 127 | v = flow_uv[:, :, 1] 128 | 129 | rad = np.sqrt(np.square(u) + np.square(v)) 130 | rad_max = np.max(rad) 131 | 132 | epsilon = 1e-5 133 | u = u / (rad_max + epsilon) 134 | v = v / (rad_max + epsilon) 135 | 136 | return flow_compute_color(u, v, convert_to_bgr) 137 | 138 | 139 | UNKNOWN_FLOW_THRESH = 1e7 140 | SMALLFLOW = 0.0 141 | LARGEFLOW = 1e8 142 | 143 | 144 | def make_color_wheel(): 145 | """ 146 | Generate color wheel according Middlebury color code 147 | :return: Color wheel 148 | """ 149 | RY = 15 150 | YG = 6 151 | GC = 4 152 | CB = 11 153 | BM = 13 154 | MR = 6 155 | 156 | ncols = RY + YG + GC + CB + BM + MR 157 | 158 | colorwheel = np.zeros([ncols, 3]) 159 | 160 | col = 0 161 | 162 | # RY 163 | colorwheel[0:RY, 0] = 255 164 | colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) 165 | col += RY 166 | 167 | # YG 168 | colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) 169 | colorwheel[col:col + YG, 1] = 255 170 | col += YG 171 | 172 | # GC 173 | colorwheel[col:col + GC, 1] = 255 174 | colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) 175 | col += GC 176 | 177 | # CB 178 | colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) 179 | colorwheel[col:col + CB, 2] = 255 180 | col += CB 181 | 182 | # BM 183 | colorwheel[col:col + BM, 2] = 255 184 | colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) 185 | col += + BM 186 | 187 | # MR 188 | colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 189 | colorwheel[col:col + MR, 0] = 255 190 | 191 | return colorwheel 192 | 193 | 194 | def compute_color(u, v): 195 | """ 196 | compute optical flow color map 197 | :param u: optical flow horizontal map 198 | :param v: optical flow vertical map 199 | :return: optical flow in color code 200 | """ 201 | [h, w] = u.shape 202 | img = np.zeros([h, w, 3]) 203 | nanIdx = np.isnan(u) | np.isnan(v) 204 | u[nanIdx] = 0 205 | v[nanIdx] = 0 206 | 207 | colorwheel = make_color_wheel() 208 | ncols = np.size(colorwheel, 0) 209 | 210 | rad = np.sqrt(u ** 2 + v ** 2) 211 | 212 | a = np.arctan2(-v, -u) / np.pi 213 | 214 | fk = (a + 1) / 2 * (ncols - 1) + 1 215 | 216 | k0 = np.floor(fk).astype(int) 217 | 218 | k1 = k0 + 1 219 | k1[k1 == ncols + 1] = 1 220 | f = fk - k0 221 | 222 | for i in range(0, np.size(colorwheel, 1)): 223 | tmp = colorwheel[:, i] 224 | col0 = tmp[k0 - 1] / 255 225 | col1 = tmp[k1 - 1] / 255 226 | col = (1 - f) * col0 + f * col1 227 | 228 | idx = rad <= 1 229 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 230 | notidx = np.logical_not(idx) 231 | 232 | col[notidx] *= 0.75 233 | img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx))) 234 | 235 | return img 236 | 237 | 238 | # from https://github.com/gengshan-y/VCN 239 | def flow_to_image(flow): 240 | """ 241 | Convert flow into middlebury color code image 242 | :param flow: optical flow map 243 | :return: optical flow image in middlebury color 244 | """ 245 | u = flow[:, :, 0] 246 | v = flow[:, :, 1] 247 | 248 | maxu = -999. 249 | maxv = -999. 250 | minu = 999. 251 | minv = 999. 252 | 253 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) 254 | u[idxUnknow] = 0 255 | v[idxUnknow] = 0 256 | 257 | maxu = max(maxu, np.max(u)) 258 | minu = min(minu, np.min(u)) 259 | 260 | maxv = max(maxv, np.max(v)) 261 | minv = min(minv, np.min(v)) 262 | 263 | rad = np.sqrt(u ** 2 + v ** 2) 264 | maxrad = max(-1, np.max(rad)) 265 | 266 | u = u / (maxrad + np.finfo(float).eps) 267 | v = v / (maxrad + np.finfo(float).eps) 268 | 269 | img = compute_color(u, v) 270 | 271 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 272 | img[idx] = 0 273 | 274 | return np.uint8(img) 275 | 276 | 277 | def save_vis_flow_tofile(flow, output_path): 278 | vis_flow = flow_to_image(flow) 279 | from PIL import Image 280 | img = Image.fromarray(vis_flow) 281 | img.save(output_path) 282 | 283 | 284 | def flow_tensor_to_image(flow): 285 | """Used for tensorboard visualization""" 286 | flow = flow.permute(1, 2, 0) # [H, W, 2] 287 | flow = flow.detach().cpu().numpy() 288 | flow = flow_to_image(flow) # [H, W, 3] 289 | flow = np.transpose(flow, (2, 0, 1)) # [3, H, W] 290 | 291 | return flow 292 | -------------------------------------------------------------------------------- /utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | import cv2 6 | 7 | TAG_CHAR = np.array([202021.25], np.float32) 8 | 9 | 10 | def readFlow(fn): 11 | """ Read .flo file in Middlebury format""" 12 | # Code adapted from: 13 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 14 | 15 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 16 | # print 'fn = %s'%(fn) 17 | with open(fn, 'rb') as f: 18 | magic = np.fromfile(f, np.float32, count=1) 19 | if 202021.25 != magic: 20 | print('Magic number incorrect. Invalid .flo file') 21 | return None 22 | else: 23 | w = np.fromfile(f, np.int32, count=1) 24 | h = np.fromfile(f, np.int32, count=1) 25 | # print 'Reading %d x %d flo file\n' % (w, h) 26 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 27 | # Reshape data into 3D array (columns, rows, bands) 28 | # The reshape here is for visualization, the original code is (w,h,2) 29 | return np.resize(data, (int(h), int(w), 2)) 30 | 31 | 32 | def readPFM(file): 33 | file = open(file, 'rb') 34 | 35 | color = None 36 | width = None 37 | height = None 38 | scale = None 39 | endian = None 40 | 41 | header = file.readline().rstrip() 42 | if header == b'PF': 43 | color = True 44 | elif header == b'Pf': 45 | color = False 46 | else: 47 | raise Exception('Not a PFM file.') 48 | 49 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 50 | if dim_match: 51 | width, height = map(int, dim_match.groups()) 52 | else: 53 | raise Exception('Malformed PFM header.') 54 | 55 | scale = float(file.readline().rstrip()) 56 | if scale < 0: # little-endian 57 | endian = '<' 58 | scale = -scale 59 | else: 60 | endian = '>' # big-endian 61 | 62 | data = np.fromfile(file, endian + 'f') 63 | shape = (height, width, 3) if color else (height, width) 64 | 65 | data = np.reshape(data, shape) 66 | data = np.flipud(data) 67 | return data 68 | 69 | 70 | def writeFlow(filename, uv, v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert (uv.ndim == 3) 81 | assert (uv.shape[2] == 2) 82 | u = uv[:, :, 0] 83 | v = uv[:, :, 1] 84 | else: 85 | u = uv 86 | 87 | assert (u.shape == v.shape) 88 | height, width = u.shape 89 | f = open(filename, 'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width * nBands)) 96 | tmp[:, np.arange(width) * 2] = u 97 | tmp[:, np.arange(width) * 2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) 104 | flow = flow[:, :, ::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2 ** 15) / 64.0 107 | return flow, valid 108 | 109 | 110 | def readDispKITTI(filename): 111 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 112 | valid = disp > 0.0 113 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 114 | return flow, valid 115 | 116 | 117 | def writeFlowKITTI(filename, uv): 118 | uv = 64.0 * uv + 2 ** 15 119 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 120 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 121 | cv2.imwrite(filename, uv[..., ::-1]) 122 | 123 | 124 | def read_gen(file_name, pil=False): 125 | ext = splitext(file_name)[-1] 126 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 127 | return Image.open(file_name) 128 | elif ext == '.bin' or ext == '.raw': 129 | return np.load(file_name) 130 | elif ext == '.flo': 131 | return readFlow(file_name).astype(np.float32) 132 | elif ext == '.pfm': 133 | flow = readPFM(file_name).astype(np.float32) 134 | if len(flow.shape) == 2: 135 | return flow 136 | else: 137 | return flow[:, :, :-1] 138 | return [] 139 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils.flow_viz import flow_tensor_to_image 4 | 5 | 6 | class Logger: 7 | def __init__(self, lr_scheduler, 8 | summary_writer, 9 | summary_freq=100, 10 | start_step=0, 11 | ): 12 | self.lr_scheduler = lr_scheduler 13 | self.total_steps = start_step 14 | self.running_loss = {} 15 | self.summary_writer = summary_writer 16 | self.summary_freq = summary_freq 17 | 18 | def print_training_status(self, mode='train'): 19 | print('step: %06d \t epe: %.3f' % (self.total_steps, self.running_loss['epe'] / self.summary_freq)) 20 | 21 | for k in self.running_loss: 22 | self.summary_writer.add_scalar(mode + '/' + k, 23 | self.running_loss[k] / self.summary_freq, self.total_steps) 24 | self.running_loss[k] = 0.0 25 | 26 | def lr_summary(self): 27 | lr = self.lr_scheduler.get_last_lr()[0] 28 | self.summary_writer.add_scalar('lr', lr, self.total_steps) 29 | 30 | def add_image_summary(self, img1, img2, flow_preds, flow_gt, mode='train', 31 | pred_bidirectional_flow=False): 32 | if self.total_steps % self.summary_freq == 0: 33 | img_concat = torch.cat((img1[0].detach().cpu(), img2[0].detach().cpu()), dim=-1) 34 | img_concat = img_concat.type(torch.uint8) # convert to uint8 to visualize in tensorboard 35 | 36 | flow_pred = flow_tensor_to_image(flow_preds[-1][0]) 37 | forward_flow_gt = flow_tensor_to_image(flow_gt[0]) 38 | flow_concat = torch.cat((torch.from_numpy(flow_pred), 39 | torch.from_numpy(forward_flow_gt)), dim=-1) 40 | 41 | concat = torch.cat((img_concat, flow_concat), dim=-2) 42 | 43 | self.summary_writer.add_image(mode + '/img_pred_gt', concat, self.total_steps) 44 | 45 | def add_init_flow_summary(self, init_flow, mode='train', tag='init_flow'): 46 | if self.total_steps % self.summary_freq == 0: 47 | init_flow = flow_tensor_to_image(init_flow[0]) 48 | init_flow = torch.from_numpy(init_flow) 49 | 50 | self.summary_writer.add_image(mode + '/' + tag, init_flow, self.total_steps) 51 | 52 | def push(self, metrics, mode='train'): 53 | self.total_steps += 1 54 | 55 | self.lr_summary() 56 | 57 | for key in metrics: 58 | if key not in self.running_loss: 59 | self.running_loss[key] = 0.0 60 | 61 | self.running_loss[key] += metrics[key] 62 | 63 | if self.total_steps % self.summary_freq == 0: 64 | self.print_training_status(mode) 65 | self.running_loss = {} 66 | 67 | def write_dict(self, results): 68 | for key in results: 69 | tag = key.split('_')[0] 70 | tag = tag + '/' + key 71 | self.summary_writer.add_scalar(tag, results[key], self.total_steps) 72 | 73 | def close(self): 74 | self.summary_writer.close() 75 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sys 4 | import json 5 | 6 | 7 | def read_text_lines(filepath): 8 | with open(filepath, 'r') as f: 9 | lines = f.readlines() 10 | lines = [l.rstrip() for l in lines] 11 | return lines 12 | 13 | 14 | def check_path(path): 15 | if not os.path.exists(path): 16 | os.makedirs(path, exist_ok=True) # explicitly set exist_ok when multi-processing 17 | 18 | 19 | def save_command(save_path, filename='command_train.txt'): 20 | check_path(save_path) 21 | command = sys.argv 22 | save_file = os.path.join(save_path, filename) 23 | # Save all training commands when resuming training 24 | with open(save_file, 'a') as f: 25 | f.write(' '.join(command)) 26 | f.write('\n\n') 27 | 28 | 29 | def save_args(args, filename='args.json'): 30 | args_dict = vars(args) 31 | check_path(args.checkpoint_dir) 32 | save_path = os.path.join(args.checkpoint_dir, filename) 33 | 34 | # Save all training args when resuming training 35 | with open(save_path, 'a') as f: 36 | json.dump(args_dict, f, indent=4, sort_keys=False) 37 | f.write('\n\n') 38 | 39 | 40 | def int_list(s): 41 | """Convert string to int list""" 42 | return [int(x) for x in s.split(',')] 43 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | 10 | def __init__(self, dims, mode='sintel', padding_factor=8): 11 | self.ht, self.wd = dims[-2:] 12 | pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor 13 | pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor 14 | if mode == 'sintel': 15 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] 16 | else: 17 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] 18 | 19 | def pad(self, *inputs): 20 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 21 | 22 | def unpad(self, x): 23 | ht, wd = x.shape[-2:] 24 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 25 | return x[..., c[0]:c[1], c[2]:c[3]] 26 | 27 | 28 | def forward_interpolate(flow): 29 | flow = flow.detach().cpu().numpy() # [2, H, W] 30 | dx, dy = flow[0], flow[1] 31 | 32 | ht, wd = dx.shape 33 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 34 | 35 | x1 = x0 + dx 36 | y1 = y0 + dy 37 | 38 | x1 = x1.reshape(-1) 39 | y1 = y1.reshape(-1) 40 | dx = dx.reshape(-1) 41 | dy = dy.reshape(-1) 42 | 43 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 44 | x1 = x1[valid] 45 | y1 = y1[valid] 46 | dx = dx[valid] 47 | dy = dy[valid] 48 | 49 | flow_x = interpolate.griddata( 50 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 51 | 52 | flow_y = interpolate.griddata( 53 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 54 | 55 | flow = np.stack([flow_x, flow_y], axis=0) 56 | return torch.from_numpy(flow).float() 57 | 58 | 59 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 60 | """ Wrapper for grid_sample, uses pixel coordinates """ 61 | if coords.size(-1) != 2: # [B, 2, H, W] -> [B, H, W, 2] 62 | coords = coords.permute(0, 2, 3, 1) 63 | 64 | H, W = img.shape[-2:] 65 | # H = height if height is not None else img.shape[-2] 66 | # W = width if width is not None else img.shape[-1] 67 | 68 | xgrid, ygrid = coords.split([1, 1], dim=-1) 69 | 70 | # To handle H or W equals to 1 by explicitly defining height and width 71 | if H == 1: 72 | assert ygrid.abs().max() < 1e-8 73 | H = 10 74 | if W == 1: 75 | assert xgrid.abs().max() < 1e-8 76 | W = 10 77 | 78 | xgrid = 2 * xgrid / (W - 1) - 1 79 | ygrid = 2 * ygrid / (H - 1) - 1 80 | 81 | grid = torch.cat([xgrid, ygrid], dim=-1) 82 | img = F.grid_sample(img, grid, mode=mode, align_corners=True) 83 | 84 | if mask: 85 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 86 | return img, mask.squeeze(-1).float() 87 | 88 | return img 89 | 90 | 91 | def coords_grid(batch, ht, wd, normalize=False): 92 | if normalize: # [-1, 1] 93 | coords = torch.meshgrid(2 * torch.arange(ht) / (ht - 1) - 1, 94 | 2 * torch.arange(wd) / (wd - 1) - 1) 95 | else: 96 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 97 | coords = torch.stack(coords[::-1], dim=0).float() 98 | return coords[None].repeat(batch, 1, 1, 1) # [B, 2, H, W] 99 | 100 | 101 | def coords_grid_np(h, w): # used for accumulating high speed sintel flow data 102 | coords = np.meshgrid(np.arange(h, dtype=np.float32), 103 | np.arange(w, dtype=np.float32), indexing='ij') 104 | coords = np.stack(coords[::-1], axis=-1) # [H, W, 2] 105 | 106 | return coords 107 | 108 | 109 | def normalize_coords(grid): 110 | """Normalize coordinates of image scale to [-1, 1] 111 | Args: 112 | grid: [B, 2, H, W] 113 | """ 114 | assert grid.size(1) == 2 115 | h, w = grid.size()[2:] 116 | grid[:, 0, :, :] = 2 * (grid[:, 0, :, :].clone() / (w - 1)) - 1 # x: [-1, 1] 117 | grid[:, 1, :, :] = 2 * (grid[:, 1, :, :].clone() / (h - 1)) - 1 # y: [-1, 1] 118 | # grid = grid.permute((0, 2, 3, 1)) # [B, H, W, 2] 119 | return grid 120 | 121 | 122 | def flow_warp(feature, flow, mask=False): 123 | b, c, h, w = feature.size() 124 | assert flow.size(1) == 2 125 | 126 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 127 | 128 | return bilinear_sampler(feature, grid, mask=mask) 129 | 130 | 131 | def upflow8(flow, mode='bilinear'): 132 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 133 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 134 | 135 | 136 | def bilinear_upflow(flow, scale_factor=8): 137 | assert flow.size(1) == 2 138 | flow = F.interpolate(flow, scale_factor=scale_factor, 139 | mode='bilinear', align_corners=True) * scale_factor 140 | 141 | return flow 142 | 143 | 144 | def upsample_flow(flow, img): 145 | if flow.size(-1) != img.size(-1): 146 | scale_factor = img.size(-1) / flow.size(-1) 147 | flow = F.interpolate(flow, size=img.size()[-2:], 148 | mode='bilinear', align_corners=True) * scale_factor 149 | return flow 150 | 151 | 152 | def count_parameters(model): 153 | num = sum(p.numel() for p in model.parameters() if p.requires_grad) 154 | return num 155 | 156 | 157 | def set_bn_eval(m): 158 | classname = m.__class__.__name__ 159 | if classname.find('BatchNorm') != -1: 160 | m.eval() 161 | --------------------------------------------------------------------------------