├── LICENSE
├── README.md
├── assets
├── flow1d.png
└── teaser.png
├── data
├── __init__.py
├── chairs_split.txt
├── datasets.py
└── transforms.py
├── demo
└── dogs-jump
│ ├── 00033.jpg
│ ├── 00034.jpg
│ ├── 00035.jpg
│ └── 00036.jpg
├── environment.yml
├── evaluate.py
├── flow1d
├── __init__.py
├── attention.py
├── correlation.py
├── extractor.py
├── flow1d.py
├── position.py
└── update.py
├── loss.py
├── main.py
├── scripts
├── demo.sh
├── evaluate.sh
└── train.sh
└── utils
├── flow_viz.py
├── frame_utils.py
├── logger.py
├── misc.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Haofei Xu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Flow1D
2 |
3 | Official PyTorch implementation of paper:
4 |
5 | [**High-Resolution Optical Flow from 1D Attention and Correlation**](https://arxiv.org/abs/2104.13918), **ICCV 2021, Oral**
6 |
7 | Authors: [Haofei Xu](https://haofeixu.github.io/), [Jiaolong Yang](https://jlyang.org/), [Jianfei Cai](https://jianfei-cai.github.io/), [Juyong Zhang](http://staff.ustc.edu.cn/~juyong/), [Xin Tong](https://scholar.google.com/citations?user=P91a-UQAAAAJ&hl=en&oi=ao)
8 |
9 | **11/15/2022 Update: Check out our new work: [Unifying Flow, Stereo and Depth Estimation](https://haofeixu.github.io/unimatch/) and code: [unimatch](https://github.com/autonomousvision/unimatch) for estimating optical flow with our new GMFlow model. [9 pretrained GMFlow models](https://github.com/autonomousvision/unimatch/blob/master/MODEL_ZOO.md) with different speed-accuray trade-offs are also released. Check out our [Colab](https://colab.research.google.com/drive/1r5m-xVy3Kw60U-m5VB-aQ98oqqg_6cab?usp=sharing) and [HuggingFace](https://huggingface.co/spaces/haofeixu/unimatch) demo to play with GMFlow in your browser!**
10 |
11 | We enabled **4K resolution** optical flow estimation by factorizing 2D optical flow with 1D attention and 1D correlation.
12 |
13 |
14 |
15 |

16 |
17 |
18 |
19 |
20 | The full framework:
21 |
22 |
23 |
24 |
25 | 
26 |
27 |
28 |
29 |
30 |
31 |
32 | ## Installation
33 |
34 | Our code is based on pytorch 1.7.1, CUDA 10.2 and python 3.7. Higher version pytorch should also work well.
35 |
36 | We recommend using [conda](https://www.anaconda.com/distribution/) for installation:
37 |
38 | ```
39 | conda env create -f environment.yml
40 | conda activate flow1d
41 | ```
42 |
43 | ## Demos
44 |
45 | All pretrained models can be downloaded from [google drive](https://drive.google.com/file/d/1IzcmvxpY90DuXYkGkwitxslO1Psq52OI/view?usp=sharing).
46 |
47 |
48 |
49 | You can run a trained model on a sequence of images and visualize the results (as shown in [scripts/demo.sh](scripts/demo.sh)):
50 |
51 | ```
52 | CUDA_VISIBLE_DEVICES=0 python main.py \
53 | --resume pretrained/flow1d_highres-e0b98d7e.pth \
54 | --val_iters 24 \
55 | --inference_dir demo/dogs-jump \
56 | --output_path output/flow1d-dogs-jump
57 | ```
58 |
59 |
60 |
61 | ## Datasets
62 |
63 | The datasets used to train and evaluate Flow1D are as follows:
64 |
65 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
66 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
67 | * [Sintel](http://sintel.is.tue.mpg.de/)
68 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
69 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/)
70 |
71 | By default the dataloader [datasets.py](data/datasets.py) assumes the datasets are located in folder `datasets` and are organized as follows:
72 |
73 | ```
74 | datasets
75 | ├── FlyingChairs_release
76 | │ └── data
77 | ├── FlyingThings3D
78 | │ ├── frames_cleanpass
79 | │ ├── frames_finalpass
80 | │ └── optical_flow
81 | ├── HD1K
82 | │ ├── hd1k_challenge
83 | │ ├── hd1k_flow_gt
84 | │ ├── hd1k_flow_uncertainty
85 | │ └── hd1k_input
86 | ├── KITTI
87 | │ ├── testing
88 | │ └── training
89 | ├── Sintel
90 | │ ├── test
91 | │ └── training
92 | ```
93 |
94 | It is recommended to symlink your dataset root to `datasets`:
95 |
96 | ```shell
97 | ln -s $YOUR_DATASET_ROOT datasets
98 | ```
99 |
100 | Otherwise, you may need to change the corresponding paths in [datasets.py](data/datasets.py).
101 |
102 |
103 |
104 | ## Evaluation
105 |
106 | You can evaluate a trained Flow1D model by running:
107 |
108 | ```
109 | CUDA_VISIBLE_DEVICES=0 python main.py --eval --val_dataset kitti --resume pretrained/flow1d_things-fd4bee1f.pth --val_iters 24
110 | ```
111 |
112 | More evaluation scripts can be found in [scripts/evaluate.sh](scripts/evaluate.sh).
113 |
114 |
115 |
116 | ## Training
117 |
118 | All training scripts on FlyingChairs, FlyingThings3D, Sintel and KITTI datasets can be found in [scripts/train.sh](scripts/train.sh).
119 |
120 | Note that our Flow1D model can be trained on a single 32GB V100 GPU. You may need to tune the number of GPUs used for training according to your hardware.
121 |
122 |
123 |
124 | We support using tensorboard to monitor and visualize the training process. You can first start a tensorboard session with
125 |
126 | ```shell
127 | tensorboard --logdir checkpoints
128 | ```
129 |
130 | and then access [http://localhost:6006](http://localhost:6006) in your browser.
131 |
132 |
133 |
134 | ## Citation
135 |
136 | If you find our work useful in your research, please consider citing our paper:
137 |
138 | ```
139 | @inproceedings{xu2021high,
140 | title={High-Resolution Optical Flow from 1D Attention and Correlation},
141 | author={Xu, Haofei and Yang, Jiaolong and Cai, Jianfei and Zhang, Juyong and Tong, Xin},
142 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
143 | pages={10498--10507},
144 | year={2021}
145 | }
146 | ```
147 |
148 |
149 |
150 | ## Acknowledgements
151 |
152 | This project is heavily based on [RAFT](https://github.com/princeton-vl/RAFT). We thank the original authors for their excellent work.
153 |
154 |
155 |
156 |
--------------------------------------------------------------------------------
/assets/flow1d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/assets/flow1d.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/assets/teaser.png
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import build_dataset
2 | from .datasets import (FlyingChairs,
3 | FlyingThings3D,
4 | MpiSintel,
5 | KITTI,
6 | HD1K,
7 | )
8 |
--------------------------------------------------------------------------------
/data/datasets.py:
--------------------------------------------------------------------------------
1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data as data
6 |
7 | import os
8 | import random
9 | from glob import glob
10 | import os.path as osp
11 |
12 | from utils import frame_utils
13 | from data.transforms import FlowAugmentor, SparseFlowAugmentor
14 |
15 |
16 | class FlowDataset(data.Dataset):
17 | def __init__(self, aug_params=None, sparse=False,
18 | ):
19 | self.augmentor = None
20 | self.sparse = sparse
21 |
22 | if aug_params is not None:
23 | if sparse:
24 | self.augmentor = SparseFlowAugmentor(**aug_params)
25 | else:
26 | self.augmentor = FlowAugmentor(**aug_params)
27 |
28 | self.is_test = False
29 | self.init_seed = False
30 | self.flow_list = []
31 | self.image_list = []
32 | self.extra_info = []
33 |
34 | def __getitem__(self, index):
35 |
36 | if self.is_test:
37 | img1 = frame_utils.read_gen(self.image_list[index][0])
38 | img2 = frame_utils.read_gen(self.image_list[index][1])
39 | img1 = np.array(img1).astype(np.uint8)[..., :3]
40 | img2 = np.array(img2).astype(np.uint8)[..., :3]
41 |
42 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
43 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
44 |
45 | return img1, img2, self.extra_info[index]
46 |
47 | if not self.init_seed:
48 | worker_info = torch.utils.data.get_worker_info()
49 | if worker_info is not None:
50 | torch.manual_seed(worker_info.id)
51 | np.random.seed(worker_info.id)
52 | random.seed(worker_info.id)
53 | self.init_seed = True
54 |
55 | index = index % len(self.image_list)
56 | valid = None
57 | if self.sparse:
58 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
59 | else:
60 | flow = frame_utils.read_gen(self.flow_list[index])
61 |
62 | img1 = frame_utils.read_gen(self.image_list[index][0])
63 | img2 = frame_utils.read_gen(self.image_list[index][1])
64 |
65 | flow = np.array(flow).astype(np.float32)
66 | img1 = np.array(img1).astype(np.uint8)
67 | img2 = np.array(img2).astype(np.uint8)
68 |
69 | # grayscale images
70 | if len(img1.shape) == 2:
71 | img1 = np.tile(img1[..., None], (1, 1, 3))
72 | img2 = np.tile(img2[..., None], (1, 1, 3))
73 | else:
74 | img1 = img1[..., :3]
75 | img2 = img2[..., :3]
76 |
77 | if self.augmentor is not None:
78 | if self.sparse:
79 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
80 | else:
81 | img1, img2, flow = self.augmentor(img1, img2, flow)
82 |
83 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
84 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
85 | flow = torch.from_numpy(flow).permute(2, 0, 1).float()
86 |
87 | if valid is not None:
88 | valid = torch.from_numpy(valid)
89 | else:
90 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
91 |
92 | return img1, img2, flow, valid.float()
93 |
94 | def __rmul__(self, v):
95 | self.flow_list = v * self.flow_list
96 | self.image_list = v * self.image_list
97 |
98 | return self
99 |
100 | def __len__(self):
101 | return len(self.image_list)
102 |
103 |
104 | class MpiSintel(FlowDataset):
105 | def __init__(self, aug_params=None, split='training',
106 | root='datasets/Sintel',
107 | dstype='clean'):
108 | super(MpiSintel, self).__init__(aug_params)
109 |
110 | flow_root = osp.join(root, split, 'flow')
111 | image_root = osp.join(root, split, dstype)
112 |
113 | if split == 'test':
114 | self.is_test = True
115 |
116 | for scene in os.listdir(image_root):
117 | image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
118 | for i in range(len(image_list) - 1):
119 | self.image_list += [[image_list[i], image_list[i + 1]]]
120 | self.extra_info += [(scene, i)] # scene and frame_id
121 |
122 | if split != 'test':
123 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
124 |
125 |
126 | class FlyingChairs(FlowDataset):
127 | def __init__(self, aug_params=None, split='train',
128 | root='datasets/FlyingChairs_release/data',
129 | ):
130 | super(FlyingChairs, self).__init__(aug_params)
131 |
132 | images = sorted(glob(osp.join(root, '*.ppm')))
133 | flows = sorted(glob(osp.join(root, '*.flo')))
134 | assert (len(images) // 2 == len(flows))
135 |
136 | split_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'chairs_split.txt')
137 | split_list = np.loadtxt(split_file, dtype=np.int32)
138 | for i in range(len(flows)):
139 | xid = split_list[i]
140 | if (split == 'training' and xid == 1) or (split == 'validation' and xid == 2):
141 | self.flow_list += [flows[i]]
142 | self.image_list += [[images[2 * i], images[2 * i + 1]]]
143 |
144 |
145 | class FlyingThings3D(FlowDataset):
146 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D',
147 | dstype='frames_cleanpass'):
148 | super(FlyingThings3D, self).__init__(aug_params)
149 |
150 | img_dir = root
151 | flow_dir = root
152 |
153 | for cam in ['left']:
154 | for direction in ['into_future', 'into_past']:
155 | image_dirs = sorted(glob(osp.join(img_dir, dstype, 'TRAIN/*/*')))
156 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
157 |
158 | flow_dirs = sorted(glob(osp.join(flow_dir, 'optical_flow/TRAIN/*/*')))
159 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
160 |
161 | for idir, fdir in zip(image_dirs, flow_dirs):
162 | images = sorted(glob(osp.join(idir, '*.png')))
163 | flows = sorted(glob(osp.join(fdir, '*.pfm')))
164 | for i in range(len(flows) - 1):
165 | if direction == 'into_future':
166 | self.image_list += [[images[i], images[i + 1]]]
167 | self.flow_list += [flows[i]]
168 | elif direction == 'into_past':
169 | self.image_list += [[images[i + 1], images[i]]]
170 | self.flow_list += [flows[i + 1]]
171 |
172 |
173 | class KITTI(FlowDataset):
174 | def __init__(self, aug_params=None, split='training',
175 | root='datasets/KITTI',
176 | ):
177 | super(KITTI, self).__init__(aug_params, sparse=True)
178 | if split == 'testing':
179 | self.is_test = True
180 |
181 | root = osp.join(root, split)
182 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
183 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
184 |
185 | for img1, img2 in zip(images1, images2):
186 | frame_id = img1.split('/')[-1]
187 | self.extra_info += [[frame_id]]
188 | self.image_list += [[img1, img2]]
189 |
190 | if split == 'training':
191 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
192 |
193 |
194 | class HD1K(FlowDataset):
195 | def __init__(self, aug_params=None, root='datasets/HD1K'):
196 | super(HD1K, self).__init__(aug_params, sparse=True)
197 |
198 | seq_ix = 0
199 | while 1:
200 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
201 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
202 |
203 | if len(flows) == 0:
204 | break
205 |
206 | for i in range(len(flows) - 1):
207 | self.flow_list += [flows[i]]
208 | self.image_list += [[images[i], images[i + 1]]]
209 |
210 | seq_ix += 1
211 |
212 |
213 | def build_dataset(args):
214 | """ Create the data loader for the corresponding training set """
215 | if args.stage == 'chairs':
216 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
217 |
218 | train_dataset = FlyingChairs(aug_params, split='training')
219 |
220 | elif args.stage == 'things':
221 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
222 |
223 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
224 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
225 |
226 | train_dataset = clean_dataset + final_dataset
227 |
228 | elif args.stage == 'sintel':
229 | # 1041 pairs for clean and final each
230 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
231 |
232 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
233 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
234 | sintel_final = MpiSintel(aug_params, split='training', dstype='final')
235 |
236 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
237 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
238 |
239 | train_dataset = 100 * sintel_clean + 100 * sintel_final + 200 * kitti + 5 * hd1k + things
240 |
241 | elif args.stage == 'kitti':
242 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
243 | train_dataset = KITTI(aug_params, split='training')
244 |
245 | else:
246 | raise ValueError(f'stage {args.stage} is not supported')
247 |
248 | return train_dataset
249 |
--------------------------------------------------------------------------------
/data/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | import cv2
5 |
6 | from torchvision.transforms import ColorJitter
7 |
8 |
9 | class FlowAugmentor:
10 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True,
11 | resize_when_needed=False,
12 | no_eraser_aug=False,
13 | ):
14 | # TODO: support resize to higher resolution, and then do croping
15 | # for instance, resize all slow_flow data to 1024x1280
16 |
17 | # spatial augmentation params
18 | self.crop_size = crop_size
19 | self.min_scale = min_scale
20 | self.max_scale = max_scale
21 | self.spatial_aug_prob = 0.8
22 | self.stretch_prob = 0.8
23 | self.max_stretch = 0.2
24 |
25 | self.resize_when_needed = resize_when_needed
26 |
27 | # flip augmentation params
28 | self.do_flip = do_flip
29 | self.h_flip_prob = 0.5
30 | self.v_flip_prob = 0.1
31 |
32 | # photometric augmentation params
33 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14)
34 | self.asymmetric_color_aug_prob = 0.2
35 |
36 | if no_eraser_aug:
37 | self.eraser_aug_prob = -1
38 | else:
39 | self.eraser_aug_prob = 0.5
40 |
41 | def color_transform(self, img1, img2):
42 | """ Photometric augmentation """
43 |
44 | # asymmetric
45 | if np.random.rand() < self.asymmetric_color_aug_prob:
46 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
47 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
48 |
49 | # symmetric
50 | else:
51 | image_stack = np.concatenate([img1, img2], axis=0)
52 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
53 | img1, img2 = np.split(image_stack, 2, axis=0)
54 |
55 | return img1, img2
56 |
57 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
58 | """ Occlusion augmentation """
59 |
60 | ht, wd = img1.shape[:2]
61 | if np.random.rand() < self.eraser_aug_prob:
62 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
63 | for _ in range(np.random.randint(1, 3)):
64 | x0 = np.random.randint(0, wd)
65 | y0 = np.random.randint(0, ht)
66 | dx = np.random.randint(bounds[0], bounds[1])
67 | dy = np.random.randint(bounds[0], bounds[1])
68 | img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color
69 |
70 | return img1, img2
71 |
72 | def spatial_transform(self, img1, img2, flow, backward_flow=None, occlusion=None, backward_occlusion=None):
73 | # randomly sample scale
74 | ht, wd = img1.shape[:2]
75 | min_scale = np.maximum(
76 | (self.crop_size[0] + 8) / float(ht),
77 | (self.crop_size[1] + 8) / float(wd))
78 |
79 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
80 | scale_x = scale
81 | scale_y = scale
82 | if np.random.rand() < self.stretch_prob:
83 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
84 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
85 |
86 | scale_x = np.clip(scale_x, min_scale, None)
87 | scale_y = np.clip(scale_y, min_scale, None)
88 |
89 | if np.random.rand() < self.spatial_aug_prob:
90 | # rescale the images
91 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
92 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
93 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
94 | flow = flow * [scale_x, scale_y]
95 |
96 | if backward_flow is not None:
97 | backward_flow = cv2.resize(backward_flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
98 | backward_flow = backward_flow * [scale_x, scale_y]
99 |
100 | if occlusion is not None:
101 | occlusion = cv2.resize(occlusion, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
102 | if backward_occlusion is not None:
103 | backward_occlusion = cv2.resize(backward_occlusion, None, fx=scale_x, fy=scale_y,
104 | interpolation=cv2.INTER_LINEAR)
105 |
106 | if self.do_flip:
107 | if np.random.rand() < self.h_flip_prob: # h-flip
108 | img1 = img1[:, ::-1]
109 | img2 = img2[:, ::-1]
110 | flow = flow[:, ::-1] * [-1.0, 1.0]
111 |
112 | if backward_flow is not None:
113 | backward_flow = backward_flow[:, ::-1] * [-1.0, 1.0]
114 |
115 | if occlusion is not None:
116 | occlusion = occlusion[:, ::-1]
117 | if backward_occlusion is not None:
118 | backward_occlusion = backward_occlusion[:, ::-1]
119 |
120 | if np.random.rand() < self.v_flip_prob: # v-flip
121 | img1 = img1[::-1, :]
122 | img2 = img2[::-1, :]
123 | flow = flow[::-1, :] * [1.0, -1.0]
124 |
125 | if backward_flow is not None:
126 | backward_flow = backward_flow[::-1, :] * [1.0, -1.0]
127 |
128 | if occlusion is not None:
129 | occlusion = occlusion[::-1, :]
130 | if backward_occlusion is not None:
131 | backward_occlusion = backward_occlusion[::-1, :]
132 |
133 | # In case no cropping
134 | if img1.shape[0] - self.crop_size[0] > 0:
135 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
136 | else:
137 | y0 = 0
138 | if img1.shape[1] - self.crop_size[1] > 0:
139 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
140 | else:
141 | x0 = 0
142 |
143 | img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
144 | img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
145 | flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
146 |
147 | if backward_flow is not None:
148 | backward_flow = backward_flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
149 |
150 | if occlusion is not None:
151 | occlusion = occlusion[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
152 |
153 | if backward_occlusion is not None:
154 | backward_occlusion = backward_occlusion[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
155 |
156 | return img1, img2, flow, backward_flow, occlusion, backward_occlusion
157 |
158 | return img1, img2, flow, backward_flow, occlusion
159 |
160 | return img1, img2, flow, backward_flow
161 |
162 | return img1, img2, flow
163 |
164 | def resize(self, img1, img2, flow):
165 | ori_h, ori_w = img1.shape[:2]
166 |
167 | if ori_h < self.crop_size[0] and ori_w < self.crop_size[1]:
168 | # resize both h and w
169 | scale_y = self.crop_size[0] / ori_h
170 | scale_x = self.crop_size[1] / ori_w
171 | elif ori_h < self.crop_size[0]: # only resize h
172 | scale_y = self.crop_size[0] / ori_h
173 | scale_x = 1.
174 | elif ori_w < self.crop_size[1]: # only resize w
175 | scale_x = self.crop_size[1] / ori_w
176 | scale_y = 1.
177 | else:
178 | raise ValueError('Original size %dx%d is not smaller than crop size %dx%d' % (
179 | ori_h, ori_w, self.crop_size[0], self.crop_size[1]
180 | ))
181 |
182 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
183 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
184 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
185 | flow = flow * [scale_x, scale_y]
186 |
187 | return img1, img2, flow
188 |
189 | def __call__(self, img1, img2, flow, backward_flow=None, occlusion=None, backward_occlusion=None):
190 | img1, img2 = self.color_transform(img1, img2)
191 | img1, img2 = self.eraser_transform(img1, img2)
192 |
193 | if self.resize_when_needed:
194 | assert backward_flow is None
195 | # Resize only when original size is smaller than the crop size
196 | if img1.shape[0] < self.crop_size[0] or img1.shape[1] < self.crop_size[1]:
197 | img1, img2, flow = self.resize(img1, img2, flow)
198 |
199 | if backward_flow is not None:
200 | if occlusion is not None:
201 | if backward_occlusion is not None:
202 | img1, img2, flow, backward_flow, occlusion, backward_occlusion = self.spatial_transform(
203 | img1, img2, flow, backward_flow, occlusion, backward_occlusion)
204 | else:
205 | img1, img2, flow, backward_flow, occlusion = self.spatial_transform(
206 | img1, img2, flow, backward_flow, occlusion)
207 | else:
208 | img1, img2, flow, backward_flow = self.spatial_transform(img1, img2, flow, backward_flow)
209 | else:
210 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
211 |
212 | img1 = np.ascontiguousarray(img1)
213 | img2 = np.ascontiguousarray(img2)
214 | flow = np.ascontiguousarray(flow)
215 |
216 | if backward_flow is not None:
217 | backward_flow = np.ascontiguousarray(backward_flow)
218 |
219 | if occlusion is not None:
220 | occlusion = np.ascontiguousarray(occlusion)
221 | if backward_occlusion is not None:
222 | backward_occlusion = np.ascontiguousarray(backward_occlusion)
223 | return img1, img2, flow, backward_flow, occlusion, backward_occlusion
224 |
225 | return img1, img2, flow, backward_flow, occlusion
226 |
227 | return img1, img2, flow, backward_flow
228 |
229 | return img1, img2, flow
230 |
231 |
232 | class SparseFlowAugmentor:
233 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False,
234 | resize_when_needed=False, # used for slow flow dataset
235 | is_kitti=True, # for KITTI dataset, use sparse resize flow, other bilinear resize
236 | no_eraser_aug=False,
237 | ):
238 | # spatial augmentation params
239 | self.crop_size = crop_size
240 | self.min_scale = min_scale
241 | self.max_scale = max_scale
242 | self.spatial_aug_prob = 0.8
243 | self.stretch_prob = 0.8
244 | self.max_stretch = 0.2
245 |
246 | # flip augmentation params
247 | self.do_flip = do_flip
248 | self.h_flip_prob = 0.5
249 | self.v_flip_prob = 0.1
250 |
251 | # photometric augmentation params
252 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14)
253 | self.asymmetric_color_aug_prob = 0.2
254 |
255 | if no_eraser_aug:
256 | self.eraser_aug_prob = -1
257 | else:
258 | self.eraser_aug_prob = 0.5
259 |
260 | self.resize_when_needed = resize_when_needed
261 | self.is_kitti = is_kitti
262 |
263 | def color_transform(self, img1, img2):
264 | image_stack = np.concatenate([img1, img2], axis=0)
265 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
266 | img1, img2 = np.split(image_stack, 2, axis=0)
267 | return img1, img2
268 |
269 | def eraser_transform(self, img1, img2):
270 | ht, wd = img1.shape[:2]
271 | if np.random.rand() < self.eraser_aug_prob:
272 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
273 | for _ in range(np.random.randint(1, 3)):
274 | x0 = np.random.randint(0, wd)
275 | y0 = np.random.randint(0, ht)
276 | dx = np.random.randint(50, 100)
277 | dy = np.random.randint(50, 100)
278 | img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color
279 |
280 | return img1, img2
281 |
282 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
283 | ht, wd = flow.shape[:2]
284 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
285 | coords = np.stack(coords, axis=-1)
286 |
287 | coords = coords.reshape(-1, 2).astype(np.float32)
288 | flow = flow.reshape(-1, 2).astype(np.float32)
289 | valid = valid.reshape(-1).astype(np.float32)
290 |
291 | coords0 = coords[valid >= 1]
292 | flow0 = flow[valid >= 1]
293 |
294 | ht1 = int(round(ht * fy))
295 | wd1 = int(round(wd * fx))
296 |
297 | coords1 = coords0 * [fx, fy]
298 | flow1 = flow0 * [fx, fy]
299 |
300 | xx = np.round(coords1[:, 0]).astype(np.int32)
301 | yy = np.round(coords1[:, 1]).astype(np.int32)
302 |
303 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
304 | xx = xx[v]
305 | yy = yy[v]
306 | flow1 = flow1[v]
307 |
308 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
309 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
310 |
311 | flow_img[yy, xx] = flow1
312 | valid_img[yy, xx] = 1
313 |
314 | return flow_img, valid_img
315 |
316 | def resize(self, img1, img2, flow, valid):
317 | ori_h, ori_w = img1.shape[:2]
318 |
319 | if ori_h < self.crop_size[0] and ori_w < self.crop_size[1]:
320 | # resize both h and w
321 | scale_y = self.crop_size[0] / ori_h
322 | scale_x = self.crop_size[1] / ori_w
323 | elif ori_h < self.crop_size[0]: # only resize h
324 | scale_y = self.crop_size[0] / ori_h
325 | scale_x = 1.
326 | elif ori_w < self.crop_size[1]: # only resize w
327 | scale_x = self.crop_size[1] / ori_w
328 | scale_y = 1.
329 | else:
330 | raise ValueError('Original size %dx%d is not smaller than crop size %dx%d' % (
331 | ori_h, ori_w, self.crop_size[0], self.crop_size[1]
332 | ))
333 |
334 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
335 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
336 |
337 | if self.is_kitti:
338 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
339 | else: # for viper and slow flow datasets, only a few pixels are invalid
340 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
341 | # NOTE: don't forget scale flow also
342 | flow = flow * [scale_x, scale_y]
343 |
344 | valid = cv2.resize(valid, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_NEAREST)
345 |
346 | return img1, img2, flow, valid
347 |
348 | def spatial_transform(self, img1, img2, flow, valid):
349 | # randomly sample scale
350 |
351 | ht, wd = img1.shape[:2]
352 | min_scale = np.maximum(
353 | (self.crop_size[0] + 1) / float(ht),
354 | (self.crop_size[1] + 1) / float(wd))
355 |
356 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
357 | scale_x = np.clip(scale, min_scale, None)
358 | scale_y = np.clip(scale, min_scale, None)
359 |
360 | if np.random.rand() < self.spatial_aug_prob:
361 | # rescale the images
362 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
363 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
364 |
365 | if self.is_kitti:
366 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
367 | else: # for viper and slow flow datasets, only a few pixels are invalid
368 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
369 | flow = flow * [scale_x, scale_y]
370 |
371 | valid = cv2.resize(valid, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_NEAREST)
372 |
373 | if self.do_flip:
374 | if np.random.rand() < 0.5: # h-flip
375 | img1 = img1[:, ::-1]
376 | img2 = img2[:, ::-1]
377 | flow = flow[:, ::-1] * [-1.0, 1.0]
378 | valid = valid[:, ::-1]
379 |
380 | margin_y = 20
381 | margin_x = 50
382 |
383 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
384 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
385 |
386 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
387 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
388 |
389 | img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
390 | img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
391 | flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
392 | valid = valid[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
393 | return img1, img2, flow, valid
394 |
395 | def __call__(self, img1, img2, flow, valid):
396 | img1, img2 = self.color_transform(img1, img2)
397 | img1, img2 = self.eraser_transform(img1, img2)
398 |
399 | if self.resize_when_needed:
400 | # Resize only when original size is smaller than the crop size
401 | if img1.shape[0] < self.crop_size[0] or img1.shape[1] < self.crop_size[1]:
402 | img1, img2, flow, valid = self.resize(img1, img2, flow, valid)
403 |
404 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
405 |
406 | img1 = np.ascontiguousarray(img1)
407 | img2 = np.ascontiguousarray(img2)
408 | flow = np.ascontiguousarray(flow)
409 | valid = np.ascontiguousarray(valid)
410 |
411 | return img1, img2, flow, valid
412 |
--------------------------------------------------------------------------------
/demo/dogs-jump/00033.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/demo/dogs-jump/00033.jpg
--------------------------------------------------------------------------------
/demo/dogs-jump/00034.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/demo/dogs-jump/00034.jpg
--------------------------------------------------------------------------------
/demo/dogs-jump/00035.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/demo/dogs-jump/00035.jpg
--------------------------------------------------------------------------------
/demo/dogs-jump/00036.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/demo/dogs-jump/00036.jpg
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: flow1d
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - blas=1.0=mkl
8 | - ca-certificates=2020.6.24=0
9 | - certifi=2020.6.20=py37_0
10 | - cloudpickle=1.5.0=py_0
11 | - cudatoolkit=10.2.89=hfd86e86_1
12 | - cycler=0.10.0=py37_0
13 | - cytoolz=0.10.1=py37h7b6447c_0
14 | - dask-core=2.20.0=py_0
15 | - dbus=1.13.16=hb2f20db_0
16 | - decorator=4.4.2=py_0
17 | - expat=2.2.9=he6710b0_2
18 | - fontconfig=2.13.0=h9420a91_0
19 | - freetype=2.10.2=h5ab3b9f_0
20 | - glib=2.65.0=h3eb4bd4_0
21 | - gst-plugins-base=1.14.0=hbbd80ab_1
22 | - gstreamer=1.14.0=hb31296c_0
23 | - icu=58.2=he6710b0_3
24 | - imageio=2.9.0=py_0
25 | - intel-openmp=2020.1=217
26 | - jpeg=9b=h024ee3a_2
27 | - kiwisolver=1.2.0=py37hfd86e86_0
28 | - lcms2=2.11=h396b838_0
29 | - ld_impl_linux-64=2.33.1=h53a641e_7
30 | - libedit=3.1.20191231=h14c3975_1
31 | - libffi=3.3=he6710b0_2
32 | - libgcc-ng=9.1.0=hdf63c60_0
33 | - libgfortran-ng=7.3.0=hdf63c60_0
34 | - libpng=1.6.37=hbc83047_0
35 | - libstdcxx-ng=9.1.0=hdf63c60_0
36 | - libtiff=4.1.0=h2733197_1
37 | - libuuid=1.0.3=h1bed415_2
38 | - libxcb=1.14=h7b6447c_0
39 | - libxml2=2.9.10=he19cac6_1
40 | - lz4-c=1.9.2=he6710b0_0
41 | - matplotlib=3.2.2=0
42 | - matplotlib-base=3.2.2=py37hef1b27d_0
43 | - mkl=2020.1=217
44 | - mkl-service=2.3.0=py37he904b0f_0
45 | - mkl_fft=1.1.0=py37h23d657b_0
46 | - mkl_random=1.1.1=py37h0573a6f_0
47 | - ncurses=6.2=he6710b0_1
48 | - networkx=2.4=py_1
49 | - ninja=1.9.0=py37hfd86e86_0
50 | - numpy=1.18.5=py37ha1c710e_0
51 | - numpy-base=1.18.5=py37hde5b4d6_0
52 | - olefile=0.46=py37_0
53 | - openssl=1.1.1g=h7b6447c_0
54 | - pcre=8.44=he6710b0_0
55 | - pillow=7.2.0=py37hb39fc2d_0
56 | - pip=20.1.1=py37_1
57 | - pyparsing=2.4.7=py_0
58 | - pyqt=5.9.2=py37h05f1152_2
59 | - python=3.7.7=hcff3b4d_5
60 | - python-dateutil=2.8.1=py_0
61 | - pytorch=1.7.1=py3.7_cuda10.2.89_cudnn7.6.5_0
62 | - pywavelets=1.1.1=py37h7b6447c_0
63 | - pyyaml=5.3.1=py37h7b6447c_1
64 | - qt=5.9.7=h5867ecd_1
65 | - readline=8.0=h7b6447c_0
66 | - scikit-image=0.16.2=py37h0573a6f_0
67 | - scipy=1.5.0=py37h0b6359f_0
68 | - setuptools=49.2.0=py37_0
69 | - sip=4.19.8=py37hf484d3e_0
70 | - six=1.15.0=py_0
71 | - sqlite=3.32.3=h62c20be_0
72 | - tk=8.6.10=hbc83047_0
73 | - toolz=0.10.0=py_0
74 | - torchvision=0.6.0=py37_cu102
75 | - tornado=6.0.4=py37h7b6447c_1
76 | - wheel=0.34.2=py37_0
77 | - xz=5.2.5=h7b6447c_0
78 | - yaml=0.2.5=h7b6447c_0
79 | - zlib=1.2.11=h7b6447c_3
80 | - zstd=1.4.5=h0b5b093_0
81 | - pip:
82 | - absl-py==0.9.0
83 | - cachetools==4.1.1
84 | - chardet==3.0.4
85 | - google-auth==1.19.2
86 | - google-auth-oauthlib==0.4.1
87 | - grpcio==1.30.0
88 | - idna==2.10
89 | - importlib-metadata==1.7.0
90 | - markdown==3.2.2
91 | - oauthlib==3.1.0
92 | - opencv-python==4.3.0.36
93 | - protobuf==3.12.2
94 | - pyasn1==0.4.8
95 | - pyasn1-modules==0.2.8
96 | - requests==2.24.0
97 | - requests-oauthlib==1.3.0
98 | - rsa==4.6
99 | - tensorboard==2.2.2
100 | - tensorboard-plugin-wit==1.7.0
101 | - urllib3==1.25.9
102 | - werkzeug==1.0.1
103 | - zipp==3.1.0
104 |
105 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import torch
5 |
6 | import data
7 | from utils import frame_utils
8 | from utils.flow_viz import save_vis_flow_tofile
9 |
10 | from utils.utils import InputPadder, forward_interpolate
11 | from glob import glob
12 |
13 |
14 | @torch.no_grad()
15 | def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission',
16 | padding_factor=8,
17 | save_vis_flow=False,
18 | no_save_flo=False,
19 | **kwargs,
20 | ):
21 | """ Create submission for the Sintel leaderboard """
22 | model.eval()
23 | for dstype in ['clean', 'final']:
24 | test_dataset = data.MpiSintel(split='test', aug_params=None, dstype=dstype)
25 |
26 | flow_prev, sequence_prev = None, None
27 | for test_id in range(len(test_dataset)):
28 | image1, image2, (sequence, frame) = test_dataset[test_id]
29 | if sequence != sequence_prev:
30 | flow_prev = None
31 |
32 | padder = InputPadder(image1.shape, padding_factor=padding_factor)
33 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
34 |
35 | flow_low, flow_pr = model(image1, image2, iters=iters,
36 | flow_init=flow_prev,
37 | test_mode=True)
38 |
39 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
40 |
41 | if warm_start:
42 | flow_prev = forward_interpolate(flow_low[0])[None].cuda()
43 |
44 | output_dir = os.path.join(output_path, dstype, sequence)
45 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame + 1))
46 |
47 | if not os.path.exists(output_dir):
48 | os.makedirs(output_dir)
49 |
50 | if not no_save_flo:
51 | frame_utils.writeFlow(output_file, flow)
52 | sequence_prev = sequence
53 |
54 | # Save vis flow
55 | if save_vis_flow:
56 | vis_flow_file = output_file.replace('.flo', '.png')
57 | save_vis_flow_tofile(flow, vis_flow_file)
58 |
59 |
60 | @torch.no_grad()
61 | def create_kitti_submission(model, iters=24, output_path='kitti_submission',
62 | padding_factor=8,
63 | save_vis_flow=False,
64 | **kwargs,
65 | ):
66 | """ Create submission for the Sintel leaderboard """
67 | model.eval()
68 | test_dataset = data.KITTI(split='testing', aug_params=None)
69 |
70 | if not os.path.exists(output_path):
71 | os.makedirs(output_path)
72 |
73 | for test_id in range(len(test_dataset)):
74 | image1, image2, (frame_id,) = test_dataset[test_id]
75 | padder = InputPadder(image1.shape, mode='kitti', padding_factor=padding_factor)
76 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
77 |
78 | flow_pr = model(image1, image2, iters=iters,
79 | flow_init=None,
80 | test_mode=True)[-1]
81 |
82 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
83 |
84 | output_filename = os.path.join(output_path, frame_id)
85 |
86 | # Save vis flow
87 | if save_vis_flow:
88 | vis_flow_file = output_filename
89 | save_vis_flow_tofile(flow, vis_flow_file)
90 | else:
91 | frame_utils.writeFlowKITTI(output_filename, flow)
92 |
93 |
94 | @torch.no_grad()
95 | def validate_chairs(model,
96 | iters=24,
97 | **kwargs,
98 | ):
99 | """ Perform evaluation on the FlyingChairs (test) split """
100 | model.eval()
101 | epe_list = []
102 | results = {}
103 |
104 | val_dataset = data.FlyingChairs(split='validation')
105 |
106 | print('Number of validation image pairs: %d' % len(val_dataset))
107 |
108 | for val_id in range(len(val_dataset)):
109 | image1, image2, flow_gt, _ = val_dataset[val_id]
110 |
111 | image1 = image1[None].cuda()
112 | image2 = image2[None].cuda()
113 |
114 | flow_pr = model(image1, image2, iters=iters, test_mode=True)[-1] # RAFT
115 |
116 | epe = torch.sum((flow_pr[0].cpu() - flow_gt) ** 2, dim=0).sqrt()
117 | epe_list.append(epe.view(-1).numpy())
118 |
119 | epe_all = np.concatenate(epe_list)
120 | epe = np.mean(epe_all)
121 | px1 = np.mean(epe_all > 1)
122 | px3 = np.mean(epe_all > 3)
123 | px5 = np.mean(epe_all > 5)
124 |
125 | print("Validation Chairs EPE: %.3f, 1px: %.3f, 3px: %.3f, 5px: %.3f" % (epe, px1, px3, px5))
126 |
127 | results['chairs_epe'] = epe
128 | results['chairs_1px'] = px1
129 | results['chairs_3px'] = px3
130 | results['chairs_5px'] = px5
131 |
132 | return results
133 |
134 |
135 | @torch.no_grad()
136 | def validate_sintel(model,
137 | count_time=False,
138 | padding_factor=8,
139 | iters=32,
140 | **kwargs,
141 | ):
142 | """ Peform validation using the Sintel (train) split """
143 | model.eval()
144 | results = {}
145 |
146 | if count_time:
147 | total_time = 0
148 | num_runs = 100
149 |
150 | for dstype in ['clean', 'final']:
151 | val_dataset = data.MpiSintel(split='training', dstype=dstype)
152 |
153 | print('Number of validation image pairs: %d' % len(val_dataset))
154 | epe_list = []
155 |
156 | for val_id in range(len(val_dataset)):
157 | image1, image2, flow_gt, _ = val_dataset[val_id]
158 | image1 = image1[None].cuda()
159 | image2 = image2[None].cuda()
160 |
161 | padder = InputPadder(image1.shape, padding_factor=padding_factor)
162 | image1, image2 = padder.pad(image1, image2)
163 |
164 | if count_time and val_id >= 5: # 5 warmup
165 | torch.cuda.synchronize()
166 | time_start = time.perf_counter()
167 |
168 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
169 |
170 | if count_time and val_id >= 5:
171 | torch.cuda.synchronize()
172 | total_time += time.perf_counter() - time_start
173 |
174 | if val_id >= num_runs + 4:
175 | break
176 |
177 | flow = padder.unpad(flow_pr[0]).cpu()
178 |
179 | epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt()
180 | epe_list.append(epe.view(-1).numpy())
181 |
182 | epe_all = np.concatenate(epe_list)
183 | epe = np.mean(epe_all)
184 | px1 = np.mean(epe_all > 1)
185 | px3 = np.mean(epe_all > 3)
186 | px5 = np.mean(epe_all > 5)
187 |
188 | print("Validation Sintel (%s) EPE: %.3f, 1px: %.3f, 3px: %.3f, 5px: %.3f" % (dstype, epe, px1, px3, px5))
189 |
190 | dstype = 'sintel_' + dstype
191 |
192 | results[dstype + '_epe'] = np.mean(epe_list)
193 | results[dstype + '_1px'] = px1
194 | results[dstype + '_3px'] = px3
195 | results[dstype + '_5px'] = px5
196 |
197 | if count_time:
198 | print('Time: %.3fs' % (total_time / num_runs))
199 | break # only the clean pass when counting time
200 |
201 | return results
202 |
203 |
204 | @torch.no_grad()
205 | def validate_kitti(model,
206 | padding_factor=8,
207 | iters=24,
208 | **kwargs,
209 | ):
210 | """ Peform validation using the KITTI-2015 (train) split """
211 | model.eval()
212 |
213 | val_dataset = data.KITTI(split='training')
214 | print('Number of validation image pairs: %d' % len(val_dataset))
215 |
216 | out_list, epe_list = [], []
217 | results = {}
218 |
219 | for val_id in range(len(val_dataset)):
220 | image1, image2, flow_gt, valid_gt = val_dataset[val_id]
221 | image1 = image1[None].cuda()
222 | image2 = image2[None].cuda()
223 |
224 | padder = InputPadder(image1.shape, mode='kitti', padding_factor=padding_factor)
225 | image1, image2 = padder.pad(image1, image2)
226 |
227 | flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
228 |
229 | flow = padder.unpad(flow_pr[0]).cpu()
230 |
231 | epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt()
232 | mag = torch.sum(flow_gt ** 2, dim=0).sqrt()
233 |
234 | epe = epe.view(-1)
235 | mag = mag.view(-1)
236 | val = valid_gt.view(-1) >= 0.5
237 |
238 | out = ((epe > 3.0) & ((epe / mag) > 0.05)).float()
239 |
240 | epe_list.append(epe[val].mean().item())
241 | out_list.append(out[val].cpu().numpy())
242 |
243 | epe_list = np.array(epe_list)
244 | out_list = np.concatenate(out_list)
245 |
246 | epe = np.mean(epe_list)
247 | f1 = 100 * np.mean(out_list)
248 |
249 | print("Validation KITTI EPE: %.3f, F1-all: %.3f" % (epe, f1))
250 | results['kitti_epe'] = epe
251 | results['kitti_f1'] = f1
252 |
253 | return results
254 |
255 |
256 | @torch.no_grad()
257 | def inference_on_dir(model, inference_dir,
258 | iters=32, warm_start=False, output_path='output',
259 | padding_factor=8,
260 | paired_data=False, # dir of paired data instead of a sequence
261 | save_flo_flow=False, # save as .flo for quantative evaluation
262 | **kwargs,
263 | ):
264 | """ Inference on a directory """
265 | model.eval()
266 |
267 | if not os.path.exists(output_path):
268 | os.makedirs(output_path)
269 |
270 | filenames = sorted(glob(inference_dir + '/*'))
271 | print('%d images found' % len(filenames))
272 |
273 | flow_prev, sequence_prev = None, None
274 |
275 | stride = 2 if paired_data else 1
276 |
277 | if paired_data:
278 | assert len(filenames) % 2 == 0
279 |
280 | for test_id in range(0, len(filenames) - 1, stride):
281 | image1 = frame_utils.read_gen(filenames[test_id])
282 | image2 = frame_utils.read_gen(filenames[test_id + 1])
283 |
284 | image1 = np.array(image1).astype(np.uint8)
285 | image2 = np.array(image2).astype(np.uint8)
286 |
287 | if len(image1.shape) == 2: # gray image, for example, HD1K
288 | image1 = np.tile(image1[..., None], (1, 1, 3))
289 | image2 = np.tile(image2[..., None], (1, 1, 3))
290 | else:
291 | image1 = image1[..., :3]
292 | image2 = image2[..., :3]
293 |
294 | image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
295 | image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
296 |
297 | if test_id == 0:
298 | flow_prev = None
299 |
300 | padder = InputPadder(image1.shape, padding_factor=padding_factor)
301 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
302 |
303 | flow_init = None
304 | flow_low, flow_pr = model(image1, image2, iters=iters,
305 | flow_init=flow_prev if flow_init is None else flow_init,
306 | test_mode=True)
307 |
308 | if warm_start:
309 | flow_prev = forward_interpolate(flow_low[0])[None].cuda()
310 |
311 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() # [H, W, 2]
312 |
313 | output_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_flow.png')
314 |
315 | # Save vis flow
316 | save_vis_flow_tofile(flow, output_file)
317 |
318 | if save_flo_flow:
319 | output_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_pred.flo')
320 | frame_utils.writeFlow(output_file, flow)
321 |
--------------------------------------------------------------------------------
/flow1d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haofeixu/flow1d/ece861d2136e2eb2e99a9db71794d82c5782dbcb/flow1d/__init__.py
--------------------------------------------------------------------------------
/flow1d/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import copy
4 |
5 |
6 | class Attention1D(nn.Module):
7 | """Cross-Attention on x or y direction,
8 | without multi-head and dropout support for faster speed
9 | """
10 |
11 | def __init__(self, in_channels,
12 | y_attention=False,
13 | double_cross_attn=False, # cross attn feature1 before computing cross attn feature2
14 | **kwargs,
15 | ):
16 | super(Attention1D, self).__init__()
17 |
18 | self.y_attention = y_attention
19 | self.double_cross_attn = double_cross_attn
20 |
21 | # self attn feature1 before cross attn
22 | if double_cross_attn:
23 | self.self_attn = copy.deepcopy(Attention1D(in_channels=in_channels,
24 | y_attention=not y_attention,
25 | )
26 | )
27 |
28 | self.query_conv = nn.Conv2d(in_channels, in_channels, 1)
29 | self.key_conv = nn.Conv2d(in_channels, in_channels, 1)
30 |
31 | # Initialize: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py#L138
32 | for p in self.parameters():
33 | if p.dim() > 1:
34 | nn.init.xavier_uniform_(p) # original Transformer initialization
35 |
36 | def forward(self, feature1, feature2, position=None, value=None):
37 | b, c, h, w = feature1.size()
38 |
39 | # self attn before cross attn
40 | if self.double_cross_attn:
41 | feature1 = self.self_attn(feature1, feature1, position)[0] # self attn feature1
42 |
43 | query = feature1 + position if position is not None else feature1
44 | query = self.query_conv(query) # [B, C, H, W]
45 |
46 | key = feature2 + position if position is not None else feature2
47 |
48 | key = self.key_conv(key) # [B, C, H, W]
49 | value = feature2 if value is None else value # [B, C, H, W]
50 | scale_factor = c ** 0.5
51 |
52 | if self.y_attention:
53 | query = query.permute(0, 3, 2, 1) # [B, W, H, C]
54 | key = key.permute(0, 3, 1, 2) # [B, W, C, H]
55 | value = value.permute(0, 3, 2, 1) # [B, W, H, C]
56 | else: # x attention
57 | query = query.permute(0, 2, 3, 1) # [B, H, W, C]
58 | key = key.permute(0, 2, 1, 3) # [B, H, C, W]
59 | value = value.permute(0, 2, 3, 1) # [B, H, W, C]
60 |
61 | scores = torch.matmul(query, key) / scale_factor # [B, W, H, H] or [B, H, W, W]
62 |
63 | attention = torch.softmax(scores, dim=-1) # [B, W, H, H] or [B, H, W, W]
64 |
65 | out = torch.matmul(attention, value) # [B, W, H, C] or [B, H, W, C]
66 |
67 | if self.y_attention:
68 | out = out.permute(0, 3, 2, 1).contiguous() # [B, C, H, W]
69 | else:
70 | out = out.permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
71 |
72 | return out, attention
73 |
--------------------------------------------------------------------------------
/flow1d/correlation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | class Correlation1D:
6 | def __init__(self, feature1, feature2,
7 | radius=32,
8 | x_correlation=False,
9 | ):
10 | self.radius = radius
11 | self.x_correlation = x_correlation
12 |
13 | if self.x_correlation:
14 | self.corr = self.corr_x(feature1, feature2) # [B*H*W, 1, 1, W]
15 | else:
16 | self.corr = self.corr_y(feature1, feature2) # [B*H*W, 1, H, 1]
17 |
18 | def __call__(self, coords):
19 | r = self.radius
20 | coords = coords.permute(0, 2, 3, 1) # [B, H, W, 2]
21 | b, h, w = coords.shape[:3]
22 |
23 | if self.x_correlation:
24 | dx = torch.linspace(-r, r, 2 * r + 1)
25 | dy = torch.zeros_like(dx)
26 | delta_x = torch.stack((dx, dy), dim=-1).to(coords.device) # [2r+1, 2]
27 |
28 | coords_x = coords[:, :, :, 0] # [B, H, W]
29 | coords_x = torch.stack((coords_x, torch.zeros_like(coords_x)), dim=-1) # [B, H, W, 2]
30 |
31 | centroid_x = coords_x.view(b * h * w, 1, 1, 2) # [B*H*W, 1, 1, 2]
32 | coords_x = centroid_x + delta_x # [B*H*W, 1, 2r+1, 2]
33 |
34 | coords_x = 2 * coords_x / (w - 1) - 1 # [-1, 1], y is always 0
35 |
36 | corr_x = F.grid_sample(self.corr, coords_x, mode='bilinear',
37 | align_corners=True) # [B*H*W, G, 1, 2r+1]
38 |
39 | corr_x = corr_x.view(b, h, w, -1) # [B, H, W, (2r+1)*G]
40 | corr_x = corr_x.permute(0, 3, 1, 2).contiguous() # [B, (2r+1)*G, H, W]
41 | return corr_x
42 | else: # y correlation
43 | dy = torch.linspace(-r, r, 2 * r + 1)
44 | dx = torch.zeros_like(dy)
45 | delta_y = torch.stack((dx, dy), dim=-1).to(coords.device) # [2r+1, 2]
46 | delta_y = delta_y.unsqueeze(1).unsqueeze(0) # [1, 2r+1, 1, 2]
47 |
48 | coords_y = coords[:, :, :, 1] # [B, H, W]
49 | coords_y = torch.stack((torch.zeros_like(coords_y), coords_y), dim=-1) # [B, H, W, 2]
50 |
51 | centroid_y = coords_y.view(b * h * w, 1, 1, 2) # [B*H*W, 1, 1, 2]
52 | coords_y = centroid_y + delta_y # [B*H*W, 2r+1, 1, 2]
53 |
54 | coords_y = 2 * coords_y / (h - 1) - 1 # [-1, 1], x is always 0
55 |
56 | corr_y = F.grid_sample(self.corr, coords_y, mode='bilinear',
57 | align_corners=True) # [B*H*W, G, 2r+1, 1]
58 |
59 | corr_y = corr_y.view(b, h, w, -1) # [B, H, W, (2r+1)*G]
60 | corr_y = corr_y.permute(0, 3, 1, 2).contiguous() # [B, (2r+1)*G, H, W]
61 |
62 | return corr_y
63 |
64 | def corr_x(self, feature1, feature2):
65 | b, c, h, w = feature1.shape # [B, C, H, W]
66 | scale_factor = c ** 0.5
67 |
68 | # x direction
69 | feature1 = feature1.permute(0, 2, 3, 1) # [B, H, W, C]
70 | feature2 = feature2.permute(0, 2, 1, 3) # [B, H, C, W]
71 | corr = torch.matmul(feature1, feature2) # [B, H, W, W]
72 |
73 | corr = corr.unsqueeze(3).unsqueeze(3) # [B, H, W, 1, 1, W]
74 | corr = corr / scale_factor
75 | corr = corr.flatten(0, 2) # [B*H*W, 1, 1, W]
76 |
77 | return corr
78 |
79 | def corr_y(self, feature1, feature2):
80 | b, c, h, w = feature1.shape # [B, C, H, W]
81 | scale_factor = c ** 0.5
82 |
83 | # y direction
84 | feature1 = feature1.permute(0, 3, 2, 1) # [B, W, H, C]
85 | feature2 = feature2.permute(0, 3, 1, 2) # [B, W, C, H]
86 | corr = torch.matmul(feature1, feature2) # [B, W, H, H]
87 |
88 | corr = corr.permute(0, 2, 1, 3).contiguous().view(b, h, w, 1, h, 1) # [B, H, W, 1, H, 1]
89 | corr = corr / scale_factor
90 | corr = corr.flatten(0, 2) # [B*H*W, 1, H, 1]
91 |
92 | return corr
93 |
--------------------------------------------------------------------------------
/flow1d/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ResidualBlock(nn.Module):
6 | def __init__(self, in_planes, planes, norm_fn='group', stride=1, dilation=1):
7 | super(ResidualBlock, self).__init__()
8 |
9 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
10 | dilation=dilation, padding=dilation, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
12 | dilation=dilation, padding=dilation)
13 | self.relu = nn.ReLU(inplace=True)
14 |
15 | num_groups = planes // 8
16 |
17 | if norm_fn == 'group':
18 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
20 | if not stride == 1:
21 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
22 |
23 | elif norm_fn == 'batch':
24 | self.norm1 = nn.BatchNorm2d(planes)
25 | self.norm2 = nn.BatchNorm2d(planes)
26 | if not stride == 1 or in_planes != planes:
27 | self.norm3 = nn.BatchNorm2d(planes)
28 |
29 | elif norm_fn == 'instance':
30 | self.norm1 = nn.InstanceNorm2d(planes)
31 | self.norm2 = nn.InstanceNorm2d(planes)
32 | if not stride == 1 or in_planes != planes:
33 | self.norm3 = nn.InstanceNorm2d(planes)
34 |
35 | elif norm_fn == 'none':
36 | self.norm1 = nn.Sequential()
37 | self.norm2 = nn.Sequential()
38 | if not stride == 1:
39 | self.norm3 = nn.Sequential()
40 |
41 | if stride == 1 and in_planes == planes:
42 | self.downsample = None
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 | def forward(self, x):
48 | y = x
49 | y = self.relu(self.norm1(self.conv1(y)))
50 | y = self.relu(self.norm2(self.conv2(y)))
51 |
52 | if self.downsample is not None:
53 | x = self.downsample(x)
54 |
55 | return self.relu(x + y)
56 |
57 |
58 | class BasicEncoder(nn.Module):
59 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0,
60 | **kwargs,
61 | ):
62 | super(BasicEncoder, self).__init__()
63 | self.norm_fn = norm_fn
64 |
65 | feature_dims = [64, 96, 128, 160]
66 |
67 | if self.norm_fn == 'group':
68 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=feature_dims[0])
69 |
70 | elif self.norm_fn == 'batch':
71 | self.norm1 = nn.BatchNorm2d(feature_dims[0])
72 |
73 | elif self.norm_fn == 'instance':
74 | self.norm1 = nn.InstanceNorm2d(feature_dims[0])
75 |
76 | elif self.norm_fn == 'none':
77 | self.norm1 = nn.Sequential()
78 |
79 | self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3)
80 | self.relu1 = nn.ReLU(inplace=True)
81 |
82 | self.in_planes = feature_dims[0]
83 | self.layer1 = self._make_layer(feature_dims[0], stride=1)
84 | self.layer2 = self._make_layer(feature_dims[1], stride=2) # 1/4
85 |
86 | self.layer3 = self._make_layer(feature_dims[2], stride=2, dilation=1)
87 |
88 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, kernel_size=1)
89 |
90 | self.dropout = None
91 | if dropout > 0:
92 | self.dropout = nn.Dropout2d(p=dropout)
93 |
94 | for m in self.modules():
95 | if isinstance(m, nn.Conv2d):
96 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
97 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
98 | if m.weight is not None:
99 | nn.init.constant_(m.weight, 1)
100 | if m.bias is not None:
101 | nn.init.constant_(m.bias, 0)
102 |
103 | def _make_layer(self, dim, stride=1, dilation=1):
104 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride, dilation=dilation)
105 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1, dilation=dilation)
106 | layers = (layer1, layer2)
107 |
108 | self.in_planes = dim
109 | return nn.Sequential(*layers)
110 |
111 | def forward(self, x):
112 |
113 | # if input is list, combine batch dimension
114 | is_list = isinstance(x, tuple) or isinstance(x, list)
115 | if is_list:
116 | batch_dim = x[0].shape[0]
117 | x = torch.cat(x, dim=0)
118 |
119 | x = self.conv1(x)
120 | x = self.norm1(x)
121 | x = self.relu1(x)
122 |
123 | x = self.layer1(x) # 1/2
124 | layer2 = self.layer2(x) # 1/4
125 |
126 | x = self.layer3(layer2) # 1/8
127 |
128 | x = self.conv2(x)
129 |
130 | if self.training and self.dropout is not None:
131 | x = self.dropout(x)
132 |
133 | if is_list:
134 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
135 |
136 | return x
137 |
--------------------------------------------------------------------------------
/flow1d/flow1d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .extractor import BasicEncoder
6 | from .attention import Attention1D
7 | from .position import PositionEmbeddingSine
8 | from .correlation import Correlation1D
9 | from .update import BasicUpdateBlock
10 | from utils.utils import coords_grid
11 |
12 |
13 | class Model(nn.Module):
14 | def __init__(self,
15 | downsample_factor=8,
16 | feature_channels=256,
17 | hidden_dim=128,
18 | context_dim=128,
19 | corr_radius=32,
20 | mixed_precision=False,
21 | **kwargs,
22 | ):
23 | super(Model, self).__init__()
24 |
25 | self.downsample_factor = downsample_factor
26 |
27 | self.feature_channels = feature_channels
28 |
29 | self.hidden_dim = hidden_dim
30 | self.context_dim = context_dim
31 | self.corr_radius = corr_radius
32 |
33 | self.mixed_precision = mixed_precision
34 |
35 | # feature network, context network, and update block
36 | self.fnet = BasicEncoder(output_dim=feature_channels, norm_fn='instance',
37 | )
38 |
39 | self.cnet = BasicEncoder(output_dim=hidden_dim + context_dim, norm_fn='batch',
40 | )
41 |
42 | # 1D attention
43 | corr_channels = (2 * corr_radius + 1) * 2
44 |
45 | self.attn_x = Attention1D(feature_channels,
46 | y_attention=False,
47 | double_cross_attn=True,
48 | )
49 | self.attn_y = Attention1D(feature_channels,
50 | y_attention=True,
51 | double_cross_attn=True,
52 | )
53 |
54 | # Update block
55 | self.update_block = BasicUpdateBlock(corr_channels=corr_channels,
56 | hidden_dim=hidden_dim,
57 | context_dim=context_dim,
58 | downsample_factor=downsample_factor,
59 | )
60 |
61 | def freeze_bn(self):
62 | for m in self.modules():
63 | if isinstance(m, nn.BatchNorm2d):
64 | m.eval()
65 |
66 | def initialize_flow(self, img, downsample=None):
67 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
68 | n, c, h, w = img.shape
69 | downsample_factor = self.downsample_factor if downsample is None else downsample
70 | coords0 = coords_grid(n, h // downsample_factor, w // downsample_factor).to(img.device)
71 | coords1 = coords_grid(n, h // downsample_factor, w // downsample_factor).to(img.device)
72 |
73 | # optical flow computed as difference: flow = coords1 - coords0
74 | return coords0, coords1
75 |
76 | def learned_upflow(self, flow, mask):
77 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
78 | n, _, h, w = flow.shape
79 | mask = mask.view(n, 1, 9, self.downsample_factor, self.downsample_factor, h, w)
80 | mask = torch.softmax(mask, dim=2)
81 |
82 | up_flow = F.unfold(self.downsample_factor * flow, [3, 3], padding=1)
83 | up_flow = up_flow.view(n, 2, 9, 1, 1, h, w)
84 |
85 | up_flow = torch.sum(mask * up_flow, dim=2)
86 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
87 | return up_flow.reshape(n, 2, self.downsample_factor * h, self.downsample_factor * w)
88 |
89 | def forward(self, image1, image2, iters=12, flow_init=None, test_mode=False,
90 | ):
91 | """ Estimate optical flow between pair of frames """
92 | image1 = 2 * (image1 / 255.0) - 1.0
93 | image2 = 2 * (image2 / 255.0) - 1.0
94 |
95 | # run the feature network
96 | feature1, feature2 = self.fnet([image1, image2])
97 |
98 | # Used for attention loss computation, store the attention matrix
99 | attn_x_list = []
100 | attn_y_list = []
101 |
102 | hdim = self.hidden_dim
103 | cdim = self.context_dim
104 |
105 | # position encoding
106 | pos_channels = self.feature_channels // 2
107 | pos_enc = PositionEmbeddingSine(pos_channels)
108 |
109 | position = pos_enc(feature1) # [B, C, H, W]
110 |
111 | # 1D correlation
112 | feature2_x, attn_x = self.attn_x(feature1, feature2, position)
113 | corr_fn_y = Correlation1D(feature1, feature2_x,
114 | radius=self.corr_radius,
115 | x_correlation=False,
116 | )
117 |
118 | feature2_y, attn_y = self.attn_y(feature1, feature2, position)
119 | corr_fn_x = Correlation1D(feature1, feature2_y,
120 | radius=self.corr_radius,
121 | x_correlation=True,
122 | )
123 |
124 | # run the context network
125 | cnet = self.cnet(image1) # list of feature pyramid, low scale to high scale
126 |
127 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
128 | net = torch.tanh(net)
129 | inp = torch.relu(inp)
130 |
131 | coords0, coords1 = self.initialize_flow(image1) # 1/8 resolution or 1/4
132 |
133 | if flow_init is not None: # flow_init is 1/8 resolution or 1/4
134 | coords1 = coords1 + flow_init
135 |
136 | flow_predictions = []
137 | for itr in range(iters):
138 | coords1 = coords1.detach() # stop gradient
139 |
140 | corr_x = corr_fn_x(coords1)
141 | corr_y = corr_fn_y(coords1)
142 | corr = torch.cat((corr_x, corr_y), dim=1) # [B, 2(2R+1), H, W]
143 |
144 | flow = coords1 - coords0
145 |
146 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow,
147 | upsample=not test_mode or itr == iters - 1,
148 | )
149 |
150 | coords1 = coords1 + delta_flow
151 |
152 | if test_mode:
153 | # only upsample the last iteration
154 | if itr == iters - 1:
155 | flow_up = self.learned_upflow(coords1 - coords0, up_mask)
156 |
157 | return coords1 - coords0, flow_up
158 | else:
159 | # upsample predictions
160 | flow_up = self.learned_upflow(coords1 - coords0, up_mask)
161 | flow_predictions.append(flow_up)
162 |
163 | return flow_predictions, attn_x_list, attn_y_list, coords1 - coords0
164 |
165 |
166 | def build_model(args):
167 | return Model(downsample_factor=args.downsample_factor,
168 | feature_channels=args.feature_channels,
169 | corr_radius=args.corr_radius,
170 | hidden_dim=args.hidden_dim,
171 | context_dim=args.context_dim,
172 | mixed_precision=args.mixed_precision,
173 | )
174 |
--------------------------------------------------------------------------------
/flow1d/position.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import math
4 |
5 |
6 | class PositionEmbeddingSine(nn.Module):
7 | """
8 | https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
9 | This is a more standard version of the position embedding, very similar to the one
10 | used by the Attention is all you need paper, generalized to work on images.
11 | """
12 |
13 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
14 | super().__init__()
15 | self.num_pos_feats = num_pos_feats
16 | self.temperature = temperature
17 | self.normalize = normalize
18 | if scale is not None and normalize is False:
19 | raise ValueError("normalize should be True if scale is passed")
20 | if scale is None:
21 | scale = 2 * math.pi
22 | self.scale = scale
23 |
24 | def forward(self, x):
25 | # x = tensor_list.tensors # [B, C, H, W]
26 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
27 | b, c, h, w = x.size()
28 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
29 | y_embed = mask.cumsum(1, dtype=torch.float32)
30 | x_embed = mask.cumsum(2, dtype=torch.float32)
31 | if self.normalize:
32 | eps = 1e-6
33 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
34 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
35 |
36 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
37 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
38 |
39 | pos_x = x_embed[:, :, :, None] / dim_t
40 | pos_y = y_embed[:, :, :, None] / dim_t
41 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
42 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
43 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
44 | return pos
45 |
--------------------------------------------------------------------------------
/flow1d/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256,
8 | ):
9 | super(FlowHead, self).__init__()
10 |
11 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
12 |
13 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
14 | self.relu = nn.ReLU(inplace=True)
15 |
16 | def forward(self, x):
17 | out = self.conv2(self.relu(self.conv1(x)))
18 |
19 | return out
20 |
21 |
22 | class SepConvGRU(nn.Module):
23 | def __init__(self, hidden_dim=128, input_dim=192 + 128,
24 | kernel_size=5,
25 | ):
26 | padding = (kernel_size - 1) // 2
27 |
28 | super(SepConvGRU, self).__init__()
29 | self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
30 | self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
31 | self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
32 |
33 | self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
34 | self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
35 | self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
36 |
37 | def forward(self, h, x):
38 | # horizontal
39 | hx = torch.cat([h, x], dim=1)
40 | z = torch.sigmoid(self.convz1(hx))
41 | r = torch.sigmoid(self.convr1(hx))
42 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
43 | h = (1 - z) * h + z * q
44 |
45 | # vertical
46 | hx = torch.cat([h, x], dim=1)
47 | z = torch.sigmoid(self.convz2(hx))
48 | r = torch.sigmoid(self.convr2(hx))
49 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
50 | h = (1 - z) * h + z * q
51 |
52 | return h
53 |
54 |
55 | class BasicMotionEncoder(nn.Module):
56 | def __init__(self, corr_channels=324,
57 | ):
58 | super(BasicMotionEncoder, self).__init__()
59 |
60 | self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0)
61 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
62 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
63 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
64 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
65 |
66 | def forward(self, flow, corr):
67 | cor = F.relu(self.convc1(corr))
68 | cor = F.relu(self.convc2(cor))
69 | flo = F.relu(self.convf1(flow))
70 | flo = F.relu(self.convf2(flo))
71 |
72 | cor_flo = torch.cat([cor, flo], dim=1)
73 | out = F.relu(self.conv(cor_flo))
74 | return torch.cat([out, flow], dim=1)
75 |
76 |
77 | class BasicUpdateBlock(nn.Module):
78 | def __init__(self, corr_channels=324,
79 | hidden_dim=128,
80 | context_dim=128,
81 | downsample_factor=8,
82 | learn_upsample=True,
83 | **kwargs,
84 | ):
85 | super(BasicUpdateBlock, self).__init__()
86 |
87 | self.encoder = BasicMotionEncoder(corr_channels=corr_channels)
88 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim)
89 |
90 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256,
91 | )
92 |
93 | self.learn_upsample = learn_upsample
94 |
95 | if learn_upsample:
96 | self.mask = nn.Sequential(
97 | nn.Conv2d(hidden_dim, 256, 3, padding=1),
98 | nn.ReLU(inplace=True),
99 | nn.Conv2d(256, downsample_factor ** 2 * 9, 1, padding=0))
100 |
101 | def forward(self, net, inp, corr, flow, upsample=True,
102 | **kwargs,
103 | ):
104 | motion_features = self.encoder(flow, corr)
105 |
106 | inp = torch.cat([inp, motion_features], dim=1)
107 |
108 | net = self.gru(net, inp)
109 | delta_flow = self.flow_head(net)
110 |
111 | if self.learn_upsample and upsample:
112 | # scale mask to balence gradients following RAFT
113 | mask = .25 * self.mask(net)
114 | else:
115 | mask = None
116 | return net, mask, delta_flow
117 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def criterion(flow_preds, flow_gt, valid, gamma=0.8, max_flow=400,
5 | ):
6 | """ Loss function defined over sequence of flow predictions
7 | """
8 |
9 | n_predictions = len(flow_preds)
10 | flow_loss = 0.0
11 |
12 | # exlude invalid pixels and extremely large diplacements
13 | mag = torch.sum(flow_gt ** 2, dim=1).sqrt()
14 | valid = (valid >= 0.5) & (mag < max_flow)
15 |
16 | for i in range(n_predictions):
17 | i_weight = gamma ** (n_predictions - i - 1)
18 | i_loss = (flow_preds[i] - flow_gt).abs()
19 |
20 | flow_loss += i_weight * (valid[:, None] * i_loss).mean()
21 |
22 | epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()
23 | epe = epe.view(-1)[valid.view(-1)]
24 |
25 | metrics = {
26 | 'epe': epe.mean().item(),
27 | '1px': (epe > 1).float().mean().item(),
28 | '3px': (epe > 3).float().mean().item(),
29 | '5px': (epe > 5).float().mean().item(),
30 | }
31 |
32 | return flow_loss, metrics
33 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torch.utils.tensorboard import SummaryWriter
4 |
5 | import argparse
6 | import numpy as np
7 | import os
8 |
9 | from data import build_dataset
10 |
11 | from flow1d.flow1d import build_model
12 | from loss import criterion
13 | from evaluate import (validate_chairs, validate_sintel, validate_kitti,
14 | create_kitti_submission, create_sintel_submission,
15 | inference_on_dir,
16 | )
17 |
18 | from utils.logger import Logger
19 | from utils import misc
20 |
21 |
22 | def get_args_parser():
23 | parser = argparse.ArgumentParser()
24 |
25 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoints/tmp')
26 | parser.add_argument('--eval', action='store_true')
27 |
28 | # Dataset
29 | parser.add_argument('--image_size', default=[368, 496], type=int, nargs='+')
30 | parser.add_argument('--stage', default='chairs', type=str)
31 | parser.add_argument('--max_flow', default=400, type=int)
32 | parser.add_argument('--padding_factor', default=8, type=int)
33 | parser.add_argument('--val_dataset', default='chairs', type=str, nargs='+')
34 |
35 | # Create Sintel and KITTI submission
36 | parser.add_argument('--submission', action='store_true',
37 | help='Create submission')
38 | parser.add_argument('--warm_start', action='store_true')
39 | parser.add_argument('--output_path', default='output', type=str)
40 | parser.add_argument('--save_vis_flow', action='store_true')
41 | parser.add_argument('--no_save_flo', action='store_true')
42 |
43 | # Inference on a directory
44 | parser.add_argument('--inference_dir', default=None, type=str)
45 | parser.add_argument('--dir_paired_data', action='store_true',
46 | help='Paired data in a dir instead of a sequence')
47 | parser.add_argument('--save_flo_flow', action='store_true')
48 |
49 | # Training
50 | parser.add_argument('--lr', default=4e-4, type=float)
51 | parser.add_argument('--lr_warmup', default=0.05, type=float,
52 | help='Percentage of lr warmup')
53 | parser.add_argument('--batch_size', default=12, type=int)
54 | parser.add_argument('--num_workers', default=4, type=int)
55 | parser.add_argument('--weight_decay', default=1e-4, type=float)
56 | parser.add_argument('--grad_clip', default=1.0, type=float)
57 | parser.add_argument('--num_steps', default=100000, type=int)
58 | parser.add_argument('--seed', default=326, type=int)
59 | parser.add_argument('--summary_freq', default=100, type=int)
60 | parser.add_argument('--val_freq', default=5000, type=int)
61 | parser.add_argument('--save_ckpt_freq', default=50000, type=int)
62 | parser.add_argument('--resume', default=None, type=str)
63 | parser.add_argument('--no_resume_optimizer', action='store_true')
64 | parser.add_argument('--no_latest_ckpt', action='store_true')
65 | parser.add_argument('--save_latest_ckpt_freq', default=1000, type=int)
66 | parser.add_argument('--freeze_bn', action='store_true')
67 |
68 | parser.add_argument('--train_iters', default=12, type=int)
69 | parser.add_argument('--val_iters', default=12, type=int)
70 |
71 | # Flow1D
72 | parser.add_argument('--downsample_factor', default=8, type=int)
73 | parser.add_argument('--feature_channels', default=256, type=int)
74 | parser.add_argument('--corr_radius', default=32, type=int)
75 | parser.add_argument('--hidden_dim', default=128, type=int)
76 | parser.add_argument('--context_dim', default=128, type=int)
77 | parser.add_argument('--gamma', default=0.8, type=float,
78 | help='Exponential weighting')
79 |
80 | parser.add_argument('--mixed_precision', action='store_true')
81 |
82 | # Distributed training
83 | parser.add_argument('--local_rank', default=0, type=int)
84 |
85 | # Misc
86 | parser.add_argument('--count_time', action='store_true')
87 |
88 | return parser
89 |
90 |
91 | def main(args):
92 | if not args.eval and not args.submission and args.inference_dir is None:
93 | print('PyTorch version:', torch.__version__)
94 | print(args)
95 | misc.save_args(args)
96 | misc.check_path(args.checkpoint_dir)
97 | misc.save_command(args.checkpoint_dir)
98 |
99 | misc.check_path(args.output_path)
100 |
101 | seed = args.seed
102 | torch.manual_seed(seed)
103 | np.random.seed(seed)
104 |
105 | torch.backends.cudnn.benchmark = True
106 |
107 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
108 |
109 | # model
110 | model = build_model(args).to(device)
111 |
112 | if not args.eval:
113 | print('Model definition:')
114 | print(model)
115 |
116 | if torch.cuda.device_count() > 1:
117 | print('Use %d GPUs' % torch.cuda.device_count())
118 | model = torch.nn.DataParallel(model)
119 |
120 | model_without_ddp = model.module
121 | else:
122 | model_without_ddp = model
123 |
124 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
125 | print('Number of params:', num_params)
126 | if not args.eval and not args.submission and args.inference_dir is None:
127 | save_name = '%d_parameters' % num_params
128 | open(os.path.join(args.checkpoint_dir, save_name), 'a').close()
129 |
130 | # optimizer
131 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr,
132 | weight_decay=args.weight_decay)
133 |
134 | start_epoch = 0
135 | start_step = 0
136 |
137 | # resume checkpoints
138 | if args.resume:
139 | print('Load checkpoint: %s' % args.resume)
140 | checkpoint = torch.load(args.resume)
141 | weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
142 | model_without_ddp.load_state_dict(weights, strict=False)
143 |
144 | if 'optimizer' in checkpoint and 'step' in checkpoint and 'epoch' in checkpoint and not \
145 | args.no_resume_optimizer:
146 | print('Load optimizer')
147 | optimizer.load_state_dict(checkpoint['optimizer'])
148 | start_epoch = checkpoint['epoch']
149 | start_step = checkpoint['step']
150 |
151 | print('start_epoch: %d, start_step: %d' % (start_epoch, start_step))
152 |
153 | # evaluate
154 | if args.eval:
155 | if 'chairs' in args.val_dataset:
156 | validate_chairs(model_without_ddp,
157 | iters=args.val_iters,
158 | )
159 | elif 'sintel' in args.val_dataset:
160 | validate_sintel(model_without_ddp,
161 | iters=args.val_iters,
162 | padding_factor=args.padding_factor,
163 | count_time=args.count_time,
164 | )
165 | elif 'kitti' in args.val_dataset:
166 | validate_kitti(model_without_ddp,
167 | iters=args.val_iters,
168 | padding_factor=args.padding_factor,
169 | )
170 | else:
171 | raise ValueError(f'Dataset type {args.val_dataset} is not supported')
172 |
173 | return
174 |
175 | # create sintel and kitti submission
176 | if args.submission:
177 | if args.val_dataset[0] == 'sintel':
178 | create_sintel_submission(model_without_ddp,
179 | iters=args.val_iters,
180 | warm_start=args.warm_start,
181 | output_path=args.output_path,
182 | padding_factor=args.padding_factor,
183 | save_vis_flow=args.save_vis_flow,
184 | no_save_flo=args.no_save_flo,
185 | )
186 | elif args.val_dataset[0] == 'kitti':
187 | create_kitti_submission(model_without_ddp,
188 | iters=args.val_iters,
189 | output_path=args.output_path,
190 | padding_factor=args.padding_factor,
191 | save_vis_flow=args.save_vis_flow,
192 | )
193 | else:
194 | raise ValueError(f'Not supported dataset for submission')
195 |
196 | return
197 |
198 | # inferece on a dir
199 | if args.inference_dir is not None:
200 | inference_on_dir(model_without_ddp,
201 | inference_dir=args.inference_dir,
202 | iters=args.val_iters,
203 | warm_start=args.warm_start,
204 | output_path=args.output_path,
205 | padding_factor=args.padding_factor,
206 | paired_data=args.dir_paired_data,
207 | save_flo_flow=args.save_flo_flow,
208 | )
209 |
210 | return
211 |
212 | # train datset
213 | train_dataset = build_dataset(args)
214 | print('Number of training images:', len(train_dataset))
215 |
216 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
217 | shuffle=True, num_workers=args.num_workers,
218 | pin_memory=True, drop_last=True)
219 |
220 | last_epoch = start_step if args.resume and not args.no_resume_optimizer else -1
221 | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps + 10,
222 | pct_start=args.lr_warmup, cycle_momentum=False,
223 | anneal_strategy='linear',
224 | last_epoch=last_epoch,
225 | )
226 |
227 | if args.local_rank == 0:
228 | summary_writer = SummaryWriter(args.checkpoint_dir)
229 | logger = Logger(lr_scheduler, summary_writer, args.summary_freq,
230 | start_step=start_step)
231 |
232 | total_steps = start_step
233 | epoch = start_epoch
234 | print('Start training')
235 | while total_steps < args.num_steps:
236 | model.train()
237 |
238 | # freeze BN after pretraining on chairs
239 | if args.freeze_bn:
240 | model_without_ddp.freeze_bn()
241 |
242 | print('Start epoch %d' % (epoch + 1))
243 | for i, sample in enumerate(train_loader):
244 | img1, img2, flow_gt, valid = [x.to(device) for x in sample]
245 |
246 | flow_preds = model(img1, img2, iters=args.train_iters)[0]
247 |
248 | loss, metrics = criterion(flow_preds, flow_gt, valid,
249 | gamma=args.gamma,
250 | max_flow=args.max_flow)
251 |
252 | optimizer.zero_grad()
253 | loss.backward()
254 |
255 | # gradient clipping
256 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
257 |
258 | optimizer.step()
259 | lr_scheduler.step()
260 |
261 | if args.local_rank == 0:
262 | logger.push(metrics)
263 |
264 | logger.add_image_summary(img1, img2, flow_preds, flow_gt)
265 |
266 | total_steps += 1
267 |
268 | if total_steps % args.save_ckpt_freq == 0 or total_steps == args.num_steps:
269 | if args.local_rank == 0:
270 | print('Save checkpoint at step: %d' % total_steps)
271 | checkpoint_path = os.path.join(args.checkpoint_dir, 'step_%06d.pth' % total_steps)
272 | torch.save({
273 | 'model': model_without_ddp.state_dict()
274 | }, checkpoint_path)
275 |
276 | if total_steps % args.save_latest_ckpt_freq == 0:
277 | # Save lastest checkpoint after each epoch
278 | checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint_latest.pth')
279 |
280 | if args.local_rank == 0:
281 | print('Save latest checkpoint')
282 | torch.save({
283 | 'model': model_without_ddp.state_dict(),
284 | 'optimizer': optimizer.state_dict(),
285 | 'step': total_steps,
286 | 'epoch': epoch,
287 | }, checkpoint_path)
288 |
289 | if total_steps % args.val_freq == 0:
290 | if args.local_rank == 0:
291 | print('Start validation')
292 |
293 | val_results = {}
294 | # Support validation on multiple datasets
295 | if 'chairs' in args.val_dataset:
296 | results_dict = validate_chairs(model_without_ddp,
297 | iters=args.val_iters,
298 | )
299 | val_results.update(results_dict)
300 | if 'sintel' in args.val_dataset:
301 | results_dict = validate_sintel(model_without_ddp,
302 | iters=args.val_iters,
303 | padding_factor=args.padding_factor,
304 | )
305 | val_results.update(results_dict)
306 |
307 | if 'kitti' in args.val_dataset:
308 | results_dict = validate_kitti(model_without_ddp,
309 | iters=args.val_iters,
310 | padding_factor=args.padding_factor,
311 | )
312 | val_results.update(results_dict)
313 |
314 | logger.write_dict(val_results)
315 |
316 | # Save validation results
317 | val_file = os.path.join(args.checkpoint_dir, 'val_results.txt')
318 | with open(val_file, 'a') as f:
319 | f.write('step: %06d\t' % total_steps)
320 | # order of metrics
321 | metrics = ['chairs_epe', 'chairs_1px', 'clean_epe', 'clean_1px', 'final_epe', 'final_1px',
322 | 'kitti_epe', 'kitti_f1']
323 | for metric in metrics:
324 | if metric in val_results.keys():
325 | f.write('%s: %.3f\t' % (metric, val_results[metric]))
326 | f.write('\n')
327 |
328 | model.train()
329 |
330 | # freeze BN after pretraining on chairs
331 | if args.freeze_bn:
332 | model_without_ddp.freeze_bn()
333 |
334 | if total_steps >= args.num_steps:
335 | print('Training done')
336 |
337 | return
338 |
339 | epoch += 1
340 |
341 |
342 | if __name__ == '__main__':
343 | parser = get_args_parser()
344 | args = parser.parse_args()
345 |
346 | main(args)
347 |
--------------------------------------------------------------------------------
/scripts/demo.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CUDA_VISIBLE_DEVICES=0 python main.py \
4 | --resume pretrained/flow1d_highres-e0b98d7e.pth \
5 | --val_iters 24 \
6 | --inference_dir demo/dogs-jump \
7 | --output_path output/flow1d-dogs-jump
8 |
--------------------------------------------------------------------------------
/scripts/evaluate.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
4 | # evaluate chairs & things trained model on kitti (24 iters)
5 | CUDA_VISIBLE_DEVICES=0 python main.py \
6 | --eval \
7 | --val_dataset kitti \
8 | --resume pretrained/flow1d_things-fd4bee1f.pth \
9 | --val_iters 24
10 |
11 |
12 | # evaluate chairs & things trained model on sintel (32 iters)
13 | CUDA_VISIBLE_DEVICES=0 python main.py \
14 | --eval \
15 | --val_dataset sintel \
16 | --resume pretrained/flow1d_things-fd4bee1f.pth \
17 | --val_iters 32
18 |
19 |
20 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # can be trained on a single 32G V100 GPU
4 |
5 | # chairs
6 | CHECKPOINT_DIR=checkpoints/chairs-flow1d && \
7 | mkdir -p ${CHECKPOINT_DIR} && \
8 | CUDA_VISIBLE_DEVICES=0 python main.py \
9 | --checkpoint_dir ${CHECKPOINT_DIR} \
10 | --batch_size 12 \
11 | --val_dataset chairs sintel kitti \
12 | --val_iters 12 \
13 | --lr 4e-4 \
14 | --image_size 368 496 \
15 | --summary_freq 100 \
16 | --val_freq 10000 \
17 | --save_ckpt_freq 5000 \
18 | --save_latest_ckpt_freq 1000 \
19 | --num_steps 100000 \
20 | 2>&1 | tee ${CHECKPOINT_DIR}/train.log
21 |
22 | # things
23 | CHECKPOINT_DIR=checkpoints/things-flow1d && \
24 | mkdir -p ${CHECKPOINT_DIR} && \
25 | CUDA_VISIBLE_DEVICES=0 python main.py \
26 | --stage things \
27 | --resume checkpoints/chairs-flow1d/step_100000.pth \
28 | --no_resume_optimizer \
29 | --checkpoint_dir ${CHECKPOINT_DIR} \
30 | --batch_size 6 \
31 | --val_dataset sintel kitti \
32 | --val_iters 12 \
33 | --lr 1.25e-4 \
34 | --image_size 400 720 \
35 | --freeze_bn \
36 | --summary_freq 100 \
37 | --val_freq 10000 \
38 | --save_ckpt_freq 5000 \
39 | --save_latest_ckpt_freq 1000 \
40 | --num_steps 100000 \
41 | 2>&1 | tee ${CHECKPOINT_DIR}/train.log
42 |
43 | # sintel
44 | CHECKPOINT_DIR=checkpoints/sintel-flow1d && \
45 | mkdir -p ${CHECKPOINT_DIR} && \
46 | CUDA_VISIBLE_DEVICES=0 python main.py \
47 | --stage sintel \
48 | --resume checkpoints/things-flow1d/step_100000.pth \
49 | --no_resume_optimizer \
50 | --checkpoint_dir ${CHECKPOINT_DIR} \
51 | --batch_size 6 \
52 | --val_dataset sintel kitti \
53 | --val_iters 12 \
54 | --lr 1.25e-4 \
55 | --weight_decay 1e-5 \
56 | --gamma 0.85 \
57 | --image_size 368 960 \
58 | --freeze_bn \
59 | --summary_freq 100 \
60 | --val_freq 10000 \
61 | --save_ckpt_freq 5000 \
62 | --save_latest_ckpt_freq 1000 \
63 | --num_steps 100000 \
64 | 2>&1 | tee ${CHECKPOINT_DIR}/train.log
65 |
66 | # kitti
67 | CHECKPOINT_DIR=checkpoints/kitti-flow1d && \
68 | mkdir -p ${CHECKPOINT_DIR} && \
69 | CUDA_VISIBLE_DEVICES=0 python main.py \
70 | --stage kitti \
71 | --resume checkpoints/sintel-flow1d/step_100000.pth \
72 | --no_resume_optimizer \
73 | --checkpoint_dir ${CHECKPOINT_DIR} \
74 | --batch_size 6 \
75 | --val_dataset kitti \
76 | --val_iters 12 \
77 | --lr 1e-4 \
78 | --weight_decay 1e-5 \
79 | --gamma 0.85 \
80 | --image_size 320 1024 \
81 | --freeze_bn \
82 | --summary_freq 100 \
83 | --val_freq 10000 \
84 | --save_ckpt_freq 5000 \
85 | --save_latest_ckpt_freq 1000 \
86 | --num_steps 50000 \
87 | 2>&1 | tee ${CHECKPOINT_DIR}/train.log
88 |
89 |
--------------------------------------------------------------------------------
/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-08-03
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 |
21 |
22 | def make_colorwheel():
23 | '''
24 | Generates a color wheel for optical flow visualization as presented in:
25 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
26 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
27 | According to the C++ source code of Daniel Scharstein
28 | According to the Matlab source code of Deqing Sun
29 | '''
30 |
31 | RY = 15
32 | YG = 6
33 | GC = 4
34 | CB = 11
35 | BM = 13
36 | MR = 6
37 |
38 | ncols = RY + YG + GC + CB + BM + MR
39 | colorwheel = np.zeros((ncols, 3))
40 | col = 0
41 |
42 | # RY
43 | colorwheel[0:RY, 0] = 255
44 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
45 | col = col + RY
46 | # YG
47 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
48 | colorwheel[col:col + YG, 1] = 255
49 | col = col + YG
50 | # GC
51 | colorwheel[col:col + GC, 1] = 255
52 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
53 | col = col + GC
54 | # CB
55 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
56 | colorwheel[col:col + CB, 2] = 255
57 | col = col + CB
58 | # BM
59 | colorwheel[col:col + BM, 2] = 255
60 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
61 | col = col + BM
62 | # MR
63 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
64 | colorwheel[col:col + MR, 0] = 255
65 | return colorwheel
66 |
67 |
68 | def flow_compute_color(u, v, convert_to_bgr=False):
69 | '''
70 | Applies the flow color wheel to (possibly clipped) flow components u and v.
71 | According to the C++ source code of Daniel Scharstein
72 | According to the Matlab source code of Deqing Sun
73 | :param u: np.ndarray, input horizontal flow
74 | :param v: np.ndarray, input vertical flow
75 | :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
76 | :return:
77 | '''
78 |
79 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
80 |
81 | colorwheel = make_colorwheel() # shape [55x3]
82 | ncols = colorwheel.shape[0]
83 |
84 | rad = np.sqrt(np.square(u) + np.square(v))
85 | a = np.arctan2(-v, -u) / np.pi
86 |
87 | fk = (a + 1) / 2 * (ncols - 1) + 1
88 | k0 = np.floor(fk).astype(np.int32)
89 | k1 = k0 + 1
90 | k1[k1 == ncols] = 1
91 | f = fk - k0
92 |
93 | for i in range(colorwheel.shape[1]):
94 | tmp = colorwheel[:, i]
95 | col0 = tmp[k0] / 255.0
96 | col1 = tmp[k1] / 255.0
97 | col = (1 - f) * col0 + f * col1
98 |
99 | idx = (rad <= 1)
100 | col[idx] = 1 - rad[idx] * (1 - col[idx])
101 | col[~idx] = col[~idx] * 0.75 # out of range?
102 |
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2 - i if convert_to_bgr else i
105 | flow_image[:, :, ch_idx] = np.floor(255 * col)
106 |
107 | return flow_image
108 |
109 |
110 | def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
111 | '''
112 | Expects a two dimensional flow image of shape [H,W,2]
113 | According to the C++ source code of Daniel Scharstein
114 | According to the Matlab source code of Deqing Sun
115 | :param flow_uv: np.ndarray of shape [H,W,2]
116 | :param clip_flow: float, maximum clipping value for flow
117 | :return:
118 | '''
119 |
120 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
121 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
122 |
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 |
126 | u = flow_uv[:, :, 0]
127 | v = flow_uv[:, :, 1]
128 |
129 | rad = np.sqrt(np.square(u) + np.square(v))
130 | rad_max = np.max(rad)
131 |
132 | epsilon = 1e-5
133 | u = u / (rad_max + epsilon)
134 | v = v / (rad_max + epsilon)
135 |
136 | return flow_compute_color(u, v, convert_to_bgr)
137 |
138 |
139 | UNKNOWN_FLOW_THRESH = 1e7
140 | SMALLFLOW = 0.0
141 | LARGEFLOW = 1e8
142 |
143 |
144 | def make_color_wheel():
145 | """
146 | Generate color wheel according Middlebury color code
147 | :return: Color wheel
148 | """
149 | RY = 15
150 | YG = 6
151 | GC = 4
152 | CB = 11
153 | BM = 13
154 | MR = 6
155 |
156 | ncols = RY + YG + GC + CB + BM + MR
157 |
158 | colorwheel = np.zeros([ncols, 3])
159 |
160 | col = 0
161 |
162 | # RY
163 | colorwheel[0:RY, 0] = 255
164 | colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
165 | col += RY
166 |
167 | # YG
168 | colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
169 | colorwheel[col:col + YG, 1] = 255
170 | col += YG
171 |
172 | # GC
173 | colorwheel[col:col + GC, 1] = 255
174 | colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
175 | col += GC
176 |
177 | # CB
178 | colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
179 | colorwheel[col:col + CB, 2] = 255
180 | col += CB
181 |
182 | # BM
183 | colorwheel[col:col + BM, 2] = 255
184 | colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
185 | col += + BM
186 |
187 | # MR
188 | colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
189 | colorwheel[col:col + MR, 0] = 255
190 |
191 | return colorwheel
192 |
193 |
194 | def compute_color(u, v):
195 | """
196 | compute optical flow color map
197 | :param u: optical flow horizontal map
198 | :param v: optical flow vertical map
199 | :return: optical flow in color code
200 | """
201 | [h, w] = u.shape
202 | img = np.zeros([h, w, 3])
203 | nanIdx = np.isnan(u) | np.isnan(v)
204 | u[nanIdx] = 0
205 | v[nanIdx] = 0
206 |
207 | colorwheel = make_color_wheel()
208 | ncols = np.size(colorwheel, 0)
209 |
210 | rad = np.sqrt(u ** 2 + v ** 2)
211 |
212 | a = np.arctan2(-v, -u) / np.pi
213 |
214 | fk = (a + 1) / 2 * (ncols - 1) + 1
215 |
216 | k0 = np.floor(fk).astype(int)
217 |
218 | k1 = k0 + 1
219 | k1[k1 == ncols + 1] = 1
220 | f = fk - k0
221 |
222 | for i in range(0, np.size(colorwheel, 1)):
223 | tmp = colorwheel[:, i]
224 | col0 = tmp[k0 - 1] / 255
225 | col1 = tmp[k1 - 1] / 255
226 | col = (1 - f) * col0 + f * col1
227 |
228 | idx = rad <= 1
229 | col[idx] = 1 - rad[idx] * (1 - col[idx])
230 | notidx = np.logical_not(idx)
231 |
232 | col[notidx] *= 0.75
233 | img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
234 |
235 | return img
236 |
237 |
238 | # from https://github.com/gengshan-y/VCN
239 | def flow_to_image(flow):
240 | """
241 | Convert flow into middlebury color code image
242 | :param flow: optical flow map
243 | :return: optical flow image in middlebury color
244 | """
245 | u = flow[:, :, 0]
246 | v = flow[:, :, 1]
247 |
248 | maxu = -999.
249 | maxv = -999.
250 | minu = 999.
251 | minv = 999.
252 |
253 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
254 | u[idxUnknow] = 0
255 | v[idxUnknow] = 0
256 |
257 | maxu = max(maxu, np.max(u))
258 | minu = min(minu, np.min(u))
259 |
260 | maxv = max(maxv, np.max(v))
261 | minv = min(minv, np.min(v))
262 |
263 | rad = np.sqrt(u ** 2 + v ** 2)
264 | maxrad = max(-1, np.max(rad))
265 |
266 | u = u / (maxrad + np.finfo(float).eps)
267 | v = v / (maxrad + np.finfo(float).eps)
268 |
269 | img = compute_color(u, v)
270 |
271 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
272 | img[idx] = 0
273 |
274 | return np.uint8(img)
275 |
276 |
277 | def save_vis_flow_tofile(flow, output_path):
278 | vis_flow = flow_to_image(flow)
279 | from PIL import Image
280 | img = Image.fromarray(vis_flow)
281 | img.save(output_path)
282 |
283 |
284 | def flow_tensor_to_image(flow):
285 | """Used for tensorboard visualization"""
286 | flow = flow.permute(1, 2, 0) # [H, W, 2]
287 | flow = flow.detach().cpu().numpy()
288 | flow = flow_to_image(flow) # [H, W, 3]
289 | flow = np.transpose(flow, (2, 0, 1)) # [3, H, W]
290 |
291 | return flow
292 |
--------------------------------------------------------------------------------
/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 | import cv2
6 |
7 | TAG_CHAR = np.array([202021.25], np.float32)
8 |
9 |
10 | def readFlow(fn):
11 | """ Read .flo file in Middlebury format"""
12 | # Code adapted from:
13 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
14 |
15 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
16 | # print 'fn = %s'%(fn)
17 | with open(fn, 'rb') as f:
18 | magic = np.fromfile(f, np.float32, count=1)
19 | if 202021.25 != magic:
20 | print('Magic number incorrect. Invalid .flo file')
21 | return None
22 | else:
23 | w = np.fromfile(f, np.int32, count=1)
24 | h = np.fromfile(f, np.int32, count=1)
25 | # print 'Reading %d x %d flo file\n' % (w, h)
26 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
27 | # Reshape data into 3D array (columns, rows, bands)
28 | # The reshape here is for visualization, the original code is (w,h,2)
29 | return np.resize(data, (int(h), int(w), 2))
30 |
31 |
32 | def readPFM(file):
33 | file = open(file, 'rb')
34 |
35 | color = None
36 | width = None
37 | height = None
38 | scale = None
39 | endian = None
40 |
41 | header = file.readline().rstrip()
42 | if header == b'PF':
43 | color = True
44 | elif header == b'Pf':
45 | color = False
46 | else:
47 | raise Exception('Not a PFM file.')
48 |
49 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
50 | if dim_match:
51 | width, height = map(int, dim_match.groups())
52 | else:
53 | raise Exception('Malformed PFM header.')
54 |
55 | scale = float(file.readline().rstrip())
56 | if scale < 0: # little-endian
57 | endian = '<'
58 | scale = -scale
59 | else:
60 | endian = '>' # big-endian
61 |
62 | data = np.fromfile(file, endian + 'f')
63 | shape = (height, width, 3) if color else (height, width)
64 |
65 | data = np.reshape(data, shape)
66 | data = np.flipud(data)
67 | return data
68 |
69 |
70 | def writeFlow(filename, uv, v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert (uv.ndim == 3)
81 | assert (uv.shape[2] == 2)
82 | u = uv[:, :, 0]
83 | v = uv[:, :, 1]
84 | else:
85 | u = uv
86 |
87 | assert (u.shape == v.shape)
88 | height, width = u.shape
89 | f = open(filename, 'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width * nBands))
96 | tmp[:, np.arange(width) * 2] = u
97 | tmp[:, np.arange(width) * 2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
104 | flow = flow[:, :, ::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2 ** 15) / 64.0
107 | return flow, valid
108 |
109 |
110 | def readDispKITTI(filename):
111 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
112 | valid = disp > 0.0
113 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
114 | return flow, valid
115 |
116 |
117 | def writeFlowKITTI(filename, uv):
118 | uv = 64.0 * uv + 2 ** 15
119 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
120 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
121 | cv2.imwrite(filename, uv[..., ::-1])
122 |
123 |
124 | def read_gen(file_name, pil=False):
125 | ext = splitext(file_name)[-1]
126 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
127 | return Image.open(file_name)
128 | elif ext == '.bin' or ext == '.raw':
129 | return np.load(file_name)
130 | elif ext == '.flo':
131 | return readFlow(file_name).astype(np.float32)
132 | elif ext == '.pfm':
133 | flow = readPFM(file_name).astype(np.float32)
134 | if len(flow.shape) == 2:
135 | return flow
136 | else:
137 | return flow[:, :, :-1]
138 | return []
139 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils.flow_viz import flow_tensor_to_image
4 |
5 |
6 | class Logger:
7 | def __init__(self, lr_scheduler,
8 | summary_writer,
9 | summary_freq=100,
10 | start_step=0,
11 | ):
12 | self.lr_scheduler = lr_scheduler
13 | self.total_steps = start_step
14 | self.running_loss = {}
15 | self.summary_writer = summary_writer
16 | self.summary_freq = summary_freq
17 |
18 | def print_training_status(self, mode='train'):
19 | print('step: %06d \t epe: %.3f' % (self.total_steps, self.running_loss['epe'] / self.summary_freq))
20 |
21 | for k in self.running_loss:
22 | self.summary_writer.add_scalar(mode + '/' + k,
23 | self.running_loss[k] / self.summary_freq, self.total_steps)
24 | self.running_loss[k] = 0.0
25 |
26 | def lr_summary(self):
27 | lr = self.lr_scheduler.get_last_lr()[0]
28 | self.summary_writer.add_scalar('lr', lr, self.total_steps)
29 |
30 | def add_image_summary(self, img1, img2, flow_preds, flow_gt, mode='train',
31 | pred_bidirectional_flow=False):
32 | if self.total_steps % self.summary_freq == 0:
33 | img_concat = torch.cat((img1[0].detach().cpu(), img2[0].detach().cpu()), dim=-1)
34 | img_concat = img_concat.type(torch.uint8) # convert to uint8 to visualize in tensorboard
35 |
36 | flow_pred = flow_tensor_to_image(flow_preds[-1][0])
37 | forward_flow_gt = flow_tensor_to_image(flow_gt[0])
38 | flow_concat = torch.cat((torch.from_numpy(flow_pred),
39 | torch.from_numpy(forward_flow_gt)), dim=-1)
40 |
41 | concat = torch.cat((img_concat, flow_concat), dim=-2)
42 |
43 | self.summary_writer.add_image(mode + '/img_pred_gt', concat, self.total_steps)
44 |
45 | def add_init_flow_summary(self, init_flow, mode='train', tag='init_flow'):
46 | if self.total_steps % self.summary_freq == 0:
47 | init_flow = flow_tensor_to_image(init_flow[0])
48 | init_flow = torch.from_numpy(init_flow)
49 |
50 | self.summary_writer.add_image(mode + '/' + tag, init_flow, self.total_steps)
51 |
52 | def push(self, metrics, mode='train'):
53 | self.total_steps += 1
54 |
55 | self.lr_summary()
56 |
57 | for key in metrics:
58 | if key not in self.running_loss:
59 | self.running_loss[key] = 0.0
60 |
61 | self.running_loss[key] += metrics[key]
62 |
63 | if self.total_steps % self.summary_freq == 0:
64 | self.print_training_status(mode)
65 | self.running_loss = {}
66 |
67 | def write_dict(self, results):
68 | for key in results:
69 | tag = key.split('_')[0]
70 | tag = tag + '/' + key
71 | self.summary_writer.add_scalar(tag, results[key], self.total_steps)
72 |
73 | def close(self):
74 | self.summary_writer.close()
75 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import sys
4 | import json
5 |
6 |
7 | def read_text_lines(filepath):
8 | with open(filepath, 'r') as f:
9 | lines = f.readlines()
10 | lines = [l.rstrip() for l in lines]
11 | return lines
12 |
13 |
14 | def check_path(path):
15 | if not os.path.exists(path):
16 | os.makedirs(path, exist_ok=True) # explicitly set exist_ok when multi-processing
17 |
18 |
19 | def save_command(save_path, filename='command_train.txt'):
20 | check_path(save_path)
21 | command = sys.argv
22 | save_file = os.path.join(save_path, filename)
23 | # Save all training commands when resuming training
24 | with open(save_file, 'a') as f:
25 | f.write(' '.join(command))
26 | f.write('\n\n')
27 |
28 |
29 | def save_args(args, filename='args.json'):
30 | args_dict = vars(args)
31 | check_path(args.checkpoint_dir)
32 | save_path = os.path.join(args.checkpoint_dir, filename)
33 |
34 | # Save all training args when resuming training
35 | with open(save_path, 'a') as f:
36 | json.dump(args_dict, f, indent=4, sort_keys=False)
37 | f.write('\n\n')
38 |
39 |
40 | def int_list(s):
41 | """Convert string to int list"""
42 | return [int(x) for x in s.split(',')]
43 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 |
10 | def __init__(self, dims, mode='sintel', padding_factor=8):
11 | self.ht, self.wd = dims[-2:]
12 | pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor
13 | pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor
14 | if mode == 'sintel':
15 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
16 | else:
17 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
18 |
19 | def pad(self, *inputs):
20 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
21 |
22 | def unpad(self, x):
23 | ht, wd = x.shape[-2:]
24 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
25 | return x[..., c[0]:c[1], c[2]:c[3]]
26 |
27 |
28 | def forward_interpolate(flow):
29 | flow = flow.detach().cpu().numpy() # [2, H, W]
30 | dx, dy = flow[0], flow[1]
31 |
32 | ht, wd = dx.shape
33 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
34 |
35 | x1 = x0 + dx
36 | y1 = y0 + dy
37 |
38 | x1 = x1.reshape(-1)
39 | y1 = y1.reshape(-1)
40 | dx = dx.reshape(-1)
41 | dy = dy.reshape(-1)
42 |
43 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
44 | x1 = x1[valid]
45 | y1 = y1[valid]
46 | dx = dx[valid]
47 | dy = dy[valid]
48 |
49 | flow_x = interpolate.griddata(
50 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
51 |
52 | flow_y = interpolate.griddata(
53 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
54 |
55 | flow = np.stack([flow_x, flow_y], axis=0)
56 | return torch.from_numpy(flow).float()
57 |
58 |
59 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
60 | """ Wrapper for grid_sample, uses pixel coordinates """
61 | if coords.size(-1) != 2: # [B, 2, H, W] -> [B, H, W, 2]
62 | coords = coords.permute(0, 2, 3, 1)
63 |
64 | H, W = img.shape[-2:]
65 | # H = height if height is not None else img.shape[-2]
66 | # W = width if width is not None else img.shape[-1]
67 |
68 | xgrid, ygrid = coords.split([1, 1], dim=-1)
69 |
70 | # To handle H or W equals to 1 by explicitly defining height and width
71 | if H == 1:
72 | assert ygrid.abs().max() < 1e-8
73 | H = 10
74 | if W == 1:
75 | assert xgrid.abs().max() < 1e-8
76 | W = 10
77 |
78 | xgrid = 2 * xgrid / (W - 1) - 1
79 | ygrid = 2 * ygrid / (H - 1) - 1
80 |
81 | grid = torch.cat([xgrid, ygrid], dim=-1)
82 | img = F.grid_sample(img, grid, mode=mode, align_corners=True)
83 |
84 | if mask:
85 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
86 | return img, mask.squeeze(-1).float()
87 |
88 | return img
89 |
90 |
91 | def coords_grid(batch, ht, wd, normalize=False):
92 | if normalize: # [-1, 1]
93 | coords = torch.meshgrid(2 * torch.arange(ht) / (ht - 1) - 1,
94 | 2 * torch.arange(wd) / (wd - 1) - 1)
95 | else:
96 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
97 | coords = torch.stack(coords[::-1], dim=0).float()
98 | return coords[None].repeat(batch, 1, 1, 1) # [B, 2, H, W]
99 |
100 |
101 | def coords_grid_np(h, w): # used for accumulating high speed sintel flow data
102 | coords = np.meshgrid(np.arange(h, dtype=np.float32),
103 | np.arange(w, dtype=np.float32), indexing='ij')
104 | coords = np.stack(coords[::-1], axis=-1) # [H, W, 2]
105 |
106 | return coords
107 |
108 |
109 | def normalize_coords(grid):
110 | """Normalize coordinates of image scale to [-1, 1]
111 | Args:
112 | grid: [B, 2, H, W]
113 | """
114 | assert grid.size(1) == 2
115 | h, w = grid.size()[2:]
116 | grid[:, 0, :, :] = 2 * (grid[:, 0, :, :].clone() / (w - 1)) - 1 # x: [-1, 1]
117 | grid[:, 1, :, :] = 2 * (grid[:, 1, :, :].clone() / (h - 1)) - 1 # y: [-1, 1]
118 | # grid = grid.permute((0, 2, 3, 1)) # [B, H, W, 2]
119 | return grid
120 |
121 |
122 | def flow_warp(feature, flow, mask=False):
123 | b, c, h, w = feature.size()
124 | assert flow.size(1) == 2
125 |
126 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
127 |
128 | return bilinear_sampler(feature, grid, mask=mask)
129 |
130 |
131 | def upflow8(flow, mode='bilinear'):
132 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
133 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
134 |
135 |
136 | def bilinear_upflow(flow, scale_factor=8):
137 | assert flow.size(1) == 2
138 | flow = F.interpolate(flow, scale_factor=scale_factor,
139 | mode='bilinear', align_corners=True) * scale_factor
140 |
141 | return flow
142 |
143 |
144 | def upsample_flow(flow, img):
145 | if flow.size(-1) != img.size(-1):
146 | scale_factor = img.size(-1) / flow.size(-1)
147 | flow = F.interpolate(flow, size=img.size()[-2:],
148 | mode='bilinear', align_corners=True) * scale_factor
149 | return flow
150 |
151 |
152 | def count_parameters(model):
153 | num = sum(p.numel() for p in model.parameters() if p.requires_grad)
154 | return num
155 |
156 |
157 | def set_bn_eval(m):
158 | classname = m.__class__.__name__
159 | if classname.find('BatchNorm') != -1:
160 | m.eval()
161 |
--------------------------------------------------------------------------------