├── VERSION
├── basicsr
├── __init__.py
├── metrics
│ ├── niqe_pris_params.npz
│ ├── __pycache__
│ │ ├── niqe.cpython-38.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── psnr_ssim.cpython-38.pyc
│ │ └── metric_util.cpython-38.pyc
│ ├── __init__.py
│ ├── metric_util.py
│ ├── fid.py
│ └── niqe.py
├── __pycache__
│ ├── version.cpython-38.pyc
│ └── __init__.cpython-38.pyc
├── utils
│ ├── __pycache__
│ │ ├── logger.cpython-38.pyc
│ │ ├── misc.cpython-38.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── img_util.cpython-38.pyc
│ │ ├── options.cpython-38.pyc
│ │ ├── create_lmdb.cpython-38.pyc
│ │ ├── dist_util.cpython-38.pyc
│ │ ├── file_client.cpython-38.pyc
│ │ ├── flow_util.cpython-38.pyc
│ │ ├── lmdb_util.cpython-38.pyc
│ │ └── matlab_functions.cpython-38.pyc
│ ├── __init__.py
│ ├── download_util.py
│ ├── dist_util.py
│ ├── options.py
│ ├── create_lmdb.py
│ ├── npz2voxel.py
│ ├── misc.py
│ ├── img_util.py
│ ├── logger.py
│ ├── flow_util.py
│ ├── file_client.py
│ └── lmdb_util.py
├── data
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── data_util.cpython-38.pyc
│ │ ├── data_sampler.cpython-38.pyc
│ │ ├── ffhq_dataset.cpython-38.pyc
│ │ ├── h5_augment.cpython-38.pyc
│ │ ├── reds_dataset.cpython-38.pyc
│ │ ├── transforms.cpython-38.pyc
│ │ ├── h5_image_dataset.cpython-38.pyc
│ │ ├── prefetch_dataloader.cpython-38.pyc
│ │ ├── paired_image_dataset.cpython-38.pyc
│ │ └── single_image_dataset.cpython-38.pyc
│ ├── data_sampler.py
│ ├── ffhq_dataset.py
│ ├── single_image_dataset.py
│ ├── prefetch_dataloader.py
│ ├── event_util.py
│ ├── __init__.py
│ ├── voxelnpz_image_dataset.py
│ └── npz_image_dataset.py
├── models
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── base_model.cpython-38.pyc
│ │ ├── lr_scheduler.cpython-38.pyc
│ │ ├── image_restoration_model.cpython-38.pyc
│ │ ├── Test_image_restoration_model.cpython-38.pyc
│ │ ├── image_event_restoration_model.cpython-38.pyc
│ │ └── Test_image_event_restoration_model.cpython-38.pyc
│ ├── archs
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── arch_util.cpython-38.pyc
│ │ │ └── EFNet_arch.cpython-38.pyc
│ │ └── __init__.py
│ ├── losses
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── losses.cpython-38.pyc
│ │ │ └── loss_util.cpython-38.pyc
│ │ ├── __init__.py
│ │ ├── loss_util.py
│ │ └── losses.py
│ ├── __init__.py
│ └── lr_scheduler.py
├── version.py
├── demo.py
└── test.py
├── figures
├── scer.png
├── ._SCER.pdf
├── ._scer.png
├── models.png
├── modules.pdf
├── ._models.pdf
├── ._models.png
├── ._modules.pdf
├── table_gopro.png
├── table_reblur.png
├── ._table_gopro.png
├── ._table_reblur.png
├── qualitative_GoPro_1.jpg
├── qualitative_GoPro_2.png
├── qualitative_REBlur_1.jpg
├── qualitative_REBlur_2.png
├── ._qualitative_GoPro_1.jpg
├── ._qualitative_GoPro_2.png
├── ._qualitative_REBlur_1.jpg
└── ._qualitative_REBlur_2.png
├── .gitignore
├── requirements.txt
├── scripts
├── download_gdrive.py
├── data_preparation
│ ├── README.md
│ ├── noise_function.py
│ ├── make_voxels_esim.py
│ ├── make_voxels_real.py
│ └── raw_event_dataset.py
├── publish_models.py
└── download_pretrained_models.py
├── datasets
└── README.md
├── setup.cfg
├── options
├── test
│ ├── GoPro
│ │ └── EFNet.yml
│ └── REBlur
│ │ └── Finetune_EFNet.yml
└── train
│ ├── HighREV
│ ├── EFNet_HighREV_Deblur.yml
│ └── EFNet_HighREV_Deblur_voxel.yml
│ ├── GoPro
│ └── EFNet.yml
│ └── REBlur
│ └── Finetune_EFNet.yml
├── setup.py
└── README.md
/VERSION:
--------------------------------------------------------------------------------
1 | 1.2.0
2 |
--------------------------------------------------------------------------------
/basicsr/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/figures/scer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/scer.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .vscode/
3 | .eggs/
4 | basicsr.egg-info/
5 | experiments/
--------------------------------------------------------------------------------
/figures/._SCER.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._SCER.pdf
--------------------------------------------------------------------------------
/figures/._scer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._scer.png
--------------------------------------------------------------------------------
/figures/models.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/models.png
--------------------------------------------------------------------------------
/figures/modules.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/modules.pdf
--------------------------------------------------------------------------------
/figures/._models.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._models.pdf
--------------------------------------------------------------------------------
/figures/._models.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._models.png
--------------------------------------------------------------------------------
/figures/._modules.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._modules.pdf
--------------------------------------------------------------------------------
/figures/table_gopro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/table_gopro.png
--------------------------------------------------------------------------------
/figures/table_reblur.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/table_reblur.png
--------------------------------------------------------------------------------
/figures/._table_gopro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._table_gopro.png
--------------------------------------------------------------------------------
/figures/._table_reblur.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._table_reblur.png
--------------------------------------------------------------------------------
/figures/qualitative_GoPro_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/qualitative_GoPro_1.jpg
--------------------------------------------------------------------------------
/figures/qualitative_GoPro_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/qualitative_GoPro_2.png
--------------------------------------------------------------------------------
/figures/qualitative_REBlur_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/qualitative_REBlur_1.jpg
--------------------------------------------------------------------------------
/figures/qualitative_REBlur_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/qualitative_REBlur_2.png
--------------------------------------------------------------------------------
/figures/._qualitative_GoPro_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._qualitative_GoPro_1.jpg
--------------------------------------------------------------------------------
/figures/._qualitative_GoPro_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._qualitative_GoPro_2.png
--------------------------------------------------------------------------------
/figures/._qualitative_REBlur_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._qualitative_REBlur_1.jpg
--------------------------------------------------------------------------------
/figures/._qualitative_REBlur_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/figures/._qualitative_REBlur_2.png
--------------------------------------------------------------------------------
/basicsr/metrics/niqe_pris_params.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/metrics/niqe_pris_params.npz
--------------------------------------------------------------------------------
/basicsr/__pycache__/version.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/__pycache__/version.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/niqe.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/metrics/__pycache__/niqe.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/logger.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/misc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/misc.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/data_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/data_util.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/img_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/img_util.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/options.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/options.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/data_sampler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/data_sampler.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/ffhq_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/ffhq_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/h5_augment.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/h5_augment.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/reds_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/reds_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/metrics/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/psnr_ssim.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/metrics/__pycache__/psnr_ssim.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/base_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/base_model.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/create_lmdb.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/create_lmdb.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/dist_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/dist_util.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/file_client.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/file_client.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/flow_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/flow_util.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/lmdb_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/lmdb_util.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/metric_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/metrics/__pycache__/metric_util.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/lr_scheduler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/lr_scheduler.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/h5_image_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/h5_image_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/archs/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/arch_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/archs/__pycache__/arch_util.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/losses/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/__pycache__/losses.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/losses/__pycache__/losses.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/matlab_functions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/utils/__pycache__/matlab_functions.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/prefetch_dataloader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/prefetch_dataloader.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/EFNet_arch.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/archs/__pycache__/EFNet_arch.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/__pycache__/loss_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/losses/__pycache__/loss_util.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/paired_image_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/paired_image_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/single_image_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/data/__pycache__/single_image_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/version.py:
--------------------------------------------------------------------------------
1 | # GENERATED VERSION FILE
2 | # TIME: Fri Feb 7 15:11:50 2025
3 | __version__ = '1.2.0+a8a710a'
4 | short_version = '1.2.0'
5 | version_info = (1, 2, 0)
6 |
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/image_restoration_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/image_restoration_model.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/Test_image_restoration_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/Test_image_restoration_model.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/image_event_restoration_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/image_event_restoration_model.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .niqe import calculate_niqe
2 | from .psnr_ssim import calculate_psnr, calculate_ssim
3 |
4 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | addict
2 | future
3 | lmdb
4 | numpy
5 | opencv-python
6 | Pillow
7 | pyyaml
8 | requests
9 | scikit-image
10 | scipy
11 | tb-nightly
12 | tqdm
13 | yapf
14 | h5py
15 |
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/Test_image_event_restoration_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AHupuJR/EFNet/HEAD/basicsr/models/__pycache__/Test_image_event_restoration_model.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import (L1Loss, MSELoss, PSNRLoss, WeightedTVLoss, CharbonnierLoss)
2 |
3 | __all__ = [
4 | 'L1Loss', 'MSELoss', 'PSNRLoss', 'WeightedTVLoss', 'CharbonnierLoss'
5 | ]
6 |
--------------------------------------------------------------------------------
/scripts/download_gdrive.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from basicsr.utils.download_util import download_file_from_google_drive
4 |
5 | if __name__ == '__main__':
6 | parser = argparse.ArgumentParser()
7 |
8 | parser.add_argument('--id', type=str, help='File id')
9 | parser.add_argument('--output', type=str, help='Save path')
10 | args = parser.parse_args()
11 |
12 | download_file_from_google_drive(args.id, args.save_path)
13 |
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Symlink/Put all the datasets here
2 |
3 | It is recommended to symlink your dataset root to this folder - `datasets` with the command `ln -s xxx yyy`.
4 |
5 | Please refer to [DatasetPreparation.md](../docs/docs/DatasetPreparation.md) for more details about data preparation.
6 |
7 | ---
8 |
9 | 推荐把数据通过 `ln -s xxx yyy` 软链到当前目录 `datasets` 下.
10 |
11 | 更多数据准备的细节参见 [DatasetPreparation_CN.md](../docs/DatasetPreparation_CN.md).
12 |
--------------------------------------------------------------------------------
/scripts/data_preparation/README.md:
--------------------------------------------------------------------------------
1 | ### Used for generate voxels from raw event (h5 file)
2 |
3 | Convert raw GoPro event file (h5) to voxel file (h5):
4 | ```
5 | python make_voxels_esim.py --input_path /your/path/to/raw/event/h5file --save_path /your/path/to/save/voxel/h5file --voxel_method SCER_esim
6 | ```
7 |
8 | Convert raw REBlur event file (h5) to voxel file (h5):
9 | ```
10 | python make_voxels_real.py --input_path /your/path/to/raw/event/h5file --save_path /your/path/to/save/voxel/h5file --voxel_method SCER_real_data
11 | ```
12 |
13 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore =
3 | # line break before binary operator (W503)
4 | W503,
5 | # line break after binary operator (W504)
6 | W504,
7 | max-line-length=79
8 |
9 | [yapf]
10 | based_on_style = pep8
11 | blank_line_before_nested_class_or_def = true
12 | split_before_expression_after_opening_paren = true
13 |
14 | [isort]
15 | line_length = 79
16 | multi_line_output = 0
17 | known_standard_library = pkg_resources,setuptools
18 | known_first_party = basicsr
19 | known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml
20 | no_lines_before = STDLIB,LOCALFOLDER
21 | default_section = THIRDPARTY
22 |
--------------------------------------------------------------------------------
/scripts/data_preparation/noise_function.py:
--------------------------------------------------------------------------------
1 |
2 | def put_hot_pixels_in_voxel_(voxel, hot_pixel_range=20, hot_pixel_fraction=0.00002):
3 | num_hot_pixels = int(hot_pixel_fraction * voxel.shape[-1] * voxel.shape[-2])
4 | x = torch.randint(0, voxel.shape[-1], (num_hot_pixels,))
5 | y = torch.randint(0, voxel.shape[-2], (num_hot_pixels,))
6 | for i in range(num_hot_pixels):
7 | # voxel[..., :, y[i], x[i]] = random.uniform(-hot_pixel_range, hot_pixel_range)
8 | voxel[..., :, y[i], x[i]] = random.randint(-hot_pixel_range, hot_pixel_range)
9 |
10 |
11 | def add_noise_to_voxel(voxel, noise_std=1.0, noise_fraction=0.1):
12 | noise = noise_std * torch.randn_like(voxel) # mean = 0, std = noise_std
13 | if noise_fraction < 1.0:
14 | mask = torch.rand_like(voxel) >= noise_fraction
15 | noise.masked_fill_(mask, 0)
16 | return voxel + noise
17 |
--------------------------------------------------------------------------------
/basicsr/demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from basicsr.models import create_model
3 | from basicsr.train import parse_options
4 | from basicsr.utils import FileClient, imfrombytes, img2tensor, padding
5 |
6 | def main():
7 | # parse options, set distributed setting, set ramdom seed
8 | opt = parse_options(is_train=False)
9 |
10 | img_path = opt['img_path'].get('input_img')
11 | output_path = opt['img_path'].get('output_img')
12 |
13 |
14 | ## 1. read image
15 | file_client = FileClient('disk')
16 |
17 | img_bytes = file_client.get(img_path, None)
18 | try:
19 | img = imfrombytes(img_bytes, float32=True)
20 | except:
21 | raise Exception("path {} not working".format(img_path))
22 |
23 | img = img2tensor(img, bgr2rgb=True, float32=True)
24 |
25 |
26 |
27 | ## 2. run inference
28 | model = create_model(opt)
29 | model.single_image_inference(img, output_path)
30 |
31 | print('inference {} .. finished.'.format(img_path))
32 |
33 | if __name__ == '__main__':
34 | main()
35 |
36 |
--------------------------------------------------------------------------------
/basicsr/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .file_client import FileClient
2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding
3 | from .logger import (MessageLogger, get_env_info, get_root_logger,
4 | init_tb_logger, init_wandb_logger)
5 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename,
6 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt)
7 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k)
8 |
9 | __all__ = [
10 | # file_client.py
11 | 'FileClient',
12 | # img_util.py
13 | 'img2tensor',
14 | 'tensor2img',
15 | 'imfrombytes',
16 | 'imwrite',
17 | 'crop_border',
18 | # logger.py
19 | 'MessageLogger',
20 | 'init_tb_logger',
21 | 'init_wandb_logger',
22 | 'get_root_logger',
23 | 'get_env_info',
24 | # misc.py
25 | 'set_random_seed',
26 | 'get_time_str',
27 | 'mkdir_and_rename',
28 | 'make_exp_dirs',
29 | 'scandir',
30 | 'check_resume',
31 | 'sizeof_fmt',
32 | 'padding',
33 | 'create_lmdb_for_reds',
34 | 'create_lmdb_for_gopro',
35 | 'create_lmdb_for_rain13k',
36 | ]
37 |
--------------------------------------------------------------------------------
/basicsr/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import get_root_logger, scandir
5 |
6 | # automatically scan and import model modules
7 | # scan all the files under the 'models' folder and collect files ending with
8 | # '_model.py'
9 | model_folder = osp.dirname(osp.abspath(__file__))
10 | model_filenames = [
11 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder)
12 | if v.endswith('_model.py')
13 | ]
14 | # import all the model modules
15 | _model_modules = [
16 | importlib.import_module(f'basicsr.models.{file_name}')
17 | for file_name in model_filenames
18 | ]
19 |
20 |
21 | def create_model(opt):
22 | """Create model.
23 |
24 | Args:
25 | opt (dict): Configuration. It constains:
26 | model_type (str): Model type.
27 | """
28 | model_type = opt['model_type']
29 |
30 | # dynamic instantiation
31 | for module in _model_modules:
32 | model_cls = getattr(module, model_type, None)
33 | if model_cls is not None:
34 | break
35 | if model_cls is None:
36 | raise ValueError(f'Model {model_type} is not found.')
37 |
38 | model = model_cls(opt)
39 |
40 | logger = get_root_logger()
41 | logger.info(f'Model [{model.__class__.__name__}] is created.')
42 | return model
43 |
--------------------------------------------------------------------------------
/options/test/GoPro/EFNet.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: EFNet_test
3 | model_type: TestImageEventRestorationModel
4 | scale: 1
5 | num_gpu: 1 # set num_gpu: 0 for cpu mode
6 | manual_seed: 10
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | test:
11 | name: gopro-bestmodel-test
12 | type: H5ImageDataset
13 |
14 | dataroot: /cluster/work/cvl/leisun/GOPRO_fullsize_h5_bin6_ver3/test # for debug
15 |
16 | # add
17 | norm_voxel: true
18 | return_voxel: true
19 | return_gt_frame: false
20 | return_mask: true
21 | use_mask: true
22 |
23 | crop_size: ~
24 | use_flip: false
25 | use_rot: false
26 | io_backend:
27 | type: h5
28 |
29 | dataset_name: GoPro
30 |
31 | # network structures
32 | network_g:
33 | type: EFNet
34 | wf: 64
35 | fuse_before_downsample: true
36 |
37 |
38 | # path
39 | path:
40 | pretrain_network_g: /cluster/work/cvl/leisun/log/experiments/EVTransformer-Finetune-2e4-300iter/models/net_g_latest.pth
41 | strict_load_g: true
42 | resume_state: ~
43 | root: /cluster/work/cvl/leisun/EFNet_inference/ # set this option ONLY in TEST!!!
44 |
45 | # validation settings
46 | val:
47 | save_img: true
48 | grids: ~
49 | crop_size: ~
50 | rgb2bgr: false # to my h5 data, its false
51 |
52 | # dist training settings
53 | dist_params:
54 | backend: nccl
55 | port: 29500
56 |
--------------------------------------------------------------------------------
/basicsr/models/archs/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 |
6 | # automatically scan and import arch modules
7 | # scan all the files under the 'archs' folder and collect files ending with
8 | # '_arch.py'
9 | arch_folder = osp.dirname(osp.abspath(__file__))
10 | arch_filenames = [
11 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder)
12 | if v.endswith('_arch.py')
13 | ]
14 | # import all the arch modules
15 | _arch_modules = [
16 | importlib.import_module(f'basicsr.models.archs.{file_name}')
17 | for file_name in arch_filenames
18 | ]
19 |
20 |
21 | def dynamic_instantiation(modules, cls_type, opt):
22 | """Dynamically instantiate class.
23 |
24 | Args:
25 | modules (list[importlib modules]): List of modules from importlib
26 | files.
27 | cls_type (str): Class type.
28 | opt (dict): Class initialization kwargs.
29 |
30 | Returns:
31 | class: Instantiated class.
32 | """
33 |
34 | for module in modules:
35 | cls_ = getattr(module, cls_type, None)
36 | if cls_ is not None:
37 | break
38 | if cls_ is None:
39 | raise ValueError(f'{cls_type} is not found.')
40 | return cls_(**opt)
41 |
42 |
43 | def define_network(opt):
44 | network_type = opt.pop('type')
45 | net = dynamic_instantiation(_arch_modules, network_type, opt)
46 | return net
47 |
--------------------------------------------------------------------------------
/basicsr/metrics/metric_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from basicsr.utils.matlab_functions import bgr2ycbcr
4 |
5 |
6 | def reorder_image(img, input_order='HWC'):
7 | """Reorder images to 'HWC' order.
8 |
9 | If the input_order is (h, w), return (h, w, 1);
10 | If the input_order is (c, h, w), return (h, w, c);
11 | If the input_order is (h, w, c), return as it is.
12 |
13 | Args:
14 | img (ndarray): Input image.
15 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
16 | If the input image shape is (h, w), input_order will not have
17 | effects. Default: 'HWC'.
18 |
19 | Returns:
20 | ndarray: reordered image.
21 | """
22 |
23 | if input_order not in ['HWC', 'CHW']:
24 | raise ValueError(
25 | f'Wrong input_order {input_order}. Supported input_orders are '
26 | "'HWC' and 'CHW'")
27 | if len(img.shape) == 2:
28 | img = img[..., None]
29 | if input_order == 'CHW':
30 | img = img.transpose(1, 2, 0)
31 | return img
32 |
33 |
34 | def to_y_channel(img):
35 | """Change to Y channel of YCbCr.
36 |
37 | Args:
38 | img (ndarray): Images with range [0, 255].
39 |
40 | Returns:
41 | (ndarray): Images with range [0, 255] (float type) without round.
42 | """
43 | img = img.astype(np.float32) / 255.
44 | if img.ndim == 3 and img.shape[2] == 3:
45 | img = bgr2ycbcr(img, y_only=True)
46 | img = img[..., None]
47 | return img * 255.
48 |
--------------------------------------------------------------------------------
/options/test/REBlur/Finetune_EFNet.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: EFNet_Finetune_test
3 | model_type: ImageEventRestorationModel
4 | scale: 1
5 | num_gpu: 1 # 4
6 | manual_seed: 10
7 |
8 | datasets:
9 | test:
10 | name: gopro-bestmodel-test
11 | type: H5ImageDataset
12 |
13 |
14 | # dataroot: /cluster/work/cvl/leisun/REBlur_addition # seems h5
15 | dataroot: /cluster/work/cvl/leisun/REBlur/test # REBlur
16 |
17 | # keep true if use events
18 | norm_voxel: true
19 | return_voxel: true
20 |
21 | return_mask: true # dataloader yields mask loss
22 | use_mask: true # use mask in model(data) mask as input in model
23 |
24 | filename_tmpl: '{}'
25 | io_backend:
26 | type: h5
27 |
28 | # gt_size: 256
29 | crop_size: 256
30 | use_flip: true
31 | use_rot: true
32 |
33 | # data loader settings
34 | use_shuffle: true
35 | num_worker_per_gpu: 3
36 | batch_size_per_gpu: 4
37 | dataset_enlarge_ratio: 1
38 | prefetch_mode: cpu
39 | num_prefetch_queue: 2
40 |
41 | dataset_name: REBlur
42 |
43 | # network structures
44 | network_g:
45 | type: EFNet
46 | wf: 64 #64
47 | fuse_before_downsample: true
48 |
49 |
50 | # path
51 | path:
52 | pretrain_network_g: /cluster/work/cvl/leisun/log/experiments/EV_Transformer_channelattention_simple_20witer/models/net_g_latest.pth
53 | strict_load_g: true
54 | resume_state: ~
55 | root: /cluster/work/cvl/leisun/EFNet_inference/ # set this option ONLY in TEST!!!
56 |
57 | val:
58 | save_img: true
59 | grids: ~
60 | crop_size: ~
61 | rgb2bgr: false # to my h5 data, its false
62 |
63 | dist_params:
64 | backend: nccl
65 | port: 29500
66 |
67 |
--------------------------------------------------------------------------------
/basicsr/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.utils.data.sampler import Sampler
4 |
5 |
6 | class EnlargedSampler(Sampler):
7 | """Sampler that restricts data loading to a subset of the dataset.
8 |
9 | Modified from torch.utils.data.distributed.DistributedSampler
10 | Support enlarging the dataset for iteration-based training, for saving
11 | time when
12 |
13 | restart the dataloader after each epoch
14 |
15 | Args:
16 | dataset (torch.utils.data.Dataset): Dataset used for sampling.
17 | num_replicas (int | None): Number of processes participating in
18 | the training. It is usually the world_size.
19 | rank (int | None): Rank of the current process within num_replicas.
20 | ratio (int): Enlarging ratio. Default: 1.
21 | """
22 |
23 | def __init__(self, dataset, num_replicas, rank, ratio=1):
24 | self.dataset = dataset
25 | self.num_replicas = num_replicas
26 | self.rank = rank
27 | self.epoch = 0
28 | self.num_samples = math.ceil(
29 | len(self.dataset) * ratio / self.num_replicas)
30 | self.total_size = self.num_samples * self.num_replicas
31 |
32 | def __iter__(self):
33 | # deterministically shuffle based on epoch
34 | g = torch.Generator()
35 | g.manual_seed(self.epoch)
36 | indices = torch.randperm(self.total_size, generator=g).tolist()
37 |
38 | dataset_size = len(self.dataset)
39 | indices = [v % dataset_size for v in indices]
40 |
41 | # subsample
42 | indices = indices[self.rank:self.total_size:self.num_replicas]
43 | assert len(indices) == self.num_samples
44 |
45 | return iter(indices)
46 |
47 | def __len__(self):
48 | return self.num_samples
49 |
50 | def set_epoch(self, epoch):
51 | self.epoch = epoch
52 |
--------------------------------------------------------------------------------
/basicsr/test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | from os import path as osp
4 |
5 | from basicsr.data import create_dataloader, create_dataset
6 | from basicsr.models import create_model
7 | from basicsr.train import parse_options
8 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
9 | make_exp_dirs)
10 | from basicsr.utils.options import dict2str
11 |
12 |
13 | def main():
14 | # parse options, set distributed setting, set ramdom seed
15 | opt = parse_options(is_train=False)
16 |
17 | torch.backends.cudnn.benchmark = True
18 | # torch.backends.cudnn.deterministic = True
19 |
20 | # mkdir and initialize loggers
21 | make_exp_dirs(opt)
22 | log_file = osp.join(opt['path']['log'],
23 | f"test_{opt['name']}_{get_time_str()}.log")
24 | logger = get_root_logger(
25 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
26 | logger.info(get_env_info())
27 | logger.info(dict2str(opt))
28 |
29 | # create test dataset and dataloader
30 | test_loaders = []
31 | for phase, dataset_opt in sorted(opt['datasets'].items()):
32 | test_set = create_dataset(dataset_opt)
33 | test_loader = create_dataloader(
34 | test_set,
35 | dataset_opt,
36 | num_gpu=opt['num_gpu'],
37 | dist=opt['dist'],
38 | sampler=None,
39 | seed=opt['manual_seed'])
40 | logger.info(
41 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
42 | test_loaders.append(test_loader)
43 |
44 | # create model
45 | model = create_model(opt)
46 |
47 | for test_loader in test_loaders:
48 | test_set_name = opt['datasets']['test']['name']
49 | logger.info(f'Testing {test_set_name}...')
50 | rgb2bgr = opt['val'].get('rgb2bgr', True)
51 | # wheather use uint8 image to compute metrics
52 | use_image = opt['val'].get('use_image', True)
53 | model.validation(
54 | test_loader,
55 | current_iter=opt['name'],
56 | tb_logger=None,
57 | save_img=opt['val']['save_img'],
58 | rgb2bgr=rgb2bgr, use_image=use_image)
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
--------------------------------------------------------------------------------
/basicsr/utils/download_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import requests
3 | from tqdm import tqdm
4 |
5 | from .misc import sizeof_fmt
6 |
7 |
8 | def download_file_from_google_drive(file_id, save_path):
9 | """Download files from google drive.
10 |
11 | Ref:
12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
13 |
14 | Args:
15 | file_id (str): File id.
16 | save_path (str): Save path.
17 | """
18 |
19 | session = requests.Session()
20 | URL = 'https://docs.google.com/uc?export=download'
21 | params = {'id': file_id}
22 |
23 | response = session.get(URL, params=params, stream=True)
24 | token = get_confirm_token(response)
25 | if token:
26 | params['confirm'] = token
27 | response = session.get(URL, params=params, stream=True)
28 |
29 | # get file size
30 | response_file_size = session.get(
31 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
32 | if 'Content-Range' in response_file_size.headers:
33 | file_size = int(
34 | response_file_size.headers['Content-Range'].split('/')[1])
35 | else:
36 | file_size = None
37 |
38 | save_response_content(response, save_path, file_size)
39 |
40 |
41 | def get_confirm_token(response):
42 | for key, value in response.cookies.items():
43 | if key.startswith('download_warning'):
44 | return value
45 | return None
46 |
47 |
48 | def save_response_content(response,
49 | destination,
50 | file_size=None,
51 | chunk_size=32768):
52 | if file_size is not None:
53 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
54 |
55 | readable_file_size = sizeof_fmt(file_size)
56 | else:
57 | pbar = None
58 |
59 | with open(destination, 'wb') as f:
60 | downloaded_size = 0
61 | for chunk in response.iter_content(chunk_size):
62 | downloaded_size += chunk_size
63 | if pbar is not None:
64 | pbar.update(1)
65 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
66 | f'/ {readable_file_size}')
67 | if chunk: # filter out keep-alive new chunks
68 | f.write(chunk)
69 | if pbar is not None:
70 | pbar.close()
71 |
--------------------------------------------------------------------------------
/basicsr/data/ffhq_dataset.py:
--------------------------------------------------------------------------------
1 | from os import path as osp
2 | from torch.utils import data as data
3 | from torchvision.transforms.functional import normalize
4 |
5 | from basicsr.data.transforms import augment
6 | from basicsr.utils import FileClient, imfrombytes, img2tensor
7 |
8 |
9 | class FFHQDataset(data.Dataset):
10 | """FFHQ dataset for StyleGAN.
11 |
12 | Args:
13 | opt (dict): Config for train datasets. It contains the following keys:
14 | dataroot_gt (str): Data root path for gt.
15 | io_backend (dict): IO backend type and other kwarg.
16 | mean (list | tuple): Image mean.
17 | std (list | tuple): Image std.
18 | use_hflip (bool): Whether to horizontally flip.
19 |
20 | """
21 |
22 | def __init__(self, opt):
23 | super(FFHQDataset, self).__init__()
24 | self.opt = opt
25 | # file client (io backend)
26 | self.file_client = None
27 | self.io_backend_opt = opt['io_backend']
28 |
29 | self.gt_folder = opt['dataroot_gt']
30 | self.mean = opt['mean']
31 | self.std = opt['std']
32 |
33 | if self.io_backend_opt['type'] == 'lmdb':
34 | self.io_backend_opt['db_paths'] = self.gt_folder
35 | if not self.gt_folder.endswith('.lmdb'):
36 | raise ValueError("'dataroot_gt' should end with '.lmdb', "
37 | f'but received {self.gt_folder}')
38 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
39 | self.paths = [line.split('.')[0] for line in fin]
40 | else:
41 | # FFHQ has 70000 images in total
42 | self.paths = [
43 | osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)
44 | ]
45 |
46 | def __getitem__(self, index):
47 | if self.file_client is None:
48 | self.file_client = FileClient(
49 | self.io_backend_opt.pop('type'), **self.io_backend_opt)
50 |
51 | # load gt image
52 | gt_path = self.paths[index]
53 | img_bytes = self.file_client.get(gt_path)
54 | img_gt = imfrombytes(img_bytes, float32=True)
55 |
56 | # random horizontal flip
57 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
58 | # BGR to RGB, HWC to CHW, numpy to tensor
59 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
60 | # normalize
61 | normalize(img_gt, self.mean, self.std, inplace=True)
62 | return {'gt': img_gt, 'gt_path': gt_path}
63 |
64 | def __len__(self):
65 | return len(self.paths)
66 |
--------------------------------------------------------------------------------
/scripts/publish_models.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import subprocess
3 | import torch
4 | from os import path as osp
5 | from torch.serialization import _is_zipfile, _open_file_like
6 |
7 |
8 | def update_sha(paths):
9 | print('# Update sha ...')
10 | for idx, path in enumerate(paths):
11 | print(f'{idx+1:03d}: Processing {path}')
12 | net = torch.load(path, map_location=torch.device('cpu'))
13 | basename = osp.basename(path)
14 | if 'params' not in net and 'params_ema' not in net:
15 | raise ValueError(f'Please check! Model {basename} does not '
16 | f"have 'params'/'params_ema' key.")
17 | else:
18 | if '-' in basename:
19 | # check whether the sha is the latest
20 | old_sha = basename.split('-')[1].split('.')[0]
21 | new_sha = subprocess.check_output(['sha256sum',
22 | path]).decode()[:8]
23 | if old_sha != new_sha:
24 | final_file = path.split('-')[0] + f'-{new_sha}.pth'
25 | print(f'\tSave from {path} to {final_file}')
26 | subprocess.Popen(['mv', path, final_file])
27 | else:
28 | sha = subprocess.check_output(['sha256sum', path]).decode()[:8]
29 | final_file = path.split('.pth')[0] + f'-{sha}.pth'
30 | print(f'\tSave from {path} to {final_file}')
31 | subprocess.Popen(['mv', path, final_file])
32 |
33 |
34 | def convert_to_backward_compatible_models(paths):
35 | """Convert to backward compatible pth files.
36 |
37 | PyTorch 1.6 uses a updated version of torch.save. In order to be compatible
38 | with previous PyTorch version, save it with
39 | _use_new_zipfile_serialization=False.
40 | """
41 | print('# Convert to backward compatible pth files ...')
42 | for idx, path in enumerate(paths):
43 | print(f'{idx+1:03d}: Processing {path}')
44 | flag_need_conversion = False
45 | with _open_file_like(path, 'rb') as opened_file:
46 | if _is_zipfile(opened_file):
47 | flag_need_conversion = True
48 |
49 | if flag_need_conversion:
50 | net = torch.load(path, map_location=torch.device('cpu'))
51 | print('\tConverting to compatible pth file...')
52 | torch.save(net, path, _use_new_zipfile_serialization=False)
53 |
54 |
55 | if __name__ == '__main__':
56 | paths = glob.glob('experiments/pretrained_models/*.pth') + glob.glob(
57 | 'experiments/pretrained_models/**/*.pth')
58 | convert_to_backward_compatible_models(paths)
59 | update_sha(paths)
60 |
--------------------------------------------------------------------------------
/basicsr/data/single_image_dataset.py:
--------------------------------------------------------------------------------
1 | from os import path as osp
2 | from torch.utils import data as data
3 | from torchvision.transforms.functional import normalize
4 |
5 | from basicsr.data.data_util import paths_from_lmdb
6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir
7 |
8 |
9 | class SingleImageDataset(data.Dataset):
10 | """Read only lq images in the test phase.
11 |
12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
13 |
14 | There are two modes:
15 | 1. 'meta_info_file': Use meta information file to generate paths.
16 | 2. 'folder': Scan folders to generate paths.
17 |
18 | Args:
19 | opt (dict): Config for train datasets. It contains the following keys:
20 | dataroot_lq (str): Data root path for lq.
21 | meta_info_file (str): Path for meta information file.
22 | io_backend (dict): IO backend type and other kwarg.
23 | """
24 |
25 | def __init__(self, opt):
26 | super(SingleImageDataset, self).__init__()
27 | self.opt = opt
28 | # file client (io backend)
29 | self.file_client = None
30 | self.io_backend_opt = opt['io_backend']
31 | self.mean = opt['mean'] if 'mean' in opt else None
32 | self.std = opt['std'] if 'std' in opt else None
33 | self.lq_folder = opt['dataroot_lq']
34 |
35 | if self.io_backend_opt['type'] == 'lmdb':
36 | self.io_backend_opt['db_paths'] = [self.lq_folder]
37 | self.io_backend_opt['client_keys'] = ['lq']
38 | self.paths = paths_from_lmdb(self.lq_folder)
39 | elif 'meta_info_file' in self.opt:
40 | with open(self.opt['meta_info_file'], 'r') as fin:
41 | self.paths = [
42 | osp.join(self.lq_folder,
43 | line.split(' ')[0]) for line in fin
44 | ]
45 | else:
46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
47 |
48 | def __getitem__(self, index):
49 | if self.file_client is None:
50 | self.file_client = FileClient(
51 | self.io_backend_opt.pop('type'), **self.io_backend_opt)
52 |
53 | # load lq image
54 | lq_path = self.paths[index]
55 | img_bytes = self.file_client.get(lq_path, 'lq')
56 | img_lq = imfrombytes(img_bytes, float32=True)
57 |
58 | # TODO: color space transform
59 | # BGR to RGB, HWC to CHW, numpy to tensor
60 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
61 | # normalize
62 | if self.mean is not None or self.std is not None:
63 | normalize(img_lq, self.mean, self.std, inplace=True)
64 | return {'lq': img_lq, 'lq_path': lq_path}
65 |
66 | def __len__(self):
67 | return len(self.paths)
68 |
--------------------------------------------------------------------------------
/options/train/HighREV/EFNet_HighREV_Deblur.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: EFNet_highrev_single_deblur # add debug for quick debug
3 | model_type: ImageEventRestorationModel
4 | scale: 1
5 | num_gpu: 1 #4
6 | manual_seed: 10
7 |
8 | datasets:
9 | train:
10 | name: highrev-train
11 | type: NpzPngSingleDeblurDataset
12 |
13 | # dataroot: /work/lei_sun/HighREV/train
14 | dataroot: /PATH/TO/DATASET/HighREV/train
15 |
16 | voxel_bins: 6
17 | gt_size: 256
18 | # keep true if use events
19 | norm_voxel: true
20 | use_hflip: true
21 | use_rot: true
22 |
23 | filename_tmpl: '{}'
24 | io_backend:
25 | type: disk
26 |
27 | # data loader settings
28 | use_shuffle: true
29 | num_worker_per_gpu: 3
30 | batch_size_per_gpu: 4 # 4 for 2080, 8 for titan
31 | dataset_enlarge_ratio: 4 # accelarate, equals to the num_gpu
32 | prefetch_mode: cpu
33 | num_prefetch_queue: 2
34 |
35 | val:
36 | name: highrev-val
37 | type: NpzPngSingleDeblurDataset
38 | voxel_bins: 6
39 | # dataroot: /work/lei_sun/HighREV/val
40 | dataroot: /PATH/TO/DATASET/HighREV/val
41 | gt_size: ~
42 | norm_voxel: true
43 |
44 | io_backend:
45 | type: disk
46 |
47 | use_hflip: false
48 | use_rot: false
49 |
50 | dataset_name: HighREV
51 |
52 | # network structures
53 | network_g:
54 | type: EFNet
55 | wf: 64
56 | fuse_before_downsample: true
57 |
58 | # path
59 | path:
60 | pretrain_network_g: ~
61 | strict_load_g: true
62 | resume_state: ~
63 | training_states: ~ # save current trainig model states, for resume
64 |
65 | # training settings
66 | train:
67 | optim_g:
68 | type: AdamW
69 | lr: !!float 2e-4
70 | weight_decay: !!float 1e-4
71 | betas: [0.9, 0.99]
72 |
73 | scheduler:
74 | type: TrueCosineAnnealingLR
75 | T_max: 200000
76 | eta_min: !!float 1e-7
77 |
78 | total_iter: 200000
79 | warmup_iter: -1 # no warm up
80 |
81 | # losses
82 | pixel_opt:
83 | type: PSNRLoss
84 | loss_weight: 0.5
85 | reduction: mean
86 |
87 | # validation settings
88 | val:
89 | val_freq: !!float 5e4 # 2e4
90 | save_img: false
91 | grids: ~
92 | crop_size: ~ # use it of the gpu memory is not enough for whole image inference
93 | max_minibatch: 8
94 |
95 | metrics:
96 | psnr:
97 | type: calculate_psnr
98 | crop_border: 0
99 | test_y_channel: false
100 |
101 | ssim:
102 | type: calculate_ssim
103 | crop_border: 0
104 | test_y_channel: false
105 |
106 | # logging settings
107 | logger:
108 | print_freq: 200
109 | save_checkpoint_freq: !!float 2e4
110 | use_tb_logger: true
111 | wandb:
112 | project: your_project_name
113 | resume_id: x
114 |
115 | # dist training settings
116 | dist_params:
117 | backend: nccl
118 | port: 29500
119 |
--------------------------------------------------------------------------------
/basicsr/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2 | import functools
3 | import os
4 | import subprocess
5 | import torch
6 | import torch.distributed as dist
7 | import torch.multiprocessing as mp
8 |
9 |
10 | def init_dist(launcher, backend='nccl', **kwargs):
11 | if mp.get_start_method(allow_none=True) is None:
12 | mp.set_start_method('spawn')
13 | if launcher == 'pytorch':
14 | _init_dist_pytorch(backend, **kwargs)
15 | elif launcher == 'slurm':
16 | _init_dist_slurm(backend, **kwargs)
17 | else:
18 | raise ValueError(f'Invalid launcher type: {launcher}')
19 |
20 |
21 | def _init_dist_pytorch(backend, **kwargs):
22 | rank = int(os.environ['RANK'])
23 | num_gpus = torch.cuda.device_count()
24 | torch.cuda.set_device(rank % num_gpus)
25 | dist.init_process_group(backend=backend, **kwargs)
26 |
27 |
28 | def _init_dist_slurm(backend, port=None):
29 | """Initialize slurm distributed training environment.
30 |
31 | If argument ``port`` is not specified, then the master port will be system
32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
33 | environment variable, then a default port ``29500`` will be used.
34 |
35 | Args:
36 | backend (str): Backend of torch.distributed.
37 | port (int, optional): Master port. Defaults to None.
38 | """
39 | proc_id = int(os.environ['SLURM_PROCID'])
40 | ntasks = int(os.environ['SLURM_NTASKS'])
41 | node_list = os.environ['SLURM_NODELIST']
42 | num_gpus = torch.cuda.device_count()
43 | torch.cuda.set_device(proc_id % num_gpus)
44 | addr = subprocess.getoutput(
45 | f'scontrol show hostname {node_list} | head -n1')
46 | # specify master port
47 | if port is not None:
48 | os.environ['MASTER_PORT'] = str(port)
49 | elif 'MASTER_PORT' in os.environ:
50 | pass # use MASTER_PORT in the environment variable
51 | else:
52 | # 29500 is torch.distributed default port
53 | os.environ['MASTER_PORT'] = '29500'
54 | os.environ['MASTER_ADDR'] = addr
55 | os.environ['WORLD_SIZE'] = str(ntasks)
56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
57 | os.environ['RANK'] = str(proc_id)
58 | dist.init_process_group(backend=backend)
59 |
60 |
61 | def get_dist_info():
62 | if dist.is_available():
63 | initialized = dist.is_initialized()
64 | else:
65 | initialized = False
66 | if initialized:
67 | rank = dist.get_rank()
68 | world_size = dist.get_world_size()
69 | else:
70 | rank = 0
71 | world_size = 1
72 | return rank, world_size
73 |
74 |
75 | def master_only(func):
76 |
77 | @functools.wraps(func)
78 | def wrapper(*args, **kwargs):
79 | rank, _ = get_dist_info()
80 | if rank == 0:
81 | return func(*args, **kwargs)
82 |
83 | return wrapper
84 |
--------------------------------------------------------------------------------
/options/train/GoPro/EFNet.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: EFNet_experiment # add 'debug' for quick debug
3 | model_type: ImageEventRestorationModel
4 | scale: 1
5 | num_gpu: 1 #4
6 | manual_seed: 10
7 |
8 | datasets:
9 | train:
10 | name: gopro-h5-train
11 | type: H5ImageDataset
12 |
13 | # dataroot: ./datasets/GoPro_scer/train
14 | dataroot: /cluster/work/cvl/leisun/GOPRO_fullsize_h5_bin6_ver3/train
15 |
16 | # keep true if use events
17 | norm_voxel: true
18 | return_voxel: true
19 |
20 | return_mask: true
21 | use_mask: true
22 |
23 | filename_tmpl: '{}'
24 | io_backend:
25 | type: h5
26 |
27 | crop_size: 256
28 | use_flip: true
29 | use_rot: true
30 |
31 | # data loader settings
32 | use_shuffle: true
33 | num_worker_per_gpu: 3
34 | batch_size_per_gpu: 4 # 4 for 2080, 8 for titan
35 | dataset_enlarge_ratio: 4 # accelarate, equals to the num_gpu
36 | prefetch_mode: cpu
37 | num_prefetch_queue: 2
38 |
39 | val:
40 | name: gopro-h5-test
41 | type: H5ImageDataset
42 |
43 | # dataroot: ./datasets/test
44 | dataroot: /cluster/work/cvl/leisun/GOPRO_fullsize_h5_bin6_ver3/test # for debug
45 |
46 | norm_voxel: true
47 | return_voxel: true
48 | return_mask: true
49 | use_mask: true
50 |
51 | io_backend:
52 | type: h5
53 |
54 | crop_size: ~
55 | use_flip: false
56 | use_rot: false
57 |
58 | dataset_name: GoPro
59 |
60 | # network structures
61 | network_g:
62 | type: EFNet
63 | wf: 64
64 | fuse_before_downsample: true
65 |
66 | # path
67 | path:
68 | pretrain_network_g: ~
69 | strict_load_g: true
70 | resume_state: ~
71 | training_states: ~ # save current trainig model states, for resume
72 |
73 | # training settings
74 | train:
75 | optim_g:
76 | type: AdamW
77 | lr: !!float 2e-4
78 | weight_decay: !!float 1e-4
79 | betas: [0.9, 0.99]
80 |
81 | scheduler:
82 | type: TrueCosineAnnealingLR
83 | T_max: 200000
84 | eta_min: !!float 1e-7
85 |
86 | total_iter: 200000
87 | warmup_iter: -1 # no warm up
88 |
89 | # losses
90 | pixel_opt:
91 | type: PSNRLoss
92 | loss_weight: 0.5
93 | reduction: mean
94 |
95 | # validation settings
96 | val:
97 | val_freq: !!float 5e4 # 2e4
98 | save_img: false
99 | grids: ~
100 | crop_size: ~
101 | max_minibatch: 8
102 |
103 | metrics:
104 | psnr:
105 | type: calculate_psnr
106 | crop_border: 0
107 | test_y_channel: false
108 |
109 | ssim:
110 | type: calculate_ssim
111 | crop_border: 0
112 | test_y_channel: false
113 |
114 | # logging settings
115 | logger:
116 | print_freq: 200
117 | save_checkpoint_freq: !!float 2e4
118 | use_tb_logger: true
119 | wandb:
120 | project: your_project_name
121 | resume_id: x
122 |
123 | # dist training settings
124 | dist_params:
125 | backend: nccl
126 | port: 29500
127 |
--------------------------------------------------------------------------------
/options/train/REBlur/Finetune_EFNet.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: EFNet_Finetune # add 'debug' for quick debug
3 | model_type: ImageEventRestorationModel
4 | scale: 1
5 | num_gpu: 1 # 4
6 | manual_seed: 10
7 |
8 | datasets:
9 | train:
10 | name: REBlur-h5-train
11 | type: H5ImageDataset
12 |
13 | # dataroot: ./datasets/train
14 | dataroot: /cluster/work/cvl/leisun/REBlur/train # REBlur
15 |
16 | # keep true if use events
17 | norm_voxel: true
18 | return_voxel: true
19 |
20 | return_mask: true
21 | use_mask: true
22 |
23 | filename_tmpl: '{}'
24 | io_backend:
25 | type: h5
26 |
27 | crop_size: 256
28 | use_flip: true
29 | use_rot: true
30 |
31 | # data loader settings
32 | use_shuffle: true
33 | num_worker_per_gpu: 3
34 | batch_size_per_gpu: 4
35 | dataset_enlarge_ratio: 1
36 | prefetch_mode: cpu
37 | num_prefetch_queue: 2
38 |
39 | val:
40 | name: REBlur-h5-test
41 | type: H5ImageDataset
42 |
43 | # dataroot: ./datasets/test
44 | dataroot: /cluster/work/cvl/leisun/REBlur/test
45 |
46 | # keep true
47 | norm_voxel: true # true !
48 | return_voxel: true
49 |
50 | return_mask: true # dataloader yields mask
51 | use_mask: true
52 |
53 | io_backend:
54 | type: h5
55 |
56 | crop_size: ~
57 | use_flip: false
58 | use_rot: false
59 |
60 | dataset_name: REBlur
61 |
62 | # network structures
63 | network_g:
64 | type: EFNet
65 | wf: 64 #64
66 | fuse_before_downsample: true
67 |
68 |
69 | # path
70 | path:
71 | pretrain_network_g: ~ # pretrain path
72 | strict_load_g: true
73 | resume_state: ~
74 | training_states: ~ # save current trainig model states, for resume
75 |
76 |
77 | # training settings
78 | train:
79 | optim_g:
80 | type: AdamW
81 | lr: !!float 2e-5
82 | weight_decay: !!float 1e-4
83 | betas: [0.9, 0.99]
84 |
85 |
86 | scheduler:
87 | type: TrueCosineAnnealingLR
88 | T_max: 600 # finetune
89 | eta_min: !!float 1e-7
90 |
91 | total_iter: 600 # finetune
92 | warmup_iter: -1 # no warm up
93 |
94 | # losses
95 | pixel_opt:
96 | type: PSNRLoss
97 | loss_weight: 0.5
98 | reduction: mean
99 |
100 | # validation settings
101 | val:
102 | val_freq: !!float 5e4 # 2e4
103 | save_img: false
104 | grids: ~
105 | crop_size: ~ #
106 | max_minibatch: 8
107 |
108 | metrics:
109 | psnr:
110 | type: calculate_psnr
111 | crop_border: 0
112 | test_y_channel: false
113 |
114 | # add ssim
115 | ssim:
116 | type: calculate_ssim
117 | crop_border: 0
118 | test_y_channel: false
119 |
120 | # logging settings
121 | logger:
122 | print_freq: 50
123 | save_checkpoint_freq: !!float 100
124 | use_tb_logger: true
125 | wandb:
126 | project: ~
127 | resume_id: ~
128 |
129 | # dist training settings
130 | dist_params:
131 | backend: nccl
132 | port: 29500
133 |
134 |
--------------------------------------------------------------------------------
/options/train/HighREV/EFNet_HighREV_Deblur_voxel.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: EFNet_highrev_single_deblur_voxel_debug # add debug for quick debug
3 | model_type: ImageEventRestorationModel
4 | scale: 1
5 | num_gpu: 1 #4
6 | manual_seed: 10
7 |
8 | datasets:
9 | train:
10 | name: highrev-train
11 | type: VoxelnpzPngSingleDeblurDataset
12 |
13 | # dataroot: /work/lei_sun/HighREV/train
14 | # dataroot_voxel: /work/lei_sun/HighREV_voxel/train/voxel
15 | dataroot: /PATH/TO/DATASET/HighREV/train
16 | dataroot_voxel: /PATH/TO/DATASET/HighREV_voxel/train/voxel
17 |
18 | gt_size: 256
19 | # keep true if use events
20 | norm_voxel: true
21 | use_hflip: true
22 | use_rot: true
23 |
24 | filename_tmpl: '{}'
25 | io_backend:
26 | type: disk
27 |
28 | # data loader settings
29 | use_shuffle: true
30 | num_worker_per_gpu: 3
31 | batch_size_per_gpu: 4 # 4 for 2080, 8 for titan
32 | dataset_enlarge_ratio: 4 # accelarate, equals to the num_gpu
33 | prefetch_mode: cpu
34 | num_prefetch_queue: 2
35 |
36 | val:
37 | name: highrev-val
38 | type: VoxelnpzPngSingleDeblurDataset
39 |
40 | # dataroot: /work/lei_sun/HighREV/val
41 | # dataroot_voxel: /work/lei_sun/HighREV_voxel/val/voxel
42 | dataroot: /PATH/TO/DATASET/HighREV/train
43 | dataroot_voxel: /PATH/TO/DATASET/HighREV_voxel/train/voxel
44 |
45 |
46 | gt_size: ~
47 | norm_voxel: true
48 |
49 | io_backend:
50 | type: disk
51 |
52 | use_hflip: false
53 | use_rot: false
54 |
55 | dataset_name: HighREV
56 |
57 | # network structures
58 | network_g:
59 | type: EFNet
60 | wf: 64
61 | fuse_before_downsample: true
62 |
63 | # path
64 | path:
65 | pretrain_network_g: ~
66 | strict_load_g: true
67 | resume_state: ~
68 | training_states: ~ # save current trainig model states, for resume
69 |
70 | # training settings
71 | train:
72 | optim_g:
73 | type: AdamW
74 | lr: !!float 2e-4
75 | weight_decay: !!float 1e-4
76 | betas: [0.9, 0.99]
77 |
78 | scheduler:
79 | type: TrueCosineAnnealingLR
80 | T_max: 200000
81 | eta_min: !!float 1e-7
82 |
83 | total_iter: 200000
84 | warmup_iter: -1 # no warm up
85 |
86 | # losses
87 | pixel_opt:
88 | type: PSNRLoss
89 | loss_weight: 0.5
90 | reduction: mean
91 |
92 | # validation settings
93 | val:
94 | val_freq: !!float 5e4 # 2e4
95 | save_img: false
96 | grids: ~
97 | crop_size: ~ # use it of the gpu memory is not enough for whole image inference
98 | max_minibatch: 8
99 |
100 | metrics:
101 | psnr:
102 | type: calculate_psnr
103 | crop_border: 0
104 | test_y_channel: false
105 |
106 | ssim:
107 | type: calculate_ssim
108 | crop_border: 0
109 | test_y_channel: false
110 |
111 | # logging settings
112 | logger:
113 | print_freq: 200
114 | save_checkpoint_freq: !!float 2e4
115 | use_tb_logger: true
116 | wandb:
117 | project: your_project_name
118 | resume_id: x
119 |
120 | # dist training settings
121 | dist_params:
122 | backend: nccl
123 | port: 29500
124 |
--------------------------------------------------------------------------------
/basicsr/models/losses/loss_util.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from torch.nn import functional as F
3 |
4 |
5 | def reduce_loss(loss, reduction):
6 | """Reduce loss as specified.
7 |
8 | Args:
9 | loss (Tensor): Elementwise loss tensor.
10 | reduction (str): Options are 'none', 'mean' and 'sum'.
11 |
12 | Returns:
13 | Tensor: Reduced loss tensor.
14 | """
15 | reduction_enum = F._Reduction.get_enum(reduction)
16 | # none: 0, elementwise_mean:1, sum: 2
17 | if reduction_enum == 0:
18 | return loss
19 | elif reduction_enum == 1:
20 | return loss.mean()
21 | else:
22 | return loss.sum()
23 |
24 |
25 | def weight_reduce_loss(loss, weight=None, reduction='mean'):
26 | """Apply element-wise weight and reduce loss.
27 |
28 | Args:
29 | loss (Tensor): Element-wise loss.
30 | weight (Tensor): Element-wise weights. Default: None.
31 | reduction (str): Same as built-in losses of PyTorch. Options are
32 | 'none', 'mean' and 'sum'. Default: 'mean'.
33 |
34 | Returns:
35 | Tensor: Loss values.
36 | """
37 | # if weight is specified, apply element-wise weight
38 | if weight is not None:
39 | assert weight.dim() == loss.dim()
40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41 | loss = loss * weight
42 |
43 | # if weight is not specified or reduction is sum, just reduce the loss
44 | if weight is None or reduction == 'sum':
45 | loss = reduce_loss(loss, reduction)
46 | # if reduction is mean, then compute mean over weight region
47 | elif reduction == 'mean':
48 | if weight.size(1) > 1:
49 | weight = weight.sum()
50 | else:
51 | weight = weight.sum() * loss.size(1)
52 | loss = loss.sum() / weight
53 |
54 | return loss
55 |
56 |
57 | def weighted_loss(loss_func):
58 | """Create a weighted version of a given loss function.
59 |
60 | To use this decorator, the loss function must have the signature like
61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute
62 | element-wise loss without any reduction. This decorator will add weight
63 | and reduction arguments to the function. The decorated function will have
64 | the signature like `loss_func(pred, target, weight=None, reduction='mean',
65 | **kwargs)`.
66 |
67 | :Example:
68 |
69 | >>> import torch
70 | >>> @weighted_loss
71 | >>> def l1_loss(pred, target):
72 | >>> return (pred - target).abs()
73 |
74 | >>> pred = torch.Tensor([0, 2, 3])
75 | >>> target = torch.Tensor([1, 1, 1])
76 | >>> weight = torch.Tensor([1, 0, 1])
77 |
78 | >>> l1_loss(pred, target)
79 | tensor(1.3333)
80 | >>> l1_loss(pred, target, weight)
81 | tensor(1.5000)
82 | >>> l1_loss(pred, target, reduction='none')
83 | tensor([1., 1., 2.])
84 | >>> l1_loss(pred, target, weight, reduction='sum')
85 | tensor(3.)
86 | """
87 |
88 | @functools.wraps(loss_func)
89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90 | # get element-wise loss
91 | loss = loss_func(pred, target, **kwargs)
92 | loss = weight_reduce_loss(loss, weight, reduction)
93 | return loss
94 |
95 | return wrapper
96 |
--------------------------------------------------------------------------------
/basicsr/data/prefetch_dataloader.py:
--------------------------------------------------------------------------------
1 | import queue as Queue
2 | import threading
3 | import torch
4 | from torch.utils.data import DataLoader
5 |
6 |
7 | class PrefetchGenerator(threading.Thread):
8 | """A general prefetch generator.
9 |
10 | Ref:
11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12 |
13 | Args:
14 | generator: Python generator.
15 | num_prefetch_queue (int): Number of prefetch queue.
16 | """
17 |
18 | def __init__(self, generator, num_prefetch_queue):
19 | threading.Thread.__init__(self)
20 | self.queue = Queue.Queue(num_prefetch_queue)
21 | self.generator = generator
22 | self.daemon = True
23 | self.start()
24 |
25 | def run(self):
26 | for item in self.generator:
27 | self.queue.put(item)
28 | self.queue.put(None)
29 |
30 | def __next__(self):
31 | next_item = self.queue.get()
32 | if next_item is None:
33 | raise StopIteration
34 | return next_item
35 |
36 | def __iter__(self):
37 | return self
38 |
39 |
40 | class PrefetchDataLoader(DataLoader):
41 | """Prefetch version of dataloader.
42 |
43 | Ref:
44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45 |
46 | TODO:
47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48 | ddp.
49 |
50 | Args:
51 | num_prefetch_queue (int): Number of prefetch queue.
52 | kwargs (dict): Other arguments for dataloader.
53 | """
54 |
55 | def __init__(self, num_prefetch_queue, **kwargs):
56 | self.num_prefetch_queue = num_prefetch_queue
57 | super(PrefetchDataLoader, self).__init__(**kwargs)
58 |
59 | def __iter__(self):
60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61 |
62 |
63 | class CPUPrefetcher():
64 | """CPU prefetcher.
65 |
66 | Args:
67 | loader: Dataloader.
68 | """
69 |
70 | def __init__(self, loader):
71 | self.ori_loader = loader
72 | self.loader = iter(loader)
73 |
74 | def next(self):
75 | try:
76 | return next(self.loader)
77 | except StopIteration:
78 | return None
79 |
80 | def reset(self):
81 | self.loader = iter(self.ori_loader)
82 |
83 |
84 | class CUDAPrefetcher():
85 | """CUDA prefetcher.
86 |
87 | Ref:
88 | https://github.com/NVIDIA/apex/issues/304#
89 |
90 | It may consums more GPU memory.
91 |
92 | Args:
93 | loader: Dataloader.
94 | opt (dict): Options.
95 | """
96 |
97 | def __init__(self, loader, opt):
98 | self.ori_loader = loader
99 | self.loader = iter(loader)
100 | self.opt = opt
101 | self.stream = torch.cuda.Stream()
102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103 | self.preload()
104 |
105 | def preload(self):
106 | try:
107 | self.batch = next(self.loader) # self.batch is a dict
108 | except StopIteration:
109 | self.batch = None
110 | return None
111 | # put tensors to gpu
112 | with torch.cuda.stream(self.stream):
113 | for k, v in self.batch.items():
114 | if torch.is_tensor(v):
115 | self.batch[k] = self.batch[k].to(
116 | device=self.device, non_blocking=True)
117 |
118 | def next(self):
119 | torch.cuda.current_stream().wait_stream(self.stream)
120 | batch = self.batch
121 | self.preload()
122 | return batch
123 |
124 | def reset(self):
125 | self.loader = iter(self.ori_loader)
126 | self.preload()
127 |
--------------------------------------------------------------------------------
/basicsr/metrics/fid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from scipy import linalg
5 | from tqdm import tqdm
6 |
7 | from basicsr.models.archs.inception import InceptionV3
8 |
9 |
10 | def load_patched_inception_v3(device='cuda',
11 | resize_input=True,
12 | normalize_input=False):
13 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it
14 | # does resize the input.
15 | inception = InceptionV3([3],
16 | resize_input=resize_input,
17 | normalize_input=normalize_input)
18 | inception = nn.DataParallel(inception).eval().to(device)
19 | return inception
20 |
21 |
22 | @torch.no_grad()
23 | def extract_inception_features(data_generator,
24 | inception,
25 | len_generator=None,
26 | device='cuda'):
27 | """Extract inception features.
28 |
29 | Args:
30 | data_generator (generator): A data generator.
31 | inception (nn.Module): Inception model.
32 | len_generator (int): Length of the data_generator to show the
33 | progressbar. Default: None.
34 | device (str): Device. Default: cuda.
35 |
36 | Returns:
37 | Tensor: Extracted features.
38 | """
39 | if len_generator is not None:
40 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract')
41 | else:
42 | pbar = None
43 | features = []
44 |
45 | for data in data_generator:
46 | if pbar:
47 | pbar.update(1)
48 | data = data.to(device)
49 | feature = inception(data)[0].view(data.shape[0], -1)
50 | features.append(feature.to('cpu'))
51 | if pbar:
52 | pbar.close()
53 | features = torch.cat(features, 0)
54 | return features
55 |
56 |
57 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
58 | """Numpy implementation of the Frechet Distance.
59 |
60 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
61 | and X_2 ~ N(mu_2, C_2) is
62 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
63 | Stable version by Dougal J. Sutherland.
64 |
65 | Args:
66 | mu1 (np.array): The sample mean over activations.
67 | sigma1 (np.array): The covariance matrix over activations for
68 | generated samples.
69 | mu2 (np.array): The sample mean over activations, precalculated on an
70 | representative data set.
71 | sigma2 (np.array): The covariance matrix over activations,
72 | precalculated on an representative data set.
73 |
74 | Returns:
75 | float: The Frechet Distance.
76 | """
77 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
78 | assert sigma1.shape == sigma2.shape, (
79 | 'Two covariances have different dimensions')
80 |
81 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
82 |
83 | # Product might be almost singular
84 | if not np.isfinite(cov_sqrt).all():
85 | print('Product of cov matrices is singular. Adding {eps} to diagonal '
86 | 'of cov estimates')
87 | offset = np.eye(sigma1.shape[0]) * eps
88 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
89 |
90 | # Numerical error might give slight imaginary component
91 | if np.iscomplexobj(cov_sqrt):
92 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
93 | m = np.max(np.abs(cov_sqrt.imag))
94 | raise ValueError(f'Imaginary component {m}')
95 | cov_sqrt = cov_sqrt.real
96 |
97 | mean_diff = mu1 - mu2
98 | mean_norm = mean_diff @ mean_diff
99 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt)
100 | fid = mean_norm + trace
101 |
102 | return fid
103 |
--------------------------------------------------------------------------------
/basicsr/data/event_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def events_to_voxel_grid(events, num_bins, width, height, return_format='CHW'):
6 | """
7 | Build a voxel grid with bilinear interpolation in the time domain from a set of events.
8 |
9 | :param events: a [N x 4] NumPy array containing one event per row in the form: [timestamp, x, y, polarity]
10 | :param num_bins: number of bins in the temporal axis of the voxel grid
11 | :param width, height: dimensions of the voxel grid
12 | :param return_format: 'CHW' or 'HWC'
13 | """
14 |
15 | assert(events.shape[1] == 4)
16 | assert(num_bins > 0)
17 | assert(width > 0)
18 | assert(height > 0)
19 |
20 | voxel_grid = np.zeros((num_bins, height, width), np.float32).ravel()
21 | # print('DEBUG: voxel.shape:{}'.format(voxel_grid.shape))
22 |
23 | # normalize the event timestamps so that they lie between 0 and num_bins
24 | last_stamp = events[-1, 0]
25 | # print('last stamp:{}'.format(last_stamp))
26 | # print('max stamp:{}'.format(events[:, 0].max()))
27 | # print('timestamp:{}'.format(events[:, 0]))
28 | # print('polarity:{}'.format(events[:, -1]))
29 |
30 | first_stamp = events[0, 0]
31 | deltaT = last_stamp - first_stamp
32 |
33 | if deltaT == 0:
34 | deltaT = 1.0
35 |
36 | events[:, 0] = (num_bins - 1) * (events[:, 0] - first_stamp) / deltaT #
37 | ts = events[:, 0]
38 | xs = events[:, 1].astype(int)
39 | ys = events[:, 2].astype(int)
40 | pols = events[:, 3]
41 | pols[pols == 0] = -1 # polarity should be +1 / -1
42 |
43 | tis = ts.astype(int)
44 | dts = ts - tis
45 | vals_left = pols * (1.0 - dts)
46 | vals_right = pols * dts
47 |
48 | valid_indices = tis < num_bins # [True True ... True]
49 | # print('x max:{}'.format(xs[valid_indices].max()))
50 | # print('y max:{}'.format(ys[valid_indices].max()))
51 | # print('tix max:{}'.format(tis[valid_indices].max()))
52 |
53 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width ## ! ! !
54 | + tis[valid_indices] * width * height, vals_left[valid_indices])
55 |
56 | valid_indices = (tis + 1) < num_bins
57 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width
58 | + (tis[valid_indices] + 1) * width * height, vals_right[valid_indices])
59 |
60 | voxel_grid = np.reshape(voxel_grid, (num_bins, height, width))
61 |
62 | if return_format == 'CHW':
63 | return voxel_grid
64 | elif return_format == 'HWC':
65 | return voxel_grid.transpose(1,2,0)
66 |
67 | def voxel_norm(voxel):
68 | """
69 | Norm the voxel
70 |
71 | :param voxel: The unnormed voxel grid
72 | :return voxel: The normed voxel grid
73 | """
74 | nonzero_ev = (voxel != 0)
75 | num_nonzeros = nonzero_ev.sum()
76 | # print('DEBUG: num_nonzeros:{}'.format(num_nonzeros))
77 | if num_nonzeros > 0:
78 | # compute mean and stddev of the **nonzero** elements of the event tensor
79 | # we do not use PyTorch's default mean() and std() functions since it's faster
80 | # to compute it by hand than applying those funcs to a masked array
81 | mean = voxel.sum() / num_nonzeros
82 | stddev = torch.sqrt((voxel ** 2).sum() / num_nonzeros - mean ** 2)
83 | mask = nonzero_ev.float()
84 | voxel = mask * (voxel - mean) / stddev
85 |
86 | return voxel
87 |
88 |
89 |
90 | def filter_event(x,y,p,t, s_e_index=[0,6]):
91 | '''
92 | s_e_index: include both left and right index
93 | '''
94 | t_1=t.squeeze(1)
95 | uniqw, inverse = np.unique(t_1, return_inverse=True)
96 | discretized_ts = np.bincount(inverse)
97 | index_exposure_start = np.sum(discretized_ts[0:s_e_index[0]])
98 | index_exposure_end = np.sum(discretized_ts[0:s_e_index[1]+1])
99 | x_1 = x[index_exposure_start:index_exposure_end]
100 | y_1 = y[index_exposure_start:index_exposure_end]
101 | p_1 = p[index_exposure_start:index_exposure_end]
102 | t_1 = t[index_exposure_start:index_exposure_end]
103 |
104 | return x_1, y_1, p_1, t_1
105 |
106 |
--------------------------------------------------------------------------------
/basicsr/utils/options.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from collections import OrderedDict
3 | from os import path as osp
4 |
5 |
6 | def ordered_yaml():
7 | """Support OrderedDict for yaml.
8 |
9 | Returns:
10 | yaml Loader and Dumper.
11 | """
12 | try:
13 | from yaml import CDumper as Dumper
14 | from yaml import CLoader as Loader
15 | except ImportError:
16 | from yaml import Dumper, Loader
17 |
18 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
19 |
20 | def dict_representer(dumper, data):
21 | return dumper.represent_dict(data.items())
22 |
23 | def dict_constructor(loader, node):
24 | return OrderedDict(loader.construct_pairs(node))
25 |
26 | Dumper.add_representer(OrderedDict, dict_representer)
27 | Loader.add_constructor(_mapping_tag, dict_constructor)
28 | return Loader, Dumper
29 |
30 |
31 | def parse(opt_path, is_train=True):
32 | """Parse option file.
33 |
34 | Args:
35 | opt_path (str): Option file path.
36 | is_train (str): Indicate whether in training or not. Default: True.
37 |
38 | Returns:
39 | (dict): Options.
40 | """
41 | with open(opt_path, mode='r') as f:
42 | Loader, _ = ordered_yaml()
43 | opt = yaml.load(f, Loader=Loader)
44 |
45 | opt['is_train'] = is_train
46 |
47 | # datasets
48 | if 'datasets' in opt:
49 | for phase, dataset in opt['datasets'].items():
50 | # for several datasets, e.g., test_1, test_2
51 | phase = phase.split('_')[0]
52 | dataset['phase'] = phase
53 | if 'scale' in opt:
54 | dataset['scale'] = opt['scale']
55 | if dataset.get('dataroot_gt') is not None:
56 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
57 | if dataset.get('dataroot_lq') is not None:
58 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
59 |
60 | # paths
61 | for key, val in opt['path'].items():
62 | if (val is not None) and ('resume_state' in key
63 | or 'pretrain_network' in key):
64 | opt['path'][key] = osp.expanduser(val)
65 |
66 | # opt['path']['root'] = osp.abspath(
67 | # osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
68 | opt['path']['root'] = opt['path'].get('root',osp.abspath(
69 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)))
70 |
71 | if is_train:
72 | experiments_root = osp.join(opt['path']['root'], 'experiments',
73 | opt['name'])
74 | opt['path']['experiments_root'] = experiments_root
75 | opt['path']['models'] = osp.join(experiments_root, 'models')
76 | opt['path']['training_states'] = osp.join(experiments_root,
77 | 'training_states')
78 | opt['path']['log'] = experiments_root
79 | opt['path']['visualization'] = osp.join(experiments_root,
80 | 'visualization')
81 |
82 | # change some options for debug mode
83 | if 'debug' in opt['name']:
84 | if 'val' in opt:
85 | opt['val']['val_freq'] = 8
86 | opt['logger']['print_freq'] = 1
87 | opt['logger']['save_checkpoint_freq'] = 8
88 | else: # test
89 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
90 | opt['path']['results_root'] = results_root
91 | opt['path']['log'] = results_root
92 | # opt['path']['visualization'] = opt['path'].get('visualization', osp.join(results_root, 'visualization'))
93 | opt['path']['visualization'] = osp.join(results_root, 'visualization')
94 |
95 | return opt
96 |
97 |
98 | def dict2str(opt, indent_level=1):
99 | """dict to string for printing options.
100 |
101 | Args:
102 | opt (dict): Option dict.
103 | indent_level (int): Indent level. Default: 1.
104 |
105 | Return:
106 | (str): Option string for printing.
107 | """
108 | msg = '\n'
109 | for k, v in opt.items():
110 | if isinstance(v, dict):
111 | msg += ' ' * (indent_level * 2) + k + ':['
112 | msg += dict2str(v, indent_level + 1)
113 | msg += ' ' * (indent_level * 2) + ']\n'
114 | else:
115 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
116 | return msg
117 |
--------------------------------------------------------------------------------
/scripts/data_preparation/make_voxels_esim.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | import cv2
4 | import os
5 | from .raw_event_dataset import *
6 | from .noise_function import add_noise_to_voxel, put_hot_pixels_in_voxel_
7 | import argparse
8 | import h5py
9 | import torch
10 | import time
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--input_path", default="/scratch/leisun/Datasets/GOPRO_fullsize_h5/test", help="Path to hdf5 file")
15 | parser.add_argument("--save_path", default="/scratch/e_work/GOPRO_SCER/test")
16 | parser.add_argument("--voxel_method", default="SCER_esim", help="SCER_esim, SCER_real_data")
17 | parser.add_argument("--add_noise", default=True, help="add noisy to voxel like hot pixel")
18 |
19 | exposure_time = 1/240
20 | num_bins = 6
21 | num_pixels = 1280*720
22 |
23 | args = parser.parse_args()
24 |
25 | os.makedirs(args.save_path, exist_ok=True)
26 |
27 | def voxel2mask(voxel):
28 | mask_final = np.zeros_like(voxel[0, :, :])
29 | mask = (voxel != 0)
30 | for i in range(mask.shape[0]):
31 | mask_final = np.logical_or(mask_final, mask[i, :, :])
32 | # to uint8 image
33 | mask_img = mask_final * np.ones_like(mask_final) * 255
34 | mask_img = mask_img[..., np.newaxis] # H,W,C
35 | mask_img = np.uint8(mask_img)
36 |
37 | return mask_img
38 |
39 | def main():
40 | # dataset settings
41 | # voxel_method = {'method': 'SCER_esim'}
42 | voxel_method = {'method': args.voxel_method}
43 |
44 | has_exposure_time = False # false if esim, true if our seems data
45 | # no data augmentation
46 | data_augment = {}
47 | dataset_kwargs = {'transforms': data_augment, 'voxel_method': voxel_method, 'num_bins': num_bins,
48 | 'has_exposure_time': has_exposure_time, 'combined_voxel_channels': True, 'voxel_temporal_bilinear': True} # voxel_temporal_bilinear for voxel grid
49 |
50 | file_folder_path = args.input_path
51 | output_file_folder_path = args.save_path
52 | h5_file_paths = [os.path.join(file_folder_path, s) for s in os.listdir(file_folder_path)]
53 | output_h5_file_paths = [os.path.join(output_file_folder_path, s) for s in os.listdir(file_folder_path)]
54 | # for each h5_file
55 | for i in range(len(h5_file_paths)):
56 | print("processing file: {}".format(h5_file_paths[i]))
57 | # h5_file = h5py.File(h5_file_paths[i], 'a')
58 | h5_file = h5py.File(output_h5_file_paths[i],'a')
59 | # voxel_dset = h5_file.create_group("voxels")
60 | # N = int(os.path.basename(h5_file_paths[i]).split('_')[1]) # num of the sharp img for blur image produce
61 | dataset_kwargs.update({'data_path':h5_file_paths[i]})
62 | dloader = GoproEsimH5Dataset(**dataset_kwargs)
63 | num_img = 0
64 | for item in dloader:
65 | voxel=item['voxel']
66 | num_events = item['num_events']
67 | blur = item['frame'] # C,H,W
68 | sharp = item['frame_gt']
69 |
70 | # add noise to voxel Here
71 | if args.add_noise:
72 | # print("Add noisy to voxels")
73 | voxel = add_noise_to_voxel(voxel, noise_std=1.0, noise_fraction=0.05)
74 | put_hot_pixels_in_voxel_(voxel, hot_pixel_range=20, hot_pixel_fraction=0.00002)
75 |
76 | voxel_np = voxel.numpy() # shape: bin+1,H,W
77 | blur_np=np.uint8(np.clip(255*blur.numpy(), 0, 255))
78 | sharp_np=np.uint8(np.clip(255*sharp.numpy(), 0, 255))
79 | mask_img = voxel2mask(voxel_np)
80 | #close filter
81 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
82 | mask_img_close = cv2.morphologyEx(mask_img, cv2.MORPH_CLOSE, kernel, iterations=1)
83 | # print(mask_img_close.shape)
84 | mask_img_close = mask_img_close[np.newaxis,...] # H,W -> C,H,W C=1
85 |
86 | # save to h5 image
87 | voxel_dset=h5_file.create_dataset("voxels/voxel{:09d}".format(num_img), data=voxel_np, dtype=np.dtype(np.float32))
88 | image_dset=h5_file.create_dataset("images/image{:09d}".format(num_img), data=blur_np, dtype=np.dtype(np.uint8))
89 | sharp_image_dset=h5_file.create_dataset("sharp_images/image{:09d}".format(num_img), data=sharp_np, dtype=np.dtype(np.uint8))
90 | mask_dset=h5_file.create_dataset("masks/mask{:09d}".format(num_img), data=mask_img_close, dtype=np.dtype(np.uint8))
91 |
92 |
93 |
94 | voxel_dset.attrs['size']=voxel_np.shape
95 | image_dset.attrs['size']=blur_np.shape
96 | sharp_image_dset.attrs['size']=sharp_np.shape
97 | mask_dset.attrs['size']=mask_img_close.shape
98 |
99 |
100 | num_img+=1
101 | sensor_resolution = [blur_np.shape[1], blur_np.shape[2]]
102 | h5_file.attrs['sensor_resolution'] = sensor_resolution
103 |
104 |
105 |
106 |
107 | if __name__ == "__main__":
108 | main()
109 |
110 |
--------------------------------------------------------------------------------
/scripts/data_preparation/make_voxels_real.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | import cv2
4 | import os
5 | from .raw_event_dataset import *
6 | from .noise_function import add_noise_to_voxel, put_hot_pixels_in_voxel_
7 | import argparse
8 | import h5py
9 | import torch
10 | import time
11 |
12 | # num_bins = 6 # SCER_esim
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--input_path", default="/scratch/leisun/REBlur_h5", help="Path to hdf5 file")
15 | parser.add_argument("--save_path", default="/scratch/leisun/REBlur_SCER")
16 |
17 | parser.add_argument("--voxel_method", default="SCER_real_data", help="SCER_esim, SCER_real_data, SBT, All_accumulate=SBT + bin=1")
18 | parser.add_argument("--add_noise", default=False, help="add noisy to voxel like hot pixel")
19 |
20 | has_exposure_time = True # false if esim, true if our seems data
21 | exposure_time = 1/240
22 | num_bins = 6
23 | num_pixels = 1280*720
24 | has_gt_frame = False
25 |
26 | args = parser.parse_args()
27 |
28 | os.makedirs(args.save_path, exist_ok=True)
29 |
30 | def voxel2mask(voxel):
31 | mask_final = np.zeros_like(voxel[0, :, :])
32 | mask = (voxel != 0)
33 | for i in range(mask.shape[0]):
34 | mask_final = np.logical_or(mask_final, mask[i, :, :])
35 | # to uint8 image
36 | mask_img = mask_final * np.ones_like(mask_final) * 255
37 | mask_img = mask_img[..., np.newaxis] # H,W,C
38 | mask_img = np.uint8(mask_img)
39 |
40 | return mask_img
41 |
42 | def main():
43 | # dataset settings
44 | # voxel_method = {'method': 'SCER_esim'}
45 | voxel_method = {'method': args.voxel_method}
46 |
47 | # no data augmentation
48 | data_augment = {}
49 | dataset_kwargs = {'transforms': data_augment, 'voxel_method': voxel_method, 'num_bins': num_bins, 'return_gt_frame':has_gt_frame,
50 | 'has_exposure_time': has_exposure_time, 'combined_voxel_channels': True, 'keep_middle':False }
51 |
52 | file_folder_path = args.input_path
53 | output_file_folder_path = args.save_path
54 | h5_file_paths = [os.path.join(file_folder_path, s) for s in os.listdir(file_folder_path)]
55 | output_h5_file_paths = [os.path.join(output_file_folder_path, s) for s in os.listdir(file_folder_path)]
56 | # for each h5_file
57 | for i in range(len(h5_file_paths)):
58 | print("processing file: {}".format(h5_file_paths[i]))
59 | # h5_file = h5py.File(h5_file_paths[i], 'a')
60 | h5_file = h5py.File(output_h5_file_paths[i],'a')
61 |
62 | dataset_kwargs.update({'data_path':h5_file_paths[i]})
63 | dloader = SeemsH5Dataset(**dataset_kwargs)
64 | num_img = 0
65 | for item in dloader:
66 | voxel=item['voxel']
67 | num_events = item['num_events']
68 | blur = item['frame'] # C,H,W
69 | if has_gt_frame:
70 | sharp = item['frame_gt']
71 |
72 | # 1,262,320 -> 1,260,320
73 | voxel = voxel[:,:-2,:]
74 | blur = blur[:,:-2,:]
75 | if has_gt_frame:
76 | sharp = sharp[:,:-2,:]
77 |
78 |
79 | # add noise to voxel Here
80 | if args.add_noise:
81 | # print("Add noisy to voxels")
82 | voxel = add_noise_to_voxel(voxel, noise_std=1.0, noise_fraction=0.05)
83 | put_hot_pixels_in_voxel_(voxel, hot_pixel_range=20, hot_pixel_fraction=0.00002)
84 |
85 |
86 | voxel_np = voxel.numpy() # shape: bin+1,H,W
87 | blur_np=np.uint8(np.clip(255*blur.numpy(), 0, 255))
88 | if has_gt_frame:
89 | sharp_np=np.uint8(np.clip(255*sharp.numpy(), 0, 255))
90 |
91 | mask_img = voxel2mask(voxel_np)
92 | #close filter
93 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
94 | mask_img_close = cv2.morphologyEx(mask_img, cv2.MORPH_CLOSE, kernel, iterations=1)
95 | # print(mask_img_close.shape)
96 | mask_img_close = mask_img_close[np.newaxis,...] # H,W -> C,H,W C=1
97 |
98 | # save to h5 image
99 | voxel_dset=h5_file.create_dataset("voxels/voxel{:09d}".format(num_img), data=voxel_np, dtype=np.dtype(np.float32))
100 | image_dset=h5_file.create_dataset("images/image{:09d}".format(num_img), data=blur_np, dtype=np.dtype(np.uint8))
101 | if has_gt_frame:
102 |
103 | sharp_image_dset=h5_file.create_dataset("sharp_images/image{:09d}".format(num_img), data=sharp_np, dtype=np.dtype(np.uint8))
104 |
105 | mask_dset=h5_file.create_dataset("masks/mask{:09d}".format(num_img), data=mask_img_close, dtype=np.dtype(np.uint8))
106 |
107 |
108 |
109 | voxel_dset.attrs['size']=voxel_np.shape
110 | image_dset.attrs['size']=blur_np.shape
111 | if has_gt_frame:
112 | sharp_image_dset.attrs['size']=sharp_np.shape
113 | mask_dset.attrs['size']=mask_img_close.shape
114 |
115 |
116 | num_img+=1
117 | sensor_resolution = [blur_np.shape[1], blur_np.shape[2]]
118 | h5_file.attrs['sensor_resolution'] = sensor_resolution
119 |
120 |
121 |
122 |
123 | if __name__ == "__main__":
124 | main()
125 |
126 |
--------------------------------------------------------------------------------
/basicsr/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import numpy as np
3 | import random
4 | import torch
5 | import torch.utils.data
6 | from functools import partial
7 | from os import path as osp
8 |
9 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader
10 | from basicsr.utils import get_root_logger, scandir
11 | from basicsr.utils.dist_util import get_dist_info
12 | from basicsr.data.h5_image_dataset import *
13 |
14 | __all__ = ['create_dataset', 'create_dataloader']
15 |
16 | # automatically scan and import dataset modules
17 | # scan all the files under the data folder with '_dataset' in file names
18 | data_folder = osp.dirname(osp.abspath(__file__))
19 | dataset_filenames = [
20 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
21 | if v.endswith('_dataset.py')
22 | ]
23 | # import all the dataset modules
24 | _dataset_modules = [
25 | importlib.import_module(f'basicsr.data.{file_name}')
26 | for file_name in dataset_filenames
27 | ]
28 |
29 |
30 | def create_dataset(dataset_opt):
31 | """Create dataset.
32 |
33 | Args:
34 | dataset_opt (dict): Configuration for dataset. It constains:
35 | name (str): Dataset name.
36 | type (str): Dataset type.
37 | """
38 | dataset_type = dataset_opt['type']
39 |
40 | # dynamic instantiation
41 | for module in _dataset_modules:
42 | dataset_cls = getattr(module, dataset_type, None)
43 | if dataset_cls is not None:
44 | break
45 | if dataset_cls is None:
46 | raise ValueError(f'Dataset {dataset_type} is not found.')
47 |
48 | if dataset_type == "H5ImageDataset":
49 | dataset = concatenate_h5_datasets(dataset_cls, dataset_opt)
50 |
51 | else:
52 | dataset = dataset_cls(dataset_opt)
53 |
54 | logger = get_root_logger()
55 | logger.info(
56 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} '
57 | 'is created.')
58 | return dataset
59 |
60 |
61 |
62 | def create_dataloader(dataset,
63 | dataset_opt,
64 | num_gpu=1,
65 | dist=False,
66 | sampler=None,
67 | seed=None):
68 | """Create dataloader.
69 |
70 | Args:
71 | dataset (torch.utils.data.Dataset): Dataset.
72 | dataset_opt (dict): Dataset options. It contains the following keys:
73 | phase (str): 'train' or 'val'.
74 | num_worker_per_gpu (int): Number of workers for each GPU.
75 | batch_size_per_gpu (int): Training batch size for each GPU.
76 | num_gpu (int): Number of GPUs. Used only in the train phase.
77 | Default: 1.
78 | dist (bool): Whether in distributed training. Used only in the train
79 | phase. Default: False.
80 | sampler (torch.utils.data.sampler): Data sampler. Default: None.
81 | seed (int | None): Seed. Default: None
82 | """
83 | phase = dataset_opt['phase']
84 | rank, _ = get_dist_info()
85 | if phase == 'train':
86 | if dist: # distributed training
87 | batch_size = dataset_opt['batch_size_per_gpu']
88 | num_workers = dataset_opt['num_worker_per_gpu']
89 | else: # non-distributed training
90 | multiplier = 1 if num_gpu == 0 else num_gpu
91 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
92 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
93 | dataloader_args = dict(
94 | dataset=dataset,
95 | batch_size=batch_size,
96 | shuffle=False,
97 | num_workers=num_workers,
98 | sampler=sampler,
99 | drop_last=True)
100 | if sampler is None:
101 | dataloader_args['shuffle'] = True
102 | dataloader_args['worker_init_fn'] = partial(
103 | worker_init_fn, num_workers=num_workers, rank=rank,
104 | seed=seed) if seed is not None else None
105 | elif phase in ['val', 'test']: # validation
106 | dataloader_args = dict(
107 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
108 | else:
109 | raise ValueError(f'Wrong dataset phase: {phase}. '
110 | "Supported ones are 'train', 'val' and 'test'.")
111 |
112 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
113 |
114 | prefetch_mode = dataset_opt.get('prefetch_mode')
115 | if prefetch_mode == 'cpu': # CPUPrefetcher
116 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
117 | logger = get_root_logger()
118 | logger.info(f'Use {prefetch_mode} prefetch dataloader: '
119 | f'num_prefetch_queue = {num_prefetch_queue}')
120 | return PrefetchDataLoader(
121 | num_prefetch_queue=num_prefetch_queue, **dataloader_args)
122 | else:
123 | # prefetch_mode=None: Normal dataloader
124 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher
125 | return torch.utils.data.DataLoader(**dataloader_args)
126 |
127 |
128 | def worker_init_fn(worker_id, num_workers, rank, seed):
129 | # Set the worker seed to num_workers * rank + worker_id + seed
130 | worker_seed = num_workers * rank + worker_id + seed
131 | np.random.seed(worker_seed)
132 | random.seed(worker_seed)
133 |
--------------------------------------------------------------------------------
/basicsr/utils/create_lmdb.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs
6 |
7 | def prepare_keys(folder_path, suffix='png'):
8 | """Prepare image path list and keys for DIV2K dataset.
9 |
10 | Args:
11 | folder_path (str): Folder path.
12 |
13 | Returns:
14 | list[str]: Image path list.
15 | list[str]: Key list.
16 | """
17 | print('Reading image path list ...')
18 | img_path_list = sorted(
19 | list(scandir(folder_path, suffix=suffix, recursive=False)))
20 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)]
21 |
22 | return img_path_list, keys
23 |
24 | def create_lmdb_for_reds():
25 | folder_path = './datasets/REDS/val/sharp_300'
26 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb'
27 | img_path_list, keys = prepare_keys(folder_path, 'png')
28 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
29 | #
30 | folder_path = './datasets/REDS/val/blur_300'
31 | lmdb_path = './datasets/REDS/val/blur_300.lmdb'
32 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
33 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
34 |
35 | folder_path = './datasets/REDS/train/train_sharp'
36 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb'
37 | img_path_list, keys = prepare_keys(folder_path, 'png')
38 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
39 |
40 | folder_path = './datasets/REDS/train/train_blur_jpeg'
41 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb'
42 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
43 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
44 |
45 |
46 | def create_lmdb_for_gopro():
47 | folder_path = './datasets/GoPro/train/blur_crops'
48 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb'
49 |
50 | img_path_list, keys = prepare_keys(folder_path, 'png')
51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
52 |
53 | folder_path = './datasets/GoPro/train/sharp_crops'
54 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb'
55 |
56 | img_path_list, keys = prepare_keys(folder_path, 'png')
57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
58 |
59 | folder_path = './datasets/GoPro/test/target'
60 | lmdb_path = './datasets/GoPro/test/target.lmdb'
61 |
62 | img_path_list, keys = prepare_keys(folder_path, 'png')
63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
64 |
65 | folder_path = './datasets/GoPro/test/input'
66 | lmdb_path = './datasets/GoPro/test/input.lmdb'
67 |
68 | img_path_list, keys = prepare_keys(folder_path, 'png')
69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
70 |
71 | def create_lmdb_for_rain13k():
72 | folder_path = './datasets/Rain13k/train/input'
73 | lmdb_path = './datasets/Rain13k/train/input.lmdb'
74 |
75 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
77 |
78 | folder_path = './datasets/Rain13k/train/target'
79 | lmdb_path = './datasets/Rain13k/train/target.lmdb'
80 |
81 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
83 |
84 | def create_lmdb_for_SIDD():
85 | folder_path = './datasets/SIDD/train/input_crops'
86 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb'
87 |
88 | img_path_list, keys = prepare_keys(folder_path, 'PNG')
89 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
90 |
91 | folder_path = './datasets/SIDD/train/gt_crops'
92 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb'
93 |
94 | img_path_list, keys = prepare_keys(folder_path, 'PNG')
95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
96 |
97 | #for val
98 | folder_path = './datasets/SIDD/val/input_crops'
99 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb'
100 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat'
101 | if not osp.exists(folder_path):
102 | os.makedirs(folder_path)
103 | assert osp.exists(mat_path)
104 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb']
105 | N, B, H ,W, C = data.shape
106 | data = data.reshape(N*B, H, W, C)
107 | for i in tqdm(range(N*B)):
108 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
109 | img_path_list, keys = prepare_keys(folder_path, 'png')
110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
111 |
112 | folder_path = './datasets/SIDD/val/gt_crops'
113 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb'
114 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat'
115 | if not osp.exists(folder_path):
116 | os.makedirs(folder_path)
117 | assert osp.exists(mat_path)
118 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb']
119 | N, B, H ,W, C = data.shape
120 | data = data.reshape(N*B, H, W, C)
121 | for i in tqdm(range(N*B)):
122 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
123 | img_path_list, keys = prepare_keys(folder_path, 'png')
124 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
125 |
--------------------------------------------------------------------------------
/basicsr/data/voxelnpz_image_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils import data as data
2 | from torchvision.transforms.functional import normalize
3 | from tqdm import tqdm
4 | import os
5 | from pathlib import Path
6 | import random
7 | import numpy as np
8 | import torch
9 |
10 | from basicsr.data.data_util import (paired_paths_from_folder,
11 | paired_paths_from_lmdb,
12 | paired_paths_from_meta_info_file,
13 | recursive_glob)
14 | from basicsr.data.event_util import events_to_voxel_grid, voxel_norm
15 | from basicsr.data.transforms import augment, triple_random_crop, random_augmentation
16 | from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, get_root_logger
17 | from torch.utils.data.dataloader import default_collate
18 |
19 |
20 | class VoxelnpzPngSingleDeblurDataset(data.Dataset):
21 | """Paired vxoel(npz) and blurry image (png) dataset for event-based single image deblurring.
22 | --HighREV
23 | |----train
24 | | |----blur
25 | | | |----SEQNAME_%5d.png
26 | | | |----...
27 | | |----voxel
28 | | | |----SEQNAME_%5d.npz
29 | | | |----...
30 | | |----sharp
31 | | | |----SEQNAME_%5d.png
32 | | | |----...
33 | |----val
34 | ...
35 |
36 |
37 | Args:
38 | opt (dict): Config for train dataset. It contains the following keys:
39 | dataroot (str): Data root path.
40 | io_backend (dict): IO backend type and other kwarg.
41 | num_end_interpolation (int): Number of sharp frames to reconstruct in each blurry image.
42 | num_inter_interpolation (int): Number of sharp frames to interpolate between two blurry images.
43 | phase (str): 'train' or 'test'
44 |
45 | gt_size (int): Cropped patched size for gt patches.
46 | random_reverse (bool): Random reverse input frames.
47 | use_hflip (bool): Use horizontal flips.
48 | use_rot (bool): Use rotation (use vertical flip and transposing h
49 | and w for implementation).
50 |
51 | scale (bool): Scale, which will be added automatically.
52 | """
53 |
54 | def __init__(self, opt):
55 | super(VoxelnpzPngSingleDeblurDataset, self).__init__()
56 | self.opt = opt
57 | self.dataroot = Path(opt['dataroot'])
58 | self.dataroot_voxel = Path(opt['dataroot_voxel'])
59 | self.split = 'train' if opt['phase'] == 'train' else 'val' # train or val
60 | self.norm_voxel = opt['norm_voxel']
61 | self.dataPath = []
62 |
63 | blur_frames = sorted(recursive_glob(rootdir=os.path.join(self.dataroot, 'blur'), suffix='.png'))
64 | blur_frames = [os.path.join(self.dataroot, 'blur', blur_frame) for blur_frame in blur_frames]
65 |
66 | sharp_frames = sorted(recursive_glob(rootdir=os.path.join(self.dataroot, 'sharp'), suffix='.png'))
67 | sharp_frames = [os.path.join(self.dataroot, 'sharp', sharp_frame) for sharp_frame in sharp_frames]
68 |
69 | event_frames = sorted(recursive_glob(rootdir=self.dataroot_voxel, suffix='.npz'))
70 | event_frames = [os.path.join(self.dataroot_voxel, event_frame) for event_frame in event_frames]
71 |
72 | assert len(blur_frames) == len(sharp_frames) == len(event_frames), f"Mismatch in blur ({len(blur_frames)}), sharp ({len(sharp_frames)}), and event ({len(event_frames)}) frame counts."
73 |
74 | for i in range(len(blur_frames)):
75 | self.dataPath.append({
76 | 'blur_path': blur_frames[i],
77 | 'sharp_path': sharp_frames[i],
78 | 'event_paths': event_frames[i],
79 | })
80 | logger = get_root_logger()
81 | logger.info(f"Dataset initialized with {len(self.dataPath)} samples.")
82 |
83 | # file client (io backend)
84 | self.file_client = None
85 | self.io_backend_opt = opt['io_backend']
86 | # import pdb; pdb.set_trace()
87 |
88 | def __getitem__(self, index):
89 | if self.file_client is None:
90 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
91 | scale = self.opt['scale']
92 | gt_size = self.opt['gt_size']
93 |
94 | image_path = self.dataPath[index]['blur_path']
95 | gt_path = self.dataPath[index]['sharp_path']
96 | event_path = self.dataPath[index]['event_paths']
97 |
98 | # get LQ
99 | img_bytes = self.file_client.get(image_path) # 'lq'
100 | img_lq = imfrombytes(img_bytes, float32=True)
101 | # get GT
102 | img_bytes = self.file_client.get(gt_path) # 'gt'
103 | img_gt = imfrombytes(img_bytes, float32=True)
104 |
105 | voxel = np.load(event_path)['voxel']
106 |
107 | ## Data augmentation
108 | # voxel shape: h,w,c
109 | # crop
110 | if gt_size is not None:
111 | img_gt, img_lq, voxel = triple_random_crop(img_gt, img_lq, voxel, gt_size, scale, gt_path)
112 |
113 | # flip, rotate
114 | total_input = [img_lq, img_gt, voxel]
115 | img_results = augment(total_input, self.opt['use_hflip'], self.opt['use_rot'])
116 |
117 | img_results = img2tensor(img_results) # hwc -> chw
118 | img_lq, img_gt, voxel = img_results
119 |
120 | ## Norm voxel
121 | if self.norm_voxel:
122 | voxel = voxel_norm(voxel)
123 |
124 | origin_index = os.path.basename(image_path).split('.')[0]
125 |
126 | return {'frame': img_lq, 'frame_gt': img_gt, 'voxel': voxel, 'image_name': origin_index}
127 |
128 | def __len__(self):
129 | return len(self.dataPath)
130 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import find_packages, setup
4 |
5 | import os
6 | import subprocess
7 | import sys
8 | import time
9 | import torch
10 | from torch.utils.cpp_extension import (BuildExtension, CppExtension,
11 | CUDAExtension)
12 |
13 | version_file = 'basicsr/version.py'
14 |
15 |
16 | def readme():
17 | return ''
18 | # with open('README.md', encoding='utf-8') as f:
19 | # content = f.read()
20 | # return content
21 |
22 |
23 | def get_git_hash():
24 |
25 | def _minimal_ext_cmd(cmd):
26 | # construct minimal environment
27 | env = {}
28 | for k in ['SYSTEMROOT', 'PATH', 'HOME']:
29 | v = os.environ.get(k)
30 | if v is not None:
31 | env[k] = v
32 | # LANGUAGE is used on win32
33 | env['LANGUAGE'] = 'C'
34 | env['LANG'] = 'C'
35 | env['LC_ALL'] = 'C'
36 | out = subprocess.Popen(
37 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
38 | return out
39 |
40 | try:
41 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
42 | sha = out.strip().decode('ascii')
43 | except OSError:
44 | sha = 'unknown'
45 |
46 | return sha
47 |
48 |
49 | def get_hash():
50 | if os.path.exists('.git'):
51 | sha = get_git_hash()[:7]
52 | elif os.path.exists(version_file):
53 | try:
54 | from basicsr.version import __version__
55 | sha = __version__.split('+')[-1]
56 | except ImportError:
57 | raise ImportError('Unable to get git version')
58 | else:
59 | sha = 'unknown'
60 |
61 | return sha
62 |
63 |
64 | def write_version_py():
65 | content = """# GENERATED VERSION FILE
66 | # TIME: {}
67 | __version__ = '{}'
68 | short_version = '{}'
69 | version_info = ({})
70 | """
71 | sha = get_hash()
72 | with open('VERSION', 'r') as f:
73 | SHORT_VERSION = f.read().strip()
74 | VERSION_INFO = ', '.join(
75 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
76 | VERSION = SHORT_VERSION + '+' + sha
77 |
78 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION,
79 | VERSION_INFO)
80 | with open(version_file, 'w') as f:
81 | f.write(version_file_str)
82 |
83 |
84 | def get_version():
85 | with open(version_file, 'r') as f:
86 | exec(compile(f.read(), version_file, 'exec'))
87 | return locals()['__version__']
88 |
89 |
90 | def make_cuda_ext(name, module, sources, sources_cuda=None):
91 | if sources_cuda is None:
92 | sources_cuda = []
93 | define_macros = []
94 | extra_compile_args = {'cxx': []}
95 |
96 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
97 | define_macros += [('WITH_CUDA', None)]
98 | extension = CUDAExtension
99 | extra_compile_args['nvcc'] = [
100 | '-D__CUDA_NO_HALF_OPERATORS__',
101 | '-D__CUDA_NO_HALF_CONVERSIONS__',
102 | '-D__CUDA_NO_HALF2_OPERATORS__',
103 | ]
104 | sources += sources_cuda
105 | else:
106 | print(f'Compiling {name} without CUDA')
107 | extension = CppExtension
108 |
109 | return extension(
110 | name=f'{module}.{name}',
111 | sources=[os.path.join(*module.split('.'), p) for p in sources],
112 | define_macros=define_macros,
113 | extra_compile_args=extra_compile_args)
114 |
115 |
116 | def get_requirements(filename='requirements.txt'):
117 | return []
118 | here = os.path.dirname(os.path.realpath(__file__))
119 | with open(os.path.join(here, filename), 'r') as f:
120 | requires = [line.replace('\n', '') for line in f.readlines()]
121 | return requires
122 |
123 |
124 | if __name__ == '__main__':
125 | if '--no_cuda_ext' in sys.argv:
126 | ext_modules = []
127 | sys.argv.remove('--no_cuda_ext')
128 | else:
129 | ext_modules = [
130 | make_cuda_ext(
131 | name='deform_conv_ext',
132 | module='basicsr.models.ops.dcn',
133 | sources=['src/deform_conv_ext.cpp'],
134 | sources_cuda=[
135 | 'src/deform_conv_cuda.cpp',
136 | 'src/deform_conv_cuda_kernel.cu'
137 | ]),
138 | make_cuda_ext(
139 | name='fused_act_ext',
140 | module='basicsr.models.ops.fused_act',
141 | sources=['src/fused_bias_act.cpp'],
142 | sources_cuda=['src/fused_bias_act_kernel.cu']),
143 | make_cuda_ext(
144 | name='upfirdn2d_ext',
145 | module='basicsr.models.ops.upfirdn2d',
146 | sources=['src/upfirdn2d.cpp'],
147 | sources_cuda=['src/upfirdn2d_kernel.cu']),
148 | ]
149 |
150 | write_version_py()
151 | setup(
152 | name='basicsr',
153 | version=get_version(),
154 | description='Open Source Image and Video Super-Resolution Toolbox',
155 | long_description=readme(),
156 | author='Xintao Wang',
157 | author_email='xintao.wang@outlook.com',
158 | keywords='computer vision, restoration, super resolution',
159 | url='https://github.com/xinntao/BasicSR',
160 | packages=find_packages(
161 | exclude=('options', 'datasets', 'experiments', 'results',
162 | 'tb_logger', 'wandb')),
163 | classifiers=[
164 | 'Development Status :: 4 - Beta',
165 | 'License :: OSI Approved :: Apache Software License',
166 | 'Operating System :: OS Independent',
167 | 'Programming Language :: Python :: 3',
168 | 'Programming Language :: Python :: 3.7',
169 | 'Programming Language :: Python :: 3.8',
170 | ],
171 | license='Apache License 2.0',
172 | setup_requires=['cython', 'numpy'],
173 | install_requires=get_requirements(),
174 | ext_modules=ext_modules,
175 | cmdclass={'build_ext': BuildExtension},
176 | zip_safe=False)
177 |
--------------------------------------------------------------------------------
/scripts/download_pretrained_models.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from os import path as osp
4 |
5 | from basicsr.utils.download_util import download_file_from_google_drive
6 |
7 |
8 | def download_pretrained_models(method, file_ids):
9 | save_path_root = f'./experiments/pretrained_models/{method}'
10 | os.makedirs(save_path_root, exist_ok=True)
11 |
12 | for file_name, file_id in file_ids.items():
13 | save_path = osp.abspath(osp.join(save_path_root, file_name))
14 | if osp.exists(save_path):
15 | user_response = input(
16 | f'{file_name} already exist. Do you want to cover it? Y/N\n')
17 | if user_response.lower() == 'y':
18 | print(f'Covering {file_name} to {save_path}')
19 | download_file_from_google_drive(file_id, save_path)
20 | elif user_response.lower() == 'n':
21 | print(f'Skipping {file_name}')
22 | else:
23 | raise ValueError('Wrong input. Only accpets Y/N.')
24 | else:
25 | print(f'Downloading {file_name} to {save_path}')
26 | download_file_from_google_drive(file_id, save_path)
27 |
28 |
29 | if __name__ == '__main__':
30 | parser = argparse.ArgumentParser()
31 |
32 | parser.add_argument(
33 | 'method',
34 | type=str,
35 | help=(
36 | "Options: 'ESRGAN', 'EDVR', 'StyleGAN', 'EDSR', 'DUF', 'DFDNet', "
37 | "'dlib'. Set to 'all' if you want to download all the models."))
38 | args = parser.parse_args()
39 |
40 | file_ids = {
41 | 'ESRGAN': {
42 | 'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth': # file name
43 | '1b3_bWZTjNO3iL2js1yWkJfjZykcQgvzT', # file id
44 | 'ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth':
45 | '1swaV5iBMFfg-DL6ZyiARztbhutDCWXMM'
46 | },
47 | 'EDVR': {
48 | 'EDVR_L_x4_SR_REDS_official-9f5f5039.pth':
49 | '127KXEjlCwfoPC1aXyDkluNwr9elwyHNb',
50 | 'EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth':
51 | '1aVR3lkX6ItCphNLcT7F5bbbC484h4Qqy',
52 | 'EDVR_M_woTSA_x4_SR_REDS_official-1edf645c.pth':
53 | '1C_WdN-NyNj-P7SOB5xIVuHl4EBOwd-Ny',
54 | 'EDVR_M_x4_SR_REDS_official-32075921.pth':
55 | '1dd6aFj-5w2v08VJTq5mS9OFsD-wALYD6',
56 | 'EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth':
57 | '1GZz_87ybR8eAAY3X2HWwI3L6ny7-5Yvl',
58 | 'EDVR_L_deblur_REDS_official-ca46bd8c.pth':
59 | '1_ma2tgHscZtkIY2tEJkVdU-UP8bnqBRE',
60 | 'EDVR_L_deblurcomp_REDS_official-0e988e5c.pth':
61 | '1fEoSeLFnHSBbIs95Au2W197p8e4ws4DW'
62 | },
63 | 'StyleGAN': {
64 | 'stylegan2_ffhq_config_f_1024_official-b09c3668.pth':
65 | '163PfuVSYKh4vhkYkfEaufw84CiF4pvWG',
66 | 'stylegan2_ffhq_config_f_1024_discriminator_official-806ddc5e.pth':
67 | '1wyOdcJnMtAT_fEwXYJObee7hcLzI8usT',
68 | 'stylegan2_cat_config_f_256_official-b82c74e3.pth':
69 | '1dGUvw8FLch50FEDAgAa6st1AXGnjduc7',
70 | 'stylegan2_cat_config_f_256_discriminator_official-f6f5ed5c.pth':
71 | '19wuj7Ztg56QtwEs01-p_LjQeoz6G11kF',
72 | 'stylegan2_church_config_f_256_official-12725a53.pth':
73 | '1Rcpguh4t833wHlFrWz9UuqFcSYERyd2d',
74 | 'stylegan2_church_config_f_256_discriminator_official-feba65b0.pth': # noqa: E501
75 | '1ImOfFUOwKqDDKZCxxM4VUdPQCc-j85Z9',
76 | 'stylegan2_car_config_f_512_official-32c42d4e.pth':
77 | '1FviBGvzORv4T3w0c3m7BaIfLNeEd0dC8',
78 | 'stylegan2_car_config_f_512_discriminator_official-31f302ab.pth':
79 | '1hlZ7M2GrK6cDFd2FIYazPxOZXTUfudB3',
80 | 'stylegan2_horse_config_f_256_official-d3d97ebc.pth':
81 | '1LV4OR22tJN19HHfGk0e7dVqMhjD0APRm',
82 | 'stylegan2_horse_config_f_256_discriminator_official-efc5e50e.pth':
83 | '1T8xbI-Tz8EeSg3gCmQBNqGjLP5l3Mv84'
84 | },
85 | 'EDSR': {
86 | 'EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth':
87 | '1mREMGVDymId3NzIc2u90sl_X4-pb4ZcV',
88 | 'EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth':
89 | '1EriqQqlIiRyPbrYGBbwr_FZzvb3iwqz5',
90 | 'EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth':
91 | '1bCK6cFYU01uJudLgUUe-jgx-tZ3ikOWn',
92 | 'EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth':
93 | '15257lZCRZ0V6F9LzTyZFYbbPrqNjKyMU',
94 | 'EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth':
95 | '18q_D434sLG_rAZeHGonAX8dkqjoyZ2su',
96 | 'EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth':
97 | '1GCi30YYCzgMCcgheGWGusP9aWKOAy5vl'
98 | },
99 | 'DUF': {
100 | 'DUF_x2_16L_official-39537cb9.pth':
101 | '1e91cEZOlUUk35keK9EnuK0F54QegnUKo',
102 | 'DUF_x3_16L_official-34ce53ec.pth':
103 | '1XN6aQj20esM7i0hxTbfiZr_SL8i4PZ76',
104 | 'DUF_x4_16L_official-bf8f0cfa.pth':
105 | '1V_h9U1CZgLSHTv1ky2M3lvuH-hK5hw_J',
106 | 'DUF_x4_28L_official-cbada450.pth':
107 | '1M8w0AMBJW65MYYD-_8_be0cSH_SHhDQ4',
108 | 'DUF_x4_52L_official-483d2c78.pth':
109 | '1GcmEWNr7mjTygi-QCOVgQWOo5OCNbh_T'
110 | },
111 | 'DFDNet': {
112 | 'DFDNet_dict_512-f79685f0.pth':
113 | '1iH00oMsoN_1OJaEQw3zP7_wqiAYMnY79',
114 | 'DFDNet_official-d1fa5650.pth':
115 | '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe'
116 | },
117 | 'dlib': {
118 | 'mmod_human_face_detector-4cb19393.dat':
119 | '1FUM-hcoxNzFCOpCWbAUStBBMiU4uIGIL',
120 | 'shape_predictor_5_face_landmarks-c4b1e980.dat':
121 | '1PNPSmFjmbuuUDd5Mg5LDxyk7tu7TQv2F',
122 | 'shape_predictor_68_face_landmarks-fbdc2cb8.dat':
123 | '1IneH-O-gNkG0SQpNCplwxtOAtRCkG2ni'
124 | }
125 | }
126 |
127 | if args.method == 'all':
128 | for method in file_ids.keys():
129 | download_pretrained_models(method, file_ids[method])
130 | else:
131 | download_pretrained_models(args.method, file_ids[args.method])
132 |
--------------------------------------------------------------------------------
/basicsr/utils/npz2voxel.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import time
4 |
5 | def events_to_voxel_grid(events, num_bins, width, height, return_format='CHW'):
6 | """
7 | Build a voxel grid with bilinear interpolation in the time domain from a set of events.
8 |
9 | :param events: a [N x 4] NumPy array containing one event per row in the form: [timestamp, x, y, polarity]
10 | :param num_bins: number of bins in the temporal axis of the voxel grid
11 | :param width, height: dimensions of the voxel grid
12 | :param return_format: 'CHW' or 'HWC'
13 | """
14 |
15 | assert(events.shape[1] == 4)
16 | assert(num_bins > 0)
17 | assert(width > 0)
18 | assert(height > 0)
19 |
20 | voxel_grid = np.zeros((num_bins, height, width), np.float32).ravel()
21 | # print('DEBUG: voxel.shape:{}'.format(voxel_grid.shape))
22 |
23 | # normalize the event timestamps so that they lie between 0 and num_bins
24 | last_stamp = events[-1, 0]
25 | # print('last stamp:{}'.format(last_stamp))
26 | # print('max stamp:{}'.format(events[:, 0].max()))
27 | # print('timestamp:{}'.format(events[:, 0]))
28 | # print('polarity:{}'.format(events[:, -1]))
29 |
30 | first_stamp = events[0, 0]
31 | deltaT = last_stamp - first_stamp
32 |
33 | if deltaT == 0:
34 | deltaT = 1.0
35 |
36 | events[:, 0] = (num_bins - 1) * (events[:, 0] - first_stamp) / deltaT #
37 | ts = events[:, 0]
38 | xs = events[:, 1].astype(int)
39 | ys = events[:, 2].astype(int)
40 | pols = events[:, 3]
41 | pols[pols == 0] = -1 # polarity should be +1 / -1
42 |
43 | tis = ts.astype(int)
44 | dts = ts - tis
45 | vals_left = pols * (1.0 - dts)
46 | vals_right = pols * dts
47 |
48 | valid_indices = tis < num_bins # [True True ... True]
49 | # print('x max:{}'.format(xs[valid_indices].max()))
50 | # print('y max:{}'.format(ys[valid_indices].max()))
51 | # print('tix max:{}'.format(tis[valid_indices].max()))
52 |
53 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width ## ! ! !
54 | + tis[valid_indices] * width * height, vals_left[valid_indices])
55 |
56 | valid_indices = (tis + 1) < num_bins
57 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width
58 | + (tis[valid_indices] + 1) * width * height, vals_right[valid_indices])
59 |
60 | voxel_grid = np.reshape(voxel_grid, (num_bins, height, width))
61 |
62 | if return_format == 'CHW':
63 | return voxel_grid
64 | elif return_format == 'HWC':
65 | return voxel_grid.transpose(1,2,0)
66 |
67 | def recursive_glob(rootdir='.', suffix=''):
68 | """Performs recursive glob with given suffix and rootdir
69 | :param rootdir is the root directory
70 | :param suffix is the suffix to be searched
71 | """
72 | # return [os.path.join(looproot, filename)
73 | # for looproot, _, filenames in os.walk(rootdir)
74 | # for filename in filenames if filename.endswith(suffix)]
75 | return [filename
76 | for looproot, _, filenames in os.walk(rootdir)
77 | for filename in filenames if filename.endswith(suffix)]
78 |
79 |
80 | def main():
81 |
82 | dataroot = '/work/lei_sun/HighREV/val'
83 | dataroot = '/work/lei_sun/HighREV/train'
84 |
85 | voxel_root = dataroot + '/voxel'
86 | if not os.path.exists(voxel_root):
87 | os.makedirs(voxel_root)
88 | voxel_bins = 6
89 | blur_frames = sorted(recursive_glob(rootdir=os.path.join(dataroot, 'blur'), suffix='.png'))
90 | blur_frames = [os.path.join(dataroot, 'blur', blur_frame) for blur_frame in blur_frames]
91 | h_lq, w_lq = 1224, 1632
92 | event_frames = sorted(recursive_glob(rootdir=os.path.join(dataroot, 'event'), suffix='.npz'))
93 | event_frames = [os.path.join(dataroot, 'event', event_frame) for event_frame in event_frames]
94 |
95 | all_event_lists = []
96 | for i in range(len(blur_frames)):
97 | blur_name = os.path.basename(blur_frames[i]) # e.g., SEQNAME_00001.png
98 | base_name = os.path.splitext(blur_name)[0] # Remove .png, get SEQNAME_00001
99 | event_list = sorted([f for f in event_frames if f.startswith(os.path.join(dataroot, 'event', base_name + '_'))])[:-1] # remove the last one because it is out of the exposure time
100 | all_event_lists.append(event_list)
101 |
102 | total_time = 0
103 | num_events_processed = 0
104 | total_events = len(all_event_lists) * len(all_event_lists[0])
105 | print("loop started")
106 | start_time = time.time()
107 | for event_list in all_event_lists:
108 | loop_start_time = time.time()
109 |
110 | events = [np.load(event_path) for event_path in event_list]
111 | parts = event_list[0].rsplit('_', 1)
112 | base_name = parts[0].split('/')[-1]
113 | print(base_name, 'start', event_list[0])
114 | all_quad_event_array = np.zeros((0,4)).astype(np.float32)
115 | for event in events:
116 | ### IMPORTANT: dataset mistake x and y !!!!!!!!
117 | ### Switch x and y here !!!!
118 | y = event['x'].astype(np.float32)
119 | x = event['y'].astype(np.float32)
120 | t = event['timestamp'].astype(np.float32)
121 | p = event['polarity'].astype(np.float32)
122 |
123 | this_quad_event_array = np.concatenate((t,x,y,p),axis=1) # N,4
124 | all_quad_event_array = np.concatenate((all_quad_event_array, this_quad_event_array), axis=0)
125 | voxel = events_to_voxel_grid(all_quad_event_array, num_bins=voxel_bins, width=w_lq, height=h_lq, return_format='HWC')
126 | # print(voxel.dtype, voxel.shape) # float32, (1224, 1632, 6)
127 |
128 | voxel_path = os.path.join(voxel_root, base_name + '.npz')
129 | # print(f'saving to {voxel_path}')
130 |
131 | # voxel_path = base_name + '.npz'
132 |
133 | np.savez(voxel_path, voxel=voxel)
134 |
135 | loop_end_time = time.time()
136 | loop_duration = loop_end_time - loop_start_time
137 | total_time += loop_duration
138 | num_events_processed += len(events)
139 |
140 | print(f"Loop {num_events_processed}/{total_events} took {loop_duration:.2f} seconds.")
141 | print("")
142 |
143 |
144 | end_time = time.time()
145 | total_duration = end_time - start_time
146 | print(f"Total time taken: {total_duration:.2f} seconds.")
147 |
148 |
149 | if __name__ == '__main__':
150 | main()
--------------------------------------------------------------------------------
/basicsr/utils/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import random
4 | import time
5 | import torch
6 | from os import path as osp
7 |
8 | from .dist_util import master_only
9 | from .logger import get_root_logger
10 |
11 |
12 | def set_random_seed(seed):
13 | """Set random seeds."""
14 | random.seed(seed)
15 | np.random.seed(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | torch.cuda.manual_seed_all(seed)
19 |
20 |
21 | def get_time_str():
22 | return time.strftime('%Y%m%d_%H%M%S', time.localtime())
23 |
24 |
25 | def mkdir_and_rename(path):
26 | """mkdirs. If path exists, rename it with timestamp and create a new one.
27 |
28 | Args:
29 | path (str): Folder path.
30 | """
31 | if osp.exists(path):
32 | new_name = path + '_archived_' + get_time_str()
33 | print(f'Path already exists. Rename it to {new_name}', flush=True)
34 | os.rename(path, new_name)
35 | os.makedirs(path, exist_ok=True)
36 |
37 |
38 | @master_only
39 | def make_exp_dirs(opt):
40 | """Make dirs for experiments."""
41 | path_opt = opt['path'].copy()
42 | if opt['is_train']:
43 | mkdir_and_rename(path_opt.pop('experiments_root'))
44 | else:
45 | mkdir_and_rename(path_opt.pop('results_root'))
46 | for key, path in path_opt.items():
47 | if ('strict_load' not in key) and ('pretrain_network'
48 | not in key) and ('resume'
49 | not in key):
50 | os.makedirs(path, exist_ok=True)
51 |
52 |
53 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
54 | """Scan a directory to find the interested files.
55 |
56 | Args:
57 | dir_path (str): Path of the directory.
58 | suffix (str | tuple(str), optional): File suffix that we are
59 | interested in. Default: None.
60 | recursive (bool, optional): If set to True, recursively scan the
61 | directory. Default: False.
62 | full_path (bool, optional): If set to True, include the dir_path.
63 | Default: False.
64 |
65 | Returns:
66 | A generator for all the interested files with relative pathes.
67 | """
68 |
69 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
70 | raise TypeError('"suffix" must be a string or tuple of strings')
71 |
72 | root = dir_path
73 |
74 | def _scandir(dir_path, suffix, recursive):
75 | for entry in os.scandir(dir_path):
76 | if not entry.name.startswith('.') and entry.is_file():
77 | if full_path:
78 | return_path = entry.path
79 | else:
80 | return_path = osp.relpath(entry.path, root)
81 |
82 | if suffix is None:
83 | yield return_path
84 | elif return_path.endswith(suffix):
85 | yield return_path
86 | else:
87 | if recursive:
88 | yield from _scandir(
89 | entry.path, suffix=suffix, recursive=recursive)
90 | else:
91 | continue
92 |
93 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
94 |
95 | def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False):
96 | """Scan a directory to find the interested files.
97 |
98 | Args:
99 | dir_path (str): Path of the directory.
100 | keywords (str | tuple(str), optional): File keywords that we are
101 | interested in. Default: None.
102 | recursive (bool, optional): If set to True, recursively scan the
103 | directory. Default: False.
104 | full_path (bool, optional): If set to True, include the dir_path.
105 | Default: False.
106 |
107 | Returns:
108 | A generator for all the interested files with relative pathes.
109 | """
110 |
111 | if (keywords is not None) and not isinstance(keywords, (str, tuple)):
112 | raise TypeError('"keywords" must be a string or tuple of strings')
113 |
114 | root = dir_path
115 |
116 | def _scandir(dir_path, keywords, recursive):
117 | for entry in os.scandir(dir_path):
118 | if not entry.name.startswith('.') and entry.is_file():
119 | if full_path:
120 | return_path = entry.path
121 | else:
122 | return_path = osp.relpath(entry.path, root)
123 |
124 | if keywords is None:
125 | yield return_path
126 | elif return_path.find(keywords) > 0:
127 | yield return_path
128 | else:
129 | if recursive:
130 | yield from _scandir(
131 | entry.path, keywords=keywords, recursive=recursive)
132 | else:
133 | continue
134 |
135 | return _scandir(dir_path, keywords=keywords, recursive=recursive)
136 |
137 | def check_resume(opt, resume_iter):
138 | """Check resume states and pretrain_network paths.
139 |
140 | Args:
141 | opt (dict): Options.
142 | resume_iter (int): Resume iteration.
143 | """
144 | logger = get_root_logger()
145 | if opt['path']['resume_state']:
146 | # get all the networks
147 | networks = [key for key in opt.keys() if key.startswith('network_')]
148 | flag_pretrain = False
149 | for network in networks:
150 | if opt['path'].get(f'pretrain_{network}') is not None:
151 | flag_pretrain = True
152 | if flag_pretrain:
153 | logger.warning(
154 | 'pretrain_network path will be ignored during resuming.')
155 | # set pretrained model paths
156 | for network in networks:
157 | name = f'pretrain_{network}'
158 | basename = network.replace('network_', '')
159 | if opt['path'].get('ignore_resume_networks') is None or (
160 | basename not in opt['path']['ignore_resume_networks']):
161 | opt['path'][name] = osp.join(
162 | opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
163 | logger.info(f"Set {name} to {opt['path'][name]}")
164 |
165 |
166 | def sizeof_fmt(size, suffix='B'):
167 | """Get human readable file size.
168 |
169 | Args:
170 | size (int): File size.
171 | suffix (str): Suffix. Default: 'B'.
172 |
173 | Return:
174 | str: Formated file siz.
175 | """
176 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
177 | if abs(size) < 1024.0:
178 | return f'{size:3.1f} {unit}{suffix}'
179 | size /= 1024.0
180 | return f'{size:3.1f} Y{suffix}'
181 |
--------------------------------------------------------------------------------
/basicsr/utils/img_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | import os
5 | import torch
6 | from torchvision.utils import make_grid
7 |
8 |
9 | def img2tensor(imgs, bgr2rgb=True, float32=True):
10 | """Numpy array to tensor.
11 |
12 | Args:
13 | imgs (list[ndarray] | ndarray): Input images.
14 | bgr2rgb (bool): Whether to change bgr to rgb.
15 | float32 (bool): Whether to change to float32.
16 |
17 | Returns:
18 | list[tensor] | tensor: Tensor images. If returned results only have
19 | one element, just return tensor.
20 | """
21 |
22 | def _totensor(img, bgr2rgb, float32):
23 | if img.shape[2] == 3 and bgr2rgb:
24 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
25 | img = torch.from_numpy(img.transpose(2, 0, 1))
26 | if float32:
27 | img = img.float()
28 | return img
29 |
30 | if isinstance(imgs, list):
31 | return [_totensor(img, bgr2rgb, float32) for img in imgs]
32 | else:
33 | return _totensor(imgs, bgr2rgb, float32)
34 |
35 |
36 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
37 | """Convert torch Tensors into image numpy arrays.
38 |
39 | After clamping to [min, max], values will be normalized to [0, 1].
40 |
41 | Args:
42 | tensor (Tensor or list[Tensor]): Accept shapes:
43 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
44 | 2) 3D Tensor of shape (3/1 x H x W);
45 | 3) 2D Tensor of shape (H x W).
46 | Tensor channel should be in RGB order.
47 | rgb2bgr (bool): Whether to change rgb to bgr.
48 | out_type (numpy type): output types. If ``np.uint8``, transform outputs
49 | to uint8 type with range [0, 255]; otherwise, float type with
50 | range [0, 1]. Default: ``np.uint8``.
51 | min_max (tuple[int]): min and max values for clamp.
52 |
53 | Returns:
54 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
55 | shape (H x W). The channel order is BGR.
56 | """
57 | if not (torch.is_tensor(tensor) or
58 | (isinstance(tensor, list)
59 | and all(torch.is_tensor(t) for t in tensor))):
60 | raise TypeError(
61 | f'tensor or list of tensors expected, got {type(tensor)}')
62 |
63 | if torch.is_tensor(tensor):
64 | tensor = [tensor]
65 | result = []
66 | for _tensor in tensor:
67 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
68 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
69 |
70 | n_dim = _tensor.dim()
71 | if n_dim == 4:
72 | img_np = make_grid(
73 | _tensor, nrow=int(math.sqrt(_tensor.size(0))),
74 | normalize=False).numpy()
75 | img_np = img_np.transpose(1, 2, 0)
76 | if rgb2bgr:
77 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
78 | elif n_dim == 3:
79 | img_np = _tensor.numpy()
80 | img_np = img_np.transpose(1, 2, 0)
81 | if img_np.shape[2] == 1: # gray image
82 | img_np = np.squeeze(img_np, axis=2)
83 | else:
84 | if rgb2bgr:
85 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
86 | elif n_dim == 2:
87 | img_np = _tensor.numpy()
88 | else:
89 | raise TypeError('Only support 4D, 3D or 2D tensor. '
90 | f'But received with dimension: {n_dim}')
91 | if out_type == np.uint8:
92 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
93 | img_np = (img_np * 255.0).round()
94 | img_np = img_np.astype(out_type)
95 | result.append(img_np)
96 | if len(result) == 1:
97 | result = result[0]
98 | return result
99 |
100 |
101 | def imfrombytes(content, flag='color', float32=False):
102 | """Read an image from bytes.
103 |
104 | Args:
105 | content (bytes): Image bytes got from files or other streams.
106 | flag (str): Flags specifying the color type of a loaded image,
107 | candidates are `color`, `grayscale` and `unchanged`.
108 | float32 (bool): Whether to change to float32., If True, will also norm
109 | to [0, 1]. Default: False.
110 |
111 | Returns:
112 | ndarray: Loaded image array.
113 | """
114 | img_np = np.frombuffer(content, np.uint8)
115 | imread_flags = {
116 | 'color': cv2.IMREAD_COLOR,
117 | 'grayscale': cv2.IMREAD_GRAYSCALE,
118 | 'unchanged': cv2.IMREAD_UNCHANGED
119 | }
120 | if img_np is None:
121 | raise Exception('None .. !!!')
122 | img = cv2.imdecode(img_np, imread_flags[flag])
123 | if float32:
124 | img = img.astype(np.float32) / 255.
125 | return img
126 |
127 | def padding(img_lq, img_gt, gt_size):
128 | h, w, _ = img_lq.shape
129 |
130 | h_pad = max(0, gt_size - h)
131 | w_pad = max(0, gt_size - w)
132 |
133 | if h_pad == 0 and w_pad == 0:
134 | return img_lq, img_gt
135 |
136 | img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
137 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
138 | # print('img_lq', img_lq.shape, img_gt.shape)
139 | return img_lq, img_gt
140 |
141 | def imwrite(img, file_path, params=None, auto_mkdir=True):
142 | """Write image to file.
143 |
144 | Args:
145 | img (ndarray): Image array to be written.
146 | file_path (str): Image file path.
147 | params (None or list): Same as opencv's :func:`imwrite` interface.
148 | auto_mkdir (bool): If the parent folder of `file_path` does not exist,
149 | whether to create it automatically.
150 |
151 | Returns:
152 | bool: Successful or not.
153 | """
154 | if auto_mkdir:
155 | dir_name = os.path.abspath(os.path.dirname(file_path))
156 | os.makedirs(dir_name, exist_ok=True)
157 | return cv2.imwrite(file_path, img, params)
158 |
159 |
160 | def crop_border(imgs, crop_border):
161 | """Crop borders of images.
162 |
163 | Args:
164 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
165 | crop_border (int): Crop border for each end of height and weight.
166 |
167 | Returns:
168 | list[ndarray]: Cropped images.
169 | """
170 | if crop_border == 0:
171 | return imgs
172 | else:
173 | if isinstance(imgs, list):
174 | return [
175 | v[crop_border:-crop_border, crop_border:-crop_border, ...]
176 | for v in imgs
177 | ]
178 | else:
179 | return imgs[crop_border:-crop_border, crop_border:-crop_border,
180 | ...]
181 |
--------------------------------------------------------------------------------
/basicsr/utils/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import time
4 |
5 | from .dist_util import get_dist_info, master_only
6 |
7 |
8 | class MessageLogger():
9 | """Message logger for printing.
10 |
11 | Args:
12 | opt (dict): Config. It contains the following keys:
13 | name (str): Exp name.
14 | logger (dict): Contains 'print_freq' (str) for logger interval.
15 | train (dict): Contains 'total_iter' (int) for total iters.
16 | use_tb_logger (bool): Use tensorboard logger.
17 | start_iter (int): Start iter. Default: 1.
18 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
19 | """
20 |
21 | def __init__(self, opt, start_iter=1, tb_logger=None):
22 | self.exp_name = opt['name']
23 | self.interval = opt['logger']['print_freq']
24 | self.start_iter = start_iter
25 | self.max_iters = opt['train']['total_iter']
26 | self.use_tb_logger = opt['logger']['use_tb_logger']
27 | self.tb_logger = tb_logger
28 | self.start_time = time.time()
29 | self.logger = get_root_logger()
30 |
31 | @master_only
32 | def __call__(self, log_vars):
33 | """Format logging message.
34 |
35 | Args:
36 | log_vars (dict): It contains the following keys:
37 | epoch (int): Epoch number.
38 | iter (int): Current iter.
39 | lrs (list): List for learning rates.
40 |
41 | time (float): Iter time.
42 | data_time (float): Data time for each iter.
43 | """
44 | # epoch, iter, learning rates
45 | epoch = log_vars.pop('epoch')
46 | current_iter = log_vars.pop('iter')
47 | lrs = log_vars.pop('lrs')
48 |
49 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, '
50 | f'iter:{current_iter:8,d}, lr:(')
51 | for v in lrs:
52 | message += f'{v:.3e},'
53 | message += ')] '
54 |
55 | # time and estimated time
56 | if 'time' in log_vars.keys():
57 | iter_time = log_vars.pop('time')
58 | data_time = log_vars.pop('data_time')
59 |
60 | total_time = time.time() - self.start_time
61 | time_sec_avg = total_time / (current_iter - self.start_iter + 1)
62 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
63 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
64 | message += f'[eta: {eta_str}, '
65 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
66 |
67 | # other items, especially losses
68 | for k, v in log_vars.items():
69 | message += f'{k}: {v:.4e} '
70 | # tensorboard logger
71 | if self.use_tb_logger and 'debug' not in self.exp_name:
72 | if k.startswith('l_'):
73 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
74 | else:
75 | self.tb_logger.add_scalar(k, v, current_iter)
76 | self.logger.info(message)
77 |
78 |
79 | @master_only
80 | def init_tb_logger(log_dir):
81 | from torch.utils.tensorboard import SummaryWriter
82 | tb_logger = SummaryWriter(log_dir=log_dir)
83 | return tb_logger
84 |
85 |
86 | @master_only
87 | def init_wandb_logger(opt):
88 | """We now only use wandb to sync tensorboard log."""
89 | import wandb
90 | logger = logging.getLogger('basicsr')
91 |
92 | project = opt['logger']['wandb']['project']
93 | resume_id = opt['logger']['wandb'].get('resume_id')
94 | if resume_id:
95 | wandb_id = resume_id
96 | resume = 'allow'
97 | logger.warning(f'Resume wandb logger with id={wandb_id}.')
98 | else:
99 | wandb_id = wandb.util.generate_id()
100 | resume = 'never'
101 |
102 | wandb.init(
103 | id=wandb_id,
104 | resume=resume,
105 | name=opt['name'],
106 | config=opt,
107 | project=project,
108 | sync_tensorboard=True)
109 |
110 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
111 |
112 |
113 | def get_root_logger(logger_name='basicsr',
114 | log_level=logging.INFO,
115 | log_file=None):
116 | """Get the root logger.
117 |
118 | The logger will be initialized if it has not been initialized. By default a
119 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
120 | also be added.
121 |
122 | Args:
123 | logger_name (str): root logger name. Default: 'basicsr'.
124 | log_file (str | None): The log filename. If specified, a FileHandler
125 | will be added to the root logger.
126 | log_level (int): The root logger level. Note that only the process of
127 | rank 0 is affected, while other processes will set the level to
128 | "Error" and be silent most of the time.
129 |
130 | Returns:
131 | logging.Logger: The root logger.
132 | """
133 | logger = logging.getLogger(logger_name)
134 | # if the logger has been initialized, just return it
135 | if logger.hasHandlers():
136 | return logger
137 |
138 | format_str = '%(asctime)s %(levelname)s: %(message)s'
139 | logging.basicConfig(format=format_str, level=log_level)
140 | rank, _ = get_dist_info()
141 | if rank != 0:
142 | logger.setLevel('ERROR')
143 | elif log_file is not None:
144 | file_handler = logging.FileHandler(log_file, 'w')
145 | file_handler.setFormatter(logging.Formatter(format_str))
146 | file_handler.setLevel(log_level)
147 | logger.addHandler(file_handler)
148 |
149 | return logger
150 |
151 |
152 | def get_env_info():
153 | """Get environment information.
154 |
155 | Currently, only log the software version.
156 | """
157 | import torch
158 | import torchvision
159 |
160 | from basicsr.version import __version__
161 | msg = r"""
162 | ____ _ _____ ____
163 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \
164 | / __ |/ __ `// ___// // ___/\__ \ / /_/ /
165 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
166 | /_____/ \__,_//____//_/ \___//____//_/ |_|
167 | ______ __ __ __ __
168 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
169 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
170 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
171 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
172 | """
173 | msg += ('\nVersion Information: '
174 | f'\n\tBasicSR: {__version__}'
175 | f'\n\tPyTorch: {torch.__version__}'
176 | f'\n\tTorchVision: {torchvision.__version__}')
177 | return msg
178 |
--------------------------------------------------------------------------------
/basicsr/data/npz_image_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils import data as data
2 | from torchvision.transforms.functional import normalize
3 | from tqdm import tqdm
4 | import os
5 | from pathlib import Path
6 | import random
7 | import numpy as np
8 | import torch
9 |
10 | from basicsr.data.data_util import (paired_paths_from_folder,
11 | paired_paths_from_lmdb,
12 | paired_paths_from_meta_info_file,
13 | recursive_glob)
14 | from basicsr.data.event_util import events_to_voxel_grid, voxel_norm
15 | from basicsr.data.transforms import augment, triple_random_crop, random_augmentation
16 | from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, get_root_logger
17 | from torch.utils.data.dataloader import default_collate
18 |
19 |
20 | class NpzPngSingleDeblurDataset(data.Dataset):
21 | """Paired npz and png dataset for event-based single image deblurring.
22 | --HighREV
23 | |----train
24 | | |----blur
25 | | | |----SEQNAME_%5d.png
26 | | | |----...
27 | | |----event
28 | | | |----SEQNAME_%5d_%2d.npz
29 | | | |----...
30 | | |----sharp
31 | | | |----SEQNAME_%5d.png
32 | | | |----...
33 | |----val
34 | ...
35 |
36 |
37 | Args:
38 | opt (dict): Config for train dataset. It contains the following keys:
39 | dataroot (str): Data root path.
40 | io_backend (dict): IO backend type and other kwarg.
41 | num_end_interpolation (int): Number of sharp frames to reconstruct in each blurry image.
42 | num_inter_interpolation (int): Number of sharp frames to interpolate between two blurry images.
43 | phase (str): 'train' or 'test'
44 |
45 | gt_size (int): Cropped patched size for gt patches.
46 | random_reverse (bool): Random reverse input frames.
47 | use_hflip (bool): Use horizontal flips.
48 | use_rot (bool): Use rotation (use vertical flip and transposing h
49 | and w for implementation).
50 |
51 | scale (bool): Scale, which will be added automatically.
52 | """
53 |
54 | def __init__(self, opt):
55 | super(NpzPngSingleDeblurDataset, self).__init__()
56 | self.opt = opt
57 | self.dataroot = Path(opt['dataroot'])
58 | self.split = 'train' if opt['phase'] == 'train' else 'val' # train or val
59 | self.voxel_bins = opt['voxel_bins']
60 | self.norm_voxel = opt['norm_voxel']
61 |
62 | self.dataPath = []
63 |
64 | blur_frames = sorted(recursive_glob(rootdir=os.path.join(self.dataroot, 'blur'), suffix='.png'))
65 | blur_frames = [os.path.join(self.dataroot, 'blur', blur_frame) for blur_frame in blur_frames]
66 |
67 | sharp_frames = sorted(recursive_glob(rootdir=os.path.join(self.dataroot, 'sharp'), suffix='.png'))
68 | sharp_frames = [os.path.join(self.dataroot, 'sharp', sharp_frame) for sharp_frame in sharp_frames]
69 |
70 | event_frames = sorted(recursive_glob(rootdir=os.path.join(self.dataroot, 'event'), suffix='.npz'))
71 | event_frames = [os.path.join(self.dataroot, 'event', event_frame) for event_frame in event_frames]
72 |
73 | assert len(blur_frames) == len(sharp_frames), f"Mismatch in blur ({len(blur_frames)}) and sharp ({len(sharp_frames)}) frame counts."
74 |
75 | for i in range(len(blur_frames)):
76 | blur_name = os.path.basename(blur_frames[i]) # e.g., SEQNAME_00001.png
77 | base_name = os.path.splitext(blur_name)[0] # Remove .png, get SEQNAME_00001
78 | event_list = sorted([f for f in event_frames if f.startswith(os.path.join(self.dataroot, 'event', base_name + '_'))])[:-1] # remove the last one because it is out of the exposure time
79 |
80 | self.dataPath.append({
81 | 'blur_path': blur_frames[i],
82 | 'sharp_path': sharp_frames[i],
83 | 'event_paths': event_list
84 | })
85 | logger = get_root_logger()
86 | logger.info(f"Dataset initialized with {len(self.dataPath)} samples.")
87 |
88 | # file client (io backend)
89 | self.file_client = None
90 | self.io_backend_opt = opt['io_backend']
91 | # import pdb; pdb.set_trace()
92 |
93 | def __getitem__(self, index):
94 | if self.file_client is None:
95 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
96 | scale = self.opt['scale']
97 | gt_size = self.opt['gt_size']
98 |
99 | image_path = self.dataPath[index]['blur_path']
100 | gt_path = self.dataPath[index]['sharp_path']
101 | event_paths = self.dataPath[index]['event_paths']
102 |
103 | # get LQ
104 | img_bytes = self.file_client.get(image_path) # 'lq'
105 | img_lq = imfrombytes(img_bytes, float32=True)
106 | # get GT
107 | img_bytes = self.file_client.get(gt_path) # 'gt'
108 | img_gt = imfrombytes(img_bytes, float32=True)
109 |
110 | h_lq, w_lq, _ = img_lq.shape
111 |
112 | ## Read event and convert to voxel grid:
113 | events = [np.load(event_path) for event_path in event_paths]
114 | # npz -> ndarray
115 | all_quad_event_array = np.zeros((0,4)).astype(np.float32)
116 | for event in events:
117 | ### IMPORTANT: dataset mistake x and y !!!!!!!!
118 | ### Switch x and y here !!!!
119 | y = event['x'].astype(np.float32)
120 | x = event['y'].astype(np.float32)
121 | t = event['timestamp'].astype(np.float32)
122 | p = event['polarity'].astype(np.float32)
123 |
124 | this_quad_event_array = np.concatenate((t,x,y,p),axis=1) # N,4
125 | all_quad_event_array = np.concatenate((all_quad_event_array, this_quad_event_array), axis=0)
126 | voxel = events_to_voxel_grid(all_quad_event_array, num_bins=self.voxel_bins, width=w_lq, height=h_lq, return_format='HWC')
127 |
128 | ## Data augmentation
129 | # voxel shape: h,w,c
130 | # crop
131 | if gt_size is not None:
132 | img_gt, img_lq, voxel = triple_random_crop(img_gt, img_lq, voxel, gt_size, scale, gt_path)
133 |
134 | # flip, rotate
135 | total_input = [img_lq, img_gt, voxel]
136 | img_results = augment(total_input, self.opt['use_hflip'], self.opt['use_rot'])
137 |
138 | img_results = img2tensor(img_results) # hwc -> chw
139 | img_lq, img_gt, voxel = img_results
140 |
141 | ## Norm voxel
142 | if self.norm_voxel:
143 | voxel = voxel_norm(voxel)
144 |
145 | origin_index = os.path.basename(image_path).split('.')[0]
146 |
147 | return {'frame': img_lq, 'frame_gt': img_gt, 'voxel': voxel, 'image_name': origin_index}
148 |
149 | def __len__(self):
150 | return len(self.dataPath)
151 |
--------------------------------------------------------------------------------
/basicsr/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from torch.optim.lr_scheduler import _LRScheduler
4 |
5 |
6 | class MultiStepRestartLR(_LRScheduler):
7 | """ MultiStep with restarts learning rate scheme.
8 |
9 | Args:
10 | optimizer (torch.nn.optimizer): Torch optimizer.
11 | milestones (list): Iterations that will decrease learning rate.
12 | gamma (float): Decrease ratio. Default: 0.1.
13 | restarts (list): Restart iterations. Default: [0].
14 | restart_weights (list): Restart weights at each restart iteration.
15 | Default: [1].
16 | last_epoch (int): Used in _LRScheduler. Default: -1.
17 | """
18 |
19 | def __init__(self,
20 | optimizer,
21 | milestones,
22 | gamma=0.1,
23 | restarts=(0, ),
24 | restart_weights=(1, ),
25 | last_epoch=-1):
26 | self.milestones = Counter(milestones)
27 | self.gamma = gamma
28 | self.restarts = restarts
29 | self.restart_weights = restart_weights
30 | assert len(self.restarts) == len(
31 | self.restart_weights), 'restarts and their weights do not match.'
32 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
33 |
34 | def get_lr(self):
35 | if self.last_epoch in self.restarts:
36 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
37 | return [
38 | group['initial_lr'] * weight
39 | for group in self.optimizer.param_groups
40 | ]
41 | if self.last_epoch not in self.milestones:
42 | return [group['lr'] for group in self.optimizer.param_groups]
43 | return [
44 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
45 | for group in self.optimizer.param_groups
46 | ]
47 |
48 | class LinearLR(_LRScheduler):
49 | """
50 |
51 | Args:
52 | optimizer (torch.nn.optimizer): Torch optimizer.
53 | milestones (list): Iterations that will decrease learning rate.
54 | gamma (float): Decrease ratio. Default: 0.1.
55 | last_epoch (int): Used in _LRScheduler. Default: -1.
56 | """
57 |
58 | def __init__(self,
59 | optimizer,
60 | total_iter,
61 | last_epoch=-1):
62 | self.total_iter = total_iter
63 | super(LinearLR, self).__init__(optimizer, last_epoch)
64 |
65 | def get_lr(self):
66 | process = self.last_epoch / self.total_iter
67 | weight = (1 - process)
68 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups])
69 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
70 |
71 | class VibrateLR(_LRScheduler):
72 | """
73 |
74 | Args:
75 | optimizer (torch.nn.optimizer): Torch optimizer.
76 | milestones (list): Iterations that will decrease learning rate.
77 | gamma (float): Decrease ratio. Default: 0.1.
78 | last_epoch (int): Used in _LRScheduler. Default: -1.
79 | """
80 |
81 | def __init__(self,
82 | optimizer,
83 | total_iter,
84 | last_epoch=-1):
85 | self.total_iter = total_iter
86 | super(VibrateLR, self).__init__(optimizer, last_epoch)
87 |
88 | def get_lr(self):
89 | process = self.last_epoch / self.total_iter
90 |
91 | f = 0.1
92 | if process < 3 / 8:
93 | f = 1 - process * 8 / 3
94 | elif process < 5 / 8:
95 | f = 0.2
96 |
97 | T = self.total_iter // 80
98 | Th = T // 2
99 |
100 | t = self.last_epoch % T
101 |
102 | f2 = t / Th
103 | if t >= Th:
104 | f2 = 2 - f2
105 |
106 | weight = f * f2
107 |
108 | if self.last_epoch < Th:
109 | weight = max(0.1, weight)
110 |
111 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2))
112 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
113 |
114 | def get_position_from_periods(iteration, cumulative_period):
115 | """Get the position from a period list.
116 |
117 | It will return the index of the right-closest number in the period list.
118 | For example, the cumulative_period = [100, 200, 300, 400],
119 | if iteration == 50, return 0;
120 | if iteration == 210, return 2;
121 | if iteration == 300, return 2.
122 |
123 | Args:
124 | iteration (int): Current iteration.
125 | cumulative_period (list[int]): Cumulative period list.
126 |
127 | Returns:
128 | int: The position of the right-closest number in the period list.
129 | """
130 | for i, period in enumerate(cumulative_period):
131 | if iteration <= period:
132 | return i
133 |
134 |
135 | class CosineAnnealingRestartLR(_LRScheduler):
136 | """ Cosine annealing with restarts learning rate scheme.
137 |
138 | An example of config:
139 | periods = [10, 10, 10, 10]
140 | restart_weights = [1, 0.5, 0.5, 0.5]
141 | eta_min=1e-7
142 |
143 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
144 | scheduler will restart with the weights in restart_weights.
145 |
146 | Args:
147 | optimizer (torch.nn.optimizer): Torch optimizer.
148 | periods (list): Period for each cosine anneling cycle.
149 | restart_weights (list): Restart weights at each restart iteration.
150 | Default: [1].
151 | eta_min (float): The mimimum lr. Default: 0.
152 | last_epoch (int): Used in _LRScheduler. Default: -1.
153 | """
154 |
155 | def __init__(self,
156 | optimizer,
157 | periods,
158 | restart_weights=(1, ),
159 | eta_min=0,
160 | last_epoch=-1):
161 | self.periods = periods
162 | self.restart_weights = restart_weights
163 | self.eta_min = eta_min
164 | assert (len(self.periods) == len(self.restart_weights)
165 | ), 'periods and restart_weights should have the same length.'
166 | self.cumulative_period = [
167 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
168 | ]
169 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
170 |
171 | def get_lr(self):
172 | idx = get_position_from_periods(self.last_epoch,
173 | self.cumulative_period)
174 | current_weight = self.restart_weights[idx]
175 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
176 | current_period = self.periods[idx]
177 |
178 | return [
179 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
180 | (1 + math.cos(math.pi * (
181 | (self.last_epoch - nearest_restart) / current_period)))
182 | for base_lr in self.base_lrs
183 | ]
184 |
--------------------------------------------------------------------------------
/basicsr/utils/flow_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
2 | import cv2
3 | import numpy as np
4 | import os
5 |
6 |
7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
8 | """Read an optical flow map.
9 |
10 | Args:
11 | flow_path (ndarray or str): Flow path.
12 | quantize (bool): whether to read quantized pair, if set to True,
13 | remaining args will be passed to :func:`dequantize_flow`.
14 | concat_axis (int): The axis that dx and dy are concatenated,
15 | can be either 0 or 1. Ignored if quantize is False.
16 |
17 | Returns:
18 | ndarray: Optical flow represented as a (h, w, 2) numpy array
19 | """
20 | if quantize:
21 | assert concat_axis in [0, 1]
22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
23 | if cat_flow.ndim != 2:
24 | raise IOError(f'{flow_path} is not a valid quantized flow file, '
25 | f'its dimension is {cat_flow.ndim}.')
26 | assert cat_flow.shape[concat_axis] % 2 == 0
27 | dx, dy = np.split(cat_flow, 2, axis=concat_axis)
28 | flow = dequantize_flow(dx, dy, *args, **kwargs)
29 | else:
30 | with open(flow_path, 'rb') as f:
31 | try:
32 | header = f.read(4).decode('utf-8')
33 | except Exception:
34 | raise IOError(f'Invalid flow file: {flow_path}')
35 | else:
36 | if header != 'PIEH':
37 | raise IOError(f'Invalid flow file: {flow_path}, '
38 | 'header does not contain PIEH')
39 |
40 | w = np.fromfile(f, np.int32, 1).squeeze()
41 | h = np.fromfile(f, np.int32, 1).squeeze()
42 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
43 |
44 | return flow.astype(np.float32)
45 |
46 |
47 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
48 | """Write optical flow to file.
49 |
50 | If the flow is not quantized, it will be saved as a .flo file losslessly,
51 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
52 | will be concatenated horizontally into a single image if quantize is True.)
53 |
54 | Args:
55 | flow (ndarray): (h, w, 2) array of optical flow.
56 | filename (str): Output filepath.
57 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg
58 | images. If set to True, remaining args will be passed to
59 | :func:`quantize_flow`.
60 | concat_axis (int): The axis that dx and dy are concatenated,
61 | can be either 0 or 1. Ignored if quantize is False.
62 | """
63 | if not quantize:
64 | with open(filename, 'wb') as f:
65 | f.write('PIEH'.encode('utf-8'))
66 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
67 | flow = flow.astype(np.float32)
68 | flow.tofile(f)
69 | f.flush()
70 | else:
71 | assert concat_axis in [0, 1]
72 | dx, dy = quantize_flow(flow, *args, **kwargs)
73 | dxdy = np.concatenate((dx, dy), axis=concat_axis)
74 | os.makedirs(filename, exist_ok=True)
75 | cv2.imwrite(dxdy, filename)
76 |
77 |
78 | def quantize_flow(flow, max_val=0.02, norm=True):
79 | """Quantize flow to [0, 255].
80 |
81 | After this step, the size of flow will be much smaller, and can be
82 | dumped as jpeg images.
83 |
84 | Args:
85 | flow (ndarray): (h, w, 2) array of optical flow.
86 | max_val (float): Maximum value of flow, values beyond
87 | [-max_val, max_val] will be truncated.
88 | norm (bool): Whether to divide flow values by image width/height.
89 |
90 | Returns:
91 | tuple[ndarray]: Quantized dx and dy.
92 | """
93 | h, w, _ = flow.shape
94 | dx = flow[..., 0]
95 | dy = flow[..., 1]
96 | if norm:
97 | dx = dx / w # avoid inplace operations
98 | dy = dy / h
99 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
100 | flow_comps = [
101 | quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
102 | ]
103 | return tuple(flow_comps)
104 |
105 |
106 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
107 | """Recover from quantized flow.
108 |
109 | Args:
110 | dx (ndarray): Quantized dx.
111 | dy (ndarray): Quantized dy.
112 | max_val (float): Maximum value used when quantizing.
113 | denorm (bool): Whether to multiply flow values with width/height.
114 |
115 | Returns:
116 | ndarray: Dequantized flow.
117 | """
118 | assert dx.shape == dy.shape
119 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
120 |
121 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
122 |
123 | if denorm:
124 | dx *= dx.shape[1]
125 | dy *= dx.shape[0]
126 | flow = np.dstack((dx, dy))
127 | return flow
128 |
129 |
130 | def quantize(arr, min_val, max_val, levels, dtype=np.int64):
131 | """Quantize an array of (-inf, inf) to [0, levels-1].
132 |
133 | Args:
134 | arr (ndarray): Input array.
135 | min_val (scalar): Minimum value to be clipped.
136 | max_val (scalar): Maximum value to be clipped.
137 | levels (int): Quantization levels.
138 | dtype (np.type): The type of the quantized array.
139 |
140 | Returns:
141 | tuple: Quantized array.
142 | """
143 | if not (isinstance(levels, int) and levels > 1):
144 | raise ValueError(
145 | f'levels must be a positive integer, but got {levels}')
146 | if min_val >= max_val:
147 | raise ValueError(
148 | f'min_val ({min_val}) must be smaller than max_val ({max_val})')
149 |
150 | arr = np.clip(arr, min_val, max_val) - min_val
151 | quantized_arr = np.minimum(
152 | np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
153 |
154 | return quantized_arr
155 |
156 |
157 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
158 | """Dequantize an array.
159 |
160 | Args:
161 | arr (ndarray): Input array.
162 | min_val (scalar): Minimum value to be clipped.
163 | max_val (scalar): Maximum value to be clipped.
164 | levels (int): Quantization levels.
165 | dtype (np.type): The type of the dequantized array.
166 |
167 | Returns:
168 | tuple: Dequantized array.
169 | """
170 | if not (isinstance(levels, int) and levels > 1):
171 | raise ValueError(
172 | f'levels must be a positive integer, but got {levels}')
173 | if min_val >= max_val:
174 | raise ValueError(
175 | f'min_val ({min_val}) must be smaller than max_val ({max_val})')
176 |
177 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
178 | min_val) / levels + min_val
179 |
180 | return dequantized_arr
181 |
--------------------------------------------------------------------------------
/basicsr/utils/file_client.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
2 | from abc import ABCMeta, abstractmethod
3 |
4 |
5 | class BaseStorageBackend(metaclass=ABCMeta):
6 | """Abstract class of storage backends.
7 |
8 | All backends need to implement two apis: ``get()`` and ``get_text()``.
9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
10 | as texts.
11 | """
12 |
13 | @abstractmethod
14 | def get(self, filepath):
15 | pass
16 |
17 | @abstractmethod
18 | def get_text(self, filepath):
19 | pass
20 |
21 |
22 | class MemcachedBackend(BaseStorageBackend):
23 | """Memcached storage backend.
24 |
25 | Attributes:
26 | server_list_cfg (str): Config file for memcached server list.
27 | client_cfg (str): Config file for memcached client.
28 | sys_path (str | None): Additional path to be appended to `sys.path`.
29 | Default: None.
30 | """
31 |
32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None):
33 | if sys_path is not None:
34 | import sys
35 | sys.path.append(sys_path)
36 | try:
37 | import mc
38 | except ImportError:
39 | raise ImportError(
40 | 'Please install memcached to enable MemcachedBackend.')
41 |
42 | self.server_list_cfg = server_list_cfg
43 | self.client_cfg = client_cfg
44 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
45 | self.client_cfg)
46 | # mc.pyvector servers as a point which points to a memory cache
47 | self._mc_buffer = mc.pyvector()
48 |
49 | def get(self, filepath):
50 | filepath = str(filepath)
51 | import mc
52 | self._client.Get(filepath, self._mc_buffer)
53 | value_buf = mc.ConvertBuffer(self._mc_buffer)
54 | return value_buf
55 |
56 | def get_text(self, filepath):
57 | raise NotImplementedError
58 |
59 |
60 | class HardDiskBackend(BaseStorageBackend):
61 | """Raw hard disks storage backend."""
62 |
63 | def get(self, filepath):
64 | filepath = str(filepath)
65 | with open(filepath, 'rb') as f:
66 | value_buf = f.read()
67 | return value_buf
68 |
69 | def get_text(self, filepath):
70 | filepath = str(filepath)
71 | with open(filepath, 'r') as f:
72 | value_buf = f.read()
73 | return value_buf
74 |
75 |
76 | class LmdbBackend(BaseStorageBackend):
77 | """Lmdb storage backend.
78 |
79 | Args:
80 | db_paths (str | list[str]): Lmdb database paths.
81 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
82 | readonly (bool, optional): Lmdb environment parameter. If True,
83 | disallow any write operations. Default: True.
84 | lock (bool, optional): Lmdb environment parameter. If False, when
85 | concurrent access occurs, do not lock the database. Default: False.
86 | readahead (bool, optional): Lmdb environment parameter. If False,
87 | disable the OS filesystem readahead mechanism, which may improve
88 | random read performance when a database is larger than RAM.
89 | Default: False.
90 |
91 | Attributes:
92 | db_paths (list): Lmdb database path.
93 | _client (list): A list of several lmdb envs.
94 | """
95 |
96 | def __init__(self,
97 | db_paths,
98 | client_keys='default',
99 | readonly=True,
100 | lock=False,
101 | readahead=False,
102 | **kwargs):
103 | try:
104 | import lmdb
105 | except ImportError:
106 | raise ImportError('Please install lmdb to enable LmdbBackend.')
107 |
108 | if isinstance(client_keys, str):
109 | client_keys = [client_keys]
110 |
111 | if isinstance(db_paths, list):
112 | self.db_paths = [str(v) for v in db_paths]
113 | elif isinstance(db_paths, str):
114 | self.db_paths = [str(db_paths)]
115 | assert len(client_keys) == len(self.db_paths), (
116 | 'client_keys and db_paths should have the same length, '
117 | f'but received {len(client_keys)} and {len(self.db_paths)}.')
118 |
119 | self._client = {}
120 |
121 | for client, path in zip(client_keys, self.db_paths):
122 | self._client[client] = lmdb.open(
123 | path,
124 | readonly=readonly,
125 | lock=lock,
126 | readahead=readahead,
127 | map_size=8*1024*10485760,
128 | # max_readers=1,
129 | **kwargs)
130 |
131 | def get(self, filepath, client_key):
132 | """Get values according to the filepath from one lmdb named client_key.
133 |
134 | Args:
135 | filepath (str | obj:`Path`): Here, filepath is the lmdb key.
136 | client_key (str): Used for distinguishing differnet lmdb envs.
137 | """
138 | filepath = str(filepath)
139 | assert client_key in self._client, (f'client_key {client_key} is not '
140 | 'in lmdb clients.')
141 | client = self._client[client_key]
142 | with client.begin(write=False) as txn:
143 | value_buf = txn.get(filepath.encode('ascii'))
144 | return value_buf
145 |
146 | def get_text(self, filepath):
147 | raise NotImplementedError
148 |
149 |
150 | class FileClient(object):
151 | """A general file client to access files in different backend.
152 |
153 | The client loads a file or text in a specified backend from its path
154 | and return it as a binary file. it can also register other backend
155 | accessor with a given name and backend class.
156 |
157 | Attributes:
158 | backend (str): The storage backend type. Options are "disk",
159 | "memcached" and "lmdb".
160 | client (:obj:`BaseStorageBackend`): The backend object.
161 | """
162 |
163 | _backends = {
164 | 'disk': HardDiskBackend,
165 | 'memcached': MemcachedBackend,
166 | 'lmdb': LmdbBackend,
167 | }
168 |
169 | def __init__(self, backend='disk', **kwargs):
170 | if backend not in self._backends:
171 | raise ValueError(
172 | f'Backend {backend} is not supported. Currently supported ones'
173 | f' are {list(self._backends.keys())}')
174 | self.backend = backend
175 | self.client = self._backends[backend](**kwargs)
176 |
177 | def get(self, filepath, client_key='default'):
178 | # client_key is used only for lmdb, where different fileclients have
179 | # different lmdb environments.
180 | if self.backend == 'lmdb':
181 | return self.client.get(filepath, client_key)
182 | else:
183 | return self.client.get(filepath)
184 |
185 | def get_text(self, filepath):
186 | return self.client.get_text(filepath)
187 |
--------------------------------------------------------------------------------
/basicsr/models/losses/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 | import numpy as np
5 |
6 | from basicsr.models.losses.loss_util import weighted_loss
7 |
8 | _reduction_modes = ['none', 'mean', 'sum']
9 |
10 |
11 | @weighted_loss
12 | def l1_loss(pred, target):
13 | return F.l1_loss(pred, target, reduction='none')
14 |
15 |
16 | @weighted_loss
17 | def mse_loss(pred, target):
18 | return F.mse_loss(pred, target, reduction='none')
19 |
20 |
21 | # AT loss
22 | def at(x):
23 | return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))
24 |
25 | def at_loss(x, y):
26 | return (at(x) - at(y)).pow(2).mean()
27 |
28 | @weighted_loss
29 | def charbonnier_loss(pred, target, eps=1e-12):
30 | return torch.sqrt((pred - target)**2 + eps)
31 |
32 | # @weighted_loss
33 | # def charbonnier_loss(pred, target, eps=1e-12):
34 | # return torch.sqrt((pred - target)**2 + eps)
35 |
36 |
37 | class L1Loss(nn.Module):
38 | """L1 (mean absolute error, MAE) loss.
39 |
40 | Args:
41 | loss_weight (float): Loss weight for L1 loss. Default: 1.0.
42 | reduction (str): Specifies the reduction to apply to the output.
43 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
44 | """
45 |
46 | def __init__(self, loss_weight=1.0, reduction='mean'):
47 | super(L1Loss, self).__init__()
48 | if reduction not in ['none', 'mean', 'sum']:
49 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
50 | f'Supported ones are: {_reduction_modes}')
51 |
52 | self.loss_weight = loss_weight
53 | self.reduction = reduction
54 |
55 | def forward(self, pred, target, weight=None, **kwargs):
56 | """
57 | Args:
58 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
59 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
60 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
61 | weights. Default: None.
62 | """
63 | return self.loss_weight * l1_loss(
64 | pred, target, weight, reduction=self.reduction)
65 |
66 | class MSELoss(nn.Module):
67 | """MSE (L2) loss.
68 |
69 | Args:
70 | loss_weight (float): Loss weight for MSE loss. Default: 1.0.
71 | reduction (str): Specifies the reduction to apply to the output.
72 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
73 | """
74 |
75 | def __init__(self, loss_weight=1.0, reduction='mean'):
76 | super(MSELoss, self).__init__()
77 | if reduction not in ['none', 'mean', 'sum']:
78 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
79 | f'Supported ones are: {_reduction_modes}')
80 |
81 | self.loss_weight = loss_weight
82 | self.reduction = reduction
83 |
84 | def forward(self, pred, target, weight=None, **kwargs):
85 | """
86 | Args:
87 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
88 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
89 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
90 | weights. Default: None.
91 | """
92 | return self.loss_weight * mse_loss(
93 | pred, target, weight, reduction=self.reduction)
94 |
95 | class PSNRLoss(nn.Module):
96 |
97 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
98 | super(PSNRLoss, self).__init__()
99 | assert reduction == 'mean'
100 | self.loss_weight = loss_weight
101 | self.scale = 10 / np.log(10)
102 | self.toY = toY
103 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
104 | self.first = True
105 |
106 | def forward(self, pred, target):
107 | assert len(pred.size()) == 4
108 | if self.toY:
109 | if self.first:
110 | self.coef = self.coef.to(pred.device)
111 | self.first = False
112 |
113 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
114 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
115 |
116 | pred, target = pred / 255., target / 255.
117 | pass
118 | assert len(pred.size()) == 4
119 |
120 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
121 |
122 |
123 | class SRNLoss(nn.Module):
124 |
125 | def __init__(self):
126 | super(SRNLoss, self).__init__()
127 |
128 | def forward(self, preds, target):
129 |
130 | gt1 = target
131 | B,C,H,W = gt1.shape
132 | gt2 = F.interpolate(gt1, size=(H // 2, W // 2), mode='bilinear', align_corners=False)
133 | gt3 = F.interpolate(gt1, size=(H // 4, W // 4), mode='bilinear', align_corners=False)
134 |
135 | l1 = mse_loss(preds[0] , gt3)
136 | l2 = mse_loss(preds[1] , gt2)
137 | l3 = mse_loss(preds[2] , gt1)
138 |
139 | return l1+l2+l3
140 |
141 |
142 |
143 | class CharbonnierLoss(nn.Module):
144 | """Charbonnier loss (one variant of Robust L1Loss, a differentiable
145 | variant of L1Loss).
146 | Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
147 | Super-Resolution".
148 | Args:
149 | loss_weight (float): Loss weight for L1 loss. Default: 1.0.
150 | reduction (str): Specifies the reduction to apply to the output.
151 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
152 | eps (float): A value used to control the curvature near zero.
153 | Default: 1e-12.
154 | """
155 |
156 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
157 | super(CharbonnierLoss, self).__init__()
158 | if reduction not in ['none', 'mean', 'sum']:
159 | raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
160 |
161 | self.loss_weight = loss_weight
162 | self.reduction = reduction
163 | self.eps = eps
164 |
165 | def forward(self, pred, target, weight=None, **kwargs):
166 | """
167 | Args:
168 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
169 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
170 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
171 | weights. Default: None.
172 | """
173 | return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
174 |
175 |
176 | class WeightedTVLoss(L1Loss):
177 | """Weighted TV loss.
178 | Args:
179 | loss_weight (float): Loss weight. Default: 1.0.
180 | """
181 |
182 | def __init__(self, loss_weight=1.0):
183 | super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
184 |
185 | def forward(self, pred, weight=None):
186 | if weight is None:
187 | y_weight = None
188 | x_weight = None
189 | else:
190 | y_weight = weight[:, :, :-1, :]
191 | x_weight = weight[:, :, :, :-1]
192 |
193 | y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
194 | x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
195 |
196 | loss = x_diff + y_diff
197 |
198 | return loss
--------------------------------------------------------------------------------
/basicsr/utils/lmdb_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import lmdb
3 | import sys
4 | from multiprocessing import Pool
5 | from os import path as osp
6 | from tqdm import tqdm
7 |
8 |
9 | def make_lmdb_from_imgs(data_path,
10 | lmdb_path,
11 | img_path_list,
12 | keys,
13 | batch=5000,
14 | compress_level=1,
15 | multiprocessing_read=False,
16 | n_thread=40,
17 | map_size=None):
18 | """Make lmdb from images.
19 |
20 | Contents of lmdb. The file structure is:
21 | example.lmdb
22 | ├── data.mdb
23 | ├── lock.mdb
24 | ├── meta_info.txt
25 |
26 | The data.mdb and lock.mdb are standard lmdb files and you can refer to
27 | https://lmdb.readthedocs.io/en/release/ for more details.
28 |
29 | The meta_info.txt is a specified txt file to record the meta information
30 | of our datasets. It will be automatically created when preparing
31 | datasets by our provided dataset tools.
32 | Each line in the txt file records 1)image name (with extension),
33 | 2)image shape, and 3)compression level, separated by a white space.
34 |
35 | For example, the meta information could be:
36 | `000_00000000.png (720,1280,3) 1`, which means:
37 | 1) image name (with extension): 000_00000000.png;
38 | 2) image shape: (720,1280,3);
39 | 3) compression level: 1
40 |
41 | We use the image name without extension as the lmdb key.
42 |
43 | If `multiprocessing_read` is True, it will read all the images to memory
44 | using multiprocessing. Thus, your server needs to have enough memory.
45 |
46 | Args:
47 | data_path (str): Data path for reading images.
48 | lmdb_path (str): Lmdb save path.
49 | img_path_list (str): Image path list.
50 | keys (str): Used for lmdb keys.
51 | batch (int): After processing batch images, lmdb commits.
52 | Default: 5000.
53 | compress_level (int): Compress level when encoding images. Default: 1.
54 | multiprocessing_read (bool): Whether use multiprocessing to read all
55 | the images to memory. Default: False.
56 | n_thread (int): For multiprocessing.
57 | map_size (int | None): Map size for lmdb env. If None, use the
58 | estimated size from images. Default: None
59 | """
60 |
61 | assert len(img_path_list) == len(keys), (
62 | 'img_path_list and keys should have the same length, '
63 | f'but got {len(img_path_list)} and {len(keys)}')
64 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
65 | print(f'Total images: {len(img_path_list)}')
66 | if not lmdb_path.endswith('.lmdb'):
67 | raise ValueError("lmdb_path must end with '.lmdb'.")
68 | if osp.exists(lmdb_path):
69 | print(f'Folder {lmdb_path} already exists. Exit.')
70 | sys.exit(1)
71 |
72 | if multiprocessing_read:
73 | # read all the images to memory (multiprocessing)
74 | dataset = {} # use dict to keep the order for multiprocessing
75 | shapes = {}
76 | print(f'Read images with multiprocessing, #thread: {n_thread} ...')
77 | pbar = tqdm(total=len(img_path_list), unit='image')
78 |
79 | def callback(arg):
80 | """get the image data and update pbar."""
81 | key, dataset[key], shapes[key] = arg
82 | pbar.update(1)
83 | pbar.set_description(f'Read {key}')
84 |
85 | pool = Pool(n_thread)
86 | for path, key in zip(img_path_list, keys):
87 | pool.apply_async(
88 | read_img_worker,
89 | args=(osp.join(data_path, path), key, compress_level),
90 | callback=callback)
91 | pool.close()
92 | pool.join()
93 | pbar.close()
94 | print(f'Finish reading {len(img_path_list)} images.')
95 |
96 | # create lmdb environment
97 | if map_size is None:
98 | # obtain data size for one image
99 | img = cv2.imread(
100 | osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
101 | _, img_byte = cv2.imencode(
102 | '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
103 | data_size_per_img = img_byte.nbytes
104 | print('Data size per image is: ', data_size_per_img)
105 | data_size = data_size_per_img * len(img_path_list)
106 | map_size = data_size * 10
107 |
108 | env = lmdb.open(lmdb_path, map_size=map_size)
109 |
110 | # write data to lmdb
111 | pbar = tqdm(total=len(img_path_list), unit='chunk')
112 | txn = env.begin(write=True)
113 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
114 | for idx, (path, key) in enumerate(zip(img_path_list, keys)):
115 | pbar.update(1)
116 | pbar.set_description(f'Write {key}')
117 | key_byte = key.encode('ascii')
118 | if multiprocessing_read:
119 | img_byte = dataset[key]
120 | h, w, c = shapes[key]
121 | else:
122 | _, img_byte, img_shape = read_img_worker(
123 | osp.join(data_path, path), key, compress_level)
124 | h, w, c = img_shape
125 |
126 | txn.put(key_byte, img_byte)
127 | # write meta information
128 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
129 | if idx % batch == 0:
130 | txn.commit()
131 | txn = env.begin(write=True)
132 | pbar.close()
133 | txn.commit()
134 | env.close()
135 | txt_file.close()
136 | print('\nFinish writing lmdb.')
137 |
138 |
139 | def read_img_worker(path, key, compress_level):
140 | """Read image worker.
141 |
142 | Args:
143 | path (str): Image path.
144 | key (str): Image key.
145 | compress_level (int): Compress level when encoding images.
146 |
147 | Returns:
148 | str: Image key.
149 | byte: Image byte.
150 | tuple[int]: Image shape.
151 | """
152 |
153 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
154 | if img.ndim == 2:
155 | h, w = img.shape
156 | c = 1
157 | else:
158 | h, w, c = img.shape
159 | _, img_byte = cv2.imencode('.png', img,
160 | [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
161 | return (key, img_byte, (h, w, c))
162 |
163 |
164 | class LmdbMaker():
165 | """LMDB Maker.
166 |
167 | Args:
168 | lmdb_path (str): Lmdb save path.
169 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
170 | batch (int): After processing batch images, lmdb commits.
171 | Default: 5000.
172 | compress_level (int): Compress level when encoding images. Default: 1.
173 | """
174 |
175 | def __init__(self,
176 | lmdb_path,
177 | map_size=1024**4,
178 | batch=5000,
179 | compress_level=1):
180 | if not lmdb_path.endswith('.lmdb'):
181 | raise ValueError("lmdb_path must end with '.lmdb'.")
182 | if osp.exists(lmdb_path):
183 | print(f'Folder {lmdb_path} already exists. Exit.')
184 | sys.exit(1)
185 |
186 | self.lmdb_path = lmdb_path
187 | self.batch = batch
188 | self.compress_level = compress_level
189 | self.env = lmdb.open(lmdb_path, map_size=map_size)
190 | self.txn = self.env.begin(write=True)
191 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
192 | self.counter = 0
193 |
194 | def put(self, img_byte, key, img_shape):
195 | self.counter += 1
196 | key_byte = key.encode('ascii')
197 | self.txn.put(key_byte, img_byte)
198 | # write meta information
199 | h, w, c = img_shape
200 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
201 | if self.counter % self.batch == 0:
202 | self.txn.commit()
203 | self.txn = self.env.begin(write=True)
204 |
205 | def close(self):
206 | self.txn.commit()
207 | self.env.close()
208 | self.txt_file.close()
209 |
--------------------------------------------------------------------------------
/basicsr/metrics/niqe.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | from scipy.ndimage.filters import convolve
5 | from scipy.special import gamma
6 |
7 | from basicsr.metrics.metric_util import reorder_image, to_y_channel
8 |
9 |
10 | def estimate_aggd_param(block):
11 | """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) paramters.
12 |
13 | Args:
14 | block (ndarray): 2D Image block.
15 |
16 | Returns:
17 | tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
18 | distribution (Estimating the parames in Equation 7 in the paper).
19 | """
20 | block = block.flatten()
21 | gam = np.arange(0.2, 10.001, 0.001) # len = 9801
22 | gam_reciprocal = np.reciprocal(gam)
23 | r_gam = np.square(gamma(gam_reciprocal * 2)) / (
24 | gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
25 |
26 | left_std = np.sqrt(np.mean(block[block < 0]**2))
27 | right_std = np.sqrt(np.mean(block[block > 0]**2))
28 | gammahat = left_std / right_std
29 | rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
30 | rhatnorm = (rhat * (gammahat**3 + 1) *
31 | (gammahat + 1)) / ((gammahat**2 + 1)**2)
32 | array_position = np.argmin((r_gam - rhatnorm)**2)
33 |
34 | alpha = gam[array_position]
35 | beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
36 | beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
37 | return (alpha, beta_l, beta_r)
38 |
39 |
40 | def compute_feature(block):
41 | """Compute features.
42 |
43 | Args:
44 | block (ndarray): 2D Image block.
45 |
46 | Returns:
47 | list: Features with length of 18.
48 | """
49 | feat = []
50 | alpha, beta_l, beta_r = estimate_aggd_param(block)
51 | feat.extend([alpha, (beta_l + beta_r) / 2])
52 |
53 | # distortions disturb the fairly regular structure of natural images.
54 | # This deviation can be captured by analyzing the sample distribution of
55 | # the products of pairs of adjacent coefficients computed along
56 | # horizontal, vertical and diagonal orientations.
57 | shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
58 | for i in range(len(shifts)):
59 | shifted_block = np.roll(block, shifts[i], axis=(0, 1))
60 | alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
61 | # Eq. 8
62 | mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
63 | feat.extend([alpha, mean, beta_l, beta_r])
64 | return feat
65 |
66 |
67 | def niqe(img,
68 | mu_pris_param,
69 | cov_pris_param,
70 | gaussian_window,
71 | block_size_h=96,
72 | block_size_w=96):
73 | """Calculate NIQE (Natural Image Quality Evaluator) metric.
74 |
75 | Ref: Making a "Completely Blind" Image Quality Analyzer.
76 | This implementation could produce almost the same results as the official
77 | MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
78 |
79 | Note that we do not include block overlap height and width, since they are
80 | always 0 in the official implementation.
81 |
82 | For good performance, it is advisable by the official implemtation to
83 | divide the distorted image in to the same size patched as used for the
84 | construction of multivariate Gaussian model.
85 |
86 | Args:
87 | img (ndarray): Input image whose quality needs to be computed. The
88 | image must be a gray or Y (of YCbCr) image with shape (h, w).
89 | Range [0, 255] with float type.
90 | mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
91 | model calculated on the pristine dataset.
92 | cov_pris_param (ndarray): Covariance of a pre-defined multivariate
93 | Gaussian model calculated on the pristine dataset.
94 | gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
95 | image.
96 | block_size_h (int): Height of the blocks in to which image is divided.
97 | Default: 96 (the official recommended value).
98 | block_size_w (int): Width of the blocks in to which image is divided.
99 | Default: 96 (the official recommended value).
100 | """
101 | assert img.ndim == 2, (
102 | 'Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
103 | # crop image
104 | h, w = img.shape
105 | num_block_h = math.floor(h / block_size_h)
106 | num_block_w = math.floor(w / block_size_w)
107 | img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
108 |
109 | distparam = [] # dist param is actually the multiscale features
110 | for scale in (1, 2): # perform on two scales (1, 2)
111 | mu = convolve(img, gaussian_window, mode='nearest')
112 | sigma = np.sqrt(
113 | np.abs(
114 | convolve(np.square(img), gaussian_window, mode='nearest') -
115 | np.square(mu)))
116 | # normalize, as in Eq. 1 in the paper
117 | img_nomalized = (img - mu) / (sigma + 1)
118 |
119 | feat = []
120 | for idx_w in range(num_block_w):
121 | for idx_h in range(num_block_h):
122 | # process ecah block
123 | block = img_nomalized[idx_h * block_size_h //
124 | scale:(idx_h + 1) * block_size_h //
125 | scale, idx_w * block_size_w //
126 | scale:(idx_w + 1) * block_size_w //
127 | scale]
128 | feat.append(compute_feature(block))
129 |
130 | distparam.append(np.array(feat))
131 | # TODO: matlab bicubic downsample with anti-aliasing
132 | # for simplicity, now we use opencv instead, which will result in
133 | # a slight difference.
134 | if scale == 1:
135 | h, w = img.shape
136 | img = cv2.resize(
137 | img / 255., (w // 2, h // 2), interpolation=cv2.INTER_LINEAR)
138 | img = img * 255.
139 |
140 | distparam = np.concatenate(distparam, axis=1)
141 |
142 | # fit a MVG (multivariate Gaussian) model to distorted patch features
143 | mu_distparam = np.nanmean(distparam, axis=0)
144 | # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
145 | distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
146 | cov_distparam = np.cov(distparam_no_nan, rowvar=False)
147 |
148 | # compute niqe quality, Eq. 10 in the paper
149 | invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
150 | quality = np.matmul(
151 | np.matmul((mu_pris_param - mu_distparam), invcov_param),
152 | np.transpose((mu_pris_param - mu_distparam)))
153 | quality = np.sqrt(quality)
154 |
155 | return quality
156 |
157 |
158 | def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'):
159 | """Calculate NIQE (Natural Image Quality Evaluator) metric.
160 |
161 | Ref: Making a "Completely Blind" Image Quality Analyzer.
162 | This implementation could produce almost the same results as the official
163 | MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
164 |
165 | We use the official params estimated from the pristine dataset.
166 | We use the recommended block size (96, 96) without overlaps.
167 |
168 | Args:
169 | img (ndarray): Input image whose quality needs to be computed.
170 | The input image must be in range [0, 255] with float/int type.
171 | The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
172 | If the input order is 'HWC' or 'CHW', it will be converted to gray
173 | or Y (of YCbCr) image according to the ``convert_to`` argument.
174 | crop_border (int): Cropped pixels in each edge of an image. These
175 | pixels are not involved in the metric calculation.
176 | input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
177 | Default: 'HWC'.
178 | convert_to (str): Whether coverted to 'y' (of MATLAB YCbCr) or 'gray'.
179 | Default: 'y'.
180 |
181 | Returns:
182 | float: NIQE result.
183 | """
184 |
185 | # we use the official params estimated from the pristine dataset.
186 | niqe_pris_params = np.load('basicsr/metrics/niqe_pris_params.npz')
187 | mu_pris_param = niqe_pris_params['mu_pris_param']
188 | cov_pris_param = niqe_pris_params['cov_pris_param']
189 | gaussian_window = niqe_pris_params['gaussian_window']
190 |
191 | img = img.astype(np.float32)
192 | if input_order != 'HW':
193 | img = reorder_image(img, input_order=input_order)
194 | if convert_to == 'y':
195 | img = to_y_channel(img)
196 | elif convert_to == 'gray':
197 | img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
198 | img = np.squeeze(img)
199 |
200 | if crop_border != 0:
201 | img = img[crop_border:-crop_border, crop_border:-crop_border]
202 |
203 | niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
204 |
205 | return niqe_result
206 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://paperswithcode.com/sota/deblurring-on-gopro?p=mefnet-multi-scale-event-fusion-network-for)
2 |
3 | Event-based Fusion for Motion Deblurring with Cross-modal Attention
4 | ---
5 | #### Lei Sun, Christos Sakaridis, Jingyun Liang, Qi Jiang, Kailun Yang, Peng Sun, Yaozu Ye, Kaiwei Wang, Luc Van Gool
6 | #### Paper: https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136780403.pdf
7 | > Traditional frame-based cameras inevitably suffer from motion blur due to long exposure times. As a kind of bio-inspired camera, the event camera records the intensity changes in an asynchronous way with high temporal resolution, providing valid image degradation information within the exposure time. In this paper, we rethink the eventbased image deblurring problem and unfold it into an end-to-end two-stage image restoration network. To effectively fuse event and image features, we design an event-image cross-modal attention module applied at multiple levels of our network, which allows to focus on relevant features from the event branch and filter out noise. We also introduce a novel symmetric cumulative event representation specifically for image deblurring as well as an event mask gated connection between the two stages of our network which helps avoid information loss. At the dataset level, to foster event-based motion deblurring and to facilitate evaluation on challenging real-world images, we introduce the Real Event Blur (REBlur) dataset, captured with an event camera in an illumination controlled optical laboratory. Our Event Fusion Network (EFNet) sets the new state of the art in motion deblurring, surpassing both the prior best-performing image-based method and all event-based methods with public implementations on the GoPro dataset (by up to 2.47dB) and on our REBlur dataset, even in extreme blurry conditions.
8 |
9 |
10 | ### News
11 | - Febuary 2025: [NTIRE 2025](https://cvlai.net/ntire/2025/) [the First Challenge on Event-Based Image Deblurring](https://codalab.lisn.upsaclay.fr/competitions/21498) starts! Example config file is provided at `./options/train/HighREV/EFNet_HighREV_Deblur.yml`
12 | - Octorber 2024: Real-time deblurring system built!! Please check our video demo in github page!
13 | - September 25: Update the dataset link. Datasets available now.
14 | - July 14 2022: :tada: :tada: Our paper was accepted in ECCV'2022 as oral presentation (2.7% of the submission).
15 | - July 14 2022: The repository is under construction.
16 |
17 | ### NTIRE 2025 the First Challenge on Event-Based Image Deblurring
18 | We provide a simple instruction here. For more details please refer to [this repo](https://github.com/AHupuJR/NTIRE2025_EventDeblur_challenge)
19 |
20 | #### HighREV dataset download
21 |
22 | [HighREV dataset](https://codalab.lisn.upsaclay.fr/my/datasets/download/9f275580-9b38-4984-b995-1e59e96b6111): Mandatory, with raw events.
23 |
24 | command: `wget -O HighREV.zip "https://codalab.lisn.upsaclay.fr/my/datasets/download/9f275580-9b38-4984-b995-1e59e96b6111"`
25 |
26 |
27 | [Processed voxel grid of events](https://codalab.lisn.upsaclay.fr/my/datasets/download/c83e95ab-d4e6-4b9f-b7de-e3d3b45356e3): Optional. Using the processed voxel grids can speed up the data processing when training.
28 |
29 |
30 | #### Dataset codes:
31 | `./basicsr/data/npz_image_dataset.py` for raw events
32 | `./basicsr/data/EFNet_HighREV_Deblur_voxel.py` for voxel grids
33 |
34 |
35 | #### Example:
36 | ```
37 | git clone https://github.com/AHupuJR/EFNet
38 | cd EFNet
39 | pip install -r requirements.txt
40 | python setup.py develop --no_cuda_ext
41 |
42 | python ./basicsr/train.py -opt options/train/HighREV/EFNet_HighREV_Deblur.yml
43 | ```
44 |
45 |
46 | ### Real-time system
47 | https://github.com/AHupuJR/EFNet/assets/35628326/b888bb41-33f9-46de-b1b6-50d5291bf4ed
48 |
49 |
50 | [Bilibili video](https://www.bilibili.com/video/BV158411r7nY/?spm_id_from=333.999.0.0&vd_source=4c2f67c9c80eb20c01a758e38fc07198)
51 |
52 | Contributed by Xiaolei Gu, Xiao Jin (Jiaxing Research Institute, Zhejiang University), Yuhan Bao (Zhejiang University).
53 |
54 | From left to right: blurry image, deblurred image, visualized events.
55 |
56 | Feel free to connect me for potential applications.
57 |
58 | ### Network Architecture
59 |
60 |
61 |
62 |
63 | ### Symmetric Cumulative Event Representation (SCER)
64 |
65 |
66 |
67 | ### Results
68 | GoPro dataset (Click to expand)
69 |
70 |
71 |
72 |
73 |
74 | REBlur dataset (Click to expand)
75 |
76 |
77 |
78 |
79 |
80 | ### Installation
81 | This implementation based on [BasicSR](https://github.com/xinntao/BasicSR) which is a open source toolbox for image/video restoration tasks.
82 |
83 | ```python
84 | python 3.8.5
85 | pytorch 1.7.1
86 | cuda 11.0
87 | ```
88 |
89 |
90 |
91 | ```
92 | git clone https://github.com/AHupuJR/EFNet
93 | cd EFNet
94 | pip install -r requirements.txt
95 | python setup.py develop --no_cuda_ext
96 | ```
97 |
98 | ### Dataset
99 | Use GoPro events to train the model. If you want to use your own event representation instead of SCER, download GoPro raw events and use EFNet/scripts/data_preparation/make_voxels_esim.py to produce your own event representation.
100 |
101 | GoPro with SCER: [[ETH_share_link](https://data.vision.ee.ethz.ch/csakarid/shared/EFNet/GOPRO.zip)] [[BaiduYunPan](https://pan.baidu.com/s/1TxWdMB2LjdlgIvuc6QN-Bg)/code: 3wm8]
102 |
103 | REBlur with SCER: [[ETH_share_link](https://data.vision.ee.ethz.ch/csakarid/shared/EFNet/REBlur.zip)] [[BaiduYunPan](https://pan.baidu.com/s/13v0CjlFUXt9TxXI0Co9tQQ?pwd=f6ha#list/path=%2F)/code:f6ha]
104 |
105 | We also provide scripts to convert raw event files to SCER using scripts in [./scripts/data_preparation/](./scripts/data_preparation/). You can also design your own event representation by modify the script. Raw event files download:
106 |
107 | GoPro with raw events: [[ETH_share_link](https://data.vision.ee.ethz.ch/csakarid/shared/EFNet/GOPRO_rawevents.zip)] [[BaiduYunPan](link)/TODO]
108 |
109 | REBlur with raw events: [[ETH_share_link](https://data.vision.ee.ethz.ch/csakarid/shared/EFNet/REBlur_rawevents.zip)] [[BaiduYunPan](link)/TODO]
110 |
111 |
112 | ### Train
113 | ---
114 | #### GoPro
115 |
116 | * prepare data
117 |
118 | * download the GoPro events dataset (see [Dataset](dataset_section)) to
119 | ```bash
120 | ./datasets
121 | ```
122 |
123 | * it should be like:
124 |
125 | ```bash
126 | ./datasets/
127 | ./datasets/DATASET_NAME/
128 | ./datasets/DATASET_NAME/train/
129 | ./datasets/DATASET_NAME/test/
130 | ```
131 |
132 | * train
133 |
134 | * ```python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/GoPro/EFNet.yml --launcher pytorch```
135 |
136 | * eval
137 | * Download [pretrained model](https://drive.google.com/file/d/19O-B-K4IODMENQblwHbSqNu0TX0IF-iA/view?usp=sharing) to ./experiments/pretrained_models/EFNet-GoPro.pth
138 | * ```python basicsr/test.py -opt options/test/GoPro/EFNet.yml ```
139 |
140 |
141 | #### REBlur
142 |
143 | * prepare data
144 |
145 | * download the REBlur dataset (see [Dataset](dataset_section)) to
146 | ```bash
147 | ./datasets
148 | ```
149 |
150 | * it should be like:
151 |
152 | ```bash
153 | ./datasets/
154 | ./datasets/DATASET_NAME/
155 | ./datasets/DATASET_NAME/train/
156 | ./datasets/DATASET_NAME/test/
157 | ```
158 |
159 | * finetune
160 |
161 | * ```python ./basicsr/train.py -opt options/train/REBlur/Finetune_EFNet.yml```
162 |
163 | * eval
164 | * Download [pretrained model](https://drive.google.com/file/d/1yMGnwfYsxWbVp7r-oc8ls9qnOEDavG3h/view?usp=sharing) to ./experiments/pretrained_models/EFNet-REBlur.pth
165 | * ```python basicsr/test.py -opt options/test/REBlur/Finetune_EFNet.yml ```
166 |
167 |
168 | ### Qualitative results
169 | All the qualitative results can be downloaded through Google Drive:
170 |
171 | [GoPro test](https://drive.google.com/file/d/17jXR5U9e3-8dXPUxB0-wBhDFg60Oe8US/view?usp=sharing)
172 |
173 | [REBlur test](https://drive.google.com/file/d/17jXR5U9e3-8dXPUxB0-wBhDFg60Oe8US/view?usp=sharing)
174 |
175 | [REBlur addition](https://drive.google.com/file/d/17jXR5U9e3-8dXPUxB0-wBhDFg60Oe8US/view?usp=sharing)
176 |
177 |
178 | ### Citations
179 |
180 | ```
181 | @inproceedings{sun2022event,
182 | title={Event-Based Fusion for Motion Deblurring with Cross-modal Attention},
183 | author={Sun, Lei and Sakaridis, Christos and Liang, Jingyun and Jiang, Qi and Yang, Kailun and Sun, Peng and Ye, Yaozu and Wang, Kaiwei and Gool, Luc Van},
184 | booktitle={European Conference on Computer Vision},
185 | pages={412--428},
186 | year={2022},
187 | organization={Springer}
188 | }
189 | ```
190 |
191 |
192 | ### Contact
193 | Should you have any questions, please feel free to contact leo_sun@zju.edu.cn or leisun@ee.ethz.ch.
194 |
195 |
196 | ### License and Acknowledgement
197 |
198 | This project is under the Apache 2.0 license, and it is based on [BasicSR](https://github.com/xinntao/BasicSR) which is under the Apache 2.0 license. Thanks to the inspirations and codes from [HINet](https://github.com/megvii-model/HINet) and [event_utils](https://github.com/TimoStoff/event_utils)
199 |
200 |
201 |
--------------------------------------------------------------------------------
/scripts/data_preparation/raw_event_dataset.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | from .base_dataset import BaseVoxelDataset
3 | import pandas as pd
4 | import numpy as np
5 |
6 | class GoproEsimH5Dataset(BaseVoxelDataset):
7 | """
8 | Dataloader for events saved in the Monash University HDF5 events format
9 | GOPRO ESIM data, index is different (index + 1)
10 | """
11 |
12 | def get_frame(self, index):
13 | """
14 | discard the first and the last frame in GoproEsim
15 | """
16 | return self.h5_file['images']['image{:09d}'.format(index+1)][:]
17 |
18 | def get_gt_frame(self, index):
19 | """
20 | discard the first and the last frame in GoproEsim
21 | """
22 | return self.h5_file['sharp_images']['image{:09d}'.format(index+1)][:]
23 |
24 | def get_events(self, idx0, idx1):
25 | xs = self.h5_file['events/xs'][idx0:idx1]
26 | ys = self.h5_file['events/ys'][idx0:idx1]
27 | ts = self.h5_file['events/ts'][idx0:idx1]
28 | ps = self.h5_file['events/ps'][idx0:idx1] * 2.0 - 1.0 # -1 and 1
29 | return xs, ys, ts, ps
30 |
31 | def load_data(self, data_path):
32 | self.data_sources = ('esim', 'ijrr', 'mvsec', 'eccd', 'hqfd', 'unknown')
33 | try:
34 | self.h5_file = h5py.File(data_path, 'r')
35 | except OSError as err:
36 | print("Couldn't open {}: {}".format(data_path, err))
37 |
38 | if self.sensor_resolution is None:
39 | self.sensor_resolution = self.h5_file.attrs['sensor_resolution'][0:2]
40 | else:
41 | self.sensor_resolution = self.sensor_resolution[0:2]
42 | print("sensor resolution = {}".format(self.sensor_resolution))
43 | self.has_flow = 'flow' in self.h5_file.keys() and len(self.h5_file['flow']) > 0
44 | self.t0 = self.h5_file['events/ts'][0]
45 | self.tk = self.h5_file['events/ts'][-1]
46 | self.num_events = self.h5_file.attrs["num_events"]
47 | self.num_frames = self.h5_file.attrs["num_imgs"]
48 |
49 | self.frame_ts = []
50 | for img_name in self.h5_file['images']:
51 | self.frame_ts.append(self.h5_file['images/{}'.format(img_name)].attrs['timestamp'])
52 |
53 | data_source = self.h5_file.attrs.get('source', 'unknown')
54 | try:
55 | self.data_source_idx = self.data_sources.index(data_source)
56 | except ValueError:
57 | self.data_source_idx = -1
58 |
59 | def find_ts_index(self, timestamp):
60 | idx = binary_search_h5_dset(self.h5_file['events/ts'], timestamp)
61 | return idx
62 |
63 | def ts(self, index):
64 | return self.h5_file['events/ts'][index]
65 |
66 | def compute_frame_indices(self):
67 | frame_indices = []
68 | start_idx = 0
69 | for img_name in self.h5_file['images']:
70 | end_idx = self.h5_file['images/{}'.format(img_name)].attrs['event_idx']
71 | frame_indices.append([start_idx, end_idx])
72 | start_idx = end_idx
73 | return frame_indices
74 |
75 |
76 | class SeemsH5Dataset(BaseVoxelDataset):
77 | """
78 | Dataloader for events saved in the Monash University HDF5 events format
79 | H5 data contains the exposure information
80 | """
81 |
82 | def get_frame(self, index):
83 | """
84 | discard the first and the last frame in GoproEsim
85 | """
86 | return self.h5_file['images']['image{:09d}'.format(index)][:]
87 |
88 | def get_gt_frame(self, index):
89 | """
90 | discard the first and the last frame in GoproEsim
91 | """
92 | return self.h5_file['sharp_images']['image{:09d}'.format(index)][:]
93 |
94 | def get_events(self, idx0, idx1):
95 | xs = self.h5_file['events/xs'][idx0:idx1]
96 | ys = self.h5_file['events/ys'][idx0:idx1]
97 | ts = self.h5_file['events/ts'][idx0:idx1]
98 | ps = self.h5_file['events/ps'][idx0:idx1] * 2.0 - 1.0 # -1 and 1
99 | return xs, ys, ts, ps
100 |
101 | def load_data(self, data_path):
102 | self.data_sources = ('esim', 'ijrr', 'mvsec', 'eccd', 'hqfd', 'unknown')
103 | try:
104 | self.h5_file = h5py.File(data_path, 'r')
105 | except OSError as err:
106 | print("Couldn't open {}: {}".format(data_path, err))
107 |
108 | if self.sensor_resolution is None:
109 | self.sensor_resolution = self.h5_file.attrs['sensor_resolution'][0:2]
110 | else:
111 | self.sensor_resolution = self.sensor_resolution[0:2]
112 | print("sensor resolution = {}".format(self.sensor_resolution))
113 | self.has_flow = 'flow' in self.h5_file.keys() and len(self.h5_file['flow']) > 0
114 | self.t0 = self.h5_file['events/ts'][0]
115 | self.tk = self.h5_file['events/ts'][-1]
116 | self.num_events = self.h5_file.attrs["num_events"]
117 | self.num_frames = self.h5_file.attrs["num_imgs"]
118 |
119 | self.frame_ts = []
120 | if self.has_exposure_time:
121 | self.frame_exposure_start=[]
122 | self.frame_exposure_end = []
123 | self.frame_exposure_time=[]
124 | for img_name in self.h5_file['images']:
125 | self.frame_ts.append(self.h5_file['images/{}'.format(img_name)].attrs['timestamp'])
126 | if self.has_exposure_time:
127 | self.frame_exposure_start.append(self.h5_file['images/{}'.format(img_name)].attrs['exposure_start'])
128 | self.frame_exposure_end.append(self.h5_file['images/{}'.format(img_name)].attrs['exposure_end'])
129 | self.frame_exposure_time.append(self.h5_file['images/{}'.format(img_name)].attrs['exposure_time'])
130 | data_source = self.h5_file.attrs.get('source', 'unknown')
131 | try:
132 | self.data_source_idx = self.data_sources.index(data_source)
133 | except ValueError:
134 | self.data_source_idx = -1
135 |
136 | def find_ts_index(self, timestamp):
137 | idx = binary_search_h5_dset(self.h5_file['events/ts'], timestamp)
138 | return idx
139 |
140 | def ts(self, index):
141 | return self.h5_file['events/ts'][index]
142 |
143 | def compute_frame_indices(self):
144 | frame_indices = []
145 | start_idx = 0
146 | for img_name in self.h5_file['images']:
147 | end_idx = self.h5_file['images/{}'.format(img_name)].attrs['event_idx']
148 | frame_indices.append([start_idx, end_idx])
149 | start_idx = end_idx
150 | return frame_indices
151 |
152 |
153 | class GoproEsimH52NpzDataset(BaseVoxelDataset):
154 | """
155 |
156 | """
157 |
158 | def get_frame(self, index):
159 | """
160 | discard the first and the last frame in GoproEsim
161 | """
162 | return self.h5_file['images']['image{:09d}'.format(index + 1)][:]
163 |
164 | def get_gt_frame(self, index):
165 | """
166 | discard the first and the last frame in GoproEsim
167 | """
168 | return self.h5_file['sharp_images']['image{:09d}'.format(index + 1)][:]
169 |
170 | def get_events(self, idx0, idx1):
171 | xs = self.h5_file['events/xs'][idx0:idx1]
172 | ys = self.h5_file['events/ys'][idx0:idx1]
173 | ts = self.h5_file['events/ts'][idx0:idx1]
174 | ps = self.h5_file['events/ps'][idx0:idx1] * 2.0 - 1.0 # -1 and 1
175 | return xs, ys, ts, ps
176 |
177 | def load_data(self, data_path, csv_path):
178 | self.data_sources = ('esim', 'ijrr', 'mvsec', 'eccd', 'hqfd', 'unknown')
179 | try:
180 | self.h5_file = h5py.File(data_path, 'r')
181 | except OSError as err:
182 | print("Couldn't open {}: {}".format(data_path, err))
183 | pd_file = pd.read_csv(csv_path, header=None, names=['t', 'path'], dtype={'t': np.float64},
184 | engine='c')
185 |
186 | img_time_stamps = pd_file.t.values.astype(np.float64)
187 | img_names = pd_file.path.values
188 |
189 | if self.sensor_resolution is None:
190 | self.sensor_resolution = self.h5_file.attrs['sensor_resolution'][0:2]
191 | else:
192 | self.sensor_resolution = self.sensor_resolution[0:2]
193 |
194 | print("sensor resolution = {}".format(self.sensor_resolution))
195 | self.has_flow = 'flow' in self.h5_file.keys() and len(self.h5_file['flow']) > 0
196 |
197 | self.t0 = self.h5_file['events/ts'][0]
198 | self.tk = self.h5_file['events/ts'][-1]
199 | self.num_events = self.h5_file.attrs["num_events"]
200 |
201 |
202 | self.frame_ts = list(img_time_stamps)
203 | self.img_names = list(img_names)
204 | self.num_frames = len(self.frame_ts)
205 |
206 | self.data_source_idx = -1
207 |
208 |
209 | def find_ts_index(self, timestamp):
210 | idx = binary_search_h5_dset(self.h5_file['events/ts'], timestamp)
211 | return idx
212 |
213 | def ts(self, index):
214 | return self.h5_file['events/ts'][index]
215 |
216 | def compute_frame_indices(self):
217 | frame_indices = []
218 | start_idx = 0
219 | for img_name in self.h5_file['images']:
220 | end_idx = self.h5_file['images/{}'.format(img_name)].attrs['event_idx']
221 | frame_indices.append([start_idx, end_idx])
222 | start_idx = end_idx
223 | return frame_indices
224 |
225 |
226 | def binary_search_h5_dset(dset, x, l=None, r=None, side='left'):
227 | """
228 | Binary search for a timestamp in an HDF5 event file, without
229 | loading the entire file into RAM
230 | @param dset The HDF5 dataset
231 | @param x The timestamp being searched for
232 | @param l Starting guess for the left side (0 if None is chosen)
233 | @param r Starting guess for the right side (-1 if None is chosen)
234 | @param side Which side to take final result for if exact match is not found
235 | @returns Index of nearest event to 'x'
236 | """
237 | l = 0 if l is None else l
238 | r = len(dset)-1 if r is None else r
239 | while l <= r:
240 | mid = l + (r - l)//2;
241 | midval = dset[mid]
242 | if midval == x:
243 | return mid
244 | elif midval < x:
245 | l = mid + 1
246 | else:
247 | r = mid - 1
248 | if side == 'left':
249 | return l
250 | return r
251 |
252 |
253 |
--------------------------------------------------------------------------------