├── .gitignore ├── LICENSE ├── README.md ├── VERSION ├── basicsr ├── __init__.py ├── archs │ ├── __init__.py │ ├── arch_util.py │ ├── art_arch.py │ └── artunet_arch.py ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ ├── single_image_dataset.py │ └── transforms.py ├── losses │ ├── __init__.py │ ├── loss_util.py │ └── losses.py ├── metrics │ ├── __init__.py │ ├── metric_util.py │ └── psnr_ssim.py ├── models │ ├── __init__.py │ ├── art_model.py │ ├── base_model.py │ ├── lr_scheduler.py │ └── sr_model.py ├── test.py ├── train.py ├── utils │ ├── __init__.py │ ├── dist_util.py │ ├── file_client.py │ ├── img_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ ├── options.py │ └── registry.py └── version.py ├── datasets ├── README.md └── example │ └── example.png ├── experiments └── pretrained_models │ └── README.md ├── figs ├── CAR.png ├── ColorDN.png ├── ComS_img_092_ART_x4.png ├── ComS_img_092_Bicubic_x4.png ├── ComS_img_092_HR_x4.png ├── ComS_img_092_SwinIR_x4.png ├── ComS_img_098_ART_x4.png ├── ComS_img_098_Bicubic_x4.png ├── ComS_img_098_HR_x4.png ├── ComS_img_098_SwinIR_x4.png ├── RealDN.png ├── Resize_ComL_img_092_HR_x4.png ├── Resize_ComL_img_098_HR_x4.png ├── SR.png ├── Visual_DN.png ├── Visual_SR_1.png ├── Visual_SR_2.png └── git.png ├── options ├── apply │ ├── test_ART_SR_x2_without_groundTruth.yml │ ├── test_ART_SR_x3_without_groundTruth.yml │ └── test_ART_SR_x4_without_groundTruth.yml ├── test │ ├── test_ART_CAR_q10.yml │ ├── test_ART_CAR_q30.yml │ ├── test_ART_CAR_q40.yml │ ├── test_ART_ColorDN_level15.yml │ ├── test_ART_ColorDN_level25.yml │ ├── test_ART_ColorDN_level50.yml │ ├── test_ART_SR_x2.yml │ ├── test_ART_SR_x3.yml │ ├── test_ART_SR_x4.yml │ ├── test_ART_S_SR_x2.yml │ ├── test_ART_S_SR_x3.yml │ └── test_ART_S_SR_x4.yml └── train │ ├── train_ART_CAR_q10.yml │ ├── train_ART_CAR_q30.yml │ ├── train_ART_CAR_q40.yml │ ├── train_ART_ColorDN_level15.yml │ ├── train_ART_ColorDN_level25.yml │ ├── train_ART_ColorDN_level50.yml │ ├── train_ART_SR_x2.yml │ ├── train_ART_SR_x3.yml │ ├── train_ART_SR_x4.yml │ ├── train_ART_S_SR_x2.yml │ ├── train_ART_S_SR_x3.yml │ └── train_ART_S_SR_x4.yml ├── realDenoising ├── README.md ├── VERSION ├── basicsr │ ├── data │ │ ├── __init__.py │ │ ├── data_sampler.py │ │ ├── data_util.py │ │ ├── ffhq_dataset.py │ │ ├── meta_info │ │ │ ├── meta_info_DIV2K800sub_GT.txt │ │ │ ├── meta_info_REDS4_test_GT.txt │ │ │ ├── meta_info_REDS_GT.txt │ │ │ ├── meta_info_REDSofficial4_test_GT.txt │ │ │ ├── meta_info_REDSval_official_test_GT.txt │ │ │ ├── meta_info_Vimeo90K_test_GT.txt │ │ │ ├── meta_info_Vimeo90K_test_fast_GT.txt │ │ │ ├── meta_info_Vimeo90K_test_medium_GT.txt │ │ │ ├── meta_info_Vimeo90K_test_slow_GT.txt │ │ │ └── meta_info_Vimeo90K_train_GT.txt │ │ ├── paired_image_dataset.py │ │ ├── prefetch_dataloader.py │ │ ├── reds_dataset.py │ │ ├── single_image_dataset.py │ │ ├── transforms.py │ │ ├── video_test_dataset.py │ │ └── vimeo90k_dataset.py │ ├── metrics │ │ ├── __init__.py │ │ ├── fid.py │ │ ├── metric_util.py │ │ ├── niqe.py │ │ ├── niqe_pris_params.npz │ │ └── psnr_ssim.py │ ├── models │ │ ├── __init__.py │ │ ├── archs │ │ │ ├── __init__.py │ │ │ ├── arch_util.py │ │ │ └── artunet_arch.py │ │ ├── base_model.py │ │ ├── image_restoration_model.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── loss_util.py │ │ │ └── losses.py │ │ └── lr_scheduler.py │ ├── test.py │ ├── train.py │ ├── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-38.pyc │ │ ├── bundle_submissions.py │ │ ├── create_lmdb.py │ │ ├── dist_util.py │ │ ├── download_util.py │ │ ├── face_util.py │ │ ├── file_client.py │ │ ├── flow_util.py │ │ ├── img_util.py │ │ ├── lmdb_util.py │ │ ├── logger.py │ │ ├── matlab_functions.py │ │ ├── misc.py │ │ └── options.py │ └── version.py ├── evaluate_sidd.m ├── options │ └── train_ART_RealDN.yml ├── setup.cfg ├── setup.py ├── test_real_denoising_dnd.py ├── test_real_denoising_sidd.py └── utils.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Build and Release Folders 2 | datasets/* 3 | experiments/* 4 | results/* 5 | tb_logger/* 6 | wandb/* 7 | tmp/* 8 | 9 | # Other files and folders 10 | .settings/ 11 | 12 | # Executables 13 | *.swf 14 | *.air 15 | *.ipa 16 | *.apk 17 | 18 | #other files 19 | *.html 20 | *.png 21 | *.jpeg 22 | *.jpg 23 | *.gif 24 | *.pth 25 | *.zip 26 | 27 | # Project files, i.e. `.project`, `.actionScriptProperties` and `.flexProperties` 28 | # should NOT be excluded as they contain compiler settings and other important 29 | # information for Eclipse / Flash Builder. 30 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 1.3.5 2 | -------------------------------------------------------------------------------- /basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | from .archs import * 2 | from .data import * 3 | from .metrics import * 4 | from .models import * 5 | from .test import * 6 | from .utils import * 7 | -------------------------------------------------------------------------------- /basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with 12 | # '_arch.py' 13 | arch_folder = osp.dirname(osp.abspath(__file__)) 14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 15 | # import all the arch modules 16 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 17 | 18 | 19 | def build_network(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop('type') 22 | net = ARCH_REGISTRY.get(network_type)(**opt) 23 | logger = get_root_logger() 24 | logger.info(f'Network [{net.__class__.__name__}] is created.') 25 | return net 26 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 11 | from basicsr.utils import get_root_logger, scandir 12 | from basicsr.utils.dist_util import get_dist_info 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must contain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) 84 | 85 | prefetch_mode = dataset_opt.get('prefetch_mode') 86 | if prefetch_mode == 'cpu': # CPUPrefetcher 87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 88 | logger = get_root_logger() 89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') 90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 91 | else: 92 | # prefetch_mode=None: Normal dataloader 93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 94 | return torch.utils.data.DataLoader(**dataloader_args) 95 | 96 | 97 | def worker_init_fn(worker_id, num_workers, rank, seed): 98 | # Set the worker seed to num_workers * rank + worker_id + seed 99 | worker_seed = num_workers * rank + worker_id + seed 100 | np.random.seed(worker_seed) 101 | random.seed(worker_seed) 102 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 116 | 117 | def next(self): 118 | torch.cuda.current_stream().wait_stream(self.stream) 119 | batch = self.batch 120 | self.preload() 121 | return batch 122 | 123 | def reset(self): 124 | self.loader = iter(self.ori_loader) 125 | self.preload() 126 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir 7 | from basicsr.utils.matlab_functions import rgb2ycbcr 8 | from basicsr.utils.registry import DATASET_REGISTRY 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class SingleImageDataset(data.Dataset): 13 | """Read only lq images in the test phase. 14 | 15 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 16 | 17 | There are two modes: 18 | 1. 'meta_info_file': Use meta information file to generate paths. 19 | 2. 'folder': Scan folders to generate paths. 20 | 21 | Args: 22 | opt (dict): Config for train datasets. It contains the following keys: 23 | dataroot_lq (str): Data root path for lq. 24 | meta_info_file (str): Path for meta information file. 25 | io_backend (dict): IO backend type and other kwarg. 26 | """ 27 | 28 | def __init__(self, opt): 29 | super(SingleImageDataset, self).__init__() 30 | self.opt = opt 31 | # file client (io backend) 32 | self.file_client = None 33 | self.io_backend_opt = opt['io_backend'] 34 | self.mean = opt['mean'] if 'mean' in opt else None 35 | self.std = opt['std'] if 'std' in opt else None 36 | self.lq_folder = opt['dataroot_lq'] 37 | 38 | if self.io_backend_opt['type'] == 'lmdb': 39 | self.io_backend_opt['db_paths'] = [self.lq_folder] 40 | self.io_backend_opt['client_keys'] = ['lq'] 41 | self.paths = paths_from_lmdb(self.lq_folder) 42 | elif 'meta_info_file' in self.opt: 43 | with open(self.opt['meta_info_file'], 'r') as fin: 44 | self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin] 45 | else: 46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 47 | 48 | def __getitem__(self, index): 49 | if self.file_client is None: 50 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 51 | 52 | # load lq image 53 | lq_path = self.paths[index] 54 | img_bytes = self.file_client.get(lq_path, 'lq') 55 | img_lq = imfrombytes(img_bytes, float32=True) 56 | 57 | # color space transform 58 | if 'color' in self.opt and self.opt['color'] == 'y': 59 | img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] 60 | 61 | # BGR to RGB, HWC to CHW, numpy to tensor 62 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 63 | # normalize 64 | if self.mean is not None or self.std is not None: 65 | normalize(img_lq, self.mean, self.std, inplace=True) 66 | return {'lq': img_lq, 'lq_path': lq_path} 67 | 68 | def __len__(self): 69 | return len(self.paths) 70 | -------------------------------------------------------------------------------- /basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils import get_root_logger 4 | from basicsr.utils.registry import LOSS_REGISTRY 5 | from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, WeightedTVLoss, g_path_regularize, 6 | gradient_penalty_loss, r1_penalty) 7 | 8 | __all__ = [ 9 | 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'GANLoss', 'gradient_penalty_loss', 10 | 'r1_penalty', 'g_path_regularize' 11 | ] 12 | 13 | 14 | def build_loss(opt): 15 | """Build loss from options. 16 | 17 | Args: 18 | opt (dict): Configuration. It must contain: 19 | type (str): Model type. 20 | """ 21 | opt = deepcopy(opt) 22 | loss_type = opt.pop('type') 23 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 24 | logger = get_root_logger() 25 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 26 | return loss 27 | -------------------------------------------------------------------------------- /basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils.registry import METRIC_REGISTRY 4 | from .psnr_ssim import calculate_psnr, calculate_ssim 5 | 6 | __all__ = ['calculate_psnr', 'calculate_ssim'] 7 | 8 | 9 | def calculate_metric(data, opt): 10 | """Calculate metric from data and options. 11 | 12 | Args: 13 | opt (dict): Configuration. It must contain: 14 | type (str): Model type. 15 | """ 16 | opt = deepcopy(opt) 17 | metric_type = opt.pop('type') 18 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 19 | return metric 20 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /basicsr/metrics/psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from basicsr.metrics.metric_util import reorder_image, to_y_channel 5 | from basicsr.utils.registry import METRIC_REGISTRY 6 | 7 | 8 | @METRIC_REGISTRY.register() 9 | def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): 10 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 11 | 12 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 13 | 14 | Args: 15 | img (ndarray): Images with range [0, 255]. 16 | img2 (ndarray): Images with range [0, 255]. 17 | crop_border (int): Cropped pixels in each edge of an image. These 18 | pixels are not involved in the PSNR calculation. 19 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 20 | Default: 'HWC'. 21 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 22 | 23 | Returns: 24 | float: psnr result. 25 | """ 26 | 27 | assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') 28 | if input_order not in ['HWC', 'CHW']: 29 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') 30 | img = reorder_image(img, input_order=input_order) 31 | img2 = reorder_image(img2, input_order=input_order) 32 | img = img.astype(np.float64) 33 | img2 = img2.astype(np.float64) 34 | 35 | if crop_border != 0: 36 | img = img[crop_border:-crop_border, crop_border:-crop_border, ...] 37 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 38 | 39 | if test_y_channel: 40 | img = to_y_channel(img) 41 | img2 = to_y_channel(img2) 42 | 43 | mse = np.mean((img - img2)**2) 44 | if mse == 0: 45 | return float('inf') 46 | return 20. * np.log10(255. / np.sqrt(mse)) 47 | 48 | 49 | def _ssim(img, img2): 50 | """Calculate SSIM (structural similarity) for one channel images. 51 | 52 | It is called by func:`calculate_ssim`. 53 | 54 | Args: 55 | img (ndarray): Images with range [0, 255] with order 'HWC'. 56 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 57 | 58 | Returns: 59 | float: ssim result. 60 | """ 61 | 62 | c1 = (0.01 * 255)**2 63 | c2 = (0.03 * 255)**2 64 | 65 | img = img.astype(np.float64) 66 | img2 = img2.astype(np.float64) 67 | kernel = cv2.getGaussianKernel(11, 1.5) 68 | window = np.outer(kernel, kernel.transpose()) 69 | 70 | mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] 71 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 72 | mu1_sq = mu1**2 73 | mu2_sq = mu2**2 74 | mu1_mu2 = mu1 * mu2 75 | sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq 76 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 77 | sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 78 | 79 | ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) 80 | return ssim_map.mean() 81 | 82 | 83 | @METRIC_REGISTRY.register() 84 | def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): 85 | """Calculate SSIM (structural similarity). 86 | 87 | Ref: 88 | Image quality assessment: From error visibility to structural similarity 89 | 90 | The results are the same as that of the official released MATLAB code in 91 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 92 | 93 | For three-channel images, SSIM is calculated for each channel and then 94 | averaged. 95 | 96 | Args: 97 | img (ndarray): Images with range [0, 255]. 98 | img2 (ndarray): Images with range [0, 255]. 99 | crop_border (int): Cropped pixels in each edge of an image. These 100 | pixels are not involved in the SSIM calculation. 101 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 102 | Default: 'HWC'. 103 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 104 | 105 | Returns: 106 | float: ssim result. 107 | """ 108 | 109 | assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') 110 | if input_order not in ['HWC', 'CHW']: 111 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') 112 | img = reorder_image(img, input_order=input_order) 113 | img2 = reorder_image(img2, input_order=input_order) 114 | img = img.astype(np.float64) 115 | img2 = img2.astype(np.float64) 116 | 117 | if crop_border != 0: 118 | img = img[crop_border:-crop_border, crop_border:-crop_border, ...] 119 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 120 | 121 | if test_y_channel: 122 | img = to_y_channel(img) 123 | img2 = to_y_channel(img2) 124 | 125 | ssims = [] 126 | for i in range(img.shape[2]): 127 | ssims.append(_ssim(img[..., i], img2[..., i])) 128 | return np.array(ssims).mean() 129 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with 12 | # '_model.py' 13 | model_folder = osp.dirname(osp.abspath(__file__)) 14 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 15 | # import all the model modules 16 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 17 | 18 | 19 | def build_model(opt): 20 | """Build model from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must contain: 24 | model_type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 28 | logger = get_root_logger() 29 | logger.info(f'Model [{model.__class__.__name__}] is created.') 30 | return model 31 | -------------------------------------------------------------------------------- /basicsr/models/art_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import MODEL_REGISTRY 5 | from basicsr.models.sr_model import SRModel 6 | 7 | 8 | @MODEL_REGISTRY.register() 9 | class ARTModel(SRModel): 10 | """ART model for image restoration.""" 11 | 12 | # test by partitioning 13 | def test(self): 14 | _, C, h, w = self.lq.size() 15 | split_token_h = h // 200 + 1 # number of horizontal cut sections 16 | split_token_w = w // 200 + 1 # number of vertical cut sections 17 | # padding 18 | mod_pad_h, mod_pad_w = 0, 0 19 | if h % split_token_h != 0: 20 | mod_pad_h = split_token_h - h % split_token_h 21 | if w % split_token_w != 0: 22 | mod_pad_w = split_token_w - w % split_token_w 23 | img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect') 24 | _, _, H, W = img.size() 25 | split_h = H // split_token_h # height of each partition 26 | split_w = W // split_token_w # width of each partition 27 | # overlapping 28 | shave_h = split_h // 10 29 | shave_w = split_w // 10 30 | scale = self.opt.get('scale', 1) 31 | ral = H // split_h 32 | row = W // split_w 33 | slices = [] # list of partition borders 34 | for i in range(ral): 35 | for j in range(row): 36 | if i == 0 and i == ral - 1: 37 | top = slice(i * split_h, (i + 1) * split_h) 38 | elif i == 0: 39 | top = slice(i*split_h, (i+1)*split_h+shave_h) 40 | elif i == ral - 1: 41 | top = slice(i*split_h-shave_h, (i+1)*split_h) 42 | else: 43 | top = slice(i*split_h-shave_h, (i+1)*split_h+shave_h) 44 | if j == 0 and j == row - 1: 45 | left = slice(j*split_w, (j+1)*split_w) 46 | elif j == 0: 47 | left = slice(j*split_w, (j+1)*split_w+shave_w) 48 | elif j == row - 1: 49 | left = slice(j*split_w-shave_w, (j+1)*split_w) 50 | else: 51 | left = slice(j*split_w-shave_w, (j+1)*split_w+shave_w) 52 | temp = (top, left) 53 | slices.append(temp) 54 | img_chops = [] # list of partitions 55 | for temp in slices: 56 | top, left = temp 57 | img_chops.append(img[..., top, left]) 58 | if hasattr(self, 'net_g_ema'): 59 | self.net_g_ema.eval() 60 | with torch.no_grad(): 61 | outputs = [] 62 | for chop in img_chops: 63 | out = self.net_g_ema(chop) # image processing of each partition 64 | outputs.append(out) 65 | _img = torch.zeros(1, C, H * scale, W * scale) 66 | # merge 67 | for i in range(ral): 68 | for j in range(row): 69 | top = slice(i * split_h * scale, (i + 1) * split_h * scale) 70 | left = slice(j * split_w * scale, (j + 1) * split_w * scale) 71 | if i == 0: 72 | _top = slice(0, split_h * scale) 73 | else: 74 | _top = slice(shave_h*scale, (shave_h+split_h)*scale) 75 | if j == 0: 76 | _left = slice(0, split_w*scale) 77 | else: 78 | _left = slice(shave_w*scale, (shave_w+split_w)*scale) 79 | _img[..., top, left] = outputs[i * row + j][..., _top, _left] 80 | self.output = _img 81 | else: 82 | self.net_g.eval() 83 | with torch.no_grad(): 84 | outputs = [] 85 | for chop in img_chops: 86 | out = self.net_g(chop) # image processing of each partition 87 | outputs.append(out) 88 | _img = torch.zeros(1, C, H * scale, W * scale) 89 | # merge 90 | for i in range(ral): 91 | for j in range(row): 92 | top = slice(i * split_h * scale, (i + 1) * split_h * scale) 93 | left = slice(j * split_w * scale, (j + 1) * split_w * scale) 94 | if i == 0: 95 | _top = slice(0, split_h * scale) 96 | else: 97 | _top = slice(shave_h * scale, (shave_h + split_h) * scale) 98 | if j == 0: 99 | _left = slice(0, split_w * scale) 100 | else: 101 | _left = slice(shave_w * scale, (shave_w + split_w) * scale) 102 | _img[..., top, left] = outputs[i * row + j][..., _top, _left] 103 | self.output = _img 104 | self.net_g.train() 105 | _, _, h, w = self.output.size() 106 | self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] 107 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The minimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] 97 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import build_dataloader, build_dataset 6 | from basicsr.models import build_model 7 | from basicsr.utils import get_root_logger, get_time_str, make_exp_dirs 8 | from basicsr.utils.options import dict2str, parse_options 9 | 10 | 11 | def test_pipeline(root_path): 12 | # parse options, set distributed setting, set ramdom seed 13 | opt, _ = parse_options(root_path, is_train=False) 14 | 15 | torch.backends.cudnn.benchmark = True 16 | # torch.backends.cudnn.deterministic = True 17 | 18 | # mkdir and initialize loggers 19 | make_exp_dirs(opt) 20 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") 21 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 22 | logger.info(dict2str(opt)) 23 | 24 | # create test dataset and dataloader 25 | test_loaders = [] 26 | for _, dataset_opt in sorted(opt['datasets'].items()): 27 | test_set = build_dataset(dataset_opt) 28 | test_loader = build_dataloader( 29 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) 30 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 31 | test_loaders.append(test_loader) 32 | 33 | # create model 34 | model = build_model(opt) 35 | 36 | for test_loader in test_loaders: 37 | test_set_name = test_loader.dataset.opt['name'] 38 | logger.info(f'Testing {test_set_name}...') 39 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) 40 | 41 | 42 | if __name__ == '__main__': 43 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 44 | test_pipeline(root_path) 45 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 3 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 4 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 5 | 6 | __all__ = [ 7 | # file_client.py 8 | 'FileClient', 9 | # img_util.py 10 | 'img2tensor', 11 | 'tensor2img', 12 | 'imfrombytes', 13 | 'imwrite', 14 | 'crop_border', 15 | # logger.py 16 | 'MessageLogger', 17 | 'AvgTimer', 18 | 'init_tb_logger', 19 | 'init_wandb_logger', 20 | 'get_root_logger', 21 | 'get_env_info', 22 | # misc.py 23 | 'set_random_seed', 24 | 'get_time_str', 25 | 'mkdir_and_rename', 26 | 'make_exp_dirs', 27 | 'scandir', 28 | 'check_resume', 29 | 'sizeof_fmt', 30 | ] 31 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError('Please install memcached to enable MemcachedBackend.') 40 | 41 | self.server_list_cfg = server_list_cfg 42 | self.client_cfg = client_cfg 43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 44 | # mc.pyvector servers as a point which points to a memory cache 45 | self._mc_buffer = mc.pyvector() 46 | 47 | def get(self, filepath): 48 | filepath = str(filepath) 49 | import mc 50 | self._client.Get(filepath, self._mc_buffer) 51 | value_buf = mc.ConvertBuffer(self._mc_buffer) 52 | return value_buf 53 | 54 | def get_text(self, filepath): 55 | raise NotImplementedError 56 | 57 | 58 | class HardDiskBackend(BaseStorageBackend): 59 | """Raw hard disks storage backend.""" 60 | 61 | def get(self, filepath): 62 | filepath = str(filepath) 63 | with open(filepath, 'rb') as f: 64 | value_buf = f.read() 65 | return value_buf 66 | 67 | def get_text(self, filepath): 68 | filepath = str(filepath) 69 | with open(filepath, 'r') as f: 70 | value_buf = f.read() 71 | return value_buf 72 | 73 | 74 | class LmdbBackend(BaseStorageBackend): 75 | """Lmdb storage backend. 76 | 77 | Args: 78 | db_paths (str | list[str]): Lmdb database paths. 79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 80 | readonly (bool, optional): Lmdb environment parameter. If True, 81 | disallow any write operations. Default: True. 82 | lock (bool, optional): Lmdb environment parameter. If False, when 83 | concurrent access occurs, do not lock the database. Default: False. 84 | readahead (bool, optional): Lmdb environment parameter. If False, 85 | disable the OS filesystem readahead mechanism, which may improve 86 | random read performance when a database is larger than RAM. 87 | Default: False. 88 | 89 | Attributes: 90 | db_paths (list): Lmdb database path. 91 | _client (list): A list of several lmdb envs. 92 | """ 93 | 94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): 95 | try: 96 | import lmdb 97 | except ImportError: 98 | raise ImportError('Please install lmdb to enable LmdbBackend.') 99 | 100 | if isinstance(client_keys, str): 101 | client_keys = [client_keys] 102 | 103 | if isinstance(db_paths, list): 104 | self.db_paths = [str(v) for v in db_paths] 105 | elif isinstance(db_paths, str): 106 | self.db_paths = [str(db_paths)] 107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' 108 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 109 | 110 | self._client = {} 111 | for client, path in zip(client_keys, self.db_paths): 112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 113 | 114 | def get(self, filepath, client_key): 115 | """Get values according to the filepath from one lmdb named client_key. 116 | 117 | Args: 118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 119 | client_key (str): Used for distinguishing different lmdb envs. 120 | """ 121 | filepath = str(filepath) 122 | assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.') 123 | client = self._client[client_key] 124 | with client.begin(write=False) as txn: 125 | value_buf = txn.get(filepath.encode('ascii')) 126 | return value_buf 127 | 128 | def get_text(self, filepath): 129 | raise NotImplementedError 130 | 131 | 132 | class FileClient(object): 133 | """A general file client to access files in different backend. 134 | 135 | The client loads a file or text in a specified backend from its path 136 | and return it as a binary file. it can also register other backend 137 | accessor with a given name and backend class. 138 | 139 | Attributes: 140 | backend (str): The storage backend type. Options are "disk", 141 | "memcached" and "lmdb". 142 | client (:obj:`BaseStorageBackend`): The backend object. 143 | """ 144 | 145 | _backends = { 146 | 'disk': HardDiskBackend, 147 | 'memcached': MemcachedBackend, 148 | 'lmdb': LmdbBackend, 149 | } 150 | 151 | def __init__(self, backend='disk', **kwargs): 152 | if backend not in self._backends: 153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones' 154 | f' are {list(self._backends.keys())}') 155 | self.backend = backend 156 | self.client = self._backends[backend](**kwargs) 157 | 158 | def get(self, filepath, client_key='default'): 159 | # client_key is used only for lmdb, where different fileclients have 160 | # different lmdb environments. 161 | if self.backend == 'lmdb': 162 | return self.client.get(filepath, client_key) 163 | else: 164 | return self.client.get(filepath) 165 | 166 | def get_text(self, filepath): 167 | return self.client.get_text(filepath) 168 | -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | 10 | 11 | def set_random_seed(seed): 12 | """Set random seeds.""" 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def get_time_str(): 21 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 22 | 23 | 24 | def mkdir_and_rename(path): 25 | """mkdirs. If path exists, rename it with timestamp and create a new one. 26 | 27 | Args: 28 | path (str): Folder path. 29 | """ 30 | if osp.exists(path): 31 | new_name = path + '_archived_' + get_time_str() 32 | print(f'Path already exists. Rename it to {new_name}', flush=True) 33 | os.rename(path, new_name) 34 | os.makedirs(path, exist_ok=True) 35 | 36 | 37 | @master_only 38 | def make_exp_dirs(opt): 39 | """Make dirs for experiments.""" 40 | path_opt = opt['path'].copy() 41 | if opt['is_train']: 42 | mkdir_and_rename(path_opt.pop('experiments_root')) 43 | else: 44 | mkdir_and_rename(path_opt.pop('results_root')) 45 | for key, path in path_opt.items(): 46 | if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): 47 | continue 48 | else: 49 | os.makedirs(path, exist_ok=True) 50 | 51 | 52 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 53 | """Scan a directory to find the interested files. 54 | 55 | Args: 56 | dir_path (str): Path of the directory. 57 | suffix (str | tuple(str), optional): File suffix that we are 58 | interested in. Default: None. 59 | recursive (bool, optional): If set to True, recursively scan the 60 | directory. Default: False. 61 | full_path (bool, optional): If set to True, include the dir_path. 62 | Default: False. 63 | 64 | Returns: 65 | A generator for all the interested files with relative paths. 66 | """ 67 | 68 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 69 | raise TypeError('"suffix" must be a string or tuple of strings') 70 | 71 | root = dir_path 72 | 73 | def _scandir(dir_path, suffix, recursive): 74 | for entry in os.scandir(dir_path): 75 | if not entry.name.startswith('.') and entry.is_file(): 76 | if full_path: 77 | return_path = entry.path 78 | else: 79 | return_path = osp.relpath(entry.path, root) 80 | 81 | if suffix is None: 82 | yield return_path 83 | elif return_path.endswith(suffix): 84 | yield return_path 85 | else: 86 | if recursive: 87 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 88 | else: 89 | continue 90 | 91 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 92 | 93 | 94 | def check_resume(opt, resume_iter): 95 | """Check resume states and pretrain_network paths. 96 | 97 | Args: 98 | opt (dict): Options. 99 | resume_iter (int): Resume iteration. 100 | """ 101 | if opt['path']['resume_state']: 102 | # get all the networks 103 | networks = [key for key in opt.keys() if key.startswith('network_')] 104 | flag_pretrain = False 105 | for network in networks: 106 | if opt['path'].get(f'pretrain_{network}') is not None: 107 | flag_pretrain = True 108 | if flag_pretrain: 109 | print('pretrain_network path will be ignored during resuming.') 110 | # set pretrained model paths 111 | for network in networks: 112 | name = f'pretrain_{network}' 113 | basename = network.replace('network_', '') 114 | if opt['path'].get('ignore_resume_networks') is None or (network 115 | not in opt['path']['ignore_resume_networks']): 116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 117 | print(f"Set {name} to {opt['path'][name]}") 118 | 119 | # change param_key to params in resume 120 | param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] 121 | for param_key in param_keys: 122 | if opt['path'][param_key] == 'params_ema': 123 | opt['path'][param_key] = 'params' 124 | print(f'Set {param_key} to params') 125 | 126 | 127 | def sizeof_fmt(size, suffix='B'): 128 | """Get human readable file size. 129 | 130 | Args: 131 | size (int): File size. 132 | suffix (str): Suffix. Default: 'B'. 133 | 134 | Return: 135 | str: Formatted file siz. 136 | """ 137 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 138 | if abs(size) < 1024.0: 139 | return f'{size:3.1f} {unit}{suffix}' 140 | size /= 1024.0 141 | return f'{size:3.1f} Y{suffix}' 142 | -------------------------------------------------------------------------------- /basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj): 39 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 40 | f"in '{self._name}' registry!") 41 | self._obj_map[name] = obj 42 | 43 | def register(self, obj=None): 44 | """ 45 | Register the given object under the the name `obj.__name__`. 46 | Can be used as either a decorator or not. 47 | See docstring of this class for usage. 48 | """ 49 | if obj is None: 50 | # used as a decorator 51 | def deco(func_or_class): 52 | name = func_or_class.__name__ 53 | self._do_register(name, func_or_class) 54 | return func_or_class 55 | 56 | return deco 57 | 58 | # used as a function call 59 | name = obj.__name__ 60 | self._do_register(name, obj) 61 | 62 | def get(self, name): 63 | ret = self._obj_map.get(name) 64 | if ret is None: 65 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 66 | return ret 67 | 68 | def __contains__(self, name): 69 | return name in self._obj_map 70 | 71 | def __iter__(self): 72 | return iter(self._obj_map.items()) 73 | 74 | def keys(self): 75 | return self._obj_map.keys() 76 | 77 | 78 | DATASET_REGISTRY = Registry('dataset') 79 | ARCH_REGISTRY = Registry('arch') 80 | MODEL_REGISTRY = Registry('model') 81 | LOSS_REGISTRY = Registry('loss') 82 | METRIC_REGISTRY = Registry('metric') 83 | -------------------------------------------------------------------------------- /basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Mon Nov 21 05:26:10 2022 3 | __version__ = '1.3.5' 4 | __gitsha__ = 'unknown' 5 | version_info = (1, 3, 5) 6 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | For training and testing, we provide the directory structure. You can download the complete datasets and put thme here. 2 | 3 | ```shell 4 | |-- datasets 5 | # image SR - train 6 | |-- DF2K 7 | |-- HR 8 | |-- LR_bicubic 9 | |-- X2 10 | |-- X3 11 | |-- X4 12 | # color image denoising - train 13 | |-- DFWB_RGB 14 | |-- HQ 15 | # real image denoising - train & val 16 | |-- SIDD 17 | |-- train 18 | |-- target_crops 19 | |-- input_crops 20 | |-- val 21 | |-- target_crops 22 | |-- input_crops 23 | # grayscale JPEG compression artifact reduction - train 24 | |-- DFWB_CAR 25 | |-- HQ 26 | |-- LQ 27 | |-- 10 28 | |-- 30 29 | |-- 40 30 | # image SR - test 31 | |-- SR 32 | |-- Set5 33 | |-- HR 34 | |-- LR_bicubic 35 | |-- X2 36 | |-- X3 37 | |-- X4 38 | |-- Set14 39 | |-- HR 40 | |-- LR_bicubic 41 | |-- X2 42 | |-- X3 43 | |-- X4 44 | |-- B100 45 | |-- HR 46 | |-- LR_bicubic 47 | |-- X2 48 | |-- X3 49 | |-- X4 50 | |-- Urban100 51 | |-- HR 52 | |-- LR_bicubic 53 | |-- X2 54 | |-- X3 55 | |-- X4 56 | |-- Manga109 57 | |-- HR 58 | |-- LR_bicubic 59 | |-- X2 60 | |-- X3 61 | |-- X4 62 | # gaussian color image denoising - test 63 | |-- ColorDN 64 | |-- CBSD68HQ 65 | |-- Kodak24HQ 66 | |-- McMasterHQ 67 | |-- Urban100HQ 68 | # real image denoising - test 69 | |-- RealDN 70 | |-- SIDD 71 | |-- ValidationGtBlocksSrgb.mat 72 | |-- ValidationNoisyBlocksSrgb.mat 73 | |-- DND 74 | |-- info.mat 75 | |-- ValidationNoisyBlocksSrgb 76 | |-- 0001.mat 77 | |-- 0002.mat 78 | : 79 | |-- 0050.mat 80 | # grayscale JPEG compression artifact reduction - test 81 | |-- CAR 82 | |-- classic5 83 | |-- Classic5_HQ 84 | |-- Classic5_LQ 85 | |-- 10 86 | |-- 30 87 | |-- 40 88 | |-- LIVE1 89 | |-- LIVE1_HQ 90 | |-- LIVE1_LQ 91 | |-- 10 92 | |-- 30 93 | |-- 40 94 | ``` 95 | 96 | -------------------------------------------------------------------------------- /datasets/example/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/datasets/example/example.png -------------------------------------------------------------------------------- /experiments/pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | You can find the provided links in the home page to download our pre-trained models. We provide the directory structure here. 2 | 3 | ```shell 4 | |-- experiments/pretrained_models/ 5 | -- SR_ART_x2.pth 6 | -- SR_ART_x3.pth 7 | -- SR_ART_x4.pth 8 | -- SR_ART_S_x2.pth 9 | -- SR_ART_S_x3.pth 10 | -- SR_ART_S_x4.pth 11 | -- ColorDN_ART_level15.pth 12 | -- ColorDN_ART_level25.pth 13 | -- ColorDN_ART_level50.pth 14 | -- RealDN_ART.pth 15 | -- CAR_ART_q10.pth 16 | -- CAR_ART_q30.pth 17 | -- CAR_ART_q40.pth 18 | 19 | ``` 20 | 21 | -------------------------------------------------------------------------------- /figs/CAR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/CAR.png -------------------------------------------------------------------------------- /figs/ColorDN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/ColorDN.png -------------------------------------------------------------------------------- /figs/ComS_img_092_ART_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/ComS_img_092_ART_x4.png -------------------------------------------------------------------------------- /figs/ComS_img_092_Bicubic_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/ComS_img_092_Bicubic_x4.png -------------------------------------------------------------------------------- /figs/ComS_img_092_HR_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/ComS_img_092_HR_x4.png -------------------------------------------------------------------------------- /figs/ComS_img_092_SwinIR_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/ComS_img_092_SwinIR_x4.png -------------------------------------------------------------------------------- /figs/ComS_img_098_ART_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/ComS_img_098_ART_x4.png -------------------------------------------------------------------------------- /figs/ComS_img_098_Bicubic_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/ComS_img_098_Bicubic_x4.png -------------------------------------------------------------------------------- /figs/ComS_img_098_HR_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/ComS_img_098_HR_x4.png -------------------------------------------------------------------------------- /figs/ComS_img_098_SwinIR_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/ComS_img_098_SwinIR_x4.png -------------------------------------------------------------------------------- /figs/RealDN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/RealDN.png -------------------------------------------------------------------------------- /figs/Resize_ComL_img_092_HR_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/Resize_ComL_img_092_HR_x4.png -------------------------------------------------------------------------------- /figs/Resize_ComL_img_098_HR_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/Resize_ComL_img_098_HR_x4.png -------------------------------------------------------------------------------- /figs/SR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/SR.png -------------------------------------------------------------------------------- /figs/Visual_DN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/Visual_DN.png -------------------------------------------------------------------------------- /figs/Visual_SR_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/Visual_SR_1.png -------------------------------------------------------------------------------- /figs/Visual_SR_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/Visual_SR_2.png -------------------------------------------------------------------------------- /figs/git.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/figs/git.png -------------------------------------------------------------------------------- /options/apply/test_ART_SR_x2_without_groundTruth.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: x2 3 | model_type: ARTModel 4 | scale: 2 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | datasets: 9 | test_1: # the 1st test dataset 10 | name: example # replace it with your upload data folder 11 | type: SingleImageDataset 12 | dataroot_lq: datasets/example # replace it with the path of your provided datasets 13 | io_backend: 14 | type: disk 15 | 16 | 17 | 18 | # network structures 19 | network_g: 20 | type: ART 21 | upscale: 2 22 | in_chans: 3 23 | img_size: 64 24 | window_size: 8 25 | img_range: 1. 26 | depths: [6, 6, 6, 6, 6, 6] 27 | interval: [4, 4, 4, 4, 4, 4] 28 | embed_dim: 180 29 | num_heads: [6, 6, 6, 6, 6, 6] 30 | mlp_ratio: 4 31 | upsampler: 'pixelshuffle' 32 | resi_connection: '1conv' 33 | 34 | # path 35 | path: 36 | pretrain_network_g: experiments/pretrained_models/SR_ART_x2.pth 37 | strict_load_g: true 38 | 39 | # validation settings 40 | val: 41 | save_img: true 42 | suffix: ~ # add suffix to saved images, if None, use exp name -------------------------------------------------------------------------------- /options/apply/test_ART_SR_x3_without_groundTruth.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: x3 3 | model_type: ARTModel 4 | scale: 3 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | datasets: 9 | test_1: # the 1st test dataset 10 | name: example # replace it with your upload data folder 11 | type: SingleImageDataset 12 | dataroot_lq: datasets/example # replace it with the path of your provided datasets 13 | io_backend: 14 | type: disk 15 | 16 | 17 | 18 | # network structures 19 | network_g: 20 | type: ART 21 | upscale: 3 22 | in_chans: 3 23 | img_size: 64 24 | window_size: 8 25 | img_range: 1. 26 | depths: [6, 6, 6, 6, 6, 6] 27 | interval: [4, 4, 4, 4, 4, 4] 28 | embed_dim: 180 29 | num_heads: [6, 6, 6, 6, 6, 6] 30 | mlp_ratio: 4 31 | upsampler: 'pixelshuffle' 32 | resi_connection: '1conv' 33 | 34 | # path 35 | path: 36 | pretrain_network_g: experiments/pretrained_models/SR_ART_x3.pth 37 | strict_load_g: true 38 | 39 | # validation settings 40 | val: 41 | save_img: true 42 | suffix: ~ # add suffix to saved images, if None, use exp name -------------------------------------------------------------------------------- /options/apply/test_ART_SR_x4_without_groundTruth.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: x4 3 | model_type: ARTModel 4 | scale: 4 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | datasets: 9 | test_1: # the 1st test dataset 10 | name: example # replace it with your upload data folder 11 | type: SingleImageDataset 12 | dataroot_lq: datasets/example # replace it with the path of your provided datasets 13 | io_backend: 14 | type: disk 15 | 16 | 17 | 18 | # network structures 19 | network_g: 20 | type: ART 21 | upscale: 4 22 | in_chans: 3 23 | img_size: 64 24 | window_size: 8 25 | img_range: 1. 26 | depths: [6, 6, 6, 6, 6, 6] 27 | interval: [4, 4, 4, 4, 4, 4] 28 | embed_dim: 180 29 | num_heads: [6, 6, 6, 6, 6, 6] 30 | mlp_ratio: 4 31 | upsampler: 'pixelshuffle' 32 | resi_connection: '1conv' 33 | 34 | # path 35 | path: 36 | pretrain_network_g: experiments/pretrained_models/SR_ART_x4.pth 37 | strict_load_g: true 38 | 39 | # validation settings 40 | val: 41 | save_img: true 42 | suffix: ~ # add suffix to saved images, if None, use exp name -------------------------------------------------------------------------------- /options/test/test_ART_CAR_q10.yml: -------------------------------------------------------------------------------- 1 | name: test_ART_CAR_q10 2 | model_type: ARTModel 3 | scale: 1 4 | num_gpu: 1 5 | manual_seed: 0 6 | 7 | datasets: 8 | test_1: 9 | task: CAR 10 | name: Classic5 11 | type: PairedImageDataset 12 | dataroot_gt: datasets/CAR/classic5/Classic5_HQ 13 | dataroot_lq: datasets/CAR/classic5/Classic5_LQ/10 14 | filename_tmpl: '{}' 15 | io_backend: 16 | type: disk 17 | 18 | test_2: 19 | task: CAR 20 | name: LIVE1 21 | type: PairedImageDataset 22 | dataroot_gt: datasets/CAR/LIVE1/LIVE1_HQ 23 | dataroot_lq: datasets/CAR/LIVE1/LIVE1_LQ/10 24 | filename_tmpl: '{}' 25 | io_backend: 26 | type: disk 27 | 28 | # network structures 29 | network_g: 30 | type: ART 31 | upscale: 1 32 | in_chans: 1 33 | img_size: 126 34 | window_size: 7 35 | img_range: 255. 36 | depths: [6, 6, 6, 6, 6, 6] 37 | interval: [18, 18, 13, 13, 7, 7] 38 | embed_dim: 180 39 | num_heads: [6, 6, 6, 6, 6, 6] 40 | mlp_ratio: 4 41 | 42 | 43 | # path 44 | path: 45 | pretrain_network_g: experiments/pretrained_models/CAR_ART_q10.pth 46 | strict_load_g: true 47 | 48 | # validation settings 49 | val: 50 | save_img: false 51 | suffix: ~ # add suffix to saved images, if None, use exp name 52 | selfensemble_testing: false 53 | patchwise_testing: false 54 | 55 | metrics: 56 | psnr: # metric name, can be arbitrary 57 | type: calculate_psnr 58 | crop_border: 0 59 | test_y_channel: true 60 | ssim: 61 | type: calculate_ssim 62 | crop_border: 0 63 | test_y_channel: true 64 | 65 | -------------------------------------------------------------------------------- /options/test/test_ART_CAR_q30.yml: -------------------------------------------------------------------------------- 1 | name: test_ART_CAR_q30 2 | model_type: ARTModel 3 | scale: 1 4 | num_gpu: 1 5 | manual_seed: 0 6 | 7 | datasets: 8 | test_1: 9 | task: CAR 10 | name: Classic5 11 | type: PairedImageDataset 12 | dataroot_gt: datasets/CAR/classic5/Classic5_HQ 13 | dataroot_lq: datasets/CAR/classic5/Classic5_LQ/30 14 | filename_tmpl: '{}' 15 | io_backend: 16 | type: disk 17 | 18 | test_2: 19 | task: CAR 20 | name: LIVE1 21 | type: PairedImageDataset 22 | dataroot_gt: datasets/CAR/LIVE1/LIVE1_HQ 23 | dataroot_lq: datasets/CAR/LIVE1/LIVE1_LQ/30 24 | filename_tmpl: '{}' 25 | io_backend: 26 | type: disk 27 | 28 | # network structures 29 | network_g: 30 | type: ART 31 | upscale: 1 32 | in_chans: 1 33 | img_size: 126 34 | window_size: 7 35 | img_range: 255. 36 | depths: [6, 6, 6, 6, 6, 6] 37 | interval: [18, 18, 13, 13, 7, 7] 38 | embed_dim: 180 39 | num_heads: [6, 6, 6, 6, 6, 6] 40 | mlp_ratio: 4 41 | 42 | 43 | # path 44 | path: 45 | pretrain_network_g: experiments/pretrained_models/CAR_ART_q30.pth 46 | strict_load_g: true 47 | 48 | # validation settings 49 | val: 50 | save_img: false 51 | suffix: ~ # add suffix to saved images, if None, use exp name 52 | selfensemble_testing: false 53 | patchwise_testing: false 54 | 55 | metrics: 56 | psnr: # metric name, can be arbitrary 57 | type: calculate_psnr 58 | crop_border: 0 59 | test_y_channel: true 60 | ssim: 61 | type: calculate_ssim 62 | crop_border: 0 63 | test_y_channel: true 64 | 65 | -------------------------------------------------------------------------------- /options/test/test_ART_CAR_q40.yml: -------------------------------------------------------------------------------- 1 | name: test_ART_CAR_q40 2 | model_type: ARTModel 3 | scale: 1 4 | num_gpu: 1 5 | manual_seed: 0 6 | 7 | datasets: 8 | test_1: 9 | task: CAR 10 | name: Classic5 11 | type: PairedImageDataset 12 | dataroot_gt: datasets/CAR/classic5/Classic5_HQ 13 | dataroot_lq: datasets/CAR/classic5/Classic5_LQ/40 14 | filename_tmpl: '{}' 15 | io_backend: 16 | type: disk 17 | 18 | test_2: 19 | task: CAR 20 | name: LIVE1 21 | type: PairedImageDataset 22 | dataroot_gt: datasets/CAR/LIVE1/LIVE1_HQ 23 | dataroot_lq: datasets/CAR/LIVE1/LIVE1_LQ/40 24 | filename_tmpl: '{}' 25 | io_backend: 26 | type: disk 27 | 28 | # network structures 29 | network_g: 30 | type: ART 31 | upscale: 1 32 | in_chans: 1 33 | img_size: 126 34 | window_size: 7 35 | img_range: 255. 36 | depths: [6, 6, 6, 6, 6, 6] 37 | interval: [18, 18, 13, 13, 7, 7] 38 | embed_dim: 180 39 | num_heads: [6, 6, 6, 6, 6, 6] 40 | mlp_ratio: 4 41 | 42 | 43 | # path 44 | path: 45 | pretrain_network_g: experiments/pretrained_models/CAR_ART_q40.pth 46 | strict_load_g: true 47 | 48 | # validation settings 49 | val: 50 | save_img: false 51 | suffix: ~ # add suffix to saved images, if None, use exp name 52 | selfensemble_testing: false 53 | patchwise_testing: false 54 | 55 | metrics: 56 | psnr: # metric name, can be arbitrary 57 | type: calculate_psnr 58 | crop_border: 0 59 | test_y_channel: true 60 | ssim: 61 | type: calculate_ssim 62 | crop_border: 0 63 | test_y_channel: true 64 | 65 | -------------------------------------------------------------------------------- /options/test/test_ART_ColorDN_level15.yml: -------------------------------------------------------------------------------- 1 | name: test_ART_ColorDN_level15 2 | model_type: ARTModel 3 | scale: 1 4 | num_gpu: 1 5 | manual_seed: 0 6 | 7 | datasets: 8 | test_1: 9 | task: denoising_color 10 | name: CBSD68 11 | type: PairedImageDataset 12 | dataroot_gt: datasets/ColorDN/CBSD68HQ 13 | dataroot_lq: datasets/ColorDN/CBSD68HQ 14 | filename_tmpl: '{}' 15 | noise: 15 # 15/25/50 16 | io_backend: 17 | type: disk 18 | 19 | test_2: 20 | task: denoising_color 21 | name: Kodak24 22 | type: PairedImageDataset 23 | dataroot_gt: datasets/ColorDN/Kodak24HQ 24 | dataroot_lq: datasets/ColorDN/Kodak24HQ 25 | filename_tmpl: '{}' 26 | noise: 15 # 15/25/50 27 | io_backend: 28 | type: disk 29 | 30 | test_3: 31 | task: denoising_color 32 | name: McMaster 33 | type: PairedImageDataset 34 | dataroot_gt: datasets/ColorDN/McMasterHQ 35 | dataroot_lq: datasets/ColorDN/McMasterHQ 36 | filename_tmpl: '{}' 37 | noise: 15 # 15/25/50 38 | io_backend: 39 | type: disk 40 | 41 | test_4: 42 | task: denoising_color 43 | name: Urban100 44 | type: PairedImageDataset 45 | dataroot_gt: datasets/ColorDN/Urban100HQ 46 | dataroot_lq: datasets/ColorDN/Urban100HQ 47 | filename_tmpl: '{}' 48 | noise: 15 # 15/25/50 49 | io_backend: 50 | type: disk 51 | 52 | # network structures 53 | network_g: 54 | type: ART 55 | upscale: 1 56 | in_chans: 3 57 | img_size: 128 58 | window_size: 8 59 | img_range: 1. 60 | depths: [6, 6, 6, 6, 6, 6] 61 | interval: [16, 16, 12, 12, 8, 8] 62 | embed_dim: 180 63 | num_heads: [6, 6, 6, 6, 6, 6] 64 | mlp_ratio: 4 65 | 66 | 67 | # path 68 | path: 69 | pretrain_network_g: experiments/pretrained_models/ColorDN_ART_level15.pth 70 | 71 | # validation settings 72 | val: 73 | save_img: false 74 | suffix: ~ # add suffix to saved images, if None, use exp name 75 | selfensemble_testing: false 76 | patchwise_testing: false 77 | 78 | metrics: 79 | psnr: # metric name, can be arbitrary 80 | type: calculate_psnr 81 | crop_border: 0 82 | test_y_channel: false 83 | ssim: 84 | type: calculate_ssim 85 | crop_border: 0 86 | test_y_channel: false 87 | -------------------------------------------------------------------------------- /options/test/test_ART_ColorDN_level25.yml: -------------------------------------------------------------------------------- 1 | name: test_ART_ColorDN_level25 2 | model_type: ARTModel 3 | scale: 1 4 | num_gpu: 1 5 | manual_seed: 0 6 | 7 | datasets: 8 | test_1: 9 | task: denoising_color 10 | name: CBSD68 11 | type: PairedImageDataset 12 | dataroot_gt: datasets/ColorDN/CBSD68HQ 13 | dataroot_lq: datasets/ColorDN/CBSD68HQ 14 | filename_tmpl: '{}' 15 | noise: 25 # 15/25/50 16 | io_backend: 17 | type: disk 18 | 19 | test_2: 20 | task: denoising_color 21 | name: Kodak24 22 | type: PairedImageDataset 23 | dataroot_gt: datasets/ColorDN/Kodak24HQ 24 | dataroot_lq: datasets/ColorDN/Kodak24HQ 25 | filename_tmpl: '{}' 26 | noise: 25 # 15/25/50 27 | io_backend: 28 | type: disk 29 | 30 | test_3: 31 | task: denoising_color 32 | name: McMaster 33 | type: PairedImageDataset 34 | dataroot_gt: datasets/ColorDN/McMasterHQ 35 | dataroot_lq: datasets/ColorDN/McMasterHQ 36 | filename_tmpl: '{}' 37 | noise: 25 # 15/25/50 38 | io_backend: 39 | type: disk 40 | 41 | test_4: 42 | task: denoising_color 43 | name: Urban100 44 | type: PairedImageDataset 45 | dataroot_gt: datasets/ColorDN/Urban100HQ 46 | dataroot_lq: datasets/ColorDN/Urban100HQ 47 | filename_tmpl: '{}' 48 | noise: 25 # 15/25/50 49 | io_backend: 50 | type: disk 51 | 52 | # network structures 53 | network_g: 54 | type: ART 55 | upscale: 1 56 | in_chans: 3 57 | img_size: 128 58 | window_size: 8 59 | img_range: 1. 60 | depths: [6, 6, 6, 6, 6, 6] 61 | interval: [16, 16, 12, 12, 8, 8] 62 | embed_dim: 180 63 | num_heads: [6, 6, 6, 6, 6, 6] 64 | mlp_ratio: 4 65 | 66 | 67 | # path 68 | path: 69 | pretrain_network_g: experiments/pretrained_models/ColorDN_ART_level25.pth 70 | 71 | # validation settings 72 | val: 73 | save_img: false 74 | suffix: ~ # add suffix to saved images, if None, use exp name 75 | selfensemble_testing: false 76 | patchwise_testing: false 77 | 78 | metrics: 79 | psnr: # metric name, can be arbitrary 80 | type: calculate_psnr 81 | crop_border: 0 82 | test_y_channel: false 83 | ssim: 84 | type: calculate_ssim 85 | crop_border: 0 86 | test_y_channel: false 87 | -------------------------------------------------------------------------------- /options/test/test_ART_ColorDN_level50.yml: -------------------------------------------------------------------------------- 1 | name: test_ART_ColorDN_level50 2 | model_type: ARTModel 3 | scale: 1 4 | num_gpu: 1 5 | manual_seed: 0 6 | 7 | datasets: 8 | test_1: 9 | task: denoising_color 10 | name: CBSD68 11 | type: PairedImageDataset 12 | dataroot_gt: datasets/ColorDN/CBSD68HQ 13 | dataroot_lq: datasets/ColorDN/CBSD68HQ 14 | filename_tmpl: '{}' 15 | noise: 50 # 15/25/50 16 | io_backend: 17 | type: disk 18 | 19 | test_2: 20 | task: denoising_color 21 | name: Kodak24 22 | type: PairedImageDataset 23 | dataroot_gt: datasets/ColorDN/Kodak24HQ 24 | dataroot_lq: datasets/ColorDN/Kodak24HQ 25 | filename_tmpl: '{}' 26 | noise: 50 # 15/25/50 27 | io_backend: 28 | type: disk 29 | 30 | test_3: 31 | task: denoising_color 32 | name: McMaster 33 | type: PairedImageDataset 34 | dataroot_gt: datasets/ColorDN/McMasterHQ 35 | dataroot_lq: datasets/ColorDN/McMasterHQ 36 | filename_tmpl: '{}' 37 | noise: 50 # 15/25/50 38 | io_backend: 39 | type: disk 40 | 41 | test_4: 42 | task: denoising_color 43 | name: Urban100 44 | type: PairedImageDataset 45 | dataroot_gt: datasets/ColorDN/Urban100HQ 46 | dataroot_lq: datasets/ColorDN/Urban100HQ 47 | filename_tmpl: '{}' 48 | noise: 50 # 15/25/50 49 | io_backend: 50 | type: disk 51 | 52 | # network structures 53 | network_g: 54 | type: ART 55 | upscale: 1 56 | in_chans: 3 57 | img_size: 128 58 | window_size: 8 59 | img_range: 1. 60 | depths: [6, 6, 6, 6, 6, 6] 61 | interval: [16, 16, 12, 12, 8, 8] 62 | embed_dim: 180 63 | num_heads: [6, 6, 6, 6, 6, 6] 64 | mlp_ratio: 4 65 | 66 | 67 | # path 68 | path: 69 | pretrain_network_g: experiments/pretrained_models/ColorDN_ART_level50.pth 70 | 71 | # validation settings 72 | val: 73 | save_img: false 74 | suffix: ~ # add suffix to saved images, if None, use exp name 75 | selfensemble_testing: false 76 | patchwise_testing: false 77 | 78 | metrics: 79 | psnr: # metric name, can be arbitrary 80 | type: calculate_psnr 81 | crop_border: 0 82 | test_y_channel: false 83 | ssim: 84 | type: calculate_ssim 85 | crop_border: 0 86 | test_y_channel: false 87 | -------------------------------------------------------------------------------- /options/test/test_ART_SR_x2.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: test_ART_SR_x2 3 | model_type: ARTModel 4 | scale: 2 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | datasets: 9 | test_1: # the 1st test dataset 10 | name: Set5 11 | type: PairedImageDataset 12 | dataroot_gt: datasets/SR/Set5/HR 13 | dataroot_lq: datasets/SR/Set5/LR_bicubic/X2 14 | filename_tmpl: '{}x2' 15 | io_backend: 16 | type: disk 17 | 18 | test_2: # the 2nd test dataset 19 | name: Set14 20 | type: PairedImageDataset 21 | dataroot_gt: datasets/SR/Set14/HR 22 | dataroot_lq: datasets/SR/Set14/LR_bicubic/X2 23 | filename_tmpl: '{}x2' 24 | io_backend: 25 | type: disk 26 | 27 | test_3: # the 3rd test dataset 28 | name: B100 29 | type: PairedImageDataset 30 | dataroot_gt: datasets/SR/B100/HR 31 | dataroot_lq: datasets/SR/B100/LR_bicubic/X2 32 | filename_tmpl: '{}x2' 33 | io_backend: 34 | type: disk 35 | 36 | test_4: # the 4th test dataset 37 | name: Urban100 38 | type: PairedImageDataset 39 | dataroot_gt: datasets/SR/Urban100/HR 40 | dataroot_lq: datasets/SR/Urban100/LR_bicubic/X2 41 | filename_tmpl: '{}x2' 42 | io_backend: 43 | type: disk 44 | 45 | test_5: # the 5th test dataset 46 | name: Manga109 47 | type: PairedImageDataset 48 | dataroot_gt: datasets/SR/Manga109/HR 49 | dataroot_lq: datasets/SR/Manga109/LR_bicubic/X2 50 | filename_tmpl: '{}_LRBI_x2' 51 | io_backend: 52 | type: disk 53 | 54 | 55 | 56 | # network structures 57 | network_g: 58 | type: ART 59 | upscale: 2 60 | in_chans: 3 61 | img_size: 64 62 | window_size: 8 63 | img_range: 1. 64 | depths: [6, 6, 6, 6, 6, 6] 65 | interval: [4, 4, 4, 4, 4, 4] 66 | embed_dim: 180 67 | num_heads: [6, 6, 6, 6, 6, 6] 68 | mlp_ratio: 4 69 | upsampler: 'pixelshuffle' 70 | resi_connection: '1conv' 71 | 72 | # path 73 | path: 74 | pretrain_network_g: experiments/pretrained_models/SR_ART_x2.pth 75 | strict_load_g: true 76 | 77 | # validation settings 78 | val: 79 | save_img: true 80 | suffix: ~ # add suffix to saved images, if None, use exp name 81 | 82 | metrics: 83 | psnr: # metric name, can be arbitrary 84 | type: calculate_psnr 85 | crop_border: 2 86 | test_y_channel: true 87 | ssim: 88 | type: calculate_ssim 89 | crop_border: 2 90 | test_y_channel: true 91 | -------------------------------------------------------------------------------- /options/test/test_ART_SR_x3.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: test_ART_SR_x3 3 | model_type: ARTModel 4 | scale: 3 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | datasets: 9 | test_1: # the 1st test dataset 10 | name: Set5 11 | type: PairedImageDataset 12 | dataroot_gt: datasets/SR/Set5/HR 13 | dataroot_lq: datasets/SR/Set5/LR_bicubic/X3 14 | filename_tmpl: '{}x3' 15 | io_backend: 16 | type: disk 17 | 18 | test_2: # the 2nd test dataset 19 | name: Set14 20 | type: PairedImageDataset 21 | dataroot_gt: datasets/SR/Set14/HR 22 | dataroot_lq: datasets/SR/Set14/LR_bicubic/X3 23 | filename_tmpl: '{}x3' 24 | io_backend: 25 | type: disk 26 | 27 | test_3: # the 3rd test dataset 28 | name: B100 29 | type: PairedImageDataset 30 | dataroot_gt: datasets/SR/B100/HR 31 | dataroot_lq: datasets/SR/B100/LR_bicubic/X3 32 | filename_tmpl: '{}x3' 33 | io_backend: 34 | type: disk 35 | 36 | test_4: # the 4th test dataset 37 | name: Urban100 38 | type: PairedImageDataset 39 | dataroot_gt: datasets/SR/Urban100/HR 40 | dataroot_lq: datasets/SR/Urban100/LR_bicubic/X3 41 | filename_tmpl: '{}x3' 42 | io_backend: 43 | type: disk 44 | 45 | test_5: # the 5th test dataset 46 | name: Manga109 47 | type: PairedImageDataset 48 | dataroot_gt: datasets/SR/Manga109/HR 49 | dataroot_lq: datasets/SR/Manga109/LR_bicubic/X3 50 | filename_tmpl: '{}_LRBI_x3' 51 | io_backend: 52 | type: disk 53 | 54 | 55 | 56 | # network structures 57 | network_g: 58 | type: ART 59 | upscale: 3 60 | in_chans: 3 61 | img_size: 64 62 | window_size: 8 63 | img_range: 1. 64 | depths: [6, 6, 6, 6, 6, 6] 65 | interval: [4, 4, 4, 4, 4, 4] 66 | embed_dim: 180 67 | num_heads: [6, 6, 6, 6, 6, 6] 68 | mlp_ratio: 4 69 | upsampler: 'pixelshuffle' 70 | resi_connection: '1conv' 71 | 72 | # path 73 | path: 74 | pretrain_network_g: experiments/pretrained_models/SR_ART_x3.pth 75 | strict_load_g: true 76 | 77 | # validation settings 78 | val: 79 | save_img: true 80 | suffix: ~ # add suffix to saved images, if None, use exp name 81 | 82 | metrics: 83 | psnr: # metric name, can be arbitrary 84 | type: calculate_psnr 85 | crop_border: 3 86 | test_y_channel: true 87 | ssim: 88 | type: calculate_ssim 89 | crop_border: 3 90 | test_y_channel: true 91 | -------------------------------------------------------------------------------- /options/test/test_ART_SR_x4.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: test_ART_SR_x4 3 | model_type: ARTModel 4 | scale: 4 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | datasets: 9 | test_1: # the 1st test dataset 10 | name: Set5 11 | type: PairedImageDataset 12 | dataroot_gt: datasets/SR/Set5/HR 13 | dataroot_lq: datasets/SR/Set5/LR_bicubic/X4 14 | filename_tmpl: '{}x4' 15 | io_backend: 16 | type: disk 17 | 18 | test_2: # the 2nd test dataset 19 | name: Set14 20 | type: PairedImageDataset 21 | dataroot_gt: datasets/SR/Set14/HR 22 | dataroot_lq: datasets/SR/Set14/LR_bicubic/X4 23 | filename_tmpl: '{}x4' 24 | io_backend: 25 | type: disk 26 | 27 | test_3: # the 3rd test dataset 28 | name: B100 29 | type: PairedImageDataset 30 | dataroot_gt: datasets/SR/B100/HR 31 | dataroot_lq: datasets/SR/B100/LR_bicubic/X4 32 | filename_tmpl: '{}x4' 33 | io_backend: 34 | type: disk 35 | 36 | test_4: # the 4th test dataset 37 | name: Urban100 38 | type: PairedImageDataset 39 | dataroot_gt: datasets/SR/Urban100/HR 40 | dataroot_lq: datasets/SR/Urban100/LR_bicubic/X4 41 | filename_tmpl: '{}x4' 42 | io_backend: 43 | type: disk 44 | 45 | test_5: # the 5th test dataset 46 | name: Manga109 47 | type: PairedImageDataset 48 | dataroot_gt: datasets/SR/Manga109/HR 49 | dataroot_lq: datasets/SR/Manga109/LR_bicubic/X4 50 | filename_tmpl: '{}_LRBI_x4' 51 | io_backend: 52 | type: disk 53 | 54 | 55 | 56 | # network structures 57 | network_g: 58 | type: ART 59 | upscale: 4 60 | in_chans: 3 61 | img_size: 64 62 | window_size: 8 63 | img_range: 1. 64 | depths: [6, 6, 6, 6, 6, 6] 65 | interval: [4, 4, 4, 4, 4, 4] 66 | embed_dim: 180 67 | num_heads: [6, 6, 6, 6, 6, 6] 68 | mlp_ratio: 4 69 | upsampler: 'pixelshuffle' 70 | resi_connection: '1conv' 71 | 72 | # path 73 | path: 74 | pretrain_network_g: experiments/pretrained_models/SR_ART_x4.pth 75 | strict_load_g: true 76 | 77 | # validation settings 78 | val: 79 | save_img: true 80 | suffix: ~ # add suffix to saved images, if None, use exp name 81 | 82 | metrics: 83 | psnr: # metric name, can be arbitrary 84 | type: calculate_psnr 85 | crop_border: 4 86 | test_y_channel: true 87 | ssim: 88 | type: calculate_ssim 89 | crop_border: 4 90 | test_y_channel: true 91 | -------------------------------------------------------------------------------- /options/test/test_ART_S_SR_x2.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: test_ART_S_SR_x2 3 | model_type: ARTModel 4 | scale: 2 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | datasets: 9 | test_1: # the 1st test dataset 10 | name: Set5 11 | type: PairedImageDataset 12 | dataroot_gt: datasets/SR/Set5/HR 13 | dataroot_lq: datasets/SR/Set5/LR_bicubic/X2 14 | filename_tmpl: '{}x2' 15 | io_backend: 16 | type: disk 17 | 18 | test_2: # the 2nd test dataset 19 | name: Set14 20 | type: PairedImageDataset 21 | dataroot_gt: datasets/SR/Set14/HR 22 | dataroot_lq: datasets/SR/Set14/LR_bicubic/X2 23 | filename_tmpl: '{}x2' 24 | io_backend: 25 | type: disk 26 | 27 | test_3: # the 3rd test dataset 28 | name: B100 29 | type: PairedImageDataset 30 | dataroot_gt: datasets/SR/B100/HR 31 | dataroot_lq: datasets/SR/B100/LR_bicubic/X2 32 | filename_tmpl: '{}x2' 33 | io_backend: 34 | type: disk 35 | 36 | test_4: # the 4th test dataset 37 | name: Urban100 38 | type: PairedImageDataset 39 | dataroot_gt: datasets/SR/Urban100/HR 40 | dataroot_lq: datasets/SR/Urban100/LR_bicubic/X2 41 | filename_tmpl: '{}x2' 42 | io_backend: 43 | type: disk 44 | 45 | test_5: # the 5th test dataset 46 | name: Manga109 47 | type: PairedImageDataset 48 | dataroot_gt: datasets/SR/Manga109/HR 49 | dataroot_lq: datasets/SR/Manga109/LR_bicubic/X2 50 | filename_tmpl: '{}_LRBI_x2' 51 | io_backend: 52 | type: disk 53 | 54 | 55 | 56 | # network structures 57 | network_g: 58 | type: ART 59 | upscale: 2 60 | in_chans: 3 61 | img_size: 64 62 | window_size: 8 63 | img_range: 1. 64 | depths: [6, 6, 6, 6, 6, 6] 65 | interval: [8, 8, 8, 8, 8, 8] 66 | embed_dim: 180 67 | num_heads: [6, 6, 6, 6, 6, 6] 68 | mlp_ratio: 2 69 | upsampler: 'pixelshuffle' 70 | resi_connection: '1conv' 71 | 72 | # path 73 | path: 74 | pretrain_network_g: experiments/pretrained_models/SR_ART_S_x2.pth 75 | strict_load_g: true 76 | 77 | # validation settings 78 | val: 79 | save_img: true 80 | suffix: ~ # add suffix to saved images, if None, use exp name 81 | 82 | metrics: 83 | psnr: # metric name, can be arbitrary 84 | type: calculate_psnr 85 | crop_border: 2 86 | test_y_channel: true 87 | ssim: 88 | type: calculate_ssim 89 | crop_border: 2 90 | test_y_channel: true 91 | -------------------------------------------------------------------------------- /options/test/test_ART_S_SR_x3.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: test_ART_S_SR_x3 3 | model_type: ARTModel 4 | scale: 3 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | datasets: 9 | test_1: # the 1st test dataset 10 | name: Set5 11 | type: PairedImageDataset 12 | dataroot_gt: datasets/SR/Set5/HR 13 | dataroot_lq: datasets/SR/Set5/LR_bicubic/X3 14 | filename_tmpl: '{}x3' 15 | io_backend: 16 | type: disk 17 | 18 | test_2: # the 2nd test dataset 19 | name: Set14 20 | type: PairedImageDataset 21 | dataroot_gt: datasets/SR/Set14/HR 22 | dataroot_lq: datasets/SR/Set14/LR_bicubic/X3 23 | filename_tmpl: '{}x3' 24 | io_backend: 25 | type: disk 26 | 27 | test_3: # the 3rd test dataset 28 | name: B100 29 | type: PairedImageDataset 30 | dataroot_gt: datasets/SR/B100/HR 31 | dataroot_lq: datasets/SR/B100/LR_bicubic/X3 32 | filename_tmpl: '{}x3' 33 | io_backend: 34 | type: disk 35 | 36 | test_4: # the 4th test dataset 37 | name: Urban100 38 | type: PairedImageDataset 39 | dataroot_gt: datasets/SR/Urban100/HR 40 | dataroot_lq: datasets/SR/Urban100/LR_bicubic/X3 41 | filename_tmpl: '{}x3' 42 | io_backend: 43 | type: disk 44 | 45 | test_5: # the 5th test dataset 46 | name: Manga109 47 | type: PairedImageDataset 48 | dataroot_gt: datasets/SR/Manga109/HR 49 | dataroot_lq: datasets/SR/Manga109/LR_bicubic/X3 50 | filename_tmpl: '{}_LRBI_x3' 51 | io_backend: 52 | type: disk 53 | 54 | 55 | 56 | # network structures 57 | network_g: 58 | type: ART 59 | upscale: 3 60 | in_chans: 3 61 | img_size: 64 62 | window_size: 8 63 | img_range: 1. 64 | depths: [6, 6, 6, 6, 6, 6] 65 | interval: [8, 8, 8, 8, 8, 8] 66 | embed_dim: 180 67 | num_heads: [6, 6, 6, 6, 6, 6] 68 | mlp_ratio: 2 69 | upsampler: 'pixelshuffle' 70 | resi_connection: '1conv' 71 | 72 | # path 73 | path: 74 | pretrain_network_g: experiments/pretrained_models/SR_ART_x3.pth 75 | strict_load_g: true 76 | 77 | # validation settings 78 | val: 79 | save_img: true 80 | suffix: ~ # add suffix to saved images, if None, use exp name 81 | 82 | metrics: 83 | psnr: # metric name, can be arbitrary 84 | type: calculate_psnr 85 | crop_border: 3 86 | test_y_channel: true 87 | ssim: 88 | type: calculate_ssim 89 | crop_border: 3 90 | test_y_channel: true 91 | -------------------------------------------------------------------------------- /options/test/test_ART_S_SR_x4.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: test_ART_S_SR_x4 3 | model_type: ARTModel 4 | scale: 4 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | datasets: 9 | test_1: # the 1st test dataset 10 | name: Set5 11 | type: PairedImageDataset 12 | dataroot_gt: datasets/SR/Set5/HR 13 | dataroot_lq: datasets/SR/Set5/LR_bicubic/X4 14 | filename_tmpl: '{}x4' 15 | io_backend: 16 | type: disk 17 | 18 | test_2: # the 2nd test dataset 19 | name: Set14 20 | type: PairedImageDataset 21 | dataroot_gt: datasets/SR/Set14/HR 22 | dataroot_lq: datasets/SR/Set14/LR_bicubic/X4 23 | filename_tmpl: '{}x4' 24 | io_backend: 25 | type: disk 26 | 27 | test_3: # the 3rd test dataset 28 | name: B100 29 | type: PairedImageDataset 30 | dataroot_gt: datasets/SR/B100/HR 31 | dataroot_lq: datasets/SR/B100/LR_bicubic/X4 32 | filename_tmpl: '{}x4' 33 | io_backend: 34 | type: disk 35 | 36 | test_4: # the 4th test dataset 37 | name: Urban100 38 | type: PairedImageDataset 39 | dataroot_gt: datasets/SR/Urban100/HR 40 | dataroot_lq: datasets/SR/Urban100/LR_bicubic/X4 41 | filename_tmpl: '{}x4' 42 | io_backend: 43 | type: disk 44 | 45 | test_5: # the 5th test dataset 46 | name: Manga109 47 | type: PairedImageDataset 48 | dataroot_gt: datasets/SR/Manga109/HR 49 | dataroot_lq: datasets/SR/Manga109/LR_bicubic/X4 50 | filename_tmpl: '{}_LRBI_x4' 51 | io_backend: 52 | type: disk 53 | 54 | 55 | 56 | # network structures 57 | network_g: 58 | type: ART 59 | upscale: 4 60 | in_chans: 3 61 | img_size: 64 62 | window_size: 8 63 | img_range: 1. 64 | depths: [6, 6, 6, 6, 6, 6] 65 | interval: [8, 8, 8, 8, 8, 8] 66 | embed_dim: 180 67 | num_heads: [6, 6, 6, 6, 6, 6] 68 | mlp_ratio: 2 69 | upsampler: 'pixelshuffle' 70 | resi_connection: '1conv' 71 | 72 | # path 73 | path: 74 | pretrain_network_g: experiments/pretrained_models/SR_ART_S_x4.pth 75 | strict_load_g: true 76 | 77 | # validation settings 78 | val: 79 | save_img: true 80 | suffix: ~ # add suffix to saved images, if None, use exp name 81 | 82 | metrics: 83 | psnr: # metric name, can be arbitrary 84 | type: calculate_psnr 85 | crop_border: 4 86 | test_y_channel: true 87 | ssim: 88 | type: calculate_ssim 89 | crop_border: 4 90 | test_y_channel: true 91 | -------------------------------------------------------------------------------- /options/train/train_ART_CAR_q10.yml: -------------------------------------------------------------------------------- 1 | # general settings for CAR training 2 | name: ART_CAR_q10 3 | model_type: ARTModel 4 | scale: 1 5 | num_gpu: 4 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | task: CAR 12 | name: DFWB_CAR 13 | type: PairedImageDataset 14 | dataroot_gt: datasets/DFWB_CAR/HQ 15 | dataroot_lq: datasets/DFWB_CAR/LQ/10 16 | filename_tmpl: '{}' 17 | io_backend: 18 | type: disk 19 | 20 | gt_size: 126 21 | use_hflip: true 22 | use_rot: true 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 2 28 | dataset_enlarge_ratio: 100 29 | prefetch_mode: ~ 30 | 31 | val: 32 | task: CAR 33 | name: Classic5 34 | type: PairedImageDataset 35 | dataroot_gt: datasets/CAR/Classic5/Classic5_HQ 36 | dataroot_lq: datasets/CAR/Classic5/Classic5_LQ/10 37 | filename_tmpl: '{}' 38 | io_backend: 39 | type: disk 40 | 41 | # network structures 42 | network_g: 43 | type: ART 44 | upscale: 1 45 | in_chans: 1 46 | img_size: 126 47 | window_size: 7 48 | img_range: 255. 49 | depths: [6, 6, 6, 6, 6, 6] 50 | interval: [18, 18, 13, 13, 7, 7] 51 | embed_dim: 180 52 | num_heads: [6, 6, 6, 6, 6, 6] 53 | mlp_ratio: 4 54 | 55 | # path 56 | path: 57 | pretrain_network_g: experiments/pretrained_models/CAR_ART_q40.pth # save training time if we finetune from quality 40 and halve initial lr. 58 | strict_load_g: true 59 | resume_state: ~ 60 | 61 | # training settings 62 | train: 63 | optim_g: 64 | type: Adam 65 | # lr: !!float 2e-4 # for jpeg 40 66 | lr: !!float 1e-4 # for jpeg 10/30 67 | weight_decay: 0 68 | betas: [0.9, 0.99] 69 | 70 | scheduler: 71 | type: MultiStepLR 72 | # milestones: [ 800000, 1200000, 1400000, 1500000 ] # for jpeg 40 73 | milestones: [ 400000, 600000, 700000, 750000 ] # for jpeg 10/30 74 | gamma: 0.5 75 | 76 | # total_iter: 1600000 # for jpeg 40 77 | total_iter: 800000 # for jpeg 10/30 78 | warmup_iter: -1 # no warm up 79 | 80 | # losses 81 | pixel_opt: 82 | type: CharbonnierLoss 83 | loss_weight: 1.0 84 | reduction: mean 85 | eps: !!float 1e-3 86 | 87 | # validation settings 88 | val: 89 | val_freq: !!float 5e3 90 | save_img: false 91 | 92 | metrics: 93 | psnr: # metric name, can be arbitrary 94 | type: calculate_psnr 95 | crop_border: 0 96 | test_y_channel: true 97 | 98 | # logging settings 99 | logger: 100 | print_freq: 200 101 | save_checkpoint_freq: !!float 2e4 102 | use_tb_logger: true 103 | wandb: 104 | project: ~ 105 | resume_id: ~ 106 | 107 | # dist training settings 108 | dist_params: 109 | backend: nccl 110 | port: 29500 111 | -------------------------------------------------------------------------------- /options/train/train_ART_CAR_q30.yml: -------------------------------------------------------------------------------- 1 | # general settings for CAR training 2 | name: ART_CAR_q30 3 | model_type: ARTModel 4 | scale: 1 5 | num_gpu: 4 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | task: CAR 12 | name: DFWB_CAR 13 | type: PairedImageDataset 14 | dataroot_gt: datasets/DFWB_CAR/HQ 15 | dataroot_lq: datasets/DFWB_CAR/LQ/30 16 | filename_tmpl: '{}' 17 | io_backend: 18 | type: disk 19 | 20 | gt_size: 126 21 | use_hflip: true 22 | use_rot: true 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 2 28 | dataset_enlarge_ratio: 100 29 | prefetch_mode: ~ 30 | 31 | val: 32 | task: CAR 33 | name: Classic5 34 | type: PairedImageDataset 35 | dataroot_gt: datasets/CAR/Classic5/Classic5_HQ 36 | dataroot_lq: datasets/CAR/Classic5/Classic5_LQ/30 37 | filename_tmpl: '{}' 38 | io_backend: 39 | type: disk 40 | 41 | # network structures 42 | network_g: 43 | type: ART 44 | upscale: 1 45 | in_chans: 1 46 | img_size: 126 47 | window_size: 7 48 | img_range: 255. 49 | depths: [6, 6, 6, 6, 6, 6] 50 | interval: [18, 18, 13, 13, 7, 7] 51 | embed_dim: 180 52 | num_heads: [6, 6, 6, 6, 6, 6] 53 | mlp_ratio: 4 54 | 55 | # path 56 | path: 57 | pretrain_network_g: experiments/pretrained_models/CAR_ART_q40.pth # save training time if we finetune from quality 40 and halve initial lr. 58 | strict_load_g: true 59 | resume_state: ~ 60 | 61 | # training settings 62 | train: 63 | optim_g: 64 | type: Adam 65 | # lr: !!float 2e-4 # for jpeg 40 66 | lr: !!float 1e-4 # for jpeg 10/30 67 | weight_decay: 0 68 | betas: [0.9, 0.99] 69 | 70 | scheduler: 71 | type: MultiStepLR 72 | # milestones: [ 800000, 1200000, 1400000, 1500000 ] # for jpeg 40 73 | milestones: [ 400000, 600000, 700000, 750000 ] # for jpeg 10/30 74 | gamma: 0.5 75 | 76 | # total_iter: 1600000 # for jpeg 40 77 | total_iter: 800000 # for jpeg 10/30 78 | warmup_iter: -1 # no warm up 79 | 80 | # losses 81 | pixel_opt: 82 | type: CharbonnierLoss 83 | loss_weight: 1.0 84 | reduction: mean 85 | eps: !!float 1e-3 86 | 87 | # validation settings 88 | val: 89 | val_freq: !!float 5e3 90 | save_img: false 91 | 92 | metrics: 93 | psnr: # metric name, can be arbitrary 94 | type: calculate_psnr 95 | crop_border: 0 96 | test_y_channel: true 97 | 98 | # logging settings 99 | logger: 100 | print_freq: 200 101 | save_checkpoint_freq: !!float 2e4 102 | use_tb_logger: true 103 | wandb: 104 | project: ~ 105 | resume_id: ~ 106 | 107 | # dist training settings 108 | dist_params: 109 | backend: nccl 110 | port: 29500 111 | -------------------------------------------------------------------------------- /options/train/train_ART_CAR_q40.yml: -------------------------------------------------------------------------------- 1 | # general settings for CAR training 2 | name: ART_CAR_q40 3 | model_type: ARTModel 4 | scale: 1 5 | num_gpu: 4 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | task: CAR 12 | name: DFWB_CAR 13 | type: PairedImageDataset 14 | dataroot_gt: datasets/DFWB_CAR/HQ 15 | dataroot_lq: datasets/DFWB_CAR/LQ/40 16 | filename_tmpl: '{}' 17 | io_backend: 18 | type: disk 19 | 20 | gt_size: 126 21 | use_hflip: true 22 | use_rot: true 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 2 28 | dataset_enlarge_ratio: 100 29 | prefetch_mode: ~ 30 | 31 | val: 32 | task: CAR 33 | name: Classic5 34 | type: PairedImageDataset 35 | dataroot_gt: datasets/CAR/Classic5/Classic5_HQ 36 | dataroot_lq: datasets/CAR/Classic5/Classic5_LQ/40 37 | filename_tmpl: '{}' 38 | io_backend: 39 | type: disk 40 | 41 | # network structures 42 | network_g: 43 | type: ART 44 | upscale: 1 45 | in_chans: 1 46 | img_size: 126 47 | window_size: 7 48 | img_range: 255. 49 | depths: [6, 6, 6, 6, 6, 6] 50 | interval: [18, 18, 13, 13, 7, 7] 51 | embed_dim: 180 52 | num_heads: [6, 6, 6, 6, 6, 6] 53 | mlp_ratio: 4 54 | 55 | # path 56 | path: 57 | pretrain_network_g: ~ 58 | strict_load_g: true 59 | resume_state: ~ 60 | 61 | # training settings 62 | train: 63 | optim_g: 64 | type: Adam 65 | lr: !!float 2e-4 # for jpeg 40 66 | weight_decay: 0 67 | betas: [0.9, 0.99] 68 | 69 | scheduler: 70 | type: MultiStepLR 71 | milestones: [ 800000, 1200000, 1400000, 1500000 ] # for jpeg 40 72 | gamma: 0.5 73 | 74 | total_iter: 1600000 # for jpeg 40 75 | warmup_iter: -1 # no warm up 76 | 77 | # losses 78 | pixel_opt: 79 | type: CharbonnierLoss 80 | loss_weight: 1.0 81 | reduction: mean 82 | eps: !!float 1e-3 83 | 84 | # validation settings 85 | val: 86 | val_freq: !!float 5e3 87 | save_img: false 88 | 89 | metrics: 90 | psnr: # metric name, can be arbitrary 91 | type: calculate_psnr 92 | crop_border: 0 93 | test_y_channel: true 94 | 95 | # logging settings 96 | logger: 97 | print_freq: 200 98 | save_checkpoint_freq: !!float 2e4 99 | use_tb_logger: true 100 | wandb: 101 | project: ~ 102 | resume_id: ~ 103 | 104 | # dist training settings 105 | dist_params: 106 | backend: nccl 107 | port: 29500 108 | -------------------------------------------------------------------------------- /options/train/train_ART_ColorDN_level15.yml: -------------------------------------------------------------------------------- 1 | # general settings for ColorDN training 2 | name: ART_ColorDN_level15 3 | model_type: ARTModel 4 | scale: 1 5 | num_gpu: 4 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | task: denoising_color 12 | noise: 15 # 15/25/50 13 | name: DFWB_RGB 14 | type: PairedImageDataset 15 | dataroot_gt: datasets/DFWB_RGB/HQ 16 | dataroot_lq: datasets/DFWB_RGB/HQ 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | gt_size: 128 22 | use_hflip: true 23 | use_rot: true 24 | 25 | # data loader 26 | use_shuffle: true 27 | num_worker_per_gpu: 8 28 | batch_size_per_gpu: 2 29 | dataset_enlarge_ratio: 100 30 | prefetch_mode: ~ 31 | 32 | val: 33 | task: denoising_color 34 | noise: 15 # 15/25/50 35 | name: McMaster 36 | type: PairedImageDataset 37 | dataroot_gt: datasets/ColorDN/McMasterHQ 38 | dataroot_lq: datasets/ColorDN/McMasterHQ 39 | filename_tmpl: '{}' 40 | io_backend: 41 | type: disk 42 | 43 | # network structures 44 | network_g: 45 | type: ART 46 | upscale: 1 47 | in_chans: 3 48 | img_size: 128 49 | window_size: 8 50 | img_range: 1. 51 | depths: [6, 6, 6, 6, 6, 6] 52 | interval: [16, 16, 12, 12, 8, 8] 53 | embed_dim: 180 54 | num_heads: [6, 6, 6, 6, 6, 6] 55 | mlp_ratio: 4 56 | 57 | 58 | path: 59 | pretrain_network_g: ~ 60 | strict_load_g: true 61 | resume_state: ~ 62 | 63 | # training settings 64 | train: 65 | optim_g: 66 | type: Adam 67 | lr: !!float 2e-4 # for noise 15 68 | # lr: !!float 1e-4 # for noise 25/50 69 | weight_decay: 0 70 | betas: [0.9, 0.99] 71 | 72 | scheduler: 73 | type: MultiStepLR 74 | milestones: [ 800000, 1200000, 1400000, 1500000 ] # for noise 15 75 | # milestones: [ 400000, 600000, 700000, 750000 ] # for noise 25/50 76 | gamma: 0.5 77 | 78 | total_iter: 1600000 # for noise 15 79 | # total_iter: 800000 # for noise 25/50 80 | warmup_iter: -1 # no warm up 81 | 82 | # losses 83 | pixel_opt: 84 | type: CharbonnierLoss 85 | loss_weight: 1.0 86 | reduction: mean 87 | eps: !!float 1e-3 88 | 89 | # validation settings 90 | val: 91 | val_freq: !!float 5e3 92 | save_img: false 93 | 94 | metrics: 95 | psnr: # metric name, can be arbitrary 96 | type: calculate_psnr 97 | crop_border: 0 98 | test_y_channel: false 99 | 100 | # logging settings 101 | logger: 102 | print_freq: 200 103 | save_checkpoint_freq: !!float 2e4 104 | use_tb_logger: true 105 | wandb: 106 | project: ~ 107 | resume_id: ~ 108 | 109 | # dist training settings 110 | dist_params: 111 | backend: nccl 112 | port: 29500 -------------------------------------------------------------------------------- /options/train/train_ART_ColorDN_level25.yml: -------------------------------------------------------------------------------- 1 | # general settings for ColorDN training 2 | name: ART_ColorDN_level25 3 | model_type: ARTModel 4 | scale: 1 5 | num_gpu: 4 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | task: denoising_color 12 | noise: 25 # 15/25/50 13 | name: DFWB_RGB 14 | type: PairedImageDataset 15 | dataroot_gt: datasets/DFWB_RGB/HQ 16 | dataroot_lq: datasets/DFWB_RGB/HQ 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | gt_size: 128 22 | use_hflip: true 23 | use_rot: true 24 | 25 | # data loader 26 | use_shuffle: true 27 | num_worker_per_gpu: 8 28 | batch_size_per_gpu: 2 29 | dataset_enlarge_ratio: 100 30 | prefetch_mode: ~ 31 | 32 | val: 33 | task: denoising_color 34 | noise: 25 # 15/25/50 35 | name: McMaster 36 | type: PairedImageDataset 37 | dataroot_gt: datasets/ColorDN/McMasterHQ 38 | dataroot_lq: datasets/ColorDN/McMasterHQ 39 | filename_tmpl: '{}' 40 | io_backend: 41 | type: disk 42 | 43 | # network structures 44 | network_g: 45 | type: ART 46 | upscale: 1 47 | in_chans: 3 48 | img_size: 128 49 | window_size: 8 50 | img_range: 1. 51 | depths: [6, 6, 6, 6, 6, 6] 52 | interval: [16, 16, 12, 12, 8, 8] 53 | embed_dim: 180 54 | num_heads: [6, 6, 6, 6, 6, 6] 55 | mlp_ratio: 4 56 | 57 | 58 | path: 59 | pretrain_network_g: experiments/pretrained_models/ColorDN_ART_level15.pth # save training time if we finetune from noise 15 and halve initial lr. 60 | strict_load_g: true 61 | resume_state: ~ 62 | 63 | # training settings 64 | train: 65 | optim_g: 66 | type: Adam 67 | # lr: !!float 2e-4 # for noise 15 68 | lr: !!float 1e-4 # for noise 25/50 69 | weight_decay: 0 70 | betas: [0.9, 0.99] 71 | 72 | scheduler: 73 | type: MultiStepLR 74 | # milestones: [ 800000, 1200000, 1400000, 1500000 ] # for noise 15 75 | milestones: [ 400000, 600000, 700000, 750000 ] # for noise 25/50 76 | gamma: 0.5 77 | 78 | # total_iter: 1600000 # for noise 15 79 | total_iter: 800000 # for noise 25/50 80 | warmup_iter: -1 # no warm up 81 | 82 | # losses 83 | pixel_opt: 84 | type: CharbonnierLoss 85 | loss_weight: 1.0 86 | reduction: mean 87 | eps: !!float 1e-3 88 | 89 | # validation settings 90 | val: 91 | val_freq: !!float 5e3 92 | save_img: false 93 | 94 | metrics: 95 | psnr: # metric name, can be arbitrary 96 | type: calculate_psnr 97 | crop_border: 0 98 | test_y_channel: false 99 | 100 | # logging settings 101 | logger: 102 | print_freq: 200 103 | save_checkpoint_freq: !!float 2e4 104 | use_tb_logger: true 105 | wandb: 106 | project: ~ 107 | resume_id: ~ 108 | 109 | # dist training settings 110 | dist_params: 111 | backend: nccl 112 | port: 29500 -------------------------------------------------------------------------------- /options/train/train_ART_ColorDN_level50.yml: -------------------------------------------------------------------------------- 1 | # general settings for ColorDN training 2 | name: ART_ColorDN_level50 3 | model_type: ARTModel 4 | scale: 1 5 | num_gpu: 4 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | task: denoising_color 12 | noise: 50 # 15/25/50 13 | name: DFWB_RGB 14 | type: PairedImageDataset 15 | dataroot_gt: datasets/DFWB_RGB/HQ 16 | dataroot_lq: datasets/DFWB_RGB/HQ 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | gt_size: 128 22 | use_hflip: true 23 | use_rot: true 24 | 25 | # data loader 26 | use_shuffle: true 27 | num_worker_per_gpu: 8 28 | batch_size_per_gpu: 2 29 | dataset_enlarge_ratio: 100 30 | prefetch_mode: ~ 31 | 32 | val: 33 | task: denoising_color 34 | noise: 50 # 15/25/50 35 | name: McMaster 36 | type: PairedImageDataset 37 | dataroot_gt: datasets/ColorDN/McMasterHQ 38 | dataroot_lq: datasets/ColorDN/McMasterHQ 39 | filename_tmpl: '{}' 40 | io_backend: 41 | type: disk 42 | 43 | # network structures 44 | network_g: 45 | type: ART 46 | upscale: 1 47 | in_chans: 3 48 | img_size: 128 49 | window_size: 8 50 | img_range: 1. 51 | depths: [6, 6, 6, 6, 6, 6] 52 | interval: [16, 16, 12, 12, 8, 8] 53 | embed_dim: 180 54 | num_heads: [6, 6, 6, 6, 6, 6] 55 | mlp_ratio: 4 56 | 57 | 58 | path: 59 | pretrain_network_g: experiments/pretrained_models/ColorDN_ART_level15.pth # save training time if we finetune from noise 15 and halve initial lr. 60 | strict_load_g: true 61 | resume_state: ~ 62 | 63 | # training settings 64 | train: 65 | optim_g: 66 | type: Adam 67 | # lr: !!float 2e-4 # for noise 15 68 | lr: !!float 1e-4 # for noise 25/50 69 | weight_decay: 0 70 | betas: [0.9, 0.99] 71 | 72 | scheduler: 73 | type: MultiStepLR 74 | # milestones: [ 800000, 1200000, 1400000, 1500000 ] # for noise 15 75 | milestones: [ 400000, 600000, 700000, 750000 ] # for noise 25/50 76 | gamma: 0.5 77 | 78 | # total_iter: 1600000 # for noise 15 79 | total_iter: 800000 # for noise 25/50 80 | warmup_iter: -1 # no warm up 81 | 82 | # losses 83 | pixel_opt: 84 | type: CharbonnierLoss 85 | loss_weight: 1.0 86 | reduction: mean 87 | eps: !!float 1e-3 88 | 89 | # validation settings 90 | val: 91 | val_freq: !!float 5e3 92 | save_img: false 93 | 94 | metrics: 95 | psnr: # metric name, can be arbitrary 96 | type: calculate_psnr 97 | crop_border: 0 98 | test_y_channel: false 99 | 100 | # logging settings 101 | logger: 102 | print_freq: 200 103 | save_checkpoint_freq: !!float 2e4 104 | use_tb_logger: true 105 | wandb: 106 | project: ~ 107 | resume_id: ~ 108 | 109 | # dist training settings 110 | dist_params: 111 | backend: nccl 112 | port: 29500 -------------------------------------------------------------------------------- /options/train/train_ART_SR_x2.yml: -------------------------------------------------------------------------------- 1 | # general settings for image SR training 2 | name: ART_SR_x2 3 | model_type: ARTModel 4 | scale: 2 5 | num_gpu: 4 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | task: SR 12 | name: DF2K 13 | type: PairedImageDataset 14 | dataroot_gt: datasets/DF2K/HR 15 | dataroot_lq: datasets/DF2K/LR_bicubic/X2 16 | filename_tmpl: '{}x2' 17 | io_backend: 18 | type: disk 19 | 20 | gt_size: 128 21 | use_hflip: true 22 | use_rot: true 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | dataset_enlarge_ratio: 100 29 | prefetch_mode: ~ 30 | 31 | val: 32 | name: Set5 33 | type: PairedImageDataset 34 | dataroot_gt: datasets/SR/Set5/HR 35 | dataroot_lq: datasets/SR/Set5/LR_bicubic/X2 36 | filename_tmpl: '{}x2' 37 | io_backend: 38 | type: disk 39 | 40 | # network structures 41 | network_g: 42 | type: ART 43 | upscale: 2 44 | in_chans: 3 45 | img_size: 64 46 | window_size: 8 47 | img_range: 1. 48 | depths: [6, 6, 6, 6, 6, 6] 49 | interval: [4, 4, 4, 4, 4, 4] 50 | embed_dim: 180 51 | num_heads: [6, 6, 6, 6, 6, 6] 52 | mlp_ratio: 4 53 | upsampler: 'pixelshuffle' 54 | resi_connection: '1conv' 55 | 56 | # path 57 | path: 58 | pretrain_network_g: ~ 59 | strict_load_g: true 60 | resume_state: ~ 61 | 62 | # training settings 63 | train: 64 | optim_g: 65 | type: Adam 66 | lr: !!float 2e-4 67 | weight_decay: 0 68 | betas: [0.9, 0.99] 69 | 70 | scheduler: 71 | type: MultiStepLR 72 | milestones: [250000, 400000, 450000, 475000] 73 | gamma: 0.5 74 | 75 | total_iter: 500000 76 | warmup_iter: -1 # no warm up 77 | 78 | # losses 79 | pixel_opt: 80 | type: L1Loss 81 | loss_weight: 1.0 82 | reduction: mean 83 | # validation settings 84 | val: 85 | val_freq: !!float 5e3 86 | save_img: true 87 | 88 | metrics: 89 | psnr: # metric name, can be arbitrary 90 | type: calculate_psnr 91 | crop_border: 2 92 | test_y_channel: true 93 | 94 | 95 | 96 | # logging settings 97 | logger: 98 | print_freq: 200 99 | save_checkpoint_freq: !!float 5e3 100 | use_tb_logger: true 101 | wandb: 102 | project: ~ 103 | resume_id: ~ 104 | 105 | # dist training settings 106 | dist_params: 107 | backend: nccl 108 | port: 29500 -------------------------------------------------------------------------------- /options/train/train_ART_SR_x3.yml: -------------------------------------------------------------------------------- 1 | # general settings for image SR training 2 | name: ART_SR_x3 3 | model_type: ARTModel 4 | scale: 3 5 | num_gpu: 4 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | task: SR 12 | name: DF2K 13 | type: PairedImageDataset 14 | dataroot_gt: datasets/DF2K/HR 15 | dataroot_lq: datasets/DF2K/LR_bicubic/X3 16 | filename_tmpl: '{}x3' 17 | io_backend: 18 | type: disk 19 | 20 | gt_size: 192 21 | use_hflip: true 22 | use_rot: true 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | dataset_enlarge_ratio: 100 29 | prefetch_mode: ~ 30 | 31 | val: 32 | name: Set5 33 | type: PairedImageDataset 34 | dataroot_gt: datasets/SR/Set5/HR 35 | dataroot_lq: datasets/SR/Set5/LR_bicubic/X3 36 | filename_tmpl: '{}x3' 37 | io_backend: 38 | type: disk 39 | 40 | # network structures 41 | network_g: 42 | type: ART 43 | upscale: 3 44 | in_chans: 3 45 | img_size: 64 46 | window_size: 8 47 | img_range: 1. 48 | depths: [6, 6, 6, 6, 6, 6] 49 | interval: [4, 4, 4, 4, 4, 4] 50 | embed_dim: 180 51 | num_heads: [6, 6, 6, 6, 6, 6] 52 | mlp_ratio: 4 53 | upsampler: 'pixelshuffle' 54 | resi_connection: '1conv' 55 | 56 | # path 57 | path: 58 | pretrain_network_g: experiments/pretrained_models/SR_ART_x2.pth # save training time if we finetune from x2 and halve initial lr. 59 | strict_load_g: false 60 | resume_state: ~ 61 | 62 | # training settings 63 | train: 64 | optim_g: 65 | type: Adam 66 | # lr: !!float 2e-4 67 | lr: !!float 1e-4 68 | weight_decay: 0 69 | betas: [0.9, 0.99] 70 | 71 | scheduler: 72 | type: MultiStepLR 73 | # milestones: [ 250000, 400000, 450000, 475000 ] 74 | milestones: [ 125000, 200000, 225000, 237500 ] 75 | gamma: 0.5 76 | 77 | # total_iter: 500000 78 | total_iter: 250000 79 | warmup_iter: -1 # no warm up 80 | 81 | # losses 82 | pixel_opt: 83 | type: L1Loss 84 | loss_weight: 1.0 85 | reduction: mean 86 | # validation settings 87 | val: 88 | val_freq: !!float 5e3 89 | save_img: true 90 | 91 | metrics: 92 | psnr: # metric name, can be arbitrary 93 | type: calculate_psnr 94 | crop_border: 3 95 | test_y_channel: true 96 | 97 | 98 | 99 | # logging settings 100 | logger: 101 | print_freq: 200 102 | save_checkpoint_freq: !!float 5e3 103 | use_tb_logger: true 104 | wandb: 105 | project: ~ 106 | resume_id: ~ 107 | 108 | # dist training settings 109 | dist_params: 110 | backend: nccl 111 | port: 29500 -------------------------------------------------------------------------------- /options/train/train_ART_SR_x4.yml: -------------------------------------------------------------------------------- 1 | # general settings for image SR training 2 | name: ART_SR_x4 3 | model_type: ARTModel 4 | scale: 4 5 | num_gpu: 4 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | task: SR 12 | name: DF2K 13 | type: PairedImageDataset 14 | dataroot_gt: datasets/DF2K/HR 15 | dataroot_lq: datasets/DF2K/LR_bicubic/X4 16 | filename_tmpl: '{}x4' 17 | io_backend: 18 | type: disk 19 | 20 | gt_size: 256 21 | use_hflip: true 22 | use_rot: true 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | dataset_enlarge_ratio: 100 29 | prefetch_mode: ~ 30 | 31 | val: 32 | name: Set5 33 | type: PairedImageDataset 34 | dataroot_gt: datasets/SR/Set5/HR 35 | dataroot_lq: datasets/SR/Set5/LR_bicubic/X4 36 | filename_tmpl: '{}x4' 37 | io_backend: 38 | type: disk 39 | 40 | # network structures 41 | network_g: 42 | type: ART 43 | upscale: 4 44 | in_chans: 3 45 | img_size: 64 46 | window_size: 8 47 | img_range: 1. 48 | depths: [6, 6, 6, 6, 6, 6] 49 | interval: [4, 4, 4, 4, 4, 4] 50 | embed_dim: 180 51 | num_heads: [6, 6, 6, 6, 6, 6] 52 | mlp_ratio: 4 53 | upsampler: 'pixelshuffle' 54 | resi_connection: '1conv' 55 | 56 | # path 57 | path: 58 | pretrain_network_g: experiments/pretrained_models/SR_ART_x2.pth # save training time if we finetune from x2 and halve initial lr. 59 | strict_load_g: false 60 | resume_state: ~ 61 | 62 | # training settings 63 | train: 64 | optim_g: 65 | type: Adam 66 | # lr: !!float 2e-4 67 | lr: !!float 1e-4 68 | weight_decay: 0 69 | betas: [0.9, 0.99] 70 | 71 | scheduler: 72 | type: MultiStepLR 73 | # milestones: [ 250000, 400000, 450000, 475000 ] 74 | milestones: [ 125000, 200000, 225000, 237500 ] 75 | gamma: 0.5 76 | 77 | # total_iter: 500000 78 | total_iter: 250000 79 | warmup_iter: -1 # no warm up 80 | 81 | # losses 82 | pixel_opt: 83 | type: L1Loss 84 | loss_weight: 1.0 85 | reduction: mean 86 | # validation settings 87 | val: 88 | val_freq: !!float 5e3 89 | save_img: true 90 | 91 | metrics: 92 | psnr: # metric name, can be arbitrary 93 | type: calculate_psnr 94 | crop_border: 4 95 | test_y_channel: true 96 | 97 | 98 | 99 | # logging settings 100 | logger: 101 | print_freq: 200 102 | save_checkpoint_freq: !!float 5e3 103 | use_tb_logger: true 104 | wandb: 105 | project: ~ 106 | resume_id: ~ 107 | 108 | # dist training settings 109 | dist_params: 110 | backend: nccl 111 | port: 29500 -------------------------------------------------------------------------------- /options/train/train_ART_S_SR_x2.yml: -------------------------------------------------------------------------------- 1 | # general settings for image SR training 2 | name: ART_S_SR_x2 3 | model_type: ARTModel 4 | scale: 2 5 | num_gpu: 4 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | task: SR 12 | name: DF2K 13 | type: PairedImageDataset 14 | dataroot_gt: datasets/DF2K/HR 15 | dataroot_lq: datasets/DF2K/LR_bicubic/X2 16 | filename_tmpl: '{}x2' 17 | io_backend: 18 | type: disk 19 | 20 | gt_size: 128 21 | use_hflip: true 22 | use_rot: true 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | dataset_enlarge_ratio: 100 29 | prefetch_mode: ~ 30 | 31 | val: 32 | name: Set5 33 | type: PairedImageDataset 34 | dataroot_gt: datasets/SR/Set5/HR 35 | dataroot_lq: datasets/SR/Set5/LR_bicubic/X2 36 | filename_tmpl: '{}x2' 37 | io_backend: 38 | type: disk 39 | 40 | # network structures 41 | network_g: 42 | type: ART 43 | upscale: 2 44 | in_chans: 3 45 | img_size: 64 46 | window_size: 8 47 | img_range: 1. 48 | depths: [6, 6, 6, 6, 6, 6] 49 | interval: [8, 8, 8, 8, 8, 8] 50 | embed_dim: 180 51 | num_heads: [6, 6, 6, 6, 6, 6] 52 | mlp_ratio: 2 53 | upsampler: 'pixelshuffle' 54 | resi_connection: '1conv' 55 | 56 | # path 57 | path: 58 | pretrain_network_g: ~ 59 | strict_load_g: true 60 | resume_state: ~ 61 | 62 | # training settings 63 | train: 64 | optim_g: 65 | type: Adam 66 | lr: !!float 2e-4 67 | weight_decay: 0 68 | betas: [0.9, 0.99] 69 | 70 | scheduler: 71 | type: MultiStepLR 72 | milestones: [250000, 400000, 450000, 475000] 73 | gamma: 0.5 74 | 75 | total_iter: 500000 76 | warmup_iter: -1 # no warm up 77 | 78 | # losses 79 | pixel_opt: 80 | type: L1Loss 81 | loss_weight: 1.0 82 | reduction: mean 83 | # validation settings 84 | val: 85 | val_freq: !!float 5e3 86 | save_img: true 87 | 88 | metrics: 89 | psnr: # metric name, can be arbitrary 90 | type: calculate_psnr 91 | crop_border: 2 92 | test_y_channel: true 93 | 94 | 95 | 96 | # logging settings 97 | logger: 98 | print_freq: 200 99 | save_checkpoint_freq: !!float 5e3 100 | use_tb_logger: true 101 | wandb: 102 | project: ~ 103 | resume_id: ~ 104 | 105 | # dist training settings 106 | dist_params: 107 | backend: nccl 108 | port: 29500 -------------------------------------------------------------------------------- /options/train/train_ART_S_SR_x3.yml: -------------------------------------------------------------------------------- 1 | # general settings for image SR training 2 | name: ART_S_SR_x3 3 | model_type: ARTModel 4 | scale: 3 5 | num_gpu: 4 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | task: SR 12 | name: DF2K 13 | type: PairedImageDataset 14 | dataroot_gt: datasets/DF2K/HR 15 | dataroot_lq: datasets/DF2K/LR_bicubic/X3 16 | filename_tmpl: '{}x3' 17 | io_backend: 18 | type: disk 19 | 20 | gt_size: 192 21 | use_hflip: true 22 | use_rot: true 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | dataset_enlarge_ratio: 100 29 | prefetch_mode: ~ 30 | 31 | val: 32 | name: Set5 33 | type: PairedImageDataset 34 | dataroot_gt: datasets/SR/Set5/HR 35 | dataroot_lq: datasets/SR/Set5/LR_bicubic/X3 36 | filename_tmpl: '{}x3' 37 | io_backend: 38 | type: disk 39 | 40 | # network structures 41 | network_g: 42 | type: ART 43 | upscale: 3 44 | in_chans: 3 45 | img_size: 64 46 | window_size: 8 47 | img_range: 1. 48 | depths: [6, 6, 6, 6, 6, 6] 49 | interval: [8, 8, 8, 8, 8, 8] 50 | embed_dim: 180 51 | num_heads: [6, 6, 6, 6, 6, 6] 52 | mlp_ratio: 2 53 | upsampler: 'pixelshuffle' 54 | resi_connection: '1conv' 55 | 56 | # path 57 | path: 58 | pretrain_network_g: experiments/pretrained_models/SR_ART_S_x2.pth # save training time if we finetune from x2 and halve initial lr. 59 | strict_load_g: false 60 | resume_state: ~ 61 | 62 | # training settings 63 | train: 64 | optim_g: 65 | type: Adam 66 | # lr: !!float 2e-4 67 | lr: !!float 1e-4 68 | weight_decay: 0 69 | betas: [0.9, 0.99] 70 | 71 | scheduler: 72 | type: MultiStepLR 73 | # milestones: [ 250000, 400000, 450000, 475000 ] 74 | milestones: [ 125000, 200000, 225000, 237500 ] 75 | gamma: 0.5 76 | 77 | # total_iter: 500000 78 | total_iter: 250000 79 | warmup_iter: -1 # no warm up 80 | 81 | # losses 82 | pixel_opt: 83 | type: L1Loss 84 | loss_weight: 1.0 85 | reduction: mean 86 | # validation settings 87 | val: 88 | val_freq: !!float 5e3 89 | save_img: true 90 | 91 | metrics: 92 | psnr: # metric name, can be arbitrary 93 | type: calculate_psnr 94 | crop_border: 3 95 | test_y_channel: true 96 | 97 | 98 | 99 | # logging settings 100 | logger: 101 | print_freq: 200 102 | save_checkpoint_freq: !!float 5e3 103 | use_tb_logger: true 104 | wandb: 105 | project: ~ 106 | resume_id: ~ 107 | 108 | # dist training settings 109 | dist_params: 110 | backend: nccl 111 | port: 29500 -------------------------------------------------------------------------------- /options/train/train_ART_S_SR_x4.yml: -------------------------------------------------------------------------------- 1 | # general settings for image SR training 2 | name: ART_S_SR_x4 3 | model_type: ARTModel 4 | scale: 4 5 | num_gpu: 4 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | task: SR 12 | name: DF2K 13 | type: PairedImageDataset 14 | dataroot_gt: datasets/DF2K/HR 15 | dataroot_lq: datasets/DF2K/LR_bicubic/X4 16 | filename_tmpl: '{}x4' 17 | io_backend: 18 | type: disk 19 | 20 | gt_size: 256 21 | use_hflip: true 22 | use_rot: true 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | dataset_enlarge_ratio: 100 29 | prefetch_mode: ~ 30 | 31 | val: 32 | name: Set5 33 | type: PairedImageDataset 34 | dataroot_gt: datasets/SR/Set5/HR 35 | dataroot_lq: datasets/SR/Set5/LR_bicubic/X4 36 | filename_tmpl: '{}x4' 37 | io_backend: 38 | type: disk 39 | 40 | # network structures 41 | network_g: 42 | type: ART 43 | upscale: 4 44 | in_chans: 3 45 | img_size: 64 46 | window_size: 8 47 | img_range: 1. 48 | depths: [6, 6, 6, 6, 6, 6] 49 | interval: [8, 8, 8, 8, 8, 8] 50 | embed_dim: 180 51 | num_heads: [6, 6, 6, 6, 6, 6] 52 | mlp_ratio: 2 53 | upsampler: 'pixelshuffle' 54 | resi_connection: '1conv' 55 | 56 | # path 57 | path: 58 | pretrain_network_g: experiments/pretrained_models/SR_ART_S_x2.pth # save training time if we finetune from x2 and halve initial lr. 59 | strict_load_g: false 60 | resume_state: ~ 61 | 62 | # training settings 63 | train: 64 | optim_g: 65 | type: Adam 66 | # lr: !!float 2e-4 67 | lr: !!float 1e-4 68 | weight_decay: 0 69 | betas: [0.9, 0.99] 70 | 71 | scheduler: 72 | type: MultiStepLR 73 | # milestones: [ 250000, 400000, 450000, 475000 ] 74 | milestones: [ 125000, 200000, 225000, 237500 ] 75 | gamma: 0.5 76 | 77 | # total_iter: 500000 78 | total_iter: 250000 79 | warmup_iter: -1 # no warm up 80 | 81 | # losses 82 | pixel_opt: 83 | type: L1Loss 84 | loss_weight: 1.0 85 | reduction: mean 86 | # validation settings 87 | val: 88 | val_freq: !!float 5e3 89 | save_img: true 90 | 91 | metrics: 92 | psnr: # metric name, can be arbitrary 93 | type: calculate_psnr 94 | crop_border: 4 95 | test_y_channel: true 96 | 97 | 98 | 99 | # logging settings 100 | logger: 101 | print_freq: 200 102 | save_checkpoint_freq: !!float 5e3 103 | use_tb_logger: true 104 | wandb: 105 | project: ~ 106 | resume_id: ~ 107 | 108 | # dist training settings 109 | dist_params: 110 | backend: nccl 111 | port: 29500 -------------------------------------------------------------------------------- /realDenoising/README.md: -------------------------------------------------------------------------------- 1 | For real image denoising task, we make two differences when compared to other tasks. Firstly, we design our ART model with the U-net structure, which is similar with Restormer. You can find the model file at `basicsr/models/archs/artunet_arch.py`. Secondly, we train our ART under the same training settings with Restormer, which means that we use the same [BasicSR](https://github.com/xinntao/BasicSR) environment (v1.2.0) and progressive training strategy. Therefore, we can make fair comparisons with Restormer. 2 | 3 | Note that this folder is built basd on [Restormer](https://github.com/swz30/Restormer). Thanks for their awesome works. -------------------------------------------------------------------------------- /realDenoising/VERSION: -------------------------------------------------------------------------------- 1 | 1.2.0 2 | -------------------------------------------------------------------------------- /realDenoising/basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from functools import partial 7 | from os import path as osp 8 | 9 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 10 | from basicsr.utils import get_root_logger, scandir 11 | from basicsr.utils.dist_util import get_dist_info 12 | 13 | __all__ = ['create_dataset', 'create_dataloader'] 14 | 15 | # automatically scan and import dataset modules 16 | # scan all the files under the data folder with '_dataset' in file names 17 | data_folder = osp.dirname(osp.abspath(__file__)) 18 | dataset_filenames = [ 19 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) 20 | if v.endswith('_dataset.py') 21 | ] 22 | # import all the dataset modules 23 | _dataset_modules = [ 24 | importlib.import_module(f'basicsr.data.{file_name}') 25 | for file_name in dataset_filenames 26 | ] 27 | 28 | 29 | def create_dataset(dataset_opt): 30 | """Create dataset. 31 | 32 | Args: 33 | dataset_opt (dict): Configuration for dataset. It constains: 34 | name (str): Dataset name. 35 | type (str): Dataset type. 36 | """ 37 | dataset_type = dataset_opt['type'] 38 | 39 | # dynamic instantiation 40 | for module in _dataset_modules: 41 | dataset_cls = getattr(module, dataset_type, None) 42 | if dataset_cls is not None: 43 | break 44 | if dataset_cls is None: 45 | raise ValueError(f'Dataset {dataset_type} is not found.') 46 | 47 | dataset = dataset_cls(dataset_opt) 48 | 49 | logger = get_root_logger() 50 | logger.info( 51 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' 52 | 'is created.') 53 | return dataset 54 | 55 | 56 | def create_dataloader(dataset, 57 | dataset_opt, 58 | num_gpu=1, 59 | dist=False, 60 | sampler=None, 61 | seed=None): 62 | """Create dataloader. 63 | 64 | Args: 65 | dataset (torch.utils.data.Dataset): Dataset. 66 | dataset_opt (dict): Dataset options. It contains the following keys: 67 | phase (str): 'train' or 'val'. 68 | num_worker_per_gpu (int): Number of workers for each GPU. 69 | batch_size_per_gpu (int): Training batch size for each GPU. 70 | num_gpu (int): Number of GPUs. Used only in the train phase. 71 | Default: 1. 72 | dist (bool): Whether in distributed training. Used only in the train 73 | phase. Default: False. 74 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 75 | seed (int | None): Seed. Default: None 76 | """ 77 | phase = dataset_opt['phase'] 78 | rank, _ = get_dist_info() 79 | if phase == 'train': 80 | if dist: # distributed training 81 | batch_size = dataset_opt['batch_size_per_gpu'] 82 | num_workers = dataset_opt['num_worker_per_gpu'] 83 | else: # non-distributed training 84 | multiplier = 1 if num_gpu == 0 else num_gpu 85 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 86 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 87 | dataloader_args = dict( 88 | dataset=dataset, 89 | batch_size=batch_size, 90 | shuffle=False, 91 | num_workers=num_workers, 92 | sampler=sampler, 93 | drop_last=True) 94 | if sampler is None: 95 | dataloader_args['shuffle'] = True 96 | dataloader_args['worker_init_fn'] = partial( 97 | worker_init_fn, num_workers=num_workers, rank=rank, 98 | seed=seed) if seed is not None else None 99 | elif phase in ['val', 'test']: # validation 100 | dataloader_args = dict( 101 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 102 | else: 103 | raise ValueError(f'Wrong dataset phase: {phase}. ' 104 | "Supported ones are 'train', 'val' and 'test'.") 105 | 106 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 107 | 108 | prefetch_mode = dataset_opt.get('prefetch_mode') 109 | if prefetch_mode == 'cpu': # CPUPrefetcher 110 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 111 | logger = get_root_logger() 112 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' 113 | f'num_prefetch_queue = {num_prefetch_queue}') 114 | return PrefetchDataLoader( 115 | num_prefetch_queue=num_prefetch_queue, **dataloader_args) 116 | else: 117 | # prefetch_mode=None: Normal dataloader 118 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 119 | return torch.utils.data.DataLoader(**dataloader_args) 120 | 121 | 122 | def worker_init_fn(worker_id, num_workers, rank, seed): 123 | # Set the worker seed to num_workers * rank + worker_id + seed 124 | worker_seed = num_workers * rank + worker_id + seed 125 | np.random.seed(worker_seed) 126 | random.seed(worker_seed) 127 | -------------------------------------------------------------------------------- /realDenoising/basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil( 27 | len(self.dataset) * ratio / self.num_replicas) 28 | self.total_size = self.num_samples * self.num_replicas 29 | 30 | def __iter__(self): 31 | # deterministically shuffle based on epoch 32 | g = torch.Generator() 33 | g.manual_seed(self.epoch) 34 | indices = torch.randperm(self.total_size, generator=g).tolist() 35 | 36 | dataset_size = len(self.dataset) 37 | indices = [v % dataset_size for v in indices] 38 | 39 | # subsample 40 | indices = indices[self.rank:self.total_size:self.num_replicas] 41 | assert len(indices) == self.num_samples 42 | 43 | return iter(indices) 44 | 45 | def __len__(self): 46 | return self.num_samples 47 | 48 | def set_epoch(self, epoch): 49 | self.epoch = epoch 50 | -------------------------------------------------------------------------------- /realDenoising/basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.transforms import augment 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor 7 | 8 | 9 | class FFHQDataset(data.Dataset): 10 | """FFHQ dataset for StyleGAN. 11 | 12 | Args: 13 | opt (dict): Config for train datasets. It contains the following keys: 14 | dataroot_gt (str): Data root path for gt. 15 | io_backend (dict): IO backend type and other kwarg. 16 | mean (list | tuple): Image mean. 17 | std (list | tuple): Image std. 18 | use_hflip (bool): Whether to horizontally flip. 19 | 20 | """ 21 | 22 | def __init__(self, opt): 23 | super(FFHQDataset, self).__init__() 24 | self.opt = opt 25 | # file client (io backend) 26 | self.file_client = None 27 | self.io_backend_opt = opt['io_backend'] 28 | 29 | self.gt_folder = opt['dataroot_gt'] 30 | self.mean = opt['mean'] 31 | self.std = opt['std'] 32 | 33 | if self.io_backend_opt['type'] == 'lmdb': 34 | self.io_backend_opt['db_paths'] = self.gt_folder 35 | if not self.gt_folder.endswith('.lmdb'): 36 | raise ValueError("'dataroot_gt' should end with '.lmdb', " 37 | f'but received {self.gt_folder}') 38 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 39 | self.paths = [line.split('.')[0] for line in fin] 40 | else: 41 | # FFHQ has 70000 images in total 42 | self.paths = [ 43 | osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000) 44 | ] 45 | 46 | def __getitem__(self, index): 47 | if self.file_client is None: 48 | self.file_client = FileClient( 49 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load gt image 52 | gt_path = self.paths[index] 53 | img_bytes = self.file_client.get(gt_path) 54 | img_gt = imfrombytes(img_bytes, float32=True) 55 | 56 | # random horizontal flip 57 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 58 | # BGR to RGB, HWC to CHW, numpy to tensor 59 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 60 | # normalize 61 | normalize(img_gt, self.mean, self.std, inplace=True) 62 | return {'gt': img_gt, 'gt_path': gt_path} 63 | 64 | def __len__(self): 65 | return len(self.paths) 66 | -------------------------------------------------------------------------------- /realDenoising/basicsr/data/meta_info/meta_info_REDS4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 000 100 (720,1280,3) 2 | 011 100 (720,1280,3) 3 | 015 100 (720,1280,3) 4 | 020 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /realDenoising/basicsr/data/meta_info/meta_info_REDSofficial4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 246 100 (720,1280,3) 4 | 257 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /realDenoising/basicsr/data/meta_info/meta_info_REDSval_official_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 242 100 (720,1280,3) 4 | 243 100 (720,1280,3) 5 | 244 100 (720,1280,3) 6 | 245 100 (720,1280,3) 7 | 246 100 (720,1280,3) 8 | 247 100 (720,1280,3) 9 | 248 100 (720,1280,3) 10 | 249 100 (720,1280,3) 11 | 250 100 (720,1280,3) 12 | 251 100 (720,1280,3) 13 | 252 100 (720,1280,3) 14 | 253 100 (720,1280,3) 15 | 254 100 (720,1280,3) 16 | 255 100 (720,1280,3) 17 | 256 100 (720,1280,3) 18 | 257 100 (720,1280,3) 19 | 258 100 (720,1280,3) 20 | 259 100 (720,1280,3) 21 | 260 100 (720,1280,3) 22 | 261 100 (720,1280,3) 23 | 262 100 (720,1280,3) 24 | 263 100 (720,1280,3) 25 | 264 100 (720,1280,3) 26 | 265 100 (720,1280,3) 27 | 266 100 (720,1280,3) 28 | 267 100 (720,1280,3) 29 | 268 100 (720,1280,3) 30 | 269 100 (720,1280,3) 31 | -------------------------------------------------------------------------------- /realDenoising/basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to( 116 | device=self.device, non_blocking=True) 117 | 118 | def next(self): 119 | torch.cuda.current_stream().wait_stream(self.stream) 120 | batch = self.batch 121 | self.preload() 122 | return batch 123 | 124 | def reset(self): 125 | self.loader = iter(self.ori_loader) 126 | self.preload() 127 | -------------------------------------------------------------------------------- /realDenoising/basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir 7 | 8 | 9 | class SingleImageDataset(data.Dataset): 10 | """Read only lq images in the test phase. 11 | 12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 13 | 14 | There are two modes: 15 | 1. 'meta_info_file': Use meta information file to generate paths. 16 | 2. 'folder': Scan folders to generate paths. 17 | 18 | Args: 19 | opt (dict): Config for train datasets. It contains the following keys: 20 | dataroot_lq (str): Data root path for lq. 21 | meta_info_file (str): Path for meta information file. 22 | io_backend (dict): IO backend type and other kwarg. 23 | """ 24 | 25 | def __init__(self, opt): 26 | super(SingleImageDataset, self).__init__() 27 | self.opt = opt 28 | # file client (io backend) 29 | self.file_client = None 30 | self.io_backend_opt = opt['io_backend'] 31 | self.mean = opt['mean'] if 'mean' in opt else None 32 | self.std = opt['std'] if 'std' in opt else None 33 | self.lq_folder = opt['dataroot_lq'] 34 | 35 | if self.io_backend_opt['type'] == 'lmdb': 36 | self.io_backend_opt['db_paths'] = [self.lq_folder] 37 | self.io_backend_opt['client_keys'] = ['lq'] 38 | self.paths = paths_from_lmdb(self.lq_folder) 39 | elif 'meta_info_file' in self.opt: 40 | with open(self.opt['meta_info_file'], 'r') as fin: 41 | self.paths = [ 42 | osp.join(self.lq_folder, 43 | line.split(' ')[0]) for line in fin 44 | ] 45 | else: 46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 47 | 48 | def __getitem__(self, index): 49 | if self.file_client is None: 50 | self.file_client = FileClient( 51 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 52 | 53 | # load lq image 54 | lq_path = self.paths[index] 55 | img_bytes = self.file_client.get(lq_path, 'lq') 56 | img_lq = imfrombytes(img_bytes, float32=True) 57 | 58 | # TODO: color space transform 59 | # BGR to RGB, HWC to CHW, numpy to tensor 60 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 61 | # normalize 62 | if self.mean is not None or self.std is not None: 63 | normalize(img_lq, self.mean, self.std, inplace=True) 64 | return {'lq': img_lq, 'lq_path': lq_path} 65 | 66 | def __len__(self): 67 | return len(self.paths) 68 | -------------------------------------------------------------------------------- /realDenoising/basicsr/data/vimeo90k_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from pathlib import Path 4 | from torch.utils import data as data 5 | 6 | from basicsr.data.transforms import augment, paired_random_crop 7 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 8 | 9 | 10 | class Vimeo90KDataset(data.Dataset): 11 | """Vimeo90K dataset for training. 12 | 13 | The keys are generated from a meta info txt file. 14 | basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt 15 | 16 | Each line contains: 17 | 1. clip name; 2. frame number; 3. image shape, seperated by a white space. 18 | Examples: 19 | 00001/0001 7 (256,448,3) 20 | 00001/0002 7 (256,448,3) 21 | 22 | Key examples: "00001/0001" 23 | GT (gt): Ground-Truth; 24 | LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. 25 | 26 | The neighboring frame list for different num_frame: 27 | num_frame | frame list 28 | 1 | 4 29 | 3 | 3,4,5 30 | 5 | 2,3,4,5,6 31 | 7 | 1,2,3,4,5,6,7 32 | 33 | Args: 34 | opt (dict): Config for train dataset. It contains the following keys: 35 | dataroot_gt (str): Data root path for gt. 36 | dataroot_lq (str): Data root path for lq. 37 | meta_info_file (str): Path for meta information file. 38 | io_backend (dict): IO backend type and other kwarg. 39 | 40 | num_frame (int): Window size for input frames. 41 | gt_size (int): Cropped patched size for gt patches. 42 | random_reverse (bool): Random reverse input frames. 43 | use_flip (bool): Use horizontal flips. 44 | use_rot (bool): Use rotation (use vertical flip and transposing h 45 | and w for implementation). 46 | 47 | scale (bool): Scale, which will be added automatically. 48 | """ 49 | 50 | def __init__(self, opt): 51 | super(Vimeo90KDataset, self).__init__() 52 | self.opt = opt 53 | self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path( 54 | opt['dataroot_lq']) 55 | 56 | with open(opt['meta_info_file'], 'r') as fin: 57 | self.keys = [line.split(' ')[0] for line in fin] 58 | 59 | # file client (io backend) 60 | self.file_client = None 61 | self.io_backend_opt = opt['io_backend'] 62 | self.is_lmdb = False 63 | if self.io_backend_opt['type'] == 'lmdb': 64 | self.is_lmdb = True 65 | self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] 66 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 67 | 68 | # indices of input images 69 | self.neighbor_list = [ 70 | i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame']) 71 | ] 72 | 73 | # temporal augmentation configs 74 | self.random_reverse = opt['random_reverse'] 75 | logger = get_root_logger() 76 | logger.info(f'Random reverse is {self.random_reverse}.') 77 | 78 | def __getitem__(self, index): 79 | if self.file_client is None: 80 | self.file_client = FileClient( 81 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 82 | 83 | # random reverse 84 | if self.random_reverse and random.random() < 0.5: 85 | self.neighbor_list.reverse() 86 | 87 | scale = self.opt['scale'] 88 | gt_size = self.opt['gt_size'] 89 | key = self.keys[index] 90 | clip, seq = key.split('/') # key example: 00001/0001 91 | 92 | # get the GT frame (im4.png) 93 | if self.is_lmdb: 94 | img_gt_path = f'{key}/im4' 95 | else: 96 | img_gt_path = self.gt_root / clip / seq / 'im4.png' 97 | img_bytes = self.file_client.get(img_gt_path, 'gt') 98 | img_gt = imfrombytes(img_bytes, float32=True) 99 | 100 | # get the neighboring LQ frames 101 | img_lqs = [] 102 | for neighbor in self.neighbor_list: 103 | if self.is_lmdb: 104 | img_lq_path = f'{clip}/{seq}/im{neighbor}' 105 | else: 106 | img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' 107 | img_bytes = self.file_client.get(img_lq_path, 'lq') 108 | img_lq = imfrombytes(img_bytes, float32=True) 109 | img_lqs.append(img_lq) 110 | 111 | # randomly crop 112 | img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, 113 | img_gt_path) 114 | 115 | # augmentation - flip, rotate 116 | img_lqs.append(img_gt) 117 | img_results = augment(img_lqs, self.opt['use_flip'], 118 | self.opt['use_rot']) 119 | 120 | img_results = img2tensor(img_results) 121 | img_lqs = torch.stack(img_results[0:-1], dim=0) 122 | img_gt = img_results[-1] 123 | 124 | # img_lqs: (t, c, h, w) 125 | # img_gt: (c, h, w) 126 | # key: str 127 | return {'lq': img_lqs, 'gt': img_gt, 'key': key} 128 | 129 | def __len__(self): 130 | return len(self.keys) 131 | -------------------------------------------------------------------------------- /realDenoising/basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .niqe import calculate_niqe 2 | from .psnr_ssim import calculate_psnr, calculate_ssim 3 | 4 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 5 | -------------------------------------------------------------------------------- /realDenoising/basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.models.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', 11 | resize_input=True, 12 | normalize_input=False): 13 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 14 | # does resize the input. 15 | inception = InceptionV3([3], 16 | resize_input=resize_input, 17 | normalize_input=normalize_input) 18 | inception = nn.DataParallel(inception).eval().to(device) 19 | return inception 20 | 21 | 22 | @torch.no_grad() 23 | def extract_inception_features(data_generator, 24 | inception, 25 | len_generator=None, 26 | device='cuda'): 27 | """Extract inception features. 28 | 29 | Args: 30 | data_generator (generator): A data generator. 31 | inception (nn.Module): Inception model. 32 | len_generator (int): Length of the data_generator to show the 33 | progressbar. Default: None. 34 | device (str): Device. Default: cuda. 35 | 36 | Returns: 37 | Tensor: Extracted features. 38 | """ 39 | if len_generator is not None: 40 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 41 | else: 42 | pbar = None 43 | features = [] 44 | 45 | for data in data_generator: 46 | if pbar: 47 | pbar.update(1) 48 | data = data.to(device) 49 | feature = inception(data)[0].view(data.shape[0], -1) 50 | features.append(feature.to('cpu')) 51 | if pbar: 52 | pbar.close() 53 | features = torch.cat(features, 0) 54 | return features 55 | 56 | 57 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 58 | """Numpy implementation of the Frechet Distance. 59 | 60 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 61 | and X_2 ~ N(mu_2, C_2) is 62 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 63 | Stable version by Dougal J. Sutherland. 64 | 65 | Args: 66 | mu1 (np.array): The sample mean over activations. 67 | sigma1 (np.array): The covariance matrix over activations for 68 | generated samples. 69 | mu2 (np.array): The sample mean over activations, precalculated on an 70 | representative data set. 71 | sigma2 (np.array): The covariance matrix over activations, 72 | precalculated on an representative data set. 73 | 74 | Returns: 75 | float: The Frechet Distance. 76 | """ 77 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 78 | assert sigma1.shape == sigma2.shape, ( 79 | 'Two covariances have different dimensions') 80 | 81 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 82 | 83 | # Product might be almost singular 84 | if not np.isfinite(cov_sqrt).all(): 85 | print('Product of cov matrices is singular. Adding {eps} to diagonal ' 86 | 'of cov estimates') 87 | offset = np.eye(sigma1.shape[0]) * eps 88 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 89 | 90 | # Numerical error might give slight imaginary component 91 | if np.iscomplexobj(cov_sqrt): 92 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 93 | m = np.max(np.abs(cov_sqrt.imag)) 94 | raise ValueError(f'Imaginary component {m}') 95 | cov_sqrt = cov_sqrt.real 96 | 97 | mean_diff = mu1 - mu2 98 | mean_norm = mean_diff @ mean_diff 99 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 100 | fid = mean_norm + trace 101 | 102 | return fid 103 | -------------------------------------------------------------------------------- /realDenoising/basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError( 25 | f'Wrong input_order {input_order}. Supported input_orders are ' 26 | "'HWC' and 'CHW'") 27 | if len(img.shape) == 2: 28 | img = img[..., None] 29 | if input_order == 'CHW': 30 | img = img.transpose(1, 2, 0) 31 | return img 32 | 33 | 34 | def to_y_channel(img): 35 | """Change to Y channel of YCbCr. 36 | 37 | Args: 38 | img (ndarray): Images with range [0, 255]. 39 | 40 | Returns: 41 | (ndarray): Images with range [0, 255] (float type) without round. 42 | """ 43 | img = img.astype(np.float32) / 255. 44 | if img.ndim == 3 and img.shape[2] == 3: 45 | img = bgr2ycbcr(img, y_only=True) 46 | img = img[..., None] 47 | return img * 255. 48 | -------------------------------------------------------------------------------- /realDenoising/basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/realDenoising/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /realDenoising/basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import get_root_logger, scandir 5 | 6 | # automatically scan and import model modules 7 | # scan all the files under the 'models' folder and collect files ending with 8 | # '_model.py' 9 | model_folder = osp.dirname(osp.abspath(__file__)) 10 | model_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) 12 | if v.endswith('_model.py') 13 | ] 14 | # import all the model modules 15 | _model_modules = [ 16 | importlib.import_module(f'basicsr.models.{file_name}') 17 | for file_name in model_filenames 18 | ] 19 | 20 | 21 | def create_model(opt): 22 | """Create model. 23 | 24 | Args: 25 | opt (dict): Configuration. It constains: 26 | model_type (str): Model type. 27 | """ 28 | model_type = opt['model_type'] 29 | 30 | # dynamic instantiation 31 | for module in _model_modules: 32 | model_cls = getattr(module, model_type, None) 33 | if model_cls is not None: 34 | break 35 | if model_cls is None: 36 | raise ValueError(f'Model {model_type} is not found.') 37 | 38 | model = model_cls(opt) 39 | 40 | logger = get_root_logger() 41 | logger.info(f'Model [{model.__class__.__name__}] is created.') 42 | return model 43 | -------------------------------------------------------------------------------- /realDenoising/basicsr/models/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import arch modules 7 | # scan all the files under the 'archs' folder and collect files ending with 8 | # '_arch.py' 9 | arch_folder = osp.dirname(osp.abspath(__file__)) 10 | arch_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) 12 | if v.endswith('_arch.py') 13 | ] 14 | # import all the arch modules 15 | _arch_modules = [ 16 | importlib.import_module(f'basicsr.models.archs.{file_name}') 17 | for file_name in arch_filenames 18 | ] 19 | 20 | 21 | def dynamic_instantiation(modules, cls_type, opt): 22 | """Dynamically instantiate class. 23 | 24 | Args: 25 | modules (list[importlib modules]): List of modules from importlib 26 | files. 27 | cls_type (str): Class type. 28 | opt (dict): Class initialization kwargs. 29 | 30 | Returns: 31 | class: Instantiated class. 32 | """ 33 | 34 | for module in modules: 35 | cls_ = getattr(module, cls_type, None) 36 | if cls_ is not None: 37 | break 38 | if cls_ is None: 39 | raise ValueError(f'{cls_type} is not found.') 40 | return cls_(**opt) 41 | 42 | 43 | def define_network(opt): 44 | network_type = opt.pop('type') 45 | net = dynamic_instantiation(_arch_modules, network_type, opt) 46 | return net 47 | -------------------------------------------------------------------------------- /realDenoising/basicsr/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss) 2 | 3 | __all__ = [ 4 | 'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss', 5 | ] 6 | -------------------------------------------------------------------------------- /realDenoising/basicsr/models/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /realDenoising/basicsr/models/losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from basicsr.models.losses.loss_util import weighted_loss 7 | 8 | _reduction_modes = ['none', 'mean', 'sum'] 9 | 10 | 11 | @weighted_loss 12 | def l1_loss(pred, target): 13 | return F.l1_loss(pred, target, reduction='none') 14 | 15 | 16 | @weighted_loss 17 | def mse_loss(pred, target): 18 | return F.mse_loss(pred, target, reduction='none') 19 | 20 | 21 | # @weighted_loss 22 | # def charbonnier_loss(pred, target, eps=1e-12): 23 | # return torch.sqrt((pred - target)**2 + eps) 24 | 25 | 26 | class L1Loss(nn.Module): 27 | """L1 (mean absolute error, MAE) loss. 28 | 29 | Args: 30 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 31 | reduction (str): Specifies the reduction to apply to the output. 32 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 33 | """ 34 | 35 | def __init__(self, loss_weight=1.0, reduction='mean'): 36 | super(L1Loss, self).__init__() 37 | if reduction not in ['none', 'mean', 'sum']: 38 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 39 | f'Supported ones are: {_reduction_modes}') 40 | 41 | self.loss_weight = loss_weight 42 | self.reduction = reduction 43 | 44 | def forward(self, pred, target, weight=None, **kwargs): 45 | """ 46 | Args: 47 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 48 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 49 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 50 | weights. Default: None. 51 | """ 52 | return self.loss_weight * l1_loss( 53 | pred, target, weight, reduction=self.reduction) 54 | 55 | class MSELoss(nn.Module): 56 | """MSE (L2) loss. 57 | 58 | Args: 59 | loss_weight (float): Loss weight for MSE loss. Default: 1.0. 60 | reduction (str): Specifies the reduction to apply to the output. 61 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 62 | """ 63 | 64 | def __init__(self, loss_weight=1.0, reduction='mean'): 65 | super(MSELoss, self).__init__() 66 | if reduction not in ['none', 'mean', 'sum']: 67 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 68 | f'Supported ones are: {_reduction_modes}') 69 | 70 | self.loss_weight = loss_weight 71 | self.reduction = reduction 72 | 73 | def forward(self, pred, target, weight=None, **kwargs): 74 | """ 75 | Args: 76 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 77 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 78 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 79 | weights. Default: None. 80 | """ 81 | return self.loss_weight * mse_loss( 82 | pred, target, weight, reduction=self.reduction) 83 | 84 | class PSNRLoss(nn.Module): 85 | 86 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 87 | super(PSNRLoss, self).__init__() 88 | assert reduction == 'mean' 89 | self.loss_weight = loss_weight 90 | self.scale = 10 / np.log(10) 91 | self.toY = toY 92 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 93 | self.first = True 94 | 95 | def forward(self, pred, target): 96 | assert len(pred.size()) == 4 97 | if self.toY: 98 | if self.first: 99 | self.coef = self.coef.to(pred.device) 100 | self.first = False 101 | 102 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 103 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 104 | 105 | pred, target = pred / 255., target / 255. 106 | pass 107 | assert len(pred.size()) == 4 108 | 109 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() 110 | 111 | class CharbonnierLoss(nn.Module): 112 | """Charbonnier Loss (L1)""" 113 | 114 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3): 115 | super(CharbonnierLoss, self).__init__() 116 | self.eps = eps 117 | 118 | def forward(self, x, y): 119 | diff = x - y 120 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 121 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 122 | return loss 123 | -------------------------------------------------------------------------------- /realDenoising/basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import create_dataloader, create_dataset 6 | from basicsr.models import create_model 7 | from basicsr.train import parse_options 8 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str, 9 | make_exp_dirs) 10 | from basicsr.utils.options import dict2str 11 | 12 | 13 | def main(): 14 | # parse options, set distributed setting, set ramdom seed 15 | opt = parse_options(is_train=False) 16 | 17 | torch.backends.cudnn.benchmark = True 18 | # torch.backends.cudnn.deterministic = True 19 | 20 | # mkdir and initialize loggers 21 | make_exp_dirs(opt) 22 | log_file = osp.join(opt['path']['log'], 23 | f"test_{opt['name']}_{get_time_str()}.log") 24 | logger = get_root_logger( 25 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 26 | logger.info(get_env_info()) 27 | logger.info(dict2str(opt)) 28 | 29 | # create test dataset and dataloader 30 | test_loaders = [] 31 | for phase, dataset_opt in sorted(opt['datasets'].items()): 32 | test_set = create_dataset(dataset_opt) 33 | test_loader = create_dataloader( 34 | test_set, 35 | dataset_opt, 36 | num_gpu=opt['num_gpu'], 37 | dist=opt['dist'], 38 | sampler=None, 39 | seed=opt['manual_seed']) 40 | logger.info( 41 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 42 | test_loaders.append(test_loader) 43 | 44 | # create model 45 | model = create_model(opt) 46 | 47 | for test_loader in test_loaders: 48 | test_set_name = test_loader.dataset.opt['name'] 49 | logger.info(f'Testing {test_set_name}...') 50 | rgb2bgr = opt['val'].get('rgb2bgr', True) 51 | # wheather use uint8 image to compute metrics 52 | use_image = opt['val'].get('use_image', True) 53 | model.validation( 54 | test_loader, 55 | current_iter=opt['name'], 56 | tb_logger=None, 57 | save_img=opt['val']['save_img'], 58 | rgb2bgr=rgb2bgr, use_image=use_image) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /realDenoising/basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP 3 | from .logger import (MessageLogger, get_env_info, get_root_logger, 4 | init_tb_logger, init_wandb_logger) 5 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, 6 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt) 7 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) 8 | 9 | __all__ = [ 10 | # file_client.py 11 | 'FileClient', 12 | # img_util.py 13 | 'img2tensor', 14 | 'tensor2img', 15 | 'imfrombytes', 16 | 'imwrite', 17 | 'crop_border', 18 | # logger.py 19 | 'MessageLogger', 20 | 'init_tb_logger', 21 | 'init_wandb_logger', 22 | 'get_root_logger', 23 | 'get_env_info', 24 | # misc.py 25 | 'set_random_seed', 26 | 'get_time_str', 27 | 'mkdir_and_rename', 28 | 'make_exp_dirs', 29 | 'scandir', 30 | 'check_resume', 31 | 'sizeof_fmt', 32 | 'padding', 33 | 'padding_DP', 34 | 'imfrombytesDP', 35 | 'create_lmdb_for_reds', 36 | 'create_lmdb_for_gopro', 37 | 'create_lmdb_for_rain13k', 38 | ] 39 | -------------------------------------------------------------------------------- /realDenoising/basicsr/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gladzhang/ART/3ee01f35294405748a06ab24cf4abc26b41e6746/realDenoising/basicsr/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /realDenoising/basicsr/utils/bundle_submissions.py: -------------------------------------------------------------------------------- 1 | # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de) 2 | 3 | # This file is part of the implementation as described in the CVPR 2017 paper: 4 | # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs. 5 | # Please see the file LICENSE.txt for the license governing this code. 6 | 7 | 8 | import numpy as np 9 | import scipy.io as sio 10 | import os 11 | import h5py 12 | 13 | def bundle_submissions_raw(submission_folder,session): 14 | ''' 15 | Bundles submission data for raw denoising 16 | 17 | submission_folder Folder where denoised images reside 18 | 19 | Output is written to /bundled/. Please submit 20 | the content of this folder. 21 | ''' 22 | 23 | out_folder = os.path.join(submission_folder, session) 24 | # out_folder = os.path.join(submission_folder, "bundled/") 25 | try: 26 | os.mkdir(out_folder) 27 | except:pass 28 | 29 | israw = True 30 | eval_version="1.0" 31 | 32 | for i in range(50): 33 | Idenoised = np.zeros((20,), dtype=np.object) 34 | for bb in range(20): 35 | filename = '%04d_%02d.mat'%(i+1,bb+1) 36 | s = sio.loadmat(os.path.join(submission_folder,filename)) 37 | Idenoised_crop = s["Idenoised_crop"] 38 | Idenoised[bb] = Idenoised_crop 39 | filename = '%04d.mat'%(i+1) 40 | sio.savemat(os.path.join(out_folder, filename), 41 | {"Idenoised": Idenoised, 42 | "israw": israw, 43 | "eval_version": eval_version}, 44 | ) 45 | 46 | def bundle_submissions_srgb(submission_folder,session): 47 | ''' 48 | Bundles submission data for sRGB denoising 49 | 50 | submission_folder Folder where denoised images reside 51 | 52 | Output is written to /bundled/. Please submit 53 | the content of this folder. 54 | ''' 55 | out_folder = os.path.join(submission_folder, session) 56 | # out_folder = os.path.join(submission_folder, "bundled/") 57 | try: 58 | os.mkdir(out_folder) 59 | except:pass 60 | israw = False 61 | eval_version="1.0" 62 | 63 | for i in range(50): 64 | Idenoised = np.zeros((20,), dtype=np.object) 65 | for bb in range(20): 66 | filename = '%04d_%02d.mat'%(i+1,bb+1) 67 | s = sio.loadmat(os.path.join(submission_folder,filename)) 68 | Idenoised_crop = s["Idenoised_crop"] 69 | Idenoised[bb] = Idenoised_crop 70 | filename = '%04d.mat'%(i+1) 71 | sio.savemat(os.path.join(out_folder, filename), 72 | {"Idenoised": Idenoised, 73 | "israw": israw, 74 | "eval_version": eval_version}, 75 | ) 76 | 77 | 78 | 79 | def bundle_submissions_srgb_v1(submission_folder,session): 80 | ''' 81 | Bundles submission data for sRGB denoising 82 | 83 | submission_folder Folder where denoised images reside 84 | 85 | Output is written to /bundled/. Please submit 86 | the content of this folder. 87 | ''' 88 | out_folder = os.path.join(submission_folder, session) 89 | # out_folder = os.path.join(submission_folder, "bundled/") 90 | try: 91 | os.mkdir(out_folder) 92 | except:pass 93 | israw = False 94 | eval_version="1.0" 95 | 96 | for i in range(50): 97 | Idenoised = np.zeros((20,), dtype=np.object) 98 | for bb in range(20): 99 | filename = '%04d_%d.mat'%(i+1,bb+1) 100 | s = sio.loadmat(os.path.join(submission_folder,filename)) 101 | Idenoised_crop = s["Idenoised_crop"] 102 | Idenoised[bb] = Idenoised_crop 103 | filename = '%04d.mat'%(i+1) 104 | sio.savemat(os.path.join(out_folder, filename), 105 | {"Idenoised": Idenoised, 106 | "israw": israw, 107 | "eval_version": eval_version}, 108 | ) -------------------------------------------------------------------------------- /realDenoising/basicsr/utils/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs 6 | 7 | def prepare_keys(folder_path, suffix='png'): 8 | """Prepare image path list and keys for DIV2K dataset. 9 | 10 | Args: 11 | folder_path (str): Folder path. 12 | 13 | Returns: 14 | list[str]: Image path list. 15 | list[str]: Key list. 16 | """ 17 | print('Reading image path list ...') 18 | img_path_list = sorted( 19 | list(scandir(folder_path, suffix=suffix, recursive=False))) 20 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] 21 | 22 | return img_path_list, keys 23 | 24 | def create_lmdb_for_reds(): 25 | folder_path = './datasets/REDS/val/sharp_300' 26 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb' 27 | img_path_list, keys = prepare_keys(folder_path, 'png') 28 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 29 | # 30 | folder_path = './datasets/REDS/val/blur_300' 31 | lmdb_path = './datasets/REDS/val/blur_300.lmdb' 32 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 33 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 34 | 35 | folder_path = './datasets/REDS/train/train_sharp' 36 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb' 37 | img_path_list, keys = prepare_keys(folder_path, 'png') 38 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 39 | 40 | folder_path = './datasets/REDS/train/train_blur_jpeg' 41 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb' 42 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 43 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 44 | 45 | 46 | def create_lmdb_for_gopro(): 47 | folder_path = './datasets/GoPro/train/blur_crops' 48 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb' 49 | 50 | img_path_list, keys = prepare_keys(folder_path, 'png') 51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 52 | 53 | folder_path = './datasets/GoPro/train/sharp_crops' 54 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb' 55 | 56 | img_path_list, keys = prepare_keys(folder_path, 'png') 57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 58 | 59 | folder_path = './datasets/GoPro/test/target' 60 | lmdb_path = './datasets/GoPro/test/target.lmdb' 61 | 62 | img_path_list, keys = prepare_keys(folder_path, 'png') 63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 64 | 65 | folder_path = './datasets/GoPro/test/input' 66 | lmdb_path = './datasets/GoPro/test/input.lmdb' 67 | 68 | img_path_list, keys = prepare_keys(folder_path, 'png') 69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 70 | 71 | def create_lmdb_for_rain13k(): 72 | folder_path = './datasets/Rain13k/train/input' 73 | lmdb_path = './datasets/Rain13k/train/input.lmdb' 74 | 75 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 77 | 78 | folder_path = './datasets/Rain13k/train/target' 79 | lmdb_path = './datasets/Rain13k/train/target.lmdb' 80 | 81 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 83 | 84 | def create_lmdb_for_SIDD(): 85 | folder_path = './datasets/SIDD/train/input_crops' 86 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb' 87 | 88 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 89 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 90 | 91 | folder_path = './datasets/SIDD/train/gt_crops' 92 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb' 93 | 94 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 96 | 97 | #for val 98 | folder_path = './datasets/SIDD/val/input_crops' 99 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb' 100 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat' 101 | if not osp.exists(folder_path): 102 | os.makedirs(folder_path) 103 | assert osp.exists(mat_path) 104 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb'] 105 | N, B, H ,W, C = data.shape 106 | data = data.reshape(N*B, H, W, C) 107 | for i in tqdm(range(N*B)): 108 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 109 | img_path_list, keys = prepare_keys(folder_path, 'png') 110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 111 | 112 | folder_path = './datasets/SIDD/val/gt_crops' 113 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb' 114 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat' 115 | if not osp.exists(folder_path): 116 | os.makedirs(folder_path) 117 | assert osp.exists(mat_path) 118 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb'] 119 | N, B, H ,W, C = data.shape 120 | data = data.reshape(N*B, H, W, C) 121 | for i in tqdm(range(N*B)): 122 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 123 | img_path_list, keys = prepare_keys(folder_path, 'png') 124 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 125 | -------------------------------------------------------------------------------- /realDenoising/basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput( 45 | f'scontrol show hostname {node_list} | head -n1') 46 | # specify master port 47 | if port is not None: 48 | os.environ['MASTER_PORT'] = str(port) 49 | elif 'MASTER_PORT' in os.environ: 50 | pass # use MASTER_PORT in the environment variable 51 | else: 52 | # 29500 is torch.distributed default port 53 | os.environ['MASTER_PORT'] = '29500' 54 | os.environ['MASTER_ADDR'] = addr 55 | os.environ['WORLD_SIZE'] = str(ntasks) 56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 57 | os.environ['RANK'] = str(proc_id) 58 | dist.init_process_group(backend=backend) 59 | 60 | 61 | def get_dist_info(): 62 | if dist.is_available(): 63 | initialized = dist.is_initialized() 64 | else: 65 | initialized = False 66 | if initialized: 67 | rank = dist.get_rank() 68 | world_size = dist.get_world_size() 69 | else: 70 | rank = 0 71 | world_size = 1 72 | return rank, world_size 73 | 74 | 75 | def master_only(func): 76 | 77 | @functools.wraps(func) 78 | def wrapper(*args, **kwargs): 79 | rank, _ = get_dist_info() 80 | if rank == 0: 81 | return func(*args, **kwargs) 82 | 83 | return wrapper 84 | -------------------------------------------------------------------------------- /realDenoising/basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import requests 3 | from tqdm import tqdm 4 | 5 | from .misc import sizeof_fmt 6 | 7 | 8 | def download_file_from_google_drive(file_id, save_path): 9 | """Download files from google drive. 10 | 11 | Ref: 12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 13 | 14 | Args: 15 | file_id (str): File id. 16 | save_path (str): Save path. 17 | """ 18 | 19 | session = requests.Session() 20 | URL = 'https://docs.google.com/uc?export=download' 21 | params = {'id': file_id} 22 | 23 | response = session.get(URL, params=params, stream=True) 24 | token = get_confirm_token(response) 25 | if token: 26 | params['confirm'] = token 27 | response = session.get(URL, params=params, stream=True) 28 | 29 | # get file size 30 | response_file_size = session.get( 31 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 32 | if 'Content-Range' in response_file_size.headers: 33 | file_size = int( 34 | response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, 49 | destination, 50 | file_size=None, 51 | chunk_size=32768): 52 | if file_size is not None: 53 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 54 | 55 | readable_file_size = sizeof_fmt(file_size) 56 | else: 57 | pbar = None 58 | 59 | with open(destination, 'wb') as f: 60 | downloaded_size = 0 61 | for chunk in response.iter_content(chunk_size): 62 | downloaded_size += chunk_size 63 | if pbar is not None: 64 | pbar.update(1) 65 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' 66 | f'/ {readable_file_size}') 67 | if chunk: # filter out keep-alive new chunks 68 | f.write(chunk) 69 | if pbar is not None: 70 | pbar.close() 71 | -------------------------------------------------------------------------------- /realDenoising/basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from collections import OrderedDict 3 | from os import path as osp 4 | 5 | 6 | def ordered_yaml(): 7 | """Support OrderedDict for yaml. 8 | 9 | Returns: 10 | yaml Loader and Dumper. 11 | """ 12 | try: 13 | from yaml import CDumper as Dumper 14 | from yaml import CLoader as Loader 15 | except ImportError: 16 | from yaml import Dumper, Loader 17 | 18 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 19 | 20 | def dict_representer(dumper, data): 21 | return dumper.represent_dict(data.items()) 22 | 23 | def dict_constructor(loader, node): 24 | return OrderedDict(loader.construct_pairs(node)) 25 | 26 | Dumper.add_representer(OrderedDict, dict_representer) 27 | Loader.add_constructor(_mapping_tag, dict_constructor) 28 | return Loader, Dumper 29 | 30 | 31 | def parse(opt_path, is_train=True): 32 | """Parse option file. 33 | 34 | Args: 35 | opt_path (str): Option file path. 36 | is_train (str): Indicate whether in training or not. Default: True. 37 | 38 | Returns: 39 | (dict): Options. 40 | """ 41 | with open(opt_path, mode='r') as f: 42 | Loader, _ = ordered_yaml() 43 | opt = yaml.load(f, Loader=Loader) 44 | 45 | opt['is_train'] = is_train 46 | 47 | # datasets 48 | for phase, dataset in opt['datasets'].items(): 49 | # for several datasets, e.g., test_1, test_2 50 | phase = phase.split('_')[0] 51 | dataset['phase'] = phase 52 | if 'scale' in opt: 53 | dataset['scale'] = opt['scale'] 54 | if dataset.get('dataroot_gt') is not None: 55 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 56 | if dataset.get('dataroot_lq') is not None: 57 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 58 | 59 | # paths 60 | for key, val in opt['path'].items(): 61 | if (val is not None) and ('resume_state' in key 62 | or 'pretrain_network' in key): 63 | opt['path'][key] = osp.expanduser(val) 64 | opt['path']['root'] = osp.abspath( 65 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 66 | if is_train: 67 | experiments_root = osp.join(opt['path']['root'], 'experiments', 68 | opt['name']) 69 | opt['path']['experiments_root'] = experiments_root 70 | opt['path']['models'] = osp.join(experiments_root, 'models') 71 | opt['path']['training_states'] = osp.join(experiments_root, 72 | 'training_states') 73 | opt['path']['log'] = experiments_root 74 | opt['path']['visualization'] = osp.join(experiments_root, 75 | 'visualization') 76 | 77 | # change some options for debug mode 78 | if 'debug' in opt['name']: 79 | if 'val' in opt: 80 | opt['val']['val_freq'] = 8 81 | opt['logger']['print_freq'] = 1 82 | opt['logger']['save_checkpoint_freq'] = 8 83 | else: # test 84 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 85 | opt['path']['results_root'] = results_root 86 | opt['path']['log'] = results_root 87 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 88 | 89 | return opt 90 | 91 | 92 | def dict2str(opt, indent_level=1): 93 | """dict to string for printing options. 94 | 95 | Args: 96 | opt (dict): Option dict. 97 | indent_level (int): Indent level. Default: 1. 98 | 99 | Return: 100 | (str): Option string for printing. 101 | """ 102 | msg = '\n' 103 | for k, v in opt.items(): 104 | if isinstance(v, dict): 105 | msg += ' ' * (indent_level * 2) + k + ':[' 106 | msg += dict2str(v, indent_level + 1) 107 | msg += ' ' * (indent_level * 2) + ']\n' 108 | else: 109 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 110 | return msg 111 | -------------------------------------------------------------------------------- /realDenoising/basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Thu Nov 10 10:11:46 2022 3 | __version__ = '1.2.0+10018c6' 4 | short_version = '1.2.0' 5 | version_info = (1, 2, 0) 6 | -------------------------------------------------------------------------------- /realDenoising/evaluate_sidd.m: -------------------------------------------------------------------------------- 1 | close all;clear all; 2 | 3 | denoised = load('./results/Real_Denoising/SIDD/mat/Idenoised.mat'); 4 | gt = load('../datasets/RealDN/SIDD/ValidationGtBlocksSrgb.mat'); 5 | 6 | denoised = denoised.Idenoised; 7 | gt = gt.ValidationGtBlocksSrgb; 8 | gt = im2single(gt); 9 | 10 | total_psnr = 0; 11 | total_ssim = 0; 12 | for i = 1:40 13 | for k = 1:32 14 | denoised_patch = squeeze(denoised(i,k,:,:,:)); 15 | gt_patch = squeeze(gt(i,k,:,:,:)); 16 | ssim_val = ssim(denoised_patch, gt_patch); 17 | psnr_val = psnr(denoised_patch, gt_patch); 18 | total_ssim = total_ssim + ssim_val; 19 | total_psnr = total_psnr + psnr_val; 20 | end 21 | end 22 | qm_psnr = total_psnr / (40*32); 23 | qm_ssim = total_ssim / (40*32); 24 | 25 | fprintf('PSNR: %f SSIM: %f\n', qm_psnr, qm_ssim); 26 | 27 | -------------------------------------------------------------------------------- /realDenoising/options/train_ART_RealDN.yml: -------------------------------------------------------------------------------- 1 | # general settings for RealDN training 2 | name: ART_RealDN 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage 13 | dataroot_gt: ../datasets/SIDD/train/target_crops 14 | dataroot_lq: ../datasets/SIDD/train/input_crops 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | ### -------------Progressive training-------------------------- 27 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 28 | iters: [92000,64000,48000,36000,36000,24000] 29 | gt_size: 384 # Max patch size for progressive training 30 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 31 | ### ------------------------------------------------------------ 32 | 33 | ### ------- Training on single fixed-patch size 128x128--------- 34 | # mini_batch_sizes: [8] 35 | # iters: [300000] 36 | # gt_size: 128 37 | # gt_sizes: [128] 38 | ### ------------------------------------------------------------ 39 | 40 | dataset_enlarge_ratio: 100 41 | prefetch_mode: ~ 42 | 43 | val: 44 | name: ValSet 45 | type: Dataset_PairedImage 46 | dataroot_gt: ../datasets/SIDD/val/target_crops 47 | dataroot_lq: ../datasets/SIDD/val/input_crops 48 | io_backend: 49 | type: disk 50 | 51 | # network structures 52 | network_g: 53 | type: ARTUNet 54 | inp_channels: 3 55 | out_channels: 3 56 | dim: 48 57 | num_blocks: [4, 6, 6, 8] 58 | num_refinement_blocks: 4 59 | heads: [1, 2, 4, 8] 60 | window_size: [8, 8, 8, 8] 61 | mlp_ratio: 4 62 | interval: [32, 16, 8, 4] 63 | bias: False 64 | dual_pixel_task: False 65 | 66 | 67 | # path 68 | path: 69 | pretrain_network_g: ~ 70 | strict_load_g: true 71 | resume_state: ~ 72 | 73 | # training settings 74 | train: 75 | total_iter: 300000 76 | warmup_iter: -1 # no warm up 77 | use_grad_clip: true 78 | 79 | # Split 300k iterations into two cycles. 80 | # 1st cycle: fixed 3e-4 LR for 92k iters. 81 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 82 | scheduler: 83 | type: CosineAnnealingRestartCyclicLR 84 | periods: [92000, 208000] 85 | restart_weights: [1,1] 86 | eta_mins: [0.0003,0.000001] 87 | 88 | mixing_augs: 89 | mixup: true 90 | mixup_beta: 1.2 91 | use_identity: true 92 | 93 | optim_g: 94 | type: AdamW 95 | lr: !!float 3e-4 96 | weight_decay: !!float 1e-4 97 | betas: [0.9, 0.999] 98 | 99 | # losses 100 | pixel_opt: 101 | type: L1Loss 102 | loss_weight: 1 103 | reduction: mean 104 | 105 | # validation settings 106 | val: 107 | window_size: 8 108 | val_freq: !!float 4e3 109 | save_img: false 110 | rgb2bgr: true 111 | use_image: false 112 | max_minibatch: 8 113 | 114 | metrics: 115 | psnr: # metric name, can be arbitrary 116 | type: calculate_psnr 117 | crop_border: 0 118 | test_y_channel: false 119 | 120 | # logging settings 121 | logger: 122 | print_freq: 1000 123 | save_checkpoint_freq: !!float 4e3 124 | use_tb_logger: true 125 | wandb: 126 | project: ~ 127 | resume_id: ~ 128 | 129 | # dist training settings 130 | dist_params: 131 | backend: nccl 132 | port: 29500 133 | -------------------------------------------------------------------------------- /realDenoising/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # line break before binary operator (W503) 4 | W503, 5 | # line break after binary operator (W504) 6 | W504, 7 | max-line-length=79 8 | 9 | [yapf] 10 | based_on_style = pep8 11 | blank_line_before_nested_class_or_def = true 12 | split_before_expression_after_opening_paren = true 13 | 14 | [isort] 15 | line_length = 79 16 | multi_line_output = 0 17 | known_standard_library = pkg_resources,setuptools 18 | known_first_party = basicsr 19 | known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml 20 | no_lines_before = STDLIB,LOCALFOLDER 21 | default_section = THIRDPARTY 22 | -------------------------------------------------------------------------------- /realDenoising/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import sys 8 | import time 9 | import torch 10 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 11 | CUDAExtension) 12 | 13 | version_file = 'basicsr/version.py' 14 | 15 | 16 | def readme(): 17 | return '' 18 | # with open('README.md', encoding='utf-8') as f: 19 | # content = f.read() 20 | # return content 21 | 22 | 23 | def get_git_hash(): 24 | 25 | def _minimal_ext_cmd(cmd): 26 | # construct minimal environment 27 | env = {} 28 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 29 | v = os.environ.get(k) 30 | if v is not None: 31 | env[k] = v 32 | # LANGUAGE is used on win32 33 | env['LANGUAGE'] = 'C' 34 | env['LANG'] = 'C' 35 | env['LC_ALL'] = 'C' 36 | out = subprocess.Popen( 37 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 38 | return out 39 | 40 | try: 41 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 42 | sha = out.strip().decode('ascii') 43 | except OSError: 44 | sha = 'unknown' 45 | 46 | return sha 47 | 48 | 49 | def get_hash(): 50 | if os.path.exists('.git'): 51 | sha = get_git_hash()[:7] 52 | elif os.path.exists(version_file): 53 | try: 54 | from basicsr.version import __version__ 55 | sha = __version__.split('+')[-1] 56 | except ImportError: 57 | raise ImportError('Unable to get git version') 58 | else: 59 | sha = 'unknown' 60 | 61 | return sha 62 | 63 | 64 | def write_version_py(): 65 | content = """# GENERATED VERSION FILE 66 | # TIME: {} 67 | __version__ = '{}' 68 | short_version = '{}' 69 | version_info = ({}) 70 | """ 71 | sha = get_hash() 72 | with open('VERSION', 'r') as f: 73 | SHORT_VERSION = f.read().strip() 74 | VERSION_INFO = ', '.join( 75 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 76 | VERSION = SHORT_VERSION + '+' + sha 77 | 78 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, 79 | VERSION_INFO) 80 | with open(version_file, 'w') as f: 81 | f.write(version_file_str) 82 | 83 | 84 | def get_version(): 85 | with open(version_file, 'r') as f: 86 | exec(compile(f.read(), version_file, 'exec')) 87 | return locals()['__version__'] 88 | 89 | 90 | def make_cuda_ext(name, module, sources, sources_cuda=None): 91 | if sources_cuda is None: 92 | sources_cuda = [] 93 | define_macros = [] 94 | extra_compile_args = {'cxx': []} 95 | 96 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 97 | define_macros += [('WITH_CUDA', None)] 98 | extension = CUDAExtension 99 | extra_compile_args['nvcc'] = [ 100 | '-D__CUDA_NO_HALF_OPERATORS__', 101 | '-D__CUDA_NO_HALF_CONVERSIONS__', 102 | '-D__CUDA_NO_HALF2_OPERATORS__', 103 | ] 104 | sources += sources_cuda 105 | else: 106 | print(f'Compiling {name} without CUDA') 107 | extension = CppExtension 108 | 109 | return extension( 110 | name=f'{module}.{name}', 111 | sources=[os.path.join(*module.split('.'), p) for p in sources], 112 | define_macros=define_macros, 113 | extra_compile_args=extra_compile_args) 114 | 115 | 116 | def get_requirements(filename='requirements.txt'): 117 | return [] 118 | here = os.path.dirname(os.path.realpath(__file__)) 119 | with open(os.path.join(here, filename), 'r') as f: 120 | requires = [line.replace('\n', '') for line in f.readlines()] 121 | return requires 122 | 123 | 124 | if __name__ == '__main__': 125 | if '--no_cuda_ext' in sys.argv: 126 | ext_modules = [] 127 | sys.argv.remove('--no_cuda_ext') 128 | else: 129 | ext_modules = [ 130 | make_cuda_ext( 131 | name='deform_conv_ext', 132 | module='basicsr.models.ops.dcn', 133 | sources=['src/deform_conv_ext.cpp'], 134 | sources_cuda=[ 135 | 'src/deform_conv_cuda.cpp', 136 | 'src/deform_conv_cuda_kernel.cu' 137 | ]), 138 | make_cuda_ext( 139 | name='fused_act_ext', 140 | module='basicsr.models.ops.fused_act', 141 | sources=['src/fused_bias_act.cpp'], 142 | sources_cuda=['src/fused_bias_act_kernel.cu']), 143 | make_cuda_ext( 144 | name='upfirdn2d_ext', 145 | module='basicsr.models.ops.upfirdn2d', 146 | sources=['src/upfirdn2d.cpp'], 147 | sources_cuda=['src/upfirdn2d_kernel.cu']), 148 | ] 149 | 150 | write_version_py() 151 | setup( 152 | name='basicsr', 153 | version=get_version(), 154 | description='Open Source Image and Video Super-Resolution Toolbox', 155 | long_description=readme(), 156 | author='Xintao Wang', 157 | author_email='xintao.wang@outlook.com', 158 | keywords='computer vision, restoration, super resolution', 159 | url='https://github.com/xinntao/BasicSR', 160 | packages=find_packages( 161 | exclude=('options', 'datasets', 'experiments', 'results', 162 | 'tb_logger', 'wandb')), 163 | classifiers=[ 164 | 'Development Status :: 4 - Beta', 165 | 'License :: OSI Approved :: Apache Software License', 166 | 'Operating System :: OS Independent', 167 | 'Programming Language :: Python :: 3', 168 | 'Programming Language :: Python :: 3.7', 169 | 'Programming Language :: Python :: 3.8', 170 | ], 171 | license='Apache License 2.0', 172 | setup_requires=['cython', 'numpy'], 173 | install_requires=get_requirements(), 174 | ext_modules=ext_modules, 175 | cmdclass={'build_ext': BuildExtension}, 176 | zip_safe=False) 177 | -------------------------------------------------------------------------------- /realDenoising/test_real_denoising_dnd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import utils 10 | 11 | from basicsr.models.archs.artunet_arch import ARTUNet 12 | from skimage import img_as_ubyte 13 | import h5py 14 | import scipy.io as sio 15 | from pdb import set_trace as stx 16 | 17 | parser = argparse.ArgumentParser(description='Real Image Denoising') 18 | 19 | parser.add_argument('--input_dir', default='../datasets/RealDN/DND/', type=str, help='Directory of validation images') 20 | parser.add_argument('--result_dir', default='./results/Real_Denoising/DND/', type=str, help='Directory for results') 21 | parser.add_argument('--weights', default='../experiments/pretrained_models/RealDN_ART.pth', type=str, help='Path to weights') 22 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory') 23 | 24 | args = parser.parse_args() 25 | 26 | ####### Load model options ####### 27 | 28 | opt_str = r""" 29 | type: ARTUNet 30 | inp_channels: 3 31 | out_channels: 3 32 | dim: 48 33 | num_blocks: [4, 6, 6, 8] 34 | num_refinement_blocks: 4 35 | heads: [1, 2, 4, 8] 36 | window_size: [8, 8, 8, 8] 37 | mlp_ratio: 4 38 | interval: [32, 16, 8, 4] 39 | bias: False 40 | dual_pixel_task: False 41 | """ 42 | 43 | import yaml 44 | opt = yaml.safe_load(opt_str) 45 | network_type = opt.pop('type') 46 | ########################################## 47 | 48 | result_dir_mat = os.path.join(args.result_dir, 'mat') 49 | os.makedirs(result_dir_mat, exist_ok=True) 50 | 51 | if args.save_images: 52 | result_dir_png = os.path.join(args.result_dir, 'png') 53 | os.makedirs(result_dir_png, exist_ok=True) 54 | 55 | model_restoration = ARTUNet(**opt) 56 | 57 | weights = 'experiments/pretrained_models/xformer_real_dn.pth' 58 | checkpoint = torch.load(args.weights) 59 | model_restoration.load_state_dict(checkpoint['params']) 60 | print("===>Testing using weights: ",args.weights) 61 | model_restoration.cuda() 62 | model_restoration = nn.DataParallel(model_restoration) 63 | model_restoration.eval() 64 | 65 | israw = False 66 | eval_version="1.0" 67 | 68 | # Load info 69 | infos = h5py.File(os.path.join(args.input_dir, 'info.mat'), 'r') 70 | info = infos['info'] 71 | bb = info['boundingboxes'] 72 | 73 | # Process data 74 | with torch.no_grad(): 75 | for i in tqdm(range(50)): 76 | Idenoised = np.zeros((20,), dtype=np.object) 77 | filename = '%04d.mat'%(i+1) 78 | filepath = os.path.join(args.input_dir, 'images_srgb', filename) 79 | img = h5py.File(filepath, 'r') 80 | Inoisy = np.float32(np.array(img['InoisySRGB']).T) 81 | 82 | # bounding box 83 | ref = bb[0][i] 84 | boxes = np.array(info[ref]).T 85 | 86 | for k in range(20): 87 | idx = [int(boxes[k,0]-1),int(boxes[k,2]),int(boxes[k,1]-1),int(boxes[k,3])] 88 | noisy_patch = torch.from_numpy(Inoisy[idx[0]:idx[1],idx[2]:idx[3],:]).unsqueeze(0).permute(0,3,1,2).cuda() 89 | restored_patch = model_restoration(noisy_patch) 90 | restored_patch = torch.clamp(restored_patch,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 91 | Idenoised[k] = restored_patch 92 | 93 | if args.save_images: 94 | save_file = os.path.join(result_dir_png, '%04d_%02d.png'%(i+1,k+1)) 95 | denoised_img = img_as_ubyte(restored_patch) 96 | utils.save_img(save_file, denoised_img) 97 | 98 | # save denoised data 99 | sio.savemat(os.path.join(result_dir_mat, filename), 100 | {"Idenoised": Idenoised, 101 | "israw": israw, 102 | "eval_version": eval_version}, 103 | ) 104 | -------------------------------------------------------------------------------- /realDenoising/test_real_denoising_sidd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from basicsr.models.archs.artunet_arch import ARTUNet 10 | import scipy.io as sio 11 | 12 | parser = argparse.ArgumentParser(description='Real Image Denoising') 13 | 14 | parser.add_argument('--input_dir', default='../datasets/RealDN/SIDD/', type=str, help='Directory of validation images') 15 | parser.add_argument('--result_dir', default='./results/Real_Denoising/SIDD/', type=str, help='Directory for results') 16 | parser.add_argument('--weights', default='../experiments/pretrained_models/RealDN_ART.pth', type=str, help='Path to weights') 17 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory') 18 | 19 | args = parser.parse_args() 20 | 21 | ####### Load yaml ####### 22 | opt_str = r""" 23 | type: ARTUNet 24 | inp_channels: 3 25 | out_channels: 3 26 | dim: 48 27 | num_blocks: [4, 6, 6, 8] 28 | num_refinement_blocks: 4 29 | heads: [1, 2, 4, 8] 30 | window_size: [8, 8, 8, 8] 31 | mlp_ratio: 4 32 | interval: [32, 16, 8, 4] 33 | bias: False 34 | dual_pixel_task: False 35 | """ 36 | 37 | import yaml 38 | x = yaml.safe_load(opt_str) 39 | 40 | s = x.pop('type') 41 | ########################## 42 | 43 | result_dir_mat = os.path.join(args.result_dir, 'mat') 44 | os.makedirs(result_dir_mat, exist_ok=True) 45 | 46 | if args.save_images: 47 | result_dir_png = os.path.join(args.result_dir, 'png') 48 | os.makedirs(result_dir_png, exist_ok=True) 49 | 50 | model_restoration = ARTUNet(**x) 51 | 52 | checkpoint = torch.load(args.weights) 53 | model_restoration.load_state_dict(checkpoint['params']) 54 | print("===>Testing using weights: ",args.weights) 55 | model_restoration.cuda() 56 | model_restoration = nn.DataParallel(model_restoration) 57 | model_restoration.eval() 58 | 59 | # Process data 60 | filepath = os.path.join(args.input_dir, 'ValidationNoisyBlocksSrgb.mat') 61 | img = sio.loadmat(filepath) 62 | Inoisy = np.float32(np.array(img['ValidationNoisyBlocksSrgb'])) 63 | Inoisy /=255. 64 | restored = np.zeros_like(Inoisy) 65 | with torch.no_grad(): 66 | for i in tqdm(range(40)): 67 | for k in range(32): 68 | noisy_patch = torch.from_numpy(Inoisy[i,k,:,:,:]).unsqueeze(0).permute(0,3,1,2).cuda() 69 | restored_patch = model_restoration(noisy_patch) 70 | restored_patch = torch.clamp(restored_patch,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0) 71 | restored[i,k,:,:,:] = restored_patch 72 | 73 | if args.save_images: 74 | save_file = os.path.join(result_dir_png, '%04d_%02d.png'%(i+1,k+1)) 75 | utils.save_img(save_file, img_as_ubyte(restored_patch)) 76 | 77 | # save denoised data 78 | sio.savemat(os.path.join(result_dir_mat, 'Idenoised.mat'), {"Idenoised": restored,}) 79 | -------------------------------------------------------------------------------- /realDenoising/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import math 5 | 6 | def calculate_psnr(img1, img2, border=0): 7 | # img1 and img2 have range [0, 255] 8 | #img1 = img1.squeeze() 9 | #img2 = img2.squeeze() 10 | if not img1.shape == img2.shape: 11 | raise ValueError('Input images must have the same dimensions.') 12 | h, w = img1.shape[:2] 13 | img1 = img1[border:h-border, border:w-border] 14 | img2 = img2[border:h-border, border:w-border] 15 | 16 | img1 = img1.astype(np.float64) 17 | img2 = img2.astype(np.float64) 18 | mse = np.mean((img1 - img2)**2) 19 | if mse == 0: 20 | return float('inf') 21 | return 20 * math.log10(255.0 / math.sqrt(mse)) 22 | 23 | 24 | # -------------------------------------------- 25 | # SSIM 26 | # -------------------------------------------- 27 | def calculate_ssim(img1, img2, border=0): 28 | '''calculate SSIM 29 | the same outputs as MATLAB's 30 | img1, img2: [0, 255] 31 | ''' 32 | #img1 = img1.squeeze() 33 | #img2 = img2.squeeze() 34 | if not img1.shape == img2.shape: 35 | raise ValueError('Input images must have the same dimensions.') 36 | h, w = img1.shape[:2] 37 | img1 = img1[border:h-border, border:w-border] 38 | img2 = img2[border:h-border, border:w-border] 39 | 40 | if img1.ndim == 2: 41 | return ssim(img1, img2) 42 | elif img1.ndim == 3: 43 | if img1.shape[2] == 3: 44 | ssims = [] 45 | for i in range(3): 46 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 47 | return np.array(ssims).mean() 48 | elif img1.shape[2] == 1: 49 | return ssim(np.squeeze(img1), np.squeeze(img2)) 50 | else: 51 | raise ValueError('Wrong input image dimensions.') 52 | 53 | 54 | def ssim(img1, img2): 55 | C1 = (0.01 * 255)**2 56 | C2 = (0.03 * 255)**2 57 | 58 | img1 = img1.astype(np.float64) 59 | img2 = img2.astype(np.float64) 60 | kernel = cv2.getGaussianKernel(11, 1.5) 61 | window = np.outer(kernel, kernel.transpose()) 62 | 63 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 64 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 65 | mu1_sq = mu1**2 66 | mu2_sq = mu2**2 67 | mu1_mu2 = mu1 * mu2 68 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 69 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 70 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 71 | 72 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 73 | (sigma1_sq + sigma2_sq + C2)) 74 | return ssim_map.mean() 75 | 76 | def load_img(filepath): 77 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 78 | 79 | def save_img(filepath, img): 80 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 81 | 82 | def load_gray_img(filepath): 83 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 84 | 85 | def save_gray_img(filepath, img): 86 | cv2.imwrite(filepath, img) 87 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict 2 | future 3 | lmdb 4 | numpy>=1.17 5 | opencv-python 6 | Pillow 7 | pyyaml 8 | requests 9 | scikit-image 10 | scipy 11 | tb-nightly 12 | torch>=1.8 13 | torchvision 14 | tqdm 15 | yapf 16 | timm 17 | einops 18 | h5py 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import time 8 | import torch 9 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 10 | 11 | version_file = 'basicsr/version.py' 12 | 13 | 14 | def readme(): 15 | with open('README.md', encoding='utf-8') as f: 16 | content = f.read() 17 | return content 18 | 19 | 20 | def get_git_hash(): 21 | 22 | def _minimal_ext_cmd(cmd): 23 | # construct minimal environment 24 | env = {} 25 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 26 | v = os.environ.get(k) 27 | if v is not None: 28 | env[k] = v 29 | # LANGUAGE is used on win32 30 | env['LANGUAGE'] = 'C' 31 | env['LANG'] = 'C' 32 | env['LC_ALL'] = 'C' 33 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 34 | return out 35 | 36 | try: 37 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 38 | sha = out.strip().decode('ascii') 39 | except OSError: 40 | sha = 'unknown' 41 | 42 | return sha 43 | 44 | 45 | def get_hash(): 46 | if os.path.exists('.git'): 47 | sha = get_git_hash()[:7] 48 | # currently ignore this 49 | # elif os.path.exists(version_file): 50 | # try: 51 | # from basicsr.version import __version__ 52 | # sha = __version__.split('+')[-1] 53 | # except ImportError: 54 | # raise ImportError('Unable to get git version') 55 | else: 56 | sha = 'unknown' 57 | 58 | return sha 59 | 60 | 61 | def write_version_py(): 62 | content = """# GENERATED VERSION FILE 63 | # TIME: {} 64 | __version__ = '{}' 65 | __gitsha__ = '{}' 66 | version_info = ({}) 67 | """ 68 | sha = get_hash() 69 | with open('VERSION', 'r') as f: 70 | SHORT_VERSION = f.read().strip() 71 | VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 72 | 73 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) 74 | with open(version_file, 'w') as f: 75 | f.write(version_file_str) 76 | 77 | 78 | def get_version(): 79 | with open(version_file, 'r') as f: 80 | exec(compile(f.read(), version_file, 'exec')) 81 | return locals()['__version__'] 82 | 83 | 84 | def make_cuda_ext(name, module, sources, sources_cuda=None): 85 | if sources_cuda is None: 86 | sources_cuda = [] 87 | define_macros = [] 88 | extra_compile_args = {'cxx': []} 89 | 90 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 91 | define_macros += [('WITH_CUDA', None)] 92 | extension = CUDAExtension 93 | extra_compile_args['nvcc'] = [ 94 | '-D__CUDA_NO_HALF_OPERATORS__', 95 | '-D__CUDA_NO_HALF_CONVERSIONS__', 96 | '-D__CUDA_NO_HALF2_OPERATORS__', 97 | ] 98 | sources += sources_cuda 99 | else: 100 | print(f'Compiling {name} without CUDA') 101 | extension = CppExtension 102 | 103 | return extension( 104 | name=f'{module}.{name}', 105 | sources=[os.path.join(*module.split('.'), p) for p in sources], 106 | define_macros=define_macros, 107 | extra_compile_args=extra_compile_args) 108 | 109 | 110 | def get_requirements(filename='requirements.txt'): 111 | here = os.path.dirname(os.path.realpath(__file__)) 112 | with open(os.path.join(here, filename), 'r') as f: 113 | requires = [line.replace('\n', '') for line in f.readlines()] 114 | return requires 115 | 116 | 117 | if __name__ == '__main__': 118 | cuda_ext = os.getenv('BASICSR_EXT') # whether compile cuda ext 119 | if cuda_ext == 'True': 120 | ext_modules = [ 121 | make_cuda_ext( 122 | name='deform_conv_ext', 123 | module='basicsr.ops.dcn', 124 | sources=['src/deform_conv_ext.cpp'], 125 | sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']), 126 | make_cuda_ext( 127 | name='fused_act_ext', 128 | module='basicsr.ops.fused_act', 129 | sources=['src/fused_bias_act.cpp'], 130 | sources_cuda=['src/fused_bias_act_kernel.cu']), 131 | make_cuda_ext( 132 | name='upfirdn2d_ext', 133 | module='basicsr.ops.upfirdn2d', 134 | sources=['src/upfirdn2d.cpp'], 135 | sources_cuda=['src/upfirdn2d_kernel.cu']), 136 | ] 137 | else: 138 | ext_modules = [] 139 | 140 | write_version_py() 141 | setup( 142 | name='basicsr', 143 | version=get_version(), 144 | description='Open Source Image and Video Super-Resolution Toolbox', 145 | long_description=readme(), 146 | long_description_content_type='text/markdown', 147 | author='Xintao Wang', 148 | author_email='xintao.wang@outlook.com', 149 | keywords='computer vision, restoration, super resolution', 150 | url='https://github.com/xinntao/BasicSR', 151 | include_package_data=True, 152 | packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), 153 | classifiers=[ 154 | 'Development Status :: 4 - Beta', 155 | 'License :: OSI Approved :: Apache Software License', 156 | 'Operating System :: OS Independent', 157 | 'Programming Language :: Python :: 3', 158 | 'Programming Language :: Python :: 3.7', 159 | 'Programming Language :: Python :: 3.8', 160 | ], 161 | license='Apache License 2.0', 162 | setup_requires=['cython', 'numpy'], 163 | install_requires=get_requirements(), 164 | ext_modules=ext_modules, 165 | cmdclass={'build_ext': BuildExtension}, 166 | zip_safe=False) 167 | --------------------------------------------------------------------------------