├── EQSR ├── __init__.py ├── archs │ ├── Involution.py │ ├── Involution_Kernel_Abs_Pos.py │ ├── Involution_Kernel_Abs_Pos_test.py │ ├── Involution_PE.py │ ├── __init__.py │ ├── hat_ModMBFormer_Sim_arch.py │ ├── hat_arch.py │ ├── kernel │ │ ├── kernel_mean_4.25.png │ │ └── kernel_var_4.25.png │ └── swinir.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── imagenet_paired_dataset.cpython-37.pyc │ ├── imagenet_paired_dataset.py │ └── meta_info │ │ └── meta_info_DF2Ksub_GT.txt └── models │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── hat_model.cpython-37.pyc │ └── hat_model.py ├── LICENSE ├── README.md ├── basicsr ├── __init__.py ├── archs │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── arch_util.cpython-37.pyc │ │ ├── basicvsr_arch.cpython-37.pyc │ │ ├── basicvsrpp_arch.cpython-37.pyc │ │ ├── dfdnet_arch.cpython-37.pyc │ │ ├── dfdnet_util.cpython-37.pyc │ │ ├── discriminator_arch.cpython-37.pyc │ │ ├── discriminator_arch.cpython-37.pyc.139952149786672 │ │ ├── duf_arch.cpython-37.pyc │ │ ├── ecbsr_arch.cpython-37.pyc │ │ ├── edsr_arch.cpython-37.pyc │ │ ├── edvr_arch.cpython-37.pyc │ │ ├── hifacegan_arch.cpython-37.pyc │ │ ├── hifacegan_util.cpython-37.pyc │ │ ├── rcan_arch.cpython-37.pyc │ │ ├── ridnet_arch.cpython-37.pyc │ │ ├── rrdbnet_arch.cpython-37.pyc │ │ ├── spynet_arch.cpython-37.pyc │ │ ├── srresnet_arch.cpython-37.pyc │ │ ├── srvgg_arch.cpython-37.pyc │ │ ├── stylegan2_arch.cpython-37.pyc │ │ ├── stylegan2_bilinear_arch.cpython-37.pyc │ │ ├── swinir_arch.cpython-37.pyc │ │ ├── tof_arch.cpython-37.pyc │ │ └── vgg_arch.cpython-37.pyc │ ├── arch_util.py │ ├── basicvsr_arch.py │ ├── basicvsrpp_arch.py │ ├── dfdnet_arch.py │ ├── dfdnet_util.py │ ├── discriminator_arch.py │ ├── duf_arch.py │ ├── ecbsr_arch.py │ ├── edsr_arch.py │ ├── edvr_arch.py │ ├── hifacegan_arch.py │ ├── hifacegan_util.py │ ├── inception.py │ ├── rcan_arch.py │ ├── ridnet_arch.py │ ├── rrdbnet_arch.py │ ├── spynet_arch.py │ ├── srresnet_arch.py │ ├── srvgg_arch.py │ ├── stylegan2_arch.py │ ├── stylegan2_bilinear_arch.py │ ├── swinir_arch.py │ ├── tof_arch.py │ └── vgg_arch.py ├── data │ ├── DSF_DF2K_dataset.py │ ├── DSF_imagenet_dataset.py │ ├── DSF_nopre_DF2K_dataset.py │ ├── DSF_val_dataset.py │ ├── LTE_imagenet_dataset.py │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── degradations.py │ ├── ffhq_dataset.py │ ├── meta_info │ │ ├── meta_info_DIV2K800sub_GT.txt │ │ ├── meta_info_REDS4_test_GT.txt │ │ ├── meta_info_REDS_GT.txt │ │ ├── meta_info_REDSofficial4_test_GT.txt │ │ ├── meta_info_REDSval_official_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_fast_GT.txt │ │ ├── meta_info_Vimeo90K_test_medium_GT.txt │ │ ├── meta_info_Vimeo90K_test_slow_GT.txt │ │ └── meta_info_Vimeo90K_train_GT.txt │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ ├── realesrgan_dataset.py │ ├── realesrgan_paired_dataset.py │ ├── reds_dataset.py │ ├── single_image_dataset.py │ ├── transforms.py │ ├── video_test_dataset.py │ └── vimeo90k_dataset.py ├── losses │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── basic_loss.cpython-37.pyc │ │ ├── gan_loss.cpython-37.pyc │ │ ├── loss_util.cpython-37.pyc │ │ └── loss_util.cpython-37.pyc.139946364727344 │ ├── basic_loss.py │ ├── gan_loss.py │ └── loss_util.py ├── metrics │ ├── README.md │ ├── README_CN.md │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── metric_util.cpython-37.pyc │ │ ├── niqe.cpython-37.pyc │ │ ├── psnr_ssim.cpython-37.pyc │ │ └── psnr_ssim.cpython-37.pyc.139946364729136 │ ├── fid.py │ ├── metric_util.py │ ├── niqe.py │ ├── niqe_pris_params.npz │ ├── psnr_ssim.py │ └── test_metrics │ │ └── test_psnr_ssim.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── base_model.cpython-37.pyc │ │ ├── edvr_model.cpython-37.pyc │ │ ├── esrgan_model.cpython-37.pyc │ │ ├── hifacegan_model.cpython-37.pyc │ │ ├── lr_scheduler.cpython-37.pyc │ │ ├── realesrgan_model.cpython-37.pyc │ │ ├── realesrnet_model.cpython-37.pyc │ │ ├── sr_model.cpython-37.pyc │ │ ├── sr_model.cpython-37.pyc.139704389111088 │ │ ├── srgan_model.cpython-37.pyc │ │ ├── stylegan2_model.cpython-37.pyc │ │ ├── swinir_model.cpython-37.pyc │ │ ├── video_base_model.cpython-37.pyc │ │ ├── video_gan_model.cpython-37.pyc │ │ ├── video_recurrent_gan_model.cpython-37.pyc │ │ └── video_recurrent_model.cpython-37.pyc │ ├── base_model.py │ ├── edvr_model.py │ ├── esrgan_model.py │ ├── hifacegan_model.py │ ├── lr_scheduler.py │ ├── realesrgan_model.py │ ├── realesrnet_model.py │ ├── sr_model.py │ ├── srgan_model.py │ ├── stylegan2_model.py │ ├── swinir_model.py │ ├── video_base_model.py │ ├── video_gan_model.py │ ├── video_recurrent_gan_model.py │ └── video_recurrent_model.py ├── ops │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── __init__.cpython-37.pyc.139952150904880 │ ├── dcn │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── deform_conv.cpython-37.pyc │ │ │ └── deform_conv.cpython-37.pyc.139952150906544 │ │ ├── deform_conv.py │ │ └── src │ │ │ ├── deform_conv_cuda.cpp │ │ │ ├── deform_conv_cuda_kernel.cu │ │ │ └── deform_conv_ext.cpp │ ├── fused_act │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── fused_act.cpython-37.pyc │ │ │ └── fused_act.cpython-37.pyc.139952150906928 │ │ ├── fused_act.py │ │ └── src │ │ │ ├── fused_bias_act.cpp │ │ │ └── fused_bias_act_kernel.cu │ └── upfirdn2d │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── upfirdn2d.cpython-37.pyc │ │ └── upfirdn2d.cpython-37.pyc.139952150101424 │ │ ├── src │ │ ├── upfirdn2d.cpp │ │ └── upfirdn2d_kernel.cu │ │ └── upfirdn2d.py ├── test.py ├── train.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── color_util.cpython-37.pyc │ ├── diffjpeg.cpython-37.pyc │ ├── dist_util.cpython-37.pyc │ ├── dist_util.cpython-37.pyc.140634762200752 │ ├── file_client.cpython-37.pyc │ ├── flow_util.cpython-37.pyc │ ├── img_process_util.cpython-37.pyc │ ├── img_util.cpython-37.pyc │ ├── logger.cpython-37.pyc │ ├── matlab_functions.cpython-37.pyc │ ├── misc.cpython-37.pyc │ ├── options.cpython-37.pyc │ └── registry.cpython-37.pyc │ ├── color_util.py │ ├── diffjpeg.py │ ├── dist_util.py │ ├── download_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_process_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ ├── options.py │ ├── plot_util.py │ └── registry.py ├── datasets └── datasets.txt ├── doc └── img │ └── motivation.PNG ├── options ├── test │ ├── test.yml │ ├── test_scale.yml │ ├── testx234.yml │ ├── testx6.yml │ └── testx8.yml └── train │ ├── train_EQSR_ImageNet_from_scratch.yml │ └── train_EQSR_finetune_from_ImageNet_pretrain.yml ├── predict.py ├── test.py └── train.py /EQSR/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .archs import * 3 | from .data import * 4 | from .models import * -------------------------------------------------------------------------------- /EQSR/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import arch modules for registry 7 | # scan all the files that end with '_arch.py' under the archs folder 8 | arch_folder = osp.dirname(osp.abspath(__file__)) 9 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 10 | # import all the arch modules 11 | _arch_modules = [importlib.import_module(f'EQSR.archs.{file_name}') for file_name in arch_filenames] 12 | -------------------------------------------------------------------------------- /EQSR/archs/kernel/kernel_mean_4.25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/EQSR/archs/kernel/kernel_mean_4.25.png -------------------------------------------------------------------------------- /EQSR/archs/kernel/kernel_var_4.25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/EQSR/archs/kernel/kernel_var_4.25.png -------------------------------------------------------------------------------- /EQSR/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import dataset modules for registry 7 | # scan all the files that end with '_dataset.py' under the data folder 8 | data_folder = osp.dirname(osp.abspath(__file__)) 9 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 10 | # import all the dataset modules 11 | _dataset_modules = [importlib.import_module(f'EQSR.data.{file_name}') for file_name in dataset_filenames] 12 | -------------------------------------------------------------------------------- /EQSR/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/EQSR/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /EQSR/data/__pycache__/imagenet_paired_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/EQSR/data/__pycache__/imagenet_paired_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /EQSR/data/imagenet_paired_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os.path as osp 4 | from torch.utils import data as data 5 | from torchvision.transforms.functional import normalize 6 | 7 | from basicsr.data.data_util import paths_from_lmdb, scandir 8 | from basicsr.data.transforms import augment, paired_random_crop 9 | from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr 10 | from basicsr.utils.matlab_functions import imresize 11 | from basicsr.utils.registry import DATASET_REGISTRY 12 | 13 | 14 | @DATASET_REGISTRY.register() 15 | class ImageNetPairedDataset(data.Dataset): 16 | 17 | def __init__(self, opt): 18 | super(ImageNetPairedDataset, self).__init__() 19 | self.opt = opt 20 | # file client (io backend) 21 | self.file_client = None 22 | self.io_backend_opt = opt['io_backend'] 23 | self.mean = opt['mean'] if 'mean' in opt else None 24 | self.std = opt['std'] if 'std' in opt else None 25 | self.gt_folder = opt['dataroot_gt'] 26 | 27 | if self.io_backend_opt['type'] == 'lmdb': 28 | self.io_backend_opt['db_paths'] = [self.gt_folder] 29 | self.io_backend_opt['client_keys'] = ['gt'] 30 | self.paths = paths_from_lmdb(self.gt_folder) 31 | elif 'meta_info_file' in self.opt: 32 | with open(self.opt['meta_info_file'], 'r') as fin: 33 | self.paths = [osp.join(self.gt_folder, line.split(' ')[0]) for line in fin] 34 | else: 35 | self.paths = sorted(list(scandir(self.gt_folder, full_path=True))) 36 | 37 | def __getitem__(self, index): 38 | if self.file_client is None: 39 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 40 | 41 | scale = self.opt['scale'] 42 | 43 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 44 | # image range: [0, 1], float32. 45 | gt_path = self.paths[index] 46 | img_bytes = self.file_client.get(gt_path, 'gt') 47 | img_gt = imfrombytes(img_bytes, float32=True) 48 | 49 | # modcrop 50 | size_h, size_w, _ = img_gt.shape 51 | size_h = size_h - size_h % scale 52 | size_w = size_w - size_w % scale 53 | img_gt = img_gt[0:size_h, 0:size_w, :] 54 | 55 | # generate training pairs 56 | size_h = max(size_h, self.opt['gt_size']) 57 | size_w = max(size_w, self.opt['gt_size']) 58 | img_gt = cv2.resize(img_gt, (size_w, size_h)) 59 | img_lq = imresize(img_gt, 1 / scale) 60 | 61 | img_gt = np.ascontiguousarray(img_gt, dtype=np.float32) 62 | img_lq = np.ascontiguousarray(img_lq, dtype=np.float32) 63 | 64 | # augmentation for training 65 | if self.opt['phase'] == 'train': 66 | gt_size = self.opt['gt_size'] 67 | # random crop 68 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 69 | # flip, rotation 70 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 71 | 72 | # color space transform 73 | if 'color' in self.opt and self.opt['color'] == 'y': 74 | img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None] 75 | img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] 76 | 77 | # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets 78 | # TODO: It is better to update the datasets, rather than force to crop 79 | if self.opt['phase'] != 'train': 80 | img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] 81 | 82 | # BGR to RGB, HWC to CHW, numpy to tensor 83 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 84 | # normalize 85 | if self.mean is not None or self.std is not None: 86 | normalize(img_lq, self.mean, self.std, inplace=True) 87 | normalize(img_gt, self.mean, self.std, inplace=True) 88 | 89 | return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path} 90 | 91 | def __len__(self): 92 | return len(self.paths) 93 | -------------------------------------------------------------------------------- /EQSR/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import model modules for registry 7 | # scan all the files that end with '_model.py' under the model folder 8 | model_folder = osp.dirname(osp.abspath(__file__)) 9 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 10 | # import all the model modules 11 | _model_modules = [importlib.import_module(f'EQSR.models.{file_name}') for file_name in model_filenames] 12 | -------------------------------------------------------------------------------- /EQSR/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/EQSR/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /EQSR/models/__pycache__/hat_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/EQSR/models/__pycache__/hat_model.cpython-37.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Arbitrary-Scale Image Super-Resolution via Scale-Equivariance Pursuit (EQSR) 2 | ## Accepted by CVPR2023 3 | **The official repository with Pytorch** 4 | 5 | Our paper can be downloaded from [EQSR](https://openaccess.thecvf.com/content/CVPR2023/papers/Wang_Deep_Arbitrary-Scale_Image_Super-Resolution_via_Scale-Equivariance_Pursuit_CVPR_2023_paper.pdf). 6 | 7 | ## Introduction 8 | 9 | EQSR is designed to pursue scale-equivariance image super-resolution. 10 | We compare the PSNR degradation rate of our method and ArbSR. 11 | Taking the SOTA fixed-scale method HAT as reference, our model presents a more stable degradation as the scale increases, reflecting 12 | the equivariance of our method. 13 | ![motivation](/doc/img/motivation.PNG) 14 | 15 | 16 | ## Installation 17 | **Clone this repo:** 18 | ```bash 19 | git clone https://github.com/neuralchen/EQSR.git 20 | cd EQSR 21 | ``` 22 | **Dependencies:** 23 | - python3.7+ 24 | - pytorch 25 | - pyyaml, scipy, tqdm, imageio, einops, opencv-python 26 | - cupy 27 | 28 | (Note: Please do not directly use "pip install" to install basicsr. It might lead to some issues due to version differences.) 29 | 30 | ## Training 31 | 32 | We divide the training into two stages. The first stage involves pretraining on the ImageNet dataset, and the second stage entails fine-tuning on the DF2K dataset. 33 | 34 | You can modify parameters such as batch size, iterations, learning rate, etc. in the configuration files. 35 | 36 | - Phase 1 37 | Modify the dataset path in options/train/train*.xml, and run the following command to train. 38 | ``` 39 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=4320 train.py -opt train_EQSR_ImageNet_from_scratch --launcher pytorch 40 | ``` 41 | - Phase 2 42 | Modify the paths of the datasets and the location of the pretrained model in the configuration file. 43 | 44 | ``` 45 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=4320 train.py -opt train_EQSR_finetune_from_ImageNet_pretrain --launcher pytorch 46 | ``` 47 | 48 | ## Datasets 49 | TODO 50 | 51 | ### Preprocess 52 | TODO 53 | 54 | ## Inference with a pretrained EQSR model 55 | ### Pretrained Models 56 | - Baidu Netdisk (百度网盘):https://pan.baidu.com/s/1ui-GSbAQLuTyOmxBlAQZVg 57 | - Extraction Code (提取码):lspg 58 | 59 | Modify the dataset path and pre-trained model path in options/test/test*.xml, and run the following command to test. 60 | If GPU memory is limited, you can consider adding ```PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32``` before these commands. 61 | 62 | ``` 63 | python test.py -opt options/test/testx234.yml 64 | python test.py -opt options/test/testx6.yml 65 | python test.py -opt options/test/testx8.yml 66 | ``` 67 | 68 | ## Ackownledgements 69 | This code is built based on [HAT](https://github.com/XPixelGroup/HAT) and [BasicSR](https://github.com/XPixelGroup/BasicSR). We thank the authors for sharing the codes. 70 | 71 | ## To cite our paper 72 | 73 | ## Related Projects 74 | -------------------------------------------------------------------------------- /basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .archs import * 4 | from .data import * 5 | from .losses import * 6 | from .metrics import * 7 | from .models import * 8 | from .ops import * 9 | from .test import * 10 | from .train import * 11 | from .utils import * 12 | # from .version import __gitsha__, __version__ 13 | -------------------------------------------------------------------------------- /basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with '_arch.py' 12 | arch_folder = osp.dirname(osp.abspath(__file__)) 13 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 14 | # import all the arch modules 15 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 16 | 17 | 18 | def build_network(opt): 19 | opt = deepcopy(opt) 20 | network_type = opt.pop('type') 21 | net = ARCH_REGISTRY.get(network_type)(**opt) 22 | logger = get_root_logger() 23 | logger.info(f'Network [{net.__class__.__name__}] is created.') 24 | return net 25 | -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/arch_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/arch_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/basicvsr_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/basicvsr_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/basicvsrpp_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/basicvsrpp_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/dfdnet_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/dfdnet_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/dfdnet_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/dfdnet_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/discriminator_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/discriminator_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/discriminator_arch.cpython-37.pyc.139952149786672: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/discriminator_arch.cpython-37.pyc.139952149786672 -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/duf_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/duf_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/ecbsr_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/ecbsr_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/edsr_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/edsr_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/edvr_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/edvr_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/hifacegan_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/hifacegan_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/hifacegan_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/hifacegan_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/rcan_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/rcan_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/ridnet_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/ridnet_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/rrdbnet_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/rrdbnet_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/spynet_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/spynet_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/srresnet_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/srresnet_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/srvgg_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/srvgg_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/stylegan2_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/stylegan2_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/stylegan2_bilinear_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/stylegan2_bilinear_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/swinir_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/swinir_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/tof_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/tof_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/__pycache__/vgg_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/archs/__pycache__/vgg_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/archs/dfdnet_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | from torch.nn.utils.spectral_norm import spectral_norm 6 | 7 | 8 | class BlurFunctionBackward(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, grad_output, kernel, kernel_flip): 12 | ctx.save_for_backward(kernel, kernel_flip) 13 | grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]) 14 | return grad_input 15 | 16 | @staticmethod 17 | def backward(ctx, gradgrad_output): 18 | kernel, _ = ctx.saved_tensors 19 | grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]) 20 | return grad_input, None, None 21 | 22 | 23 | class BlurFunction(Function): 24 | 25 | @staticmethod 26 | def forward(ctx, x, kernel, kernel_flip): 27 | ctx.save_for_backward(kernel, kernel_flip) 28 | output = F.conv2d(x, kernel, padding=1, groups=x.shape[1]) 29 | return output 30 | 31 | @staticmethod 32 | def backward(ctx, grad_output): 33 | kernel, kernel_flip = ctx.saved_tensors 34 | grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip) 35 | return grad_input, None, None 36 | 37 | 38 | blur = BlurFunction.apply 39 | 40 | 41 | class Blur(nn.Module): 42 | 43 | def __init__(self, channel): 44 | super().__init__() 45 | kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32) 46 | kernel = kernel.view(1, 1, 3, 3) 47 | kernel = kernel / kernel.sum() 48 | kernel_flip = torch.flip(kernel, [2, 3]) 49 | 50 | self.kernel = kernel.repeat(channel, 1, 1, 1) 51 | self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1) 52 | 53 | def forward(self, x): 54 | return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x)) 55 | 56 | 57 | def calc_mean_std(feat, eps=1e-5): 58 | """Calculate mean and std for adaptive_instance_normalization. 59 | 60 | Args: 61 | feat (Tensor): 4D tensor. 62 | eps (float): A small value added to the variance to avoid 63 | divide-by-zero. Default: 1e-5. 64 | """ 65 | size = feat.size() 66 | assert len(size) == 4, 'The input feature should be 4D tensor.' 67 | n, c = size[:2] 68 | feat_var = feat.view(n, c, -1).var(dim=2) + eps 69 | feat_std = feat_var.sqrt().view(n, c, 1, 1) 70 | feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1) 71 | return feat_mean, feat_std 72 | 73 | 74 | def adaptive_instance_normalization(content_feat, style_feat): 75 | """Adaptive instance normalization. 76 | 77 | Adjust the reference features to have the similar color and illuminations 78 | as those in the degradate features. 79 | 80 | Args: 81 | content_feat (Tensor): The reference feature. 82 | style_feat (Tensor): The degradate features. 83 | """ 84 | size = content_feat.size() 85 | style_mean, style_std = calc_mean_std(style_feat) 86 | content_mean, content_std = calc_mean_std(content_feat) 87 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 88 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 89 | 90 | 91 | def AttentionBlock(in_channel): 92 | return nn.Sequential( 93 | spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True), 94 | spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1))) 95 | 96 | 97 | def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True): 98 | """Conv block used in MSDilationBlock.""" 99 | 100 | return nn.Sequential( 101 | spectral_norm( 102 | nn.Conv2d( 103 | in_channels, 104 | out_channels, 105 | kernel_size=kernel_size, 106 | stride=stride, 107 | dilation=dilation, 108 | padding=((kernel_size - 1) // 2) * dilation, 109 | bias=bias)), 110 | nn.LeakyReLU(0.2), 111 | spectral_norm( 112 | nn.Conv2d( 113 | out_channels, 114 | out_channels, 115 | kernel_size=kernel_size, 116 | stride=stride, 117 | dilation=dilation, 118 | padding=((kernel_size - 1) // 2) * dilation, 119 | bias=bias)), 120 | ) 121 | 122 | 123 | class MSDilationBlock(nn.Module): 124 | """Multi-scale dilation block.""" 125 | 126 | def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True): 127 | super(MSDilationBlock, self).__init__() 128 | 129 | self.conv_blocks = nn.ModuleList() 130 | for i in range(4): 131 | self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias)) 132 | self.conv_fusion = spectral_norm( 133 | nn.Conv2d( 134 | in_channels * 4, 135 | in_channels, 136 | kernel_size=kernel_size, 137 | stride=1, 138 | padding=(kernel_size - 1) // 2, 139 | bias=bias)) 140 | 141 | def forward(self, x): 142 | out = [] 143 | for i in range(4): 144 | out.append(self.conv_blocks[i](x)) 145 | out = torch.cat(out, 1) 146 | out = self.conv_fusion(out) + x 147 | return out 148 | 149 | 150 | class UpResBlock(nn.Module): 151 | 152 | def __init__(self, in_channel): 153 | super(UpResBlock, self).__init__() 154 | self.body = nn.Sequential( 155 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 156 | nn.LeakyReLU(0.2, True), 157 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 158 | ) 159 | 160 | def forward(self, x): 161 | out = x + self.body(x) 162 | return out 163 | -------------------------------------------------------------------------------- /basicsr/archs/edsr_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | 7 | 8 | @ARCH_REGISTRY.register() 9 | class EDSR(nn.Module): 10 | """EDSR network structure. 11 | 12 | Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution. 13 | Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch 14 | 15 | Args: 16 | num_in_ch (int): Channel number of inputs. 17 | num_out_ch (int): Channel number of outputs. 18 | num_feat (int): Channel number of intermediate features. 19 | Default: 64. 20 | num_block (int): Block number in the trunk network. Default: 16. 21 | upscale (int): Upsampling factor. Support 2^n and 3. 22 | Default: 4. 23 | res_scale (float): Used to scale the residual in residual block. 24 | Default: 1. 25 | img_range (float): Image range. Default: 255. 26 | rgb_mean (tuple[float]): Image mean in RGB orders. 27 | Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. 28 | """ 29 | 30 | def __init__(self, 31 | num_in_ch, 32 | num_out_ch, 33 | num_feat=64, 34 | num_block=16, 35 | upscale=4, 36 | res_scale=1, 37 | img_range=255., 38 | rgb_mean=(0.4488, 0.4371, 0.4040)): 39 | super(EDSR, self).__init__() 40 | 41 | self.img_range = img_range 42 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 43 | 44 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 45 | self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True) 46 | self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 47 | self.upsample = Upsample(upscale, num_feat) 48 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 49 | 50 | def forward(self, x): 51 | self.mean = self.mean.type_as(x) 52 | 53 | x = (x - self.mean) * self.img_range 54 | x = self.conv_first(x) 55 | res = self.conv_after_body(self.body(x)) 56 | res += x 57 | 58 | x = self.conv_last(self.upsample(res)) 59 | x = x / self.img_range + self.mean 60 | 61 | return x 62 | -------------------------------------------------------------------------------- /basicsr/archs/rcan_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | from .arch_util import Upsample, make_layer 6 | 7 | 8 | class ChannelAttention(nn.Module): 9 | """Channel attention used in RCAN. 10 | 11 | Args: 12 | num_feat (int): Channel number of intermediate features. 13 | squeeze_factor (int): Channel squeeze factor. Default: 16. 14 | """ 15 | 16 | def __init__(self, num_feat, squeeze_factor=16): 17 | super(ChannelAttention, self).__init__() 18 | self.attention = nn.Sequential( 19 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), 20 | nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid()) 21 | 22 | def forward(self, x): 23 | y = self.attention(x) 24 | return x * y 25 | 26 | 27 | class RCAB(nn.Module): 28 | """Residual Channel Attention Block (RCAB) used in RCAN. 29 | 30 | Args: 31 | num_feat (int): Channel number of intermediate features. 32 | squeeze_factor (int): Channel squeeze factor. Default: 16. 33 | res_scale (float): Scale the residual. Default: 1. 34 | """ 35 | 36 | def __init__(self, num_feat, squeeze_factor=16, res_scale=1): 37 | super(RCAB, self).__init__() 38 | self.res_scale = res_scale 39 | 40 | self.rcab = nn.Sequential( 41 | nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1), 42 | ChannelAttention(num_feat, squeeze_factor)) 43 | 44 | def forward(self, x): 45 | res = self.rcab(x) * self.res_scale 46 | return res + x 47 | 48 | 49 | class ResidualGroup(nn.Module): 50 | """Residual Group of RCAB. 51 | 52 | Args: 53 | num_feat (int): Channel number of intermediate features. 54 | num_block (int): Block number in the body network. 55 | squeeze_factor (int): Channel squeeze factor. Default: 16. 56 | res_scale (float): Scale the residual. Default: 1. 57 | """ 58 | 59 | def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1): 60 | super(ResidualGroup, self).__init__() 61 | 62 | self.residual_group = make_layer( 63 | RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale) 64 | self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 65 | 66 | def forward(self, x): 67 | res = self.conv(self.residual_group(x)) 68 | return res + x 69 | 70 | 71 | @ARCH_REGISTRY.register() 72 | class RCAN(nn.Module): 73 | """Residual Channel Attention Networks. 74 | 75 | ``Paper: Image Super-Resolution Using Very Deep Residual Channel Attention Networks`` 76 | 77 | Reference: https://github.com/yulunzhang/RCAN 78 | 79 | Args: 80 | num_in_ch (int): Channel number of inputs. 81 | num_out_ch (int): Channel number of outputs. 82 | num_feat (int): Channel number of intermediate features. 83 | Default: 64. 84 | num_group (int): Number of ResidualGroup. Default: 10. 85 | num_block (int): Number of RCAB in ResidualGroup. Default: 16. 86 | squeeze_factor (int): Channel squeeze factor. Default: 16. 87 | upscale (int): Upsampling factor. Support 2^n and 3. 88 | Default: 4. 89 | res_scale (float): Used to scale the residual in residual block. 90 | Default: 1. 91 | img_range (float): Image range. Default: 255. 92 | rgb_mean (tuple[float]): Image mean in RGB orders. 93 | Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. 94 | """ 95 | 96 | def __init__(self, 97 | num_in_ch, 98 | num_out_ch, 99 | num_feat=64, 100 | num_group=10, 101 | num_block=16, 102 | squeeze_factor=16, 103 | upscale=4, 104 | res_scale=1, 105 | img_range=255., 106 | rgb_mean=(0.4488, 0.4371, 0.4040)): 107 | super(RCAN, self).__init__() 108 | 109 | self.img_range = img_range 110 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 111 | 112 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 113 | self.body = make_layer( 114 | ResidualGroup, 115 | num_group, 116 | num_feat=num_feat, 117 | num_block=num_block, 118 | squeeze_factor=squeeze_factor, 119 | res_scale=res_scale) 120 | self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 121 | self.upsample = Upsample(upscale, num_feat) 122 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 123 | 124 | def forward(self, x): 125 | self.mean = self.mean.type_as(x) 126 | 127 | x = (x - self.mean) * self.img_range 128 | x = self.conv_first(x) 129 | res = self.conv_after_body(self.body(x)) 130 | res += x 131 | 132 | x = self.conv_last(self.upsample(res)) 133 | x = x / self.img_range + self.mean 134 | 135 | return x 136 | -------------------------------------------------------------------------------- /basicsr/archs/rrdbnet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | from .arch_util import default_init_weights, make_layer, pixel_unshuffle 7 | 8 | 9 | class ResidualDenseBlock(nn.Module): 10 | """Residual Dense Block. 11 | 12 | Used in RRDB block in ESRGAN. 13 | 14 | Args: 15 | num_feat (int): Channel number of intermediate features. 16 | num_grow_ch (int): Channels for each growth. 17 | """ 18 | 19 | def __init__(self, num_feat=64, num_grow_ch=32): 20 | super(ResidualDenseBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 22 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 23 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 24 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 25 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 26 | 27 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | 29 | # initialization 30 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 31 | 32 | def forward(self, x): 33 | x1 = self.lrelu(self.conv1(x)) 34 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 35 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 36 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 37 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 38 | # Empirically, we use 0.2 to scale the residual for better performance 39 | return x5 * 0.2 + x 40 | 41 | 42 | class RRDB(nn.Module): 43 | """Residual in Residual Dense Block. 44 | 45 | Used in RRDB-Net in ESRGAN. 46 | 47 | Args: 48 | num_feat (int): Channel number of intermediate features. 49 | num_grow_ch (int): Channels for each growth. 50 | """ 51 | 52 | def __init__(self, num_feat, num_grow_ch=32): 53 | super(RRDB, self).__init__() 54 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 55 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 56 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 57 | 58 | def forward(self, x): 59 | out = self.rdb1(x) 60 | out = self.rdb2(out) 61 | out = self.rdb3(out) 62 | # Empirically, we use 0.2 to scale the residual for better performance 63 | return out * 0.2 + x 64 | 65 | 66 | @ARCH_REGISTRY.register() 67 | class RRDBNet(nn.Module): 68 | """Networks consisting of Residual in Residual Dense Block, which is used 69 | in ESRGAN. 70 | 71 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 72 | 73 | We extend ESRGAN for scale x2 and scale x1. 74 | Note: This is one option for scale 1, scale 2 in RRDBNet. 75 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 76 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 77 | 78 | Args: 79 | num_in_ch (int): Channel number of inputs. 80 | num_out_ch (int): Channel number of outputs. 81 | num_feat (int): Channel number of intermediate features. 82 | Default: 64 83 | num_block (int): Block number in the trunk network. Defaults: 23 84 | num_grow_ch (int): Channels for each growth. Default: 32. 85 | """ 86 | 87 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 88 | super(RRDBNet, self).__init__() 89 | self.scale = scale 90 | if scale == 2: 91 | num_in_ch = num_in_ch * 4 92 | elif scale == 1: 93 | num_in_ch = num_in_ch * 16 94 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 95 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 96 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 97 | # upsample 98 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 99 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 100 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 101 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 102 | 103 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 104 | 105 | def forward(self, x): 106 | if self.scale == 2: 107 | feat = pixel_unshuffle(x, scale=2) 108 | elif self.scale == 1: 109 | feat = pixel_unshuffle(x, scale=4) 110 | else: 111 | feat = x 112 | feat = self.conv_first(feat) 113 | body_feat = self.conv_body(self.body(feat)) 114 | feat = feat + body_feat 115 | # upsample 116 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 117 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 118 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 119 | return out 120 | -------------------------------------------------------------------------------- /basicsr/archs/spynet_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn as nn 4 | from torch.nn import functional as F 5 | 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | from .arch_util import flow_warp 8 | 9 | 10 | class BasicModule(nn.Module): 11 | """Basic Module for SpyNet. 12 | """ 13 | 14 | def __init__(self): 15 | super(BasicModule, self).__init__() 16 | 17 | self.basic_module = nn.Sequential( 18 | nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 19 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 20 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 21 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 22 | nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) 23 | 24 | def forward(self, tensor_input): 25 | return self.basic_module(tensor_input) 26 | 27 | 28 | @ARCH_REGISTRY.register() 29 | class SpyNet(nn.Module): 30 | """SpyNet architecture. 31 | 32 | Args: 33 | load_path (str): path for pretrained SpyNet. Default: None. 34 | """ 35 | 36 | def __init__(self, load_path=None): 37 | super(SpyNet, self).__init__() 38 | self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) 39 | if load_path: 40 | self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) 41 | 42 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 43 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 44 | 45 | def preprocess(self, tensor_input): 46 | tensor_output = (tensor_input - self.mean) / self.std 47 | return tensor_output 48 | 49 | def process(self, ref, supp): 50 | flow = [] 51 | 52 | ref = [self.preprocess(ref)] 53 | supp = [self.preprocess(supp)] 54 | 55 | for level in range(5): 56 | ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) 57 | supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) 58 | 59 | flow = ref[0].new_zeros( 60 | [ref[0].size(0), 2, 61 | int(math.floor(ref[0].size(2) / 2.0)), 62 | int(math.floor(ref[0].size(3) / 2.0))]) 63 | 64 | for level in range(len(ref)): 65 | upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 66 | 67 | if upsampled_flow.size(2) != ref[level].size(2): 68 | upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') 69 | if upsampled_flow.size(3) != ref[level].size(3): 70 | upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') 71 | 72 | flow = self.basic_module[level](torch.cat([ 73 | ref[level], 74 | flow_warp( 75 | supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), 76 | upsampled_flow 77 | ], 1)) + upsampled_flow 78 | 79 | return flow 80 | 81 | def forward(self, ref, supp): 82 | assert ref.size() == supp.size() 83 | 84 | h, w = ref.size(2), ref.size(3) 85 | w_floor = math.floor(math.ceil(w / 32.0) * 32.0) 86 | h_floor = math.floor(math.ceil(h / 32.0) * 32.0) 87 | 88 | ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) 89 | supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) 90 | 91 | flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False) 92 | 93 | flow[:, 0, :, :] *= float(w) / float(w_floor) 94 | flow[:, 1, :, :] *= float(h) / float(h_floor) 95 | 96 | return flow 97 | -------------------------------------------------------------------------------- /basicsr/archs/srresnet_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer 6 | 7 | 8 | @ARCH_REGISTRY.register() 9 | class MSRResNet(nn.Module): 10 | """Modified SRResNet. 11 | 12 | A compacted version modified from SRResNet in 13 | "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network" 14 | It uses residual blocks without BN, similar to EDSR. 15 | Currently, it supports x2, x3 and x4 upsampling scale factor. 16 | 17 | Args: 18 | num_in_ch (int): Channel number of inputs. Default: 3. 19 | num_out_ch (int): Channel number of outputs. Default: 3. 20 | num_feat (int): Channel number of intermediate features. Default: 64. 21 | num_block (int): Block number in the body network. Default: 16. 22 | upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4. 23 | """ 24 | 25 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4): 26 | super(MSRResNet, self).__init__() 27 | self.upscale = upscale 28 | 29 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 30 | self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat) 31 | 32 | # upsampling 33 | if self.upscale in [2, 3]: 34 | self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1) 35 | self.pixel_shuffle = nn.PixelShuffle(self.upscale) 36 | elif self.upscale == 4: 37 | self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) 38 | self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) 39 | self.pixel_shuffle = nn.PixelShuffle(2) 40 | 41 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 42 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 43 | 44 | # activation function 45 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 46 | 47 | # initialization 48 | default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1) 49 | if self.upscale == 4: 50 | default_init_weights(self.upconv2, 0.1) 51 | 52 | def forward(self, x): 53 | feat = self.lrelu(self.conv_first(x)) 54 | out = self.body(feat) 55 | 56 | if self.upscale == 4: 57 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 58 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 59 | elif self.upscale in [2, 3]: 60 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 61 | 62 | out = self.conv_last(self.lrelu(self.conv_hr(out))) 63 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) 64 | out += base 65 | return out 66 | -------------------------------------------------------------------------------- /basicsr/archs/srvgg_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | 6 | 7 | @ARCH_REGISTRY.register(suffix='basicsr') 8 | class SRVGGNetCompact(nn.Module): 9 | """A compact VGG-style network structure for super-resolution. 10 | 11 | It is a compact network structure, which performs upsampling in the last layer and no convolution is 12 | conducted on the HR feature space. 13 | 14 | Args: 15 | num_in_ch (int): Channel number of inputs. Default: 3. 16 | num_out_ch (int): Channel number of outputs. Default: 3. 17 | num_feat (int): Channel number of intermediate features. Default: 64. 18 | num_conv (int): Number of convolution layers in the body network. Default: 16. 19 | upscale (int): Upsampling factor. Default: 4. 20 | act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. 21 | """ 22 | 23 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): 24 | super(SRVGGNetCompact, self).__init__() 25 | self.num_in_ch = num_in_ch 26 | self.num_out_ch = num_out_ch 27 | self.num_feat = num_feat 28 | self.num_conv = num_conv 29 | self.upscale = upscale 30 | self.act_type = act_type 31 | 32 | self.body = nn.ModuleList() 33 | # the first conv 34 | self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) 35 | # the first activation 36 | if act_type == 'relu': 37 | activation = nn.ReLU(inplace=True) 38 | elif act_type == 'prelu': 39 | activation = nn.PReLU(num_parameters=num_feat) 40 | elif act_type == 'leakyrelu': 41 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 42 | self.body.append(activation) 43 | 44 | # the body structure 45 | for _ in range(num_conv): 46 | self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) 47 | # activation 48 | if act_type == 'relu': 49 | activation = nn.ReLU(inplace=True) 50 | elif act_type == 'prelu': 51 | activation = nn.PReLU(num_parameters=num_feat) 52 | elif act_type == 'leakyrelu': 53 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 54 | self.body.append(activation) 55 | 56 | # the last conv 57 | self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) 58 | # upsample 59 | self.upsampler = nn.PixelShuffle(upscale) 60 | 61 | def forward(self, x): 62 | out = x 63 | for i in range(0, len(self.body)): 64 | out = self.body[i](out) 65 | 66 | out = self.upsampler(out) 67 | # add the nearest upsampled image, so that the network learns the residual 68 | base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') 69 | out += base 70 | return out 71 | -------------------------------------------------------------------------------- /basicsr/archs/tof_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | from .arch_util import flow_warp 7 | 8 | 9 | class BasicModule(nn.Module): 10 | """Basic module of SPyNet. 11 | 12 | Note that unlike the architecture in spynet_arch.py, the basic module 13 | here contains batch normalization. 14 | """ 15 | 16 | def __init__(self): 17 | super(BasicModule, self).__init__() 18 | self.basic_module = nn.Sequential( 19 | nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False), 20 | nn.BatchNorm2d(32), nn.ReLU(inplace=True), 21 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False), 22 | nn.BatchNorm2d(64), nn.ReLU(inplace=True), 23 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False), 24 | nn.BatchNorm2d(32), nn.ReLU(inplace=True), 25 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False), 26 | nn.BatchNorm2d(16), nn.ReLU(inplace=True), 27 | nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) 28 | 29 | def forward(self, tensor_input): 30 | """ 31 | Args: 32 | tensor_input (Tensor): Input tensor with shape (b, 8, h, w). 33 | 8 channels contain: 34 | [reference image (3), neighbor image (3), initial flow (2)]. 35 | 36 | Returns: 37 | Tensor: Estimated flow with shape (b, 2, h, w) 38 | """ 39 | return self.basic_module(tensor_input) 40 | 41 | 42 | class SPyNetTOF(nn.Module): 43 | """SPyNet architecture for TOF. 44 | 45 | Note that this implementation is specifically for TOFlow. Please use :file:`spynet_arch.py` for general use. 46 | They differ in the following aspects: 47 | 48 | 1. The basic modules here contain BatchNorm. 49 | 2. Normalization and denormalization are not done here, as they are done in TOFlow. 50 | 51 | ``Paper: Optical Flow Estimation using a Spatial Pyramid Network`` 52 | 53 | Reference: https://github.com/Coldog2333/pytoflow 54 | 55 | Args: 56 | load_path (str): Path for pretrained SPyNet. Default: None. 57 | """ 58 | 59 | def __init__(self, load_path=None): 60 | super(SPyNetTOF, self).__init__() 61 | 62 | self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)]) 63 | if load_path: 64 | self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) 65 | 66 | def forward(self, ref, supp): 67 | """ 68 | Args: 69 | ref (Tensor): Reference image with shape of (b, 3, h, w). 70 | supp: The supporting image to be warped: (b, 3, h, w). 71 | 72 | Returns: 73 | Tensor: Estimated optical flow: (b, 2, h, w). 74 | """ 75 | num_batches, _, h, w = ref.size() 76 | ref = [ref] 77 | supp = [supp] 78 | 79 | # generate downsampled frames 80 | for _ in range(3): 81 | ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) 82 | supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) 83 | 84 | # flow computation 85 | flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16) 86 | for i in range(4): 87 | flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 88 | flow = flow_up + self.basic_module[i]( 89 | torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1)) 90 | return flow 91 | 92 | 93 | @ARCH_REGISTRY.register() 94 | class TOFlow(nn.Module): 95 | """PyTorch implementation of TOFlow. 96 | 97 | In TOFlow, the LR frames are pre-upsampled and have the same size with the GT frames. 98 | 99 | ``Paper: Video Enhancement with Task-Oriented Flow`` 100 | 101 | Reference: https://github.com/anchen1011/toflow 102 | 103 | Reference: https://github.com/Coldog2333/pytoflow 104 | 105 | Args: 106 | adapt_official_weights (bool): Whether to adapt the weights translated 107 | from the official implementation. Set to false if you want to 108 | train from scratch. Default: False 109 | """ 110 | 111 | def __init__(self, adapt_official_weights=False): 112 | super(TOFlow, self).__init__() 113 | self.adapt_official_weights = adapt_official_weights 114 | self.ref_idx = 0 if adapt_official_weights else 3 115 | 116 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 117 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 118 | 119 | # flow estimation module 120 | self.spynet = SPyNetTOF() 121 | 122 | # reconstruction module 123 | self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4) 124 | self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4) 125 | self.conv_3 = nn.Conv2d(64, 64, 1) 126 | self.conv_4 = nn.Conv2d(64, 3, 1) 127 | 128 | # activation function 129 | self.relu = nn.ReLU(inplace=True) 130 | 131 | def normalize(self, img): 132 | return (img - self.mean) / self.std 133 | 134 | def denormalize(self, img): 135 | return img * self.std + self.mean 136 | 137 | def forward(self, lrs): 138 | """ 139 | Args: 140 | lrs: Input lr frames: (b, 7, 3, h, w). 141 | 142 | Returns: 143 | Tensor: SR frame: (b, 3, h, w). 144 | """ 145 | # In the official implementation, the 0-th frame is the reference frame 146 | if self.adapt_official_weights: 147 | lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :] 148 | 149 | num_batches, num_lrs, _, h, w = lrs.size() 150 | 151 | lrs = self.normalize(lrs.view(-1, 3, h, w)) 152 | lrs = lrs.view(num_batches, num_lrs, 3, h, w) 153 | 154 | lr_ref = lrs[:, self.ref_idx, :, :, :] 155 | lr_aligned = [] 156 | for i in range(7): # 7 frames 157 | if i == self.ref_idx: 158 | lr_aligned.append(lr_ref) 159 | else: 160 | lr_supp = lrs[:, i, :, :, :] 161 | flow = self.spynet(lr_ref, lr_supp) 162 | lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1))) 163 | 164 | # reconstruction 165 | hr = torch.stack(lr_aligned, dim=1) 166 | hr = hr.view(num_batches, -1, h, w) 167 | hr = self.relu(self.conv_1(hr)) 168 | hr = self.relu(self.conv_2(hr)) 169 | hr = self.relu(self.conv_3(hr)) 170 | hr = self.conv_4(hr) + lr_ref 171 | 172 | return self.denormalize(hr) 173 | -------------------------------------------------------------------------------- /basicsr/archs/vgg_arch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from torch import nn as nn 5 | from torchvision.models import vgg as vgg 6 | 7 | from basicsr.utils.registry import ARCH_REGISTRY 8 | 9 | VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' 10 | NAMES = { 11 | 'vgg11': [ 12 | 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 13 | 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 14 | 'pool5' 15 | ], 16 | 'vgg13': [ 17 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 18 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 19 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' 20 | ], 21 | 'vgg16': [ 22 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 23 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 24 | 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 25 | 'pool5' 26 | ], 27 | 'vgg19': [ 28 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 29 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', 30 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 31 | 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' 32 | ] 33 | } 34 | 35 | 36 | def insert_bn(names): 37 | """Insert bn layer after each conv. 38 | 39 | Args: 40 | names (list): The list of layer names. 41 | 42 | Returns: 43 | list: The list of layer names with bn layers. 44 | """ 45 | names_bn = [] 46 | for name in names: 47 | names_bn.append(name) 48 | if 'conv' in name: 49 | position = name.replace('conv', '') 50 | names_bn.append('bn' + position) 51 | return names_bn 52 | 53 | 54 | @ARCH_REGISTRY.register() 55 | class VGGFeatureExtractor(nn.Module): 56 | """VGG network for feature extraction. 57 | 58 | In this implementation, we allow users to choose whether use normalization 59 | in the input feature and the type of vgg network. Note that the pretrained 60 | path must fit the vgg type. 61 | 62 | Args: 63 | layer_name_list (list[str]): Forward function returns the corresponding 64 | features according to the layer_name_list. 65 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}. 66 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'. 67 | use_input_norm (bool): If True, normalize the input image. Importantly, 68 | the input feature must in the range [0, 1]. Default: True. 69 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. 70 | Default: False. 71 | requires_grad (bool): If true, the parameters of VGG network will be 72 | optimized. Default: False. 73 | remove_pooling (bool): If true, the max pooling operations in VGG net 74 | will be removed. Default: False. 75 | pooling_stride (int): The stride of max pooling operation. Default: 2. 76 | """ 77 | 78 | def __init__(self, 79 | layer_name_list, 80 | vgg_type='vgg19', 81 | use_input_norm=True, 82 | range_norm=False, 83 | requires_grad=False, 84 | remove_pooling=False, 85 | pooling_stride=2): 86 | super(VGGFeatureExtractor, self).__init__() 87 | 88 | self.layer_name_list = layer_name_list 89 | self.use_input_norm = use_input_norm 90 | self.range_norm = range_norm 91 | 92 | self.names = NAMES[vgg_type.replace('_bn', '')] 93 | if 'bn' in vgg_type: 94 | self.names = insert_bn(self.names) 95 | 96 | # only borrow layers that will be used to avoid unused params 97 | max_idx = 0 98 | for v in layer_name_list: 99 | idx = self.names.index(v) 100 | if idx > max_idx: 101 | max_idx = idx 102 | 103 | if os.path.exists(VGG_PRETRAIN_PATH): 104 | vgg_net = getattr(vgg, vgg_type)(pretrained=False) 105 | state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) 106 | vgg_net.load_state_dict(state_dict) 107 | else: 108 | vgg_net = getattr(vgg, vgg_type)(pretrained=True) 109 | 110 | features = vgg_net.features[:max_idx + 1] 111 | 112 | modified_net = OrderedDict() 113 | for k, v in zip(self.names, features): 114 | if 'pool' in k: 115 | # if remove_pooling is true, pooling operation will be removed 116 | if remove_pooling: 117 | continue 118 | else: 119 | # in some cases, we may want to change the default stride 120 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) 121 | else: 122 | modified_net[k] = v 123 | 124 | self.vgg_net = nn.Sequential(modified_net) 125 | 126 | if not requires_grad: 127 | self.vgg_net.eval() 128 | for param in self.parameters(): 129 | param.requires_grad = False 130 | else: 131 | self.vgg_net.train() 132 | for param in self.parameters(): 133 | param.requires_grad = True 134 | 135 | if self.use_input_norm: 136 | # the mean is for image with range [0, 1] 137 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 138 | # the std is for image with range [0, 1] 139 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 140 | 141 | def forward(self, x): 142 | """Forward function. 143 | 144 | Args: 145 | x (Tensor): Input tensor with shape (n, c, h, w). 146 | 147 | Returns: 148 | Tensor: Forward results. 149 | """ 150 | if self.range_norm: 151 | x = (x + 1) / 2 152 | if self.use_input_norm: 153 | x = (x - self.mean) / self.std 154 | 155 | output = {} 156 | for key, layer in self.vgg_net._modules.items(): 157 | x = layer(x) 158 | if key in self.layer_name_list: 159 | output[key] = x.clone() 160 | 161 | return output 162 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 11 | from basicsr.utils import get_root_logger, scandir 12 | from basicsr.utils.dist_util import get_dist_info 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must contain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) 84 | 85 | prefetch_mode = dataset_opt.get('prefetch_mode') 86 | if prefetch_mode == 'cpu': # CPUPrefetcher 87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 88 | logger = get_root_logger() 89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') 90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 91 | else: 92 | # prefetch_mode=None: Normal dataloader 93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 94 | return torch.utils.data.DataLoader(**dataloader_args) 95 | 96 | 97 | def worker_init_fn(worker_id, num_workers, rank, seed): 98 | # Set the worker seed to num_workers * rank + worker_id + seed 99 | worker_seed = num_workers * rank + worker_id + seed 100 | np.random.seed(worker_seed) 101 | random.seed(worker_seed) 102 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from os import path as osp 4 | from torch.utils import data as data 5 | from torchvision.transforms.functional import normalize 6 | 7 | from basicsr.data.transforms import augment 8 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 9 | from basicsr.utils.registry import DATASET_REGISTRY 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class FFHQDataset(data.Dataset): 14 | """FFHQ dataset for StyleGAN. 15 | 16 | Args: 17 | opt (dict): Config for train datasets. It contains the following keys: 18 | dataroot_gt (str): Data root path for gt. 19 | io_backend (dict): IO backend type and other kwarg. 20 | mean (list | tuple): Image mean. 21 | std (list | tuple): Image std. 22 | use_hflip (bool): Whether to horizontally flip. 23 | 24 | """ 25 | 26 | def __init__(self, opt): 27 | super(FFHQDataset, self).__init__() 28 | self.opt = opt 29 | # file client (io backend) 30 | self.file_client = None 31 | self.io_backend_opt = opt['io_backend'] 32 | 33 | self.gt_folder = opt['dataroot_gt'] 34 | self.mean = opt['mean'] 35 | self.std = opt['std'] 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = self.gt_folder 39 | if not self.gt_folder.endswith('.lmdb'): 40 | raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") 41 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 42 | self.paths = [line.split('.')[0] for line in fin] 43 | else: 44 | # FFHQ has 70000 images in total 45 | self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)] 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load gt image 52 | gt_path = self.paths[index] 53 | # avoid errors caused by high latency in reading files 54 | retry = 3 55 | while retry > 0: 56 | try: 57 | img_bytes = self.file_client.get(gt_path) 58 | except Exception as e: 59 | logger = get_root_logger() 60 | logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}') 61 | # change another file to read 62 | index = random.randint(0, self.__len__()) 63 | gt_path = self.paths[index] 64 | time.sleep(1) # sleep 1s for occasional server congestion 65 | else: 66 | break 67 | finally: 68 | retry -= 1 69 | img_gt = imfrombytes(img_bytes, float32=True) 70 | 71 | # random horizontal flip 72 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 73 | # BGR to RGB, HWC to CHW, numpy to tensor 74 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 75 | # normalize 76 | normalize(img_gt, self.mean, self.std, inplace=True) 77 | return {'gt': img_gt, 'gt_path': gt_path} 78 | 79 | def __len__(self): 80 | return len(self.paths) 81 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDS4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 000 100 (720,1280,3) 2 | 011 100 (720,1280,3) 3 | 015 100 (720,1280,3) 4 | 020 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSofficial4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 246 100 (720,1280,3) 4 | 257 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSval_official_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 242 100 (720,1280,3) 4 | 243 100 (720,1280,3) 5 | 244 100 (720,1280,3) 6 | 245 100 (720,1280,3) 7 | 246 100 (720,1280,3) 8 | 247 100 (720,1280,3) 9 | 248 100 (720,1280,3) 10 | 249 100 (720,1280,3) 11 | 250 100 (720,1280,3) 12 | 251 100 (720,1280,3) 13 | 252 100 (720,1280,3) 14 | 253 100 (720,1280,3) 15 | 254 100 (720,1280,3) 16 | 255 100 (720,1280,3) 17 | 256 100 (720,1280,3) 18 | 257 100 (720,1280,3) 19 | 258 100 (720,1280,3) 20 | 259 100 (720,1280,3) 21 | 260 100 (720,1280,3) 22 | 261 100 (720,1280,3) 23 | 262 100 (720,1280,3) 24 | 263 100 (720,1280,3) 25 | 264 100 (720,1280,3) 26 | 265 100 (720,1280,3) 27 | 266 100 (720,1280,3) 28 | 267 100 (720,1280,3) 29 | 268 100 (720,1280,3) 30 | 269 100 (720,1280,3) 31 | -------------------------------------------------------------------------------- /basicsr/data/paired_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from torchvision.transforms.functional import normalize 3 | 4 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file 5 | from basicsr.data.transforms import augment, paired_random_crop 6 | from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class PairedImageDataset(data.Dataset): 12 | """Paired image dataset for image restoration. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 15 | 16 | There are three modes: 17 | 18 | 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb. 19 | 2. **meta_info_file**: Use meta information file to generate paths. \ 20 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 21 | 3. **folder**: Scan folders to generate paths. The rest. 22 | 23 | Args: 24 | opt (dict): Config for train datasets. It contains the following keys: 25 | dataroot_gt (str): Data root path for gt. 26 | dataroot_lq (str): Data root path for lq. 27 | meta_info_file (str): Path for meta information file. 28 | io_backend (dict): IO backend type and other kwarg. 29 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 30 | Default: '{}'. 31 | gt_size (int): Cropped patched size for gt patches. 32 | use_hflip (bool): Use horizontal flips. 33 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 34 | scale (bool): Scale, which will be added automatically. 35 | phase (str): 'train' or 'val'. 36 | """ 37 | 38 | def __init__(self, opt): 39 | super(PairedImageDataset, self).__init__() 40 | self.opt = opt 41 | # file client (io backend) 42 | self.file_client = None 43 | self.io_backend_opt = opt['io_backend'] 44 | self.mean = opt['mean'] if 'mean' in opt else None 45 | self.std = opt['std'] if 'std' in opt else None 46 | 47 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 48 | if 'filename_tmpl' in opt: 49 | self.filename_tmpl = opt['filename_tmpl'] 50 | else: 51 | self.filename_tmpl = '{}' 52 | 53 | if self.io_backend_opt['type'] == 'lmdb': 54 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 55 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 56 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 57 | elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: 58 | self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], 59 | self.opt['meta_info_file'], self.filename_tmpl) 60 | else: 61 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 62 | 63 | def __getitem__(self, index): 64 | if self.file_client is None: 65 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 66 | 67 | scale = self.opt['scale'] 68 | 69 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 70 | # image range: [0, 1], float32. 71 | gt_path = self.paths[index]['gt_path'] 72 | img_bytes = self.file_client.get(gt_path, 'gt') 73 | img_gt = imfrombytes(img_bytes, float32=True) 74 | lq_path = self.paths[index]['lq_path'] 75 | img_bytes = self.file_client.get(lq_path, 'lq') 76 | img_lq = imfrombytes(img_bytes, float32=True) 77 | 78 | # augmentation for training 79 | if self.opt['phase'] == 'train': 80 | gt_size = self.opt['gt_size'] 81 | # random crop 82 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 83 | # flip, rotation 84 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 85 | 86 | # color space transform 87 | if 'color' in self.opt and self.opt['color'] == 'y': 88 | img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None] 89 | img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None] 90 | 91 | # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets 92 | # TODO: It is better to update the datasets, rather than force to crop 93 | if self.opt['phase'] != 'train': 94 | img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] 95 | 96 | # BGR to RGB, HWC to CHW, numpy to tensor 97 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 98 | # normalize 99 | if self.mean is not None or self.std is not None: 100 | normalize(img_lq, self.mean, self.std, inplace=True) 101 | normalize(img_gt, self.mean, self.std, inplace=True) 102 | 103 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 104 | 105 | def __len__(self): 106 | return len(self.paths) 107 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 11 | 12 | Args: 13 | generator: Python generator. 14 | num_prefetch_queue (int): Number of prefetch queue. 15 | """ 16 | 17 | def __init__(self, generator, num_prefetch_queue): 18 | threading.Thread.__init__(self) 19 | self.queue = Queue.Queue(num_prefetch_queue) 20 | self.generator = generator 21 | self.daemon = True 22 | self.start() 23 | 24 | def run(self): 25 | for item in self.generator: 26 | self.queue.put(item) 27 | self.queue.put(None) 28 | 29 | def __next__(self): 30 | next_item = self.queue.get() 31 | if next_item is None: 32 | raise StopIteration 33 | return next_item 34 | 35 | def __iter__(self): 36 | return self 37 | 38 | 39 | class PrefetchDataLoader(DataLoader): 40 | """Prefetch version of dataloader. 41 | 42 | Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 43 | 44 | TODO: 45 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 46 | ddp. 47 | 48 | Args: 49 | num_prefetch_queue (int): Number of prefetch queue. 50 | kwargs (dict): Other arguments for dataloader. 51 | """ 52 | 53 | def __init__(self, num_prefetch_queue, **kwargs): 54 | self.num_prefetch_queue = num_prefetch_queue 55 | super(PrefetchDataLoader, self).__init__(**kwargs) 56 | 57 | def __iter__(self): 58 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 59 | 60 | 61 | class CPUPrefetcher(): 62 | """CPU prefetcher. 63 | 64 | Args: 65 | loader: Dataloader. 66 | """ 67 | 68 | def __init__(self, loader): 69 | self.ori_loader = loader 70 | self.loader = iter(loader) 71 | 72 | def next(self): 73 | try: 74 | return next(self.loader) 75 | except StopIteration: 76 | return None 77 | 78 | def reset(self): 79 | self.loader = iter(self.ori_loader) 80 | 81 | 82 | class CUDAPrefetcher(): 83 | """CUDA prefetcher. 84 | 85 | Reference: https://github.com/NVIDIA/apex/issues/304# 86 | 87 | It may consume more GPU memory. 88 | 89 | Args: 90 | loader: Dataloader. 91 | opt (dict): Options. 92 | """ 93 | 94 | def __init__(self, loader, opt): 95 | self.ori_loader = loader 96 | self.loader = iter(loader) 97 | self.opt = opt 98 | self.stream = torch.cuda.Stream() 99 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 100 | self.preload() 101 | 102 | def preload(self): 103 | try: 104 | self.batch = next(self.loader) # self.batch is a dict 105 | except StopIteration: 106 | self.batch = None 107 | return None 108 | # put tensors to gpu 109 | with torch.cuda.stream(self.stream): 110 | for k, v in self.batch.items(): 111 | if torch.is_tensor(v): 112 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 113 | 114 | def next(self): 115 | torch.cuda.current_stream().wait_stream(self.stream) 116 | batch = self.batch 117 | self.preload() 118 | return batch 119 | 120 | def reset(self): 121 | self.loader = iter(self.ori_loader) 122 | self.preload() 123 | -------------------------------------------------------------------------------- /basicsr/data/realesrgan_paired_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb 6 | from basicsr.data.transforms import augment, paired_random_crop 7 | from basicsr.utils import FileClient, imfrombytes, img2tensor 8 | from basicsr.utils.registry import DATASET_REGISTRY 9 | 10 | 11 | @DATASET_REGISTRY.register(suffix='basicsr') 12 | class RealESRGANPairedDataset(data.Dataset): 13 | """Paired image dataset for image restoration. 14 | 15 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 16 | 17 | There are three modes: 18 | 19 | 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb. 20 | 2. **meta_info_file**: Use meta information file to generate paths. \ 21 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 22 | 3. **folder**: Scan folders to generate paths. The rest. 23 | 24 | Args: 25 | opt (dict): Config for train datasets. It contains the following keys: 26 | dataroot_gt (str): Data root path for gt. 27 | dataroot_lq (str): Data root path for lq. 28 | meta_info (str): Path for meta information file. 29 | io_backend (dict): IO backend type and other kwarg. 30 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 31 | Default: '{}'. 32 | gt_size (int): Cropped patched size for gt patches. 33 | use_hflip (bool): Use horizontal flips. 34 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 35 | scale (bool): Scale, which will be added automatically. 36 | phase (str): 'train' or 'val'. 37 | """ 38 | 39 | def __init__(self, opt): 40 | super(RealESRGANPairedDataset, self).__init__() 41 | self.opt = opt 42 | self.file_client = None 43 | self.io_backend_opt = opt['io_backend'] 44 | # mean and std for normalizing the input images 45 | self.mean = opt['mean'] if 'mean' in opt else None 46 | self.std = opt['std'] if 'std' in opt else None 47 | 48 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 49 | self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' 50 | 51 | # file client (lmdb io backend) 52 | if self.io_backend_opt['type'] == 'lmdb': 53 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 54 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 55 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 56 | elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: 57 | # disk backend with meta_info 58 | # Each line in the meta_info describes the relative path to an image 59 | with open(self.opt['meta_info']) as fin: 60 | paths = [line.strip() for line in fin] 61 | self.paths = [] 62 | for path in paths: 63 | gt_path, lq_path = path.split(', ') 64 | gt_path = os.path.join(self.gt_folder, gt_path) 65 | lq_path = os.path.join(self.lq_folder, lq_path) 66 | self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) 67 | else: 68 | # disk backend 69 | # it will scan the whole folder to get meta info 70 | # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file 71 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 72 | 73 | def __getitem__(self, index): 74 | if self.file_client is None: 75 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 76 | 77 | scale = self.opt['scale'] 78 | 79 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 80 | # image range: [0, 1], float32. 81 | gt_path = self.paths[index]['gt_path'] 82 | img_bytes = self.file_client.get(gt_path, 'gt') 83 | img_gt = imfrombytes(img_bytes, float32=True) 84 | lq_path = self.paths[index]['lq_path'] 85 | img_bytes = self.file_client.get(lq_path, 'lq') 86 | img_lq = imfrombytes(img_bytes, float32=True) 87 | 88 | # augmentation for training 89 | if self.opt['phase'] == 'train': 90 | gt_size = self.opt['gt_size'] 91 | # random crop 92 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 93 | # flip, rotation 94 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 95 | 96 | # BGR to RGB, HWC to CHW, numpy to tensor 97 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 98 | # normalize 99 | if self.mean is not None or self.std is not None: 100 | normalize(img_lq, self.mean, self.std, inplace=True) 101 | normalize(img_gt, self.mean, self.std, inplace=True) 102 | 103 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 104 | 105 | def __len__(self): 106 | return len(self.paths) 107 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SingleImageDataset(data.Dataset): 12 | """Read only lq images in the test phase. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 15 | 16 | There are two modes: 17 | 1. 'meta_info_file': Use meta information file to generate paths. 18 | 2. 'folder': Scan folders to generate paths. 19 | 20 | Args: 21 | opt (dict): Config for train datasets. It contains the following keys: 22 | dataroot_lq (str): Data root path for lq. 23 | meta_info_file (str): Path for meta information file. 24 | io_backend (dict): IO backend type and other kwarg. 25 | """ 26 | 27 | def __init__(self, opt): 28 | super(SingleImageDataset, self).__init__() 29 | self.opt = opt 30 | # file client (io backend) 31 | self.file_client = None 32 | self.io_backend_opt = opt['io_backend'] 33 | self.mean = opt['mean'] if 'mean' in opt else None 34 | self.std = opt['std'] if 'std' in opt else None 35 | self.lq_folder = opt['dataroot_lq'] 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = [self.lq_folder] 39 | self.io_backend_opt['client_keys'] = ['lq'] 40 | self.paths = paths_from_lmdb(self.lq_folder) 41 | elif 'meta_info_file' in self.opt: 42 | with open(self.opt['meta_info_file'], 'r') as fin: 43 | self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin] 44 | else: 45 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load lq image 52 | lq_path = self.paths[index] 53 | img_bytes = self.file_client.get(lq_path, 'lq') 54 | img_lq = imfrombytes(img_bytes, float32=True) 55 | 56 | # color space transform 57 | if 'color' in self.opt and self.opt['color'] == 'y': 58 | img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] 59 | 60 | # BGR to RGB, HWC to CHW, numpy to tensor 61 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 62 | # normalize 63 | if self.mean is not None or self.std is not None: 64 | normalize(img_lq, self.mean, self.std, inplace=True) 65 | return {'lq': img_lq, 'lq_path': lq_path} 66 | 67 | def __len__(self): 68 | return len(self.paths) 69 | -------------------------------------------------------------------------------- /basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import LOSS_REGISTRY 7 | from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty 8 | 9 | __all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] 10 | 11 | # automatically scan and import loss modules for registry 12 | # scan all the files under the 'losses' folder and collect files ending with '_loss.py' 13 | loss_folder = osp.dirname(osp.abspath(__file__)) 14 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] 15 | # import all the loss modules 16 | _model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames] 17 | 18 | 19 | def build_loss(opt): 20 | """Build loss from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must contain: 24 | type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | loss_type = opt.pop('type') 28 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 29 | logger = get_root_logger() 30 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 31 | return loss 32 | -------------------------------------------------------------------------------- /basicsr/losses/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/losses/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/losses/__pycache__/basic_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/losses/__pycache__/basic_loss.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/losses/__pycache__/gan_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/losses/__pycache__/gan_loss.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/losses/__pycache__/loss_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/losses/__pycache__/loss_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/losses/__pycache__/loss_util.cpython-37.pyc.139946364727344: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/losses/__pycache__/loss_util.cpython-37.pyc.139946364727344 -------------------------------------------------------------------------------- /basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def reduce_loss(loss, reduction): 7 | """Reduce loss as specified. 8 | 9 | Args: 10 | loss (Tensor): Elementwise loss tensor. 11 | reduction (str): Options are 'none', 'mean' and 'sum'. 12 | 13 | Returns: 14 | Tensor: Reduced loss tensor. 15 | """ 16 | reduction_enum = F._Reduction.get_enum(reduction) 17 | # none: 0, elementwise_mean:1, sum: 2 18 | if reduction_enum == 0: 19 | return loss 20 | elif reduction_enum == 1: 21 | return loss.mean() 22 | else: 23 | return loss.sum() 24 | 25 | 26 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 27 | """Apply element-wise weight and reduce loss. 28 | 29 | Args: 30 | loss (Tensor): Element-wise loss. 31 | weight (Tensor): Element-wise weights. Default: None. 32 | reduction (str): Same as built-in losses of PyTorch. Options are 33 | 'none', 'mean' and 'sum'. Default: 'mean'. 34 | 35 | Returns: 36 | Tensor: Loss values. 37 | """ 38 | # if weight is specified, apply element-wise weight 39 | if weight is not None: 40 | assert weight.dim() == loss.dim() 41 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 42 | loss = loss * weight 43 | 44 | # if weight is not specified or reduction is sum, just reduce the loss 45 | if weight is None or reduction == 'sum': 46 | loss = reduce_loss(loss, reduction) 47 | # if reduction is mean, then compute mean over weight region 48 | elif reduction == 'mean': 49 | if weight.size(1) > 1: 50 | weight = weight.sum() 51 | else: 52 | weight = weight.sum() * loss.size(1) 53 | loss = loss.sum() / weight 54 | 55 | return loss 56 | 57 | 58 | def weighted_loss(loss_func): 59 | """Create a weighted version of a given loss function. 60 | 61 | To use this decorator, the loss function must have the signature like 62 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 63 | element-wise loss without any reduction. This decorator will add weight 64 | and reduction arguments to the function. The decorated function will have 65 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 66 | **kwargs)`. 67 | 68 | :Example: 69 | 70 | >>> import torch 71 | >>> @weighted_loss 72 | >>> def l1_loss(pred, target): 73 | >>> return (pred - target).abs() 74 | 75 | >>> pred = torch.Tensor([0, 2, 3]) 76 | >>> target = torch.Tensor([1, 1, 1]) 77 | >>> weight = torch.Tensor([1, 0, 1]) 78 | 79 | >>> l1_loss(pred, target) 80 | tensor(1.3333) 81 | >>> l1_loss(pred, target, weight) 82 | tensor(1.5000) 83 | >>> l1_loss(pred, target, reduction='none') 84 | tensor([1., 1., 2.]) 85 | >>> l1_loss(pred, target, weight, reduction='sum') 86 | tensor(3.) 87 | """ 88 | 89 | @functools.wraps(loss_func) 90 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 91 | # get element-wise loss 92 | loss = loss_func(pred, target, **kwargs) 93 | loss = weight_reduce_loss(loss, weight, reduction) 94 | return loss 95 | 96 | return wrapper 97 | 98 | 99 | def get_local_weights(residual, ksize): 100 | """Get local weights for generating the artifact map of LDL. 101 | 102 | It is only called by the `get_refined_artifact_map` function. 103 | 104 | Args: 105 | residual (Tensor): Residual between predicted and ground truth images. 106 | ksize (Int): size of the local window. 107 | 108 | Returns: 109 | Tensor: weight for each pixel to be discriminated as an artifact pixel 110 | """ 111 | 112 | pad = (ksize - 1) // 2 113 | residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect') 114 | 115 | unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1) 116 | pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1) 117 | 118 | return pixel_level_weight 119 | 120 | 121 | def get_refined_artifact_map(img_gt, img_output, img_ema, ksize): 122 | """Calculate the artifact map of LDL 123 | (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022) 124 | 125 | Args: 126 | img_gt (Tensor): ground truth images. 127 | img_output (Tensor): output images given by the optimizing model. 128 | img_ema (Tensor): output images given by the ema model. 129 | ksize (Int): size of the local window. 130 | 131 | Returns: 132 | overall_weight: weight for each pixel to be discriminated as an artifact pixel 133 | (calculated based on both local and global observations). 134 | """ 135 | 136 | residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True) 137 | residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True) 138 | 139 | patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5) 140 | pixel_level_weight = get_local_weights(residual_sr.clone(), ksize) 141 | overall_weight = patch_level_weight * pixel_level_weight 142 | 143 | overall_weight[residual_sr < residual_ema] = 0 144 | 145 | return overall_weight 146 | -------------------------------------------------------------------------------- /basicsr/metrics/README.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | [English](README.md) **|** [简体中文](README_CN.md) 4 | 5 | - [约定](#约定) 6 | - [PSNR 和 SSIM](#psnr-和-ssim) 7 | 8 | ## 约定 9 | 10 | 因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定: 11 | 12 | - Numpy 类型 (一般是 cv2 的结果) 13 | - UINT8: BGR, [0, 255], (h, w, c) 14 | - float: BGR, [0, 1], (h, w, c). 一般作为中间结果 15 | - Tensor 类型 16 | - float: RGB, [0, 1], (n, c, h, w) 17 | 18 | 其他约定: 19 | 20 | - 以 `_pt` 结尾的是 PyTorch 结果 21 | - PyTorch version 支持 batch 计算 22 | - 颜色转换在 float32 上做;metric计算在 float64 上做 23 | 24 | ## PSNR 和 SSIM 25 | 26 | PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。 27 | 在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378)) 28 | 29 | 下面列了各个实现的结果比对. 30 | 总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异 31 | 32 | - PSNR 比对 33 | 34 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 35 | |:---| :---: | :---: | :---: | :---: | :---: | 36 | |baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 | 37 | |baboon| Y | - |22.441898 | 22.441899 | 22.444916| 38 | |comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 | 39 | |comic | Y | - | 21.720398 | 21.720398 | 21.721663| 40 | 41 | - SSIM 比对 42 | 43 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 44 | |:---| :---: | :---: | :---: | :---: | :---: | 45 | |baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 | 46 | |baboon| Y | - |0.453097| 0.453097 | 0.453171| 47 | |comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738| 48 | |comic | Y | - | 0.585511 | 0.585511 | 0.585522 | 49 | -------------------------------------------------------------------------------- /basicsr/metrics/README_CN.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | [English](README.md) **|** [简体中文](README_CN.md) 4 | 5 | - [约定](#约定) 6 | - [PSNR 和 SSIM](#psnr-和-ssim) 7 | 8 | ## 约定 9 | 10 | 因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定: 11 | 12 | - Numpy 类型 (一般是 cv2 的结果) 13 | - UINT8: BGR, [0, 255], (h, w, c) 14 | - float: BGR, [0, 1], (h, w, c). 一般作为中间结果 15 | - Tensor 类型 16 | - float: RGB, [0, 1], (n, c, h, w) 17 | 18 | 其他约定: 19 | 20 | - 以 `_pt` 结尾的是 PyTorch 结果 21 | - PyTorch version 支持 batch 计算 22 | - 颜色转换在 float32 上做;metric计算在 float64 上做 23 | 24 | ## PSNR 和 SSIM 25 | 26 | PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。 27 | 在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378)) 28 | 29 | 下面列了各个实现的结果比对. 30 | 总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异 31 | 32 | - PSNR 比对 33 | 34 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 35 | |:---| :---: | :---: | :---: | :---: | :---: | 36 | |baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 | 37 | |baboon| Y | - |22.441898 | 22.441899 | 22.444916| 38 | |comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 | 39 | |comic | Y | - | 21.720398 | 21.720398 | 21.721663| 40 | 41 | - SSIM 比对 42 | 43 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 44 | |:---| :---: | :---: | :---: | :---: | :---: | 45 | |baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 | 46 | |baboon| Y | - |0.453097| 0.453097 | 0.453171| 47 | |comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738| 48 | |comic | Y | - | 0.585511 | 0.585511 | 0.585522 | 49 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils.registry import METRIC_REGISTRY 4 | from .niqe import calculate_niqe 5 | from .psnr_ssim import calculate_psnr, calculate_ssim 6 | 7 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 8 | 9 | 10 | def calculate_metric(data, opt): 11 | """Calculate metric from data and options. 12 | 13 | Args: 14 | opt (dict): Configuration. It must contain: 15 | type (str): Model type. 16 | """ 17 | opt = deepcopy(opt) 18 | metric_type = opt.pop('type') 19 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 20 | return metric 21 | -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/metrics/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/metric_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/metrics/__pycache__/metric_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/niqe.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/metrics/__pycache__/niqe.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc.139946364729136: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc.139946364729136 -------------------------------------------------------------------------------- /basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False): 11 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 12 | # does resize the input. 13 | inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input) 14 | inception = nn.DataParallel(inception).eval().to(device) 15 | return inception 16 | 17 | 18 | @torch.no_grad() 19 | def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'): 20 | """Extract inception features. 21 | 22 | Args: 23 | data_generator (generator): A data generator. 24 | inception (nn.Module): Inception model. 25 | len_generator (int): Length of the data_generator to show the 26 | progressbar. Default: None. 27 | device (str): Device. Default: cuda. 28 | 29 | Returns: 30 | Tensor: Extracted features. 31 | """ 32 | if len_generator is not None: 33 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 34 | else: 35 | pbar = None 36 | features = [] 37 | 38 | for data in data_generator: 39 | if pbar: 40 | pbar.update(1) 41 | data = data.to(device) 42 | feature = inception(data)[0].view(data.shape[0], -1) 43 | features.append(feature.to('cpu')) 44 | if pbar: 45 | pbar.close() 46 | features = torch.cat(features, 0) 47 | return features 48 | 49 | 50 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 51 | """Numpy implementation of the Frechet Distance. 52 | 53 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is: 54 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 55 | Stable version by Dougal J. Sutherland. 56 | 57 | Args: 58 | mu1 (np.array): The sample mean over activations. 59 | sigma1 (np.array): The covariance matrix over activations for generated samples. 60 | mu2 (np.array): The sample mean over activations, precalculated on an representative data set. 61 | sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set. 62 | 63 | Returns: 64 | float: The Frechet Distance. 65 | """ 66 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 67 | assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions') 68 | 69 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 70 | 71 | # Product might be almost singular 72 | if not np.isfinite(cov_sqrt).all(): 73 | print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates') 74 | offset = np.eye(sigma1.shape[0]) * eps 75 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 76 | 77 | # Numerical error might give slight imaginary component 78 | if np.iscomplexobj(cov_sqrt): 79 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 80 | m = np.max(np.abs(cov_sqrt.imag)) 81 | raise ValueError(f'Imaginary component {m}') 82 | cov_sqrt = cov_sqrt.real 83 | 84 | mean_diff = mu1 - mu2 85 | mean_norm = mean_diff @ mean_diff 86 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 87 | fid = mean_norm + trace 88 | 89 | return fid 90 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /basicsr/metrics/test_metrics/test_psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | 4 | from basicsr.metrics import calculate_psnr, calculate_ssim 5 | from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt 6 | from basicsr.utils import img2tensor 7 | 8 | 9 | def test(img_path, img_path2, crop_border, test_y_channel=False): 10 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 11 | img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED) 12 | 13 | # --------------------- Numpy --------------------- 14 | psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) 15 | ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) 16 | print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}') 17 | 18 | # --------------------- PyTorch (CPU) --------------------- 19 | img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0) 20 | img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0) 21 | 22 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 23 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 24 | print(f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') 25 | 26 | # --------------------- PyTorch (GPU) --------------------- 27 | img = img.cuda() 28 | img2 = img2.cuda() 29 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 30 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 31 | print(f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') 32 | 33 | psnr_pth = calculate_psnr_pt( 34 | torch.repeat_interleave(img, 2, dim=0), 35 | torch.repeat_interleave(img2, 2, dim=0), 36 | crop_border=crop_border, 37 | test_y_channel=test_y_channel) 38 | ssim_pth = calculate_ssim_pt( 39 | torch.repeat_interleave(img, 2, dim=0), 40 | torch.repeat_interleave(img2, 2, dim=0), 41 | crop_border=crop_border, 42 | test_y_channel=test_y_channel) 43 | print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,' 44 | f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}') 45 | 46 | 47 | if __name__ == '__main__': 48 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=False) 49 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=True) 50 | 51 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=False) 52 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=True) 53 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with '_model.py' 12 | model_folder = osp.dirname(osp.abspath(__file__)) 13 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 14 | # import all the model modules 15 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 16 | 17 | 18 | def build_model(opt): 19 | """Build model from options. 20 | 21 | Args: 22 | opt (dict): Configuration. It must contain: 23 | model_type (str): Model type. 24 | """ 25 | opt = deepcopy(opt) 26 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 27 | logger = get_root_logger() 28 | logger.info(f'Model [{model.__class__.__name__}] is created.') 29 | return model 30 | -------------------------------------------------------------------------------- /basicsr/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/edvr_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/edvr_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/esrgan_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/esrgan_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/hifacegan_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/hifacegan_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/realesrgan_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/realesrgan_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/realesrnet_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/realesrnet_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/sr_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/sr_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/sr_model.cpython-37.pyc.139704389111088: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/sr_model.cpython-37.pyc.139704389111088 -------------------------------------------------------------------------------- /basicsr/models/__pycache__/srgan_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/srgan_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/stylegan2_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/stylegan2_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/swinir_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/swinir_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/video_base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/video_base_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/video_gan_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/video_gan_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/video_recurrent_gan_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/video_recurrent_gan_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/video_recurrent_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/models/__pycache__/video_recurrent_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/edvr_model.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils import get_root_logger 2 | from basicsr.utils.registry import MODEL_REGISTRY 3 | from .video_base_model import VideoBaseModel 4 | 5 | 6 | @MODEL_REGISTRY.register() 7 | class EDVRModel(VideoBaseModel): 8 | """EDVR Model. 9 | 10 | Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501 11 | """ 12 | 13 | def __init__(self, opt): 14 | super(EDVRModel, self).__init__(opt) 15 | if self.is_train: 16 | self.train_tsa_iter = opt['train'].get('tsa_iter') 17 | 18 | def setup_optimizers(self): 19 | train_opt = self.opt['train'] 20 | dcn_lr_mul = train_opt.get('dcn_lr_mul', 1) 21 | logger = get_root_logger() 22 | logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.') 23 | if dcn_lr_mul == 1: 24 | optim_params = self.net_g.parameters() 25 | else: # separate dcn params and normal params for different lr 26 | normal_params = [] 27 | dcn_params = [] 28 | for name, param in self.net_g.named_parameters(): 29 | if 'dcn' in name: 30 | dcn_params.append(param) 31 | else: 32 | normal_params.append(param) 33 | optim_params = [ 34 | { # add normal params first 35 | 'params': normal_params, 36 | 'lr': train_opt['optim_g']['lr'] 37 | }, 38 | { 39 | 'params': dcn_params, 40 | 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul 41 | }, 42 | ] 43 | 44 | optim_type = train_opt['optim_g'].pop('type') 45 | self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) 46 | self.optimizers.append(self.optimizer_g) 47 | 48 | def optimize_parameters(self, current_iter): 49 | if self.train_tsa_iter: 50 | if current_iter == 1: 51 | logger = get_root_logger() 52 | logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.') 53 | for name, param in self.net_g.named_parameters(): 54 | if 'fusion' not in name: 55 | param.requires_grad = False 56 | elif current_iter == self.train_tsa_iter: 57 | logger = get_root_logger() 58 | logger.warning('Train all the parameters.') 59 | for param in self.net_g.parameters(): 60 | param.requires_grad = True 61 | 62 | super(EDVRModel, self).optimize_parameters(current_iter) 63 | -------------------------------------------------------------------------------- /basicsr/models/esrgan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | from basicsr.utils.registry import MODEL_REGISTRY 5 | from .srgan_model import SRGANModel 6 | 7 | 8 | @MODEL_REGISTRY.register() 9 | class ESRGANModel(SRGANModel): 10 | """ESRGAN model for single image super-resolution.""" 11 | 12 | def optimize_parameters(self, current_iter): 13 | # optimize net_g 14 | for p in self.net_d.parameters(): 15 | p.requires_grad = False 16 | 17 | self.optimizer_g.zero_grad() 18 | self.output = self.net_g(self.lq) 19 | 20 | l_g_total = 0 21 | loss_dict = OrderedDict() 22 | if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): 23 | # pixel loss 24 | if self.cri_pix: 25 | l_g_pix = self.cri_pix(self.output, self.gt) 26 | l_g_total += l_g_pix 27 | loss_dict['l_g_pix'] = l_g_pix 28 | # perceptual loss 29 | if self.cri_perceptual: 30 | l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) 31 | if l_g_percep is not None: 32 | l_g_total += l_g_percep 33 | loss_dict['l_g_percep'] = l_g_percep 34 | if l_g_style is not None: 35 | l_g_total += l_g_style 36 | loss_dict['l_g_style'] = l_g_style 37 | # gan loss (relativistic gan) 38 | real_d_pred = self.net_d(self.gt).detach() 39 | fake_g_pred = self.net_d(self.output) 40 | l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False) 41 | l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False) 42 | l_g_gan = (l_g_real + l_g_fake) / 2 43 | 44 | l_g_total += l_g_gan 45 | loss_dict['l_g_gan'] = l_g_gan 46 | 47 | l_g_total.backward() 48 | self.optimizer_g.step() 49 | 50 | # optimize net_d 51 | for p in self.net_d.parameters(): 52 | p.requires_grad = True 53 | 54 | self.optimizer_d.zero_grad() 55 | # gan loss (relativistic gan) 56 | 57 | # In order to avoid the error in distributed training: 58 | # "Error detected in CudnnBatchNormBackward: RuntimeError: one of 59 | # the variables needed for gradient computation has been modified by 60 | # an inplace operation", 61 | # we separate the backwards for real and fake, and also detach the 62 | # tensor for calculating mean. 63 | 64 | # real 65 | fake_d_pred = self.net_d(self.output).detach() 66 | real_d_pred = self.net_d(self.gt) 67 | l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5 68 | l_d_real.backward() 69 | # fake 70 | fake_d_pred = self.net_d(self.output.detach()) 71 | l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5 72 | l_d_fake.backward() 73 | self.optimizer_d.step() 74 | 75 | loss_dict['l_d_real'] = l_d_real 76 | loss_dict['l_d_fake'] = l_d_fake 77 | loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) 78 | loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) 79 | 80 | self.log_dict = self.reduce_loss_dict(loss_dict) 81 | 82 | if self.ema_decay > 0: 83 | self.model_ema(decay=self.ema_decay) 84 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The minimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] 97 | -------------------------------------------------------------------------------- /basicsr/models/srgan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | from basicsr.archs import build_network 5 | from basicsr.losses import build_loss 6 | from basicsr.utils import get_root_logger 7 | from basicsr.utils.registry import MODEL_REGISTRY 8 | from .sr_model import SRModel 9 | 10 | 11 | @MODEL_REGISTRY.register() 12 | class SRGANModel(SRModel): 13 | """SRGAN model for single image super-resolution.""" 14 | 15 | def init_training_settings(self): 16 | train_opt = self.opt['train'] 17 | 18 | self.ema_decay = train_opt.get('ema_decay', 0) 19 | if self.ema_decay > 0: 20 | logger = get_root_logger() 21 | logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') 22 | # define network net_g with Exponential Moving Average (EMA) 23 | # net_g_ema is used only for testing on one GPU and saving 24 | # There is no need to wrap with DistributedDataParallel 25 | self.net_g_ema = build_network(self.opt['network_g']).to(self.device) 26 | # load pretrained model 27 | load_path = self.opt['path'].get('pretrain_network_g', None) 28 | if load_path is not None: 29 | self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') 30 | else: 31 | self.model_ema(0) # copy net_g weight 32 | self.net_g_ema.eval() 33 | 34 | # define network net_d 35 | self.net_d = build_network(self.opt['network_d']) 36 | self.net_d = self.model_to_device(self.net_d) 37 | self.print_network(self.net_d) 38 | 39 | # load pretrained models 40 | load_path = self.opt['path'].get('pretrain_network_d', None) 41 | if load_path is not None: 42 | param_key = self.opt['path'].get('param_key_d', 'params') 43 | self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key) 44 | 45 | self.net_g.train() 46 | self.net_d.train() 47 | 48 | # define losses 49 | if train_opt.get('pixel_opt'): 50 | self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) 51 | else: 52 | self.cri_pix = None 53 | 54 | if train_opt.get('ldl_opt'): 55 | self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device) 56 | else: 57 | self.cri_ldl = None 58 | 59 | if train_opt.get('perceptual_opt'): 60 | self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) 61 | else: 62 | self.cri_perceptual = None 63 | 64 | if train_opt.get('gan_opt'): 65 | self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) 66 | 67 | self.net_d_iters = train_opt.get('net_d_iters', 1) 68 | self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) 69 | 70 | # set up optimizers and schedulers 71 | self.setup_optimizers() 72 | self.setup_schedulers() 73 | 74 | def setup_optimizers(self): 75 | train_opt = self.opt['train'] 76 | # optimizer g 77 | optim_type = train_opt['optim_g'].pop('type') 78 | self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g']) 79 | self.optimizers.append(self.optimizer_g) 80 | # optimizer d 81 | optim_type = train_opt['optim_d'].pop('type') 82 | self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) 83 | self.optimizers.append(self.optimizer_d) 84 | 85 | def optimize_parameters(self, current_iter): 86 | # optimize net_g 87 | for p in self.net_d.parameters(): 88 | p.requires_grad = False 89 | 90 | self.optimizer_g.zero_grad() 91 | self.output = self.net_g(self.lq) 92 | 93 | l_g_total = 0 94 | loss_dict = OrderedDict() 95 | if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): 96 | # pixel loss 97 | if self.cri_pix: 98 | l_g_pix = self.cri_pix(self.output, self.gt) 99 | l_g_total += l_g_pix 100 | loss_dict['l_g_pix'] = l_g_pix 101 | # perceptual loss 102 | if self.cri_perceptual: 103 | l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) 104 | if l_g_percep is not None: 105 | l_g_total += l_g_percep 106 | loss_dict['l_g_percep'] = l_g_percep 107 | if l_g_style is not None: 108 | l_g_total += l_g_style 109 | loss_dict['l_g_style'] = l_g_style 110 | # gan loss 111 | fake_g_pred = self.net_d(self.output) 112 | l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) 113 | l_g_total += l_g_gan 114 | loss_dict['l_g_gan'] = l_g_gan 115 | 116 | l_g_total.backward() 117 | self.optimizer_g.step() 118 | 119 | # optimize net_d 120 | for p in self.net_d.parameters(): 121 | p.requires_grad = True 122 | 123 | self.optimizer_d.zero_grad() 124 | # real 125 | real_d_pred = self.net_d(self.gt) 126 | l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) 127 | loss_dict['l_d_real'] = l_d_real 128 | loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) 129 | l_d_real.backward() 130 | # fake 131 | fake_d_pred = self.net_d(self.output.detach()) 132 | l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) 133 | loss_dict['l_d_fake'] = l_d_fake 134 | loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) 135 | l_d_fake.backward() 136 | self.optimizer_d.step() 137 | 138 | self.log_dict = self.reduce_loss_dict(loss_dict) 139 | 140 | if self.ema_decay > 0: 141 | self.model_ema(decay=self.ema_decay) 142 | 143 | def save(self, epoch, current_iter): 144 | if hasattr(self, 'net_g_ema'): 145 | self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) 146 | else: 147 | self.save_network(self.net_g, 'net_g', current_iter) 148 | self.save_network(self.net_d, 'net_d', current_iter) 149 | self.save_training_state(epoch, current_iter) 150 | -------------------------------------------------------------------------------- /basicsr/models/swinir_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import MODEL_REGISTRY 5 | from .sr_model import SRModel 6 | 7 | 8 | @MODEL_REGISTRY.register() 9 | class SwinIRModel(SRModel): 10 | 11 | def test(self): 12 | # pad to multiplication of window_size 13 | window_size = self.opt['network_g']['window_size'] 14 | scale = self.opt.get('scale', 1) 15 | mod_pad_h, mod_pad_w = 0, 0 16 | _, _, h, w = self.lq.size() 17 | if h % window_size != 0: 18 | mod_pad_h = window_size - h % window_size 19 | if w % window_size != 0: 20 | mod_pad_w = window_size - w % window_size 21 | img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect') 22 | if hasattr(self, 'net_g_ema'): 23 | self.net_g_ema.eval() 24 | with torch.no_grad(): 25 | self.output = self.net_g_ema(img) 26 | else: 27 | self.net_g.eval() 28 | with torch.no_grad(): 29 | self.output = self.net_g(img) 30 | self.net_g.train() 31 | 32 | _, _, h, w = self.output.size() 33 | self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] 34 | -------------------------------------------------------------------------------- /basicsr/models/video_gan_model.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils.registry import MODEL_REGISTRY 2 | from .srgan_model import SRGANModel 3 | from .video_base_model import VideoBaseModel 4 | 5 | 6 | @MODEL_REGISTRY.register() 7 | class VideoGANModel(SRGANModel, VideoBaseModel): 8 | """Video GAN model. 9 | 10 | Use multiple inheritance. 11 | It will first use the functions of :class:`SRGANModel`: 12 | 13 | - :func:`init_training_settings` 14 | - :func:`setup_optimizers` 15 | - :func:`optimize_parameters` 16 | - :func:`save` 17 | 18 | Then find functions in :class:`VideoBaseModel`. 19 | """ 20 | -------------------------------------------------------------------------------- /basicsr/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/ops/__init__.py -------------------------------------------------------------------------------- /basicsr/ops/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/ops/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/ops/__pycache__/__init__.cpython-37.pyc.139952150904880: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/ops/__pycache__/__init__.cpython-37.pyc.139952150904880 -------------------------------------------------------------------------------- /basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, 2 | modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /basicsr/ops/dcn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/ops/dcn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/ops/dcn/__pycache__/deform_conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/ops/dcn/__pycache__/deform_conv.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/ops/dcn/__pycache__/deform_conv.cpython-37.pyc.139952150906544: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/ops/dcn/__pycache__/deform_conv.cpython-37.pyc.139952150906544 -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/ops/fused_act/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__pycache__/fused_act.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/ops/fused_act/__pycache__/fused_act.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__pycache__/fused_act.cpython-37.pyc.139952150906928: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/ops/fused_act/__pycache__/fused_act.cpython-37.pyc.139952150906928 -------------------------------------------------------------------------------- /basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Function 7 | 8 | BASICSR_JIT = os.getenv('BASICSR_JIT') 9 | if BASICSR_JIT == 'True': 10 | from torch.utils.cpp_extension import load 11 | module_path = os.path.dirname(__file__) 12 | fused_act_ext = load( 13 | 'fused', 14 | sources=[ 15 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'), 16 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), 17 | ], 18 | ) 19 | else: 20 | try: 21 | from . import fused_act_ext 22 | except ImportError: 23 | pass 24 | # avoid annoying print output 25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' 26 | # '1. compile with BASICSR_EXT=True. or\n ' 27 | # '2. set BASICSR_JIT=True during running') 28 | 29 | 30 | class FusedLeakyReLUFunctionBackward(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, grad_output, out, negative_slope, scale): 34 | ctx.save_for_backward(out) 35 | ctx.negative_slope = negative_slope 36 | ctx.scale = scale 37 | 38 | empty = grad_output.new_empty(0) 39 | 40 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 41 | 42 | dim = [0] 43 | 44 | if grad_input.ndim > 2: 45 | dim += list(range(2, grad_input.ndim)) 46 | 47 | grad_bias = grad_input.sum(dim).detach() 48 | 49 | return grad_input, grad_bias 50 | 51 | @staticmethod 52 | def backward(ctx, gradgrad_input, gradgrad_bias): 53 | out, = ctx.saved_tensors 54 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, 55 | ctx.scale) 56 | 57 | return gradgrad_out, None, None, None 58 | 59 | 60 | class FusedLeakyReLUFunction(Function): 61 | 62 | @staticmethod 63 | def forward(ctx, input, bias, negative_slope, scale): 64 | empty = input.new_empty(0) 65 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 66 | ctx.save_for_backward(out) 67 | ctx.negative_slope = negative_slope 68 | ctx.scale = scale 69 | 70 | return out 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | out, = ctx.saved_tensors 75 | 76 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 77 | 78 | return grad_input, grad_bias, None, None 79 | 80 | 81 | class FusedLeakyReLU(nn.Module): 82 | 83 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 84 | super().__init__() 85 | 86 | self.bias = nn.Parameter(torch.zeros(channel)) 87 | self.negative_slope = negative_slope 88 | self.scale = scale 89 | 90 | def forward(self, input): 91 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 92 | 93 | 94 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 95 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 96 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | #include 3 | 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, 6 | const torch::Tensor& bias, 7 | const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, 15 | const torch::Tensor& bias, 16 | const torch::Tensor& refer, 17 | int act, int grad, float alpha, float scale) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(bias); 20 | 21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 26 | } 27 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu 2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 3 | // 4 | // This work is made available under the Nvidia Source Code License-NC. 5 | // To view a copy of this license, visit 6 | // https://nvlabs.github.io/stylegan2/license.html 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | 19 | template 20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 23 | 24 | scalar_t zero = 0.0; 25 | 26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 27 | scalar_t x = p_x[xi]; 28 | 29 | if (use_bias) { 30 | x += p_b[(xi / step_b) % size_b]; 31 | } 32 | 33 | scalar_t ref = use_ref ? p_ref[xi] : zero; 34 | 35 | scalar_t y; 36 | 37 | switch (act * 10 + grad) { 38 | default: 39 | case 10: y = x; break; 40 | case 11: y = x; break; 41 | case 12: y = 0.0; break; 42 | 43 | case 30: y = (x > 0.0) ? x : x * alpha; break; 44 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 45 | case 32: y = 0.0; break; 46 | } 47 | 48 | out[xi] = y * scale; 49 | } 50 | } 51 | 52 | 53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 54 | int act, int grad, float alpha, float scale) { 55 | int curDevice = -1; 56 | cudaGetDevice(&curDevice); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 58 | 59 | auto x = input.contiguous(); 60 | auto b = bias.contiguous(); 61 | auto ref = refer.contiguous(); 62 | 63 | int use_bias = b.numel() ? 1 : 0; 64 | int use_ref = ref.numel() ? 1 : 0; 65 | 66 | int size_x = x.numel(); 67 | int size_b = b.numel(); 68 | int step_b = 1; 69 | 70 | for (int i = 1 + 1; i < x.dim(); i++) { 71 | step_b *= x.size(i); 72 | } 73 | 74 | int loop_x = 4; 75 | int block_size = 4 * 32; 76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 77 | 78 | auto y = torch::empty_like(x); 79 | 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 81 | fused_bias_act_kernel<<>>( 82 | y.data_ptr(), 83 | x.data_ptr(), 84 | b.data_ptr(), 85 | ref.data_ptr(), 86 | act, 87 | grad, 88 | alpha, 89 | scale, 90 | loop_x, 91 | size_x, 92 | step_b, 93 | size_b, 94 | use_bias, 95 | use_ref 96 | ); 97 | }); 98 | 99 | return y; 100 | } 101 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ['upfirdn2d'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/ops/upfirdn2d/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__pycache__/upfirdn2d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/ops/upfirdn2d/__pycache__/upfirdn2d.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__pycache__/upfirdn2d.cpython-37.pyc.139952150101424: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/ops/upfirdn2d/__pycache__/upfirdn2d.cpython-37.pyc.139952150101424 -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/src/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | #include 3 | 4 | 5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 6 | int up_x, int up_y, int down_x, int down_y, 7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 14 | int up_x, int up_y, int down_x, int down_y, 15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(kernel); 18 | 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import build_dataloader, build_dataset 6 | from basicsr.models import build_model 7 | from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs 8 | from basicsr.utils.options import dict2str, parse_options 9 | 10 | 11 | def test_pipeline(root_path): 12 | print("---> from my pipeline") 13 | # parse options, set distributed setting, set ramdom seed 14 | opt, _ = parse_options(root_path, is_train=False) 15 | 16 | torch.backends.cudnn.benchmark = True 17 | # torch.backends.cudnn.deterministic = True 18 | 19 | # mkdir and initialize loggers 20 | make_exp_dirs(opt) 21 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") 22 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 23 | # logger.info(get_env_info()) 24 | logger.info(dict2str(opt)) 25 | 26 | # create test dataset and dataloader 27 | test_loaders = [] 28 | for _, dataset_opt in sorted(opt['datasets'].items()): 29 | test_set = build_dataset(dataset_opt) 30 | test_loader = build_dataloader( 31 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) 32 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 33 | test_loaders.append(test_loader) 34 | 35 | # create model 36 | model = build_model(opt) 37 | net = model.net_g 38 | 39 | total = sum(l.numel() for l in net.parameters()) 40 | print("HAT--", total) 41 | # for l in net.parameters(): 42 | # print(l) 43 | 44 | for test_loader in test_loaders: 45 | test_set_name = test_loader.dataset.opt['name'] 46 | logger.info(f'Testing {test_set_name}...') 47 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) 48 | 49 | print("--test time", model.t) 50 | print(torch.cuda.max_memory_allocated()) 51 | 52 | 53 | if __name__ == '__main__': 54 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 55 | test_pipeline(root_path) 56 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb 2 | from .diffjpeg import DiffJPEG 3 | from .file_client import FileClient 4 | from .img_process_util import USMSharp, usm_sharp 5 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 6 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 7 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 8 | from .options import yaml_load 9 | 10 | __all__ = [ 11 | # color_util.py 12 | 'bgr2ycbcr', 13 | 'rgb2ycbcr', 14 | 'rgb2ycbcr_pt', 15 | 'ycbcr2bgr', 16 | 'ycbcr2rgb', 17 | # file_client.py 18 | 'FileClient', 19 | # img_util.py 20 | 'img2tensor', 21 | 'tensor2img', 22 | 'imfrombytes', 23 | 'imwrite', 24 | 'crop_border', 25 | # logger.py 26 | 'MessageLogger', 27 | 'AvgTimer', 28 | 'init_tb_logger', 29 | 'init_wandb_logger', 30 | 'get_root_logger', 31 | 'get_env_info', 32 | # misc.py 33 | 'set_random_seed', 34 | 'get_time_str', 35 | 'mkdir_and_rename', 36 | 'make_exp_dirs', 37 | 'scandir', 38 | 'check_resume', 39 | 'sizeof_fmt', 40 | # diffjpeg 41 | 'DiffJPEG', 42 | # img_process_util 43 | 'USMSharp', 44 | 'usm_sharp', 45 | # options 46 | 'yaml_load' 47 | ] 48 | -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/color_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/color_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/diffjpeg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/diffjpeg.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/dist_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/dist_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/dist_util.cpython-37.pyc.140634762200752: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/dist_util.cpython-37.pyc.140634762200752 -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/file_client.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/file_client.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/flow_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/flow_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/img_process_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/img_process_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/img_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/img_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/matlab_functions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/matlab_functions.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/options.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/registry.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/basicsr/utils/__pycache__/registry.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | print("init_dist_slurm -- ", proc_id, ntasks, node_list, num_gpus) 44 | torch.cuda.set_device(proc_id % num_gpus) 45 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 46 | # specify master port 47 | if port is not None: 48 | os.environ['MASTER_PORT'] = str(port) 49 | elif 'MASTER_PORT' in os.environ: 50 | pass # use MASTER_PORT in the environment variable 51 | else: 52 | # 29500 is torch.distributed default port 53 | os.environ['MASTER_PORT'] = '29500' 54 | os.environ['MASTER_ADDR'] = addr 55 | os.environ['WORLD_SIZE'] = str(ntasks) 56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 57 | os.environ['RANK'] = str(proc_id) 58 | dist.init_process_group(backend=backend) 59 | 60 | 61 | def get_dist_info(): 62 | if dist.is_available(): 63 | initialized = dist.is_initialized() 64 | else: 65 | initialized = False 66 | if initialized: 67 | rank = dist.get_rank() 68 | world_size = dist.get_world_size() 69 | else: 70 | rank = 0 71 | world_size = 1 72 | return rank, world_size 73 | 74 | 75 | def master_only(func): 76 | 77 | @functools.wraps(func) 78 | def wrapper(*args, **kwargs): 79 | rank, _ = get_dist_info() 80 | if rank == 0: 81 | return func(*args, **kwargs) 82 | 83 | return wrapper 84 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | from .misc import sizeof_fmt 9 | 10 | 11 | def download_file_from_google_drive(file_id, save_path): 12 | """Download files from google drive. 13 | 14 | Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive 15 | 16 | Args: 17 | file_id (str): File id. 18 | save_path (str): Save path. 19 | """ 20 | 21 | session = requests.Session() 22 | URL = 'https://docs.google.com/uc?export=download' 23 | params = {'id': file_id} 24 | 25 | response = session.get(URL, params=params, stream=True) 26 | token = get_confirm_token(response) 27 | if token: 28 | params['confirm'] = token 29 | response = session.get(URL, params=params, stream=True) 30 | 31 | # get file size 32 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 33 | if 'Content-Range' in response_file_size.headers: 34 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 49 | if file_size is not None: 50 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 51 | 52 | readable_file_size = sizeof_fmt(file_size) 53 | else: 54 | pbar = None 55 | 56 | with open(destination, 'wb') as f: 57 | downloaded_size = 0 58 | for chunk in response.iter_content(chunk_size): 59 | downloaded_size += chunk_size 60 | if pbar is not None: 61 | pbar.update(1) 62 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 63 | if chunk: # filter out keep-alive new chunks 64 | f.write(chunk) 65 | if pbar is not None: 66 | pbar.close() 67 | 68 | 69 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 70 | """Load file form http url, will download models if necessary. 71 | 72 | Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 73 | 74 | Args: 75 | url (str): URL to be downloaded. 76 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 77 | Default: None. 78 | progress (bool): Whether to show the download progress. Default: True. 79 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 80 | 81 | Returns: 82 | str: The path to the downloaded file. 83 | """ 84 | if model_dir is None: # use the pytorch hub_dir 85 | hub_dir = get_dir() 86 | model_dir = os.path.join(hub_dir, 'checkpoints') 87 | 88 | os.makedirs(model_dir, exist_ok=True) 89 | 90 | parts = urlparse(url) 91 | filename = os.path.basename(parts.path) 92 | if file_name is not None: 93 | filename = file_name 94 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 95 | if not os.path.exists(cached_file): 96 | print(f'Downloading: "{url}" to {cached_file}\n') 97 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 98 | return cached_file 99 | -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError('Please install memcached to enable MemcachedBackend.') 40 | 41 | self.server_list_cfg = server_list_cfg 42 | self.client_cfg = client_cfg 43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 44 | # mc.pyvector servers as a point which points to a memory cache 45 | self._mc_buffer = mc.pyvector() 46 | 47 | def get(self, filepath): 48 | filepath = str(filepath) 49 | import mc 50 | self._client.Get(filepath, self._mc_buffer) 51 | value_buf = mc.ConvertBuffer(self._mc_buffer) 52 | return value_buf 53 | 54 | def get_text(self, filepath): 55 | raise NotImplementedError 56 | 57 | 58 | class HardDiskBackend(BaseStorageBackend): 59 | """Raw hard disks storage backend.""" 60 | 61 | def get(self, filepath): 62 | filepath = str(filepath) 63 | with open(filepath, 'rb') as f: 64 | value_buf = f.read() 65 | return value_buf 66 | 67 | def get_text(self, filepath): 68 | filepath = str(filepath) 69 | with open(filepath, 'r') as f: 70 | value_buf = f.read() 71 | return value_buf 72 | 73 | 74 | class LmdbBackend(BaseStorageBackend): 75 | """Lmdb storage backend. 76 | 77 | Args: 78 | db_paths (str | list[str]): Lmdb database paths. 79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 80 | readonly (bool, optional): Lmdb environment parameter. If True, 81 | disallow any write operations. Default: True. 82 | lock (bool, optional): Lmdb environment parameter. If False, when 83 | concurrent access occurs, do not lock the database. Default: False. 84 | readahead (bool, optional): Lmdb environment parameter. If False, 85 | disable the OS filesystem readahead mechanism, which may improve 86 | random read performance when a database is larger than RAM. 87 | Default: False. 88 | 89 | Attributes: 90 | db_paths (list): Lmdb database path. 91 | _client (list): A list of several lmdb envs. 92 | """ 93 | 94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): 95 | try: 96 | import lmdb 97 | except ImportError: 98 | raise ImportError('Please install lmdb to enable LmdbBackend.') 99 | 100 | if isinstance(client_keys, str): 101 | client_keys = [client_keys] 102 | 103 | if isinstance(db_paths, list): 104 | self.db_paths = [str(v) for v in db_paths] 105 | elif isinstance(db_paths, str): 106 | self.db_paths = [str(db_paths)] 107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' 108 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 109 | 110 | self._client = {} 111 | for client, path in zip(client_keys, self.db_paths): 112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 113 | 114 | def get(self, filepath, client_key): 115 | """Get values according to the filepath from one lmdb named client_key. 116 | 117 | Args: 118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 119 | client_key (str): Used for distinguishing different lmdb envs. 120 | """ 121 | filepath = str(filepath) 122 | assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.') 123 | client = self._client[client_key] 124 | with client.begin(write=False) as txn: 125 | value_buf = txn.get(filepath.encode('ascii')) 126 | return value_buf 127 | 128 | def get_text(self, filepath): 129 | raise NotImplementedError 130 | 131 | 132 | class FileClient(object): 133 | """A general file client to access files in different backend. 134 | 135 | The client loads a file or text in a specified backend from its path 136 | and return it as a binary file. it can also register other backend 137 | accessor with a given name and backend class. 138 | 139 | Attributes: 140 | backend (str): The storage backend type. Options are "disk", 141 | "memcached" and "lmdb". 142 | client (:obj:`BaseStorageBackend`): The backend object. 143 | """ 144 | 145 | _backends = { 146 | 'disk': HardDiskBackend, 147 | 'memcached': MemcachedBackend, 148 | 'lmdb': LmdbBackend, 149 | } 150 | 151 | def __init__(self, backend='disk', **kwargs): 152 | if backend not in self._backends: 153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones' 154 | f' are {list(self._backends.keys())}') 155 | self.backend = backend 156 | self.client = self._backends[backend](**kwargs) 157 | 158 | def get(self, filepath, client_key='default'): 159 | # client_key is used only for lmdb, where different fileclients have 160 | # different lmdb environments. 161 | if self.backend == 'lmdb': 162 | return self.client.get(filepath, client_key) 163 | else: 164 | return self.client.get(filepath) 165 | 166 | def get_text(self, filepath): 167 | return self.client.get_text(filepath) 168 | -------------------------------------------------------------------------------- /basicsr/utils/flow_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 2 | import cv2 3 | import numpy as np 4 | import os 5 | 6 | 7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): 8 | """Read an optical flow map. 9 | 10 | Args: 11 | flow_path (ndarray or str): Flow path. 12 | quantize (bool): whether to read quantized pair, if set to True, 13 | remaining args will be passed to :func:`dequantize_flow`. 14 | concat_axis (int): The axis that dx and dy are concatenated, 15 | can be either 0 or 1. Ignored if quantize is False. 16 | 17 | Returns: 18 | ndarray: Optical flow represented as a (h, w, 2) numpy array 19 | """ 20 | if quantize: 21 | assert concat_axis in [0, 1] 22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) 23 | if cat_flow.ndim != 2: 24 | raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') 25 | assert cat_flow.shape[concat_axis] % 2 == 0 26 | dx, dy = np.split(cat_flow, 2, axis=concat_axis) 27 | flow = dequantize_flow(dx, dy, *args, **kwargs) 28 | else: 29 | with open(flow_path, 'rb') as f: 30 | try: 31 | header = f.read(4).decode('utf-8') 32 | except Exception: 33 | raise IOError(f'Invalid flow file: {flow_path}') 34 | else: 35 | if header != 'PIEH': 36 | raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH') 37 | 38 | w = np.fromfile(f, np.int32, 1).squeeze() 39 | h = np.fromfile(f, np.int32, 1).squeeze() 40 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) 41 | 42 | return flow.astype(np.float32) 43 | 44 | 45 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): 46 | """Write optical flow to file. 47 | 48 | If the flow is not quantized, it will be saved as a .flo file losslessly, 49 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy 50 | will be concatenated horizontally into a single image if quantize is True.) 51 | 52 | Args: 53 | flow (ndarray): (h, w, 2) array of optical flow. 54 | filename (str): Output filepath. 55 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg 56 | images. If set to True, remaining args will be passed to 57 | :func:`quantize_flow`. 58 | concat_axis (int): The axis that dx and dy are concatenated, 59 | can be either 0 or 1. Ignored if quantize is False. 60 | """ 61 | if not quantize: 62 | with open(filename, 'wb') as f: 63 | f.write('PIEH'.encode('utf-8')) 64 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 65 | flow = flow.astype(np.float32) 66 | flow.tofile(f) 67 | f.flush() 68 | else: 69 | assert concat_axis in [0, 1] 70 | dx, dy = quantize_flow(flow, *args, **kwargs) 71 | dxdy = np.concatenate((dx, dy), axis=concat_axis) 72 | os.makedirs(os.path.dirname(filename), exist_ok=True) 73 | cv2.imwrite(filename, dxdy) 74 | 75 | 76 | def quantize_flow(flow, max_val=0.02, norm=True): 77 | """Quantize flow to [0, 255]. 78 | 79 | After this step, the size of flow will be much smaller, and can be 80 | dumped as jpeg images. 81 | 82 | Args: 83 | flow (ndarray): (h, w, 2) array of optical flow. 84 | max_val (float): Maximum value of flow, values beyond 85 | [-max_val, max_val] will be truncated. 86 | norm (bool): Whether to divide flow values by image width/height. 87 | 88 | Returns: 89 | tuple[ndarray]: Quantized dx and dy. 90 | """ 91 | h, w, _ = flow.shape 92 | dx = flow[..., 0] 93 | dy = flow[..., 1] 94 | if norm: 95 | dx = dx / w # avoid inplace operations 96 | dy = dy / h 97 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. 98 | flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] 99 | return tuple(flow_comps) 100 | 101 | 102 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True): 103 | """Recover from quantized flow. 104 | 105 | Args: 106 | dx (ndarray): Quantized dx. 107 | dy (ndarray): Quantized dy. 108 | max_val (float): Maximum value used when quantizing. 109 | denorm (bool): Whether to multiply flow values with width/height. 110 | 111 | Returns: 112 | ndarray: Dequantized flow. 113 | """ 114 | assert dx.shape == dy.shape 115 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) 116 | 117 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] 118 | 119 | if denorm: 120 | dx *= dx.shape[1] 121 | dy *= dx.shape[0] 122 | flow = np.dstack((dx, dy)) 123 | return flow 124 | 125 | 126 | def quantize(arr, min_val, max_val, levels, dtype=np.int64): 127 | """Quantize an array of (-inf, inf) to [0, levels-1]. 128 | 129 | Args: 130 | arr (ndarray): Input array. 131 | min_val (scalar): Minimum value to be clipped. 132 | max_val (scalar): Maximum value to be clipped. 133 | levels (int): Quantization levels. 134 | dtype (np.type): The type of the quantized array. 135 | 136 | Returns: 137 | tuple: Quantized array. 138 | """ 139 | if not (isinstance(levels, int) and levels > 1): 140 | raise ValueError(f'levels must be a positive integer, but got {levels}') 141 | if min_val >= max_val: 142 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 143 | 144 | arr = np.clip(arr, min_val, max_val) - min_val 145 | quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) 146 | 147 | return quantized_arr 148 | 149 | 150 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64): 151 | """Dequantize an array. 152 | 153 | Args: 154 | arr (ndarray): Input array. 155 | min_val (scalar): Minimum value to be clipped. 156 | max_val (scalar): Maximum value to be clipped. 157 | levels (int): Quantization levels. 158 | dtype (np.type): The type of the dequantized array. 159 | 160 | Returns: 161 | tuple: Dequantized array. 162 | """ 163 | if not (isinstance(levels, int) and levels > 1): 164 | raise ValueError(f'levels must be a positive integer, but got {levels}') 165 | if min_val >= max_val: 166 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 167 | 168 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val 169 | 170 | return dequantized_arr 171 | -------------------------------------------------------------------------------- /basicsr/utils/img_process_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def filter2D(img, kernel): 8 | """PyTorch version of cv2.filter2D 9 | 10 | Args: 11 | img (Tensor): (b, c, h, w) 12 | kernel (Tensor): (b, k, k) 13 | """ 14 | k = kernel.size(-1) 15 | b, c, h, w = img.size() 16 | if k % 2 == 1: 17 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') 18 | else: 19 | raise ValueError('Wrong kernel size') 20 | 21 | ph, pw = img.size()[-2:] 22 | 23 | if kernel.size(0) == 1: 24 | # apply the same kernel to all batch images 25 | img = img.view(b * c, 1, ph, pw) 26 | kernel = kernel.view(1, 1, k, k) 27 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w) 28 | else: 29 | img = img.view(1, b * c, ph, pw) 30 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) 31 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) 32 | 33 | 34 | def usm_sharp(img, weight=0.5, radius=50, threshold=10): 35 | """USM sharpening. 36 | 37 | Input image: I; Blurry image: B. 38 | 1. sharp = I + weight * (I - B) 39 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 40 | 3. Blur mask: 41 | 4. Out = Mask * sharp + (1 - Mask) * I 42 | 43 | 44 | Args: 45 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 46 | weight (float): Sharp weight. Default: 1. 47 | radius (float): Kernel size of Gaussian blur. Default: 50. 48 | threshold (int): 49 | """ 50 | if radius % 2 == 0: 51 | radius += 1 52 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 53 | residual = img - blur 54 | mask = np.abs(residual) * 255 > threshold 55 | mask = mask.astype('float32') 56 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 57 | 58 | sharp = img + weight * residual 59 | sharp = np.clip(sharp, 0, 1) 60 | return soft_mask * sharp + (1 - soft_mask) * img 61 | 62 | 63 | class USMSharp(torch.nn.Module): 64 | 65 | def __init__(self, radius=50, sigma=0): 66 | super(USMSharp, self).__init__() 67 | if radius % 2 == 0: 68 | radius += 1 69 | self.radius = radius 70 | kernel = cv2.getGaussianKernel(radius, sigma) 71 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) 72 | self.register_buffer('kernel', kernel) 73 | 74 | def forward(self, img, weight=0.5, threshold=10): 75 | blur = filter2D(img, self.kernel) 76 | residual = img - blur 77 | 78 | mask = torch.abs(residual) * 255 > threshold 79 | mask = mask.float() 80 | soft_mask = filter2D(mask, self.kernel) 81 | sharp = img + weight * residual 82 | sharp = torch.clip(sharp, 0, 1) 83 | return soft_mask * sharp + (1 - soft_mask) * img 84 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | if img.dtype == 'float64': 25 | img = img.astype('float32') 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | img = torch.from_numpy(img.transpose(2, 0, 1)) 28 | if float32: 29 | img = img.float() 30 | return img 31 | 32 | if isinstance(imgs, list): 33 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 34 | else: 35 | return _totensor(imgs, bgr2rgb, float32) 36 | 37 | 38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 39 | """Convert torch Tensors into image numpy arrays. 40 | 41 | After clamping to [min, max], values will be normalized to [0, 1]. 42 | 43 | Args: 44 | tensor (Tensor or list[Tensor]): Accept shapes: 45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 46 | 2) 3D Tensor of shape (3/1 x H x W); 47 | 3) 2D Tensor of shape (H x W). 48 | Tensor channel should be in RGB order. 49 | rgb2bgr (bool): Whether to change rgb to bgr. 50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 51 | to uint8 type with range [0, 255]; otherwise, float type with 52 | range [0, 1]. Default: ``np.uint8``. 53 | min_max (tuple[int]): min and max values for clamp. 54 | 55 | Returns: 56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 57 | shape (H x W). The channel order is BGR. 58 | """ 59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if torch.is_tensor(tensor): 63 | tensor = [tensor] 64 | result = [] 65 | for _tensor in tensor: 66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 68 | 69 | n_dim = _tensor.dim() 70 | if n_dim == 4: 71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 72 | img_np = img_np.transpose(1, 2, 0) 73 | if rgb2bgr: 74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 75 | elif n_dim == 3: 76 | img_np = _tensor.numpy() 77 | img_np = img_np.transpose(1, 2, 0) 78 | if img_np.shape[2] == 1: # gray image 79 | img_np = np.squeeze(img_np, axis=2) 80 | else: 81 | if rgb2bgr: 82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 83 | elif n_dim == 2: 84 | img_np = _tensor.numpy() 85 | else: 86 | raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') 87 | if out_type == np.uint8: 88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 89 | img_np = (img_np * 255.0).round() 90 | img_np = img_np.astype(out_type) 91 | result.append(img_np) 92 | if len(result) == 1: 93 | result = result[0] 94 | return result 95 | 96 | 97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): 98 | """This implementation is slightly faster than tensor2img. 99 | It now only supports torch tensor with shape (1, c, h, w). 100 | 101 | Args: 102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w). 103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True. 104 | min_max (tuple[int]): min and max values for clamp. 105 | """ 106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) 107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 108 | output = output.type(torch.uint8).cpu().numpy() 109 | if rgb2bgr: 110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 111 | return output 112 | 113 | 114 | def imfrombytes(content, flag='color', float32=False): 115 | """Read an image from bytes. 116 | 117 | Args: 118 | content (bytes): Image bytes got from files or other streams. 119 | flag (str): Flags specifying the color type of a loaded image, 120 | candidates are `color`, `grayscale` and `unchanged`. 121 | float32 (bool): Whether to change to float32., If True, will also norm 122 | to [0, 1]. Default: False. 123 | 124 | Returns: 125 | ndarray: Loaded image array. 126 | """ 127 | img_np = np.frombuffer(content, np.uint8) 128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} 129 | img = cv2.imdecode(img_np, imread_flags[flag]) 130 | if float32: 131 | img = img.astype(np.float32) / 255. 132 | return img 133 | 134 | 135 | def imwrite(img, file_path, params=None, auto_mkdir=True): 136 | """Write image to file. 137 | 138 | Args: 139 | img (ndarray): Image array to be written. 140 | file_path (str): Image file path. 141 | params (None or list): Same as opencv's :func:`imwrite` interface. 142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 143 | whether to create it automatically. 144 | 145 | Returns: 146 | bool: Successful or not. 147 | """ 148 | if auto_mkdir: 149 | dir_name = os.path.abspath(os.path.dirname(file_path)) 150 | os.makedirs(dir_name, exist_ok=True) 151 | ok = cv2.imwrite(file_path, img, params) 152 | if not ok: 153 | raise IOError('Failed in writing images.') 154 | 155 | 156 | def crop_border(imgs, crop_border): 157 | """Crop borders of images. 158 | 159 | Args: 160 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 161 | crop_border (int): Crop border for each end of height and weight. 162 | 163 | Returns: 164 | list[ndarray]: Cropped images. 165 | """ 166 | if crop_border == 0: 167 | return imgs 168 | else: 169 | if isinstance(imgs, list): 170 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] 171 | else: 172 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] 173 | -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | 10 | 11 | def set_random_seed(seed): 12 | """Set random seeds.""" 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def get_time_str(): 21 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 22 | 23 | 24 | def mkdir_and_rename(path): 25 | """mkdirs. If path exists, rename it with timestamp and create a new one. 26 | 27 | Args: 28 | path (str): Folder path. 29 | """ 30 | if osp.exists(path): 31 | new_name = path + '_archived_' + get_time_str() 32 | print(f'Path already exists. Rename it to {new_name}', flush=True) 33 | os.rename(path, new_name) 34 | os.makedirs(path, exist_ok=True) 35 | 36 | 37 | @master_only 38 | def make_exp_dirs(opt): 39 | """Make dirs for experiments.""" 40 | path_opt = opt['path'].copy() 41 | if opt['is_train']: 42 | mkdir_and_rename(path_opt.pop('experiments_root')) 43 | else: 44 | mkdir_and_rename(path_opt.pop('results_root')) 45 | for key, path in path_opt.items(): 46 | if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): 47 | continue 48 | else: 49 | os.makedirs(path, exist_ok=True) 50 | 51 | 52 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 53 | """Scan a directory to find the interested files. 54 | 55 | Args: 56 | dir_path (str): Path of the directory. 57 | suffix (str | tuple(str), optional): File suffix that we are 58 | interested in. Default: None. 59 | recursive (bool, optional): If set to True, recursively scan the 60 | directory. Default: False. 61 | full_path (bool, optional): If set to True, include the dir_path. 62 | Default: False. 63 | 64 | Returns: 65 | A generator for all the interested files with relative paths. 66 | """ 67 | 68 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 69 | raise TypeError('"suffix" must be a string or tuple of strings') 70 | 71 | root = dir_path 72 | 73 | def _scandir(dir_path, suffix, recursive): 74 | for entry in os.scandir(dir_path): 75 | if not entry.name.startswith('.') and entry.is_file(): 76 | if full_path: 77 | return_path = entry.path 78 | else: 79 | return_path = osp.relpath(entry.path, root) 80 | 81 | if suffix is None: 82 | yield return_path 83 | elif return_path.endswith(suffix): 84 | yield return_path 85 | else: 86 | if recursive: 87 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 88 | else: 89 | continue 90 | 91 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 92 | 93 | 94 | def check_resume(opt, resume_iter): 95 | """Check resume states and pretrain_network paths. 96 | 97 | Args: 98 | opt (dict): Options. 99 | resume_iter (int): Resume iteration. 100 | """ 101 | if opt['path']['resume_state']: 102 | # get all the networks 103 | networks = [key for key in opt.keys() if key.startswith('network_')] 104 | flag_pretrain = False 105 | for network in networks: 106 | if opt['path'].get(f'pretrain_{network}') is not None: 107 | flag_pretrain = True 108 | if flag_pretrain: 109 | print('pretrain_network path will be ignored during resuming.') 110 | # set pretrained model paths 111 | for network in networks: 112 | name = f'pretrain_{network}' 113 | basename = network.replace('network_', '') 114 | if opt['path'].get('ignore_resume_networks') is None or (network 115 | not in opt['path']['ignore_resume_networks']): 116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 117 | print(f"Set {name} to {opt['path'][name]}") 118 | 119 | # change param_key to params in resume 120 | param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] 121 | for param_key in param_keys: 122 | if opt['path'][param_key] == 'params_ema': 123 | opt['path'][param_key] = 'params' 124 | print(f'Set {param_key} to params') 125 | 126 | 127 | def sizeof_fmt(size, suffix='B'): 128 | """Get human readable file size. 129 | 130 | Args: 131 | size (int): File size. 132 | suffix (str): Suffix. Default: 'B'. 133 | 134 | Return: 135 | str: Formatted file size. 136 | """ 137 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 138 | if abs(size) < 1024.0: 139 | return f'{size:3.1f} {unit}{suffix}' 140 | size /= 1024.0 141 | return f'{size:3.1f} Y{suffix}' 142 | -------------------------------------------------------------------------------- /basicsr/utils/plot_util.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def read_data_from_tensorboard(log_path, tag): 5 | """Get raw data (steps and values) from tensorboard events. 6 | 7 | Args: 8 | log_path (str): Path to the tensorboard log. 9 | tag (str): tag to be read. 10 | """ 11 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 12 | 13 | # tensorboard event 14 | event_acc = EventAccumulator(log_path) 15 | event_acc.Reload() 16 | scalar_list = event_acc.Tags()['scalars'] 17 | print('tag list: ', scalar_list) 18 | steps = [int(s.step) for s in event_acc.Scalars(tag)] 19 | values = [s.value for s in event_acc.Scalars(tag)] 20 | return steps, values 21 | 22 | 23 | def read_data_from_txt_2v(path, pattern, step_one=False): 24 | """Read data from txt with 2 returned values (usually [step, value]). 25 | 26 | Args: 27 | path (str): path to the txt file. 28 | pattern (str): re (regular expression) pattern. 29 | step_one (bool): add 1 to steps. Default: False. 30 | """ 31 | with open(path) as f: 32 | lines = f.readlines() 33 | lines = [line.strip() for line in lines] 34 | steps = [] 35 | values = [] 36 | 37 | pattern = re.compile(pattern) 38 | for line in lines: 39 | match = pattern.match(line) 40 | if match: 41 | steps.append(int(match.group(1))) 42 | values.append(float(match.group(2))) 43 | if step_one: 44 | steps = [v + 1 for v in steps] 45 | return steps, values 46 | 47 | 48 | def read_data_from_txt_1v(path, pattern): 49 | """Read data from txt with 1 returned values. 50 | 51 | Args: 52 | path (str): path to the txt file. 53 | pattern (str): re (regular expression) pattern. 54 | """ 55 | with open(path) as f: 56 | lines = f.readlines() 57 | lines = [line.strip() for line in lines] 58 | data = [] 59 | 60 | pattern = re.compile(pattern) 61 | for line in lines: 62 | match = pattern.match(line) 63 | if match: 64 | data.append(float(match.group(1))) 65 | return data 66 | 67 | 68 | def smooth_data(values, smooth_weight): 69 | """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does). 70 | 71 | Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501 72 | 73 | Args: 74 | values (list): A list of values to be smoothed. 75 | smooth_weight (float): Smooth weight. 76 | """ 77 | values_sm = [] 78 | last_sm_value = values[0] 79 | for value in values: 80 | value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value 81 | values_sm.append(value_sm) 82 | last_sm_value = value_sm 83 | return values_sm 84 | -------------------------------------------------------------------------------- /basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj, suffix=None): 39 | if isinstance(suffix, str): 40 | name = name + '_' + suffix 41 | 42 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 43 | f"in '{self._name}' registry!") 44 | self._obj_map[name] = obj 45 | 46 | def register(self, obj=None, suffix=None): 47 | """ 48 | Register the given object under the the name `obj.__name__`. 49 | Can be used as either a decorator or not. 50 | See docstring of this class for usage. 51 | """ 52 | if obj is None: 53 | # used as a decorator 54 | def deco(func_or_class): 55 | name = func_or_class.__name__ 56 | self._do_register(name, func_or_class, suffix) 57 | return func_or_class 58 | 59 | return deco 60 | 61 | # used as a function call 62 | name = obj.__name__ 63 | self._do_register(name, obj, suffix) 64 | 65 | def get(self, name, suffix='basicsr'): 66 | ret = self._obj_map.get(name) 67 | if ret is None: 68 | ret = self._obj_map.get(name + '_' + suffix) 69 | print(f'Name {name} is not found, use name: {name}_{suffix}!') 70 | if ret is None: 71 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 72 | return ret 73 | 74 | def __contains__(self, name): 75 | return name in self._obj_map 76 | 77 | def __iter__(self): 78 | return iter(self._obj_map.items()) 79 | 80 | def keys(self): 81 | return self._obj_map.keys() 82 | 83 | 84 | DATASET_REGISTRY = Registry('dataset') 85 | ARCH_REGISTRY = Registry('arch') 86 | MODEL_REGISTRY = Registry('model') 87 | LOSS_REGISTRY = Registry('loss') 88 | METRIC_REGISTRY = Registry('metric') 89 | -------------------------------------------------------------------------------- /datasets/datasets.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/datasets/datasets.txt -------------------------------------------------------------------------------- /doc/img/motivation.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/EQSR/e10a258ed81c3d2589fe773d7ca7421282347e30/doc/img/motivation.PNG -------------------------------------------------------------------------------- /options/test/test.yml: -------------------------------------------------------------------------------- 1 | name: DU_6 2 | model_type: HATModel 3 | scale: 8 4 | num_gpu: 1 # set num_gpu: 0 for cpu mode 5 | manual_seed: 0 6 | 7 | datasets: 8 | test_1: # the 1st test dataset 9 | name: one_pic 10 | type: DSF_val_Dataset 11 | dataroot_gt: /data3/GeoSR_data/benchmark/Set5/LR_bicubic/X4 12 | dataroot_lq: ./set5-eqs/set5x4-D/6 13 | io_backend: 14 | type: disk 15 | 16 | 17 | # network structures 18 | network_g: 19 | type: EQSR 20 | upscale: 4 21 | in_chans: 3 22 | img_size: 64 23 | window_size: 16 24 | compress_ratio: 3 25 | squeeze_factor: 30 26 | conv_scale: 0.01 27 | overlap_ratio: 0.5 28 | img_range: 1. 29 | depths: [5, 5, 5, 5, 5, 5] 30 | embed_dim: 180 31 | num_heads: [6, 6, 6, 6, 6, 6] 32 | mlp_ratio: 2 33 | upsampler: 'pixelshuffle' 34 | resi_connection: '1conv' 35 | 36 | 37 | # path 38 | path: 39 | pretrain_network_g: /data3/KITTI/HAT/experiments/train_ModMBFormer_Sim_DSF_ImageNet_DF2K/models/net_g_latest.pth 40 | strict_load_g: true 41 | param_key_g: 'params_ema' 42 | 43 | # validation settings 44 | val: 45 | save_img: true 46 | suffix: ~ # add suffix to saved images, if None, use exp name 47 | 48 | -------------------------------------------------------------------------------- /options/test/test_scale.yml: -------------------------------------------------------------------------------- 1 | name: task_name 2 | model_type: HATModel 3 | scale: 5.75 4 | num_gpu: 1 # set num_gpu: 0 for cpu mode 5 | manual_seed: 0 6 | 7 | datasets: 8 | test_1: # the 1st test dataset 9 | name: Set5 10 | type: DSF_val_downsample_Dataset 11 | dataroot_gt: /data3/GeoSR_data/benchmark/Set5/HR 12 | # dataroot_lq: /data3/GeoSR_data/benchmark/Set5/LR_bicubic/X2 13 | io_backend: 14 | type: disk 15 | 16 | test_2: # the 2nd test dataset 17 | name: Set14 18 | type: DSF_val_downsample_Dataset 19 | dataroot_gt: /data3/GeoSR_data/benchmark/Set14/HR 20 | # dataroot_lq: /data3/GeoSR_data/benchmark/Set14/LR_bicubic/X2 21 | io_backend: 22 | type: disk 23 | 24 | test_3: 25 | name: Urban100 26 | type: DSF_val_downsample_Dataset 27 | dataroot_gt: /data3/GeoSR_data/benchmark/Urban100/HR 28 | # dataroot_lq: /data3/GeoSR_data/benchmark/Urban100/LR_bicubic/X2 29 | io_backend: 30 | type: disk 31 | 32 | test_4: 33 | name: B100 34 | type: DSF_val_downsample_Dataset 35 | dataroot_gt: /data3/GeoSR_data/benchmark/B100/HR 36 | # dataroot_lq: /data3/GeoSR_data/benchmark/B100/LR_bicubic/X2 37 | io_backend: 38 | type: disk 39 | 40 | test_5: 41 | name: Manga109 42 | type: DSF_val_downsample_Dataset 43 | dataroot_gt: datasets/benchmark/Manga109/HR 44 | # dataroot_lq: ./datasets/manga109/LRbicx4 45 | io_backend: 46 | type: disk 47 | 48 | # network structures 49 | network_g: 50 | type: EQSR 51 | upscale: 4 52 | in_chans: 3 53 | img_size: 64 54 | window_size: 16 55 | compress_ratio: 3 56 | squeeze_factor: 30 57 | conv_scale: 0.01 58 | overlap_ratio: 0.5 59 | img_range: 1. 60 | depths: [5, 5, 5, 5, 5, 5] 61 | embed_dim: 180 62 | num_heads: [6, 6, 6, 6, 6, 6] 63 | mlp_ratio: 2 64 | upsampler: 'pixelshuffle' 65 | resi_connection: '1conv' 66 | 67 | 68 | # path 69 | path: 70 | pretrain_network_g: /data3/KITTI/HAT/experiments/train_ModMBFormer_Sim_DSF_ImageNet_DF2K_2/models/net_g_latest.pth 71 | strict_load_g: true 72 | param_key_g: 'params_ema' 73 | 74 | # validation settings 75 | val: 76 | save_img: false 77 | suffix: ~ # add suffix to saved images, if None, use exp name 78 | 79 | metrics: 80 | psnr: # metric name, can be arbitrary 81 | type: calculate_psnr 82 | crop_border: 4 83 | test_y_channel: true 84 | ssim: 85 | type: calculate_ssim 86 | crop_border: 4 87 | test_y_channel: true 88 | -------------------------------------------------------------------------------- /options/test/testx234.yml: -------------------------------------------------------------------------------- 1 | name: task_name 2 | model_type: HATModel 3 | scale: 2 4 | num_gpu: 1 # set num_gpu: 0 for cpu mode 5 | manual_seed: 0 6 | 7 | datasets: 8 | test_1: # the 1st test dataset 9 | name: Set5x2 10 | type: DSF_val_Dataset 11 | dataroot_gt: /data3/GeoSR_data/benchmark/Set5/HR 12 | dataroot_lq: /data3/GeoSR_data/benchmark/Set5/LR_bicubic/X2 13 | io_backend: 14 | type: disk 15 | 16 | test_2: # the 2nd test dataset 17 | name: Set14x2 18 | type: DSF_val_Dataset 19 | dataroot_gt: /data3/GeoSR_data/benchmark/Set14/HR 20 | dataroot_lq: /data3/GeoSR_data/benchmark/Set14/LR_bicubic/X2 21 | io_backend: 22 | type: disk 23 | 24 | test_3: 25 | name: Urban100x2 26 | type: DSF_val_Dataset 27 | dataroot_gt: /data3/GeoSR_data/benchmark/Urban100/HR 28 | dataroot_lq: /data3/GeoSR_data/benchmark/Urban100/LR_bicubic/X2 29 | io_backend: 30 | type: disk 31 | 32 | test_4: 33 | name: B100x2 34 | type: DSF_val_Dataset 35 | dataroot_gt: /data3/GeoSR_data/benchmark/B100/HR 36 | dataroot_lq: /data3/GeoSR_data/benchmark/B100/LR_bicubic/X2 37 | io_backend: 38 | type: disk 39 | 40 | test_401: 41 | name: Manga109x2 42 | type: DSF_val_Dataset 43 | dataroot_gt: /data3/GeoSR_data/benchmark/Manga109/HR 44 | dataroot_lq: /data3/GeoSR_data/benchmark/Manga109/LR_bicubic/X2 45 | io_backend: 46 | type: disk 47 | 48 | test_5: # the 1st test dataset 49 | name: Set5x3 50 | type: DSF_val_Dataset 51 | dataroot_gt: /data3/GeoSR_data/benchmark/Set5/HR 52 | dataroot_lq: /data3/GeoSR_data/benchmark/Set5/LR_bicubic/X3 53 | io_backend: 54 | type: disk 55 | 56 | test_6: # the 2nd test dataset 57 | name: Set14x3 58 | type: DSF_val_Dataset 59 | dataroot_gt: /data3/GeoSR_data/benchmark/Set14/HR 60 | dataroot_lq: /data3/GeoSR_data/benchmark/Set14/LR_bicubic/X3 61 | io_backend: 62 | type: disk 63 | 64 | test_7: 65 | name: Urban100x3 66 | type: DSF_val_Dataset 67 | dataroot_gt: /data3/GeoSR_data/benchmark/Urban100/HR 68 | dataroot_lq: /data3/GeoSR_data/benchmark/Urban100/LR_bicubic/X3 69 | io_backend: 70 | type: disk 71 | 72 | test_8: 73 | name: B100x3 74 | type: DSF_val_Dataset 75 | dataroot_gt: /data3/GeoSR_data/benchmark/B100/HR 76 | dataroot_lq: /data3/GeoSR_data/benchmark/B100/LR_bicubic/X3 77 | io_backend: 78 | type: disk 79 | 80 | test_801: 81 | name: Manga109x3 82 | type: DSF_val_Dataset 83 | dataroot_gt: /data3/GeoSR_data/benchmark/Manga109/HR 84 | dataroot_lq: /data3/GeoSR_data/benchmark/Manga109/LR_bicubic/X3 85 | io_backend: 86 | type: disk 87 | 88 | test_9: # the 1st test dataset 89 | name: Set5x4 90 | type: DSF_val_Dataset 91 | dataroot_gt: /data3/GeoSR_data/benchmark/Set5/HR 92 | dataroot_lq: /data3/GeoSR_data/benchmark/Set5/LR_bicubic/X4 93 | io_backend: 94 | type: disk 95 | 96 | test_900: # the 2nd test dataset 97 | name: Set14x4 98 | type: DSF_val_Dataset 99 | dataroot_gt: /data3/GeoSR_data/benchmark/Set14/HR 100 | dataroot_lq: /data3/GeoSR_data/benchmark/Set14/LR_bicubic/X4 101 | io_backend: 102 | type: disk 103 | 104 | test_901: 105 | name: Urban100x4 106 | type: DSF_val_Dataset 107 | dataroot_gt: /data3/GeoSR_data/benchmark/Urban100/HR 108 | dataroot_lq: /data3/GeoSR_data/benchmark/Urban100/LR_bicubic/X4 109 | io_backend: 110 | type: disk 111 | 112 | test_902: 113 | name: B100x4 114 | type: DSF_val_Dataset 115 | dataroot_gt: /data3/GeoSR_data/benchmark/B100/HR 116 | dataroot_lq: /data3/GeoSR_data/benchmark/B100/LR_bicubic/X4 117 | io_backend: 118 | type: disk 119 | 120 | test_903: 121 | name: Manga109x4 122 | type: DSF_val_Dataset 123 | dataroot_gt: /data3/GeoSR_data/benchmark/Manga109/HR 124 | dataroot_lq: /data3/GeoSR_data/benchmark/Manga109/LR_bicubic/X4 125 | io_backend: 126 | type: disk 127 | 128 | # network structures 129 | network_g: 130 | type: EQSR 131 | upscale: 3 132 | in_chans: 3 133 | img_size: 64 134 | window_size: 16 135 | compress_ratio: 3 136 | squeeze_factor: 30 137 | conv_scale: 0.01 138 | overlap_ratio: 0.5 139 | img_range: 1. 140 | depths: [5, 5, 5, 5, 5, 5] 141 | embed_dim: 180 142 | num_heads: [6, 6, 6, 6, 6, 6] 143 | mlp_ratio: 2 144 | upsampler: 'pixelshuffle' 145 | resi_connection: '1conv' 146 | 147 | 148 | # path 149 | path: 150 | pretrain_network_g: ./experiments/train_ModMBFormer_Sim_DSF_ImageNet_DF2K/models/net_g_latest.pth 151 | strict_load_g: true 152 | param_key_g: 'params_ema' 153 | 154 | # validation settings 155 | val: 156 | save_img: false 157 | suffix: ~ # add suffix to saved images, if None, use exp name 158 | 159 | metrics: 160 | psnr: # metric name, can be arbitrary 161 | type: calculate_psnr 162 | crop_border: 4 163 | test_y_channel: true 164 | ssim: 165 | type: calculate_ssim 166 | crop_border: 4 167 | test_y_channel: true 168 | -------------------------------------------------------------------------------- /options/test/testx6.yml: -------------------------------------------------------------------------------- 1 | name: task_name 2 | model_type: HATModel 3 | scale: 6 4 | num_gpu: 1 # set num_gpu: 0 for cpu mode 5 | manual_seed: 0 6 | 7 | datasets: 8 | 9 | train: 10 | name: DF2K 11 | type: DSF_DF2K_Dataset 12 | dataroot_gt: /data3/GeoSR_data/benchmark/Urban100/HR 13 | dataroot_lq: /data3/GeoSR_data/benchmark/Urban100/LR_bicubic/X2 14 | io_backend: 15 | type: disk 16 | 17 | patch_size: 32 18 | 19 | # data loader 20 | use_shuffle: true 21 | num_worker_per_gpu: 6 22 | batch_size_per_gpu: 4 23 | dataset_enlarge_ratio: 1 24 | prefetch_mode: ~ 25 | 26 | 27 | test_1: # the 1st test dataset 28 | name: Set5 29 | type: DSF_val_downsample_Dataset 30 | dataroot_gt: /data3/GeoSR_data/benchmark/Set5/HR 31 | # dataroot_lq: /data3/GeoSR_data/benchmark/Set5/LR_bicubic/X2 32 | io_backend: 33 | type: disk 34 | 35 | test_2: # the 2nd test dataset 36 | name: Set14 37 | type: DSF_val_downsample_Dataset 38 | dataroot_gt: /data3/GeoSR_data/benchmark/Set14/HR 39 | # dataroot_lq: /data3/GeoSR_data/benchmark/Set14/LR_bicubic/X2 40 | io_backend: 41 | type: disk 42 | 43 | test_3: 44 | name: Urban100 45 | type: DSF_val_downsample_Dataset 46 | dataroot_gt: /data3/GeoSR_data/benchmark/Urban100/HR 47 | # dataroot_lq: /data3/GeoSR_data/benchmark/Urban100/LR_bicubic/X2 48 | io_backend: 49 | type: disk 50 | 51 | test_4: 52 | name: B100 53 | type: DSF_val_downsample_Dataset 54 | dataroot_gt: /data3/GeoSR_data/benchmark/B100/HR 55 | # dataroot_lq: /data3/GeoSR_data/benchmark/B100/LR_bicubic/X2 56 | io_backend: 57 | type: disk 58 | 59 | test_5: 60 | name: Manga109 61 | type: DSF_val_downsample_Dataset 62 | dataroot_gt: datasets/benchmark/Manga109/HR 63 | io_backend: 64 | type: disk 65 | 66 | # network structures 67 | network_g: 68 | type: EQSR 69 | upscale: 4 70 | in_chans: 3 71 | img_size: 64 72 | window_size: 16 73 | compress_ratio: 3 74 | squeeze_factor: 30 75 | conv_scale: 0.01 76 | overlap_ratio: 0.5 77 | img_range: 1. 78 | depths: [5, 5, 5, 5, 5, 5] 79 | embed_dim: 180 80 | num_heads: [6, 6, 6, 6, 6, 6] 81 | mlp_ratio: 2 82 | upsampler: 'pixelshuffle' 83 | resi_connection: '1conv' 84 | 85 | 86 | # path 87 | path: 88 | pretrain_network_g: /data3/KITTI/HAT/experiments/train_ModMBFormer_Sim_DSF_DF2K_DF2K_only/models/net_g_latest.pth 89 | strict_load_g: true 90 | param_key_g: 'params_ema' 91 | 92 | # validation settings 93 | val: 94 | save_img: false 95 | suffix: ~ # add suffix to saved images, if None, use exp name 96 | 97 | metrics: 98 | psnr: # metric name, can be arbitrary 99 | type: calculate_psnr 100 | crop_border: 6 101 | test_y_channel: true 102 | ssim: 103 | type: calculate_ssim 104 | crop_border: 6 105 | test_y_channel: true 106 | -------------------------------------------------------------------------------- /options/test/testx8.yml: -------------------------------------------------------------------------------- 1 | name: task_name 2 | model_type: HATModel 3 | scale: 8 4 | num_gpu: 1 # set num_gpu: 0 for cpu mode 5 | manual_seed: 0 6 | 7 | datasets: 8 | test_1: # the 1st test dataset 9 | name: Set5 10 | type: DSF_val_downsample_Dataset 11 | dataroot_gt: /data3/GeoSR_data/benchmark/Set5/HR 12 | # dataroot_lq: /data3/GeoSR_data/benchmark/Set5/LR_bicubic/X2 13 | io_backend: 14 | type: disk 15 | 16 | test_2: # the 2nd test dataset 17 | name: Set14 18 | type: DSF_val_downsample_Dataset 19 | dataroot_gt: /data3/GeoSR_data/benchmark/Set14/HR 20 | # dataroot_lq: /data3/GeoSR_data/benchmark/Set14/LR_bicubic/X2 21 | io_backend: 22 | type: disk 23 | 24 | test_3: 25 | name: Urban100 26 | type: DSF_val_downsample_Dataset 27 | dataroot_gt: /data3/GeoSR_data/benchmark/Urban100/HR 28 | # dataroot_lq: /data3/GeoSR_data/benchmark/Urban100/LR_bicubic/X2 29 | io_backend: 30 | type: disk 31 | 32 | test_4: 33 | name: B100 34 | type: DSF_val_downsample_Dataset 35 | dataroot_gt: /data3/GeoSR_data/benchmark/B100/HR 36 | # dataroot_lq: /data3/GeoSR_data/benchmark/B100/LR_bicubic/X2 37 | io_backend: 38 | type: disk 39 | 40 | test_5: 41 | name: Manga109 42 | type: DSF_val_downsample_Dataset 43 | dataroot_gt: datasets/benchmark/Manga109/HR 44 | # dataroot_lq: ./datasets/manga109/LRbicx4 45 | io_backend: 46 | type: disk 47 | 48 | # network structures 49 | network_g: 50 | type: EQSR 51 | upscale: 4 52 | in_chans: 3 53 | img_size: 64 54 | window_size: 16 55 | compress_ratio: 3 56 | squeeze_factor: 30 57 | conv_scale: 0.01 58 | overlap_ratio: 0.5 59 | img_range: 1. 60 | depths: [5, 5, 5, 5, 5, 5] 61 | embed_dim: 180 62 | num_heads: [6, 6, 6, 6, 6, 6] 63 | mlp_ratio: 2 64 | upsampler: 'pixelshuffle' 65 | resi_connection: '1conv' 66 | 67 | 68 | # path 69 | path: 70 | pretrain_network_g: /data3/KITTI/HAT/experiments/train_ModMBFormer_Sim_DSF_DF2K_DF2K_only/models/net_g_latest.pth 71 | strict_load_g: true 72 | param_key_g: 'params_ema' 73 | 74 | # validation settings 75 | val: 76 | save_img: false 77 | suffix: ~ # add suffix to saved images, if None, use exp name 78 | 79 | metrics: 80 | psnr: # metric name, can be arbitrary 81 | type: calculate_psnr 82 | crop_border: 10 83 | test_y_channel: true 84 | ssim: 85 | type: calculate_ssim 86 | crop_border: 10 87 | test_y_channel: true 88 | -------------------------------------------------------------------------------- /options/train/train_EQSR_ImageNet_from_scratch.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_ModMBFormer_Sim_DSF_ImageNet_from_scratch 3 | model_type: HATModel 4 | scale: 8 5 | num_gpu: auto 6 | manual_seed: 0 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: ImageNet 12 | type: DSF_imagenet_Dataset 13 | dataroot_gt: datasets/ImageNet/train 14 | io_backend: 15 | type: disk 16 | 17 | patch_size: 48 18 | use_hflip: true 19 | use_rot: true 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 4 24 | batch_size_per_gpu: 8 25 | dataset_enlarge_ratio: 1 26 | prefetch_mode: ~ 27 | 28 | val_1: 29 | name: Set5x4 30 | type: DSF_val_Dataset 31 | dataroot_gt: datasets/benchmark/Set5/HR 32 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 33 | io_backend: 34 | type: disk 35 | 36 | val_2: 37 | name: Set5x8 38 | type: DSF_val_downsample_Dataset 39 | dataroot_gt: datasets/benchmark/Set5/HR 40 | io_backend: 41 | type: disk 42 | 43 | # val_3: 44 | # name: Urban100 45 | # type: PairedImageDataset 46 | # dataroot_gt: ./datasets/urban100/GTmod3 47 | # dataroot_lq: ./datasets/urban100/LRbicx3 48 | # io_backend: 49 | # type: disk 50 | 51 | 52 | # network structures 53 | network_g: 54 | type: EQSR 55 | upscale: 3 56 | in_chans: 3 57 | img_size: 64 58 | window_size: 16 59 | compress_ratio: 3 60 | squeeze_factor: 30 61 | conv_scale: 0.01 62 | overlap_ratio: 0.5 63 | img_range: 1. 64 | depths: [5, 5, 5, 5, 5, 5] 65 | embed_dim: 180 66 | num_heads: [6, 6, 6, 6, 6, 6] 67 | mlp_ratio: 2 68 | upsampler: 'pixelshuffle' 69 | resi_connection: '1conv' 70 | 71 | # path 72 | path: 73 | pretrain_network_g: ~ 74 | strict_load_g: true 75 | resume_state: ~ 76 | 77 | # training settings 78 | train: 79 | ema_decay: 0.999 80 | optim_g: 81 | type: Adam 82 | lr: !!float 2e-4 83 | weight_decay: 0 84 | betas: [0.9, 0.99] 85 | 86 | scheduler: 87 | type: MultiStepLR 88 | milestones: [300000, 500000, 650000, 700000, 750000] 89 | gamma: 0.5 90 | 91 | total_iter: 800000 92 | warmup_iter: -1 # no warm up 93 | 94 | # losses 95 | pixel_opt: 96 | type: L1Loss 97 | loss_weight: 1.0 98 | reduction: mean 99 | 100 | # validation settings 101 | val: 102 | val_freq: !!float 1e3 103 | save_img: false 104 | pbar: False 105 | 106 | metrics: 107 | psnr: 108 | type: calculate_psnr 109 | crop_border: 3 110 | test_y_channel: true 111 | better: higher # the higher, the better. Default: higher 112 | ssim: 113 | type: calculate_ssim 114 | crop_border: 3 115 | test_y_channel: true 116 | better: higher # the higher, the better. Default: higher 117 | 118 | # logging settings 119 | logger: 120 | print_freq: 100 121 | save_checkpoint_freq: !!float 1e4 122 | use_tb_logger: true 123 | wandb: 124 | project: ~ 125 | resume_id: ~ 126 | 127 | # dist training settings 128 | dist_params: 129 | backend: nccl 130 | port: 29500 131 | -------------------------------------------------------------------------------- /options/train/train_EQSR_finetune_from_ImageNet_pretrain.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_ModMBFormer_Sim_DSF_ImageNet_DF2K 3 | model_type: HATModel 4 | scale: 8 5 | num_gpu: auto 6 | manual_seed: 0 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: DF2K 12 | type: DSF_DF2K_Dataset 13 | dataroot_gt: datasets/DF2K_HR 14 | io_backend: 15 | type: disk 16 | 17 | patch_size: 48 18 | use_hflip: true 19 | use_rot: true 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 0 24 | batch_size_per_gpu: 8 25 | dataset_enlarge_ratio: 20 26 | prefetch_mode: ~ 27 | 28 | val_1: 29 | name: Set5x4 30 | type: DSF_val_Dataset 31 | dataroot_gt: datasets/benchmark/Set5/HR 32 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 33 | io_backend: 34 | type: disk 35 | 36 | val_2: 37 | name: Set5x8 38 | type: DSF_val_downsample_Dataset 39 | dataroot_gt: datasets/benchmark/Set5/HR 40 | # dataroot_lq: ./datasets/Set14/LRbicx2 41 | io_backend: 42 | type: disk 43 | 44 | # val_3: 45 | # name: Urban100 46 | # type: PairedImageDataset 47 | # dataroot_gt: ./datasets/urban100/GTmod2 48 | # dataroot_lq: ./datasets/urban100/LRbicx2 49 | # io_backend: 50 | # type: disk 51 | 52 | 53 | # network structures 54 | network_g: 55 | type: EQSR 56 | upscale: 3 57 | in_chans: 3 58 | img_size: 64 59 | window_size: 16 60 | compress_ratio: 3 61 | squeeze_factor: 30 62 | conv_scale: 0.01 63 | overlap_ratio: 0.5 64 | img_range: 1. 65 | depths: [5, 5, 5, 5, 5, 5] 66 | embed_dim: 180 67 | num_heads: [6, 6, 6, 6, 6, 6] 68 | mlp_ratio: 2 69 | upsampler: 'pixelshuffle' 70 | resi_connection: '1conv' 71 | 72 | # path 73 | path: 74 | pretrain_network_g: ./experiments/train_ModMBFormer_Sim_DSF_ImageNet_from_scratch/models/net_g_latest.pth 75 | param_key_g: 'params_ema' 76 | strict_load_g: true 77 | resume_state: ~ 78 | 79 | # training settings 80 | train: 81 | ema_decay: 0.999 82 | optim_g: 83 | type: Adam 84 | lr: !!float 1e-5 85 | weight_decay: 0 86 | betas: [0.9, 0.99] 87 | 88 | scheduler: 89 | type: MultiStepLR 90 | milestones: [125000, 200000, 225000, 240000] 91 | gamma: 0.5 92 | 93 | total_iter: 250000 94 | warmup_iter: -1 # no warm up 95 | 96 | # losses 97 | pixel_opt: 98 | type: L1Loss 99 | loss_weight: 1.0 100 | reduction: mean 101 | 102 | # validation settings 103 | val: 104 | val_freq: !!float 1e3 105 | save_img: false 106 | pbar: False 107 | 108 | metrics: 109 | psnr: 110 | type: calculate_psnr 111 | crop_border: 3 112 | test_y_channel: true 113 | better: higher # the higher, the better. Default: higher 114 | ssim: 115 | type: calculate_ssim 116 | crop_border: 3 117 | test_y_channel: true 118 | better: higher # the higher, the better. Default: higher 119 | 120 | # logging settings 121 | logger: 122 | print_freq: 100 123 | save_checkpoint_freq: !!float 1e4 124 | use_tb_logger: true 125 | wandb: 126 | project: ~ 127 | resume_id: ~ 128 | 129 | # dist training settings 130 | dist_params: 131 | backend: nccl 132 | port: 29500 133 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tempfile 3 | import shutil 4 | import os 5 | from PIL import Image 6 | import subprocess 7 | from cog import BasePredictor, Input, Path 8 | 9 | 10 | class Predictor(BasePredictor): 11 | def predict( 12 | self, 13 | image: Path = Input( 14 | description="Input Image.", 15 | ), 16 | ) -> Path: 17 | input_dir = "input_dir" 18 | output_path = Path(tempfile.mkdtemp()) / "output.png" 19 | 20 | try: 21 | for d in [input_dir, "results"]: 22 | if os.path.exists(input_dir): 23 | shutil.rmtree(input_dir) 24 | os.makedirs(input_dir, exist_ok=False) 25 | 26 | input_path = os.path.join(input_dir, os.path.basename(image)) 27 | shutil.copy(str(image), input_path) 28 | subprocess.call( 29 | [ 30 | "python", 31 | "hat/test.py", 32 | "-opt", 33 | "options/test/HAT_SRx4_ImageNet-LR.yml", 34 | ] 35 | ) 36 | res_dir = os.path.join( 37 | "results", "HAT_SRx4_ImageNet-LR", "visualization", "custom" 38 | ) 39 | assert ( 40 | len(os.listdir(res_dir)) == 1 41 | ), "Should contain only one result for Single prediction." 42 | res = Image.open(os.path.join(res_dir, os.listdir(res_dir)[0])) 43 | res.save(str(output_path)) 44 | 45 | finally: 46 | pass 47 | shutil.rmtree(input_dir) 48 | shutil.rmtree("results") 49 | 50 | return output_path 51 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import os.path as osp 3 | 4 | import EQSR.archs 5 | import EQSR.data 6 | import EQSR.models 7 | from basicsr.test import test_pipeline 8 | 9 | if __name__ == '__main__': 10 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 11 | test_pipeline(root_path) 12 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import os.path as osp 3 | import torch 4 | print(torch.version.cuda) 5 | 6 | import EQSR.archs 7 | import EQSR.data 8 | import EQSR.models 9 | from basicsr.train import train_pipeline 10 | 11 | if __name__ == '__main__': 12 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 13 | train_pipeline(root_path) 14 | --------------------------------------------------------------------------------