├── .gitignore ├── LICENSE ├── VERSION ├── basicsr ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── ffhq_dataset.py │ ├── meta_info │ │ ├── meta_info_DIV2K800sub_GT.txt │ │ ├── meta_info_REDS4_test_GT.txt │ │ ├── meta_info_REDS_GT.txt │ │ ├── meta_info_REDSofficial4_test_GT.txt │ │ ├── meta_info_REDSval_official_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_fast_GT.txt │ │ ├── meta_info_Vimeo90K_test_medium_GT.txt │ │ ├── meta_info_Vimeo90K_test_slow_GT.txt │ │ └── meta_info_Vimeo90K_train_GT.txt │ ├── paired_image_SR_LR_FullImage_Memory_dataset.py │ ├── paired_image_SR_LR_dataset.py │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ ├── reds_dataset.py │ ├── single_image_dataset.py │ ├── transforms.py │ ├── video_test_dataset.py │ └── vimeo90k_dataset.py ├── demo.py ├── demo_ssr.py ├── metrics │ ├── __init__.py │ ├── fid.py │ ├── metric_util.py │ ├── niqe.py │ ├── niqe_pris_params.npz │ └── psnr_ssim.py ├── models │ ├── __init__.py │ ├── archs │ │ ├── Baseline_arch.py │ │ ├── NAFNet_arch.py │ │ ├── NAFSSR_arch.py │ │ ├── __init__.py │ │ ├── arch_util.py │ │ └── local_arch.py │ ├── base_model.py │ ├── image_restoration_model.py │ ├── losses │ │ ├── __init__.py │ │ ├── loss_util.py │ │ └── losses.py │ └── lr_scheduler.py ├── test.py ├── train.py ├── utils │ ├── __init__.py │ ├── create_lmdb.py │ ├── dist_util.py │ ├── download_util.py │ ├── face_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ └── options.py └── version.py ├── cog.yaml ├── datasets └── README.md ├── demo ├── blurry.jpg ├── denoise_img.png ├── lr_img_l.png ├── lr_img_r.png ├── noisy.png ├── sr_img_l.png └── sr_img_r.png ├── docs ├── GoPro.md ├── REDS.md ├── SIDD.md └── StereoSR.md ├── experiments └── pretrained_models │ └── README.md ├── figures ├── NAFSSR_arch.jpg ├── NAFSSR_params.jpg ├── PSNR_vs_MACs.jpg ├── StereoSR.gif ├── deblur.gif └── denoise.gif ├── options ├── test │ ├── GoPro │ │ ├── Baseline-width32.yml │ │ ├── Baseline-width64.yml │ │ ├── NAFNet-width32.yml │ │ └── NAFNet-width64.yml │ ├── NAFSSR │ │ ├── NAFSSR-B_2x.yml │ │ ├── NAFSSR-B_4x.yml │ │ ├── NAFSSR-L_2x.yml │ │ ├── NAFSSR-L_4x.yml │ │ ├── NAFSSR-S_2x.yml │ │ ├── NAFSSR-S_4x.yml │ │ ├── NAFSSR-T_2x.yml │ │ └── NAFSSR-T_4x.yml │ ├── REDS │ │ └── NAFNet-width64.yml │ └── SIDD │ │ ├── Baseline-width32.yml │ │ ├── Baseline-width64.yml │ │ ├── NAFNet-width32.yml │ │ └── NAFNet-width64.yml └── train │ ├── GoPro │ ├── Baseline-width32.yml │ ├── Baseline-width64.yml │ ├── NAFNet-width32.yml │ └── NAFNet-width64.yml │ ├── NAFSSR │ ├── NAFSSR-B_x2.yml │ ├── NAFSSR-B_x4.yml │ ├── NAFSSR-L_x2.yml │ ├── NAFSSR-L_x4.yml │ ├── NAFSSR-S_x2.yml │ ├── NAFSSR-S_x4.yml │ ├── NAFSSR-T_x2.yml │ └── NAFSSR-T_x4.yml │ ├── REDS │ └── NAFNet-width64.yml │ └── SIDD │ ├── Baseline-width32.yml │ ├── Baseline-width64.yml │ ├── NAFNet-width32.yml │ └── NAFNet-width64.yml ├── predict.py ├── readme.md ├── requirements.txt ├── scripts ├── data_preparation │ ├── gopro.py │ ├── reds.py │ └── sidd.py └── make_pickle.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/* 3 | experiments 4 | logs/ 5 | *results* 6 | *__pycache__* 7 | *.sh 8 | datasets 9 | basicsr.egg-info -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 1.2.0 2 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | 8 | import importlib 9 | import numpy as np 10 | import random 11 | import torch 12 | import torch.utils.data 13 | from functools import partial 14 | from os import path as osp 15 | 16 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 17 | from basicsr.utils import get_root_logger, scandir 18 | from basicsr.utils.dist_util import get_dist_info 19 | 20 | __all__ = ['create_dataset', 'create_dataloader'] 21 | 22 | # automatically scan and import dataset modules 23 | # scan all the files under the data folder with '_dataset' in file names 24 | data_folder = osp.dirname(osp.abspath(__file__)) 25 | dataset_filenames = [ 26 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) 27 | if v.endswith('_dataset.py') 28 | ] 29 | # import all the dataset modules 30 | _dataset_modules = [ 31 | importlib.import_module(f'basicsr.data.{file_name}') 32 | for file_name in dataset_filenames 33 | ] 34 | 35 | 36 | def create_dataset(dataset_opt): 37 | """Create dataset. 38 | 39 | Args: 40 | dataset_opt (dict): Configuration for dataset. It constains: 41 | name (str): Dataset name. 42 | type (str): Dataset type. 43 | """ 44 | dataset_type = dataset_opt['type'] 45 | 46 | # dynamic instantiation 47 | for module in _dataset_modules: 48 | dataset_cls = getattr(module, dataset_type, None) 49 | if dataset_cls is not None: 50 | break 51 | if dataset_cls is None: 52 | raise ValueError(f'Dataset {dataset_type} is not found.') 53 | 54 | dataset = dataset_cls(dataset_opt) 55 | 56 | logger = get_root_logger() 57 | logger.info( 58 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' 59 | 'is created.') 60 | return dataset 61 | 62 | 63 | def create_dataloader(dataset, 64 | dataset_opt, 65 | num_gpu=1, 66 | dist=False, 67 | sampler=None, 68 | seed=None): 69 | """Create dataloader. 70 | 71 | Args: 72 | dataset (torch.utils.data.Dataset): Dataset. 73 | dataset_opt (dict): Dataset options. It contains the following keys: 74 | phase (str): 'train' or 'val'. 75 | num_worker_per_gpu (int): Number of workers for each GPU. 76 | batch_size_per_gpu (int): Training batch size for each GPU. 77 | num_gpu (int): Number of GPUs. Used only in the train phase. 78 | Default: 1. 79 | dist (bool): Whether in distributed training. Used only in the train 80 | phase. Default: False. 81 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 82 | seed (int | None): Seed. Default: None 83 | """ 84 | phase = dataset_opt['phase'] 85 | rank, _ = get_dist_info() 86 | if phase == 'train': 87 | if dist: # distributed training 88 | batch_size = dataset_opt['batch_size_per_gpu'] 89 | num_workers = dataset_opt['num_worker_per_gpu'] 90 | else: # non-distributed training 91 | multiplier = 1 if num_gpu == 0 else num_gpu 92 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 93 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 94 | dataloader_args = dict( 95 | dataset=dataset, 96 | batch_size=batch_size, 97 | shuffle=False, 98 | num_workers=num_workers, 99 | sampler=sampler, 100 | drop_last=True, 101 | persistent_workers=True 102 | ) 103 | if sampler is None: 104 | dataloader_args['shuffle'] = True 105 | dataloader_args['worker_init_fn'] = partial( 106 | worker_init_fn, num_workers=num_workers, rank=rank, 107 | seed=seed) if seed is not None else None 108 | elif phase in ['val', 'test']: # validation 109 | dataloader_args = dict( 110 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 111 | else: 112 | raise ValueError(f'Wrong dataset phase: {phase}. ' 113 | "Supported ones are 'train', 'val' and 'test'.") 114 | 115 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 116 | 117 | prefetch_mode = dataset_opt.get('prefetch_mode') 118 | if prefetch_mode == 'cpu': # CPUPrefetcher 119 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 120 | logger = get_root_logger() 121 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' 122 | f'num_prefetch_queue = {num_prefetch_queue}') 123 | return PrefetchDataLoader( 124 | num_prefetch_queue=num_prefetch_queue, **dataloader_args) 125 | else: 126 | # prefetch_mode=None: Normal dataloader 127 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 128 | return torch.utils.data.DataLoader(**dataloader_args) 129 | 130 | 131 | def worker_init_fn(worker_id, num_workers, rank, seed): 132 | # Set the worker seed to num_workers * rank + worker_id + seed 133 | worker_seed = num_workers * rank + worker_id + seed 134 | np.random.seed(worker_seed) 135 | random.seed(worker_seed) 136 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | 8 | import math 9 | import torch 10 | from torch.utils.data.sampler import Sampler 11 | 12 | 13 | class EnlargedSampler(Sampler): 14 | """Sampler that restricts data loading to a subset of the dataset. 15 | 16 | Modified from torch.utils.data.distributed.DistributedSampler 17 | Support enlarging the dataset for iteration-based training, for saving 18 | time when restart the dataloader after each epoch 19 | 20 | Args: 21 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 22 | num_replicas (int | None): Number of processes participating in 23 | the training. It is usually the world_size. 24 | rank (int | None): Rank of the current process within num_replicas. 25 | ratio (int): Enlarging ratio. Default: 1. 26 | """ 27 | 28 | def __init__(self, dataset, num_replicas, rank, ratio=1): 29 | self.dataset = dataset 30 | self.num_replicas = num_replicas 31 | self.rank = rank 32 | self.epoch = 0 33 | self.num_samples = math.ceil( 34 | len(self.dataset) * ratio / self.num_replicas) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | indices = torch.randperm(self.total_size, generator=g).tolist() 42 | 43 | dataset_size = len(self.dataset) 44 | indices = [v % dataset_size for v in indices] 45 | 46 | # subsample 47 | indices = indices[self.rank:self.total_size:self.num_replicas] 48 | assert len(indices) == self.num_samples 49 | 50 | return iter(indices) 51 | 52 | def __len__(self): 53 | return self.num_samples 54 | 55 | def set_epoch(self, epoch): 56 | self.epoch = epoch 57 | -------------------------------------------------------------------------------- /basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | from os import path as osp 8 | from torch.utils import data as data 9 | from torchvision.transforms.functional import normalize 10 | 11 | from basicsr.data.transforms import augment 12 | from basicsr.utils import FileClient, imfrombytes, img2tensor 13 | 14 | 15 | class FFHQDataset(data.Dataset): 16 | """FFHQ dataset for StyleGAN. 17 | 18 | Args: 19 | opt (dict): Config for train datasets. It contains the following keys: 20 | dataroot_gt (str): Data root path for gt. 21 | io_backend (dict): IO backend type and other kwarg. 22 | mean (list | tuple): Image mean. 23 | std (list | tuple): Image std. 24 | use_hflip (bool): Whether to horizontally flip. 25 | 26 | """ 27 | 28 | def __init__(self, opt): 29 | super(FFHQDataset, self).__init__() 30 | self.opt = opt 31 | # file client (io backend) 32 | self.file_client = None 33 | self.io_backend_opt = opt['io_backend'] 34 | 35 | self.gt_folder = opt['dataroot_gt'] 36 | self.mean = opt['mean'] 37 | self.std = opt['std'] 38 | 39 | if self.io_backend_opt['type'] == 'lmdb': 40 | self.io_backend_opt['db_paths'] = self.gt_folder 41 | if not self.gt_folder.endswith('.lmdb'): 42 | raise ValueError("'dataroot_gt' should end with '.lmdb', " 43 | f'but received {self.gt_folder}') 44 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 45 | self.paths = [line.split('.')[0] for line in fin] 46 | else: 47 | # FFHQ has 70000 images in total 48 | self.paths = [ 49 | osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000) 50 | ] 51 | 52 | def __getitem__(self, index): 53 | if self.file_client is None: 54 | self.file_client = FileClient( 55 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 56 | 57 | # load gt image 58 | gt_path = self.paths[index] 59 | img_bytes = self.file_client.get(gt_path) 60 | img_gt = imfrombytes(img_bytes, float32=True) 61 | 62 | # random horizontal flip 63 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 64 | # BGR to RGB, HWC to CHW, numpy to tensor 65 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 66 | # normalize 67 | normalize(img_gt, self.mean, self.std, inplace=True) 68 | return {'gt': img_gt, 'gt_path': gt_path} 69 | 70 | def __len__(self): 71 | return len(self.paths) 72 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDS4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 000 100 (720,1280,3) 2 | 011 100 (720,1280,3) 3 | 015 100 (720,1280,3) 4 | 020 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDS_GT.txt: -------------------------------------------------------------------------------- 1 | 000 100 (720,1280,3) 2 | 001 100 (720,1280,3) 3 | 002 100 (720,1280,3) 4 | 003 100 (720,1280,3) 5 | 004 100 (720,1280,3) 6 | 005 100 (720,1280,3) 7 | 006 100 (720,1280,3) 8 | 007 100 (720,1280,3) 9 | 008 100 (720,1280,3) 10 | 009 100 (720,1280,3) 11 | 010 100 (720,1280,3) 12 | 011 100 (720,1280,3) 13 | 012 100 (720,1280,3) 14 | 013 100 (720,1280,3) 15 | 014 100 (720,1280,3) 16 | 015 100 (720,1280,3) 17 | 016 100 (720,1280,3) 18 | 017 100 (720,1280,3) 19 | 018 100 (720,1280,3) 20 | 019 100 (720,1280,3) 21 | 020 100 (720,1280,3) 22 | 021 100 (720,1280,3) 23 | 022 100 (720,1280,3) 24 | 023 100 (720,1280,3) 25 | 024 100 (720,1280,3) 26 | 025 100 (720,1280,3) 27 | 026 100 (720,1280,3) 28 | 027 100 (720,1280,3) 29 | 028 100 (720,1280,3) 30 | 029 100 (720,1280,3) 31 | 030 100 (720,1280,3) 32 | 031 100 (720,1280,3) 33 | 032 100 (720,1280,3) 34 | 033 100 (720,1280,3) 35 | 034 100 (720,1280,3) 36 | 035 100 (720,1280,3) 37 | 036 100 (720,1280,3) 38 | 037 100 (720,1280,3) 39 | 038 100 (720,1280,3) 40 | 039 100 (720,1280,3) 41 | 040 100 (720,1280,3) 42 | 041 100 (720,1280,3) 43 | 042 100 (720,1280,3) 44 | 043 100 (720,1280,3) 45 | 044 100 (720,1280,3) 46 | 045 100 (720,1280,3) 47 | 046 100 (720,1280,3) 48 | 047 100 (720,1280,3) 49 | 048 100 (720,1280,3) 50 | 049 100 (720,1280,3) 51 | 050 100 (720,1280,3) 52 | 051 100 (720,1280,3) 53 | 052 100 (720,1280,3) 54 | 053 100 (720,1280,3) 55 | 054 100 (720,1280,3) 56 | 055 100 (720,1280,3) 57 | 056 100 (720,1280,3) 58 | 057 100 (720,1280,3) 59 | 058 100 (720,1280,3) 60 | 059 100 (720,1280,3) 61 | 060 100 (720,1280,3) 62 | 061 100 (720,1280,3) 63 | 062 100 (720,1280,3) 64 | 063 100 (720,1280,3) 65 | 064 100 (720,1280,3) 66 | 065 100 (720,1280,3) 67 | 066 100 (720,1280,3) 68 | 067 100 (720,1280,3) 69 | 068 100 (720,1280,3) 70 | 069 100 (720,1280,3) 71 | 070 100 (720,1280,3) 72 | 071 100 (720,1280,3) 73 | 072 100 (720,1280,3) 74 | 073 100 (720,1280,3) 75 | 074 100 (720,1280,3) 76 | 075 100 (720,1280,3) 77 | 076 100 (720,1280,3) 78 | 077 100 (720,1280,3) 79 | 078 100 (720,1280,3) 80 | 079 100 (720,1280,3) 81 | 080 100 (720,1280,3) 82 | 081 100 (720,1280,3) 83 | 082 100 (720,1280,3) 84 | 083 100 (720,1280,3) 85 | 084 100 (720,1280,3) 86 | 085 100 (720,1280,3) 87 | 086 100 (720,1280,3) 88 | 087 100 (720,1280,3) 89 | 088 100 (720,1280,3) 90 | 089 100 (720,1280,3) 91 | 090 100 (720,1280,3) 92 | 091 100 (720,1280,3) 93 | 092 100 (720,1280,3) 94 | 093 100 (720,1280,3) 95 | 094 100 (720,1280,3) 96 | 095 100 (720,1280,3) 97 | 096 100 (720,1280,3) 98 | 097 100 (720,1280,3) 99 | 098 100 (720,1280,3) 100 | 099 100 (720,1280,3) 101 | 100 100 (720,1280,3) 102 | 101 100 (720,1280,3) 103 | 102 100 (720,1280,3) 104 | 103 100 (720,1280,3) 105 | 104 100 (720,1280,3) 106 | 105 100 (720,1280,3) 107 | 106 100 (720,1280,3) 108 | 107 100 (720,1280,3) 109 | 108 100 (720,1280,3) 110 | 109 100 (720,1280,3) 111 | 110 100 (720,1280,3) 112 | 111 100 (720,1280,3) 113 | 112 100 (720,1280,3) 114 | 113 100 (720,1280,3) 115 | 114 100 (720,1280,3) 116 | 115 100 (720,1280,3) 117 | 116 100 (720,1280,3) 118 | 117 100 (720,1280,3) 119 | 118 100 (720,1280,3) 120 | 119 100 (720,1280,3) 121 | 120 100 (720,1280,3) 122 | 121 100 (720,1280,3) 123 | 122 100 (720,1280,3) 124 | 123 100 (720,1280,3) 125 | 124 100 (720,1280,3) 126 | 125 100 (720,1280,3) 127 | 126 100 (720,1280,3) 128 | 127 100 (720,1280,3) 129 | 128 100 (720,1280,3) 130 | 129 100 (720,1280,3) 131 | 130 100 (720,1280,3) 132 | 131 100 (720,1280,3) 133 | 132 100 (720,1280,3) 134 | 133 100 (720,1280,3) 135 | 134 100 (720,1280,3) 136 | 135 100 (720,1280,3) 137 | 136 100 (720,1280,3) 138 | 137 100 (720,1280,3) 139 | 138 100 (720,1280,3) 140 | 139 100 (720,1280,3) 141 | 140 100 (720,1280,3) 142 | 141 100 (720,1280,3) 143 | 142 100 (720,1280,3) 144 | 143 100 (720,1280,3) 145 | 144 100 (720,1280,3) 146 | 145 100 (720,1280,3) 147 | 146 100 (720,1280,3) 148 | 147 100 (720,1280,3) 149 | 148 100 (720,1280,3) 150 | 149 100 (720,1280,3) 151 | 150 100 (720,1280,3) 152 | 151 100 (720,1280,3) 153 | 152 100 (720,1280,3) 154 | 153 100 (720,1280,3) 155 | 154 100 (720,1280,3) 156 | 155 100 (720,1280,3) 157 | 156 100 (720,1280,3) 158 | 157 100 (720,1280,3) 159 | 158 100 (720,1280,3) 160 | 159 100 (720,1280,3) 161 | 160 100 (720,1280,3) 162 | 161 100 (720,1280,3) 163 | 162 100 (720,1280,3) 164 | 163 100 (720,1280,3) 165 | 164 100 (720,1280,3) 166 | 165 100 (720,1280,3) 167 | 166 100 (720,1280,3) 168 | 167 100 (720,1280,3) 169 | 168 100 (720,1280,3) 170 | 169 100 (720,1280,3) 171 | 170 100 (720,1280,3) 172 | 171 100 (720,1280,3) 173 | 172 100 (720,1280,3) 174 | 173 100 (720,1280,3) 175 | 174 100 (720,1280,3) 176 | 175 100 (720,1280,3) 177 | 176 100 (720,1280,3) 178 | 177 100 (720,1280,3) 179 | 178 100 (720,1280,3) 180 | 179 100 (720,1280,3) 181 | 180 100 (720,1280,3) 182 | 181 100 (720,1280,3) 183 | 182 100 (720,1280,3) 184 | 183 100 (720,1280,3) 185 | 184 100 (720,1280,3) 186 | 185 100 (720,1280,3) 187 | 186 100 (720,1280,3) 188 | 187 100 (720,1280,3) 189 | 188 100 (720,1280,3) 190 | 189 100 (720,1280,3) 191 | 190 100 (720,1280,3) 192 | 191 100 (720,1280,3) 193 | 192 100 (720,1280,3) 194 | 193 100 (720,1280,3) 195 | 194 100 (720,1280,3) 196 | 195 100 (720,1280,3) 197 | 196 100 (720,1280,3) 198 | 197 100 (720,1280,3) 199 | 198 100 (720,1280,3) 200 | 199 100 (720,1280,3) 201 | 200 100 (720,1280,3) 202 | 201 100 (720,1280,3) 203 | 202 100 (720,1280,3) 204 | 203 100 (720,1280,3) 205 | 204 100 (720,1280,3) 206 | 205 100 (720,1280,3) 207 | 206 100 (720,1280,3) 208 | 207 100 (720,1280,3) 209 | 208 100 (720,1280,3) 210 | 209 100 (720,1280,3) 211 | 210 100 (720,1280,3) 212 | 211 100 (720,1280,3) 213 | 212 100 (720,1280,3) 214 | 213 100 (720,1280,3) 215 | 214 100 (720,1280,3) 216 | 215 100 (720,1280,3) 217 | 216 100 (720,1280,3) 218 | 217 100 (720,1280,3) 219 | 218 100 (720,1280,3) 220 | 219 100 (720,1280,3) 221 | 220 100 (720,1280,3) 222 | 221 100 (720,1280,3) 223 | 222 100 (720,1280,3) 224 | 223 100 (720,1280,3) 225 | 224 100 (720,1280,3) 226 | 225 100 (720,1280,3) 227 | 226 100 (720,1280,3) 228 | 227 100 (720,1280,3) 229 | 228 100 (720,1280,3) 230 | 229 100 (720,1280,3) 231 | 230 100 (720,1280,3) 232 | 231 100 (720,1280,3) 233 | 232 100 (720,1280,3) 234 | 233 100 (720,1280,3) 235 | 234 100 (720,1280,3) 236 | 235 100 (720,1280,3) 237 | 236 100 (720,1280,3) 238 | 237 100 (720,1280,3) 239 | 238 100 (720,1280,3) 240 | 239 100 (720,1280,3) 241 | 240 100 (720,1280,3) 242 | 241 100 (720,1280,3) 243 | 242 100 (720,1280,3) 244 | 243 100 (720,1280,3) 245 | 244 100 (720,1280,3) 246 | 245 100 (720,1280,3) 247 | 246 100 (720,1280,3) 248 | 247 100 (720,1280,3) 249 | 248 100 (720,1280,3) 250 | 249 100 (720,1280,3) 251 | 250 100 (720,1280,3) 252 | 251 100 (720,1280,3) 253 | 252 100 (720,1280,3) 254 | 253 100 (720,1280,3) 255 | 254 100 (720,1280,3) 256 | 255 100 (720,1280,3) 257 | 256 100 (720,1280,3) 258 | 257 100 (720,1280,3) 259 | 258 100 (720,1280,3) 260 | 259 100 (720,1280,3) 261 | 260 100 (720,1280,3) 262 | 261 100 (720,1280,3) 263 | 262 100 (720,1280,3) 264 | 263 100 (720,1280,3) 265 | 264 100 (720,1280,3) 266 | 265 100 (720,1280,3) 267 | 266 100 (720,1280,3) 268 | 267 100 (720,1280,3) 269 | 268 100 (720,1280,3) 270 | 269 100 (720,1280,3) 271 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSofficial4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 246 100 (720,1280,3) 4 | 257 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSval_official_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 242 100 (720,1280,3) 4 | 243 100 (720,1280,3) 5 | 244 100 (720,1280,3) 6 | 245 100 (720,1280,3) 7 | 246 100 (720,1280,3) 8 | 247 100 (720,1280,3) 9 | 248 100 (720,1280,3) 10 | 249 100 (720,1280,3) 11 | 250 100 (720,1280,3) 12 | 251 100 (720,1280,3) 13 | 252 100 (720,1280,3) 14 | 253 100 (720,1280,3) 15 | 254 100 (720,1280,3) 16 | 255 100 (720,1280,3) 17 | 256 100 (720,1280,3) 18 | 257 100 (720,1280,3) 19 | 258 100 (720,1280,3) 20 | 259 100 (720,1280,3) 21 | 260 100 (720,1280,3) 22 | 261 100 (720,1280,3) 23 | 262 100 (720,1280,3) 24 | 263 100 (720,1280,3) 25 | 264 100 (720,1280,3) 26 | 265 100 (720,1280,3) 27 | 266 100 (720,1280,3) 28 | 267 100 (720,1280,3) 29 | 268 100 (720,1280,3) 30 | 269 100 (720,1280,3) 31 | -------------------------------------------------------------------------------- /basicsr/data/paired_image_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | from torch.utils import data as data 8 | from torchvision.transforms.functional import normalize 9 | 10 | from basicsr.data.data_util import (paired_paths_from_folder, 11 | paired_paths_from_lmdb, 12 | paired_paths_from_meta_info_file) 13 | from basicsr.data.transforms import augment, paired_random_crop 14 | from basicsr.utils import FileClient, imfrombytes, img2tensor, padding 15 | 16 | 17 | class PairedImageDataset(data.Dataset): 18 | """Paired image dataset for image restoration. 19 | 20 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and 21 | GT image pairs. 22 | 23 | There are three modes: 24 | 1. 'lmdb': Use lmdb files. 25 | If opt['io_backend'] == lmdb. 26 | 2. 'meta_info_file': Use meta information file to generate paths. 27 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 28 | 3. 'folder': Scan folders to generate paths. 29 | The rest. 30 | 31 | Args: 32 | opt (dict): Config for train datasets. It contains the following keys: 33 | dataroot_gt (str): Data root path for gt. 34 | dataroot_lq (str): Data root path for lq. 35 | meta_info_file (str): Path for meta information file. 36 | io_backend (dict): IO backend type and other kwarg. 37 | filename_tmpl (str): Template for each filename. Note that the 38 | template excludes the file extension. Default: '{}'. 39 | gt_size (int): Cropped patched size for gt patches. 40 | use_flip (bool): Use horizontal flips. 41 | use_rot (bool): Use rotation (use vertical flip and transposing h 42 | and w for implementation). 43 | 44 | scale (bool): Scale, which will be added automatically. 45 | phase (str): 'train' or 'val'. 46 | """ 47 | 48 | def __init__(self, opt): 49 | super(PairedImageDataset, self).__init__() 50 | self.opt = opt 51 | # file client (io backend) 52 | self.file_client = None 53 | self.io_backend_opt = opt['io_backend'] 54 | self.mean = opt['mean'] if 'mean' in opt else None 55 | self.std = opt['std'] if 'std' in opt else None 56 | 57 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 58 | if 'filename_tmpl' in opt: 59 | self.filename_tmpl = opt['filename_tmpl'] 60 | else: 61 | self.filename_tmpl = '{}' 62 | 63 | if self.io_backend_opt['type'] == 'lmdb': 64 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 65 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 66 | self.paths = paired_paths_from_lmdb( 67 | [self.lq_folder, self.gt_folder], ['lq', 'gt']) 68 | elif 'meta_info_file' in self.opt and self.opt[ 69 | 'meta_info_file'] is not None: 70 | self.paths = paired_paths_from_meta_info_file( 71 | [self.lq_folder, self.gt_folder], ['lq', 'gt'], 72 | self.opt['meta_info_file'], self.filename_tmpl) 73 | else: 74 | self.paths = paired_paths_from_folder( 75 | [self.lq_folder, self.gt_folder], ['lq', 'gt'], 76 | self.filename_tmpl) 77 | 78 | def __getitem__(self, index): 79 | if self.file_client is None: 80 | self.file_client = FileClient( 81 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 82 | 83 | scale = self.opt['scale'] 84 | 85 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 86 | # image range: [0, 1], float32. 87 | gt_path = self.paths[index]['gt_path'] 88 | # print('gt path,', gt_path) 89 | img_bytes = self.file_client.get(gt_path, 'gt') 90 | try: 91 | img_gt = imfrombytes(img_bytes, float32=True) 92 | except: 93 | raise Exception("gt path {} not working".format(gt_path)) 94 | 95 | lq_path = self.paths[index]['lq_path'] 96 | # print(', lq path', lq_path) 97 | img_bytes = self.file_client.get(lq_path, 'lq') 98 | try: 99 | img_lq = imfrombytes(img_bytes, float32=True) 100 | except: 101 | raise Exception("lq path {} not working".format(lq_path)) 102 | 103 | 104 | # augmentation for training 105 | if self.opt['phase'] == 'train': 106 | gt_size = self.opt['gt_size'] 107 | # padding 108 | img_gt, img_lq = padding(img_gt, img_lq, gt_size) 109 | 110 | # random crop 111 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, 112 | gt_path) 113 | # flip, rotation 114 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], 115 | self.opt['use_rot']) 116 | 117 | # TODO: color space transform 118 | # BGR to RGB, HWC to CHW, numpy to tensor 119 | img_gt, img_lq = img2tensor([img_gt, img_lq], 120 | bgr2rgb=True, 121 | float32=True) 122 | # normalize 123 | if self.mean is not None or self.std is not None: 124 | normalize(img_lq, self.mean, self.std, inplace=True) 125 | normalize(img_gt, self.mean, self.std, inplace=True) 126 | 127 | return { 128 | 'lq': img_lq, 129 | 'gt': img_gt, 130 | 'lq_path': lq_path, 131 | 'gt_path': gt_path 132 | } 133 | 134 | def __len__(self): 135 | return len(self.paths) 136 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import queue as Queue 8 | import threading 9 | import torch 10 | from torch.utils.data import DataLoader 11 | 12 | 13 | class PrefetchGenerator(threading.Thread): 14 | """A general prefetch generator. 15 | 16 | Ref: 17 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 18 | 19 | Args: 20 | generator: Python generator. 21 | num_prefetch_queue (int): Number of prefetch queue. 22 | """ 23 | 24 | def __init__(self, generator, num_prefetch_queue): 25 | threading.Thread.__init__(self) 26 | self.queue = Queue.Queue(num_prefetch_queue) 27 | self.generator = generator 28 | self.daemon = True 29 | self.start() 30 | 31 | def run(self): 32 | for item in self.generator: 33 | self.queue.put(item) 34 | self.queue.put(None) 35 | 36 | def __next__(self): 37 | next_item = self.queue.get() 38 | if next_item is None: 39 | raise StopIteration 40 | return next_item 41 | 42 | def __iter__(self): 43 | return self 44 | 45 | 46 | class PrefetchDataLoader(DataLoader): 47 | """Prefetch version of dataloader. 48 | 49 | Ref: 50 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 51 | 52 | TODO: 53 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 54 | ddp. 55 | 56 | Args: 57 | num_prefetch_queue (int): Number of prefetch queue. 58 | kwargs (dict): Other arguments for dataloader. 59 | """ 60 | 61 | def __init__(self, num_prefetch_queue, **kwargs): 62 | self.num_prefetch_queue = num_prefetch_queue 63 | super(PrefetchDataLoader, self).__init__(**kwargs) 64 | 65 | def __iter__(self): 66 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 67 | 68 | 69 | class CPUPrefetcher(): 70 | """CPU prefetcher. 71 | 72 | Args: 73 | loader: Dataloader. 74 | """ 75 | 76 | def __init__(self, loader): 77 | self.ori_loader = loader 78 | self.loader = iter(loader) 79 | 80 | def next(self): 81 | try: 82 | return next(self.loader) 83 | except StopIteration: 84 | return None 85 | 86 | def reset(self): 87 | self.loader = iter(self.ori_loader) 88 | 89 | 90 | class CUDAPrefetcher(): 91 | """CUDA prefetcher. 92 | 93 | Ref: 94 | https://github.com/NVIDIA/apex/issues/304# 95 | 96 | It may consums more GPU memory. 97 | 98 | Args: 99 | loader: Dataloader. 100 | opt (dict): Options. 101 | """ 102 | 103 | def __init__(self, loader, opt): 104 | self.ori_loader = loader 105 | self.loader = iter(loader) 106 | self.opt = opt 107 | self.stream = torch.cuda.Stream() 108 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 109 | self.preload() 110 | 111 | def preload(self): 112 | try: 113 | self.batch = next(self.loader) # self.batch is a dict 114 | except StopIteration: 115 | self.batch = None 116 | return None 117 | # put tensors to gpu 118 | with torch.cuda.stream(self.stream): 119 | for k, v in self.batch.items(): 120 | if torch.is_tensor(v): 121 | self.batch[k] = self.batch[k].to( 122 | device=self.device, non_blocking=True) 123 | 124 | def next(self): 125 | torch.cuda.current_stream().wait_stream(self.stream) 126 | batch = self.batch 127 | self.preload() 128 | return batch 129 | 130 | def reset(self): 131 | self.loader = iter(self.ori_loader) 132 | self.preload() 133 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | from os import path as osp 8 | from torch.utils import data as data 9 | from torchvision.transforms.functional import normalize 10 | 11 | from basicsr.data.data_util import paths_from_lmdb 12 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir 13 | 14 | 15 | class SingleImageDataset(data.Dataset): 16 | """Read only lq images in the test phase. 17 | 18 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 19 | 20 | There are two modes: 21 | 1. 'meta_info_file': Use meta information file to generate paths. 22 | 2. 'folder': Scan folders to generate paths. 23 | 24 | Args: 25 | opt (dict): Config for train datasets. It contains the following keys: 26 | dataroot_lq (str): Data root path for lq. 27 | meta_info_file (str): Path for meta information file. 28 | io_backend (dict): IO backend type and other kwarg. 29 | """ 30 | 31 | def __init__(self, opt): 32 | super(SingleImageDataset, self).__init__() 33 | self.opt = opt 34 | # file client (io backend) 35 | self.file_client = None 36 | self.io_backend_opt = opt['io_backend'] 37 | self.mean = opt['mean'] if 'mean' in opt else None 38 | self.std = opt['std'] if 'std' in opt else None 39 | self.lq_folder = opt['dataroot_lq'] 40 | 41 | if self.io_backend_opt['type'] == 'lmdb': 42 | self.io_backend_opt['db_paths'] = [self.lq_folder] 43 | self.io_backend_opt['client_keys'] = ['lq'] 44 | self.paths = paths_from_lmdb(self.lq_folder) 45 | elif 'meta_info_file' in self.opt: 46 | with open(self.opt['meta_info_file'], 'r') as fin: 47 | self.paths = [ 48 | osp.join(self.lq_folder, 49 | line.split(' ')[0]) for line in fin 50 | ] 51 | else: 52 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 53 | 54 | def __getitem__(self, index): 55 | if self.file_client is None: 56 | self.file_client = FileClient( 57 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 58 | 59 | # load lq image 60 | lq_path = self.paths[index] 61 | img_bytes = self.file_client.get(lq_path, 'lq') 62 | img_lq = imfrombytes(img_bytes, float32=True) 63 | 64 | # TODO: color space transform 65 | # BGR to RGB, HWC to CHW, numpy to tensor 66 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 67 | # normalize 68 | if self.mean is not None or self.std is not None: 69 | normalize(img_lq, self.mean, self.std, inplace=True) 70 | return {'lq': img_lq, 'lq_path': lq_path} 71 | 72 | def __len__(self): 73 | return len(self.paths) 74 | -------------------------------------------------------------------------------- /basicsr/data/vimeo90k_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import random 8 | import torch 9 | from pathlib import Path 10 | from torch.utils import data as data 11 | 12 | from basicsr.data.transforms import augment, paired_random_crop 13 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 14 | 15 | 16 | class Vimeo90KDataset(data.Dataset): 17 | """Vimeo90K dataset for training. 18 | 19 | The keys are generated from a meta info txt file. 20 | basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt 21 | 22 | Each line contains: 23 | 1. clip name; 2. frame number; 3. image shape, seperated by a white space. 24 | Examples: 25 | 00001/0001 7 (256,448,3) 26 | 00001/0002 7 (256,448,3) 27 | 28 | Key examples: "00001/0001" 29 | GT (gt): Ground-Truth; 30 | LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. 31 | 32 | The neighboring frame list for different num_frame: 33 | num_frame | frame list 34 | 1 | 4 35 | 3 | 3,4,5 36 | 5 | 2,3,4,5,6 37 | 7 | 1,2,3,4,5,6,7 38 | 39 | Args: 40 | opt (dict): Config for train dataset. It contains the following keys: 41 | dataroot_gt (str): Data root path for gt. 42 | dataroot_lq (str): Data root path for lq. 43 | meta_info_file (str): Path for meta information file. 44 | io_backend (dict): IO backend type and other kwarg. 45 | 46 | num_frame (int): Window size for input frames. 47 | gt_size (int): Cropped patched size for gt patches. 48 | random_reverse (bool): Random reverse input frames. 49 | use_flip (bool): Use horizontal flips. 50 | use_rot (bool): Use rotation (use vertical flip and transposing h 51 | and w for implementation). 52 | 53 | scale (bool): Scale, which will be added automatically. 54 | """ 55 | 56 | def __init__(self, opt): 57 | super(Vimeo90KDataset, self).__init__() 58 | self.opt = opt 59 | self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path( 60 | opt['dataroot_lq']) 61 | 62 | with open(opt['meta_info_file'], 'r') as fin: 63 | self.keys = [line.split(' ')[0] for line in fin] 64 | 65 | # file client (io backend) 66 | self.file_client = None 67 | self.io_backend_opt = opt['io_backend'] 68 | self.is_lmdb = False 69 | if self.io_backend_opt['type'] == 'lmdb': 70 | self.is_lmdb = True 71 | self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] 72 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 73 | 74 | # indices of input images 75 | self.neighbor_list = [ 76 | i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame']) 77 | ] 78 | 79 | # temporal augmentation configs 80 | self.random_reverse = opt['random_reverse'] 81 | logger = get_root_logger() 82 | logger.info(f'Random reverse is {self.random_reverse}.') 83 | 84 | def __getitem__(self, index): 85 | if self.file_client is None: 86 | self.file_client = FileClient( 87 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 88 | 89 | # random reverse 90 | if self.random_reverse and random.random() < 0.5: 91 | self.neighbor_list.reverse() 92 | 93 | scale = self.opt['scale'] 94 | gt_size = self.opt['gt_size'] 95 | key = self.keys[index] 96 | clip, seq = key.split('/') # key example: 00001/0001 97 | 98 | # get the GT frame (im4.png) 99 | if self.is_lmdb: 100 | img_gt_path = f'{key}/im4' 101 | else: 102 | img_gt_path = self.gt_root / clip / seq / 'im4.png' 103 | img_bytes = self.file_client.get(img_gt_path, 'gt') 104 | img_gt = imfrombytes(img_bytes, float32=True) 105 | 106 | # get the neighboring LQ frames 107 | img_lqs = [] 108 | for neighbor in self.neighbor_list: 109 | if self.is_lmdb: 110 | img_lq_path = f'{clip}/{seq}/im{neighbor}' 111 | else: 112 | img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' 113 | img_bytes = self.file_client.get(img_lq_path, 'lq') 114 | img_lq = imfrombytes(img_bytes, float32=True) 115 | img_lqs.append(img_lq) 116 | 117 | # randomly crop 118 | img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, 119 | img_gt_path) 120 | 121 | # augmentation - flip, rotate 122 | img_lqs.append(img_gt) 123 | img_results = augment(img_lqs, self.opt['use_flip'], 124 | self.opt['use_rot']) 125 | 126 | img_results = img2tensor(img_results) 127 | img_lqs = torch.stack(img_results[0:-1], dim=0) 128 | img_gt = img_results[-1] 129 | 130 | # img_lqs: (t, c, h, w) 131 | # img_gt: (c, h, w) 132 | # key: str 133 | return {'lq': img_lqs, 'gt': img_gt, 'key': key} 134 | 135 | def __len__(self): 136 | return len(self.keys) 137 | -------------------------------------------------------------------------------- /basicsr/demo.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import torch 8 | 9 | # from basicsr.data import create_dataloader, create_dataset 10 | from basicsr.models import create_model 11 | from basicsr.train import parse_options 12 | from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, tensor2img, imwrite 13 | 14 | # from basicsr.utils import (get_env_info, get_root_logger, get_time_str, 15 | # make_exp_dirs) 16 | # from basicsr.utils.options import dict2str 17 | 18 | def main(): 19 | # parse options, set distributed setting, set ramdom seed 20 | opt = parse_options(is_train=False) 21 | opt['num_gpu'] = torch.cuda.device_count() 22 | 23 | img_path = opt['img_path'].get('input_img') 24 | output_path = opt['img_path'].get('output_img') 25 | 26 | 27 | ## 1. read image 28 | file_client = FileClient('disk') 29 | 30 | img_bytes = file_client.get(img_path, None) 31 | try: 32 | img = imfrombytes(img_bytes, float32=True) 33 | except: 34 | raise Exception("path {} not working".format(img_path)) 35 | 36 | img = img2tensor(img, bgr2rgb=True, float32=True) 37 | 38 | 39 | 40 | ## 2. run inference 41 | opt['dist'] = False 42 | model = create_model(opt) 43 | 44 | model.feed_data(data={'lq': img.unsqueeze(dim=0)}) 45 | 46 | if model.opt['val'].get('grids', False): 47 | model.grids() 48 | 49 | model.test() 50 | 51 | if model.opt['val'].get('grids', False): 52 | model.grids_inverse() 53 | 54 | visuals = model.get_current_visuals() 55 | sr_img = tensor2img([visuals['result']]) 56 | imwrite(sr_img, output_path) 57 | 58 | print(f'inference {img_path} .. finished. saved to {output_path}') 59 | 60 | if __name__ == '__main__': 61 | main() 62 | 63 | -------------------------------------------------------------------------------- /basicsr/demo_ssr.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import torch 8 | 9 | # from basicsr.data import create_dataloader, create_dataset 10 | from basicsr.models import create_model 11 | from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, tensor2img, imwrite, set_random_seed 12 | 13 | import argparse 14 | from basicsr.utils.options import dict2str, parse 15 | from basicsr.utils.dist_util import get_dist_info, init_dist 16 | import random 17 | 18 | def parse_options(is_train=True): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | '-opt', type=str, required=True, help='Path to option YAML file.') 22 | parser.add_argument( 23 | '--launcher', 24 | choices=['none', 'pytorch', 'slurm'], 25 | default='none', 26 | help='job launcher') 27 | parser.add_argument('--local_rank', type=int, default=0) 28 | 29 | parser.add_argument('--input_l_path', type=str, required=True, help='The path to the input left image. For stereo image inference only.') 30 | parser.add_argument('--input_r_path', type=str, required=True, help='The path to the input right image. For stereo image inference only.') 31 | parser.add_argument('--output_l_path', type=str, required=True, help='The path to the output left image. For stereo image inference only.') 32 | parser.add_argument('--output_r_path', type=str, required=True, help='The path to the output right image. For stereo image inference only.') 33 | 34 | args = parser.parse_args() 35 | opt = parse(args.opt, is_train=is_train) 36 | 37 | # distributed settings 38 | if args.launcher == 'none': 39 | opt['dist'] = False 40 | print('Disable distributed.', flush=True) 41 | else: 42 | opt['dist'] = True 43 | if args.launcher == 'slurm' and 'dist_params' in opt: 44 | init_dist(args.launcher, **opt['dist_params']) 45 | else: 46 | init_dist(args.launcher) 47 | print('init dist .. ', args.launcher) 48 | 49 | opt['rank'], opt['world_size'] = get_dist_info() 50 | 51 | # random seed 52 | seed = opt.get('manual_seed') 53 | if seed is None: 54 | seed = random.randint(1, 10000) 55 | opt['manual_seed'] = seed 56 | set_random_seed(seed + opt['rank']) 57 | 58 | opt['img_path'] = { 59 | 'input_l': args.input_l_path, 60 | 'input_r': args.input_r_path, 61 | 'output_l': args.output_l_path, 62 | 'output_r': args.output_r_path 63 | } 64 | 65 | return opt 66 | 67 | def imread(img_path): 68 | file_client = FileClient('disk') 69 | img_bytes = file_client.get(img_path, None) 70 | try: 71 | img = imfrombytes(img_bytes, float32=True) 72 | except: 73 | raise Exception("path {} not working".format(img_path)) 74 | 75 | img = img2tensor(img, bgr2rgb=True, float32=True) 76 | return img 77 | 78 | def main(): 79 | # parse options, set distributed setting, set ramdom seed 80 | opt = parse_options(is_train=False) 81 | opt['num_gpu'] = torch.cuda.device_count() 82 | 83 | img_l_path = opt['img_path'].get('input_l') 84 | img_r_path = opt['img_path'].get('input_r') 85 | output_l_path = opt['img_path'].get('output_l') 86 | output_r_path = opt['img_path'].get('output_r') 87 | 88 | ## 1. read image 89 | img_l = imread(img_l_path) 90 | img_r = imread(img_r_path) 91 | img = torch.cat([img_l, img_r], dim=0) 92 | 93 | ## 2. run inference 94 | opt['dist'] = False 95 | model = create_model(opt) 96 | 97 | model.feed_data(data={'lq': img.unsqueeze(dim=0)}) 98 | 99 | if model.opt['val'].get('grids', False): 100 | model.grids() 101 | 102 | model.test() 103 | 104 | if model.opt['val'].get('grids', False): 105 | model.grids_inverse() 106 | 107 | visuals = model.get_current_visuals() 108 | sr_img_l = visuals['result'][:,:3] 109 | sr_img_r = visuals['result'][:,3:] 110 | sr_img_l, sr_img_r = tensor2img([sr_img_l, sr_img_r]) 111 | imwrite(sr_img_l, output_l_path) 112 | imwrite(sr_img_r, output_r_path) 113 | 114 | print(f'inference {img_l_path} .. finished. saved to {output_l_path}') 115 | print(f'inference {img_r_path} .. finished. saved to {output_r_path}') 116 | 117 | if __name__ == '__main__': 118 | main() 119 | 120 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | from .niqe import calculate_niqe 8 | from .psnr_ssim import calculate_psnr, calculate_ssim, calculate_ssim_left, calculate_psnr_left, calculate_skimage_ssim, calculate_skimage_ssim_left 9 | 10 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe', 'calculate_ssim_left', 'calculate_psnr_left', 'calculate_skimage_ssim', 'calculate_skimage_ssim_left'] 11 | -------------------------------------------------------------------------------- /basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from scipy import linalg 11 | from tqdm import tqdm 12 | 13 | from basicsr.models.archs.inception import InceptionV3 14 | 15 | 16 | def load_patched_inception_v3(device='cuda', 17 | resize_input=True, 18 | normalize_input=False): 19 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 20 | # does resize the input. 21 | inception = InceptionV3([3], 22 | resize_input=resize_input, 23 | normalize_input=normalize_input) 24 | inception = nn.DataParallel(inception).eval().to(device) 25 | return inception 26 | 27 | 28 | @torch.no_grad() 29 | def extract_inception_features(data_generator, 30 | inception, 31 | len_generator=None, 32 | device='cuda'): 33 | """Extract inception features. 34 | 35 | Args: 36 | data_generator (generator): A data generator. 37 | inception (nn.Module): Inception model. 38 | len_generator (int): Length of the data_generator to show the 39 | progressbar. Default: None. 40 | device (str): Device. Default: cuda. 41 | 42 | Returns: 43 | Tensor: Extracted features. 44 | """ 45 | if len_generator is not None: 46 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 47 | else: 48 | pbar = None 49 | features = [] 50 | 51 | for data in data_generator: 52 | if pbar: 53 | pbar.update(1) 54 | data = data.to(device) 55 | feature = inception(data)[0].view(data.shape[0], -1) 56 | features.append(feature.to('cpu')) 57 | if pbar: 58 | pbar.close() 59 | features = torch.cat(features, 0) 60 | return features 61 | 62 | 63 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 64 | """Numpy implementation of the Frechet Distance. 65 | 66 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 67 | and X_2 ~ N(mu_2, C_2) is 68 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 69 | Stable version by Dougal J. Sutherland. 70 | 71 | Args: 72 | mu1 (np.array): The sample mean over activations. 73 | sigma1 (np.array): The covariance matrix over activations for 74 | generated samples. 75 | mu2 (np.array): The sample mean over activations, precalculated on an 76 | representative data set. 77 | sigma2 (np.array): The covariance matrix over activations, 78 | precalculated on an representative data set. 79 | 80 | Returns: 81 | float: The Frechet Distance. 82 | """ 83 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 84 | assert sigma1.shape == sigma2.shape, ( 85 | 'Two covariances have different dimensions') 86 | 87 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 88 | 89 | # Product might be almost singular 90 | if not np.isfinite(cov_sqrt).all(): 91 | print('Product of cov matrices is singular. Adding {eps} to diagonal ' 92 | 'of cov estimates') 93 | offset = np.eye(sigma1.shape[0]) * eps 94 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 95 | 96 | # Numerical error might give slight imaginary component 97 | if np.iscomplexobj(cov_sqrt): 98 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 99 | m = np.max(np.abs(cov_sqrt.imag)) 100 | raise ValueError(f'Imaginary component {m}') 101 | cov_sqrt = cov_sqrt.real 102 | 103 | mean_diff = mu1 - mu2 104 | mean_norm = mean_diff @ mean_diff 105 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 106 | fid = mean_norm + trace 107 | 108 | return fid 109 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import numpy as np 8 | 9 | from basicsr.utils.matlab_functions import bgr2ycbcr 10 | 11 | 12 | def reorder_image(img, input_order='HWC'): 13 | """Reorder images to 'HWC' order. 14 | 15 | If the input_order is (h, w), return (h, w, 1); 16 | If the input_order is (c, h, w), return (h, w, c); 17 | If the input_order is (h, w, c), return as it is. 18 | 19 | Args: 20 | img (ndarray): Input image. 21 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 22 | If the input image shape is (h, w), input_order will not have 23 | effects. Default: 'HWC'. 24 | 25 | Returns: 26 | ndarray: reordered image. 27 | """ 28 | 29 | if input_order not in ['HWC', 'CHW']: 30 | raise ValueError( 31 | f'Wrong input_order {input_order}. Supported input_orders are ' 32 | "'HWC' and 'CHW'") 33 | if len(img.shape) == 2: 34 | img = img[..., None] 35 | if input_order == 'CHW': 36 | img = img.transpose(1, 2, 0) 37 | return img 38 | 39 | 40 | def to_y_channel(img): 41 | """Change to Y channel of YCbCr. 42 | 43 | Args: 44 | img (ndarray): Images with range [0, 255]. 45 | 46 | Returns: 47 | (ndarray): Images with range [0, 255] (float type) without round. 48 | """ 49 | img = img.astype(np.float32) / 255. 50 | if img.ndim == 3 and img.shape[2] == 3: 51 | img = bgr2ycbcr(img, y_only=True) 52 | img = img[..., None] 53 | return img * 255. 54 | -------------------------------------------------------------------------------- /basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import importlib 8 | from os import path as osp 9 | 10 | from basicsr.utils import get_root_logger, scandir 11 | 12 | # automatically scan and import model modules 13 | # scan all the files under the 'models' folder and collect files ending with 14 | # '_model.py' 15 | model_folder = osp.dirname(osp.abspath(__file__)) 16 | model_filenames = [ 17 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) 18 | if v.endswith('_model.py') 19 | ] 20 | # import all the model modules 21 | _model_modules = [ 22 | importlib.import_module(f'basicsr.models.{file_name}') 23 | for file_name in model_filenames 24 | ] 25 | 26 | 27 | def create_model(opt): 28 | """Create model. 29 | 30 | Args: 31 | opt (dict): Configuration. It constains: 32 | model_type (str): Model type. 33 | """ 34 | model_type = opt['model_type'] 35 | 36 | # dynamic instantiation 37 | for module in _model_modules: 38 | model_cls = getattr(module, model_type, None) 39 | if model_cls is not None: 40 | break 41 | if model_cls is None: 42 | raise ValueError(f'Model {model_type} is not found.') 43 | 44 | model = model_cls(opt) 45 | 46 | logger = get_root_logger() 47 | logger.info(f'Model [{model.__class__.__name__}] is created.') 48 | return model 49 | -------------------------------------------------------------------------------- /basicsr/models/archs/NAFSSR_arch.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | 5 | ''' 6 | NAFSSR: Stereo Image Super-Resolution Using NAFNet 7 | 8 | @InProceedings{Chu2022NAFSSR, 9 | author = {Xiaojie Chu and Liangyu Chen and Wenqing Yu}, 10 | title = {NAFSSR: Stereo Image Super-Resolution Using NAFNet}, 11 | booktitle = {CVPRW}, 12 | year = {2022}, 13 | } 14 | ''' 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from basicsr.models.archs.NAFNet_arch import LayerNorm2d, NAFBlock 22 | from basicsr.models.archs.arch_util import MySequential 23 | from basicsr.models.archs.local_arch import Local_Base 24 | 25 | class SCAM(nn.Module): 26 | ''' 27 | Stereo Cross Attention Module (SCAM) 28 | ''' 29 | def __init__(self, c): 30 | super().__init__() 31 | self.scale = c ** -0.5 32 | 33 | self.norm_l = LayerNorm2d(c) 34 | self.norm_r = LayerNorm2d(c) 35 | self.l_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) 36 | self.r_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) 37 | 38 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 39 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 40 | 41 | self.l_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) 42 | self.r_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) 43 | 44 | def forward(self, x_l, x_r): 45 | Q_l = self.l_proj1(self.norm_l(x_l)).permute(0, 2, 3, 1) # B, H, W, c 46 | Q_r_T = self.r_proj1(self.norm_r(x_r)).permute(0, 2, 1, 3) # B, H, c, W (transposed) 47 | 48 | V_l = self.l_proj2(x_l).permute(0, 2, 3, 1) # B, H, W, c 49 | V_r = self.r_proj2(x_r).permute(0, 2, 3, 1) # B, H, W, c 50 | 51 | # (B, H, W, c) x (B, H, c, W) -> (B, H, W, W) 52 | attention = torch.matmul(Q_l, Q_r_T) * self.scale 53 | 54 | F_r2l = torch.matmul(torch.softmax(attention, dim=-1), V_r) #B, H, W, c 55 | F_l2r = torch.matmul(torch.softmax(attention.permute(0, 1, 3, 2), dim=-1), V_l) #B, H, W, c 56 | 57 | # scale 58 | F_r2l = F_r2l.permute(0, 3, 1, 2) * self.beta 59 | F_l2r = F_l2r.permute(0, 3, 1, 2) * self.gamma 60 | return x_l + F_r2l, x_r + F_l2r 61 | 62 | class DropPath(nn.Module): 63 | def __init__(self, drop_rate, module): 64 | super().__init__() 65 | self.drop_rate = drop_rate 66 | self.module = module 67 | 68 | def forward(self, *feats): 69 | if self.training and np.random.rand() < self.drop_rate: 70 | return feats 71 | 72 | new_feats = self.module(*feats) 73 | factor = 1. / (1 - self.drop_rate) if self.training else 1. 74 | 75 | if self.training and factor != 1.: 76 | new_feats = tuple([x+factor*(new_x-x) for x, new_x in zip(feats, new_feats)]) 77 | return new_feats 78 | 79 | class NAFBlockSR(nn.Module): 80 | ''' 81 | NAFBlock for Super-Resolution 82 | ''' 83 | def __init__(self, c, fusion=False, drop_out_rate=0.): 84 | super().__init__() 85 | self.blk = NAFBlock(c, drop_out_rate=drop_out_rate) 86 | self.fusion = SCAM(c) if fusion else None 87 | 88 | def forward(self, *feats): 89 | feats = tuple([self.blk(x) for x in feats]) 90 | if self.fusion: 91 | feats = self.fusion(*feats) 92 | return feats 93 | 94 | class NAFNetSR(nn.Module): 95 | ''' 96 | NAFNet for Super-Resolution 97 | ''' 98 | def __init__(self, up_scale=4, width=48, num_blks=16, img_channel=3, drop_path_rate=0., drop_out_rate=0., fusion_from=-1, fusion_to=-1, dual=False): 99 | super().__init__() 100 | self.dual = dual # dual input for stereo SR (left view, right view) 101 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, 102 | bias=True) 103 | self.body = MySequential( 104 | *[DropPath( 105 | drop_path_rate, 106 | NAFBlockSR( 107 | width, 108 | fusion=(fusion_from <= i and i <= fusion_to), 109 | drop_out_rate=drop_out_rate 110 | )) for i in range(num_blks)] 111 | ) 112 | 113 | self.up = nn.Sequential( 114 | nn.Conv2d(in_channels=width, out_channels=img_channel * up_scale**2, kernel_size=3, padding=1, stride=1, groups=1, bias=True), 115 | nn.PixelShuffle(up_scale) 116 | ) 117 | self.up_scale = up_scale 118 | 119 | def forward(self, inp): 120 | inp_hr = F.interpolate(inp, scale_factor=self.up_scale, mode='bilinear') 121 | if self.dual: 122 | inp = inp.chunk(2, dim=1) 123 | else: 124 | inp = (inp, ) 125 | feats = [self.intro(x) for x in inp] 126 | feats = self.body(*feats) 127 | out = torch.cat([self.up(x) for x in feats], dim=1) 128 | out = out + inp_hr 129 | return out 130 | 131 | class NAFSSR(Local_Base, NAFNetSR): 132 | def __init__(self, *args, train_size=(1, 6, 30, 90), fast_imp=False, fusion_from=-1, fusion_to=1000, **kwargs): 133 | Local_Base.__init__(self) 134 | NAFNetSR.__init__(self, *args, img_channel=3, fusion_from=fusion_from, fusion_to=fusion_to, dual=True, **kwargs) 135 | 136 | N, C, H, W = train_size 137 | base_size = (int(H * 1.5), int(W * 1.5)) 138 | 139 | self.eval() 140 | with torch.no_grad(): 141 | self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) 142 | 143 | if __name__ == '__main__': 144 | num_blks = 128 145 | width = 128 146 | droppath=0.1 147 | train_size = (1, 6, 30, 90) 148 | 149 | net = NAFSSR(up_scale=2,train_size=train_size, fast_imp=True, width=width, num_blks=num_blks, drop_path_rate=droppath) 150 | 151 | inp_shape = (6, 64, 64) 152 | 153 | from ptflops import get_model_complexity_info 154 | FLOPS = 0 155 | macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True) 156 | 157 | # params = float(params[:-4]) 158 | print(params) 159 | macs = float(macs[:-4]) + FLOPS / 10 ** 9 160 | 161 | print('mac', macs, params) 162 | 163 | # from basicsr.models.archs.arch_util import measure_inference_speed 164 | # net = net.cuda() 165 | # data = torch.randn((1, 6, 128, 128)).cuda() 166 | # measure_inference_speed(net, (data,)) 167 | 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /basicsr/models/archs/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import importlib 8 | from os import path as osp 9 | 10 | from basicsr.utils import scandir 11 | 12 | # automatically scan and import arch modules 13 | # scan all the files under the 'archs' folder and collect files ending with 14 | # '_arch.py' 15 | arch_folder = osp.dirname(osp.abspath(__file__)) 16 | arch_filenames = [ 17 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) 18 | if v.endswith('_arch.py') 19 | ] 20 | # import all the arch modules 21 | _arch_modules = [ 22 | importlib.import_module(f'basicsr.models.archs.{file_name}') 23 | for file_name in arch_filenames 24 | ] 25 | 26 | 27 | def dynamic_instantiation(modules, cls_type, opt): 28 | """Dynamically instantiate class. 29 | 30 | Args: 31 | modules (list[importlib modules]): List of modules from importlib 32 | files. 33 | cls_type (str): Class type. 34 | opt (dict): Class initialization kwargs. 35 | 36 | Returns: 37 | class: Instantiated class. 38 | """ 39 | 40 | for module in modules: 41 | cls_ = getattr(module, cls_type, None) 42 | if cls_ is not None: 43 | break 44 | if cls_ is None: 45 | raise ValueError(f'{cls_type} is not found.') 46 | return cls_(**opt) 47 | 48 | 49 | def define_network(opt): 50 | network_type = opt.pop('type') 51 | net = dynamic_instantiation(_arch_modules, network_type, opt) 52 | return net 53 | -------------------------------------------------------------------------------- /basicsr/models/archs/local_arch.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class AvgPool2d(nn.Module): 11 | def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None): 12 | super().__init__() 13 | self.kernel_size = kernel_size 14 | self.base_size = base_size 15 | self.auto_pad = auto_pad 16 | 17 | # only used for fast implementation 18 | self.fast_imp = fast_imp 19 | self.rs = [5, 4, 3, 2, 1] 20 | self.max_r1 = self.rs[0] 21 | self.max_r2 = self.rs[0] 22 | self.train_size = train_size 23 | 24 | def extra_repr(self) -> str: 25 | return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format( 26 | self.kernel_size, self.base_size, self.kernel_size, self.fast_imp 27 | ) 28 | 29 | def forward(self, x): 30 | if self.kernel_size is None and self.base_size: 31 | train_size = self.train_size 32 | if isinstance(self.base_size, int): 33 | self.base_size = (self.base_size, self.base_size) 34 | self.kernel_size = list(self.base_size) 35 | self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2] 36 | self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1] 37 | 38 | # only used for fast implementation 39 | self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2]) 40 | self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1]) 41 | 42 | if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1): 43 | return F.adaptive_avg_pool2d(x, 1) 44 | 45 | if self.fast_imp: # Non-equivalent implementation but faster 46 | h, w = x.shape[2:] 47 | if self.kernel_size[0] >= h and self.kernel_size[1] >= w: 48 | out = F.adaptive_avg_pool2d(x, 1) 49 | else: 50 | r1 = [r for r in self.rs if h % r == 0][0] 51 | r2 = [r for r in self.rs if w % r == 0][0] 52 | # reduction_constraint 53 | r1 = min(self.max_r1, r1) 54 | r2 = min(self.max_r2, r2) 55 | s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2) 56 | n, c, h, w = s.shape 57 | k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2) 58 | out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2) 59 | out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2)) 60 | else: 61 | n, c, h, w = x.shape 62 | s = x.cumsum(dim=-1).cumsum_(dim=-2) 63 | s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience 64 | k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1]) 65 | s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:] 66 | out = s4 + s1 - s2 - s3 67 | out = out / (k1 * k2) 68 | 69 | if self.auto_pad: 70 | n, c, h, w = x.shape 71 | _h, _w = out.shape[2:] 72 | # print(x.shape, self.kernel_size) 73 | pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2) 74 | out = torch.nn.functional.pad(out, pad2d, mode='replicate') 75 | 76 | return out 77 | 78 | def replace_layers(model, base_size, train_size, fast_imp, **kwargs): 79 | for n, m in model.named_children(): 80 | if len(list(m.children())) > 0: 81 | ## compound module, go inside it 82 | replace_layers(m, base_size, train_size, fast_imp, **kwargs) 83 | 84 | if isinstance(m, nn.AdaptiveAvgPool2d): 85 | pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size) 86 | assert m.output_size == 1 87 | setattr(model, n, pool) 88 | 89 | 90 | ''' 91 | ref. 92 | @article{chu2021tlsc, 93 | title={Revisiting Global Statistics Aggregation for Improving Image Restoration}, 94 | author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin}, 95 | journal={arXiv preprint arXiv:2112.04491}, 96 | year={2021} 97 | } 98 | ''' 99 | class Local_Base(): 100 | def convert(self, *args, train_size, **kwargs): 101 | replace_layers(self, *args, train_size=train_size, **kwargs) 102 | imgs = torch.rand(train_size) 103 | with torch.no_grad(): 104 | self.forward(imgs) 105 | -------------------------------------------------------------------------------- /basicsr/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | from .losses import (L1Loss, MSELoss, PSNRLoss) 8 | 9 | __all__ = [ 10 | 'L1Loss', 'MSELoss', 'PSNRLoss', 11 | ] 12 | -------------------------------------------------------------------------------- /basicsr/models/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import functools 8 | from torch.nn import functional as F 9 | 10 | 11 | def reduce_loss(loss, reduction): 12 | """Reduce loss as specified. 13 | 14 | Args: 15 | loss (Tensor): Elementwise loss tensor. 16 | reduction (str): Options are 'none', 'mean' and 'sum'. 17 | 18 | Returns: 19 | Tensor: Reduced loss tensor. 20 | """ 21 | reduction_enum = F._Reduction.get_enum(reduction) 22 | # none: 0, elementwise_mean:1, sum: 2 23 | if reduction_enum == 0: 24 | return loss 25 | elif reduction_enum == 1: 26 | return loss.mean() 27 | else: 28 | return loss.sum() 29 | 30 | 31 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 32 | """Apply element-wise weight and reduce loss. 33 | 34 | Args: 35 | loss (Tensor): Element-wise loss. 36 | weight (Tensor): Element-wise weights. Default: None. 37 | reduction (str): Same as built-in losses of PyTorch. Options are 38 | 'none', 'mean' and 'sum'. Default: 'mean'. 39 | 40 | Returns: 41 | Tensor: Loss values. 42 | """ 43 | # if weight is specified, apply element-wise weight 44 | if weight is not None: 45 | assert weight.dim() == loss.dim() 46 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 47 | loss = loss * weight 48 | 49 | # if weight is not specified or reduction is sum, just reduce the loss 50 | if weight is None or reduction == 'sum': 51 | loss = reduce_loss(loss, reduction) 52 | # if reduction is mean, then compute mean over weight region 53 | elif reduction == 'mean': 54 | if weight.size(1) > 1: 55 | weight = weight.sum() 56 | else: 57 | weight = weight.sum() * loss.size(1) 58 | loss = loss.sum() / weight 59 | 60 | return loss 61 | 62 | 63 | def weighted_loss(loss_func): 64 | """Create a weighted version of a given loss function. 65 | 66 | To use this decorator, the loss function must have the signature like 67 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 68 | element-wise loss without any reduction. This decorator will add weight 69 | and reduction arguments to the function. The decorated function will have 70 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 71 | **kwargs)`. 72 | 73 | :Example: 74 | 75 | >>> import torch 76 | >>> @weighted_loss 77 | >>> def l1_loss(pred, target): 78 | >>> return (pred - target).abs() 79 | 80 | >>> pred = torch.Tensor([0, 2, 3]) 81 | >>> target = torch.Tensor([1, 1, 1]) 82 | >>> weight = torch.Tensor([1, 0, 1]) 83 | 84 | >>> l1_loss(pred, target) 85 | tensor(1.3333) 86 | >>> l1_loss(pred, target, weight) 87 | tensor(1.5000) 88 | >>> l1_loss(pred, target, reduction='none') 89 | tensor([1., 1., 2.]) 90 | >>> l1_loss(pred, target, weight, reduction='sum') 91 | tensor(3.) 92 | """ 93 | 94 | @functools.wraps(loss_func) 95 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 96 | # get element-wise loss 97 | loss = loss_func(pred, target, **kwargs) 98 | loss = weight_reduce_loss(loss, weight, reduction) 99 | return loss 100 | 101 | return wrapper 102 | -------------------------------------------------------------------------------- /basicsr/models/losses/losses.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import torch 8 | from torch import nn as nn 9 | from torch.nn import functional as F 10 | import numpy as np 11 | 12 | from basicsr.models.losses.loss_util import weighted_loss 13 | 14 | _reduction_modes = ['none', 'mean', 'sum'] 15 | 16 | 17 | @weighted_loss 18 | def l1_loss(pred, target): 19 | return F.l1_loss(pred, target, reduction='none') 20 | 21 | 22 | @weighted_loss 23 | def mse_loss(pred, target): 24 | return F.mse_loss(pred, target, reduction='none') 25 | 26 | 27 | # @weighted_loss 28 | # def charbonnier_loss(pred, target, eps=1e-12): 29 | # return torch.sqrt((pred - target)**2 + eps) 30 | 31 | 32 | class L1Loss(nn.Module): 33 | """L1 (mean absolute error, MAE) loss. 34 | 35 | Args: 36 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 37 | reduction (str): Specifies the reduction to apply to the output. 38 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 39 | """ 40 | 41 | def __init__(self, loss_weight=1.0, reduction='mean'): 42 | super(L1Loss, self).__init__() 43 | if reduction not in ['none', 'mean', 'sum']: 44 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 45 | f'Supported ones are: {_reduction_modes}') 46 | 47 | self.loss_weight = loss_weight 48 | self.reduction = reduction 49 | 50 | def forward(self, pred, target, weight=None, **kwargs): 51 | """ 52 | Args: 53 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 54 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 55 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 56 | weights. Default: None. 57 | """ 58 | return self.loss_weight * l1_loss( 59 | pred, target, weight, reduction=self.reduction) 60 | 61 | class MSELoss(nn.Module): 62 | """MSE (L2) loss. 63 | 64 | Args: 65 | loss_weight (float): Loss weight for MSE loss. Default: 1.0. 66 | reduction (str): Specifies the reduction to apply to the output. 67 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 68 | """ 69 | 70 | def __init__(self, loss_weight=1.0, reduction='mean'): 71 | super(MSELoss, self).__init__() 72 | if reduction not in ['none', 'mean', 'sum']: 73 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 74 | f'Supported ones are: {_reduction_modes}') 75 | 76 | self.loss_weight = loss_weight 77 | self.reduction = reduction 78 | 79 | def forward(self, pred, target, weight=None, **kwargs): 80 | """ 81 | Args: 82 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 83 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 84 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 85 | weights. Default: None. 86 | """ 87 | return self.loss_weight * mse_loss( 88 | pred, target, weight, reduction=self.reduction) 89 | 90 | class PSNRLoss(nn.Module): 91 | 92 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 93 | super(PSNRLoss, self).__init__() 94 | assert reduction == 'mean' 95 | self.loss_weight = loss_weight 96 | self.scale = 10 / np.log(10) 97 | self.toY = toY 98 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 99 | self.first = True 100 | 101 | def forward(self, pred, target): 102 | assert len(pred.size()) == 4 103 | if self.toY: 104 | if self.first: 105 | self.coef = self.coef.to(pred.device) 106 | self.first = False 107 | 108 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 109 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 110 | 111 | pred, target = pred / 255., target / 255. 112 | pass 113 | assert len(pred.size()) == 4 114 | 115 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() 116 | 117 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import logging 8 | import torch 9 | from os import path as osp 10 | 11 | from basicsr.data import create_dataloader, create_dataset 12 | from basicsr.models import create_model 13 | from basicsr.train import parse_options 14 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str, 15 | make_exp_dirs) 16 | from basicsr.utils.options import dict2str 17 | 18 | 19 | def main(): 20 | # parse options, set distributed setting, set ramdom seed 21 | opt = parse_options(is_train=False) 22 | 23 | torch.backends.cudnn.benchmark = True 24 | # torch.backends.cudnn.deterministic = True 25 | 26 | # mkdir and initialize loggers 27 | make_exp_dirs(opt) 28 | log_file = osp.join(opt['path']['log'], 29 | f"test_{opt['name']}_{get_time_str()}.log") 30 | logger = get_root_logger( 31 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 32 | logger.info(get_env_info()) 33 | logger.info(dict2str(opt)) 34 | 35 | # create test dataset and dataloader 36 | test_loaders = [] 37 | for phase, dataset_opt in sorted(opt['datasets'].items()): 38 | if 'test' in phase: 39 | dataset_opt['phase'] = 'test' 40 | test_set = create_dataset(dataset_opt) 41 | test_loader = create_dataloader( 42 | test_set, 43 | dataset_opt, 44 | num_gpu=opt['num_gpu'], 45 | dist=opt['dist'], 46 | sampler=None, 47 | seed=opt['manual_seed']) 48 | logger.info( 49 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 50 | test_loaders.append(test_loader) 51 | 52 | # create model 53 | model = create_model(opt) 54 | 55 | for test_loader in test_loaders: 56 | test_set_name = test_loader.dataset.opt['name'] 57 | logger.info(f'Testing {test_set_name}...') 58 | rgb2bgr = opt['val'].get('rgb2bgr', True) 59 | # wheather use uint8 image to compute metrics 60 | use_image = opt['val'].get('use_image', True) 61 | model.validation( 62 | test_loader, 63 | current_iter=opt['name'], 64 | tb_logger=None, 65 | save_img=opt['val']['save_img'], 66 | rgb2bgr=rgb2bgr, use_image=use_image) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | from .file_client import FileClient 8 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding 9 | from .logger import (MessageLogger, get_env_info, get_root_logger, 10 | init_tb_logger, init_wandb_logger) 11 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, 12 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt) 13 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) 14 | 15 | __all__ = [ 16 | # file_client.py 17 | 'FileClient', 18 | # img_util.py 19 | 'img2tensor', 20 | 'tensor2img', 21 | 'imfrombytes', 22 | 'imwrite', 23 | 'crop_border', 24 | # logger.py 25 | 'MessageLogger', 26 | 'init_tb_logger', 27 | 'init_wandb_logger', 28 | 'get_root_logger', 29 | 'get_env_info', 30 | # misc.py 31 | 'set_random_seed', 32 | 'get_time_str', 33 | 'mkdir_and_rename', 34 | 'make_exp_dirs', 35 | 'scandir', 36 | 'scandir_SIDD', 37 | 'check_resume', 38 | 'sizeof_fmt', 39 | 'padding', 40 | 'create_lmdb_for_reds', 41 | 'create_lmdb_for_gopro', 42 | 'create_lmdb_for_rain13k', 43 | ] 44 | -------------------------------------------------------------------------------- /basicsr/utils/create_lmdb.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import argparse 8 | from os import path as osp 9 | 10 | from basicsr.utils import scandir 11 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs 12 | 13 | def prepare_keys(folder_path, suffix='png'): 14 | """Prepare image path list and keys for DIV2K dataset. 15 | 16 | Args: 17 | folder_path (str): Folder path. 18 | 19 | Returns: 20 | list[str]: Image path list. 21 | list[str]: Key list. 22 | """ 23 | print('Reading image path list ...') 24 | img_path_list = sorted( 25 | list(scandir(folder_path, suffix=suffix, recursive=False))) 26 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] 27 | 28 | return img_path_list, keys 29 | 30 | def create_lmdb_for_reds(): 31 | # folder_path = './datasets/REDS/val/sharp_300' 32 | # lmdb_path = './datasets/REDS/val/sharp_300.lmdb' 33 | # img_path_list, keys = prepare_keys(folder_path, 'png') 34 | # make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 35 | # 36 | # folder_path = './datasets/REDS/val/blur_300' 37 | # lmdb_path = './datasets/REDS/val/blur_300.lmdb' 38 | # img_path_list, keys = prepare_keys(folder_path, 'jpg') 39 | # make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 40 | 41 | folder_path = './datasets/REDS/train/train_sharp' 42 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb' 43 | img_path_list, keys = prepare_keys(folder_path, 'png') 44 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 45 | 46 | folder_path = './datasets/REDS/train/train_blur_jpeg' 47 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb' 48 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 49 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 50 | 51 | 52 | def create_lmdb_for_gopro(): 53 | folder_path = './datasets/GoPro/train/blur_crops' 54 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb' 55 | 56 | img_path_list, keys = prepare_keys(folder_path, 'png') 57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 58 | 59 | folder_path = './datasets/GoPro/train/sharp_crops' 60 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb' 61 | 62 | img_path_list, keys = prepare_keys(folder_path, 'png') 63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 64 | 65 | # folder_path = './datasets/GoPro/test/target' 66 | # lmdb_path = './datasets/GoPro/test/target.lmdb' 67 | 68 | # img_path_list, keys = prepare_keys(folder_path, 'png') 69 | # make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 70 | 71 | # folder_path = './datasets/GoPro/test/input' 72 | # lmdb_path = './datasets/GoPro/test/input.lmdb' 73 | 74 | # img_path_list, keys = prepare_keys(folder_path, 'png') 75 | # make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 76 | 77 | def create_lmdb_for_rain13k(): 78 | folder_path = './datasets/Rain13k/train/input' 79 | lmdb_path = './datasets/Rain13k/train/input.lmdb' 80 | 81 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 83 | 84 | folder_path = './datasets/Rain13k/train/target' 85 | lmdb_path = './datasets/Rain13k/train/target.lmdb' 86 | 87 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 88 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 89 | 90 | def create_lmdb_for_SIDD(): 91 | folder_path = './datasets/SIDD/train/input_crops' 92 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb' 93 | 94 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 96 | 97 | folder_path = './datasets/SIDD/train/gt_crops' 98 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb' 99 | 100 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 101 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 102 | 103 | #for val 104 | ''' 105 | 106 | folder_path = './datasets/SIDD/val/input_crops' 107 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb' 108 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat' 109 | if not osp.exists(folder_path): 110 | os.makedirs(folder_path) 111 | assert osp.exists(mat_path) 112 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb'] 113 | N, B, H ,W, C = data.shape 114 | data = data.reshape(N*B, H, W, C) 115 | for i in tqdm(range(N*B)): 116 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 117 | img_path_list, keys = prepare_keys(folder_path, 'png') 118 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 119 | 120 | folder_path = './datasets/SIDD/val/gt_crops' 121 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb' 122 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat' 123 | if not osp.exists(folder_path): 124 | os.makedirs(folder_path) 125 | assert osp.exists(mat_path) 126 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb'] 127 | N, B, H ,W, C = data.shape 128 | data = data.reshape(N*B, H, W, C) 129 | for i in tqdm(range(N*B)): 130 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 131 | img_path_list, keys = prepare_keys(folder_path, 'png') 132 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 133 | ''' 134 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | 8 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 9 | import functools 10 | import os 11 | import subprocess 12 | import torch 13 | import torch.distributed as dist 14 | import torch.multiprocessing as mp 15 | 16 | 17 | def init_dist(launcher, backend='nccl', **kwargs): 18 | if mp.get_start_method(allow_none=True) is None: 19 | mp.set_start_method('spawn') 20 | if launcher == 'pytorch': 21 | _init_dist_pytorch(backend, **kwargs) 22 | elif launcher == 'slurm': 23 | _init_dist_slurm(backend, **kwargs) 24 | else: 25 | raise ValueError(f'Invalid launcher type: {launcher}') 26 | 27 | 28 | def _init_dist_pytorch(backend, **kwargs): 29 | rank = int(os.environ['RANK']) 30 | num_gpus = torch.cuda.device_count() 31 | torch.cuda.set_device(rank % num_gpus) 32 | dist.init_process_group(backend=backend, **kwargs) 33 | 34 | 35 | def _init_dist_slurm(backend, port=None): 36 | """Initialize slurm distributed training environment. 37 | 38 | If argument ``port`` is not specified, then the master port will be system 39 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 40 | environment variable, then a default port ``29500`` will be used. 41 | 42 | Args: 43 | backend (str): Backend of torch.distributed. 44 | port (int, optional): Master port. Defaults to None. 45 | """ 46 | proc_id = int(os.environ['SLURM_PROCID']) 47 | ntasks = int(os.environ['SLURM_NTASKS']) 48 | node_list = os.environ['SLURM_NODELIST'] 49 | num_gpus = torch.cuda.device_count() 50 | torch.cuda.set_device(proc_id % num_gpus) 51 | addr = subprocess.getoutput( 52 | f'scontrol show hostname {node_list} | head -n1') 53 | # specify master port 54 | if port is not None: 55 | os.environ['MASTER_PORT'] = str(port) 56 | elif 'MASTER_PORT' in os.environ: 57 | pass # use MASTER_PORT in the environment variable 58 | else: 59 | # 29500 is torch.distributed default port 60 | os.environ['MASTER_PORT'] = '29500' 61 | os.environ['MASTER_ADDR'] = addr 62 | os.environ['WORLD_SIZE'] = str(ntasks) 63 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 64 | os.environ['RANK'] = str(proc_id) 65 | dist.init_process_group(backend=backend) 66 | 67 | 68 | def get_dist_info(): 69 | if dist.is_available(): 70 | initialized = dist.is_initialized() 71 | else: 72 | initialized = False 73 | if initialized: 74 | rank = dist.get_rank() 75 | world_size = dist.get_world_size() 76 | else: 77 | rank = 0 78 | world_size = 1 79 | return rank, world_size 80 | 81 | 82 | def master_only(func): 83 | 84 | @functools.wraps(func) 85 | def wrapper(*args, **kwargs): 86 | rank, _ = get_dist_info() 87 | if rank == 0: 88 | return func(*args, **kwargs) 89 | 90 | return wrapper 91 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import math 8 | import requests 9 | from tqdm import tqdm 10 | 11 | from .misc import sizeof_fmt 12 | 13 | 14 | def download_file_from_google_drive(file_id, save_path): 15 | """Download files from google drive. 16 | 17 | Ref: 18 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 19 | 20 | Args: 21 | file_id (str): File id. 22 | save_path (str): Save path. 23 | """ 24 | 25 | session = requests.Session() 26 | URL = 'https://docs.google.com/uc?export=download' 27 | params = {'id': file_id} 28 | 29 | response = session.get(URL, params=params, stream=True) 30 | token = get_confirm_token(response) 31 | if token: 32 | params['confirm'] = token 33 | response = session.get(URL, params=params, stream=True) 34 | 35 | # get file size 36 | response_file_size = session.get( 37 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 38 | if 'Content-Range' in response_file_size.headers: 39 | file_size = int( 40 | response_file_size.headers['Content-Range'].split('/')[1]) 41 | else: 42 | file_size = None 43 | 44 | save_response_content(response, save_path, file_size) 45 | 46 | 47 | def get_confirm_token(response): 48 | for key, value in response.cookies.items(): 49 | if key.startswith('download_warning'): 50 | return value 51 | return None 52 | 53 | 54 | def save_response_content(response, 55 | destination, 56 | file_size=None, 57 | chunk_size=32768): 58 | if file_size is not None: 59 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 60 | 61 | readable_file_size = sizeof_fmt(file_size) 62 | else: 63 | pbar = None 64 | 65 | with open(destination, 'wb') as f: 66 | downloaded_size = 0 67 | for chunk in response.iter_content(chunk_size): 68 | downloaded_size += chunk_size 69 | if pbar is not None: 70 | pbar.update(1) 71 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' 72 | f'/ {readable_file_size}') 73 | if chunk: # filter out keep-alive new chunks 74 | f.write(chunk) 75 | if pbar is not None: 76 | pbar.close() 77 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import yaml 8 | from collections import OrderedDict 9 | from os import path as osp 10 | 11 | 12 | def ordered_yaml(): 13 | """Support OrderedDict for yaml. 14 | 15 | Returns: 16 | yaml Loader and Dumper. 17 | """ 18 | try: 19 | from yaml import CDumper as Dumper 20 | from yaml import CLoader as Loader 21 | except ImportError: 22 | from yaml import Dumper, Loader 23 | 24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 25 | 26 | def dict_representer(dumper, data): 27 | return dumper.represent_dict(data.items()) 28 | 29 | def dict_constructor(loader, node): 30 | return OrderedDict(loader.construct_pairs(node)) 31 | 32 | Dumper.add_representer(OrderedDict, dict_representer) 33 | Loader.add_constructor(_mapping_tag, dict_constructor) 34 | return Loader, Dumper 35 | 36 | 37 | def parse(opt_path, is_train=True): 38 | """Parse option file. 39 | 40 | Args: 41 | opt_path (str): Option file path. 42 | is_train (str): Indicate whether in training or not. Default: True. 43 | 44 | Returns: 45 | (dict): Options. 46 | """ 47 | with open(opt_path, mode='r') as f: 48 | Loader, _ = ordered_yaml() 49 | opt = yaml.load(f, Loader=Loader) 50 | 51 | opt['is_train'] = is_train 52 | 53 | # datasets 54 | if 'datasets' in opt: 55 | for phase, dataset in opt['datasets'].items(): 56 | # for several datasets, e.g., test_1, test_2 57 | phase = phase.split('_')[0] 58 | dataset['phase'] = phase 59 | if 'scale' in opt: 60 | dataset['scale'] = opt['scale'] 61 | if dataset.get('dataroot_gt') is not None: 62 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 63 | if dataset.get('dataroot_lq') is not None: 64 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 65 | 66 | # paths 67 | for key, val in opt['path'].items(): 68 | if (val is not None) and ('resume_state' in key 69 | or 'pretrain_network' in key): 70 | opt['path'][key] = osp.expanduser(val) 71 | opt['path']['root'] = osp.abspath( 72 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 73 | if is_train: 74 | experiments_root = osp.join(opt['path']['root'], 'experiments', 75 | opt['name']) 76 | opt['path']['experiments_root'] = experiments_root 77 | opt['path']['models'] = osp.join(experiments_root, 'models') 78 | opt['path']['training_states'] = osp.join(experiments_root, 79 | 'training_states') 80 | opt['path']['log'] = experiments_root 81 | opt['path']['visualization'] = osp.join(experiments_root, 82 | 'visualization') 83 | 84 | # change some options for debug mode 85 | if 'debug' in opt['name']: 86 | if 'val' in opt: 87 | opt['val']['val_freq'] = 8 88 | opt['logger']['print_freq'] = 1 89 | opt['logger']['save_checkpoint_freq'] = 8 90 | else: # test 91 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 92 | opt['path']['results_root'] = results_root 93 | opt['path']['log'] = results_root 94 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 95 | 96 | return opt 97 | 98 | 99 | def dict2str(opt, indent_level=1): 100 | """dict to string for printing options. 101 | 102 | Args: 103 | opt (dict): Option dict. 104 | indent_level (int): Indent level. Default: 1. 105 | 106 | Return: 107 | (str): Option string for printing. 108 | """ 109 | msg = '\n' 110 | for k, v in opt.items(): 111 | if isinstance(v, dict): 112 | msg += ' ' * (indent_level * 2) + k + ':[' 113 | msg += dict2str(v, indent_level + 1) 114 | msg += ' ' * (indent_level * 2) + ']\n' 115 | else: 116 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 117 | return msg 118 | -------------------------------------------------------------------------------- /basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Mon Apr 18 21:35:20 2022 3 | __version__ = '1.2.0+386ca20' 4 | short_version = '1.2.0' 5 | version_info = (1, 2, 0) 6 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | cuda: "11.3" 3 | gpu: true 4 | python_version: "3.9" 5 | system_packages: 6 | - "libgl1-mesa-glx" 7 | - "libglib2.0-0" 8 | python_packages: 9 | - "numpy==1.21.1" 10 | - "ipython==7.21.0" 11 | - "addict==2.4.0" 12 | - "future==0.18.2" 13 | - "lmdb==1.3.0" 14 | - "opencv-python==4.5.5.64" 15 | - "Pillow==9.1.0" 16 | - "pyyaml==6.0" 17 | - "torch==1.11.0" 18 | - "torchvision==0.12.0" 19 | - "tqdm==4.64.0" 20 | - "scipy==1.8.0" 21 | - "scikit-image==0.19.2" 22 | - "matplotlib==3.5.1" 23 | 24 | predict: "predict.py:Predictor" 25 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | ### Data Preparation 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /demo/blurry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/demo/blurry.jpg -------------------------------------------------------------------------------- /demo/denoise_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/demo/denoise_img.png -------------------------------------------------------------------------------- /demo/lr_img_l.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/demo/lr_img_l.png -------------------------------------------------------------------------------- /demo/lr_img_r.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/demo/lr_img_r.png -------------------------------------------------------------------------------- /demo/noisy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/demo/noisy.png -------------------------------------------------------------------------------- /demo/sr_img_l.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/demo/sr_img_l.png -------------------------------------------------------------------------------- /demo/sr_img_r.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/demo/sr_img_r.png -------------------------------------------------------------------------------- /docs/GoPro.md: -------------------------------------------------------------------------------- 1 | # reproduce the GoPro dataset results 2 | 3 | 4 | 5 | ### 1. Data Preparation 6 | 7 | ##### Download the train set and place it in ```./datasets/GoPro/train```: 8 | 9 | * [google drive](https://drive.google.com/file/d/1zgALzrLCC_tcXKu_iHQTHukKUVT1aodI/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1fdsn-M5JhxCL7oThEgt1Sw?pwd=9d26) 10 | * it should be like ```./datasets/GoPro/train/input ``` and ```./datasets/GoPro/train/target``` 11 | * ```python scripts/data_preparation/gopro.py``` to crop the train image pairs to 512x512 patches and make the data into lmdb format. 12 | 13 | ##### Download the evaluation data (in lmdb format) and place it in ```./datasets/GoPro/test/```: 14 | 15 | * [google drive](https://drive.google.com/file/d/1abXSfeRGrzj2mQ2n2vIBHtObU6vXvr7C/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1oZtEtYB7-2p3fCIspky_mw?pwd=rmv9) 16 | * it should be like ```./datasets/GoPro/test/input.lmdb``` and ```./datasets/GoPro/test/target.lmdb``` 17 | 18 | 19 | 20 | ### 2. Training 21 | 22 | * NAFNet-GoPro-width32: 23 | 24 | ``` 25 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/GoPro/NAFNet-width32.yml --launcher pytorch 26 | ``` 27 | 28 | * NAFNet-GoPro-width64: 29 | 30 | ``` 31 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/GoPro/NAFNet-width64.yml --launcher pytorch 32 | ``` 33 | 34 | * Baseline-GoPro-width32: 35 | 36 | ``` 37 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/GoPro/Baseline-width32.yml --launcher pytorch 38 | ``` 39 | 40 | * Baseline-GoPro-width64: 41 | 42 | ``` 43 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/GoPro/Baseline-width64.yml --launcher pytorch 44 | ``` 45 | 46 | * 8 gpus by default. Set ```--nproc_per_node``` to # of gpus for distributed validation. 47 | 48 | 49 | 50 | 51 | ### 3. Evaluation 52 | 53 | 54 | ##### Download the pretrain model in ```./experiments/pretrained_models/``` 55 | * **NAFNet-GoPro-width32**: [google drive](https://drive.google.com/file/d/1Fr2QadtDCEXg6iwWX8OzeZLbHOx2t5Bj/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1AbgG0yoROHmrRQN7dgzDvQ?pwd=so6v) 56 | * **NAFNet-GoPro-width64**: [google drive](https://drive.google.com/file/d/1S0PVRbyTakYY9a82kujgZLbMihfNBLfC/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1g-E1x6En-PbYXm94JfI1vg?pwd=wnwh) 57 | * **Baseline-GoPro-width32**: [google drive](https://drive.google.com/file/d/14z7CxRzVkYEhFgsZg79GlPTEr3VFIGyl/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1WnFKYTAQyAQ9XuD5nlHw_Q?pwd=oieh) 58 | * **Baseline-GoPro-width64**: [google drive](https://drive.google.com/file/d/1yy0oPNJjJxfaEmO0pfPW_TpeoCotYkuO/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1Fqi2T4nyF_wo4wh1QpgIGg?pwd=we36) 59 | 60 | 61 | 62 | ##### Testing on GoPro dataset 63 | 64 | * NAFNet-GoPro-width32: 65 | ``` 66 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/GoPro/NAFNet-width32.yml --launcher pytorch 67 | ``` 68 | 69 | * NAFNet-GoPro-width64: 70 | ``` 71 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/GoPro/NAFNet-width64.yml --launcher pytorch 72 | ``` 73 | 74 | * Baseline-GoPro-width32: 75 | ``` 76 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/GoPro/Baseline-width32.yml --launcher pytorch 77 | ``` 78 | 79 | * Baseline-GoPro-width64: 80 | ``` 81 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/GoPro/Baseline-width64.yml --launcher pytorch 82 | ``` 83 | 84 | * Test by a single gpu by default. Set ```--nproc_per_node``` to # of gpus for distributed validation. 85 | 86 | -------------------------------------------------------------------------------- /docs/REDS.md: -------------------------------------------------------------------------------- 1 | # reproduce the REDS dataset results 2 | 3 | 4 | 5 | ### 1. Data Preparation 6 | 7 | ##### Download the train set and place it in ```./datasets/REDS/train```: 8 | 9 | * google drive ([link](https://drive.google.com/file/d/1VTXyhwrTgcaUWklG-6Dh4MyCmYvX39mW/view) and [link](https://drive.google.com/file/d/1YLksKtMhd2mWyVSkvhDaDLWSc1qYNCz-/view)) or SNU CVLab Server ([link](http://data.cv.snu.ac.kr:8008/webdav/dataset/REDS/train_blur_jpeg.zip) and [link](http://data.cv.snu.ac.kr:8008/webdav/dataset/REDS/train_sharp.zip)) 10 | * it should be like ```./datasets/REDS/train/train_blur_jpeg ``` and ```./datasets/REDS/train/train_sharp``` 11 | * ```python scripts/data_preparation/reds.py``` to make the data into lmdb format. 12 | 13 | ##### Download the evaluation data (in lmdb format) and place it in ```./datasets/REDS/val/```: 14 | 15 | * [google drive](https://drive.google.com/file/d/1_WPxX6mDSzdyigvie_OlpI-Dknz7RHKh/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1yUGdGFHQGCB5LZKt9dVecw?pwd=ikki), 16 | * it should be like ```./datasets/REDS/val/blur_300.lmdb``` and ```./datasets/REDS/val/sharp_300.lmdb``` 17 | 18 | 19 | 20 | ### 2. Training 21 | 22 | * NAFNet-REDS-width64: 23 | 24 | ``` 25 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/REDS/NAFNet-width64.yml --launcher pytorch 26 | ``` 27 | 28 | * 8 gpus by default. Set ```--nproc_per_node``` to # of gpus for distributed validation. 29 | 30 | 31 | 32 | 33 | ### 3. Evaluation 34 | 35 | 36 | ##### Download the pretrain model in ```./experiments/pretrained_models/``` 37 | * **NAFNet-REDS-width64**: [google drive](https://drive.google.com/file/d/14D4V4raNYIOhETfcuuLI3bGLB-OYIv6X/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1vg89ccbpIxg3mK9IONBfGg?pwd=9fas) 38 | 39 | 40 | 41 | ##### Testing on REDS dataset 42 | 43 | * NAFNet-REDS-width64: 44 | ``` 45 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/REDS/NAFNet-width64.yml --launcher pytorch 46 | ``` 47 | 48 | * Test by a single gpu by default. Set ```--nproc_per_node``` to # of gpus for distributed validation. 49 | 50 | -------------------------------------------------------------------------------- /docs/SIDD.md: -------------------------------------------------------------------------------- 1 | # reproduce the SIDD dataset results 2 | 3 | 4 | 5 | ### 1. Data Preparation 6 | 7 | ##### Download the train set and place it in ```./datasets/SIDD/Data```: 8 | 9 | * [google drive](https://drive.google.com/file/d/1UHjWZzLPGweA9ZczmV8lFSRcIxqiOVJw/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1EnBVjrfFBiXIRPBgjFrifg?pwd=sl6h), 10 | * ```python scripts/data_preparation/sidd.py``` to crop the train image pairs to 512x512 patches and make the data into lmdb format. 11 | 12 | ##### Download the evaluation data (in lmdb format) and place it in ```./datasets/SIDD/val/```: 13 | 14 | * [google drive](https://drive.google.com/file/d/1gZx_K2vmiHalRNOb1aj93KuUQ2guOlLp/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1I9N5fDa4SNP0nuHEy6k-rw?pwd=59d7), 15 | * it should be like ```./datasets/SIDD/val/input_crops.lmdb``` and ```./datasets/SIDD/val/gt_crops.lmdb``` 16 | 17 | 18 | 19 | ### 2. Training 20 | 21 | * NAFNet-SIDD-width32: 22 | 23 | ``` 24 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/SIDD/NAFNet-width32.yml --launcher pytorch 25 | ``` 26 | 27 | * NAFNet-SIDD-width64: 28 | 29 | ``` 30 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/SIDD/NAFNet-width64.yml --launcher pytorch 31 | ``` 32 | 33 | * Baseline-SIDD-width32: 34 | 35 | ``` 36 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/SIDD/Baseline-width32.yml --launcher pytorch 37 | ``` 38 | 39 | * Baseline-SIDD-width64: 40 | 41 | ``` 42 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/SIDD/Baseline-width64.yml --launcher pytorch 43 | ``` 44 | 45 | * 8 gpus by default. Set ```--nproc_per_node``` to # of gpus for distributed validation. 46 | 47 | 48 | 49 | 50 | ### 3. Evaluation 51 | 52 | 53 | ##### Download the pretrain model in ```./experiments/pretrained_models/``` 54 | 55 | * **NAFNet-SIDD-width32**: [google drive](https://drive.google.com/file/d/1lsByk21Xw-6aW7epCwOQxvm6HYCQZPHZ/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1Xses38SWl-7wuyuhaGNhaw?pwd=um97) 56 | 57 | * **NAFNet-SIDD-width64**: [google drive](https://drive.google.com/file/d/14Fht1QQJ2gMlk4N1ERCRuElg8JfjrWWR/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/198kYyVSrY_xZF0jGv9U0sQ?pwd=dton) 58 | 59 | * **Baseline-SIDD-width32**: [google drive](https://drive.google.com/file/d/1NhqVcqkDcYvYgF_P4BOOfo9tuTcKDuhW/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1wkskmCRKhXq6dGa6Ns8D0A?pwd=0rin) 60 | 61 | * **Baseline-SIDD-width64**: [google drive](https://drive.google.com/file/d/1wQ1HHHPhSp70_ledMBZhDhIGjZQs16wO/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1ivruGfSRGfWq5AEB8qc7YQ?pwd=t9w8) 62 | 63 | 64 | ##### Testing on SIDD dataset 65 | 66 | * NAFNet-SIDD-width32: 67 | 68 | ``` 69 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/NAFNet-width32.yml --launcher pytorch 70 | ``` 71 | 72 | * NAFNet-SIDD-width64: 73 | 74 | ``` 75 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/NAFNet-width64.yml --launcher pytorch 76 | ``` 77 | 78 | * Baseline-SIDD-width32: 79 | 80 | ``` 81 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/Baseline-width32.yml --launcher pytorch 82 | ``` 83 | 84 | * Baseline-SIDD-width64: 85 | 86 | ``` 87 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/Baseline-width64.yml --launcher pytorch 88 | ``` 89 | 90 | * Test by a single gpu by default. Set ```--nproc_per_node``` to # of gpus for distributed validation. 91 | 92 | -------------------------------------------------------------------------------- /experiments/pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | ### Pretrained NAFNet Models 2 | --- 3 | 4 | please refer to https://github.com/megvii-research/NAFNet/#results-and-pre-trained-models, and download the pretrained models into ./experiments/pretrained_models 5 | -------------------------------------------------------------------------------- /figures/NAFSSR_arch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/figures/NAFSSR_arch.jpg -------------------------------------------------------------------------------- /figures/NAFSSR_params.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/figures/NAFSSR_params.jpg -------------------------------------------------------------------------------- /figures/PSNR_vs_MACs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/figures/PSNR_vs_MACs.jpg -------------------------------------------------------------------------------- /figures/StereoSR.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/figures/StereoSR.gif -------------------------------------------------------------------------------- /figures/deblur.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/figures/deblur.gif -------------------------------------------------------------------------------- /figures/denoise.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/NAFNet/2b4af71ebe098a92a75910c233a3965a3e93ede4/figures/denoise.gif -------------------------------------------------------------------------------- /options/test/GoPro/Baseline-width32.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: Baseline-GoPro-width32-test 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | 17 | test: 18 | name: gopro-test 19 | type: PairedImageDataset 20 | 21 | dataroot_gt: ./datasets/GoPro/test/target.lmdb 22 | dataroot_lq: ./datasets/GoPro/test/input.lmdb 23 | 24 | io_backend: 25 | type: lmdb 26 | 27 | # network structures 28 | network_g: 29 | type: BaselineLocal 30 | width: 32 31 | enc_blk_nums: [1, 1, 1, 28] 32 | middle_blk_num: 1 33 | dec_blk_nums: [1, 1, 1, 1] 34 | dw_expand: 1 35 | ffn_expand: 2 36 | 37 | # path 38 | path: 39 | pretrain_network_g: experiments/pretrained_models/Baseline-GoPro-width32.pth 40 | strict_load_g: true 41 | resume_state: ~ 42 | 43 | # validation settings 44 | val: 45 | save_img: true 46 | grids: false 47 | 48 | 49 | metrics: 50 | psnr: # metric name, can be arbitrary 51 | type: calculate_psnr 52 | crop_border: 0 53 | test_y_channel: false 54 | ssim: 55 | type: calculate_ssim 56 | crop_border: 0 57 | test_y_channel: false 58 | 59 | # dist training settings 60 | dist_params: 61 | backend: nccl 62 | port: 29500 63 | -------------------------------------------------------------------------------- /options/test/GoPro/Baseline-width64.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: Baseline-GoPro-width64-test 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | 17 | test: 18 | name: gopro-test 19 | type: PairedImageDataset 20 | 21 | dataroot_gt: ./datasets/GoPro/test/target.lmdb 22 | dataroot_lq: ./datasets/GoPro/test/input.lmdb 23 | 24 | io_backend: 25 | type: lmdb 26 | 27 | # network structures 28 | network_g: 29 | type: BaselineLocal 30 | width: 64 31 | enc_blk_nums: [1, 1, 1, 28] 32 | middle_blk_num: 1 33 | dec_blk_nums: [1, 1, 1, 1] 34 | dw_expand: 2 35 | ffn_expand: 2 36 | 37 | # path 38 | path: 39 | pretrain_network_g: experiments/pretrained_models/Baseline-GoPro-width64.pth 40 | strict_load_g: true 41 | resume_state: ~ 42 | 43 | # validation settings 44 | val: 45 | save_img: true 46 | grids: false 47 | 48 | 49 | metrics: 50 | psnr: # metric name, can be arbitrary 51 | type: calculate_psnr 52 | crop_border: 0 53 | test_y_channel: false 54 | ssim: 55 | type: calculate_ssim 56 | crop_border: 0 57 | test_y_channel: false 58 | 59 | # dist training settings 60 | dist_params: 61 | backend: nccl 62 | port: 29500 63 | -------------------------------------------------------------------------------- /options/test/GoPro/NAFNet-width32.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNet-GoPro-width32-test 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | 17 | test: 18 | name: gopro-test 19 | type: PairedImageDataset 20 | 21 | dataroot_gt: ./datasets/GoPro/test/target.lmdb 22 | dataroot_lq: ./datasets/GoPro/test/input.lmdb 23 | 24 | io_backend: 25 | type: lmdb 26 | 27 | # network structures 28 | network_g: 29 | type: NAFNetLocal 30 | width: 32 31 | enc_blk_nums: [1, 1, 1, 28] 32 | middle_blk_num: 1 33 | dec_blk_nums: [1, 1, 1, 1] 34 | 35 | # path 36 | path: 37 | pretrain_network_g: experiments/pretrained_models/NAFNet-GoPro-width32.pth 38 | strict_load_g: true 39 | resume_state: ~ 40 | 41 | # validation settings 42 | val: 43 | save_img: true 44 | grids: false 45 | 46 | 47 | metrics: 48 | psnr: # metric name, can be arbitrary 49 | type: calculate_psnr 50 | crop_border: 0 51 | test_y_channel: false 52 | ssim: 53 | type: calculate_ssim 54 | crop_border: 0 55 | test_y_channel: false 56 | 57 | # dist training settings 58 | dist_params: 59 | backend: nccl 60 | port: 29500 61 | -------------------------------------------------------------------------------- /options/test/GoPro/NAFNet-width64.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNet-GoPro-width64-test 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | 17 | test: 18 | name: gopro-test 19 | type: PairedImageDataset 20 | 21 | dataroot_gt: ./datasets/GoPro/test/target.lmdb 22 | dataroot_lq: ./datasets/GoPro/test/input.lmdb 23 | 24 | io_backend: 25 | type: lmdb 26 | 27 | # network structures 28 | network_g: 29 | type: NAFNetLocal 30 | width: 64 31 | enc_blk_nums: [1, 1, 1, 28] 32 | middle_blk_num: 1 33 | dec_blk_nums: [1, 1, 1, 1] 34 | 35 | # path 36 | path: 37 | pretrain_network_g: experiments/pretrained_models/NAFNet-GoPro-width64.pth 38 | strict_load_g: true 39 | resume_state: ~ 40 | 41 | # validation settings 42 | val: 43 | save_img: true 44 | grids: false 45 | 46 | 47 | metrics: 48 | psnr: # metric name, can be arbitrary 49 | type: calculate_psnr 50 | crop_border: 0 51 | test_y_channel: false 52 | ssim: 53 | type: calculate_ssim 54 | crop_border: 0 55 | test_y_channel: false 56 | 57 | # dist training settings 58 | dist_params: 59 | backend: nccl 60 | port: 29500 61 | -------------------------------------------------------------------------------- /options/test/NAFSSR/NAFSSR-B_2x.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFSSR-B_2x 9 | model_type: ImageRestorationModel 10 | scale: 2 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | test0: 17 | name: KITTI2012 18 | type: PairedStereoImageDataset 19 | dataroot_gt: datasets/StereoSR/test/KITTI2012/hr 20 | dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x2 21 | io_backend: 22 | type: disk 23 | 24 | test1: 25 | name: KITTI2015 26 | type: PairedStereoImageDataset 27 | dataroot_gt: datasets/StereoSR/test/KITTI2015/hr 28 | dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x2 29 | io_backend: 30 | type: disk 31 | 32 | test2: 33 | name: Middlebury 34 | type: PairedStereoImageDataset 35 | dataroot_gt: datasets/StereoSR/test/Middlebury/hr 36 | dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x2 37 | io_backend: 38 | type: disk 39 | 40 | test3: 41 | name: Flickr1024 42 | type: PairedStereoImageDataset 43 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 44 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 45 | io_backend: 46 | type: disk 47 | 48 | # network structures 49 | network_g: 50 | type: NAFSSR 51 | up_scale: 2 52 | width: 96 53 | num_blks: 64 54 | 55 | 56 | # path 57 | path: 58 | pretrain_network_g: experiments/pretrained_models/NAFSSR-B_2x.pth 59 | strict_load_g: true 60 | resume_state: ~ 61 | 62 | # validation settings 63 | val: 64 | save_img: true 65 | grids: false 66 | 67 | metrics: 68 | psnr: # metric name, can be arbitrary 69 | type: calculate_psnr 70 | crop_border: 0 71 | test_y_channel: false 72 | ssim: 73 | type: calculate_skimage_ssim 74 | # psnr_left: # metric name, can be arbitrary 75 | # type: calculate_psnr_left 76 | # crop_border: 0 77 | # test_y_channel: false 78 | # ssim_left: 79 | # type: calculate_skimage_ssim_left 80 | 81 | 82 | # dist training settings 83 | dist_params: 84 | backend: nccl 85 | port: 29500 86 | -------------------------------------------------------------------------------- /options/test/NAFSSR/NAFSSR-B_4x.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFSSR-B_4x 9 | model_type: ImageRestorationModel 10 | scale: 4 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | test0: 17 | name: KITTI2012 18 | type: PairedStereoImageDataset 19 | dataroot_gt: datasets/StereoSR/test/KITTI2012/hr 20 | dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x4 21 | io_backend: 22 | type: disk 23 | 24 | test1: 25 | name: KITTI2015 26 | type: PairedStereoImageDataset 27 | dataroot_gt: datasets/StereoSR/test/KITTI2015/hr 28 | dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x4 29 | io_backend: 30 | type: disk 31 | 32 | test2: 33 | name: Middlebury 34 | type: PairedStereoImageDataset 35 | dataroot_gt: datasets/StereoSR/test/Middlebury/hr 36 | dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x4 37 | io_backend: 38 | type: disk 39 | 40 | test3: 41 | name: Flickr1024 42 | type: PairedStereoImageDataset 43 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 44 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 45 | io_backend: 46 | type: disk 47 | 48 | # network structures 49 | network_g: 50 | type: NAFSSR 51 | up_scale: 4 52 | width: 96 53 | num_blks: 64 54 | 55 | 56 | # path 57 | path: 58 | pretrain_network_g: experiments/pretrained_models/NAFSSR-B_4x.pth 59 | strict_load_g: true 60 | resume_state: ~ 61 | 62 | # validation settings 63 | val: 64 | save_img: true 65 | grids: false 66 | 67 | metrics: 68 | psnr: # metric name, can be arbitrary 69 | type: calculate_psnr 70 | crop_border: 0 71 | test_y_channel: false 72 | ssim: 73 | type: calculate_skimage_ssim 74 | # psnr_left: # metric name, can be arbitrary 75 | # type: calculate_psnr_left 76 | # crop_border: 0 77 | # test_y_channel: false 78 | # ssim_left: 79 | # type: calculate_skimage_ssim_left 80 | 81 | 82 | # dist training settings 83 | dist_params: 84 | backend: nccl 85 | port: 29500 86 | -------------------------------------------------------------------------------- /options/test/NAFSSR/NAFSSR-L_2x.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFSSR-L_2x 9 | model_type: ImageRestorationModel 10 | scale: 2 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | test0: 17 | name: KITTI2012 18 | type: PairedStereoImageDataset 19 | dataroot_gt: datasets/StereoSR/test/KITTI2012/hr 20 | dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x2 21 | io_backend: 22 | type: disk 23 | 24 | test1: 25 | name: KITTI2015 26 | type: PairedStereoImageDataset 27 | dataroot_gt: datasets/StereoSR/test/KITTI2015/hr 28 | dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x2 29 | io_backend: 30 | type: disk 31 | 32 | test2: 33 | name: Middlebury 34 | type: PairedStereoImageDataset 35 | dataroot_gt: datasets/StereoSR/test/Middlebury/hr 36 | dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x2 37 | io_backend: 38 | type: disk 39 | 40 | test3: 41 | name: Flickr1024 42 | type: PairedStereoImageDataset 43 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 44 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 45 | io_backend: 46 | type: disk 47 | 48 | # network structures 49 | network_g: 50 | type: NAFSSR 51 | up_scale: 2 52 | width: 128 53 | num_blks: 128 54 | 55 | 56 | # path 57 | path: 58 | pretrain_network_g: experiments/pretrained_models/NAFSSR-L_2x.pth 59 | strict_load_g: true 60 | resume_state: ~ 61 | 62 | # validation settings 63 | val: 64 | save_img: true 65 | grids: false 66 | 67 | metrics: 68 | psnr: # metric name, can be arbitrary 69 | type: calculate_psnr 70 | crop_border: 0 71 | test_y_channel: false 72 | ssim: 73 | type: calculate_skimage_ssim 74 | # psnr_left: # metric name, can be arbitrary 75 | # type: calculate_psnr_left 76 | # crop_border: 0 77 | # test_y_channel: false 78 | # ssim_left: 79 | # type: calculate_skimage_ssim_left 80 | 81 | 82 | # dist training settings 83 | dist_params: 84 | backend: nccl 85 | port: 29500 86 | -------------------------------------------------------------------------------- /options/test/NAFSSR/NAFSSR-L_4x.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFSSR-L_4x 9 | model_type: ImageRestorationModel 10 | scale: 4 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | test0: 17 | name: KITTI2012 18 | type: PairedStereoImageDataset 19 | dataroot_gt: datasets/StereoSR/test/KITTI2012/hr 20 | dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x4 21 | io_backend: 22 | type: disk 23 | 24 | test1: 25 | name: KITTI2015 26 | type: PairedStereoImageDataset 27 | dataroot_gt: datasets/StereoSR/test/KITTI2015/hr 28 | dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x4 29 | io_backend: 30 | type: disk 31 | 32 | test2: 33 | name: Middlebury 34 | type: PairedStereoImageDataset 35 | dataroot_gt: datasets/StereoSR/test/Middlebury/hr 36 | dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x4 37 | io_backend: 38 | type: disk 39 | 40 | test3: 41 | name: Flickr1024 42 | type: PairedStereoImageDataset 43 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 44 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 45 | io_backend: 46 | type: disk 47 | 48 | # network structures 49 | network_g: 50 | type: NAFSSR 51 | up_scale: 4 52 | width: 128 53 | num_blks: 128 54 | 55 | 56 | # path 57 | path: 58 | pretrain_network_g: experiments/pretrained_models/NAFSSR-L_4x.pth 59 | strict_load_g: true 60 | resume_state: ~ 61 | 62 | # validation settings 63 | val: 64 | save_img: true 65 | grids: false 66 | 67 | metrics: 68 | psnr: # metric name, can be arbitrary 69 | type: calculate_psnr 70 | crop_border: 0 71 | test_y_channel: false 72 | ssim: 73 | type: calculate_skimage_ssim 74 | # psnr_left: # metric name, can be arbitrary 75 | # type: calculate_psnr_left 76 | # crop_border: 0 77 | # test_y_channel: false 78 | # ssim_left: 79 | # type: calculate_skimage_ssim_left 80 | 81 | 82 | # dist training settings 83 | dist_params: 84 | backend: nccl 85 | port: 29500 86 | -------------------------------------------------------------------------------- /options/test/NAFSSR/NAFSSR-S_2x.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFSSR-S_2x 9 | model_type: ImageRestorationModel 10 | scale: 2 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | test0: 17 | name: KITTI2012 18 | type: PairedStereoImageDataset 19 | dataroot_gt: datasets/StereoSR/test/KITTI2012/hr 20 | dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x2 21 | io_backend: 22 | type: disk 23 | 24 | test1: 25 | name: KITTI2015 26 | type: PairedStereoImageDataset 27 | dataroot_gt: datasets/StereoSR/test/KITTI2015/hr 28 | dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x2 29 | io_backend: 30 | type: disk 31 | 32 | test2: 33 | name: Middlebury 34 | type: PairedStereoImageDataset 35 | dataroot_gt: datasets/StereoSR/test/Middlebury/hr 36 | dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x2 37 | io_backend: 38 | type: disk 39 | 40 | test3: 41 | name: Flickr1024 42 | type: PairedStereoImageDataset 43 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 44 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 45 | io_backend: 46 | type: disk 47 | 48 | # network structures 49 | network_g: 50 | type: NAFSSR 51 | up_scale: 2 52 | width: 64 53 | num_blks: 32 54 | 55 | 56 | # path 57 | path: 58 | pretrain_network_g: experiments/pretrained_models/NAFSSR-S_2x.pth 59 | strict_load_g: true 60 | resume_state: ~ 61 | 62 | # validation settings 63 | val: 64 | save_img: true 65 | grids: false 66 | 67 | metrics: 68 | psnr: # metric name, can be arbitrary 69 | type: calculate_psnr 70 | crop_border: 0 71 | test_y_channel: false 72 | ssim: 73 | type: calculate_skimage_ssim 74 | # psnr_left: # metric name, can be arbitrary 75 | # type: calculate_psnr_left 76 | # crop_border: 0 77 | # test_y_channel: false 78 | # ssim_left: 79 | # type: calculate_skimage_ssim_left 80 | 81 | 82 | # dist training settings 83 | dist_params: 84 | backend: nccl 85 | port: 29500 86 | -------------------------------------------------------------------------------- /options/test/NAFSSR/NAFSSR-S_4x.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFSSR-S_4x 9 | model_type: ImageRestorationModel 10 | scale: 4 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | test0: 17 | name: KITTI2012 18 | type: PairedStereoImageDataset 19 | dataroot_gt: datasets/StereoSR/test/KITTI2012/hr 20 | dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x4 21 | io_backend: 22 | type: disk 23 | 24 | test1: 25 | name: KITTI2015 26 | type: PairedStereoImageDataset 27 | dataroot_gt: datasets/StereoSR/test/KITTI2015/hr 28 | dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x4 29 | io_backend: 30 | type: disk 31 | 32 | test2: 33 | name: Middlebury 34 | type: PairedStereoImageDataset 35 | dataroot_gt: datasets/StereoSR/test/Middlebury/hr 36 | dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x4 37 | io_backend: 38 | type: disk 39 | 40 | test3: 41 | name: Flickr1024 42 | type: PairedStereoImageDataset 43 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 44 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 45 | io_backend: 46 | type: disk 47 | 48 | # network structures 49 | network_g: 50 | type: NAFSSR 51 | up_scale: 4 52 | width: 64 53 | num_blks: 32 54 | 55 | 56 | # path 57 | path: 58 | pretrain_network_g: experiments/pretrained_models/NAFSSR-S_4x.pth 59 | strict_load_g: true 60 | resume_state: ~ 61 | 62 | # validation settings 63 | val: 64 | save_img: true 65 | grids: false 66 | 67 | metrics: 68 | psnr: # metric name, can be arbitrary 69 | type: calculate_psnr 70 | crop_border: 0 71 | test_y_channel: false 72 | ssim: 73 | type: calculate_skimage_ssim 74 | # psnr_left: # metric name, can be arbitrary 75 | # type: calculate_psnr_left 76 | # crop_border: 0 77 | # test_y_channel: false 78 | # ssim_left: 79 | # type: calculate_skimage_ssim_left 80 | 81 | 82 | # dist training settings 83 | dist_params: 84 | backend: nccl 85 | port: 29500 86 | -------------------------------------------------------------------------------- /options/test/NAFSSR/NAFSSR-T_2x.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFSSR-T_2x 9 | model_type: ImageRestorationModel 10 | scale: 2 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | test0: 17 | name: KITTI2012 18 | type: PairedStereoImageDataset 19 | dataroot_gt: datasets/StereoSR/test/KITTI2012/hr 20 | dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x2 21 | io_backend: 22 | type: disk 23 | 24 | test1: 25 | name: KITTI2015 26 | type: PairedStereoImageDataset 27 | dataroot_gt: datasets/StereoSR/test/KITTI2015/hr 28 | dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x2 29 | io_backend: 30 | type: disk 31 | 32 | test2: 33 | name: Middlebury 34 | type: PairedStereoImageDataset 35 | dataroot_gt: datasets/StereoSR/test/Middlebury/hr 36 | dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x2 37 | io_backend: 38 | type: disk 39 | 40 | test3: 41 | name: Flickr1024 42 | type: PairedStereoImageDataset 43 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 44 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 45 | io_backend: 46 | type: disk 47 | 48 | # network structures 49 | network_g: 50 | type: NAFSSR 51 | up_scale: 2 52 | width: 48 53 | num_blks: 16 54 | 55 | 56 | # path 57 | path: 58 | pretrain_network_g: experiments/pretrained_models/NAFSSR-T_2x.pth 59 | strict_load_g: true 60 | resume_state: ~ 61 | 62 | # validation settings 63 | val: 64 | save_img: true 65 | grids: false 66 | 67 | metrics: 68 | psnr: # metric name, can be arbitrary 69 | type: calculate_psnr 70 | crop_border: 0 71 | test_y_channel: false 72 | ssim: 73 | type: calculate_skimage_ssim 74 | # psnr_left: # metric name, can be arbitrary 75 | # type: calculate_psnr_left 76 | # crop_border: 0 77 | # test_y_channel: false 78 | # ssim_left: 79 | # type: calculate_skimage_ssim_left 80 | 81 | 82 | # dist training settings 83 | dist_params: 84 | backend: nccl 85 | port: 29500 86 | -------------------------------------------------------------------------------- /options/test/NAFSSR/NAFSSR-T_4x.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFSSR-T_4x 9 | model_type: ImageRestorationModel 10 | scale: 4 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | test0: 17 | name: KITTI2012 18 | type: PairedStereoImageDataset 19 | dataroot_gt: datasets/StereoSR/test/KITTI2012/hr 20 | dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x4 21 | io_backend: 22 | type: disk 23 | 24 | test1: 25 | name: KITTI2015 26 | type: PairedStereoImageDataset 27 | dataroot_gt: datasets/StereoSR/test/KITTI2015/hr 28 | dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x4 29 | io_backend: 30 | type: disk 31 | 32 | test2: 33 | name: Middlebury 34 | type: PairedStereoImageDataset 35 | dataroot_gt: datasets/StereoSR/test/Middlebury/hr 36 | dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x4 37 | io_backend: 38 | type: disk 39 | 40 | test3: 41 | name: Flickr1024 42 | type: PairedStereoImageDataset 43 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 44 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 45 | io_backend: 46 | type: disk 47 | 48 | # network structures 49 | network_g: 50 | type: NAFSSR 51 | up_scale: 4 52 | width: 48 53 | num_blks: 16 54 | 55 | 56 | # path 57 | path: 58 | pretrain_network_g: experiments/pretrained_models/NAFSSR-T_4x.pth 59 | strict_load_g: true 60 | resume_state: ~ 61 | 62 | # validation settings 63 | val: 64 | save_img: true 65 | grids: false 66 | 67 | metrics: 68 | psnr: # metric name, can be arbitrary 69 | type: calculate_psnr 70 | crop_border: 0 71 | test_y_channel: false 72 | ssim: 73 | type: calculate_skimage_ssim 74 | # psnr_left: # metric name, can be arbitrary 75 | # type: calculate_psnr_left 76 | # crop_border: 0 77 | # test_y_channel: false 78 | # ssim_left: 79 | # type: calculate_skimage_ssim_left 80 | 81 | 82 | # dist training settings 83 | dist_params: 84 | backend: nccl 85 | port: 29500 86 | -------------------------------------------------------------------------------- /options/test/REDS/NAFNet-width64.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNet-REDS-width64-test 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | 17 | test: 18 | name: reds-val300-test 19 | type: PairedImageDataset 20 | 21 | dataroot_gt: ./datasets/REDS/val/sharp_300.lmdb 22 | dataroot_lq: ./datasets/REDS/val/blur_300.lmdb 23 | 24 | io_backend: 25 | type: lmdb 26 | 27 | # network structures 28 | network_g: 29 | type: NAFNetLocal 30 | width: 64 31 | enc_blk_nums: [1, 1, 1, 28] 32 | middle_blk_num: 1 33 | dec_blk_nums: [1, 1, 1, 1] 34 | 35 | # path 36 | path: 37 | pretrain_network_g: experiments/pretrained_models/NAFNet-REDS-width64.pth 38 | strict_load_g: true 39 | resume_state: ~ 40 | 41 | # validation settings 42 | val: 43 | save_img: true 44 | grids: false 45 | 46 | 47 | metrics: 48 | psnr: # metric name, can be arbitrary 49 | type: calculate_psnr 50 | crop_border: 0 51 | test_y_channel: false 52 | ssim: 53 | type: calculate_ssim 54 | crop_border: 0 55 | test_y_channel: false 56 | 57 | # dist training settings 58 | dist_params: 59 | backend: nccl 60 | port: 29500 61 | -------------------------------------------------------------------------------- /options/test/SIDD/Baseline-width32.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: Baseline-SIDD-width32-test 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | 17 | val: 18 | name: SIDD_val 19 | type: PairedImageDataset 20 | 21 | dataroot_gt: ./datasets/SIDD/val/gt_crops.lmdb 22 | dataroot_lq: ./datasets/SIDD/val/input_crops.lmdb 23 | 24 | io_backend: 25 | type: lmdb 26 | 27 | # network structures 28 | network_g: 29 | type: Baseline 30 | width: 32 31 | enc_blk_nums: [2, 2, 4, 8] 32 | middle_blk_num: 12 33 | dec_blk_nums: [2, 2, 2, 2] 34 | dw_expand: 1 35 | ffn_expand: 2 36 | 37 | # path 38 | path: 39 | pretrain_network_g: experiments/pretrained_models/Baseline-SIDD-width32.pth 40 | strict_load_g: true 41 | resume_state: ~ 42 | 43 | # validation settings 44 | val: 45 | save_img: true 46 | grids: false 47 | use_image: false 48 | 49 | metrics: 50 | psnr: # metric name, can be arbitrary 51 | type: calculate_psnr 52 | crop_border: 0 53 | test_y_channel: false 54 | ssim: 55 | type: calculate_ssim 56 | crop_border: 0 57 | test_y_channel: false 58 | 59 | # dist training settings 60 | dist_params: 61 | backend: nccl 62 | port: 29500 63 | -------------------------------------------------------------------------------- /options/test/SIDD/Baseline-width64.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: Baseline-SIDD-width64-test 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | 17 | val: 18 | name: SIDD_val 19 | type: PairedImageDataset 20 | 21 | dataroot_gt: ./datasets/SIDD/val/gt_crops.lmdb 22 | dataroot_lq: ./datasets/SIDD/val/input_crops.lmdb 23 | 24 | io_backend: 25 | type: lmdb 26 | 27 | # network structures 28 | network_g: 29 | type: Baseline 30 | width: 64 31 | enc_blk_nums: [2, 2, 4, 8] 32 | middle_blk_num: 12 33 | dec_blk_nums: [2, 2, 2, 2] 34 | dw_expand: 2 35 | ffn_expand: 2 36 | 37 | # path 38 | path: 39 | pretrain_network_g: experiments/pretrained_models/Baseline-SIDD-width64.pth 40 | strict_load_g: true 41 | resume_state: ~ 42 | 43 | # validation settings 44 | val: 45 | save_img: true 46 | grids: false 47 | use_image: false 48 | 49 | metrics: 50 | psnr: # metric name, can be arbitrary 51 | type: calculate_psnr 52 | crop_border: 0 53 | test_y_channel: false 54 | ssim: 55 | type: calculate_ssim 56 | crop_border: 0 57 | test_y_channel: false 58 | 59 | # dist training settings 60 | dist_params: 61 | backend: nccl 62 | port: 29500 63 | -------------------------------------------------------------------------------- /options/test/SIDD/NAFNet-width32.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNet-SIDD-width32-test 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | 17 | val: 18 | name: SIDD_val 19 | type: PairedImageDataset 20 | 21 | dataroot_gt: ./datasets/SIDD/val/gt_crops.lmdb 22 | dataroot_lq: ./datasets/SIDD/val/input_crops.lmdb 23 | 24 | io_backend: 25 | type: lmdb 26 | 27 | # network structures 28 | network_g: 29 | type: NAFNet 30 | width: 32 31 | enc_blk_nums: [2, 2, 4, 8] 32 | middle_blk_num: 12 33 | dec_blk_nums: [2, 2, 2, 2] 34 | 35 | # path 36 | path: 37 | pretrain_network_g: experiments/pretrained_models/NAFNet-SIDD-width32.pth 38 | strict_load_g: true 39 | resume_state: ~ 40 | 41 | # validation settings 42 | val: 43 | save_img: true 44 | grids: false 45 | use_image: false 46 | 47 | metrics: 48 | psnr: # metric name, can be arbitrary 49 | type: calculate_psnr 50 | crop_border: 0 51 | test_y_channel: false 52 | ssim: 53 | type: calculate_ssim 54 | crop_border: 0 55 | test_y_channel: false 56 | 57 | # dist training settings 58 | dist_params: 59 | backend: nccl 60 | port: 29500 61 | -------------------------------------------------------------------------------- /options/test/SIDD/NAFNet-width64.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNet-SIDD-width64-test 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 1 # set num_gpu: 0 for cpu mode 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | 17 | val: 18 | name: SIDD_val 19 | type: PairedImageDataset 20 | 21 | dataroot_gt: ./datasets/SIDD/val/gt_crops.lmdb 22 | dataroot_lq: ./datasets/SIDD/val/input_crops.lmdb 23 | 24 | io_backend: 25 | type: lmdb 26 | 27 | # network structures 28 | network_g: 29 | type: NAFNet 30 | width: 64 31 | enc_blk_nums: [2, 2, 4, 8] 32 | middle_blk_num: 12 33 | dec_blk_nums: [2, 2, 2, 2] 34 | 35 | # path 36 | path: 37 | pretrain_network_g: experiments/pretrained_models/NAFNet-SIDD-width64.pth 38 | strict_load_g: true 39 | resume_state: ~ 40 | 41 | # validation settings 42 | val: 43 | save_img: true 44 | grids: false 45 | use_image: false 46 | 47 | metrics: 48 | psnr: # metric name, can be arbitrary 49 | type: calculate_psnr 50 | crop_border: 0 51 | test_y_channel: false 52 | ssim: 53 | type: calculate_ssim 54 | crop_border: 0 55 | test_y_channel: false 56 | 57 | # dist training settings 58 | dist_params: 59 | backend: nccl 60 | port: 29500 61 | -------------------------------------------------------------------------------- /options/train/GoPro/Baseline-width32.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: Baseline-GoPro-width32 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 8 12 | manual_seed: 42 13 | 14 | datasets: 15 | train: 16 | name: gopro-train 17 | type: PairedImageDataset 18 | dataroot_gt: ./datasets/GoPro/train/sharp_crops.lmdb 19 | dataroot_lq: ./datasets/GoPro/train/blur_crops.lmdb 20 | 21 | filename_tmpl: '{}' 22 | io_backend: 23 | type: lmdb 24 | 25 | gt_size: 256 26 | use_flip: true 27 | use_rot: true 28 | 29 | # data loader 30 | use_shuffle: true 31 | num_worker_per_gpu: 4 32 | batch_size_per_gpu: 4 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: gopro-test 38 | type: PairedImageDataset 39 | dataroot_gt: ./datasets/GoPro/test/target.lmdb 40 | dataroot_lq: ./datasets/GoPro/test/input.lmdb 41 | io_backend: 42 | type: lmdb 43 | 44 | 45 | network_g: 46 | type: BaselineLocal 47 | width: 32 48 | enc_blk_nums: [1, 1, 1, 28] 49 | middle_blk_num: 1 50 | dec_blk_nums: [1, 1, 1, 1] 51 | dw_expand: 1 52 | ffn_expand: 2 53 | 54 | # path 55 | path: 56 | pretrain_network_g: ~ 57 | strict_load_g: true 58 | resume_state: ~ 59 | 60 | # training settings 61 | train: 62 | optim_g: 63 | type: AdamW 64 | lr: !!float 1e-3 65 | weight_decay: !!float 1e-3 66 | betas: [0.9, 0.9] 67 | 68 | scheduler: 69 | type: TrueCosineAnnealingLR 70 | T_max: 200000 71 | eta_min: !!float 1e-7 72 | 73 | total_iter: 200000 74 | warmup_iter: -1 # no warm up 75 | 76 | # losses 77 | pixel_opt: 78 | type: PSNRLoss 79 | loss_weight: 1 80 | reduction: mean 81 | 82 | # validation settings 83 | val: 84 | val_freq: !!float 2e4 85 | save_img: false 86 | 87 | 88 | metrics: 89 | psnr: # metric name, can be arbitrary 90 | type: calculate_psnr 91 | crop_border: 0 92 | test_y_channel: false 93 | ssim: 94 | type: calculate_ssim 95 | crop_border: 0 96 | test_y_channel: false 97 | 98 | # logging settings 99 | logger: 100 | print_freq: 200 101 | save_checkpoint_freq: !!float 5e3 102 | use_tb_logger: true 103 | wandb: 104 | project: ~ 105 | resume_id: ~ 106 | 107 | # dist training settings 108 | dist_params: 109 | backend: nccl 110 | port: 29500 111 | -------------------------------------------------------------------------------- /options/train/GoPro/Baseline-width64.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: Baseline-GoPro-width64 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: gopro-train 17 | type: PairedImageDataset 18 | dataroot_gt: ./datasets/GoPro/train/sharp_crops.lmdb 19 | dataroot_lq: ./datasets/GoPro/train/blur_crops.lmdb 20 | 21 | filename_tmpl: '{}' 22 | io_backend: 23 | type: lmdb 24 | 25 | gt_size: 256 26 | use_flip: true 27 | use_rot: true 28 | 29 | # data loader 30 | use_shuffle: true 31 | num_worker_per_gpu: 8 32 | batch_size_per_gpu: 8 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: gopro-test 38 | type: PairedImageDataset 39 | dataroot_gt: ./datasets/GoPro/test/target.lmdb 40 | dataroot_lq: ./datasets/GoPro/test/input.lmdb 41 | io_backend: 42 | type: lmdb 43 | 44 | 45 | network_g: 46 | type: BaselineLocal 47 | width: 64 48 | enc_blk_nums: [1, 1, 1, 28] 49 | middle_blk_num: 1 50 | dec_blk_nums: [1, 1, 1, 1] 51 | dw_expand: 2 52 | ffn_expand: 2 53 | 54 | # path 55 | path: 56 | pretrain_network_g: ~ 57 | strict_load_g: true 58 | resume_state: ~ 59 | 60 | # training settings 61 | train: 62 | optim_g: 63 | type: AdamW 64 | lr: !!float 1e-3 65 | weight_decay: 0. 66 | betas: [0.9, 0.9] 67 | 68 | scheduler: 69 | type: TrueCosineAnnealingLR 70 | T_max: 400000 71 | eta_min: !!float 1e-7 72 | 73 | total_iter: 400000 74 | warmup_iter: -1 # no warm up 75 | 76 | # losses 77 | pixel_opt: 78 | type: PSNRLoss 79 | loss_weight: 1 80 | reduction: mean 81 | 82 | # validation settings 83 | val: 84 | val_freq: !!float 2e4 85 | save_img: false 86 | 87 | 88 | metrics: 89 | psnr: # metric name, can be arbitrary 90 | type: calculate_psnr 91 | crop_border: 0 92 | test_y_channel: false 93 | ssim: 94 | type: calculate_ssim 95 | crop_border: 0 96 | test_y_channel: false 97 | 98 | # logging settings 99 | logger: 100 | print_freq: 200 101 | save_checkpoint_freq: !!float 5e3 102 | use_tb_logger: true 103 | wandb: 104 | project: ~ 105 | resume_id: ~ 106 | 107 | # dist training settings 108 | dist_params: 109 | backend: nccl 110 | port: 29500 111 | -------------------------------------------------------------------------------- /options/train/GoPro/NAFNet-width32.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNet-GoPro-width32 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 8 12 | manual_seed: 42 13 | 14 | datasets: 15 | train: 16 | name: gopro-train 17 | type: PairedImageDataset 18 | dataroot_gt: ./datasets/GoPro/train/sharp_crops.lmdb 19 | dataroot_lq: ./datasets/GoPro/train/blur_crops.lmdb 20 | 21 | filename_tmpl: '{}' 22 | io_backend: 23 | type: lmdb 24 | 25 | gt_size: 256 26 | use_flip: true 27 | use_rot: true 28 | 29 | # data loader 30 | use_shuffle: true 31 | num_worker_per_gpu: 4 32 | batch_size_per_gpu: 4 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: gopro-test 38 | type: PairedImageDataset 39 | dataroot_gt: ./datasets/GoPro/test/target.lmdb 40 | dataroot_lq: ./datasets/GoPro/test/input.lmdb 41 | io_backend: 42 | type: lmdb 43 | 44 | 45 | network_g: 46 | type: NAFNetLocal 47 | width: 32 48 | enc_blk_nums: [1, 1, 1, 28] 49 | middle_blk_num: 1 50 | dec_blk_nums: [1, 1, 1, 1] 51 | 52 | # path 53 | path: 54 | pretrain_network_g: ~ 55 | strict_load_g: true 56 | resume_state: ~ 57 | 58 | # training settings 59 | train: 60 | optim_g: 61 | type: AdamW 62 | lr: !!float 1e-3 63 | weight_decay: !!float 1e-3 64 | betas: [0.9, 0.9] 65 | 66 | scheduler: 67 | type: TrueCosineAnnealingLR 68 | T_max: 200000 69 | eta_min: !!float 1e-7 70 | 71 | total_iter: 200000 72 | warmup_iter: -1 # no warm up 73 | 74 | # losses 75 | pixel_opt: 76 | type: PSNRLoss 77 | loss_weight: 1 78 | reduction: mean 79 | 80 | # validation settings 81 | val: 82 | val_freq: !!float 2e4 83 | save_img: false 84 | 85 | 86 | metrics: 87 | psnr: # metric name, can be arbitrary 88 | type: calculate_psnr 89 | crop_border: 0 90 | test_y_channel: false 91 | ssim: 92 | type: calculate_ssim 93 | crop_border: 0 94 | test_y_channel: false 95 | 96 | # logging settings 97 | logger: 98 | print_freq: 200 99 | save_checkpoint_freq: !!float 5e3 100 | use_tb_logger: true 101 | wandb: 102 | project: ~ 103 | resume_id: ~ 104 | 105 | # dist training settings 106 | dist_params: 107 | backend: nccl 108 | port: 29500 109 | -------------------------------------------------------------------------------- /options/train/GoPro/NAFNet-width64.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNet-GoPro-width64 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: gopro-train 17 | type: PairedImageDataset 18 | dataroot_gt: ./datasets/GoPro/train/sharp_crops.lmdb 19 | dataroot_lq: ./datasets/GoPro/train/blur_crops.lmdb 20 | 21 | filename_tmpl: '{}' 22 | io_backend: 23 | type: lmdb 24 | 25 | gt_size: 256 26 | use_flip: true 27 | use_rot: true 28 | 29 | # data loader 30 | use_shuffle: true 31 | num_worker_per_gpu: 8 32 | batch_size_per_gpu: 8 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: gopro-test 38 | type: PairedImageDataset 39 | dataroot_gt: ./datasets/GoPro/test/target.lmdb 40 | dataroot_lq: ./datasets/GoPro/test/input.lmdb 41 | io_backend: 42 | type: lmdb 43 | 44 | 45 | network_g: 46 | type: NAFNetLocal 47 | width: 64 48 | enc_blk_nums: [1, 1, 1, 28] 49 | middle_blk_num: 1 50 | dec_blk_nums: [1, 1, 1, 1] 51 | 52 | # path 53 | path: 54 | pretrain_network_g: ~ 55 | strict_load_g: true 56 | resume_state: ~ 57 | 58 | # training settings 59 | train: 60 | optim_g: 61 | type: AdamW 62 | lr: !!float 1e-3 63 | weight_decay: !!float 1e-3 64 | betas: [0.9, 0.9] 65 | 66 | scheduler: 67 | type: TrueCosineAnnealingLR 68 | T_max: 400000 69 | eta_min: !!float 1e-7 70 | 71 | total_iter: 400000 72 | warmup_iter: -1 # no warm up 73 | 74 | # losses 75 | pixel_opt: 76 | type: PSNRLoss 77 | loss_weight: 1 78 | reduction: mean 79 | 80 | # validation settings 81 | val: 82 | val_freq: !!float 2e4 83 | save_img: false 84 | 85 | 86 | metrics: 87 | psnr: # metric name, can be arbitrary 88 | type: calculate_psnr 89 | crop_border: 0 90 | test_y_channel: false 91 | ssim: 92 | type: calculate_ssim 93 | crop_border: 0 94 | test_y_channel: false 95 | 96 | # logging settings 97 | logger: 98 | print_freq: 200 99 | save_checkpoint_freq: !!float 5e3 100 | use_tb_logger: true 101 | wandb: 102 | project: ~ 103 | resume_id: ~ 104 | 105 | # dist training settings 106 | dist_params: 107 | backend: nccl 108 | port: 29500 109 | -------------------------------------------------------------------------------- /options/train/NAFSSR/NAFSSR-B_x2.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNetSR-B_x2 9 | model_type: ImageRestorationModel 10 | scale: 2 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: Flickr1024-sr-train 17 | type: PairedStereoImageDataset 18 | dataroot_gt: datasets/StereoSR/patches_x2/ 19 | dataroot_lq: datasets/StereoSR/patches_x2/ 20 | io_backend: 21 | type: disk 22 | 23 | gt_size_h: 60 24 | gt_size_w: 180 25 | use_hflip: true 26 | use_vflip: true 27 | use_rot: false 28 | flip_RGB: true 29 | 30 | # data loader 31 | use_shuffle: true 32 | num_worker_per_gpu: 4 33 | batch_size_per_gpu: 4 34 | dataset_enlarge_ratio: 1 35 | prefetch_mode: ~ 36 | 37 | val: 38 | name: Flickr1024-sr-test 39 | type: PairedStereoImageDataset 40 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 41 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: NAFSSR 48 | up_scale: 2 49 | width: 96 50 | num_blks: 64 51 | drop_path_rate: 0.2 52 | train_size: [1, 6, 30, 90] 53 | drop_out_rate: 0. 54 | 55 | # path 56 | path: 57 | pretrain_network_g: ~ 58 | strict_load_g: true 59 | resume_state: ~ 60 | 61 | # training settings 62 | train: 63 | optim_g: 64 | type: AdamW 65 | lr: !!float 3e-3 66 | weight_decay: !!float 0 67 | betas: [0.9, 0.9] 68 | 69 | scheduler: 70 | type: TrueCosineAnnealingLR 71 | T_max: 100000 72 | eta_min: !!float 1e-7 73 | 74 | total_iter: 100000 75 | warmup_iter: -1 # no warm up 76 | mixup: false 77 | 78 | # losses 79 | pixel_opt: 80 | type: MSELoss 81 | loss_weight: 1. 82 | reduction: mean 83 | 84 | # validation settings 85 | val: 86 | val_freq: !!float 2e4 87 | save_img: false 88 | trans_num: 1 89 | 90 | max_minibatch: 1 91 | 92 | metrics: 93 | psnr: # metric name, can be arbitrary 94 | type: calculate_psnr 95 | crop_border: 0 96 | test_y_channel: false 97 | ssim: 98 | type: calculate_skimage_ssim 99 | 100 | # logging settings 101 | logger: 102 | print_freq: 200 103 | save_checkpoint_freq: !!float 1e4 104 | use_tb_logger: true 105 | wandb: 106 | project: ~ 107 | resume_id: ~ 108 | 109 | # dist training settings 110 | dist_params: 111 | backend: nccl 112 | port: 29500 113 | -------------------------------------------------------------------------------- /options/train/NAFSSR/NAFSSR-B_x4.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNetSR-B_x4 9 | model_type: ImageRestorationModel 10 | scale: 4 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: Flickr1024-sr-train 17 | type: PairedStereoImageDataset 18 | dataroot_gt: datasets/StereoSR/patches_x4/ 19 | dataroot_lq: datasets/StereoSR/patches_x4/ 20 | io_backend: 21 | type: disk 22 | 23 | gt_size_h: 120 24 | gt_size_w: 360 25 | use_hflip: true 26 | use_vflip: true 27 | use_rot: false 28 | flip_RGB: true 29 | 30 | # data loader 31 | use_shuffle: true 32 | num_worker_per_gpu: 4 33 | batch_size_per_gpu: 4 34 | dataset_enlarge_ratio: 1 35 | prefetch_mode: ~ 36 | 37 | val: 38 | name: Flickr1024-sr-test 39 | type: PairedStereoImageDataset 40 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 41 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: NAFSSR 48 | up_scale: 4 49 | width: 96 50 | num_blks: 64 51 | drop_path_rate: 0.2 52 | train_size: [1, 6, 30, 90] 53 | drop_out_rate: 0. 54 | 55 | # path 56 | path: 57 | pretrain_network_g: ~ 58 | strict_load_g: true 59 | resume_state: ~ 60 | 61 | # training settings 62 | train: 63 | optim_g: 64 | type: AdamW 65 | lr: !!float 3e-3 66 | weight_decay: !!float 0 67 | betas: [0.9, 0.9] 68 | 69 | scheduler: 70 | type: TrueCosineAnnealingLR 71 | T_max: 100000 72 | eta_min: !!float 1e-7 73 | 74 | total_iter: 100000 75 | warmup_iter: -1 # no warm up 76 | mixup: false 77 | 78 | # losses 79 | pixel_opt: 80 | type: MSELoss 81 | loss_weight: 1. 82 | reduction: mean 83 | 84 | # validation settings 85 | val: 86 | val_freq: !!float 2e4 87 | save_img: false 88 | trans_num: 1 89 | 90 | max_minibatch: 1 91 | 92 | metrics: 93 | psnr: # metric name, can be arbitrary 94 | type: calculate_psnr 95 | crop_border: 0 96 | test_y_channel: false 97 | ssim: 98 | type: calculate_skimage_ssim 99 | 100 | # logging settings 101 | logger: 102 | print_freq: 200 103 | save_checkpoint_freq: !!float 1e4 104 | use_tb_logger: true 105 | wandb: 106 | project: ~ 107 | resume_id: ~ 108 | 109 | # dist training settings 110 | dist_params: 111 | backend: nccl 112 | port: 29500 113 | -------------------------------------------------------------------------------- /options/train/NAFSSR/NAFSSR-L_x2.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNetSR-L_x2 9 | model_type: ImageRestorationModel 10 | scale: 2 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: Flickr1024-sr-train 17 | type: PairedStereoImageDataset 18 | dataroot_gt: datasets/StereoSR/patches_x2/ 19 | dataroot_lq: datasets/StereoSR/patches_x2/ 20 | io_backend: 21 | type: disk 22 | 23 | gt_size_h: 60 24 | gt_size_w: 180 25 | use_hflip: true 26 | use_vflip: true 27 | use_rot: false 28 | flip_RGB: true 29 | 30 | # data loader 31 | use_shuffle: true 32 | num_worker_per_gpu: 4 33 | batch_size_per_gpu: 4 34 | dataset_enlarge_ratio: 1 35 | prefetch_mode: ~ 36 | 37 | val: 38 | name: Flickr1024-sr-test 39 | type: PairedStereoImageDataset 40 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 41 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: NAFSSR 48 | up_scale: 2 49 | width: 128 50 | num_blks: 128 51 | drop_path_rate: 0.3 52 | train_size: [1, 6, 30, 90] 53 | drop_out_rate: 0. 54 | 55 | # path 56 | path: 57 | pretrain_network_g: ~ 58 | strict_load_g: true 59 | resume_state: ~ 60 | 61 | # training settings 62 | train: 63 | optim_g: 64 | type: AdamW 65 | lr: !!float 3e-3 66 | weight_decay: !!float 0 67 | betas: [0.9, 0.9] 68 | 69 | scheduler: 70 | type: TrueCosineAnnealingLR 71 | T_max: 100000 72 | eta_min: !!float 1e-7 73 | 74 | total_iter: 100000 75 | warmup_iter: -1 # no warm up 76 | mixup: false 77 | 78 | # losses 79 | pixel_opt: 80 | type: MSELoss 81 | loss_weight: 1. 82 | reduction: mean 83 | 84 | # validation settings 85 | val: 86 | val_freq: !!float 2e4 87 | save_img: false 88 | trans_num: 1 89 | 90 | max_minibatch: 1 91 | 92 | metrics: 93 | psnr: # metric name, can be arbitrary 94 | type: calculate_psnr 95 | crop_border: 0 96 | test_y_channel: false 97 | ssim: 98 | type: calculate_skimage_ssim 99 | 100 | # logging settings 101 | logger: 102 | print_freq: 200 103 | save_checkpoint_freq: !!float 1e4 104 | use_tb_logger: true 105 | wandb: 106 | project: ~ 107 | resume_id: ~ 108 | 109 | # dist training settings 110 | dist_params: 111 | backend: nccl 112 | port: 29500 113 | -------------------------------------------------------------------------------- /options/train/NAFSSR/NAFSSR-L_x4.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNetSR-L_x4 9 | model_type: ImageRestorationModel 10 | scale: 4 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: Flickr1024-sr-train 17 | type: PairedStereoImageDataset 18 | dataroot_gt: datasets/StereoSR/patches_x4/ 19 | dataroot_lq: datasets/StereoSR/patches_x4/ 20 | io_backend: 21 | type: disk 22 | 23 | gt_size_h: 120 24 | gt_size_w: 360 25 | use_hflip: true 26 | use_vflip: true 27 | use_rot: false 28 | flip_RGB: true 29 | 30 | # data loader 31 | use_shuffle: true 32 | num_worker_per_gpu: 4 33 | batch_size_per_gpu: 4 34 | dataset_enlarge_ratio: 1 35 | prefetch_mode: ~ 36 | 37 | val: 38 | name: Flickr1024-sr-test 39 | type: PairedStereoImageDataset 40 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 41 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: NAFSSR 48 | up_scale: 4 49 | width: 128 50 | num_blks: 128 51 | drop_path_rate: 0.3 52 | train_size: [1, 6, 30, 90] 53 | drop_out_rate: 0. 54 | 55 | # path 56 | path: 57 | pretrain_network_g: ~ 58 | strict_load_g: true 59 | resume_state: ~ 60 | 61 | # training settings 62 | train: 63 | optim_g: 64 | type: AdamW 65 | lr: !!float 3e-3 66 | weight_decay: !!float 0 67 | betas: [0.9, 0.9] 68 | 69 | scheduler: 70 | type: TrueCosineAnnealingLR 71 | T_max: 100000 72 | eta_min: !!float 1e-7 73 | 74 | total_iter: 100000 75 | warmup_iter: -1 # no warm up 76 | mixup: false 77 | 78 | # losses 79 | pixel_opt: 80 | type: MSELoss 81 | loss_weight: 1. 82 | reduction: mean 83 | 84 | # validation settings 85 | val: 86 | val_freq: !!float 2e4 87 | save_img: false 88 | trans_num: 1 89 | 90 | max_minibatch: 1 91 | 92 | metrics: 93 | psnr: # metric name, can be arbitrary 94 | type: calculate_psnr 95 | crop_border: 0 96 | test_y_channel: false 97 | ssim: 98 | type: calculate_skimage_ssim 99 | 100 | # logging settings 101 | logger: 102 | print_freq: 200 103 | save_checkpoint_freq: !!float 1e4 104 | use_tb_logger: true 105 | wandb: 106 | project: ~ 107 | resume_id: ~ 108 | 109 | # dist training settings 110 | dist_params: 111 | backend: nccl 112 | port: 29500 113 | -------------------------------------------------------------------------------- /options/train/NAFSSR/NAFSSR-S_x2.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNetSR-S_x2 9 | model_type: ImageRestorationModel 10 | scale: 2 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: Flickr1024-sr-train 17 | type: PairedStereoImageDataset 18 | dataroot_gt: datasets/StereoSR/patches_x2/ 19 | dataroot_lq: datasets/StereoSR/patches_x2/ 20 | io_backend: 21 | type: disk 22 | 23 | gt_size_h: 60 24 | gt_size_w: 180 25 | use_hflip: true 26 | use_vflip: true 27 | use_rot: false 28 | flip_RGB: true 29 | 30 | # data loader 31 | use_shuffle: true 32 | num_worker_per_gpu: 4 33 | batch_size_per_gpu: 4 34 | dataset_enlarge_ratio: 1 35 | prefetch_mode: ~ 36 | 37 | val: 38 | name: Flickr1024-sr-test 39 | type: PairedStereoImageDataset 40 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 41 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: NAFSSR 48 | up_scale: 2 49 | width: 64 50 | num_blks: 32 51 | drop_path_rate: 0.1 52 | train_size: [1, 6, 30, 90] 53 | drop_out_rate: 0. 54 | 55 | # path 56 | path: 57 | pretrain_network_g: ~ 58 | strict_load_g: true 59 | resume_state: ~ 60 | 61 | # training settings 62 | train: 63 | optim_g: 64 | type: AdamW 65 | lr: !!float 3e-3 66 | weight_decay: !!float 0 67 | betas: [0.9, 0.9] 68 | 69 | scheduler: 70 | type: TrueCosineAnnealingLR 71 | T_max: 100000 72 | eta_min: !!float 1e-7 73 | 74 | total_iter: 100000 75 | warmup_iter: -1 # no warm up 76 | mixup: false 77 | 78 | # losses 79 | pixel_opt: 80 | type: MSELoss 81 | loss_weight: 1. 82 | reduction: mean 83 | 84 | # validation settings 85 | val: 86 | val_freq: !!float 2e4 87 | save_img: false 88 | trans_num: 1 89 | 90 | max_minibatch: 1 91 | 92 | metrics: 93 | psnr: # metric name, can be arbitrary 94 | type: calculate_psnr 95 | crop_border: 0 96 | test_y_channel: false 97 | ssim: 98 | type: calculate_skimage_ssim 99 | 100 | # logging settings 101 | logger: 102 | print_freq: 200 103 | save_checkpoint_freq: !!float 1e4 104 | use_tb_logger: true 105 | wandb: 106 | project: ~ 107 | resume_id: ~ 108 | 109 | # dist training settings 110 | dist_params: 111 | backend: nccl 112 | port: 29500 113 | -------------------------------------------------------------------------------- /options/train/NAFSSR/NAFSSR-S_x4.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNetSR-S_x4 9 | model_type: ImageRestorationModel 10 | scale: 4 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: Flickr1024-sr-train 17 | type: PairedStereoImageDataset 18 | dataroot_gt: datasets/StereoSR/patches_x4/ 19 | dataroot_lq: datasets/StereoSR/patches_x4/ 20 | io_backend: 21 | type: disk 22 | 23 | gt_size_h: 120 24 | gt_size_w: 360 25 | use_hflip: true 26 | use_vflip: true 27 | use_rot: false 28 | flip_RGB: true 29 | 30 | # data loader 31 | use_shuffle: true 32 | num_worker_per_gpu: 4 33 | batch_size_per_gpu: 4 34 | dataset_enlarge_ratio: 1 35 | prefetch_mode: ~ 36 | 37 | val: 38 | name: Flickr1024-sr-test 39 | type: PairedStereoImageDataset 40 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 41 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: NAFSSR 48 | up_scale: 4 49 | width: 64 50 | num_blks: 32 51 | drop_path_rate: 0.1 52 | train_size: [1, 6, 30, 90] 53 | drop_out_rate: 0. 54 | 55 | # path 56 | path: 57 | pretrain_network_g: ~ 58 | strict_load_g: true 59 | resume_state: ~ 60 | 61 | # training settings 62 | train: 63 | optim_g: 64 | type: AdamW 65 | lr: !!float 3e-3 66 | weight_decay: !!float 0 67 | betas: [0.9, 0.9] 68 | 69 | scheduler: 70 | type: TrueCosineAnnealingLR 71 | T_max: 100000 72 | eta_min: !!float 1e-7 73 | 74 | total_iter: 100000 75 | warmup_iter: -1 # no warm up 76 | mixup: false 77 | 78 | # losses 79 | pixel_opt: 80 | type: MSELoss 81 | loss_weight: 1. 82 | reduction: mean 83 | 84 | # validation settings 85 | val: 86 | val_freq: !!float 2e4 87 | save_img: false 88 | trans_num: 1 89 | 90 | max_minibatch: 1 91 | 92 | metrics: 93 | psnr: # metric name, can be arbitrary 94 | type: calculate_psnr 95 | crop_border: 0 96 | test_y_channel: false 97 | ssim: 98 | type: calculate_skimage_ssim 99 | 100 | # logging settings 101 | logger: 102 | print_freq: 200 103 | save_checkpoint_freq: !!float 1e4 104 | use_tb_logger: true 105 | wandb: 106 | project: ~ 107 | resume_id: ~ 108 | 109 | # dist training settings 110 | dist_params: 111 | backend: nccl 112 | port: 29500 113 | -------------------------------------------------------------------------------- /options/train/NAFSSR/NAFSSR-T_x2.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNetSR-T_x2 9 | model_type: ImageRestorationModel 10 | scale: 2 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: Flickr1024-sr-train 17 | type: PairedStereoImageDataset 18 | dataroot_gt: datasets/StereoSR/patches_x2/ 19 | dataroot_lq: datasets/StereoSR/patches_x2/ 20 | io_backend: 21 | type: disk 22 | 23 | gt_size_h: 60 24 | gt_size_w: 180 25 | use_hflip: true 26 | use_vflip: true 27 | use_rot: false 28 | flip_RGB: true 29 | 30 | # data loader 31 | use_shuffle: true 32 | num_worker_per_gpu: 4 33 | batch_size_per_gpu: 4 34 | dataset_enlarge_ratio: 1 35 | prefetch_mode: ~ 36 | 37 | val: 38 | name: Flickr1024-sr-test 39 | type: PairedStereoImageDataset 40 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 41 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: NAFSSR 48 | up_scale: 2 49 | width: 48 50 | num_blks: 16 51 | drop_path_rate: 0. 52 | train_size: [1, 6, 30, 90] 53 | drop_out_rate: 0. 54 | 55 | # path 56 | path: 57 | pretrain_network_g: ~ 58 | strict_load_g: true 59 | resume_state: ~ 60 | 61 | # training settings 62 | train: 63 | optim_g: 64 | type: AdamW 65 | lr: !!float 3e-3 66 | weight_decay: !!float 0 67 | betas: [0.9, 0.9] 68 | 69 | scheduler: 70 | type: TrueCosineAnnealingLR 71 | T_max: 400000 72 | eta_min: !!float 1e-7 73 | 74 | total_iter: 400000 75 | warmup_iter: -1 # no warm up 76 | mixup: false 77 | 78 | # losses 79 | pixel_opt: 80 | type: MSELoss 81 | loss_weight: 1. 82 | reduction: mean 83 | 84 | # validation settings 85 | val: 86 | val_freq: !!float 2e4 87 | save_img: false 88 | trans_num: 1 89 | 90 | max_minibatch: 1 91 | 92 | metrics: 93 | psnr: # metric name, can be arbitrary 94 | type: calculate_psnr 95 | crop_border: 0 96 | test_y_channel: false 97 | ssim: 98 | type: calculate_skimage_ssim 99 | 100 | # logging settings 101 | logger: 102 | print_freq: 200 103 | save_checkpoint_freq: !!float 1e4 104 | use_tb_logger: true 105 | wandb: 106 | project: ~ 107 | resume_id: ~ 108 | 109 | # dist training settings 110 | dist_params: 111 | backend: nccl 112 | port: 29500 113 | -------------------------------------------------------------------------------- /options/train/NAFSSR/NAFSSR-T_x4.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNetSR-T_x4 9 | model_type: ImageRestorationModel 10 | scale: 4 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: Flickr1024-sr-train 17 | type: PairedStereoImageDataset 18 | dataroot_gt: datasets/StereoSR/patches_x4/ 19 | dataroot_lq: datasets/StereoSR/patches_x4/ 20 | io_backend: 21 | type: disk 22 | 23 | gt_size_h: 120 24 | gt_size_w: 360 25 | use_hflip: true 26 | use_vflip: true 27 | use_rot: false 28 | flip_RGB: true 29 | 30 | # data loader 31 | use_shuffle: true 32 | num_worker_per_gpu: 4 33 | batch_size_per_gpu: 4 34 | dataset_enlarge_ratio: 1 35 | prefetch_mode: ~ 36 | 37 | val: 38 | name: Flickr1024-sr-test 39 | type: PairedStereoImageDataset 40 | dataroot_gt: datasets/StereoSR/test/Flickr1024/hr 41 | dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: NAFSSR 48 | up_scale: 4 49 | width: 48 50 | num_blks: 16 51 | drop_path_rate: 0. 52 | train_size: [1, 6, 30, 90] 53 | drop_out_rate: 0. 54 | 55 | # path 56 | path: 57 | pretrain_network_g: ~ 58 | strict_load_g: true 59 | resume_state: ~ 60 | 61 | # training settings 62 | train: 63 | optim_g: 64 | type: AdamW 65 | lr: !!float 3e-3 66 | weight_decay: !!float 0 67 | betas: [0.9, 0.9] 68 | 69 | scheduler: 70 | type: TrueCosineAnnealingLR 71 | T_max: 400000 72 | eta_min: !!float 1e-7 73 | 74 | total_iter: 400000 75 | warmup_iter: -1 # no warm up 76 | mixup: false 77 | 78 | # losses 79 | pixel_opt: 80 | type: MSELoss 81 | loss_weight: 1. 82 | reduction: mean 83 | 84 | # validation settings 85 | val: 86 | val_freq: !!float 2e4 87 | save_img: false 88 | trans_num: 1 89 | 90 | max_minibatch: 1 91 | 92 | metrics: 93 | psnr: # metric name, can be arbitrary 94 | type: calculate_psnr 95 | crop_border: 0 96 | test_y_channel: false 97 | ssim: 98 | type: calculate_skimage_ssim 99 | 100 | # logging settings 101 | logger: 102 | print_freq: 200 103 | save_checkpoint_freq: !!float 1e4 104 | use_tb_logger: true 105 | wandb: 106 | project: ~ 107 | resume_id: ~ 108 | 109 | # dist training settings 110 | dist_params: 111 | backend: nccl 112 | port: 29500 113 | -------------------------------------------------------------------------------- /options/train/REDS/NAFNet-width64.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNet-REDS-width64 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: reds-train 17 | type: PairedImageDataset 18 | dataroot_gt: ./datasets/REDS/train/train_sharp.lmdb 19 | dataroot_lq: ./datasets/REDS/train/train_blur_jpeg.lmdb 20 | 21 | filename_tmpl: '{}' 22 | io_backend: 23 | type: lmdb 24 | 25 | gt_size: 256 26 | use_flip: true 27 | use_rot: true 28 | 29 | # data loader 30 | use_shuffle: true 31 | num_worker_per_gpu: 8 32 | batch_size_per_gpu: 8 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: reds-val300-test 38 | type: PairedImageDataset 39 | dataroot_gt: ./datasets/REDS/val/sharp_300.lmdb 40 | dataroot_lq: ./datasets/REDS/val/blur_300.lmdb 41 | io_backend: 42 | type: lmdb 43 | 44 | 45 | network_g: 46 | type: NAFNetLocal 47 | width: 64 48 | enc_blk_nums: [1, 1, 1, 28] 49 | middle_blk_num: 1 50 | dec_blk_nums: [1, 1, 1, 1] 51 | 52 | # path 53 | path: 54 | pretrain_network_g: ~ 55 | strict_load_g: true 56 | resume_state: ~ 57 | 58 | # training settings 59 | train: 60 | optim_g: 61 | type: AdamW 62 | lr: !!float 1e-3 63 | weight_decay: !!float 1e-3 64 | betas: [0.9, 0.9] 65 | 66 | scheduler: 67 | type: TrueCosineAnnealingLR 68 | T_max: 400000 69 | eta_min: !!float 1e-7 70 | 71 | total_iter: 400000 72 | warmup_iter: -1 # no warm up 73 | 74 | # losses 75 | pixel_opt: 76 | type: PSNRLoss 77 | loss_weight: 1 78 | reduction: mean 79 | 80 | # validation settings 81 | val: 82 | val_freq: !!float 2e4 83 | save_img: false 84 | 85 | 86 | metrics: 87 | psnr: # metric name, can be arbitrary 88 | type: calculate_psnr 89 | crop_border: 0 90 | test_y_channel: false 91 | ssim: 92 | type: calculate_ssim 93 | crop_border: 0 94 | test_y_channel: false 95 | 96 | # logging settings 97 | logger: 98 | print_freq: 200 99 | save_checkpoint_freq: !!float 5e3 100 | use_tb_logger: true 101 | wandb: 102 | project: ~ 103 | resume_id: ~ 104 | 105 | # dist training settings 106 | dist_params: 107 | backend: nccl 108 | port: 29500 109 | -------------------------------------------------------------------------------- /options/train/SIDD/Baseline-width32.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: Baseline-SIDD-width32 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: SIDD 17 | type: PairedImageDataset 18 | dataroot_gt: ./datasets/SIDD/train/gt_crops.lmdb 19 | dataroot_lq: ./datasets/SIDD/train/input_crops.lmdb 20 | 21 | filename_tmpl: '{}' 22 | io_backend: 23 | type: lmdb 24 | 25 | gt_size: 256 26 | use_flip: false 27 | use_rot: false 28 | 29 | # data loader 30 | use_shuffle: true 31 | num_worker_per_gpu: 4 32 | batch_size_per_gpu: 4 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: SIDD_val 38 | type: PairedImageDataset 39 | dataroot_gt: ./datasets/SIDD/val/gt_crops.lmdb 40 | dataroot_lq: ./datasets/SIDD/val/input_crops.lmdb 41 | io_backend: 42 | type: lmdb 43 | 44 | 45 | network_g: 46 | type: Baseline 47 | width: 32 48 | enc_blk_nums: [2, 2, 4, 8] 49 | middle_blk_num: 12 50 | dec_blk_nums: [2, 2, 2, 2] 51 | dw_expand: 1 52 | ffn_expand: 2 53 | 54 | # path 55 | path: 56 | pretrain_network_g: ~ 57 | strict_load_g: true 58 | resume_state: ~ 59 | 60 | # training settings 61 | train: 62 | optim_g: 63 | type: AdamW 64 | lr: !!float 1e-3 65 | weight_decay: 0. 66 | betas: [0.9, 0.9] 67 | 68 | scheduler: 69 | type: TrueCosineAnnealingLR 70 | T_max: 200000 71 | eta_min: !!float 1e-7 72 | 73 | total_iter: 200000 74 | warmup_iter: -1 # no warm up 75 | 76 | # losses 77 | pixel_opt: 78 | type: PSNRLoss 79 | loss_weight: 1 80 | reduction: mean 81 | 82 | # validation settings 83 | val: 84 | val_freq: !!float 2e4 85 | save_img: false 86 | use_image: false 87 | 88 | metrics: 89 | psnr: # metric name, can be arbitrary 90 | type: calculate_psnr 91 | crop_border: 0 92 | test_y_channel: false 93 | ssim: 94 | type: calculate_ssim 95 | crop_border: 0 96 | test_y_channel: false 97 | 98 | # logging settings 99 | logger: 100 | print_freq: 200 101 | save_checkpoint_freq: !!float 5e3 102 | use_tb_logger: true 103 | wandb: 104 | project: ~ 105 | resume_id: ~ 106 | 107 | # dist training settings 108 | dist_params: 109 | backend: nccl 110 | port: 29500 111 | -------------------------------------------------------------------------------- /options/train/SIDD/Baseline-width64.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: Baseline-SIDD-width64 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: SIDD 17 | type: PairedImageDataset 18 | dataroot_gt: ./datasets/SIDD/train/gt_crops.lmdb 19 | dataroot_lq: ./datasets/SIDD/train/input_crops.lmdb 20 | 21 | filename_tmpl: '{}' 22 | io_backend: 23 | type: lmdb 24 | 25 | gt_size: 256 26 | use_flip: false 27 | use_rot: false 28 | 29 | # data loader 30 | use_shuffle: true 31 | num_worker_per_gpu: 8 32 | batch_size_per_gpu: 8 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: SIDD_val 38 | type: PairedImageDataset 39 | dataroot_gt: ./datasets/SIDD/val/gt_crops.lmdb 40 | dataroot_lq: ./datasets/SIDD/val/input_crops.lmdb 41 | io_backend: 42 | type: lmdb 43 | 44 | 45 | network_g: 46 | type: Baseline 47 | width: 64 48 | enc_blk_nums: [2, 2, 4, 8] 49 | middle_blk_num: 12 50 | dec_blk_nums: [2, 2, 2, 2] 51 | dw_expand: 2 52 | ffn_expand: 2 53 | 54 | # path 55 | path: 56 | pretrain_network_g: ~ 57 | strict_load_g: true 58 | resume_state: ~ 59 | 60 | # training settings 61 | train: 62 | optim_g: 63 | type: AdamW 64 | lr: !!float 1e-3 65 | weight_decay: 0. 66 | betas: [0.9, 0.9] 67 | 68 | scheduler: 69 | type: TrueCosineAnnealingLR 70 | T_max: 400000 71 | eta_min: !!float 1e-7 72 | 73 | total_iter: 400000 74 | warmup_iter: -1 # no warm up 75 | 76 | # losses 77 | pixel_opt: 78 | type: PSNRLoss 79 | loss_weight: 1 80 | reduction: mean 81 | 82 | # validation settings 83 | val: 84 | val_freq: !!float 2e4 85 | save_img: false 86 | use_image: false 87 | 88 | metrics: 89 | psnr: # metric name, can be arbitrary 90 | type: calculate_psnr 91 | crop_border: 0 92 | test_y_channel: false 93 | ssim: 94 | type: calculate_ssim 95 | crop_border: 0 96 | test_y_channel: false 97 | 98 | # logging settings 99 | logger: 100 | print_freq: 200 101 | save_checkpoint_freq: !!float 5e3 102 | use_tb_logger: true 103 | wandb: 104 | project: ~ 105 | resume_id: ~ 106 | 107 | # dist training settings 108 | dist_params: 109 | backend: nccl 110 | port: 29500 111 | -------------------------------------------------------------------------------- /options/train/SIDD/NAFNet-width32.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNet-SIDD-width32 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: SIDD 17 | type: PairedImageDataset 18 | dataroot_gt: ./datasets/SIDD/train/gt_crops.lmdb 19 | dataroot_lq: ./datasets/SIDD/train/input_crops.lmdb 20 | 21 | filename_tmpl: '{}' 22 | io_backend: 23 | type: lmdb 24 | 25 | gt_size: 256 26 | use_flip: false 27 | use_rot: false 28 | 29 | # data loader 30 | use_shuffle: true 31 | num_worker_per_gpu: 4 32 | batch_size_per_gpu: 4 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: SIDD_val 38 | type: PairedImageDataset 39 | dataroot_gt: ./datasets/SIDD/val/gt_crops.lmdb 40 | dataroot_lq: ./datasets/SIDD/val/input_crops.lmdb 41 | io_backend: 42 | type: lmdb 43 | 44 | 45 | network_g: 46 | type: NAFNet 47 | width: 32 48 | enc_blk_nums: [2, 2, 4, 8] 49 | middle_blk_num: 12 50 | dec_blk_nums: [2, 2, 2, 2] 51 | 52 | # path 53 | path: 54 | pretrain_network_g: ~ 55 | strict_load_g: true 56 | resume_state: ~ 57 | 58 | # training settings 59 | train: 60 | optim_g: 61 | type: AdamW 62 | lr: !!float 1e-3 63 | weight_decay: 0. 64 | betas: [0.9, 0.9] 65 | 66 | scheduler: 67 | type: TrueCosineAnnealingLR 68 | T_max: 200000 69 | eta_min: !!float 1e-7 70 | 71 | total_iter: 200000 72 | warmup_iter: -1 # no warm up 73 | 74 | # losses 75 | pixel_opt: 76 | type: PSNRLoss 77 | loss_weight: 1 78 | reduction: mean 79 | 80 | # validation settings 81 | val: 82 | val_freq: !!float 2e4 83 | save_img: false 84 | use_image: false 85 | 86 | metrics: 87 | psnr: # metric name, can be arbitrary 88 | type: calculate_psnr 89 | crop_border: 0 90 | test_y_channel: false 91 | ssim: 92 | type: calculate_ssim 93 | crop_border: 0 94 | test_y_channel: false 95 | 96 | # logging settings 97 | logger: 98 | print_freq: 200 99 | save_checkpoint_freq: !!float 5e3 100 | use_tb_logger: true 101 | wandb: 102 | project: ~ 103 | resume_id: ~ 104 | 105 | # dist training settings 106 | dist_params: 107 | backend: nccl 108 | port: 29500 109 | -------------------------------------------------------------------------------- /options/train/SIDD/NAFNet-width64.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: NAFNet-SIDD-width64 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 8 12 | manual_seed: 10 13 | 14 | datasets: 15 | train: 16 | name: SIDD 17 | type: PairedImageDataset 18 | dataroot_gt: ./datasets/SIDD/train/gt_crops.lmdb 19 | dataroot_lq: ./datasets/SIDD/train/input_crops.lmdb 20 | 21 | filename_tmpl: '{}' 22 | io_backend: 23 | type: lmdb 24 | 25 | gt_size: 256 26 | use_flip: false 27 | use_rot: false 28 | 29 | # data loader 30 | use_shuffle: true 31 | num_worker_per_gpu: 8 32 | batch_size_per_gpu: 8 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: SIDD_val 38 | type: PairedImageDataset 39 | dataroot_gt: ./datasets/SIDD/val/gt_crops.lmdb 40 | dataroot_lq: ./datasets/SIDD/val/input_crops.lmdb 41 | io_backend: 42 | type: lmdb 43 | 44 | 45 | network_g: 46 | type: NAFNet 47 | width: 64 48 | enc_blk_nums: [2, 2, 4, 8] 49 | middle_blk_num: 12 50 | dec_blk_nums: [2, 2, 2, 2] 51 | 52 | # path 53 | path: 54 | pretrain_network_g: ~ 55 | strict_load_g: true 56 | resume_state: ~ 57 | 58 | # training settings 59 | train: 60 | optim_g: 61 | type: AdamW 62 | lr: !!float 1e-3 63 | weight_decay: 0. 64 | betas: [0.9, 0.9] 65 | 66 | scheduler: 67 | type: TrueCosineAnnealingLR 68 | T_max: 400000 69 | eta_min: !!float 1e-7 70 | 71 | total_iter: 400000 72 | warmup_iter: -1 # no warm up 73 | 74 | # losses 75 | pixel_opt: 76 | type: PSNRLoss 77 | loss_weight: 1 78 | reduction: mean 79 | 80 | # validation settings 81 | val: 82 | val_freq: !!float 2e4 83 | save_img: false 84 | use_image: false 85 | 86 | metrics: 87 | psnr: # metric name, can be arbitrary 88 | type: calculate_psnr 89 | crop_border: 0 90 | test_y_channel: false 91 | ssim: 92 | type: calculate_ssim 93 | crop_border: 0 94 | test_y_channel: false 95 | 96 | # logging settings 97 | logger: 98 | print_freq: 200 99 | save_checkpoint_freq: !!float 5e3 100 | use_tb_logger: true 101 | wandb: 102 | project: ~ 103 | resume_id: ~ 104 | 105 | # dist training settings 106 | dist_params: 107 | backend: nccl 108 | port: 29500 109 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | import tempfile 5 | import matplotlib.pyplot as plt 6 | from cog import BasePredictor, Path, Input, BaseModel 7 | 8 | from basicsr.models import create_model 9 | from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite 10 | from basicsr.utils.options import parse 11 | 12 | 13 | class Predictor(BasePredictor): 14 | def setup(self): 15 | opt_path_denoise = "options/test/SIDD/NAFNet-width64.yml" 16 | opt_denoise = parse(opt_path_denoise, is_train=False) 17 | opt_denoise["dist"] = False 18 | 19 | opt_path_deblur = "options/test/GoPro/NAFNet-width64.yml" 20 | opt_deblur = parse(opt_path_deblur, is_train=False) 21 | opt_deblur["dist"] = False 22 | 23 | opt_path_stereo = "options/test/NAFSSR/NAFSSR-L_4x.yml" 24 | opt_stereo = parse(opt_path_stereo, is_train=False) 25 | opt_stereo["dist"] = False 26 | 27 | self.models = { 28 | "Image Denoising": create_model(opt_denoise), 29 | "Image Debluring": create_model(opt_deblur), 30 | "Stereo Image Super-Resolution": create_model(opt_stereo), 31 | } 32 | 33 | def predict( 34 | self, 35 | task_type: str = Input( 36 | choices=[ 37 | "Image Denoising", 38 | "Image Debluring", 39 | "Stereo Image Super-Resolution", 40 | ], 41 | default="Image Debluring", 42 | description="Choose task type.", 43 | ), 44 | image: Path = Input( 45 | description="Input image. Stereo Image Super-Resolution, upload the left image here.", 46 | ), 47 | image_r: Path = Input( 48 | default=None, 49 | description="Right Input image for Stereo Image Super-Resolution. Optional, only valid for Stereo" 50 | " Image Super-Resolution task.", 51 | ), 52 | ) -> Path: 53 | 54 | out_path = Path(tempfile.mkdtemp()) / "output.png" 55 | 56 | model = self.models[task_type] 57 | if task_type == "Stereo Image Super-Resolution": 58 | assert image_r is not None, ( 59 | "Please provide both left and right input image for " 60 | "Stereo Image Super-Resolution task." 61 | ) 62 | 63 | img_l = imread(str(image)) 64 | inp_l = img2tensor(img_l) 65 | img_r = imread(str(image_r)) 66 | inp_r = img2tensor(img_r) 67 | stereo_image_inference(model, inp_l, inp_r, str(out_path)) 68 | 69 | else: 70 | 71 | img_input = imread(str(image)) 72 | inp = img2tensor(img_input) 73 | out_path = Path(tempfile.mkdtemp()) / "output.png" 74 | single_image_inference(model, inp, str(out_path)) 75 | 76 | return out_path 77 | 78 | 79 | def imread(img_path): 80 | img = cv2.imread(img_path) 81 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 82 | return img 83 | 84 | 85 | def img2tensor(img, bgr2rgb=False, float32=True): 86 | img = img.astype(np.float32) / 255.0 87 | return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32) 88 | 89 | 90 | def single_image_inference(model, img, save_path): 91 | model.feed_data(data={"lq": img.unsqueeze(dim=0)}) 92 | 93 | if model.opt["val"].get("grids", False): 94 | model.grids() 95 | 96 | model.test() 97 | 98 | if model.opt["val"].get("grids", False): 99 | model.grids_inverse() 100 | 101 | visuals = model.get_current_visuals() 102 | sr_img = tensor2img([visuals["result"]]) 103 | imwrite(sr_img, save_path) 104 | 105 | 106 | def stereo_image_inference(model, img_l, img_r, out_path): 107 | img = torch.cat([img_l, img_r], dim=0) 108 | model.feed_data(data={"lq": img.unsqueeze(dim=0)}) 109 | 110 | if model.opt["val"].get("grids", False): 111 | model.grids() 112 | 113 | model.test() 114 | 115 | if model.opt["val"].get("grids", False): 116 | model.grids_inverse() 117 | 118 | visuals = model.get_current_visuals() 119 | img_L = visuals["result"][:, :3] 120 | img_R = visuals["result"][:, 3:] 121 | img_L, img_R = tensor2img([img_L, img_R], rgb2bgr=False) 122 | 123 | # save_stereo_image 124 | h, w = img_L.shape[:2] 125 | fig = plt.figure(figsize=(w // 40, h // 40)) 126 | ax1 = fig.add_subplot(2, 1, 1) 127 | plt.title("NAFSSR output (Left)", fontsize=14) 128 | ax1.axis("off") 129 | ax1.imshow(img_L) 130 | 131 | ax2 = fig.add_subplot(2, 1, 2) 132 | plt.title("NAFSSR output (Right)", fontsize=14) 133 | ax2.axis("off") 134 | ax2.imshow(img_R) 135 | 136 | plt.subplots_adjust(hspace=0.08) 137 | plt.savefig(str(out_path), bbox_inches="tight", dpi=600) 138 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict 2 | future 3 | lmdb 4 | numpy 5 | opencv-python 6 | Pillow 7 | pyyaml 8 | requests 9 | scikit-image 10 | scipy 11 | tb-nightly 12 | tqdm 13 | yapf 14 | -------------------------------------------------------------------------------- /scripts/data_preparation/gopro.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import cv2 8 | import numpy as np 9 | import os 10 | import sys 11 | from multiprocessing import Pool 12 | from os import path as osp 13 | from tqdm import tqdm 14 | 15 | from basicsr.utils import scandir 16 | from basicsr.utils.create_lmdb import create_lmdb_for_gopro 17 | 18 | def main(): 19 | opt = {} 20 | opt['n_thread'] = 20 21 | opt['compression_level'] = 3 22 | 23 | opt['input_folder'] = './datasets/GoPro/train/input' 24 | opt['save_folder'] = './datasets/GoPro/train/blur_crops' 25 | opt['crop_size'] = 512 26 | opt['step'] = 256 27 | opt['thresh_size'] = 0 28 | extract_subimages(opt) 29 | 30 | opt['input_folder'] = './datasets/GoPro/train/target' 31 | opt['save_folder'] = './datasets/GoPro/train/sharp_crops' 32 | opt['crop_size'] = 512 33 | opt['step'] = 256 34 | opt['thresh_size'] = 0 35 | extract_subimages(opt) 36 | 37 | create_lmdb_for_gopro() 38 | 39 | 40 | def extract_subimages(opt): 41 | """Crop images to subimages. 42 | 43 | Args: 44 | opt (dict): Configuration dict. It contains: 45 | input_folder (str): Path to the input folder. 46 | save_folder (str): Path to save folder. 47 | n_thread (int): Thread number. 48 | """ 49 | input_folder = opt['input_folder'] 50 | save_folder = opt['save_folder'] 51 | if not osp.exists(save_folder): 52 | os.makedirs(save_folder) 53 | print(f'mkdir {save_folder} ...') 54 | else: 55 | print(f'Folder {save_folder} already exists. Exit.') 56 | sys.exit(1) 57 | 58 | img_list = list(scandir(input_folder, full_path=True)) 59 | 60 | pbar = tqdm(total=len(img_list), unit='image', desc='Extract') 61 | pool = Pool(opt['n_thread']) 62 | for path in img_list: 63 | pool.apply_async( 64 | worker, args=(path, opt), callback=lambda arg: pbar.update(1)) 65 | pool.close() 66 | pool.join() 67 | pbar.close() 68 | print('All processes done.') 69 | 70 | 71 | def worker(path, opt): 72 | """Worker for each process. 73 | 74 | Args: 75 | path (str): Image path. 76 | opt (dict): Configuration dict. It contains: 77 | crop_size (int): Crop size. 78 | step (int): Step for overlapped sliding window. 79 | thresh_size (int): Threshold size. Patches whose size is lower 80 | than thresh_size will be dropped. 81 | save_folder (str): Path to save folder. 82 | compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. 83 | 84 | Returns: 85 | process_info (str): Process information displayed in progress bar. 86 | """ 87 | crop_size = opt['crop_size'] 88 | step = opt['step'] 89 | thresh_size = opt['thresh_size'] 90 | img_name, extension = osp.splitext(osp.basename(path)) 91 | 92 | # remove the x2, x3, x4 and x8 in the filename for DIV2K 93 | img_name = img_name.replace('x2', 94 | '').replace('x3', 95 | '').replace('x4', 96 | '').replace('x8', '') 97 | 98 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 99 | 100 | if img.ndim == 2: 101 | h, w = img.shape 102 | elif img.ndim == 3: 103 | h, w, c = img.shape 104 | else: 105 | raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}') 106 | 107 | h_space = np.arange(0, h - crop_size + 1, step) 108 | if h - (h_space[-1] + crop_size) > thresh_size: 109 | h_space = np.append(h_space, h - crop_size) 110 | w_space = np.arange(0, w - crop_size + 1, step) 111 | if w - (w_space[-1] + crop_size) > thresh_size: 112 | w_space = np.append(w_space, w - crop_size) 113 | 114 | index = 0 115 | for x in h_space: 116 | for y in w_space: 117 | index += 1 118 | cropped_img = img[x:x + crop_size, y:y + crop_size, ...] 119 | cropped_img = np.ascontiguousarray(cropped_img) 120 | cv2.imwrite( 121 | osp.join(opt['save_folder'], 122 | f'{img_name}_s{index:03d}{extension}'), cropped_img, 123 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) 124 | process_info = f'Processing {img_name} ...' 125 | return process_info 126 | 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /scripts/data_preparation/reds.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | ''' 8 | for val set, extract the subset val-300 9 | 10 | ''' 11 | import os 12 | import time 13 | from basicsr.utils.create_lmdb import create_lmdb_for_reds 14 | 15 | def make_val_300(folder, dst): 16 | if not os.path.exists(dst): 17 | os.mkdir(dst) 18 | templates = '*9.*' 19 | cp_command = 'cp {} {}'.format(os.path.join(folder, templates), dst) 20 | os.system(cp_command) 21 | 22 | 23 | def flatten_folders(folder): 24 | for vid in range(300): 25 | vidfolder_path = '{:03}'.format(vid) 26 | 27 | if not os.path.exists(os.path.join(folder, vidfolder_path)): 28 | continue 29 | 30 | print('working on .. {} .. {}'.format(folder, vid)) 31 | for fid in range(100): 32 | src_filename = '{:08}'.format(fid) 33 | 34 | suffixes = ['.jpg', '.png'] 35 | suffix = None 36 | 37 | for suf in suffixes: 38 | # print(os.path.join(folder, vidfolder_path, src_filename+suf)) 39 | if os.path.exists(os.path.join(folder, vidfolder_path, src_filename+suf)): 40 | suffix = suf 41 | break 42 | assert suffix is not None 43 | 44 | 45 | src_filepath = os.path.join(folder, vidfolder_path, src_filename+suffix) 46 | dst_filepath = os.path.join(folder, '{}_{}{}'.format(vidfolder_path, src_filename, suffix)) 47 | os.system('mv {} {}'.format(src_filepath, dst_filepath)) 48 | time.sleep(0.001) 49 | os.system('rm -r {}'.format(os.path.join(folder, vidfolder_path))) 50 | 51 | 52 | if __name__ == '__main__': 53 | flatten_folders('./datasets/REDS/train/train_blur_jpeg') 54 | flatten_folders('./datasets/REDS/train/train_sharp') 55 | 56 | # flatten_folders('./datasets/REDS/val/val_blur_jpeg') 57 | # flatten_folders('./datasets/REDS/val/val_sharp') 58 | # make_val_300('./datasets/REDS/val/val_blur_jpeg', './datasets/REDS/val/blur_300') 59 | # make_val_300('./datasets/REDS/val/val_sharp', './datasets/REDS/val/sharp_300') 60 | 61 | create_lmdb_for_reds() 62 | 63 | 64 | -------------------------------------------------------------------------------- /scripts/data_preparation/sidd.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import cv2 8 | import numpy as np 9 | import os 10 | import sys 11 | from multiprocessing import Pool 12 | from os import path as osp 13 | from tqdm import tqdm 14 | 15 | from basicsr.utils import scandir_SIDD 16 | from basicsr.utils.create_lmdb import create_lmdb_for_SIDD 17 | 18 | 19 | def main(): 20 | opt = {} 21 | opt['n_thread'] = 20 22 | opt['compression_level'] = 3 23 | 24 | opt['input_folder'] = './datasets/SIDD/Data' 25 | opt['save_folder'] = './datasets/SIDD/train/input_crops' 26 | opt['crop_size'] = 512 27 | opt['step'] = 384 28 | opt['thresh_size'] = 0 29 | opt['keywords'] = '_NOISY' 30 | extract_subimages(opt) 31 | 32 | opt['save_folder'] = './datasets/SIDD/train/gt_crops' 33 | opt['keywords'] = '_GT' 34 | extract_subimages(opt) 35 | 36 | create_lmdb_for_SIDD() 37 | 38 | 39 | def extract_subimages(opt): 40 | """Crop images to subimages. 41 | Args: 42 | opt (dict): Configuration dict. It contains: 43 | input_folder (str): Path to the input folder. 44 | save_folder (str): Path to save folder. 45 | n_thread (int): Thread number. 46 | """ 47 | input_folder = opt['input_folder'] 48 | save_folder = opt['save_folder'] 49 | if not osp.exists(save_folder): 50 | os.makedirs(save_folder) 51 | print(f'mkdir {save_folder} ...') 52 | else: 53 | print(f'Folder {save_folder} already exists. Exit.') 54 | # sys.exit(1) 55 | 56 | img_list = list(scandir_SIDD(input_folder, keywords=opt['keywords'], recursive=True, full_path=True)) 57 | 58 | pbar = tqdm(total=len(img_list), unit='image', desc='Extract') 59 | pool = Pool(opt['n_thread']) 60 | for path in img_list: 61 | pool.apply_async( 62 | worker, args=(path, opt), callback=lambda arg: pbar.update(1)) 63 | pool.close() 64 | pool.join() 65 | pbar.close() 66 | print('All processes done.') 67 | 68 | 69 | def worker(path, opt): 70 | """Worker for each process. 71 | Args: 72 | path (str): Image path. 73 | opt (dict): Configuration dict. It contains: 74 | crop_size (int): Crop size. 75 | step (int): Step for overlapped sliding window. 76 | thresh_size (int): Threshold size. Patches whose size is lower 77 | than thresh_size will be dropped. 78 | save_folder (str): Path to save folder. 79 | compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. 80 | Returns: 81 | process_info (str): Process information displayed in progress bar. 82 | """ 83 | crop_size = opt['crop_size'] 84 | step = opt['step'] 85 | thresh_size = opt['thresh_size'] 86 | img_name, extension = osp.splitext(osp.basename(path)) 87 | 88 | img_name = img_name.replace(opt['keywords'], '') 89 | 90 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 91 | 92 | if img.ndim == 2: 93 | h, w = img.shape 94 | elif img.ndim == 3: 95 | h, w, c = img.shape 96 | else: 97 | raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}') 98 | 99 | h_space = np.arange(0, h - crop_size + 1, step) 100 | if h - (h_space[-1] + crop_size) > thresh_size: 101 | h_space = np.append(h_space, h - crop_size) 102 | w_space = np.arange(0, w - crop_size + 1, step) 103 | if w - (w_space[-1] + crop_size) > thresh_size: 104 | w_space = np.append(w_space, w - crop_size) 105 | 106 | index = 0 107 | for x in h_space: 108 | for y in w_space: 109 | index += 1 110 | cropped_img = img[x:x + crop_size, y:y + crop_size, ...] 111 | cropped_img = np.ascontiguousarray(cropped_img) 112 | cv2.imwrite( 113 | osp.join(opt['save_folder'], 114 | f'{img_name}_s{index:03d}{extension}'), cropped_img, 115 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) 116 | process_info = f'Processing {img_name} ...' 117 | return process_info 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | # ... make sidd to lmdb -------------------------------------------------------------------------------- /scripts/make_pickle.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | 5 | PATH = './datasets/SR/NTIRE22-StereoSR/Train' 6 | 7 | LR_FOLDER = 'LR_x4' 8 | HR_FOLDER = 'HR' 9 | 10 | 11 | lr_lists = [] 12 | hr_lists = [] 13 | 14 | cnt = 0 15 | 16 | for idx in range(1, 801): 17 | 18 | L_name = f'{idx:04}_L.png' 19 | R_name = f'{idx:04}_R.png' 20 | 21 | 22 | LR_L = cv2.imread(os.path.join(PATH, LR_FOLDER, L_name)) 23 | LR_R = cv2.imread(os.path.join(PATH, LR_FOLDER, R_name)) 24 | 25 | HR_L = cv2.imread(os.path.join(PATH, HR_FOLDER, L_name)) 26 | HR_R = cv2.imread(os.path.join(PATH, HR_FOLDER, R_name)) 27 | 28 | LR = np.concatenate([LR_L, LR_R], axis=-1) 29 | HR = np.concatenate([HR_L, HR_R], axis=-1) 30 | 31 | lr_lists.append(LR) 32 | hr_lists.append(HR) 33 | 34 | cnt = cnt + 1 35 | if cnt % 50 == 0: 36 | print(f'cnt .. {cnt}, idx: {idx}') 37 | 38 | 39 | 40 | import pickle 41 | with open('./datasets/ntire-stereo-sr.train.lr.pickle', 'wb') as f: 42 | pickle.dump(lr_lists, f) 43 | 44 | with open('./datasets/ntire-stereo-sr.train.hr.pickle', 'wb') as f: 45 | pickle.dump(hr_lists, f) 46 | 47 | 48 | 49 | # print(f'... {lr_all_np.shape}, {lr_all_np.dtype}') 50 | # print(f'... {hr_all_np.shape}, {hr_all_np.dtype}') 51 | 52 | # np.save('./datasets/ntire-stereo-sr.train.lr.npy', lr_all_np) 53 | # np.save('./datasets/ntire-stereo-sr.train.hr.npy', hr_all_np) 54 | 55 | 56 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # line break before binary operator (W503) 4 | W503, 5 | # line break after binary operator (W504) 6 | W504, 7 | max-line-length=79 8 | 9 | [yapf] 10 | based_on_style = pep8 11 | blank_line_before_nested_class_or_def = true 12 | split_before_expression_after_opening_paren = true 13 | 14 | [isort] 15 | line_length = 79 16 | multi_line_output = 0 17 | known_standard_library = pkg_resources,setuptools 18 | known_first_party = basicsr 19 | known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml 20 | no_lines_before = STDLIB,LOCALFOLDER 21 | default_section = THIRDPARTY 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | #!/usr/bin/env python 8 | 9 | from setuptools import find_packages, setup 10 | 11 | import os 12 | import subprocess 13 | import sys 14 | import time 15 | import torch 16 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 17 | CUDAExtension) 18 | 19 | version_file = 'basicsr/version.py' 20 | 21 | 22 | def readme(): 23 | return '' 24 | # with open('README.md', encoding='utf-8') as f: 25 | # content = f.read() 26 | # return content 27 | 28 | 29 | def get_git_hash(): 30 | 31 | def _minimal_ext_cmd(cmd): 32 | # construct minimal environment 33 | env = {} 34 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 35 | v = os.environ.get(k) 36 | if v is not None: 37 | env[k] = v 38 | # LANGUAGE is used on win32 39 | env['LANGUAGE'] = 'C' 40 | env['LANG'] = 'C' 41 | env['LC_ALL'] = 'C' 42 | out = subprocess.Popen( 43 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 44 | return out 45 | 46 | try: 47 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 48 | sha = out.strip().decode('ascii') 49 | except OSError: 50 | sha = 'unknown' 51 | 52 | return sha 53 | 54 | 55 | def get_hash(): 56 | if os.path.exists('.git'): 57 | sha = get_git_hash()[:7] 58 | elif os.path.exists(version_file): 59 | try: 60 | from basicsr.version import __version__ 61 | sha = __version__.split('+')[-1] 62 | except ImportError: 63 | raise ImportError('Unable to get git version') 64 | else: 65 | sha = 'unknown' 66 | 67 | return sha 68 | 69 | 70 | def write_version_py(): 71 | content = """# GENERATED VERSION FILE 72 | # TIME: {} 73 | __version__ = '{}' 74 | short_version = '{}' 75 | version_info = ({}) 76 | """ 77 | sha = get_hash() 78 | with open('VERSION', 'r') as f: 79 | SHORT_VERSION = f.read().strip() 80 | VERSION_INFO = ', '.join( 81 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 82 | VERSION = SHORT_VERSION + '+' + sha 83 | 84 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, 85 | VERSION_INFO) 86 | with open(version_file, 'w') as f: 87 | f.write(version_file_str) 88 | 89 | 90 | def get_version(): 91 | with open(version_file, 'r') as f: 92 | exec(compile(f.read(), version_file, 'exec')) 93 | return locals()['__version__'] 94 | 95 | 96 | def make_cuda_ext(name, module, sources, sources_cuda=None): 97 | if sources_cuda is None: 98 | sources_cuda = [] 99 | define_macros = [] 100 | extra_compile_args = {'cxx': []} 101 | 102 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 103 | define_macros += [('WITH_CUDA', None)] 104 | extension = CUDAExtension 105 | extra_compile_args['nvcc'] = [ 106 | '-D__CUDA_NO_HALF_OPERATORS__', 107 | '-D__CUDA_NO_HALF_CONVERSIONS__', 108 | '-D__CUDA_NO_HALF2_OPERATORS__', 109 | ] 110 | sources += sources_cuda 111 | else: 112 | print(f'Compiling {name} without CUDA') 113 | extension = CppExtension 114 | 115 | return extension( 116 | name=f'{module}.{name}', 117 | sources=[os.path.join(*module.split('.'), p) for p in sources], 118 | define_macros=define_macros, 119 | extra_compile_args=extra_compile_args) 120 | 121 | 122 | def get_requirements(filename='requirements.txt'): 123 | return [] 124 | here = os.path.dirname(os.path.realpath(__file__)) 125 | with open(os.path.join(here, filename), 'r') as f: 126 | requires = [line.replace('\n', '') for line in f.readlines()] 127 | return requires 128 | 129 | 130 | if __name__ == '__main__': 131 | if '--no_cuda_ext' in sys.argv: 132 | ext_modules = [] 133 | sys.argv.remove('--no_cuda_ext') 134 | else: 135 | ext_modules = [ 136 | make_cuda_ext( 137 | name='deform_conv_ext', 138 | module='basicsr.models.ops.dcn', 139 | sources=['src/deform_conv_ext.cpp'], 140 | sources_cuda=[ 141 | 'src/deform_conv_cuda.cpp', 142 | 'src/deform_conv_cuda_kernel.cu' 143 | ]), 144 | make_cuda_ext( 145 | name='fused_act_ext', 146 | module='basicsr.models.ops.fused_act', 147 | sources=['src/fused_bias_act.cpp'], 148 | sources_cuda=['src/fused_bias_act_kernel.cu']), 149 | make_cuda_ext( 150 | name='upfirdn2d_ext', 151 | module='basicsr.models.ops.upfirdn2d', 152 | sources=['src/upfirdn2d.cpp'], 153 | sources_cuda=['src/upfirdn2d_kernel.cu']), 154 | ] 155 | 156 | write_version_py() 157 | setup( 158 | name='basicsr', 159 | version=get_version(), 160 | description='Open Source Image and Video Super-Resolution Toolbox', 161 | long_description=readme(), 162 | author='Xintao Wang', 163 | author_email='xintao.wang@outlook.com', 164 | keywords='computer vision, restoration, super resolution', 165 | url='https://github.com/xinntao/BasicSR', 166 | packages=find_packages( 167 | exclude=('options', 'datasets', 'experiments', 'results', 168 | 'tb_logger', 'wandb')), 169 | classifiers=[ 170 | 'Development Status :: 4 - Beta', 171 | 'License :: OSI Approved :: Apache Software License', 172 | 'Operating System :: OS Independent', 173 | 'Programming Language :: Python :: 3', 174 | 'Programming Language :: Python :: 3.7', 175 | 'Programming Language :: Python :: 3.8', 176 | ], 177 | license='Apache License 2.0', 178 | setup_requires=['cython', 'numpy'], 179 | install_requires=get_requirements(), 180 | ext_modules=ext_modules, 181 | cmdclass={'build_ext': BuildExtension}, 182 | zip_safe=False) 183 | --------------------------------------------------------------------------------