├── VERSION ├── basicsr ├── __init__.py ├── metrics │ ├── niqe_pris_params.npz │ ├── __pycache__ │ │ ├── niqe.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── psnr_ssim.cpython-38.pyc │ │ └── metric_util.cpython-38.pyc │ ├── __init__.py │ ├── metric_util.py │ ├── fid.py │ └── niqe.py ├── __pycache__ │ ├── version.cpython-38.pyc │ └── __init__.cpython-38.pyc ├── utils │ ├── __pycache__ │ │ ├── logger.cpython-38.pyc │ │ ├── misc.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── img_util.cpython-38.pyc │ │ ├── options.cpython-38.pyc │ │ ├── create_lmdb.cpython-38.pyc │ │ ├── dist_util.cpython-38.pyc │ │ ├── file_client.cpython-38.pyc │ │ ├── flow_util.cpython-38.pyc │ │ ├── lmdb_util.cpython-38.pyc │ │ └── matlab_functions.cpython-38.pyc │ ├── __init__.py │ ├── download_util.py │ ├── dist_util.py │ ├── options.py │ ├── create_lmdb.py │ ├── npz2voxel.py │ ├── misc.py │ ├── img_util.py │ ├── logger.py │ ├── flow_util.py │ ├── file_client.py │ └── lmdb_util.py ├── data │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── data_util.cpython-38.pyc │ │ ├── data_sampler.cpython-38.pyc │ │ ├── ffhq_dataset.cpython-38.pyc │ │ ├── h5_augment.cpython-38.pyc │ │ ├── reds_dataset.cpython-38.pyc │ │ ├── transforms.cpython-38.pyc │ │ ├── h5_image_dataset.cpython-38.pyc │ │ ├── prefetch_dataloader.cpython-38.pyc │ │ ├── paired_image_dataset.cpython-38.pyc │ │ └── single_image_dataset.cpython-38.pyc │ ├── data_sampler.py │ ├── ffhq_dataset.py │ ├── single_image_dataset.py │ ├── prefetch_dataloader.py │ ├── event_util.py │ ├── __init__.py │ ├── voxelnpz_image_dataset.py │ └── npz_image_dataset.py ├── models │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── base_model.cpython-38.pyc │ │ ├── lr_scheduler.cpython-38.pyc │ │ ├── image_restoration_model.cpython-38.pyc │ │ ├── Test_image_restoration_model.cpython-38.pyc │ │ ├── image_event_restoration_model.cpython-38.pyc │ │ └── Test_image_event_restoration_model.cpython-38.pyc │ ├── archs │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── arch_util.cpython-38.pyc │ │ │ └── EFNet_arch.cpython-38.pyc │ │ └── __init__.py │ ├── losses │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── losses.cpython-38.pyc │ │ │ └── loss_util.cpython-38.pyc │ │ ├── __init__.py │ │ ├── loss_util.py │ │ └── losses.py │ ├── __init__.py │ └── lr_scheduler.py ├── version.py ├── demo.py └── test.py ├── figures ├── scer.png ├── ._SCER.pdf ├── ._scer.png ├── models.png ├── modules.pdf ├── ._models.pdf ├── ._models.png ├── ._modules.pdf ├── table_gopro.png ├── table_reblur.png ├── ._table_gopro.png ├── ._table_reblur.png ├── qualitative_GoPro_1.jpg ├── qualitative_GoPro_2.png ├── qualitative_REBlur_1.jpg ├── qualitative_REBlur_2.png ├── ._qualitative_GoPro_1.jpg ├── ._qualitative_GoPro_2.png ├── ._qualitative_REBlur_1.jpg └── ._qualitative_REBlur_2.png ├── .gitignore ├── requirements.txt ├── scripts ├── download_gdrive.py ├── data_preparation │ ├── README.md │ ├── noise_function.py │ ├── make_voxels_esim.py │ ├── make_voxels_real.py │ └── raw_event_dataset.py ├── publish_models.py └── download_pretrained_models.py ├── datasets └── README.md ├── setup.cfg ├── options ├── test │ ├── GoPro │ │ └── EFNet.yml │ └── REBlur │ │ └── Finetune_EFNet.yml └── train │ ├── HighREV │ ├── EFNet_HighREV_Deblur.yml │ └── EFNet_HighREV_Deblur_voxel.yml │ ├── GoPro │ └── EFNet.yml │ └── REBlur │ └── Finetune_EFNet.yml ├── setup.py └── README.md /VERSION: -------------------------------------------------------------------------------- 1 | 1.2.0 2 | -------------------------------------------------------------------------------- /basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/scer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/scer.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode/ 3 | .eggs/ 4 | basicsr.egg-info/ 5 | experiments/ -------------------------------------------------------------------------------- /figures/._SCER.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._SCER.pdf -------------------------------------------------------------------------------- /figures/._scer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._scer.png -------------------------------------------------------------------------------- /figures/models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/models.png -------------------------------------------------------------------------------- /figures/modules.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/modules.pdf -------------------------------------------------------------------------------- /figures/._models.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._models.pdf -------------------------------------------------------------------------------- /figures/._models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._models.png -------------------------------------------------------------------------------- /figures/._modules.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._modules.pdf -------------------------------------------------------------------------------- /figures/table_gopro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/table_gopro.png -------------------------------------------------------------------------------- /figures/table_reblur.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/table_reblur.png -------------------------------------------------------------------------------- /figures/._table_gopro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._table_gopro.png -------------------------------------------------------------------------------- /figures/._table_reblur.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._table_reblur.png -------------------------------------------------------------------------------- /figures/qualitative_GoPro_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/qualitative_GoPro_1.jpg -------------------------------------------------------------------------------- /figures/qualitative_GoPro_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/qualitative_GoPro_2.png -------------------------------------------------------------------------------- /figures/qualitative_REBlur_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/qualitative_REBlur_1.jpg -------------------------------------------------------------------------------- /figures/qualitative_REBlur_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/qualitative_REBlur_2.png -------------------------------------------------------------------------------- /figures/._qualitative_GoPro_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._qualitative_GoPro_1.jpg -------------------------------------------------------------------------------- /figures/._qualitative_GoPro_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._qualitative_GoPro_2.png -------------------------------------------------------------------------------- /figures/._qualitative_REBlur_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._qualitative_REBlur_1.jpg -------------------------------------------------------------------------------- /figures/._qualitative_REBlur_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._qualitative_REBlur_2.png -------------------------------------------------------------------------------- /basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /basicsr/__pycache__/version.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/__pycache__/version.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/niqe.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/metrics/__pycache__/niqe.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/data_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/data_util.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/img_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/img_util.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/options.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/data_sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/data_sampler.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/ffhq_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/ffhq_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/h5_augment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/h5_augment.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/reds_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/reds_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/metrics/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/psnr_ssim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/metrics/__pycache__/psnr_ssim.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/base_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/base_model.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/create_lmdb.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/create_lmdb.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/dist_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/dist_util.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/file_client.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/file_client.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/flow_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/flow_util.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/lmdb_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/lmdb_util.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/metric_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/metrics/__pycache__/metric_util.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/h5_image_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/h5_image_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/archs/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/arch_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/archs/__pycache__/arch_util.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/losses/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/losses/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/matlab_functions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/matlab_functions.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/prefetch_dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/prefetch_dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/EFNet_arch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/archs/__pycache__/EFNet_arch.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/__pycache__/loss_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/losses/__pycache__/loss_util.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/paired_image_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/paired_image_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/single_image_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/single_image_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Fri Feb 7 15:11:50 2025 3 | __version__ = '1.2.0+a8a710a' 4 | short_version = '1.2.0' 5 | version_info = (1, 2, 0) 6 | -------------------------------------------------------------------------------- /basicsr/models/__pycache__/image_restoration_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/image_restoration_model.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/Test_image_restoration_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/Test_image_restoration_model.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/image_event_restoration_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/image_event_restoration_model.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .niqe import calculate_niqe 2 | from .psnr_ssim import calculate_psnr, calculate_ssim 3 | 4 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict 2 | future 3 | lmdb 4 | numpy 5 | opencv-python 6 | Pillow 7 | pyyaml 8 | requests 9 | scikit-image 10 | scipy 11 | tb-nightly 12 | tqdm 13 | yapf 14 | h5py 15 | -------------------------------------------------------------------------------- /basicsr/models/__pycache__/Test_image_event_restoration_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/Test_image_event_restoration_model.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import (L1Loss, MSELoss, PSNRLoss, WeightedTVLoss, CharbonnierLoss) 2 | 3 | __all__ = [ 4 | 'L1Loss', 'MSELoss', 'PSNRLoss', 'WeightedTVLoss', 'CharbonnierLoss' 5 | ] 6 | -------------------------------------------------------------------------------- /scripts/download_gdrive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from basicsr.utils.download_util import download_file_from_google_drive 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument('--id', type=str, help='File id') 9 | parser.add_argument('--output', type=str, help='Save path') 10 | args = parser.parse_args() 11 | 12 | download_file_from_google_drive(args.id, args.save_path) 13 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Symlink/Put all the datasets here 2 | 3 | It is recommended to symlink your dataset root to this folder - `datasets` with the command `ln -s xxx yyy`. 4 | 5 | Please refer to [DatasetPreparation.md](../docs/docs/DatasetPreparation.md) for more details about data preparation. 6 | 7 | --- 8 | 9 | 推荐把数据通过 `ln -s xxx yyy` 软链到当前目录 `datasets` 下. 10 | 11 | 更多数据准备的细节参见 [DatasetPreparation_CN.md](../docs/DatasetPreparation_CN.md). 12 | -------------------------------------------------------------------------------- /scripts/data_preparation/README.md: -------------------------------------------------------------------------------- 1 | ### Used for generate voxels from raw event (h5 file) 2 | 3 | Convert raw GoPro event file (h5) to voxel file (h5): 4 | ``` 5 | python make_voxels_esim.py --input_path /your/path/to/raw/event/h5file --save_path /your/path/to/save/voxel/h5file --voxel_method SCER_esim 6 | ``` 7 | 8 | Convert raw REBlur event file (h5) to voxel file (h5): 9 | ``` 10 | python make_voxels_real.py --input_path /your/path/to/raw/event/h5file --save_path /your/path/to/save/voxel/h5file --voxel_method SCER_real_data 11 | ``` 12 | 13 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # line break before binary operator (W503) 4 | W503, 5 | # line break after binary operator (W504) 6 | W504, 7 | max-line-length=79 8 | 9 | [yapf] 10 | based_on_style = pep8 11 | blank_line_before_nested_class_or_def = true 12 | split_before_expression_after_opening_paren = true 13 | 14 | [isort] 15 | line_length = 79 16 | multi_line_output = 0 17 | known_standard_library = pkg_resources,setuptools 18 | known_first_party = basicsr 19 | known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml 20 | no_lines_before = STDLIB,LOCALFOLDER 21 | default_section = THIRDPARTY 22 | -------------------------------------------------------------------------------- /scripts/data_preparation/noise_function.py: -------------------------------------------------------------------------------- 1 | 2 | def put_hot_pixels_in_voxel_(voxel, hot_pixel_range=20, hot_pixel_fraction=0.00002): 3 | num_hot_pixels = int(hot_pixel_fraction * voxel.shape[-1] * voxel.shape[-2]) 4 | x = torch.randint(0, voxel.shape[-1], (num_hot_pixels,)) 5 | y = torch.randint(0, voxel.shape[-2], (num_hot_pixels,)) 6 | for i in range(num_hot_pixels): 7 | # voxel[..., :, y[i], x[i]] = random.uniform(-hot_pixel_range, hot_pixel_range) 8 | voxel[..., :, y[i], x[i]] = random.randint(-hot_pixel_range, hot_pixel_range) 9 | 10 | 11 | def add_noise_to_voxel(voxel, noise_std=1.0, noise_fraction=0.1): 12 | noise = noise_std * torch.randn_like(voxel) # mean = 0, std = noise_std 13 | if noise_fraction < 1.0: 14 | mask = torch.rand_like(voxel) >= noise_fraction 15 | noise.masked_fill_(mask, 0) 16 | return voxel + noise 17 | -------------------------------------------------------------------------------- /basicsr/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from basicsr.models import create_model 3 | from basicsr.train import parse_options 4 | from basicsr.utils import FileClient, imfrombytes, img2tensor, padding 5 | 6 | def main(): 7 | # parse options, set distributed setting, set ramdom seed 8 | opt = parse_options(is_train=False) 9 | 10 | img_path = opt['img_path'].get('input_img') 11 | output_path = opt['img_path'].get('output_img') 12 | 13 | 14 | ## 1. read image 15 | file_client = FileClient('disk') 16 | 17 | img_bytes = file_client.get(img_path, None) 18 | try: 19 | img = imfrombytes(img_bytes, float32=True) 20 | except: 21 | raise Exception("path {} not working".format(img_path)) 22 | 23 | img = img2tensor(img, bgr2rgb=True, float32=True) 24 | 25 | 26 | 27 | ## 2. run inference 28 | model = create_model(opt) 29 | model.single_image_inference(img, output_path) 30 | 31 | print('inference {} .. finished.'.format(img_path)) 32 | 33 | if __name__ == '__main__': 34 | main() 35 | 36 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding 3 | from .logger import (MessageLogger, get_env_info, get_root_logger, 4 | init_tb_logger, init_wandb_logger) 5 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, 6 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt) 7 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) 8 | 9 | __all__ = [ 10 | # file_client.py 11 | 'FileClient', 12 | # img_util.py 13 | 'img2tensor', 14 | 'tensor2img', 15 | 'imfrombytes', 16 | 'imwrite', 17 | 'crop_border', 18 | # logger.py 19 | 'MessageLogger', 20 | 'init_tb_logger', 21 | 'init_wandb_logger', 22 | 'get_root_logger', 23 | 'get_env_info', 24 | # misc.py 25 | 'set_random_seed', 26 | 'get_time_str', 27 | 'mkdir_and_rename', 28 | 'make_exp_dirs', 29 | 'scandir', 30 | 'check_resume', 31 | 'sizeof_fmt', 32 | 'padding', 33 | 'create_lmdb_for_reds', 34 | 'create_lmdb_for_gopro', 35 | 'create_lmdb_for_rain13k', 36 | ] 37 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import get_root_logger, scandir 5 | 6 | # automatically scan and import model modules 7 | # scan all the files under the 'models' folder and collect files ending with 8 | # '_model.py' 9 | model_folder = osp.dirname(osp.abspath(__file__)) 10 | model_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) 12 | if v.endswith('_model.py') 13 | ] 14 | # import all the model modules 15 | _model_modules = [ 16 | importlib.import_module(f'basicsr.models.{file_name}') 17 | for file_name in model_filenames 18 | ] 19 | 20 | 21 | def create_model(opt): 22 | """Create model. 23 | 24 | Args: 25 | opt (dict): Configuration. It constains: 26 | model_type (str): Model type. 27 | """ 28 | model_type = opt['model_type'] 29 | 30 | # dynamic instantiation 31 | for module in _model_modules: 32 | model_cls = getattr(module, model_type, None) 33 | if model_cls is not None: 34 | break 35 | if model_cls is None: 36 | raise ValueError(f'Model {model_type} is not found.') 37 | 38 | model = model_cls(opt) 39 | 40 | logger = get_root_logger() 41 | logger.info(f'Model [{model.__class__.__name__}] is created.') 42 | return model 43 | -------------------------------------------------------------------------------- /options/test/GoPro/EFNet.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: EFNet_test 3 | model_type: TestImageEventRestorationModel 4 | scale: 1 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | test: 11 | name: gopro-bestmodel-test 12 | type: H5ImageDataset 13 | 14 | dataroot: /cluster/work/cvl/leisun/GOPRO_fullsize_h5_bin6_ver3/test # for debug 15 | 16 | # add 17 | norm_voxel: true 18 | return_voxel: true 19 | return_gt_frame: false 20 | return_mask: true 21 | use_mask: true 22 | 23 | crop_size: ~ 24 | use_flip: false 25 | use_rot: false 26 | io_backend: 27 | type: h5 28 | 29 | dataset_name: GoPro 30 | 31 | # network structures 32 | network_g: 33 | type: EFNet 34 | wf: 64 35 | fuse_before_downsample: true 36 | 37 | 38 | # path 39 | path: 40 | pretrain_network_g: /cluster/work/cvl/leisun/log/experiments/EVTransformer-Finetune-2e4-300iter/models/net_g_latest.pth 41 | strict_load_g: true 42 | resume_state: ~ 43 | root: /cluster/work/cvl/leisun/EFNet_inference/ # set this option ONLY in TEST!!! 44 | 45 | # validation settings 46 | val: 47 | save_img: true 48 | grids: ~ 49 | crop_size: ~ 50 | rgb2bgr: false # to my h5 data, its false 51 | 52 | # dist training settings 53 | dist_params: 54 | backend: nccl 55 | port: 29500 56 | -------------------------------------------------------------------------------- /basicsr/models/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import arch modules 7 | # scan all the files under the 'archs' folder and collect files ending with 8 | # '_arch.py' 9 | arch_folder = osp.dirname(osp.abspath(__file__)) 10 | arch_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) 12 | if v.endswith('_arch.py') 13 | ] 14 | # import all the arch modules 15 | _arch_modules = [ 16 | importlib.import_module(f'basicsr.models.archs.{file_name}') 17 | for file_name in arch_filenames 18 | ] 19 | 20 | 21 | def dynamic_instantiation(modules, cls_type, opt): 22 | """Dynamically instantiate class. 23 | 24 | Args: 25 | modules (list[importlib modules]): List of modules from importlib 26 | files. 27 | cls_type (str): Class type. 28 | opt (dict): Class initialization kwargs. 29 | 30 | Returns: 31 | class: Instantiated class. 32 | """ 33 | 34 | for module in modules: 35 | cls_ = getattr(module, cls_type, None) 36 | if cls_ is not None: 37 | break 38 | if cls_ is None: 39 | raise ValueError(f'{cls_type} is not found.') 40 | return cls_(**opt) 41 | 42 | 43 | def define_network(opt): 44 | network_type = opt.pop('type') 45 | net = dynamic_instantiation(_arch_modules, network_type, opt) 46 | return net 47 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError( 25 | f'Wrong input_order {input_order}. Supported input_orders are ' 26 | "'HWC' and 'CHW'") 27 | if len(img.shape) == 2: 28 | img = img[..., None] 29 | if input_order == 'CHW': 30 | img = img.transpose(1, 2, 0) 31 | return img 32 | 33 | 34 | def to_y_channel(img): 35 | """Change to Y channel of YCbCr. 36 | 37 | Args: 38 | img (ndarray): Images with range [0, 255]. 39 | 40 | Returns: 41 | (ndarray): Images with range [0, 255] (float type) without round. 42 | """ 43 | img = img.astype(np.float32) / 255. 44 | if img.ndim == 3 and img.shape[2] == 3: 45 | img = bgr2ycbcr(img, y_only=True) 46 | img = img[..., None] 47 | return img * 255. 48 | -------------------------------------------------------------------------------- /options/test/REBlur/Finetune_EFNet.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: EFNet_Finetune_test 3 | model_type: ImageEventRestorationModel 4 | scale: 1 5 | num_gpu: 1 # 4 6 | manual_seed: 10 7 | 8 | datasets: 9 | test: 10 | name: gopro-bestmodel-test 11 | type: H5ImageDataset 12 | 13 | 14 | # dataroot: /cluster/work/cvl/leisun/REBlur_addition # seems h5 15 | dataroot: /cluster/work/cvl/leisun/REBlur/test # REBlur 16 | 17 | # keep true if use events 18 | norm_voxel: true 19 | return_voxel: true 20 | 21 | return_mask: true # dataloader yields mask loss 22 | use_mask: true # use mask in model(data) mask as input in model 23 | 24 | filename_tmpl: '{}' 25 | io_backend: 26 | type: h5 27 | 28 | # gt_size: 256 29 | crop_size: 256 30 | use_flip: true 31 | use_rot: true 32 | 33 | # data loader settings 34 | use_shuffle: true 35 | num_worker_per_gpu: 3 36 | batch_size_per_gpu: 4 37 | dataset_enlarge_ratio: 1 38 | prefetch_mode: cpu 39 | num_prefetch_queue: 2 40 | 41 | dataset_name: REBlur 42 | 43 | # network structures 44 | network_g: 45 | type: EFNet 46 | wf: 64 #64 47 | fuse_before_downsample: true 48 | 49 | 50 | # path 51 | path: 52 | pretrain_network_g: /cluster/work/cvl/leisun/log/experiments/EV_Transformer_channelattention_simple_20witer/models/net_g_latest.pth 53 | strict_load_g: true 54 | resume_state: ~ 55 | root: /cluster/work/cvl/leisun/EFNet_inference/ # set this option ONLY in TEST!!! 56 | 57 | val: 58 | save_img: true 59 | grids: ~ 60 | crop_size: ~ 61 | rgb2bgr: false # to my h5 data, its false 62 | 63 | dist_params: 64 | backend: nccl 65 | port: 29500 66 | 67 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when 12 | 13 | restart the dataloader after each epoch 14 | 15 | Args: 16 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 17 | num_replicas (int | None): Number of processes participating in 18 | the training. It is usually the world_size. 19 | rank (int | None): Rank of the current process within num_replicas. 20 | ratio (int): Enlarging ratio. Default: 1. 21 | """ 22 | 23 | def __init__(self, dataset, num_replicas, rank, ratio=1): 24 | self.dataset = dataset 25 | self.num_replicas = num_replicas 26 | self.rank = rank 27 | self.epoch = 0 28 | self.num_samples = math.ceil( 29 | len(self.dataset) * ratio / self.num_replicas) 30 | self.total_size = self.num_samples * self.num_replicas 31 | 32 | def __iter__(self): 33 | # deterministically shuffle based on epoch 34 | g = torch.Generator() 35 | g.manual_seed(self.epoch) 36 | indices = torch.randperm(self.total_size, generator=g).tolist() 37 | 38 | dataset_size = len(self.dataset) 39 | indices = [v % dataset_size for v in indices] 40 | 41 | # subsample 42 | indices = indices[self.rank:self.total_size:self.num_replicas] 43 | assert len(indices) == self.num_samples 44 | 45 | return iter(indices) 46 | 47 | def __len__(self): 48 | return self.num_samples 49 | 50 | def set_epoch(self, epoch): 51 | self.epoch = epoch 52 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import create_dataloader, create_dataset 6 | from basicsr.models import create_model 7 | from basicsr.train import parse_options 8 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str, 9 | make_exp_dirs) 10 | from basicsr.utils.options import dict2str 11 | 12 | 13 | def main(): 14 | # parse options, set distributed setting, set ramdom seed 15 | opt = parse_options(is_train=False) 16 | 17 | torch.backends.cudnn.benchmark = True 18 | # torch.backends.cudnn.deterministic = True 19 | 20 | # mkdir and initialize loggers 21 | make_exp_dirs(opt) 22 | log_file = osp.join(opt['path']['log'], 23 | f"test_{opt['name']}_{get_time_str()}.log") 24 | logger = get_root_logger( 25 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 26 | logger.info(get_env_info()) 27 | logger.info(dict2str(opt)) 28 | 29 | # create test dataset and dataloader 30 | test_loaders = [] 31 | for phase, dataset_opt in sorted(opt['datasets'].items()): 32 | test_set = create_dataset(dataset_opt) 33 | test_loader = create_dataloader( 34 | test_set, 35 | dataset_opt, 36 | num_gpu=opt['num_gpu'], 37 | dist=opt['dist'], 38 | sampler=None, 39 | seed=opt['manual_seed']) 40 | logger.info( 41 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 42 | test_loaders.append(test_loader) 43 | 44 | # create model 45 | model = create_model(opt) 46 | 47 | for test_loader in test_loaders: 48 | test_set_name = opt['datasets']['test']['name'] 49 | logger.info(f'Testing {test_set_name}...') 50 | rgb2bgr = opt['val'].get('rgb2bgr', True) 51 | # wheather use uint8 image to compute metrics 52 | use_image = opt['val'].get('use_image', True) 53 | model.validation( 54 | test_loader, 55 | current_iter=opt['name'], 56 | tb_logger=None, 57 | save_img=opt['val']['save_img'], 58 | rgb2bgr=rgb2bgr, use_image=use_image) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import requests 3 | from tqdm import tqdm 4 | 5 | from .misc import sizeof_fmt 6 | 7 | 8 | def download_file_from_google_drive(file_id, save_path): 9 | """Download files from google drive. 10 | 11 | Ref: 12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 13 | 14 | Args: 15 | file_id (str): File id. 16 | save_path (str): Save path. 17 | """ 18 | 19 | session = requests.Session() 20 | URL = 'https://docs.google.com/uc?export=download' 21 | params = {'id': file_id} 22 | 23 | response = session.get(URL, params=params, stream=True) 24 | token = get_confirm_token(response) 25 | if token: 26 | params['confirm'] = token 27 | response = session.get(URL, params=params, stream=True) 28 | 29 | # get file size 30 | response_file_size = session.get( 31 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 32 | if 'Content-Range' in response_file_size.headers: 33 | file_size = int( 34 | response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, 49 | destination, 50 | file_size=None, 51 | chunk_size=32768): 52 | if file_size is not None: 53 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 54 | 55 | readable_file_size = sizeof_fmt(file_size) 56 | else: 57 | pbar = None 58 | 59 | with open(destination, 'wb') as f: 60 | downloaded_size = 0 61 | for chunk in response.iter_content(chunk_size): 62 | downloaded_size += chunk_size 63 | if pbar is not None: 64 | pbar.update(1) 65 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' 66 | f'/ {readable_file_size}') 67 | if chunk: # filter out keep-alive new chunks 68 | f.write(chunk) 69 | if pbar is not None: 70 | pbar.close() 71 | -------------------------------------------------------------------------------- /basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.transforms import augment 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor 7 | 8 | 9 | class FFHQDataset(data.Dataset): 10 | """FFHQ dataset for StyleGAN. 11 | 12 | Args: 13 | opt (dict): Config for train datasets. It contains the following keys: 14 | dataroot_gt (str): Data root path for gt. 15 | io_backend (dict): IO backend type and other kwarg. 16 | mean (list | tuple): Image mean. 17 | std (list | tuple): Image std. 18 | use_hflip (bool): Whether to horizontally flip. 19 | 20 | """ 21 | 22 | def __init__(self, opt): 23 | super(FFHQDataset, self).__init__() 24 | self.opt = opt 25 | # file client (io backend) 26 | self.file_client = None 27 | self.io_backend_opt = opt['io_backend'] 28 | 29 | self.gt_folder = opt['dataroot_gt'] 30 | self.mean = opt['mean'] 31 | self.std = opt['std'] 32 | 33 | if self.io_backend_opt['type'] == 'lmdb': 34 | self.io_backend_opt['db_paths'] = self.gt_folder 35 | if not self.gt_folder.endswith('.lmdb'): 36 | raise ValueError("'dataroot_gt' should end with '.lmdb', " 37 | f'but received {self.gt_folder}') 38 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 39 | self.paths = [line.split('.')[0] for line in fin] 40 | else: 41 | # FFHQ has 70000 images in total 42 | self.paths = [ 43 | osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000) 44 | ] 45 | 46 | def __getitem__(self, index): 47 | if self.file_client is None: 48 | self.file_client = FileClient( 49 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load gt image 52 | gt_path = self.paths[index] 53 | img_bytes = self.file_client.get(gt_path) 54 | img_gt = imfrombytes(img_bytes, float32=True) 55 | 56 | # random horizontal flip 57 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 58 | # BGR to RGB, HWC to CHW, numpy to tensor 59 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 60 | # normalize 61 | normalize(img_gt, self.mean, self.std, inplace=True) 62 | return {'gt': img_gt, 'gt_path': gt_path} 63 | 64 | def __len__(self): 65 | return len(self.paths) 66 | -------------------------------------------------------------------------------- /scripts/publish_models.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import subprocess 3 | import torch 4 | from os import path as osp 5 | from torch.serialization import _is_zipfile, _open_file_like 6 | 7 | 8 | def update_sha(paths): 9 | print('# Update sha ...') 10 | for idx, path in enumerate(paths): 11 | print(f'{idx+1:03d}: Processing {path}') 12 | net = torch.load(path, map_location=torch.device('cpu')) 13 | basename = osp.basename(path) 14 | if 'params' not in net and 'params_ema' not in net: 15 | raise ValueError(f'Please check! Model {basename} does not ' 16 | f"have 'params'/'params_ema' key.") 17 | else: 18 | if '-' in basename: 19 | # check whether the sha is the latest 20 | old_sha = basename.split('-')[1].split('.')[0] 21 | new_sha = subprocess.check_output(['sha256sum', 22 | path]).decode()[:8] 23 | if old_sha != new_sha: 24 | final_file = path.split('-')[0] + f'-{new_sha}.pth' 25 | print(f'\tSave from {path} to {final_file}') 26 | subprocess.Popen(['mv', path, final_file]) 27 | else: 28 | sha = subprocess.check_output(['sha256sum', path]).decode()[:8] 29 | final_file = path.split('.pth')[0] + f'-{sha}.pth' 30 | print(f'\tSave from {path} to {final_file}') 31 | subprocess.Popen(['mv', path, final_file]) 32 | 33 | 34 | def convert_to_backward_compatible_models(paths): 35 | """Convert to backward compatible pth files. 36 | 37 | PyTorch 1.6 uses a updated version of torch.save. In order to be compatible 38 | with previous PyTorch version, save it with 39 | _use_new_zipfile_serialization=False. 40 | """ 41 | print('# Convert to backward compatible pth files ...') 42 | for idx, path in enumerate(paths): 43 | print(f'{idx+1:03d}: Processing {path}') 44 | flag_need_conversion = False 45 | with _open_file_like(path, 'rb') as opened_file: 46 | if _is_zipfile(opened_file): 47 | flag_need_conversion = True 48 | 49 | if flag_need_conversion: 50 | net = torch.load(path, map_location=torch.device('cpu')) 51 | print('\tConverting to compatible pth file...') 52 | torch.save(net, path, _use_new_zipfile_serialization=False) 53 | 54 | 55 | if __name__ == '__main__': 56 | paths = glob.glob('experiments/pretrained_models/*.pth') + glob.glob( 57 | 'experiments/pretrained_models/**/*.pth') 58 | convert_to_backward_compatible_models(paths) 59 | update_sha(paths) 60 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir 7 | 8 | 9 | class SingleImageDataset(data.Dataset): 10 | """Read only lq images in the test phase. 11 | 12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 13 | 14 | There are two modes: 15 | 1. 'meta_info_file': Use meta information file to generate paths. 16 | 2. 'folder': Scan folders to generate paths. 17 | 18 | Args: 19 | opt (dict): Config for train datasets. It contains the following keys: 20 | dataroot_lq (str): Data root path for lq. 21 | meta_info_file (str): Path for meta information file. 22 | io_backend (dict): IO backend type and other kwarg. 23 | """ 24 | 25 | def __init__(self, opt): 26 | super(SingleImageDataset, self).__init__() 27 | self.opt = opt 28 | # file client (io backend) 29 | self.file_client = None 30 | self.io_backend_opt = opt['io_backend'] 31 | self.mean = opt['mean'] if 'mean' in opt else None 32 | self.std = opt['std'] if 'std' in opt else None 33 | self.lq_folder = opt['dataroot_lq'] 34 | 35 | if self.io_backend_opt['type'] == 'lmdb': 36 | self.io_backend_opt['db_paths'] = [self.lq_folder] 37 | self.io_backend_opt['client_keys'] = ['lq'] 38 | self.paths = paths_from_lmdb(self.lq_folder) 39 | elif 'meta_info_file' in self.opt: 40 | with open(self.opt['meta_info_file'], 'r') as fin: 41 | self.paths = [ 42 | osp.join(self.lq_folder, 43 | line.split(' ')[0]) for line in fin 44 | ] 45 | else: 46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 47 | 48 | def __getitem__(self, index): 49 | if self.file_client is None: 50 | self.file_client = FileClient( 51 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 52 | 53 | # load lq image 54 | lq_path = self.paths[index] 55 | img_bytes = self.file_client.get(lq_path, 'lq') 56 | img_lq = imfrombytes(img_bytes, float32=True) 57 | 58 | # TODO: color space transform 59 | # BGR to RGB, HWC to CHW, numpy to tensor 60 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 61 | # normalize 62 | if self.mean is not None or self.std is not None: 63 | normalize(img_lq, self.mean, self.std, inplace=True) 64 | return {'lq': img_lq, 'lq_path': lq_path} 65 | 66 | def __len__(self): 67 | return len(self.paths) 68 | -------------------------------------------------------------------------------- /options/train/HighREV/EFNet_HighREV_Deblur.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: EFNet_highrev_single_deblur # add debug for quick debug 3 | model_type: ImageEventRestorationModel 4 | scale: 1 5 | num_gpu: 1 #4 6 | manual_seed: 10 7 | 8 | datasets: 9 | train: 10 | name: highrev-train 11 | type: NpzPngSingleDeblurDataset 12 | 13 | # dataroot: /work/lei_sun/HighREV/train 14 | dataroot: /PATH/TO/DATASET/HighREV/train 15 | 16 | voxel_bins: 6 17 | gt_size: 256 18 | # keep true if use events 19 | norm_voxel: true 20 | use_hflip: true 21 | use_rot: true 22 | 23 | filename_tmpl: '{}' 24 | io_backend: 25 | type: disk 26 | 27 | # data loader settings 28 | use_shuffle: true 29 | num_worker_per_gpu: 3 30 | batch_size_per_gpu: 4 # 4 for 2080, 8 for titan 31 | dataset_enlarge_ratio: 4 # accelarate, equals to the num_gpu 32 | prefetch_mode: cpu 33 | num_prefetch_queue: 2 34 | 35 | val: 36 | name: highrev-val 37 | type: NpzPngSingleDeblurDataset 38 | voxel_bins: 6 39 | # dataroot: /work/lei_sun/HighREV/val 40 | dataroot: /PATH/TO/DATASET/HighREV/val 41 | gt_size: ~ 42 | norm_voxel: true 43 | 44 | io_backend: 45 | type: disk 46 | 47 | use_hflip: false 48 | use_rot: false 49 | 50 | dataset_name: HighREV 51 | 52 | # network structures 53 | network_g: 54 | type: EFNet 55 | wf: 64 56 | fuse_before_downsample: true 57 | 58 | # path 59 | path: 60 | pretrain_network_g: ~ 61 | strict_load_g: true 62 | resume_state: ~ 63 | training_states: ~ # save current trainig model states, for resume 64 | 65 | # training settings 66 | train: 67 | optim_g: 68 | type: AdamW 69 | lr: !!float 2e-4 70 | weight_decay: !!float 1e-4 71 | betas: [0.9, 0.99] 72 | 73 | scheduler: 74 | type: TrueCosineAnnealingLR 75 | T_max: 200000 76 | eta_min: !!float 1e-7 77 | 78 | total_iter: 200000 79 | warmup_iter: -1 # no warm up 80 | 81 | # losses 82 | pixel_opt: 83 | type: PSNRLoss 84 | loss_weight: 0.5 85 | reduction: mean 86 | 87 | # validation settings 88 | val: 89 | val_freq: !!float 5e4 # 2e4 90 | save_img: false 91 | grids: ~ 92 | crop_size: ~ # use it of the gpu memory is not enough for whole image inference 93 | max_minibatch: 8 94 | 95 | metrics: 96 | psnr: 97 | type: calculate_psnr 98 | crop_border: 0 99 | test_y_channel: false 100 | 101 | ssim: 102 | type: calculate_ssim 103 | crop_border: 0 104 | test_y_channel: false 105 | 106 | # logging settings 107 | logger: 108 | print_freq: 200 109 | save_checkpoint_freq: !!float 2e4 110 | use_tb_logger: true 111 | wandb: 112 | project: your_project_name 113 | resume_id: x 114 | 115 | # dist training settings 116 | dist_params: 117 | backend: nccl 118 | port: 29500 119 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput( 45 | f'scontrol show hostname {node_list} | head -n1') 46 | # specify master port 47 | if port is not None: 48 | os.environ['MASTER_PORT'] = str(port) 49 | elif 'MASTER_PORT' in os.environ: 50 | pass # use MASTER_PORT in the environment variable 51 | else: 52 | # 29500 is torch.distributed default port 53 | os.environ['MASTER_PORT'] = '29500' 54 | os.environ['MASTER_ADDR'] = addr 55 | os.environ['WORLD_SIZE'] = str(ntasks) 56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 57 | os.environ['RANK'] = str(proc_id) 58 | dist.init_process_group(backend=backend) 59 | 60 | 61 | def get_dist_info(): 62 | if dist.is_available(): 63 | initialized = dist.is_initialized() 64 | else: 65 | initialized = False 66 | if initialized: 67 | rank = dist.get_rank() 68 | world_size = dist.get_world_size() 69 | else: 70 | rank = 0 71 | world_size = 1 72 | return rank, world_size 73 | 74 | 75 | def master_only(func): 76 | 77 | @functools.wraps(func) 78 | def wrapper(*args, **kwargs): 79 | rank, _ = get_dist_info() 80 | if rank == 0: 81 | return func(*args, **kwargs) 82 | 83 | return wrapper 84 | -------------------------------------------------------------------------------- /options/train/GoPro/EFNet.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: EFNet_experiment # add 'debug' for quick debug 3 | model_type: ImageEventRestorationModel 4 | scale: 1 5 | num_gpu: 1 #4 6 | manual_seed: 10 7 | 8 | datasets: 9 | train: 10 | name: gopro-h5-train 11 | type: H5ImageDataset 12 | 13 | # dataroot: ./datasets/GoPro_scer/train 14 | dataroot: /cluster/work/cvl/leisun/GOPRO_fullsize_h5_bin6_ver3/train 15 | 16 | # keep true if use events 17 | norm_voxel: true 18 | return_voxel: true 19 | 20 | return_mask: true 21 | use_mask: true 22 | 23 | filename_tmpl: '{}' 24 | io_backend: 25 | type: h5 26 | 27 | crop_size: 256 28 | use_flip: true 29 | use_rot: true 30 | 31 | # data loader settings 32 | use_shuffle: true 33 | num_worker_per_gpu: 3 34 | batch_size_per_gpu: 4 # 4 for 2080, 8 for titan 35 | dataset_enlarge_ratio: 4 # accelarate, equals to the num_gpu 36 | prefetch_mode: cpu 37 | num_prefetch_queue: 2 38 | 39 | val: 40 | name: gopro-h5-test 41 | type: H5ImageDataset 42 | 43 | # dataroot: ./datasets/test 44 | dataroot: /cluster/work/cvl/leisun/GOPRO_fullsize_h5_bin6_ver3/test # for debug 45 | 46 | norm_voxel: true 47 | return_voxel: true 48 | return_mask: true 49 | use_mask: true 50 | 51 | io_backend: 52 | type: h5 53 | 54 | crop_size: ~ 55 | use_flip: false 56 | use_rot: false 57 | 58 | dataset_name: GoPro 59 | 60 | # network structures 61 | network_g: 62 | type: EFNet 63 | wf: 64 64 | fuse_before_downsample: true 65 | 66 | # path 67 | path: 68 | pretrain_network_g: ~ 69 | strict_load_g: true 70 | resume_state: ~ 71 | training_states: ~ # save current trainig model states, for resume 72 | 73 | # training settings 74 | train: 75 | optim_g: 76 | type: AdamW 77 | lr: !!float 2e-4 78 | weight_decay: !!float 1e-4 79 | betas: [0.9, 0.99] 80 | 81 | scheduler: 82 | type: TrueCosineAnnealingLR 83 | T_max: 200000 84 | eta_min: !!float 1e-7 85 | 86 | total_iter: 200000 87 | warmup_iter: -1 # no warm up 88 | 89 | # losses 90 | pixel_opt: 91 | type: PSNRLoss 92 | loss_weight: 0.5 93 | reduction: mean 94 | 95 | # validation settings 96 | val: 97 | val_freq: !!float 5e4 # 2e4 98 | save_img: false 99 | grids: ~ 100 | crop_size: ~ 101 | max_minibatch: 8 102 | 103 | metrics: 104 | psnr: 105 | type: calculate_psnr 106 | crop_border: 0 107 | test_y_channel: false 108 | 109 | ssim: 110 | type: calculate_ssim 111 | crop_border: 0 112 | test_y_channel: false 113 | 114 | # logging settings 115 | logger: 116 | print_freq: 200 117 | save_checkpoint_freq: !!float 2e4 118 | use_tb_logger: true 119 | wandb: 120 | project: your_project_name 121 | resume_id: x 122 | 123 | # dist training settings 124 | dist_params: 125 | backend: nccl 126 | port: 29500 127 | -------------------------------------------------------------------------------- /options/train/REBlur/Finetune_EFNet.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: EFNet_Finetune # add 'debug' for quick debug 3 | model_type: ImageEventRestorationModel 4 | scale: 1 5 | num_gpu: 1 # 4 6 | manual_seed: 10 7 | 8 | datasets: 9 | train: 10 | name: REBlur-h5-train 11 | type: H5ImageDataset 12 | 13 | # dataroot: ./datasets/train 14 | dataroot: /cluster/work/cvl/leisun/REBlur/train # REBlur 15 | 16 | # keep true if use events 17 | norm_voxel: true 18 | return_voxel: true 19 | 20 | return_mask: true 21 | use_mask: true 22 | 23 | filename_tmpl: '{}' 24 | io_backend: 25 | type: h5 26 | 27 | crop_size: 256 28 | use_flip: true 29 | use_rot: true 30 | 31 | # data loader settings 32 | use_shuffle: true 33 | num_worker_per_gpu: 3 34 | batch_size_per_gpu: 4 35 | dataset_enlarge_ratio: 1 36 | prefetch_mode: cpu 37 | num_prefetch_queue: 2 38 | 39 | val: 40 | name: REBlur-h5-test 41 | type: H5ImageDataset 42 | 43 | # dataroot: ./datasets/test 44 | dataroot: /cluster/work/cvl/leisun/REBlur/test 45 | 46 | # keep true 47 | norm_voxel: true # true ! 48 | return_voxel: true 49 | 50 | return_mask: true # dataloader yields mask 51 | use_mask: true 52 | 53 | io_backend: 54 | type: h5 55 | 56 | crop_size: ~ 57 | use_flip: false 58 | use_rot: false 59 | 60 | dataset_name: REBlur 61 | 62 | # network structures 63 | network_g: 64 | type: EFNet 65 | wf: 64 #64 66 | fuse_before_downsample: true 67 | 68 | 69 | # path 70 | path: 71 | pretrain_network_g: ~ # pretrain path 72 | strict_load_g: true 73 | resume_state: ~ 74 | training_states: ~ # save current trainig model states, for resume 75 | 76 | 77 | # training settings 78 | train: 79 | optim_g: 80 | type: AdamW 81 | lr: !!float 2e-5 82 | weight_decay: !!float 1e-4 83 | betas: [0.9, 0.99] 84 | 85 | 86 | scheduler: 87 | type: TrueCosineAnnealingLR 88 | T_max: 600 # finetune 89 | eta_min: !!float 1e-7 90 | 91 | total_iter: 600 # finetune 92 | warmup_iter: -1 # no warm up 93 | 94 | # losses 95 | pixel_opt: 96 | type: PSNRLoss 97 | loss_weight: 0.5 98 | reduction: mean 99 | 100 | # validation settings 101 | val: 102 | val_freq: !!float 5e4 # 2e4 103 | save_img: false 104 | grids: ~ 105 | crop_size: ~ # 106 | max_minibatch: 8 107 | 108 | metrics: 109 | psnr: 110 | type: calculate_psnr 111 | crop_border: 0 112 | test_y_channel: false 113 | 114 | # add ssim 115 | ssim: 116 | type: calculate_ssim 117 | crop_border: 0 118 | test_y_channel: false 119 | 120 | # logging settings 121 | logger: 122 | print_freq: 50 123 | save_checkpoint_freq: !!float 100 124 | use_tb_logger: true 125 | wandb: 126 | project: ~ 127 | resume_id: ~ 128 | 129 | # dist training settings 130 | dist_params: 131 | backend: nccl 132 | port: 29500 133 | 134 | -------------------------------------------------------------------------------- /options/train/HighREV/EFNet_HighREV_Deblur_voxel.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: EFNet_highrev_single_deblur_voxel_debug # add debug for quick debug 3 | model_type: ImageEventRestorationModel 4 | scale: 1 5 | num_gpu: 1 #4 6 | manual_seed: 10 7 | 8 | datasets: 9 | train: 10 | name: highrev-train 11 | type: VoxelnpzPngSingleDeblurDataset 12 | 13 | # dataroot: /work/lei_sun/HighREV/train 14 | # dataroot_voxel: /work/lei_sun/HighREV_voxel/train/voxel 15 | dataroot: /PATH/TO/DATASET/HighREV/train 16 | dataroot_voxel: /PATH/TO/DATASET/HighREV_voxel/train/voxel 17 | 18 | gt_size: 256 19 | # keep true if use events 20 | norm_voxel: true 21 | use_hflip: true 22 | use_rot: true 23 | 24 | filename_tmpl: '{}' 25 | io_backend: 26 | type: disk 27 | 28 | # data loader settings 29 | use_shuffle: true 30 | num_worker_per_gpu: 3 31 | batch_size_per_gpu: 4 # 4 for 2080, 8 for titan 32 | dataset_enlarge_ratio: 4 # accelarate, equals to the num_gpu 33 | prefetch_mode: cpu 34 | num_prefetch_queue: 2 35 | 36 | val: 37 | name: highrev-val 38 | type: VoxelnpzPngSingleDeblurDataset 39 | 40 | # dataroot: /work/lei_sun/HighREV/val 41 | # dataroot_voxel: /work/lei_sun/HighREV_voxel/val/voxel 42 | dataroot: /PATH/TO/DATASET/HighREV/train 43 | dataroot_voxel: /PATH/TO/DATASET/HighREV_voxel/train/voxel 44 | 45 | 46 | gt_size: ~ 47 | norm_voxel: true 48 | 49 | io_backend: 50 | type: disk 51 | 52 | use_hflip: false 53 | use_rot: false 54 | 55 | dataset_name: HighREV 56 | 57 | # network structures 58 | network_g: 59 | type: EFNet 60 | wf: 64 61 | fuse_before_downsample: true 62 | 63 | # path 64 | path: 65 | pretrain_network_g: ~ 66 | strict_load_g: true 67 | resume_state: ~ 68 | training_states: ~ # save current trainig model states, for resume 69 | 70 | # training settings 71 | train: 72 | optim_g: 73 | type: AdamW 74 | lr: !!float 2e-4 75 | weight_decay: !!float 1e-4 76 | betas: [0.9, 0.99] 77 | 78 | scheduler: 79 | type: TrueCosineAnnealingLR 80 | T_max: 200000 81 | eta_min: !!float 1e-7 82 | 83 | total_iter: 200000 84 | warmup_iter: -1 # no warm up 85 | 86 | # losses 87 | pixel_opt: 88 | type: PSNRLoss 89 | loss_weight: 0.5 90 | reduction: mean 91 | 92 | # validation settings 93 | val: 94 | val_freq: !!float 5e4 # 2e4 95 | save_img: false 96 | grids: ~ 97 | crop_size: ~ # use it of the gpu memory is not enough for whole image inference 98 | max_minibatch: 8 99 | 100 | metrics: 101 | psnr: 102 | type: calculate_psnr 103 | crop_border: 0 104 | test_y_channel: false 105 | 106 | ssim: 107 | type: calculate_ssim 108 | crop_border: 0 109 | test_y_channel: false 110 | 111 | # logging settings 112 | logger: 113 | print_freq: 200 114 | save_checkpoint_freq: !!float 2e4 115 | use_tb_logger: true 116 | wandb: 117 | project: your_project_name 118 | resume_id: x 119 | 120 | # dist training settings 121 | dist_params: 122 | backend: nccl 123 | port: 29500 124 | -------------------------------------------------------------------------------- /basicsr/models/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to( 116 | device=self.device, non_blocking=True) 117 | 118 | def next(self): 119 | torch.cuda.current_stream().wait_stream(self.stream) 120 | batch = self.batch 121 | self.preload() 122 | return batch 123 | 124 | def reset(self): 125 | self.loader = iter(self.ori_loader) 126 | self.preload() 127 | -------------------------------------------------------------------------------- /basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.models.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', 11 | resize_input=True, 12 | normalize_input=False): 13 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 14 | # does resize the input. 15 | inception = InceptionV3([3], 16 | resize_input=resize_input, 17 | normalize_input=normalize_input) 18 | inception = nn.DataParallel(inception).eval().to(device) 19 | return inception 20 | 21 | 22 | @torch.no_grad() 23 | def extract_inception_features(data_generator, 24 | inception, 25 | len_generator=None, 26 | device='cuda'): 27 | """Extract inception features. 28 | 29 | Args: 30 | data_generator (generator): A data generator. 31 | inception (nn.Module): Inception model. 32 | len_generator (int): Length of the data_generator to show the 33 | progressbar. Default: None. 34 | device (str): Device. Default: cuda. 35 | 36 | Returns: 37 | Tensor: Extracted features. 38 | """ 39 | if len_generator is not None: 40 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 41 | else: 42 | pbar = None 43 | features = [] 44 | 45 | for data in data_generator: 46 | if pbar: 47 | pbar.update(1) 48 | data = data.to(device) 49 | feature = inception(data)[0].view(data.shape[0], -1) 50 | features.append(feature.to('cpu')) 51 | if pbar: 52 | pbar.close() 53 | features = torch.cat(features, 0) 54 | return features 55 | 56 | 57 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 58 | """Numpy implementation of the Frechet Distance. 59 | 60 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 61 | and X_2 ~ N(mu_2, C_2) is 62 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 63 | Stable version by Dougal J. Sutherland. 64 | 65 | Args: 66 | mu1 (np.array): The sample mean over activations. 67 | sigma1 (np.array): The covariance matrix over activations for 68 | generated samples. 69 | mu2 (np.array): The sample mean over activations, precalculated on an 70 | representative data set. 71 | sigma2 (np.array): The covariance matrix over activations, 72 | precalculated on an representative data set. 73 | 74 | Returns: 75 | float: The Frechet Distance. 76 | """ 77 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 78 | assert sigma1.shape == sigma2.shape, ( 79 | 'Two covariances have different dimensions') 80 | 81 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 82 | 83 | # Product might be almost singular 84 | if not np.isfinite(cov_sqrt).all(): 85 | print('Product of cov matrices is singular. Adding {eps} to diagonal ' 86 | 'of cov estimates') 87 | offset = np.eye(sigma1.shape[0]) * eps 88 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 89 | 90 | # Numerical error might give slight imaginary component 91 | if np.iscomplexobj(cov_sqrt): 92 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 93 | m = np.max(np.abs(cov_sqrt.imag)) 94 | raise ValueError(f'Imaginary component {m}') 95 | cov_sqrt = cov_sqrt.real 96 | 97 | mean_diff = mu1 - mu2 98 | mean_norm = mean_diff @ mean_diff 99 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 100 | fid = mean_norm + trace 101 | 102 | return fid 103 | -------------------------------------------------------------------------------- /basicsr/data/event_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def events_to_voxel_grid(events, num_bins, width, height, return_format='CHW'): 6 | """ 7 | Build a voxel grid with bilinear interpolation in the time domain from a set of events. 8 | 9 | :param events: a [N x 4] NumPy array containing one event per row in the form: [timestamp, x, y, polarity] 10 | :param num_bins: number of bins in the temporal axis of the voxel grid 11 | :param width, height: dimensions of the voxel grid 12 | :param return_format: 'CHW' or 'HWC' 13 | """ 14 | 15 | assert(events.shape[1] == 4) 16 | assert(num_bins > 0) 17 | assert(width > 0) 18 | assert(height > 0) 19 | 20 | voxel_grid = np.zeros((num_bins, height, width), np.float32).ravel() 21 | # print('DEBUG: voxel.shape:{}'.format(voxel_grid.shape)) 22 | 23 | # normalize the event timestamps so that they lie between 0 and num_bins 24 | last_stamp = events[-1, 0] 25 | # print('last stamp:{}'.format(last_stamp)) 26 | # print('max stamp:{}'.format(events[:, 0].max())) 27 | # print('timestamp:{}'.format(events[:, 0])) 28 | # print('polarity:{}'.format(events[:, -1])) 29 | 30 | first_stamp = events[0, 0] 31 | deltaT = last_stamp - first_stamp 32 | 33 | if deltaT == 0: 34 | deltaT = 1.0 35 | 36 | events[:, 0] = (num_bins - 1) * (events[:, 0] - first_stamp) / deltaT # 37 | ts = events[:, 0] 38 | xs = events[:, 1].astype(int) 39 | ys = events[:, 2].astype(int) 40 | pols = events[:, 3] 41 | pols[pols == 0] = -1 # polarity should be +1 / -1 42 | 43 | tis = ts.astype(int) 44 | dts = ts - tis 45 | vals_left = pols * (1.0 - dts) 46 | vals_right = pols * dts 47 | 48 | valid_indices = tis < num_bins # [True True ... True] 49 | # print('x max:{}'.format(xs[valid_indices].max())) 50 | # print('y max:{}'.format(ys[valid_indices].max())) 51 | # print('tix max:{}'.format(tis[valid_indices].max())) 52 | 53 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width ## ! ! ! 54 | + tis[valid_indices] * width * height, vals_left[valid_indices]) 55 | 56 | valid_indices = (tis + 1) < num_bins 57 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width 58 | + (tis[valid_indices] + 1) * width * height, vals_right[valid_indices]) 59 | 60 | voxel_grid = np.reshape(voxel_grid, (num_bins, height, width)) 61 | 62 | if return_format == 'CHW': 63 | return voxel_grid 64 | elif return_format == 'HWC': 65 | return voxel_grid.transpose(1,2,0) 66 | 67 | def voxel_norm(voxel): 68 | """ 69 | Norm the voxel 70 | 71 | :param voxel: The unnormed voxel grid 72 | :return voxel: The normed voxel grid 73 | """ 74 | nonzero_ev = (voxel != 0) 75 | num_nonzeros = nonzero_ev.sum() 76 | # print('DEBUG: num_nonzeros:{}'.format(num_nonzeros)) 77 | if num_nonzeros > 0: 78 | # compute mean and stddev of the **nonzero** elements of the event tensor 79 | # we do not use PyTorch's default mean() and std() functions since it's faster 80 | # to compute it by hand than applying those funcs to a masked array 81 | mean = voxel.sum() / num_nonzeros 82 | stddev = torch.sqrt((voxel ** 2).sum() / num_nonzeros - mean ** 2) 83 | mask = nonzero_ev.float() 84 | voxel = mask * (voxel - mean) / stddev 85 | 86 | return voxel 87 | 88 | 89 | 90 | def filter_event(x,y,p,t, s_e_index=[0,6]): 91 | ''' 92 | s_e_index: include both left and right index 93 | ''' 94 | t_1=t.squeeze(1) 95 | uniqw, inverse = np.unique(t_1, return_inverse=True) 96 | discretized_ts = np.bincount(inverse) 97 | index_exposure_start = np.sum(discretized_ts[0:s_e_index[0]]) 98 | index_exposure_end = np.sum(discretized_ts[0:s_e_index[1]+1]) 99 | x_1 = x[index_exposure_start:index_exposure_end] 100 | y_1 = y[index_exposure_start:index_exposure_end] 101 | p_1 = p[index_exposure_start:index_exposure_end] 102 | t_1 = t[index_exposure_start:index_exposure_end] 103 | 104 | return x_1, y_1, p_1, t_1 105 | 106 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from collections import OrderedDict 3 | from os import path as osp 4 | 5 | 6 | def ordered_yaml(): 7 | """Support OrderedDict for yaml. 8 | 9 | Returns: 10 | yaml Loader and Dumper. 11 | """ 12 | try: 13 | from yaml import CDumper as Dumper 14 | from yaml import CLoader as Loader 15 | except ImportError: 16 | from yaml import Dumper, Loader 17 | 18 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 19 | 20 | def dict_representer(dumper, data): 21 | return dumper.represent_dict(data.items()) 22 | 23 | def dict_constructor(loader, node): 24 | return OrderedDict(loader.construct_pairs(node)) 25 | 26 | Dumper.add_representer(OrderedDict, dict_representer) 27 | Loader.add_constructor(_mapping_tag, dict_constructor) 28 | return Loader, Dumper 29 | 30 | 31 | def parse(opt_path, is_train=True): 32 | """Parse option file. 33 | 34 | Args: 35 | opt_path (str): Option file path. 36 | is_train (str): Indicate whether in training or not. Default: True. 37 | 38 | Returns: 39 | (dict): Options. 40 | """ 41 | with open(opt_path, mode='r') as f: 42 | Loader, _ = ordered_yaml() 43 | opt = yaml.load(f, Loader=Loader) 44 | 45 | opt['is_train'] = is_train 46 | 47 | # datasets 48 | if 'datasets' in opt: 49 | for phase, dataset in opt['datasets'].items(): 50 | # for several datasets, e.g., test_1, test_2 51 | phase = phase.split('_')[0] 52 | dataset['phase'] = phase 53 | if 'scale' in opt: 54 | dataset['scale'] = opt['scale'] 55 | if dataset.get('dataroot_gt') is not None: 56 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 57 | if dataset.get('dataroot_lq') is not None: 58 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 59 | 60 | # paths 61 | for key, val in opt['path'].items(): 62 | if (val is not None) and ('resume_state' in key 63 | or 'pretrain_network' in key): 64 | opt['path'][key] = osp.expanduser(val) 65 | 66 | # opt['path']['root'] = osp.abspath( 67 | # osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 68 | opt['path']['root'] = opt['path'].get('root',osp.abspath( 69 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))) 70 | 71 | if is_train: 72 | experiments_root = osp.join(opt['path']['root'], 'experiments', 73 | opt['name']) 74 | opt['path']['experiments_root'] = experiments_root 75 | opt['path']['models'] = osp.join(experiments_root, 'models') 76 | opt['path']['training_states'] = osp.join(experiments_root, 77 | 'training_states') 78 | opt['path']['log'] = experiments_root 79 | opt['path']['visualization'] = osp.join(experiments_root, 80 | 'visualization') 81 | 82 | # change some options for debug mode 83 | if 'debug' in opt['name']: 84 | if 'val' in opt: 85 | opt['val']['val_freq'] = 8 86 | opt['logger']['print_freq'] = 1 87 | opt['logger']['save_checkpoint_freq'] = 8 88 | else: # test 89 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 90 | opt['path']['results_root'] = results_root 91 | opt['path']['log'] = results_root 92 | # opt['path']['visualization'] = opt['path'].get('visualization', osp.join(results_root, 'visualization')) 93 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 94 | 95 | return opt 96 | 97 | 98 | def dict2str(opt, indent_level=1): 99 | """dict to string for printing options. 100 | 101 | Args: 102 | opt (dict): Option dict. 103 | indent_level (int): Indent level. Default: 1. 104 | 105 | Return: 106 | (str): Option string for printing. 107 | """ 108 | msg = '\n' 109 | for k, v in opt.items(): 110 | if isinstance(v, dict): 111 | msg += ' ' * (indent_level * 2) + k + ':[' 112 | msg += dict2str(v, indent_level + 1) 113 | msg += ' ' * (indent_level * 2) + ']\n' 114 | else: 115 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 116 | return msg 117 | -------------------------------------------------------------------------------- /scripts/data_preparation/make_voxels_esim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import cv2 4 | import os 5 | from .raw_event_dataset import * 6 | from .noise_function import add_noise_to_voxel, put_hot_pixels_in_voxel_ 7 | import argparse 8 | import h5py 9 | import torch 10 | import time 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--input_path", default="/scratch/leisun/Datasets/GOPRO_fullsize_h5/test", help="Path to hdf5 file") 15 | parser.add_argument("--save_path", default="/scratch/e_work/GOPRO_SCER/test") 16 | parser.add_argument("--voxel_method", default="SCER_esim", help="SCER_esim, SCER_real_data") 17 | parser.add_argument("--add_noise", default=True, help="add noisy to voxel like hot pixel") 18 | 19 | exposure_time = 1/240 20 | num_bins = 6 21 | num_pixels = 1280*720 22 | 23 | args = parser.parse_args() 24 | 25 | os.makedirs(args.save_path, exist_ok=True) 26 | 27 | def voxel2mask(voxel): 28 | mask_final = np.zeros_like(voxel[0, :, :]) 29 | mask = (voxel != 0) 30 | for i in range(mask.shape[0]): 31 | mask_final = np.logical_or(mask_final, mask[i, :, :]) 32 | # to uint8 image 33 | mask_img = mask_final * np.ones_like(mask_final) * 255 34 | mask_img = mask_img[..., np.newaxis] # H,W,C 35 | mask_img = np.uint8(mask_img) 36 | 37 | return mask_img 38 | 39 | def main(): 40 | # dataset settings 41 | # voxel_method = {'method': 'SCER_esim'} 42 | voxel_method = {'method': args.voxel_method} 43 | 44 | has_exposure_time = False # false if esim, true if our seems data 45 | # no data augmentation 46 | data_augment = {} 47 | dataset_kwargs = {'transforms': data_augment, 'voxel_method': voxel_method, 'num_bins': num_bins, 48 | 'has_exposure_time': has_exposure_time, 'combined_voxel_channels': True, 'voxel_temporal_bilinear': True} # voxel_temporal_bilinear for voxel grid 49 | 50 | file_folder_path = args.input_path 51 | output_file_folder_path = args.save_path 52 | h5_file_paths = [os.path.join(file_folder_path, s) for s in os.listdir(file_folder_path)] 53 | output_h5_file_paths = [os.path.join(output_file_folder_path, s) for s in os.listdir(file_folder_path)] 54 | # for each h5_file 55 | for i in range(len(h5_file_paths)): 56 | print("processing file: {}".format(h5_file_paths[i])) 57 | # h5_file = h5py.File(h5_file_paths[i], 'a') 58 | h5_file = h5py.File(output_h5_file_paths[i],'a') 59 | # voxel_dset = h5_file.create_group("voxels") 60 | # N = int(os.path.basename(h5_file_paths[i]).split('_')[1]) # num of the sharp img for blur image produce 61 | dataset_kwargs.update({'data_path':h5_file_paths[i]}) 62 | dloader = GoproEsimH5Dataset(**dataset_kwargs) 63 | num_img = 0 64 | for item in dloader: 65 | voxel=item['voxel'] 66 | num_events = item['num_events'] 67 | blur = item['frame'] # C,H,W 68 | sharp = item['frame_gt'] 69 | 70 | # add noise to voxel Here 71 | if args.add_noise: 72 | # print("Add noisy to voxels") 73 | voxel = add_noise_to_voxel(voxel, noise_std=1.0, noise_fraction=0.05) 74 | put_hot_pixels_in_voxel_(voxel, hot_pixel_range=20, hot_pixel_fraction=0.00002) 75 | 76 | voxel_np = voxel.numpy() # shape: bin+1,H,W 77 | blur_np=np.uint8(np.clip(255*blur.numpy(), 0, 255)) 78 | sharp_np=np.uint8(np.clip(255*sharp.numpy(), 0, 255)) 79 | mask_img = voxel2mask(voxel_np) 80 | #close filter 81 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) 82 | mask_img_close = cv2.morphologyEx(mask_img, cv2.MORPH_CLOSE, kernel, iterations=1) 83 | # print(mask_img_close.shape) 84 | mask_img_close = mask_img_close[np.newaxis,...] # H,W -> C,H,W C=1 85 | 86 | # save to h5 image 87 | voxel_dset=h5_file.create_dataset("voxels/voxel{:09d}".format(num_img), data=voxel_np, dtype=np.dtype(np.float32)) 88 | image_dset=h5_file.create_dataset("images/image{:09d}".format(num_img), data=blur_np, dtype=np.dtype(np.uint8)) 89 | sharp_image_dset=h5_file.create_dataset("sharp_images/image{:09d}".format(num_img), data=sharp_np, dtype=np.dtype(np.uint8)) 90 | mask_dset=h5_file.create_dataset("masks/mask{:09d}".format(num_img), data=mask_img_close, dtype=np.dtype(np.uint8)) 91 | 92 | 93 | 94 | voxel_dset.attrs['size']=voxel_np.shape 95 | image_dset.attrs['size']=blur_np.shape 96 | sharp_image_dset.attrs['size']=sharp_np.shape 97 | mask_dset.attrs['size']=mask_img_close.shape 98 | 99 | 100 | num_img+=1 101 | sensor_resolution = [blur_np.shape[1], blur_np.shape[2]] 102 | h5_file.attrs['sensor_resolution'] = sensor_resolution 103 | 104 | 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | 110 | -------------------------------------------------------------------------------- /scripts/data_preparation/make_voxels_real.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import cv2 4 | import os 5 | from .raw_event_dataset import * 6 | from .noise_function import add_noise_to_voxel, put_hot_pixels_in_voxel_ 7 | import argparse 8 | import h5py 9 | import torch 10 | import time 11 | 12 | # num_bins = 6 # SCER_esim 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--input_path", default="/scratch/leisun/REBlur_h5", help="Path to hdf5 file") 15 | parser.add_argument("--save_path", default="/scratch/leisun/REBlur_SCER") 16 | 17 | parser.add_argument("--voxel_method", default="SCER_real_data", help="SCER_esim, SCER_real_data, SBT, All_accumulate=SBT + bin=1") 18 | parser.add_argument("--add_noise", default=False, help="add noisy to voxel like hot pixel") 19 | 20 | has_exposure_time = True # false if esim, true if our seems data 21 | exposure_time = 1/240 22 | num_bins = 6 23 | num_pixels = 1280*720 24 | has_gt_frame = False 25 | 26 | args = parser.parse_args() 27 | 28 | os.makedirs(args.save_path, exist_ok=True) 29 | 30 | def voxel2mask(voxel): 31 | mask_final = np.zeros_like(voxel[0, :, :]) 32 | mask = (voxel != 0) 33 | for i in range(mask.shape[0]): 34 | mask_final = np.logical_or(mask_final, mask[i, :, :]) 35 | # to uint8 image 36 | mask_img = mask_final * np.ones_like(mask_final) * 255 37 | mask_img = mask_img[..., np.newaxis] # H,W,C 38 | mask_img = np.uint8(mask_img) 39 | 40 | return mask_img 41 | 42 | def main(): 43 | # dataset settings 44 | # voxel_method = {'method': 'SCER_esim'} 45 | voxel_method = {'method': args.voxel_method} 46 | 47 | # no data augmentation 48 | data_augment = {} 49 | dataset_kwargs = {'transforms': data_augment, 'voxel_method': voxel_method, 'num_bins': num_bins, 'return_gt_frame':has_gt_frame, 50 | 'has_exposure_time': has_exposure_time, 'combined_voxel_channels': True, 'keep_middle':False } 51 | 52 | file_folder_path = args.input_path 53 | output_file_folder_path = args.save_path 54 | h5_file_paths = [os.path.join(file_folder_path, s) for s in os.listdir(file_folder_path)] 55 | output_h5_file_paths = [os.path.join(output_file_folder_path, s) for s in os.listdir(file_folder_path)] 56 | # for each h5_file 57 | for i in range(len(h5_file_paths)): 58 | print("processing file: {}".format(h5_file_paths[i])) 59 | # h5_file = h5py.File(h5_file_paths[i], 'a') 60 | h5_file = h5py.File(output_h5_file_paths[i],'a') 61 | 62 | dataset_kwargs.update({'data_path':h5_file_paths[i]}) 63 | dloader = SeemsH5Dataset(**dataset_kwargs) 64 | num_img = 0 65 | for item in dloader: 66 | voxel=item['voxel'] 67 | num_events = item['num_events'] 68 | blur = item['frame'] # C,H,W 69 | if has_gt_frame: 70 | sharp = item['frame_gt'] 71 | 72 | # 1,262,320 -> 1,260,320 73 | voxel = voxel[:,:-2,:] 74 | blur = blur[:,:-2,:] 75 | if has_gt_frame: 76 | sharp = sharp[:,:-2,:] 77 | 78 | 79 | # add noise to voxel Here 80 | if args.add_noise: 81 | # print("Add noisy to voxels") 82 | voxel = add_noise_to_voxel(voxel, noise_std=1.0, noise_fraction=0.05) 83 | put_hot_pixels_in_voxel_(voxel, hot_pixel_range=20, hot_pixel_fraction=0.00002) 84 | 85 | 86 | voxel_np = voxel.numpy() # shape: bin+1,H,W 87 | blur_np=np.uint8(np.clip(255*blur.numpy(), 0, 255)) 88 | if has_gt_frame: 89 | sharp_np=np.uint8(np.clip(255*sharp.numpy(), 0, 255)) 90 | 91 | mask_img = voxel2mask(voxel_np) 92 | #close filter 93 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) 94 | mask_img_close = cv2.morphologyEx(mask_img, cv2.MORPH_CLOSE, kernel, iterations=1) 95 | # print(mask_img_close.shape) 96 | mask_img_close = mask_img_close[np.newaxis,...] # H,W -> C,H,W C=1 97 | 98 | # save to h5 image 99 | voxel_dset=h5_file.create_dataset("voxels/voxel{:09d}".format(num_img), data=voxel_np, dtype=np.dtype(np.float32)) 100 | image_dset=h5_file.create_dataset("images/image{:09d}".format(num_img), data=blur_np, dtype=np.dtype(np.uint8)) 101 | if has_gt_frame: 102 | 103 | sharp_image_dset=h5_file.create_dataset("sharp_images/image{:09d}".format(num_img), data=sharp_np, dtype=np.dtype(np.uint8)) 104 | 105 | mask_dset=h5_file.create_dataset("masks/mask{:09d}".format(num_img), data=mask_img_close, dtype=np.dtype(np.uint8)) 106 | 107 | 108 | 109 | voxel_dset.attrs['size']=voxel_np.shape 110 | image_dset.attrs['size']=blur_np.shape 111 | if has_gt_frame: 112 | sharp_image_dset.attrs['size']=sharp_np.shape 113 | mask_dset.attrs['size']=mask_img_close.shape 114 | 115 | 116 | num_img+=1 117 | sensor_resolution = [blur_np.shape[1], blur_np.shape[2]] 118 | h5_file.attrs['sensor_resolution'] = sensor_resolution 119 | 120 | 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | 126 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from functools import partial 7 | from os import path as osp 8 | 9 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 10 | from basicsr.utils import get_root_logger, scandir 11 | from basicsr.utils.dist_util import get_dist_info 12 | from basicsr.data.h5_image_dataset import * 13 | 14 | __all__ = ['create_dataset', 'create_dataloader'] 15 | 16 | # automatically scan and import dataset modules 17 | # scan all the files under the data folder with '_dataset' in file names 18 | data_folder = osp.dirname(osp.abspath(__file__)) 19 | dataset_filenames = [ 20 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) 21 | if v.endswith('_dataset.py') 22 | ] 23 | # import all the dataset modules 24 | _dataset_modules = [ 25 | importlib.import_module(f'basicsr.data.{file_name}') 26 | for file_name in dataset_filenames 27 | ] 28 | 29 | 30 | def create_dataset(dataset_opt): 31 | """Create dataset. 32 | 33 | Args: 34 | dataset_opt (dict): Configuration for dataset. It constains: 35 | name (str): Dataset name. 36 | type (str): Dataset type. 37 | """ 38 | dataset_type = dataset_opt['type'] 39 | 40 | # dynamic instantiation 41 | for module in _dataset_modules: 42 | dataset_cls = getattr(module, dataset_type, None) 43 | if dataset_cls is not None: 44 | break 45 | if dataset_cls is None: 46 | raise ValueError(f'Dataset {dataset_type} is not found.') 47 | 48 | if dataset_type == "H5ImageDataset": 49 | dataset = concatenate_h5_datasets(dataset_cls, dataset_opt) 50 | 51 | else: 52 | dataset = dataset_cls(dataset_opt) 53 | 54 | logger = get_root_logger() 55 | logger.info( 56 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' 57 | 'is created.') 58 | return dataset 59 | 60 | 61 | 62 | def create_dataloader(dataset, 63 | dataset_opt, 64 | num_gpu=1, 65 | dist=False, 66 | sampler=None, 67 | seed=None): 68 | """Create dataloader. 69 | 70 | Args: 71 | dataset (torch.utils.data.Dataset): Dataset. 72 | dataset_opt (dict): Dataset options. It contains the following keys: 73 | phase (str): 'train' or 'val'. 74 | num_worker_per_gpu (int): Number of workers for each GPU. 75 | batch_size_per_gpu (int): Training batch size for each GPU. 76 | num_gpu (int): Number of GPUs. Used only in the train phase. 77 | Default: 1. 78 | dist (bool): Whether in distributed training. Used only in the train 79 | phase. Default: False. 80 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 81 | seed (int | None): Seed. Default: None 82 | """ 83 | phase = dataset_opt['phase'] 84 | rank, _ = get_dist_info() 85 | if phase == 'train': 86 | if dist: # distributed training 87 | batch_size = dataset_opt['batch_size_per_gpu'] 88 | num_workers = dataset_opt['num_worker_per_gpu'] 89 | else: # non-distributed training 90 | multiplier = 1 if num_gpu == 0 else num_gpu 91 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 92 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 93 | dataloader_args = dict( 94 | dataset=dataset, 95 | batch_size=batch_size, 96 | shuffle=False, 97 | num_workers=num_workers, 98 | sampler=sampler, 99 | drop_last=True) 100 | if sampler is None: 101 | dataloader_args['shuffle'] = True 102 | dataloader_args['worker_init_fn'] = partial( 103 | worker_init_fn, num_workers=num_workers, rank=rank, 104 | seed=seed) if seed is not None else None 105 | elif phase in ['val', 'test']: # validation 106 | dataloader_args = dict( 107 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 108 | else: 109 | raise ValueError(f'Wrong dataset phase: {phase}. ' 110 | "Supported ones are 'train', 'val' and 'test'.") 111 | 112 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 113 | 114 | prefetch_mode = dataset_opt.get('prefetch_mode') 115 | if prefetch_mode == 'cpu': # CPUPrefetcher 116 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 117 | logger = get_root_logger() 118 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' 119 | f'num_prefetch_queue = {num_prefetch_queue}') 120 | return PrefetchDataLoader( 121 | num_prefetch_queue=num_prefetch_queue, **dataloader_args) 122 | else: 123 | # prefetch_mode=None: Normal dataloader 124 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 125 | return torch.utils.data.DataLoader(**dataloader_args) 126 | 127 | 128 | def worker_init_fn(worker_id, num_workers, rank, seed): 129 | # Set the worker seed to num_workers * rank + worker_id + seed 130 | worker_seed = num_workers * rank + worker_id + seed 131 | np.random.seed(worker_seed) 132 | random.seed(worker_seed) 133 | -------------------------------------------------------------------------------- /basicsr/utils/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs 6 | 7 | def prepare_keys(folder_path, suffix='png'): 8 | """Prepare image path list and keys for DIV2K dataset. 9 | 10 | Args: 11 | folder_path (str): Folder path. 12 | 13 | Returns: 14 | list[str]: Image path list. 15 | list[str]: Key list. 16 | """ 17 | print('Reading image path list ...') 18 | img_path_list = sorted( 19 | list(scandir(folder_path, suffix=suffix, recursive=False))) 20 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] 21 | 22 | return img_path_list, keys 23 | 24 | def create_lmdb_for_reds(): 25 | folder_path = './datasets/REDS/val/sharp_300' 26 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb' 27 | img_path_list, keys = prepare_keys(folder_path, 'png') 28 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 29 | # 30 | folder_path = './datasets/REDS/val/blur_300' 31 | lmdb_path = './datasets/REDS/val/blur_300.lmdb' 32 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 33 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 34 | 35 | folder_path = './datasets/REDS/train/train_sharp' 36 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb' 37 | img_path_list, keys = prepare_keys(folder_path, 'png') 38 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 39 | 40 | folder_path = './datasets/REDS/train/train_blur_jpeg' 41 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb' 42 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 43 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 44 | 45 | 46 | def create_lmdb_for_gopro(): 47 | folder_path = './datasets/GoPro/train/blur_crops' 48 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb' 49 | 50 | img_path_list, keys = prepare_keys(folder_path, 'png') 51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 52 | 53 | folder_path = './datasets/GoPro/train/sharp_crops' 54 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb' 55 | 56 | img_path_list, keys = prepare_keys(folder_path, 'png') 57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 58 | 59 | folder_path = './datasets/GoPro/test/target' 60 | lmdb_path = './datasets/GoPro/test/target.lmdb' 61 | 62 | img_path_list, keys = prepare_keys(folder_path, 'png') 63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 64 | 65 | folder_path = './datasets/GoPro/test/input' 66 | lmdb_path = './datasets/GoPro/test/input.lmdb' 67 | 68 | img_path_list, keys = prepare_keys(folder_path, 'png') 69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 70 | 71 | def create_lmdb_for_rain13k(): 72 | folder_path = './datasets/Rain13k/train/input' 73 | lmdb_path = './datasets/Rain13k/train/input.lmdb' 74 | 75 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 77 | 78 | folder_path = './datasets/Rain13k/train/target' 79 | lmdb_path = './datasets/Rain13k/train/target.lmdb' 80 | 81 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 83 | 84 | def create_lmdb_for_SIDD(): 85 | folder_path = './datasets/SIDD/train/input_crops' 86 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb' 87 | 88 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 89 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 90 | 91 | folder_path = './datasets/SIDD/train/gt_crops' 92 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb' 93 | 94 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 96 | 97 | #for val 98 | folder_path = './datasets/SIDD/val/input_crops' 99 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb' 100 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat' 101 | if not osp.exists(folder_path): 102 | os.makedirs(folder_path) 103 | assert osp.exists(mat_path) 104 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb'] 105 | N, B, H ,W, C = data.shape 106 | data = data.reshape(N*B, H, W, C) 107 | for i in tqdm(range(N*B)): 108 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 109 | img_path_list, keys = prepare_keys(folder_path, 'png') 110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 111 | 112 | folder_path = './datasets/SIDD/val/gt_crops' 113 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb' 114 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat' 115 | if not osp.exists(folder_path): 116 | os.makedirs(folder_path) 117 | assert osp.exists(mat_path) 118 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb'] 119 | N, B, H ,W, C = data.shape 120 | data = data.reshape(N*B, H, W, C) 121 | for i in tqdm(range(N*B)): 122 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 123 | img_path_list, keys = prepare_keys(folder_path, 'png') 124 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 125 | -------------------------------------------------------------------------------- /basicsr/data/voxelnpz_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from torchvision.transforms.functional import normalize 3 | from tqdm import tqdm 4 | import os 5 | from pathlib import Path 6 | import random 7 | import numpy as np 8 | import torch 9 | 10 | from basicsr.data.data_util import (paired_paths_from_folder, 11 | paired_paths_from_lmdb, 12 | paired_paths_from_meta_info_file, 13 | recursive_glob) 14 | from basicsr.data.event_util import events_to_voxel_grid, voxel_norm 15 | from basicsr.data.transforms import augment, triple_random_crop, random_augmentation 16 | from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, get_root_logger 17 | from torch.utils.data.dataloader import default_collate 18 | 19 | 20 | class VoxelnpzPngSingleDeblurDataset(data.Dataset): 21 | """Paired vxoel(npz) and blurry image (png) dataset for event-based single image deblurring. 22 | --HighREV 23 | |----train 24 | | |----blur 25 | | | |----SEQNAME_%5d.png 26 | | | |----... 27 | | |----voxel 28 | | | |----SEQNAME_%5d.npz 29 | | | |----... 30 | | |----sharp 31 | | | |----SEQNAME_%5d.png 32 | | | |----... 33 | |----val 34 | ... 35 | 36 | 37 | Args: 38 | opt (dict): Config for train dataset. It contains the following keys: 39 | dataroot (str): Data root path. 40 | io_backend (dict): IO backend type and other kwarg. 41 | num_end_interpolation (int): Number of sharp frames to reconstruct in each blurry image. 42 | num_inter_interpolation (int): Number of sharp frames to interpolate between two blurry images. 43 | phase (str): 'train' or 'test' 44 | 45 | gt_size (int): Cropped patched size for gt patches. 46 | random_reverse (bool): Random reverse input frames. 47 | use_hflip (bool): Use horizontal flips. 48 | use_rot (bool): Use rotation (use vertical flip and transposing h 49 | and w for implementation). 50 | 51 | scale (bool): Scale, which will be added automatically. 52 | """ 53 | 54 | def __init__(self, opt): 55 | super(VoxelnpzPngSingleDeblurDataset, self).__init__() 56 | self.opt = opt 57 | self.dataroot = Path(opt['dataroot']) 58 | self.dataroot_voxel = Path(opt['dataroot_voxel']) 59 | self.split = 'train' if opt['phase'] == 'train' else 'val' # train or val 60 | self.norm_voxel = opt['norm_voxel'] 61 | self.dataPath = [] 62 | 63 | blur_frames = sorted(recursive_glob(rootdir=os.path.join(self.dataroot, 'blur'), suffix='.png')) 64 | blur_frames = [os.path.join(self.dataroot, 'blur', blur_frame) for blur_frame in blur_frames] 65 | 66 | sharp_frames = sorted(recursive_glob(rootdir=os.path.join(self.dataroot, 'sharp'), suffix='.png')) 67 | sharp_frames = [os.path.join(self.dataroot, 'sharp', sharp_frame) for sharp_frame in sharp_frames] 68 | 69 | event_frames = sorted(recursive_glob(rootdir=self.dataroot_voxel, suffix='.npz')) 70 | event_frames = [os.path.join(self.dataroot_voxel, event_frame) for event_frame in event_frames] 71 | 72 | assert len(blur_frames) == len(sharp_frames) == len(event_frames), f"Mismatch in blur ({len(blur_frames)}), sharp ({len(sharp_frames)}), and event ({len(event_frames)}) frame counts." 73 | 74 | for i in range(len(blur_frames)): 75 | self.dataPath.append({ 76 | 'blur_path': blur_frames[i], 77 | 'sharp_path': sharp_frames[i], 78 | 'event_paths': event_frames[i], 79 | }) 80 | logger = get_root_logger() 81 | logger.info(f"Dataset initialized with {len(self.dataPath)} samples.") 82 | 83 | # file client (io backend) 84 | self.file_client = None 85 | self.io_backend_opt = opt['io_backend'] 86 | # import pdb; pdb.set_trace() 87 | 88 | def __getitem__(self, index): 89 | if self.file_client is None: 90 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 91 | scale = self.opt['scale'] 92 | gt_size = self.opt['gt_size'] 93 | 94 | image_path = self.dataPath[index]['blur_path'] 95 | gt_path = self.dataPath[index]['sharp_path'] 96 | event_path = self.dataPath[index]['event_paths'] 97 | 98 | # get LQ 99 | img_bytes = self.file_client.get(image_path) # 'lq' 100 | img_lq = imfrombytes(img_bytes, float32=True) 101 | # get GT 102 | img_bytes = self.file_client.get(gt_path) # 'gt' 103 | img_gt = imfrombytes(img_bytes, float32=True) 104 | 105 | voxel = np.load(event_path)['voxel'] 106 | 107 | ## Data augmentation 108 | # voxel shape: h,w,c 109 | # crop 110 | if gt_size is not None: 111 | img_gt, img_lq, voxel = triple_random_crop(img_gt, img_lq, voxel, gt_size, scale, gt_path) 112 | 113 | # flip, rotate 114 | total_input = [img_lq, img_gt, voxel] 115 | img_results = augment(total_input, self.opt['use_hflip'], self.opt['use_rot']) 116 | 117 | img_results = img2tensor(img_results) # hwc -> chw 118 | img_lq, img_gt, voxel = img_results 119 | 120 | ## Norm voxel 121 | if self.norm_voxel: 122 | voxel = voxel_norm(voxel) 123 | 124 | origin_index = os.path.basename(image_path).split('.')[0] 125 | 126 | return {'frame': img_lq, 'frame_gt': img_gt, 'voxel': voxel, 'image_name': origin_index} 127 | 128 | def __len__(self): 129 | return len(self.dataPath) 130 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import sys 8 | import time 9 | import torch 10 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 11 | CUDAExtension) 12 | 13 | version_file = 'basicsr/version.py' 14 | 15 | 16 | def readme(): 17 | return '' 18 | # with open('README.md', encoding='utf-8') as f: 19 | # content = f.read() 20 | # return content 21 | 22 | 23 | def get_git_hash(): 24 | 25 | def _minimal_ext_cmd(cmd): 26 | # construct minimal environment 27 | env = {} 28 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 29 | v = os.environ.get(k) 30 | if v is not None: 31 | env[k] = v 32 | # LANGUAGE is used on win32 33 | env['LANGUAGE'] = 'C' 34 | env['LANG'] = 'C' 35 | env['LC_ALL'] = 'C' 36 | out = subprocess.Popen( 37 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 38 | return out 39 | 40 | try: 41 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 42 | sha = out.strip().decode('ascii') 43 | except OSError: 44 | sha = 'unknown' 45 | 46 | return sha 47 | 48 | 49 | def get_hash(): 50 | if os.path.exists('.git'): 51 | sha = get_git_hash()[:7] 52 | elif os.path.exists(version_file): 53 | try: 54 | from basicsr.version import __version__ 55 | sha = __version__.split('+')[-1] 56 | except ImportError: 57 | raise ImportError('Unable to get git version') 58 | else: 59 | sha = 'unknown' 60 | 61 | return sha 62 | 63 | 64 | def write_version_py(): 65 | content = """# GENERATED VERSION FILE 66 | # TIME: {} 67 | __version__ = '{}' 68 | short_version = '{}' 69 | version_info = ({}) 70 | """ 71 | sha = get_hash() 72 | with open('VERSION', 'r') as f: 73 | SHORT_VERSION = f.read().strip() 74 | VERSION_INFO = ', '.join( 75 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 76 | VERSION = SHORT_VERSION + '+' + sha 77 | 78 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, 79 | VERSION_INFO) 80 | with open(version_file, 'w') as f: 81 | f.write(version_file_str) 82 | 83 | 84 | def get_version(): 85 | with open(version_file, 'r') as f: 86 | exec(compile(f.read(), version_file, 'exec')) 87 | return locals()['__version__'] 88 | 89 | 90 | def make_cuda_ext(name, module, sources, sources_cuda=None): 91 | if sources_cuda is None: 92 | sources_cuda = [] 93 | define_macros = [] 94 | extra_compile_args = {'cxx': []} 95 | 96 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 97 | define_macros += [('WITH_CUDA', None)] 98 | extension = CUDAExtension 99 | extra_compile_args['nvcc'] = [ 100 | '-D__CUDA_NO_HALF_OPERATORS__', 101 | '-D__CUDA_NO_HALF_CONVERSIONS__', 102 | '-D__CUDA_NO_HALF2_OPERATORS__', 103 | ] 104 | sources += sources_cuda 105 | else: 106 | print(f'Compiling {name} without CUDA') 107 | extension = CppExtension 108 | 109 | return extension( 110 | name=f'{module}.{name}', 111 | sources=[os.path.join(*module.split('.'), p) for p in sources], 112 | define_macros=define_macros, 113 | extra_compile_args=extra_compile_args) 114 | 115 | 116 | def get_requirements(filename='requirements.txt'): 117 | return [] 118 | here = os.path.dirname(os.path.realpath(__file__)) 119 | with open(os.path.join(here, filename), 'r') as f: 120 | requires = [line.replace('\n', '') for line in f.readlines()] 121 | return requires 122 | 123 | 124 | if __name__ == '__main__': 125 | if '--no_cuda_ext' in sys.argv: 126 | ext_modules = [] 127 | sys.argv.remove('--no_cuda_ext') 128 | else: 129 | ext_modules = [ 130 | make_cuda_ext( 131 | name='deform_conv_ext', 132 | module='basicsr.models.ops.dcn', 133 | sources=['src/deform_conv_ext.cpp'], 134 | sources_cuda=[ 135 | 'src/deform_conv_cuda.cpp', 136 | 'src/deform_conv_cuda_kernel.cu' 137 | ]), 138 | make_cuda_ext( 139 | name='fused_act_ext', 140 | module='basicsr.models.ops.fused_act', 141 | sources=['src/fused_bias_act.cpp'], 142 | sources_cuda=['src/fused_bias_act_kernel.cu']), 143 | make_cuda_ext( 144 | name='upfirdn2d_ext', 145 | module='basicsr.models.ops.upfirdn2d', 146 | sources=['src/upfirdn2d.cpp'], 147 | sources_cuda=['src/upfirdn2d_kernel.cu']), 148 | ] 149 | 150 | write_version_py() 151 | setup( 152 | name='basicsr', 153 | version=get_version(), 154 | description='Open Source Image and Video Super-Resolution Toolbox', 155 | long_description=readme(), 156 | author='Xintao Wang', 157 | author_email='xintao.wang@outlook.com', 158 | keywords='computer vision, restoration, super resolution', 159 | url='https://github.com/xinntao/BasicSR', 160 | packages=find_packages( 161 | exclude=('options', 'datasets', 'experiments', 'results', 162 | 'tb_logger', 'wandb')), 163 | classifiers=[ 164 | 'Development Status :: 4 - Beta', 165 | 'License :: OSI Approved :: Apache Software License', 166 | 'Operating System :: OS Independent', 167 | 'Programming Language :: Python :: 3', 168 | 'Programming Language :: Python :: 3.7', 169 | 'Programming Language :: Python :: 3.8', 170 | ], 171 | license='Apache License 2.0', 172 | setup_requires=['cython', 'numpy'], 173 | install_requires=get_requirements(), 174 | ext_modules=ext_modules, 175 | cmdclass={'build_ext': BuildExtension}, 176 | zip_safe=False) 177 | -------------------------------------------------------------------------------- /scripts/download_pretrained_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path as osp 4 | 5 | from basicsr.utils.download_util import download_file_from_google_drive 6 | 7 | 8 | def download_pretrained_models(method, file_ids): 9 | save_path_root = f'./experiments/pretrained_models/{method}' 10 | os.makedirs(save_path_root, exist_ok=True) 11 | 12 | for file_name, file_id in file_ids.items(): 13 | save_path = osp.abspath(osp.join(save_path_root, file_name)) 14 | if osp.exists(save_path): 15 | user_response = input( 16 | f'{file_name} already exist. Do you want to cover it? Y/N\n') 17 | if user_response.lower() == 'y': 18 | print(f'Covering {file_name} to {save_path}') 19 | download_file_from_google_drive(file_id, save_path) 20 | elif user_response.lower() == 'n': 21 | print(f'Skipping {file_name}') 22 | else: 23 | raise ValueError('Wrong input. Only accpets Y/N.') 24 | else: 25 | print(f'Downloading {file_name} to {save_path}') 26 | download_file_from_google_drive(file_id, save_path) 27 | 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser() 31 | 32 | parser.add_argument( 33 | 'method', 34 | type=str, 35 | help=( 36 | "Options: 'ESRGAN', 'EDVR', 'StyleGAN', 'EDSR', 'DUF', 'DFDNet', " 37 | "'dlib'. Set to 'all' if you want to download all the models.")) 38 | args = parser.parse_args() 39 | 40 | file_ids = { 41 | 'ESRGAN': { 42 | 'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth': # file name 43 | '1b3_bWZTjNO3iL2js1yWkJfjZykcQgvzT', # file id 44 | 'ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth': 45 | '1swaV5iBMFfg-DL6ZyiARztbhutDCWXMM' 46 | }, 47 | 'EDVR': { 48 | 'EDVR_L_x4_SR_REDS_official-9f5f5039.pth': 49 | '127KXEjlCwfoPC1aXyDkluNwr9elwyHNb', 50 | 'EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth': 51 | '1aVR3lkX6ItCphNLcT7F5bbbC484h4Qqy', 52 | 'EDVR_M_woTSA_x4_SR_REDS_official-1edf645c.pth': 53 | '1C_WdN-NyNj-P7SOB5xIVuHl4EBOwd-Ny', 54 | 'EDVR_M_x4_SR_REDS_official-32075921.pth': 55 | '1dd6aFj-5w2v08VJTq5mS9OFsD-wALYD6', 56 | 'EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth': 57 | '1GZz_87ybR8eAAY3X2HWwI3L6ny7-5Yvl', 58 | 'EDVR_L_deblur_REDS_official-ca46bd8c.pth': 59 | '1_ma2tgHscZtkIY2tEJkVdU-UP8bnqBRE', 60 | 'EDVR_L_deblurcomp_REDS_official-0e988e5c.pth': 61 | '1fEoSeLFnHSBbIs95Au2W197p8e4ws4DW' 62 | }, 63 | 'StyleGAN': { 64 | 'stylegan2_ffhq_config_f_1024_official-b09c3668.pth': 65 | '163PfuVSYKh4vhkYkfEaufw84CiF4pvWG', 66 | 'stylegan2_ffhq_config_f_1024_discriminator_official-806ddc5e.pth': 67 | '1wyOdcJnMtAT_fEwXYJObee7hcLzI8usT', 68 | 'stylegan2_cat_config_f_256_official-b82c74e3.pth': 69 | '1dGUvw8FLch50FEDAgAa6st1AXGnjduc7', 70 | 'stylegan2_cat_config_f_256_discriminator_official-f6f5ed5c.pth': 71 | '19wuj7Ztg56QtwEs01-p_LjQeoz6G11kF', 72 | 'stylegan2_church_config_f_256_official-12725a53.pth': 73 | '1Rcpguh4t833wHlFrWz9UuqFcSYERyd2d', 74 | 'stylegan2_church_config_f_256_discriminator_official-feba65b0.pth': # noqa: E501 75 | '1ImOfFUOwKqDDKZCxxM4VUdPQCc-j85Z9', 76 | 'stylegan2_car_config_f_512_official-32c42d4e.pth': 77 | '1FviBGvzORv4T3w0c3m7BaIfLNeEd0dC8', 78 | 'stylegan2_car_config_f_512_discriminator_official-31f302ab.pth': 79 | '1hlZ7M2GrK6cDFd2FIYazPxOZXTUfudB3', 80 | 'stylegan2_horse_config_f_256_official-d3d97ebc.pth': 81 | '1LV4OR22tJN19HHfGk0e7dVqMhjD0APRm', 82 | 'stylegan2_horse_config_f_256_discriminator_official-efc5e50e.pth': 83 | '1T8xbI-Tz8EeSg3gCmQBNqGjLP5l3Mv84' 84 | }, 85 | 'EDSR': { 86 | 'EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth': 87 | '1mREMGVDymId3NzIc2u90sl_X4-pb4ZcV', 88 | 'EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth': 89 | '1EriqQqlIiRyPbrYGBbwr_FZzvb3iwqz5', 90 | 'EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth': 91 | '1bCK6cFYU01uJudLgUUe-jgx-tZ3ikOWn', 92 | 'EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth': 93 | '15257lZCRZ0V6F9LzTyZFYbbPrqNjKyMU', 94 | 'EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth': 95 | '18q_D434sLG_rAZeHGonAX8dkqjoyZ2su', 96 | 'EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth': 97 | '1GCi30YYCzgMCcgheGWGusP9aWKOAy5vl' 98 | }, 99 | 'DUF': { 100 | 'DUF_x2_16L_official-39537cb9.pth': 101 | '1e91cEZOlUUk35keK9EnuK0F54QegnUKo', 102 | 'DUF_x3_16L_official-34ce53ec.pth': 103 | '1XN6aQj20esM7i0hxTbfiZr_SL8i4PZ76', 104 | 'DUF_x4_16L_official-bf8f0cfa.pth': 105 | '1V_h9U1CZgLSHTv1ky2M3lvuH-hK5hw_J', 106 | 'DUF_x4_28L_official-cbada450.pth': 107 | '1M8w0AMBJW65MYYD-_8_be0cSH_SHhDQ4', 108 | 'DUF_x4_52L_official-483d2c78.pth': 109 | '1GcmEWNr7mjTygi-QCOVgQWOo5OCNbh_T' 110 | }, 111 | 'DFDNet': { 112 | 'DFDNet_dict_512-f79685f0.pth': 113 | '1iH00oMsoN_1OJaEQw3zP7_wqiAYMnY79', 114 | 'DFDNet_official-d1fa5650.pth': 115 | '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe' 116 | }, 117 | 'dlib': { 118 | 'mmod_human_face_detector-4cb19393.dat': 119 | '1FUM-hcoxNzFCOpCWbAUStBBMiU4uIGIL', 120 | 'shape_predictor_5_face_landmarks-c4b1e980.dat': 121 | '1PNPSmFjmbuuUDd5Mg5LDxyk7tu7TQv2F', 122 | 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': 123 | '1IneH-O-gNkG0SQpNCplwxtOAtRCkG2ni' 124 | } 125 | } 126 | 127 | if args.method == 'all': 128 | for method in file_ids.keys(): 129 | download_pretrained_models(method, file_ids[method]) 130 | else: 131 | download_pretrained_models(args.method, file_ids[args.method]) 132 | -------------------------------------------------------------------------------- /basicsr/utils/npz2voxel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | 5 | def events_to_voxel_grid(events, num_bins, width, height, return_format='CHW'): 6 | """ 7 | Build a voxel grid with bilinear interpolation in the time domain from a set of events. 8 | 9 | :param events: a [N x 4] NumPy array containing one event per row in the form: [timestamp, x, y, polarity] 10 | :param num_bins: number of bins in the temporal axis of the voxel grid 11 | :param width, height: dimensions of the voxel grid 12 | :param return_format: 'CHW' or 'HWC' 13 | """ 14 | 15 | assert(events.shape[1] == 4) 16 | assert(num_bins > 0) 17 | assert(width > 0) 18 | assert(height > 0) 19 | 20 | voxel_grid = np.zeros((num_bins, height, width), np.float32).ravel() 21 | # print('DEBUG: voxel.shape:{}'.format(voxel_grid.shape)) 22 | 23 | # normalize the event timestamps so that they lie between 0 and num_bins 24 | last_stamp = events[-1, 0] 25 | # print('last stamp:{}'.format(last_stamp)) 26 | # print('max stamp:{}'.format(events[:, 0].max())) 27 | # print('timestamp:{}'.format(events[:, 0])) 28 | # print('polarity:{}'.format(events[:, -1])) 29 | 30 | first_stamp = events[0, 0] 31 | deltaT = last_stamp - first_stamp 32 | 33 | if deltaT == 0: 34 | deltaT = 1.0 35 | 36 | events[:, 0] = (num_bins - 1) * (events[:, 0] - first_stamp) / deltaT # 37 | ts = events[:, 0] 38 | xs = events[:, 1].astype(int) 39 | ys = events[:, 2].astype(int) 40 | pols = events[:, 3] 41 | pols[pols == 0] = -1 # polarity should be +1 / -1 42 | 43 | tis = ts.astype(int) 44 | dts = ts - tis 45 | vals_left = pols * (1.0 - dts) 46 | vals_right = pols * dts 47 | 48 | valid_indices = tis < num_bins # [True True ... True] 49 | # print('x max:{}'.format(xs[valid_indices].max())) 50 | # print('y max:{}'.format(ys[valid_indices].max())) 51 | # print('tix max:{}'.format(tis[valid_indices].max())) 52 | 53 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width ## ! ! ! 54 | + tis[valid_indices] * width * height, vals_left[valid_indices]) 55 | 56 | valid_indices = (tis + 1) < num_bins 57 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width 58 | + (tis[valid_indices] + 1) * width * height, vals_right[valid_indices]) 59 | 60 | voxel_grid = np.reshape(voxel_grid, (num_bins, height, width)) 61 | 62 | if return_format == 'CHW': 63 | return voxel_grid 64 | elif return_format == 'HWC': 65 | return voxel_grid.transpose(1,2,0) 66 | 67 | def recursive_glob(rootdir='.', suffix=''): 68 | """Performs recursive glob with given suffix and rootdir 69 | :param rootdir is the root directory 70 | :param suffix is the suffix to be searched 71 | """ 72 | # return [os.path.join(looproot, filename) 73 | # for looproot, _, filenames in os.walk(rootdir) 74 | # for filename in filenames if filename.endswith(suffix)] 75 | return [filename 76 | for looproot, _, filenames in os.walk(rootdir) 77 | for filename in filenames if filename.endswith(suffix)] 78 | 79 | 80 | def main(): 81 | 82 | dataroot = '/work/lei_sun/HighREV/val' 83 | dataroot = '/work/lei_sun/HighREV/train' 84 | 85 | voxel_root = dataroot + '/voxel' 86 | if not os.path.exists(voxel_root): 87 | os.makedirs(voxel_root) 88 | voxel_bins = 6 89 | blur_frames = sorted(recursive_glob(rootdir=os.path.join(dataroot, 'blur'), suffix='.png')) 90 | blur_frames = [os.path.join(dataroot, 'blur', blur_frame) for blur_frame in blur_frames] 91 | h_lq, w_lq = 1224, 1632 92 | event_frames = sorted(recursive_glob(rootdir=os.path.join(dataroot, 'event'), suffix='.npz')) 93 | event_frames = [os.path.join(dataroot, 'event', event_frame) for event_frame in event_frames] 94 | 95 | all_event_lists = [] 96 | for i in range(len(blur_frames)): 97 | blur_name = os.path.basename(blur_frames[i]) # e.g., SEQNAME_00001.png 98 | base_name = os.path.splitext(blur_name)[0] # Remove .png, get SEQNAME_00001 99 | event_list = sorted([f for f in event_frames if f.startswith(os.path.join(dataroot, 'event', base_name + '_'))])[:-1] # remove the last one because it is out of the exposure time 100 | all_event_lists.append(event_list) 101 | 102 | total_time = 0 103 | num_events_processed = 0 104 | total_events = len(all_event_lists) * len(all_event_lists[0]) 105 | print("loop started") 106 | start_time = time.time() 107 | for event_list in all_event_lists: 108 | loop_start_time = time.time() 109 | 110 | events = [np.load(event_path) for event_path in event_list] 111 | parts = event_list[0].rsplit('_', 1) 112 | base_name = parts[0].split('/')[-1] 113 | print(base_name, 'start', event_list[0]) 114 | all_quad_event_array = np.zeros((0,4)).astype(np.float32) 115 | for event in events: 116 | ### IMPORTANT: dataset mistake x and y !!!!!!!! 117 | ### Switch x and y here !!!! 118 | y = event['x'].astype(np.float32) 119 | x = event['y'].astype(np.float32) 120 | t = event['timestamp'].astype(np.float32) 121 | p = event['polarity'].astype(np.float32) 122 | 123 | this_quad_event_array = np.concatenate((t,x,y,p),axis=1) # N,4 124 | all_quad_event_array = np.concatenate((all_quad_event_array, this_quad_event_array), axis=0) 125 | voxel = events_to_voxel_grid(all_quad_event_array, num_bins=voxel_bins, width=w_lq, height=h_lq, return_format='HWC') 126 | # print(voxel.dtype, voxel.shape) # float32, (1224, 1632, 6) 127 | 128 | voxel_path = os.path.join(voxel_root, base_name + '.npz') 129 | # print(f'saving to {voxel_path}') 130 | 131 | # voxel_path = base_name + '.npz' 132 | 133 | np.savez(voxel_path, voxel=voxel) 134 | 135 | loop_end_time = time.time() 136 | loop_duration = loop_end_time - loop_start_time 137 | total_time += loop_duration 138 | num_events_processed += len(events) 139 | 140 | print(f"Loop {num_events_processed}/{total_events} took {loop_duration:.2f} seconds.") 141 | print("") 142 | 143 | 144 | end_time = time.time() 145 | total_duration = end_time - start_time 146 | print(f"Total time taken: {total_duration:.2f} seconds.") 147 | 148 | 149 | if __name__ == '__main__': 150 | main() -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | from .logger import get_root_logger 10 | 11 | 12 | def set_random_seed(seed): 13 | """Set random seeds.""" 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | 21 | def get_time_str(): 22 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 23 | 24 | 25 | def mkdir_and_rename(path): 26 | """mkdirs. If path exists, rename it with timestamp and create a new one. 27 | 28 | Args: 29 | path (str): Folder path. 30 | """ 31 | if osp.exists(path): 32 | new_name = path + '_archived_' + get_time_str() 33 | print(f'Path already exists. Rename it to {new_name}', flush=True) 34 | os.rename(path, new_name) 35 | os.makedirs(path, exist_ok=True) 36 | 37 | 38 | @master_only 39 | def make_exp_dirs(opt): 40 | """Make dirs for experiments.""" 41 | path_opt = opt['path'].copy() 42 | if opt['is_train']: 43 | mkdir_and_rename(path_opt.pop('experiments_root')) 44 | else: 45 | mkdir_and_rename(path_opt.pop('results_root')) 46 | for key, path in path_opt.items(): 47 | if ('strict_load' not in key) and ('pretrain_network' 48 | not in key) and ('resume' 49 | not in key): 50 | os.makedirs(path, exist_ok=True) 51 | 52 | 53 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 54 | """Scan a directory to find the interested files. 55 | 56 | Args: 57 | dir_path (str): Path of the directory. 58 | suffix (str | tuple(str), optional): File suffix that we are 59 | interested in. Default: None. 60 | recursive (bool, optional): If set to True, recursively scan the 61 | directory. Default: False. 62 | full_path (bool, optional): If set to True, include the dir_path. 63 | Default: False. 64 | 65 | Returns: 66 | A generator for all the interested files with relative pathes. 67 | """ 68 | 69 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 70 | raise TypeError('"suffix" must be a string or tuple of strings') 71 | 72 | root = dir_path 73 | 74 | def _scandir(dir_path, suffix, recursive): 75 | for entry in os.scandir(dir_path): 76 | if not entry.name.startswith('.') and entry.is_file(): 77 | if full_path: 78 | return_path = entry.path 79 | else: 80 | return_path = osp.relpath(entry.path, root) 81 | 82 | if suffix is None: 83 | yield return_path 84 | elif return_path.endswith(suffix): 85 | yield return_path 86 | else: 87 | if recursive: 88 | yield from _scandir( 89 | entry.path, suffix=suffix, recursive=recursive) 90 | else: 91 | continue 92 | 93 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 94 | 95 | def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False): 96 | """Scan a directory to find the interested files. 97 | 98 | Args: 99 | dir_path (str): Path of the directory. 100 | keywords (str | tuple(str), optional): File keywords that we are 101 | interested in. Default: None. 102 | recursive (bool, optional): If set to True, recursively scan the 103 | directory. Default: False. 104 | full_path (bool, optional): If set to True, include the dir_path. 105 | Default: False. 106 | 107 | Returns: 108 | A generator for all the interested files with relative pathes. 109 | """ 110 | 111 | if (keywords is not None) and not isinstance(keywords, (str, tuple)): 112 | raise TypeError('"keywords" must be a string or tuple of strings') 113 | 114 | root = dir_path 115 | 116 | def _scandir(dir_path, keywords, recursive): 117 | for entry in os.scandir(dir_path): 118 | if not entry.name.startswith('.') and entry.is_file(): 119 | if full_path: 120 | return_path = entry.path 121 | else: 122 | return_path = osp.relpath(entry.path, root) 123 | 124 | if keywords is None: 125 | yield return_path 126 | elif return_path.find(keywords) > 0: 127 | yield return_path 128 | else: 129 | if recursive: 130 | yield from _scandir( 131 | entry.path, keywords=keywords, recursive=recursive) 132 | else: 133 | continue 134 | 135 | return _scandir(dir_path, keywords=keywords, recursive=recursive) 136 | 137 | def check_resume(opt, resume_iter): 138 | """Check resume states and pretrain_network paths. 139 | 140 | Args: 141 | opt (dict): Options. 142 | resume_iter (int): Resume iteration. 143 | """ 144 | logger = get_root_logger() 145 | if opt['path']['resume_state']: 146 | # get all the networks 147 | networks = [key for key in opt.keys() if key.startswith('network_')] 148 | flag_pretrain = False 149 | for network in networks: 150 | if opt['path'].get(f'pretrain_{network}') is not None: 151 | flag_pretrain = True 152 | if flag_pretrain: 153 | logger.warning( 154 | 'pretrain_network path will be ignored during resuming.') 155 | # set pretrained model paths 156 | for network in networks: 157 | name = f'pretrain_{network}' 158 | basename = network.replace('network_', '') 159 | if opt['path'].get('ignore_resume_networks') is None or ( 160 | basename not in opt['path']['ignore_resume_networks']): 161 | opt['path'][name] = osp.join( 162 | opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 163 | logger.info(f"Set {name} to {opt['path'][name]}") 164 | 165 | 166 | def sizeof_fmt(size, suffix='B'): 167 | """Get human readable file size. 168 | 169 | Args: 170 | size (int): File size. 171 | suffix (str): Suffix. Default: 'B'. 172 | 173 | Return: 174 | str: Formated file siz. 175 | """ 176 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 177 | if abs(size) < 1024.0: 178 | return f'{size:3.1f} {unit}{suffix}' 179 | size /= 1024.0 180 | return f'{size:3.1f} Y{suffix}' 181 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 25 | img = torch.from_numpy(img.transpose(2, 0, 1)) 26 | if float32: 27 | img = img.float() 28 | return img 29 | 30 | if isinstance(imgs, list): 31 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 32 | else: 33 | return _totensor(imgs, bgr2rgb, float32) 34 | 35 | 36 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 37 | """Convert torch Tensors into image numpy arrays. 38 | 39 | After clamping to [min, max], values will be normalized to [0, 1]. 40 | 41 | Args: 42 | tensor (Tensor or list[Tensor]): Accept shapes: 43 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 44 | 2) 3D Tensor of shape (3/1 x H x W); 45 | 3) 2D Tensor of shape (H x W). 46 | Tensor channel should be in RGB order. 47 | rgb2bgr (bool): Whether to change rgb to bgr. 48 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 49 | to uint8 type with range [0, 255]; otherwise, float type with 50 | range [0, 1]. Default: ``np.uint8``. 51 | min_max (tuple[int]): min and max values for clamp. 52 | 53 | Returns: 54 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 55 | shape (H x W). The channel order is BGR. 56 | """ 57 | if not (torch.is_tensor(tensor) or 58 | (isinstance(tensor, list) 59 | and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError( 61 | f'tensor or list of tensors expected, got {type(tensor)}') 62 | 63 | if torch.is_tensor(tensor): 64 | tensor = [tensor] 65 | result = [] 66 | for _tensor in tensor: 67 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 68 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 69 | 70 | n_dim = _tensor.dim() 71 | if n_dim == 4: 72 | img_np = make_grid( 73 | _tensor, nrow=int(math.sqrt(_tensor.size(0))), 74 | normalize=False).numpy() 75 | img_np = img_np.transpose(1, 2, 0) 76 | if rgb2bgr: 77 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 78 | elif n_dim == 3: 79 | img_np = _tensor.numpy() 80 | img_np = img_np.transpose(1, 2, 0) 81 | if img_np.shape[2] == 1: # gray image 82 | img_np = np.squeeze(img_np, axis=2) 83 | else: 84 | if rgb2bgr: 85 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 86 | elif n_dim == 2: 87 | img_np = _tensor.numpy() 88 | else: 89 | raise TypeError('Only support 4D, 3D or 2D tensor. ' 90 | f'But received with dimension: {n_dim}') 91 | if out_type == np.uint8: 92 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 93 | img_np = (img_np * 255.0).round() 94 | img_np = img_np.astype(out_type) 95 | result.append(img_np) 96 | if len(result) == 1: 97 | result = result[0] 98 | return result 99 | 100 | 101 | def imfrombytes(content, flag='color', float32=False): 102 | """Read an image from bytes. 103 | 104 | Args: 105 | content (bytes): Image bytes got from files or other streams. 106 | flag (str): Flags specifying the color type of a loaded image, 107 | candidates are `color`, `grayscale` and `unchanged`. 108 | float32 (bool): Whether to change to float32., If True, will also norm 109 | to [0, 1]. Default: False. 110 | 111 | Returns: 112 | ndarray: Loaded image array. 113 | """ 114 | img_np = np.frombuffer(content, np.uint8) 115 | imread_flags = { 116 | 'color': cv2.IMREAD_COLOR, 117 | 'grayscale': cv2.IMREAD_GRAYSCALE, 118 | 'unchanged': cv2.IMREAD_UNCHANGED 119 | } 120 | if img_np is None: 121 | raise Exception('None .. !!!') 122 | img = cv2.imdecode(img_np, imread_flags[flag]) 123 | if float32: 124 | img = img.astype(np.float32) / 255. 125 | return img 126 | 127 | def padding(img_lq, img_gt, gt_size): 128 | h, w, _ = img_lq.shape 129 | 130 | h_pad = max(0, gt_size - h) 131 | w_pad = max(0, gt_size - w) 132 | 133 | if h_pad == 0 and w_pad == 0: 134 | return img_lq, img_gt 135 | 136 | img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 137 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 138 | # print('img_lq', img_lq.shape, img_gt.shape) 139 | return img_lq, img_gt 140 | 141 | def imwrite(img, file_path, params=None, auto_mkdir=True): 142 | """Write image to file. 143 | 144 | Args: 145 | img (ndarray): Image array to be written. 146 | file_path (str): Image file path. 147 | params (None or list): Same as opencv's :func:`imwrite` interface. 148 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 149 | whether to create it automatically. 150 | 151 | Returns: 152 | bool: Successful or not. 153 | """ 154 | if auto_mkdir: 155 | dir_name = os.path.abspath(os.path.dirname(file_path)) 156 | os.makedirs(dir_name, exist_ok=True) 157 | return cv2.imwrite(file_path, img, params) 158 | 159 | 160 | def crop_border(imgs, crop_border): 161 | """Crop borders of images. 162 | 163 | Args: 164 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 165 | crop_border (int): Crop border for each end of height and weight. 166 | 167 | Returns: 168 | list[ndarray]: Cropped images. 169 | """ 170 | if crop_border == 0: 171 | return imgs 172 | else: 173 | if isinstance(imgs, list): 174 | return [ 175 | v[crop_border:-crop_border, crop_border:-crop_border, ...] 176 | for v in imgs 177 | ] 178 | else: 179 | return imgs[crop_border:-crop_border, crop_border:-crop_border, 180 | ...] 181 | -------------------------------------------------------------------------------- /basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | 8 | class MessageLogger(): 9 | """Message logger for printing. 10 | 11 | Args: 12 | opt (dict): Config. It contains the following keys: 13 | name (str): Exp name. 14 | logger (dict): Contains 'print_freq' (str) for logger interval. 15 | train (dict): Contains 'total_iter' (int) for total iters. 16 | use_tb_logger (bool): Use tensorboard logger. 17 | start_iter (int): Start iter. Default: 1. 18 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 19 | """ 20 | 21 | def __init__(self, opt, start_iter=1, tb_logger=None): 22 | self.exp_name = opt['name'] 23 | self.interval = opt['logger']['print_freq'] 24 | self.start_iter = start_iter 25 | self.max_iters = opt['train']['total_iter'] 26 | self.use_tb_logger = opt['logger']['use_tb_logger'] 27 | self.tb_logger = tb_logger 28 | self.start_time = time.time() 29 | self.logger = get_root_logger() 30 | 31 | @master_only 32 | def __call__(self, log_vars): 33 | """Format logging message. 34 | 35 | Args: 36 | log_vars (dict): It contains the following keys: 37 | epoch (int): Epoch number. 38 | iter (int): Current iter. 39 | lrs (list): List for learning rates. 40 | 41 | time (float): Iter time. 42 | data_time (float): Data time for each iter. 43 | """ 44 | # epoch, iter, learning rates 45 | epoch = log_vars.pop('epoch') 46 | current_iter = log_vars.pop('iter') 47 | lrs = log_vars.pop('lrs') 48 | 49 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' 50 | f'iter:{current_iter:8,d}, lr:(') 51 | for v in lrs: 52 | message += f'{v:.3e},' 53 | message += ')] ' 54 | 55 | # time and estimated time 56 | if 'time' in log_vars.keys(): 57 | iter_time = log_vars.pop('time') 58 | data_time = log_vars.pop('data_time') 59 | 60 | total_time = time.time() - self.start_time 61 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 62 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 63 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 64 | message += f'[eta: {eta_str}, ' 65 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 66 | 67 | # other items, especially losses 68 | for k, v in log_vars.items(): 69 | message += f'{k}: {v:.4e} ' 70 | # tensorboard logger 71 | if self.use_tb_logger and 'debug' not in self.exp_name: 72 | if k.startswith('l_'): 73 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 74 | else: 75 | self.tb_logger.add_scalar(k, v, current_iter) 76 | self.logger.info(message) 77 | 78 | 79 | @master_only 80 | def init_tb_logger(log_dir): 81 | from torch.utils.tensorboard import SummaryWriter 82 | tb_logger = SummaryWriter(log_dir=log_dir) 83 | return tb_logger 84 | 85 | 86 | @master_only 87 | def init_wandb_logger(opt): 88 | """We now only use wandb to sync tensorboard log.""" 89 | import wandb 90 | logger = logging.getLogger('basicsr') 91 | 92 | project = opt['logger']['wandb']['project'] 93 | resume_id = opt['logger']['wandb'].get('resume_id') 94 | if resume_id: 95 | wandb_id = resume_id 96 | resume = 'allow' 97 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 98 | else: 99 | wandb_id = wandb.util.generate_id() 100 | resume = 'never' 101 | 102 | wandb.init( 103 | id=wandb_id, 104 | resume=resume, 105 | name=opt['name'], 106 | config=opt, 107 | project=project, 108 | sync_tensorboard=True) 109 | 110 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 111 | 112 | 113 | def get_root_logger(logger_name='basicsr', 114 | log_level=logging.INFO, 115 | log_file=None): 116 | """Get the root logger. 117 | 118 | The logger will be initialized if it has not been initialized. By default a 119 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 120 | also be added. 121 | 122 | Args: 123 | logger_name (str): root logger name. Default: 'basicsr'. 124 | log_file (str | None): The log filename. If specified, a FileHandler 125 | will be added to the root logger. 126 | log_level (int): The root logger level. Note that only the process of 127 | rank 0 is affected, while other processes will set the level to 128 | "Error" and be silent most of the time. 129 | 130 | Returns: 131 | logging.Logger: The root logger. 132 | """ 133 | logger = logging.getLogger(logger_name) 134 | # if the logger has been initialized, just return it 135 | if logger.hasHandlers(): 136 | return logger 137 | 138 | format_str = '%(asctime)s %(levelname)s: %(message)s' 139 | logging.basicConfig(format=format_str, level=log_level) 140 | rank, _ = get_dist_info() 141 | if rank != 0: 142 | logger.setLevel('ERROR') 143 | elif log_file is not None: 144 | file_handler = logging.FileHandler(log_file, 'w') 145 | file_handler.setFormatter(logging.Formatter(format_str)) 146 | file_handler.setLevel(log_level) 147 | logger.addHandler(file_handler) 148 | 149 | return logger 150 | 151 | 152 | def get_env_info(): 153 | """Get environment information. 154 | 155 | Currently, only log the software version. 156 | """ 157 | import torch 158 | import torchvision 159 | 160 | from basicsr.version import __version__ 161 | msg = r""" 162 | ____ _ _____ ____ 163 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 164 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 165 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 166 | /_____/ \__,_//____//_/ \___//____//_/ |_| 167 | ______ __ __ __ __ 168 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 169 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 170 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 171 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 172 | """ 173 | msg += ('\nVersion Information: ' 174 | f'\n\tBasicSR: {__version__}' 175 | f'\n\tPyTorch: {torch.__version__}' 176 | f'\n\tTorchVision: {torchvision.__version__}') 177 | return msg 178 | -------------------------------------------------------------------------------- /basicsr/data/npz_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from torchvision.transforms.functional import normalize 3 | from tqdm import tqdm 4 | import os 5 | from pathlib import Path 6 | import random 7 | import numpy as np 8 | import torch 9 | 10 | from basicsr.data.data_util import (paired_paths_from_folder, 11 | paired_paths_from_lmdb, 12 | paired_paths_from_meta_info_file, 13 | recursive_glob) 14 | from basicsr.data.event_util import events_to_voxel_grid, voxel_norm 15 | from basicsr.data.transforms import augment, triple_random_crop, random_augmentation 16 | from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, get_root_logger 17 | from torch.utils.data.dataloader import default_collate 18 | 19 | 20 | class NpzPngSingleDeblurDataset(data.Dataset): 21 | """Paired npz and png dataset for event-based single image deblurring. 22 | --HighREV 23 | |----train 24 | | |----blur 25 | | | |----SEQNAME_%5d.png 26 | | | |----... 27 | | |----event 28 | | | |----SEQNAME_%5d_%2d.npz 29 | | | |----... 30 | | |----sharp 31 | | | |----SEQNAME_%5d.png 32 | | | |----... 33 | |----val 34 | ... 35 | 36 | 37 | Args: 38 | opt (dict): Config for train dataset. It contains the following keys: 39 | dataroot (str): Data root path. 40 | io_backend (dict): IO backend type and other kwarg. 41 | num_end_interpolation (int): Number of sharp frames to reconstruct in each blurry image. 42 | num_inter_interpolation (int): Number of sharp frames to interpolate between two blurry images. 43 | phase (str): 'train' or 'test' 44 | 45 | gt_size (int): Cropped patched size for gt patches. 46 | random_reverse (bool): Random reverse input frames. 47 | use_hflip (bool): Use horizontal flips. 48 | use_rot (bool): Use rotation (use vertical flip and transposing h 49 | and w for implementation). 50 | 51 | scale (bool): Scale, which will be added automatically. 52 | """ 53 | 54 | def __init__(self, opt): 55 | super(NpzPngSingleDeblurDataset, self).__init__() 56 | self.opt = opt 57 | self.dataroot = Path(opt['dataroot']) 58 | self.split = 'train' if opt['phase'] == 'train' else 'val' # train or val 59 | self.voxel_bins = opt['voxel_bins'] 60 | self.norm_voxel = opt['norm_voxel'] 61 | 62 | self.dataPath = [] 63 | 64 | blur_frames = sorted(recursive_glob(rootdir=os.path.join(self.dataroot, 'blur'), suffix='.png')) 65 | blur_frames = [os.path.join(self.dataroot, 'blur', blur_frame) for blur_frame in blur_frames] 66 | 67 | sharp_frames = sorted(recursive_glob(rootdir=os.path.join(self.dataroot, 'sharp'), suffix='.png')) 68 | sharp_frames = [os.path.join(self.dataroot, 'sharp', sharp_frame) for sharp_frame in sharp_frames] 69 | 70 | event_frames = sorted(recursive_glob(rootdir=os.path.join(self.dataroot, 'event'), suffix='.npz')) 71 | event_frames = [os.path.join(self.dataroot, 'event', event_frame) for event_frame in event_frames] 72 | 73 | assert len(blur_frames) == len(sharp_frames), f"Mismatch in blur ({len(blur_frames)}) and sharp ({len(sharp_frames)}) frame counts." 74 | 75 | for i in range(len(blur_frames)): 76 | blur_name = os.path.basename(blur_frames[i]) # e.g., SEQNAME_00001.png 77 | base_name = os.path.splitext(blur_name)[0] # Remove .png, get SEQNAME_00001 78 | event_list = sorted([f for f in event_frames if f.startswith(os.path.join(self.dataroot, 'event', base_name + '_'))])[:-1] # remove the last one because it is out of the exposure time 79 | 80 | self.dataPath.append({ 81 | 'blur_path': blur_frames[i], 82 | 'sharp_path': sharp_frames[i], 83 | 'event_paths': event_list 84 | }) 85 | logger = get_root_logger() 86 | logger.info(f"Dataset initialized with {len(self.dataPath)} samples.") 87 | 88 | # file client (io backend) 89 | self.file_client = None 90 | self.io_backend_opt = opt['io_backend'] 91 | # import pdb; pdb.set_trace() 92 | 93 | def __getitem__(self, index): 94 | if self.file_client is None: 95 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 96 | scale = self.opt['scale'] 97 | gt_size = self.opt['gt_size'] 98 | 99 | image_path = self.dataPath[index]['blur_path'] 100 | gt_path = self.dataPath[index]['sharp_path'] 101 | event_paths = self.dataPath[index]['event_paths'] 102 | 103 | # get LQ 104 | img_bytes = self.file_client.get(image_path) # 'lq' 105 | img_lq = imfrombytes(img_bytes, float32=True) 106 | # get GT 107 | img_bytes = self.file_client.get(gt_path) # 'gt' 108 | img_gt = imfrombytes(img_bytes, float32=True) 109 | 110 | h_lq, w_lq, _ = img_lq.shape 111 | 112 | ## Read event and convert to voxel grid: 113 | events = [np.load(event_path) for event_path in event_paths] 114 | # npz -> ndarray 115 | all_quad_event_array = np.zeros((0,4)).astype(np.float32) 116 | for event in events: 117 | ### IMPORTANT: dataset mistake x and y !!!!!!!! 118 | ### Switch x and y here !!!! 119 | y = event['x'].astype(np.float32) 120 | x = event['y'].astype(np.float32) 121 | t = event['timestamp'].astype(np.float32) 122 | p = event['polarity'].astype(np.float32) 123 | 124 | this_quad_event_array = np.concatenate((t,x,y,p),axis=1) # N,4 125 | all_quad_event_array = np.concatenate((all_quad_event_array, this_quad_event_array), axis=0) 126 | voxel = events_to_voxel_grid(all_quad_event_array, num_bins=self.voxel_bins, width=w_lq, height=h_lq, return_format='HWC') 127 | 128 | ## Data augmentation 129 | # voxel shape: h,w,c 130 | # crop 131 | if gt_size is not None: 132 | img_gt, img_lq, voxel = triple_random_crop(img_gt, img_lq, voxel, gt_size, scale, gt_path) 133 | 134 | # flip, rotate 135 | total_input = [img_lq, img_gt, voxel] 136 | img_results = augment(total_input, self.opt['use_hflip'], self.opt['use_rot']) 137 | 138 | img_results = img2tensor(img_results) # hwc -> chw 139 | img_lq, img_gt, voxel = img_results 140 | 141 | ## Norm voxel 142 | if self.norm_voxel: 143 | voxel = voxel_norm(voxel) 144 | 145 | origin_index = os.path.basename(image_path).split('.')[0] 146 | 147 | return {'frame': img_lq, 'frame_gt': img_gt, 'voxel': voxel, 'image_name': origin_index} 148 | 149 | def __len__(self): 150 | return len(self.dataPath) 151 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, 20 | optimizer, 21 | milestones, 22 | gamma=0.1, 23 | restarts=(0, ), 24 | restart_weights=(1, ), 25 | last_epoch=-1): 26 | self.milestones = Counter(milestones) 27 | self.gamma = gamma 28 | self.restarts = restarts 29 | self.restart_weights = restart_weights 30 | assert len(self.restarts) == len( 31 | self.restart_weights), 'restarts and their weights do not match.' 32 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 33 | 34 | def get_lr(self): 35 | if self.last_epoch in self.restarts: 36 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 37 | return [ 38 | group['initial_lr'] * weight 39 | for group in self.optimizer.param_groups 40 | ] 41 | if self.last_epoch not in self.milestones: 42 | return [group['lr'] for group in self.optimizer.param_groups] 43 | return [ 44 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 45 | for group in self.optimizer.param_groups 46 | ] 47 | 48 | class LinearLR(_LRScheduler): 49 | """ 50 | 51 | Args: 52 | optimizer (torch.nn.optimizer): Torch optimizer. 53 | milestones (list): Iterations that will decrease learning rate. 54 | gamma (float): Decrease ratio. Default: 0.1. 55 | last_epoch (int): Used in _LRScheduler. Default: -1. 56 | """ 57 | 58 | def __init__(self, 59 | optimizer, 60 | total_iter, 61 | last_epoch=-1): 62 | self.total_iter = total_iter 63 | super(LinearLR, self).__init__(optimizer, last_epoch) 64 | 65 | def get_lr(self): 66 | process = self.last_epoch / self.total_iter 67 | weight = (1 - process) 68 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups]) 69 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 70 | 71 | class VibrateLR(_LRScheduler): 72 | """ 73 | 74 | Args: 75 | optimizer (torch.nn.optimizer): Torch optimizer. 76 | milestones (list): Iterations that will decrease learning rate. 77 | gamma (float): Decrease ratio. Default: 0.1. 78 | last_epoch (int): Used in _LRScheduler. Default: -1. 79 | """ 80 | 81 | def __init__(self, 82 | optimizer, 83 | total_iter, 84 | last_epoch=-1): 85 | self.total_iter = total_iter 86 | super(VibrateLR, self).__init__(optimizer, last_epoch) 87 | 88 | def get_lr(self): 89 | process = self.last_epoch / self.total_iter 90 | 91 | f = 0.1 92 | if process < 3 / 8: 93 | f = 1 - process * 8 / 3 94 | elif process < 5 / 8: 95 | f = 0.2 96 | 97 | T = self.total_iter // 80 98 | Th = T // 2 99 | 100 | t = self.last_epoch % T 101 | 102 | f2 = t / Th 103 | if t >= Th: 104 | f2 = 2 - f2 105 | 106 | weight = f * f2 107 | 108 | if self.last_epoch < Th: 109 | weight = max(0.1, weight) 110 | 111 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2)) 112 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 113 | 114 | def get_position_from_periods(iteration, cumulative_period): 115 | """Get the position from a period list. 116 | 117 | It will return the index of the right-closest number in the period list. 118 | For example, the cumulative_period = [100, 200, 300, 400], 119 | if iteration == 50, return 0; 120 | if iteration == 210, return 2; 121 | if iteration == 300, return 2. 122 | 123 | Args: 124 | iteration (int): Current iteration. 125 | cumulative_period (list[int]): Cumulative period list. 126 | 127 | Returns: 128 | int: The position of the right-closest number in the period list. 129 | """ 130 | for i, period in enumerate(cumulative_period): 131 | if iteration <= period: 132 | return i 133 | 134 | 135 | class CosineAnnealingRestartLR(_LRScheduler): 136 | """ Cosine annealing with restarts learning rate scheme. 137 | 138 | An example of config: 139 | periods = [10, 10, 10, 10] 140 | restart_weights = [1, 0.5, 0.5, 0.5] 141 | eta_min=1e-7 142 | 143 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 144 | scheduler will restart with the weights in restart_weights. 145 | 146 | Args: 147 | optimizer (torch.nn.optimizer): Torch optimizer. 148 | periods (list): Period for each cosine anneling cycle. 149 | restart_weights (list): Restart weights at each restart iteration. 150 | Default: [1]. 151 | eta_min (float): The mimimum lr. Default: 0. 152 | last_epoch (int): Used in _LRScheduler. Default: -1. 153 | """ 154 | 155 | def __init__(self, 156 | optimizer, 157 | periods, 158 | restart_weights=(1, ), 159 | eta_min=0, 160 | last_epoch=-1): 161 | self.periods = periods 162 | self.restart_weights = restart_weights 163 | self.eta_min = eta_min 164 | assert (len(self.periods) == len(self.restart_weights) 165 | ), 'periods and restart_weights should have the same length.' 166 | self.cumulative_period = [ 167 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 168 | ] 169 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 170 | 171 | def get_lr(self): 172 | idx = get_position_from_periods(self.last_epoch, 173 | self.cumulative_period) 174 | current_weight = self.restart_weights[idx] 175 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 176 | current_period = self.periods[idx] 177 | 178 | return [ 179 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 180 | (1 + math.cos(math.pi * ( 181 | (self.last_epoch - nearest_restart) / current_period))) 182 | for base_lr in self.base_lrs 183 | ] 184 | -------------------------------------------------------------------------------- /basicsr/utils/flow_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 2 | import cv2 3 | import numpy as np 4 | import os 5 | 6 | 7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): 8 | """Read an optical flow map. 9 | 10 | Args: 11 | flow_path (ndarray or str): Flow path. 12 | quantize (bool): whether to read quantized pair, if set to True, 13 | remaining args will be passed to :func:`dequantize_flow`. 14 | concat_axis (int): The axis that dx and dy are concatenated, 15 | can be either 0 or 1. Ignored if quantize is False. 16 | 17 | Returns: 18 | ndarray: Optical flow represented as a (h, w, 2) numpy array 19 | """ 20 | if quantize: 21 | assert concat_axis in [0, 1] 22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) 23 | if cat_flow.ndim != 2: 24 | raise IOError(f'{flow_path} is not a valid quantized flow file, ' 25 | f'its dimension is {cat_flow.ndim}.') 26 | assert cat_flow.shape[concat_axis] % 2 == 0 27 | dx, dy = np.split(cat_flow, 2, axis=concat_axis) 28 | flow = dequantize_flow(dx, dy, *args, **kwargs) 29 | else: 30 | with open(flow_path, 'rb') as f: 31 | try: 32 | header = f.read(4).decode('utf-8') 33 | except Exception: 34 | raise IOError(f'Invalid flow file: {flow_path}') 35 | else: 36 | if header != 'PIEH': 37 | raise IOError(f'Invalid flow file: {flow_path}, ' 38 | 'header does not contain PIEH') 39 | 40 | w = np.fromfile(f, np.int32, 1).squeeze() 41 | h = np.fromfile(f, np.int32, 1).squeeze() 42 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) 43 | 44 | return flow.astype(np.float32) 45 | 46 | 47 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): 48 | """Write optical flow to file. 49 | 50 | If the flow is not quantized, it will be saved as a .flo file losslessly, 51 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy 52 | will be concatenated horizontally into a single image if quantize is True.) 53 | 54 | Args: 55 | flow (ndarray): (h, w, 2) array of optical flow. 56 | filename (str): Output filepath. 57 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg 58 | images. If set to True, remaining args will be passed to 59 | :func:`quantize_flow`. 60 | concat_axis (int): The axis that dx and dy are concatenated, 61 | can be either 0 or 1. Ignored if quantize is False. 62 | """ 63 | if not quantize: 64 | with open(filename, 'wb') as f: 65 | f.write('PIEH'.encode('utf-8')) 66 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 67 | flow = flow.astype(np.float32) 68 | flow.tofile(f) 69 | f.flush() 70 | else: 71 | assert concat_axis in [0, 1] 72 | dx, dy = quantize_flow(flow, *args, **kwargs) 73 | dxdy = np.concatenate((dx, dy), axis=concat_axis) 74 | os.makedirs(filename, exist_ok=True) 75 | cv2.imwrite(dxdy, filename) 76 | 77 | 78 | def quantize_flow(flow, max_val=0.02, norm=True): 79 | """Quantize flow to [0, 255]. 80 | 81 | After this step, the size of flow will be much smaller, and can be 82 | dumped as jpeg images. 83 | 84 | Args: 85 | flow (ndarray): (h, w, 2) array of optical flow. 86 | max_val (float): Maximum value of flow, values beyond 87 | [-max_val, max_val] will be truncated. 88 | norm (bool): Whether to divide flow values by image width/height. 89 | 90 | Returns: 91 | tuple[ndarray]: Quantized dx and dy. 92 | """ 93 | h, w, _ = flow.shape 94 | dx = flow[..., 0] 95 | dy = flow[..., 1] 96 | if norm: 97 | dx = dx / w # avoid inplace operations 98 | dy = dy / h 99 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. 100 | flow_comps = [ 101 | quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy] 102 | ] 103 | return tuple(flow_comps) 104 | 105 | 106 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True): 107 | """Recover from quantized flow. 108 | 109 | Args: 110 | dx (ndarray): Quantized dx. 111 | dy (ndarray): Quantized dy. 112 | max_val (float): Maximum value used when quantizing. 113 | denorm (bool): Whether to multiply flow values with width/height. 114 | 115 | Returns: 116 | ndarray: Dequantized flow. 117 | """ 118 | assert dx.shape == dy.shape 119 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) 120 | 121 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] 122 | 123 | if denorm: 124 | dx *= dx.shape[1] 125 | dy *= dx.shape[0] 126 | flow = np.dstack((dx, dy)) 127 | return flow 128 | 129 | 130 | def quantize(arr, min_val, max_val, levels, dtype=np.int64): 131 | """Quantize an array of (-inf, inf) to [0, levels-1]. 132 | 133 | Args: 134 | arr (ndarray): Input array. 135 | min_val (scalar): Minimum value to be clipped. 136 | max_val (scalar): Maximum value to be clipped. 137 | levels (int): Quantization levels. 138 | dtype (np.type): The type of the quantized array. 139 | 140 | Returns: 141 | tuple: Quantized array. 142 | """ 143 | if not (isinstance(levels, int) and levels > 1): 144 | raise ValueError( 145 | f'levels must be a positive integer, but got {levels}') 146 | if min_val >= max_val: 147 | raise ValueError( 148 | f'min_val ({min_val}) must be smaller than max_val ({max_val})') 149 | 150 | arr = np.clip(arr, min_val, max_val) - min_val 151 | quantized_arr = np.minimum( 152 | np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) 153 | 154 | return quantized_arr 155 | 156 | 157 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64): 158 | """Dequantize an array. 159 | 160 | Args: 161 | arr (ndarray): Input array. 162 | min_val (scalar): Minimum value to be clipped. 163 | max_val (scalar): Maximum value to be clipped. 164 | levels (int): Quantization levels. 165 | dtype (np.type): The type of the dequantized array. 166 | 167 | Returns: 168 | tuple: Dequantized array. 169 | """ 170 | if not (isinstance(levels, int) and levels > 1): 171 | raise ValueError( 172 | f'levels must be a positive integer, but got {levels}') 173 | if min_val >= max_val: 174 | raise ValueError( 175 | f'min_val ({min_val}) must be smaller than max_val ({max_val})') 176 | 177 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - 178 | min_val) / levels + min_val 179 | 180 | return dequantized_arr 181 | -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError( 40 | 'Please install memcached to enable MemcachedBackend.') 41 | 42 | self.server_list_cfg = server_list_cfg 43 | self.client_cfg = client_cfg 44 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, 45 | self.client_cfg) 46 | # mc.pyvector servers as a point which points to a memory cache 47 | self._mc_buffer = mc.pyvector() 48 | 49 | def get(self, filepath): 50 | filepath = str(filepath) 51 | import mc 52 | self._client.Get(filepath, self._mc_buffer) 53 | value_buf = mc.ConvertBuffer(self._mc_buffer) 54 | return value_buf 55 | 56 | def get_text(self, filepath): 57 | raise NotImplementedError 58 | 59 | 60 | class HardDiskBackend(BaseStorageBackend): 61 | """Raw hard disks storage backend.""" 62 | 63 | def get(self, filepath): 64 | filepath = str(filepath) 65 | with open(filepath, 'rb') as f: 66 | value_buf = f.read() 67 | return value_buf 68 | 69 | def get_text(self, filepath): 70 | filepath = str(filepath) 71 | with open(filepath, 'r') as f: 72 | value_buf = f.read() 73 | return value_buf 74 | 75 | 76 | class LmdbBackend(BaseStorageBackend): 77 | """Lmdb storage backend. 78 | 79 | Args: 80 | db_paths (str | list[str]): Lmdb database paths. 81 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 82 | readonly (bool, optional): Lmdb environment parameter. If True, 83 | disallow any write operations. Default: True. 84 | lock (bool, optional): Lmdb environment parameter. If False, when 85 | concurrent access occurs, do not lock the database. Default: False. 86 | readahead (bool, optional): Lmdb environment parameter. If False, 87 | disable the OS filesystem readahead mechanism, which may improve 88 | random read performance when a database is larger than RAM. 89 | Default: False. 90 | 91 | Attributes: 92 | db_paths (list): Lmdb database path. 93 | _client (list): A list of several lmdb envs. 94 | """ 95 | 96 | def __init__(self, 97 | db_paths, 98 | client_keys='default', 99 | readonly=True, 100 | lock=False, 101 | readahead=False, 102 | **kwargs): 103 | try: 104 | import lmdb 105 | except ImportError: 106 | raise ImportError('Please install lmdb to enable LmdbBackend.') 107 | 108 | if isinstance(client_keys, str): 109 | client_keys = [client_keys] 110 | 111 | if isinstance(db_paths, list): 112 | self.db_paths = [str(v) for v in db_paths] 113 | elif isinstance(db_paths, str): 114 | self.db_paths = [str(db_paths)] 115 | assert len(client_keys) == len(self.db_paths), ( 116 | 'client_keys and db_paths should have the same length, ' 117 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 118 | 119 | self._client = {} 120 | 121 | for client, path in zip(client_keys, self.db_paths): 122 | self._client[client] = lmdb.open( 123 | path, 124 | readonly=readonly, 125 | lock=lock, 126 | readahead=readahead, 127 | map_size=8*1024*10485760, 128 | # max_readers=1, 129 | **kwargs) 130 | 131 | def get(self, filepath, client_key): 132 | """Get values according to the filepath from one lmdb named client_key. 133 | 134 | Args: 135 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 136 | client_key (str): Used for distinguishing differnet lmdb envs. 137 | """ 138 | filepath = str(filepath) 139 | assert client_key in self._client, (f'client_key {client_key} is not ' 140 | 'in lmdb clients.') 141 | client = self._client[client_key] 142 | with client.begin(write=False) as txn: 143 | value_buf = txn.get(filepath.encode('ascii')) 144 | return value_buf 145 | 146 | def get_text(self, filepath): 147 | raise NotImplementedError 148 | 149 | 150 | class FileClient(object): 151 | """A general file client to access files in different backend. 152 | 153 | The client loads a file or text in a specified backend from its path 154 | and return it as a binary file. it can also register other backend 155 | accessor with a given name and backend class. 156 | 157 | Attributes: 158 | backend (str): The storage backend type. Options are "disk", 159 | "memcached" and "lmdb". 160 | client (:obj:`BaseStorageBackend`): The backend object. 161 | """ 162 | 163 | _backends = { 164 | 'disk': HardDiskBackend, 165 | 'memcached': MemcachedBackend, 166 | 'lmdb': LmdbBackend, 167 | } 168 | 169 | def __init__(self, backend='disk', **kwargs): 170 | if backend not in self._backends: 171 | raise ValueError( 172 | f'Backend {backend} is not supported. Currently supported ones' 173 | f' are {list(self._backends.keys())}') 174 | self.backend = backend 175 | self.client = self._backends[backend](**kwargs) 176 | 177 | def get(self, filepath, client_key='default'): 178 | # client_key is used only for lmdb, where different fileclients have 179 | # different lmdb environments. 180 | if self.backend == 'lmdb': 181 | return self.client.get(filepath, client_key) 182 | else: 183 | return self.client.get(filepath) 184 | 185 | def get_text(self, filepath): 186 | return self.client.get_text(filepath) 187 | -------------------------------------------------------------------------------- /basicsr/models/losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from basicsr.models.losses.loss_util import weighted_loss 7 | 8 | _reduction_modes = ['none', 'mean', 'sum'] 9 | 10 | 11 | @weighted_loss 12 | def l1_loss(pred, target): 13 | return F.l1_loss(pred, target, reduction='none') 14 | 15 | 16 | @weighted_loss 17 | def mse_loss(pred, target): 18 | return F.mse_loss(pred, target, reduction='none') 19 | 20 | 21 | # AT loss 22 | def at(x): 23 | return F.normalize(x.pow(2).mean(1).view(x.size(0), -1)) 24 | 25 | def at_loss(x, y): 26 | return (at(x) - at(y)).pow(2).mean() 27 | 28 | @weighted_loss 29 | def charbonnier_loss(pred, target, eps=1e-12): 30 | return torch.sqrt((pred - target)**2 + eps) 31 | 32 | # @weighted_loss 33 | # def charbonnier_loss(pred, target, eps=1e-12): 34 | # return torch.sqrt((pred - target)**2 + eps) 35 | 36 | 37 | class L1Loss(nn.Module): 38 | """L1 (mean absolute error, MAE) loss. 39 | 40 | Args: 41 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 42 | reduction (str): Specifies the reduction to apply to the output. 43 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 44 | """ 45 | 46 | def __init__(self, loss_weight=1.0, reduction='mean'): 47 | super(L1Loss, self).__init__() 48 | if reduction not in ['none', 'mean', 'sum']: 49 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 50 | f'Supported ones are: {_reduction_modes}') 51 | 52 | self.loss_weight = loss_weight 53 | self.reduction = reduction 54 | 55 | def forward(self, pred, target, weight=None, **kwargs): 56 | """ 57 | Args: 58 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 59 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 60 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 61 | weights. Default: None. 62 | """ 63 | return self.loss_weight * l1_loss( 64 | pred, target, weight, reduction=self.reduction) 65 | 66 | class MSELoss(nn.Module): 67 | """MSE (L2) loss. 68 | 69 | Args: 70 | loss_weight (float): Loss weight for MSE loss. Default: 1.0. 71 | reduction (str): Specifies the reduction to apply to the output. 72 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 73 | """ 74 | 75 | def __init__(self, loss_weight=1.0, reduction='mean'): 76 | super(MSELoss, self).__init__() 77 | if reduction not in ['none', 'mean', 'sum']: 78 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 79 | f'Supported ones are: {_reduction_modes}') 80 | 81 | self.loss_weight = loss_weight 82 | self.reduction = reduction 83 | 84 | def forward(self, pred, target, weight=None, **kwargs): 85 | """ 86 | Args: 87 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 88 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 89 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 90 | weights. Default: None. 91 | """ 92 | return self.loss_weight * mse_loss( 93 | pred, target, weight, reduction=self.reduction) 94 | 95 | class PSNRLoss(nn.Module): 96 | 97 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 98 | super(PSNRLoss, self).__init__() 99 | assert reduction == 'mean' 100 | self.loss_weight = loss_weight 101 | self.scale = 10 / np.log(10) 102 | self.toY = toY 103 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 104 | self.first = True 105 | 106 | def forward(self, pred, target): 107 | assert len(pred.size()) == 4 108 | if self.toY: 109 | if self.first: 110 | self.coef = self.coef.to(pred.device) 111 | self.first = False 112 | 113 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 114 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 115 | 116 | pred, target = pred / 255., target / 255. 117 | pass 118 | assert len(pred.size()) == 4 119 | 120 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() 121 | 122 | 123 | class SRNLoss(nn.Module): 124 | 125 | def __init__(self): 126 | super(SRNLoss, self).__init__() 127 | 128 | def forward(self, preds, target): 129 | 130 | gt1 = target 131 | B,C,H,W = gt1.shape 132 | gt2 = F.interpolate(gt1, size=(H // 2, W // 2), mode='bilinear', align_corners=False) 133 | gt3 = F.interpolate(gt1, size=(H // 4, W // 4), mode='bilinear', align_corners=False) 134 | 135 | l1 = mse_loss(preds[0] , gt3) 136 | l2 = mse_loss(preds[1] , gt2) 137 | l3 = mse_loss(preds[2] , gt1) 138 | 139 | return l1+l2+l3 140 | 141 | 142 | 143 | class CharbonnierLoss(nn.Module): 144 | """Charbonnier loss (one variant of Robust L1Loss, a differentiable 145 | variant of L1Loss). 146 | Described in "Deep Laplacian Pyramid Networks for Fast and Accurate 147 | Super-Resolution". 148 | Args: 149 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 150 | reduction (str): Specifies the reduction to apply to the output. 151 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 152 | eps (float): A value used to control the curvature near zero. 153 | Default: 1e-12. 154 | """ 155 | 156 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12): 157 | super(CharbonnierLoss, self).__init__() 158 | if reduction not in ['none', 'mean', 'sum']: 159 | raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') 160 | 161 | self.loss_weight = loss_weight 162 | self.reduction = reduction 163 | self.eps = eps 164 | 165 | def forward(self, pred, target, weight=None, **kwargs): 166 | """ 167 | Args: 168 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 169 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 170 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 171 | weights. Default: None. 172 | """ 173 | return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction) 174 | 175 | 176 | class WeightedTVLoss(L1Loss): 177 | """Weighted TV loss. 178 | Args: 179 | loss_weight (float): Loss weight. Default: 1.0. 180 | """ 181 | 182 | def __init__(self, loss_weight=1.0): 183 | super(WeightedTVLoss, self).__init__(loss_weight=loss_weight) 184 | 185 | def forward(self, pred, weight=None): 186 | if weight is None: 187 | y_weight = None 188 | x_weight = None 189 | else: 190 | y_weight = weight[:, :, :-1, :] 191 | x_weight = weight[:, :, :, :-1] 192 | 193 | y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight) 194 | x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight) 195 | 196 | loss = x_diff + y_diff 197 | 198 | return loss -------------------------------------------------------------------------------- /basicsr/utils/lmdb_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import lmdb 3 | import sys 4 | from multiprocessing import Pool 5 | from os import path as osp 6 | from tqdm import tqdm 7 | 8 | 9 | def make_lmdb_from_imgs(data_path, 10 | lmdb_path, 11 | img_path_list, 12 | keys, 13 | batch=5000, 14 | compress_level=1, 15 | multiprocessing_read=False, 16 | n_thread=40, 17 | map_size=None): 18 | """Make lmdb from images. 19 | 20 | Contents of lmdb. The file structure is: 21 | example.lmdb 22 | ├── data.mdb 23 | ├── lock.mdb 24 | ├── meta_info.txt 25 | 26 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 27 | https://lmdb.readthedocs.io/en/release/ for more details. 28 | 29 | The meta_info.txt is a specified txt file to record the meta information 30 | of our datasets. It will be automatically created when preparing 31 | datasets by our provided dataset tools. 32 | Each line in the txt file records 1)image name (with extension), 33 | 2)image shape, and 3)compression level, separated by a white space. 34 | 35 | For example, the meta information could be: 36 | `000_00000000.png (720,1280,3) 1`, which means: 37 | 1) image name (with extension): 000_00000000.png; 38 | 2) image shape: (720,1280,3); 39 | 3) compression level: 1 40 | 41 | We use the image name without extension as the lmdb key. 42 | 43 | If `multiprocessing_read` is True, it will read all the images to memory 44 | using multiprocessing. Thus, your server needs to have enough memory. 45 | 46 | Args: 47 | data_path (str): Data path for reading images. 48 | lmdb_path (str): Lmdb save path. 49 | img_path_list (str): Image path list. 50 | keys (str): Used for lmdb keys. 51 | batch (int): After processing batch images, lmdb commits. 52 | Default: 5000. 53 | compress_level (int): Compress level when encoding images. Default: 1. 54 | multiprocessing_read (bool): Whether use multiprocessing to read all 55 | the images to memory. Default: False. 56 | n_thread (int): For multiprocessing. 57 | map_size (int | None): Map size for lmdb env. If None, use the 58 | estimated size from images. Default: None 59 | """ 60 | 61 | assert len(img_path_list) == len(keys), ( 62 | 'img_path_list and keys should have the same length, ' 63 | f'but got {len(img_path_list)} and {len(keys)}') 64 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...') 65 | print(f'Total images: {len(img_path_list)}') 66 | if not lmdb_path.endswith('.lmdb'): 67 | raise ValueError("lmdb_path must end with '.lmdb'.") 68 | if osp.exists(lmdb_path): 69 | print(f'Folder {lmdb_path} already exists. Exit.') 70 | sys.exit(1) 71 | 72 | if multiprocessing_read: 73 | # read all the images to memory (multiprocessing) 74 | dataset = {} # use dict to keep the order for multiprocessing 75 | shapes = {} 76 | print(f'Read images with multiprocessing, #thread: {n_thread} ...') 77 | pbar = tqdm(total=len(img_path_list), unit='image') 78 | 79 | def callback(arg): 80 | """get the image data and update pbar.""" 81 | key, dataset[key], shapes[key] = arg 82 | pbar.update(1) 83 | pbar.set_description(f'Read {key}') 84 | 85 | pool = Pool(n_thread) 86 | for path, key in zip(img_path_list, keys): 87 | pool.apply_async( 88 | read_img_worker, 89 | args=(osp.join(data_path, path), key, compress_level), 90 | callback=callback) 91 | pool.close() 92 | pool.join() 93 | pbar.close() 94 | print(f'Finish reading {len(img_path_list)} images.') 95 | 96 | # create lmdb environment 97 | if map_size is None: 98 | # obtain data size for one image 99 | img = cv2.imread( 100 | osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) 101 | _, img_byte = cv2.imencode( 102 | '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 103 | data_size_per_img = img_byte.nbytes 104 | print('Data size per image is: ', data_size_per_img) 105 | data_size = data_size_per_img * len(img_path_list) 106 | map_size = data_size * 10 107 | 108 | env = lmdb.open(lmdb_path, map_size=map_size) 109 | 110 | # write data to lmdb 111 | pbar = tqdm(total=len(img_path_list), unit='chunk') 112 | txn = env.begin(write=True) 113 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 114 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 115 | pbar.update(1) 116 | pbar.set_description(f'Write {key}') 117 | key_byte = key.encode('ascii') 118 | if multiprocessing_read: 119 | img_byte = dataset[key] 120 | h, w, c = shapes[key] 121 | else: 122 | _, img_byte, img_shape = read_img_worker( 123 | osp.join(data_path, path), key, compress_level) 124 | h, w, c = img_shape 125 | 126 | txn.put(key_byte, img_byte) 127 | # write meta information 128 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') 129 | if idx % batch == 0: 130 | txn.commit() 131 | txn = env.begin(write=True) 132 | pbar.close() 133 | txn.commit() 134 | env.close() 135 | txt_file.close() 136 | print('\nFinish writing lmdb.') 137 | 138 | 139 | def read_img_worker(path, key, compress_level): 140 | """Read image worker. 141 | 142 | Args: 143 | path (str): Image path. 144 | key (str): Image key. 145 | compress_level (int): Compress level when encoding images. 146 | 147 | Returns: 148 | str: Image key. 149 | byte: Image byte. 150 | tuple[int]: Image shape. 151 | """ 152 | 153 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 154 | if img.ndim == 2: 155 | h, w = img.shape 156 | c = 1 157 | else: 158 | h, w, c = img.shape 159 | _, img_byte = cv2.imencode('.png', img, 160 | [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 161 | return (key, img_byte, (h, w, c)) 162 | 163 | 164 | class LmdbMaker(): 165 | """LMDB Maker. 166 | 167 | Args: 168 | lmdb_path (str): Lmdb save path. 169 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. 170 | batch (int): After processing batch images, lmdb commits. 171 | Default: 5000. 172 | compress_level (int): Compress level when encoding images. Default: 1. 173 | """ 174 | 175 | def __init__(self, 176 | lmdb_path, 177 | map_size=1024**4, 178 | batch=5000, 179 | compress_level=1): 180 | if not lmdb_path.endswith('.lmdb'): 181 | raise ValueError("lmdb_path must end with '.lmdb'.") 182 | if osp.exists(lmdb_path): 183 | print(f'Folder {lmdb_path} already exists. Exit.') 184 | sys.exit(1) 185 | 186 | self.lmdb_path = lmdb_path 187 | self.batch = batch 188 | self.compress_level = compress_level 189 | self.env = lmdb.open(lmdb_path, map_size=map_size) 190 | self.txn = self.env.begin(write=True) 191 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 192 | self.counter = 0 193 | 194 | def put(self, img_byte, key, img_shape): 195 | self.counter += 1 196 | key_byte = key.encode('ascii') 197 | self.txn.put(key_byte, img_byte) 198 | # write meta information 199 | h, w, c = img_shape 200 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') 201 | if self.counter % self.batch == 0: 202 | self.txn.commit() 203 | self.txn = self.env.begin(write=True) 204 | 205 | def close(self): 206 | self.txn.commit() 207 | self.env.close() 208 | self.txt_file.close() 209 | -------------------------------------------------------------------------------- /basicsr/metrics/niqe.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | from scipy.ndimage.filters import convolve 5 | from scipy.special import gamma 6 | 7 | from basicsr.metrics.metric_util import reorder_image, to_y_channel 8 | 9 | 10 | def estimate_aggd_param(block): 11 | """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) paramters. 12 | 13 | Args: 14 | block (ndarray): 2D Image block. 15 | 16 | Returns: 17 | tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD 18 | distribution (Estimating the parames in Equation 7 in the paper). 19 | """ 20 | block = block.flatten() 21 | gam = np.arange(0.2, 10.001, 0.001) # len = 9801 22 | gam_reciprocal = np.reciprocal(gam) 23 | r_gam = np.square(gamma(gam_reciprocal * 2)) / ( 24 | gamma(gam_reciprocal) * gamma(gam_reciprocal * 3)) 25 | 26 | left_std = np.sqrt(np.mean(block[block < 0]**2)) 27 | right_std = np.sqrt(np.mean(block[block > 0]**2)) 28 | gammahat = left_std / right_std 29 | rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2) 30 | rhatnorm = (rhat * (gammahat**3 + 1) * 31 | (gammahat + 1)) / ((gammahat**2 + 1)**2) 32 | array_position = np.argmin((r_gam - rhatnorm)**2) 33 | 34 | alpha = gam[array_position] 35 | beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) 36 | beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) 37 | return (alpha, beta_l, beta_r) 38 | 39 | 40 | def compute_feature(block): 41 | """Compute features. 42 | 43 | Args: 44 | block (ndarray): 2D Image block. 45 | 46 | Returns: 47 | list: Features with length of 18. 48 | """ 49 | feat = [] 50 | alpha, beta_l, beta_r = estimate_aggd_param(block) 51 | feat.extend([alpha, (beta_l + beta_r) / 2]) 52 | 53 | # distortions disturb the fairly regular structure of natural images. 54 | # This deviation can be captured by analyzing the sample distribution of 55 | # the products of pairs of adjacent coefficients computed along 56 | # horizontal, vertical and diagonal orientations. 57 | shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] 58 | for i in range(len(shifts)): 59 | shifted_block = np.roll(block, shifts[i], axis=(0, 1)) 60 | alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block) 61 | # Eq. 8 62 | mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha)) 63 | feat.extend([alpha, mean, beta_l, beta_r]) 64 | return feat 65 | 66 | 67 | def niqe(img, 68 | mu_pris_param, 69 | cov_pris_param, 70 | gaussian_window, 71 | block_size_h=96, 72 | block_size_w=96): 73 | """Calculate NIQE (Natural Image Quality Evaluator) metric. 74 | 75 | Ref: Making a "Completely Blind" Image Quality Analyzer. 76 | This implementation could produce almost the same results as the official 77 | MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip 78 | 79 | Note that we do not include block overlap height and width, since they are 80 | always 0 in the official implementation. 81 | 82 | For good performance, it is advisable by the official implemtation to 83 | divide the distorted image in to the same size patched as used for the 84 | construction of multivariate Gaussian model. 85 | 86 | Args: 87 | img (ndarray): Input image whose quality needs to be computed. The 88 | image must be a gray or Y (of YCbCr) image with shape (h, w). 89 | Range [0, 255] with float type. 90 | mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian 91 | model calculated on the pristine dataset. 92 | cov_pris_param (ndarray): Covariance of a pre-defined multivariate 93 | Gaussian model calculated on the pristine dataset. 94 | gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the 95 | image. 96 | block_size_h (int): Height of the blocks in to which image is divided. 97 | Default: 96 (the official recommended value). 98 | block_size_w (int): Width of the blocks in to which image is divided. 99 | Default: 96 (the official recommended value). 100 | """ 101 | assert img.ndim == 2, ( 102 | 'Input image must be a gray or Y (of YCbCr) image with shape (h, w).') 103 | # crop image 104 | h, w = img.shape 105 | num_block_h = math.floor(h / block_size_h) 106 | num_block_w = math.floor(w / block_size_w) 107 | img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w] 108 | 109 | distparam = [] # dist param is actually the multiscale features 110 | for scale in (1, 2): # perform on two scales (1, 2) 111 | mu = convolve(img, gaussian_window, mode='nearest') 112 | sigma = np.sqrt( 113 | np.abs( 114 | convolve(np.square(img), gaussian_window, mode='nearest') - 115 | np.square(mu))) 116 | # normalize, as in Eq. 1 in the paper 117 | img_nomalized = (img - mu) / (sigma + 1) 118 | 119 | feat = [] 120 | for idx_w in range(num_block_w): 121 | for idx_h in range(num_block_h): 122 | # process ecah block 123 | block = img_nomalized[idx_h * block_size_h // 124 | scale:(idx_h + 1) * block_size_h // 125 | scale, idx_w * block_size_w // 126 | scale:(idx_w + 1) * block_size_w // 127 | scale] 128 | feat.append(compute_feature(block)) 129 | 130 | distparam.append(np.array(feat)) 131 | # TODO: matlab bicubic downsample with anti-aliasing 132 | # for simplicity, now we use opencv instead, which will result in 133 | # a slight difference. 134 | if scale == 1: 135 | h, w = img.shape 136 | img = cv2.resize( 137 | img / 255., (w // 2, h // 2), interpolation=cv2.INTER_LINEAR) 138 | img = img * 255. 139 | 140 | distparam = np.concatenate(distparam, axis=1) 141 | 142 | # fit a MVG (multivariate Gaussian) model to distorted patch features 143 | mu_distparam = np.nanmean(distparam, axis=0) 144 | # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html 145 | distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)] 146 | cov_distparam = np.cov(distparam_no_nan, rowvar=False) 147 | 148 | # compute niqe quality, Eq. 10 in the paper 149 | invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) 150 | quality = np.matmul( 151 | np.matmul((mu_pris_param - mu_distparam), invcov_param), 152 | np.transpose((mu_pris_param - mu_distparam))) 153 | quality = np.sqrt(quality) 154 | 155 | return quality 156 | 157 | 158 | def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'): 159 | """Calculate NIQE (Natural Image Quality Evaluator) metric. 160 | 161 | Ref: Making a "Completely Blind" Image Quality Analyzer. 162 | This implementation could produce almost the same results as the official 163 | MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip 164 | 165 | We use the official params estimated from the pristine dataset. 166 | We use the recommended block size (96, 96) without overlaps. 167 | 168 | Args: 169 | img (ndarray): Input image whose quality needs to be computed. 170 | The input image must be in range [0, 255] with float/int type. 171 | The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order) 172 | If the input order is 'HWC' or 'CHW', it will be converted to gray 173 | or Y (of YCbCr) image according to the ``convert_to`` argument. 174 | crop_border (int): Cropped pixels in each edge of an image. These 175 | pixels are not involved in the metric calculation. 176 | input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'. 177 | Default: 'HWC'. 178 | convert_to (str): Whether coverted to 'y' (of MATLAB YCbCr) or 'gray'. 179 | Default: 'y'. 180 | 181 | Returns: 182 | float: NIQE result. 183 | """ 184 | 185 | # we use the official params estimated from the pristine dataset. 186 | niqe_pris_params = np.load('basicsr/metrics/niqe_pris_params.npz') 187 | mu_pris_param = niqe_pris_params['mu_pris_param'] 188 | cov_pris_param = niqe_pris_params['cov_pris_param'] 189 | gaussian_window = niqe_pris_params['gaussian_window'] 190 | 191 | img = img.astype(np.float32) 192 | if input_order != 'HW': 193 | img = reorder_image(img, input_order=input_order) 194 | if convert_to == 'y': 195 | img = to_y_channel(img) 196 | elif convert_to == 'gray': 197 | img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255. 198 | img = np.squeeze(img) 199 | 200 | if crop_border != 0: 201 | img = img[crop_border:-crop_border, crop_border:-crop_border] 202 | 203 | niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window) 204 | 205 | return niqe_result 206 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/mefnet-multi-scale-event-fusion-network-for/deblurring-on-gopro)](https://paperswithcode.com/sota/deblurring-on-gopro?p=mefnet-multi-scale-event-fusion-network-for) 2 | 3 | Event-based Fusion for Motion Deblurring with Cross-modal Attention 4 | --- 5 | #### Lei Sun, Christos Sakaridis, Jingyun Liang, Qi Jiang, Kailun Yang, Peng Sun, Yaozu Ye, Kaiwei Wang, Luc Van Gool 6 | #### Paper: https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136780403.pdf 7 | > Traditional frame-based cameras inevitably suffer from motion blur due to long exposure times. As a kind of bio-inspired camera, the event camera records the intensity changes in an asynchronous way with high temporal resolution, providing valid image degradation information within the exposure time. In this paper, we rethink the eventbased image deblurring problem and unfold it into an end-to-end two-stage image restoration network. To effectively fuse event and image features, we design an event-image cross-modal attention module applied at multiple levels of our network, which allows to focus on relevant features from the event branch and filter out noise. We also introduce a novel symmetric cumulative event representation specifically for image deblurring as well as an event mask gated connection between the two stages of our network which helps avoid information loss. At the dataset level, to foster event-based motion deblurring and to facilitate evaluation on challenging real-world images, we introduce the Real Event Blur (REBlur) dataset, captured with an event camera in an illumination controlled optical laboratory. Our Event Fusion Network (EFNet) sets the new state of the art in motion deblurring, surpassing both the prior best-performing image-based method and all event-based methods with public implementations on the GoPro dataset (by up to 2.47dB) and on our REBlur dataset, even in extreme blurry conditions. 8 | 9 | 10 | ### News 11 | - Febuary 2025: [NTIRE 2025](https://cvlai.net/ntire/2025/) [the First Challenge on Event-Based Image Deblurring](https://codalab.lisn.upsaclay.fr/competitions/21498) starts! Example config file is provided at `./options/train/HighREV/EFNet_HighREV_Deblur.yml` 12 | - Octorber 2024: Real-time deblurring system built!! Please check our video demo in github page! 13 | - September 25: Update the dataset link. Datasets available now. 14 | - July 14 2022: :tada: :tada: Our paper was accepted in ECCV'2022 as oral presentation (2.7% of the submission). 15 | - July 14 2022: The repository is under construction. 16 | 17 | ### NTIRE 2025 the First Challenge on Event-Based Image Deblurring 18 | We provide a simple instruction here. For more details please refer to [this repo](https://github.com/AHupuJR/NTIRE2025_EventDeblur_challenge) 19 | 20 | #### HighREV dataset download 21 | 22 | [HighREV dataset](https://codalab.lisn.upsaclay.fr/my/datasets/download/9f275580-9b38-4984-b995-1e59e96b6111): Mandatory, with raw events. 23 | 24 | command: `wget -O HighREV.zip "https://codalab.lisn.upsaclay.fr/my/datasets/download/9f275580-9b38-4984-b995-1e59e96b6111"` 25 | 26 | 27 | [Processed voxel grid of events](https://codalab.lisn.upsaclay.fr/my/datasets/download/c83e95ab-d4e6-4b9f-b7de-e3d3b45356e3): Optional. Using the processed voxel grids can speed up the data processing when training. 28 | 29 | 30 | #### Dataset codes: 31 | `./basicsr/data/npz_image_dataset.py` for raw events 32 | `./basicsr/data/EFNet_HighREV_Deblur_voxel.py` for voxel grids 33 | 34 | 35 | #### Example: 36 | ``` 37 | git clone https://github.com/AHupuJR/EFNet 38 | cd EFNet 39 | pip install -r requirements.txt 40 | python setup.py develop --no_cuda_ext 41 | 42 | python ./basicsr/train.py -opt options/train/HighREV/EFNet_HighREV_Deblur.yml 43 | ``` 44 | 45 | 46 | ### Real-time system 47 | https://github.com/AHupuJR/EFNet/assets/35628326/b888bb41-33f9-46de-b1b6-50d5291bf4ed 48 | 49 | 50 | [Bilibili video](https://www.bilibili.com/video/BV158411r7nY/?spm_id_from=333.999.0.0&vd_source=4c2f67c9c80eb20c01a758e38fc07198) 51 | 52 | Contributed by Xiaolei Gu, Xiao Jin (Jiaxing Research Institute, Zhejiang University), Yuhan Bao (Zhejiang University). 53 | 54 | From left to right: blurry image, deblurred image, visualized events. 55 | 56 | Feel free to connect me for potential applications. 57 | 58 | ### Network Architecture 59 | 60 | arch 61 | 62 | 63 | ### Symmetric Cumulative Event Representation (SCER) 64 | 65 | scer 66 | 67 | ### Results 68 |
GoPro dataset (Click to expand) 69 | gopro1 70 | gopro2 71 | gopro_table 72 |
73 | 74 |
REBlur dataset (Click to expand) 75 | reblur1 76 | reblur2 77 | reblur_table 78 |
79 | 80 | ### Installation 81 | This implementation based on [BasicSR](https://github.com/xinntao/BasicSR) which is a open source toolbox for image/video restoration tasks. 82 | 83 | ```python 84 | python 3.8.5 85 | pytorch 1.7.1 86 | cuda 11.0 87 | ``` 88 | 89 | 90 | 91 | ``` 92 | git clone https://github.com/AHupuJR/EFNet 93 | cd EFNet 94 | pip install -r requirements.txt 95 | python setup.py develop --no_cuda_ext 96 | ``` 97 | 98 | ### Dataset 99 | Use GoPro events to train the model. If you want to use your own event representation instead of SCER, download GoPro raw events and use EFNet/scripts/data_preparation/make_voxels_esim.py to produce your own event representation. 100 | 101 | GoPro with SCER: [[ETH_share_link](https://data.vision.ee.ethz.ch/csakarid/shared/EFNet/GOPRO.zip)] [[BaiduYunPan](https://pan.baidu.com/s/1TxWdMB2LjdlgIvuc6QN-Bg)/code: 3wm8] 102 | 103 | REBlur with SCER: [[ETH_share_link](https://data.vision.ee.ethz.ch/csakarid/shared/EFNet/REBlur.zip)] [[BaiduYunPan](https://pan.baidu.com/s/13v0CjlFUXt9TxXI0Co9tQQ?pwd=f6ha#list/path=%2F)/code:f6ha] 104 | 105 | We also provide scripts to convert raw event files to SCER using scripts in [./scripts/data_preparation/](./scripts/data_preparation/). You can also design your own event representation by modify the script. Raw event files download: 106 | 107 | GoPro with raw events: [[ETH_share_link](https://data.vision.ee.ethz.ch/csakarid/shared/EFNet/GOPRO_rawevents.zip)] [[BaiduYunPan](link)/TODO] 108 | 109 | REBlur with raw events: [[ETH_share_link](https://data.vision.ee.ethz.ch/csakarid/shared/EFNet/REBlur_rawevents.zip)] [[BaiduYunPan](link)/TODO] 110 | 111 | 112 | ### Train 113 | --- 114 | #### GoPro 115 | 116 | * prepare data 117 | 118 | * download the GoPro events dataset (see [Dataset](dataset_section)) to 119 | ```bash 120 | ./datasets 121 | ``` 122 | 123 | * it should be like: 124 | 125 | ```bash 126 | ./datasets/ 127 | ./datasets/DATASET_NAME/ 128 | ./datasets/DATASET_NAME/train/ 129 | ./datasets/DATASET_NAME/test/ 130 | ``` 131 | 132 | * train 133 | 134 | * ```python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/GoPro/EFNet.yml --launcher pytorch``` 135 | 136 | * eval 137 | * Download [pretrained model](https://drive.google.com/file/d/19O-B-K4IODMENQblwHbSqNu0TX0IF-iA/view?usp=sharing) to ./experiments/pretrained_models/EFNet-GoPro.pth 138 | * ```python basicsr/test.py -opt options/test/GoPro/EFNet.yml ``` 139 | 140 | 141 | #### REBlur 142 | 143 | * prepare data 144 | 145 | * download the REBlur dataset (see [Dataset](dataset_section)) to 146 | ```bash 147 | ./datasets 148 | ``` 149 | 150 | * it should be like: 151 | 152 | ```bash 153 | ./datasets/ 154 | ./datasets/DATASET_NAME/ 155 | ./datasets/DATASET_NAME/train/ 156 | ./datasets/DATASET_NAME/test/ 157 | ``` 158 | 159 | * finetune 160 | 161 | * ```python ./basicsr/train.py -opt options/train/REBlur/Finetune_EFNet.yml``` 162 | 163 | * eval 164 | * Download [pretrained model](https://drive.google.com/file/d/1yMGnwfYsxWbVp7r-oc8ls9qnOEDavG3h/view?usp=sharing) to ./experiments/pretrained_models/EFNet-REBlur.pth 165 | * ```python basicsr/test.py -opt options/test/REBlur/Finetune_EFNet.yml ``` 166 | 167 | 168 | ### Qualitative results 169 | All the qualitative results can be downloaded through Google Drive: 170 | 171 | [GoPro test](https://drive.google.com/file/d/17jXR5U9e3-8dXPUxB0-wBhDFg60Oe8US/view?usp=sharing) 172 | 173 | [REBlur test](https://drive.google.com/file/d/17jXR5U9e3-8dXPUxB0-wBhDFg60Oe8US/view?usp=sharing) 174 | 175 | [REBlur addition](https://drive.google.com/file/d/17jXR5U9e3-8dXPUxB0-wBhDFg60Oe8US/view?usp=sharing) 176 | 177 | 178 | ### Citations 179 | 180 | ``` 181 | @inproceedings{sun2022event, 182 | title={Event-Based Fusion for Motion Deblurring with Cross-modal Attention}, 183 | author={Sun, Lei and Sakaridis, Christos and Liang, Jingyun and Jiang, Qi and Yang, Kailun and Sun, Peng and Ye, Yaozu and Wang, Kaiwei and Gool, Luc Van}, 184 | booktitle={European Conference on Computer Vision}, 185 | pages={412--428}, 186 | year={2022}, 187 | organization={Springer} 188 | } 189 | ``` 190 | 191 | 192 | ### Contact 193 | Should you have any questions, please feel free to contact leo_sun@zju.edu.cn or leisun@ee.ethz.ch. 194 | 195 | 196 | ### License and Acknowledgement 197 | 198 | This project is under the Apache 2.0 license, and it is based on [BasicSR](https://github.com/xinntao/BasicSR) which is under the Apache 2.0 license. Thanks to the inspirations and codes from [HINet](https://github.com/megvii-model/HINet) and [event_utils](https://github.com/TimoStoff/event_utils) 199 | 200 | 201 | -------------------------------------------------------------------------------- /scripts/data_preparation/raw_event_dataset.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | from .base_dataset import BaseVoxelDataset 3 | import pandas as pd 4 | import numpy as np 5 | 6 | class GoproEsimH5Dataset(BaseVoxelDataset): 7 | """ 8 | Dataloader for events saved in the Monash University HDF5 events format 9 | GOPRO ESIM data, index is different (index + 1) 10 | """ 11 | 12 | def get_frame(self, index): 13 | """ 14 | discard the first and the last frame in GoproEsim 15 | """ 16 | return self.h5_file['images']['image{:09d}'.format(index+1)][:] 17 | 18 | def get_gt_frame(self, index): 19 | """ 20 | discard the first and the last frame in GoproEsim 21 | """ 22 | return self.h5_file['sharp_images']['image{:09d}'.format(index+1)][:] 23 | 24 | def get_events(self, idx0, idx1): 25 | xs = self.h5_file['events/xs'][idx0:idx1] 26 | ys = self.h5_file['events/ys'][idx0:idx1] 27 | ts = self.h5_file['events/ts'][idx0:idx1] 28 | ps = self.h5_file['events/ps'][idx0:idx1] * 2.0 - 1.0 # -1 and 1 29 | return xs, ys, ts, ps 30 | 31 | def load_data(self, data_path): 32 | self.data_sources = ('esim', 'ijrr', 'mvsec', 'eccd', 'hqfd', 'unknown') 33 | try: 34 | self.h5_file = h5py.File(data_path, 'r') 35 | except OSError as err: 36 | print("Couldn't open {}: {}".format(data_path, err)) 37 | 38 | if self.sensor_resolution is None: 39 | self.sensor_resolution = self.h5_file.attrs['sensor_resolution'][0:2] 40 | else: 41 | self.sensor_resolution = self.sensor_resolution[0:2] 42 | print("sensor resolution = {}".format(self.sensor_resolution)) 43 | self.has_flow = 'flow' in self.h5_file.keys() and len(self.h5_file['flow']) > 0 44 | self.t0 = self.h5_file['events/ts'][0] 45 | self.tk = self.h5_file['events/ts'][-1] 46 | self.num_events = self.h5_file.attrs["num_events"] 47 | self.num_frames = self.h5_file.attrs["num_imgs"] 48 | 49 | self.frame_ts = [] 50 | for img_name in self.h5_file['images']: 51 | self.frame_ts.append(self.h5_file['images/{}'.format(img_name)].attrs['timestamp']) 52 | 53 | data_source = self.h5_file.attrs.get('source', 'unknown') 54 | try: 55 | self.data_source_idx = self.data_sources.index(data_source) 56 | except ValueError: 57 | self.data_source_idx = -1 58 | 59 | def find_ts_index(self, timestamp): 60 | idx = binary_search_h5_dset(self.h5_file['events/ts'], timestamp) 61 | return idx 62 | 63 | def ts(self, index): 64 | return self.h5_file['events/ts'][index] 65 | 66 | def compute_frame_indices(self): 67 | frame_indices = [] 68 | start_idx = 0 69 | for img_name in self.h5_file['images']: 70 | end_idx = self.h5_file['images/{}'.format(img_name)].attrs['event_idx'] 71 | frame_indices.append([start_idx, end_idx]) 72 | start_idx = end_idx 73 | return frame_indices 74 | 75 | 76 | class SeemsH5Dataset(BaseVoxelDataset): 77 | """ 78 | Dataloader for events saved in the Monash University HDF5 events format 79 | H5 data contains the exposure information 80 | """ 81 | 82 | def get_frame(self, index): 83 | """ 84 | discard the first and the last frame in GoproEsim 85 | """ 86 | return self.h5_file['images']['image{:09d}'.format(index)][:] 87 | 88 | def get_gt_frame(self, index): 89 | """ 90 | discard the first and the last frame in GoproEsim 91 | """ 92 | return self.h5_file['sharp_images']['image{:09d}'.format(index)][:] 93 | 94 | def get_events(self, idx0, idx1): 95 | xs = self.h5_file['events/xs'][idx0:idx1] 96 | ys = self.h5_file['events/ys'][idx0:idx1] 97 | ts = self.h5_file['events/ts'][idx0:idx1] 98 | ps = self.h5_file['events/ps'][idx0:idx1] * 2.0 - 1.0 # -1 and 1 99 | return xs, ys, ts, ps 100 | 101 | def load_data(self, data_path): 102 | self.data_sources = ('esim', 'ijrr', 'mvsec', 'eccd', 'hqfd', 'unknown') 103 | try: 104 | self.h5_file = h5py.File(data_path, 'r') 105 | except OSError as err: 106 | print("Couldn't open {}: {}".format(data_path, err)) 107 | 108 | if self.sensor_resolution is None: 109 | self.sensor_resolution = self.h5_file.attrs['sensor_resolution'][0:2] 110 | else: 111 | self.sensor_resolution = self.sensor_resolution[0:2] 112 | print("sensor resolution = {}".format(self.sensor_resolution)) 113 | self.has_flow = 'flow' in self.h5_file.keys() and len(self.h5_file['flow']) > 0 114 | self.t0 = self.h5_file['events/ts'][0] 115 | self.tk = self.h5_file['events/ts'][-1] 116 | self.num_events = self.h5_file.attrs["num_events"] 117 | self.num_frames = self.h5_file.attrs["num_imgs"] 118 | 119 | self.frame_ts = [] 120 | if self.has_exposure_time: 121 | self.frame_exposure_start=[] 122 | self.frame_exposure_end = [] 123 | self.frame_exposure_time=[] 124 | for img_name in self.h5_file['images']: 125 | self.frame_ts.append(self.h5_file['images/{}'.format(img_name)].attrs['timestamp']) 126 | if self.has_exposure_time: 127 | self.frame_exposure_start.append(self.h5_file['images/{}'.format(img_name)].attrs['exposure_start']) 128 | self.frame_exposure_end.append(self.h5_file['images/{}'.format(img_name)].attrs['exposure_end']) 129 | self.frame_exposure_time.append(self.h5_file['images/{}'.format(img_name)].attrs['exposure_time']) 130 | data_source = self.h5_file.attrs.get('source', 'unknown') 131 | try: 132 | self.data_source_idx = self.data_sources.index(data_source) 133 | except ValueError: 134 | self.data_source_idx = -1 135 | 136 | def find_ts_index(self, timestamp): 137 | idx = binary_search_h5_dset(self.h5_file['events/ts'], timestamp) 138 | return idx 139 | 140 | def ts(self, index): 141 | return self.h5_file['events/ts'][index] 142 | 143 | def compute_frame_indices(self): 144 | frame_indices = [] 145 | start_idx = 0 146 | for img_name in self.h5_file['images']: 147 | end_idx = self.h5_file['images/{}'.format(img_name)].attrs['event_idx'] 148 | frame_indices.append([start_idx, end_idx]) 149 | start_idx = end_idx 150 | return frame_indices 151 | 152 | 153 | class GoproEsimH52NpzDataset(BaseVoxelDataset): 154 | """ 155 | 156 | """ 157 | 158 | def get_frame(self, index): 159 | """ 160 | discard the first and the last frame in GoproEsim 161 | """ 162 | return self.h5_file['images']['image{:09d}'.format(index + 1)][:] 163 | 164 | def get_gt_frame(self, index): 165 | """ 166 | discard the first and the last frame in GoproEsim 167 | """ 168 | return self.h5_file['sharp_images']['image{:09d}'.format(index + 1)][:] 169 | 170 | def get_events(self, idx0, idx1): 171 | xs = self.h5_file['events/xs'][idx0:idx1] 172 | ys = self.h5_file['events/ys'][idx0:idx1] 173 | ts = self.h5_file['events/ts'][idx0:idx1] 174 | ps = self.h5_file['events/ps'][idx0:idx1] * 2.0 - 1.0 # -1 and 1 175 | return xs, ys, ts, ps 176 | 177 | def load_data(self, data_path, csv_path): 178 | self.data_sources = ('esim', 'ijrr', 'mvsec', 'eccd', 'hqfd', 'unknown') 179 | try: 180 | self.h5_file = h5py.File(data_path, 'r') 181 | except OSError as err: 182 | print("Couldn't open {}: {}".format(data_path, err)) 183 | pd_file = pd.read_csv(csv_path, header=None, names=['t', 'path'], dtype={'t': np.float64}, 184 | engine='c') 185 | 186 | img_time_stamps = pd_file.t.values.astype(np.float64) 187 | img_names = pd_file.path.values 188 | 189 | if self.sensor_resolution is None: 190 | self.sensor_resolution = self.h5_file.attrs['sensor_resolution'][0:2] 191 | else: 192 | self.sensor_resolution = self.sensor_resolution[0:2] 193 | 194 | print("sensor resolution = {}".format(self.sensor_resolution)) 195 | self.has_flow = 'flow' in self.h5_file.keys() and len(self.h5_file['flow']) > 0 196 | 197 | self.t0 = self.h5_file['events/ts'][0] 198 | self.tk = self.h5_file['events/ts'][-1] 199 | self.num_events = self.h5_file.attrs["num_events"] 200 | 201 | 202 | self.frame_ts = list(img_time_stamps) 203 | self.img_names = list(img_names) 204 | self.num_frames = len(self.frame_ts) 205 | 206 | self.data_source_idx = -1 207 | 208 | 209 | def find_ts_index(self, timestamp): 210 | idx = binary_search_h5_dset(self.h5_file['events/ts'], timestamp) 211 | return idx 212 | 213 | def ts(self, index): 214 | return self.h5_file['events/ts'][index] 215 | 216 | def compute_frame_indices(self): 217 | frame_indices = [] 218 | start_idx = 0 219 | for img_name in self.h5_file['images']: 220 | end_idx = self.h5_file['images/{}'.format(img_name)].attrs['event_idx'] 221 | frame_indices.append([start_idx, end_idx]) 222 | start_idx = end_idx 223 | return frame_indices 224 | 225 | 226 | def binary_search_h5_dset(dset, x, l=None, r=None, side='left'): 227 | """ 228 | Binary search for a timestamp in an HDF5 event file, without 229 | loading the entire file into RAM 230 | @param dset The HDF5 dataset 231 | @param x The timestamp being searched for 232 | @param l Starting guess for the left side (0 if None is chosen) 233 | @param r Starting guess for the right side (-1 if None is chosen) 234 | @param side Which side to take final result for if exact match is not found 235 | @returns Index of nearest event to 'x' 236 | """ 237 | l = 0 if l is None else l 238 | r = len(dset)-1 if r is None else r 239 | while l <= r: 240 | mid = l + (r - l)//2; 241 | midval = dset[mid] 242 | if midval == x: 243 | return mid 244 | elif midval < x: 245 | l = mid + 1 246 | else: 247 | r = mid - 1 248 | if side == 'left': 249 | return l 250 | return r 251 | 252 | 253 | --------------------------------------------------------------------------------