├── LICENSE ├── README.md ├── chairs_split.txt ├── opticalflow ├── __init__.py ├── api │ ├── __init__.py │ ├── data_augment.py │ ├── evaluate.py │ ├── inference.py │ ├── init_model.py │ ├── manage_data.py │ ├── postprocess.py │ └── preprocess.py ├── core │ ├── __init__.py │ ├── data_aug │ │ ├── __init__.py │ │ └── distortion.py │ ├── dataset │ │ ├── DatasetManagerBase.py │ │ ├── KITTIManager.py │ │ └── __init__.py │ └── model │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── csflow.py │ │ ├── external │ │ ├── csflow.py │ │ ├── panoflow_csflow.py │ │ ├── panoflow_raft.py │ │ └── raft.py │ │ ├── panoflow_csflow.py │ │ ├── panoflow_raft.py │ │ └── raft.py ├── dataset │ ├── __init__.py │ ├── base_flow.py │ ├── flow360.py │ ├── flying_chairs.py │ ├── flying_things.py │ └── omni.py └── utils │ ├── __init__.py │ ├── augmentor.py │ ├── flow_utils.py │ ├── logger.py │ └── utils.py ├── results ├── Flow360.png ├── compare_.png ├── compare_quant.png └── panoshow.png ├── setup.cfg ├── setup.py └── tools ├── eval.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Hao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PanoFlow: Learning Optical Flow for Panoramic Images 2 | The implementations of [PanoFlow: Learning Optical Flow for Panoramic Images](https://arxiv.org/pdf/2202.13388.pdf). 3 | We achieve state-of-the-art accuracy on the public OmniFlowNet dataset and the proposed FlowScape (Flow360) dataset. 4 | This repository is built on the basis of [CSFlow](https://github.com/MasterHow/CSFlow). 5 | 6 | ![](results/panoshow.png) 7 | ![](results/compare_quant.png) 8 | ![](results/compare_.png) 9 | 10 | # FlowScape (Flow360) Dataset 11 | ![](results/Flow360.png) 12 | From left to right: overlapping image pairs, optical flow, and semantics. 13 | FlowScape (Flow360) dataset consists of 8 various city maps in four weathers: 14 | sunny, fog, cloud, and rain. 15 | We collect 100 consecutive panoramic images at each random position, 16 | resulting in a total of 6,400 frames with a resolution of 1024 x 512 , 17 | each with optical flow ground truth and semantic labels, 18 | which can be used for training and evaluation. In the current release, 19 | we provide optical flow ground truth in the classic format (i.e. the traditional flow). 20 | If you need 360° flow ground truth, you can simply convert it refer to the 21 | [paper](https://arxiv.org/pdf/2202.13388.pdf). 22 | 23 | Since the flow field of panoramic images usually contains large displacement that interferes with visualization and fades colors, 24 | we modified the visualization method of optical flow, 25 | and lowered the color saturation of optical flow greater than the threshold. 26 | 27 | ``` 28 | better_flow_to_image(flow, alpha=0.1, max_flow=25) 29 | ``` 30 | The function can be found in the flow_utils.py. In our paper, we set the alpha=0.1, max_flow=25. 31 | 32 | The valid mask excludes pixels whose semantics are sky. 33 | 34 | The semantic labels are as following: 35 | ``` 36 | camvid_colors = OrderedDict([ 37 | ("Unlabeled", np.array([0, 0, 0], dtype=np.uint8)), 38 | ("Building", np.array([70, 70, 70], dtype=np.uint8)), 39 | ("Fence", np.array([100, 40, 40], dtype=np.uint8)), 40 | ("Other", np.array([55, 90, 80], dtype=np.uint8)), 41 | ("Pedestrian", np.array([220, 20, 60], dtype=np.uint8)), 42 | ("Pole", np.array([153, 153, 153], dtype=np.uint8)), 43 | ("RoadLine", np.array([157, 234, 50], dtype=np.uint8)), 44 | ("Road", np.array([128, 64, 128], dtype=np.uint8)), 45 | ("SideWalk", np.array([244, 35, 232], dtype=np.uint8)), 46 | ("Vegetation", np.array([107, 142, 35], dtype=np.uint8)), 47 | ("Vehicles", np.array([0, 0, 142], dtype=np.uint8)), 48 | ("Wall", np.array([102, 102, 156], dtype=np.uint8)), 49 | ("TrafficSign", np.array([220, 220, 0], dtype=np.uint8)), 50 | ("Sky", np.array([70, 130, 180], dtype=np.uint8)), 51 | ("Ground", np.array([81, 0, 81], dtype=np.uint8)), 52 | ("Bridge", np.array([150, 100, 100], dtype=np.uint8)), 53 | ("RailTrack", np.array([230, 150, 140], dtype=np.uint8)), 54 | ("GroundRail", np.array([180, 165, 180], dtype=np.uint8)), 55 | ("TrafficLight", np.array([250, 170, 30], dtype=np.uint8)), 56 | ("Static", np.array([110, 190, 160], dtype=np.uint8)), 57 | ("Dynamic", np.array([170, 120, 50], dtype=np.uint8)), 58 | ("Water", np.array([45, 60, 150], dtype=np.uint8)), 59 | ("Terrain", np.array([145, 170, 100], dtype=np.uint8)), 60 | ]) 61 | ``` 62 | 63 | Anyone can download our FlowScape (Flow360) dataset via these links. 64 | 65 | Download link 1 [Tencent WeiYun](https://share.weiyun.com/SoXICYgh) 66 | 67 | Download link 2 [Baidu Cloud](https://pan.baidu.com/s/1ZjY6J-zN5Wb7JxRMeHvQSw?pwd=7u2v) 68 | 69 | Download link 3 [Google Drive](https://drive.google.com/file/d/1cKJZBRprwS6fu6Nf4U0eU6lkqB88tW_v/view?usp=sharing) 70 | 71 | The content in the above links are consistent, if you encounter network problems, you can try switching to the other link. 72 | 73 | # Install 74 | ``` 75 | python setup.py develop 76 | ``` 77 | 78 | # Pretrained Model 79 | The pretrained model that the paper used can be found there: 80 | 81 | Download link 1 [Tencent WeiYun](https://share.weiyun.com/SIpeQTNE) 82 | 83 | Download link 2 [Baidu Cloud](https://pan.baidu.com/s/10pmFoK8_Tc2y4790mQyBfA?pwd=FLOW) 84 | 85 | Download link 3 [Google Drive](https://drive.google.com/drive/folders/1Li3PpkjmxYWL4tdkY_tR8wOfOC9YmJE2?usp=sharing) 86 | 87 | Notice that the checkpoints don‘t consist of the CFE, 88 | considering that CFE is an estimation method, 89 | you only need to turn it on while inferring to obtain the 360° flow. 90 | 91 | # Train and Eval 92 | To train, use the following command format: 93 | ``` 94 | python ./tools/train.py 95 | --model PanoFlow(CSFlow) 96 | --dataset Flow360 97 | --data_root $YOUR_DATA_PATH$ 98 | --batch_size 6 99 | --name PanoFlow(CSFlow)-test 100 | --validation Chairs 101 | --val_Chairs_root $YOUR_DATA_PATH$ 102 | --num_steps 100 103 | --lr 0.000125 104 | --image_size 400 720 105 | --wdecay 0.0001 106 | ``` 107 | To eval, use the following command format: 108 | ``` 109 | python ./tools/eval.py 110 | --model PanoFlow(CSFlow) 111 | --restore_ckpt ./checkpoints/PanoFlow(CSFlow)-wo-CFE.pth 112 | --CFE 113 | --validation Flow360 114 | --val_Flow360_root $YOUR_DATA_PATH$ 115 | ``` 116 | For more details, please check the code or refer our [paper](https://arxiv.org/pdf/2202.13388.pdf). 117 | 118 | # Folder Hierarchy 119 | \* local: you should create this folder in your local repository and these folders will not upload to remote repository. 120 | ``` 121 | ├── data (local) # Store test/training data 122 | ├── checkpoints (local) # Store the checkpoints 123 | ├── runs (local) # Store the training log 124 | ├── opticalflow # All source code 125 | | ├─ api # Called by tools 126 | | ├─ core # Core code call by other directorys. Provide dataset, models .... 127 | | | ├─ dataset # I/O of each dataset 128 | | | ├─ model # Models, includeing all the modules that derive nn.module 129 | | | ├─ util # Utility functions 130 | ├── tools # Scripts for test and train 131 | ├── work_dirs (local) # For developers to save thier own codes and assets 132 | ``` 133 | 134 | # Citation 135 | If you find our project helpful in your research, please cite with: 136 | ``` 137 | @article{shi2022panoflow, 138 | title={PanoFlow: Learning optical flow for panoramic images}, 139 | author={Shi, Hao and Zhou, Yifan and Yang, Kailun and Ye, Yaozu and Yin, Xiaoting and Yin, Zhe and Meng, Shi and Wang, Kaiwei}, 140 | journal={arXiv preprint arXiv:2202.13388}, 141 | year={2022} 142 | } 143 | ``` 144 | 145 | # Devs 146 | Hao Shi,YiFan Zhou 147 | 148 | # Need Help? 149 | If you have any questions, welcome to e-mail me: haoshi@zju.edu.cn, and I will try my best to help you. =) -------------------------------------------------------------------------------- /opticalflow/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | importlib.import_module('opticalflow.api') 4 | importlib.import_module('opticalflow.core') 5 | importlib.import_module('opticalflow.utils') 6 | -------------------------------------------------------------------------------- /opticalflow/api/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_augment import distort_flow, distort_img 2 | from .evaluate import init_evaluator 3 | from .inference import inference 4 | from .init_model import init_model 5 | from .manage_data import create_dataloader, load_data, output_data 6 | from .postprocess import postprocess_data 7 | from .preprocess import preprocess_data 8 | 9 | __all__ = [ 10 | 'load_data', 'output_data', 'init_model', 'preprocess_data', 11 | 'create_dataloader', 'inference', 'postprocess_data', 'init_evaluator', 12 | 'distort_img', 'distort_flow' 13 | ] 14 | -------------------------------------------------------------------------------- /opticalflow/api/data_augment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from opticalflow.core.data_aug import RadialDistortion 4 | 5 | 6 | def __build_distortion(method: str, **kwargs): 7 | if method == 'radial': 8 | if not kwargs: 9 | kwargs = {'ks': [0, 1e-5, 0, 1e-14, 0, 1e-15]} 10 | distortion = RadialDistortion(**kwargs) 11 | else: 12 | raise NotImplementedError(f'Unknown method: {method}') 13 | return distortion 14 | 15 | 16 | def distort_img(tensor: np.ndarray, 17 | method: str = 'radial', 18 | resolution=None, 19 | nearest=None, 20 | inverse=False, 21 | **kwargs) -> np.ndarray: 22 | distortion = __build_distortion(method, **kwargs) 23 | return distortion.distort_img(tensor, output_resolution=resolution, nearest_inter=nearest, inverse=inverse) 24 | 25 | 26 | def distort_flow(tensor: np.ndarray, 27 | method: str = 'radial', 28 | resolution=None, 29 | nearest=None, 30 | **kwargs) -> np.ndarray: 31 | distortion = __build_distortion(method, **kwargs) 32 | return distortion.distort_flow(tensor, output_resolution=resolution, nearest_inter=nearest) 33 | -------------------------------------------------------------------------------- /opticalflow/api/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import opticalflow.dataset as dataset 5 | from opticalflow.utils.utils import InputPadder 6 | from opticalflow.utils.flow_utils import convert_360_gt 7 | 8 | class Evaluator(): 9 | 10 | def __init__(self, data_size: int = None): 11 | self._loss_list = [] 12 | self._data_size = data_size 13 | self._count = 0 14 | self._data_completed_count = 0 15 | 16 | def record_result(self, y, y_gt): 17 | pass 18 | 19 | def record_loss(self, loss, current_batch_size): 20 | self._loss_list.append(loss) 21 | self._data_completed_count += current_batch_size 22 | 23 | def print_current_evaluation_result(self): 24 | self._count += 1 25 | loss = self._loss_list[-1] 26 | print(f'({self._count}) loss: {loss:>7f} ' 27 | f'[{self._data_completed_count:>5d}/{self._data_size:>5d}]') 28 | 29 | def print_all_evaluation_result(self): 30 | print('Looks good') 31 | 32 | 33 | def init_evaluator(data_size=1): 34 | return Evaluator(data_size) 35 | 36 | 37 | def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=400): 38 | """Loss function defined over sequence of flow predictions, from RAFT.""" 39 | 40 | n_predictions = len(flow_preds) 41 | flow_loss = 0.0 42 | 43 | # exlude invalid pixels and extremely large diplacements 44 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 45 | valid = (valid >= 0.5) & (mag < max_flow) 46 | 47 | for i in range(n_predictions): 48 | i_weight = gamma**(n_predictions - i - 1) 49 | i_loss = (flow_preds[i] - flow_gt).abs() 50 | flow_loss += i_weight * (valid[:, None] * i_loss).mean() 51 | 52 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() 53 | epe = epe.view(-1)[valid.view(-1)] 54 | 55 | metrics = { 56 | 'epe': epe.mean().item(), 57 | '1px': (epe < 1).float().mean().item(), 58 | '3px': (epe < 3).float().mean().item(), 59 | '5px': (epe < 5).float().mean().item(), 60 | } 61 | 62 | return flow_loss, metrics 63 | 64 | 65 | class Not360Exception(Exception): 66 | def __init__(self): 67 | print("Not the 360° flow ground truth! Please check if your gt is converted or trun off the CFE.") 68 | 69 | 70 | @torch.no_grad() 71 | def validate_chairs(model, data_root, gpus=[0]): 72 | """Perform evaluation on the FlyingChairs (test) split, from RAFT, 73 | modified.""" 74 | model.eval() 75 | epe_list = [] 76 | 77 | val_dataset = dataset.FlyingChairs(split='validation', root=data_root) 78 | for val_id in range(len(val_dataset)): 79 | image1, image2, flow_gt, _ = val_dataset[val_id] 80 | image1 = image1[None].cuda(gpus[0]) 81 | image2 = image2[None].cuda(gpus[0]) 82 | 83 | # zip image 84 | image_pair = torch.stack((image1, image2)) 85 | 86 | _, flow_pr = model._model(image_pair, test_mode=True) 87 | epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt() 88 | epe_list.append(epe.view(-1).numpy()) 89 | 90 | epe = np.mean(np.concatenate(epe_list)) 91 | print('Validation Chairs EPE: %f' % epe) 92 | return {'chairs': epe} 93 | 94 | 95 | @torch.no_grad() 96 | def validate_omni(model, data_root, gpus=[0]): 97 | """Peform validation using the OmniDataset""" 98 | model.eval() 99 | results = {} 100 | epe_any = [] 101 | 102 | for dstype in ['CartoonTree', 'Forest', 'LowPolyModels']: 103 | val_dataset = dataset.OmniDataset(root=data_root, dstype=dstype) 104 | epe_list = [] 105 | 106 | for val_id in range(len(val_dataset)): 107 | image1, image2, flow_gt, _ = val_dataset[val_id] 108 | image1 = image1[None].cuda(gpus[0]) 109 | image2 = image2[None].cuda(gpus[0]) 110 | 111 | padder = InputPadder(image1.shape) 112 | image1, image2 = padder.pad(image1, image2) 113 | 114 | # zip image 115 | image_pair = torch.stack((image1, image2)) 116 | 117 | flow_low, flow_pr = model._model(image_pair, test_mode=True) 118 | flow = padder.unpad(flow_pr[0]).cpu() 119 | 120 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 121 | epe_list.append(epe.view(-1).numpy()) 122 | 123 | epe_any.append(epe_list) 124 | epe_all = np.concatenate(epe_list) 125 | epe = np.mean(epe_all) 126 | 127 | print('Validation Omni (%s) EPE: %f' % 128 | (dstype, epe)) 129 | results[dstype] = np.mean(epe_list) 130 | 131 | epe_final_all = np.concatenate(epe_any) 132 | epe_final = np.mean(epe_final_all) 133 | print('Validation Omni (all) EPE: %f' % 134 | (epe_final)) 135 | 136 | return results 137 | 138 | 139 | @torch.no_grad() 140 | def validate_omni_cfe(model, data_root, cvt_gt=False, gpus=[0]): 141 | """Peform validation using the OmniFlowNet Dataset, under PanoFlow Framework""" 142 | model.eval() 143 | results = {} 144 | epe_any = [] 145 | 146 | for dstype in ['CartoonTree', 'Forest', 'LowPolyModels']: 147 | val_dataset = dataset.OmniDataset(root=data_root, dstype=dstype) 148 | epe_list = [] 149 | 150 | for val_id in range(len(val_dataset)): 151 | image1, image2, flow_gt, _ = val_dataset[val_id] 152 | image1 = image1[None].cuda(gpus[0]) 153 | image2 = image2[None].cuda(gpus[0]) 154 | 155 | padder = InputPadder(image1.shape) 156 | image1, image2 = padder.pad(image1, image2) 157 | 158 | # convert gt to 360 flow 159 | if cvt_gt: 160 | flow_gt = convert_360_gt(flow_gt) 161 | 162 | # check if is 360 flow gt 163 | if flow_gt[0, :, :].max() > flow_gt.shape[2]//2: 164 | raise Not360Exception() 165 | 166 | # zip image 167 | image_pair = torch.stack((image1, image2)) 168 | 169 | # generate fmaps 170 | fmap1, fmap2, cnet1 = model._model(image_pair, test_mode=True, gen_fmap=True) 171 | 172 | # split fmaps # 173 | img_A1 = fmap1[:, :, :, 0:fmap1.shape[3] // 2] 174 | img_B1 = fmap1[:, :, :, fmap1.shape[3] // 2:] 175 | img_A2 = fmap2[:, :, :, 0:fmap2.shape[3] // 2] 176 | img_B2 = fmap2[:, :, :, fmap2.shape[3] // 2:] 177 | 178 | cnet_A1 = cnet1[:, :, :, 0:fmap1.shape[3] // 2] 179 | cnet_B1 = cnet1[:, :, :, fmap1.shape[3] // 2:] 180 | 181 | # prepare fmap pairs # 182 | img11 = torch.cat([img_B1, img_A1], dim=3) 183 | img21 = torch.cat([img_B2, img_A2], dim=3) 184 | cnet11 = torch.cat([cnet_B1, cnet_A1], dim=3) 185 | img_pair_B1A1 = torch.stack((img11, img21, cnet11)) 186 | 187 | img12 = torch.cat([img_A1, img_B1], dim=3) 188 | img22 = torch.cat([img_A2, img_B2], dim=3) 189 | cnet12 = torch.cat([cnet_A1, cnet_B1], dim=3) 190 | img_pair_A1B1 = torch.stack((img12, img22, cnet12)) 191 | 192 | # flow prediction # 193 | # skip encoder 194 | 195 | _, flow_pr_B1A1 = model._model(img_pair_B1A1, test_mode=True, skip_encode=True) 196 | 197 | _, flow_pr_A1B1 = model._model(img_pair_A1B1, test_mode=True, skip_encode=True) 198 | 199 | flow_pr_A1 = flow_pr_B1A1[:, :, :, flow_pr_B1A1.shape[3] // 2:] 200 | flow_pr_A2 = flow_pr_A1B1[:, :, :, 0:flow_pr_A1B1.shape[3] // 2] 201 | 202 | flow_pr_A = torch.minimum(flow_pr_A1, flow_pr_A2) 203 | 204 | flow_pr_B1 = flow_pr_B1A1[:, :, :, 0:flow_pr_B1A1.shape[3] // 2] 205 | flow_pr_B2 = flow_pr_A1B1[:, :, :, flow_pr_A1B1.shape[3] // 2:] 206 | 207 | flow_pr_B = torch.minimum(flow_pr_B1, flow_pr_B2) 208 | 209 | # all 210 | flow_pr = torch.cat([flow_pr_A, flow_pr_B], dim=3) 211 | flow_pr[:, :, :, flow_pr.shape[3] // 2] = flow_pr[:, :, :, (flow_pr.shape[3] // 2) + 1] 212 | flow_pr[:, :, :, (flow_pr.shape[3] // 2) - 1] = flow_pr[:, :, :, (flow_pr.shape[3] // 2) - 2] 213 | 214 | flow = padder.unpad(flow_pr[0]).cpu() 215 | 216 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 217 | epe_list.append(epe.view(-1).numpy()) 218 | 219 | epe_any.append(epe_list) 220 | epe_all = np.concatenate(epe_list) 221 | epe = np.mean(epe_all) 222 | print('Validation Omni (%s) EPE: %f' % 223 | (dstype, epe)) 224 | dstype = 'Flow360' + dstype 225 | results[dstype] = np.mean(epe_list) 226 | 227 | epe_final_all = np.concatenate(epe_any) 228 | epe_final = np.mean(epe_final_all) 229 | print('Validation Omni (all) EPE: %f' % 230 | (epe_final)) 231 | 232 | return results 233 | 234 | 235 | @torch.no_grad() 236 | def validate_flow360(model, data_root, gpus=[0]): 237 | """Peform validation using the Flow360 (test) split""" 238 | model.eval() 239 | results = {} 240 | epe_any = [] 241 | 242 | for dstype in ['cloud', 'fog', 'rain', 'sunny']: 243 | val_dataset = dataset.Flow360( 244 | split='test', root=data_root, dstype=dstype) 245 | epe_list = [] 246 | 247 | for val_id in range(len(val_dataset)): 248 | image1, image2, flow_gt, _ = val_dataset[val_id] 249 | image1 = image1[None].cuda(gpus[0]) 250 | image2 = image2[None].cuda(gpus[0]) 251 | 252 | padder = InputPadder(image1.shape) 253 | image1, image2 = padder.pad(image1, image2) 254 | 255 | # zip image 256 | image_pair = torch.stack((image1, image2)) 257 | 258 | flow_low, flow_pr = model._model(image_pair, test_mode=True) 259 | flow = padder.unpad(flow_pr[0]).cpu() 260 | 261 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 262 | epe_list.append(epe.view(-1).numpy()) 263 | 264 | epe_any.append(epe_list) 265 | epe_all = np.concatenate(epe_list) 266 | epe = np.mean(epe_all) 267 | 268 | print('Validation FLow360 (%s) EPE: %f' % 269 | (dstype, epe)) 270 | dstype = 'Flow360' + dstype 271 | results[dstype] = np.mean(epe_list) 272 | 273 | epe_final_all = np.concatenate(epe_any) 274 | epe_final = np.mean(epe_final_all) 275 | print('Validation FLow360 (all) EPE: %f' % 276 | (epe_final)) 277 | 278 | return results 279 | 280 | 281 | @torch.no_grad() 282 | def validate_flow360_cfe(model, data_root, cvt_gt=False, gpus=[0]): 283 | """Peform validation using the Flow360 (test) split, under PanoFlow Framework""" 284 | model.eval() 285 | results = {} 286 | epe_any = [] 287 | 288 | for dstype in ['cloud', 'fog', 'rain', 'sunny']: 289 | val_dataset = dataset.Flow360( 290 | split='test', root=data_root, dstype=dstype) 291 | epe_list = [] 292 | 293 | for val_id in range(len(val_dataset)): 294 | image1, image2, flow_gt, _ = val_dataset[val_id] 295 | image1 = image1[None].cuda(gpus[0]) 296 | image2 = image2[None].cuda(gpus[0]) 297 | 298 | padder = InputPadder(image1.shape) 299 | image1, image2 = padder.pad(image1, image2) 300 | 301 | # convert gt to 360 flow 302 | if cvt_gt: 303 | flow_gt = convert_360_gt(flow_gt) 304 | 305 | # check if is 360 flow gt 306 | if flow_gt[0, :, :].max() > flow_gt.shape[2]//2: 307 | raise Not360Exception() 308 | 309 | # zip image 310 | image_pair = torch.stack((image1, image2)) 311 | 312 | # generate fmaps 313 | fmap1, fmap2, cnet1 = model._model(image_pair, test_mode=True, gen_fmap=True) 314 | 315 | # split fmaps # 316 | img_A1 = fmap1[:, :, :, 0:fmap1.shape[3] // 2] 317 | img_B1 = fmap1[:, :, :, fmap1.shape[3] // 2:] 318 | img_A2 = fmap2[:, :, :, 0:fmap2.shape[3] // 2] 319 | img_B2 = fmap2[:, :, :, fmap2.shape[3] // 2:] 320 | 321 | cnet_A1 = cnet1[:, :, :, 0:fmap1.shape[3] // 2] 322 | cnet_B1 = cnet1[:, :, :, fmap1.shape[3] // 2:] 323 | 324 | # prepare fmap pairs # 325 | img11 = torch.cat([img_B1, img_A1], dim=3) 326 | img21 = torch.cat([img_B2, img_A2], dim=3) 327 | cnet11 = torch.cat([cnet_B1, cnet_A1], dim=3) 328 | img_pair_B1A1 = torch.stack((img11, img21, cnet11)) 329 | 330 | img12 = torch.cat([img_A1, img_B1], dim=3) 331 | img22 = torch.cat([img_A2, img_B2], dim=3) 332 | cnet12 = torch.cat([cnet_A1, cnet_B1], dim=3) 333 | img_pair_A1B1 = torch.stack((img12, img22, cnet12)) 334 | 335 | # flow prediction # 336 | # skip encoder 337 | 338 | _, flow_pr_B1A1 = model._model(img_pair_B1A1, test_mode=True, skip_encode=True) 339 | 340 | _, flow_pr_A1B1 = model._model(img_pair_A1B1, test_mode=True, skip_encode=True) 341 | 342 | flow_pr_A1 = flow_pr_B1A1[:, :, :, flow_pr_B1A1.shape[3] // 2:] 343 | flow_pr_A2 = flow_pr_A1B1[:, :, :, 0:flow_pr_A1B1.shape[3] // 2] 344 | 345 | flow_pr_A = torch.minimum(flow_pr_A1, flow_pr_A2) 346 | 347 | flow_pr_B1 = flow_pr_B1A1[:, :, :, 0:flow_pr_B1A1.shape[3] // 2] 348 | flow_pr_B2 = flow_pr_A1B1[:, :, :, flow_pr_A1B1.shape[3] // 2:] 349 | 350 | flow_pr_B = torch.minimum(flow_pr_B1, flow_pr_B2) 351 | 352 | # all 353 | flow_pr = torch.cat([flow_pr_A, flow_pr_B], dim=3) 354 | flow_pr[:, :, :, flow_pr.shape[3] // 2] = flow_pr[:, :, :, (flow_pr.shape[3] // 2) + 1] 355 | flow_pr[:, :, :, (flow_pr.shape[3] // 2) - 1] = flow_pr[:, :, :, (flow_pr.shape[3] // 2) - 2] 356 | 357 | flow = padder.unpad(flow_pr[0]).cpu() 358 | 359 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 360 | epe_list.append(epe.view(-1).numpy()) 361 | 362 | epe_any.append(epe_list) 363 | epe_all = np.concatenate(epe_list) 364 | epe = np.mean(epe_all) 365 | print('Validation FLow360 (%s) EPE: %f' % 366 | (dstype, epe)) 367 | dstype = 'Flow360' + dstype 368 | results[dstype] = np.mean(epe_list) 369 | 370 | epe_final_all = np.concatenate(epe_any) 371 | epe_final = np.mean(epe_final_all) 372 | print('Validation FLow360 (all) EPE: %f' % 373 | (epe_final)) 374 | 375 | return results 376 | 377 | 378 | @torch.no_grad() 379 | def validate_flow360_cfe_double_estimate(model, data_root, gpus=[0]): 380 | """Peform validation using the Flow360 (test) split, under PanoFlow Framework, use double estimate setting""" 381 | model.eval() 382 | results = {} 383 | epe_any = [] 384 | 385 | for dstype in ['cloud', 'fog', 'rain', 'sunny']: 386 | val_dataset = dataset.Flow360( 387 | split='test', root=data_root, dstype=dstype) 388 | epe_list = [] 389 | 390 | for val_id in range(len(val_dataset)): 391 | image1, image2, flow_gt, _ = val_dataset[val_id] 392 | image1 = image1[None].cuda(gpus[0]) 393 | image2 = image2[None].cuda(gpus[0]) 394 | 395 | padder = InputPadder(image1.shape) 396 | image1, image2 = padder.pad(image1, image2) 397 | 398 | # check if is 360 flow gt 399 | if flow_gt[0, :, :].max() > flow_gt.shape[2]//2: 400 | raise Not360Exception() 401 | 402 | # split images # 403 | img_A1 = image1[:, :, :, 0:image1.shape[3] // 2] 404 | img_B1 = image1[:, :, :, image1.shape[3] // 2:] 405 | img_A2 = image2[:, :, :, 0:image2.shape[3] // 2] 406 | img_B2 = image2[:, :, :, image2.shape[3] // 2:] 407 | 408 | # prepare image pairs # 409 | img11 = torch.cat([img_B1, img_A1], dim=3) 410 | img21 = torch.cat([img_B2, img_A2], dim=3) 411 | img_pair_B1A1 = torch.stack((img11, img21)) 412 | 413 | img12 = torch.cat([img_A1, img_B1], dim=3) 414 | img22 = torch.cat([img_A2, img_B2], dim=3) 415 | img_pair_A1B1 = torch.stack((img12, img22)) 416 | 417 | # double estimate 418 | _, flow_pr_B1A1 = model._model(img_pair_B1A1, test_mode=True) 419 | 420 | _, flow_pr_A1B1 = model._model(img_pair_A1B1, test_mode=True) 421 | 422 | flow_pr_A1 = flow_pr_B1A1[:, :, :, flow_pr_B1A1.shape[3] // 2:] 423 | flow_pr_A2 = flow_pr_A1B1[:, :, :, 0:flow_pr_A1B1.shape[3] // 2] 424 | 425 | flow_pr_A = torch.minimum(flow_pr_A1, flow_pr_A2) 426 | 427 | flow_pr_B1 = flow_pr_B1A1[:, :, :, 0:flow_pr_B1A1.shape[3] // 2] 428 | flow_pr_B2 = flow_pr_A1B1[:, :, :, flow_pr_A1B1.shape[3] // 2:] 429 | 430 | flow_pr_B = torch.minimum(flow_pr_B1, flow_pr_B2) 431 | 432 | # all 433 | flow_pr = torch.cat([flow_pr_A, flow_pr_B], dim=3) 434 | flow_pr[:, :, :, flow_pr.shape[3] // 2] = flow_pr[:, :, :, (flow_pr.shape[3] // 2) + 1] 435 | flow_pr[:, :, :, (flow_pr.shape[3] // 2) - 1] = flow_pr[:, :, :, (flow_pr.shape[3] // 2) - 2] 436 | 437 | flow = padder.unpad(flow_pr[0]).cpu() 438 | 439 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 440 | epe_list.append(epe.view(-1).numpy()) 441 | 442 | epe_any.append(epe_list) 443 | epe_all = np.concatenate(epe_list) 444 | epe = np.mean(epe_all) 445 | print('Validation FLow360 (%s) EPE: %f' % 446 | (dstype, epe)) 447 | dstype = 'Flow360' + dstype 448 | results[dstype] = np.mean(epe_list) 449 | 450 | epe_final_all = np.concatenate(epe_any) 451 | epe_final = np.mean(epe_final_all) 452 | print('Validation FLow360 (all) EPE: %f' % 453 | (epe_final)) 454 | 455 | return results 456 | 457 | 458 | @torch.no_grad() 459 | def validate_flow360_cfe_same_padding(model, data_root, gpus=[0]): 460 | """Peform validation using the Flow360 (test) split, under PanoFlow Framework with same padding""" 461 | model.eval() 462 | results = {} 463 | epe_any = [] 464 | 465 | for dstype in ['cloud', 'fog', 'rain', 'sunny']: 466 | val_dataset = dataset.Flow360( 467 | split='test', root=data_root, dstype=dstype) 468 | epe_list = [] 469 | 470 | for val_id in range(len(val_dataset)): 471 | image1, image2, flow_gt, _ = val_dataset[val_id] 472 | image1 = image1[None].cuda(gpus[0]) 473 | image2 = image2[None].cuda(gpus[0]) 474 | 475 | padder = InputPadder(image1.shape) 476 | image1, image2 = padder.pad(image1, image2) 477 | 478 | # check if is 360 flow gt 479 | if flow_gt[0, :, :].max() > flow_gt.shape[2]//2: 480 | raise Not360Exception() 481 | 482 | # zip image 483 | image_pair = torch.stack((image1, image2)) 484 | 485 | # generate fmaps 486 | fmap1, fmap2, cnet1 = model._model(image_pair, test_mode=True, gen_fmap=True) 487 | 488 | # split fmaps # 489 | img_A1 = fmap1[:, :, :, 0:fmap1.shape[3] // 2] 490 | img_B1 = fmap1[:, :, :, fmap1.shape[3] // 2:] 491 | img_A2 = fmap2[:, :, :, 0:fmap2.shape[3] // 2] 492 | img_B2 = fmap2[:, :, :, fmap2.shape[3] // 2:] 493 | 494 | cnet_A1 = cnet1[:, :, :, 0:fmap1.shape[3] // 2] 495 | cnet_B1 = cnet1[:, :, :, fmap1.shape[3] // 2:] 496 | 497 | # prepare fmap pairs # 498 | # section A 499 | img11 = torch.cat([img_A1, img_A1], dim=3) 500 | img21 = torch.cat([img_B2, img_A2], dim=3) 501 | cnet11 = torch.cat([cnet_A1, cnet_A1], dim=3) 502 | img_pair_A1 = torch.stack((img11, img21, cnet11)) 503 | 504 | # section B 505 | img13 = torch.cat([img_B1, img_B1], dim=3) 506 | img23 = torch.cat([img_A2, img_B2], dim=3) 507 | cnet13 = torch.cat([img_B1, cnet_B1], dim=3) 508 | img_pair_B1 = torch.stack((img13, img23, cnet13)) 509 | 510 | # flow prediction # 511 | # skip encoder 512 | 513 | _, flow_pr_A = model._model(img_pair_A1, test_mode=True, skip_encode=True) 514 | 515 | _, flow_pr_B = model._model(img_pair_B1, test_mode=True, skip_encode=True) 516 | 517 | flow_pr_A1 = flow_pr_A[:, :, :, flow_pr_A.shape[3] // 2:] 518 | flow_pr_A2 = flow_pr_A[:, :, :, 0:flow_pr_A.shape[3] // 2] 519 | 520 | flow_pr_A = torch.minimum(flow_pr_A1, flow_pr_A2) 521 | 522 | flow_pr_B1 = flow_pr_B[:, :, :, flow_pr_B.shape[3] // 2:] 523 | flow_pr_B2 = flow_pr_B[:, :, :, 0:flow_pr_B.shape[3] // 2] 524 | 525 | flow_pr_B = torch.minimum(flow_pr_B1, flow_pr_B2) 526 | 527 | # all 528 | flow_pr = torch.cat([flow_pr_A, flow_pr_B], dim=3) 529 | flow_pr[:, :, :, flow_pr.shape[3] // 2] = flow_pr[:, :, :, (flow_pr.shape[3] // 2) + 1] 530 | flow_pr[:, :, :, (flow_pr.shape[3] // 2) - 1] = flow_pr[:, :, :, (flow_pr.shape[3] // 2) - 2] 531 | 532 | flow = padder.unpad(flow_pr[0]).cpu() 533 | 534 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 535 | epe_list.append(epe.view(-1).numpy()) 536 | 537 | epe_any.append(epe_list) 538 | epe_all = np.concatenate(epe_list) 539 | epe = np.mean(epe_all) 540 | print('Validation FLow360 (%s) EPE: %f' % 541 | (dstype, epe)) 542 | dstype = 'Flow360' + dstype 543 | results[dstype] = np.mean(epe_list) 544 | 545 | epe_final_all = np.concatenate(epe_any) 546 | epe_final = np.mean(epe_final_all) 547 | print('Validation FLow360 (all) EPE: %f' % 548 | (epe_final)) 549 | -------------------------------------------------------------------------------- /opticalflow/api/inference.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | from opticalflow.utils.utils import (fill_order_keys, fix_order_keys, 6 | fix_read_order_keys) 7 | 8 | 9 | def inference(model, x, args): 10 | if args.model == 'RAFT': 11 | if args.train: 12 | model = torch.nn.DataParallel(model._model) 13 | try: 14 | model.load_state_dict(torch.load(args.checkpoint)) 15 | except (Exception): 16 | d1 = torch.load(args.checkpoint) 17 | d2 = OrderedDict([(fix_order_keys(k, 6), v) 18 | for k, v in d1.items()]) 19 | model.load_state_dict(d2) 20 | model = model.module 21 | model.to(args.DEVICE) 22 | model.eval() 23 | with torch.no_grad(): 24 | return model(x, test_mode=True) 25 | else: # fix model load bug when test 26 | model = torch.nn.DataParallel(model._model) 27 | try: 28 | model.load_state_dict( 29 | torch.load(args.checkpoint, map_location='cpu')) 30 | except (Exception): 31 | try: 32 | d1 = torch.load(args.checkpoint, map_location='cpu') 33 | d2 = OrderedDict([(fix_order_keys(k, 6), v) 34 | for k, v in d1.items()]) 35 | model.load_state_dict(d2) 36 | except (Exception): 37 | d1 = torch.load(args.checkpoint, map_location='cpu') 38 | d2 = OrderedDict([(fill_order_keys( 39 | fix_read_order_keys(k, 7), 40 | fill_value='module.', 41 | fill_position=0), v) for k, v in d1.items()]) 42 | model.load_state_dict(d2) 43 | model = model.module 44 | model.to(args.DEVICE) 45 | model.eval() 46 | with torch.no_grad(): 47 | return model(x, test_mode=True) 48 | else: 49 | with torch.no_grad(): 50 | return model(x) 51 | -------------------------------------------------------------------------------- /opticalflow/api/init_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | from opticalflow.core.model import csflow, raft, panoflow_csflow, panoflow_raft 6 | from opticalflow.utils.utils import fill_order_keys, fix_read_order_keys 7 | 8 | 9 | def count_parameters(model): 10 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 11 | 12 | 13 | def init_CSFlow(args): 14 | if args.train: 15 | if args.change_gpu: 16 | device = torch.device(('cuda:' + str(args.gpus[0]))) 17 | model = csflow.CSFlow(args) 18 | model.to(device) 19 | print('Parameter Count: %d' % count_parameters(model)) 20 | 21 | # read checkpoint 22 | if args.restore_ckpt is not None: 23 | try: 24 | model.load_state_dict(torch.load(args.restore_ckpt)) 25 | except (Exception): 26 | try: 27 | d1 = torch.load(args.restore_ckpt) 28 | d2 = OrderedDict([ 29 | (fill_order_keys(k, fill_value='_model.'), v) 30 | for k, v in d1.items() 31 | ]) 32 | model.load_state_dict(d2) 33 | except: 34 | try: 35 | d1 = torch.load(args.restore_ckpt) 36 | d2 = OrderedDict([ 37 | (fix_read_order_keys(k, start_value=7), v) 38 | for k, v in d1.items() 39 | ]) 40 | model.load_state_dict(d2) 41 | except: 42 | d1 = torch.load(args.restore_ckpt) 43 | d2 = OrderedDict([(fill_order_keys( 44 | fix_read_order_keys(k, start_value=7), 45 | fill_value='_model.', 46 | fill_position=0), v) for k, v in d1.items()]) 47 | model.load_state_dict(d2) 48 | 49 | pass 50 | 51 | model.to(device) 52 | model.train() 53 | 54 | if args.dataset != 'Chairs': 55 | # model.module.freeze_bn() 56 | for m in model.modules(): 57 | if isinstance(m, torch.nn.BatchNorm2d): 58 | m.eval() 59 | 60 | return model 61 | else: 62 | model = torch.nn.DataParallel( 63 | csflow.CSFlow(args), device_ids=args.gpus) 64 | print('Parameter Count: %d' % count_parameters(model)) 65 | 66 | # read checkpoint 67 | if args.restore_ckpt is not None: 68 | try: 69 | model.load_state_dict(torch.load(args.restore_ckpt)) 70 | except (Exception): 71 | try: 72 | d1 = torch.load(args.restore_ckpt) 73 | d2 = OrderedDict([ 74 | (fill_order_keys(k, fill_value='_model.'), v) 75 | for k, v in d1.items() 76 | ]) 77 | model.load_state_dict(d2) 78 | except (Exception): 79 | d1 = torch.load(args.restore_ckpt, map_location='cpu') 80 | d2 = OrderedDict([(fill_order_keys( 81 | k, fill_value='module.', fill_position=0), v) 82 | for k, v in d1.items()]) 83 | model.load_state_dict(d2) 84 | pass 85 | 86 | model.cuda() 87 | model.train() 88 | 89 | if args.dataset != 'Chairs': 90 | # model.module.freeze_bn() 91 | for m in model.modules(): 92 | if isinstance(m, torch.nn.BatchNorm2d): 93 | m.eval() 94 | 95 | return model 96 | else: 97 | return csflow.CSFlow(args) 98 | 99 | 100 | def init_RAFT(args): 101 | if args.train: 102 | if args.change_gpu: 103 | device = torch.device(('cuda:' + str(args.gpus[0]))) 104 | model = raft.RAFT(args) 105 | model.to(device) 106 | print('Parameter Count: %d' % count_parameters(model)) 107 | 108 | # read checkpoint 109 | if args.restore_ckpt is not None: 110 | try: 111 | model.load_state_dict(torch.load(args.restore_ckpt)) 112 | except (Exception): 113 | try: 114 | d1 = torch.load(args.restore_ckpt) 115 | d2 = OrderedDict([ 116 | (fill_order_keys(k, fill_value='_model.'), v) 117 | for k, v in d1.items() 118 | ]) 119 | model.load_state_dict(d2) 120 | except: 121 | try: 122 | d1 = torch.load(args.restore_ckpt) 123 | d2 = OrderedDict([ 124 | (fix_read_order_keys(k, start_value=7), v) 125 | for k, v in d1.items() 126 | ]) 127 | model.load_state_dict(d2) 128 | except: 129 | d1 = torch.load(args.restore_ckpt) 130 | d2 = OrderedDict([(fill_order_keys( 131 | fix_read_order_keys(k, start_value=7), 132 | fill_value='_model.', 133 | fill_position=0), v) for k, v in d1.items()]) 134 | model.load_state_dict(d2) 135 | 136 | pass 137 | 138 | model.to(device) 139 | model.train() 140 | 141 | if args.dataset != 'Chairs': 142 | # model.module.freeze_bn() 143 | for m in model.modules(): 144 | if isinstance(m, torch.nn.BatchNorm2d): 145 | m.eval() 146 | 147 | return model 148 | else: 149 | model = torch.nn.DataParallel( 150 | raft.RAFT(args), device_ids=args.gpus) 151 | print('Parameter Count: %d' % count_parameters(model)) 152 | 153 | # read checkpoint 154 | if args.restore_ckpt is not None: 155 | try: 156 | model.load_state_dict(torch.load(args.restore_ckpt)) 157 | except (Exception): 158 | try: 159 | d1 = torch.load(args.restore_ckpt) 160 | d2 = OrderedDict([ 161 | (fill_order_keys(k, fill_value='_model.'), v) 162 | for k, v in d1.items() 163 | ]) 164 | model.load_state_dict(d2) 165 | except (Exception): 166 | d1 = torch.load(args.restore_ckpt, map_location='cpu') 167 | d2 = OrderedDict([(fill_order_keys( 168 | k, fill_value='module.', fill_position=0), v) 169 | for k, v in d1.items()]) 170 | model.load_state_dict(d2) 171 | pass 172 | 173 | model.cuda() 174 | model.train() 175 | 176 | if args.dataset != 'Chairs': 177 | # model.module.freeze_bn() 178 | for m in model.modules(): 179 | if isinstance(m, torch.nn.BatchNorm2d): 180 | m.eval() 181 | 182 | return model 183 | else: 184 | return raft.RAFT(args) 185 | 186 | 187 | def init_PanoCSFlow(args): 188 | if args.train: 189 | if args.change_gpu: 190 | device = torch.device(('cuda:' + str(args.gpus[0]))) 191 | model = panoflow_csflow.PanoCSFlow(args) 192 | model.to(device) 193 | print('Parameter Count: %d' % count_parameters(model)) 194 | 195 | # read checkpoint 196 | if args.restore_ckpt is not None: 197 | try: 198 | model.load_state_dict(torch.load(args.restore_ckpt)) 199 | except (Exception): 200 | try: 201 | d1 = torch.load(args.restore_ckpt) 202 | d2 = OrderedDict([ 203 | (fill_order_keys(k, fill_value='_model.'), v) 204 | for k, v in d1.items() 205 | ]) 206 | model.load_state_dict(d2) 207 | except: 208 | try: 209 | d1 = torch.load(args.restore_ckpt) 210 | d2 = OrderedDict([ 211 | (fix_read_order_keys(k, start_value=7), v) 212 | for k, v in d1.items() 213 | ]) 214 | model.load_state_dict(d2) 215 | except: 216 | d1 = torch.load(args.restore_ckpt) 217 | d2 = OrderedDict([(fill_order_keys( 218 | fix_read_order_keys(k, start_value=7), 219 | fill_value='_model.', 220 | fill_position=0), v) for k, v in d1.items()]) 221 | model.load_state_dict(d2) 222 | 223 | pass 224 | 225 | model.to(device) 226 | model.train() 227 | 228 | if args.dataset != 'Chairs': 229 | # model.module.freeze_bn() 230 | for m in model.modules(): 231 | if isinstance(m, torch.nn.BatchNorm2d): 232 | m.eval() 233 | 234 | return model 235 | else: 236 | model = torch.nn.DataParallel( 237 | panoflow_csflow.PanoCSFlow(args), device_ids=args.gpus) 238 | print('Parameter Count: %d' % count_parameters(model)) 239 | 240 | # read checkpoint 241 | if args.restore_ckpt is not None: 242 | try: 243 | model.load_state_dict(torch.load(args.restore_ckpt)) 244 | except (Exception): 245 | try: 246 | d1 = torch.load(args.restore_ckpt) 247 | d2 = OrderedDict([ 248 | (fill_order_keys(k, fill_value='_model.'), v) 249 | for k, v in d1.items() 250 | ]) 251 | model.load_state_dict(d2) 252 | except (Exception): 253 | d1 = torch.load(args.restore_ckpt, map_location='cpu') 254 | d2 = OrderedDict([(fill_order_keys( 255 | k, fill_value='module.', fill_position=0), v) 256 | for k, v in d1.items()]) 257 | model.load_state_dict(d2) 258 | pass 259 | 260 | model.cuda() 261 | model.train() 262 | 263 | if args.dataset != 'Chairs': 264 | # model.module.freeze_bn() 265 | for m in model.modules(): 266 | if isinstance(m, torch.nn.BatchNorm2d): 267 | m.eval() 268 | 269 | return model 270 | else: 271 | return panoflow_csflow.PanoCSFlow(args) 272 | 273 | 274 | def init_PanoRAFT(args): 275 | if args.train: 276 | if args.change_gpu: 277 | device = torch.device(('cuda:' + str(args.gpus[0]))) 278 | model = panoflow_raft.PanoRAFT(args) 279 | model.to(device) 280 | print('Parameter Count: %d' % count_parameters(model)) 281 | 282 | # read checkpoint 283 | if args.restore_ckpt is not None: 284 | try: 285 | model.load_state_dict(torch.load(args.restore_ckpt)) 286 | except (Exception): 287 | try: 288 | d1 = torch.load(args.restore_ckpt) 289 | d2 = OrderedDict([ 290 | (fill_order_keys(k, fill_value='_model.'), v) 291 | for k, v in d1.items() 292 | ]) 293 | model.load_state_dict(d2) 294 | except: 295 | try: 296 | d1 = torch.load(args.restore_ckpt) 297 | d2 = OrderedDict([ 298 | (fix_read_order_keys(k, start_value=7), v) 299 | for k, v in d1.items() 300 | ]) 301 | model.load_state_dict(d2) 302 | except: 303 | d1 = torch.load(args.restore_ckpt) 304 | d2 = OrderedDict([(fill_order_keys( 305 | fix_read_order_keys(k, start_value=7), 306 | fill_value='_model.', 307 | fill_position=0), v) for k, v in d1.items()]) 308 | model.load_state_dict(d2) 309 | 310 | pass 311 | 312 | model.to(device) 313 | model.train() 314 | 315 | if args.dataset != 'Chairs': 316 | # model.module.freeze_bn() 317 | for m in model.modules(): 318 | if isinstance(m, torch.nn.BatchNorm2d): 319 | m.eval() 320 | 321 | return model 322 | else: 323 | model = torch.nn.DataParallel( 324 | panoflow_raft.PanoRAFT(args), device_ids=args.gpus) 325 | print('Parameter Count: %d' % count_parameters(model)) 326 | 327 | # read checkpoint 328 | if args.restore_ckpt is not None: 329 | try: 330 | model.load_state_dict(torch.load(args.restore_ckpt)) 331 | except (Exception): 332 | try: 333 | d1 = torch.load(args.restore_ckpt) 334 | d2 = OrderedDict([ 335 | (fill_order_keys(k, fill_value='_model.'), v) 336 | for k, v in d1.items() 337 | ]) 338 | model.load_state_dict(d2) 339 | except (Exception): 340 | d1 = torch.load(args.restore_ckpt, map_location='cpu') 341 | d2 = OrderedDict([(fill_order_keys( 342 | k, fill_value='module.', fill_position=0), v) 343 | for k, v in d1.items()]) 344 | model.load_state_dict(d2) 345 | pass 346 | 347 | model.cuda() 348 | model.train() 349 | 350 | if args.dataset != 'Chairs': 351 | # model.module.freeze_bn() 352 | for m in model.modules(): 353 | if isinstance(m, torch.nn.BatchNorm2d): 354 | m.eval() 355 | 356 | return model 357 | else: 358 | return panoflow_raft.PanoRAFT(args) 359 | 360 | 361 | class NoModelException(Exception): 362 | def __init__(self): 363 | print("Not the supported model!") 364 | 365 | 366 | def init_model(args): 367 | if args.model == 'CSFlow': 368 | return init_CSFlow(args) 369 | elif args.model == 'RAFT': 370 | return init_RAFT(args) 371 | elif args.model == 'PanoFlow(CSFlow)': 372 | return init_PanoCSFlow(args) 373 | elif args.model == 'PanoFlow(RAFT)': 374 | return init_PanoRAFT(args) 375 | else: 376 | raise NoModelException() 377 | -------------------------------------------------------------------------------- /opticalflow/api/manage_data.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import cv2 4 | 5 | from opticalflow.core.dataset import KITTIDemoManager 6 | from opticalflow.dataset import (FlyingChairs, FlyingThings3D, Flow360) 7 | 8 | 9 | def load_data(args): 10 | return KITTIDemoManager.load_images(args.img_prefix, args) 11 | 12 | 13 | def create_dataloader(data, args): 14 | return KITTIDemoManager.create_dataloader(data, args) 15 | 16 | 17 | def output_data(imgs, output_dir): 18 | file_path = osp.join(output_dir, 'demo.jpg') 19 | cv2.imwrite(file_path, imgs) 20 | 21 | 22 | def fetch_training_data(args): 23 | """Create the data loader for the corresponding trainign set.""" 24 | 25 | if args.dataset == 'Chairs': 26 | if args.model == 'PanoFlow(CSFlow)' or args.model == 'PanoFlow(RAFT)': 27 | do_distort = True 28 | else: 29 | do_distort = False 30 | aug_params = { 31 | 'crop_size': args.image_size, 32 | 'min_scale': -0.1, 33 | 'max_scale': 1.0, 34 | 'do_flip': True, 35 | 'do_distort': do_distort 36 | } 37 | training_data = FlyingChairs( 38 | aug_params, split='training', root=args.data_root) 39 | 40 | elif args.dataset == 'Things': 41 | if args.model == 'PanoFlow(CSFlow)' or args.model == 'PanoFlow(RAFT)': 42 | do_distort = True 43 | else: 44 | do_distort = False 45 | aug_params = { 46 | 'crop_size': args.image_size, 47 | 'min_scale': -0.4, 48 | 'max_scale': 0.8, 49 | 'do_flip': True, 50 | 'do_distort': do_distort 51 | } 52 | clean_dataset = FlyingThings3D( 53 | aug_params, dstype='frames_cleanpass', root=args.data_root) 54 | final_dataset = FlyingThings3D( 55 | aug_params, dstype='frames_finalpass', root=args.data_root) 56 | training_data = clean_dataset + final_dataset 57 | 58 | elif args.dataset == 'Flow360': 59 | aug_params = { 60 | 'crop_size': args.image_size, 61 | 'min_scale': -0.2, 62 | 'max_scale': 0.6, 63 | 'do_flip': True, 64 | 'do_distort': False 65 | } 66 | sunny = Flow360( 67 | aug_params, 68 | split='train', 69 | root=args.train_Flow360_root, 70 | dstype='sunny') 71 | cloud = Flow360( 72 | aug_params, 73 | split='train', 74 | root=args.train_Flow360_root, 75 | dstype='cloud') 76 | rain = Flow360( 77 | aug_params, 78 | split='train', 79 | root=args.train_Flow360_root, 80 | dstype='rain') 81 | fog = Flow360( 82 | aug_params, 83 | split='train', 84 | root=args.train_Flow360_root, 85 | dstype='fog') 86 | training_data = sunny + cloud + rain + fog 87 | 88 | 89 | print('Training with %d image pairs' % len(training_data)) 90 | return training_data 91 | -------------------------------------------------------------------------------- /opticalflow/api/postprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from opticalflow.core.dataset import KITTIDemoManager 4 | 5 | 6 | def postprocess_data(input_imgs, result: torch.tensor): 7 | result = KITTIDemoManager.postprocess_data(result) 8 | return result 9 | -------------------------------------------------------------------------------- /opticalflow/api/preprocess.py: -------------------------------------------------------------------------------- 1 | from opticalflow.core.dataset import KITTIDemoManager 2 | 3 | 4 | def preprocess_data(np_raw_data, args): 5 | return KITTIDemoManager.preprocess_data(np_raw_data, args) 6 | -------------------------------------------------------------------------------- /opticalflow/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_aug import * # noqa: F401,F403 2 | from .dataset import * # noqa: F401,F403 3 | from .model import * # noqa: F401,F403 4 | -------------------------------------------------------------------------------- /opticalflow/core/data_aug/__init__.py: -------------------------------------------------------------------------------- 1 | from .distortion import Distortion, RadialDistortion 2 | 3 | __all__ = ['Distortion', 'RadialDistortion'] 4 | -------------------------------------------------------------------------------- /opticalflow/core/data_aug/distortion.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Optional, Sequence 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | 8 | class Distortion(metaclass=ABCMeta): 9 | """Abstract class for distorction.""" 10 | 11 | def __init__(self): 12 | pass 13 | 14 | @abstractmethod 15 | def distort_pos_func(self, pos: np.ndarray, resolution: tuple, 16 | inverse: bool) -> np.ndarray: 17 | """pos (np.ndarray): The position of each point in grid. resolution 18 | (tuple): The input resolution(height, width) in tuple. inverse (bool): 19 | If inverse is `True`, the function will map the distorted position to 20 | calibrated position. This parameter is useful in image calibration. 21 | 22 | Returns: 23 | np.ndarray: The distorted position of each point in grid. 24 | """ 25 | pass 26 | 27 | def distort_img_pos(self, 28 | height: int, 29 | width: int, 30 | pos_arr: np.ndarray = None, 31 | inverse: bool = False) -> np.ndarray: 32 | """Given a array of position, return their distortion position. 33 | 34 | By default, the pos_arr comprises the coordinate of all pixels of an 35 | image. 36 | """ 37 | if pos_arr is None: 38 | ij_grid = np.indices((height, width), dtype=np.float32) 39 | pos_arr = np.transpose(ij_grid, (1, 2, 0)) 40 | 41 | new_pos = self.distort_pos_func(pos_arr, (height, width), inverse) 42 | new_pos = new_pos.astype(np.float32) 43 | return new_pos 44 | 45 | def distort_img(self, 46 | img: np.ndarray, 47 | inverse: bool = False, 48 | output_resolution: Optional[tuple] = None, 49 | nearest_inter: bool = None) -> np.ndarray: 50 | assert img.ndim >= 2 51 | height, width = img.shape[0], img.shape[1] 52 | 53 | distort_pos = self.distort_img_pos(height, width, inverse=inverse) 54 | 55 | distort_pos = np.flip(distort_pos, 2) 56 | if nearest_inter: 57 | res = cv2.remap(img, distort_pos, None, cv2.INTER_NEAREST) 58 | else: 59 | res = cv2.remap(img, distort_pos, None, cv2.INTER_LINEAR) 60 | res = res.astype(img.dtype) 61 | if output_resolution: 62 | output_resolution = np.flip(output_resolution, 0) 63 | if nearest_inter: 64 | res = cv2.resize( 65 | res, output_resolution, interpolation=cv2.INTER_NEAREST) 66 | else: 67 | res = cv2.resize( 68 | res, output_resolution, interpolation=cv2.INTER_LINEAR) 69 | return res 70 | 71 | def distort_flow(self, 72 | flow: np.ndarray, 73 | inverse: bool = False, 74 | output_resolution: Optional[tuple] = None, 75 | nearest_inter: bool = None) -> np.ndarray: 76 | height, width = flow.shape[0], flow.shape[1] 77 | i_indices = np.arange(height, dtype=np.float32) 78 | j_indices = np.arange(width, dtype=np.float32) 79 | ji_grid = np.meshgrid(j_indices, i_indices) 80 | ji_grid = np.transpose(ji_grid, (1, 2, 0)) 81 | ji_grid -= flow 82 | dst_pos = np.flip(ji_grid, 2) 83 | 84 | distort_pos = self.distort_img_pos(height, width, inverse=inverse) 85 | 86 | distort_origin_pos = self.distort_img_pos( 87 | height, width, inverse=not inverse) 88 | distort_dst_pos = self.distort_img_pos( 89 | height, width, dst_pos, inverse=not inverse) 90 | flow_d = (distort_origin_pos - distort_dst_pos).astype(np.float32) 91 | distort_pos = np.flip(distort_pos, 2) 92 | if nearest_inter: 93 | res = cv2.remap(flow_d, distort_pos, None, cv2.INTER_NEAREST) 94 | else: 95 | res = cv2.remap(flow_d, distort_pos, None, cv2.INTER_LINEAR) 96 | res = res.astype(np.float32) 97 | res = np.flip(res, 2) 98 | 99 | if output_resolution: 100 | ratio = np.array(output_resolution) / np.array(flow.shape[0:2]) 101 | output_resolution = np.flip(output_resolution, 0) 102 | if nearest_inter: 103 | res = cv2.resize( 104 | res, output_resolution, interpolation=cv2.INTER_NEAREST) 105 | else: 106 | res = cv2.resize( 107 | res, output_resolution, interpolation=cv2.INTER_LINEAR) 108 | ratio = np.flip(ratio, 0) 109 | np.multiply(res, ratio, res) 110 | return res 111 | 112 | 113 | class RadialDistortion(Distortion): 114 | 115 | def __init__(self, ks: Sequence[float]): 116 | """ 117 | A sixtic polynominal L(r) = 1 + k_1r + k_2r^2 + k_4r^4 + k_6r^6 118 | 119 | Args: 120 | ks(Sequence[float]): The parameters of k_1, k_2, ... 121 | """ 122 | super().__init__() 123 | self.ks = ks 124 | 125 | def distort_func(self, r: np.ndarray): 126 | res = 1 127 | r_pow = 1 128 | for k in self.ks: 129 | r_pow *= r 130 | res += k * r_pow 131 | 132 | return res 133 | 134 | def distort_pos_func(self, pos: np.ndarray, resolution: tuple, 135 | inverse: bool) -> np.ndarray: 136 | """pos (np.ndarray): The position of each point in grid. resolution 137 | (tuple): The input resolution(height, width) in tuple. inverse (bool): 138 | If inverse is `True`, the function will map the distorted position to 139 | calibrated position. This parameter is useful in image calibration. 140 | 141 | Returns: 142 | np.ndarray: The distorted position of each point in grid. 143 | """ 144 | height, width = resolution 145 | 146 | center = np.array([height / 2, width / 2], dtype=pos.dtype) 147 | dis_ij = pos - center 148 | r = np.linalg.norm(dis_ij, axis=2) 149 | Lr: np.ndarray = self.distort_func(r) 150 | Lr = Lr.reshape(*Lr.shape, 1) 151 | if inverse: 152 | np.divide(dis_ij, Lr, dis_ij) 153 | else: 154 | np.multiply(dis_ij, Lr, dis_ij) 155 | np.add(dis_ij, center, dis_ij) 156 | 157 | return dis_ij 158 | -------------------------------------------------------------------------------- /opticalflow/core/dataset/DatasetManagerBase.py: -------------------------------------------------------------------------------- 1 | # The interface class for all datasets 2 | import abc 3 | 4 | import numpy as np 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class DatasetManagerBase(metaclass=abc.ABCMeta): 9 | 10 | def __init__(self) -> None: 11 | pass 12 | 13 | @abc.abstractclassmethod 14 | def load_images(root_dir: str, **kwargs): 15 | """Load Images from directory and return a numpy array.""" 16 | pass 17 | 18 | @abc.abstractclassmethod 19 | def _preprocess_image(cls, image: np.ndarray) -> np.ndarray: 20 | """Preprocess single image.""" 21 | pass 22 | 23 | @abc.abstractclassmethod 24 | def create_dataloader(cls, images: np.ndarray) -> DataLoader: 25 | pass 26 | 27 | @abc.abstractclassmethod 28 | def _postprocess_image(cls, image: np.ndarray) -> np.ndarray: 29 | """Postprocess single image.""" 30 | pass 31 | 32 | @abc.abstractclassmethod 33 | def _preprocess_flow(cls, image: np.ndarray) -> np.ndarray: 34 | """Preprocess single flow image.""" 35 | pass 36 | 37 | @abc.abstractclassmethod 38 | def preprocess_data(cls, data): 39 | """Preprocess I/O images.""" 40 | pass 41 | 42 | @abc.abstractclassmethod 43 | def postprocess_data(cls, data): 44 | """Postprocess output images.""" 45 | pass 46 | -------------------------------------------------------------------------------- /opticalflow/core/dataset/KITTIManager.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import cvbase as cvb 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader, Dataset 6 | 7 | from .DatasetManagerBase import DatasetManagerBase 8 | 9 | 10 | class KITTIDataset(Dataset): 11 | 12 | def __init__(self, images: torch.tensor) -> None: 13 | super().__init__() 14 | self._images = images 15 | self._len = len(images) 16 | 17 | def __len__(self): 18 | return self._len 19 | 20 | def __getitem__(self, index) -> torch.tensor: 21 | return self._images[index] 22 | 23 | 24 | class KITTIDemoManager(DatasetManagerBase): 25 | 26 | IMG1_SUFFIX = '_10.png' 27 | IMG2_SUFFIX = '_11.png' 28 | FLOW_SUFFIX_KITTI = '_flow_10.png' 29 | FLOW_SUFFIX = '_10.flo' 30 | BATCH_SIZE = 1 31 | 32 | def __init__(self) -> None: 33 | super().__init__() 34 | 35 | @classmethod 36 | def load_images(cls, img_prefix: str, args): 37 | """Load Images from directory and return a numpy array.""" 38 | # load imgs 39 | img1 = cv2.imread(img_prefix + cls.IMG1_SUFFIX).astype(np.float32) 40 | img2 = cv2.imread(img_prefix + cls.IMG2_SUFFIX).astype(np.float32) 41 | 42 | # load flow 43 | if args.dataset == 'KITTI': 44 | flow = cv2.imread(img_prefix + cls.FLOW_SUFFIX_KITTI, 45 | cv2.IMREAD_ANYDEPTH 46 | | cv2.IMREAD_COLOR).astype(np.float32) 47 | else: 48 | flow = cvb.read_flow(img_prefix + cls.FLOW_SUFFIX) 49 | 50 | return (img1, img2), flow 51 | 52 | @classmethod 53 | def _preprocess_image(cls, image: np.ndarray, args) -> np.ndarray: 54 | """Preprocess single image.""" 55 | if args.model == 'RAFT': 56 | image = image.transpose(2, 0, 1) 57 | return image 58 | 59 | @classmethod 60 | def _preprocess_flow(cls, image: np.ndarray, args) -> np.ndarray: 61 | """Preprocess single flow image.""" 62 | if args.dataset == 'KITTI': 63 | image = image[:, :, ::-1].astype(np.float32) 64 | flow, _ = image[:, :, :2], image[:, :, 2] 65 | flow = (flow - 2**15) / 64.0 66 | # flow is a (u, v) 2-channel image, u, v \in [-512, 512] 67 | else: 68 | flow = image 69 | return flow 70 | 71 | @classmethod 72 | def create_dataloader(cls, images, args) -> DataLoader: 73 | (x1, x2), y = images 74 | x1 = torch.from_numpy(x1) 75 | x2 = torch.from_numpy(x2) 76 | y = torch.from_numpy(y) 77 | 78 | x1 = x1.to(args.DEVICE) 79 | x2 = x2.to(args.DEVICE) 80 | y = y.to(args.DEVICE) 81 | 82 | images = (x1, x2), y 83 | 84 | dataset = KITTIDataset([images]) 85 | dataloader = DataLoader(dataset, batch_size=cls.BATCH_SIZE) 86 | return dataloader 87 | 88 | @classmethod 89 | def preprocess_data(cls, data, args): 90 | """Preprocess I/O images.""" 91 | (img1, img2), gt = data 92 | img1 = cls._preprocess_image(img1, args) 93 | img2 = cls._preprocess_image(img2, args) 94 | gt = cls._preprocess_flow(gt, args) 95 | return (img1, img2), gt 96 | 97 | @classmethod 98 | def postprocess_data(cls, data): 99 | """Postprocess output images.""" 100 | data = data.to('cpu').numpy() 101 | return cls._postprocess_image(data) 102 | -------------------------------------------------------------------------------- /opticalflow/core/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .KITTIManager import KITTIDemoManager 2 | 3 | __all__ = ['KITTIDemoManager'] 4 | -------------------------------------------------------------------------------- /opticalflow/core/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .csflow import CSFlow 2 | from .raft import RAFT 3 | from .panoflow_csflow import PanoCSFlow 4 | from .panoflow_raft import PanoRAFT 5 | 6 | __all__ = ['CSFlow', 'RAFT', 'PanoCSFlow', 'PanoRAFT'] 7 | -------------------------------------------------------------------------------- /opticalflow/core/model/base_model.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import enum 3 | from typing import Any 4 | 5 | import torch.nn as nn 6 | 7 | 8 | class ModelMode(enum.Enum): 9 | TRAIN = 0 10 | TEST = 1 11 | 12 | 13 | class BaseModel(nn.Module, metaclass=abc.ABCMeta): 14 | 15 | def __init__(self, mode: ModelMode = ModelMode.TEST): 16 | super().__init__() 17 | self._mode = mode 18 | 19 | def set_mode(self, value: ModelMode): 20 | self._mode = value 21 | 22 | @abc.abstractmethod 23 | def _preprocess(self, x: Any): 24 | pass 25 | 26 | @abc.abstractmethod 27 | def _forward_test(self, x: Any): 28 | pass 29 | 30 | @abc.abstractmethod 31 | def _forward_train(self, x: Any): 32 | pass 33 | 34 | def forward(self, x: Any): 35 | x = self._preprocess(x) 36 | if self._mode == ModelMode.TEST: 37 | x = self._forward_test(x) 38 | elif self._mode == ModelMode.TRAIN: 39 | x = self._forward_train(x) 40 | return x 41 | -------------------------------------------------------------------------------- /opticalflow/core/model/csflow.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from .base_model import BaseModel, ModelMode 4 | from .external import csflow 5 | 6 | 7 | class CSFlow(BaseModel): 8 | 9 | def __init__(self, args, mode: ModelMode = ModelMode.TEST): 10 | super().__init__(mode=mode) 11 | self._model = csflow.CSFlow(args) 12 | 13 | def _preprocess(self, x: Any): 14 | if isinstance(x, (tuple, list)): 15 | x = x[0] 16 | return x 17 | 18 | def _forward_test(self, x: Any): 19 | self._model.eval() 20 | return self._model(x) 21 | 22 | def _forward_train(self, x: Any): 23 | return self._model(x) 24 | -------------------------------------------------------------------------------- /opticalflow/core/model/external/csflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | try: 6 | autocast = torch.cuda.amp.autocast 7 | except (Exception): 8 | # dummy autocast for PyTorch < 1.6 9 | class autocast: 10 | 11 | def __init__(self, enabled): 12 | pass 13 | 14 | def __enter__(self): 15 | pass 16 | 17 | def __exit__(self, *args): 18 | pass 19 | 20 | 21 | class CSFlow(nn.Module): 22 | 23 | def __init__(self, args): 24 | super(CSFlow, self).__init__() 25 | self.args = args 26 | 27 | self.hidden_dim = hdim = 128 28 | self.context_dim = cdim = 128 29 | args.corr_levels = 4 30 | args.corr_radius = 4 31 | 32 | if 'dropout' not in self.args: 33 | self.args.dropout = 0 34 | 35 | if 'alternate_corr' not in self.args: 36 | self.args.alternate_corr = False 37 | 38 | if 'mixed_precision' not in self.args: 39 | self.args.mixed_precision = False 40 | 41 | # feature network, context network, and update block 42 | self.fnet = BasicEncoder( 43 | output_dim=256, norm_fn='instance', dropout=args.dropout) 44 | 45 | self.cnet = BasicEncoder( 46 | output_dim=hdim + cdim, 47 | norm_fn='batch', 48 | dropout=args.dropout) 49 | 50 | self.strip_corr_block_v2 = StripCrossCorrMap_v2( 51 | in_chan=256, out_chan=256) 52 | self.update_block = BasicUpdateBlock( 53 | self.args, hidden_dim=hdim) 54 | 55 | def freeze_bn(self): 56 | for m in self.modules(): 57 | if isinstance(m, nn.BatchNorm2d): 58 | m.eval() 59 | 60 | def initialize_flow(self, img, dataset, train_flag): 61 | """Flow is represented as difference between two coordinate grids. 62 | 63 | flow = coords1 - coords0, Modified by Hao 64 | """ 65 | N, C, H, W = img.shape 66 | 67 | if dataset == 'KITTI' and not train_flag: 68 | coords0 = coords_grid(N, H // 8 + 1, W // 8 + 1, device=img.device) 69 | coords1 = coords_grid(N, H // 8 + 1, W // 8 + 1, device=img.device) 70 | elif dataset == 'Sintel' and not train_flag: 71 | coords0 = coords_grid(N, H // 8 + 1, W // 8, device=img.device) 72 | coords1 = coords_grid(N, H // 8 + 1, W // 8, device=img.device) 73 | else: 74 | coords0 = coords_grid(N, H // 8, W // 8, device=img.device) 75 | coords1 = coords_grid(N, H // 8, W // 8, device=img.device) 76 | 77 | # optical flow computed as difference: flow = coords1 - coords0 78 | return coords0, coords1 79 | 80 | def upsample_flow(self, flow, mask): 81 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex 82 | combination.""" 83 | N, _, H, W = flow.shape 84 | mask = mask.view(N, 1, 9, 8, 8, H, W) 85 | mask = torch.softmax(mask, dim=2) 86 | 87 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 88 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 89 | 90 | up_flow = torch.sum(mask * up_flow, dim=2) 91 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 92 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 93 | 94 | def forward(self, images, flow_init=None, upsample=True, test_mode=False, gen_fmap=False, skip_encode=False): 95 | """Estimate optical flow between pair of frames.""" 96 | 97 | if not skip_encode: 98 | # Modified, take image pairs as input 99 | image1 = images[0] 100 | image2 = images[1] 101 | image1 = 2 * (image1 / 255.0) - 1.0 102 | image2 = 2 * (image2 / 255.0) - 1.0 103 | 104 | image1 = image1.contiguous() 105 | image2 = image2.contiguous() 106 | 107 | hdim = self.hidden_dim 108 | cdim = self.context_dim 109 | 110 | # run the feature network 111 | with autocast(enabled=self.args.mixed_precision): 112 | fmap1, fmap2 = self.fnet([image1, image2]) 113 | 114 | fmap1 = fmap1.float() 115 | fmap2 = fmap2.float() 116 | else: 117 | hdim = self.hidden_dim 118 | cdim = self.context_dim 119 | fmap1 = images[0] 120 | fmap2 = images[1] 121 | 122 | # run the context network 123 | with autocast(enabled=self.args.mixed_precision): 124 | 125 | if not skip_encode: 126 | cnet = self.cnet(image1) 127 | 128 | if test_mode: 129 | if gen_fmap: 130 | return fmap1, fmap2, cnet 131 | else: 132 | cnet = images[2] 133 | 134 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 135 | net = torch.tanh(net) 136 | inp = torch.relu(inp) 137 | 138 | strip_coor_map, strip_corr_map_w, strip_corr_map_h = self.strip_corr_block_v2( 139 | [fmap1, fmap2]) 140 | corr_fn = CorrBlock_v2( 141 | fmap1, fmap2, strip_coor_map, radius=self.args.corr_radius) 142 | 143 | if not skip_encode: 144 | coords0, coords1 = self.initialize_flow(image1, self.args.dataset, 145 | self.args.train) 146 | else: 147 | b, c, h, w = fmap1.shape 148 | image1 = torch.zeros(b, c, 8 * h, 8 * w).cuda() 149 | coords0, coords1 = self.initialize_flow(image1, self.args.dataset, 150 | self.args.train) 151 | 152 | if flow_init is not None: 153 | coords1 = coords1 + flow_init 154 | 155 | flow_predictions = [] 156 | 157 | # init flow with regression before GRU iters 158 | corr_w_act = torch.nn.functional.softmax( 159 | strip_corr_map_w, dim=3) # B H1 W1 1 W2 160 | corr_h_act = torch.nn.functional.softmax( 161 | strip_corr_map_h, dim=4) # B H1 W1 H2 1 162 | 163 | flo_v = corr_w_act.mul(strip_corr_map_w) # B H1 W1 1 W2 164 | flo_u = corr_h_act.mul(strip_corr_map_h) # B H1 W1 H2 1 165 | 166 | flow_v = torch.sum(flo_v, dim=4).squeeze(dim=3) # B H1 W1 167 | flow_u = torch.sum(flo_u, dim=3).squeeze(dim=3) # B H1 W1 168 | 169 | corr_init = torch.stack((flow_u, flow_v), dim=1) # B 2 H1 W1 170 | 171 | coords1 = coords1.detach() 172 | coords1 = coords1 + corr_init 173 | 174 | # add loss 175 | flow_up = upflow8(coords1 - coords0) 176 | flow_predictions.append(flow_up) 177 | 178 | if not test_mode: 179 | for itr in range(self.args.iters): 180 | coords1 = coords1.detach() 181 | corr = corr_fn(coords1) # index correlation volume 182 | 183 | flow = coords1 - coords0 184 | with autocast(enabled=self.args.mixed_precision): 185 | 186 | net, up_mask, delta_flow = self.update_block( 187 | net, inp, corr, flow) 188 | 189 | # F(t+1) = F(t) + \Delta(t) 190 | coords1 = coords1 + delta_flow 191 | 192 | # upsample predictions 193 | if up_mask is None: 194 | flow_up = upflow8(coords1 - coords0) 195 | else: 196 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 197 | 198 | flow_predictions.append(flow_up) 199 | else: 200 | iters = self.args.eval_iters 201 | 202 | for itr in range(iters): 203 | coords1 = coords1.detach() 204 | corr = corr_fn(coords1) # index correlation volume 205 | 206 | flow = coords1 - coords0 207 | with autocast(enabled=self.args.mixed_precision): 208 | 209 | net, up_mask, delta_flow = self.update_block( 210 | net, inp, corr, flow) 211 | 212 | # F(t+1) = F(t) + \Delta(t) 213 | coords1 = coords1 + delta_flow 214 | 215 | # upsample predictions 216 | if up_mask is None: 217 | flow_up = upflow8(coords1 - coords0) 218 | else: 219 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 220 | 221 | flow_predictions.append(flow_up) 222 | 223 | if test_mode: 224 | return coords1 - coords0, flow_up 225 | 226 | return flow_predictions 227 | 228 | 229 | class StripCrossCorrMap_v2(nn.Module): 230 | """Strip Cross Corr Augmentation Module by Hao, version2.0""" 231 | 232 | def __init__(self, in_chan=256, out_chan=256, *args, **kwargs): 233 | super(StripCrossCorrMap_v2, self).__init__() 234 | self.conv1_1 = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 235 | self.conv1_2 = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 236 | self.conv2_1 = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 237 | self.conv2_2 = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 238 | 239 | self.init_weight() 240 | 241 | def forward(self, x): 242 | fmap1, fmap2 = x 243 | 244 | # vertical query map 245 | fmap1_w = self.conv1_1(fmap1) # B, 64, H, W 246 | batchsize, c_middle, h, w = fmap1_w.size() 247 | fmap1_w = fmap1_w.view(batchsize, c_middle, -1) 248 | 249 | # horizontal query map 250 | fmap1_h = self.conv1_2(fmap1) # B, 64, H, W 251 | batchsize, c_middle, h, w = fmap1_h.size() 252 | fmap1_h = fmap1_h.view(batchsize, c_middle, -1) 253 | 254 | # vertical striping map 255 | fmap2_w = self.conv2_1(fmap2) # B, 64, H, W 256 | fmap2_w = F.avg_pool2d(fmap2_w, [h, 1]) 257 | fmap2_w = fmap2_w.view(batchsize, c_middle, -1).permute(0, 2, 1) 258 | 259 | # horizontal striping map 260 | fmap2_h = self.conv2_2(fmap2) # B, 64, H, W 261 | fmap2_h = F.avg_pool2d(fmap2_h, [1, w]) 262 | fmap2_h = fmap2_h.view(batchsize, c_middle, -1).permute(0, 2, 1) 263 | 264 | # cross strip corr map 265 | strip_corr_map_w = torch.bmm(fmap2_w, fmap1_w).\ 266 | view(batchsize, w, h, w, 1).permute(0, 2, 3, 4, 1) # B H1 W1 1 W2 267 | strip_corr_map_h = torch.bmm(fmap2_h, fmap1_h).\ 268 | view(batchsize, h, h, w, 1).permute(0, 2, 3, 1, 4) # B H1 W1 H2 1 269 | 270 | return (strip_corr_map_w + strip_corr_map_h).view( 271 | batchsize, h, w, 1, h, w), strip_corr_map_w, strip_corr_map_h 272 | 273 | def init_weight(self): 274 | for ly in self.children(): 275 | if isinstance(ly, nn.Conv2d): 276 | nn.init.kaiming_normal_(ly.weight, a=1) 277 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 278 | 279 | def get_params(self): 280 | wd_params, nowd_params = [], [] 281 | for name, module in self.named_modules(): 282 | if isinstance(module, (nn.Linear, nn.Conv2d)): 283 | wd_params.append(module.weight) 284 | if not module.bias is None: 285 | nowd_params.append(module.bias) 286 | elif isinstance(module, torch.nn.BatchNorm2d): 287 | nowd_params += list(module.parameters()) 288 | return wd_params, nowd_params 289 | 290 | 291 | class ConvBNReLU(nn.Module): 292 | """Conv with BN and ReLU, used for Strip Corr Module""" 293 | 294 | def __init__(self, 295 | in_chan, 296 | out_chan, 297 | ks=3, 298 | stride=1, 299 | padding=1, 300 | *args, 301 | **kwargs): 302 | super(ConvBNReLU, self).__init__() 303 | self.conv = nn.Conv2d( 304 | in_chan, 305 | out_chan, 306 | kernel_size=ks, 307 | stride=stride, 308 | padding=padding, 309 | bias=False) 310 | self.bn = torch.nn.BatchNorm2d(out_chan) 311 | self.relu = nn.ReLU(inplace=True) 312 | 313 | def forward(self, x): 314 | x = self.conv(x) 315 | x = self.bn(x) 316 | x = self.relu(x) 317 | return x 318 | 319 | def init_weight(self): 320 | for ly in self.children(): 321 | if isinstance(ly, nn.Conv2d): 322 | nn.init.kaiming_normal_(ly.weight, a=1) 323 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 324 | 325 | 326 | class BasicUpdateBlock(nn.Module): 327 | """Modified by Hao, support for CSFlow""" 328 | 329 | def __init__(self, args, hidden_dim=128, input_dim=128): 330 | super(BasicUpdateBlock, self).__init__() 331 | self.args = args 332 | self.encoder = BasicMotionEncoder_v2(args) 333 | self.gru = SepConvGRU( 334 | hidden_dim=hidden_dim, input_dim=128 + hidden_dim) 335 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 336 | 337 | self.mask = nn.Sequential( 338 | nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), 339 | nn.Conv2d(256, 64 * 9, 1, padding=0)) 340 | 341 | def forward(self, net, inp, corr, flow, upsample=True): 342 | motion_features = self.encoder(flow, corr) 343 | inp = torch.cat([inp, motion_features], dim=1) 344 | 345 | net = self.gru(net, inp) 346 | delta_flow = self.flow_head(net) 347 | 348 | # scale mask to balence gradients 349 | mask = .25 * self.mask(net) 350 | return net, mask, delta_flow 351 | 352 | 353 | def pool2x(x): 354 | return F.avg_pool2d(x, 3, stride=2, padding=1) 355 | 356 | 357 | def interp(x, dest): 358 | interp_args = {'mode': 'bilinear', 'align_corners': True} 359 | return F.interpolate(x, dest.shape[2:], **interp_args) 360 | 361 | 362 | class BasicEncoder(nn.Module): 363 | 364 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 365 | from torch.nn.modules.utils import _pair 366 | super(BasicEncoder, self).__init__() 367 | self.norm_fn = norm_fn 368 | 369 | if self.norm_fn == 'group': 370 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 371 | 372 | elif self.norm_fn == 'batch': 373 | self.norm1 = nn.BatchNorm2d(64) 374 | 375 | elif self.norm_fn == 'instance': 376 | self.norm1 = nn.InstanceNorm2d(64) 377 | 378 | elif self.norm_fn == 'none': 379 | self.norm1 = nn.Sequential() 380 | 381 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 382 | self.relu1 = nn.ReLU(inplace=True) 383 | 384 | self.in_planes = 64 385 | self.layer1 = self._make_layer(64, stride=1) 386 | self.layer2 = self._make_layer(96, stride=2) 387 | self.layer3 = self._make_layer(128, stride=2) 388 | 389 | # output convolution 390 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 391 | 392 | self.dropout = None 393 | if dropout > 0: 394 | self.dropout = nn.Dropout2d(p=dropout) 395 | 396 | for m in self.modules(): 397 | if isinstance(m, nn.Conv2d): 398 | nn.init.kaiming_normal_( 399 | m.weight, mode='fan_out', nonlinearity='relu') 400 | elif isinstance(m, 401 | (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 402 | if m.weight is not None: 403 | nn.init.constant_(m.weight, 1) 404 | if m.bias is not None: 405 | nn.init.constant_(m.bias, 0) 406 | 407 | def _make_layer(self, dim, stride=1): 408 | layer1 = ResidualBlock( 409 | self.in_planes, dim, self.norm_fn, stride=stride) 410 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 411 | layers = (layer1, layer2) 412 | 413 | self.in_planes = dim 414 | return nn.Sequential(*layers) 415 | 416 | def forward(self, x): 417 | 418 | # if input is list, combine batch dimension 419 | is_list = isinstance(x, tuple) or isinstance(x, list) 420 | if is_list: 421 | batch_dim = x[0].shape[0] 422 | x = torch.cat(x, dim=0) 423 | 424 | x = self.conv1(x) 425 | 426 | x = self.norm1(x) 427 | x = self.relu1(x) 428 | 429 | x = self.layer1(x) 430 | x = self.layer2(x) 431 | x = self.layer3(x) 432 | 433 | x = self.conv2(x) 434 | 435 | if self.training and self.dropout is not None: 436 | x = self.dropout(x) 437 | 438 | if is_list: 439 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 440 | 441 | return x 442 | 443 | 444 | class CorrBlock_v2: 445 | """Corr Block, modified by Hao, concat SC with 4D corr""" 446 | 447 | def __init__(self, 448 | fmap1, 449 | fmap2, 450 | strip_coor_map=None, 451 | num_levels=4, 452 | radius=4): 453 | self.num_levels = num_levels 454 | self.radius = radius 455 | self.corr_pyramid = [] 456 | 457 | # all pairs correlation 458 | corr = CorrBlock_v2.corr(fmap1, fmap2) 459 | 460 | if strip_coor_map is not None: 461 | # strip correlation augmentation with concat 462 | corr = torch.cat((corr, strip_coor_map), dim=3) 463 | 464 | batch, h1, w1, dim, h2, w2 = corr.shape 465 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 466 | 467 | self.corr_pyramid.append(corr) 468 | for i in range(self.num_levels - 1): 469 | corr = F.avg_pool2d(corr, 2, stride=2) 470 | self.corr_pyramid.append(corr) 471 | 472 | def __call__(self, coords): 473 | r = self.radius 474 | coords = coords.permute(0, 2, 3, 1) 475 | batch, h1, w1, _ = coords.shape 476 | 477 | out_pyramid = [] 478 | for i in range(self.num_levels): 479 | corr = self.corr_pyramid[i] 480 | dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 481 | dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 482 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 483 | 484 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i 485 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 486 | coords_lvl = centroid_lvl + delta_lvl 487 | 488 | corr = bilinear_sampler(corr, coords_lvl) 489 | corr = corr.view(batch, h1, w1, -1) 490 | out_pyramid.append(corr) 491 | 492 | out = torch.cat(out_pyramid, dim=-1) 493 | return out.permute(0, 3, 1, 2).contiguous().float() 494 | 495 | @staticmethod 496 | def corr(fmap1, fmap2): 497 | batch, dim, ht, wd = fmap1.shape 498 | fmap1 = fmap1.view(batch, dim, ht * wd) 499 | fmap2 = fmap2.view(batch, dim, ht * wd) 500 | 501 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 502 | corr = corr.view(batch, ht, wd, 1, ht, wd) 503 | return corr / torch.sqrt(torch.tensor(dim).float()) 504 | 505 | 506 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 507 | """Wrapper for grid_sample, uses pixel coordinates.""" 508 | H, W = img.shape[-2:] 509 | xgrid, ygrid = coords.split([1, 1], dim=-1) 510 | xgrid = 2 * xgrid / (W - 1) - 1 511 | ygrid = 2 * ygrid / (H - 1) - 1 512 | 513 | grid = torch.cat([xgrid, ygrid], dim=-1) 514 | img = F.grid_sample(img, grid, align_corners=True) 515 | 516 | if mask: 517 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 518 | return img, mask.float() 519 | 520 | return img 521 | 522 | 523 | def coords_grid(batch, ht, wd, device): 524 | coords = torch.meshgrid( 525 | torch.arange(ht, device=device), torch.arange(wd, device=device)) 526 | coords = torch.stack(coords[::-1], dim=0).float() 527 | return coords[None].repeat(batch, 1, 1, 1) 528 | 529 | 530 | def upflow8(flow, mode='bilinear'): 531 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 532 | return 8 * F.interpolate( 533 | flow, size=new_size, mode=mode, align_corners=True) 534 | 535 | 536 | class ResidualBlock(nn.Module): 537 | 538 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 539 | super(ResidualBlock, self).__init__() 540 | 541 | self.conv1 = nn.Conv2d( 542 | in_planes, planes, kernel_size=3, padding=1, stride=stride) 543 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 544 | self.relu = nn.ReLU(inplace=True) 545 | 546 | num_groups = planes // 8 547 | 548 | if norm_fn == 'group': 549 | self.norm1 = nn.GroupNorm( 550 | num_groups=num_groups, num_channels=planes) 551 | self.norm2 = nn.GroupNorm( 552 | num_groups=num_groups, num_channels=planes) 553 | if not stride == 1: 554 | self.norm3 = nn.GroupNorm( 555 | num_groups=num_groups, num_channels=planes) 556 | 557 | elif norm_fn == 'batch': 558 | self.norm1 = nn.BatchNorm2d(planes) 559 | self.norm2 = nn.BatchNorm2d(planes) 560 | if not stride == 1: 561 | self.norm3 = nn.BatchNorm2d(planes) 562 | 563 | elif norm_fn == 'instance': 564 | self.norm1 = nn.InstanceNorm2d(planes) 565 | self.norm2 = nn.InstanceNorm2d(planes) 566 | if not stride == 1: 567 | self.norm3 = nn.InstanceNorm2d(planes) 568 | 569 | elif norm_fn == 'none': 570 | self.norm1 = nn.Sequential() 571 | self.norm2 = nn.Sequential() 572 | if not stride == 1: 573 | self.norm3 = nn.Sequential() 574 | 575 | if stride == 1: 576 | self.downsample = None 577 | 578 | else: 579 | self.downsample = nn.Sequential( 580 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), 581 | self.norm3) 582 | 583 | def forward(self, x): 584 | y = x 585 | y = self.relu(self.norm1(self.conv1(y))) 586 | y = self.relu(self.norm2(self.conv2(y))) 587 | 588 | if self.downsample is not None: 589 | x = self.downsample(x) 590 | 591 | return self.relu(x + y) 592 | 593 | 594 | class BottleneckBlock(nn.Module): 595 | 596 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 597 | super(BottleneckBlock, self).__init__() 598 | 599 | self.conv1 = nn.Conv2d( 600 | in_planes, planes // 4, kernel_size=1, padding=0) 601 | self.conv2 = nn.Conv2d( 602 | planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) 603 | self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) 604 | self.relu = nn.ReLU(inplace=True) 605 | 606 | num_groups = planes // 8 607 | 608 | if norm_fn == 'group': 609 | self.norm1 = nn.GroupNorm( 610 | num_groups=num_groups, num_channels=planes // 4) 611 | self.norm2 = nn.GroupNorm( 612 | num_groups=num_groups, num_channels=planes // 4) 613 | self.norm3 = nn.GroupNorm( 614 | num_groups=num_groups, num_channels=planes) 615 | if not stride == 1: 616 | self.norm4 = nn.GroupNorm( 617 | num_groups=num_groups, num_channels=planes) 618 | 619 | elif norm_fn == 'batch': 620 | self.norm1 = nn.BatchNorm2d(planes // 4) 621 | self.norm2 = nn.BatchNorm2d(planes // 4) 622 | self.norm3 = nn.BatchNorm2d(planes) 623 | if not stride == 1: 624 | self.norm4 = nn.BatchNorm2d(planes) 625 | 626 | elif norm_fn == 'instance': 627 | self.norm1 = nn.InstanceNorm2d(planes // 4) 628 | self.norm2 = nn.InstanceNorm2d(planes // 4) 629 | self.norm3 = nn.InstanceNorm2d(planes) 630 | if not stride == 1: 631 | self.norm4 = nn.InstanceNorm2d(planes) 632 | 633 | elif norm_fn == 'none': 634 | self.norm1 = nn.Sequential() 635 | self.norm2 = nn.Sequential() 636 | self.norm3 = nn.Sequential() 637 | if not stride == 1: 638 | self.norm4 = nn.Sequential() 639 | 640 | if stride == 1: 641 | self.downsample = None 642 | 643 | else: 644 | self.downsample = nn.Sequential( 645 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), 646 | self.norm4) 647 | 648 | def forward(self, x): 649 | y = x 650 | y = self.relu(self.norm1(self.conv1(y))) 651 | y = self.relu(self.norm2(self.conv2(y))) 652 | y = self.relu(self.norm3(self.conv3(y))) 653 | 654 | if self.downsample is not None: 655 | x = self.downsample(x) 656 | 657 | return self.relu(x + y) 658 | 659 | 660 | class BasicMotionEncoder_v2(nn.Module): 661 | """Get Motion Feature from CSFlow, by Hao""" 662 | 663 | def __init__(self, args): 664 | super(BasicMotionEncoder_v2, self).__init__() 665 | # double cor_plances due to concat aug 666 | cor_planes = 2 * (args.corr_levels * (2 * args.corr_radius + 1)**2) 667 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 668 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 669 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 670 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 671 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) 672 | 673 | def forward(self, flow, corr): 674 | cor = F.relu(self.convc1(corr)) 675 | cor = F.relu(self.convc2(cor)) 676 | flo = F.relu(self.convf1(flow)) 677 | flo = F.relu(self.convf2(flo)) 678 | 679 | cor_flo = torch.cat([cor, flo], dim=1) 680 | out = F.relu(self.conv(cor_flo)) 681 | return torch.cat([out, flow], dim=1) 682 | 683 | 684 | class SepConvGRU(nn.Module): 685 | 686 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 687 | super(SepConvGRU, self).__init__() 688 | self.convz1 = nn.Conv2d( 689 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 690 | self.convr1 = nn.Conv2d( 691 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 692 | self.convq1 = nn.Conv2d( 693 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 694 | 695 | self.convz2 = nn.Conv2d( 696 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 697 | self.convr2 = nn.Conv2d( 698 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 699 | self.convq2 = nn.Conv2d( 700 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 701 | 702 | def forward(self, h, x): 703 | # horizontal 704 | hx = torch.cat([h, x], dim=1) 705 | z = torch.sigmoid(self.convz1(hx)) 706 | r = torch.sigmoid(self.convr1(hx)) 707 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 708 | h = (1 - z) * h + z * q 709 | 710 | # vertical 711 | hx = torch.cat([h, x], dim=1) 712 | z = torch.sigmoid(self.convz2(hx)) 713 | r = torch.sigmoid(self.convr2(hx)) 714 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 715 | h = (1 - z) * h + z * q 716 | 717 | return h 718 | 719 | 720 | class FlowHead(nn.Module): 721 | 722 | def __init__(self, input_dim=128, hidden_dim=256): 723 | super(FlowHead, self).__init__() 724 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 725 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 726 | self.relu = nn.ReLU(inplace=True) 727 | 728 | def forward(self, x): 729 | return self.conv2(self.relu(self.conv1(x))) 730 | 731 | 732 | class ConvGRU(nn.Module): 733 | 734 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 735 | super(ConvGRU, self).__init__() 736 | self.convz = nn.Conv2d( 737 | hidden_dim + input_dim, hidden_dim, 3, padding=1) 738 | self.convr = nn.Conv2d( 739 | hidden_dim + input_dim, hidden_dim, 3, padding=1) 740 | self.convq = nn.Conv2d( 741 | hidden_dim + input_dim, hidden_dim, 3, padding=1) 742 | 743 | def forward(self, h, x): 744 | hx = torch.cat([h, x], dim=1) 745 | 746 | z = torch.sigmoid(self.convz(hx)) 747 | r = torch.sigmoid(self.convr(hx)) 748 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) 749 | 750 | h = (1 - z) * h + z * q 751 | return h 752 | -------------------------------------------------------------------------------- /opticalflow/core/model/external/panoflow_csflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torchvision.ops import DeformConv2d 6 | 7 | try: 8 | autocast = torch.cuda.amp.autocast 9 | except (Exception): 10 | # dummy autocast for PyTorch < 1.6 11 | class autocast: 12 | 13 | def __init__(self, enabled): 14 | pass 15 | 16 | def __enter__(self): 17 | pass 18 | 19 | def __exit__(self, *args): 20 | pass 21 | 22 | 23 | class PanoCSFlow(nn.Module): 24 | 25 | def __init__(self, args): 26 | super(PanoCSFlow, self).__init__() 27 | self.args = args 28 | 29 | self.hidden_dim = hdim = 128 30 | self.context_dim = cdim = 128 31 | args.corr_levels = 4 32 | args.corr_radius = 4 33 | 34 | if 'dcn' not in self.args: 35 | self.args.dcn = True 36 | 37 | if 'dropout' not in self.args: 38 | self.args.dropout = 0 39 | 40 | if 'alternate_corr' not in self.args: 41 | self.args.alternate_corr = False 42 | 43 | if 'mixed_precision' not in self.args: 44 | self.args.mixed_precision = False 45 | 46 | # feature network, context network, and update block 47 | self.fnet = BasicEncoder( 48 | output_dim=256, norm_fn='instance', dropout=args.dropout, 49 | dcn=self.args.dcn) 50 | 51 | self.cnet = BasicEncoder( 52 | output_dim=hdim + cdim, 53 | norm_fn='batch', 54 | dropout=args.dropout, 55 | dcn=self.args.dcn) 56 | 57 | self.strip_corr_block_v2 = StripCrossCorrMap_v2( 58 | in_chan=256, out_chan=256) 59 | self.update_block = BasicUpdateBlock( 60 | self.args, hidden_dim=hdim) 61 | 62 | def freeze_bn(self): 63 | for m in self.modules(): 64 | if isinstance(m, nn.BatchNorm2d): 65 | m.eval() 66 | 67 | def initialize_flow(self, img, dataset, train_flag): 68 | """Flow is represented as difference between two coordinate grids. 69 | 70 | flow = coords1 - coords0, Modified by Hao 71 | """ 72 | N, C, H, W = img.shape 73 | 74 | if dataset == 'KITTI' and not train_flag: 75 | coords0 = coords_grid(N, H // 8 + 1, W // 8 + 1, device=img.device) 76 | coords1 = coords_grid(N, H // 8 + 1, W // 8 + 1, device=img.device) 77 | elif dataset == 'Sintel' and not train_flag: 78 | coords0 = coords_grid(N, H // 8 + 1, W // 8, device=img.device) 79 | coords1 = coords_grid(N, H // 8 + 1, W // 8, device=img.device) 80 | else: 81 | coords0 = coords_grid(N, H // 8, W // 8, device=img.device) 82 | coords1 = coords_grid(N, H // 8, W // 8, device=img.device) 83 | 84 | # optical flow computed as difference: flow = coords1 - coords0 85 | return coords0, coords1 86 | 87 | def upsample_flow(self, flow, mask): 88 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex 89 | combination.""" 90 | N, _, H, W = flow.shape 91 | mask = mask.view(N, 1, 9, 8, 8, H, W) 92 | mask = torch.softmax(mask, dim=2) 93 | 94 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 95 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 96 | 97 | up_flow = torch.sum(mask * up_flow, dim=2) 98 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 99 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 100 | 101 | def forward(self, images, flow_init=None, upsample=True, test_mode=False, gen_fmap=False, skip_encode=False): 102 | """Estimate optical flow between pair of frames.""" 103 | 104 | if not skip_encode: 105 | # Modified, take image pairs as input 106 | image1 = images[0] 107 | image2 = images[1] 108 | image1 = 2 * (image1 / 255.0) - 1.0 109 | image2 = 2 * (image2 / 255.0) - 1.0 110 | 111 | image1 = image1.contiguous() 112 | image2 = image2.contiguous() 113 | 114 | hdim = self.hidden_dim 115 | cdim = self.context_dim 116 | 117 | # run the feature network 118 | with autocast(enabled=self.args.mixed_precision): 119 | fmap1, fmap2 = self.fnet([image1, image2]) 120 | 121 | fmap1 = fmap1.float() 122 | fmap2 = fmap2.float() 123 | else: 124 | hdim = self.hidden_dim 125 | cdim = self.context_dim 126 | fmap1 = images[0] 127 | fmap2 = images[1] 128 | 129 | # run the context network 130 | with autocast(enabled=self.args.mixed_precision): 131 | 132 | if not skip_encode: 133 | cnet = self.cnet(image1) 134 | 135 | if test_mode: 136 | if gen_fmap: 137 | return fmap1, fmap2, cnet 138 | else: 139 | cnet = images[2] 140 | 141 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 142 | net = torch.tanh(net) 143 | inp = torch.relu(inp) 144 | 145 | strip_coor_map, strip_corr_map_w, strip_corr_map_h = self.strip_corr_block_v2( 146 | [fmap1, fmap2]) 147 | corr_fn = CorrBlock_v2( 148 | fmap1, fmap2, strip_coor_map, radius=self.args.corr_radius) 149 | 150 | if not skip_encode: 151 | coords0, coords1 = self.initialize_flow(image1, self.args.dataset, 152 | self.args.train) 153 | else: 154 | b, c, h, w = fmap1.shape 155 | image1 = torch.zeros(b, c, 8 * h, 8 * w).cuda() 156 | coords0, coords1 = self.initialize_flow(image1, self.args.dataset, 157 | self.args.train) 158 | 159 | if flow_init is not None: 160 | coords1 = coords1 + flow_init 161 | 162 | flow_predictions = [] 163 | 164 | # init flow with regression before GRU iters 165 | corr_w_act = torch.nn.functional.softmax( 166 | strip_corr_map_w, dim=3) # B H1 W1 1 W2 167 | corr_h_act = torch.nn.functional.softmax( 168 | strip_corr_map_h, dim=4) # B H1 W1 H2 1 169 | 170 | flo_v = corr_w_act.mul(strip_corr_map_w) # B H1 W1 1 W2 171 | flo_u = corr_h_act.mul(strip_corr_map_h) # B H1 W1 H2 1 172 | 173 | flow_v = torch.sum(flo_v, dim=4).squeeze(dim=3) # B H1 W1 174 | flow_u = torch.sum(flo_u, dim=3).squeeze(dim=3) # B H1 W1 175 | 176 | corr_init = torch.stack((flow_u, flow_v), dim=1) # B 2 H1 W1 177 | 178 | coords1 = coords1.detach() 179 | coords1 = coords1 + corr_init 180 | 181 | # add loss 182 | flow_up = upflow8(coords1 - coords0) 183 | flow_predictions.append(flow_up) 184 | 185 | if not test_mode: 186 | for itr in range(self.args.iters): 187 | coords1 = coords1.detach() 188 | corr = corr_fn(coords1) # index correlation volume 189 | 190 | flow = coords1 - coords0 191 | with autocast(enabled=self.args.mixed_precision): 192 | 193 | net, up_mask, delta_flow = self.update_block( 194 | net, inp, corr, flow) 195 | 196 | # F(t+1) = F(t) + \Delta(t) 197 | coords1 = coords1 + delta_flow 198 | 199 | # upsample predictions 200 | if up_mask is None: 201 | flow_up = upflow8(coords1 - coords0) 202 | else: 203 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 204 | 205 | flow_predictions.append(flow_up) 206 | else: 207 | iters = self.args.eval_iters 208 | 209 | for itr in range(iters): 210 | coords1 = coords1.detach() 211 | corr = corr_fn(coords1) # index correlation volume 212 | 213 | flow = coords1 - coords0 214 | with autocast(enabled=self.args.mixed_precision): 215 | 216 | net, up_mask, delta_flow = self.update_block( 217 | net, inp, corr, flow) 218 | 219 | # F(t+1) = F(t) + \Delta(t) 220 | coords1 = coords1 + delta_flow 221 | 222 | # upsample predictions 223 | if up_mask is None: 224 | flow_up = upflow8(coords1 - coords0) 225 | else: 226 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 227 | 228 | flow_predictions.append(flow_up) 229 | 230 | if test_mode: 231 | return coords1 - coords0, flow_up 232 | 233 | return flow_predictions 234 | 235 | 236 | class StripCrossCorrMap_v2(nn.Module): 237 | """Strip Cross Corr Augmentation Module by Hao, version2.0""" 238 | 239 | def __init__(self, in_chan=256, out_chan=256, *args, **kwargs): 240 | super(StripCrossCorrMap_v2, self).__init__() 241 | self.conv1_1 = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 242 | self.conv1_2 = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 243 | self.conv2_1 = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 244 | self.conv2_2 = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 245 | 246 | self.init_weight() 247 | 248 | def forward(self, x): 249 | fmap1, fmap2 = x 250 | 251 | # vertical query map 252 | fmap1_w = self.conv1_1(fmap1) # B, 64, H, W 253 | batchsize, c_middle, h, w = fmap1_w.size() 254 | fmap1_w = fmap1_w.view(batchsize, c_middle, -1) 255 | 256 | # horizontal query map 257 | fmap1_h = self.conv1_2(fmap1) # B, 64, H, W 258 | batchsize, c_middle, h, w = fmap1_h.size() 259 | fmap1_h = fmap1_h.view(batchsize, c_middle, -1) 260 | 261 | # vertical striping map 262 | fmap2_w = self.conv2_1(fmap2) # B, 64, H, W 263 | fmap2_w = F.avg_pool2d(fmap2_w, [h, 1]) 264 | fmap2_w = fmap2_w.view(batchsize, c_middle, -1).permute(0, 2, 1) 265 | 266 | # horizontal striping map 267 | fmap2_h = self.conv2_2(fmap2) # B, 64, H, W 268 | fmap2_h = F.avg_pool2d(fmap2_h, [1, w]) 269 | fmap2_h = fmap2_h.view(batchsize, c_middle, -1).permute(0, 2, 1) 270 | 271 | # cross strip corr map 272 | strip_corr_map_w = torch.bmm(fmap2_w, fmap1_w).\ 273 | view(batchsize, w, h, w, 1).permute(0, 2, 3, 4, 1) # B H1 W1 1 W2 274 | strip_corr_map_h = torch.bmm(fmap2_h, fmap1_h).\ 275 | view(batchsize, h, h, w, 1).permute(0, 2, 3, 1, 4) # B H1 W1 H2 1 276 | 277 | return (strip_corr_map_w + strip_corr_map_h).view( 278 | batchsize, h, w, 1, h, w), strip_corr_map_w, strip_corr_map_h 279 | 280 | def init_weight(self): 281 | for ly in self.children(): 282 | if isinstance(ly, nn.Conv2d): 283 | nn.init.kaiming_normal_(ly.weight, a=1) 284 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 285 | 286 | def get_params(self): 287 | wd_params, nowd_params = [], [] 288 | for name, module in self.named_modules(): 289 | if isinstance(module, (nn.Linear, nn.Conv2d)): 290 | wd_params.append(module.weight) 291 | if not module.bias is None: 292 | nowd_params.append(module.bias) 293 | elif isinstance(module, torch.nn.BatchNorm2d): 294 | nowd_params += list(module.parameters()) 295 | return wd_params, nowd_params 296 | 297 | 298 | class ConvBNReLU(nn.Module): 299 | """Conv with BN and ReLU, used for Strip Corr Module""" 300 | 301 | def __init__(self, 302 | in_chan, 303 | out_chan, 304 | ks=3, 305 | stride=1, 306 | padding=1, 307 | *args, 308 | **kwargs): 309 | super(ConvBNReLU, self).__init__() 310 | self.conv = nn.Conv2d( 311 | in_chan, 312 | out_chan, 313 | kernel_size=ks, 314 | stride=stride, 315 | padding=padding, 316 | bias=False) 317 | self.bn = torch.nn.BatchNorm2d(out_chan) 318 | self.relu = nn.ReLU(inplace=True) 319 | 320 | def forward(self, x): 321 | x = self.conv(x) 322 | x = self.bn(x) 323 | x = self.relu(x) 324 | return x 325 | 326 | def init_weight(self): 327 | for ly in self.children(): 328 | if isinstance(ly, nn.Conv2d): 329 | nn.init.kaiming_normal_(ly.weight, a=1) 330 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 331 | 332 | 333 | class BasicUpdateBlock(nn.Module): 334 | """Modified by Hao, support for CSFlow""" 335 | 336 | def __init__(self, args, hidden_dim=128, input_dim=128): 337 | super(BasicUpdateBlock, self).__init__() 338 | self.args = args 339 | self.encoder = BasicMotionEncoder_v2(args) 340 | self.gru = SepConvGRU( 341 | hidden_dim=hidden_dim, input_dim=128 + hidden_dim) 342 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 343 | 344 | self.mask = nn.Sequential( 345 | nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), 346 | nn.Conv2d(256, 64 * 9, 1, padding=0)) 347 | 348 | def forward(self, net, inp, corr, flow, upsample=True): 349 | motion_features = self.encoder(flow, corr) 350 | inp = torch.cat([inp, motion_features], dim=1) 351 | 352 | net = self.gru(net, inp) 353 | delta_flow = self.flow_head(net) 354 | 355 | # scale mask to balence gradients 356 | mask = .25 * self.mask(net) 357 | return net, mask, delta_flow 358 | 359 | 360 | def pool2x(x): 361 | return F.avg_pool2d(x, 3, stride=2, padding=1) 362 | 363 | 364 | def interp(x, dest): 365 | interp_args = {'mode': 'bilinear', 'align_corners': True} 366 | return F.interpolate(x, dest.shape[2:], **interp_args) 367 | 368 | 369 | class BasicEncoder(nn.Module): 370 | 371 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, dcn=False): 372 | from torch.nn.modules.utils import _pair 373 | super(BasicEncoder, self).__init__() 374 | self.norm_fn = norm_fn 375 | 376 | self.dcn = dcn 377 | 378 | if self.norm_fn == 'group': 379 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 380 | 381 | elif self.norm_fn == 'batch': 382 | self.norm1 = nn.BatchNorm2d(64) 383 | 384 | elif self.norm_fn == 'instance': 385 | self.norm1 = nn.InstanceNorm2d(64) 386 | 387 | elif self.norm_fn == 'none': 388 | self.norm1 = nn.Sequential() 389 | 390 | if dcn: 391 | self.conv_offset = nn.Conv2d( 392 | 3, 393 | 2 * 7 * 7, 394 | 7, 395 | stride=_pair(2), 396 | padding=_pair(3), 397 | dilation=_pair(1), 398 | bias=True) 399 | self.conv_offset.weight.data.zero_() 400 | self.conv_offset.bias.data.zero_() 401 | self.dconv = DeformConv2d(3, 64, kernel_size=7, stride=2, padding=3) 402 | else: 403 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 404 | self.relu1 = nn.ReLU(inplace=True) 405 | 406 | self.in_planes = 64 407 | self.layer1 = self._make_layer(64, stride=1) 408 | self.layer2 = self._make_layer(96, stride=2) 409 | self.layer3 = self._make_layer(128, stride=2) 410 | 411 | # output convolution 412 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 413 | 414 | self.dropout = None 415 | if dropout > 0: 416 | self.dropout = nn.Dropout2d(p=dropout) 417 | 418 | for m in self.modules(): 419 | if isinstance(m, nn.Conv2d): 420 | nn.init.kaiming_normal_( 421 | m.weight, mode='fan_out', nonlinearity='relu') 422 | elif isinstance(m, 423 | (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 424 | if m.weight is not None: 425 | nn.init.constant_(m.weight, 1) 426 | if m.bias is not None: 427 | nn.init.constant_(m.bias, 0) 428 | 429 | def _make_layer(self, dim, stride=1): 430 | layer1 = ResidualBlock( 431 | self.in_planes, dim, self.norm_fn, stride=stride) 432 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 433 | layers = (layer1, layer2) 434 | 435 | self.in_planes = dim 436 | return nn.Sequential(*layers) 437 | 438 | def forward(self, x): 439 | 440 | # if input is list, combine batch dimension 441 | is_list = isinstance(x, tuple) or isinstance(x, list) 442 | if is_list: 443 | batch_dim = x[0].shape[0] 444 | x = torch.cat(x, dim=0) 445 | 446 | if self.dcn: 447 | offset = self.conv_offset(x) 448 | x = self.dconv(x, offset) 449 | else: 450 | x = self.conv1(x) 451 | 452 | x = self.norm1(x) 453 | x = self.relu1(x) 454 | 455 | x = self.layer1(x) 456 | x = self.layer2(x) 457 | x = self.layer3(x) 458 | 459 | x = self.conv2(x) 460 | 461 | if self.training and self.dropout is not None: 462 | x = self.dropout(x) 463 | 464 | if is_list: 465 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 466 | 467 | return x 468 | 469 | 470 | class CorrBlock_v2: 471 | """Corr Block, modified by Hao, concat SC with 4D corr""" 472 | 473 | def __init__(self, 474 | fmap1, 475 | fmap2, 476 | strip_coor_map=None, 477 | num_levels=4, 478 | radius=4): 479 | self.num_levels = num_levels 480 | self.radius = radius 481 | self.corr_pyramid = [] 482 | 483 | # all pairs correlation 484 | corr = CorrBlock_v2.corr(fmap1, fmap2) 485 | 486 | if strip_coor_map is not None: 487 | # strip correlation augmentation with concat 488 | corr = torch.cat((corr, strip_coor_map), dim=3) 489 | 490 | batch, h1, w1, dim, h2, w2 = corr.shape 491 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 492 | 493 | self.corr_pyramid.append(corr) 494 | for i in range(self.num_levels - 1): 495 | corr = F.avg_pool2d(corr, 2, stride=2) 496 | self.corr_pyramid.append(corr) 497 | 498 | def __call__(self, coords): 499 | r = self.radius 500 | coords = coords.permute(0, 2, 3, 1) 501 | batch, h1, w1, _ = coords.shape 502 | 503 | out_pyramid = [] 504 | for i in range(self.num_levels): 505 | corr = self.corr_pyramid[i] 506 | dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 507 | dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 508 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 509 | 510 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i 511 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 512 | coords_lvl = centroid_lvl + delta_lvl 513 | 514 | corr = bilinear_sampler(corr, coords_lvl) 515 | corr = corr.view(batch, h1, w1, -1) 516 | out_pyramid.append(corr) 517 | 518 | out = torch.cat(out_pyramid, dim=-1) 519 | return out.permute(0, 3, 1, 2).contiguous().float() 520 | 521 | @staticmethod 522 | def corr(fmap1, fmap2): 523 | batch, dim, ht, wd = fmap1.shape 524 | fmap1 = fmap1.view(batch, dim, ht * wd) 525 | fmap2 = fmap2.view(batch, dim, ht * wd) 526 | 527 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 528 | corr = corr.view(batch, ht, wd, 1, ht, wd) 529 | return corr / torch.sqrt(torch.tensor(dim).float()) 530 | 531 | 532 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 533 | """Wrapper for grid_sample, uses pixel coordinates.""" 534 | H, W = img.shape[-2:] 535 | xgrid, ygrid = coords.split([1, 1], dim=-1) 536 | xgrid = 2 * xgrid / (W - 1) - 1 537 | ygrid = 2 * ygrid / (H - 1) - 1 538 | 539 | grid = torch.cat([xgrid, ygrid], dim=-1) 540 | img = F.grid_sample(img, grid, align_corners=True) 541 | 542 | if mask: 543 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 544 | return img, mask.float() 545 | 546 | return img 547 | 548 | 549 | def coords_grid(batch, ht, wd, device): 550 | coords = torch.meshgrid( 551 | torch.arange(ht, device=device), torch.arange(wd, device=device)) 552 | coords = torch.stack(coords[::-1], dim=0).float() 553 | return coords[None].repeat(batch, 1, 1, 1) 554 | 555 | 556 | def upflow8(flow, mode='bilinear'): 557 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 558 | return 8 * F.interpolate( 559 | flow, size=new_size, mode=mode, align_corners=True) 560 | 561 | 562 | class ResidualBlock(nn.Module): 563 | 564 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 565 | super(ResidualBlock, self).__init__() 566 | 567 | self.conv1 = nn.Conv2d( 568 | in_planes, planes, kernel_size=3, padding=1, stride=stride) 569 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 570 | self.relu = nn.ReLU(inplace=True) 571 | 572 | num_groups = planes // 8 573 | 574 | if norm_fn == 'group': 575 | self.norm1 = nn.GroupNorm( 576 | num_groups=num_groups, num_channels=planes) 577 | self.norm2 = nn.GroupNorm( 578 | num_groups=num_groups, num_channels=planes) 579 | if not stride == 1: 580 | self.norm3 = nn.GroupNorm( 581 | num_groups=num_groups, num_channels=planes) 582 | 583 | elif norm_fn == 'batch': 584 | self.norm1 = nn.BatchNorm2d(planes) 585 | self.norm2 = nn.BatchNorm2d(planes) 586 | if not stride == 1: 587 | self.norm3 = nn.BatchNorm2d(planes) 588 | 589 | elif norm_fn == 'instance': 590 | self.norm1 = nn.InstanceNorm2d(planes) 591 | self.norm2 = nn.InstanceNorm2d(planes) 592 | if not stride == 1: 593 | self.norm3 = nn.InstanceNorm2d(planes) 594 | 595 | elif norm_fn == 'none': 596 | self.norm1 = nn.Sequential() 597 | self.norm2 = nn.Sequential() 598 | if not stride == 1: 599 | self.norm3 = nn.Sequential() 600 | 601 | if stride == 1: 602 | self.downsample = None 603 | 604 | else: 605 | self.downsample = nn.Sequential( 606 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), 607 | self.norm3) 608 | 609 | def forward(self, x): 610 | y = x 611 | y = self.relu(self.norm1(self.conv1(y))) 612 | y = self.relu(self.norm2(self.conv2(y))) 613 | 614 | if self.downsample is not None: 615 | x = self.downsample(x) 616 | 617 | return self.relu(x + y) 618 | 619 | 620 | class BottleneckBlock(nn.Module): 621 | 622 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 623 | super(BottleneckBlock, self).__init__() 624 | 625 | self.conv1 = nn.Conv2d( 626 | in_planes, planes // 4, kernel_size=1, padding=0) 627 | self.conv2 = nn.Conv2d( 628 | planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) 629 | self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) 630 | self.relu = nn.ReLU(inplace=True) 631 | 632 | num_groups = planes // 8 633 | 634 | if norm_fn == 'group': 635 | self.norm1 = nn.GroupNorm( 636 | num_groups=num_groups, num_channels=planes // 4) 637 | self.norm2 = nn.GroupNorm( 638 | num_groups=num_groups, num_channels=planes // 4) 639 | self.norm3 = nn.GroupNorm( 640 | num_groups=num_groups, num_channels=planes) 641 | if not stride == 1: 642 | self.norm4 = nn.GroupNorm( 643 | num_groups=num_groups, num_channels=planes) 644 | 645 | elif norm_fn == 'batch': 646 | self.norm1 = nn.BatchNorm2d(planes // 4) 647 | self.norm2 = nn.BatchNorm2d(planes // 4) 648 | self.norm3 = nn.BatchNorm2d(planes) 649 | if not stride == 1: 650 | self.norm4 = nn.BatchNorm2d(planes) 651 | 652 | elif norm_fn == 'instance': 653 | self.norm1 = nn.InstanceNorm2d(planes // 4) 654 | self.norm2 = nn.InstanceNorm2d(planes // 4) 655 | self.norm3 = nn.InstanceNorm2d(planes) 656 | if not stride == 1: 657 | self.norm4 = nn.InstanceNorm2d(planes) 658 | 659 | elif norm_fn == 'none': 660 | self.norm1 = nn.Sequential() 661 | self.norm2 = nn.Sequential() 662 | self.norm3 = nn.Sequential() 663 | if not stride == 1: 664 | self.norm4 = nn.Sequential() 665 | 666 | if stride == 1: 667 | self.downsample = None 668 | 669 | else: 670 | self.downsample = nn.Sequential( 671 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), 672 | self.norm4) 673 | 674 | def forward(self, x): 675 | y = x 676 | y = self.relu(self.norm1(self.conv1(y))) 677 | y = self.relu(self.norm2(self.conv2(y))) 678 | y = self.relu(self.norm3(self.conv3(y))) 679 | 680 | if self.downsample is not None: 681 | x = self.downsample(x) 682 | 683 | return self.relu(x + y) 684 | 685 | 686 | class BasicMotionEncoder_v2(nn.Module): 687 | """Get Motion Feature from CSFlow, by Hao""" 688 | 689 | def __init__(self, args): 690 | super(BasicMotionEncoder_v2, self).__init__() 691 | # double cor_plances due to concat aug 692 | cor_planes = 2 * (args.corr_levels * (2 * args.corr_radius + 1)**2) 693 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 694 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 695 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 696 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 697 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) 698 | 699 | def forward(self, flow, corr): 700 | cor = F.relu(self.convc1(corr)) 701 | cor = F.relu(self.convc2(cor)) 702 | flo = F.relu(self.convf1(flow)) 703 | flo = F.relu(self.convf2(flo)) 704 | 705 | cor_flo = torch.cat([cor, flo], dim=1) 706 | out = F.relu(self.conv(cor_flo)) 707 | return torch.cat([out, flow], dim=1) 708 | 709 | 710 | class SepConvGRU(nn.Module): 711 | 712 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 713 | super(SepConvGRU, self).__init__() 714 | self.convz1 = nn.Conv2d( 715 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 716 | self.convr1 = nn.Conv2d( 717 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 718 | self.convq1 = nn.Conv2d( 719 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 720 | 721 | self.convz2 = nn.Conv2d( 722 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 723 | self.convr2 = nn.Conv2d( 724 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 725 | self.convq2 = nn.Conv2d( 726 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 727 | 728 | def forward(self, h, x): 729 | # horizontal 730 | hx = torch.cat([h, x], dim=1) 731 | z = torch.sigmoid(self.convz1(hx)) 732 | r = torch.sigmoid(self.convr1(hx)) 733 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 734 | h = (1 - z) * h + z * q 735 | 736 | # vertical 737 | hx = torch.cat([h, x], dim=1) 738 | z = torch.sigmoid(self.convz2(hx)) 739 | r = torch.sigmoid(self.convr2(hx)) 740 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 741 | h = (1 - z) * h + z * q 742 | 743 | return h 744 | 745 | 746 | class FlowHead(nn.Module): 747 | 748 | def __init__(self, input_dim=128, hidden_dim=256): 749 | super(FlowHead, self).__init__() 750 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 751 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 752 | self.relu = nn.ReLU(inplace=True) 753 | 754 | def forward(self, x): 755 | return self.conv2(self.relu(self.conv1(x))) 756 | 757 | 758 | class ConvGRU(nn.Module): 759 | 760 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 761 | super(ConvGRU, self).__init__() 762 | self.convz = nn.Conv2d( 763 | hidden_dim + input_dim, hidden_dim, 3, padding=1) 764 | self.convr = nn.Conv2d( 765 | hidden_dim + input_dim, hidden_dim, 3, padding=1) 766 | self.convq = nn.Conv2d( 767 | hidden_dim + input_dim, hidden_dim, 3, padding=1) 768 | 769 | def forward(self, h, x): 770 | hx = torch.cat([h, x], dim=1) 771 | 772 | z = torch.sigmoid(self.convz(hx)) 773 | r = torch.sigmoid(self.convr(hx)) 774 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) 775 | 776 | h = (1 - z) * h + z * q 777 | return h 778 | -------------------------------------------------------------------------------- /opticalflow/core/model/external/panoflow_raft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torchvision.ops import DeformConv2d 6 | 7 | try: 8 | autocast = torch.cuda.amp.autocast 9 | except (Exception): 10 | # dummy autocast for PyTorch < 1.6 11 | class autocast: 12 | 13 | def __init__(self, enabled): 14 | pass 15 | 16 | def __enter__(self): 17 | pass 18 | 19 | def __exit__(self, *args): 20 | pass 21 | 22 | 23 | class PanoRAFT(nn.Module): 24 | 25 | def __init__(self, args): 26 | super(PanoRAFT, self).__init__() 27 | self.args = args 28 | 29 | self.hidden_dim = hdim = 128 30 | self.context_dim = cdim = 128 31 | args.corr_levels = 4 32 | args.corr_radius = 4 33 | 34 | if 'dropout' not in self.args: 35 | self.args.dropout = 0 36 | 37 | if 'mixed_precision' not in self.args: 38 | self.args.mixed_precision = False 39 | 40 | # feature network, context network, and update block 41 | 42 | self.fnet = BasicEncoder( 43 | output_dim=256, norm_fn='instance', dropout=args.dropout, 44 | dcn=True) 45 | 46 | self.cnet = BasicEncoder( 47 | output_dim=hdim + cdim, 48 | norm_fn='batch', 49 | dropout=args.dropout, 50 | dcn=True) 51 | 52 | self.update_block = BasicUpdateBlock( 53 | self.args, hidden_dim=hdim) 54 | 55 | def freeze_bn(self): 56 | for m in self.modules(): 57 | if isinstance(m, nn.BatchNorm2d): 58 | m.eval() 59 | 60 | def initialize_flow(self, img, dataset, train_flag): 61 | """Flow is represented as difference between two coordinate grids. 62 | 63 | flow = coords1 - coords0, Modified by Hao 64 | """ 65 | N, C, H, W = img.shape 66 | 67 | if dataset == 'KITTI' and not train_flag: 68 | coords0 = coords_grid(N, H // 8 + 1, W // 8 + 1, device=img.device) 69 | coords1 = coords_grid(N, H // 8 + 1, W // 8 + 1, device=img.device) 70 | elif dataset == 'Sintel' and not train_flag: 71 | coords0 = coords_grid(N, H // 8 + 1, W // 8, device=img.device) 72 | coords1 = coords_grid(N, H // 8 + 1, W // 8, device=img.device) 73 | else: 74 | coords0 = coords_grid(N, H // 8, W // 8, device=img.device) 75 | coords1 = coords_grid(N, H // 8, W // 8, device=img.device) 76 | 77 | # optical flow computed as difference: flow = coords1 - coords0 78 | return coords0, coords1 79 | 80 | def upsample_flow(self, flow, mask): 81 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex 82 | combination.""" 83 | N, _, H, W = flow.shape 84 | mask = mask.view(N, 1, 9, 8, 8, H, W) 85 | mask = torch.softmax(mask, dim=2) 86 | 87 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 88 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 89 | 90 | up_flow = torch.sum(mask * up_flow, dim=2) 91 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 92 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 93 | 94 | def forward(self, images, flow_init=None, upsample=True, test_mode=False, gen_fmap=False, skip_encode=False): 95 | """Estimate optical flow between pair of frames.""" 96 | 97 | if not skip_encode: 98 | # Modified, take image pairs as input 99 | image1 = images[0] 100 | image2 = images[1] 101 | image1 = 2 * (image1 / 255.0) - 1.0 102 | image2 = 2 * (image2 / 255.0) - 1.0 103 | 104 | image1 = image1.contiguous() 105 | image2 = image2.contiguous() 106 | 107 | hdim = self.hidden_dim 108 | cdim = self.context_dim 109 | 110 | # run the feature network 111 | with autocast(enabled=self.args.mixed_precision): 112 | fmap1, fmap2 = self.fnet([image1, image2]) 113 | 114 | fmap1 = fmap1.float() 115 | fmap2 = fmap2.float() 116 | else: 117 | hdim = self.hidden_dim 118 | cdim = self.context_dim 119 | fmap1 = images[0] 120 | fmap2 = images[1] 121 | 122 | # run the context network 123 | with autocast(enabled=self.args.mixed_precision): 124 | if not skip_encode: 125 | cnet = self.cnet(image1) 126 | 127 | if test_mode: 128 | if gen_fmap: 129 | return fmap1, fmap2, cnet 130 | else: 131 | cnet = images[2] 132 | 133 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 134 | net = torch.tanh(net) 135 | inp = torch.relu(inp) 136 | 137 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 138 | 139 | if not skip_encode: 140 | coords0, coords1 = self.initialize_flow(image1, self.args.dataset, 141 | self.args.train) 142 | else: 143 | N, C, H, W = fmap1.shape 144 | coords0 = coords_grid(N, H, W, device=fmap1.device) 145 | coords1 = coords_grid(N, H, W, device=fmap1.device) 146 | 147 | if flow_init is not None: 148 | coords1 = coords1 + flow_init 149 | 150 | flow_predictions = [] 151 | 152 | if not test_mode: 153 | for itr in range(self.args.iters): 154 | coords1 = coords1.detach() 155 | corr = corr_fn(coords1) # index correlation volume 156 | 157 | flow = coords1 - coords0 158 | with autocast(enabled=self.args.mixed_precision): 159 | net, up_mask, delta_flow = self.update_block( 160 | net, inp, corr, flow) 161 | 162 | # F(t+1) = F(t) + \Delta(t) 163 | coords1 = coords1 + delta_flow 164 | 165 | # upsample predictions 166 | if up_mask is None: 167 | flow_up = upflow8(coords1 - coords0) 168 | else: 169 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 170 | 171 | flow_predictions.append(flow_up) 172 | else: 173 | iters = self.args.eval_iters 174 | 175 | for itr in range(iters): 176 | coords1 = coords1.detach() 177 | corr = corr_fn(coords1) # index correlation volume 178 | 179 | flow = coords1 - coords0 180 | with autocast(enabled=self.args.mixed_precision): 181 | net, up_mask, delta_flow = self.update_block( 182 | net, inp, corr, flow) 183 | 184 | # F(t+1) = F(t) + \Delta(t) 185 | coords1 = coords1 + delta_flow 186 | 187 | # upsample predictions 188 | if up_mask is None: 189 | flow_up = upflow8(coords1 - coords0) 190 | else: 191 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 192 | 193 | flow_predictions.append(flow_up) 194 | 195 | if test_mode: 196 | return coords1 - coords0, flow_up 197 | 198 | return flow_predictions 199 | 200 | 201 | class BasicUpdateBlock(nn.Module): 202 | """Modified by Hao""" 203 | 204 | def __init__(self, args, hidden_dim=128, input_dim=128): 205 | super(BasicUpdateBlock, self).__init__() 206 | self.args = args 207 | self.encoder = BasicMotionEncoder(args) 208 | self.gru = SepConvGRU( 209 | hidden_dim=hidden_dim, input_dim=128 + hidden_dim) 210 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 211 | 212 | self.mask = nn.Sequential( 213 | nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), 214 | nn.Conv2d(256, 64 * 9, 1, padding=0)) 215 | 216 | def forward(self, net, inp, corr, flow, upsample=True): 217 | motion_features = self.encoder(flow, corr) 218 | inp = torch.cat([inp, motion_features], dim=1) 219 | 220 | net = self.gru(net, inp) 221 | delta_flow = self.flow_head(net) 222 | 223 | # scale mask to balence gradients 224 | mask = .25 * self.mask(net) 225 | return net, mask, delta_flow 226 | 227 | 228 | def pool2x(x): 229 | return F.avg_pool2d(x, 3, stride=2, padding=1) 230 | 231 | 232 | def interp(x, dest): 233 | interp_args = {'mode': 'bilinear', 'align_corners': True} 234 | return F.interpolate(x, dest.shape[2:], **interp_args) 235 | 236 | 237 | class BasicEncoder(nn.Module): 238 | 239 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, dcn=False): 240 | from torch.nn.modules.utils import _pair 241 | super(BasicEncoder, self).__init__() 242 | self.norm_fn = norm_fn 243 | 244 | self.dcn = dcn 245 | 246 | if self.norm_fn == 'group': 247 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 248 | 249 | elif self.norm_fn == 'batch': 250 | self.norm1 = nn.BatchNorm2d(64) 251 | 252 | elif self.norm_fn == 'instance': 253 | self.norm1 = nn.InstanceNorm2d(64) 254 | 255 | elif self.norm_fn == 'none': 256 | self.norm1 = nn.Sequential() 257 | 258 | if dcn: 259 | self.conv_offset = nn.Conv2d( 260 | 3, 261 | 2 * 7 * 7, 262 | 7, 263 | stride=_pair(2), 264 | padding=_pair(3), 265 | dilation=_pair(1), 266 | bias=True) 267 | self.conv_offset.weight.data.zero_() 268 | self.conv_offset.bias.data.zero_() 269 | self.dconv = DeformConv2d(3, 64, kernel_size=7, stride=2, padding=3) 270 | else: 271 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 272 | self.relu1 = nn.ReLU(inplace=True) 273 | 274 | self.in_planes = 64 275 | self.layer1 = self._make_layer(64, stride=1) 276 | self.layer2 = self._make_layer(96, stride=2) 277 | self.layer3 = self._make_layer(128, stride=2) 278 | 279 | # output convolution 280 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 281 | 282 | self.dropout = None 283 | if dropout > 0: 284 | self.dropout = nn.Dropout2d(p=dropout) 285 | 286 | for m in self.modules(): 287 | if isinstance(m, nn.Conv2d): 288 | nn.init.kaiming_normal_( 289 | m.weight, mode='fan_out', nonlinearity='relu') 290 | elif isinstance(m, 291 | (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 292 | if m.weight is not None: 293 | nn.init.constant_(m.weight, 1) 294 | if m.bias is not None: 295 | nn.init.constant_(m.bias, 0) 296 | 297 | def _make_layer(self, dim, stride=1): 298 | layer1 = ResidualBlock( 299 | self.in_planes, dim, self.norm_fn, stride=stride) 300 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 301 | layers = (layer1, layer2) 302 | 303 | self.in_planes = dim 304 | return nn.Sequential(*layers) 305 | 306 | def forward(self, x): 307 | 308 | # if input is list, combine batch dimension 309 | is_list = isinstance(x, tuple) or isinstance(x, list) 310 | if is_list: 311 | batch_dim = x[0].shape[0] 312 | x = torch.cat(x, dim=0) 313 | 314 | if self.dcn: 315 | offset = self.conv_offset(x) 316 | x = self.dconv(x, offset) 317 | else: 318 | x = self.conv1(x) 319 | 320 | x = self.norm1(x) 321 | x = self.relu1(x) 322 | 323 | x = self.layer1(x) 324 | x = self.layer2(x) 325 | x = self.layer3(x) 326 | 327 | x = self.conv2(x) 328 | 329 | if self.training and self.dropout is not None: 330 | x = self.dropout(x) 331 | 332 | if is_list: 333 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 334 | 335 | return x 336 | 337 | 338 | class CorrBlock: 339 | """Corr Block of RAFT, modified by Hao""" 340 | 341 | def __init__(self, 342 | fmap1, 343 | fmap2, 344 | strip_coor_map=None, 345 | num_levels=4, 346 | radius=4): 347 | self.num_levels = num_levels 348 | self.radius = radius 349 | self.corr_pyramid = [] 350 | 351 | # all pairs correlation 352 | corr = CorrBlock.corr(fmap1, fmap2) 353 | 354 | if strip_coor_map is not None: 355 | # strip correlation augmentation 356 | corr = corr + strip_coor_map 357 | 358 | batch, h1, w1, dim, h2, w2 = corr.shape 359 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 360 | 361 | self.corr_pyramid.append(corr) 362 | for i in range(self.num_levels - 1): 363 | corr = F.avg_pool2d(corr, 2, stride=2) 364 | self.corr_pyramid.append(corr) 365 | 366 | def __call__(self, coords): 367 | r = self.radius 368 | coords = coords.permute(0, 2, 3, 1) 369 | batch, h1, w1, _ = coords.shape 370 | 371 | out_pyramid = [] 372 | for i in range(self.num_levels): 373 | corr = self.corr_pyramid[i] 374 | dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 375 | dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 376 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 377 | 378 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i 379 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 380 | coords_lvl = centroid_lvl + delta_lvl 381 | 382 | corr = bilinear_sampler(corr, coords_lvl) 383 | corr = corr.view(batch, h1, w1, -1) 384 | out_pyramid.append(corr) 385 | 386 | out = torch.cat(out_pyramid, dim=-1) 387 | return out.permute(0, 3, 1, 2).contiguous().float() 388 | 389 | @staticmethod 390 | def corr(fmap1, fmap2): 391 | batch, dim, ht, wd = fmap1.shape 392 | fmap1 = fmap1.view(batch, dim, ht * wd) 393 | fmap2 = fmap2.view(batch, dim, ht * wd) 394 | 395 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 396 | corr = corr.view(batch, ht, wd, 1, ht, wd) 397 | return corr / torch.sqrt(torch.tensor(dim).float()) 398 | 399 | 400 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 401 | """Wrapper for grid_sample, uses pixel coordinates.""" 402 | H, W = img.shape[-2:] 403 | xgrid, ygrid = coords.split([1, 1], dim=-1) 404 | xgrid = 2 * xgrid / (W - 1) - 1 405 | ygrid = 2 * ygrid / (H - 1) - 1 406 | 407 | grid = torch.cat([xgrid, ygrid], dim=-1) 408 | img = F.grid_sample(img, grid, align_corners=True) 409 | 410 | if mask: 411 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 412 | return img, mask.float() 413 | 414 | return img 415 | 416 | 417 | def coords_grid(batch, ht, wd, device): 418 | coords = torch.meshgrid( 419 | torch.arange(ht, device=device), torch.arange(wd, device=device)) 420 | coords = torch.stack(coords[::-1], dim=0).float() 421 | return coords[None].repeat(batch, 1, 1, 1) 422 | 423 | 424 | def upflow8(flow, mode='bilinear'): 425 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 426 | return 8 * F.interpolate( 427 | flow, size=new_size, mode=mode, align_corners=True) 428 | 429 | 430 | class ResidualBlock(nn.Module): 431 | 432 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 433 | super(ResidualBlock, self).__init__() 434 | 435 | self.conv1 = nn.Conv2d( 436 | in_planes, planes, kernel_size=3, padding=1, stride=stride) 437 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 438 | self.relu = nn.ReLU(inplace=True) 439 | 440 | num_groups = planes // 8 441 | 442 | if norm_fn == 'group': 443 | self.norm1 = nn.GroupNorm( 444 | num_groups=num_groups, num_channels=planes) 445 | self.norm2 = nn.GroupNorm( 446 | num_groups=num_groups, num_channels=planes) 447 | if not stride == 1: 448 | self.norm3 = nn.GroupNorm( 449 | num_groups=num_groups, num_channels=planes) 450 | 451 | elif norm_fn == 'batch': 452 | self.norm1 = nn.BatchNorm2d(planes) 453 | self.norm2 = nn.BatchNorm2d(planes) 454 | if not stride == 1: 455 | self.norm3 = nn.BatchNorm2d(planes) 456 | 457 | elif norm_fn == 'instance': 458 | self.norm1 = nn.InstanceNorm2d(planes) 459 | self.norm2 = nn.InstanceNorm2d(planes) 460 | if not stride == 1: 461 | self.norm3 = nn.InstanceNorm2d(planes) 462 | 463 | elif norm_fn == 'none': 464 | self.norm1 = nn.Sequential() 465 | self.norm2 = nn.Sequential() 466 | if not stride == 1: 467 | self.norm3 = nn.Sequential() 468 | 469 | if stride == 1: 470 | self.downsample = None 471 | 472 | else: 473 | self.downsample = nn.Sequential( 474 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), 475 | self.norm3) 476 | 477 | def forward(self, x): 478 | y = x 479 | y = self.relu(self.norm1(self.conv1(y))) 480 | y = self.relu(self.norm2(self.conv2(y))) 481 | 482 | if self.downsample is not None: 483 | x = self.downsample(x) 484 | 485 | return self.relu(x + y) 486 | 487 | 488 | class BottleneckBlock(nn.Module): 489 | 490 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 491 | super(BottleneckBlock, self).__init__() 492 | 493 | self.conv1 = nn.Conv2d( 494 | in_planes, planes // 4, kernel_size=1, padding=0) 495 | self.conv2 = nn.Conv2d( 496 | planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) 497 | self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) 498 | self.relu = nn.ReLU(inplace=True) 499 | 500 | num_groups = planes // 8 501 | 502 | if norm_fn == 'group': 503 | self.norm1 = nn.GroupNorm( 504 | num_groups=num_groups, num_channels=planes // 4) 505 | self.norm2 = nn.GroupNorm( 506 | num_groups=num_groups, num_channels=planes // 4) 507 | self.norm3 = nn.GroupNorm( 508 | num_groups=num_groups, num_channels=planes) 509 | if not stride == 1: 510 | self.norm4 = nn.GroupNorm( 511 | num_groups=num_groups, num_channels=planes) 512 | 513 | elif norm_fn == 'batch': 514 | self.norm1 = nn.BatchNorm2d(planes // 4) 515 | self.norm2 = nn.BatchNorm2d(planes // 4) 516 | self.norm3 = nn.BatchNorm2d(planes) 517 | if not stride == 1: 518 | self.norm4 = nn.BatchNorm2d(planes) 519 | 520 | elif norm_fn == 'instance': 521 | self.norm1 = nn.InstanceNorm2d(planes // 4) 522 | self.norm2 = nn.InstanceNorm2d(planes // 4) 523 | self.norm3 = nn.InstanceNorm2d(planes) 524 | if not stride == 1: 525 | self.norm4 = nn.InstanceNorm2d(planes) 526 | 527 | elif norm_fn == 'none': 528 | self.norm1 = nn.Sequential() 529 | self.norm2 = nn.Sequential() 530 | self.norm3 = nn.Sequential() 531 | if not stride == 1: 532 | self.norm4 = nn.Sequential() 533 | 534 | if stride == 1: 535 | self.downsample = None 536 | 537 | else: 538 | self.downsample = nn.Sequential( 539 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), 540 | self.norm4) 541 | 542 | def forward(self, x): 543 | y = x 544 | y = self.relu(self.norm1(self.conv1(y))) 545 | y = self.relu(self.norm2(self.conv2(y))) 546 | y = self.relu(self.norm3(self.conv3(y))) 547 | 548 | if self.downsample is not None: 549 | x = self.downsample(x) 550 | 551 | return self.relu(x + y) 552 | 553 | 554 | class BasicMotionEncoder(nn.Module): 555 | 556 | def __init__(self, args): 557 | super(BasicMotionEncoder, self).__init__() 558 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1)**2 559 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 560 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 561 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 562 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 563 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) 564 | 565 | def forward(self, flow, corr): 566 | cor = F.relu(self.convc1(corr)) 567 | cor = F.relu(self.convc2(cor)) 568 | flo = F.relu(self.convf1(flow)) 569 | flo = F.relu(self.convf2(flo)) 570 | 571 | cor_flo = torch.cat([cor, flo], dim=1) 572 | out = F.relu(self.conv(cor_flo)) 573 | return torch.cat([out, flow], dim=1) 574 | 575 | 576 | class SepConvGRU(nn.Module): 577 | 578 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 579 | super(SepConvGRU, self).__init__() 580 | self.convz1 = nn.Conv2d( 581 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 582 | self.convr1 = nn.Conv2d( 583 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 584 | self.convq1 = nn.Conv2d( 585 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 586 | 587 | self.convz2 = nn.Conv2d( 588 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 589 | self.convr2 = nn.Conv2d( 590 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 591 | self.convq2 = nn.Conv2d( 592 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 593 | 594 | def forward(self, h, x): 595 | # horizontal 596 | hx = torch.cat([h, x], dim=1) 597 | z = torch.sigmoid(self.convz1(hx)) 598 | r = torch.sigmoid(self.convr1(hx)) 599 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 600 | h = (1 - z) * h + z * q 601 | 602 | # vertical 603 | hx = torch.cat([h, x], dim=1) 604 | z = torch.sigmoid(self.convz2(hx)) 605 | r = torch.sigmoid(self.convr2(hx)) 606 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 607 | h = (1 - z) * h + z * q 608 | 609 | return h 610 | 611 | 612 | class FlowHead(nn.Module): 613 | 614 | def __init__(self, input_dim=128, hidden_dim=256): 615 | super(FlowHead, self).__init__() 616 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 617 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 618 | self.relu = nn.ReLU(inplace=True) 619 | 620 | def forward(self, x): 621 | return self.conv2(self.relu(self.conv1(x))) 622 | 623 | 624 | class ConvGRU(nn.Module): 625 | 626 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 627 | super(ConvGRU, self).__init__() 628 | self.convz = nn.Conv2d( 629 | hidden_dim + input_dim, hidden_dim, 3, padding=1) 630 | self.convr = nn.Conv2d( 631 | hidden_dim + input_dim, hidden_dim, 3, padding=1) 632 | self.convq = nn.Conv2d( 633 | hidden_dim + input_dim, hidden_dim, 3, padding=1) 634 | 635 | def forward(self, h, x): 636 | hx = torch.cat([h, x], dim=1) 637 | 638 | z = torch.sigmoid(self.convz(hx)) 639 | r = torch.sigmoid(self.convr(hx)) 640 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) 641 | 642 | h = (1 - z) * h + z * q 643 | return h 644 | -------------------------------------------------------------------------------- /opticalflow/core/model/external/raft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torchvision.ops import DeformConv2d 6 | 7 | try: 8 | autocast = torch.cuda.amp.autocast 9 | except (Exception): 10 | # dummy autocast for PyTorch < 1.6 11 | class autocast: 12 | 13 | def __init__(self, enabled): 14 | pass 15 | 16 | def __enter__(self): 17 | pass 18 | 19 | def __exit__(self, *args): 20 | pass 21 | 22 | 23 | class RAFT(nn.Module): 24 | 25 | def __init__(self, args): 26 | super(RAFT, self).__init__() 27 | self.args = args 28 | 29 | self.hidden_dim = hdim = 128 30 | self.context_dim = cdim = 128 31 | args.corr_levels = 4 32 | args.corr_radius = 4 33 | 34 | if 'dropout' not in self.args: 35 | self.args.dropout = 0 36 | 37 | if 'mixed_precision' not in self.args: 38 | self.args.mixed_precision = False 39 | 40 | # feature network, context network, and update block 41 | 42 | self.fnet = BasicEncoder( 43 | output_dim=256, norm_fn='instance', dropout=args.dropout) 44 | 45 | self.cnet = BasicEncoder( 46 | output_dim=hdim + cdim, 47 | norm_fn='batch', 48 | dropout=args.dropout) 49 | 50 | self.update_block = BasicUpdateBlock( 51 | self.args, hidden_dim=hdim) 52 | 53 | def freeze_bn(self): 54 | for m in self.modules(): 55 | if isinstance(m, nn.BatchNorm2d): 56 | m.eval() 57 | 58 | def initialize_flow(self, img, dataset, train_flag): 59 | """Flow is represented as difference between two coordinate grids. 60 | 61 | flow = coords1 - coords0, Modified by Hao 62 | """ 63 | N, C, H, W = img.shape 64 | 65 | if dataset == 'KITTI' and not train_flag: 66 | coords0 = coords_grid(N, H // 8 + 1, W // 8 + 1, device=img.device) 67 | coords1 = coords_grid(N, H // 8 + 1, W // 8 + 1, device=img.device) 68 | elif dataset == 'Sintel' and not train_flag: 69 | coords0 = coords_grid(N, H // 8 + 1, W // 8, device=img.device) 70 | coords1 = coords_grid(N, H // 8 + 1, W // 8, device=img.device) 71 | else: 72 | coords0 = coords_grid(N, H // 8, W // 8, device=img.device) 73 | coords1 = coords_grid(N, H // 8, W // 8, device=img.device) 74 | 75 | # optical flow computed as difference: flow = coords1 - coords0 76 | return coords0, coords1 77 | 78 | def upsample_flow(self, flow, mask): 79 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex 80 | combination.""" 81 | N, _, H, W = flow.shape 82 | mask = mask.view(N, 1, 9, 8, 8, H, W) 83 | mask = torch.softmax(mask, dim=2) 84 | 85 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 86 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 87 | 88 | up_flow = torch.sum(mask * up_flow, dim=2) 89 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 90 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 91 | 92 | def forward(self, images, flow_init=None, upsample=True, test_mode=False, gen_fmap=False, skip_encode=False): 93 | """Estimate optical flow between pair of frames.""" 94 | 95 | if not skip_encode: 96 | # Modified, take image pairs as input 97 | image1 = images[0] 98 | image2 = images[1] 99 | image1 = 2 * (image1 / 255.0) - 1.0 100 | image2 = 2 * (image2 / 255.0) - 1.0 101 | 102 | image1 = image1.contiguous() 103 | image2 = image2.contiguous() 104 | 105 | hdim = self.hidden_dim 106 | cdim = self.context_dim 107 | 108 | # run the feature network 109 | with autocast(enabled=self.args.mixed_precision): 110 | fmap1, fmap2 = self.fnet([image1, image2]) 111 | 112 | fmap1 = fmap1.float() 113 | fmap2 = fmap2.float() 114 | else: 115 | hdim = self.hidden_dim 116 | cdim = self.context_dim 117 | fmap1 = images[0] 118 | fmap2 = images[1] 119 | 120 | # run the context network 121 | with autocast(enabled=self.args.mixed_precision): 122 | if not skip_encode: 123 | cnet = self.cnet(image1) 124 | 125 | if test_mode: 126 | if gen_fmap: 127 | return fmap1, fmap2, cnet 128 | else: 129 | cnet = images[2] 130 | 131 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 132 | net = torch.tanh(net) 133 | inp = torch.relu(inp) 134 | 135 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 136 | 137 | if not skip_encode: 138 | coords0, coords1 = self.initialize_flow(image1, self.args.dataset, 139 | self.args.train) 140 | else: 141 | N, C, H, W = fmap1.shape 142 | coords0 = coords_grid(N, H, W, device=fmap1.device) 143 | coords1 = coords_grid(N, H, W, device=fmap1.device) 144 | 145 | if flow_init is not None: 146 | coords1 = coords1 + flow_init 147 | 148 | flow_predictions = [] 149 | 150 | if not test_mode: 151 | for itr in range(self.args.iters): 152 | coords1 = coords1.detach() 153 | corr = corr_fn(coords1) # index correlation volume 154 | 155 | flow = coords1 - coords0 156 | with autocast(enabled=self.args.mixed_precision): 157 | net, up_mask, delta_flow = self.update_block( 158 | net, inp, corr, flow) 159 | 160 | # F(t+1) = F(t) + \Delta(t) 161 | coords1 = coords1 + delta_flow 162 | 163 | # upsample predictions 164 | if up_mask is None: 165 | flow_up = upflow8(coords1 - coords0) 166 | else: 167 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 168 | 169 | flow_predictions.append(flow_up) 170 | else: 171 | iters = self.args.eval_iters 172 | 173 | for itr in range(iters): 174 | coords1 = coords1.detach() 175 | corr = corr_fn(coords1) # index correlation volume 176 | 177 | flow = coords1 - coords0 178 | with autocast(enabled=self.args.mixed_precision): 179 | net, up_mask, delta_flow = self.update_block( 180 | net, inp, corr, flow) 181 | 182 | # F(t+1) = F(t) + \Delta(t) 183 | coords1 = coords1 + delta_flow 184 | 185 | # upsample predictions 186 | if up_mask is None: 187 | flow_up = upflow8(coords1 - coords0) 188 | else: 189 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 190 | 191 | flow_predictions.append(flow_up) 192 | 193 | if test_mode: 194 | return coords1 - coords0, flow_up 195 | 196 | return flow_predictions 197 | 198 | 199 | class BasicUpdateBlock(nn.Module): 200 | """Modified by Hao""" 201 | 202 | def __init__(self, args, hidden_dim=128, input_dim=128): 203 | super(BasicUpdateBlock, self).__init__() 204 | self.args = args 205 | self.encoder = BasicMotionEncoder(args) 206 | self.gru = SepConvGRU( 207 | hidden_dim=hidden_dim, input_dim=128 + hidden_dim) 208 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 209 | 210 | self.mask = nn.Sequential( 211 | nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), 212 | nn.Conv2d(256, 64 * 9, 1, padding=0)) 213 | 214 | def forward(self, net, inp, corr, flow, upsample=True): 215 | motion_features = self.encoder(flow, corr) 216 | inp = torch.cat([inp, motion_features], dim=1) 217 | 218 | net = self.gru(net, inp) 219 | delta_flow = self.flow_head(net) 220 | 221 | # scale mask to balence gradients 222 | mask = .25 * self.mask(net) 223 | return net, mask, delta_flow 224 | 225 | 226 | def pool2x(x): 227 | return F.avg_pool2d(x, 3, stride=2, padding=1) 228 | 229 | 230 | def interp(x, dest): 231 | interp_args = {'mode': 'bilinear', 'align_corners': True} 232 | return F.interpolate(x, dest.shape[2:], **interp_args) 233 | 234 | 235 | class BasicEncoder(nn.Module): 236 | 237 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, dcn=False): 238 | from torch.nn.modules.utils import _pair 239 | super(BasicEncoder, self).__init__() 240 | self.norm_fn = norm_fn 241 | 242 | self.dcn = dcn 243 | 244 | if self.norm_fn == 'group': 245 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 246 | 247 | elif self.norm_fn == 'batch': 248 | self.norm1 = nn.BatchNorm2d(64) 249 | 250 | elif self.norm_fn == 'instance': 251 | self.norm1 = nn.InstanceNorm2d(64) 252 | 253 | elif self.norm_fn == 'none': 254 | self.norm1 = nn.Sequential() 255 | 256 | if dcn: 257 | self.conv_offset = nn.Conv2d( 258 | 3, 259 | 2 * 7 * 7, 260 | 7, 261 | stride=_pair(2), 262 | padding=_pair(3), 263 | dilation=_pair(1), 264 | bias=True) 265 | self.conv_offset.weight.data.zero_() 266 | self.conv_offset.bias.data.zero_() 267 | self.dconv = DeformConv2d(3, 64, kernel_size=7, stride=2, padding=3) 268 | else: 269 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 270 | self.relu1 = nn.ReLU(inplace=True) 271 | 272 | self.in_planes = 64 273 | self.layer1 = self._make_layer(64, stride=1) 274 | self.layer2 = self._make_layer(96, stride=2) 275 | self.layer3 = self._make_layer(128, stride=2) 276 | 277 | # output convolution 278 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 279 | 280 | self.dropout = None 281 | if dropout > 0: 282 | self.dropout = nn.Dropout2d(p=dropout) 283 | 284 | for m in self.modules(): 285 | if isinstance(m, nn.Conv2d): 286 | nn.init.kaiming_normal_( 287 | m.weight, mode='fan_out', nonlinearity='relu') 288 | elif isinstance(m, 289 | (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 290 | if m.weight is not None: 291 | nn.init.constant_(m.weight, 1) 292 | if m.bias is not None: 293 | nn.init.constant_(m.bias, 0) 294 | 295 | def _make_layer(self, dim, stride=1): 296 | layer1 = ResidualBlock( 297 | self.in_planes, dim, self.norm_fn, stride=stride) 298 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 299 | layers = (layer1, layer2) 300 | 301 | self.in_planes = dim 302 | return nn.Sequential(*layers) 303 | 304 | def forward(self, x): 305 | 306 | # if input is list, combine batch dimension 307 | is_list = isinstance(x, tuple) or isinstance(x, list) 308 | if is_list: 309 | batch_dim = x[0].shape[0] 310 | x = torch.cat(x, dim=0) 311 | 312 | if self.dcn: 313 | offset = self.conv_offset(x) 314 | x = self.dconv(x, offset) 315 | else: 316 | x = self.conv1(x) 317 | 318 | x = self.norm1(x) 319 | x = self.relu1(x) 320 | 321 | x = self.layer1(x) 322 | x = self.layer2(x) 323 | x = self.layer3(x) 324 | 325 | x = self.conv2(x) 326 | 327 | if self.training and self.dropout is not None: 328 | x = self.dropout(x) 329 | 330 | if is_list: 331 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 332 | 333 | return x 334 | 335 | 336 | class CorrBlock: 337 | """Corr Block of RAFT, modified by Hao""" 338 | 339 | def __init__(self, 340 | fmap1, 341 | fmap2, 342 | strip_coor_map=None, 343 | num_levels=4, 344 | radius=4): 345 | self.num_levels = num_levels 346 | self.radius = radius 347 | self.corr_pyramid = [] 348 | 349 | # all pairs correlation 350 | corr = CorrBlock.corr(fmap1, fmap2) 351 | 352 | if strip_coor_map is not None: 353 | # strip correlation augmentation 354 | corr = corr + strip_coor_map 355 | 356 | batch, h1, w1, dim, h2, w2 = corr.shape 357 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 358 | 359 | self.corr_pyramid.append(corr) 360 | for i in range(self.num_levels - 1): 361 | corr = F.avg_pool2d(corr, 2, stride=2) 362 | self.corr_pyramid.append(corr) 363 | 364 | def __call__(self, coords): 365 | r = self.radius 366 | coords = coords.permute(0, 2, 3, 1) 367 | batch, h1, w1, _ = coords.shape 368 | 369 | out_pyramid = [] 370 | for i in range(self.num_levels): 371 | corr = self.corr_pyramid[i] 372 | dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 373 | dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 374 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 375 | 376 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i 377 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 378 | coords_lvl = centroid_lvl + delta_lvl 379 | 380 | corr = bilinear_sampler(corr, coords_lvl) 381 | corr = corr.view(batch, h1, w1, -1) 382 | out_pyramid.append(corr) 383 | 384 | out = torch.cat(out_pyramid, dim=-1) 385 | return out.permute(0, 3, 1, 2).contiguous().float() 386 | 387 | @staticmethod 388 | def corr(fmap1, fmap2): 389 | batch, dim, ht, wd = fmap1.shape 390 | fmap1 = fmap1.view(batch, dim, ht * wd) 391 | fmap2 = fmap2.view(batch, dim, ht * wd) 392 | 393 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 394 | corr = corr.view(batch, ht, wd, 1, ht, wd) 395 | return corr / torch.sqrt(torch.tensor(dim).float()) 396 | 397 | 398 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 399 | """Wrapper for grid_sample, uses pixel coordinates.""" 400 | H, W = img.shape[-2:] 401 | xgrid, ygrid = coords.split([1, 1], dim=-1) 402 | xgrid = 2 * xgrid / (W - 1) - 1 403 | ygrid = 2 * ygrid / (H - 1) - 1 404 | 405 | grid = torch.cat([xgrid, ygrid], dim=-1) 406 | img = F.grid_sample(img, grid, align_corners=True) 407 | 408 | if mask: 409 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 410 | return img, mask.float() 411 | 412 | return img 413 | 414 | 415 | def coords_grid(batch, ht, wd, device): 416 | coords = torch.meshgrid( 417 | torch.arange(ht, device=device), torch.arange(wd, device=device)) 418 | coords = torch.stack(coords[::-1], dim=0).float() 419 | return coords[None].repeat(batch, 1, 1, 1) 420 | 421 | 422 | def upflow8(flow, mode='bilinear'): 423 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 424 | return 8 * F.interpolate( 425 | flow, size=new_size, mode=mode, align_corners=True) 426 | 427 | 428 | class ResidualBlock(nn.Module): 429 | 430 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 431 | super(ResidualBlock, self).__init__() 432 | 433 | self.conv1 = nn.Conv2d( 434 | in_planes, planes, kernel_size=3, padding=1, stride=stride) 435 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 436 | self.relu = nn.ReLU(inplace=True) 437 | 438 | num_groups = planes // 8 439 | 440 | if norm_fn == 'group': 441 | self.norm1 = nn.GroupNorm( 442 | num_groups=num_groups, num_channels=planes) 443 | self.norm2 = nn.GroupNorm( 444 | num_groups=num_groups, num_channels=planes) 445 | if not stride == 1: 446 | self.norm3 = nn.GroupNorm( 447 | num_groups=num_groups, num_channels=planes) 448 | 449 | elif norm_fn == 'batch': 450 | self.norm1 = nn.BatchNorm2d(planes) 451 | self.norm2 = nn.BatchNorm2d(planes) 452 | if not stride == 1: 453 | self.norm3 = nn.BatchNorm2d(planes) 454 | 455 | elif norm_fn == 'instance': 456 | self.norm1 = nn.InstanceNorm2d(planes) 457 | self.norm2 = nn.InstanceNorm2d(planes) 458 | if not stride == 1: 459 | self.norm3 = nn.InstanceNorm2d(planes) 460 | 461 | elif norm_fn == 'none': 462 | self.norm1 = nn.Sequential() 463 | self.norm2 = nn.Sequential() 464 | if not stride == 1: 465 | self.norm3 = nn.Sequential() 466 | 467 | if stride == 1: 468 | self.downsample = None 469 | 470 | else: 471 | self.downsample = nn.Sequential( 472 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), 473 | self.norm3) 474 | 475 | def forward(self, x): 476 | y = x 477 | y = self.relu(self.norm1(self.conv1(y))) 478 | y = self.relu(self.norm2(self.conv2(y))) 479 | 480 | if self.downsample is not None: 481 | x = self.downsample(x) 482 | 483 | return self.relu(x + y) 484 | 485 | 486 | class BottleneckBlock(nn.Module): 487 | 488 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 489 | super(BottleneckBlock, self).__init__() 490 | 491 | self.conv1 = nn.Conv2d( 492 | in_planes, planes // 4, kernel_size=1, padding=0) 493 | self.conv2 = nn.Conv2d( 494 | planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) 495 | self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) 496 | self.relu = nn.ReLU(inplace=True) 497 | 498 | num_groups = planes // 8 499 | 500 | if norm_fn == 'group': 501 | self.norm1 = nn.GroupNorm( 502 | num_groups=num_groups, num_channels=planes // 4) 503 | self.norm2 = nn.GroupNorm( 504 | num_groups=num_groups, num_channels=planes // 4) 505 | self.norm3 = nn.GroupNorm( 506 | num_groups=num_groups, num_channels=planes) 507 | if not stride == 1: 508 | self.norm4 = nn.GroupNorm( 509 | num_groups=num_groups, num_channels=planes) 510 | 511 | elif norm_fn == 'batch': 512 | self.norm1 = nn.BatchNorm2d(planes // 4) 513 | self.norm2 = nn.BatchNorm2d(planes // 4) 514 | self.norm3 = nn.BatchNorm2d(planes) 515 | if not stride == 1: 516 | self.norm4 = nn.BatchNorm2d(planes) 517 | 518 | elif norm_fn == 'instance': 519 | self.norm1 = nn.InstanceNorm2d(planes // 4) 520 | self.norm2 = nn.InstanceNorm2d(planes // 4) 521 | self.norm3 = nn.InstanceNorm2d(planes) 522 | if not stride == 1: 523 | self.norm4 = nn.InstanceNorm2d(planes) 524 | 525 | elif norm_fn == 'none': 526 | self.norm1 = nn.Sequential() 527 | self.norm2 = nn.Sequential() 528 | self.norm3 = nn.Sequential() 529 | if not stride == 1: 530 | self.norm4 = nn.Sequential() 531 | 532 | if stride == 1: 533 | self.downsample = None 534 | 535 | else: 536 | self.downsample = nn.Sequential( 537 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), 538 | self.norm4) 539 | 540 | def forward(self, x): 541 | y = x 542 | y = self.relu(self.norm1(self.conv1(y))) 543 | y = self.relu(self.norm2(self.conv2(y))) 544 | y = self.relu(self.norm3(self.conv3(y))) 545 | 546 | if self.downsample is not None: 547 | x = self.downsample(x) 548 | 549 | return self.relu(x + y) 550 | 551 | 552 | class BasicMotionEncoder(nn.Module): 553 | 554 | def __init__(self, args): 555 | super(BasicMotionEncoder, self).__init__() 556 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1)**2 557 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 558 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 559 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 560 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 561 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) 562 | 563 | def forward(self, flow, corr): 564 | cor = F.relu(self.convc1(corr)) 565 | cor = F.relu(self.convc2(cor)) 566 | flo = F.relu(self.convf1(flow)) 567 | flo = F.relu(self.convf2(flo)) 568 | 569 | cor_flo = torch.cat([cor, flo], dim=1) 570 | out = F.relu(self.conv(cor_flo)) 571 | return torch.cat([out, flow], dim=1) 572 | 573 | 574 | class SepConvGRU(nn.Module): 575 | 576 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 577 | super(SepConvGRU, self).__init__() 578 | self.convz1 = nn.Conv2d( 579 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 580 | self.convr1 = nn.Conv2d( 581 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 582 | self.convq1 = nn.Conv2d( 583 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 584 | 585 | self.convz2 = nn.Conv2d( 586 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 587 | self.convr2 = nn.Conv2d( 588 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 589 | self.convq2 = nn.Conv2d( 590 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 591 | 592 | def forward(self, h, x): 593 | # horizontal 594 | hx = torch.cat([h, x], dim=1) 595 | z = torch.sigmoid(self.convz1(hx)) 596 | r = torch.sigmoid(self.convr1(hx)) 597 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 598 | h = (1 - z) * h + z * q 599 | 600 | # vertical 601 | hx = torch.cat([h, x], dim=1) 602 | z = torch.sigmoid(self.convz2(hx)) 603 | r = torch.sigmoid(self.convr2(hx)) 604 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 605 | h = (1 - z) * h + z * q 606 | 607 | return h 608 | 609 | 610 | class FlowHead(nn.Module): 611 | 612 | def __init__(self, input_dim=128, hidden_dim=256): 613 | super(FlowHead, self).__init__() 614 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 615 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 616 | self.relu = nn.ReLU(inplace=True) 617 | 618 | def forward(self, x): 619 | return self.conv2(self.relu(self.conv1(x))) 620 | 621 | 622 | class ConvGRU(nn.Module): 623 | 624 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 625 | super(ConvGRU, self).__init__() 626 | self.convz = nn.Conv2d( 627 | hidden_dim + input_dim, hidden_dim, 3, padding=1) 628 | self.convr = nn.Conv2d( 629 | hidden_dim + input_dim, hidden_dim, 3, padding=1) 630 | self.convq = nn.Conv2d( 631 | hidden_dim + input_dim, hidden_dim, 3, padding=1) 632 | 633 | def forward(self, h, x): 634 | hx = torch.cat([h, x], dim=1) 635 | 636 | z = torch.sigmoid(self.convz(hx)) 637 | r = torch.sigmoid(self.convr(hx)) 638 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) 639 | 640 | h = (1 - z) * h + z * q 641 | return h 642 | -------------------------------------------------------------------------------- /opticalflow/core/model/panoflow_csflow.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from .base_model import BaseModel, ModelMode 4 | from .external import panoflow_csflow 5 | 6 | 7 | class PanoCSFlow(BaseModel): 8 | 9 | def __init__(self, args, mode: ModelMode = ModelMode.TEST): 10 | super().__init__(mode=mode) 11 | self._model = panoflow_csflow.PanoCSFlow(args) 12 | 13 | def _preprocess(self, x: Any): 14 | if isinstance(x, (tuple, list)): 15 | x = x[0] 16 | return x 17 | 18 | def _forward_test(self, x: Any): 19 | self._model.eval() 20 | return self._model(x) 21 | 22 | def _forward_train(self, x: Any): 23 | return self._model(x) 24 | -------------------------------------------------------------------------------- /opticalflow/core/model/panoflow_raft.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from .base_model import BaseModel, ModelMode 4 | from .external import panoflow_raft 5 | 6 | 7 | class PanoRAFT(BaseModel): 8 | 9 | def __init__(self, args, mode: ModelMode = ModelMode.TEST): 10 | super().__init__(mode=mode) 11 | self._model = panoflow_raft.PanoRAFT(args) 12 | 13 | def _preprocess(self, x: Any): 14 | if isinstance(x, (tuple, list)): 15 | x = x[0] 16 | return x 17 | 18 | def _forward_test(self, x: Any): 19 | self._model.eval() 20 | return self._model(x) 21 | 22 | def _forward_train(self, x: Any): 23 | return self._model(x) 24 | -------------------------------------------------------------------------------- /opticalflow/core/model/raft.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from .base_model import BaseModel, ModelMode 4 | from .external import raft 5 | 6 | 7 | class RAFT(BaseModel): 8 | 9 | def __init__(self, args, mode: ModelMode = ModelMode.TEST): 10 | super().__init__(mode=mode) 11 | self._model = raft.RAFT(args) 12 | 13 | def _preprocess(self, x: Any): 14 | if isinstance(x, (tuple, list)): 15 | x = x[0] 16 | return x 17 | 18 | def _forward_test(self, x: Any): 19 | self._model.eval() 20 | return self._model(x) 21 | 22 | def _forward_train(self, x: Any): 23 | return self._model(x) 24 | -------------------------------------------------------------------------------- /opticalflow/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_flow import FlowDataset 2 | from .flying_chairs import FlyingChairs 3 | from .flying_things import FlyingThings3D 4 | from .flow360 import Flow360 5 | from .omni import OmniDataset 6 | 7 | __all__ = [ 8 | 'FlowDataset', 'FlyingChairs', 'FlyingThings3D', 'Flow360', 'OmniDataset' 9 | ] 10 | -------------------------------------------------------------------------------- /opticalflow/dataset/base_flow.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | from opticalflow.utils import augmentor, flow_utils 8 | 9 | 10 | class FlowDataset(Dataset): 11 | 12 | def __init__(self, aug_params=None, sparse=False): 13 | self.augmentor = None 14 | self.sparse = sparse 15 | if aug_params is not None: 16 | if sparse: 17 | self.augmentor = augmentor.SparseFlowAugmentor(**aug_params) 18 | else: 19 | self.augmentor = augmentor.FlowAugmentor(**aug_params) 20 | 21 | self.is_test = False 22 | self.init_seed = False 23 | self.flow_list = [] 24 | self.image_list = [] 25 | self.extra_info = [] 26 | 27 | def __getitem__(self, index): 28 | 29 | if self.is_test: 30 | img1 = flow_utils.read_gen(self.image_list[index][0]) 31 | img2 = flow_utils.read_gen(self.image_list[index][1]) 32 | img1 = np.array(img1).astype(np.uint8)[..., :3] 33 | img2 = np.array(img2).astype(np.uint8)[..., :3] 34 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 35 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 36 | return img1, img2, self.extra_info[index] 37 | 38 | if not self.init_seed: 39 | worker_info = torch.utils.data.get_worker_info() 40 | if worker_info is not None: 41 | torch.manual_seed(worker_info.id) 42 | np.random.seed(worker_info.id) 43 | random.seed(worker_info.id) 44 | self.init_seed = True 45 | 46 | index = index % len(self.image_list) 47 | valid = None 48 | if self.sparse: 49 | flow, valid = flow_utils.readFlowKITTI(self.flow_list[index]) 50 | else: 51 | flow = flow_utils.read_gen(self.flow_list[index]) 52 | 53 | img1 = flow_utils.read_gen(self.image_list[index][0]) 54 | img2 = flow_utils.read_gen(self.image_list[index][1]) 55 | 56 | flow = np.array(flow).astype(np.float32) 57 | img1 = np.array(img1).astype(np.uint8) 58 | img2 = np.array(img2).astype(np.uint8) 59 | 60 | # grayscale images 61 | if len(img1.shape) == 2: 62 | img1 = np.tile(img1[..., None], (1, 1, 3)) 63 | img2 = np.tile(img2[..., None], (1, 1, 3)) 64 | else: 65 | img1 = img1[..., :3] 66 | img2 = img2[..., :3] 67 | 68 | if self.augmentor is not None: 69 | if self.sparse: 70 | img1, img2, flow, valid = self.augmentor( 71 | img1, img2, flow, valid) 72 | else: 73 | img1, img2, flow = self.augmentor(img1, img2, flow) 74 | 75 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 76 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 77 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 78 | 79 | if valid is not None: 80 | valid = torch.from_numpy(valid) 81 | else: 82 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 83 | 84 | return img1, img2, flow, valid.float() 85 | 86 | def __rmul__(self, v): 87 | self.flow_list = v * self.flow_list 88 | self.image_list = v * self.image_list 89 | return self 90 | 91 | def __len__(self): 92 | return len(self.image_list) 93 | -------------------------------------------------------------------------------- /opticalflow/dataset/flow360.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from glob import glob 4 | 5 | from .base_flow import FlowDataset 6 | 7 | 8 | class Flow360(FlowDataset): 9 | 10 | def __init__(self, 11 | aug_params=None, 12 | split='train', 13 | root='datasets/Flow360', 14 | dstype='sunny'): 15 | super(Flow360, self).__init__(aug_params) 16 | 17 | flow_root = osp.join(root, split, dstype, 'flow') 18 | image_root = osp.join(root, split, dstype, 'img') 19 | 20 | for scene in os.listdir(image_root): 21 | image_list = sorted(glob(osp.join(image_root, scene, '*.jpg'))) 22 | for i in range(len(image_list) - 1): 23 | self.image_list += [[image_list[i], image_list[i + 1]]] 24 | self.extra_info += [(scene, i)] # scene and frame_id 25 | 26 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 27 | -------------------------------------------------------------------------------- /opticalflow/dataset/flying_chairs.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from glob import glob 3 | 4 | import numpy as np 5 | 6 | from .base_flow import FlowDataset 7 | 8 | 9 | class FlyingChairs(FlowDataset): 10 | 11 | def __init__(self, 12 | aug_params=None, 13 | split='train', 14 | root='datasets/FlyingChairs_release/data'): 15 | super(FlyingChairs, self).__init__(aug_params) 16 | 17 | images = sorted(glob(osp.join(root, '*.ppm'))) 18 | flows = sorted(glob(osp.join(root, '*.flo'))) 19 | assert (len(images) // 2 == len(flows)) 20 | 21 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 22 | for i in range(len(flows)): 23 | xid = split_list[i] 24 | if (split == 'training' and xid == 1) or (split == 'validation' 25 | and xid == 2): 26 | self.flow_list += [flows[i]] 27 | self.image_list += [[images[2 * i], images[2 * i + 1]]] 28 | -------------------------------------------------------------------------------- /opticalflow/dataset/flying_things.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from glob import glob 3 | 4 | from .base_flow import FlowDataset 5 | 6 | 7 | class FlyingThings3D(FlowDataset): 8 | 9 | def __init__(self, 10 | aug_params=None, 11 | root='datasets/FlyingThings3D', 12 | dstype='frames_cleanpass'): 13 | super(FlyingThings3D, self).__init__(aug_params) 14 | 15 | for cam in ['left']: 16 | for direction in ['into_future', 'into_past']: 17 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 18 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 19 | 20 | flow_dirs = sorted( 21 | glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 22 | flow_dirs = sorted( 23 | [osp.join(f, direction, cam) for f in flow_dirs]) 24 | 25 | for idir, fdir in zip(image_dirs, flow_dirs): 26 | images = sorted(glob(osp.join(idir, '*.png'))) 27 | flows = sorted(glob(osp.join(fdir, '*.pfm'))) 28 | for i in range(len(flows) - 1): 29 | if direction == 'into_future': 30 | self.image_list += [[images[i], images[i + 1]]] 31 | self.flow_list += [flows[i]] 32 | elif direction == 'into_past': 33 | self.image_list += [[images[i + 1], images[i]]] 34 | self.flow_list += [flows[i + 1]] 35 | -------------------------------------------------------------------------------- /opticalflow/dataset/omni.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from glob import glob 3 | 4 | from .base_flow import FlowDataset 5 | 6 | 7 | class OmniDataset(FlowDataset): 8 | 9 | def __init__(self, 10 | aug_params=None, 11 | root='datasets/OMNIFLOWNET_DATASET', 12 | dstype='Forest', 13 | is_test=False): 14 | super(OmniDataset, self).__init__(aug_params) 15 | 16 | self.is_test = is_test 17 | self.dstype = dstype 18 | 19 | for id in range(5): 20 | if id == 0: 21 | name = self.dstype 22 | else: 23 | name = f'{self.dstype}_{id}' 24 | 25 | flow_root = osp.join(root, dstype, name, 'ground_truth') 26 | image_root = osp.join(root, dstype, name, 'images') 27 | 28 | image_list = sorted(glob(osp.join(image_root, '*.png'))) 29 | for i in range(len(image_list) - 1): 30 | self.image_list += [[image_list[i], image_list[i + 1]]] 31 | self.extra_info += [i] # frame_id 32 | 33 | self.flow_list += sorted( 34 | glob(osp.join(flow_root, '*.flo')))[0:-1] 35 | -------------------------------------------------------------------------------- /opticalflow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /opticalflow/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | from torchvision.transforms import ColorJitter 7 | 8 | from opticalflow.api.data_augment import distort_flow, distort_img 9 | 10 | cv2.setNumThreads(0) 11 | cv2.ocl.setUseOpenCL(False) 12 | 13 | 14 | class FlowAugmentor: 15 | 16 | def __init__(self, 17 | crop_size, 18 | min_scale=-0.2, 19 | max_scale=0.5, 20 | do_flip=True, 21 | do_distort=False): 22 | 23 | # spatial augmentation params 24 | self.crop_size = crop_size 25 | self.min_scale = min_scale 26 | self.max_scale = max_scale 27 | self.spatial_aug_prob = 0.8 28 | self.stretch_prob = 0.8 29 | self.max_stretch = 0.2 30 | 31 | # flip augmentation params 32 | self.do_flip = do_flip 33 | self.h_flip_prob = 0.5 34 | self.v_flip_prob = 0.1 35 | 36 | # photometric augmentation params 37 | self.photo_aug = ColorJitter( 38 | brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14) 39 | self.asymmetric_color_aug_prob = 0.2 40 | self.eraser_aug_prob = 0.5 41 | 42 | # distortion augmentation params 43 | self.do_distort = do_distort 44 | self.do_distort_prob = 0.5 45 | self.k2_limit = 1e-6 46 | self.k4_limit = 1e-14 47 | self.k6_limit = 0 48 | 49 | def color_transform(self, img1, img2): 50 | """Photometric augmentation.""" 51 | 52 | # asymmetric 53 | if np.random.rand() < self.asymmetric_color_aug_prob: 54 | img1 = np.array( 55 | self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 56 | img2 = np.array( 57 | self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 58 | 59 | # symmetric 60 | else: 61 | image_stack = np.concatenate([img1, img2], axis=0) 62 | image_stack = np.array( 63 | self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 64 | img1, img2 = np.split(image_stack, 2, axis=0) 65 | 66 | return img1, img2 67 | 68 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 69 | """Occlusion augmentation.""" 70 | 71 | ht, wd = img1.shape[:2] 72 | if np.random.rand() < self.eraser_aug_prob: 73 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 74 | for _ in range(np.random.randint(1, 3)): 75 | x0 = np.random.randint(0, wd) 76 | y0 = np.random.randint(0, ht) 77 | dx = np.random.randint(bounds[0], bounds[1]) 78 | dy = np.random.randint(bounds[0], bounds[1]) 79 | img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color 80 | 81 | return img1, img2 82 | 83 | def spatial_transform(self, img1, img2, flow): 84 | # randomly sample scale 85 | ht, wd = img1.shape[:2] 86 | min_scale = np.maximum((self.crop_size[0] + 8) / float(ht), 87 | (self.crop_size[1] + 8) / float(wd)) 88 | 89 | scale = 2**np.random.uniform(self.min_scale, self.max_scale) 90 | scale_x = scale 91 | scale_y = scale 92 | if np.random.rand() < self.stretch_prob: 93 | scale_x *= 2**np.random.uniform(-self.max_stretch, 94 | self.max_stretch) 95 | scale_y *= 2**np.random.uniform(-self.max_stretch, 96 | self.max_stretch) 97 | 98 | scale_x = np.clip(scale_x, min_scale, None) 99 | scale_y = np.clip(scale_y, min_scale, None) 100 | 101 | if np.random.rand() < self.spatial_aug_prob: 102 | # rescale the images 103 | img1 = cv2.resize( 104 | img1, 105 | None, 106 | fx=scale_x, 107 | fy=scale_y, 108 | interpolation=cv2.INTER_LINEAR) 109 | img2 = cv2.resize( 110 | img2, 111 | None, 112 | fx=scale_x, 113 | fy=scale_y, 114 | interpolation=cv2.INTER_LINEAR) 115 | flow = cv2.resize( 116 | flow, 117 | None, 118 | fx=scale_x, 119 | fy=scale_y, 120 | interpolation=cv2.INTER_LINEAR) 121 | flow = flow * [scale_x, scale_y] 122 | 123 | if self.do_flip: 124 | if np.random.rand() < self.h_flip_prob: # h-flip 125 | img1 = img1[:, ::-1] 126 | img2 = img2[:, ::-1] 127 | flow = flow[:, ::-1] * [-1.0, 1.0] 128 | 129 | if np.random.rand() < self.v_flip_prob: # v-flip 130 | img1 = img1[::-1, :] 131 | img2 = img2[::-1, :] 132 | flow = flow[::-1, :] * [1.0, -1.0] 133 | 134 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 135 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 136 | 137 | img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 138 | img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 139 | flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 140 | 141 | return img1, img2, flow 142 | 143 | def distortion_transform(self, img1, img2, flow): 144 | """Optical distortion augmentation.""" 145 | 146 | if self.do_distort: 147 | if np.random.rand() < self.do_distort_prob: 148 | # random k 149 | k2 = random.uniform(-self.k2_limit, self.k2_limit) 150 | k4 = random.uniform(-self.k4_limit, self.k4_limit) 151 | k6 = random.uniform(-self.k6_limit, self.k6_limit) 152 | 153 | # distort img and flow 154 | img1 = distort_img(img1, 'radial', ks=[0, k2, 0, k4, 0, k6]) 155 | img2 = distort_img(img2, 'radial', ks=[0, k2, 0, k4, 0, k6]) 156 | flow = distort_flow(flow, 'radial', ks=[0, k2, 0, k4, 0, k6]) 157 | 158 | return img1, img2, flow 159 | 160 | def __call__(self, img1, img2, flow): 161 | img1, img2 = self.color_transform(img1, img2) 162 | img1, img2 = self.eraser_transform(img1, img2) 163 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 164 | img1, img2, flow = self.distortion_transform(img1, img2, flow) 165 | 166 | img1 = np.ascontiguousarray(img1) 167 | img2 = np.ascontiguousarray(img2) 168 | flow = np.ascontiguousarray(flow) 169 | 170 | return img1, img2, flow 171 | 172 | 173 | class SparseFlowAugmentor: 174 | """used only for KITTI, modified""" 175 | 176 | def __init__(self, 177 | crop_size, 178 | min_scale=-0.2, 179 | max_scale=0.5, 180 | do_flip=False, 181 | do_distort=False): 182 | # spatial augmentation params 183 | self.crop_size = crop_size 184 | self.min_scale = min_scale 185 | self.max_scale = max_scale 186 | self.spatial_aug_prob = 0.8 187 | self.stretch_prob = 0.8 188 | self.max_stretch = 0.2 189 | 190 | # flip augmentation params 191 | self.do_flip = do_flip 192 | self.h_flip_prob = 0.5 193 | self.v_flip_prob = 0.1 194 | 195 | # photometric augmentation params 196 | self.photo_aug = ColorJitter( 197 | brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14) 198 | self.asymmetric_color_aug_prob = 0.2 199 | self.eraser_aug_prob = 0.5 200 | 201 | def color_transform(self, img1, img2): 202 | image_stack = np.concatenate([img1, img2], axis=0) 203 | image_stack = np.array( 204 | self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 205 | img1, img2 = np.split(image_stack, 2, axis=0) 206 | return img1, img2 207 | 208 | def eraser_transform(self, img1, img2): 209 | ht, wd = img1.shape[:2] 210 | if np.random.rand() < self.eraser_aug_prob: 211 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 212 | for _ in range(np.random.randint(1, 3)): 213 | x0 = np.random.randint(0, wd) 214 | y0 = np.random.randint(0, ht) 215 | dx = np.random.randint(50, 100) 216 | dy = np.random.randint(50, 100) 217 | img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color 218 | 219 | return img1, img2 220 | 221 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 222 | ht, wd = flow.shape[:2] 223 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 224 | coords = np.stack(coords, axis=-1) 225 | 226 | coords = coords.reshape(-1, 2).astype(np.float32) 227 | flow = flow.reshape(-1, 2).astype(np.float32) 228 | valid = valid.reshape(-1).astype(np.float32) 229 | 230 | coords0 = coords[valid >= 1] 231 | flow0 = flow[valid >= 1] 232 | 233 | ht1 = int(round(ht * fy)) 234 | wd1 = int(round(wd * fx)) 235 | 236 | coords1 = coords0 * [fx, fy] 237 | flow1 = flow0 * [fx, fy] 238 | 239 | xx = np.round(coords1[:, 0]).astype(np.int32) 240 | yy = np.round(coords1[:, 1]).astype(np.int32) 241 | 242 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 243 | xx = xx[v] 244 | yy = yy[v] 245 | flow1 = flow1[v] 246 | 247 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 248 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 249 | 250 | flow_img[yy, xx] = flow1 251 | valid_img[yy, xx] = 1 252 | 253 | return flow_img, valid_img 254 | 255 | def spatial_transform(self, img1, img2, flow, valid): 256 | # randomly sample scale 257 | 258 | ht, wd = img1.shape[:2] 259 | min_scale = np.maximum((self.crop_size[0] + 1) / float(ht), 260 | (self.crop_size[1] + 1) / float(wd)) 261 | 262 | scale = 2**np.random.uniform(self.min_scale, self.max_scale) 263 | scale_x = np.clip(scale, min_scale, None) 264 | scale_y = np.clip(scale, min_scale, None) 265 | 266 | if np.random.rand() < self.spatial_aug_prob: 267 | # rescale the images 268 | img1 = cv2.resize( 269 | img1, 270 | None, 271 | fx=scale_x, 272 | fy=scale_y, 273 | interpolation=cv2.INTER_LINEAR) 274 | img2 = cv2.resize( 275 | img2, 276 | None, 277 | fx=scale_x, 278 | fy=scale_y, 279 | interpolation=cv2.INTER_LINEAR) 280 | flow, valid = self.resize_sparse_flow_map( 281 | flow, valid, fx=scale_x, fy=scale_y) 282 | 283 | if self.do_flip: 284 | if np.random.rand() < 0.5: # h-flip 285 | img1 = img1[:, ::-1] 286 | img2 = img2[:, ::-1] 287 | flow = flow[:, ::-1] * [-1.0, 1.0] 288 | valid = valid[:, ::-1] 289 | 290 | margin_y = 20 291 | margin_x = 50 292 | 293 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 294 | x0 = np.random.randint(-margin_x, 295 | img1.shape[1] - self.crop_size[1] + margin_x) 296 | 297 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 298 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 299 | 300 | img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 301 | img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 302 | flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 303 | valid = valid[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 304 | return img1, img2, flow, valid 305 | 306 | def __call__(self, img1, img2, flow, valid): 307 | img1, img2 = self.color_transform(img1, img2) 308 | img1, img2 = self.eraser_transform(img1, img2) 309 | img1, img2, flow, valid = self.spatial_transform( 310 | img1, img2, flow, valid) 311 | 312 | img1 = np.ascontiguousarray(img1) 313 | img2 = np.ascontiguousarray(img2) 314 | flow = np.ascontiguousarray(flow) 315 | valid = np.ascontiguousarray(valid) 316 | 317 | return img1, img2, flow, valid 318 | -------------------------------------------------------------------------------- /opticalflow/utils/flow_utils.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import re 3 | import cv2 4 | import torch 5 | import numpy as np 6 | from PIL import Image 7 | from scipy import interpolate 8 | from torch import from_numpy 9 | 10 | # the header of writeFlow() 11 | TAG_CHAR = np.array([202021.25], np.float32) 12 | 13 | 14 | def make_colorwheel(): 15 | """ 16 | Generates a color wheel for optical flow visualization as presented in: 17 | Baker et al. "A Database and Evaluation Methodology 18 | for Optical Flow" (ICCV, 2007) 19 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 20 | 21 | Code follows the original C++ source code of Daniel Scharstein. 22 | Code follows the the Matlab source code of Deqing Sun. 23 | 24 | Returns: 25 | np.ndarray: Color wheel 26 | """ 27 | 28 | RY = 15 29 | YG = 6 30 | GC = 4 31 | CB = 11 32 | BM = 13 33 | MR = 6 34 | 35 | ncols = RY + YG + GC + CB + BM + MR 36 | colorwheel = np.zeros((ncols, 3)) 37 | col = 0 38 | 39 | # RY 40 | colorwheel[0:RY, 0] = 255 41 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 42 | col = col + RY 43 | # YG 44 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 45 | colorwheel[col:col + YG, 1] = 255 46 | col = col + YG 47 | # GC 48 | colorwheel[col:col + GC, 1] = 255 49 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 50 | col = col + GC 51 | # CB 52 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 53 | colorwheel[col:col + CB, 2] = 255 54 | col = col + CB 55 | # BM 56 | colorwheel[col:col + BM, 2] = 255 57 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 58 | col = col + BM 59 | # MR 60 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 61 | colorwheel[col:col + MR, 0] = 255 62 | return colorwheel 63 | 64 | 65 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 66 | """Applies the flow color wheel to (possibly clipped) flow components u 67 | andv. 68 | 69 | According to the C++ source code of Daniel Scharstein 70 | According to the Matlab source code of Deqing Sun 71 | 72 | Args: 73 | u (np.ndarray): Input horizontal flow of shape [H,W] 74 | v (np.ndarray): Input vertical flow of shape [H,W] 75 | convert_to_bgr (bool, optional): Convert output image 76 | to BGR. Defaults to False. 77 | 78 | Returns: 79 | np.ndarray: Flow visualization image of shape [H,W,3] 80 | """ 81 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 82 | colorwheel = make_colorwheel() # shape [55x3] 83 | ncols = colorwheel.shape[0] 84 | rad = np.sqrt(np.square(u) + np.square(v)) 85 | a = np.arctan2(-v, -u) / np.pi 86 | fk = (a + 1) / 2 * (ncols - 1) 87 | k0 = np.floor(fk).astype(np.int32) 88 | k1 = k0 + 1 89 | k1[k1 == ncols] = 0 90 | f = fk - k0 91 | for i in range(colorwheel.shape[1]): 92 | tmp = colorwheel[:, i] 93 | col0 = tmp[k0] / 255.0 94 | col1 = tmp[k1] / 255.0 95 | col = (1 - f) * col0 + f * col1 96 | idx = (rad <= 1) 97 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 98 | col[~idx] = col[~idx] * 0.75 # out of range 99 | # Note the 2-i => BGR instead of RGB 100 | ch_idx = 2 - i if convert_to_bgr else i 101 | flow_image[:, :, ch_idx] = np.floor(255 * col) 102 | return flow_image 103 | 104 | 105 | def better_flow_to_image(flow_uv, 106 | alpha=0.5, 107 | max_flow=724, 108 | clip_flow=None, 109 | convert_to_bgr=False): 110 | """Used for visualize extremely large-distance flow""" 111 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 112 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 113 | if clip_flow is not None: 114 | flow_uv = np.clip(flow_uv, 0, clip_flow) 115 | u = flow_uv[:, :, 0] 116 | v = flow_uv[:, :, 1] 117 | rad = np.sqrt(np.square(u) + np.square(v)) 118 | rad_max = max_flow 119 | param_with_alpha = np.power(rad / max_flow, alpha) 120 | epsilon = 1e-5 121 | u = param_with_alpha * u / (rad_max + epsilon) 122 | v = param_with_alpha * v / (rad_max + epsilon) 123 | return flow_uv_to_colors(u, v, convert_to_bgr) 124 | 125 | 126 | def forward_interpolate(flow): 127 | """Interpolate flow for warm start, from RAFT.""" 128 | flow = flow.detach().cpu().numpy() 129 | dx, dy = flow[0], flow[1] 130 | 131 | ht, wd = dx.shape 132 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 133 | 134 | x1 = x0 + dx 135 | y1 = y0 + dy 136 | 137 | x1 = x1.reshape(-1) 138 | y1 = y1.reshape(-1) 139 | dx = dx.reshape(-1) 140 | dy = dy.reshape(-1) 141 | 142 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 143 | x1 = x1[valid] 144 | y1 = y1[valid] 145 | dx = dx[valid] 146 | dy = dy[valid] 147 | 148 | flow_x = interpolate.griddata((x1, y1), 149 | dx, (x0, y0), 150 | method='nearest', 151 | fill_value=0) 152 | 153 | flow_y = interpolate.griddata((x1, y1), 154 | dy, (x0, y0), 155 | method='nearest', 156 | fill_value=0) 157 | 158 | flow = np.stack([flow_x, flow_y], axis=0) 159 | return from_numpy(flow).float() 160 | 161 | 162 | def readFlow(fn): 163 | """Read .flo file in Middlebury format.""" 164 | # Code adapted from: 165 | # http://stackoverflow.com/questions/28013200/reading-middlebury- \ 166 | # flow-files-with-python-bytes-array-numpy 167 | 168 | # WARNING: this will work on little-endian architectures 169 | # (eg Intel x86) only! 170 | # print 'fn = %s'%(fn) 171 | with open(fn, 'rb') as f: 172 | magic = np.fromfile(f, np.float32, count=1) 173 | if 202021.25 != magic: 174 | print('Magic number incorrect. Invalid .flo file') 175 | return None 176 | else: 177 | w = np.fromfile(f, np.int32, count=1) 178 | h = np.fromfile(f, np.int32, count=1) 179 | # print 'Reading %d x %d flo file\n' % (w, h) 180 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 181 | # Reshape data into 3D array (columns, rows, bands) 182 | # The reshape here is for visualization, the original code 183 | # is (w,h,2) 184 | return np.resize(data, (int(h), int(w), 2)) 185 | 186 | 187 | def writeFlow(filename, uv, v=None): 188 | """Write optical flow to file, from RAFT. 189 | 190 | If v is None, uv is assumed to contain both u and v channels, stacked in 191 | depth. Original code by Deqing Sun, adapted from Daniel Scharstein. 192 | """ 193 | nBands = 2 194 | 195 | if v is None: 196 | assert (uv.ndim == 3) 197 | assert (uv.shape[2] == 2) 198 | u = uv[:, :, 0] 199 | v = uv[:, :, 1] 200 | else: 201 | u = uv 202 | 203 | assert (u.shape == v.shape) 204 | height, width = u.shape 205 | f = open(filename, 'wb') 206 | # write the header 207 | f.write(TAG_CHAR) 208 | np.array(width).astype(np.int32).tofile(f) 209 | np.array(height).astype(np.int32).tofile(f) 210 | # arrange into matrix form 211 | tmp = np.zeros((height, width * nBands)) 212 | tmp[:, np.arange(width) * 2] = u 213 | tmp[:, np.arange(width) * 2 + 1] = v 214 | tmp.astype(np.float32).tofile(f) 215 | f.close() 216 | 217 | 218 | def readPFM(file): 219 | file = open(file, 'rb') 220 | 221 | color = None 222 | width = None 223 | height = None 224 | scale = None 225 | endian = None 226 | 227 | header = file.readline().rstrip() 228 | if header == b'PF': 229 | color = True 230 | elif header == b'Pf': 231 | color = False 232 | else: 233 | raise Exception('Not a PFM file.') 234 | 235 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 236 | if dim_match: 237 | width, height = map(int, dim_match.groups()) 238 | else: 239 | raise Exception('Malformed PFM header.') 240 | 241 | scale = float(file.readline().rstrip()) 242 | if scale < 0: # little-endian 243 | endian = '<' 244 | scale = -scale 245 | else: 246 | endian = '>' # big-endian 247 | 248 | data = np.fromfile(file, endian + 'f') 249 | shape = (height, width, 3) if color else (height, width) 250 | 251 | data = np.reshape(data, shape) 252 | data = np.flipud(data) 253 | return data 254 | 255 | 256 | def read_gen(file_name, pil=False): 257 | ext = osp.splitext(file_name)[-1] 258 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 259 | return Image.open(file_name) 260 | elif ext == '.bin' or ext == '.raw': 261 | return np.load(file_name) 262 | elif ext == '.flo': 263 | return readFlow(file_name).astype(np.float32) 264 | elif ext == '.pfm': 265 | flow = readPFM(file_name).astype(np.float32) 266 | if len(flow.shape) == 2: 267 | return flow 268 | else: 269 | return flow[:, :, :-1] 270 | return [] 271 | 272 | 273 | def readFlowKITTI(filename): 274 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) 275 | flow = flow[:, :, ::-1].astype(np.float32) 276 | flow, valid = flow[:, :, :2], flow[:, :, 2] 277 | flow = (flow - 2**15) / 64.0 278 | return flow, valid 279 | 280 | 281 | def writeFlowKITTI(filename, uv): 282 | uv = 64.0 * uv + 2**15 283 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 284 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 285 | cv2.imwrite(filename, uv[..., ::-1]) 286 | 287 | 288 | def convert_360_gt(flow_gt): 289 | '''Convert gt to 360 flow''' 290 | flow_gt = flow_gt.unsqueeze(dim=0) 291 | flow_gt[:, 0] = torch.where(flow_gt[:, 0] > (flow_gt.shape[3] // 2), 292 | flow_gt[:, 0] - flow_gt.shape[3], 293 | flow_gt[:, 0]) 294 | flow_gt[:, 0] = torch.where(flow_gt[:, 0] < -(flow_gt.shape[3] // 2), 295 | flow_gt.shape[3] + flow_gt[:, 0], 296 | flow_gt[:, 0]) 297 | return flow_gt.squeeze() 298 | -------------------------------------------------------------------------------- /opticalflow/utils/logger.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | 3 | 4 | class Logger: 5 | """"Logger from RAFT.""" 6 | 7 | def __init__(self, model, scheduler, SUM_FREQ=100, start_step=0): 8 | self.model = model 9 | self.scheduler = scheduler 10 | self.total_steps = start_step 11 | self.running_loss = {} 12 | self.writer = None 13 | self.SUM_FREQ = SUM_FREQ 14 | 15 | def _print_training_status(self): 16 | metrics_data = [ 17 | self.running_loss[k] / self.SUM_FREQ 18 | for k in sorted(self.running_loss.keys()) 19 | ] 20 | training_str = '[{:6d}, {:10.7f}] '.format( 21 | self.total_steps + 1, 22 | self.scheduler.get_last_lr()[0]) 23 | metrics_str = ('{:10.4f}, ' * len(metrics_data)).format(*metrics_data) 24 | 25 | # print the training status 26 | print(training_str + metrics_str) 27 | 28 | if self.writer is None: 29 | self.writer = SummaryWriter() 30 | 31 | for k in self.running_loss: 32 | self.writer.add_scalar(k, self.running_loss[k] / self.SUM_FREQ, 33 | self.total_steps) 34 | self.running_loss[k] = 0.0 35 | 36 | def push(self, metrics): 37 | self.total_steps += 1 38 | 39 | for key in metrics: 40 | if key not in self.running_loss: 41 | self.running_loss[key] = 0.0 42 | 43 | self.running_loss[key] += metrics[key] 44 | 45 | if self.total_steps % self.SUM_FREQ == self.SUM_FREQ - 1: 46 | self._print_training_status() 47 | self.running_loss = {} 48 | 49 | def write_dict(self, results): 50 | if self.writer is None: 51 | self.writer = SummaryWriter() 52 | 53 | for key in results: 54 | self.writer.add_scalar(key, results[key], self.total_steps) 55 | 56 | def close(self): 57 | self.writer.close() 58 | -------------------------------------------------------------------------------- /opticalflow/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from collections import OrderedDict 3 | import numpy as np 4 | 5 | 6 | class InputPadder: 7 | """Pads images such that dimensions are divisible by 8 , from RAFT.""" 8 | 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [ 15 | pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, 16 | pad_ht - pad_ht // 2 17 | ] 18 | else: 19 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] 20 | 21 | def pad(self, *inputs): 22 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 23 | 24 | def unpad(self, x): 25 | ht, wd = x.shape[-2:] 26 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 27 | return x[..., c[0]:c[1], c[2]:c[3]] 28 | 29 | 30 | def fill_order_keys(key, fill_value='_model.', fill_position=7): 31 | """fill order_dict keys in checkpoint, by Hao.""" 32 | return key[0:fill_position] + fill_value + key[fill_position:] 33 | 34 | 35 | def fix_order_keys(key, delete_value=6): 36 | """fix order_dict keys in checkpoint, by Hao.""" 37 | return key[0:delete_value] + key[13:] 38 | 39 | 40 | def fix_read_order_keys(key, start_value=7): 41 | """fix reading restored ckpt order_dict keys, by Hao.""" 42 | return key[start_value:] 43 | 44 | 45 | # CARLA semantic labels 46 | camvid_colors = OrderedDict([ 47 | ("Unlabeled", np.array([0, 0, 0], dtype=np.uint8)), 48 | ("Building", np.array([70, 70, 70], dtype=np.uint8)), 49 | ("Fence", np.array([100, 40, 40], dtype=np.uint8)), 50 | ("Other", np.array([55, 90, 80], dtype=np.uint8)), 51 | ("Pedestrian", np.array([220, 20, 60], dtype=np.uint8)), 52 | ("Pole", np.array([153, 153, 153], dtype=np.uint8)), 53 | ("RoadLine", np.array([157, 234, 50], dtype=np.uint8)), 54 | ("Road", np.array([128, 64, 128], dtype=np.uint8)), 55 | ("SideWalk", np.array([244, 35, 232], dtype=np.uint8)), 56 | ("Vegetation", np.array([107, 142, 35], dtype=np.uint8)), 57 | ("Vehicles", np.array([0, 0, 142], dtype=np.uint8)), 58 | ("Wall", np.array([102, 102, 156], dtype=np.uint8)), 59 | ("TrafficSign", np.array([220, 220, 0], dtype=np.uint8)), 60 | ("Sky", np.array([70, 130, 180], dtype=np.uint8)), 61 | ("Ground", np.array([81, 0, 81], dtype=np.uint8)), 62 | ("Bridge", np.array([150, 100, 100], dtype=np.uint8)), 63 | ("RailTrack", np.array([230, 150, 140], dtype=np.uint8)), 64 | ("GroundRail", np.array([180, 165, 180], dtype=np.uint8)), 65 | ("TrafficLight", np.array([250, 170, 30], dtype=np.uint8)), 66 | ("Static", np.array([110, 190, 160], dtype=np.uint8)), 67 | ("Dynamic", np.array([170, 120, 50], dtype=np.uint8)), 68 | ("Water", np.array([45, 60, 150], dtype=np.uint8)), 69 | ("Terrain", np.array([145, 170, 100], dtype=np.uint8)), 70 | ]) 71 | 72 | 73 | def convert_label_to_grayscale(im): 74 | out = (np.ones(im.shape[:2]) * 255).astype(np.uint8) 75 | for gray_val, (label, rgb) in enumerate(camvid_colors.items()): 76 | match_pxls = np.where((im == np.asarray(rgb)).sum(-1) == 3) 77 | out[match_pxls] = gray_val 78 | assert (out != 79 | 255).all(), "rounding errors or missing classes in camvid_colors" 80 | return out.astype(np.uint8) 81 | 82 | 83 | def convert_label_to_rgb(im): 84 | out = np.zeros((im.shape[0], im.shape[1], 3)).astype(np.uint8) 85 | for gray_val, (label, rgb) in enumerate(camvid_colors.items()): 86 | match_x, match_y = np.where(im == gray_val) 87 | out[match_x, match_y] = rgb 88 | return out.astype(np.uint8) 89 | -------------------------------------------------------------------------------- /results/Flow360.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasterHow/PanoFlow/4cae20dd58dc14bad074a26a787a3c474ad62efb/results/Flow360.png -------------------------------------------------------------------------------- /results/compare_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasterHow/PanoFlow/4cae20dd58dc14bad074a26a787a3c474ad62efb/results/compare_.png -------------------------------------------------------------------------------- /results/compare_quant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasterHow/PanoFlow/4cae20dd58dc14bad074a26a787a3c474ad62efb/results/compare_quant.png -------------------------------------------------------------------------------- /results/panoshow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasterHow/PanoFlow/4cae20dd58dc14bad074a26a787a3c474ad62efb/results/panoshow.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [aliases] 5 | test=pytest 6 | 7 | [yapf] 8 | based_on_style = pep8 9 | blank_line_before_nested_class_or_def = true 10 | split_before_expression_after_opening_paren = true 11 | 12 | [isort] 13 | line_length = 79 14 | multi_line_output = 0 15 | known_standard_library = setuptools 16 | known_first_party = opticalflow 17 | known_third_party = opencv2 18 | no_lines_before = STDLIB,LOCALFOLDER 19 | default_section = THIRDPARTY 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == '__main__': 4 | setup(name='opticalflow', version=0.1) 5 | -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from opticalflow.api import init_model 4 | from opticalflow.api.evaluate import validate_flow360, validate_flow360_cfe, validate_omni_cfe, validate_omni 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description='Test and evaluate the model') 9 | parser.add_argument( 10 | '--model', 11 | help='The model use to inference', 12 | default='CSFlow', 13 | choices=['CSFlow', 'RAFT', 'PanoFlow(CSFlow)', 'PanoFlow(RAFT)']) 14 | parser.add_argument( 15 | '--CFE', 16 | help='inference under CFE framework, details in paper', 17 | action='store_true') 18 | parser.add_argument( 19 | '--restore_ckpt', 20 | help="Restored checkpoint you are using/path or None", 21 | default='./checkpoints/raft-things.pth') 22 | parser.add_argument( 23 | '--iters', 24 | type=int, 25 | help='Iterations of GRU unit when train', 26 | default=20) 27 | parser.add_argument( 28 | '--eval_iters', 29 | type=int, 30 | help='Iterations of GRU unit when eval', 31 | default=12) 32 | parser.add_argument( 33 | '--train', help='True or False', default=True, choices=[True, False]) 34 | parser.add_argument( 35 | '--eval', 36 | default=True, 37 | help='Whether eval or test demo', 38 | choices=[True, False]) 39 | parser.add_argument( 40 | '--dataset', 41 | help='The data use to train', 42 | default='Things') 43 | parser.add_argument( 44 | '--val_Flow360_root', 45 | help='Root of the current datasets') 46 | parser.add_argument( 47 | '--val_Omni_root', 48 | help='Root of the current datasets') 49 | parser.add_argument( 50 | '--validation', 51 | type=str, 52 | nargs='+', 53 | default=[], 54 | help='The dataset used to validate RAFT') 55 | parser.add_argument( 56 | '--cvt_gt', 57 | help='convert gt to 360 flow', 58 | default=True) 59 | parser.add_argument( 60 | '--change_gpu', 61 | help='train on cuda device but not cuda:0', 62 | action='store_true') 63 | parser.add_argument('--gpus', type=int, nargs='+', default=[0]) 64 | parser.add_argument('--DEVICE', help='The using device', default='cuda') 65 | args = parser.parse_args() 66 | return args 67 | 68 | 69 | def main(): 70 | args = parse_args() 71 | 72 | model = init_model(args) 73 | 74 | results = {} 75 | if args.CFE: 76 | for val_dataset in args.validation: 77 | if val_dataset == 'Flow360': 78 | if args.change_gpu: 79 | results.update( 80 | validate_flow360_cfe(model, args.val_Flow360_root, cvt_gt=args.cvt_gt, gpus=args.gpus)) 81 | else: 82 | results.update( 83 | validate_flow360_cfe(model.module, args.val_Flow360_root, cvt_gt=args.cvt_gt)) 84 | if val_dataset == 'Omni': 85 | if args.change_gpu: 86 | results.update( 87 | validate_omni_cfe(model, args.val_Omni_root, cvt_gt=args.cvt_gt, gpus=args.gpus)) 88 | else: 89 | results.update( 90 | validate_omni_cfe(model.module, args.val_Omni_root, cvt_gt=args.cvt_gt)) 91 | else: 92 | for val_dataset in args.validation: 93 | if val_dataset == 'Flow360': 94 | if args.change_gpu: 95 | results.update( 96 | validate_flow360(model, args.val_Flow360_root, args.gpus)) 97 | else: 98 | results.update( 99 | validate_flow360(model.module, args.val_Flow360_root)) 100 | elif val_dataset == 'Omni': 101 | if args.change_gpu: 102 | results.update( 103 | validate_omni(model, args.val_Omni_root, args.gpus)) 104 | else: 105 | results.update( 106 | validate_omni(model.module, args.val_Omni_root)) 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from opticalflow.api import init_model, manage_data 8 | from opticalflow.api.evaluate import (sequence_loss, validate_chairs, validate_flow360) 9 | from opticalflow.utils.logger import Logger 10 | 11 | # sum events to one message 12 | SUM_FREQ = 100 13 | 14 | # validation frequency 15 | VAL_FREQ = 10000 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser( 20 | description='Train and evaluate the model') 21 | parser.add_argument( 22 | '--train', help='True or False', default=True, choices=[True, False]) 23 | parser.add_argument( 24 | '--model', 25 | help='The model used to train and inference', 26 | default='PanoFlow(CSFlow)', 27 | choices=['CSFlow', 'RAFT', 'PanoFlow(CSFlow)', 'PanoFlow(RAFT)']) 28 | parser.add_argument( 29 | '--name', help='name your experiment', default='PanoFlow(CSFlow)-test') 30 | parser.add_argument( 31 | '--restore_ckpt', 32 | help='Restored checkpoint you are using/path or None', 33 | default=None) 34 | parser.add_argument( 35 | '--start_step', 36 | type=int, 37 | help='Start of train steps, used for RESUME', 38 | default=0) 39 | parser.add_argument( 40 | '--num_steps', 41 | type=int, 42 | help='Total number of train steps', 43 | default=100) 44 | parser.add_argument('--batch_size', default=10) 45 | parser.add_argument( 46 | '--dataset', 47 | help='The data use to train', 48 | default='Chairs', 49 | choices=['Chairs', 'Things', 'Flow360']) 50 | parser.add_argument( 51 | '--image_size', 52 | type=int, 53 | nargs='+', 54 | help='Cropped img size used to train', 55 | default=[384, 512]) 56 | parser.add_argument( 57 | '--data_root', 58 | help='Root of the current training datasets') 59 | parser.add_argument( 60 | '--validation', 61 | type=str, 62 | nargs='+', 63 | help='The dataset used to validate') 64 | parser.add_argument( 65 | '--train_Flow360_root', 66 | help='Root of the Flow360 datasets') 67 | parser.add_argument( 68 | '--val_Chairs_root', 69 | help='Root of the Chairs validation datasets') 70 | parser.add_argument( 71 | '--val_Flow360_root', 72 | help='Root of the current datasets') 73 | parser.add_argument('--DEVICE', help='The using device', default='cuda') 74 | parser.add_argument( 75 | '--lr', type=float, help='Learning rate', default=0.00002) 76 | parser.add_argument( 77 | '--wdecay', 78 | type=float, 79 | help='Decay rate of learning rate ', 80 | default=.00005) 81 | parser.add_argument( 82 | '--gamma', 83 | type=float, 84 | help='exponential weighting the loss', 85 | default=0.8) 86 | parser.add_argument( 87 | '--change_gpu', 88 | help='train on cuda device but not cuda:0', 89 | action='store_true') 90 | parser.add_argument('--gpus', type=int, nargs='+', default=[0]) 91 | parser.add_argument( 92 | '--iters', 93 | type=int, 94 | help='Iterations of GRU unit, 12 for train', 95 | default=12) 96 | parser.add_argument( 97 | '--eval_iters', 98 | type=int, 99 | help='Iterations of GRU unit when eval', 100 | default=24) 101 | parser.add_argument('--epsilon', type=float, default=1e-8) 102 | parser.add_argument('--clip', type=float, default=1.0) 103 | parser.add_argument('--dropout', type=float, default=0.0) 104 | parser.add_argument('--add_noise', action='store_true') 105 | args = parser.parse_args() 106 | return args 107 | 108 | 109 | def fetch_optimizer(args, model, data_length): 110 | """Create the optimizer and learning rate scheduler, from RAFT.""" 111 | optimizer = torch.optim.AdamW( 112 | model.parameters(), 113 | lr=args.lr, 114 | weight_decay=args.wdecay, 115 | eps=args.epsilon) 116 | 117 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 118 | optimizer, 119 | args.lr, 120 | args.num_steps + 100, 121 | pct_start=0.05, 122 | cycle_momentum=False, 123 | anneal_strategy='linear') 124 | 125 | return optimizer, scheduler 126 | 127 | 128 | def main(): 129 | args = parse_args() 130 | 131 | # random seed 132 | if args.model == 'CSFlow': 133 | torch.manual_seed(1234) 134 | np.random.seed(1234) 135 | 136 | batch_size = int(args.batch_size) 137 | 138 | # Prepare dataloader 139 | training_data = manage_data.fetch_training_data(args) 140 | dataloader = DataLoader( 141 | training_data, 142 | batch_size=batch_size, 143 | pin_memory=False, 144 | shuffle=True, 145 | num_workers=4, 146 | drop_last=True) 147 | 148 | # Init model 149 | model = init_model(args) 150 | 151 | # Wrap optimizer 152 | optimizer, scheduler = fetch_optimizer(args, model, len(training_data)) 153 | scaler = torch.cuda.amp.GradScaler() 154 | logger = Logger(model, scheduler, SUM_FREQ, args.start_step) 155 | 156 | # Wrap training loop 157 | 158 | total_steps = 0 159 | if args.start_step > 0: 160 | # For resume training 161 | should_keep_training = True 162 | while should_keep_training: 163 | scheduler.step() 164 | total_steps += 1 165 | 166 | if total_steps > args.start_step: 167 | break 168 | 169 | should_keep_training = True 170 | while should_keep_training: 171 | 172 | for i_batch, data_blob in enumerate(dataloader): 173 | optimizer.zero_grad() 174 | image1, image2, flow, valid = [ 175 | x.cuda(args.gpus[0]) for x in data_blob 176 | ] 177 | 178 | if args.add_noise: 179 | stdv = np.random.uniform(0.0, 5.0) 180 | image1 = (image1 + 181 | stdv * torch.randn(*image1.shape).cuda()).clamp( 182 | 0.0, 255.0) 183 | image2 = (image2 + 184 | stdv * torch.randn(*image2.shape).cuda()).clamp( 185 | 0.0, 255.0) 186 | 187 | # zip image 188 | image_pair = torch.stack((image1, image2)) 189 | flow_predictions = model(image_pair) 190 | 191 | loss, metrics = sequence_loss(flow_predictions, flow, valid, 192 | args.gamma) 193 | scaler.scale(loss).backward() 194 | scaler.unscale_(optimizer) 195 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 196 | 197 | scaler.step(optimizer) 198 | scheduler.step() 199 | scaler.update() 200 | 201 | logger.push(metrics) 202 | 203 | if total_steps % VAL_FREQ == VAL_FREQ - 1: 204 | PATH = './checkpoints/%d_%s.pth' % (total_steps + 1, 205 | args.name) 206 | torch.save(model.state_dict(), PATH) 207 | 208 | results = {} 209 | for val_dataset in args.validation: 210 | if val_dataset == 'Chairs': 211 | if args.change_gpu: 212 | results.update( 213 | validate_chairs(model, 214 | args.val_Chairs_root, 215 | args.gpus)) 216 | else: 217 | results.update( 218 | validate_chairs(model.module, 219 | args.val_Chairs_root)) 220 | elif val_dataset == 'Flow360': 221 | if args.change_gpu: 222 | results.update( 223 | validate_flow360(model, args.val_Flow360_root, args.gpus)) 224 | else: 225 | results.update( 226 | validate_flow360(model.module, args.val_Flow360_root)) 227 | 228 | logger.write_dict(results) 229 | 230 | model.train() 231 | if args.dataset != 'Chairs': 232 | try: 233 | model.module.freeze_bn() 234 | except (Exception): 235 | try: 236 | model.freeze_bn() 237 | except: 238 | for m in model.modules(): 239 | if isinstance(m, torch.nn.BatchNorm2d): 240 | m.eval() 241 | 242 | total_steps += 1 243 | 244 | if total_steps > args.num_steps: 245 | should_keep_training = False 246 | break 247 | 248 | logger.close() 249 | PATH = './checkpoints/%s.pth' % args.name 250 | torch.save(model.state_dict(), PATH) 251 | 252 | print('Looks nice! Wish you a good luck! =)') 253 | 254 | 255 | if __name__ == '__main__': 256 | main() 257 | --------------------------------------------------------------------------------