├── .gitignore ├── README.md ├── config ├── parser.py ├── train_ours.yml └── train_ours_exposuredecision.yml ├── datalist ├── generate_datalist.py ├── prepare.sh └── unknown_ex │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── dataloader ├── encodings.py ├── h5dataloader.py ├── h5dataset.py ├── h5dataset_fast.py ├── h5dataset_realdata.py └── util.py ├── generate_dataset ├── README.md ├── convert_unknown.py ├── syn_gopro.py ├── tools │ ├── add_hdf5_attribute.py │ ├── event_packagers.py │ ├── h5_to_memmap.py │ ├── read_events.py │ ├── rosbag_to_h5.py │ └── txt_to_h5.py └── upsampling │ ├── README.md │ ├── checkpoint │ └── .gitignore │ ├── example │ └── .gitignore │ ├── upsample.py │ └── utils │ ├── __init__.py │ ├── const.py │ ├── dataset.py │ ├── model.py │ ├── upsampler.py │ └── utils.py ├── infer_ours.py ├── logger ├── __init__.py ├── logger.py ├── logger_config.json └── visualization.py ├── loss ├── PerceptualSimilarity │ └── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── dist_model.py │ │ ├── networks_basic.py │ │ ├── pretrained_networks.py │ │ └── weights │ │ ├── v0.0 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth │ │ └── v0.1 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth ├── __init__.py ├── adversarial.py ├── discriminator.py ├── flow.py ├── reconstruction.py └── restore.py ├── models ├── DCNv2 │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── dcn_v2.py │ ├── dcn_v2_onnx.py │ ├── make.sh │ ├── setup.py │ ├── src │ │ ├── cpu │ │ │ ├── dcn_v2_cpu.cpp │ │ │ ├── dcn_v2_im2col_cpu.cpp │ │ │ ├── dcn_v2_im2col_cpu.h │ │ │ ├── dcn_v2_psroi_pooling_cpu.cpp │ │ │ └── vision.h │ │ ├── cuda │ │ │ ├── dcn_v2_cuda.cu │ │ │ ├── dcn_v2_im2col_cuda.cu │ │ │ ├── dcn_v2_im2col_cuda.h │ │ │ ├── dcn_v2_psroi_pooling_cuda.cu │ │ │ └── vision.h │ │ ├── dcn_v2.h │ │ └── vision.cpp │ ├── testcpu.py │ └── testcuda.py ├── FAC │ ├── README.md │ ├── __init__.py │ ├── install.sh │ └── kernelconv2d │ │ ├── KernelConv2D.py │ │ ├── KernelConv2D_cuda.cpp │ │ ├── KernelConv2D_cuda.h │ │ ├── KernelConv2D_kernel.cu │ │ ├── KernelConv2D_kernel.h │ │ ├── __init__.py │ │ ├── build │ │ ├── lib.linux-x86_64-3.8 │ │ │ └── kernelconv2d_cuda.cpython-38-x86_64-linux-gnu.so │ │ └── temp.linux-x86_64-3.8 │ │ │ ├── .ninja_deps │ │ │ ├── .ninja_log │ │ │ ├── KernelConv2D_cuda.o │ │ │ ├── KernelConv2D_kernel.o │ │ │ └── build.ninja │ │ ├── dist │ │ └── kernelconv2d_cuda-1.0.0-py3.8-linux-x86_64.egg │ │ ├── kernelconv2d_cuda.egg-info │ │ ├── PKG-INFO │ │ ├── SOURCES.txt │ │ ├── dependency_links.txt │ │ └── top_level.txt │ │ └── setup.py ├── Ours │ └── model_singleframe.py └── model_misc │ ├── base.py │ ├── model_util.py │ ├── resnet_3D.py │ ├── submodules.py │ └── unet.py ├── myutils ├── data_augmentation.py ├── event_visual_example.py ├── gradients.py ├── iwe.py ├── timers.py ├── utils.py └── vis_events │ ├── matplotlib_plot_events.py │ ├── tools │ ├── __init__.py │ ├── add_hdf5_attribute.py │ ├── event_packagers.py │ ├── h5_to_memmap.py │ ├── hxy_events2ply.py │ ├── read_events.py │ ├── rosbag_to_h5.py │ └── txt_to_h5.py │ └── visualization.py ├── scripts ├── infer_ours.sh └── train_ours.sh ├── train_ours.py └── train_ours_exposuredecision.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | ./upsampling/example/* 3 | ./tmp/* 4 | .vscode/* 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Event-based Blurry Frame Interpolation under Blind Exposure 2 | 3 | **Official implementation** of the following paper: 4 | 5 | Event-based Blurry Frame Interpolation under Blind Exposure by Wenming Weng, Yueyi Zhang, Zhiwei Xiong. In CVPR 2023. 6 | 7 | ## Dataset 8 | 9 | Please follow the instructions from directory `generate_dataset` to generate the synthetic dataset. 10 | Our collected real-world dataset RealBlur-DAVIS can be downloaded from this [site](https://rec.ustc.edu.cn/share/421c2a20-0fd3-11ee-bf10-0bc6e486a7d3) with password: 64l7. 11 | 12 | ## Pretrained model 13 | 14 | The pretrained model will be released in this [site](https://rec.ustc.edu.cn/share/7582d0c0-0fd3-11ee-9b7b-233132bcb7d9) with password: uvon. 15 | 16 | ## Training and Inference 17 | 18 | Please check the file `scripts\train_ours.sh` and `scripts\infer_ours.sh` for training and inference. 19 | 20 | ## Citation 21 | 22 | If you find this work helpful, please consider citing our paper. 23 | 24 | ```latex 25 | @InProceedings{Weng_2023_CVPR, 26 | author = {Weng, Wenming and Zhang, Yueyi and Xiong, Zhiwei}, 27 | title = {Event-based Blurry Frame Interpolation under Blind Exposure}, 28 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition 29 | (CVPR)}, 30 | year = {2023}, 31 | } 32 | ``` 33 | 34 | ## Contact 35 | 36 | If you have any problem about the released code, please do not hesitate to contact me with email (wmweng@mail.ustc.edu.cn). 37 | -------------------------------------------------------------------------------- /config/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import logging 4 | import argparse 5 | from pathlib import Path 6 | import yaml 7 | from datetime import datetime 8 | from operator import getitem 9 | from functools import reduce 10 | #local modules 11 | from logger import setup_logging 12 | 13 | 14 | class YAMLParser: 15 | """ YAML parser for config files """ 16 | def __init__(self, config, args, modification=None): 17 | # load config file and apply modification 18 | self._config = self._update_config(config, modification) 19 | self.args = args 20 | config_runid = self.config.get('id', None) 21 | parse_runid = args.runid 22 | assert config_runid is not None or parse_runid is not None, 'Please input runid!' 23 | if config_runid is not None: 24 | run_id = config_runid 25 | else: 26 | run_id = parse_runid 27 | 28 | # set save_dir where trained model and log will be saved. 29 | save_dir = Path(self.config['trainer']['output_path']) 30 | 31 | exper_name = self.config['experiment'] 32 | if run_id is None: # use timestamp as default run-id 33 | run_id = datetime.now().strftime(r'%m%d_%H%M%S') 34 | self._save_dir = save_dir / 'models' / exper_name / run_id 35 | self._log_dir = save_dir / 'logs' / exper_name / run_id 36 | 37 | # make directory for saving checkpoints and log. 38 | self.save_dir.mkdir(parents=True, exist_ok=True) 39 | self.log_dir.mkdir(parents=True, exist_ok=True) 40 | 41 | # save updated config file to the checkpoint dir 42 | self.log_config(self.config, self.log_dir / 'config.yml') 43 | 44 | # configure logging module 45 | setup_logging(self.log_dir) 46 | self.log_levels = { 47 | 0: logging.WARNING, 48 | 1: logging.INFO, 49 | 2: logging.DEBUG 50 | } 51 | 52 | @classmethod 53 | def from_args(cls, args, options=''): 54 | assert isinstance(args, argparse.ArgumentParser), 'args is not ArgumentParser class' 55 | 56 | for opt in options: 57 | args.add_argument(*opt.flags, default=None, type=opt.type) 58 | args = args.parse_args() 59 | 60 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 61 | assert args.config is not None, msg_no_cfg 62 | cfg_fname = Path(args.config) 63 | config = cls.parse_config(cfg_fname) 64 | 65 | # parse custom cli options into dictionary 66 | modification = {opt.target: getattr(args, cls._get_opt_name(opt.flags)) for opt in options} 67 | return cls(config, args, modification) 68 | 69 | def get_logger(self, name, verbosity=2): 70 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys()) 71 | assert verbosity in self.log_levels, msg_verbosity 72 | logger = logging.getLogger(name) 73 | logger.setLevel(self.log_levels[verbosity]) 74 | return logger 75 | 76 | @staticmethod 77 | def parse_config(file): 78 | with open(file) as fid: 79 | yaml_config = yaml.load(fid, Loader=yaml.FullLoader) 80 | 81 | return yaml_config 82 | 83 | @staticmethod 84 | def log_config(config, path_config): 85 | with open(path_config, "w") as fid: 86 | yaml.dump(config, fid) 87 | 88 | def __getitem__(self, name): 89 | """Access items like ordinary dict.""" 90 | return self.config[name] 91 | 92 | @staticmethod 93 | def _get_opt_name(flags): 94 | for flg in flags: 95 | if flg.startswith('--'): 96 | return flg.replace('--', '') 97 | return flags[0].replace('--', '') 98 | 99 | @classmethod 100 | def _update_config(cls, config, modification): 101 | if modification is None: 102 | return config 103 | 104 | for k, v in modification.items(): 105 | if v is not None: 106 | cls._set_by_path(config, k, v) 107 | return config 108 | 109 | @classmethod 110 | def _set_by_path(cls, tree, keys, value): 111 | """Set a value in a nested object in tree by sequence of keys.""" 112 | keys = keys.split(';') 113 | cls._get_by_path(tree, keys[:-1])[keys[-1]] = value 114 | 115 | @classmethod 116 | def _get_by_path(cls, tree, keys): 117 | """Access a nested object in tree by sequence of keys.""" 118 | return reduce(getitem, keys, tree) 119 | 120 | @property 121 | def config(self): 122 | return self._config 123 | 124 | @config.setter 125 | def config(self, config): 126 | self._config = config 127 | 128 | @property 129 | def save_dir(self): 130 | return self._save_dir 131 | 132 | @property 133 | def log_dir(self): 134 | return self._log_dir 135 | 136 | 137 | if __name__ == '__main__': 138 | args = argparse.ArgumentParser(description='test YAMLParser') 139 | args.add_argument('-c', '--config', default=None, type=str) 140 | args.add_argument('-id', '--runid', default=None, type=str) 141 | args.add_argument('-r', '--resume', default=None, type=str, 142 | help='path to latest checkpoint (default: None)') 143 | 144 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 145 | options = [ 146 | CustomArgs(['-lr', '--learning_rate'], type=float, target='test;item1;body1'), 147 | CustomArgs(['-bs', '--batch_size'], type=int, target='data_loader;args;batch_size') 148 | ] 149 | 150 | config_parser = YAMLParser.from_args(args, options) 151 | train_logger = config_parser.get_logger('train') 152 | test_logger = config_parser.get_logger('test') 153 | 154 | pass -------------------------------------------------------------------------------- /config/train_ours.yml: -------------------------------------------------------------------------------- 1 | experiment: Ours 2 | id: name your ex 3 | 4 | SCALE: &SCALE 2 5 | ORI_SCALE: &ORI_SCALE down2 6 | TIME_BINS: &TIME_BINS 16 7 | NumFramePerPeriod: &NumFramePerPeriod 16 8 | NumFramePerBlurry: &NumFramePerBlurry 16 # valid if ExposureMethod is Fixed 9 | NumPeriodPerSeq: &NumPeriodPerSeq 4 # 4 10 | SlidingWindowSeq: &SlidingWindowSeq 4 # 4 11 | NumPeriodPerLoad: &NumPeriodPerLoad 1 12 | SlidingWindowLoad: &SlidingWindowLoad 1 13 | ExposureMethod: &ExposureMethod Custom # Auto/Fixed/Custom 14 | ExposureTime: &ExposureTime [9, 10, 11, 12, 13, 14, 15] # valid if ExposureMethod is Custom 15 | NeedNeighborGT: &NeedNeighborGT False 16 | DeblurPretrain: &DeblurPretrain False 17 | 18 | BatchSize: &BatchSize 8 # 8 19 | 20 | NoiseEnabled: &NoiseEnabled False # False for real-world data 21 | 22 | PATH_TO_OUTPUT: &PATH_TO_OUTPUT /path/to/output 23 | PATH_TO_TRAIN: &PATH_TO_TRAIN /path/to/input/train 24 | PATH_TO_VALID: &PATH_TO_VALID /path/to/input/valid 25 | 26 | model: 27 | name: EVFIAutoEx 28 | args: 29 | FrameBasech: 64 30 | EventBasech: 64 31 | InterCH: 64 32 | TB: *TIME_BINS 33 | norm: null 34 | activation: LeakyReLU 35 | 36 | # exposure decision 37 | # without exposure estimation 38 | UseGTEx: False # if True, then the configs below are invalid 39 | FixEx: null # if not None, then used for controlling, 0~1, and the configs below are invalid 40 | # with exposure estimation 41 | BlurryFashion: RGBLap # DarkCh, Lap, RGB, RGBDark, RGBLap 42 | BLInch: 4 # 1 for DarkCh, Lap; 3 for RGB; 4 for RGBDark, RGBLap 43 | UseEvents: True 44 | LoadPretrainEX: False # if LoadPretrainEX is True, then must provide PretrainedEXPath 45 | PretrainedEXPath: /path/to/pretrainedExposure 46 | FrozenEX: False 47 | 48 | # time-exposure control 49 | step: 12 # 8 16 50 | DualPath: True 51 | 52 | # modification 53 | residual: True 54 | 55 | # detail restoration 56 | DetailEnabled: True 57 | channels: [16, 24, 32, 64] # [16, 24, 32, 64] [32, 64, 96, 128], only valid when enabled is True 58 | 59 | optimizer: 60 | name: Adam 61 | args: 62 | lr: !!float 1e-4 # pretrain with 1e-4, finetune with 1e-5 63 | # weight_decay: !!float 1e-4 64 | betas: [0.9, 0.999] 65 | amsgrad: False 66 | 67 | # lr_scheduler: 68 | # name: ExponentialLR 69 | # args: 70 | # # step_size: 10000 # epochs or iterations according to the training mode 71 | # gamma: 0.95 72 | 73 | lr_scheduler: 74 | name: StepLR 75 | args: 76 | step_size: !!float 2e5 # 2e5 # epochs or iterations according to the training mode 77 | gamma: 0.5 78 | 79 | trainer: 80 | output_path: *PATH_TO_OUTPUT 81 | epoch_based_train: 82 | enabled: False 83 | epochs: 2 84 | save_period: 1 # save model every 'save_period' epoch 85 | train_log_step: 100 # total number for printing train log in one epoch 86 | valid_log_step: 100 # total number for printing train log in one epoch 87 | valid_step: 1 # epoch steps for validation 88 | iteration_based_train: 89 | enabled: True 90 | iterations: !!float 2e6 91 | save_period: 1000 # save model every 'save_period' iteration 92 | train_log_step: 50 # iteration steps for printing train log 93 | valid_log_step: 50 # iteration steps for printing valid log 94 | valid_step: 5000 # iteration steps for validation 95 | lr_change_rate: 1 # iteration steps to perform "lr_scheduler.step()" 96 | monitor: 'min valid_loss' 97 | early_stop: 10 # max valid instervals to continue to train 98 | tensorboard: True 99 | accu_step: 1 # increase batch size while saving memory 100 | do_validation: True 101 | lr_min: !!float 1e-6 102 | vis: 103 | enabled: True 104 | train_img_writer_num: 20 # iteration steps for visualizing train items 105 | valid_img_writer_num: 20 # iteration steps for visualizing valid items 106 | 107 | train_dataloader: 108 | use_ddp: True 109 | path_to_datalist_txt: *PATH_TO_TRAIN 110 | batch_size: *BatchSize 111 | shuffle: True 112 | num_workers: 4 113 | pin_memory: True 114 | drop_last: True 115 | dataset: 116 | scale: *SCALE 117 | ori_scale: *ORI_SCALE 118 | time_bins: *TIME_BINS 119 | NumFramePerPeriod: *NumFramePerPeriod 120 | NumFramePerBlurry: *NumFramePerBlurry 121 | NumPeriodPerSeq: *NumPeriodPerSeq 122 | SlidingWindowSeq: *SlidingWindowSeq 123 | NumPeriodPerLoad: *NumPeriodPerLoad 124 | SlidingWindowLoad: *SlidingWindowLoad 125 | ExposureMethod: *ExposureMethod 126 | ExposureTime: *ExposureTime 127 | NeedNeighborGT: *NeedNeighborGT 128 | DeblurPretrain: *DeblurPretrain 129 | data_augment: 130 | enabled: True 131 | augment: ['RandomCrop', 'CenterCrop', "HorizontalFlip", "VertivcalFlip", 'Noise', 'HotPixel'] 132 | random_crop: 133 | enabled: True 134 | size: [128, 128] # HxW, related to HR size 135 | center_crop: 136 | enabled: False 137 | size: [128, 128] 138 | flip: 139 | enabled: True 140 | horizontal_prob: 0.5 141 | vertical_prob: 0.5 142 | noise: 143 | enabled: *NoiseEnabled # False for real-world data 144 | noise_std: 1.0 145 | noise_fraction: 0.05 146 | hot_pixel: 147 | enabled: *NoiseEnabled # False for real-world data 148 | hot_pixel_std: 2.0 149 | hot_pixel_fraction: 0.001 150 | 151 | valid_dataloader: 152 | use_ddp: True 153 | path_to_datalist_txt: *PATH_TO_VALID 154 | batch_size: 2 155 | shuffle: False 156 | num_workers: 4 157 | pin_memory: True 158 | drop_last: False 159 | dataset: 160 | scale: *SCALE 161 | ori_scale: *ORI_SCALE 162 | time_bins: *TIME_BINS 163 | NumFramePerPeriod: *NumFramePerPeriod 164 | NumFramePerBlurry: *NumFramePerBlurry 165 | NumPeriodPerSeq: *NumPeriodPerSeq 166 | SlidingWindowSeq: *SlidingWindowSeq 167 | NumPeriodPerLoad: *NumPeriodPerLoad 168 | SlidingWindowLoad: *SlidingWindowLoad 169 | ExposureMethod: *ExposureMethod 170 | ExposureTime: *ExposureTime 171 | NeedNeighborGT: *NeedNeighborGT 172 | data_augment: 173 | enabled: True 174 | augment: ['RandomCrop', 'CenterCrop', "HorizontalFlip", "VertivcalFlip", 'Noise', 'HotPixel'] 175 | random_crop: 176 | enabled: False 177 | size: [128, 128] 178 | center_crop: 179 | enabled: True 180 | size: [128, 128] 181 | flip: 182 | enabled: False 183 | horizontal_prob: 0.5 184 | vertical_prob: 0.5 185 | noise: 186 | enabled: *NoiseEnabled # False for real-world data 187 | noise_std: 1.0 188 | noise_fraction: 0.05 189 | hot_pixel: 190 | enabled: *NoiseEnabled # False for real-world data 191 | hot_pixel_std: 2.0 192 | hot_pixel_fraction: 0.001 193 | -------------------------------------------------------------------------------- /config/train_ours_exposuredecision.yml: -------------------------------------------------------------------------------- 1 | experiment: ExposurePretrain 2 | id: name your ex 3 | 4 | SCALE: &SCALE 2 5 | ORI_SCALE: &ORI_SCALE down2 6 | TIME_BINS: &TIME_BINS 16 7 | NumFramePerPeriod: &NumFramePerPeriod 16 8 | NumFramePerBlurry: &NumFramePerBlurry 16 # valid if ExposureMethod is Fixed 9 | NumPeriodPerSeq: &NumPeriodPerSeq 2 10 | SlidingWindowSeq: &SlidingWindowSeq 2 11 | NumPeriodPerLoad: &NumPeriodPerLoad 1 12 | SlidingWindowLoad: &SlidingWindowLoad 1 13 | ExposureMethod: &ExposureMethod Custom # Auto/Fixed/Custom 14 | ExposureTime: &ExposureTime [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] # valid if ExposureMethod is Custom 15 | NeedNeighborGT: &NeedNeighborGT False 16 | DeblurPretrain: &DeblurPretrain False 17 | 18 | BatchSize: &BatchSize 4 19 | 20 | NoiseEnabled: &NoiseEnabled False # False for real-world data 21 | 22 | PATH_TO_OUTPUT: &PATH_TO_OUTPUT /path/to/output 23 | PATH_TO_TRAIN: &PATH_TO_TRAIN /path/to/input/train 24 | PATH_TO_VALID: &PATH_TO_VALID /path/to/input/valid 25 | 26 | model: 27 | name: ExposureDecision # ExposureDecision, ExposureDecisionNoEvents 28 | BlurryFashion: RGBLap # DarkCh, Lap, RGB, RGBDark, RGBLap 29 | args: 30 | EventInch: 32 # 2*TB 31 | 32 | BLInch: 4 # 1 for DarkCh, Lap; 3 for RGB; 4 for RGBDark, RGBLap 33 | 34 | InterCH: 64 35 | Group: 4 36 | norm: null 37 | activation: LeakyReLU 38 | 39 | optimizer: 40 | name: Adam 41 | args: 42 | lr: !!float 1e-4 # pretrain with 1e-4, finetune with 1e-5 43 | # weight_decay: !!float 1e-4 44 | betas: [0.9, 0.999] 45 | amsgrad: False 46 | 47 | # lr_scheduler: 48 | # name: ExponentialLR 49 | # args: 50 | # # step_size: 10000 # epochs or iterations according to the training mode 51 | # gamma: 0.95 52 | 53 | lr_scheduler: 54 | name: StepLR 55 | args: 56 | step_size: !!float 2e5 # 2e5 # epochs or iterations according to the training mode 57 | gamma: 0.5 58 | 59 | trainer: 60 | output_path: *PATH_TO_OUTPUT 61 | epoch_based_train: 62 | enabled: False 63 | epochs: 2 64 | save_period: 1 # save model every 'save_period' epoch 65 | train_log_step: 100 # total number for printing train log in one epoch 66 | valid_log_step: 100 # total number for printing train log in one epoch 67 | valid_step: 1 # epoch steps for validation 68 | iteration_based_train: 69 | enabled: True 70 | iterations: !!float 2e6 71 | save_period: 1000 # save model every 'save_period' iteration 72 | train_log_step: 50 # iteration steps for printing train log 73 | valid_log_step: 50 # iteration steps for printing valid log 74 | valid_step: 5000 # iteration steps for validation 75 | lr_change_rate: 1 # iteration steps to perform "lr_scheduler.step()" 76 | monitor: 'min valid_loss' 77 | early_stop: 10 # max valid instervals to continue to train 78 | tensorboard: True 79 | accu_step: 1 # increase batch size while saving memory 80 | do_validation: True 81 | lr_min: !!float 1e-6 82 | vis: 83 | enabled: True 84 | train_img_writer_num: 20 # iteration steps for visualizing train items 85 | valid_img_writer_num: 20 # iteration steps for visualizing valid items 86 | 87 | train_dataloader: 88 | use_ddp: True 89 | path_to_datalist_txt: *PATH_TO_TRAIN 90 | batch_size: *BatchSize 91 | shuffle: True 92 | num_workers: 4 93 | pin_memory: True 94 | drop_last: True 95 | dataset: 96 | scale: *SCALE 97 | ori_scale: *ORI_SCALE 98 | time_bins: *TIME_BINS 99 | NumFramePerPeriod: *NumFramePerPeriod 100 | NumFramePerBlurry: *NumFramePerBlurry 101 | NumPeriodPerSeq: *NumPeriodPerSeq 102 | SlidingWindowSeq: *SlidingWindowSeq 103 | NumPeriodPerLoad: *NumPeriodPerLoad 104 | SlidingWindowLoad: *SlidingWindowLoad 105 | ExposureMethod: *ExposureMethod 106 | ExposureTime: *ExposureTime 107 | NeedNeighborGT: *NeedNeighborGT 108 | DeblurPretrain: *DeblurPretrain 109 | data_augment: 110 | enabled: True 111 | augment: ['RandomCrop', 'CenterCrop', "HorizontalFlip", "VertivcalFlip", 'Noise', 'HotPixel'] 112 | random_crop: 113 | enabled: True 114 | size: [128, 128] # HxW, related to HR size 115 | center_crop: 116 | enabled: False 117 | size: [128, 128] 118 | flip: 119 | enabled: True 120 | horizontal_prob: 0.5 121 | vertical_prob: 0.5 122 | noise: 123 | enabled: *NoiseEnabled # False for real-world data 124 | noise_std: 1.0 125 | noise_fraction: 0.05 126 | hot_pixel: 127 | enabled: *NoiseEnabled # False for real-world data 128 | hot_pixel_std: 2.0 129 | hot_pixel_fraction: 0.001 130 | 131 | valid_dataloader: 132 | use_ddp: True 133 | path_to_datalist_txt: *PATH_TO_VALID 134 | batch_size: 2 135 | shuffle: False 136 | num_workers: 4 137 | pin_memory: True 138 | drop_last: False 139 | dataset: 140 | scale: *SCALE 141 | ori_scale: *ORI_SCALE 142 | time_bins: *TIME_BINS 143 | NumFramePerPeriod: *NumFramePerPeriod 144 | NumFramePerBlurry: *NumFramePerBlurry 145 | NumPeriodPerSeq: *NumPeriodPerSeq 146 | SlidingWindowSeq: *SlidingWindowSeq 147 | NumPeriodPerLoad: *NumPeriodPerLoad 148 | SlidingWindowLoad: *SlidingWindowLoad 149 | ExposureMethod: *ExposureMethod 150 | ExposureTime: *ExposureTime 151 | NeedNeighborGT: *NeedNeighborGT 152 | data_augment: 153 | enabled: True 154 | augment: ['RandomCrop', 'CenterCrop', "HorizontalFlip", "VertivcalFlip", 'Noise', 'HotPixel'] 155 | random_crop: 156 | enabled: False 157 | size: [128, 128] 158 | center_crop: 159 | enabled: True 160 | size: [128, 128] 161 | flip: 162 | enabled: False 163 | horizontal_prob: 0.5 164 | vertical_prob: 0.5 165 | noise: 166 | enabled: *NoiseEnabled # False for real-world data 167 | noise_std: 1.0 168 | noise_fraction: 0.05 169 | hot_pixel: 170 | enabled: *NoiseEnabled # False for real-world data 171 | hot_pixel_std: 2.0 172 | hot_pixel_fraction: 0.001 173 | -------------------------------------------------------------------------------- /datalist/generate_datalist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import glob 5 | 6 | 7 | def get_flags(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--data_path', required=True) 10 | parser.add_argument('--valid_data_path', default=None) 11 | parser.add_argument('--num', type=int, default=None) 12 | parser.add_argument('--valid_num', type=int, default=None) 13 | parser.add_argument('--portion', type=float, default=None) 14 | parser.add_argument('--mode', type=int, choices=[0, 1, 2, 3], required=True) 15 | parser.add_argument('--seed', type=int, default=123) 16 | parser.add_argument('--train_txt_name', type=str, default='train.txt') 17 | parser.add_argument('--valid_txt_name', type=str, default='valid.txt') 18 | flags = parser.parse_args() 19 | 20 | return flags 21 | 22 | 23 | def write_txt(path: str, data: list): 24 | with open(path, 'w') as f: 25 | f.writelines([str(i) + '\n' for i in data]) 26 | 27 | 28 | if __name__ == '__main__': 29 | flags = get_flags() 30 | 31 | data_path = flags.data_path 32 | valid_data_path = flags.valid_data_path 33 | num = flags.num 34 | valid_num = flags.valid_num 35 | portion = flags.portion 36 | mode = flags.mode 37 | train_txt_name = flags.train_txt_name 38 | valid_txt_name = flags.valid_txt_name 39 | seed = flags.seed 40 | 41 | assert os.path.exists(data_path) 42 | data_paths = sorted(glob.glob(os.path.join(data_path, '*.h5'))) 43 | data_len = len(data_paths) 44 | if valid_data_path is not None: 45 | assert os.path.exists(valid_data_path) 46 | valid_data_paths = sorted(glob.glob(os.path.join(valid_data_path, '*.h5'))) 47 | valid_data_len = len(valid_data_paths) 48 | 49 | if mode == 0: 50 | if num is None: 51 | num = data_len 52 | assert num > 0 and num <= data_len, f'num must be set > 0 and < data len {data_len}, but got {num}' 53 | random.seed(seed) 54 | train_candidates = sorted(random.sample(data_paths, num)) 55 | write_txt(f'datalist/{train_txt_name}', train_candidates) 56 | 57 | print(f'Sample {num} training items from {data_path}') 58 | 59 | elif mode == 1: 60 | assert num is not None 61 | assert valid_num is not None 62 | assert valid_num > 0 and valid_num < data_len 63 | assert num > 0 and num < data_len 64 | assert (valid_num + num) > 0 and (valid_num + num) <= data_len 65 | 66 | random.seed(seed) 67 | train_candidates = random.sample(data_paths, num) 68 | left_candidates = sorted(list(set(data_paths) - set(train_candidates))) 69 | random.seed(seed) 70 | valid_candidates = sorted(random.sample(left_candidates, valid_num)) 71 | 72 | write_txt(f'datalist/{train_txt_name}', train_candidates) 73 | write_txt(f'datalist/{valid_txt_name}', valid_candidates) 74 | 75 | print(f'Sample {num} training items from {data_path}') 76 | print(f'Sample {valid_num} validating items from {data_path}') 77 | 78 | elif mode == 2: 79 | assert portion is not None 80 | train_num = int(data_len * portion) 81 | random.seed(seed) 82 | train_candidates = random.sample(data_paths, train_num) 83 | valid_candidates = sorted(list(set(data_paths) - set(train_candidates))) 84 | 85 | write_txt(f'datalist/{train_txt_name}', train_candidates) 86 | write_txt(f'datalist/{valid_txt_name}', valid_candidates) 87 | 88 | print(f'Sample {train_num} training items from {data_path}') 89 | print(f'Sample {data_len - train_num} validating items from {data_path}') 90 | 91 | elif mode == 3: 92 | assert valid_data_path is not None 93 | assert valid_num is not None 94 | assert num is not None 95 | 96 | random.seed(seed) 97 | train_candidates = sorted(random.sample(data_paths, num)) 98 | random.seed(seed) 99 | valid_candidates = sorted(random.sample(valid_data_paths, valid_num)) 100 | 101 | write_txt(f'datalist/{train_txt_name}', train_candidates) 102 | write_txt(f'datalist/{valid_txt_name}', valid_candidates) 103 | 104 | print(f'Sample {num} training items from {data_path}') 105 | print(f'Sample {valid_num} validating items from {valid_data_path}') 106 | 107 | else: 108 | raise Exception(f'Invalid mode {mode}') 109 | -------------------------------------------------------------------------------- /datalist/prepare.sh: -------------------------------------------------------------------------------- 1 | ## example use for generating datalist 2 | 3 | python datalist/generate_datalist.py --mode 2 --data_path /path/to/h5 --portion 0.9 \ 4 | --train_txt_name train.txt \ 5 | --valid_txt_name valid.txt 6 | 7 | python datalist/generate_datalist.py --mode 1 --data_path /path/to/h5 \ 8 | --num 4 --valid_num 2 \ 9 | --train_txt_name train.txt \ 10 | --valid_txt_name valid.txt 11 | 12 | python datalist/generate_datalist.py --mode 0 --data_path /path/to/h5 \ 13 | --num 2 \ 14 | --train_txt_name train.txt 15 | -------------------------------------------------------------------------------- /datalist/unknown_ex/test.txt: -------------------------------------------------------------------------------- 1 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_00.h5 2 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_01.h5 3 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_02.h5 4 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_03.h5 5 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_04.h5 6 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_05.h5 7 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_06.h5 8 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_07.h5 9 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_08.h5 10 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_09.h5 11 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_10.h5 12 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_11.h5 13 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_12.h5 14 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_13.h5 15 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_14.h5 16 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_15.h5 17 | -------------------------------------------------------------------------------- /datalist/unknown_ex/train.txt: -------------------------------------------------------------------------------- 1 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_00.h5 2 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_01.h5 3 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_02.h5 4 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_03.h5 5 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_04.h5 6 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_05.h5 7 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_06.h5 8 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_07.h5 9 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_08.h5 10 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_09.h5 11 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_10.h5 12 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_11.h5 13 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_12.h5 14 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_13.h5 15 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_14.h5 16 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_15.h5 17 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_16.h5 18 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_17.h5 19 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_18.h5 20 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_19.h5 21 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_20.h5 22 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_21.h5 23 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_22.h5 24 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_23.h5 25 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_24.h5 26 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_25.h5 27 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_26.h5 28 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_27.h5 29 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_28.h5 30 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_29.h5 31 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_30.h5 32 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_31.h5 33 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_32.h5 34 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_33.h5 35 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_34.h5 36 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_35.h5 37 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_36.h5 38 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_37.h5 39 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_38.h5 40 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_39.h5 41 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_40.h5 42 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_41.h5 43 | /data2/wengwm/work/dataset/unknown_ex/h5/train/train_subset_42.h5 44 | -------------------------------------------------------------------------------- /datalist/unknown_ex/valid.txt: -------------------------------------------------------------------------------- 1 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_00.h5 2 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_01.h5 3 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_05.h5 4 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_08.h5 5 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_09.h5 6 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_10.h5 7 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_11.h5 8 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_12.h5 9 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_13.h5 10 | /data2/wengwm/work/dataset/unknown_ex/h5/test/test_subset_15.h5 11 | -------------------------------------------------------------------------------- /dataloader/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @author: ZhangX 5 | """ 6 | import os 7 | import numpy as np 8 | 9 | 10 | def get_filename(path,suffix): 11 | ## function used to get file names 12 | namelist=[] 13 | filelist = os.listdir(path) 14 | for i in filelist: 15 | if os.path.splitext(i)[1] == suffix: 16 | namelist.append(i) 17 | namelist.sort() 18 | return namelist 19 | 20 | def mkdir(path): 21 | if not os.path.exists(path): 22 | os.makedirs(path) 23 | 24 | def normalize(img, max_val=255.): 25 | if (img.max() == img.min()): 26 | return img # no normalization 27 | else: 28 | return (img - img.min()) * max_val / (img.max() - img.min()) 29 | 30 | def fold_time_dim(inp): 31 | if inp.ndim == 4: 32 | T,C,H,W = inp.shape 33 | out = inp.reshape((T*C, H, W)) #[T,C,H,W] -> [T*C,H,W] 34 | elif inp.ndim == 5: 35 | N,T,C,H,W = inp.shape 36 | out = inp.reshape((N,T*C, H, W)) #[N,T,C,H,W] -> [N,T*C,H,W] 37 | return out 38 | 39 | def filter_events(event_data, start, end): 40 | ## filter events based on temporal dimension 41 | x = event_data['x'][event_data['t']>=start] 42 | y = event_data['y'][event_data['t']>=start] 43 | p = event_data['p'][event_data['t']>=start] 44 | t = event_data['t'][event_data['t']>=start] 45 | 46 | x = x[t<=end] 47 | y = y[t<=end] 48 | p = p[t<=end] 49 | t = t[t<=end] 50 | return x,y,p,t 51 | 52 | def filter_events_by_space(key,x1,x2,x3, start, end): 53 | ## filter events based on spatial dimension 54 | # start inclusive and end exclusive 55 | new_x1 = x1[key>=start] 56 | new_x2 = x2[key>=start] 57 | new_x3 = x3[key>=start] 58 | new_key = key[key>=start] 59 | 60 | new_x1 = new_x1[new_key=0 and noise<=1 98 | if noise>0: 99 | num_noise = int(noise * len(t)) 100 | img_size = (H, W) 101 | noise_x = np.random.randint(0,img_size[1],(num_noise,1)) 102 | noise_y = np.random.randint(0,img_size[0],(num_noise,1)) 103 | noise_p = np.random.randint(0,2,(num_noise,1)) 104 | noise_t = np.random.randint(0,idx+1,(num_noise,1)) 105 | # add noise 106 | np.add.at(eframe, noise_x + noise_y*W + noise_p*W*H + noise_t*W*H*C, 1) 107 | 108 | eframe = np.reshape(eframe, (T,C,H,W)) 109 | 110 | return eframe 111 | 112 | def event2frame(event, img_size, ts, f_span, total_span, num_frame, noise, roiTL=(0,0)): 113 | ## convert event streams to [T, C, H, W] event tensor, C=2 indicates polarity 114 | f_start, f_end = f_span 115 | total_start, total_end = total_span 116 | 117 | preE = np.zeros((num_frame, 2, img_size[0], img_size[1])) 118 | postE = np.zeros((num_frame, 2, img_size[0], img_size[1])) 119 | interval = (total_end - total_start) / num_frame # based on whole event range 120 | 121 | if event['t'].shape[0] > 0: 122 | preE = e2f_detail(event,preE,ts,f_start,interval, noise, roiTL, img_size) 123 | postE = e2f_detail(event,postE,ts,f_end,interval, noise, roiTL, img_size) 124 | 125 | pre_coef = (ts - f_start) / (f_end - f_start) 126 | post_coef = (f_end - ts) / (f_end - f_start) 127 | 128 | return preE, postE, pre_coef, post_coef 129 | 130 | def event_single_intergral(event, img_size, span, roiTL=(0,0)): 131 | ## generate event frames for sharp-event loss 132 | start, end = span 133 | H, W = img_size 134 | event_img = np.zeros((H, W)).ravel() 135 | 136 | x,y,p,t = filter_events(event, start, end) # filter events by temporal dim 137 | x,y,p,t = filter_events_by_space(x,y,p,t,roiTL[1], roiTL[1]+img_size[1]) # filter events by x dim 138 | y,x,p,t = filter_events_by_space(y,x,p,t,roiTL[0], roiTL[0]+img_size[0]) # filter events by y dim 139 | x -= roiTL[1] # shift minima to zero 140 | y -= roiTL[0] # shift minima to zero 141 | 142 | np.add.at(event_img, x + y*W, p) 143 | event_img = event_img.reshape((H,W)) 144 | 145 | return event_img -------------------------------------------------------------------------------- /generate_dataset/README.md: -------------------------------------------------------------------------------- 1 | ## Synthetic Dataset 2 | 3 | First download GoPro dataset from the [site](https://seungjunnah.github.io/Datasets/gopro). Then run file `syn_gopro.py` to generate the synthetic data. 4 | 5 | ## Semi-real Dataset 6 | 7 | First download RealSharp-DAVIS dataset from the [site](https://intelpro.github.io/UEVD/). Then run file `convert_unknown.py` to generate the semi-real data. 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /generate_dataset/convert_unknown.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from glob import glob 4 | import argparse 5 | import os 6 | import pandas as pd 7 | import cv2 8 | # Local modules 9 | from tools.event_packagers import * 10 | 11 | 12 | def get_flags(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--path_to_data', type=str, default='/path/to/data') 15 | parser.add_argument('--path_to_h5', type=str, default='/path/to/output') 16 | args = parser.parse_args() 17 | return args 18 | 19 | 20 | def main(): 21 | flags = get_flags() 22 | 23 | path_to_data = flags.path_to_data 24 | path_to_h5 = flags.path_to_h5 25 | os.makedirs(path_to_h5, exist_ok=True) 26 | 27 | sequences = [os.path.join(path_to_data, dir) for dir in os.listdir(path_to_data)] 28 | print(f'all sequences: {sequences}') 29 | 30 | for sequence in sequences: 31 | print(f'Processing sequence: {sequence}') 32 | events_file = os.path.join(sequence, 'events', 'events.npz') 33 | imgs_dir = sorted(glob(os.path.join(sequence, 'frames', '*.png'))) 34 | timestamps = pd.read_csv(os.path.join(sequence, 'frame_time.txt'), header=None).values.flatten().tolist() 35 | num_imgs = len(imgs_dir) 36 | 37 | h5_dir = os.path.join(path_to_h5, f'{os.path.basename(sequence)}.h5') 38 | ep = hdf5_packager_multiscale(h5_dir) 39 | 40 | print('Adding events') 41 | events = np.load(events_file)['data'] 42 | x, y, t, p = events['x'].astype(np.int16), events['y'].astype(np.int16), events['timestamp'].astype(np.float64), events['polarity'].astype(np.int8) 43 | p[p==0] = -1 44 | t = t / 1e6 # microsecs to seconds 45 | ep.package_events('ori', x, y, t, p) 46 | 47 | print('Adding images') 48 | for idx in range(num_imgs): 49 | img = cv2.imread(imgs_dir[idx], 1) 50 | timestamp = int(timestamps[idx].split(' ')[1]) / 1e6 # microsecs to seconds 51 | resolution = img.shape[0:2] 52 | ep.package_image('ori', img, timestamp, idx) 53 | 54 | ep.add_event_indices() 55 | ep.add_data(resolution) 56 | 57 | print('all {} files are done!'.format(len(sequences))) 58 | 59 | if __name__ == '__main__': 60 | main() -------------------------------------------------------------------------------- /generate_dataset/syn_gopro.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from glob import glob 5 | import argparse 6 | from tqdm import tqdm 7 | import esim_py 8 | import random 9 | import shutil 10 | # local modules 11 | from tools.event_packagers import hdf5_packager_multiscale 12 | 13 | 14 | config = { 15 | 'Cp_init': 0.1, 16 | 'Cn_init': 0.1, 17 | 'refractory_period': 1e-4, 18 | 'log_eps': 1e-3, 19 | 'use_log':True, 20 | 'CT_range': [0.2, 0.5], 21 | 'max_CT': 0.5, 22 | 'min_CT': 0.2, 23 | 'mu': 1, 24 | 'sigma': 0.1, 25 | 26 | 'fps': 240, 27 | } 28 | 29 | 30 | def write_img(img: np.ndarray, idx: int, imgs_dir: str): 31 | assert os.path.isdir(imgs_dir) 32 | path = os.path.join(imgs_dir, "%05d.png" % idx) 33 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 34 | cv2.imwrite(path, img) 35 | 36 | 37 | def write_timestamps(timestamps: list, timestamps_filename: str): 38 | with open(timestamps_filename, 'w') as t_file: 39 | t_file.writelines([str(t) + '\n' for t in timestamps]) 40 | 41 | 42 | def write_config(config_path): 43 | with open(config_path, 'w') as f: 44 | for key, value in config.items(): 45 | f.write(f'{key}: {value} \n') 46 | 47 | 48 | def prepare_output_dir(src_dir: str, dest_dir: str): 49 | # Copy directory structure. 50 | def ignore_files(directory, files): 51 | return [f for f in files if os.path.isfile(os.path.join(directory, f))] 52 | shutil.copytree(src_dir, dest_dir, ignore=ignore_files) 53 | 54 | 55 | def get_flags(): 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--root_data_path', default='/path/to/data') 58 | parser.add_argument('--path_to_h5', default='/path/to/output') 59 | args = parser.parse_args() 60 | 61 | return args 62 | 63 | 64 | if __name__ == '__main__': 65 | flags = get_flags() 66 | 67 | root_data_path = flags.root_data_path 68 | path_to_h5 = flags.path_to_h5 69 | 70 | if not os.path.exists(path_to_h5): 71 | os.makedirs(path_to_h5, exist_ok=False) 72 | 73 | data_dirs = sorted([os.path.join(root_data_path, dir) for dir in os.listdir(root_data_path)]) 74 | resolution = None 75 | fps = config['fps'] 76 | 77 | esim = esim_py.EventSimulator(config['Cp_init'], 78 | config['Cn_init'], 79 | config['refractory_period'], 80 | config['log_eps'], 81 | config['use_log']) 82 | 83 | CT = [] 84 | for data_dir in tqdm(data_dirs): 85 | print(f'\n processing {data_dir}') 86 | path_to_rgb = os.path.join(data_dir, 'rgb') 87 | path_to_mono = os.path.join(data_dir, 'mono') 88 | path_timestamps = os.path.join(data_dir, 'timestamps.txt') 89 | 90 | prex = os.path.splitext(os.listdir(path_to_rgb)[0])[-1] 91 | rgb_imgs = sorted(glob(os.path.join(path_to_rgb, '*' + prex))) 92 | basename = os.path.basename(data_dir) 93 | h5_filename = basename + '.h5' 94 | ep = hdf5_packager_multiscale(os.path.join(path_to_h5, h5_filename)) 95 | 96 | # images writing 97 | for idx, img_path in enumerate(rgb_imgs): 98 | ori_img = cv2.imread(img_path, 1) 99 | if ori_img is None and idx == 0: 100 | print('Images is None! Donot write images!') 101 | break 102 | if resolution is None and idx == 0: 103 | resolution = ori_img.shape[:-1] 104 | ep.package_image('ori', ori_img, idx/fps, idx) 105 | 106 | # simulate events 107 | print('Events simulating and writing!') 108 | Cp = random.uniform(config['CT_range'][0], config['CT_range'][1]) 109 | Cn = random.gauss(config['mu'], config['sigma']) * Cp 110 | Cp = min(max(Cp, config['min_CT']), config['max_CT']) 111 | Cn = min(max(Cn, config['min_CT']), config['max_CT']) 112 | msg = f'{data_dir}:Cp={Cp}, Cn={Cn}' 113 | CT.append(msg) 114 | print(f'{msg}') 115 | esim.setParameters(Cp, Cn, config['refractory_period'], config['log_eps'], config['use_log']) 116 | events = esim.generateFromFolder(path_to_mono, path_timestamps) # x y t p 117 | ep.package_events('ori', events[:, 0], events[:, 1], events[:, 2], events[:, 3]) 118 | ep.add_event_indices() 119 | ep.add_data(resolution) 120 | resolution = None 121 | 122 | path_to_config = os.path.join(path_to_h5, 'config') 123 | os.makedirs(path_to_config) 124 | write_config(os.path.join(path_to_config, 'config.txt')) 125 | write_timestamps(CT, os.path.join(path_to_config, 'ct.txt')) 126 | print('all {} files are done!'.format(len(data_dirs))) 127 | 128 | -------------------------------------------------------------------------------- /generate_dataset/tools/add_hdf5_attribute.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import h5py 4 | import os 5 | import glob 6 | 7 | def endswith(path, extensions): 8 | for ext in extensions: 9 | if path.endswith(ext): 10 | return True 11 | return False 12 | 13 | def get_filepaths_from_path_or_file(path, extensions=[], datafile_extensions=[".txt", ".csv"]): 14 | files = [] 15 | path = path.rstrip("/") 16 | if os.path.isdir(path): 17 | for ext in extensions: 18 | files += sorted(glob.glob("{}/*{}".format(path, ext))) 19 | else: 20 | if endswith(path, extensions): 21 | files.append(path) 22 | elif endswith(path, datafile_extensions): 23 | with open(path, 'r') as f: 24 | #files.append(line) for line in f.readlines 25 | files = [line.strip() for line in f.readlines()] 26 | return files 27 | 28 | def add_attribute(h5_filepaths, group, attribute_name, attribute_value, dry_run=False): 29 | for h5_filepath in h5_filepaths: 30 | print("adding {}/{}[{}]={}".format(h5_filepath, group, attribute_name, attribute_value)) 31 | if dry_run: 32 | continue 33 | h5_file = h5py.File(h5_filepath, 'a') 34 | dset = h5_file["{}/".format(group)] 35 | dset.attrs[attribute_name] = attribute_value 36 | h5_file.close() 37 | 38 | if __name__ == "__main__": 39 | # arguments 40 | parser = argparse.ArgumentParser() 41 | parser._action_groups.pop() 42 | required = parser.add_argument_group('required arguments') 43 | optional = parser.add_argument_group('optional arguments') 44 | 45 | required.add_argument("--path", help="Can be either 1: path to individual hdf file, " + 46 | "2: txt file with list of hdf files, or " + 47 | "3: directory (all hdf files in directory will be processed).", required=True) 48 | required.add_argument("--attr_name", help="Name of new attribute", required=True) 49 | required.add_argument("--attr_val", help="Value of new attribute", required=True) 50 | optional.add_argument("--group", help="Group to add attribute to. Subgroups " + 51 | "are represented like paths, eg: /group1/subgroup2...", default="") 52 | optional.add_argument("--dry_run", default=0, type=int, 53 | help="If set to 1, will print changes without performing them") 54 | 55 | args = parser.parse_args() 56 | path = args.path 57 | extensions = [".hdf", ".h5"] 58 | files = get_filepaths_from_path_or_file(path, extensions=extensions) 59 | print(files) 60 | dry_run = False if args.dry_run <= 0 else True 61 | add_attribute(files, args.group, args.attr_name, args.attr_val, dry_run=dry_run) 62 | -------------------------------------------------------------------------------- /generate_dataset/tools/h5_to_memmap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import h5py 3 | import numpy as np 4 | import os, shutil 5 | import json 6 | 7 | def find_safe_alternative(output_base_path): 8 | i = 0 9 | alternative_path = "{}_{:09d}".format(output_base_path, i) 10 | while(os.path.exists(alternative_path)): 11 | i += 1 12 | alternative_path = "{}_{:09d}".format(output_base_path, i) 13 | assert(i < 999999999) 14 | return alternative_path 15 | 16 | def save_additional_data_as_mmap(f, mmap_pth, data): 17 | data_path = os.path.join(mmap_pth, data['mmap_filename']) 18 | data_ts_path = os.path.join(mmap_pth, data['mmap_ts_filename']) 19 | data_event_idx_path = os.path.join(mmap_pth, data['mmap_event_idx_filename']) 20 | data_key = data['h5_key'] 21 | print('Writing {} to mmap {}, timestamps to {}'.format(data_key, data_path, data_ts_path)) 22 | h, w, c = 1, 1, 1 23 | if data_key in f.keys(): 24 | num_data = len(f[data_key].keys()) 25 | if num_data > 0: 26 | data_keys = f[data_key].keys() 27 | data_size = f[data_key][data_keys[0]].attrs['size'] 28 | h, w = data_size[0], data_size[1] 29 | c = 1 if len(data_size) <= 2 else data_size[2] 30 | else: 31 | num_data = 1 32 | mmp_imgs = np.memmap(data_path, dtype='uint8', mode='w+', shape=(num_data, h, w, c)) 33 | mmp_img_ts = np.memmap(data_ts_path, dtype='float64', mode='w+', shape=(num_data, 1)) 34 | mmp_event_indices = np.memmap(data_event_idx_path, dtype='uint16', mode='w+', shape=(num_data, 1)) 35 | 36 | if data_key in f.keys(): 37 | data = [] 38 | data_timestamps = [] 39 | data_event_index = [] 40 | for img_key in f[data_key].keys(): 41 | data.append(f[data_key][img_key][:]) 42 | data_timestamps.append(f[data_key][img_key].attrs['timestamp']) 43 | data_event_index.append(f[data_key][img_key].attrs['event_idx']) 44 | 45 | data_stack = np.expand_dims(np.stack(data), axis=3) 46 | data_ts_stack = np.expand_dims(np.stack(data_timestamps), axis=1) 47 | data_event_indices_stack = np.expand_dims(np.stack(data_event_index), axis=1) 48 | mmp_imgs[...] = data_stack 49 | mmp_img_ts[...] = data_ts_stack 50 | mmp_event_indices[...] = data_event_indices_stack 51 | 52 | def write_metadata(f, metadata_path): 53 | metadata = {} 54 | for attr in f.attrs: 55 | val = f.attrs[attr] 56 | if isinstance(val, np.ndarray): 57 | val = val.tolist() 58 | metadata[attr] = val 59 | with open(metadata_path, 'w') as js: 60 | json.dump(metadata, js) 61 | 62 | def h5_to_memmap(h5_file_path, output_base_path, overwrite=True): 63 | output_pth = output_base_path 64 | if os.path.exists(output_pth): 65 | if overwrite: 66 | print("Overwriting {}".format(output_pth)) 67 | shutil.rmtree(output_pth) 68 | else: 69 | output_pth = find_safe_alternative(output_base_path) 70 | print('Data will be extracted to: {}'.format(output_pth)) 71 | os.makedirs(output_pth) 72 | mmap_pth = os.path.join(output_pth, "memmap") 73 | os.makedirs(mmap_pth) 74 | 75 | ts_path = os.path.join(mmap_pth, 't.npy') 76 | xy_path = os.path.join(mmap_pth, 'xy.npy') 77 | ps_path = os.path.join(mmap_pth, 'p.npy') 78 | metadata_path = os.path.join(mmap_pth, 'metadata.json') 79 | 80 | additional_data = { 81 | "images": 82 | { 83 | 'h5_key' : 'images', 84 | 'mmap_filename' : 'images.npy', 85 | 'mmap_ts_filename' : 'timestamps.npy', 86 | 'mmap_event_idx_filename' : 'image_event_indices.npy', 87 | 'dims' : 3 88 | }, 89 | "flow": 90 | { 91 | 'h5_key' : 'flow', 92 | 'mmap_filename' : 'flow.npy', 93 | 'mmap_ts_filename' : 'flow_timestamps.npy', 94 | 'mmap_event_idx_filename' : 'flow_event_indices.npy', 95 | 'dims' : 3 96 | } 97 | } 98 | 99 | with h5py.File(h5_file_path, 'r') as f: 100 | num_events = f.attrs['num_events'] 101 | num_images = f.attrs['num_imgs'] 102 | num_flow = f.attrs['num_flow'] 103 | 104 | mmp_ts = np.memmap(ts_path, dtype='float64', mode='w+', shape=(num_events, 1)) 105 | mmp_xy = np.memmap(xy_path, dtype='int16', mode='w+', shape=(num_events, 2)) 106 | mmp_ps = np.memmap(ps_path, dtype='uint8', mode='w+', shape=(num_events, 1)) 107 | 108 | mmp_ts[:, 0] = f['events/ts'][:] 109 | mmp_xy[:, :] = np.stack((f['events/xs'][:], f['events/ys'][:])).transpose() 110 | mmp_ps[:, 0] = f['events/ps'][:] 111 | 112 | for data in additional_data: 113 | save_additional_data_as_mmap(f, mmap_pth, additional_data[data]) 114 | write_metadata(f, metadata_path) 115 | 116 | 117 | if __name__ == "__main__": 118 | """ 119 | Tool to convert this projects style hdf5 files to the memmap format used in some RPG projects 120 | """ 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("path", help="HDF5 file to convert") 123 | parser.add_argument("--output_dir", default=None, help="Path to extract (same as bag if left empty)") 124 | parser.add_argument('--not_overwrite', action='store_false', help='If set, will not overwrite\ 125 | existing memmap, but will place safe alternative') 126 | 127 | args = parser.parse_args() 128 | 129 | bagname = os.path.splitext(os.path.basename(args.path))[0] 130 | if args.output_dir is None: 131 | output_path = os.path.join(os.path.dirname(os.path.abspath(args.path)), bagname) 132 | else: 133 | output_path = os.path.join(args.output_dir, bagname) 134 | h5_to_memmap(args.path, output_path, overwrite=args.not_overwrite) 135 | -------------------------------------------------------------------------------- /generate_dataset/tools/read_events.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import os 4 | 5 | def compute_indices(event_stamps, frame_stamps): 6 | indices_first = np.searchsorted(event_stamps[:,0], frame_stamps[1:]) 7 | indices_last = np.searchsorted(event_stamps[:,0], frame_stamps[:-1]) 8 | index = np.stack([indices_first, indices_last], -1) 9 | return index 10 | 11 | def read_memmap_events(memmap_path, skip_frames=1, return_events=False, images_file = 'images.npy', 12 | images_ts_file = 'timestamps.npy', optic_flow_file = 'optic_flow.npy', 13 | optic_flow_ts_file = 'optic_flow_timestamps.npy', events_xy_file = 'xy.npy', 14 | events_p_file = 'p.npy', events_t_file = 't.npy'): 15 | assert os.path.isdir(memmap_path), '%s is not a valid memmap_pathectory' % memmap_path 16 | 17 | data = {} 18 | has_flow = False 19 | for subroot, _, fnames in sorted(os.walk(memmap_path)): 20 | for fname in sorted(fnames): 21 | path = os.path.join(subroot, fname) 22 | if fname.endswith(".npy"): 23 | if fname=="index.npy": # index mapping image index to event idx 24 | indices = np.load(path) # N x 2 25 | assert len(indices.shape) == 2 and indices.shape[1] == 2 26 | indices = indices.astype("int64") # ignore event indices which are 0 (before first image) 27 | data["index"] = indices.T 28 | elif fname==images_ts_file: 29 | data["frame_stamps"] = np.load(path)[::skip_frames,...] 30 | elif fname==images_file: 31 | data["images"] = np.load(path, mmap_mode="r")[::skip_frames,...] 32 | elif fname==optic_flow_file: 33 | data["optic_flow"] = np.load(path, mmap_mode="r")[::skip_frames,...] 34 | has_flow = True 35 | elif fname==optic_flow_ts_file: 36 | data["optic_flow_stamps"] = np.load(path)[::skip_frames,...] 37 | 38 | handle = np.load(path, mmap_mode="r") 39 | if fname==events_t_file: # timestamps 40 | data["t"] = handle[:].squeeze() if return_events else handle 41 | data["t0"] = handle[0] 42 | elif fname==events_xy_file: # coordinates 43 | data["xy"] = handle[:].squeeze() if return_events else handle 44 | elif fname==events_p_file: # polarity 45 | data["p"] = handle[:].squeeze() if return_events else handle 46 | 47 | if len(data) > 0: 48 | data['path'] = subroot 49 | if "t" not in data: 50 | raise Exception(f"Ignoring memmap_pathectory {subroot} since no events") 51 | if not (len(data['p']) == len(data['xy']) and len(data['p']) == len(data['t'])): 52 | raise Exception(f"Events from {subroot} invalid") 53 | data["num_events"] = len(data['p']) 54 | 55 | if "index" not in data and "frame_stamps" in data: 56 | data["index"] = compute_indices(data["t"], data['frame_stamps']) 57 | return data 58 | 59 | def read_h5_events(hdf_path): 60 | f = h5py.File(hdf_path, 'r') 61 | if 'events/x' in f: 62 | #legacy 63 | events = np.stack((f['events/x'][:], f['events/y'][:], f['events/ts'][:], np.where(f['events/p'][:], 1, -1)), axis=1) 64 | else: 65 | events = np.stack((f['events/xs'][:], f['events/ys'][:], f['events/ts'][:], np.where(f['events/ps'][:], 1, -1)), axis=1) 66 | return events 67 | 68 | def read_h5_event_components(hdf_path): 69 | f = h5py.File(hdf_path, 'r') 70 | if 'events/x' in f: 71 | #legacy 72 | return (f['events/x'][:], f['events/y'][:], f['events/ts'][:], np.where(f['events/p'][:], 1, -1)) 73 | else: 74 | return (f['events/xs'][:], f['events/ys'][:], f['events/ts'][:], np.where(f['events/ps'][:], 1, -1)) 75 | -------------------------------------------------------------------------------- /generate_dataset/tools/txt_to_h5.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import argparse 3 | import os 4 | import h5py 5 | import pandas as pd 6 | import numpy as np 7 | from event_packagers import * 8 | 9 | 10 | def get_sensor_size(txt_path): 11 | try: 12 | header = pd.read_csv(txt_path, delim_whitespace=True, header=None, names=['width', 'height'], 13 | dtype={'width': np.int, 'height': np.int}, 14 | nrows=1) 15 | width, height = header.values[0] 16 | sensor_size = [height, width] 17 | except: 18 | sensor_size = None 19 | print('Warning: could not read sensor size from first line of {}'.format(txt_path)) 20 | return sensor_size 21 | 22 | 23 | def extract_txt(txt_path, output_path, zero_timestamps=False, 24 | packager=hdf5_packager): 25 | ep = packager(output_path) 26 | first_ts = -1 27 | t0 = -1 28 | if not os.path.exists(txt_path): 29 | print("{} does not exist!".format(txt_path)) 30 | return 31 | 32 | # compute sensor size 33 | sensor_size = get_sensor_size(txt_path) 34 | # Extract events to h5 35 | ep.set_data_available(num_images=0, num_flow=0) 36 | total_num_pos, total_num_neg, last_ts = 0, 0, 0 37 | 38 | chunksize = 100000 39 | iterator = pd.read_csv(txt_path, delim_whitespace=True, header=None, 40 | names=['t', 'x', 'y', 'pol'], 41 | dtype={'t': np.float64, 'x': np.int16, 'y': np.int16, 'pol': np.int16}, 42 | engine='c', 43 | skiprows=1, chunksize=chunksize, nrows=None, memory_map=True) 44 | 45 | for i, event_window in enumerate(iterator): 46 | events = event_window.values 47 | ts = events[:, 0].astype(np.float64) 48 | xs = events[:, 1].astype(np.int16) 49 | ys = events[:, 2].astype(np.int16) 50 | ps = events[:, 3] 51 | ps[ps < 0] = 0 # should be [0 or 1] 52 | ps = ps.astype(bool) 53 | 54 | if first_ts == -1: 55 | first_ts = ts[0] 56 | 57 | if zero_timestamps: 58 | ts -= first_ts 59 | last_ts = ts[-1] 60 | if sensor_size is None or sensor_size[0] < max(ys) or sensor_size[1] < max(xs): 61 | sensor_size = [max(xs), max(ys)] 62 | print("Sensor size inferred from events as {}".format(sensor_size)) 63 | 64 | sum_ps = sum(ps) 65 | total_num_pos += sum_ps 66 | total_num_neg += len(ps) - sum_ps 67 | ep.package_events(xs, ys, ts, ps) 68 | if i % 10 == 9: 69 | print('Events written: {} M'.format((total_num_pos + total_num_neg) / 1e6)) 70 | print('Events written: {} M'.format((total_num_pos + total_num_neg) / 1e6)) 71 | print("Detect sensor size {}".format(sensor_size)) 72 | t0 = 0 if zero_timestamps else first_ts 73 | ep.add_metadata(total_num_pos, total_num_neg, last_ts-t0, t0, last_ts, num_imgs=0, num_flow=0, sensor_size=sensor_size) 74 | 75 | 76 | def extract_txts(txt_paths, output_dir, zero_timestamps=False): 77 | for path in txt_paths: 78 | filename = os.path.splitext(os.path.basename(path))[0] 79 | out_path = os.path.join(output_dir, "{}.h5".format(filename)) 80 | print("Extracting {} to {}".format(path, out_path)) 81 | extract_txt(path, out_path, zero_timestamps=zero_timestamps) 82 | 83 | 84 | if __name__ == "__main__": 85 | """ 86 | Tool for converting txt events to an efficient HDF5 format that can be speedily 87 | accessed by python code. 88 | """ 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("path", help="txt file to extract or directory containing txt files") 91 | parser.add_argument("--output_dir", default="/tmp/extracted_data", help="Folder where to extract the data") 92 | parser.add_argument('--zero_timestamps', action='store_true', help='If true, timestamps will be offset to start at 0') 93 | args = parser.parse_args() 94 | 95 | print('Data will be extracted in folder: {}'.format(args.output_dir)) 96 | if not os.path.exists(args.output_dir): 97 | os.makedirs(args.output_dir) 98 | if os.path.isdir(args.path): 99 | txt_paths = sorted(list(glob.glob(os.path.join(args.path, "*.txt"))) 100 | + list(glob.glob(os.path.join(args.path, "*.zip")))) 101 | else: 102 | txt_paths = [args.path] 103 | extract_txts(txt_paths, args.output_dir, zero_timestamps=args.zero_timestamps) 104 | -------------------------------------------------------------------------------- /generate_dataset/upsampling/README.md: -------------------------------------------------------------------------------- 1 | # Adaptive Upsampling 2 | 3 | ## Generate Upsampled Video or Image Sequences 4 | You can use our example directory to experiment 5 | ```bash 6 | device=cpu 7 | # device=cuda:0 8 | python upsample.py --input_dir=example/original --output_dir=example/upsampled --device=$device 9 | 10 | ``` 11 | The **expected input structure** is as follows: 12 | ``` 13 | input_dir 14 | ├── seq0 15 | │   ├── fps.txt 16 | │   └── imgs 17 | │   ├── 00000001.png 18 | │   ├── 00000002.png 19 | │   ├── 00000003.png 20 | │   └── ....png 21 | ├── seq1 22 | │   └── video.mp4 23 | └── dirname_does_not_matter 24 | ├── fps.txt 25 | └── filename_does_not_matter.mov 26 | 27 | ``` 28 | - The number of sequences (subfolders of the input directory) is unlimited. 29 | - The `fps.txt` file 30 | - must specify the frames per second in the first line. The rest of the file should be empty (see example directory). 31 | - is required for sequences (such as seq0) with image files. 32 | - is **optional** for sequences with a video file. In case of a missing `fps.txt` file, the frames per second will be inferred from the metadata of the video file. 33 | 34 | The **resulting output structure** is as follows: 35 | ``` 36 | output_dir 37 | ├── seq0 38 | │   ├── imgs 39 | │   │ ├── 00000001.png 40 | │   │ ├── 00000002.png 41 | │   │ ├── 00000003.png 42 | │   │ └── ....png 43 | │ └── timestamps.txt 44 | ├── seq1 45 | │   ├── imgs 46 | │   │ ├── 00000001.png 47 | │   │ ├── 00000002.png 48 | │   │ ├── 00000003.png 49 | │   │ └── ....png 50 | │ └── timestamps.txt 51 | └── dirname_does_not_matter 52 |    ├── imgs 53 |    │ ├── 00000001.png 54 |    │ ├── 00000002.png 55 |    │ ├── 00000003.png 56 |    │ └── ....png 57 | └── timestamps.txt 58 | ``` 59 | The resulting image directories can later be used to generate events. The `timestamps.txt` file contains the timestamp of each image in seconds. 60 | 61 | 62 | ## Remarks 63 | - Use a GPU device whenever possible to speed up the upsampling procedure. 64 | - The upsampling will increase the storage requirements significantly. Try a small sample first to get an impression. 65 | - Downsample (height and width) your images and video to save storage space and processing time. 66 | - Why store the upsampling result in images: 67 | - Images support random access from a dataloader. A video file, for example, can typically only be accessed sequentally when we try to avoid loading the whole video into RAM. 68 | - Same sequence can be accessed by multiple processes (e.g. PyTorch num\_workers > 1). 69 | - Well established C++ interface to load images. This is useful to generate events on the fly (needed for contrast threshold randomization) in C++ code without loading data in Python first. 70 | If there is a need to store the resulting sequences in a different format, raise an issue (feature request) on this GitHub repository. 71 | - Be aware that upsampling videos might fail due to a [bug in scikit-video](https://github.com/scikit-video/scikit-video/issues/60) 72 | 73 | ### Generating Video Files from Images 74 | If you want to convert an ordered sequence of images (here png files) into video format you can use the following command (you may have to deactivate the current conda environment): 75 | ```bash 76 | frame_rate=25 77 | img_dirpath="example/original/seq0/imgs" 78 | img_suffix=".png" 79 | output_file="video.mp4" 80 | ffmpeg -framerate $frame_rate -pattern_type glob -i "$img_dirpath/*$img_suffix" -c:v libx265 -x265-params lossless=1 $output_file 81 | ``` 82 | 83 | ### Generating Images from a Video File 84 | If you want to convert a video file to a sequence of images: 85 | ```bash 86 | input_file="video.mp4" 87 | output_dirpath="your_path_to_specify" 88 | ffmpeg -i $input_file "$output_dirpath/%08d.png" 89 | ``` 90 | -------------------------------------------------------------------------------- /generate_dataset/upsampling/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /generate_dataset/upsampling/example/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !/original/ 4 | -------------------------------------------------------------------------------- /generate_dataset/upsampling/upsample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | # Must be set before importing torch. 4 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 5 | 6 | from utils import Upsampler 7 | 8 | 9 | def get_flags(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--input_dir", required=True, help='Path to input directory. See README.md for expected structure of the directory.') 12 | parser.add_argument("--output_dir", required=True, help='Path to non-existing output directory. This script will generate the directory.') 13 | parser.add_argument("--device", type=str, default="cuda:0", help='Device to be used (cpu, cuda:X)') 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | def main(): 19 | flags = get_flags() 20 | 21 | upsampler = Upsampler( 22 | input_dir=flags.input_dir, 23 | output_dir=flags.output_dir, 24 | device=flags.device) 25 | upsampler.upsample() 26 | 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /generate_dataset/upsampling/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Sequence 2 | from .upsampler import Upsampler 3 | from .utils import get_sequence_or_none 4 | -------------------------------------------------------------------------------- /generate_dataset/upsampling/utils/const.py: -------------------------------------------------------------------------------- 1 | mean = [0.429, 0.431, 0.397] 2 | std = [1, 1, 1] 3 | fps_filename = 'fps.txt' 4 | imgs_dirname = 'imgs' 5 | # TODO(magehrig): Use https://github.com/ahupp/python-magic instead. 6 | video_formats = {'.webm', '.mp4', '.m4p', '.m4v', '.avi', '.avchd', '.ogg', '.mov', '.ogv', '.vob', '.f4v', '.mkv', '.svi', '.m2v', '.mpg', '.mp2', '.mpeg', '.mpe', '.mpv', '.amv', '.wmv', '.flv', '.mts', '.m2ts', '.ts', '.qt', '.3gp', '.3g2', '.f4p', '.f4a', '.f4b'} 7 | img_formats = {'.png', '.jpg', '.jpeg', '.bmp', '.pbm', '.pgm', '.ppm', '.pnm', '.webp', '.tiff', '.tif'} 8 | -------------------------------------------------------------------------------- /generate_dataset/upsampling/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | from fractions import Fraction 6 | from PIL import Image 7 | import skvideo.io 8 | import torch 9 | import torchvision.transforms as transforms 10 | 11 | from .const import mean, std, img_formats 12 | 13 | 14 | class Sequence: 15 | def __init__(self): 16 | normalize = transforms.Normalize(mean=mean, std=std) 17 | self.transform = transforms.Compose([transforms.ToTensor(), normalize]) 18 | 19 | def __iter__(self): 20 | return self 21 | 22 | def __next__(self): 23 | raise NotImplementedError 24 | 25 | def __len__(self): 26 | raise NotImplementedError 27 | 28 | 29 | class ImageSequence(Sequence): 30 | def __init__(self, imgs_dirpath: str, fps: float): 31 | super().__init__() 32 | self.fps = fps 33 | 34 | assert os.path.isdir(imgs_dirpath) 35 | self.imgs_dirpath = imgs_dirpath 36 | 37 | self.file_names = [f for f in os.listdir(imgs_dirpath) if self._is_img_file(f)] 38 | assert self.file_names 39 | self.file_names.sort() 40 | 41 | @classmethod 42 | def _is_img_file(cls, path: str): 43 | return Path(path).suffix.lower() in img_formats 44 | 45 | def __next__(self): 46 | for idx in range(0, len(self.file_names) - 1): 47 | file_paths = self._get_path_from_name([self.file_names[idx], self.file_names[idx + 1]]) 48 | imgs = list() 49 | for file_path in file_paths: 50 | img = self._pil_loader(file_path) 51 | img = self.transform(img) 52 | imgs.append(img) 53 | times_sec = [idx/self.fps, (idx + 1)/self.fps] 54 | yield imgs, times_sec 55 | 56 | def __len__(self): 57 | return len(self.file_names) - 1 58 | 59 | @staticmethod 60 | def _pil_loader(path): 61 | with open(path, 'rb') as f: 62 | img = Image.open(f) 63 | img = img.convert('RGB') 64 | 65 | w_orig, h_orig = img.size 66 | w, h = w_orig//32*32, h_orig//32*32 67 | 68 | left = (w_orig - w)//2 69 | upper = (h_orig - h)//2 70 | right = left + w 71 | lower = upper + h 72 | img = img.crop((left, upper, right, lower)) 73 | return img 74 | 75 | def _get_path_from_name(self, file_names: Union[list, str]) -> Union[list, str]: 76 | if isinstance(file_names, list): 77 | return [os.path.join(self.imgs_dirpath, f) for f in file_names] 78 | return os.path.join(self.imgs_dirpath, file_names) 79 | 80 | 81 | class VideoSequence(Sequence): 82 | def __init__(self, video_filepath: str, fps: float=None): 83 | super().__init__() 84 | metadata = skvideo.io.ffprobe(video_filepath) 85 | self.fps = fps 86 | if self.fps is None: 87 | self.fps = float(Fraction(metadata['video']['@avg_frame_rate'])) 88 | assert self.fps > 0, 'Could not retrieve fps from video metadata. fps: {}'.format(self.fps) 89 | print('Using video metadata: Got fps of {} frames/sec'.format(self.fps)) 90 | 91 | # Length is number of frames - 1 (because we return pairs). 92 | self.len = int(metadata['video']['@nb_frames']) - 1 93 | self.videogen = skvideo.io.vreader(video_filepath) 94 | self.last_frame = None 95 | 96 | def __next__(self): 97 | for idx, frame in enumerate(self.videogen): 98 | h_orig, w_orig, _ = frame.shape 99 | w, h = w_orig//32*32, h_orig//32*32 100 | 101 | left = (w_orig - w)//2 102 | upper = (h_orig - h)//2 103 | right = left + w 104 | lower = upper + h 105 | frame = frame[upper:lower, left:right] 106 | assert frame.shape[:2] == (h, w) 107 | frame = self.transform(frame) 108 | 109 | if self.last_frame is None: 110 | self.last_frame = frame 111 | continue 112 | last_frame_copy = self.last_frame.detach().clone() 113 | self.last_frame = frame 114 | imgs = [last_frame_copy, frame] 115 | times_sec = [(idx - 1)/self.fps, idx/self.fps] 116 | yield imgs, times_sec 117 | 118 | def __len__(self): 119 | return self.len 120 | -------------------------------------------------------------------------------- /generate_dataset/upsampling/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | from .const import fps_filename, imgs_dirname, video_formats 6 | from .dataset import Sequence, ImageSequence, VideoSequence 7 | 8 | def is_video_file(filepath: str) -> bool: 9 | return Path(filepath).suffix.lower() in video_formats 10 | 11 | def get_fps_file(dirpath: str) -> Union[None, str]: 12 | fps_file = os.path.join(dirpath, fps_filename) 13 | if os.path.isfile(fps_file): 14 | return fps_file 15 | return None 16 | 17 | def get_imgs_directory(dirpath: str) -> Union[None, str]: 18 | imgs_dir = os.path.join(dirpath, imgs_dirname) 19 | if os.path.isdir(imgs_dir): 20 | return imgs_dir 21 | return None 22 | 23 | def get_video_file(dirpath: str) -> Union[None, str]: 24 | filenames = [f for f in os.listdir(dirpath) if is_video_file(f)] 25 | if len(filenames) == 0: 26 | return None 27 | assert len(filenames) == 1 28 | filepath = os.path.join(dirpath, filenames[0]) 29 | return filepath 30 | 31 | def fps_from_file(fps_file) -> float: 32 | assert os.path.isfile(fps_file) 33 | with open(fps_file, 'r') as f: 34 | fps = float(f.readline().strip()) 35 | assert fps > 0, 'Expected fps to be larger than 0. Instead got fps={}'.format(fps) 36 | return fps 37 | 38 | def get_sequence_or_none(dirpath: str) -> Union[None, Sequence]: 39 | fps_file = get_fps_file(dirpath) 40 | if fps_file: 41 | # Must be a sequence (either ImageSequence or VideoSequence) 42 | fps = fps_from_file(fps_file) 43 | imgs_dir = get_imgs_directory(dirpath) 44 | if imgs_dir: 45 | return ImageSequence(imgs_dir, fps) 46 | video_file = get_video_file(dirpath) 47 | assert video_file is not None 48 | return VideoSequence(video_file, fps) 49 | # Can be VideoSequence if there is a video file. But have to use fps from meta data. 50 | video_file = get_video_file(dirpath) 51 | if video_file is not None: 52 | return VideoSequence(video_file) 53 | return None 54 | 55 | 56 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .visualization import * -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | from collections import OrderedDict 5 | import json 6 | 7 | 8 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): 9 | """ 10 | Setup logging configuration 11 | """ 12 | log_config = Path(log_config) 13 | if log_config.is_file(): 14 | config = read_json(log_config) 15 | # modify logging paths based on run config 16 | for _, handler in config['handlers'].items(): 17 | if 'filename' in handler: 18 | handler['filename'] = str(save_dir / handler['filename']) 19 | 20 | logging.config.dictConfig(config) 21 | else: 22 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 23 | logging.basicConfig(level=default_level) 24 | 25 | 26 | def read_json(fname): 27 | fname = Path(fname) 28 | with fname.open('rt') as handle: 29 | return json.load(handle, object_hook=OrderedDict) -------------------------------------------------------------------------------- /logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.txt", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /logger/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from datetime import datetime 3 | 4 | 5 | class TensorboardWriter(): 6 | def __init__(self, log_dir, logger, enabled): 7 | self.writer = None 8 | self.selected_module = "" 9 | 10 | if enabled: 11 | log_dir = str(log_dir) 12 | 13 | # Retrieve vizualization writer. 14 | succeeded = False 15 | for module in ["torch.utils.tensorboard", "tensorboardX"]: 16 | try: 17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 18 | succeeded = True 19 | break 20 | except ImportError: 21 | succeeded = False 22 | self.selected_module = module 23 | 24 | if not succeeded: 25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 26 | "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \ 27 | "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." 28 | logger.warning(message) 29 | 30 | self.step = 0 31 | self.mode = '' 32 | 33 | self.tb_writer_ftns = { 34 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 35 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding', 'add_video' 36 | } 37 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding', 'add_video'} 38 | self.timer = datetime.now() 39 | 40 | def set_step(self, step, mode='train'): 41 | self.mode = mode 42 | self.step = step 43 | if step == 0: 44 | self.timer = datetime.now() 45 | else: 46 | duration = datetime.now() - self.timer 47 | self.add_scalar('steps_per_sec', 1 / duration.total_seconds()) 48 | self.timer = datetime.now() 49 | 50 | def __getattr__(self, name): 51 | """ 52 | If visualization is configured to use: 53 | return add_data() methods of tensorboard with additional information (step, tag) added. 54 | Otherwise: 55 | return a blank function handle that does nothing 56 | """ 57 | if name in self.tb_writer_ftns: 58 | add_data = getattr(self.writer, name, None) 59 | 60 | def wrapper(tag, data, *args, **kwargs): 61 | if add_data is not None: 62 | # add mode(train/valid) tag 63 | if name not in self.tag_mode_exceptions: 64 | tag = '{}/{}'.format(tag, self.mode) 65 | add_data(tag, data, self.step, *args, **kwargs) 66 | return wrapper 67 | else: 68 | # default action for returning methods defined in this class, set_step() for instance. 69 | try: 70 | attr = object.__getattr__(name) 71 | except AttributeError: 72 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 73 | return attr 74 | -------------------------------------------------------------------------------- /loss/PerceptualSimilarity/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | # from skimage.measure import compare_ssim 8 | from skimage.metrics import structural_similarity 9 | import torch 10 | from torch.autograd import Variable 11 | 12 | # import sys, os 13 | # sys.path.append(os.getcwd()) 14 | from loss.PerceptualSimilarity.models import dist_model 15 | # import .dist_model 16 | # import dist_model 17 | 18 | 19 | class PerceptualLoss(torch.nn.Module): 20 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) 21 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 22 | super(PerceptualLoss, self).__init__() 23 | print('Setting up Perceptual loss...') 24 | self.use_gpu = use_gpu 25 | self.spatial = spatial 26 | self.gpu_ids = gpu_ids 27 | self.model = dist_model.DistModel() 28 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) 29 | print('...[%s] initialized'%self.model.name()) 30 | print('...Done') 31 | 32 | def forward(self, pred, target, normalize=False): 33 | """ 34 | Pred and target are Variables. 35 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 36 | If normalize is False, assumes the images are already between [-1,+1] 37 | 38 | Inputs pred and target are Nx3xHxW 39 | Output pytorch Variable N long 40 | """ 41 | 42 | if normalize: 43 | target = 2 * target - 1 44 | pred = 2 * pred - 1 45 | 46 | return self.model.forward(target, pred) 47 | 48 | def normalize_tensor(in_feat,eps=1e-10): 49 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 50 | return in_feat/(norm_factor+eps) 51 | 52 | def l2(p0, p1, range=255.): 53 | return .5*np.mean((p0 / range - p1 / range)**2) 54 | 55 | def psnr(p0, p1, peak=255.): 56 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 57 | 58 | def dssim(p0, p1, range=255.): 59 | return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2. 60 | 61 | def rgb2lab(in_img,mean_cent=False): 62 | from skimage import color 63 | img_lab = color.rgb2lab(in_img) 64 | if(mean_cent): 65 | img_lab[:,:,0] = img_lab[:,:,0]-50 66 | return img_lab 67 | 68 | def tensor2np(tensor_obj): 69 | # change dimension of a tensor object into a numpy array 70 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 71 | 72 | def np2tensor(np_obj): 73 | # change dimenion of np array into tensor array 74 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 75 | 76 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 77 | # image tensor to lab tensor 78 | from skimage import color 79 | 80 | img = tensor2im(image_tensor) 81 | img_lab = color.rgb2lab(img) 82 | if(mc_only): 83 | img_lab[:,:,0] = img_lab[:,:,0]-50 84 | if(to_norm and not mc_only): 85 | img_lab[:,:,0] = img_lab[:,:,0]-50 86 | img_lab = img_lab/100. 87 | 88 | return np2tensor(img_lab) 89 | 90 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 91 | from skimage import color 92 | import warnings 93 | warnings.filterwarnings("ignore") 94 | 95 | lab = tensor2np(lab_tensor)*100. 96 | lab[:,:,0] = lab[:,:,0]+50 97 | 98 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 99 | if(return_inbnd): 100 | # convert back to lab, see if we match 101 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 102 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 103 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 104 | return (im2tensor(rgb_back),mask) 105 | else: 106 | return im2tensor(rgb_back) 107 | 108 | def rgb2lab(input): 109 | from skimage import color 110 | return color.rgb2lab(input / 255.) 111 | 112 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 113 | image_numpy = image_tensor[0].cpu().float().numpy() 114 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 115 | return image_numpy.astype(imtype) 116 | 117 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 118 | return torch.Tensor((image / factor - cent) 119 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 120 | 121 | def tensor2vec(vector_tensor): 122 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 123 | 124 | def voc_ap(rec, prec, use_07_metric=False): 125 | """ ap = voc_ap(rec, prec, [use_07_metric]) 126 | Compute VOC AP given precision and recall. 127 | If use_07_metric is true, uses the 128 | VOC 07 11 point method (default:False). 129 | """ 130 | if use_07_metric: 131 | # 11 point metric 132 | ap = 0. 133 | for t in np.arange(0., 1.1, 0.1): 134 | if np.sum(rec >= t) == 0: 135 | p = 0 136 | else: 137 | p = np.max(prec[rec >= t]) 138 | ap = ap + p / 11. 139 | else: 140 | # correct AP calculation 141 | # first append sentinel values at the end 142 | mrec = np.concatenate(([0.], rec, [1.])) 143 | mpre = np.concatenate(([0.], prec, [0.])) 144 | 145 | # compute the precision envelope 146 | for i in range(mpre.size - 1, 0, -1): 147 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 148 | 149 | # to calculate area under PR curve, look for points 150 | # where X axis (recall) changes value 151 | i = np.where(mrec[1:] != mrec[:-1])[0] 152 | 153 | # and sum (\Delta recall) * prec 154 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 155 | return ap 156 | 157 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 158 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 159 | image_numpy = image_tensor[0].cpu().float().numpy() 160 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 161 | return image_numpy.astype(imtype) 162 | 163 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 164 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 165 | return torch.Tensor((image / factor - cent) 166 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 167 | -------------------------------------------------------------------------------- /loss/PerceptualSimilarity/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /loss/PerceptualSimilarity/models/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | from IPython import embed 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2,5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 61 | self.slice1 = torch.nn.Sequential() 62 | self.slice2 = torch.nn.Sequential() 63 | self.slice3 = torch.nn.Sequential() 64 | self.slice4 = torch.nn.Sequential() 65 | self.slice5 = torch.nn.Sequential() 66 | self.N_slices = 5 67 | for x in range(2): 68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 69 | for x in range(2, 5): 70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 71 | for x in range(5, 8): 72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 73 | for x in range(8, 10): 74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(10, 12): 76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 77 | if not requires_grad: 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, X): 82 | h = self.slice1(X) 83 | h_relu1 = h 84 | h = self.slice2(h) 85 | h_relu2 = h 86 | h = self.slice3(h) 87 | h_relu3 = h 88 | h = self.slice4(h) 89 | h_relu4 = h 90 | h = self.slice5(h) 91 | h_relu5 = h 92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 94 | 95 | return out 96 | 97 | class vgg16(torch.nn.Module): 98 | def __init__(self, requires_grad=False, pretrained=True): 99 | super(vgg16, self).__init__() 100 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 101 | self.slice1 = torch.nn.Sequential() 102 | self.slice2 = torch.nn.Sequential() 103 | self.slice3 = torch.nn.Sequential() 104 | self.slice4 = torch.nn.Sequential() 105 | self.slice5 = torch.nn.Sequential() 106 | self.N_slices = 5 107 | for x in range(4): 108 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 109 | for x in range(4, 9): 110 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(9, 16): 112 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(16, 23): 114 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(23, 30): 116 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 117 | if not requires_grad: 118 | for param in self.parameters(): 119 | param.requires_grad = False 120 | 121 | def forward(self, X): 122 | h = self.slice1(X) 123 | h_relu1_2 = h 124 | h = self.slice2(h) 125 | h_relu2_2 = h 126 | h = self.slice3(h) 127 | h_relu3_3 = h 128 | h = self.slice4(h) 129 | h_relu4_3 = h 130 | h = self.slice5(h) 131 | h_relu5_3 = h 132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 134 | 135 | return out 136 | 137 | 138 | 139 | class resnet(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True, num=18): 141 | super(resnet, self).__init__() 142 | if(num==18): 143 | self.net = tv.resnet18(pretrained=pretrained) 144 | elif(num==34): 145 | self.net = tv.resnet34(pretrained=pretrained) 146 | elif(num==50): 147 | self.net = tv.resnet50(pretrained=pretrained) 148 | elif(num==101): 149 | self.net = tv.resnet101(pretrained=pretrained) 150 | elif(num==152): 151 | self.net = tv.resnet152(pretrained=pretrained) 152 | self.N_slices = 5 153 | 154 | self.conv1 = self.net.conv1 155 | self.bn1 = self.net.bn1 156 | self.relu = self.net.relu 157 | self.maxpool = self.net.maxpool 158 | self.layer1 = self.net.layer1 159 | self.layer2 = self.net.layer2 160 | self.layer3 = self.net.layer3 161 | self.layer4 = self.net.layer4 162 | 163 | def forward(self, X): 164 | h = self.conv1(X) 165 | h = self.bn1(h) 166 | h = self.relu(h) 167 | h_relu1 = h 168 | h = self.maxpool(h) 169 | h = self.layer1(h) 170 | h_conv2 = h 171 | h = self.layer2(h) 172 | h_conv3 = h 173 | h = self.layer3(h) 174 | h_conv4 = h 175 | h = self.layer4(h) 176 | h_conv5 = h 177 | 178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /loss/PerceptualSimilarity/models/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/loss/PerceptualSimilarity/models/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /loss/PerceptualSimilarity/models/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/loss/PerceptualSimilarity/models/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /loss/PerceptualSimilarity/models/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/loss/PerceptualSimilarity/models/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /loss/PerceptualSimilarity/models/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/loss/PerceptualSimilarity/models/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /loss/PerceptualSimilarity/models/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/loss/PerceptualSimilarity/models/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /loss/PerceptualSimilarity/models/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/loss/PerceptualSimilarity/models/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .flow import * 2 | from .reconstruction import * 3 | from .restore import * 4 | from .adversarial import Adversarial -------------------------------------------------------------------------------- /loss/adversarial.py: -------------------------------------------------------------------------------- 1 | from loss import discriminator 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | import torch.optim.lr_scheduler as lrs 8 | 9 | 10 | def make_optimizer(args, my_model): 11 | trainable = filter(lambda x: x.requires_grad, my_model.parameters()) 12 | 13 | if args['type'] == 'SGD': 14 | optimizer_function = optim.SGD 15 | kwargs = {'momentum': 0.9} 16 | elif args['type'] == 'ADAM': 17 | optimizer_function = optim.Adam 18 | kwargs = { 19 | 'betas': (0.9, 0.999), 20 | 'eps': 1e-08 21 | } 22 | elif args['type'] == 'ADAMax': 23 | optimizer_function = optim.Adamax 24 | kwargs = { 25 | 'betas': (0.9, 0.999), 26 | 'eps': 1e-08 27 | } 28 | elif args['type'] == 'RMSprop': 29 | optimizer_function = optim.RMSprop 30 | kwargs = {'eps': 1e-08} 31 | 32 | kwargs['lr'] = args['lr'] 33 | kwargs['weight_decay'] = args['weight_decay'] 34 | 35 | return optimizer_function(trainable, **kwargs) 36 | 37 | 38 | def make_scheduler(args, my_optimizer): 39 | if args['decay_type'] == 'step': 40 | scheduler = lrs.StepLR( 41 | my_optimizer, 42 | step_size=args['lr_decay'], 43 | gamma=args['gamma'] 44 | ) 45 | elif args['decay_type'].find('step') >= 0: 46 | milestones = args['decay_type'].split('_') 47 | milestones.pop(0) 48 | milestones = list(map(lambda x: int(x), milestones)) 49 | scheduler = lrs.MultiStepLR( 50 | my_optimizer, 51 | milestones=milestones, 52 | gamma=args['gamma'] 53 | ) 54 | elif args['decay_type'] == 'plateau': 55 | scheduler = lrs.ReduceLROnPlateau( 56 | my_optimizer, 57 | mode='max', 58 | factor=args['gamma'], 59 | patience=args['patience'], 60 | threshold=0.01, # metric to be used is psnr 61 | threshold_mode='abs', 62 | verbose=True 63 | ) 64 | 65 | return scheduler 66 | 67 | 68 | class Adversarial(nn.Module): 69 | def __init__(self, PatchSize, gan_type): 70 | super(Adversarial, self).__init__() 71 | self.gan_type = gan_type 72 | self.gan_k = 1 73 | if gan_type == 'T_WGAN_GP': 74 | self.discriminator = discriminator.Temporal_Discriminator(PatchSize) 75 | elif gan_type == 'FI_GAN': 76 | self.discriminator = discriminator.FI_Discriminator(PatchSize) 77 | elif gan_type == 'FI_Cond_GAN': 78 | self.discriminator = discriminator.FI_Cond_Discriminator(PatchSize) 79 | elif gan_type == 'STGAN': 80 | self.discriminator = discriminator.ST_Discriminator(PatchSize) 81 | else: 82 | self.discriminator = discriminator.Discriminator(PatchSize, gan_type) 83 | 84 | # optimizer 85 | if gan_type != 'WGAN_GP' and gan_type != 'T_WGAN_GP': 86 | self.optimizer = make_optimizer(args={'type': 'ADAMax', 'lr': 0.001, 'weight_decay': 0}, my_model=self.discriminator) 87 | else: 88 | self.optimizer = optim.Adam( 89 | self.discriminator.parameters(), 90 | betas=(0, 0.9), eps=1e-8, lr=1e-5 91 | ) 92 | self.scheduler = make_scheduler(args={'decay_type': 'plateau', 'gamma': 0.5, 'patience': 5}, my_optimizer=self.optimizer) 93 | 94 | def forward(self, fake, real, input_frames=None): 95 | # if len(input_frames) == 4: 96 | # input_frames = input_frames[1:3] 97 | fake_detach = fake.detach() 98 | 99 | self.loss = 0 100 | for _ in range(self.gan_k): 101 | self.optimizer.zero_grad() 102 | # discriminator forward pass 103 | if self.gan_type in ['T_WGAN_GP', 'FI_Cond_GAN', 'STGAN']: 104 | d_fake = self.discriminator(input_frames[:, 0], fake_detach, input_frames[:, 1]) 105 | d_real = self.discriminator(input_frames[:, 0], real, input_frames[:, 1]) 106 | elif self.gan_type == 'FI_GAN': 107 | d_01 = self.discriminator(input_frames[:, 0], fake_detach) 108 | d_12 = self.discriminator(fake_detach, input_frames[:, 1]) 109 | else: 110 | d_fake = self.discriminator(fake_detach) 111 | d_real = self.discriminator(real) 112 | 113 | # compute discriminator loss 114 | if self.gan_type in ['GAN', 'FI_Cond_GAN', 'STGAN']: 115 | label_fake = torch.zeros_like(d_fake) 116 | label_real = torch.ones_like(d_real) 117 | loss_d = F.binary_cross_entropy_with_logits(d_fake, label_fake) + F.binary_cross_entropy_with_logits(d_real, label_real) 118 | elif self.gan_type == 'FI_GAN': 119 | label_01 = torch.zeros_like(d_01) 120 | label_12 = torch.ones_like(d_12) 121 | loss_d = F.binary_cross_entropy_with_logits(d_01, label_01) + F.binary_cross_entropy_with_logits(d_12, label_12) 122 | elif self.gan_type.find('WGAN') >= 0: 123 | loss_d = (d_fake - d_real).mean() 124 | if self.gan_type.find('GP') >= 0: 125 | epsilon = torch.rand_like(fake) 126 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) 127 | hat.requires_grad = True 128 | d_hat = self.discriminator(hat) 129 | gradients = torch.autograd.grad( 130 | outputs=d_hat.sum(), inputs=hat, 131 | retain_graph=True, create_graph=True, only_inputs=True 132 | )[0] 133 | gradients = gradients.view(gradients.size(0), -1) 134 | gradient_norm = gradients.norm(2, dim=1) 135 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() 136 | loss_d += gradient_penalty 137 | 138 | # Discriminator update 139 | self.loss += loss_d.item() 140 | loss_d.backward() 141 | self.optimizer.step() 142 | 143 | if self.gan_type == 'WGAN': 144 | for p in self.discriminator.parameters(): 145 | p.data.clamp_(-1, 1) 146 | 147 | self.loss /= self.gan_k 148 | 149 | if self.gan_type == 'GAN': 150 | d_fake_for_g = self.discriminator(fake) 151 | loss_g = F.binary_cross_entropy_with_logits(d_fake_for_g, label_real) 152 | 153 | elif self.gan_type == 'FI_GAN': 154 | d_01_for_g = F.sigmoid(self.discriminator(input_frames[:, 0], fake)) 155 | d_12_for_g = F.sigmoid(self.discriminator(fake, input_frames[:, 1])) 156 | loss_g = d_01_for_g * torch.log(d_01_for_g + 1e-12) + d_12_for_g * torch.log(d_12_for_g + 1e-12) 157 | loss_g = loss_g.mean() 158 | 159 | elif self.gan_type.find('WGAN') >= 0: 160 | d_fake_for_g = self.discriminator(fake) 161 | loss_g = -d_fake_for_g.mean() 162 | 163 | elif self.gan_type in ['FI_Cond_GAN', 'STGAN']: 164 | d_fake_for_g = self.discriminator(input_frames[:, 0], fake, input_frames[:, 1]) 165 | loss_g = F.binary_cross_entropy_with_logits(d_fake_for_g, label_real) 166 | 167 | # Generator loss 168 | return loss_g 169 | -------------------------------------------------------------------------------- /loss/reconstruction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | from .flow import AveragedIWE 9 | 10 | parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 11 | sys.path.append(parent_dir_name) 12 | 13 | from myutils.iwe import deblur_events 14 | from myutils.gradients import Sobel 15 | 16 | 17 | class BrightnessConstancy(torch.nn.Module): 18 | """ 19 | Self-supervised image reconstruction loss, as described in Section 3.4 of the paper 'Back to Event Basics: 20 | Self-Supervised Image Reconstruction for Event Cameras via Photometric Constancy', Paredes-Valles et al., CVPR'21. 21 | The reconstruction loss is the combination of three components. 22 | 1) Image reconstruction through the generative model of event cameras. The reconstruction error propagates back 23 | through the spatial gradients of the reconstructed images. The loss consists in an L2-norm of the difference of the 24 | brightness increment images that can be obtained through the generative model and by means of event integration. 25 | 2) Temporal consistency. Simple L1-norm of the warping error between two consecutive reconstructed frames. 26 | 3) Image regularization. Conventional total variation formulation. 27 | """ 28 | 29 | def __init__(self, config, device): 30 | super(BrightnessConstancy, self).__init__() 31 | self.sobel = Sobel(device) 32 | self.res = config["loader"]["resolution"] 33 | self.flow_scaling = max(config["loader"]["resolution"]) 34 | self.weights = config["loss"]["reconstruction_regul_weight"] 35 | 36 | col_idx = np.linspace(0, self.res[1] - 1, num=self.res[1]) 37 | row_idx = np.linspace(0, self.res[0] - 1, num=self.res[0]) 38 | mx, my = np.meshgrid(col_idx, row_idx) 39 | indices = np.zeros((1, 2, self.res[0], self.res[1])) 40 | indices[:, 0, :, :] = my 41 | indices[:, 1, :, :] = mx 42 | self.indices = torch.from_numpy(indices).float().to(device) 43 | 44 | self.averaged_iwe = AveragedIWE(config, device) 45 | 46 | def generative_model(self, flow, img, inputs): 47 | """ 48 | :param flow: [batch_size x 2 x H x W] optical flow map 49 | :param img: [batch_size x 1 x H x W] last reconstructed image 50 | :param inputs: dataloader dictionary 51 | :return generative model loss 52 | """ 53 | 54 | event_cnt = inputs["inp_cnt"].to(flow.device) 55 | event_list = inputs["inp_list"].to(flow.device) 56 | pol_mask = inputs["inp_pol_mask"].to(flow.device) 57 | 58 | # mask optical flow with input events 59 | flow_mask = torch.sum(event_cnt, dim=1, keepdim=True) 60 | flow_mask[flow_mask > 0] = 1 61 | flow = flow * flow_mask 62 | 63 | # foward warping metrics 64 | warped_y = self.indices[:, 0:1, :, :] - flow[:, 1:2, :, :] * self.flow_scaling 65 | warped_x = self.indices[:, 1:2, :, :] - flow[:, 0:1, :, :] * self.flow_scaling 66 | warped_y = 2 * warped_y / (self.res[0] - 1) - 1 67 | warped_x = 2 * warped_x / (self.res[1] - 1) - 1 68 | grid_pos = torch.cat([warped_x, warped_y], dim=1).permute(0, 2, 3, 1) 69 | 70 | # warped predicted brightness increment (previous image) 71 | img_gradx, img_grady = self.sobel(img) 72 | warped_img_grady = F.grid_sample(img_grady, grid_pos, mode="bilinear", padding_mode="zeros") 73 | warped_img_gradx = F.grid_sample(img_gradx, grid_pos, mode="bilinear", padding_mode="zeros") 74 | pred_deltaL = warped_img_gradx * flow[:, 0:1, :, :] + warped_img_grady * flow[:, 1:2, :, :] 75 | pred_deltaL = pred_deltaL * self.flow_scaling 76 | 77 | # warped brightness increment from the averaged image of warped events 78 | avg_iwe = self.averaged_iwe(flow, event_list, pol_mask) 79 | event_deltaL = avg_iwe[:, 0:1, :, :] - avg_iwe[:, 1:2, :, :] # C == 1 80 | 81 | # squared L2 norm - brightness constancy error 82 | bc_error = event_deltaL + pred_deltaL 83 | bc_error = ( 84 | torch.norm( 85 | bc_error.view( 86 | bc_error.shape[0], 87 | bc_error.shape[1], 88 | 1, 89 | -1, 90 | ), 91 | p=2, 92 | dim=3, 93 | ) 94 | ** 2 95 | ) # norm in the spatial dimension 96 | 97 | return bc_error.sum() 98 | 99 | def temporal_consistency(self, flow, prev_img, img): 100 | """ 101 | :param flow: [batch_size x 2 x H x W] optical flow map 102 | :param prev_img: [batch_size x 1 x H x W] previous reconstructed image 103 | :param img: [batch_size x 1 x H x W] last reconstructed image 104 | :return weighted temporal consistency loss 105 | """ 106 | 107 | # foward warping metrics 108 | warped_y = self.indices[:, 0:1, :, :] - flow[:, 1:2, :, :] * self.flow_scaling 109 | warped_x = self.indices[:, 1:2, :, :] - flow[:, 0:1, :, :] * self.flow_scaling 110 | warped_y = 2 * warped_y / (self.res[0] - 1) - 1 111 | warped_x = 2 * warped_x / (self.res[1] - 1) - 1 112 | grid_pos = torch.cat([warped_x, warped_y], dim=1).permute(0, 2, 3, 1) 113 | 114 | # temporal consistency 115 | warped_prev_img = F.grid_sample(prev_img, grid_pos, mode="bilinear", padding_mode="zeros") 116 | tc_error = img - warped_prev_img 117 | tc_error = ( 118 | torch.norm( 119 | tc_error.view( 120 | tc_error.shape[0], 121 | tc_error.shape[1], 122 | 1, 123 | -1, 124 | ), 125 | p=1, 126 | dim=3, 127 | ) 128 | ** 1 129 | ) # norm in the spatial dimension 130 | tc_error = tc_error.sum() 131 | 132 | return self.weights[1] * tc_error 133 | 134 | def regularization(self, img): 135 | """ 136 | :param img: [batch_size x 1 x H x W] last reconstructed image 137 | :return weighted image regularization loss 138 | """ 139 | 140 | # conventional total variation with forward differences 141 | img_dx = torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]) 142 | img_dy = torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]) 143 | tv_error = img_dx.sum() + img_dy.sum() 144 | 145 | return self.weights[0] * tv_error 146 | -------------------------------------------------------------------------------- /loss/restore.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from skimage.metrics import structural_similarity as SSIM 5 | from skimage.metrics import peak_signal_noise_ratio as PSNR 6 | import torch.nn.functional as F 7 | # local modules 8 | from .PerceptualSimilarity import models 9 | 10 | 11 | class perceptual_loss(): 12 | def __init__(self, weight=1.0, net='alex', use_gpu=True, gpu_ids=[0]): 13 | """ 14 | Wrapper for PerceptualSimilarity.models.PerceptualLoss 15 | """ 16 | self.model = models.PerceptualLoss(net=net, use_gpu=use_gpu, gpu_ids=gpu_ids) 17 | self.weight = weight 18 | 19 | def __call__(self, pred, target, normalize=True): 20 | """ 21 | pred and target are Tensors with shape N x C x H x W (C {1, 3}) 22 | normalize scales images from [0, 1] to [-1, 1] (default: True) 23 | PerceptualLoss expects N x 3 x H x W. 24 | """ 25 | assert pred.size() == target.size() 26 | 27 | if pred.shape[1] == 1: 28 | pred = torch.cat([pred, pred, pred], dim=1) 29 | target = torch.cat([target, target, target], dim=1) 30 | dist = self.model.forward(pred, target, normalize=normalize) 31 | elif pred.shape[1] == 3: 32 | dist = self.model.forward(pred, target, normalize=normalize) 33 | else: 34 | num_ch = pred.shape[1] 35 | dist = 0 36 | for idx in range(num_ch): 37 | dist += self.model.forward(pred[:, idx].repeat(1, 3, 1, 1), target[:, idx].repeat(1, 3, 1, 1), normalize=normalize) 38 | dist /= num_ch 39 | 40 | return self.weight * dist.mean() 41 | 42 | 43 | class ssim_loss(): 44 | def __init__(self): 45 | self.ssim = SSIM 46 | 47 | def __call__(self, pred, tgt): 48 | """ 49 | pred, tgt: torch.tensor, 1xNxHxW 50 | """ 51 | assert pred.size() == tgt.size() 52 | pred = pred.squeeze().cpu().numpy() 53 | tgt = tgt.squeeze().cpu().numpy() 54 | 55 | if len(pred.shape) == 3: 56 | num_ch = pred.shape[0] 57 | loss = 0 58 | for idx in range(num_ch): 59 | loss += self.ssim(pred[idx], tgt[idx]) 60 | loss /= num_ch 61 | else: 62 | loss = self.ssim(pred, tgt) 63 | 64 | return loss 65 | 66 | 67 | class psnr_loss(): 68 | def __init__(self): 69 | self.psnr = PSNR 70 | 71 | def __call__(self, pred, tgt): 72 | """ 73 | pred, tgt: torch.tensor, 1xNxHxW 74 | """ 75 | assert pred.size() == tgt.size() 76 | pred = pred.squeeze().cpu().numpy() 77 | tgt = tgt.squeeze().cpu().numpy() 78 | 79 | if len(pred.shape) == 3: 80 | num_ch = pred.shape[0] 81 | loss = 0 82 | for idx in range(num_ch): 83 | # data_range = max(tgt[idx].max()-tgt.min(), pred[idx].max()-pred[idx].min()) 84 | data_range = tgt[idx].max()-tgt.min() 85 | loss += self.psnr(tgt[idx], pred[idx], data_range=data_range) 86 | loss /= num_ch 87 | else: 88 | loss = self.psnr(pred.clip(0, 1), tgt.clip(0, 1)) 89 | 90 | # loss = self.psnr((tgt.squeeze().cpu().numpy()*255).astype(np.uint8), (pred.squeeze().cpu().numpy()*255).astype(np.uint8), data_range=255) 91 | 92 | return loss 93 | 94 | 95 | class CharbonnierLoss(nn.Module): 96 | """Charbonnier Loss (L1)""" 97 | 98 | def __init__(self, eps=1e-3): 99 | super(CharbonnierLoss, self).__init__() 100 | self.eps = eps 101 | 102 | def forward(self, x, y): 103 | diff = x - y 104 | loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 105 | return loss 106 | 107 | 108 | class Ternary(nn.Module): 109 | def __init__(self, patch_size=7): 110 | super(Ternary, self).__init__() 111 | self.patch_size = patch_size 112 | out_channels = patch_size * patch_size 113 | self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels)) 114 | self.w = np.transpose(self.w, (3, 2, 0, 1)) 115 | if torch.cuda.is_available(): 116 | self.w = torch.tensor(self.w).float().cuda() 117 | 118 | def transform(self, tensor): 119 | tensor_ = tensor.mean(dim=1, keepdim=True) 120 | patches = F.conv2d(tensor_, self.w, padding=self.patch_size//2, bias=None) 121 | loc_diff = patches - tensor_ 122 | loc_diff_norm = loc_diff / torch.sqrt(0.81 + loc_diff ** 2) 123 | 124 | return loc_diff_norm 125 | 126 | def valid_mask(self, tensor): 127 | padding = self.patch_size//2 128 | b, c, h, w = tensor.size() 129 | inner = torch.ones(b, 1, h - 2 * padding, w - 2 * padding).type_as(tensor) 130 | mask = F.pad(inner, [padding] * 4) 131 | 132 | return mask 133 | 134 | def forward(self, x, y): 135 | loc_diff_x = self.transform(x) 136 | loc_diff_y = self.transform(y) 137 | diff = loc_diff_x - loc_diff_y.detach() 138 | dist = (diff ** 2 / (0.1 + diff ** 2)).mean(dim=1, keepdim=True) 139 | mask = self.valid_mask(x) 140 | loss = (dist * mask).mean() 141 | 142 | return loss 143 | 144 | 145 | # laplacian loss 146 | class GaussianConv(nn.Module): 147 | def __init__(self): 148 | super(GaussianConv, self).__init__() 149 | kernel = torch.tensor([[1., 4., 6., 4., 1], 150 | [4., 16., 24., 16., 4.], 151 | [6., 24., 36., 24., 6.], 152 | [4., 16., 24., 16., 4.], 153 | [1., 4., 6., 4., 1.]]) 154 | self.kernel = nn.Parameter(kernel.div(256).repeat(3,1,1,1), requires_grad=False) 155 | 156 | def forward(self, x, factor=1): 157 | c, h, w = x.shape[1:] 158 | p = (self.kernel.shape[-1]-1)//2 159 | blurred = F.conv2d(F.pad(x, pad=(p,p,p,p), mode='reflect'), factor*self.kernel, groups=c) 160 | return blurred 161 | 162 | class LaplacianPyramid(nn.Module): 163 | """ 164 | Implementing "The Laplacian pyramid as a compact image code." Burt, Peter J., and Edward H. Adelson. 165 | """ 166 | def __init__(self, max_level=5): 167 | super(LaplacianPyramid, self).__init__() 168 | self.gaussian_conv = GaussianConv() 169 | self.max_level = max_level 170 | 171 | def forward(self, X): 172 | pyramid = [] 173 | current = X 174 | for _ in range(self.max_level-1): 175 | blurred = self.gaussian_conv(current) 176 | reduced = self.reduce(blurred) 177 | expanded = self.expand(reduced) 178 | diff = current - expanded 179 | pyramid.append(diff) 180 | current = reduced 181 | 182 | pyramid.append(current) 183 | 184 | return pyramid 185 | 186 | def reduce(self, x): 187 | return F.avg_pool2d(x, 2) 188 | 189 | def expand(self, x): 190 | # injecting even zero rows 191 | tmp = torch.cat([x, torch.zeros_like(x).to(x.device)], dim=3) 192 | tmp = tmp.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3]) 193 | tmp = tmp.permute(0,1,3,2) 194 | # injecting even zero columns 195 | tmp = torch.cat([tmp, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2, device=x.device)], dim=3) 196 | tmp = tmp.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2) 197 | x_up = tmp.permute(0,1,3,2) 198 | # convolve with 4 x Gaussian kernel 199 | return self.gaussian_conv(x_up, factor=4) 200 | 201 | class LaplacianLoss(nn.Module): 202 | def __init__(self): 203 | super(LaplacianLoss, self).__init__() 204 | 205 | self.criterion = nn.L1Loss(reduction='sum') 206 | self.lap = LaplacianPyramid() 207 | 208 | def forward(self, x, y): 209 | x_lap, y_lap = self.lap(x), self.lap(y) 210 | return sum(2**i * self.criterion(a, b) for i, (a, b) in enumerate(zip(x_lap, y_lap))) 211 | -------------------------------------------------------------------------------- /models/DCNv2/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .idea 3 | *.so 4 | *.o 5 | *pyc 6 | _ext 7 | build 8 | DCNv2.egg-info 9 | dist 10 | vendor/ 11 | 12 | -------------------------------------------------------------------------------- /models/DCNv2/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Charles Shang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /models/DCNv2/README.md: -------------------------------------------------------------------------------- 1 | # DCNv2 latest 2 | 3 | 4 | 5 | Since DCN is used in many models and performance well but in industry this op support is not very well. Including pytorch, onnx, tensorrt etc. This repo is make DCNv2 available at all versions in pytorch. 6 | 7 | ![image-20210217183103121](https://gitee.com/jinfagang/picbed/raw/master/img/image-20210217183103121.png) 8 | 9 | Pytorch 1.7 inferenced in CenterNet-DLA model. It works on Pytorch 1.7 so that you can use it in your RTX 30 series cards. 10 | 11 | 12 | 13 | ## Updates 14 | 15 | - **2021.03.24**: It was confirmed PyTorch 1.8 is OK with master branch, feel free to use it. 16 | - **2021.02.18**: Happy new year! PyTorch 1.7 finally supported on master branch! **for lower version theoretically also works, if not, pls fire an issue to me!**. 17 | - **2020.09.23**: Now master branch works for pytorch 1.6 by default, for older version you gonna need separated one. 18 | - **2020.08.25**: Check out pytorch1.6 branch for pytorch 1.6 support, you will meet an error like `THCudaBlas_Sgemv undefined` if you using pytorch 1.6 build master branch. master branch now work for pytorch 1.5; 19 | 20 | 21 | 22 | ## Contact 23 | 24 | If you have any question, please using this platform post questions: http://t.manaai.cn 25 | -------------------------------------------------------------------------------- /models/DCNv2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/models/DCNv2/__init__.py -------------------------------------------------------------------------------- /models/DCNv2/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm *.so 3 | rm -r build/ 4 | python3 setup.py build develop 5 | -------------------------------------------------------------------------------- /models/DCNv2/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | 6 | import torch 7 | from setuptools import find_packages, setup 8 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 9 | 10 | requirements = ["torch", "torchvision"] 11 | 12 | 13 | def get_extensions(): 14 | this_dir = os.path.dirname(os.path.abspath(__file__)) 15 | extensions_dir = os.path.join(this_dir, "src") 16 | 17 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 18 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 19 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 20 | os.environ["CC"] = "g++" 21 | sources = main_file + source_cpu 22 | extension = CppExtension 23 | extra_compile_args = {"cxx": []} 24 | define_macros = [] 25 | 26 | 27 | if torch.cuda.is_available() and CUDA_HOME is not None: 28 | extension = CUDAExtension 29 | sources += source_cuda 30 | define_macros += [("WITH_CUDA", None)] 31 | extra_compile_args["nvcc"] = [ 32 | "-DCUDA_HAS_FP16=1", 33 | "-D__CUDA_NO_HALF_OPERATORS__", 34 | "-D__CUDA_NO_HALF_CONVERSIONS__", 35 | "-D__CUDA_NO_HALF2_OPERATORS__", 36 | ] 37 | else: 38 | # raise NotImplementedError('Cuda is not available') 39 | pass 40 | 41 | sources = [os.path.join(extensions_dir, s) for s in sources] 42 | include_dirs = [extensions_dir] 43 | ext_modules = [ 44 | extension( 45 | "_ext", 46 | sources, 47 | include_dirs=include_dirs, 48 | define_macros=define_macros, 49 | extra_compile_args=extra_compile_args, 50 | ) 51 | ] 52 | return ext_modules 53 | 54 | 55 | setup( 56 | name="DCNv2", 57 | version="0.1", 58 | author="charlesshang", 59 | url="https://github.com/charlesshang/DCNv2", 60 | description="deformable convolutional networks", 61 | packages=find_packages(exclude=("configs", "tests")), 62 | # install_requires=requirements, 63 | ext_modules=get_extensions(), 64 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 65 | ) 66 | -------------------------------------------------------------------------------- /models/DCNv2/src/cpu/dcn_v2_im2col_cpu.h: -------------------------------------------------------------------------------- 1 | 2 | /*! 3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 4 | * 5 | * COPYRIGHT 6 | * 7 | * All contributions by the University of California: 8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 9 | * All rights reserved. 10 | * 11 | * All other contributions: 12 | * Copyright (c) 2014-2017, the respective contributors 13 | * All rights reserved. 14 | * 15 | * Caffe uses a shared copyright model: each contributor holds copyright over 16 | * their contributions to Caffe. The project versioning records all such 17 | * contribution and copyright details. If a contributor wants to further mark 18 | * their specific copyright on a particular contribution, they should indicate 19 | * their copyright solely in the commit message of the change when it is 20 | * committed. 21 | * 22 | * LICENSE 23 | * 24 | * Redistribution and use in source and binary forms, with or without 25 | * modification, are permitted provided that the following conditions are met: 26 | * 27 | * 1. Redistributions of source code must retain the above copyright notice, this 28 | * list of conditions and the following disclaimer. 29 | * 2. Redistributions in binary form must reproduce the above copyright notice, 30 | * this list of conditions and the following disclaimer in the documentation 31 | * and/or other materials provided with the distribution. 32 | * 33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 43 | * 44 | * CONTRIBUTION AGREEMENT 45 | * 46 | * By contributing to the BVLC/caffe repository through pull-request, comment, 47 | * or otherwise, the contributor releases their content to the 48 | * license and copyright terms herein. 49 | * 50 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 51 | * 52 | * Copyright (c) 2018 Microsoft 53 | * Licensed under The MIT License [see LICENSE for details] 54 | * \file modulated_deformable_im2col.h 55 | * \brief Function definitions of converting an image to 56 | * column matrix based on kernel, padding, dilation, and offset. 57 | * These functions are mainly used in deformable convolution operators. 58 | * \ref: https://arxiv.org/abs/1811.11168 59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu 60 | */ 61 | 62 | /***************** Adapted by Charles Shang *********************/ 63 | // modified from the CUDA version for CPU use by Daniel K. Suhendro 64 | 65 | #ifndef DCN_V2_IM2COL_CPU 66 | #define DCN_V2_IM2COL_CPU 67 | 68 | #ifdef __cplusplus 69 | extern "C" 70 | { 71 | #endif 72 | 73 | void modulated_deformable_im2col_cpu(const float *data_im, const float *data_offset, const float *data_mask, 74 | const int batch_size, const int channels, const int height_im, const int width_im, 75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 77 | const int dilation_h, const int dilation_w, 78 | const int deformable_group, float *data_col); 79 | 80 | void modulated_deformable_col2im_cpu(const float *data_col, const float *data_offset, const float *data_mask, 81 | const int batch_size, const int channels, const int height_im, const int width_im, 82 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 83 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 84 | const int dilation_h, const int dilation_w, 85 | const int deformable_group, float *grad_im); 86 | 87 | void modulated_deformable_col2im_coord_cpu(const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, 88 | const int batch_size, const int channels, const int height_im, const int width_im, 89 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 90 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 91 | const int dilation_h, const int dilation_w, 92 | const int deformable_group, 93 | float *grad_offset, float *grad_mask); 94 | 95 | #ifdef __cplusplus 96 | } 97 | #endif 98 | 99 | #endif -------------------------------------------------------------------------------- /models/DCNv2/src/cpu/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | dcn_v2_cpu_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int deformable_group); 19 | 20 | std::vector 21 | dcn_v2_cpu_backward(const at::Tensor &input, 22 | const at::Tensor &weight, 23 | const at::Tensor &bias, 24 | const at::Tensor &offset, 25 | const at::Tensor &mask, 26 | const at::Tensor &grad_output, 27 | int kernel_h, int kernel_w, 28 | int stride_h, int stride_w, 29 | int pad_h, int pad_w, 30 | int dilation_h, int dilation_w, 31 | int deformable_group); 32 | 33 | 34 | std::tuple 35 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, 36 | const at::Tensor &bbox, 37 | const at::Tensor &trans, 38 | const int no_trans, 39 | const float spatial_scale, 40 | const int output_dim, 41 | const int group_size, 42 | const int pooled_size, 43 | const int part_size, 44 | const int sample_per_part, 45 | const float trans_std); 46 | 47 | std::tuple 48 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, 49 | const at::Tensor &input, 50 | const at::Tensor &bbox, 51 | const at::Tensor &trans, 52 | const at::Tensor &top_count, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); -------------------------------------------------------------------------------- /models/DCNv2/src/cuda/dcn_v2_im2col_cuda.h: -------------------------------------------------------------------------------- 1 | 2 | /*! 3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 4 | * 5 | * COPYRIGHT 6 | * 7 | * All contributions by the University of California: 8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 9 | * All rights reserved. 10 | * 11 | * All other contributions: 12 | * Copyright (c) 2014-2017, the respective contributors 13 | * All rights reserved. 14 | * 15 | * Caffe uses a shared copyright model: each contributor holds copyright over 16 | * their contributions to Caffe. The project versioning records all such 17 | * contribution and copyright details. If a contributor wants to further mark 18 | * their specific copyright on a particular contribution, they should indicate 19 | * their copyright solely in the commit message of the change when it is 20 | * committed. 21 | * 22 | * LICENSE 23 | * 24 | * Redistribution and use in source and binary forms, with or without 25 | * modification, are permitted provided that the following conditions are met: 26 | * 27 | * 1. Redistributions of source code must retain the above copyright notice, this 28 | * list of conditions and the following disclaimer. 29 | * 2. Redistributions in binary form must reproduce the above copyright notice, 30 | * this list of conditions and the following disclaimer in the documentation 31 | * and/or other materials provided with the distribution. 32 | * 33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 43 | * 44 | * CONTRIBUTION AGREEMENT 45 | * 46 | * By contributing to the BVLC/caffe repository through pull-request, comment, 47 | * or otherwise, the contributor releases their content to the 48 | * license and copyright terms herein. 49 | * 50 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 51 | * 52 | * Copyright (c) 2018 Microsoft 53 | * Licensed under The MIT License [see LICENSE for details] 54 | * \file modulated_deformable_im2col.h 55 | * \brief Function definitions of converting an image to 56 | * column matrix based on kernel, padding, dilation, and offset. 57 | * These functions are mainly used in deformable convolution operators. 58 | * \ref: https://arxiv.org/abs/1811.11168 59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu 60 | */ 61 | 62 | /***************** Adapted by Charles Shang *********************/ 63 | 64 | #ifndef DCN_V2_IM2COL_CUDA 65 | #define DCN_V2_IM2COL_CUDA 66 | 67 | #ifdef __cplusplus 68 | extern "C" 69 | { 70 | #endif 71 | 72 | void modulated_deformable_im2col_cuda(cudaStream_t stream, 73 | const float *data_im, const float *data_offset, const float *data_mask, 74 | const int batch_size, const int channels, const int height_im, const int width_im, 75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 77 | const int dilation_h, const int dilation_w, 78 | const int deformable_group, float *data_col); 79 | 80 | void modulated_deformable_col2im_cuda(cudaStream_t stream, 81 | const float *data_col, const float *data_offset, const float *data_mask, 82 | const int batch_size, const int channels, const int height_im, const int width_im, 83 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 84 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 85 | const int dilation_h, const int dilation_w, 86 | const int deformable_group, float *grad_im); 87 | 88 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, 89 | const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, 90 | const int batch_size, const int channels, const int height_im, const int width_im, 91 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 92 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 93 | const int dilation_h, const int dilation_w, 94 | const int deformable_group, 95 | float *grad_offset, float *grad_mask); 96 | 97 | #ifdef __cplusplus 98 | } 99 | #endif 100 | 101 | #endif -------------------------------------------------------------------------------- /models/DCNv2/src/cuda/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | at::Tensor 5 | dcn_v2_cuda_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int deformable_group); 19 | 20 | std::vector 21 | dcn_v2_cuda_backward(const at::Tensor &input, 22 | const at::Tensor &weight, 23 | const at::Tensor &bias, 24 | const at::Tensor &offset, 25 | const at::Tensor &mask, 26 | const at::Tensor &grad_output, 27 | int kernel_h, int kernel_w, 28 | int stride_h, int stride_w, 29 | int pad_h, int pad_w, 30 | int dilation_h, int dilation_w, 31 | int deformable_group); 32 | 33 | 34 | std::tuple 35 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, 36 | const at::Tensor &bbox, 37 | const at::Tensor &trans, 38 | const int no_trans, 39 | const float spatial_scale, 40 | const int output_dim, 41 | const int group_size, 42 | const int pooled_size, 43 | const int part_size, 44 | const int sample_per_part, 45 | const float trans_std); 46 | 47 | std::tuple 48 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, 49 | const at::Tensor &input, 50 | const at::Tensor &bbox, 51 | const at::Tensor &trans, 52 | const at::Tensor &top_count, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); -------------------------------------------------------------------------------- /models/DCNv2/src/dcn_v2.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu/vision.h" 4 | 5 | #ifdef WITH_CUDA 6 | #include "cuda/vision.h" 7 | #endif 8 | 9 | at::Tensor 10 | dcn_v2_forward(const at::Tensor &input, 11 | const at::Tensor &weight, 12 | const at::Tensor &bias, 13 | const at::Tensor &offset, 14 | const at::Tensor &mask, 15 | const int kernel_h, 16 | const int kernel_w, 17 | const int stride_h, 18 | const int stride_w, 19 | const int pad_h, 20 | const int pad_w, 21 | const int dilation_h, 22 | const int dilation_w, 23 | const int deformable_group) 24 | { 25 | if (input.type().is_cuda()) 26 | { 27 | #ifdef WITH_CUDA 28 | return dcn_v2_cuda_forward(input, weight, bias, offset, mask, 29 | kernel_h, kernel_w, 30 | stride_h, stride_w, 31 | pad_h, pad_w, 32 | dilation_h, dilation_w, 33 | deformable_group); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | else{ 39 | return dcn_v2_cpu_forward(input, weight, bias, offset, mask, 40 | kernel_h, kernel_w, 41 | stride_h, stride_w, 42 | pad_h, pad_w, 43 | dilation_h, dilation_w, 44 | deformable_group); 45 | } 46 | } 47 | 48 | std::vector 49 | dcn_v2_backward(const at::Tensor &input, 50 | const at::Tensor &weight, 51 | const at::Tensor &bias, 52 | const at::Tensor &offset, 53 | const at::Tensor &mask, 54 | const at::Tensor &grad_output, 55 | int kernel_h, int kernel_w, 56 | int stride_h, int stride_w, 57 | int pad_h, int pad_w, 58 | int dilation_h, int dilation_w, 59 | int deformable_group) 60 | { 61 | if (input.type().is_cuda()) 62 | { 63 | #ifdef WITH_CUDA 64 | return dcn_v2_cuda_backward(input, 65 | weight, 66 | bias, 67 | offset, 68 | mask, 69 | grad_output, 70 | kernel_h, kernel_w, 71 | stride_h, stride_w, 72 | pad_h, pad_w, 73 | dilation_h, dilation_w, 74 | deformable_group); 75 | #else 76 | AT_ERROR("Not compiled with GPU support"); 77 | #endif 78 | } 79 | else{ 80 | return dcn_v2_cpu_backward(input, 81 | weight, 82 | bias, 83 | offset, 84 | mask, 85 | grad_output, 86 | kernel_h, kernel_w, 87 | stride_h, stride_w, 88 | pad_h, pad_w, 89 | dilation_h, dilation_w, 90 | deformable_group); 91 | } 92 | } 93 | 94 | std::tuple 95 | dcn_v2_psroi_pooling_forward(const at::Tensor &input, 96 | const at::Tensor &bbox, 97 | const at::Tensor &trans, 98 | const int no_trans, 99 | const float spatial_scale, 100 | const int output_dim, 101 | const int group_size, 102 | const int pooled_size, 103 | const int part_size, 104 | const int sample_per_part, 105 | const float trans_std) 106 | { 107 | if (input.type().is_cuda()) 108 | { 109 | #ifdef WITH_CUDA 110 | return dcn_v2_psroi_pooling_cuda_forward(input, 111 | bbox, 112 | trans, 113 | no_trans, 114 | spatial_scale, 115 | output_dim, 116 | group_size, 117 | pooled_size, 118 | part_size, 119 | sample_per_part, 120 | trans_std); 121 | #else 122 | AT_ERROR("Not compiled with GPU support"); 123 | #endif 124 | } 125 | else{ 126 | return dcn_v2_psroi_pooling_cpu_forward(input, 127 | bbox, 128 | trans, 129 | no_trans, 130 | spatial_scale, 131 | output_dim, 132 | group_size, 133 | pooled_size, 134 | part_size, 135 | sample_per_part, 136 | trans_std); 137 | } 138 | } 139 | 140 | std::tuple 141 | dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad, 142 | const at::Tensor &input, 143 | const at::Tensor &bbox, 144 | const at::Tensor &trans, 145 | const at::Tensor &top_count, 146 | const int no_trans, 147 | const float spatial_scale, 148 | const int output_dim, 149 | const int group_size, 150 | const int pooled_size, 151 | const int part_size, 152 | const int sample_per_part, 153 | const float trans_std) 154 | { 155 | if (input.type().is_cuda()) 156 | { 157 | #ifdef WITH_CUDA 158 | return dcn_v2_psroi_pooling_cuda_backward(out_grad, 159 | input, 160 | bbox, 161 | trans, 162 | top_count, 163 | no_trans, 164 | spatial_scale, 165 | output_dim, 166 | group_size, 167 | pooled_size, 168 | part_size, 169 | sample_per_part, 170 | trans_std); 171 | #else 172 | AT_ERROR("Not compiled with GPU support"); 173 | #endif 174 | } 175 | else{ 176 | return dcn_v2_psroi_pooling_cpu_backward(out_grad, 177 | input, 178 | bbox, 179 | trans, 180 | top_count, 181 | no_trans, 182 | spatial_scale, 183 | output_dim, 184 | group_size, 185 | pooled_size, 186 | part_size, 187 | sample_per_part, 188 | trans_std); 189 | } 190 | } -------------------------------------------------------------------------------- /models/DCNv2/src/vision.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "dcn_v2.h" 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 5 | m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward"); 6 | m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward"); 7 | m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward"); 8 | m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward"); 9 | } 10 | -------------------------------------------------------------------------------- /models/FAC/README.md: -------------------------------------------------------------------------------- 1 | # Filter Adaptive Convolutional (FAC) Layer 2 | FAC layer applies generated spatially variant filters (element-wise) to the features. 3 | 4 | (Here we release the full code of FAC layer, including both the forwards and the backwards pass.) 5 | 6 | ## Prerequisites 7 | - CUDA 8.0/9.0/10.0 8 | - gcc 4.9+ 9 | - Pytorch 1.0+ 10 | 11 | Note that if your CUDA is 10.2+, you need to modify KernelConv2D_cuda.cpp: 12 | ``` 13 | 1. Uncomment: #include 14 | 2. Modify: at::cuda::getCurrentCUDAStream() -> c10::cuda::getCurrentCUDAStream() 15 | ``` 16 | 17 | ## Install 18 | ``` 19 | bash install.sh 20 | 21 | ``` 22 | -------------------------------------------------------------------------------- /models/FAC/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/models/FAC/__init__.py -------------------------------------------------------------------------------- /models/FAC/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ./kernelconv2d 3 | python setup.py clean 4 | python setup.py install --user 5 | -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/KernelConv2D.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Developed by Shangchen Zhou 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Function 8 | import kernelconv2d_cuda 9 | import random 10 | 11 | 12 | class KernelConv2DFunction(Function): 13 | # def __init__(self, kernel_size=3): 14 | # super(KernelConv2DFunction, self).__init__() 15 | # self.kernel_size = kernel_size 16 | @staticmethod 17 | def forward(ctx, input, kernel, kernel_size): 18 | ctx.kernel_size = kernel_size 19 | assert (input.is_contiguous() == True) 20 | assert (kernel.is_contiguous() == True) 21 | ctx.save_for_backward(input, kernel) 22 | assert (ctx.kernel_size == int((kernel.size(1) / input.size(1)) ** 0.5)) 23 | intKernelSize = ctx.kernel_size 24 | intBatches = input.size(0) 25 | intInputDepth = input.size(1) 26 | intInputHeight = input.size(2) 27 | intInputWidth = input.size(3) 28 | intOutputHeight = kernel.size(2) 29 | intOutputWidth = kernel.size(3) 30 | 31 | assert (intInputHeight - intKernelSize == intOutputHeight - 1) 32 | assert (intInputWidth - intKernelSize == intOutputWidth - 1) 33 | 34 | with torch.cuda.device_of(input): 35 | output = input.new().resize_(intBatches, intInputDepth, intOutputHeight, intOutputWidth).zero_() 36 | if input.is_cuda == True: 37 | kernelconv2d_cuda.forward(input, kernel, intKernelSize, output) 38 | elif input.is_cuda == False: 39 | raise NotImplementedError() # CPU VERSION NOT IMPLEMENTED 40 | print(5) 41 | 42 | return output 43 | 44 | @staticmethod 45 | def backward(ctx, grad_output): 46 | input, kernel = ctx.saved_tensors 47 | intKernelSize = ctx.kernel_size 48 | grad_output = grad_output.contiguous() 49 | with torch.cuda.device_of(input): 50 | grad_input = input.new().resize_(input.size()).zero_() 51 | grad_kernel = kernel.new().resize_(kernel.size()).zero_() 52 | if grad_output.is_cuda == True: 53 | kernelconv2d_cuda.backward(input, kernel, intKernelSize, grad_output, grad_input, grad_kernel) 54 | 55 | elif grad_output.is_cuda == False: 56 | raise NotImplementedError() # CPU VERSION NOT IMPLEMENTED 57 | 58 | return grad_input, grad_kernel, None 59 | 60 | 61 | def gradient_check(): 62 | kernel_size_list = [1, 3] 63 | len_list = [8, 10] 64 | for i in range(10): 65 | B = random.randint(1, 4) 66 | C = i + 1 67 | K = random.choice(kernel_size_list) 68 | H = random.choice(len_list) 69 | W = random.choice(len_list) 70 | input = torch.randn(B, C, H + K - 1, W + K - 1, requires_grad=True).cuda() 71 | kernel = torch.randn(B, C * K * K, H, W, requires_grad=True).cuda() 72 | # linear function, thus eps set to 1e-1 73 | print(torch.autograd.gradcheck(KernelConv2DFunction(K), (input, kernel), eps=1e-1, atol=1e-5, rtol=1e-3, 74 | raise_exception=True)) 75 | 76 | 77 | class KernelConv2D(nn.Module): 78 | def __init__(self, kernel_size): 79 | super(KernelConv2D, self).__init__() 80 | assert (kernel_size % 2 == 1) 81 | self.kernel_size = kernel_size 82 | self.pad = torch.nn.ReplicationPad2d( 83 | [(kernel_size - 1) // 2, (kernel_size - 1) // 2, (kernel_size - 1) // 2, (kernel_size - 1) // 2]) 84 | 85 | def forward(self, input, kernel): 86 | input_pad = self.pad(input) 87 | return KernelConv2DFunction.apply(input_pad, kernel, self.kernel_size) 88 | -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/KernelConv2D_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include //for CUDA 10.2+ 5 | #include 6 | 7 | #include "KernelConv2D_kernel.h" 8 | 9 | 10 | int KernelConv2D_forward_cuda( 11 | at::Tensor& input, 12 | at::Tensor& kernel, 13 | int kernel_size, 14 | at::Tensor& output 15 | ) { 16 | int success = KernelConv2D_forward_cuda_kernel( 17 | input, 18 | kernel, 19 | kernel_size, 20 | output, 21 | // at::cuda::getCurrentCUDAStream() 22 | //at::globalContext().getCurrentCUDAStream() //for torch 0.4.1 23 | c10::cuda::getCurrentCUDAStream() //for CUDA 10.2+ 24 | 25 | ); 26 | if (!success) { 27 | AT_ERROR("CUDA call failed"); 28 | } 29 | return 1; 30 | } 31 | 32 | int KernelConv2D_backward_cuda( 33 | at::Tensor& input, 34 | at::Tensor& kernel, 35 | int kernel_size, 36 | at::Tensor& grad_output, 37 | at::Tensor& grad_input, 38 | at::Tensor& grad_kernel 39 | ) { 40 | 41 | int success = KernelConv2D_backward_cuda_kernel( 42 | input, 43 | kernel, 44 | kernel_size, 45 | grad_output, 46 | grad_input, 47 | grad_kernel, 48 | // at::cuda::getCurrentCUDAStream() 49 | //at::globalContext().getCurrentCUDAStream() //for torch 0.4.1 50 | c10::cuda::getCurrentCUDAStream() //for CUDA 10.2+ 51 | ); 52 | if (!success) { 53 | AT_ERROR("CUDA call failed"); 54 | } 55 | return 1; 56 | } 57 | 58 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 59 | m.def("forward", &KernelConv2D_forward_cuda, "KernelConv2D forward (CUDA)"); 60 | m.def("backward", &KernelConv2D_backward_cuda, "KernelConv2D backward (CUDA)"); 61 | } 62 | 63 | 64 | -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/KernelConv2D_cuda.h: -------------------------------------------------------------------------------- 1 | int KernelConv2D_forward_cuda( 2 | at::Tensor& input, 3 | at::Tensor& kernel, 4 | int kernel_size, 5 | at::Tensor& output 6 | ); 7 | 8 | int KernelConv2D_backward_cuda( 9 | at::Tensor& input, 10 | at::Tensor& kernel, 11 | int kernel_size, 12 | at::Tensor& grad_output, 13 | at::Tensor& grad_input, 14 | at::Tensor& grad_kernel 15 | ); 16 | -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/KernelConv2D_kernel.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | int KernelConv2D_forward_cuda_kernel( 6 | at::Tensor& input, 7 | at::Tensor& kernel, 8 | int kernel_size, 9 | at::Tensor& output, 10 | cudaStream_t stream 11 | ); 12 | 13 | int KernelConv2D_backward_cuda_kernel( 14 | at::Tensor& input, 15 | at::Tensor& kernel, 16 | int kernel_size, 17 | at::Tensor& grad_output, 18 | at::Tensor& grad_input, 19 | at::Tensor& grad_kernel, 20 | cudaStream_t stream 21 | ); 22 | 23 | 24 | #ifdef __cplusplus 25 | } 26 | #endif 27 | -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/models/FAC/kernelconv2d/__init__.py -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/build/lib.linux-x86_64-3.8/kernelconv2d_cuda.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/models/FAC/kernelconv2d/build/lib.linux-x86_64-3.8/kernelconv2d_cuda.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/build/temp.linux-x86_64-3.8/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/models/FAC/kernelconv2d/build/temp.linux-x86_64-3.8/.ninja_deps -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/build/temp.linux-x86_64-3.8/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 0 16265 1656638840935668097 /data2/wengwm/work/code/EVFI-BE/models/FAC/kernelconv2d/build/temp.linux-x86_64-3.8/KernelConv2D_cuda.o 641e039f201ef056 3 | 0 29746 1656638854407540912 /data2/wengwm/work/code/EVFI-BE/models/FAC/kernelconv2d/build/temp.linux-x86_64-3.8/KernelConv2D_kernel.o 4bbb724a9158c3db 4 | -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/build/temp.linux-x86_64-3.8/KernelConv2D_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/models/FAC/kernelconv2d/build/temp.linux-x86_64-3.8/KernelConv2D_cuda.o -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/build/temp.linux-x86_64-3.8/KernelConv2D_kernel.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/models/FAC/kernelconv2d/build/temp.linux-x86_64-3.8/KernelConv2D_kernel.o -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/build/temp.linux-x86_64-3.8/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /usr/local/cuda/bin/nvcc 4 | 5 | cflags = -pthread -B /opt/conda/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/opt/conda/lib/python3.8/site-packages/torch/include -I/opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.8/site-packages/torch/include/TH -I/opt/conda/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda/include -I/opt/conda/include/python3.8 -c 6 | post_cflags = -std=c++14 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=kernelconv2d_cuda -D_GLIBCXX_USE_CXX11_ABI=0 7 | cuda_cflags = -I/opt/conda/lib/python3.8/site-packages/torch/include -I/opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.8/site-packages/torch/include/TH -I/opt/conda/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda/include -I/opt/conda/include/python3.8 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_86,code=sm_86 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=kernelconv2d_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /data2/wengwm/work/code/EVFI-BE/models/FAC/kernelconv2d/build/temp.linux-x86_64-3.8/KernelConv2D_cuda.o: compile /data2/wengwm/work/code/EVFI-BE/models/FAC/kernelconv2d/KernelConv2D_cuda.cpp 24 | build /data2/wengwm/work/code/EVFI-BE/models/FAC/kernelconv2d/build/temp.linux-x86_64-3.8/KernelConv2D_kernel.o: cuda_compile /data2/wengwm/work/code/EVFI-BE/models/FAC/kernelconv2d/KernelConv2D_kernel.cu 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/dist/kernelconv2d_cuda-1.0.0-py3.8-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/models/FAC/kernelconv2d/dist/kernelconv2d_cuda-1.0.0-py3.8-linux-x86_64.egg -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/kernelconv2d_cuda.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: kernelconv2d-cuda 3 | Version: 1.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/kernelconv2d_cuda.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | KernelConv2D_cuda.cpp 2 | KernelConv2D_kernel.cu 3 | setup.py 4 | kernelconv2d_cuda.egg-info/PKG-INFO 5 | kernelconv2d_cuda.egg-info/SOURCES.txt 6 | kernelconv2d_cuda.egg-info/dependency_links.txt 7 | kernelconv2d_cuda.egg-info/top_level.txt -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/kernelconv2d_cuda.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/kernelconv2d_cuda.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | kernelconv2d_cuda 2 | -------------------------------------------------------------------------------- /models/FAC/kernelconv2d/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from setuptools import setup, find_packages 5 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 6 | 7 | cxx_args = ['-std=c++14'] 8 | 9 | nvcc_args = [ 10 | '-gencode', 'arch=compute_50,code=sm_50', 11 | '-gencode', 'arch=compute_52,code=sm_52', 12 | '-gencode', 'arch=compute_60,code=sm_60', 13 | '-gencode', 'arch=compute_61,code=sm_61', 14 | '-gencode', 'arch=compute_70,code=sm_70', 15 | '-gencode', 'arch=compute_75,code=sm_75', 16 | '-gencode', 'arch=compute_86,code=sm_86' 17 | ] 18 | 19 | setup( 20 | version='1.0.0', 21 | name='kernelconv2d_cuda', 22 | ext_modules=[ 23 | CUDAExtension('kernelconv2d_cuda', [ 24 | 'KernelConv2D_cuda.cpp', 25 | 'KernelConv2D_kernel.cu' 26 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 27 | ], 28 | cmdclass={ 29 | 'build_ext': BuildExtension 30 | }) 31 | -------------------------------------------------------------------------------- /models/model_misc/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from UZH-RPG https://github.com/uzh-rpg/rpg_e2vid 3 | """ 4 | 5 | from abc import abstractmethod 6 | 7 | import numpy as np 8 | import torch.nn as nn 9 | 10 | 11 | class BaseModel(nn.Module): 12 | """ 13 | Base class for all models 14 | """ 15 | 16 | @abstractmethod 17 | def forward(self, *inputs): 18 | """ 19 | Forward pass logic 20 | 21 | :return: Model output 22 | """ 23 | raise NotImplementedError 24 | 25 | def __str__(self): 26 | """ 27 | Model prints with number of trainable parameters 28 | """ 29 | # model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 30 | # params = sum([np.prod(p.size()) for p in model_parameters]) 31 | trained_params = sum(p.numel() for p in self.parameters() if p.requires_grad) 32 | all_params = sum(p.numel() for p in self.parameters()) 33 | return super().__str__() + "\nTrainable parameters: {} \nAll parameters: {}".format(trained_params, all_params) 34 | -------------------------------------------------------------------------------- /myutils/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchvision.transforms 4 | from math import sin, cos, pi 5 | import numbers 6 | import numpy as np 7 | import random 8 | from typing import Union 9 | 10 | 11 | # class RandomCrop(object): 12 | # """Crop the tensor at a random location. 13 | # """ 14 | 15 | # def __init__(self, size, preserve_mosaicing_pattern=False): 16 | # if isinstance(size, numbers.Number): 17 | # self.size = (int(size), int(size)) 18 | # else: 19 | # self.size = size 20 | 21 | # self.preserve_mosaicing_pattern = preserve_mosaicing_pattern 22 | 23 | # @staticmethod 24 | # def get_params(x, output_size): 25 | # w, h = x.shape[2], x.shape[1] 26 | # th, tw = output_size 27 | # assert th % 4 == 0 and tw % 4 == 0 28 | # if th > h or tw > w: 29 | # raise Exception("Input size {}x{} is less than desired cropped \ 30 | # size {}x{} - input tensor shape = {}".format(w,h,tw,th,x.shape)) 31 | # if w == tw and h == th: 32 | # return 0, 0, h, w 33 | 34 | # i = random.randint(0, h - th) 35 | # j = random.randint(0, w - tw) 36 | 37 | # i = int(i // 4) * 4 38 | # j = int(j // 4) * 4 39 | 40 | # return i, j, th, tw 41 | 42 | # def crop(self, x, i, j, h, w): 43 | # """ 44 | # x: [C x H x W] Tensor to be rotated. 45 | # Returns: 46 | # Tensor: Cropped tensor. 47 | # """ 48 | # if self.preserve_mosaicing_pattern: 49 | # # make sure that i and j are even, to preserve the mosaicing pattern 50 | # if i % 2 == 1: 51 | # i = i + 1 52 | # if j % 2 == 1: 53 | # j = j + 1 54 | 55 | # return x[:, i:i + h, j:j + w] 56 | 57 | # def __call__(self, pred_frame, ori_left_frame, down2_left_frame, down4_left_frame, 58 | # ori_right_frame, down2_right_frame, down4_right_frame, 59 | # hr_event_cnt, ori_lr_event_cnt, down2_lr_event_cnt, down4_lr_event_cnt): 60 | # assert pred_frame.size()[-2:] == hr_event_cnt.size()[-2:] 61 | 62 | # i, j, h, w = self.get_params(pred_frame, self.size) 63 | 64 | # pred_frame = self.crop(pred_frame, i, j, h, w) 65 | # ori_left_frame = self.crop(ori_left_frame, i, j, h, w) 66 | # down2_left_frame = self.crop(down2_left_frame, i//2, j//2, h//2, w//2) 67 | # down4_left_frame = self.crop(down4_left_frame, i//4, j//4, h//4, w//4) 68 | # ori_right_frame = self.crop(ori_right_frame, i, j, h, w) 69 | # down2_right_frame = self.crop(down2_right_frame, i//2, j//2, h//2, w//2) 70 | # down4_right_frame = self.crop(down4_right_frame, i//4, j//4, h//4, w//4) 71 | 72 | # hr_event_cnt = self.crop(hr_event_cnt, i, j, h, w) 73 | # ori_lr_event_cnt = self.crop(ori_lr_event_cnt, i, j, h, w) 74 | # down2_lr_event_cnt = self.crop(down2_lr_event_cnt, i//2, j//2, h//2, w//2) 75 | # down4_lr_event_cnt = self.crop(down4_lr_event_cnt, i//4, j//4, h//4, w//4) 76 | 77 | # return pred_frame, ori_left_frame, down2_left_frame, down4_left_frame, \ 78 | # ori_right_frame, down2_right_frame, down4_right_frame, \ 79 | # hr_event_cnt, ori_lr_event_cnt, down2_lr_event_cnt, down4_lr_event_cnt 80 | 81 | # def __repr__(self): 82 | # return self.__class__.__name__ + '(size={0})'.format(self.size) 83 | 84 | 85 | def RandomCrop(size, scale, 86 | pred_frame, 87 | ori_left_frame, down2_left_frame, down4_left_frame, 88 | ori_right_frame, down2_right_frame, down4_right_frame, 89 | hr_event_cnt, lr_event_cnt, lr_scaled_event_cnt, 90 | ori_left_lr_event_cnt, down2_left_lr_event_cnt, down4_left_lr_event_cnt, 91 | ori_right_lr_event_cnt, down2_right_lr_event_cnt, down4_right_lr_event_cnt 92 | ): 93 | def get_params(x, output_size): 94 | w, h = x.shape[2], x.shape[1] 95 | th, tw = output_size 96 | assert th % 4 == 0 and tw % 4 == 0 97 | if th > h or tw > w: 98 | raise Exception("Input size {}x{} is less than desired cropped \ 99 | size {}x{} - input tensor shape = {}".format(w,h,tw,th,x.shape)) 100 | if w == tw and h == th: 101 | return 0, 0, h, w 102 | 103 | i = random.randint(0, h - th) 104 | j = random.randint(0, w - tw) 105 | 106 | i = int(i // 4) * 4 107 | j = int(j // 4) * 4 108 | 109 | return i, j, th, tw 110 | 111 | def crop(x, i, j, h, w): 112 | """ 113 | x: [C x H x W] Tensor to be rotated. 114 | Returns: 115 | Tensor: Cropped tensor. 116 | """ 117 | 118 | return x[:, i:i + h, j:j + w] 119 | 120 | assert pred_frame.size()[-2:] == hr_event_cnt.size()[-2:] 121 | 122 | i, j, h, w = get_params(pred_frame, size) 123 | 124 | pred_frame = crop(pred_frame, i, j, h, w) 125 | ori_left_frame = crop(ori_left_frame, i, j, h, w) 126 | down2_left_frame = crop(down2_left_frame, i//2, j//2, h//2, w//2) 127 | down4_left_frame = crop(down4_left_frame, i//4, j//4, h//4, w//4) 128 | ori_right_frame = crop(ori_right_frame, i, j, h, w) 129 | down2_right_frame = crop(down2_right_frame, i//2, j//2, h//2, w//2) 130 | down4_right_frame = crop(down4_right_frame, i//4, j//4, h//4, w//4) 131 | 132 | hr_event_cnt = crop(hr_event_cnt, i, j, h, w) 133 | lr_event_cnt = crop(lr_event_cnt, i//scale, j//scale, h//scale, w//scale) 134 | lr_scaled_event_cnt = crop(lr_scaled_event_cnt, i, j, h, w) 135 | ori_left_lr_event_cnt = crop(ori_left_lr_event_cnt, i, j, h, w) 136 | down2_left_lr_event_cnt = crop(down2_left_lr_event_cnt, i//2, j//2, h//2, w//2) 137 | down4_left_lr_event_cnt = crop(down4_left_lr_event_cnt, i//4, j//4, h//4, w//4) 138 | ori_right_lr_event_cnt = crop(ori_right_lr_event_cnt, i, j, h, w) 139 | down2_right_lr_event_cnt = crop(down2_right_lr_event_cnt, i//2, j//2, h//2, w//2) 140 | down4_right_lr_event_cnt = crop(down4_right_lr_event_cnt, i//4, j//4, h//4, w//4) 141 | 142 | return pred_frame, \ 143 | ori_left_frame, down2_left_frame, down4_left_frame, \ 144 | ori_right_frame, down2_right_frame, down4_right_frame, \ 145 | hr_event_cnt, lr_event_cnt, lr_scaled_event_cnt, \ 146 | ori_left_lr_event_cnt, down2_left_lr_event_cnt, down4_left_lr_event_cnt, \ 147 | ori_right_lr_event_cnt, down2_right_lr_event_cnt, down4_right_lr_event_cnt 148 | -------------------------------------------------------------------------------- /myutils/event_visual_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import h5py 6 | import argparse 7 | 8 | 9 | class Visualization: 10 | def __init__(self, px=400, color_scheme='green_red', eval_id=-1): 11 | self.img_idx = 0 12 | self.px = px 13 | self.color_scheme = color_scheme # gray / blue_red / green_red 14 | 15 | def plot_event(self, event_cnt, is_save, name='events_img'): 16 | 17 | event_img = (self.events_to_image(event_cnt, self.color_scheme)*255).astype(np.uint8) 18 | 19 | cv2.namedWindow(f"{name}", cv2.WINDOW_NORMAL) 20 | cv2.resizeWindow(f"{name}", int(self.px), int(self.px)) 21 | cv2.imshow(f"{name}", event_img) 22 | 23 | if is_save: 24 | filename = '/tmp/event_img.png' 25 | cv2.imwrite(filename, event_img) 26 | 27 | def plot_frame(self, frame, name='frame'): 28 | 29 | cv2.namedWindow(f"{name}", cv2.WINDOW_NORMAL) 30 | cv2.resizeWindow(f"{name}", int(self.px), int(self.px)) 31 | cv2.imshow(f"{name}", frame) 32 | 33 | @staticmethod 34 | def events_to_image(inp_events, color_scheme="green_red"): 35 | """ 36 | Visualize the input events. 37 | :param inp_events: [H x W x 2] per-pixel and per-polarity event count, numpy.narray 38 | :param color_scheme: green_red/gray/blue_red 39 | :return event_image: [H x W x 3] color-coded event image, range: [0, 1] 40 | """ 41 | assert color_scheme in ['green_red', 'gray', 'blue_red'], f'Not support {color_scheme}' 42 | 43 | pos = inp_events[:, :, 0] 44 | neg = inp_events[:, :, 1] 45 | pos_max = np.percentile(pos, 99) 46 | pos_min = np.percentile(pos, 1) 47 | neg_max = np.percentile(neg, 99) 48 | neg_min = np.percentile(neg, 1) 49 | max = pos_max if pos_max > neg_max else neg_max 50 | 51 | if pos_min != max: 52 | pos = (pos - pos_min) / (max - pos_min) 53 | if neg_min != max: 54 | neg = (neg - neg_min) / (max - neg_min) 55 | 56 | pos = np.clip(pos, 0, 1) 57 | neg = np.clip(neg, 0, 1) 58 | 59 | event_image = np.ones((inp_events.shape[0], inp_events.shape[1])) 60 | if color_scheme == "gray": 61 | event_image *= 0.5 62 | pos *= 0.5 63 | neg *= -0.5 64 | event_image += pos + neg 65 | 66 | elif color_scheme == "green_red": 67 | event_image = np.repeat(event_image[:, :, np.newaxis], 3, axis=2) 68 | event_image *= 0 69 | mask_pos = pos > 0 70 | mask_neg = neg > 0 71 | mask_not_pos = pos == 0 72 | mask_not_neg = neg == 0 73 | 74 | event_image[:, :, 0][mask_pos] = 0 75 | event_image[:, :, 1][mask_pos] = pos[mask_pos] 76 | event_image[:, :, 2][mask_pos * mask_not_neg] = 0 77 | event_image[:, :, 2][mask_neg] = neg[mask_neg] 78 | event_image[:, :, 0][mask_neg] = 0 79 | event_image[:, :, 1][mask_neg * mask_not_pos] = 0 80 | 81 | elif color_scheme == "blue_red": 82 | event_image = np.repeat(event_image[:, :, np.newaxis], 3, axis=2) 83 | event_image *= 0 84 | mask_pos = pos > 0 85 | mask_neg = neg > 0 86 | mask_not_pos = pos == 0 87 | mask_not_neg = neg == 0 88 | 89 | event_image[:, :, 1][mask_pos] = 0 90 | event_image[:, :, 0][mask_pos] = pos[mask_pos] 91 | event_image[:, :, 2][mask_pos * mask_not_neg] = 0 92 | event_image[:, :, 2][mask_neg] = neg[mask_neg] 93 | event_image[:, :, 1][mask_neg] = 0 94 | event_image[:, :, 0][mask_neg * mask_not_pos] = 0 95 | 96 | return event_image 97 | 98 | def events_to_channels(self, xs, ys, ps, sensor_size=(180, 240)): 99 | """ 100 | Generate a two-channel event image containing event counters. 101 | """ 102 | 103 | assert len(xs) == len(ys) and len(ys) == len(ps) 104 | 105 | xs = torch.from_numpy(xs) 106 | ys = torch.from_numpy(ys) 107 | ps = torch.from_numpy(ps) 108 | 109 | mask_pos = ps.clone() 110 | mask_neg = ps.clone() 111 | mask_pos[ps < 0] = 0 112 | mask_neg[ps > 0] = 0 113 | 114 | pos_cnt = self.events2image(xs, ys, ps * mask_pos, sensor_size=sensor_size) 115 | neg_cnt = self.events2image(xs, ys, ps * mask_neg, sensor_size=sensor_size) 116 | 117 | return torch.stack([pos_cnt, neg_cnt], dim=-1).numpy() 118 | 119 | @staticmethod 120 | def events2image(xs, ys, ps, sensor_size=(180, 240)): 121 | """ 122 | Accumulate events into an image. 123 | """ 124 | 125 | device = xs.device 126 | img_size = list(sensor_size) 127 | img = torch.zeros(img_size).to(device) 128 | 129 | if xs.dtype is not torch.long: 130 | xs = xs.long().to(device) 131 | if ys.dtype is not torch.long: 132 | ys = ys.long().to(device) 133 | img.index_put_((ys, xs), ps, accumulate=True) 134 | 135 | return img 136 | 137 | 138 | def get_flags(): 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument('--h5_file_path', required=True) 141 | parser.add_argument('--idx', type=int, default=3) 142 | flags = parser.parse_args() 143 | 144 | return flags 145 | 146 | 147 | if __name__ == '__main__': 148 | """ 149 | usage for event visualization: 150 | python utils/event_visual_example.py --h5_file_path path/to/event.h5 151 | """ 152 | 153 | flags = get_flags() 154 | 155 | h5_file_path = flags.h5_file_path 156 | idx = flags.idx 157 | 158 | assert os.path.isfile(h5_file_path) 159 | assert idx < 7 and idx > 0 160 | 161 | vis = Visualization() 162 | 163 | event_h5 = h5py.File(h5_file_path, 'r') 164 | 165 | sensor_resolution = event_h5.attrs['sensor_resolution'] 166 | 167 | frame0_h5 = event_h5['images']['image{:09d}'.format(idx-1)] 168 | frame1_h5 = event_h5['images']['image{:09d}'.format(idx)] 169 | frame0 = frame0_h5[:] 170 | frame1 = frame1_h5[:] 171 | 172 | events_idx = [frame0_h5.attrs['event_idx'], frame1_h5.attrs['event_idx']] 173 | 174 | xs = event_h5['events/xs'][events_idx[0]:events_idx[1]].astype(np.float32) 175 | ys = event_h5['events/ys'][events_idx[0]:events_idx[1]].astype(np.float32) 176 | ts = event_h5['events/ts'][events_idx[0]:events_idx[1]].astype(np.float32) 177 | ps = event_h5['events/ps'][events_idx[0]:events_idx[1]].astype(np.float32) 178 | 179 | event_cnt = vis.events_to_channels(xs, ys, ps, sensor_resolution) 180 | 181 | vis.plot_event(event_cnt, is_save=True) # set is_save to True to save event img 182 | vis.plot_frame(frame0) 183 | 184 | cv2.waitKey() 185 | cv2.destroyAllWindows() 186 | 187 | -------------------------------------------------------------------------------- /myutils/gradients.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Sobel(nn.Module): 8 | """ 9 | Computes the spatial gradients of 3D data using Sobel filters. 10 | """ 11 | 12 | def __init__(self, device): 13 | super().__init__() 14 | self.pad = nn.ReplicationPad2d(1) 15 | a = np.zeros((1, 1, 3, 3)) 16 | b = np.zeros((1, 1, 3, 3)) 17 | a[0, :, :, :] = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) 18 | b[0, :, :, :] = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]) 19 | self.a = torch.from_numpy(a).float().to(device) 20 | self.b = torch.from_numpy(b).float().to(device) 21 | 22 | def forward(self, x): 23 | """ 24 | :param x: [batch_size x 1 x H x W] input tensor 25 | :return gradx: [batch_size x 2 x H x W-1] spatial gradient in the x direction 26 | :return grady: [batch_size x 2 x H-1 x W] spatial gradient in the y direction 27 | """ 28 | 29 | x = x.view(-1, 1, x.shape[2], x.shape[3]) # (batch * channels, 1, height, width) 30 | x = self.pad(x) 31 | gradx = F.conv2d(x, self.a, groups=1) / 8 # normalized gradients 32 | grady = F.conv2d(x, self.b, groups=1) / 8 # normalized gradients 33 | return gradx, grady 34 | -------------------------------------------------------------------------------- /myutils/iwe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def purge_unfeasible(x, res): 5 | """ 6 | Purge unfeasible event locations by setting their interpolation weights to zero. 7 | :param x: location of motion compensated events 8 | :param res: resolution of the image space 9 | :return masked indices 10 | :return mask for interpolation weights 11 | """ 12 | 13 | mask = torch.ones((x.shape[0], x.shape[1], 1)).to(x.device) 14 | mask_y = (x[:, :, 0:1] < 0) + (x[:, :, 0:1] >= res[0]) 15 | mask_x = (x[:, :, 1:2] < 0) + (x[:, :, 1:2] >= res[1]) 16 | mask[mask_y + mask_x] = 0 17 | return x * mask, mask 18 | 19 | 20 | def get_interpolation(events, flow, tref, res, flow_scaling, round_idx=False): 21 | """ 22 | Warp the input events according to the provided optical flow map and compute the bilinar interpolation 23 | (or rounding) weights to distribute the events to the closes (integer) locations in the image space. 24 | :param events: [batch_size x N x 4] input events (y, x, ts, p) 25 | :param flow: [batch_size x 2 x H x W] optical flow map 26 | :param tref: reference time toward which events are warped 27 | :param res: resolution of the image space 28 | :param flow_scaling: scalar that multiplies the optical flow map 29 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. (default = False) 30 | :return interpolated event indices 31 | :return interpolation weights 32 | """ 33 | 34 | # event propagation 35 | warped_events = events[:, :, 1:3] + (tref - events[:, :, 0:1]) * flow * flow_scaling 36 | 37 | if round_idx: 38 | 39 | # no bilinear interpolation 40 | idx = torch.round(warped_events) 41 | weights = torch.ones(idx.shape).to(events.device) 42 | 43 | else: 44 | 45 | # get scattering indices 46 | top_y = torch.floor(warped_events[:, :, 0:1]) 47 | bot_y = torch.floor(warped_events[:, :, 0:1] + 1) 48 | left_x = torch.floor(warped_events[:, :, 1:2]) 49 | right_x = torch.floor(warped_events[:, :, 1:2] + 1) 50 | 51 | top_left = torch.cat([top_y, left_x], dim=2) 52 | top_right = torch.cat([top_y, right_x], dim=2) 53 | bottom_left = torch.cat([bot_y, left_x], dim=2) 54 | bottom_right = torch.cat([bot_y, right_x], dim=2) 55 | idx = torch.cat([top_left, top_right, bottom_left, bottom_right], dim=1) 56 | 57 | # get scattering interpolation weights 58 | warped_events = torch.cat([warped_events for i in range(4)], dim=1) 59 | zeros = torch.zeros(warped_events.shape).to(events.device) 60 | weights = torch.max(zeros, 1 - torch.abs(warped_events - idx)) 61 | 62 | # purge unfeasible indices 63 | idx, mask = purge_unfeasible(idx, res) 64 | 65 | # make unfeasible weights zero 66 | weights = torch.prod(weights, dim=-1, keepdim=True) * mask # bilinear interpolation 67 | 68 | # prepare indices 69 | idx[:, :, 0] *= res[1] # torch.view is row-major 70 | idx = torch.sum(idx, dim=2, keepdim=True) 71 | 72 | return idx, weights 73 | 74 | 75 | def interpolate(idx, weights, res, polarity_mask=None): 76 | """ 77 | Create an image-like representation of the warped events. 78 | :param idx: [batch_size x N x 1] warped event locations 79 | :param weights: [batch_size x N x 1] interpolation weights for the warped events 80 | :param res: resolution of the image space 81 | :param polarity_mask: [batch_size x N x 2] polarity mask for the warped events (default = None) 82 | :return image of warped events 83 | """ 84 | 85 | if polarity_mask is not None: 86 | weights = weights * polarity_mask 87 | iwe = torch.zeros((idx.shape[0], res[0] * res[1], 1)).to(idx.device) 88 | iwe = iwe.scatter_add_(1, idx.long(), weights) 89 | iwe = iwe.view((idx.shape[0], 1, res[0], res[1])) 90 | return iwe 91 | 92 | 93 | def deblur_events(flow, event_list, res, flow_scaling=128, round_idx=True, polarity_mask=None): 94 | """ 95 | Deblur the input events given an optical flow map. 96 | Event timestamp needs to be normalized between 0 and 1. 97 | :param flow: [batch_size x 2 x H x W] optical flow map 98 | :param events: [batch_size x N x 4] input events (y, x, ts, p) 99 | :param res: resolution of the image space 100 | :param flow_scaling: scalar that multiplies the optical flow map 101 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. (default = False) 102 | :param polarity_mask: [batch_size x N x 2] polarity mask for the warped events (default = None) 103 | :return iwe: [batch_size x 1 x H x W] image of warped events 104 | """ 105 | 106 | # flow vector per input event 107 | flow_idx = event_list[:, :, 1:3].clone() 108 | flow_idx[:, :, 0] *= res[1] # torch.view is row-major 109 | flow_idx = torch.sum(flow_idx, dim=2) 110 | 111 | # get flow for every event in the list 112 | flow = flow.view(flow.shape[0], 2, -1) 113 | event_flowy = torch.gather(flow[:, 1, :], 1, flow_idx.long()) # vertical component 114 | event_flowx = torch.gather(flow[:, 0, :], 1, flow_idx.long()) # horizontal component 115 | event_flowy = event_flowy.view(event_flowy.shape[0], event_flowy.shape[1], 1) 116 | event_flowx = event_flowx.view(event_flowx.shape[0], event_flowx.shape[1], 1) 117 | event_flow = torch.cat([event_flowy, event_flowx], dim=2) 118 | 119 | # interpolate forward 120 | fw_idx, fw_weights = get_interpolation(event_list, event_flow, 1, res, flow_scaling, round_idx=round_idx) 121 | if not round_idx: 122 | polarity_mask = torch.cat([polarity_mask for i in range(4)], dim=1) 123 | 124 | # image of (forward) warped events 125 | iwe = interpolate(fw_idx.long(), fw_weights, res, polarity_mask=polarity_mask) 126 | 127 | return iwe 128 | 129 | 130 | def compute_pol_iwe(flow, event_list, res, pos_mask, neg_mask, flow_scaling=128, round_idx=True): 131 | """ 132 | Create a per-polarity image of warped events given an optical flow map. 133 | :param flow: [batch_size x 2 x H x W] optical flow map 134 | :param event_list: [batch_size x N x 4] input events (y, x, ts, p) 135 | :param res: resolution of the image space 136 | :param pos_mask: [batch_size x N x 1] polarity mask for positive events 137 | :param neg_mask: [batch_size x N x 1] polarity mask for negative events 138 | :param flow_scaling: scalar that multiplies the optical flow map 139 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. (default = True) 140 | :return iwe: [batch_size x 2 x H x W] image of warped events 141 | """ 142 | 143 | iwe_pos = deblur_events( 144 | flow, event_list, res, flow_scaling=flow_scaling, round_idx=round_idx, polarity_mask=pos_mask 145 | ) 146 | iwe_neg = deblur_events( 147 | flow, event_list, res, flow_scaling=flow_scaling, round_idx=round_idx, polarity_mask=neg_mask 148 | ) 149 | iwe = torch.cat([iwe_pos, iwe_neg], dim=1) 150 | 151 | return iwe 152 | -------------------------------------------------------------------------------- /myutils/timers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import time 4 | import numpy as np 5 | import atexit 6 | from collections import defaultdict 7 | 8 | 9 | cuda_timers = defaultdict(list) 10 | timers = defaultdict(list) 11 | 12 | 13 | # class CudaTimer: 14 | # def __init__(self, timer_name=''): 15 | # self.timer_name = timer_name 16 | 17 | # self.start = torch.cuda.Event(enable_timing=True) 18 | # self.end = torch.cuda.Event(enable_timing=True) 19 | 20 | # def __enter__(self): 21 | # self.start.record() 22 | # return self 23 | 24 | # def __exit__(self, *args): 25 | # self.end.record() 26 | # torch.cuda.synchronize() 27 | # cuda_timers[self.timer_name].append(self.start.elapsed_time(self.end)) 28 | 29 | class CudaTimer: 30 | def __init__(self, timer_name=''): 31 | self.timer_name = timer_name 32 | 33 | def __enter__(self): 34 | self.start = time.time() 35 | return self 36 | 37 | def __exit__(self, *args): 38 | self.end = time.time() 39 | self.interval = self.end - self.start # measured in seconds 40 | self.interval *= 1000.0 # convert to milliseconds 41 | timers[self.timer_name].append(self.interval) 42 | 43 | class Timer: 44 | def __init__(self, timer_name='', logger=None): 45 | self.timer_name = timer_name 46 | self.logger = logger 47 | 48 | def __enter__(self): 49 | self.start = time.time() 50 | return self 51 | 52 | def __exit__(self, *args): 53 | self.end = time.time() 54 | self.interval = self.end - self.start # measured in seconds 55 | self.interval *= 1000.0 # convert to milliseconds 56 | timers[self.timer_name].append(self.interval) 57 | 58 | if self.timer_name == 'Time of training one epoch' and self.logger != None: 59 | if dist.get_rank() == 0: 60 | if self.interval < 1000.0: 61 | self.logger.info('{}: {:.2f} ms'.format(self.timer_name, self.interval)) 62 | else: 63 | self.logger.info('{}: {:.2f} s'.format(self.timer_name, self.interval / 1000.0)) 64 | 65 | 66 | def print_timing_info(): 67 | print('== Timing statistics ==') 68 | for timer_name, timing_values in [*cuda_timers.items(), *timers.items()]: 69 | timing_value = np.mean(np.array(timing_values)) 70 | if timing_value < 1000.0: 71 | print('{}: {:.2f} ms ({} samples)'.format(timer_name, timing_value, len(timing_values))) 72 | else: 73 | print('{}: {:.2f} s ({} samples)'.format(timer_name, timing_value / 1000.0, len(timing_values))) 74 | 75 | 76 | # this will print all the timer values upon termination of any program that imported this file 77 | atexit.register(print_timing_info) 78 | -------------------------------------------------------------------------------- /myutils/vis_events/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WarranWeng/EBFI-BE/8595a11a84242e08c85c5bee6a8bfe956015d2e4/myutils/vis_events/tools/__init__.py -------------------------------------------------------------------------------- /myutils/vis_events/tools/add_hdf5_attribute.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import h5py 4 | import os 5 | import glob 6 | 7 | def endswith(path, extensions): 8 | for ext in extensions: 9 | if path.endswith(ext): 10 | return True 11 | return False 12 | 13 | def get_filepaths_from_path_or_file(path, extensions=[], datafile_extensions=[".txt", ".csv"]): 14 | files = [] 15 | path = path.rstrip("/") 16 | if os.path.isdir(path): 17 | for ext in extensions: 18 | files += sorted(glob.glob("{}/*{}".format(path, ext))) 19 | else: 20 | if endswith(path, extensions): 21 | files.append(path) 22 | elif endswith(path, datafile_extensions): 23 | with open(path, 'r') as f: 24 | #files.append(line) for line in f.readlines 25 | files = [line.strip() for line in f.readlines()] 26 | return files 27 | 28 | def add_attribute(h5_filepaths, group, attribute_name, attribute_value, dry_run=False): 29 | for h5_filepath in h5_filepaths: 30 | print("adding {}/{}[{}]={}".format(h5_filepath, group, attribute_name, attribute_value)) 31 | if dry_run: 32 | continue 33 | h5_file = h5py.File(h5_filepath, 'a') 34 | dset = h5_file["{}/".format(group)] 35 | dset.attrs[attribute_name] = attribute_value 36 | h5_file.close() 37 | 38 | if __name__ == "__main__": 39 | # arguments 40 | parser = argparse.ArgumentParser() 41 | parser._action_groups.pop() 42 | required = parser.add_argument_group('required arguments') 43 | optional = parser.add_argument_group('optional arguments') 44 | 45 | required.add_argument("--path", help="Can be either 1: path to individual hdf file, " + 46 | "2: txt file with list of hdf files, or " + 47 | "3: directory (all hdf files in directory will be processed).", required=True) 48 | required.add_argument("--attr_name", help="Name of new attribute", required=True) 49 | required.add_argument("--attr_val", help="Value of new attribute", required=True) 50 | optional.add_argument("--group", help="Group to add attribute to. Subgroups " + 51 | "are represented like paths, eg: /group1/subgroup2...", default="") 52 | optional.add_argument("--dry_run", default=0, type=int, 53 | help="If set to 1, will print changes without performing them") 54 | 55 | args = parser.parse_args() 56 | path = args.path 57 | extensions = [".hdf", ".h5"] 58 | files = get_filepaths_from_path_or_file(path, extensions=extensions) 59 | print(files) 60 | dry_run = False if args.dry_run <= 0 else True 61 | add_attribute(files, args.group, args.attr_name, args.attr_val, dry_run=dry_run) 62 | -------------------------------------------------------------------------------- /myutils/vis_events/tools/event_packagers.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import h5py 3 | import cv2 as cv 4 | import numpy as np 5 | 6 | class packager(): 7 | 8 | __metaclass__ = ABCMeta 9 | 10 | def __init__(self, name, output_path, max_buffer_size=1000000): 11 | self.name = name 12 | self.output_path = output_path 13 | self.max_buffer_size = max_buffer_size 14 | 15 | @abstractmethod 16 | def package_events(self, xs, ys, ts, ps): 17 | pass 18 | 19 | @abstractmethod 20 | def package_image(self, frame, timestamp): 21 | pass 22 | 23 | @abstractmethod 24 | def package_flow(self, flow, timestamp): 25 | pass 26 | 27 | @abstractmethod 28 | def add_metadata(self, num_events, num_pos, num_neg, 29 | duration, t0, tk, num_imgs, num_flow): 30 | pass 31 | 32 | @abstractmethod 33 | def set_data_available(self, num_images, num_flow): 34 | pass 35 | 36 | class hdf5_packager(packager): 37 | """ 38 | This class packages data to hdf5 files 39 | """ 40 | def __init__(self, output_path, max_buffer_size=1000000): 41 | packager.__init__(self, 'hdf5', output_path, max_buffer_size) 42 | print("CREATING FILE IN {}".format(output_path)) 43 | self.events_file = h5py.File(output_path, 'w') 44 | self.event_xs = self.events_file.create_dataset("events/xs", (0, ), dtype=np.dtype(np.int16), maxshape=(None, ), chunks=True) 45 | self.event_ys = self.events_file.create_dataset("events/ys", (0, ), dtype=np.dtype(np.int16), maxshape=(None, ), chunks=True) 46 | self.event_ts = self.events_file.create_dataset("events/ts", (0, ), dtype=np.dtype(np.float64), maxshape=(None, ), chunks=True) 47 | self.event_ps = self.events_file.create_dataset("events/ps", (0, ), dtype=np.dtype(np.bool_), maxshape=(None, ), chunks=True) 48 | 49 | def append_to_dataset(self, dataset, data): 50 | dataset.resize(dataset.shape[0] + len(data), axis=0) 51 | if len(data) == 0: 52 | return 53 | dataset[-len(data):] = data[:] 54 | 55 | def package_events(self, xs, ys, ts, ps): 56 | self.append_to_dataset(self.event_xs, xs) 57 | self.append_to_dataset(self.event_ys, ys) 58 | self.append_to_dataset(self.event_ts, ts) 59 | self.append_to_dataset(self.event_ps, ps) 60 | 61 | def package_image(self, image, timestamp, img_idx): 62 | image_dset = self.events_file.create_dataset("images/image{:09d}".format(img_idx), 63 | data=image, dtype=np.dtype(np.uint8)) 64 | image_dset.attrs['size'] = image.shape 65 | image_dset.attrs['timestamp'] = timestamp 66 | image_dset.attrs['type'] = "greyscale" if image.shape[-1] == 1 or len(image.shape) == 2 else "color_bgr" 67 | 68 | def package_flow(self, flow_image, timestamp, flow_idx): 69 | flow_dset = self.events_file.create_dataset("flow/flow{:09d}".format(flow_idx), 70 | data=flow_image, dtype=np.dtype(np.float32)) 71 | flow_dset.attrs['size'] = flow_image.shape 72 | flow_dset.attrs['timestamp'] = timestamp 73 | 74 | def add_event_indices(self): 75 | datatypes = ['images', 'flow'] 76 | for datatype in datatypes: 77 | if datatype in self.events_file.keys(): 78 | s = 0 79 | added = 0 80 | ts = self.events_file["events/ts"][s:s+self.max_buffer_size] 81 | for image in self.events_file[datatype]: 82 | img_ts = self.events_file[datatype][image].attrs['timestamp'] 83 | event_idx = np.searchsorted(ts, img_ts) 84 | if event_idx == len(ts): 85 | added += len(ts) 86 | s += self.max_buffer_size 87 | ts = self.events_file["events/ts"][s:s+self.max_buffer_size] 88 | event_idx = np.searchsorted(ts, img_ts) 89 | event_idx = max(0, event_idx-1) 90 | self.events_file[datatype][image].attrs['event_idx'] = event_idx + added 91 | 92 | def add_metadata(self, num_pos, num_neg, 93 | duration, t0, tk, num_imgs, num_flow, sensor_size): 94 | self.events_file.attrs['num_events'] = num_pos+num_neg 95 | self.events_file.attrs['num_pos'] = num_pos 96 | self.events_file.attrs['num_neg'] = num_neg 97 | self.events_file.attrs['duration'] = tk-t0 98 | self.events_file.attrs['t0'] = t0 99 | self.events_file.attrs['tk'] = tk 100 | self.events_file.attrs['num_imgs'] = num_imgs 101 | self.events_file.attrs['num_flow'] = num_flow 102 | self.events_file.attrs['sensor_resolution'] = sensor_size 103 | self.add_event_indices() 104 | 105 | def set_data_available(self, num_images, num_flow): 106 | if num_images > 0: 107 | self.image_dset = self.events_file.create_group("images") 108 | self.image_dset.attrs['num_images'] = num_images 109 | if num_flow > 0: 110 | self.flow_dset = self.events_file.create_group("flow") 111 | self.flow_dset.attrs['num_images'] = num_flow 112 | 113 | -------------------------------------------------------------------------------- /myutils/vis_events/tools/h5_to_memmap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import h5py 3 | import numpy as np 4 | import os, shutil 5 | import json 6 | 7 | def find_safe_alternative(output_base_path): 8 | i = 0 9 | alternative_path = "{}_{:09d}".format(output_base_path, i) 10 | while(os.path.exists(alternative_path)): 11 | i += 1 12 | alternative_path = "{}_{:09d}".format(output_base_path, i) 13 | assert(i < 999999999) 14 | return alternative_path 15 | 16 | def save_additional_data_as_mmap(f, mmap_pth, data): 17 | data_path = os.path.join(mmap_pth, data['mmap_filename']) 18 | data_ts_path = os.path.join(mmap_pth, data['mmap_ts_filename']) 19 | data_event_idx_path = os.path.join(mmap_pth, data['mmap_event_idx_filename']) 20 | data_key = data['h5_key'] 21 | print('Writing {} to mmap {}, timestamps to {}'.format(data_key, data_path, data_ts_path)) 22 | h, w, c = 1, 1, 1 23 | if data_key in f.keys(): 24 | num_data = len(f[data_key].keys()) 25 | if num_data > 0: 26 | data_keys = f[data_key].keys() 27 | data_size = f[data_key][data_keys[0]].attrs['size'] 28 | h, w = data_size[0], data_size[1] 29 | c = 1 if len(data_size) <= 2 else data_size[2] 30 | else: 31 | num_data = 1 32 | mmp_imgs = np.memmap(data_path, dtype='uint8', mode='w+', shape=(num_data, h, w, c)) 33 | mmp_img_ts = np.memmap(data_ts_path, dtype='float64', mode='w+', shape=(num_data, 1)) 34 | mmp_event_indices = np.memmap(data_event_idx_path, dtype='uint16', mode='w+', shape=(num_data, 1)) 35 | 36 | if data_key in f.keys(): 37 | data = [] 38 | data_timestamps = [] 39 | data_event_index = [] 40 | for img_key in f[data_key].keys(): 41 | data.append(f[data_key][img_key][:]) 42 | data_timestamps.append(f[data_key][img_key].attrs['timestamp']) 43 | data_event_index.append(f[data_key][img_key].attrs['event_idx']) 44 | 45 | data_stack = np.expand_dims(np.stack(data), axis=3) 46 | data_ts_stack = np.expand_dims(np.stack(data_timestamps), axis=1) 47 | data_event_indices_stack = np.expand_dims(np.stack(data_event_index), axis=1) 48 | mmp_imgs[...] = data_stack 49 | mmp_img_ts[...] = data_ts_stack 50 | mmp_event_indices[...] = data_event_indices_stack 51 | 52 | def write_metadata(f, metadata_path): 53 | metadata = {} 54 | for attr in f.attrs: 55 | val = f.attrs[attr] 56 | if isinstance(val, np.ndarray): 57 | val = val.tolist() 58 | metadata[attr] = val 59 | with open(metadata_path, 'w') as js: 60 | json.dump(metadata, js) 61 | 62 | def h5_to_memmap(h5_file_path, output_base_path, overwrite=True): 63 | output_pth = output_base_path 64 | if os.path.exists(output_pth): 65 | if overwrite: 66 | print("Overwriting {}".format(output_pth)) 67 | shutil.rmtree(output_pth) 68 | else: 69 | output_pth = find_safe_alternative(output_base_path) 70 | print('Data will be extracted to: {}'.format(output_pth)) 71 | os.makedirs(output_pth) 72 | mmap_pth = os.path.join(output_pth, "memmap") 73 | os.makedirs(mmap_pth) 74 | 75 | ts_path = os.path.join(mmap_pth, 't.npy') 76 | xy_path = os.path.join(mmap_pth, 'xy.npy') 77 | ps_path = os.path.join(mmap_pth, 'p.npy') 78 | metadata_path = os.path.join(mmap_pth, 'metadata.json') 79 | 80 | additional_data = { 81 | "images": 82 | { 83 | 'h5_key' : 'images', 84 | 'mmap_filename' : 'images.npy', 85 | 'mmap_ts_filename' : 'timestamps.npy', 86 | 'mmap_event_idx_filename' : 'image_event_indices.npy', 87 | 'dims' : 3 88 | }, 89 | "flow": 90 | { 91 | 'h5_key' : 'flow', 92 | 'mmap_filename' : 'flow.npy', 93 | 'mmap_ts_filename' : 'flow_timestamps.npy', 94 | 'mmap_event_idx_filename' : 'flow_event_indices.npy', 95 | 'dims' : 3 96 | } 97 | } 98 | 99 | with h5py.File(h5_file_path, 'r') as f: 100 | num_events = f.attrs['num_events'] 101 | num_images = f.attrs['num_imgs'] 102 | num_flow = f.attrs['num_flow'] 103 | 104 | mmp_ts = np.memmap(ts_path, dtype='float64', mode='w+', shape=(num_events, 1)) 105 | mmp_xy = np.memmap(xy_path, dtype='int16', mode='w+', shape=(num_events, 2)) 106 | mmp_ps = np.memmap(ps_path, dtype='uint8', mode='w+', shape=(num_events, 1)) 107 | 108 | mmp_ts[:, 0] = f['events/ts'][:] 109 | mmp_xy[:, :] = np.stack((f['events/xs'][:], f['events/ys'][:])).transpose() 110 | mmp_ps[:, 0] = f['events/ps'][:] 111 | 112 | for data in additional_data: 113 | save_additional_data_as_mmap(f, mmap_pth, additional_data[data]) 114 | write_metadata(f, metadata_path) 115 | 116 | 117 | if __name__ == "__main__": 118 | """ 119 | Tool to convert this projects style hdf5 files to the memmap format used in some RPG projects 120 | """ 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("path", help="HDF5 file to convert") 123 | parser.add_argument("--output_dir", default=None, help="Path to extract (same as bag if left empty)") 124 | parser.add_argument('--not_overwrite', action='store_false', help='If set, will not overwrite\ 125 | existing memmap, but will place safe alternative') 126 | 127 | args = parser.parse_args() 128 | 129 | bagname = os.path.splitext(os.path.basename(args.path))[0] 130 | if args.output_dir is None: 131 | output_path = os.path.join(os.path.dirname(os.path.abspath(args.path)), bagname) 132 | else: 133 | output_path = os.path.join(args.output_dir, bagname) 134 | h5_to_memmap(args.path, output_path, overwrite=args.not_overwrite) 135 | -------------------------------------------------------------------------------- /myutils/vis_events/tools/hxy_events2ply.py: -------------------------------------------------------------------------------- 1 | from numpy.lib.type_check import imag 2 | from plyfile import PlyData, PlyElement 3 | import numpy as np 4 | from scipy.optimize.optimize import vecnorm 5 | import cv2 6 | import h5py 7 | import os 8 | 9 | 10 | def read_h5_events(filename, start_idx, end_idx, inp_prex = 'down4'): 11 | h5_file = h5py.File(filename, 'r') 12 | xs = h5_file[f'{inp_prex}_events/xs'][start_idx:end_idx] 13 | ys = h5_file[f'{inp_prex}_events/ys'][start_idx:end_idx] 14 | ts = h5_file[f'{inp_prex}_events/ts'][start_idx:end_idx] 15 | ps = h5_file[f'{inp_prex}_events/ps'][start_idx:end_idx] 16 | 17 | sensor_resolution = h5_file.attrs['sensor_resolution'] 18 | 19 | return sensor_resolution, xs, ys, ts, ps 20 | 21 | 22 | def main(): 23 | filename = '/media/wwm/wwmdisk/data/Nfs/train_h5/bee.h5' 24 | basename = os.path.basename(filename).split('.')[0] 25 | TOTAL_COUNT = 122033 26 | start_idx = 0 27 | 28 | vertices_final = np.empty(0, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 29 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 30 | sensor_resolution, xs, ys, ts, ps = read_h5_events(filename, start_idx=start_idx, end_idx=start_idx+TOTAL_COUNT) 31 | # flip 32 | H, W = sensor_resolution 33 | max = ts.max() 34 | min = ts.min() 35 | ts = (ts - min) / (max - min) * H 36 | 37 | # xs = 240 - xs 38 | # ys = 180 - ys 39 | # image 40 | # image = np.zeros((180,240),dtype=np.uint8) 41 | # for i in range(TOTAL_COUNT): 42 | # x, y = xs[i], ys[i] 43 | # image[180-y,x-1] = 255 44 | # cv2.imwrite("stacking_{}.png".format(i),image) 45 | # LAYER_NUM = 10 46 | # LAYER_COUNT = TOTAL_COUNT // LAYER_NUM 47 | print("event count: {}, duration time {:0.2f}".format(TOTAL_COUNT, ts[-1]-ts[0])) 48 | # bottom_xs,bottom_ys,bottom_ts = xs[0:LAYER_COUNT],ys[0:LAYER_COUNT],ts[0:LAYER_COUNT] 49 | # top_xs,top_ys,top_ts = xs[int(9*LAYER_COUNT):],ys[int(9*LAYER_COUNT):],ts[int(9*LAYER_COUNT):] 50 | # connect the proper data structures 51 | vertices = np.empty(TOTAL_COUNT, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 52 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 53 | # red = np.ones((LAYER_COUNT,))*255*(i%3==0) 54 | # green = np.ones((LAYER_COUNT,))*255*(i%3==1) 55 | # blue = np.ones((LAYER_COUNT,))*255*(i%3==2) 56 | red = ps*255 57 | green = ps*0 58 | blue = (ps==-1)*255 59 | vertices['x'] = xs.astype('f4') 60 | vertices['y'] = ys.astype('f4') 61 | vertices['z'] = ts.astype('f4') 62 | vertices['red'] = red.astype('u1') 63 | vertices['green'] = green.astype('u1') 64 | vertices['blue'] = blue.astype('u1') 65 | vertices_final = np.concatenate((vertices_final, vertices)) 66 | 67 | 68 | # save as ply 69 | ply = PlyData([PlyElement.describe(vertices_final, 'vertex')], text=False) 70 | ply.write(f'/disk/work/output/dataset_ply/{basename}.ply') 71 | 72 | 73 | if __name__ == '__main__': 74 | main() -------------------------------------------------------------------------------- /myutils/vis_events/tools/read_events.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import os 4 | 5 | def compute_indices(event_stamps, frame_stamps): 6 | indices_first = np.searchsorted(event_stamps[:,0], frame_stamps[1:]) 7 | indices_last = np.searchsorted(event_stamps[:,0], frame_stamps[:-1]) 8 | index = np.stack([indices_first, indices_last], -1) 9 | return index 10 | 11 | def read_memmap_events(memmap_path, skip_frames=1, return_events=False, images_file = 'images.npy', 12 | images_ts_file = 'timestamps.npy', optic_flow_file = 'optic_flow.npy', 13 | optic_flow_ts_file = 'optic_flow_timestamps.npy', events_xy_file = 'xy.npy', 14 | events_p_file = 'p.npy', events_t_file = 't.npy'): 15 | assert os.path.isdir(memmap_path), '%s is not a valid memmap_pathectory' % memmap_path 16 | 17 | data = {} 18 | has_flow = False 19 | for subroot, _, fnames in sorted(os.walk(memmap_path)): 20 | for fname in sorted(fnames): 21 | path = os.path.join(subroot, fname) 22 | if fname.endswith(".npy"): 23 | if fname=="index.npy": # index mapping image index to event idx 24 | indices = np.load(path) # N x 2 25 | assert len(indices.shape) == 2 and indices.shape[1] == 2 26 | indices = indices.astype("int64") # ignore event indices which are 0 (before first image) 27 | data["index"] = indices.T 28 | elif fname==images_ts_file: 29 | data["frame_stamps"] = np.load(path)[::skip_frames,...] 30 | elif fname==images_file: 31 | data["images"] = np.load(path, mmap_mode="r")[::skip_frames,...] 32 | elif fname==optic_flow_file: 33 | data["optic_flow"] = np.load(path, mmap_mode="r")[::skip_frames,...] 34 | has_flow = True 35 | elif fname==optic_flow_ts_file: 36 | data["optic_flow_stamps"] = np.load(path)[::skip_frames,...] 37 | 38 | handle = np.load(path, mmap_mode="r") 39 | if fname==events_t_file: # timestamps 40 | data["t"] = handle[:].squeeze() if return_events else handle 41 | data["t0"] = handle[0] 42 | elif fname==events_xy_file: # coordinates 43 | data["xy"] = handle[:].squeeze() if return_events else handle 44 | elif fname==events_p_file: # polarity 45 | data["p"] = handle[:].squeeze() if return_events else handle 46 | 47 | if len(data) > 0: 48 | data['path'] = subroot 49 | if "t" not in data: 50 | raise Exception(f"Ignoring memmap_pathectory {subroot} since no events") 51 | if not (len(data['p']) == len(data['xy']) and len(data['p']) == len(data['t'])): 52 | raise Exception(f"Events from {subroot} invalid") 53 | data["num_events"] = len(data['p']) 54 | 55 | if "index" not in data and "frame_stamps" in data: 56 | data["index"] = compute_indices(data["t"], data['frame_stamps']) 57 | return data 58 | 59 | def read_h5_events(hdf_path): 60 | f = h5py.File(hdf_path, 'r') 61 | if 'events/x' in f: 62 | #legacy 63 | events = np.stack((f['events/x'][:], f['events/y'][:], f['events/ts'][:], np.where(f['events/p'][:], 1, -1)), axis=1) 64 | else: 65 | events = np.stack((f['events/xs'][:], f['events/ys'][:], f['events/ts'][:], np.where(f['events/ps'][:], 1, -1)), axis=1) 66 | return events 67 | 68 | def read_h5_event_components(hdf_path): 69 | f = h5py.File(hdf_path, 'r') 70 | if 'events/x' in f: 71 | #legacy 72 | return (f['events/x'][:], f['events/y'][:], f['events/ts'][:], np.where(f['events/p'][:], 1, -1)) 73 | else: 74 | return (f['events/xs'][:], f['events/ys'][:], f['events/ts'][:], np.where(f['events/ps'][:], 1, -1)) 75 | -------------------------------------------------------------------------------- /myutils/vis_events/tools/txt_to_h5.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import argparse 3 | import os 4 | import h5py 5 | import pandas as pd 6 | import numpy as np 7 | from event_packagers import * 8 | 9 | 10 | def get_sensor_size(txt_path): 11 | try: 12 | header = pd.read_csv(txt_path, delim_whitespace=True, header=None, names=['width', 'height'], 13 | dtype={'width': np.int, 'height': np.int}, 14 | nrows=1) 15 | width, height = header.values[0] 16 | sensor_size = [height, width] 17 | except: 18 | sensor_size = None 19 | print('Warning: could not read sensor size from first line of {}'.format(txt_path)) 20 | return sensor_size 21 | 22 | 23 | def extract_txt(txt_path, output_path, zero_timestamps=False, 24 | packager=hdf5_packager): 25 | ep = packager(output_path) 26 | first_ts = -1 27 | t0 = -1 28 | if not os.path.exists(txt_path): 29 | print("{} does not exist!".format(txt_path)) 30 | return 31 | 32 | # compute sensor size 33 | sensor_size = get_sensor_size(txt_path) 34 | # Extract events to h5 35 | ep.set_data_available(num_images=0, num_flow=0) 36 | total_num_pos, total_num_neg, last_ts = 0, 0, 0 37 | 38 | chunksize = 100000 39 | iterator = pd.read_csv(txt_path, delim_whitespace=True, header=None, 40 | names=['t', 'x', 'y', 'pol'], 41 | dtype={'t': np.float64, 'x': np.int16, 'y': np.int16, 'pol': np.int16}, 42 | engine='c', 43 | skiprows=1, chunksize=chunksize, nrows=None, memory_map=True) 44 | 45 | for i, event_window in enumerate(iterator): 46 | events = event_window.values 47 | ts = events[:, 0].astype(np.float64) 48 | xs = events[:, 1].astype(np.int16) 49 | ys = events[:, 2].astype(np.int16) 50 | ps = events[:, 3] 51 | ps[ps < 0] = 0 # should be [0 or 1] 52 | ps = ps.astype(bool) 53 | 54 | if first_ts == -1: 55 | first_ts = ts[0] 56 | 57 | if zero_timestamps: 58 | ts -= first_ts 59 | last_ts = ts[-1] 60 | if sensor_size is None or sensor_size[0] < max(ys) or sensor_size[1] < max(xs): 61 | sensor_size = [max(xs), max(ys)] 62 | print("Sensor size inferred from events as {}".format(sensor_size)) 63 | 64 | sum_ps = sum(ps) 65 | total_num_pos += sum_ps 66 | total_num_neg += len(ps) - sum_ps 67 | ep.package_events(xs, ys, ts, ps) 68 | if i % 10 == 9: 69 | print('Events written: {} M'.format((total_num_pos + total_num_neg) / 1e6)) 70 | print('Events written: {} M'.format((total_num_pos + total_num_neg) / 1e6)) 71 | print("Detect sensor size [h={}, w={}]".format(sensor_size[0], sensor_size[1])) 72 | t0 = 0 if zero_timestamps else first_ts 73 | ep.add_metadata(total_num_pos, total_num_neg, last_ts-t0, t0, last_ts, num_imgs=0, num_flow=0, sensor_size=sensor_size) 74 | 75 | 76 | def extract_txts(txt_paths, output_dir, zero_timestamps=False): 77 | for path in txt_paths: 78 | filename = os.path.splitext(os.path.basename(path))[0] 79 | out_path = os.path.join(output_dir, "{}.h5".format(filename)) 80 | print("Extracting {} to {}".format(path, out_path)) 81 | extract_txt(path, out_path, zero_timestamps=zero_timestamps) 82 | 83 | 84 | if __name__ == "__main__": 85 | """ 86 | Tool for converting txt events to an efficient HDF5 format that can be speedily 87 | accessed by python code. 88 | Input path can be single file or directory containing files. 89 | Individual input event files can be txt or zip with format matching 90 | https://github.com/uzh-rpg/rpg_e2vid: 91 | 92 | width height 93 | t1 x1 y1 p1 94 | t2 x2 y2 p2 95 | t3 x3 y3 p3 96 | ... 97 | 98 | 99 | i.e. first line of file is sensor size width first and height second. 100 | This script only does events -> h5, not images or anything else (yet). 101 | """ 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument("path", help="txt file to extract or directory containing txt files") 104 | parser.add_argument("--output_dir", default="/tmp/extracted_data", help="Folder where to extract the data") 105 | parser.add_argument('--zero_timestamps', action='store_true', help='If true, timestamps will be offset to start at 0') 106 | args = parser.parse_args() 107 | 108 | print('Data will be extracted in folder: {}'.format(args.output_dir)) 109 | if not os.path.exists(args.output_dir): 110 | os.makedirs(args.output_dir) 111 | if os.path.isdir(args.path): 112 | txt_paths = sorted(list(glob.glob(os.path.join(args.path, "*.txt"))) 113 | + list(glob.glob(os.path.join(args.path, "*.zip")))) 114 | else: 115 | txt_paths = [args.path] 116 | extract_txts(txt_paths, args.output_dir, zero_timestamps=args.zero_timestamps) 117 | -------------------------------------------------------------------------------- /scripts/infer_ours.sh: -------------------------------------------------------------------------------- 1 | ############## synthetic data 2 | CUDA_VISIBLE_DEVICES='0' \ 3 | python infer_ours.py \ 4 | --model_path /path/to/model \ 5 | --data_list /path/to/test.txt \ 6 | --output_path /path/to/output \ 7 | --scale 2 \ 8 | --ori_scale down2 \ 9 | --time_bins 16 \ 10 | --num_frame_per_period 16 \ # tune for different exposure assumptions 11 | --num_frame_per_blurry 3 \ # tune for different exposure assumptions 12 | --num_period_per_seq 2 \ 13 | --sliding_window_seq 2 \ 14 | --num_period_per_load 1 \ 15 | --sliding_window_load 1 \ 16 | --exposure_method Fixed \ # Auto/Fixed/Custom 17 | --noise_enabled 18 | 19 | 20 | ############## real-world data: RealBlur-DAVIS 21 | CUDA_VISIBLE_DEVICES='1' \ 22 | python infer_ours.py \ 23 | --model_path /path/to/model \ 24 | --data_list /path/to/test.txt \ 25 | --output_path /path/to/output \ 26 | --scale 2 \ 27 | --ori_scale down2 \ 28 | --time_bins 16 \ 29 | --interp_num 256 \ # define frame number for interpolation 30 | --num_period_per_seq 2 \ 31 | --sliding_window_seq 2 \ 32 | --num_period_per_load 1 \ 33 | --sliding_window_load 1 \ 34 | --noise_enabled \ 35 | --real_blur 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /scripts/train_ours.sh: -------------------------------------------------------------------------------- 1 | # first train exposure estimation 2 | CUDA_VISIBLE_DEVICES='0' \ 3 | python -m torch.distributed.launch --nproc_per_node 1 --use_env --master_port 355827 \ 4 | train_ours_exposuredecision.py -c config\train_ours_exposuredecision.yml -id provide_id_name 5 | 6 | # then train the whole model 7 | CUDA_VISIBLE_DEVICES='1' \ 8 | python -m torch.distributed.launch --nproc_per_node 1 --use_env --master_port 355829 \ 9 | train_ours.py -c config/train_ours.yml -id provide_id_name --------------------------------------------------------------------------------