├── .gitignore ├── LICENSE ├── README.md ├── argparser.py ├── assets └── viz.png ├── core ├── DCEIFlow.py ├── RAFT.py ├── __init__.py ├── backbone │ └── raft_encoder.py ├── corr │ ├── __init__.py │ └── raft_corr.py ├── decoder │ ├── __init__.py │ ├── raft_updater.py │ └── with_event_updater.py ├── loss │ ├── Combine.py │ ├── L1Loss.py │ └── __init__.py └── metric │ ├── Combine.py │ ├── EPE.py │ └── __init__.py ├── evaluate.py ├── logger.py ├── main.py ├── trainer.py └── utils ├── __init__.py ├── augmentor ├── __init__.py ├── event_augmentor.py └── image_augmentor.py ├── datasets ├── FlyingChairs2.py ├── MVSEC.py ├── MVSEC_utils.py └── __init__.py ├── event_uitls.py ├── file_io.py ├── sample_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | results 3 | logs* 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | .venv 91 | venv/ 92 | ENV/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # input data, saved log, checkpoints 108 | data/ 109 | input/ 110 | saved/ 111 | 112 | # editor, os cache directory 113 | .vscode/ 114 | .idea/ 115 | __MACOSX/ 116 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zhexiong Wan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DCEIFlow 2 | This repository contains the source code for our paper: 3 | 4 | Learning Dense and Continuous Optical Flow from an Event Camera 5 | 6 | IEEE Transactions on Image Processing (TIP 2022) 7 | 8 | Zhexiong Wan, Yuchao Dai, Yuxin Mao 9 | 10 | [Project Page](https://npucvr.github.io/DCEIFlow/), [arXiv](https://arxiv.org/abs/2211.09078), [IEEE](https://ieeexplore.ieee.org/document/9950520), [Supp](https://npucvr.github.io/DCEIFlow/Supp_Final_compressed.pdf). 11 | 12 | 13 | 14 | If you have any questions, please do not hesitate to raise the issue or contact my email wanzhexiong@mail.nwpu.edu.cn 15 | 16 | ## Requirements 17 | The code has been tested with PyTorch 1.12.1 and Cuda 11.7. 18 | 19 | ## Pretrained Weights 20 | 21 | Pretrained weights can be downloaded from 22 | [Google Drive](https://drive.google.com/drive/folders/1Dh7BqXozY59SZKOgVj7_yZ5d09R8qilw?usp=share_link). 23 | 24 | Please put them into the `checkpoint` folder. 25 | 26 | ## Evaluation 27 | 28 | To evaluate our model, you need first download the HDF5 files version of [MVSEC](https://daniilidis-group.github.io/mvsec/download/) datasets. 29 | 30 | ``` 31 | data/MVSEC_HDF5 32 | ├── indoor_flying 33 | │   ├── indoor_flying1_data.hdf5 34 | │   ├── indoor_flying1_gt.hdf5 35 | │   ├── indoor_flying2_data.hdf5 36 | │   ├── indoor_flying2_gt.hdf5 37 | │   ├── indoor_flying3_data.hdf5 38 | │   ├── indoor_flying3_gt.hdf5 39 | ├── outdoor_day 40 | │   ├── outdoor_day1_data.hdf5 41 | │   ├── outdoor_day1_gt.hdf5 42 | │   ├── outdoor_day2_data.hdf5 43 | │   ├── outdoor_day2_gt.hdf5 44 | ``` 45 | 46 | After the environment is configured and the pretrained weights is downloaded, run the following command to get the consistent results as reported in the paper. 47 | 48 | ``` 49 | python main.py --task test --stage mvsec --checkpoint ./checkpoint/DCEIFlow_paper.pth 50 | ``` 51 | 52 | The results reported in our paper are simulated with [esim_py](https://github.com/uzh-rpg/rpg_vid2e). We also provide another pretrained model using a new simulator, [DVS-Voltmeter](https://github.com/Lynn0306/DVS-Voltmeter), thanks for this open source project. It improves the performance on Chairs and Sintel, while MVSEC is basically unchanged. **We recommend using the updated model because of better generalization performance.** 53 | 54 | ``` 55 | python main.py --task test --stage mvsec --checkpoint ./checkpoint/DCEIFlow.pth 56 | ``` 57 | 58 | 59 | ## Training 60 | 61 | To train our model, you need to download the [FlyingChairs2](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs2) datasets and simulate the events corresponding to every two frames. 62 | ``` 63 | data/FlyingChairs2 64 | ├── train 65 | ├── val 66 | ├── events_train 67 | ├── events_val 68 | ``` 69 | 70 | The simulated events are stored in HDF5 format with name ```******-event.hdf5```. Please refer to the ``read_event_h5()`` function in ``utils/file_io.py``. 71 | 72 | After completing the simulation, you can run the following command to start the training. 73 | 74 | ``` 75 | CUDA_VISIBLE_DEVICES="0," python main.py --task train --stage chairs2 --isbi --model DCEIFlow --batch_size 4 --epoch 200 --lr 0.0004 --weight_decay 0.0001 --loss_gamma=0.80 --name DCEIFlow 76 | ``` 77 | 78 | ## Citation 79 | If our work or code helps you, please cite our paper. 80 | 81 | **If our code is very useful for your new research, I hope you can also open source your code including training.** 82 | 83 | ``` 84 | @article{wan2022DCEIFlow, 85 | author={Wan, Zhexiong and Dai, Yuchao and Mao, Yuxin}, 86 | title={Learning Dense and Continuous Optical Flow From an Event Camera}, 87 | journal={IEEE Transactions on Image Processing}, 88 | year={2022}, 89 | volume={31}, 90 | pages={7237-7251}, 91 | doi={10.1109/TIP.2022.3220938} 92 | } 93 | ``` 94 | 95 | ## Acknowledgments 96 | 97 | This research was sponsored by Zhejiang Lab. 98 | 99 | Thanks the assiciate editor and the reviewers for their comments, which is very helpful to improve our paper. 100 | 101 | Thanks for the following helpful open source projects: 102 | 103 | [RAFT](https://github.com/princeton-vl/RAFT), 104 | [event_utils](https://github.com/TimoStoff/event_utils), 105 | [EV-FlowNet](https://github.com/daniilidis-group/EV-FlowNet), 106 | [mvsec_eval](https://github.com/TimoStoff/mvsec_eval), 107 | [esim_py](https://github.com/uzh-rpg/rpg_vid2e), 108 | [DVS-Voltmeter](https://github.com/Lynn0306/DVS-Voltmeter), 109 | [Spike-FlowNet](https://github.com/chan8972/Spike-FlowNet). 110 | -------------------------------------------------------------------------------- /argparser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import datetime 4 | 5 | class ArgParser: 6 | def __init__(self): 7 | self.args = None 8 | self.parse = argparse.ArgumentParser() 9 | self.parse.add_argument("--batch_size", type=int, default=1, help="") 10 | self.parse.add_argument("--test_batch_size", type=int, default=-1, help="") 11 | self.parse.add_argument("--lr", type=float, default=0.0001, help="") 12 | self.parse.add_argument("--weight_decay", type=float, default=0.0001, help="") 13 | self.parse.add_argument("--epsilon", type=float, default=1e-8, help="") 14 | self.parse.add_argument("--clip", type=float, default=1.0, help="") 15 | 16 | self.parse.add_argument("--distributed", type=str, default="ddp", help="ddp(data distributed) or dp(data parallel)") 17 | self.parse.add_argument("--ip", type=str, default="127.0.0.1", help="ddp ip") 18 | self.parse.add_argument("--port", type=str, default="23130", help="ddp port") 19 | 20 | self.parse.add_argument('--gpus', type=int, nargs='+', default=[-1]) 21 | self.parse.add_argument('--jobs', type=int, default=4) 22 | self.parse.add_argument('--mixed_precision', action='store_true', default=False) 23 | self.parse.add_argument("--resume", action='store_true', help="") 24 | self.parse.add_argument("--checkpoint", type=str, default="", help="") 25 | 26 | self.parse.add_argument("--model", type=str, default="DCEIFlow", help="") 27 | self.parse.add_argument('--iters', type=int, default=6, help="iters from low level to higher") 28 | self.parse.add_argument("--backbone", type=str, default="BasicEncoder", help="") 29 | self.parse.add_argument("--corr", type=str, default="Corr", help="") 30 | self.parse.add_argument("--decoder", type=str, default="Updater", help="") 31 | self.parse.add_argument('--small', action='store_true', default=False, help='use small model') 32 | self.parse.add_argument('--warm_start', action='store_true', default=False, help='use warm start in evaluate stage') 33 | 34 | self.parse.add_argument('--event_bins', type=int, default=5, \ 35 | help='number of bins in the voxel grid event tensor') 36 | self.parse.add_argument('--no_event_polarity', dest='no_event_polarity', action='store_true', \ 37 | default=False, help='Don not divide event voxel by polarity') 38 | 39 | self.parse.add_argument("--loss", type=str, nargs='+', default=["L1Loss"], help="") 40 | self.parse.add_argument("--loss_gamma", type=float, default=0.8, help="") 41 | self.parse.add_argument("--loss_weights", type=float, nargs='+', default=[1.0], help="") 42 | 43 | self.parse.add_argument("--name", type=str, default="", help="") 44 | self.parse.add_argument("--task", type=str, default="train", help="") 45 | self.parse.add_argument("--stage", type=str, default="chairs2", help="") 46 | self.parse.add_argument("--metric", type=str, nargs='+', default=["epe"], help="") 47 | 48 | self.parse.add_argument('--isbi', action='store_true', default=False, help='bidirection flow training') 49 | self.parse.add_argument("--seed", type=int, default=20, help="") 50 | self.parse.add_argument('--not_save_board', action='store_true', default=False) 51 | self.parse.add_argument('--not_save_log', action='store_true', default=False) 52 | self.parse.add_argument("--log_path", type=str, default="logs", help="") 53 | 54 | self.parse.add_argument("--epoch", type=int, default=200, help="") 55 | self.parse.add_argument("--eval_feq", type=int, default=5, help="every eval_feq for epoch") 56 | self.parse.add_argument("--save_feq", type=int, default=5, help="every save_feq for epoch") 57 | self.parse.add_argument("--save_path", type=str, default="logs", help="") 58 | self.parse.add_argument("--debug", action='store_true', default=False) 59 | 60 | self.parse.add_argument("--crop_size", type=int, nargs='+', default=[368, 496]) 61 | self.parse.add_argument("--pad", type=int, default=8, help="") 62 | 63 | self.parse.add_argument('--skip_num', type=int, default=1, help='skip images in dataset to get more events') 64 | self.parse.add_argument('--skip_mode', type=str, default='i', \ 65 | help='skip images mode in dataset to get more events i(interrupt)/c(continue)') 66 | 67 | def _print(self): 68 | print(">>> ======= Options ==========") 69 | for k, v in vars(self.args).items(): 70 | print(k, '=', v) 71 | print("<<< ======= Options ==========") 72 | 73 | def logger_debug(self, logger): 74 | for k, v in vars(self.args).items(): 75 | logger.log_debug('{}\t=\t{}'.format(k, v), "argparser") 76 | 77 | def parser(self): 78 | self.args = self.parse.parse_args() 79 | 80 | if self.args.test_batch_size == -1: 81 | self.args.test_batch_size = self.args.batch_size 82 | 83 | if self.args.name == '' or self.args.task[:4] == 'test' or self.args.task == 'submit': 84 | self.args.not_save_board = True 85 | self.args.not_save_log = True 86 | if self.args.task != 'test': 87 | print("not_save_board and not_save_log are set to True") 88 | 89 | time = datetime.datetime.now() 90 | self.args.log_name = "{:0>2d}{:0>2d}{:0>2d}_{:0>2d}{:0>2d}{:0>2d}_{}".format(time.year % 100, 91 | time.month, time.day, time.hour, time.minute, time.second, self.args.name) if self.args.name != "" else "" 92 | self.args.log_path = os.path.join(self.args.log_path, self.args.log_name) 93 | self.args.save_path = os.path.join(self.args.save_path, self.args.log_name) 94 | 95 | if self.args.task != 'test': 96 | self._print() 97 | 98 | return self.args 99 | -------------------------------------------------------------------------------- /assets/viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danqu130/DCEIFlow/f1c24c1199aa033cc09ab2979acaeecc0cf3a3f7/assets/viz.png -------------------------------------------------------------------------------- /core/DCEIFlow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import sys 6 | sys.path.append('.') 7 | sys.path.append('core') 8 | 9 | from core.decoder.with_event_updater import BasicUpdateBlockNoMask, SmallUpdateBlock 10 | from core.backbone.raft_encoder import BasicEncoder, SmallEncoder 11 | from core.corr.raft_corr import CorrBlock, AlternateCorrBlock 12 | from utils.sample_utils import coords_grid, upflow8 13 | 14 | try: 15 | autocast = torch.cuda.amp.autocast 16 | except: 17 | # dummy autocast for PyTorch < 1.6 18 | class autocast: 19 | def __init__(self, enabled): 20 | pass 21 | def __enter__(self): 22 | pass 23 | def __exit__(self, *args): 24 | pass 25 | 26 | 27 | class EIFusion(nn.Module): 28 | def __init__(self, input_dim=256): 29 | super().__init__() 30 | self.conv1 = nn.Conv2d(input_dim, 192, 1, padding=0) 31 | self.conv2 = nn.Conv2d(input_dim, 192, 1, padding=0) 32 | self.convo = nn.Conv2d(192*2, input_dim, 3, padding=1) 33 | 34 | def forward(self, x1, x2): 35 | c1 = F.relu(self.conv1(x1)) 36 | c2 = F.relu(self.conv2(x2)) 37 | out = torch.cat([c1, c2], dim=1) 38 | out = F.relu(self.convo(out)) 39 | return out + x1 40 | 41 | 42 | class DCEIFlow(nn.Module): 43 | def __init__(self, args): 44 | super().__init__() 45 | self.args = args 46 | self.small = False 47 | self.dropout = 0 48 | self.alternate_corr = False 49 | 50 | self.isbi = args.isbi 51 | self.event_bins = args.event_bins if args.no_event_polarity is True else 2 * args.event_bins 52 | 53 | if self.small: 54 | self.hidden_dim = hdim = 96 55 | self.context_dim = cdim = 64 56 | args.corr_levels = 4 57 | args.corr_radius = 3 58 | else: 59 | self.hidden_dim = hdim = 128 60 | self.context_dim = cdim = 128 61 | args.corr_levels = 4 62 | args.corr_radius = 4 63 | 64 | # feature network, context network, and update block 65 | if self.small: 66 | self.fnet = SmallEncoder(input_dim=3, output_dim=128, norm_fn='instance', dropout=self.dropout) 67 | self.cnet = SmallEncoder(input_dim=3, output_dim=hdim+cdim, norm_fn='none', dropout=self.dropout) 68 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 69 | self.fusion = EIFusion(input_dim=128) 70 | self.enet = SmallEncoder(input_dim=self.event_bins, output_dim=128, norm_fn='instance', dropout=self.dropout) 71 | else: 72 | self.fnet = BasicEncoder(input_dim=3, output_dim=256, norm_fn='instance', dropout=self.dropout) 73 | self.cnet = BasicEncoder(input_dim=3, output_dim=hdim+cdim, norm_fn='batch', dropout=self.dropout) 74 | self.update_block = BasicUpdateBlockNoMask(self.args, hidden_dim=hdim) 75 | self.fusion = EIFusion(input_dim=256) 76 | self.enet = BasicEncoder(input_dim=self.event_bins, output_dim=256, norm_fn='instance', dropout=self.dropout) 77 | 78 | def freeze_bn(self): 79 | for m in self.modules(): 80 | if isinstance(m, nn.BatchNorm2d): 81 | m.eval() 82 | 83 | def initialize_flow(self, img): 84 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 85 | N, C, H, W = img.shape 86 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 87 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 88 | 89 | # optical flow computed as difference: flow = coords1 - coords0 90 | return coords0, coords1 91 | 92 | def upsample_flow(self, flow, mask): 93 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 94 | N, _, H, W = flow.shape 95 | mask = mask.view(N, 1, 9, 8, 8, H, W) 96 | mask = torch.softmax(mask, dim=2) 97 | 98 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 99 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 100 | 101 | up_flow = torch.sum(mask * up_flow, dim=2) 102 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 103 | return up_flow.reshape(N, 2, 8*H, 8*W) 104 | 105 | def forward(self, batch, iters=12, flow_init=None, upsample=True): 106 | """ Estimate optical flow between pair of frames """ 107 | 108 | image1 = batch['image1'] 109 | image1 = 2 * (image1 / 255.0) - 1.0 110 | image1 = image1.contiguous() 111 | 112 | image2 = None 113 | if self.training or self.isbi: 114 | assert 'image2' in batch.keys() 115 | image2 = batch['image2'] 116 | image2 = 2 * (image2 / 255.0) - 1.0 117 | image2 = image2.contiguous() 118 | 119 | event_voxel = batch['event_voxel'] 120 | event_voxel = 2 * event_voxel - 1.0 121 | event_voxel = event_voxel.contiguous() 122 | 123 | hdim = self.hidden_dim 124 | cdim = self.context_dim 125 | 126 | # run the feature network 127 | reversed_emap = None 128 | with autocast(enabled=self.args.mixed_precision): 129 | emap = self.enet(event_voxel) 130 | if self.isbi and 'reversed_event_voxel' in batch.keys(): 131 | assert image2 is not None 132 | fmap1, fmap2 = self.fnet([image1, image2]) 133 | reversed_event_voxel = batch['reversed_event_voxel'] 134 | reversed_event_voxel = 2 * reversed_event_voxel - 1.0 135 | reversed_event_voxel = reversed_event_voxel.contiguous() 136 | reversed_emap = self.enet(reversed_event_voxel) 137 | else: 138 | reversed_emap = None 139 | if image2 is None: 140 | fmap1 = self.fnet(image1) 141 | fmap2 = None 142 | else: 143 | fmap1, fmap2 = self.fnet([image1, image2]) 144 | 145 | fmap1 = fmap1.float() 146 | emap = emap.float() 147 | if fmap2 is not None: 148 | fmap2 = fmap2.float() 149 | 150 | with autocast(enabled=self.args.mixed_precision): 151 | pseudo_fmap2 = self.fusion(fmap1, emap) 152 | 153 | corr_fn = CorrBlock(fmap1, pseudo_fmap2, radius=self.args.corr_radius) 154 | 155 | # run the context network 156 | with autocast(enabled=self.args.mixed_precision): 157 | cnet = self.cnet(image1) 158 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 159 | net = torch.tanh(net) 160 | inp = torch.relu(inp) 161 | 162 | coords0, coords1 = self.initialize_flow(image1) 163 | 164 | if flow_init is not None: 165 | coords1 = coords1 + flow_init 166 | 167 | flow_predictions = [] 168 | flow_predictions_bw = [] 169 | flow_up = None 170 | flow_up_bw = None 171 | pseudo_fmap1 = None 172 | 173 | for itr in range(iters): 174 | coords1 = coords1.detach() 175 | corr = corr_fn(coords1) # index correlation volume 176 | 177 | flow = coords1 - coords0 178 | with autocast(enabled=self.args.mixed_precision): 179 | net, up_mask, delta_flow = self.update_block(net, inp, corr, emap, flow) 180 | 181 | # F(t+1) = F(t) + \Delta(t) 182 | coords1 = coords1 + delta_flow 183 | 184 | # upsample predictions 185 | if up_mask is None: 186 | flow_up = upflow8(coords1 - coords0) 187 | else: 188 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 189 | 190 | flow_predictions.append(flow_up) 191 | 192 | if fmap2 is not None and reversed_emap is not None: 193 | 194 | with autocast(enabled=self.args.mixed_precision): 195 | # pseudo_fmap1 = fmap2 + r_emap 196 | pseudo_fmap1 = self.fusion(fmap2, reversed_emap) 197 | 198 | if self.alternate_corr: 199 | corr_fn = AlternateCorrBlock(fmap2, pseudo_fmap1, radius=self.args.corr_radius) 200 | else: 201 | corr_fn = CorrBlock(fmap2, pseudo_fmap1, radius=self.args.corr_radius) 202 | 203 | # run the context network 204 | with autocast(enabled=self.args.mixed_precision): 205 | cnet = self.cnet(image2) 206 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 207 | net = torch.tanh(net) 208 | inp = torch.relu(inp) 209 | 210 | coords0, coords1 = self.initialize_flow(image2) 211 | 212 | if flow_init is not None: 213 | coords1 = coords1 + flow_init 214 | 215 | for itr in range(iters): 216 | coords1 = coords1.detach() 217 | corr = corr_fn(coords1) # index correlation volume 218 | 219 | flow = coords1 - coords0 220 | with autocast(enabled=self.args.mixed_precision): 221 | net, up_mask, delta_flow = self.update_block(net, inp, corr, reversed_emap, flow) 222 | 223 | # F(t+1) = F(t) + \Delta(t) 224 | coords1 = coords1 + delta_flow 225 | 226 | # upsample predictions 227 | if up_mask is None: 228 | flow_up_bw = upflow8(coords1 - coords0) 229 | else: 230 | flow_up_bw = self.upsample_flow(coords1 - coords0, up_mask) 231 | 232 | flow_predictions_bw.append(flow_up_bw) 233 | 234 | if self.training: 235 | batch = dict( 236 | flow_preds=flow_predictions, 237 | flow_preds_bw=flow_predictions_bw, 238 | flow_init=coords1 - coords0, 239 | flow_final=flow_up, 240 | flow_final_bw=flow_up_bw, 241 | fmap2_gt=fmap2, 242 | fmap2_pseudo=pseudo_fmap2, 243 | fmap1_gt=fmap1, 244 | fmap1_pseudo=pseudo_fmap1, 245 | ) 246 | else: 247 | batch = dict( 248 | flow_preds=flow_predictions, 249 | flow_init=coords1 - coords0, 250 | flow_final=flow_up, 251 | flow_final_bw=flow_up_bw, 252 | ) 253 | return batch 254 | -------------------------------------------------------------------------------- /core/RAFT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import sys 7 | sys.path.append('.') 8 | sys.path.append('core') 9 | 10 | from core.decoder.raft_updater import BasicUpdateBlock, SmallUpdateBlock 11 | from core.backbone.raft_encoder import BasicEncoder, SmallEncoder 12 | from core.corr.raft_corr import CorrBlock, AlternateCorrBlock 13 | from utils.sample_utils import bilinear_sampler, coords_grid, upflow8 14 | 15 | try: 16 | autocast = torch.cuda.amp.autocast 17 | except: 18 | # dummy autocast for PyTorch < 1.6 19 | class autocast: 20 | def __init__(self, enabled): 21 | pass 22 | def __enter__(self): 23 | pass 24 | def __exit__(self, *args): 25 | pass 26 | 27 | 28 | class RAFT(nn.Module): 29 | def __init__(self, args): 30 | super().__init__() 31 | self.args = args 32 | self.small = False 33 | self.dropout = 0 34 | self.alternate_corr = False 35 | 36 | if self.small: 37 | self.hidden_dim = hdim = 96 38 | self.context_dim = cdim = 64 39 | args.corr_levels = 4 40 | args.corr_radius = 3 41 | 42 | else: 43 | self.hidden_dim = hdim = 128 44 | self.context_dim = cdim = 128 45 | args.corr_levels = 4 46 | args.corr_radius = 4 47 | 48 | # feature network, context network, and update block 49 | if self.small: 50 | self.fnet = SmallEncoder(input_dim=3, output_dim=128, norm_fn='instance', dropout=self.dropout) 51 | self.cnet = SmallEncoder(input_dim=3, output_dim=hdim+cdim, norm_fn='none', dropout=self.dropout) 52 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 53 | 54 | else: 55 | self.fnet = BasicEncoder(input_dim=3, output_dim=256, norm_fn='instance', dropout=self.dropout) 56 | self.cnet = BasicEncoder(input_dim=3, output_dim=hdim+cdim, norm_fn='batch', dropout=self.dropout) 57 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 58 | 59 | def freeze_bn(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.BatchNorm2d): 62 | m.eval() 63 | 64 | def initialize_flow(self, img): 65 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 66 | N, C, H, W = img.shape 67 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 68 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 69 | 70 | # optical flow computed as difference: flow = coords1 - coords0 71 | return coords0, coords1 72 | 73 | def upsample_flow(self, flow, mask): 74 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 75 | N, _, H, W = flow.shape 76 | mask = mask.view(N, 1, 9, 8, 8, H, W) 77 | mask = torch.softmax(mask, dim=2) 78 | 79 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 80 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 81 | 82 | up_flow = torch.sum(mask * up_flow, dim=2) 83 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 84 | return up_flow.reshape(N, 2, 8*H, 8*W) 85 | 86 | def forward(self, batch, iters=12, flow_init=None, upsample=True): 87 | """ Estimate optical flow between pair of frames """ 88 | image1 = batch['image1'] 89 | image2 = batch['image2'] 90 | image1 = 2 * (image1 / 255.0) - 1.0 91 | image2 = 2 * (image2 / 255.0) - 1.0 92 | image1 = image1.contiguous() 93 | image2 = image2.contiguous() 94 | 95 | hdim = self.hidden_dim 96 | cdim = self.context_dim 97 | 98 | # run the feature network 99 | with autocast(enabled=self.args.mixed_precision): 100 | fmap1, fmap2 = self.fnet([image1, image2]) 101 | 102 | fmap1 = fmap1.float() 103 | fmap2 = fmap2.float() 104 | if self.alternate_corr: 105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 106 | else: 107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | 109 | # run the context network 110 | with autocast(enabled=self.args.mixed_precision): 111 | cnet = self.cnet(image1) 112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 113 | net = torch.tanh(net) 114 | inp = torch.relu(inp) 115 | 116 | coords0, coords1 = self.initialize_flow(image1) 117 | 118 | if flow_init is not None: 119 | coords1 = coords1 + flow_init 120 | 121 | flow_predictions = [] 122 | 123 | for itr in range(iters): 124 | coords1 = coords1.detach() 125 | corr = corr_fn(coords1) # index correlation volume 126 | 127 | flow = coords1 - coords0 128 | with autocast(enabled=self.args.mixed_precision): 129 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 130 | 131 | # F(t+1) = F(t) + \Delta(t) 132 | coords1 = coords1 + delta_flow 133 | 134 | # upsample predictions 135 | if up_mask is None: 136 | flow_up = upflow8(coords1 - coords0) 137 | else: 138 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 139 | 140 | flow_predictions.append(flow_up) 141 | 142 | batch = dict( 143 | flow_preds=flow_predictions, 144 | flow_init=coords1 - coords0, 145 | flow_final=flow_up, 146 | fmap2_gt=None, 147 | fmap2_pseudo=None, 148 | ) 149 | return batch 150 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danqu130/DCEIFlow/f1c24c1199aa033cc09ab2979acaeecc0cf3a3f7/core/__init__.py -------------------------------------------------------------------------------- /core/backbone/raft_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d( 11 | in_planes, planes, kernel_size=3, padding=1, stride=stride) 12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | num_groups = planes // 8 16 | 17 | if norm_fn == 'group': 18 | self.norm1 = nn.GroupNorm( 19 | num_groups=num_groups, num_channels=planes) 20 | self.norm2 = nn.GroupNorm( 21 | num_groups=num_groups, num_channels=planes) 22 | if not stride == 1: 23 | self.norm3 = nn.GroupNorm( 24 | num_groups=num_groups, num_channels=planes) 25 | 26 | elif norm_fn == 'batch': 27 | self.norm1 = nn.BatchNorm2d(planes) 28 | self.norm2 = nn.BatchNorm2d(planes) 29 | if not stride == 1: 30 | self.norm3 = nn.BatchNorm2d(planes) 31 | 32 | elif norm_fn == 'instance': 33 | self.norm1 = nn.InstanceNorm2d(planes) 34 | self.norm2 = nn.InstanceNorm2d(planes) 35 | if not stride == 1: 36 | self.norm3 = nn.InstanceNorm2d(planes) 37 | 38 | elif norm_fn == 'none': 39 | self.norm1 = nn.Sequential() 40 | self.norm2 = nn.Sequential() 41 | if not stride == 1: 42 | self.norm3 = nn.Sequential() 43 | 44 | if stride == 1: 45 | self.downsample = None 46 | 47 | else: 48 | self.downsample = nn.Sequential( 49 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 50 | 51 | def forward(self, x): 52 | y = x 53 | y = self.relu(self.norm1(self.conv1(y))) 54 | y = self.relu(self.norm2(self.conv2(y))) 55 | 56 | if self.downsample is not None: 57 | x = self.downsample(x) 58 | 59 | return self.relu(x+y) 60 | 61 | 62 | class BottleneckBlock(nn.Module): 63 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 64 | super(BottleneckBlock, self).__init__() 65 | 66 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 67 | self.conv2 = nn.Conv2d(planes//4, planes//4, 68 | kernel_size=3, padding=1, stride=stride) 69 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 70 | self.relu = nn.ReLU(inplace=True) 71 | 72 | num_groups = planes // 8 73 | 74 | if norm_fn == 'group': 75 | self.norm1 = nn.GroupNorm( 76 | num_groups=num_groups, num_channels=planes//4) 77 | self.norm2 = nn.GroupNorm( 78 | num_groups=num_groups, num_channels=planes//4) 79 | self.norm3 = nn.GroupNorm( 80 | num_groups=num_groups, num_channels=planes) 81 | if not stride == 1: 82 | self.norm4 = nn.GroupNorm( 83 | num_groups=num_groups, num_channels=planes) 84 | 85 | elif norm_fn == 'batch': 86 | self.norm1 = nn.BatchNorm2d(planes//4) 87 | self.norm2 = nn.BatchNorm2d(planes//4) 88 | self.norm3 = nn.BatchNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.BatchNorm2d(planes) 91 | 92 | elif norm_fn == 'instance': 93 | self.norm1 = nn.InstanceNorm2d(planes//4) 94 | self.norm2 = nn.InstanceNorm2d(planes//4) 95 | self.norm3 = nn.InstanceNorm2d(planes) 96 | if not stride == 1: 97 | self.norm4 = nn.InstanceNorm2d(planes) 98 | 99 | elif norm_fn == 'none': 100 | self.norm1 = nn.Sequential() 101 | self.norm2 = nn.Sequential() 102 | self.norm3 = nn.Sequential() 103 | if not stride == 1: 104 | self.norm4 = nn.Sequential() 105 | 106 | if stride == 1: 107 | self.downsample = None 108 | 109 | else: 110 | self.downsample = nn.Sequential( 111 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 112 | 113 | def forward(self, x): 114 | y = x 115 | y = self.relu(self.norm1(self.conv1(y))) 116 | y = self.relu(self.norm2(self.conv2(y))) 117 | y = self.relu(self.norm3(self.conv3(y))) 118 | 119 | if self.downsample is not None: 120 | x = self.downsample(x) 121 | 122 | return self.relu(x+y) 123 | 124 | 125 | class BasicEncoder(nn.Module): 126 | def __init__(self, input_dim=3, output_dim=128, norm_fn='batch', dropout=0.0): 127 | super(BasicEncoder, self).__init__() 128 | self.norm_fn = norm_fn 129 | 130 | if self.norm_fn == 'group': 131 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 132 | 133 | elif self.norm_fn == 'batch': 134 | self.norm1 = nn.BatchNorm2d(64) 135 | 136 | elif self.norm_fn == 'instance': 137 | self.norm1 = nn.InstanceNorm2d(64) 138 | 139 | elif self.norm_fn == 'none': 140 | self.norm1 = nn.Sequential() 141 | 142 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3) 143 | self.relu1 = nn.ReLU(inplace=True) 144 | 145 | self.in_planes = 64 146 | self.layer1 = self._make_layer(64, stride=1) 147 | self.layer2 = self._make_layer(96, stride=2) 148 | self.layer3 = self._make_layer(128, stride=2) 149 | 150 | # output convolution 151 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 152 | 153 | self.dropout = None 154 | if dropout > 0: 155 | self.dropout = nn.Dropout2d(p=dropout) 156 | 157 | for m in self.modules(): 158 | if isinstance(m, nn.Conv2d): 159 | nn.init.kaiming_normal_( 160 | m.weight, mode='fan_out', nonlinearity='relu') 161 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 162 | if m.weight is not None: 163 | nn.init.constant_(m.weight, 1) 164 | if m.bias is not None: 165 | nn.init.constant_(m.bias, 0) 166 | 167 | def _make_layer(self, dim, stride=1): 168 | layer1 = ResidualBlock(self.in_planes, dim, 169 | self.norm_fn, stride=stride) 170 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 171 | layers = (layer1, layer2) 172 | 173 | self.in_planes = dim 174 | return nn.Sequential(*layers) 175 | 176 | def forward(self, x): 177 | 178 | # if input is list, combine batch dimension 179 | is_list = isinstance(x, tuple) or isinstance(x, list) 180 | if is_list: 181 | batch_dim = x[0].shape[0] 182 | x = torch.cat(x, dim=0) 183 | 184 | x = self.conv1(x) 185 | x = self.norm1(x) 186 | x = self.relu1(x) 187 | 188 | x = self.layer1(x) 189 | x = self.layer2(x) 190 | x = self.layer3(x) 191 | 192 | x = self.conv2(x) 193 | 194 | if self.training and self.dropout is not None: 195 | x = self.dropout(x) 196 | 197 | if is_list: 198 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 199 | 200 | return x 201 | 202 | 203 | class SmallEncoder(nn.Module): 204 | def __init__(self, input_dim=3, output_dim=128, norm_fn='batch', dropout=0.0): 205 | super(SmallEncoder, self).__init__() 206 | self.norm_fn = norm_fn 207 | 208 | if self.norm_fn == 'group': 209 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 210 | 211 | elif self.norm_fn == 'batch': 212 | self.norm1 = nn.BatchNorm2d(32) 213 | 214 | elif self.norm_fn == 'instance': 215 | self.norm1 = nn.InstanceNorm2d(32) 216 | 217 | elif self.norm_fn == 'none': 218 | self.norm1 = nn.Sequential() 219 | 220 | self.conv1 = nn.Conv2d(input_dim, 32, kernel_size=7, stride=2, padding=3) 221 | self.relu1 = nn.ReLU(inplace=True) 222 | 223 | self.in_planes = 32 224 | self.layer1 = self._make_layer(32, stride=1) 225 | self.layer2 = self._make_layer(64, stride=2) 226 | self.layer3 = self._make_layer(96, stride=2) 227 | 228 | self.dropout = None 229 | if dropout > 0: 230 | self.dropout = nn.Dropout2d(p=dropout) 231 | 232 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 233 | 234 | for m in self.modules(): 235 | if isinstance(m, nn.Conv2d): 236 | nn.init.kaiming_normal_( 237 | m.weight, mode='fan_out', nonlinearity='relu') 238 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 239 | if m.weight is not None: 240 | nn.init.constant_(m.weight, 1) 241 | if m.bias is not None: 242 | nn.init.constant_(m.bias, 0) 243 | 244 | def _make_layer(self, dim, stride=1): 245 | layer1 = BottleneckBlock( 246 | self.in_planes, dim, self.norm_fn, stride=stride) 247 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 248 | layers = (layer1, layer2) 249 | 250 | self.in_planes = dim 251 | return nn.Sequential(*layers) 252 | 253 | def forward(self, x): 254 | 255 | # if input is list, combine batch dimension 256 | is_list = isinstance(x, tuple) or isinstance(x, list) 257 | if is_list: 258 | batch_dim = x[0].shape[0] 259 | x = torch.cat(x, dim=0) 260 | 261 | x = self.conv1(x) 262 | x = self.norm1(x) 263 | x = self.relu1(x) 264 | 265 | x = self.layer1(x) 266 | x = self.layer2(x) 267 | x = self.layer3(x) 268 | x = self.conv2(x) 269 | 270 | if self.training and self.dropout is not None: 271 | x = self.dropout(x) 272 | 273 | if is_list: 274 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 275 | 276 | return x 277 | -------------------------------------------------------------------------------- /core/corr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danqu130/DCEIFlow/f1c24c1199aa033cc09ab2979acaeecc0cf3a3f7/core/corr/__init__.py -------------------------------------------------------------------------------- /core/corr/raft_corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import sys 4 | sys.path.append('core') 5 | from utils.sample_utils import bilinear_sampler, coords_grid 6 | 7 | try: 8 | import alt_cuda_corr 9 | except: 10 | # alt_cuda_corr is not compiled 11 | pass 12 | 13 | 14 | class CorrBlock: 15 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 16 | self.num_levels = num_levels 17 | self.radius = radius 18 | self.corr_pyramid = [] 19 | 20 | # all pairs correlation 21 | corr = CorrBlock.corr(fmap1, fmap2) 22 | 23 | batch, h1, w1, dim, h2, w2 = corr.shape 24 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 25 | 26 | self.corr_pyramid.append(corr) 27 | for i in range(self.num_levels-1): 28 | corr = F.avg_pool2d(corr, 2, stride=2) 29 | self.corr_pyramid.append(corr) 30 | 31 | def __call__(self, coords): 32 | r = self.radius 33 | coords = coords.permute(0, 2, 3, 1) 34 | batch, h1, w1, _ = coords.shape 35 | 36 | out_pyramid = [] 37 | for i in range(self.num_levels): 38 | corr = self.corr_pyramid[i] 39 | dx = torch.linspace(-r, r, 2*r+1) 40 | dy = torch.linspace(-r, r, 2*r+1) 41 | delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), 42 | axis=-1).to(coords.device) 43 | 44 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 45 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 46 | coords_lvl = centroid_lvl + delta_lvl 47 | 48 | corr = bilinear_sampler(corr, coords_lvl) 49 | corr = corr.view(batch, h1, w1, -1) 50 | out_pyramid.append(corr) 51 | 52 | out = torch.cat(out_pyramid, dim=-1) 53 | return out.permute(0, 3, 1, 2).contiguous().float() 54 | 55 | @staticmethod 56 | def corr(fmap1, fmap2): 57 | batch, dim, ht, wd = fmap1.shape 58 | fmap1 = fmap1.view(batch, dim, ht*wd) 59 | fmap2 = fmap2.view(batch, dim, ht*wd) 60 | 61 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 62 | corr = corr.view(batch, ht, wd, 1, ht, wd) 63 | # return corr / torch.sqrt(torch.tensor(dim).float()) 64 | return corr.mul_(1.0/torch.sqrt(torch.tensor(dim).float())) 65 | 66 | 67 | class HalfCorrBlock: 68 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 69 | self.num_levels = num_levels 70 | self.radius = radius 71 | self.corr_pyramid = [] 72 | 73 | # all pairs correlation 74 | corr = self.corr(fmap1, fmap2) 75 | 76 | batch, h1, w1, dim, h2, w2 = corr.shape 77 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 78 | 79 | self.corr_pyramid.append(corr) 80 | for i in range(self.num_levels-1): 81 | corr = F.avg_pool2d(corr, 2, stride=2) 82 | self.corr_pyramid.append(corr) 83 | 84 | def __call__(self, coords): 85 | r = self.radius 86 | coords = coords.permute(0, 2, 3, 1) 87 | batch, h1, w1, _ = coords.shape 88 | 89 | out_pyramid = [] 90 | for i in range(self.num_levels): 91 | corr = self.corr_pyramid[i] 92 | dx = torch.linspace(-r, r, 2*r+1) 93 | dy = torch.linspace(-r, r, 2*r+1) 94 | delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), 95 | axis=-1).half().to(coords.device) 96 | 97 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 98 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 99 | coords_lvl = centroid_lvl.half() + delta_lvl.half() 100 | 101 | corr = bilinear_sampler(corr, coords_lvl) 102 | corr = corr.view(batch, h1, w1, -1) 103 | out_pyramid.append(corr) 104 | 105 | out = torch.cat(out_pyramid, dim=-1) 106 | return out.permute(0, 3, 1, 2).contiguous().float() 107 | 108 | @staticmethod 109 | def corr(fmap1, fmap2): 110 | batch, dim, ht, wd = fmap1.shape 111 | fmap1 = fmap1.view(batch, dim, ht*wd).half() 112 | fmap2 = fmap2.view(batch, dim, ht*wd).half() 113 | 114 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 115 | corr = corr.view(batch, ht, wd, 1, ht, wd) 116 | # return corr / torch.sqrt(torch.tensor(dim).float()) 117 | return corr.mul_(1.0/torch.sqrt(torch.tensor(dim).float())) 118 | 119 | 120 | class EfficientCorrBlock: 121 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 122 | self.num_levels = num_levels 123 | self.radius = radius 124 | self.corr_pyramid = [] 125 | 126 | self.fmap1 = fmap1 127 | self.fmap2 = fmap2 128 | 129 | def corr(self, fmap1, fmap2, coords): 130 | 131 | B, D, H, W = fmap2.shape 132 | fmap1 = fmap1.unsqueeze(dim=-1) 133 | fmap2 = fmap2.unsqueeze(dim=-1) 134 | 135 | # map grid coordinates to [-1,1] 136 | xgrid, ygrid = coords.split([1, 1], dim=-1) 137 | xgrid = 2*xgrid/(W-1) - 1 138 | ygrid = 2*ygrid/(H-1) - 1 139 | zgrid = torch.zeros_like(xgrid) - 1 140 | grid = torch.cat([zgrid, xgrid, ygrid], dim=-1) 141 | 142 | fmapw = F.grid_sample(fmap2, grid, align_corners=True) 143 | 144 | corr = torch.sum(fmap1*fmapw, dim=1) 145 | return corr / torch.sqrt(torch.tensor(D).float()) 146 | 147 | def __call__(self, coords): 148 | 149 | r = self.radius 150 | coords = coords.permute(0, 2, 3, 1) 151 | batch, h1, w1, _ = coords.shape 152 | 153 | fmap1 = self.fmap1 154 | fmap2 = self.fmap2 155 | 156 | out_pyramid = [] 157 | for i in range(self.num_levels): 158 | dx = torch.linspace(-r, r, 2*r+1) 159 | dy = torch.linspace(-r, r, 2*r+1) 160 | delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), 161 | axis=-1).to(coords.device) 162 | 163 | centroid_lvl = coords.reshape(batch, h1, w1, 1, 2) / 2**i 164 | coords_lvl = centroid_lvl + delta.view(-1, 2) 165 | 166 | corr = self.corr(fmap1, fmap2, coords_lvl) 167 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 168 | out_pyramid.append(corr) 169 | 170 | out = torch.cat(out_pyramid, dim=-1) 171 | return out.permute(0, 3, 1, 2).contiguous().float() 172 | 173 | 174 | class AlternateCorrBlock: 175 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 176 | self.num_levels = num_levels 177 | self.radius = radius 178 | 179 | self.pyramid = [(fmap1, fmap2)] 180 | for i in range(self.num_levels): 181 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 182 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 183 | self.pyramid.append((fmap1, fmap2)) 184 | 185 | def __call__(self, coords): 186 | coords = coords.permute(0, 2, 3, 1) 187 | B, H, W, _ = coords.shape 188 | dim = self.pyramid[0][0].shape[1] 189 | 190 | corr_list = [] 191 | for i in range(self.num_levels): 192 | r = self.radius 193 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 194 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 195 | 196 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 197 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 198 | corr_list.append(corr.squeeze(1)) 199 | 200 | corr = torch.stack(corr_list, dim=1) 201 | corr = corr.reshape(B, -1, H, W) 202 | return corr / torch.sqrt(torch.tensor(dim).float()) 203 | -------------------------------------------------------------------------------- /core/decoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danqu130/DCEIFlow/f1c24c1199aa033cc09ab2979acaeecc0cf3a3f7/core/decoder/__init__.py -------------------------------------------------------------------------------- /core/decoder/raft_updater.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | 17 | class ConvGRU(nn.Module): 18 | def __init__(self, hidden_dim=128, input_dim=192+128): 19 | super(ConvGRU, self).__init__() 20 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 23 | 24 | def forward(self, h, x): 25 | hx = torch.cat([h, x], dim=1) 26 | 27 | z = torch.sigmoid(self.convz(hx)) 28 | r = torch.sigmoid(self.convr(hx)) 29 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 30 | 31 | h = (1-z) * h + z * q 32 | return h 33 | 34 | 35 | class SepConvGRU(nn.Module): 36 | def __init__(self, hidden_dim=128, input_dim=192+128): 37 | super(SepConvGRU, self).__init__() 38 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, 39 | hidden_dim, (1, 5), padding=(0, 2)) 40 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, 41 | hidden_dim, (1, 5), padding=(0, 2)) 42 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, 43 | hidden_dim, (1, 5), padding=(0, 2)) 44 | 45 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, 46 | hidden_dim, (5, 1), padding=(2, 0)) 47 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, 48 | hidden_dim, (5, 1), padding=(2, 0)) 49 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, 50 | hidden_dim, (5, 1), padding=(2, 0)) 51 | 52 | def forward(self, h, x): 53 | # horizontal 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz1(hx)) 56 | r = torch.sigmoid(self.convr1(hx)) 57 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | # vertical 61 | hx = torch.cat([h, x], dim=1) 62 | z = torch.sigmoid(self.convz2(hx)) 63 | r = torch.sigmoid(self.convr2(hx)) 64 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 65 | h = (1-z) * h + z * q 66 | 67 | return h 68 | 69 | 70 | class SmallMotionEncoder(nn.Module): 71 | def __init__(self, args): 72 | super(SmallMotionEncoder, self).__init__() 73 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 74 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 75 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 76 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 77 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 78 | 79 | def forward(self, flow, corr): 80 | cor = F.relu(self.convc1(corr)) 81 | flo = F.relu(self.convf1(flow)) 82 | flo = F.relu(self.convf2(flo)) 83 | cor_flo = torch.cat([cor, flo], dim=1) 84 | out = F.relu(self.conv(cor_flo)) 85 | return torch.cat([out, flow], dim=1) 86 | 87 | 88 | class BasicMotionEncoder(nn.Module): 89 | def __init__(self, args): 90 | super(BasicMotionEncoder, self).__init__() 91 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 92 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 93 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 94 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 95 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 96 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 97 | 98 | def forward(self, flow, corr): 99 | cor = F.relu(self.convc1(corr)) 100 | cor = F.relu(self.convc2(cor)) 101 | flo = F.relu(self.convf1(flow)) 102 | flo = F.relu(self.convf2(flo)) 103 | 104 | cor_flo = torch.cat([cor, flo], dim=1) 105 | out = F.relu(self.conv(cor_flo)) 106 | return torch.cat([out, flow], dim=1) 107 | 108 | 109 | class SmallUpdateBlock(nn.Module): 110 | def __init__(self, args, hidden_dim=96): 111 | super(SmallUpdateBlock, self).__init__() 112 | self.encoder = SmallMotionEncoder(args) 113 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 114 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 115 | 116 | def forward(self, net, inp, corr, flow): 117 | motion_features = self.encoder(flow, corr) 118 | inp = torch.cat([inp, motion_features], dim=1) 119 | net = self.gru(net, inp) 120 | delta_flow = self.flow_head(net) 121 | 122 | return net, None, delta_flow 123 | 124 | 125 | class BasicUpdateBlock(nn.Module): 126 | def __init__(self, args, hidden_dim=128, input_dim=128): 127 | super(BasicUpdateBlock, self).__init__() 128 | self.args = args 129 | self.encoder = BasicMotionEncoder(args) 130 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 131 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 132 | 133 | self.mask = nn.Sequential( 134 | nn.Conv2d(128, 256, 3, padding=1), 135 | nn.ReLU(inplace=True), 136 | nn.Conv2d(256, 64*9, 1, padding=0)) 137 | 138 | def forward(self, net, inp, corr, flow, upsample=True): 139 | motion_features = self.encoder(flow, corr) 140 | inp = torch.cat([inp, motion_features], dim=1) 141 | 142 | net = self.gru(net, inp) 143 | delta_flow = self.flow_head(net) 144 | 145 | # scale mask to balence gradients 146 | mask = .25 * self.mask(net) 147 | return net, mask, delta_flow 148 | -------------------------------------------------------------------------------- /core/decoder/with_event_updater.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | 17 | class ConvGRU(nn.Module): 18 | def __init__(self, hidden_dim=128, input_dim=192+128): 19 | super(ConvGRU, self).__init__() 20 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 23 | 24 | def forward(self, h, x): 25 | hx = torch.cat([h, x], dim=1) 26 | 27 | z = torch.sigmoid(self.convz(hx)) 28 | r = torch.sigmoid(self.convr(hx)) 29 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 30 | 31 | h = (1-z) * h + z * q 32 | return h 33 | 34 | 35 | class SepConvGRU(nn.Module): 36 | def __init__(self, hidden_dim=128, input_dim=192+128): 37 | super(SepConvGRU, self).__init__() 38 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, 39 | hidden_dim, (1, 5), padding=(0, 2)) 40 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, 41 | hidden_dim, (1, 5), padding=(0, 2)) 42 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, 43 | hidden_dim, (1, 5), padding=(0, 2)) 44 | 45 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, 46 | hidden_dim, (5, 1), padding=(2, 0)) 47 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, 48 | hidden_dim, (5, 1), padding=(2, 0)) 49 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, 50 | hidden_dim, (5, 1), padding=(2, 0)) 51 | 52 | def forward(self, h, x): 53 | # horizontal 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz1(hx)) 56 | r = torch.sigmoid(self.convr1(hx)) 57 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | # vertical 61 | hx = torch.cat([h, x], dim=1) 62 | z = torch.sigmoid(self.convz2(hx)) 63 | r = torch.sigmoid(self.convr2(hx)) 64 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 65 | h = (1-z) * h + z * q 66 | 67 | return h 68 | 69 | 70 | class SmallMotionEncoder(nn.Module): 71 | def __init__(self, args): 72 | super(SmallMotionEncoder, self).__init__() 73 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 74 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 75 | self.conve1 = nn.Conv2d(128, 64, 1, padding=0) 76 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 77 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 78 | self.conv = nn.Conv2d(96+32+64, 80, 3, padding=1) 79 | 80 | def forward(self, flow, emap, corr): 81 | cor = F.relu(self.convc1(corr)) 82 | ema = F.relu(self.conve1(emap)) 83 | flo = F.relu(self.convf1(flow)) 84 | flo = F.relu(self.convf2(flo)) 85 | cor_ema_flo = torch.cat([cor, ema, flo], dim=1) 86 | out = F.relu(self.conv(cor_ema_flo)) 87 | return torch.cat([out, flow], dim=1) 88 | 89 | 90 | class BasicMotionEncoder(nn.Module): 91 | def __init__(self, args): 92 | super(BasicMotionEncoder, self).__init__() 93 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 94 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 95 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 96 | self.conve1 = nn.Conv2d(256, 128, 1, padding=0) 97 | self.conve2 = nn.Conv2d(128, 64, 3, padding=1) 98 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 99 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 100 | self.conv = nn.Conv2d(64+192+64, 128-2, 3, padding=1) 101 | 102 | def forward(self, flow, emap, corr): 103 | cor = F.relu(self.convc1(corr)) 104 | cor = F.relu(self.convc2(cor)) 105 | ema = F.relu(self.conve1(emap)) 106 | ema = F.relu(self.conve2(ema)) 107 | flo = F.relu(self.convf1(flow)) 108 | flo = F.relu(self.convf2(flo)) 109 | 110 | cor_ema_flo = torch.cat([cor, ema, flo], dim=1) 111 | out = F.relu(self.conv(cor_ema_flo)) 112 | return torch.cat([out, flow], dim=1) 113 | 114 | 115 | class SmallUpdateBlock(nn.Module): 116 | def __init__(self, args, hidden_dim=96): 117 | super(SmallUpdateBlock, self).__init__() 118 | self.encoder = SmallMotionEncoder(args) 119 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 121 | 122 | def forward(self, net, inp, corr, emap, flow): 123 | motion_features = self.encoder(flow, emap, corr) 124 | inp = torch.cat([inp, motion_features], dim=1) 125 | net = self.gru(net, inp) 126 | delta_flow = self.flow_head(net) 127 | 128 | return net, None, delta_flow 129 | 130 | 131 | class BasicUpdateBlock(nn.Module): 132 | def __init__(self, args, hidden_dim=128, input_dim=128): 133 | super(BasicUpdateBlock, self).__init__() 134 | self.args = args 135 | self.encoder = BasicMotionEncoder(args) 136 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 137 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 138 | 139 | self.mask = nn.Sequential( 140 | nn.Conv2d(128, 256, 3, padding=1), 141 | nn.ReLU(inplace=True), 142 | nn.Conv2d(256, 64*9, 1, padding=0)) 143 | 144 | def forward(self, net, inp, corr, emap, flow): 145 | motion_features = self.encoder(flow, emap, corr) 146 | inp = torch.cat([inp, motion_features], dim=1) 147 | 148 | net = self.gru(net, inp) 149 | delta_flow = self.flow_head(net) 150 | 151 | # scale mask to balence gradients 152 | mask = .25 * self.mask(net) 153 | return net, mask, delta_flow 154 | 155 | 156 | class BasicUpdateBlockNoMask(nn.Module): 157 | def __init__(self, args, hidden_dim=128, input_dim=128): 158 | super().__init__() 159 | self.args = args 160 | self.encoder = BasicMotionEncoder(args) 161 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 162 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 163 | 164 | def forward(self, net, inp, corr, emap, flow): 165 | motion_features = self.encoder(flow, emap, corr) 166 | inp = torch.cat([inp, motion_features], dim=1) 167 | 168 | net = self.gru(net, inp) 169 | delta_flow = self.flow_head(net) 170 | 171 | return net, None, delta_flow 172 | -------------------------------------------------------------------------------- /core/loss/Combine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import sys 6 | sys.path.append('core') 7 | 8 | from utils.utils import build_module 9 | 10 | 11 | class Combine(nn.Module): 12 | def __init__(self, args): 13 | super().__init__() 14 | self.loss_names = args.loss 15 | self.loss_weights = args.loss_weights 16 | self.loss_num = len(self.loss_names) 17 | self.loss = [] 18 | for i in range(self.loss_num): 19 | self.loss.append(build_module("core.loss", self.loss_names[i])(args)) 20 | 21 | def forward(self, output, target): 22 | 23 | loss_all = 0. 24 | loss_dict = {} 25 | for i in range(self.loss_num): 26 | loss_each, loss_metric = self.loss[i](output, target) 27 | loss_all += loss_each * self.loss_weights[i] 28 | loss_dict.update(loss_metric) 29 | 30 | loss_dict.update({ 31 | "loss": loss_all, 32 | }) 33 | 34 | return loss_dict 35 | -------------------------------------------------------------------------------- /core/loss/L1Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class L1Loss(nn.Module): 7 | def __init__(self, args): 8 | super().__init__() 9 | self.iters = args.iters 10 | self.gamma = args.loss_gamma 11 | self.isbi = args.isbi 12 | self.max_flow = 400 13 | 14 | def rescaleflow_tosize(self, flow, new_size, mode='bilinear'): 15 | if new_size[0] == flow.shape[2] and new_size[1] == flow.shape[3]: 16 | return flow 17 | 18 | h_scale = new_size[0] / flow.shape[2] 19 | w_scale = new_size[1] / flow.shape[3] 20 | assert h_scale == w_scale 21 | return h_scale * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 22 | 23 | def resizeflow_tosize(self, flow, new_size, mode='bilinear'): 24 | if new_size[0] == flow.shape[2] and new_size[1] == flow.shape[3]: 25 | return flow 26 | 27 | h_scale = new_size[0] / flow.shape[2] 28 | w_scale = new_size[1] / flow.shape[3] 29 | assert h_scale == w_scale 30 | return F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 31 | 32 | def compute(self, flow_preds, fmap2_gt, fmap2_pseudo, flow_gt, valid_original): 33 | 34 | flow_loss = 0.0 35 | # exlude invalid pixels and extremely large diplacements 36 | mag = torch.sum(flow_gt**2, dim=1, keepdim=True).sqrt() 37 | valid = (valid_original >= 0.5) & (mag < self.max_flow) 38 | 39 | for i in range(len(flow_preds)): 40 | i_weight = self.gamma**(len(flow_preds) - i - 1) 41 | if flow_gt.shape == flow_preds[i].shape: 42 | i_loss = (flow_preds[i] - flow_gt).abs() 43 | flow_loss += i_weight * (valid * i_loss).mean() 44 | else: 45 | scaled_flow_gt = self.resizeflow_tosize(flow_gt, flow_preds[i].shape[2:]) 46 | i_loss = (flow_preds[i] - scaled_flow_gt).abs() 47 | scaled_mag = torch.sum(scaled_flow_gt**2, dim=1, keepdim=True).sqrt() 48 | scaled_valid = (self.resizeflow_tosize(valid_original, flow_preds[i].shape[2:]) >= 0.5) & (scaled_mag < self.max_flow) 49 | flow_loss += i_weight * (scaled_valid * i_loss).mean() 50 | 51 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() 52 | epe = epe.view(-1)[valid.view(-1)] 53 | 54 | if fmap2_pseudo is not None: 55 | if isinstance(fmap2_pseudo, list): 56 | for i in range(len(fmap2_pseudo)): 57 | i_weight = self.gamma**(len(fmap2_pseudo) - i - 1) if len(fmap2_pseudo) != 1 else 1.0 58 | i_loss = F.l1_loss(fmap2_pseudo[i], fmap2_gt[i]) * 10 59 | pseudo_loss += i_weight * i_loss 60 | else: 61 | pseudo_loss = F.l1_loss(fmap2_pseudo, fmap2_gt) * 10 62 | 63 | flow_loss += pseudo_loss 64 | 65 | if fmap2_pseudo is None: 66 | metrics = { 67 | 'l1loss': flow_loss, 68 | 'epe': epe.mean(), 69 | '1px': (epe < 1).float().mean(), 70 | '3px': (epe < 3).float().mean(), 71 | '5px': (epe < 5).float().mean(), 72 | } 73 | else: 74 | metrics = { 75 | 'l1loss': flow_loss, 76 | 'epe': epe.mean(), 77 | 'pseudo': pseudo_loss, 78 | '1px': (epe < 1).float().mean(), 79 | '3px': (epe < 3).float().mean(), 80 | '5px': (epe < 5).float().mean(), 81 | } 82 | 83 | return flow_loss, metrics 84 | 85 | 86 | def forward(self, out, target): 87 | """ Loss function defined over sequence of flow predictions """ 88 | flow_loss = 0.0 89 | 90 | flow_preds = out['flow_preds'] 91 | fmap2_gt = out['fmap2_gt'] 92 | fmap2_pseudo = out['fmap2_pseudo'] 93 | flow_gt = target['flow_gt'] 94 | valid = target['flow_valid'] 95 | flow_loss_fw, metrics_fw = self.compute(flow_preds, fmap2_gt, fmap2_pseudo, flow_gt, valid) 96 | 97 | if not self.isbi: 98 | return flow_loss_fw, metrics_fw 99 | else: 100 | assert 'flow_preds_bw' in out.keys() 101 | flow_preds = out['flow_preds_bw'] 102 | fmap2_gt = out['fmap1_gt'] 103 | fmap2_pseudo = out['fmap1_pseudo'] 104 | flow_gt = target['flow10_gt'] 105 | valid = target['flow10_valid'] 106 | flow_loss_bw, metrics_bw = self.compute(flow_preds, fmap2_gt, fmap2_pseudo, flow_gt, valid) 107 | 108 | flow_loss = (flow_loss_fw + flow_loss_bw) * 0.5 109 | metrics = {} 110 | for key in metrics_fw: 111 | assert key in metrics_bw.keys() 112 | metrics[key] = (metrics_fw[key] + metrics_bw[key]) * 0.5 113 | 114 | return flow_loss, metrics 115 | -------------------------------------------------------------------------------- /core/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danqu130/DCEIFlow/f1c24c1199aa033cc09ab2979acaeecc0cf3a3f7/core/loss/__init__.py -------------------------------------------------------------------------------- /core/metric/Combine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import sys 6 | sys.path.append('core') 7 | 8 | from utils.utils import build_module 9 | 10 | 11 | class Combine: 12 | def __init__(self, args): 13 | super().__init__() 14 | self.metric_names = args.metric 15 | self.metric_num = len(self.metric_names) 16 | self.metrics = [] 17 | for i in range(self.metric_num): 18 | if self.metric_names[i] == 'epe': 19 | self.metrics.append(build_module("core.metric", "EPE")(args)) 20 | else: 21 | self.metrics.append(build_module("core.metric", self.metric_names[i])(args)) 22 | self.all_metrics = {} 23 | 24 | def clear(self): 25 | self.all_metrics = {} 26 | 27 | def calculate(self, output, target, name=None): 28 | metrics = {} 29 | for i in range(self.metric_num): 30 | metric_each = self.metrics[i](output, target, name) 31 | metrics.update(metric_each) 32 | return metrics 33 | 34 | def push(self, metric_each): 35 | for key in metric_each: 36 | if key not in self.all_metrics.keys(): 37 | self.all_metrics[key] = [] 38 | self.all_metrics[key].append(metric_each[key]) 39 | return self.all_metrics 40 | 41 | def get_all(self): 42 | return self.all_metrics 43 | 44 | def summary(self): 45 | 46 | metrics_summary = {} 47 | metrics_str = "" 48 | for key, values in self.all_metrics.items(): 49 | num = sum(values) / len(values) 50 | metrics_summary[key] = num 51 | metrics_str += "{}:{:8.6f},".format(key, num) 52 | self.clear() 53 | return metrics_str, metrics_summary 54 | -------------------------------------------------------------------------------- /core/metric/EPE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def epe(flow_pred, flow_gt, valid_gt=None): 5 | 6 | epe = torch.sum((flow_pred - flow_gt)**2, dim=1).sqrt() 7 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 8 | 9 | epe = epe.view(-1) 10 | mag = mag.view(-1) 11 | 12 | outlier = (epe > 3.0).float() 13 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() 14 | 15 | if valid_gt is not None: 16 | val = valid_gt.view(-1) >= 0.5 17 | metrics = { 18 | 'epe': epe[val].mean(), 19 | '1px': (epe[val] < 1).float().mean(), 20 | '3px': (epe[val] < 3).float().mean(), 21 | '5px': (epe[val] < 5).float().mean(), 22 | 'F1': out[val].mean() * 100, 23 | 'ol': outlier[val].mean() * 100, 24 | } 25 | else: 26 | metrics = { 27 | 'epe': epe.mean(), 28 | '1px': (epe < 1).float().mean(), 29 | '3px': (epe < 3).float().mean(), 30 | '5px': (epe < 5).float().mean(), 31 | 'F1': out.mean() * 100, 32 | 'ol': outlier.mean() * 100, 33 | } 34 | return metrics 35 | 36 | 37 | def epe_f1(flow_pred, flow_gt, valid_gt=None): 38 | 39 | epe = torch.sum((flow_pred - flow_gt)**2, dim=1).sqrt() 40 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 41 | 42 | epe = epe.view(-1) 43 | mag = mag.view(-1) 44 | 45 | outlier = (epe > 3.0).float() 46 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() 47 | 48 | if valid_gt is not None: 49 | val = valid_gt.view(-1) >= 0.5 50 | metrics = { 51 | 'epe': epe[val].mean(), 52 | '1px': (epe[val] < 1).float().mean(), 53 | '3px': (epe[val] < 3).float().mean(), 54 | '5px': (epe[val] < 5).float().mean(), 55 | 'F1': out[val].mean() * 100, 56 | 'ol': outlier[val].mean() * 100, 57 | } 58 | else: 59 | metrics = { 60 | 'epe': epe.mean(), 61 | '1px': (epe < 1).float().mean(), 62 | '3px': (epe < 3).float().mean(), 63 | '5px': (epe < 5).float().mean(), 64 | 'F1': out.mean() * 100, 65 | 'ol': outlier.mean() * 100, 66 | } 67 | 68 | return metrics 69 | 70 | class EPE: 71 | def __init__(self, args): 72 | pass 73 | 74 | def cal(self, flow_pred, flow_gt, flow_valid, event_valid, name): 75 | if 'mvsec' in name: 76 | if 'outdoor' in name: 77 | # remove bottom car 78 | # https://github.com/daniilidis-group/EV-FlowNet/blob/master/src/eval_utils.py#L10 79 | flow_pred = flow_pred[:, :, 0:190, :].contiguous() 80 | flow_gt = flow_gt[:, :, 0:190, :].contiguous() 81 | flow_valid = flow_valid[:, :, 0:190, :].contiguous() 82 | event_valid = event_valid[:, :, 0:190, :].contiguous() 83 | 84 | metric = epe_f1(flow_pred, flow_gt, flow_valid) 85 | masked_metric = epe_f1(flow_pred, flow_gt, flow_valid * event_valid) 86 | 87 | for key, values in masked_metric.items(): 88 | new_key = "emasked_{}".format(key) 89 | assert new_key not in metric 90 | metric[new_key] = values 91 | 92 | else: 93 | metric = epe(flow_pred, flow_gt) 94 | 95 | return metric 96 | 97 | def __call__(self, output, target, name=None): 98 | assert name is not None 99 | 100 | flow_pred = output['flow_pred'] 101 | flow_gt = target['flow_gt'] 102 | flow_valid = target['flow_valid'] 103 | if 'event_valid' in target.keys(): 104 | event_valid = target['event_valid'] 105 | else: 106 | event_valid = None 107 | 108 | metric = self.cal(flow_pred, flow_gt, flow_valid, event_valid, name) 109 | 110 | return metric 111 | -------------------------------------------------------------------------------- /core/metric/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danqu130/DCEIFlow/f1c24c1199aa033cc09ab2979acaeecc0cf3a3f7/core/metric/__init__.py -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | import os 4 | os.environ["KMP_BLOCKTIME"] = "0" 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | import time 10 | from tqdm import tqdm 11 | import torch 12 | from torch.utils.data import DataLoader 13 | import torch.distributed as dist 14 | 15 | from utils.utils import InputPadder 16 | 17 | 18 | def reduce_list(lists, nprocs): 19 | new_lists = {} 20 | for key, value in lists.items(): 21 | rt = value.clone() 22 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 23 | rt /= nprocs 24 | new_lists[key] = rt.item() 25 | return new_lists 26 | 27 | 28 | def reduce_tensor(tensor, nprocs): 29 | rt = tensor.clone() 30 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 31 | rt /= nprocs 32 | return rt 33 | 34 | 35 | def evaluates(args, model, datasets, names, metric_fun, logger=None): 36 | if logger is not None: 37 | _print = logger.log_info 38 | else: 39 | def print_line(line, subname=None): 40 | print(line) 41 | _print = print_line 42 | 43 | metrics = {} 44 | for val_set, name in zip(datasets, names): 45 | if args.distributed == 'ddp': 46 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_set, shuffle=False, \ 47 | seed=args.seed, drop_last=False) 48 | else: 49 | val_sampler = None 50 | val_loader = DataLoader(val_set, args.test_batch_size, num_workers=args.jobs, sampler=val_sampler) 51 | if args.distributed != 'ddp' or args.local_rank == 0: 52 | _print(">>> For evaluate {}, use length (bs/loader/set): ({}/{}/{})".format( \ 53 | name, args.test_batch_size, len(val_loader), len(val_set)), "evaluates") 54 | metric = evaluate(args, model, val_loader, name, metric_fun, logger=logger) 55 | 56 | for key, values in metric.items(): 57 | new_key = "val_{}/{}".format(name, key) 58 | assert new_key not in metrics 59 | metrics[new_key] = values 60 | 61 | return metrics 62 | 63 | 64 | def evaluate(args, model, dataloader, name, metric_fun, logger=None): 65 | if logger is not None: 66 | _print = logger.log_info 67 | else: 68 | def print_line(line, subname=None): 69 | print(line) 70 | _print = print_line 71 | 72 | start = time.time() 73 | model.eval() 74 | 75 | metric_fun.clear() 76 | 77 | if args.distributed != 'ddp' or args.local_rank == 0: 78 | bar = tqdm(total=len(dataloader), position=0, leave=True) 79 | 80 | for index, batch in enumerate(dataloader): 81 | 82 | for key in batch.keys(): 83 | if torch.is_tensor(batch[key]): 84 | batch[key] = batch[key].cuda(args.gpus[args.local_rank] \ 85 | if args.local_rank != -1 else 0, non_blocking=True) 86 | 87 | padder = InputPadder(batch['image1'].shape, div=args.pad) 88 | pad_batch = padder.pad_batch(batch) 89 | 90 | torch.cuda.synchronize() 91 | tm = time.time() 92 | 93 | with torch.no_grad(): 94 | output = model(pad_batch, iters=args.iters) 95 | 96 | torch.cuda.synchronize() 97 | elapsed = time.time() - tm 98 | 99 | output['flow_pred'] = padder.unpad(output['flow_final']) 100 | if args.isbi and 'flow_final_bw' in output.keys(): 101 | output['flow_pred_bw'] = padder.unpad(output['flow_final_bw']) 102 | 103 | if 'image1_valid' in batch.keys(): 104 | output['flow_pred'][batch['image1_valid'].repeat(1, 2, 1, 1) < 0.5] = 0 105 | 106 | metric_each = metric_fun.calculate(output, batch, name) 107 | 108 | if args.distributed == 'ddp': 109 | torch.distributed.barrier() 110 | reduced_metric_each = reduce_list(metric_each, args.nprocs) 111 | else: 112 | reduced_metric_each = metric_each 113 | 114 | reduced_metric_each.update({'time': elapsed}) 115 | 116 | if args.distributed != 'ddp' or args.local_rank == 0: 117 | metric_fun.push(reduced_metric_each) 118 | 119 | if args.distributed != 'ddp' or args.local_rank == 0: 120 | if 'masked_epe' in metric_each.keys(): 121 | bar.set_description("{}/{}[{}:{}], time:{:8.6f}, epe:{:8.6f}, masked_epe:{:8.6f}".format(index * len(batch['basename']), \ 122 | len(dataloader.dataset), batch['raw_index'][0], batch['basename'][0], elapsed, metric_each['epe'], metric_each['masked_epe'])) 123 | else: 124 | bar.set_description("{}/{}[{}:{}],time:{:8.6f}, epe:{:8.6f}".format(index * len(batch['basename']), \ 125 | len(dataloader.dataset), batch['raw_index'][0], batch['basename'][0], elapsed, metric_each['epe'])) 126 | bar.update(1) 127 | 128 | if args.distributed != 'ddp' or args.local_rank == 0: 129 | bar.close() 130 | metrics_str, all_metrics = metric_fun.summary() 131 | metric_fun.clear() 132 | 133 | if args.distributed != 'ddp' or args.local_rank == 0: 134 | _print("<<< In {} eval: {} (100X F1), with time {}s.".format(name, metrics_str, time.time() - start), "evaluate") 135 | 136 | model.train() 137 | return all_metrics 138 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('core') 4 | import time 5 | import logging 6 | import torchvision 7 | from torch.utils.tensorboard import SummaryWriter 8 | from utils.utils import ensure_folder 9 | 10 | 11 | class Logger: 12 | def __init__(self, args, main="main"): 13 | self.name = args.name 14 | if self.name == "": 15 | self.not_save_board = True 16 | self.not_save_log = True 17 | else: 18 | self.not_save_board = args.not_save_board 19 | self.not_save_log = args.not_save_log 20 | 21 | self.log_path = args.log_path 22 | self.debug = args.debug 23 | 24 | self.each_steps = 0 25 | self.running_loss = {} 26 | 27 | self.last_time = None 28 | self.writer = None 29 | self.logger = None 30 | self.logger_main = main 31 | self.log_dir = self.log_path 32 | if not self.not_save_board or not self.not_save_log: 33 | ensure_folder(self.log_dir) 34 | 35 | def _log_summary(self, index): 36 | metrics_data = [self.running_loss[k] / self.each_steps for k in self.running_loss.keys()] 37 | keys = self.running_loss.keys() 38 | 39 | # metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data) 40 | metrics_str = "" 41 | for data, key in zip(metrics_data, keys): 42 | metrics_str += "{}:{:8.6f}, ".format(key, data) 43 | latest_time = time.time() 44 | metrics_str += "time:{:8.6f}s.".format(latest_time - self.last_time) 45 | self.last_time = latest_time 46 | 47 | # print the training status 48 | self.log_info("Summary {}, {}".format(index, metrics_str), "trainer") 49 | 50 | def _write_summary(self, index): 51 | if self.not_save_board: 52 | return 53 | 54 | if self.writer is None: 55 | self.init_writer() 56 | 57 | for k in self.running_loss: 58 | self.writer.add_scalar(k, self.running_loss[k]/self.each_steps, index) 59 | 60 | def _clear_summary(self): 61 | for k in self.running_loss: 62 | self.running_loss[k] = 0.0 63 | self.each_steps = 0 64 | self.last_time = None 65 | 66 | def push(self, metrics, group=None, last=True): 67 | if last is True: 68 | self.each_steps += 1 69 | 70 | if self.last_time is None: 71 | self.last_time = time.time() 72 | 73 | for key in metrics: 74 | if group is not None: 75 | loss_key = "{}/{}".format(group, key) 76 | else: 77 | loss_key = key 78 | 79 | if loss_key not in self.running_loss: 80 | self.running_loss[loss_key] = 0.0 81 | self.running_loss[loss_key] += metrics[key] 82 | 83 | def summary(self, index): 84 | self._log_summary(index) 85 | self._write_summary(index) 86 | self._clear_summary() 87 | 88 | def write_dict(self, index, results, group=None): 89 | if self.not_save_board: 90 | return 91 | 92 | if self.writer is None: 93 | self.init_writer() 94 | 95 | for key in results: 96 | if group is not None: 97 | self.writer.add_scalar("{}/{}".format(group, key), results[key], index) 98 | else: 99 | self.writer.add_scalar(key, results[key], index) 100 | 101 | def write_image(self, index, name, image): 102 | if self.not_save_board: 103 | return 104 | 105 | if self.writer is None: 106 | self.init_writer() 107 | grid = torchvision.utils.make_grid(image) 108 | self.writer.add_image(name, grid, index) 109 | 110 | def init_writer(self): 111 | self.writer = SummaryWriter(log_dir=self.log_dir) 112 | 113 | def init_logger(self, name=None): 114 | log_path = os.path.join(self.log_dir, "{}.log".format(self.name if name is None else name)) 115 | 116 | self.logger = logging.getLogger(self.logger_main) 117 | self.logger.setLevel(logging.DEBUG) 118 | 119 | if not self.not_save_board: 120 | handler = logging.FileHandler(log_path) 121 | formatter = logging.Formatter('[%(asctime)s-%(name)s-%(levelname)s]: %(message)s', \ 122 | datefmt='%Y/%m/%d %H:%M:%S') 123 | handler.setFormatter(formatter) 124 | self.logger.addHandler(handler) 125 | 126 | stream_handler = logging.StreamHandler(sys.stdout) 127 | formatter = logging.Formatter('[%(asctime)s]: %(message)s', datefmt='%m/%d %H:%M:%S') 128 | stream_handler.setFormatter(formatter) 129 | if self.debug: 130 | stream_handler.setLevel(logging.DEBUG) 131 | else: 132 | stream_handler.setLevel(logging.INFO) 133 | self.logger.addHandler(stream_handler) 134 | 135 | def log_error(self, error, subname=None): 136 | if self.logger is None: 137 | self.init_logger() 138 | if subname is not None: 139 | logger = logging.getLogger("{}.{}".format(self.logger_main, subname)) 140 | logger.error(error) 141 | else: 142 | self.logger.error(error) 143 | 144 | def log_warn(self, warn, subname=None): 145 | if self.logger is None: 146 | self.init_logger() 147 | if subname is not None: 148 | logger = logging.getLogger("{}.{}".format(self.logger_main, subname)) 149 | logger.warning(warn) 150 | else: 151 | self.logger.warning(warn) 152 | 153 | def log_info(self, info, subname=None): 154 | if self.logger is None: 155 | self.init_logger() 156 | if subname is not None: 157 | logger = logging.getLogger("{}.{}".format(self.logger_main, subname)) 158 | logger.info(info) 159 | else: 160 | self.logger.info(info) 161 | 162 | def log_debug(self, debug, subname=None): 163 | if self.logger is None: 164 | self.init_logger() 165 | if subname is not None: 166 | logger = logging.getLogger("{}.{}".format(self.logger_main, subname)) 167 | logger.debug(debug) 168 | else: 169 | self.logger.debug(debug) 170 | 171 | def close(self): 172 | if self.writer is not None: 173 | self.writer.close() 174 | self.writer = None 175 | if self.logger is not None: 176 | logging.shutdown() 177 | self.logger = None 178 | 179 | def is_init_writer(self): 180 | return self.writer is not None 181 | 182 | def is_init_logger(self): 183 | return self.logger is not None 184 | 185 | def is_init(self): 186 | return self.is_init_writer() and self.is_init_logger() 187 | 188 | def __del__(self): 189 | self.close() 190 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | import os 4 | os.environ["KMP_BLOCKTIME"] = "0" 5 | import numpy as np 6 | np.finfo(np.dtype("float32")) 7 | np.finfo(np.dtype("float64")) 8 | import cv2 9 | cv2.setNumThreads(0) 10 | cv2.ocl.setUseOpenCL(False) 11 | 12 | import time 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | from torch.utils.data import DataLoader 17 | import torch.distributed as dist 18 | import torch.multiprocessing as mp 19 | 20 | from logger import Logger 21 | from argparser import ArgParser 22 | from trainer import Trainer 23 | from evaluate import evaluates 24 | from utils.datasets import fetch_dataset, fetch_test_dataset 25 | from utils.utils import setup_seed, count_parameters, count_all_parameters, build_module 26 | 27 | 28 | try: 29 | from torch.cuda.amp import GradScaler 30 | except: 31 | # dummy gradscale for PyTorch < 1.6 32 | class GradScaler: 33 | def __init__(self): 34 | pass 35 | def scale(self, loss): 36 | return loss 37 | def unscale_(self, optimizer): 38 | pass 39 | def step(self, optimizer): 40 | optimizer.step() 41 | def update(self): 42 | pass 43 | 44 | 45 | def train(local_rank, args): 46 | 47 | args.local_rank = local_rank 48 | 49 | if args.local_rank == 0: 50 | logger = Logger(args) 51 | for k, v in vars(args).items(): 52 | logger.log_debug('{}\t=\t{}'.format(k, v), "argparser") 53 | _print = logger.log_info 54 | else: 55 | logger = None 56 | def print_line(line, subname=None): 57 | print(line) 58 | _print = print_line 59 | 60 | if args.distributed == 'ddp': 61 | dist.init_process_group(backend='nccl', init_method='tcp://{}:{}'.format(args.ip, args.port), world_size=args.nprocs, rank=local_rank) 62 | torch.cuda.set_device(args.local_rank) 63 | 64 | train_set, val_sets, val_setnames = fetch_dataset(args) 65 | if args.distributed == 'ddp': 66 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True, seed=args.seed, drop_last=False) 67 | train_loader = DataLoader(train_set, args.batch_size, num_workers=args.jobs, sampler=train_sampler) 68 | else: 69 | train_sampler = None 70 | train_loader = DataLoader(train_set, args.batch_size, shuffle=True, num_workers=args.jobs, \ 71 | pin_memory=True, drop_last=True, sampler=None) 72 | if args.distributed != 'ddp' or args.local_rank == 0: 73 | _print("Use training set {} with length: bs/loader/dataset ({}/{}({})/{})".format( \ 74 | args.stage, args.batch_size, len(train_loader), len(train_loader.dataset), len(train_set))) 75 | 76 | assert len(val_setnames) == len(val_sets) 77 | val_length_str = "" 78 | for val_set, name in zip(val_sets, val_setnames): 79 | val_length_str += "({}/{}),".format(name, len(val_set)) 80 | if args.distributed != 'ddp' or args.local_rank == 0: 81 | _print("Use validation set: test_bs={}, name/datalength:{}".format( \ 82 | args.test_batch_size, val_length_str)) 83 | 84 | model = build_module("core", args.model)(args) 85 | if args.distributed == 'ddp': 86 | model.cuda(args.gpus[args.local_rank]) 87 | model.train() 88 | model = torch.nn.parallel.DistributedDataParallel(model, \ 89 | device_ids=[args.gpus[args.local_rank]]) 90 | _print("Use DistributedDataParallel at gpu {} with find_unused_parameters:False".format( \ 91 | args.gpus[args.local_rank])) 92 | else: 93 | model = nn.DataParallel(model, device_ids=args.gpus) 94 | model.cuda(args.gpus[0]) 95 | model.train() 96 | 97 | loss = build_module("core.loss", "Combine")(args) 98 | if args.distributed != 'ddp' or args.local_rank == 0: 99 | _print("Use losses: {} with weights: {}".format(args.loss, args.loss_weights)) 100 | 101 | metric_fun = build_module("core.metric", "Combine")(args) 102 | if args.distributed != 'ddp' or args.local_rank == 0: 103 | _print("Use metrics: {}".format(args.metric)) 104 | 105 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, \ 106 | eps=args.epsilon) 107 | if args.distributed != 'ddp' or args.local_rank == 0: 108 | _print("Use optimizer: {} with init lr:{}, decay:{}, epsilon:{} ".format( \ 109 | "AdamW", args.lr, args.weight_decay, args.epsilon)) 110 | 111 | lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, \ 112 | steps_per_epoch=len(train_loader), epochs=args.epoch) 113 | if args.distributed != 'ddp' or args.local_rank == 0: 114 | _print("Use scheduler: {}, with epoch:{}, steps_per_epoch {}".format( \ 115 | "OneCycleLR", args.epoch, len(train_loader))) 116 | 117 | scaler = GradScaler(enabled=args.mixed_precision) 118 | if args.distributed != 'ddp' or args.local_rank == 0: 119 | _print("Use gradscaler with mixed_precision? {}".format(args.mixed_precision)) 120 | 121 | trainer = Trainer(args, model, loss=loss, optimizer=optimizer, \ 122 | lr_scheduler=lr_scheduler, scaler=scaler, logger=logger) 123 | start = 0 124 | if args.checkpoint != '': 125 | start = trainer.load(args.checkpoint, only_model=False if args.resume else True) 126 | 127 | if args.distributed != 'ddp' or args.local_rank == 0: 128 | _print("For model {} with name {}, Parameter Count: {}(trainable)/{}(all), gpus: {}".format( \ 129 | args.model, args.name if args.name != "" else "NoNAME", count_parameters(trainer.model), \ 130 | count_all_parameters(trainer.model), args.gpus)) 131 | _print("Use small? {}".format(args.small)) 132 | 133 | setup_seed(args.seed) 134 | 135 | for i in range(start+1, args.epoch+1): 136 | if args.distributed != 'ddp' or args.local_rank == 0: 137 | _print(">>> Start the {}/{} training epoch with save feq {} at stage {}".format( \ 138 | i, args.epoch, args.save_feq, args.stage), "training") 139 | if train_sampler is not None: 140 | train_sampler.set_epoch(i) 141 | trainer.run_epoch(train_loader) 142 | if args.local_rank == 0 and logger is not None: 143 | logger.summary(i) 144 | 145 | if i % args.eval_feq == 0: 146 | if args.distributed != 'ddp' or args.local_rank == 0: 147 | _print(">>> Run {} evaluate epoch".format(i), "training") 148 | scores = evaluates(args, model, val_sets, val_setnames, metric_fun, logger=logger) 149 | if args.local_rank == 0 and logger is not None: 150 | logger.write_dict(i, scores) 151 | if args.local_rank == 0 and i % args.save_feq == 0: 152 | trainer.store(args.save_path, args.name, i) 153 | 154 | if args.local_rank == 0: 155 | dist.destroy_process_group() 156 | _print("Destroy_process_group", 'train') 157 | 158 | if logger is not None: 159 | logger.close() 160 | 161 | 162 | def test(local_rank, args, logger=None): 163 | 164 | args.local_rank = local_rank 165 | 166 | if logger is not None: 167 | _print = logger.log_info 168 | else: 169 | def print_line(line, subname=None): 170 | print(line) 171 | _print = print_line 172 | 173 | if args.distributed == 'ddp': 174 | dist.init_process_group(backend='nccl', init_method='tcp://{}:{}'.format(args.ip, args.port), world_size=args.nprocs, rank=local_rank) 175 | torch.cuda.set_device(args.local_rank) 176 | 177 | assert args.checkpoint != '' 178 | 179 | start = time.time() 180 | test_sets, test_setnames = fetch_test_dataset(args) 181 | 182 | assert len(test_setnames) == len(test_sets) 183 | test_length_str = "" 184 | for test_set, name in zip(test_sets, test_setnames): 185 | test_length_str += "({}/{}),".format(name, len(test_set)) 186 | 187 | if args.distributed != 'ddp' or args.local_rank == 0: 188 | _print("Use test set: test_bs={}, name/datalength:{}".format(args.test_batch_size, test_length_str), 'test') 189 | 190 | metric_fun = build_module("core.metric", "Combine")(args) 191 | if args.distributed != 'ddp' or args.local_rank == 0: 192 | _print("Use metrics: {}".format(args.metric), 'test') 193 | 194 | model = build_module("core", args.model)(args) 195 | 196 | if args.checkpoint != '': 197 | if args.distributed != 'ddp' or args.local_rank == 0: 198 | _print("Evalulate Model {} for checkpoint {}".format(args.model, args.checkpoint), 'test') 199 | _print("For model {} with name {}, Parameter Count: {}(trainable)/{}(all), gpus: {}".format( \ 200 | args.model, args.name if args.name != "" else "NoNAME", count_parameters(model), \ 201 | count_all_parameters(model), args.gpus)) 202 | 203 | state_dict = torch.load(args.checkpoint, map_location=torch.device("cpu")) 204 | try: 205 | if "model" in state_dict.keys(): 206 | state_dict = state_dict.pop("model") 207 | elif 'model_state_dict' in state_dict.keys(): 208 | state_dict = state_dict.pop("model_state_dict") 209 | 210 | if "module." in list(state_dict.keys())[0]: 211 | for key in list(state_dict.keys()): 212 | state_dict.update({key[7:]:state_dict.pop(key)}) 213 | 214 | model.load_state_dict(state_dict) 215 | except: 216 | raise KeyError("'model' not in or mismatch state_dict.keys(), please check checkpoint path {}".format(args.checkpoint)) 217 | else: 218 | raise NotImplementedError("Please set --checkpoint") 219 | 220 | if args.distributed == 'ddp': 221 | model.cuda(args.local_rank) 222 | model.eval() 223 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpus[local_rank]]) 224 | else: 225 | model = nn.DataParallel(model, device_ids=args.gpus) 226 | model.cuda(args.gpus[0]) 227 | model.eval() 228 | 229 | scores = evaluates(args, model, test_sets, test_setnames, metric_fun, logger=logger) 230 | 231 | summary_str = "" 232 | for key in scores.keys(): 233 | summary_str += "({}/{}),".format(key, scores[key]) 234 | 235 | if args.distributed != 'ddp' or args.local_rank == 0: 236 | dist.destroy_process_group() 237 | _print("Destroy_process_group", 'test') 238 | 239 | _print("Test complete, {}, time consuming {}/s".format(summary_str, time.time() - start), 'test') 240 | 241 | 242 | if __name__ == "__main__": 243 | argparser = ArgParser() 244 | args = argparser.parser() 245 | setup_seed(args.seed) 246 | 247 | if args.gpus[0] == -1: 248 | args.gpus = [i for i in range(torch.cuda.device_count())] 249 | args.nprocs = len(args.gpus) 250 | 251 | if args.task == "train": 252 | if args.distributed == 'ddp': 253 | mp.spawn(train, nprocs=args.nprocs, args=(args, )) 254 | else: 255 | train(-1, args) 256 | elif args.task[:4] == "test": 257 | if args.distributed == 'ddp': 258 | mp.spawn(test, nprocs=args.nprocs, args=(args, )) 259 | else: 260 | train(-1, args) 261 | else: 262 | print("task error") 263 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | import os 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn as nn 7 | import torch.distributed as dist 8 | 9 | from utils.utils import ensure_folder 10 | 11 | 12 | def reduce_list(lists, nprocs): 13 | new_lists = {} 14 | for key, value in lists.items(): 15 | rt = value.clone() 16 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 17 | rt /= nprocs 18 | new_lists[key] = rt.item() 19 | return new_lists 20 | 21 | 22 | def reduce_tensor(tensor, nprocs): 23 | rt = tensor.clone() 24 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 25 | rt /= nprocs 26 | return rt 27 | 28 | 29 | class Trainer: 30 | def __init__(self, args, model, loss=None, optimizer=None, logger=None, lr_scheduler=None, scaler=None): 31 | self.args = args 32 | self.model = model 33 | self.loss = loss 34 | self.optimizer = optimizer 35 | self.logger = logger 36 | self.lr_scheduler = lr_scheduler 37 | self.scaler = scaler 38 | 39 | if self.logger is None: 40 | def print_line(line, subname=None): 41 | if self.args.local_rank == 0: 42 | print(line) 43 | self.log_info = print_line 44 | else: 45 | self.log_info = self.logger.log_info 46 | 47 | def weight_fix(self, way, refer_dict=None): 48 | 49 | # fix weights 50 | if way == 'checkpoint': 51 | assert refer_dict is not None 52 | for n, p in self.model.named_parameters(): 53 | if n in refer_dict.keys(): 54 | p.requires_grad = False 55 | elif way == 'encoder': 56 | for n, p in self.model.named_parameters(): 57 | if 'fnet' in n or 'cnet' in n or 'enet' in n or 'fusion' in n: 58 | p.requires_grad = False 59 | elif way == 'event': 60 | for n, p in self.model.named_parameters(): 61 | if 'enet' in n or 'fusion' in n: 62 | p.requires_grad = False 63 | elif way == 'eventencoder': 64 | for n, p in self.model.named_parameters(): 65 | if 'enet' in n: 66 | p.requires_grad = False 67 | elif way == 'eventfusion': 68 | for n, p in self.model.named_parameters(): 69 | if 'fusion' in n: 70 | p.requires_grad = False 71 | elif way == 'imageencoder': 72 | for n, p in self.model.named_parameters(): 73 | if 'fnet' in n or 'cnet' in n: 74 | p.requires_grad = False 75 | elif way == 'raft': 76 | for n, p in self.model.named_parameters(): 77 | if 'fnet' in n or 'cnet' in n or 'update_block' in n: 78 | p.requires_grad = False 79 | elif way == 'allencoder': 80 | for n, p in self.model.named_parameters(): 81 | if 'enet' in n or 'fusion' in n or 'fnet' in n or 'cnet' in n: 82 | p.requires_grad = False 83 | elif way == 'update': 84 | for n, p in self.model.named_parameters(): 85 | if 'update_block' in n: 86 | p.requires_grad = False 87 | 88 | self.log_info("Weight fix way: {} complete.".format(way if way != "" else "None"), "trainer") 89 | 90 | def partial_load(self, path, weight_fix=None, not_load=False): 91 | # partial parameters loading 92 | assert path != '' 93 | load_dict = torch.load(path, map_location=torch.device("cpu")) 94 | try: 95 | if "model" not in load_dict.keys(): 96 | pretrained_dict = {k: v for k, v in load_dict.items() if k in self.model.state_dict().keys() \ 97 | and k != 'module.update_block.encoder.conv.weight' \ 98 | and k != 'module.update_block.encoder.conv.bias' \ 99 | and not k.startswith('module.update_block.flow_enc')} 100 | else: 101 | pretrained_dict = {k: v for k, v in load_dict.pop("model").items() if k in self.model.state_dict().keys() \ 102 | and k != 'module.update_block.encoder.conv.weight' \ 103 | and k != 'module.update_block.encoder.conv.bias' \ 104 | and not k.startswith('module.update_block.flow_enc')} 105 | assert len(pretrained_dict.keys()) > 0 106 | if not not_load: 107 | self.model.load_state_dict(pretrained_dict, strict=False) 108 | self.log_info("Partial load model from {} complete.".format(path), "trainer") 109 | else: 110 | self.log_info("Partial load dict from {} only for weight fix, but not load to model.".format(path), "trainer") 111 | except: 112 | raise KeyError("'model' not in or mismatch state_dict.keys(), please check partial checkpoint path {}".format(path)) 113 | 114 | self.weight_fix(weight_fix, pretrained_dict) 115 | 116 | def load(self, path, only_model=True): 117 | assert path != '' 118 | state_dict = torch.load(path, map_location=torch.device("cpu")) 119 | try: 120 | if "model" not in state_dict.keys(): 121 | self.model.load_state_dict(state_dict) 122 | else: 123 | self.model.load_state_dict(state_dict.pop("model")) 124 | except: 125 | raise KeyError("'model' not in or mismatch state_dict.keys(), please check checkpoint path {}".format(path)) 126 | 127 | index = 0 128 | if not only_model: 129 | try: 130 | self.optimizer.load_state_dict(state_dict.pop("optimizer")) 131 | except: 132 | self.log_info("'optimizer' not in state_dict.keys(), skip it.", "trainer") 133 | 134 | try: 135 | self.lr_scheduler.load_state_dict(state_dict.pop("lr_scheduler")) 136 | except: 137 | self.log_info("'lr_scheduler' not in state_dict.keys(), skip it.", "trainer") 138 | 139 | try: 140 | index = state_dict.pop("index") 141 | except: 142 | self.log_info("'index' not in state_dict.keys(), set to 0.", "trainer") 143 | 144 | self.log_info("Load model/optimizer/index from {} complete, index {}".format(path, index), "trainer") 145 | else: 146 | self.log_info("Load model from {} complete, index {}".format(path, index), "trainer") 147 | 148 | return index 149 | 150 | def store(self, path, name, index=None): 151 | if path != "" and name != "": 152 | checkpoint = {} 153 | checkpoint["model"] = self.model.state_dict() 154 | checkpoint["optimizer"] = self.optimizer.state_dict() 155 | checkpoint["lr_scheduler"] = self.lr_scheduler.state_dict() 156 | checkpoint["index"] = index 157 | 158 | ensure_folder(path) 159 | save_path = os.path.join(path, "{}_{}.pth".format(name, checkpoint["index"])) 160 | torch.save(checkpoint, save_path) 161 | self.log_info("<<< Save model to {} complete".format(save_path), "trainer") 162 | 163 | def run_epoch(self, dataloader): 164 | self.model.train() 165 | 166 | if self.args.local_rank == 0: 167 | self.bar = tqdm(total=len(dataloader), position=0, leave=True) 168 | for index, batch in enumerate(dataloader): 169 | 170 | for key in batch.keys(): 171 | if torch.is_tensor(batch[key]): 172 | batch[key] = batch[key].cuda(self.args.gpus[self.args.local_rank] \ 173 | if self.args.local_rank != -1 else 0, non_blocking=True) 174 | 175 | # output = self.model(batch['img1'], batch['img2'], self.args.iters) 176 | output = self.model(batch, self.args.iters) 177 | loss = self.loss(output, batch) 178 | 179 | self.optimizer.zero_grad() 180 | 181 | torch.distributed.barrier() 182 | reduced_loss = reduce_list(loss, self.args.nprocs) 183 | 184 | self.scaler.scale(loss['loss']).backward() 185 | self.scaler.unscale_(self.optimizer) 186 | nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip) 187 | 188 | self.scaler.step(self.optimizer) 189 | self.lr_scheduler.step() 190 | self.scaler.update() 191 | 192 | if self.args.local_rank == 0: 193 | self.bar_update(reduced_loss) 194 | 195 | if self.logger is not None: 196 | self.logger.push(reduced_loss, 'loss', last=False) 197 | self.logger.push({'lr': self.optimizer.state_dict()['param_groups'][0]['lr']}) 198 | 199 | if self.args.local_rank == 0: 200 | self.bar.close() 201 | 202 | def bar_update(self, loss): 203 | loss_description = "" 204 | for data, key in zip(loss.values(), loss.keys()): 205 | loss_description += "{}:{:5.4f}, ".format(key, data) if 'px' not in key else "{}:{:4.3f}, ".format(key, data) 206 | self.bar.set_description(loss_description) 207 | self.bar.update(1) 208 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danqu130/DCEIFlow/f1c24c1199aa033cc09ab2979acaeecc0cf3a3f7/utils/__init__.py -------------------------------------------------------------------------------- /utils/augmentor/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_augmentor import FlowAugmentor as ImageFlowAugmentor 2 | from .image_augmentor import SparseFlowAugmentor as ImageSparseFlowAugmentor 3 | from .event_augmentor import EventFlowAugmentor as EventFlowAugmentor 4 | from .event_augmentor import SparseEventFlowAugmentor as EventSparseFlowAugmentor 5 | 6 | 7 | def fetch_augmentor(is_event=True, is_sparse=False, aug_params=None): 8 | if is_event: 9 | if is_sparse: 10 | return EventSparseFlowAugmentor(**aug_params) 11 | else: 12 | return EventFlowAugmentor(**aug_params) 13 | else: 14 | if is_sparse: 15 | return ImageSparseFlowAugmentor(**aug_params) 16 | else: 17 | return ImageFlowAugmentor(**aug_params) 18 | -------------------------------------------------------------------------------- /utils/augmentor/event_augmentor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core/utils') 3 | import cv2 4 | cv2.setNumThreads(0) 5 | cv2.ocl.setUseOpenCL(False) 6 | 7 | import torch 8 | import numpy as np 9 | from PIL import Image 10 | from torchvision.transforms import ColorJitter 11 | 12 | 13 | def resize_flow(flow, des_height, des_width, method='bilinear'): 14 | # improper for sparse flow 15 | src_height = flow.shape[1] 16 | src_width = flow.shape[2] 17 | if src_width == des_width and src_height == des_height: 18 | return flow 19 | ratio_height = float(des_height) / float(src_height) 20 | ratio_width = float(des_width) / float(src_width) 21 | 22 | flow = np.transpose(flow, (1, 2, 0)) 23 | if method == 'bilinear': 24 | flow = cv2.resize( 25 | flow, (des_width, des_height), interpolation=cv2.INTER_LINEAR) 26 | elif method == 'nearest': 27 | flow = cv2.resize( 28 | flow, (des_width, des_height), interpolation=cv2.INTER_NEAREST) 29 | else: 30 | raise Exception('Invalid resize flow method!') 31 | flow = np.transpose(flow, (2, 0, 1)) 32 | 33 | flow[0, :, :] = flow[0, :, :] * ratio_width 34 | flow[1, :, :] = flow[1, :, :] * ratio_height 35 | return flow 36 | 37 | 38 | def horizontal_flip_flow(flow): 39 | flow = np.transpose(flow, (1, 2, 0)) 40 | flow = np.copy(np.fliplr(flow)) 41 | flow = np.transpose(flow, (2, 0, 1)) 42 | flow[0, :, :] *= -1 43 | return flow 44 | 45 | 46 | def vertical_flip_flow(flow): 47 | flow = np.transpose(flow, (1, 2, 0)) 48 | flow = np.copy(np.flipud(flow)) 49 | flow = np.transpose(flow, (2, 0, 1)) 50 | flow[1, :, :] *= -1 51 | return flow 52 | 53 | 54 | def remove_ambiguity_flow(flow_img, err_img, threshold_err=10.0): 55 | thre_flow = flow_img 56 | mask_img = np.ones(err_img.shape, dtype=np.uint8) 57 | mask_img[err_img > threshold_err] = 0.0 58 | thre_flow[err_img > threshold_err] = 0.0 59 | return thre_flow, mask_img 60 | 61 | 62 | class EventFlowAugmentor: 63 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False, spatial_aug_prob=0.8): 64 | # spatial augmentation params 65 | self.crop_size = crop_size 66 | self.min_scale = min_scale 67 | self.max_scale = max_scale 68 | self.spatial_aug_prob = spatial_aug_prob 69 | self.stretch_prob = 0.8 70 | self.max_stretch = 0.2 71 | 72 | # flip augmentation params 73 | self.do_flip = do_flip 74 | self.h_flip_prob = 0.5 75 | self.v_flip_prob = 0.1 76 | 77 | # photometric augmentation params 78 | self.photo_aug = ColorJitter( 79 | brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 80 | self.asymmetric_color_aug_prob = 0.2 81 | # self.eraser_aug_prob = 0.5 82 | 83 | def color_transform(self, img1, img2): 84 | """ Photometric augmentation """ 85 | 86 | # asymmetric 87 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.asymmetric_color_aug_prob: 88 | img1 = np.array(self.photo_aug( 89 | Image.fromarray(img1)), dtype=np.uint8) 90 | img2 = np.array(self.photo_aug( 91 | Image.fromarray(img2)), dtype=np.uint8) 92 | 93 | # symmetric 94 | else: 95 | image_stack = np.concatenate([img1, img2], axis=0) 96 | image_stack = np.array(self.photo_aug( 97 | Image.fromarray(image_stack)), dtype=np.uint8) 98 | img1, img2 = np.split(image_stack, 2, axis=0) 99 | 100 | return img1, img2 101 | 102 | def spatial_transform(self, event, img1, img2, flow, flow10=None, occ=None, occ10=None, event_r=None): 103 | 104 | if self.do_flip: 105 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.h_flip_prob: # h-flip 106 | event = event[:, :, ::-1] 107 | img1 = img1[:, ::-1] 108 | img2 = img2[:, ::-1] 109 | flow = flow[:, ::-1] * [-1.0, 1.0] 110 | if flow10 is not None: 111 | flow10 = flow10[:, ::-1] * [-1.0, 1.0] 112 | if occ is not None: 113 | occ = occ[:, ::-1] 114 | if occ10 is not None: 115 | occ10 = occ10[:, ::-1] 116 | if event_r is not None: 117 | event_r = event_r[:, :, ::-1] 118 | 119 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.v_flip_prob: # v-flip 120 | event = event[:, ::-1, :] 121 | img1 = img1[::-1, :] 122 | img2 = img2[::-1, :] 123 | flow = flow[::-1, :] * [1.0, -1.0] 124 | if flow10 is not None: 125 | flow10 = flow10[::-1, :] * [1.0, -1.0] 126 | if occ is not None: 127 | occ = occ[::-1, :] 128 | if occ10 is not None: 129 | occ10 = occ10[::-1, :] 130 | if event_r is not None: 131 | event_r = event_r[:, ::-1, :] 132 | 133 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 134 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 135 | 136 | event = event[:, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 137 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 138 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 139 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 140 | if flow10 is not None: 141 | flow10 = flow10[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 142 | if occ is not None: 143 | occ = occ[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 144 | if occ10 is not None: 145 | occ10 = occ10[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 146 | if event_r is not None: 147 | event_r = event_r[:, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 148 | 149 | return event, img1, img2, flow, flow10, occ, occ10, event_r 150 | 151 | def __call__(self, event, img1, img2, flow, flow10=None, occ=None, occ10=None, event_r=None): 152 | img1, img2 = self.color_transform(img1, img2) 153 | event, img1, img2, flow, flow10, occ, occ10, event_r = self.spatial_transform(\ 154 | event, img1, img2, flow, flow10, occ, occ10, event_r) 155 | 156 | event = np.ascontiguousarray(event) 157 | img1 = np.ascontiguousarray(img1) 158 | img2 = np.ascontiguousarray(img2) 159 | flow = np.ascontiguousarray(flow) 160 | if flow10 is not None: 161 | flow10 = np.ascontiguousarray(flow10) 162 | if occ is not None: 163 | occ = np.ascontiguousarray(occ) 164 | if occ10 is not None: 165 | occ10 = np.ascontiguousarray(occ10) 166 | if event_r is not None: 167 | event_r = np.ascontiguousarray(event_r) 168 | 169 | return event, img1, img2, flow, flow10, occ, occ10, event_r 170 | 171 | 172 | # TODO SparseEventFlowAugmentor 173 | class SparseEventFlowAugmentor: 174 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False, spatial_aug_prob=0.8): 175 | self.crop_size = crop_size 176 | self.min_scale = min_scale 177 | self.max_scale = max_scale 178 | self.spatial_aug_prob = spatial_aug_prob 179 | self.stretch_prob = 0.8 180 | self.max_stretch = 0.2 181 | 182 | # flip augmentation params 183 | self.do_flip = do_flip 184 | self.h_flip_prob = 0.5 185 | self.v_flip_prob = 0.1 186 | 187 | # photometric augmentation params 188 | self.photo_aug = ColorJitter( 189 | brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 190 | self.asymmetric_color_aug_prob = 0.2 191 | # self.eraser_aug_prob = 0.5 192 | 193 | def color_transform(self, img1, img2): 194 | """ Photometric augmentation """ 195 | 196 | # asymmetric 197 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.asymmetric_color_aug_prob: 198 | img1 = np.array(self.photo_aug( 199 | Image.fromarray(img1)), dtype=np.uint8) 200 | img2 = np.array(self.photo_aug( 201 | Image.fromarray(img2)), dtype=np.uint8) 202 | 203 | # symmetric 204 | else: 205 | image_stack = np.concatenate([img1, img2], axis=0) 206 | image_stack = np.array(self.photo_aug( 207 | Image.fromarray(image_stack)), dtype=np.uint8) 208 | img1, img2 = np.split(image_stack, 2, axis=0) 209 | 210 | return img1, img2 211 | 212 | def spatial_transform(self, event, img1, img2, flow, valid=None, flow10=None, valid10=None): 213 | if self.do_flip: 214 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.h_flip_prob: # h-flip 215 | event = event[:, :, ::-1] 216 | img1 = img1[:, ::-1] 217 | img2 = img2[:, ::-1] 218 | flow = flow[:, ::-1] * [-1.0, 1.0] 219 | if valid is not None: 220 | valid = valid[:, ::-1] 221 | if flow10 is not None and valid10 is not None: 222 | flow10 = flow10[:, ::-1] * [-1.0, 1.0] 223 | valid10 = valid10[:, ::-1] 224 | 225 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.v_flip_prob: # v-flip 226 | event = event[:, ::-1, :] 227 | img1 = img1[::-1, :] 228 | img2 = img2[::-1, :] 229 | flow = flow[::-1, :] * [1.0, -1.0] 230 | if valid is not None: 231 | valid = valid[:, ::-1] 232 | if flow10 is not None and valid10 is not None: 233 | flow10 = flow10[::-1, :] * [1.0, -1.0] 234 | valid10 = valid10[:, ::-1] 235 | 236 | if img1.shape[0] == self.crop_size[0] or img1.shape[1] == self.crop_size[1]: 237 | pass 238 | else: 239 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 240 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 241 | 242 | event = event[:, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 243 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 244 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 245 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 246 | 247 | if valid is not None: 248 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 249 | if flow10 is not None and valid10 is not None: 250 | flow10 = flow10[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 251 | valid10 = valid10[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 252 | 253 | return event, img1, img2, flow, valid, flow10, valid10 254 | 255 | def __call__(self, event, img1, img2, flow, valid=None, flow10=None, valid10=None): 256 | 257 | img1, img2 = self.color_transform(img1, img2) 258 | event, img1, img2, flow, valid, flow10, valid10 = self.spatial_transform(\ 259 | event, img1, img2, flow, valid, flow10, valid10) 260 | 261 | event = np.ascontiguousarray(event) 262 | img1 = np.ascontiguousarray(img1) 263 | img2 = np.ascontiguousarray(img2) 264 | flow = np.ascontiguousarray(flow) 265 | 266 | if valid is not None: 267 | valid = np.ascontiguousarray(valid) 268 | if flow10 is not None: 269 | flow10 = np.ascontiguousarray(flow10) 270 | if valid10 is not None: 271 | valid10 = np.ascontiguousarray(valid10) 272 | 273 | return event, img1, img2, flow, valid, flow10, valid10 274 | -------------------------------------------------------------------------------- /utils/augmentor/image_augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import torch.nn.functional as F 5 | from torchvision.transforms import ColorJitter 6 | 7 | import sys 8 | sys.path.append('core/utils') 9 | import cv2 10 | cv2.setNumThreads(0) 11 | cv2.ocl.setUseOpenCL(False) 12 | 13 | 14 | class FlowAugmentor: 15 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 16 | 17 | # spatial augmentation params 18 | self.crop_size = crop_size 19 | self.min_scale = min_scale 20 | self.max_scale = max_scale 21 | self.spatial_aug_prob = 0.8 22 | self.stretch_prob = 0.8 23 | self.max_stretch = 0.2 24 | 25 | # flip augmentation params 26 | self.do_flip = do_flip 27 | self.h_flip_prob = 0.5 28 | self.v_flip_prob = 0.1 29 | 30 | # photometric augmentation params 31 | self.photo_aug = ColorJitter( 32 | brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, img1, img2): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.asymmetric_color_aug_prob: 41 | img1 = np.array(self.photo_aug( 42 | Image.fromarray(img1)), dtype=np.uint8) 43 | img2 = np.array(self.photo_aug( 44 | Image.fromarray(img2)), dtype=np.uint8) 45 | 46 | # symmetric 47 | else: 48 | image_stack = np.concatenate([img1, img2], axis=0) 49 | image_stack = np.array(self.photo_aug( 50 | Image.fromarray(image_stack)), dtype=np.uint8) 51 | img1, img2 = np.split(image_stack, 2, axis=0) 52 | 53 | return img1, img2 54 | 55 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 56 | """ Occlusion augmentation """ 57 | 58 | ht, wd = img1.shape[:2] 59 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.eraser_aug_prob: 60 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 61 | for _ in range(np.random.randint(1, 3)): 62 | x0 = np.random.randint(0, wd) 63 | y0 = np.random.randint(0, ht) 64 | dx = np.random.randint(bounds[0], bounds[1]) 65 | dy = np.random.randint(bounds[0], bounds[1]) 66 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 67 | 68 | return img1, img2 69 | 70 | def spatial_transform(self, img1, img2, flow, flow10): 71 | # randomly sample scale 72 | ht, wd = img1.shape[:2] 73 | min_scale = np.maximum( 74 | (self.crop_size[0] + 8) / float(ht), 75 | (self.crop_size[1] + 8) / float(wd)) 76 | 77 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 78 | scale_x = scale 79 | scale_y = scale 80 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.stretch_prob: 81 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, 82 | self.max_stretch) 83 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, 84 | self.max_stretch) 85 | 86 | scale_x = np.clip(scale_x, min_scale, None) 87 | scale_y = np.clip(scale_y, min_scale, None) 88 | 89 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.spatial_aug_prob: 90 | # rescale the images 91 | img1 = cv2.resize(img1, None, fx=scale_x, 92 | fy=scale_y, interpolation=cv2.INTER_LINEAR) 93 | img2 = cv2.resize(img2, None, fx=scale_x, 94 | fy=scale_y, interpolation=cv2.INTER_LINEAR) 95 | flow = cv2.resize(flow, None, fx=scale_x, 96 | fy=scale_y, interpolation=cv2.INTER_LINEAR) 97 | flow = flow * [scale_x, scale_y] 98 | if flow10 is not None: 99 | flow10 = cv2.resize(flow10, None, fx=scale_x, 100 | fy=scale_y, interpolation=cv2.INTER_LINEAR) 101 | flow10 = flow10 * [scale_x, scale_y] 102 | 103 | if self.do_flip: 104 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.h_flip_prob: # h-flip 105 | img1 = img1[:, ::-1] 106 | img2 = img2[:, ::-1] 107 | flow = flow[:, ::-1] * [-1.0, 1.0] 108 | if flow10 is not None: 109 | flow10 = flow10[:, ::-1] * [-1.0, 1.0] 110 | 111 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.v_flip_prob: # v-flip 112 | img1 = img1[::-1, :] 113 | img2 = img2[::-1, :] 114 | flow = flow[::-1, :] * [1.0, -1.0] 115 | if flow10 is not None: 116 | flow10 = flow10[::-1, :] * [1.0, -1.0] 117 | 118 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 119 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 120 | 121 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 122 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 123 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 124 | if flow10 is not None: 125 | flow10 = flow10[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 126 | 127 | return img1, img2, flow, flow10 128 | 129 | def __call__(self, img1, img2, flow, flow10=None): 130 | img1, img2 = self.color_transform(img1, img2) 131 | img1, img2 = self.eraser_transform(img1, img2) 132 | img1, img2, flow, flow10 = self.spatial_transform(img1, img2, flow, flow10) 133 | 134 | img1 = np.ascontiguousarray(img1) 135 | img2 = np.ascontiguousarray(img2) 136 | flow = np.ascontiguousarray(flow) 137 | if flow10 is not None: 138 | flow10 = np.ascontiguousarray(flow10) 139 | 140 | return img1, img2, flow, flow10 141 | 142 | 143 | class SparseFlowAugmentor: 144 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False, spatial_aug_prob=0.8): 145 | # spatial augmentation params 146 | self.crop_size = crop_size 147 | self.min_scale = min_scale 148 | self.max_scale = max_scale 149 | self.spatial_aug_prob = spatial_aug_prob 150 | self.stretch_prob = 0.8 151 | self.max_stretch = 0.2 152 | 153 | # flip augmentation params 154 | self.do_flip = do_flip 155 | self.h_flip_prob = 0.5 156 | self.v_flip_prob = 0.1 157 | 158 | # photometric augmentation params 159 | self.photo_aug = ColorJitter( 160 | brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 161 | self.asymmetric_color_aug_prob = 0.2 162 | self.eraser_aug_prob = 0.5 163 | 164 | def color_transform(self, img1, img2): 165 | image_stack = np.concatenate([img1, img2], axis=0) 166 | image_stack = np.array(self.photo_aug( 167 | Image.fromarray(image_stack)), dtype=np.uint8) 168 | img1, img2 = np.split(image_stack, 2, axis=0) 169 | return img1, img2 170 | 171 | def eraser_transform(self, img1, img2): 172 | ht, wd = img1.shape[:2] 173 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.eraser_aug_prob: 174 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 175 | for _ in range(np.random.randint(1, 3)): 176 | x0 = np.random.randint(0, wd) 177 | y0 = np.random.randint(0, ht) 178 | dx = np.random.randint(50, 100) 179 | dy = np.random.randint(50, 100) 180 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 181 | 182 | return img1, img2 183 | 184 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 185 | ht, wd = flow.shape[:2] 186 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 187 | coords = np.stack(coords, axis=-1) 188 | 189 | coords = coords.reshape(-1, 2).astype(np.float32) 190 | flow = flow.reshape(-1, 2).astype(np.float32) 191 | valid = valid.reshape(-1).astype(np.float32) 192 | 193 | coords0 = coords[valid >= 1] 194 | flow0 = flow[valid >= 1] 195 | 196 | ht1 = int(round(ht * fy)) 197 | wd1 = int(round(wd * fx)) 198 | 199 | coords1 = coords0 * [fx, fy] 200 | flow1 = flow0 * [fx, fy] 201 | 202 | xx = np.round(coords1[:, 0]).astype(np.int32) 203 | yy = np.round(coords1[:, 1]).astype(np.int32) 204 | 205 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 206 | xx = xx[v] 207 | yy = yy[v] 208 | flow1 = flow1[v] 209 | 210 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 211 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 212 | 213 | flow_img[yy, xx] = flow1 214 | valid_img[yy, xx] = 1 215 | 216 | return flow_img, valid_img 217 | 218 | def spatial_transform(self, img1, img2, flow, valid, flow10=None, valid10=None): 219 | # randomly sample scale 220 | 221 | ht, wd = img1.shape[:2] 222 | min_scale = np.maximum( 223 | (self.crop_size[0] + 1) / float(ht), 224 | (self.crop_size[1] + 1) / float(wd)) 225 | 226 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 227 | scale_x = np.clip(scale, min_scale, None) 228 | scale_y = np.clip(scale, min_scale, None) 229 | 230 | if torch.FloatTensor(1).uniform_(0, 1).item() < self.spatial_aug_prob or img1.shape[0] - self.crop_size[0] < 0: 231 | # rescale the images 232 | img1 = cv2.resize(img1, None, fx=scale_x, 233 | fy=scale_y, interpolation=cv2.INTER_LINEAR) 234 | img2 = cv2.resize(img2, None, fx=scale_x, 235 | fy=scale_y, interpolation=cv2.INTER_LINEAR) 236 | flow, valid = self.resize_sparse_flow_map( 237 | flow, valid, fx=scale_x, fy=scale_y) 238 | if flow10 is not None and valid10 is not None: 239 | flow10, valid10 = self.resize_sparse_flow_map( 240 | flow10, valid10, fx=scale_x, fy=scale_y) 241 | 242 | if self.do_flip: 243 | if torch.FloatTensor(1).uniform_(0, 1).item() < 0.5: # h-flip 244 | img1 = img1[:, ::-1] 245 | img2 = img2[:, ::-1] 246 | flow = flow[:, ::-1] * [-1.0, 1.0] 247 | valid = valid[:, ::-1] 248 | if flow10 is not None and valid10 is not None: 249 | flow10 = flow10[:, ::-1] * [-1.0, 1.0] 250 | valid10 = valid10[:, ::-1] 251 | 252 | margin_y = 20 253 | margin_x = 50 254 | 255 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 256 | x0 = np.random.randint(-margin_x, 257 | img1.shape[1] - self.crop_size[1] + margin_x) 258 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 259 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 260 | 261 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 262 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 263 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 264 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 265 | if flow10 is not None and valid10 is not None: 266 | flow10 = flow10[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 267 | valid10 = valid10[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 268 | return img1, img2, flow, valid, flow10, valid10 269 | 270 | def __call__(self, img1, img2, flow, valid, flow10=None, valid10=None): 271 | img1, img2 = self.color_transform(img1, img2) 272 | img1, img2 = self.eraser_transform(img1, img2) 273 | img1, img2, flow, valid, flow10, valid10 = self.spatial_transform( 274 | img1, img2, flow, valid, flow10, valid10) 275 | 276 | img1 = np.ascontiguousarray(img1) 277 | img2 = np.ascontiguousarray(img2) 278 | flow = np.ascontiguousarray(flow) 279 | valid = np.ascontiguousarray(valid) 280 | if flow10 is not None and valid10 is not None: 281 | flow10 = np.ascontiguousarray(flow10) 282 | valid10 = np.ascontiguousarray(valid10) 283 | 284 | return img1, img2, flow, valid, flow10, valid10 285 | -------------------------------------------------------------------------------- /utils/datasets/FlyingChairs2.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["KMP_BLOCKTIME"] = "0" 3 | import sys 4 | sys.path.append('utils') 5 | sys.path.append('utils/datasets') 6 | sys.path.append('utils/augmentor') 7 | import cv2 8 | cv2.setNumThreads(0) 9 | cv2.ocl.setUseOpenCL(False) 10 | 11 | import numpy as np 12 | from glob import glob 13 | import torch 14 | import torch.utils.data.dataset as dataset 15 | 16 | from utils.event_uitls import eventsToVoxel 17 | from utils.file_io import read_gen, readDenseFlow, read_event_h5 18 | from utils.augmentor import fetch_augmentor 19 | 20 | 21 | FlyingChairs2_BAD_ID = [ 22 | '0000114', 23 | '0000163', 24 | '0000491', 25 | '0000621', 26 | '0000107', 27 | '0011516', 28 | '0011949', 29 | '0019593', 30 | '0013451', 31 | '0006500', 32 | '0019693', 33 | '0009912', 34 | '0016755', 35 | '0016809', 36 | '0011031', 37 | '0001888', 38 | '0001535', 39 | '0002853', 40 | '0009141', 41 | '0009677', 42 | '0016628', 43 | '0003666', 44 | '0008214', 45 | '0012774', 46 | '0007896', 47 | '0012890', 48 | '0011034', 49 | '0016447', 50 | '0002242', 51 | '0013501', 52 | '0012985', 53 | '0014770', 54 | '0018237', 55 | '0019582', 56 | '0019767', ] 57 | 58 | 59 | VALIDATE_INDICES = [ 60 | 5, 17, 42, 45, 58, 62, 96, 111, 117, 120, 121, 131, 132, 61 | 152, 160, 248, 263, 264, 291, 293, 295, 299, 316, 320, 336, 62 | 337, 343, 358, 399, 401, 429, 438, 468, 476, 494, 509, 528, 63 | 531, 572, 581, 583, 588, 593, 681, 688, 696, 714, 767, 786, 64 | 810, 825, 836, 841, 883, 917, 937, 942, 970, 974, 980, 1016, 65 | 1043, 1064, 1118, 1121, 1133, 1153, 1155, 1158, 1159, 1173, 66 | 1187, 1219, 1237, 1238, 1259, 1266, 1278, 1296, 1354, 1378, 67 | 1387, 1494, 1508, 1518, 1574, 1601, 1614, 1668, 1673, 1699, 68 | 1712, 1714, 1737, 1841, 1872, 1879, 1901, 1921, 1934, 1961, 69 | 1967, 1978, 2018, 2030, 2039, 2043, 2061, 2113, 2204, 2216, 70 | 2236, 2250, 2274, 2292, 2310, 2342, 2359, 2374, 2382, 2399, 71 | 2415, 2419, 2483, 2502, 2504, 2576, 2589, 2590, 2622, 2624, 72 | 2636, 2651, 2655, 2658, 2659, 2664, 2672, 2706, 2707, 2709, 73 | 2725, 2732, 2761, 2827, 2864, 2866, 2905, 2922, 2929, 2966, 74 | 2972, 2993, 3010, 3025, 3031, 3040, 3041, 3070, 3113, 3124, 75 | 3129, 3137, 3141, 3157, 3183, 3206, 3219, 3247, 3253, 3272, 76 | 3276, 3321, 3328, 3333, 3338, 3341, 3346, 3351, 3396, 3419, 77 | 3430, 3433, 3448, 3455, 3463, 3503, 3526, 3529, 3537, 3555, 78 | 3577, 3584, 3591, 3594, 3597, 3603, 3613, 3615, 3670, 3676, 79 | 3678, 3697, 3723, 3728, 3734, 3745, 3750, 3752, 3779, 3782, 80 | 3813, 3817, 3819, 3854, 3885, 3944, 3947, 3970, 3985, 4011, 81 | 4022, 4071, 4075, 4132, 4158, 4167, 4190, 4194, 4207, 4246, 82 | 4249, 4298, 4307, 4317, 4318, 4319, 4320, 4382, 4399, 4401, 83 | 4407, 4416, 4423, 4484, 4491, 4493, 4517, 4525, 4538, 4578, 84 | 4606, 4609, 4620, 4623, 4637, 4646, 4662, 4668, 4716, 4739, 85 | 4747, 4770, 4774, 4776, 4785, 4800, 4845, 4863, 4891, 4904, 86 | 4922, 4925, 4956, 4963, 4964, 4994, 5011, 5019, 5036, 5038, 87 | 5041, 5055, 5118, 5122, 5130, 5162, 5164, 5178, 5196, 5227, 88 | 5266, 5270, 5273, 5279, 5299, 5310, 5314, 5363, 5375, 5384, 89 | 5393, 5414, 5417, 5433, 5448, 5494, 5505, 5509, 5525, 5566, 90 | 5581, 5602, 5609, 5620, 5653, 5670, 5678, 5690, 5700, 5703, 91 | 5724, 5752, 5765, 5803, 5811, 5860, 5881, 5895, 5912, 5915, 92 | 5940, 5952, 5966, 5977, 5988, 6007, 6037, 6061, 6069, 6080, 93 | 6111, 6127, 6146, 6161, 6166, 6168, 6178, 6182, 6190, 6220, 94 | 6235, 6253, 6270, 6343, 6372, 6379, 6410, 6411, 6442, 6453, 95 | 6481, 6498, 6500, 6509, 6532, 6541, 6543, 6560, 6576, 6580, 96 | 6594, 6595, 6609, 6625, 6629, 6644, 6658, 6673, 6680, 6698, 97 | 6699, 6702, 6705, 6741, 6759, 6785, 6792, 6794, 6809, 6810, 98 | 6830, 6838, 6869, 6871, 6889, 6925, 6995, 7003, 7026, 7029, 99 | 7080, 7082, 7097, 7102, 7116, 7165, 7200, 7232, 7271, 7282, 100 | 7324, 7333, 7335, 7372, 7387, 7407, 7472, 7474, 7482, 7489, 101 | 7499, 7516, 7533, 7536, 7566, 7620, 7654, 7691, 7704, 7722, 102 | 7746, 7750, 7773, 7806, 7821, 7827, 7851, 7873, 7880, 7884, 103 | 7904, 7912, 7948, 7964, 7965, 7984, 7989, 7992, 8035, 8050, 104 | 8074, 8091, 8094, 8113, 8116, 8151, 8159, 8171, 8179, 8194, 105 | 8195, 8239, 8263, 8290, 8295, 8312, 8367, 8374, 8387, 8407, 106 | 8437, 8439, 8518, 8556, 8588, 8597, 8601, 8651, 8657, 8723, 107 | 8759, 8763, 8785, 8802, 8813, 8826, 8854, 8856, 8866, 8918, 108 | 8922, 8923, 8932, 8958, 8967, 9003, 9018, 9078, 9095, 9104, 109 | 9112, 9129, 9147, 9170, 9171, 9197, 9200, 9249, 9253, 9270, 110 | 9282, 9288, 9295, 9321, 9323, 9324, 9347, 9399, 9403, 9417, 111 | 9426, 9427, 9439, 9468, 9486, 9496, 9511, 9516, 9518, 9529, 112 | 9557, 9563, 9564, 9584, 9586, 9591, 9599, 9600, 9601, 9632, 113 | 9654, 9667, 9678, 9696, 9716, 9723, 9740, 9820, 9824, 9825, 114 | 9828, 9863, 9866, 9868, 9889, 9929, 9938, 9953, 9967, 10019, 115 | 10020, 10025, 10059, 10111, 10118, 10125, 10174, 10194, 116 | 10201, 10202, 10220, 10221, 10226, 10242, 10250, 10276, 117 | 10295, 10302, 10305, 10327, 10351, 10360, 10369, 10393, 118 | 10407, 10438, 10455, 10463, 10465, 10470, 10478, 10503, 119 | 10508, 10509, 10809, 11080, 11331, 11607, 11610, 11864, 120 | 12390, 12393, 12396, 12399, 12671, 12921, 12930, 13178, 121 | 13453, 13717, 14499, 14517, 14775, 15297, 15556, 15834, 122 | 15839, 16126, 16127, 16386, 16633, 16644, 16651, 17166, 123 | 17169, 17958, 17959, 17962, 18224, 21176, 21180, 21190, 124 | 21802, 21803, 21806, 22584, 22857, 22858, 22866] 125 | 126 | 127 | class FlyingChairs2(dataset.Dataset): 128 | def __init__(self, args, data_root, data_kind='train', aug_params=None): 129 | super().__init__() 130 | self.args = args 131 | self.event_bins = args.event_bins 132 | self.event_polarity = False if args.no_event_polarity else True 133 | self.data_root = data_root 134 | 135 | if data_kind[:5] == 'train': 136 | self.data_split = 'train' 137 | if len(data_kind) > 5: 138 | self.data_mode = data_kind[5:] 139 | else: 140 | self.data_mode = 'train' 141 | elif data_kind[:3] == 'val': 142 | self.data_split = 'val' 143 | self.data_mode = 'full' 144 | else: 145 | raise NotImplementedError( 146 | "Unsupported data kind {}".format(data_kind)) 147 | 148 | self.augmentor = None 149 | if aug_params is not None: 150 | self.augmentor = fetch_augmentor( 151 | is_event=True, is_sparse=False, aug_params=aug_params) 152 | 153 | self.fetch_valids() 154 | self.data_length = len(self.image1_filenames) 155 | 156 | def fetch_valids(self): 157 | 158 | images_root = os.path.join(self.data_root, self.data_split) 159 | events_root = os.path.join(self.data_root, "events_" + self.data_split) 160 | 161 | image1_filenames = sorted( 162 | glob(os.path.join(images_root, "*-img_0.png"))) 163 | image2_filenames = sorted( 164 | glob(os.path.join(images_root, "*-img_1.png"))) 165 | flow01_filenames = sorted( 166 | glob(os.path.join(images_root, "*-flow_01.flo"))) 167 | flow10_filenames = sorted( 168 | glob(os.path.join(images_root, "*-flow_10.flo"))) 169 | event_filenames = sorted( 170 | glob(os.path.join(events_root, "*-event.hdf5"))) 171 | 172 | validate_indices = [ 173 | x for x in VALIDATE_INDICES if x in range(len(image1_filenames))] 174 | list_of_indices = None 175 | if self.data_mode[:3] == "val": 176 | list_of_indices = validate_indices 177 | elif self.data_mode == "full": 178 | list_of_indices = range(len(image1_filenames)) 179 | elif self.data_mode == "train": 180 | list_of_indices = [x for x in range( 181 | len(image1_filenames)) if x not in validate_indices] 182 | else: 183 | raise NotImplementedError( 184 | "Unsupported data mode {}".format(self.data_mode)) 185 | 186 | final_indices = [] 187 | for i in range(len(image1_filenames)): 188 | im1_base_fileid = (os.path.basename( 189 | image1_filenames[i])).split("-", 2)[0] 190 | im2_base_fileid = (os.path.basename( 191 | image2_filenames[i])).split("-", 2)[0] 192 | flow_f_base_fileid = (os.path.basename( 193 | flow01_filenames[i])).split("-", 2)[0] 194 | flow_b_base_fileid = (os.path.basename( 195 | flow10_filenames[i])).split("-", 2)[0] 196 | event_base_fileid = (os.path.basename( 197 | event_filenames[i])).split("-", 2)[0] 198 | 199 | assert (im1_base_fileid == im2_base_fileid) 200 | assert (im1_base_fileid == flow_f_base_fileid) 201 | assert (im1_base_fileid == flow_b_base_fileid) 202 | assert (im1_base_fileid == event_base_fileid) 203 | if i in list_of_indices and im1_base_fileid not in FlyingChairs2_BAD_ID: 204 | final_indices.append(i) 205 | 206 | self.image1_filenames = [image1_filenames[i] for i in final_indices] 207 | self.image2_filenames = [image2_filenames[i] for i in final_indices] 208 | self.flow01_filenames = [flow01_filenames[i] for i in final_indices] 209 | self.flow10_filenames = [flow10_filenames[i] for i in final_indices] 210 | self.event_filenames = [event_filenames[i] for i in final_indices] 211 | 212 | def load_data_by_index(self, index): 213 | im1_filename = self.image1_filenames[index] 214 | im2_filename = self.image2_filenames[index] 215 | flow01_filename = self.flow01_filenames[index] 216 | flow10_filename = self.flow10_filenames[index] 217 | event_filename = self.event_filenames[index] 218 | 219 | im1_nparray = np.array(read_gen(im1_filename)).astype(np.uint8) 220 | im2_nparray = np.array(read_gen(im2_filename)).astype(np.uint8) 221 | flow01_nparray = readDenseFlow(flow01_filename) 222 | flow10_nparray = readDenseFlow(flow10_filename) 223 | events_nparray = read_event_h5(event_filename) 224 | 225 | return im1_nparray, im2_nparray, flow01_nparray, flow10_nparray, events_nparray 226 | 227 | def __getitem__(self, index): 228 | 229 | index = index % self.data_length 230 | 231 | im1_filename = self.image1_filenames[index] 232 | basename = os.path.basename(im1_filename)[:6] 233 | 234 | im1_nparray, im2_nparray, flow01_nparray, flow10_nparray, events_nparray = \ 235 | self.load_data_by_index(index) 236 | 237 | height, width = im1_nparray.shape[:2] 238 | 239 | event_voxel_nparray = eventsToVoxel(events_nparray, num_bins=self.event_bins, height=height, 240 | width=width, event_polarity=self.event_polarity, temporal_bilinear=True) 241 | 242 | event_voxel_reversed_nparray = None 243 | if self.args.isbi: 244 | event_x = np.flip(events_nparray[:, 0].astype(np.int), axis=0) 245 | event_y = np.flip(events_nparray[:, 1].astype(np.int), axis=0) 246 | event_pols = np.flip(-1 * events_nparray[:, 3].astype(np.int), axis=0) 247 | event_timestamps = events_nparray[:, 2] 248 | event_timestamps = np.flip( 249 | event_timestamps.max() - event_timestamps, axis=0) 250 | events_nparray_bi = np.concatenate( 251 | (event_x[:, np.newaxis], event_y[:, np.newaxis], event_timestamps[:, np.newaxis], event_pols[:, np.newaxis]), axis=1) 252 | event_voxel_reversed_nparray = eventsToVoxel( 253 | events_nparray_bi, num_bins=self.event_bins, height=height, width=width, event_polarity=self.event_polarity, temporal_bilinear=True) 254 | 255 | valid = None 256 | valid10 = None 257 | 258 | if self.augmentor is not None: 259 | event_voxel_nparray, im1_nparray, im2_nparray, flow01_nparray, flow10_nparray, \ 260 | _, _, event_voxel_reversed_nparray = self.augmentor( 261 | event_voxel_nparray, im1_nparray, im2_nparray, flow01_nparray, flow10_nparray, 262 | event_r=event_voxel_reversed_nparray) 263 | 264 | image1 = torch.from_numpy(im1_nparray).permute(2, 0, 1).float() 265 | image2 = torch.from_numpy(im2_nparray).permute(2, 0, 1).float() 266 | flow = torch.from_numpy(flow01_nparray).permute(2, 0, 1).float() 267 | flow10 = torch.from_numpy(flow10_nparray).permute(2, 0, 1).float() 268 | 269 | event_voxel = torch.from_numpy(event_voxel_nparray).float() 270 | if self.args.isbi: 271 | reversed_event_voxel = torch.from_numpy(event_voxel_reversed_nparray).float() 272 | 273 | event_valid = torch.norm(event_voxel, p=2, dim=0, keepdim=False) > 0 274 | event_valid = event_valid.float().unsqueeze(0) 275 | 276 | if valid is not None: 277 | valid = torch.from_numpy(valid) 278 | else: 279 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 280 | valid = valid.float().unsqueeze(0) 281 | 282 | if flow10 is not None: 283 | if valid10 is not None: 284 | valid10 = torch.from_numpy(valid) 285 | else: 286 | valid10 = (flow10[0].abs() < 1000) & (flow10[1].abs() < 1000) 287 | valid10 = valid10.float().unsqueeze(0) 288 | else: 289 | flow10 = torch.zeros_like(flow, dtype=torch.float32) 290 | valid10 = torch.zeros_like(valid, dtype=torch.float32) 291 | 292 | if self.args.isbi: 293 | batch = dict( 294 | index=index, 295 | raw_index=index, 296 | basename=basename, 297 | height=height, 298 | width=width, 299 | image1=image1, 300 | image2=image2, 301 | event_voxel=event_voxel, 302 | event_valid=event_valid, 303 | reversed_event_voxel=reversed_event_voxel, 304 | flow_gt=flow, 305 | flow_valid=valid, 306 | flow10_gt=flow10, 307 | flow10_valid=valid10, 308 | ) 309 | else: 310 | batch = dict( 311 | index=index, 312 | raw_index=index, 313 | basename=basename, 314 | height=height, 315 | width=width, 316 | image1=image1, 317 | image2=image2, 318 | event_voxel=event_voxel, 319 | event_valid=event_valid, 320 | flow_gt=flow, 321 | flow_valid=valid, 322 | ) 323 | 324 | return batch 325 | 326 | def get_raw_events(self, index): 327 | event_filename = self.event_filenames[index] 328 | events_nparray = read_event_h5(event_filename) 329 | return events_nparray 330 | 331 | def get_raw_events_length(self, index): 332 | return len(self.get_raw_events(index)) 333 | 334 | def __len__(self): 335 | return self.data_length 336 | -------------------------------------------------------------------------------- /utils/datasets/MVSEC.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["KMP_BLOCKTIME"] = "0" 3 | import sys 4 | sys.path.append('utils') 5 | sys.path.append('utils/datasets') 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | import h5py 10 | from glob import glob 11 | import numpy as np 12 | 13 | import torch 14 | from torch.utils.data import ConcatDataset 15 | from torch.utils.data import random_split 16 | import torch.utils.data.dataset as dataset 17 | 18 | from .MVSEC_utils import generate_corresponding_gt_flow 19 | from utils.augmentor import fetch_augmentor 20 | from utils.event_uitls import eventsToVoxel 21 | 22 | 23 | DatasetMapping = { 24 | 'in1': 'indoor_flying/indoor_flying1', 25 | 'inday1': 'indoor_flying/indoor_flying1', 26 | 'indoor1': 'indoor_flying/indoor_flying1', 27 | 'indoor_flying1': 'indoor_flying/indoor_flying1', 28 | 29 | 'in2': 'indoor_flying/indoor_flying2', 30 | 'inday2': 'indoor_flying/indoor_flying2', 31 | 'indoor2': 'indoor_flying/indoor_flying2', 32 | 'indoor_flying2': 'indoor_flying/indoor_flying2', 33 | 34 | 'in3': 'indoor_flying/indoor_flying3', 35 | 'inday3': 'indoor_flying/indoor_flying3', 36 | 'indoor3': 'indoor_flying/indoor_flying3', 37 | 'indoor_flying3': 'indoor_flying/indoor_flying3', 38 | 39 | 'in4': 'indoor_flying/indoor_flying4', 40 | 'inday4': 'indoor_flying/indoor_flying4', 41 | 'indoor4': 'indoor_flying/indoor_flying4', 42 | 'indoor_flying4': 'indoor_flying/indoor_flying4', 43 | 44 | 'out1': 'outdoor_day/outdoor_day1', 45 | 'outday1': 'outdoor_day/outdoor_day1', 46 | 'outdoor1': 'outdoor_day/outdoor_day1', 47 | 'outdoor_day1': 'outdoor_day/outdoor_day1', 48 | 49 | 'out2': 'outdoor_day/outdoor_day2', 50 | 'outday2': 'outdoor_day/outdoor_day2', 51 | 'outdoor2': 'outdoor_day/outdoor_day2', 52 | 'outdoor_day2': 'outdoor_day/outdoor_day2', 53 | } 54 | 55 | Valid_Time_Index = { 56 | 'indoor_flying/indoor_flying1': [314, 2199], 57 | 'indoor_flying/indoor_flying2': [314, 2199], 58 | 'indoor_flying/indoor_flying3': [314, 2199], 59 | 'indoor_flying/indoor_flying4': [196, 570], 60 | 'outdoor_day/outdoor_day1': [245, 3000], 61 | 'outdoor_day/outdoor_day2': [4375, 7002], 62 | } 63 | 64 | 65 | class MVSEC(dataset.Dataset): 66 | def __init__(self, args, data_root, data_split='in1', data_mode='full', train_ratio=0.6, skip_num=None, aug_params=None): 67 | super().__init__() 68 | self.args = args 69 | 70 | self.args.crop_size = [256, 256] 71 | self.data_root = data_root 72 | self.data_split = data_split 73 | assert data_split in DatasetMapping.keys() 74 | self.data_filepath = os.path.join( 75 | data_root, DatasetMapping[data_split] + '_data.hdf5') 76 | self.gt_filepath = os.path.join( 77 | data_root, DatasetMapping[data_split] + '_gt.hdf5') 78 | assert os.path.isfile(self.data_filepath) 79 | assert os.path.isfile(self.gt_filepath) 80 | 81 | self.data_mode = data_mode 82 | self.train_ratio = train_ratio 83 | 84 | self.event_bins = args.event_bins 85 | self.event_polarity = False if args.no_event_polarity else True 86 | 87 | self.augmentor = None 88 | if aug_params is not None: 89 | self.augmentor = fetch_augmentor(is_event=True, is_sparse=True, aug_params=aug_params) 90 | 91 | if skip_num is None: 92 | self.skip_num = args.skip_num 93 | else: 94 | self.skip_num = skip_num 95 | 96 | # 'continue' or 'interrupt' or 'skip by events number' 97 | if args.skip_mode == 'continue' or args.skip_mode == 'c': 98 | self.skip_mode = 'c' 99 | elif args.skip_mode == 'interrupt' or args.skip_mode == 'i': 100 | self.skip_mode = 'i' 101 | else: 102 | raise NotImplementedError("skip mode {} is not supported!".format(args.skip_mode)) 103 | 104 | self.raw_index_shift = Valid_Time_Index[DatasetMapping[data_split]][0] 105 | self.raw_index_max = Valid_Time_Index[DatasetMapping[data_split]][1] - 1 106 | if self.skip_mode == 'i': 107 | self.data_length = (self.raw_index_max - 108 | self.raw_index_shift) // self.skip_num - 1 109 | elif self.skip_mode == 'c': 110 | self.data_length = self.raw_index_max - \ 111 | self.raw_index_shift - (self.skip_num - 1) 112 | 113 | np.random.seed(20) 114 | split_index = np.random.rand(self.data_length) <= self.train_ratio 115 | if self.data_mode == 'full': 116 | self.INDEX_MAP = [i for i in range(self.data_length)] 117 | elif self.data_mode == 'train': 118 | self.INDEX_MAP = [i for i in range(self.data_length) if split_index[i] ] 119 | elif self.data_mode == 'val': 120 | self.INDEX_MAP = [i for i in range(self.data_length) if not split_index[i] ] 121 | else: 122 | raise NotImplementedError("unknow data mode {}".format(self.data_mode)) 123 | self.data_length = len(self.INDEX_MAP) 124 | 125 | def open_hdf5(self): 126 | 127 | data_file = h5py.File(self.data_filepath, 'r') 128 | self.events_data = data_file.get('davis/left/events') 129 | self.image_data = data_file.get('davis/left/image_raw') 130 | self.image_ts_data = data_file.get('davis/left/image_raw_ts') 131 | self.image_event_inds = data_file.get('davis/left/image_raw_event_inds') 132 | assert len(self.image_data) == len(self.image_ts_data) 133 | 134 | gt_file = h5py.File(self.gt_filepath, 'r') 135 | self.flow_dist_data = gt_file.get('davis/left/flow_dist') 136 | self.flow_dist_ts = gt_file.get('davis/left/flow_dist_ts') 137 | self.flow_dist_ts_numpy = np.array(self.flow_dist_ts, dtype=np.float) 138 | 139 | self.image_length = len(self.image_data) 140 | self.event_length = len(self.events_data) 141 | self.flow_length = len(self.flow_dist_data) 142 | 143 | assert self.data_length <= self.image_length 144 | 145 | def __getitem__(self, index): 146 | 147 | if not hasattr(self, 'events_data'): 148 | self.open_hdf5() 149 | 150 | if self.skip_mode == 'i': 151 | raw_index = self.INDEX_MAP[index] * self.skip_num + self.raw_index_shift 152 | elif self.skip_mode == 'c': 153 | raw_index = self.INDEX_MAP[index] + self.raw_index_shift 154 | assert raw_index < self.raw_index_max 155 | 156 | image1 = self.image_data[raw_index] 157 | image1_ts = self.image_ts_data[raw_index] 158 | image1_event_index = self.image_event_inds[raw_index] 159 | image2 = self.image_data[raw_index + self.skip_num] 160 | image2_ts = self.image_ts_data[raw_index + self.skip_num] 161 | image2_event_index = self.image_event_inds[raw_index + self.skip_num] 162 | assert image1_event_index < image2_event_index 163 | assert image2_event_index < self.event_length 164 | 165 | if self.skip_mode == 'i' or self.skip_mode == 'c': 166 | events = self.events_data[image1_event_index:image2_event_index] 167 | next_ts = image2_ts 168 | 169 | height, width = image1.shape[:2] 170 | event_voxel = eventsToVoxel(events, num_bins=self.event_bins, height=height, width=width, \ 171 | event_polarity=self.event_polarity, temporal_bilinear=True) 172 | 173 | flow_left_index = np.searchsorted(self.flow_dist_ts_numpy, image1_ts, side='right') - 1 174 | flow_right_index = np.searchsorted(self.flow_dist_ts_numpy, next_ts, side='right') 175 | assert flow_left_index <= flow_right_index 176 | assert flow_left_index < self.flow_length 177 | assert flow_right_index < self.flow_length 178 | 179 | flows = self.flow_dist_data[flow_left_index:flow_right_index] 180 | flows_ts = self.flow_dist_ts_numpy[flow_left_index:flow_right_index+1] 181 | 182 | final_flow = generate_corresponding_gt_flow(flows, flows_ts, image1_ts, next_ts) 183 | final_flow = final_flow.transpose(1, 2, 0) 184 | 185 | # grayscale images 186 | if len(image1.shape) == 2: 187 | image1 = np.tile(image1[..., None], (1, 1, 3)) 188 | image2 = np.tile(image2[..., None], (1, 1, 3)) 189 | else: 190 | image1 = image1[..., :3] 191 | image2 = image2[..., :3] 192 | 193 | crop_height, crop_width = self.args.crop_size[:2] 194 | if 'out' in self.data_split: 195 | assert crop_height < height and crop_width < width 196 | start_y = (height - crop_height) // 2 197 | start_x = (width - crop_width) // 2 198 | image1 = image1[start_y:start_y+crop_height, start_x:start_x+crop_width, :] 199 | image2 = image2[start_y:start_y+crop_height, start_x:start_x+crop_width, :] 200 | event_voxel = event_voxel[:, start_y:start_y+crop_height, start_x:start_x+crop_width] 201 | final_flow = final_flow[start_y:start_y+crop_height, start_x:start_x+crop_width, :] 202 | 203 | if self.augmentor is not None: 204 | event_voxel, image1, image2, final_flow, _, _, _ = \ 205 | self.augmentor(event_voxel, image1, image2, final_flow) 206 | 207 | height, width = image1.shape[:2] 208 | image1 = torch.from_numpy(image1).permute(2, 0, 1).float() 209 | image2 = torch.from_numpy(image2).permute(2, 0, 1).float() 210 | final_flow = torch.from_numpy(final_flow).permute(2, 0, 1).float() 211 | event_voxel = torch.from_numpy(event_voxel).float() 212 | 213 | event_valid = torch.norm(event_voxel, p=2, dim=0, keepdim=False) > 0 214 | event_valid = event_valid.float().unsqueeze(0) 215 | 216 | flow_valid = (torch.norm(final_flow, p=2, dim=0, keepdim=False) > 0) & (final_flow[0].abs() < 1000) & (final_flow[1].abs() < 1000) 217 | flow_valid = flow_valid.float().unsqueeze(0) 218 | 219 | if height == crop_height and width == crop_width: 220 | pass 221 | else: 222 | assert crop_height < height and crop_width < width 223 | start_y = (height - crop_height) // 2 224 | start_x = (width - crop_width) // 2 225 | image1 = image1[:, start_y:start_y+crop_height, start_x:start_x+crop_width] 226 | image2 = image2[:, start_y:start_y+crop_height, start_x:start_x+crop_width] 227 | event_voxel = event_voxel[:, start_y:start_y+crop_height, start_x:start_x+crop_width] 228 | event_valid = event_valid[:, start_y:start_y+crop_height, start_x:start_x+crop_width] 229 | final_flow = final_flow[:, start_y:start_y+crop_height, start_x:start_x+crop_width] 230 | flow_valid = flow_valid[:, start_y:start_y+crop_height, start_x:start_x+crop_width] 231 | 232 | height, width = image1.shape[:2] 233 | 234 | basename = "{}_{:0>5d}".format(self.data_split, index) 235 | 236 | batch = dict( 237 | index=index, 238 | raw_index=raw_index, 239 | basename=basename, 240 | height=height, 241 | width=width, 242 | image1=image1, 243 | image2=image2, 244 | event_voxel=event_voxel, 245 | event_valid=event_valid, 246 | flow_gt=final_flow, 247 | flow_valid=flow_valid, 248 | ) 249 | 250 | return batch 251 | 252 | def get_raw_events(self, index): 253 | 254 | if not hasattr(self, 'events_data'): 255 | self.open_hdf5() 256 | 257 | if self.skip_mode == 'i': 258 | raw_index = self.INDEX_MAP[index] * self.skip_num + self.raw_index_shift 259 | elif self.skip_mode == 'c' or self.skip_mode == 'e': 260 | raw_index = self.INDEX_MAP[index] + self.raw_index_shift 261 | assert raw_index < self.raw_index_max 262 | 263 | image1_event_index = self.image_event_inds[raw_index] 264 | image2_event_index = self.image_event_inds[raw_index + self.skip_num] 265 | assert image1_event_index < image2_event_index 266 | assert image2_event_index < self.event_length 267 | 268 | if self.skip_mode == 'i' or self.skip_mode == 'c': 269 | events = self.events_data[image1_event_index:image2_event_index] 270 | 271 | return events 272 | 273 | def __len__(self): 274 | return self.data_length 275 | -------------------------------------------------------------------------------- /utils/datasets/MVSEC_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import numpy as np 3 | import math 4 | import cv2 5 | 6 | """ 7 | Calculates per pixel flow error between flow_pred and flow_gt. 8 | event_img is used to mask out any pixels without events (are 0). 9 | If is_car is True, only the top 190 rows of the images will be evaluated to remove the hood of 10 | the car which does not appear in the GT. 11 | """ 12 | def flow_error_dense(flow_gt, flow_pred, event_img=None, is_car=False): 13 | max_row = flow_gt.shape[1] 14 | 15 | if event_img is None: 16 | event_img = np.ones(flow_pred.shape[0:2]) 17 | if is_car: 18 | max_row = 190 19 | 20 | event_img_cropped = event_img[:max_row, :] 21 | flow_gt_cropped = flow_gt[:max_row, :, :] 22 | 23 | flow_pred_cropped = flow_pred[:max_row, :, :] 24 | 25 | event_mask = event_img_cropped > 0 26 | 27 | # Only compute error over points that are valid in the GT (not inf or 0). 28 | flow_mask = np.logical_and( 29 | np.logical_and(~np.isinf(flow_gt_cropped[:, :, 0]), ~np.isinf(flow_gt_cropped[:, :, 1])), 30 | np.linalg.norm(flow_gt_cropped, axis=2) > 0) 31 | total_mask = np.squeeze(np.logical_and(event_mask, flow_mask)) 32 | 33 | gt_masked = flow_gt_cropped[total_mask, :] 34 | pred_masked = flow_pred_cropped[total_mask, :] 35 | 36 | # Average endpoint error. 37 | EE = np.linalg.norm(gt_masked - pred_masked, axis=-1) 38 | n_points = EE.shape[0] 39 | AEE = np.mean(EE) 40 | 41 | # Percentage of points with EE < 3 pixels. 42 | thresh = 3. 43 | percent_AEE = float((EE < thresh).sum()) / float(EE.shape[0] + 1e-5) 44 | 45 | return AEE, percent_AEE, n_points 46 | 47 | """ 48 | Propagates x_indices and y_indices by their flow, as defined in x_flow, y_flow. 49 | x_mask and y_mask are zeroed out at each pixel where the indices leave the image. 50 | The optional scale_factor will scale the final displacement. 51 | """ 52 | def prop_flow(x_flow, y_flow, x_indices, y_indices, x_mask, y_mask, scale_factor=1.0): 53 | flow_x_interp = cv2.remap(x_flow, 54 | x_indices, 55 | y_indices, 56 | cv2.INTER_NEAREST) 57 | 58 | flow_y_interp = cv2.remap(y_flow, 59 | x_indices, 60 | y_indices, 61 | cv2.INTER_NEAREST) 62 | 63 | x_mask[flow_x_interp == 0] = False 64 | y_mask[flow_y_interp == 0] = False 65 | 66 | x_indices += flow_x_interp * scale_factor 67 | y_indices += flow_y_interp * scale_factor 68 | 69 | return 70 | 71 | """ 72 | The ground truth flow maps are not time synchronized with the grayscale images. Therefore, we 73 | need to propagate the ground truth flow over the time between two images. 74 | This function assumes that the ground truth flow is in terms of pixel displacement, not velocity. 75 | 76 | Pseudo code for this process is as follows: 77 | 78 | x_orig = range(cols) 79 | y_orig = range(rows) 80 | x_prop = x_orig 81 | y_prop = y_orig 82 | Find all GT flows that fit in [image_timestamp, image_timestamp+image_dt]. 83 | for all of these flows: 84 | x_prop = x_prop + gt_flow_x(x_prop, y_prop) 85 | y_prop = y_prop + gt_flow_y(x_prop, y_prop) 86 | 87 | The final flow, then, is x_prop - x-orig, y_prop - y_orig. 88 | Note that this is flow in terms of pixel displacement, with units of pixels, not pixel velocity. 89 | 90 | Inputs: 91 | x_flow_in, y_flow_in - list of numpy arrays, each array corresponds to per pixel flow at 92 | each timestamp. 93 | gt_timestamps - timestamp for each flow array. 94 | start_time, end_time - gt flow will be estimated between start_time and end time. 95 | """ 96 | def generate_corresponding_gt_flow(flows, 97 | flows_ts, 98 | start_time, 99 | end_time): 100 | 101 | flow_length = len(flows) 102 | assert flow_length == len(flows_ts) - 1 103 | 104 | x_flow = flows[0][0] 105 | y_flow = flows[0][1] 106 | gt_dt = flows_ts[1] - flows_ts[0] 107 | pre_dt = end_time - start_time 108 | 109 | # if gt_dt > pre_dt: 110 | if start_time > flows_ts[0] and end_time <= flows_ts[1]: 111 | x_flow *= pre_dt / gt_dt 112 | y_flow *= pre_dt / gt_dt 113 | return np.concatenate((x_flow[np.newaxis, :], y_flow[np.newaxis, :]), axis=0) 114 | 115 | x_indices, y_indices = np.meshgrid(np.arange(x_flow.shape[1]), 116 | np.arange(x_flow.shape[0])) 117 | 118 | x_indices = x_indices.astype(np.float32) 119 | y_indices = y_indices.astype(np.float32) 120 | 121 | orig_x_indices = np.copy(x_indices) 122 | orig_y_indices = np.copy(y_indices) 123 | 124 | # Mask keeps track of the points that leave the image, and zeros out the flow afterwards. 125 | x_mask = np.ones(x_indices.shape, dtype=bool) 126 | y_mask = np.ones(y_indices.shape, dtype=bool) 127 | 128 | scale_factor = (flows_ts[1] - start_time) / gt_dt 129 | total_dt = flows_ts[1] - start_time 130 | 131 | prop_flow(x_flow, y_flow, 132 | x_indices, y_indices, 133 | x_mask, y_mask, 134 | scale_factor=scale_factor) 135 | 136 | for i in range(1, flow_length-1): 137 | x_flow = flows[i][0] 138 | y_flow = flows[i][1] 139 | 140 | prop_flow(x_flow, y_flow, 141 | x_indices, y_indices, 142 | x_mask, y_mask) 143 | 144 | total_dt += flows_ts[i+1] - flows_ts[i] 145 | 146 | gt_dt = flows_ts[flow_length] - flows_ts[flow_length-1] 147 | pred_dt = end_time - flows_ts[flow_length-1] 148 | total_dt += pred_dt 149 | 150 | x_flow = flows[flow_length-1][0] 151 | y_flow = flows[flow_length-1][1] 152 | 153 | scale_factor = pred_dt / gt_dt 154 | 155 | prop_flow(x_flow, y_flow, 156 | x_indices, y_indices, 157 | x_mask, y_mask, 158 | scale_factor) 159 | 160 | x_shift = x_indices - orig_x_indices 161 | y_shift = y_indices - orig_y_indices 162 | x_shift[~x_mask] = 0 163 | y_shift[~y_mask] = 0 164 | 165 | return np.concatenate((x_shift[np.newaxis, :], y_shift[np.newaxis, :]), axis=0) 166 | -------------------------------------------------------------------------------- /utils/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .FlyingChairs2 import FlyingChairs2 2 | from .MVSEC import MVSEC 3 | 4 | def fetch_dataset(args): 5 | """ Create the data loader for the corresponding training set """ 6 | 7 | train_dataset = None 8 | val_datasets = None 9 | val_setnames = None 10 | 11 | if args.stage == 'chairs2': 12 | aug_params = {'crop_size': args.crop_size, 13 | 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': True} 14 | train_dataset = FlyingChairs2(args, './data/FlyingChairs2', data_kind='train', aug_params=aug_params) 15 | dataset1_val = FlyingChairs2(args, './data/FlyingChairs2', data_kind='trainval') 16 | val_datasets = [dataset1_val] 17 | val_setnames = ['chairs2trainval'] 18 | 19 | assert train_dataset is not None 20 | 21 | return train_dataset, val_datasets, val_setnames 22 | 23 | 24 | def fetch_test_dataset(args): 25 | """ Create the torch Dataset for the corresponding testing set / name """ 26 | 27 | test_datasets = None 28 | names = None 29 | 30 | if args.stage == 'chairs2' or args.stage == 'chairs2val': 31 | dataset = FlyingChairs2(args, './data/FlyingChairs2', data_kind='val') 32 | test_datasets = [dataset] 33 | names = ['chairs2val'] 34 | 35 | elif args.stage == 'chairs2train': 36 | dataset = FlyingChairs2(args, './data/FlyingChairs2', data_kind='train') 37 | test_datasets = [dataset] 38 | names = ['chairs2train'] 39 | 40 | elif args.stage == 'mvsec' or args.stage == 'mvsecfull': 41 | dataset1 = MVSEC(args, './data/MVSEC_HDF5', data_split='indoor_flying1') 42 | dataset2 = MVSEC(args, './data/MVSEC_HDF5', data_split='indoor_flying2') 43 | dataset3 = MVSEC(args, './data/MVSEC_HDF5', data_split='indoor_flying3') 44 | dataset4 = MVSEC(args, './data/MVSEC_HDF5', data_split='outdoor_day1') 45 | dataset5 = MVSEC(args, './data/MVSEC_HDF5', data_split='outdoor_day2') 46 | test_datasets = [dataset1, dataset2, dataset3, dataset4, dataset5] 47 | names = ['mvsecval/indoor_flying1', 'mvsecval/indoor_flying2', 'mvsecval/indoor_flying3', \ 48 | 'mvsecval/outdoor_day1', 'mvsecval/outdoor_day2'] 49 | 50 | assert test_datasets is not None 51 | 52 | return test_datasets, names 53 | -------------------------------------------------------------------------------- /utils/event_uitls.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('utils') 3 | import numpy as np 4 | import cv2 5 | cv2.setNumThreads(0) 6 | cv2.ocl.setUseOpenCL(False) 7 | import torch 8 | 9 | # from https://github.com/TimoStoff/event_utils/blob/master/lib/representations/voxel_grid.py 10 | 11 | def interpolate_to_image(pxs, pys, dxs, dys, weights, img): 12 | """ 13 | Accumulate x and y coords to an image using bilinear interpolation 14 | """ 15 | img.index_put_((pys, pxs ), weights*(1.0-dxs)*(1.0-dys), accumulate=True) 16 | img.index_put_((pys, pxs+1), weights*dxs*(1.0-dys), accumulate=True) 17 | img.index_put_((pys+1, pxs ), weights*(1.0-dxs)*dys, accumulate=True) 18 | img.index_put_((pys+1, pxs+1), weights*dxs*dys, accumulate=True) 19 | return img 20 | 21 | 22 | def binary_search_torch_tensor(t, l, r, x, side='left'): 23 | """ 24 | Binary search sorted pytorch tensor 25 | """ 26 | if r is None: 27 | r = len(t)-1 28 | while l <= r: 29 | mid = l + (r - l)//2; 30 | midval = t[mid] 31 | if midval == x: 32 | return mid 33 | elif midval < x: 34 | l = mid + 1 35 | else: 36 | r = mid - 1 37 | if side == 'left': 38 | return l 39 | return r 40 | 41 | 42 | def events_to_image_torch(xs, ys, ps, 43 | device=None, sensor_size=(180, 240), clip_out_of_range=True, 44 | interpolation=None, padding=True): 45 | """ 46 | Method to turn event tensor to image. Allows for bilinear interpolation. 47 | :param xs: tensor of x coords of events 48 | :param ys: tensor of y coords of events 49 | :param ps: tensor of event polarities/weights 50 | :param device: the device on which the image is. If none, set to events device 51 | :param sensor_size: the size of the image sensor/output image 52 | :param clip_out_of_range: if the events go beyond the desired image size, 53 | clip the events to fit into the image 54 | :param interpolation: which interpolation to use. Options=None,'bilinear' 55 | :param padding if bilinear interpolation, allow padding the image by 1 to allow events to fit: 56 | """ 57 | if device is None: 58 | device = xs.device 59 | if interpolation == 'bilinear' and padding: 60 | img_size = (sensor_size[0]+1, sensor_size[1]+1) 61 | else: 62 | img_size = list(sensor_size) 63 | 64 | mask = torch.ones(xs.size(), device=device) 65 | if clip_out_of_range: 66 | zero_v = torch.tensor([0.], device=device) 67 | ones_v = torch.tensor([1.], device=device) 68 | clipx = img_size[1] if interpolation is None and padding==False else img_size[1]-1 69 | clipy = img_size[0] if interpolation is None and padding==False else img_size[0]-1 70 | mask = torch.where(xs>=clipx, zero_v, ones_v)*torch.where(ys>=clipy, zero_v, ones_v) 71 | 72 | img = torch.zeros(img_size, dtype=torch.float32).to(device) 73 | if interpolation == 'bilinear' and xs.dtype is not torch.long and xs.dtype is not torch.long: 74 | pxs = (xs.floor()).float() 75 | pys = (ys.floor()).float() 76 | dxs = (xs-pxs).float() 77 | dys = (ys-pys).float() 78 | pxs = (pxs*mask).long() 79 | pys = (pys*mask).long() 80 | masked_ps = ps.squeeze()*mask 81 | interpolate_to_image(pxs, pys, dxs, dys, masked_ps, img) 82 | else: 83 | if xs.dtype is not torch.long: 84 | xs = xs.long().to(device) 85 | if ys.dtype is not torch.long: 86 | ys = ys.long().to(device) 87 | img.index_put_((ys, xs), ps.float(), accumulate=True) 88 | return img 89 | 90 | 91 | def events_to_voxel_torch(xs, ys, ts, ps, B, device=None, sensor_size=(180, 240), temporal_bilinear=True): 92 | """ 93 | Turn set of events to a voxel grid tensor, using temporal bilinear interpolation 94 | Parameters 95 | ---------- 96 | xs : list of event x coordinates (torch tensor) 97 | ys : list of event y coordinates (torch tensor) 98 | ts : list of event timestamps (torch tensor) 99 | ps : list of event polarities (torch tensor) 100 | B : number of bins in output voxel grids (int) 101 | device : device to put voxel grid. If left empty, same device as events 102 | sensor_size : the size of the event sensor/output voxels 103 | temporal_bilinear : whether the events should be naively 104 | accumulated to the voxels (faster), or properly 105 | temporally distributed 106 | Returns 107 | ------- 108 | voxel: voxel of the events between t0 and t1 109 | """ 110 | if isinstance(xs, np.ndarray): 111 | xs = torch.from_numpy(xs) 112 | ys = torch.from_numpy(ys) 113 | ts = torch.from_numpy(ts) 114 | ps = torch.from_numpy(ps) 115 | 116 | if device is None: 117 | device = xs.device 118 | assert(len(xs)==len(ys) and len(ys)==len(ts) and len(ts)==len(ps)) 119 | 120 | bins = [] 121 | dt = ts[-1]-ts[0] 122 | t_norm = (ts-ts[0])/dt*(B-1) 123 | zeros = torch.zeros(t_norm.size()) 124 | for bi in range(B): 125 | if temporal_bilinear: 126 | bilinear_weights = torch.max(zeros, 1.0-torch.abs(t_norm-bi)) 127 | weights = ps*bilinear_weights 128 | vb = events_to_image_torch(xs, ys, 129 | weights, device, sensor_size=sensor_size, 130 | clip_out_of_range=False) 131 | else: 132 | tstart = ts[0] + dt*bi 133 | tend = tstart + dt 134 | beg = binary_search_torch_tensor(ts, 0, len(ts)-1, tstart) 135 | end = binary_search_torch_tensor(ts, 0, len(ts)-1, tend) 136 | vb = events_to_image_torch(xs[beg:end], ys[beg:end], 137 | ps[beg:end], device, sensor_size=sensor_size, 138 | clip_out_of_range=False) 139 | bins.append(vb) 140 | bins = torch.stack(bins) 141 | return bins 142 | 143 | 144 | def events_to_neg_pos_voxel_torch(xs, ys, ts, ps, B, device=None, 145 | sensor_size=(180, 240), temporal_bilinear=True): 146 | """ 147 | Turn set of events to a voxel grid tensor, using temporal bilinear interpolation. 148 | Positive and negative events are put into separate voxel grids 149 | Parameters 150 | ---------- 151 | xs : list of event x coordinates 152 | ys : list of event y coordinates 153 | ts : list of event timestamps 154 | ps : list of event polarities 155 | B : number of bins in output voxel grids (int) 156 | device : the device that the events are on 157 | sensor_size : the size of the event sensor/output voxels 158 | temporal_bilinear : whether the events should be naively 159 | accumulated to the voxels (faster), or properly 160 | temporally distributed 161 | Returns 162 | ------- 163 | voxel_pos: voxel of the positive events 164 | voxel_neg: voxel of the negative events 165 | """ 166 | 167 | if isinstance(xs, np.ndarray): 168 | xs = torch.from_numpy(xs) 169 | ys = torch.from_numpy(ys) 170 | ts = torch.from_numpy(ts) 171 | ps = torch.from_numpy(ps) 172 | 173 | zero_v = torch.tensor([0.]) 174 | ones_v = torch.tensor([1.]) 175 | pos_weights = torch.where(ps>0, ones_v, zero_v) 176 | neg_weights = torch.where(ps<=0, ones_v, zero_v) 177 | 178 | voxel_pos = events_to_voxel_torch(xs, ys, ts, pos_weights, B, device=device, 179 | sensor_size=sensor_size, temporal_bilinear=temporal_bilinear) 180 | voxel_neg = events_to_voxel_torch(xs, ys, ts, neg_weights, B, device=device, 181 | sensor_size=sensor_size, temporal_bilinear=temporal_bilinear) 182 | 183 | return voxel_pos, voxel_neg 184 | 185 | 186 | def eventsToXYTP(events, process=False): 187 | event_x = events[:, 0].astype(np.int) 188 | event_y = events[:, 1].astype(np.int) 189 | event_pols = events[:, 3].astype(np.int) 190 | 191 | event_timestamps = events[:, 2] 192 | 193 | if process: 194 | last_stamp = event_timestamps[-1] 195 | first_stamp = event_timestamps[0] 196 | deltaT = last_stamp - first_stamp 197 | event_timestamps = (event_timestamps - first_stamp) / deltaT 198 | 199 | # event_pols[event_pols == 0] = -1 # polarity should be +1 / -1 200 | 201 | return event_x, event_y, event_timestamps, event_pols 202 | 203 | 204 | def eventsToVoxel(events, num_bins=5, height=None, width=None, event_polarity=False, temporal_bilinear=True): 205 | return eventsToVoxelTorch(events, num_bins, height, width, event_polarity, temporal_bilinear).numpy() 206 | 207 | 208 | def eventsToVoxelTorch(events, num_bins=5, height=None, width=None, event_polarity=False, temporal_bilinear=True): 209 | xs, ys, ts, ps = eventsToXYTP(events, process=True) 210 | 211 | if height is None or width is None: 212 | width = xs.max() + 1 213 | height = ys.max() + 1 214 | 215 | if not event_polarity: 216 | # generate voxel grid which has size num_bins x H x W 217 | voxel_grid = events_to_voxel_torch(xs, ys, ts, ps, num_bins, sensor_size=(height, width), temporal_bilinear=temporal_bilinear) 218 | else: 219 | # generate voxel grid which has size 2*num_bins x H x W 220 | voxel_grid = events_to_neg_pos_voxel_torch(xs, ys, ts, ps, num_bins, sensor_size=(height, width), temporal_bilinear=temporal_bilinear) 221 | voxel_grid = torch.cat([voxel_grid[0], voxel_grid[1]], 0) 222 | 223 | return voxel_grid 224 | -------------------------------------------------------------------------------- /utils/file_io.py: -------------------------------------------------------------------------------- 1 | from fileinput import filename 2 | import os 3 | import re 4 | import h5py 5 | import imageio 6 | import numpy as np 7 | import pandas as pd 8 | from PIL import Image 9 | 10 | import sys 11 | sys.path.append('utils') 12 | import cv2 13 | cv2.setNumThreads(0) 14 | cv2.ocl.setUseOpenCL(False) 15 | 16 | 17 | def readFlow(fn): 18 | """ Read .flo file in Middlebury format""" 19 | # Code adapted from: 20 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 21 | 22 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 23 | # print 'fn = %s'%(fn) 24 | with open(fn, 'rb') as f: 25 | magic = np.fromfile(f, np.float32, count=1) 26 | if 202021.25 != magic: 27 | print('Magic number incorrect. Invalid .flo file') 28 | return None 29 | else: 30 | w = np.fromfile(f, np.int32, count=1) 31 | h = np.fromfile(f, np.int32, count=1) 32 | # print 'Reading %d x %d flo file\n' % (w, h) 33 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 34 | # Reshape data into 3D array (columns, rows, bands) 35 | # The reshape here is for visualization, the original code is (w,h,2) 36 | return np.resize(data, (int(h), int(w), 2)) 37 | 38 | 39 | def readPFM(file): 40 | file = open(file, 'rb') 41 | 42 | color = None 43 | width = None 44 | height = None 45 | scale = None 46 | endian = None 47 | 48 | header = file.readline().rstrip() 49 | if header == b'PF': 50 | color = True 51 | elif header == b'Pf': 52 | color = False 53 | else: 54 | raise Exception('Not a PFM file.') 55 | 56 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 57 | if dim_match: 58 | width, height = map(int, dim_match.groups()) 59 | else: 60 | raise Exception('Malformed PFM header.') 61 | 62 | scale = float(file.readline().rstrip()) 63 | if scale < 0: # little-endian 64 | endian = '<' 65 | scale = -scale 66 | else: 67 | endian = '>' # big-endian 68 | 69 | data = np.fromfile(file, endian + 'f') 70 | shape = (height, width, 3) if color else (height, width) 71 | 72 | data = np.reshape(data, shape) 73 | data = np.flipud(data) 74 | return data 75 | 76 | 77 | def readFlowKITTI(filename): 78 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 79 | flow = flow[:,:,::-1].astype(np.float32) 80 | flow, valid = flow[:, :, :2], flow[:, :, 2] 81 | flow = (flow - 2**15) / 64.0 82 | return flow, valid 83 | 84 | 85 | def read_gen(file_name): 86 | ext = os.path.splitext(file_name)[-1] 87 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 88 | return Image.open(file_name) 89 | elif ext == '.bin' or ext == '.raw': 90 | return np.load(file_name) 91 | elif ext == '.flo' or ext == '.pfm': 92 | return readDenseFlow(file_name) 93 | return [] 94 | 95 | 96 | def readDenseFlow(file_name): 97 | ext = os.path.splitext(file_name)[-1] 98 | if ext == '.flo': 99 | return readFlow(file_name).astype(np.float32) 100 | elif ext == '.pfm': 101 | flow = readPFM(file_name).astype(np.float32) 102 | if len(flow.shape) == 2: 103 | return flow 104 | else: 105 | return flow[:, :, :-1] 106 | return [] 107 | 108 | 109 | def read_event_h5(path): 110 | file = h5py.File(path, 'r') 111 | length = len(file['x']) 112 | events = np.zeros([length, 4], dtype=np.float32) 113 | events[:, 0] = file['x'] 114 | events[:, 1] = file['y'] 115 | events[:, 2] = file['t'] 116 | events[:, 3] = file['p'] 117 | file.close() 118 | return events 119 | -------------------------------------------------------------------------------- /utils/sample_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | def forward_interpolate(flow): 8 | flow = flow.detach().cpu().numpy() 9 | dx, dy = flow[0], flow[1] 10 | 11 | ht, wd = dx.shape 12 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 13 | 14 | x1 = x0 + dx 15 | y1 = y0 + dy 16 | 17 | x1 = x1.reshape(-1) 18 | y1 = y1.reshape(-1) 19 | dx = dx.reshape(-1) 20 | dy = dy.reshape(-1) 21 | 22 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 23 | x1 = x1[valid] 24 | y1 = y1[valid] 25 | dx = dx[valid] 26 | dy = dy[valid] 27 | 28 | flow_x = interpolate.griddata( 29 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 30 | 31 | flow_y = interpolate.griddata( 32 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 33 | 34 | flow = np.stack([flow_x, flow_y], axis=0) 35 | return torch.from_numpy(flow).float() 36 | 37 | 38 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 39 | """ Wrapper for grid_sample, uses pixel coordinates """ 40 | H, W = img.shape[-2:] 41 | xgrid, ygrid = coords.split([1, 1], dim=-1) 42 | xgrid = 2*xgrid/(W-1) - 1 43 | ygrid = 2*ygrid/(H-1) - 1 44 | 45 | grid = torch.cat([xgrid, ygrid], dim=-1) 46 | img = F.grid_sample(img, grid, align_corners=True) 47 | 48 | if mask: 49 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 50 | return img, mask.float() 51 | 52 | return img 53 | 54 | 55 | def coords_grid(batch, ht, wd): 56 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd), indexing='ij') 57 | coords = torch.stack(coords[::-1], dim=0).float() 58 | return coords[None].repeat(batch, 1, 1, 1) 59 | 60 | 61 | def upflow16(flow, mode='bilinear'): 62 | new_size = (16 * flow.shape[2], 16 * flow.shape[3]) 63 | return 16 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 64 | 65 | 66 | def upflow8(flow, mode='bilinear'): 67 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 68 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 69 | 70 | 71 | def upflow4(flow, mode='bilinear'): 72 | new_size = (4 * flow.shape[2], 4 * flow.shape[3]) 73 | return 4 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 74 | 75 | 76 | def upflow2(flow, mode='bilinear'): 77 | new_size = (2 * flow.shape[2], 2 * flow.shape[3]) 78 | return 2 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 79 | 80 | 81 | def downflow2(flow, mode='bilinear'): 82 | new_size = (flow.shape[2] // 2, flow.shape[3] // 2) 83 | return 0.5 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 84 | 85 | def downflow4(flow, mode='bilinear'): 86 | new_size = (flow.shape[2] // 4, flow.shape[3] // 4) 87 | return 0.25 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 88 | 89 | def downflow8(flow, mode='bilinear'): 90 | new_size = (flow.shape[2] // 8, flow.shape[3] // 8) 91 | return 0.125 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 92 | 93 | 94 | def downflow2_pool2d(flow): 95 | _, _, h, w = flow.size() 96 | return F.adaptive_avg_pool2d(flow, [h//2, w//2]) 97 | 98 | 99 | def downgrid2(coord, mode='bilinear'): 100 | b, _, ht, wd = coord.shape 101 | coords_0 = coords_grid(b, ht, wd).to(coord.device) 102 | flow = coord - coords_0 103 | flow = 0.5 * F.interpolate(flow, size=(ht // 2, wd // 2), 104 | mode=mode, align_corners=True) 105 | coords_1 = coords_grid(b, ht // 2, wd // 2).to(coord.device) 106 | return coords_1 + flow 107 | 108 | 109 | def make_coord_center(shape, ranges=None, flatten=True): 110 | """ Make coordinates at grid centers. 111 | """ 112 | coord_seqs = [] 113 | for i, n in enumerate(shape): 114 | if ranges is None: 115 | v0, v1 = -1, 1 116 | else: 117 | v0, v1 = ranges[i] 118 | r = (v1 - v0) / (2 * n) 119 | seq = v0 + r + (2 * r) * torch.arange(n).float() 120 | coord_seqs.append(seq) 121 | ret = torch.stack(torch.meshgrid(*coord_seqs, indexing='ij'), dim=-1) 122 | if flatten: 123 | ret = ret.view(-1, ret.shape[-1]) 124 | return ret 125 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | import logging 6 | import importlib 7 | import torch.nn.functional as F 8 | from pytz import timezone 9 | from datetime import datetime 10 | 11 | 12 | def setup_seed(seed): 13 | os.environ['PYTHONHASHSEED'] = str(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | np.random.seed(seed) 18 | random.seed(seed) 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = True 21 | 22 | 23 | class InputPadder: 24 | """ Pads images such that dimensions are divisible by 8 """ 25 | def __init__(self, dims, div=8, mode='sintel'): 26 | self.ht, self.wd = dims[-2:] 27 | pad_ht = (((self.ht // div) + 1) * div - self.ht) % div 28 | pad_wd = (((self.wd // div) + 1) * div - self.wd) % div 29 | if mode == 'sintel': 30 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 31 | else: 32 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 33 | 34 | def pad_batch(self, batch): 35 | pad_batch = {} 36 | for key in batch.keys(): 37 | if torch.is_tensor(batch[key]) and len(batch[key].shape) == 4: 38 | pad_batch[key] = F.pad(batch[key], self._pad, mode='replicate') 39 | elif torch.is_tensor(batch[key]) and len(batch[key].shape) == 3: 40 | pad_batch[key] = F.pad(batch[key].unsqueeze(1), self._pad, mode='replicate').squeeze(1) 41 | return pad_batch 42 | 43 | def pad(self, *inputs): 44 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 45 | 46 | def unpad(self, x): 47 | if x is not None: 48 | ht, wd = x.shape[-2:] 49 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 50 | if len(x.shape) == 4: 51 | return x[..., c[0]:c[1], c[2]:c[3]] 52 | elif len(x.shape) == 3: 53 | return x[:, c[0]:c[1], c[2]:c[3]] 54 | else: 55 | raise NotImplementedError('not supported pad size for x.shape {}'.format(x.shape)) 56 | else: 57 | return None 58 | 59 | 60 | def ensure_folder(path): 61 | if not os.path.isdir(path): 62 | os.makedirs(path) 63 | 64 | 65 | def count_parameters(model): 66 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 67 | 68 | 69 | def count_all_parameters(model): 70 | return sum(p.numel() for p in model.parameters()) 71 | 72 | 73 | def build_module(module_path, module_name): 74 | module_path = module_path + '.' + module_name 75 | try: 76 | module = importlib.import_module(module_path) 77 | module = getattr(module, module_name) 78 | except Exception as e: 79 | logging.exception(e) 80 | raise ModuleNotFoundError("No module named '{}'".format(module_path)) 81 | 82 | return module 83 | --------------------------------------------------------------------------------