├── basicsr ├── ops │ ├── __init__.py │ ├── upfirdn2d │ │ ├── __init__.py │ │ ├── src │ │ │ └── upfirdn2d.cpp │ │ └── upfirdn2d.py │ ├── fused_act │ │ ├── __init__.py │ │ ├── src │ │ │ ├── fused_bias_act.cpp │ │ │ └── fused_bias_act_kernel.cu │ │ └── fused_act.py │ └── dcn │ │ ├── __init__.py │ │ └── src │ │ └── deform_conv_ext.cpp ├── losses │ ├── lpips │ │ ├── weights │ │ │ ├── v0.0 │ │ │ │ ├── vgg.pth │ │ │ │ ├── alex.pth │ │ │ │ └── squeeze.pth │ │ │ └── v0.1 │ │ │ │ ├── vgg.pth │ │ │ │ ├── alex.pth │ │ │ │ └── squeeze.pth │ │ ├── __init__.py │ │ ├── pretrained_networks.py │ │ └── lpips.py │ ├── __init__.py │ └── loss_util.py ├── metrics │ ├── __init__.py │ ├── metric_util.py │ └── psnr_ssim.py ├── utils │ ├── __init__.py │ ├── registry.py │ ├── dist_util.py │ ├── download_util.py │ ├── options.py │ ├── video_util.py │ ├── misc.py │ ├── file_client.py │ ├── img_util.py │ ├── logger.py │ ├── lmdb_util.py │ └── util.py ├── archs │ ├── __init__.py │ ├── rrdbnet_arch.py │ ├── AIEM.py │ ├── DRSW_arch.py │ └── arcface_arch.py ├── models │ ├── __init__.py │ └── lr_scheduler.py ├── data │ ├── data_sampler.py │ ├── prefetch_dataloader.py │ ├── __init__.py │ ├── paired_image_dataset.py │ └── transforms.py ├── options │ ├── test.yml │ ├── VarFormer_train_stage1.yml │ ├── options.py │ └── VarFormer_train_stage2.yml ├── varformer_test.py ├── pretrain.py └── train.py ├── docs └── fig1_4_00.png └── README.md /basicsr/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/fig1_4_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siywang541/Varformer/HEAD/docs/fig1_4_00.png -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ['upfirdn2d'] 4 | -------------------------------------------------------------------------------- /basicsr/losses/lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siywang541/Varformer/HEAD/basicsr/losses/lpips/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /basicsr/losses/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siywang541/Varformer/HEAD/basicsr/losses/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /basicsr/losses/lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siywang541/Varformer/HEAD/basicsr/losses/lpips/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /basicsr/losses/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siywang541/Varformer/HEAD/basicsr/losses/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /basicsr/losses/lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siywang541/Varformer/HEAD/basicsr/losses/lpips/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /basicsr/losses/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siywang541/Varformer/HEAD/basicsr/losses/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, 2 | modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from utils.registry import METRIC_REGISTRY 4 | from .psnr_ssim import calculate_psnr, calculate_ssim 5 | 6 | __all__ = ['calculate_psnr', 'calculate_ssim'] 7 | 8 | 9 | def calculate_metric(data, opt): 10 | """Calculate metric from data and options. 11 | 12 | Args: 13 | opt (dict): Configuration. It must constain: 14 | type (str): Model type. 15 | """ 16 | opt = deepcopy(opt) 17 | metric_type = opt.pop('type') 18 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 19 | return metric 20 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 3 | from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 4 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 5 | 6 | __all__ = [ 7 | # file_client.py 8 | 'FileClient', 9 | # img_util.py 10 | 'img2tensor', 11 | 'tensor2img', 12 | 'imfrombytes', 13 | 'imwrite', 14 | 'crop_border', 15 | # logger.py 16 | 'MessageLogger', 17 | 'init_tb_logger', 18 | 'init_wandb_logger', 19 | 'get_root_logger', 20 | 'get_env_info', 21 | # misc.py 22 | 'set_random_seed', 23 | 'get_time_str', 24 | 'mkdir_and_rename', 25 | 'make_exp_dirs', 26 | 'scandir', 27 | 'check_resume', 28 | 'sizeof_fmt' 29 | ] 30 | -------------------------------------------------------------------------------- /basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from utils import get_root_logger, scandir 6 | from utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with 12 | # '_arch.py' 13 | arch_folder = osp.dirname(osp.abspath(__file__)) 14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 15 | # import all the arch modules 16 | _arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames] 17 | 18 | 19 | def build_network(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop('type') 22 | net = ARCH_REGISTRY.get(network_type)(**opt) 23 | logger = get_root_logger() 24 | logger.info(f'Network [{net.__class__.__name__}] is created.') 25 | return net 26 | -------------------------------------------------------------------------------- /basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from utils import get_root_logger 4 | from utils.registry import LOSS_REGISTRY 5 | from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, PerceptualLoss2, WeightedTVLoss, g_path_regularize, 6 | gradient_penalty_loss, r1_penalty,PSNRLoss,SSIMLoss) 7 | 8 | __all__ = [ 9 | 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'PerceptualLoss2', 'GANLoss', 'gradient_penalty_loss', 10 | 'r1_penalty', 'g_path_regularize', 'PSNRLoss', 'SSIMLoss' 11 | ] 12 | 13 | 14 | def build_loss(opt): 15 | """Build loss from options. 16 | 17 | Args: 18 | opt (dict): Configuration. It must constain: 19 | type (str): Model type. 20 | """ 21 | opt = deepcopy(opt) 22 | loss_type = opt.pop('type') 23 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 24 | logger = get_root_logger() 25 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 26 | return loss 27 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from utils import get_root_logger, scandir 6 | from utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with 12 | # '_model.py' 13 | model_folder = osp.dirname(osp.abspath(__file__)) 14 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 15 | # import all the model modules 16 | _model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames] 17 | 18 | 19 | def build_model(opt): 20 | """Build model from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must constain: 24 | model_type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 28 | logger = get_root_logger() 29 | logger.info(f'Model [{model.__class__.__name__}] is created.') 30 | return model 31 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/src/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | #include 3 | 4 | 5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 6 | int up_x, int up_y, int down_x, int down_y, 7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 14 | int up_x, int up_y, int down_x, int down_y, 15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(kernel); 18 | 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | #include 3 | 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, 6 | const torch::Tensor& bias, 7 | const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, 15 | const torch::Tensor& bias, 16 | const torch::Tensor& refer, 17 | int act, int grad, float alpha, float scale) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(bias); 20 | 21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 26 | } 27 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /basicsr/options/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | suffix: ~ # add suffix to saved images 3 | distortion: sr 4 | scale: 1 5 | crop_border: 0 # crop border when evaluation. If None(~), crop the scale pixels 6 | input_size: 256 7 | mean: [0.5, 0.5, 0.5] 8 | std: [0.5, 0.5, 0.5] 9 | 10 | 11 | # network structures 12 | network_g: 13 | type: VarFormer2 # VarFormer5 #VarFormer1 14 | depth: 16 #24 #20 15 | patch_nums: [1, 2, 3, 4, 5, 6, 8, 10, 13, 16] 16 | attn_l2_norm: True 17 | var_force_dpth: 7 18 | n_layers: 4 19 | if_enhance: True 20 | dec_adjust: False 21 | var_cross_c: 3 22 | ch_mult: [1, 1, 2, 2, 4] 23 | num_res_blocks: 2 24 | 25 | 26 | fix_modules: [] 27 | vqgan_path: ./experiments/pretrained_models/vae_ch160v4096z32.pth 28 | var_path: ./experiments/pretrained_models/var_d16.pth # keep the same with depth 29 | 30 | 31 | datasets: 32 | test_0: 33 | name: low 34 | mode: LQGTDataset 35 | dataroot_GT: /home/wangsy/dataset/llight/LOLdataset/eval15/high 36 | dataroot_LQ: /home/wangsy/dataset/llight/LOLdataset/eval15/low 37 | test_1: 38 | name: test100 39 | mode: LQGTDataset 40 | dataroot_GT: /home/wangsy/dataset/derain/test/Test100/target 41 | dataroot_LQ: /home/wangsy/dataset/derain/test/Test100/input 42 | test_2: 43 | name: haze 44 | mode: LQGTDataset3 45 | dataroot_GT: /home/wangsy/dataset/dehaze/SOTS/outdoor/gt 46 | dataroot_LQ: /home/wangsy/dataset/dehaze/SOTS/outdoor/hazy 47 | test_3: 48 | name: urban100_0_50 49 | mode: LQGTDataset 50 | dataroot_GT: /home/wangsy/dataset/denoise/urban100_0_50/target 51 | dataroot_LQ: /home/wangsy/dataset/denoise/urban100_0_50/input 52 | test_4: 53 | name: GoPro 54 | mode: LQGTDataset 55 | dataroot_GT: /home/wangsy/dataset/deblur/GoPro/test/groundtruth 56 | dataroot_LQ: /home/wangsy/dataset/deblur/GoPro/test/input 57 | test_5: 58 | name: SIDD 59 | mode: LQGTDataset 60 | dataroot_GT: /home/wangsy/dataset/denoise/SIDD/val/target 61 | dataroot_LQ: /home/wangsy/dataset/denoise/SIDD/val/target 62 | 63 | #### path 64 | path: 65 | root: ./experiments/ 66 | pretrain_model: ./experiments/pretrained_models/net_g_last.pth 67 | strict_load: false 68 | -------------------------------------------------------------------------------- /basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj): 39 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 40 | f"in '{self._name}' registry!") 41 | self._obj_map[name] = obj 42 | 43 | def register(self, obj=None): 44 | """ 45 | Register the given object under the the name `obj.__name__`. 46 | Can be used as either a decorator or not. 47 | See docstring of this class for usage. 48 | """ 49 | if obj is None: 50 | # used as a decorator 51 | def deco(func_or_class): 52 | name = func_or_class.__name__ 53 | self._do_register(name, func_or_class) 54 | return func_or_class 55 | 56 | return deco 57 | 58 | # used as a function call 59 | name = obj.__name__ 60 | self._do_register(name, obj) 61 | 62 | def get(self, name): 63 | ret = self._obj_map.get(name) 64 | if ret is None: 65 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 66 | return ret 67 | 68 | def __contains__(self, name): 69 | return name in self._obj_map 70 | 71 | def __iter__(self): 72 | return iter(self._obj_map.items()) 73 | 74 | def keys(self): 75 | return self._obj_map.keys() 76 | 77 | 78 | DATASET_REGISTRY = Registry('dataset') 79 | ARCH_REGISTRY = Registry('arch') 80 | MODEL_REGISTRY = Registry('model') 81 | LOSS_REGISTRY = Registry('loss') 82 | METRIC_REGISTRY = Registry('metric') 83 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | 7 | try: 8 | from . import fused_act_ext 9 | except ImportError: 10 | import os 11 | BASICSR_JIT = os.getenv('BASICSR_JIT') 12 | if BASICSR_JIT == 'True': 13 | from torch.utils.cpp_extension import load 14 | module_path = os.path.dirname(__file__) 15 | fused_act_ext = load( 16 | 'fused', 17 | sources=[ 18 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'), 19 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), 20 | ], 21 | ) 22 | 23 | 24 | class FusedLeakyReLUFunctionBackward(Function): 25 | 26 | @staticmethod 27 | def forward(ctx, grad_output, out, negative_slope, scale): 28 | ctx.save_for_backward(out) 29 | ctx.negative_slope = negative_slope 30 | ctx.scale = scale 31 | 32 | empty = grad_output.new_empty(0) 33 | 34 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 35 | 36 | dim = [0] 37 | 38 | if grad_input.ndim > 2: 39 | dim += list(range(2, grad_input.ndim)) 40 | 41 | grad_bias = grad_input.sum(dim).detach() 42 | 43 | return grad_input, grad_bias 44 | 45 | @staticmethod 46 | def backward(ctx, gradgrad_input, gradgrad_bias): 47 | out, = ctx.saved_tensors 48 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, 49 | ctx.scale) 50 | 51 | return gradgrad_out, None, None, None 52 | 53 | 54 | class FusedLeakyReLUFunction(Function): 55 | 56 | @staticmethod 57 | def forward(ctx, input, bias, negative_slope, scale): 58 | empty = input.new_empty(0) 59 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 60 | ctx.save_for_backward(out) 61 | ctx.negative_slope = negative_slope 62 | ctx.scale = scale 63 | 64 | return out 65 | 66 | @staticmethod 67 | def backward(ctx, grad_output): 68 | out, = ctx.saved_tensors 69 | 70 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 71 | 72 | return grad_input, grad_bias, None, None 73 | 74 | 75 | class FusedLeakyReLU(nn.Module): 76 | 77 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 78 | super().__init__() 79 | 80 | self.bias = nn.Parameter(torch.zeros(channel)) 81 | self.negative_slope = negative_slope 82 | self.scale = scale 83 | 84 | def forward(self, input): 85 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 86 | 87 | 88 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 89 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 90 | -------------------------------------------------------------------------------- /basicsr/varformer_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import time 5 | import argparse 6 | from collections import OrderedDict 7 | import numpy as np 8 | 9 | from options import options as option 10 | import utils.util as util 11 | from torchvision.transforms.functional import normalize 12 | from archs import build_network 13 | from utils.misc import gpu_is_available, get_device 14 | import torch 15 | from glob import glob 16 | import tqdm 17 | import cv2 18 | import torch.nn.functional as F 19 | from utils import tensor2img 20 | 21 | 22 | #### options 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') 25 | opt = option.parse(parser.parse_args().opt, is_train=False) 26 | opt = option.dict_to_nonedict(opt) 27 | 28 | util.mkdirs( 29 | (path for key, path in opt['path'].items() 30 | if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) 31 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, 32 | screen=True, tofile=True) 33 | logger = logging.getLogger('base') 34 | logger.info(option.dict2str(opt)) 35 | 36 | device = get_device() 37 | model = build_network(opt['network_g']).to(device) 38 | ckpt_path = opt['path']['pretrain_model'] 39 | checkpoint = torch.load(ckpt_path) 40 | 41 | 42 | ckpt = checkpoint['params_ema'] 43 | model.load_state_dict(ckpt,strict=opt['path']['strict_load']) 44 | 45 | model.eval() 46 | 47 | 48 | #### Create test dataset and dataloader 49 | test_loaders = [] 50 | for phase, dataset_opt in sorted(opt['datasets'].items()): 51 | paths_LQ = dataset_opt['dataroot_LQ'] 52 | paths_GT = dataset_opt['dataroot_GT'] 53 | 54 | crop_border = opt['crop_border'] if opt['crop_border'] is not None else opt['scale'] 55 | need_GT = False if dataset_opt['dataroot_GT'] is None else True 56 | 57 | test_set_name = dataset_opt['name'] 58 | logger.info('\nTesting [{:s}]...'.format(test_set_name)) 59 | dataset_dir = osp.join(opt['path']['results_root'], test_set_name) 60 | util.mkdir(dataset_dir) 61 | 62 | img_path_list = glob(os.path.join(paths_LQ, "*.png")) + glob(os.path.join(paths_LQ, "*.jpg")) 63 | 64 | for img_path in tqdm.tqdm(img_path_list): 65 | """ Load an image """ 66 | img_name = os.path.basename(img_path) 67 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 68 | H, W, C = img.shape 69 | img = img.astype(np.float32) / 255. 70 | img = cv2.resize(img, (opt['input_size'], opt['input_size']), interpolation=cv2.INTER_LINEAR) 71 | if img.shape[2] == 3: 72 | img = img[:, :, [2, 1, 0]] 73 | img = torch.from_numpy(np.ascontiguousarray(np.transpose(img, (2, 0, 1)))).float() 74 | normalize(img, opt['mean'], opt['std'], inplace=True) 75 | 76 | model.eval() 77 | output, _ = model(img.unsqueeze(0).to(device)) 78 | sr_img = output.detach().cpu() 79 | sr_img = tensor2img(sr_img) 80 | save_img_path = osp.join(dataset_dir, img_name) 81 | util.save_img(sr_img, save_img_path) 82 | -------------------------------------------------------------------------------- /basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Navigating Image Restoration with VAR’s Distribution Alignment Prior [![arXiv](https://img.shields.io/badge/arXiv%20paper-2412.21063-b31b1b.svg)](https://arxiv.org/abs/2412.21063v2)  2 | 3 | This repository contains the official implementation of the CVPR 2025 poster paper [Navigating Image Restoration with VAR’s Distribution Alignment Prior](https://arxiv.org/abs/2412.21063v2). 4 | ## 🚀 Abstract 5 | ![Teaser](docs/fig1_4_00.png) 6 | >Generative models trained on extensive high-quality datasets effectively capture the structural and statistical properties of clean images, rendering them powerful priors for transforming degraded features into clean ones in image restoration. VAR, a novel image generative paradigm, surpasses diffusion models in generation quality by applying a next-scale prediction approach. It progressively captures both global structures and fine-grained details through the autoregressive process, consistent with the multi-scale restoration principle widely acknowledged in the restoration community. Furthermore, we observe that during the image reconstruction process utilizing VAR, scale predictions automatically modulate the input, facilitating the alignment of representations at subsequent scales with the distribution of clean images. To harness VAR’s adaptive distribution alignment capability in image restoration tasks, we formulate the multi-scale latent representations within VAR as the restoration prior, thus advancing our delicately designed VarFormer framework. The strategic application of these priors enables our VarFormer to achieve remarkable generalization on unseen tasks while also reducing training computational costs. 7 | 8 | 9 | 10 | ## 🔥 Data preparing 11 | we refer to [DiffUIR](https://github.com/iSEE-Laboratory/DiffUIR), the other datasets you could download from [AdaIR](https://github.com/c-yn/AdaIR/blob/main/INSTALL.md) and [DF2K](https://github.com/XPixelGroup/BasicSR/blob/master/docs/DatasetPreparation.md) 12 | 13 | - The data structure is like this: 14 | ``` 15 | dataset 16 | ├── GoPro 17 | │ ├── train 18 | │ ├── test 19 | ├── LOL 20 | │ ├── our485 21 | │ ├── eval15 22 | --- 23 | ``` 24 | 25 | ## 🔥Stage 1 26 | 27 | Download [VAR](https://github.com/FoundationVision/VAR) into `experiments/pretrained_models/` 28 | 29 | ``` 30 | 31 | python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4344 basicsr/pretrain.py -opt basicsr/options/VarFormer_train_stage1.yml --launcher pytorch 32 | 33 | ``` 34 | 35 | ## 🔥Stage 2 36 | 37 | You can train based on stage 1 or just start from scratch. Modify the configuration file `options/VarFormer_train_stage2.yml` accordingly. 38 | 39 | ``` 40 | python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4344 basicsr/train.py -opt basicsr/options/VarFormer_train_stage2.yml --launcher pytorch 41 | 42 | ``` 43 | 44 | ## 🔥Inference 45 | 46 | Modify the configuration file `options/test.yml` accordingly. Download [VarFormer_16](https://huggingface.co/wsy541/VarFormer/resolve/main/net_g_last.pth) into `experiments/pretrained_models/` 47 | 48 | ``` 49 | 50 | python basicsr/varformer_test.py -opt basicsr/options/test.yml 51 | 52 | ``` 53 | 54 | 55 | ## Acknowledgements 56 | 57 | This code is built upon [VAR](https://github.com/FoundationVision/VAR) and [CodeFormer](https://github.com/sczhou/CodeFormer), thanks for their excellent work! 58 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu 2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 3 | // 4 | // This work is made available under the Nvidia Source Code License-NC. 5 | // To view a copy of this license, visit 6 | // https://nvlabs.github.io/stylegan2/license.html 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | 19 | template 20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 23 | 24 | scalar_t zero = 0.0; 25 | 26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 27 | scalar_t x = p_x[xi]; 28 | 29 | if (use_bias) { 30 | x += p_b[(xi / step_b) % size_b]; 31 | } 32 | 33 | scalar_t ref = use_ref ? p_ref[xi] : zero; 34 | 35 | scalar_t y; 36 | 37 | switch (act * 10 + grad) { 38 | default: 39 | case 10: y = x; break; 40 | case 11: y = x; break; 41 | case 12: y = 0.0; break; 42 | 43 | case 30: y = (x > 0.0) ? x : x * alpha; break; 44 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 45 | case 32: y = 0.0; break; 46 | } 47 | 48 | out[xi] = y * scale; 49 | } 50 | } 51 | 52 | 53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 54 | int act, int grad, float alpha, float scale) { 55 | int curDevice = -1; 56 | cudaGetDevice(&curDevice); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 58 | 59 | auto x = input.contiguous(); 60 | auto b = bias.contiguous(); 61 | auto ref = refer.contiguous(); 62 | 63 | int use_bias = b.numel() ? 1 : 0; 64 | int use_ref = ref.numel() ? 1 : 0; 65 | 66 | int size_x = x.numel(); 67 | int size_b = b.numel(); 68 | int step_b = 1; 69 | 70 | for (int i = 1 + 1; i < x.dim(); i++) { 71 | step_b *= x.size(i); 72 | } 73 | 74 | int loop_x = 4; 75 | int block_size = 4 * 32; 76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 77 | 78 | auto y = torch::empty_like(x); 79 | 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 81 | fused_bias_act_kernel<<>>( 82 | y.data_ptr(), 83 | x.data_ptr(), 84 | b.data_ptr(), 85 | ref.data_ptr(), 86 | act, 87 | grad, 88 | alpha, 89 | scale, 90 | loop_x, 91 | size_x, 92 | step_b, 93 | size_b, 94 | use_bias, 95 | use_ref 96 | ); 97 | }); 98 | 99 | return y; 100 | } 101 | -------------------------------------------------------------------------------- /basicsr/options/VarFormer_train_stage1.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: VarFormer2_pre 3 | model_type: VarFormerIdxModel4_2_losszd #VarFormerIdxModel4_2 4 | num_gpu: 2 5 | manual_seed: 114514 #3407 6 | scale: 1 7 | crop_border: 0 8 | weights: [] 9 | 10 | 11 | datasets: 12 | pretrain: 13 | name: DF2K 14 | type: LQGTDataset3_pre 15 | data_type: img 16 | data_num: 1 17 | 18 | dataroot_GT0: /home/wangsy/dataset/DF2K/HR 19 | dataroot_LQ0: /home/wangsy/dataset/DF2K/HR 20 | 21 | need_GT: true 22 | cond_scale: 4 23 | use_shuffle: true 24 | GT_size: 256 # #480 # 160 | 480 25 | use_flip: false 26 | use_rot: false 27 | # data loader 28 | num_worker_per_gpu: 2 29 | batch_size_per_gpu: 4 30 | dataset_enlarge_ratio: 10 31 | prefetch_mode: ~ 32 | 33 | # network structures 34 | network_g: 35 | type: VarFormer2_varRec 36 | depth: 20 # 16 #24 ## keep the same with var_path 37 | patch_nums: [1, 2, 3, 4, 5, 6, 8, 10, 13, 16] 38 | attn_l2_norm: True 39 | var_force_dpth: 7 40 | n_layers: 4 41 | if_enhance: True 42 | 43 | fix_modules: ['encoder','quant_conv','quantize','blocks','decoder','post_quant_conv'] 44 | vqgan_path: ./experiments/pretrained_models/vae_ch160v4096z32.pth # pretrained VQGAN 45 | var_path: ./experiments/pretrained_models/var_d20.pth # var_d16.pth # var_d24.pth # keep the same with depth 46 | network_vqgan: # this config is needed if no pre-calculated latent 47 | type: VarVQAutoEncoder 48 | model_path: ./experiments/pretrained_models/vae_ch160v4096z32.pth 49 | 50 | # path 51 | path: 52 | pretrain_network_g: ~ 53 | param_key_g: params_ema 54 | strict_load_g: false 55 | pretrain_network_d: ~ #/experiments/pretrained_models/vqgan_discriminator.pth 56 | strict_load_d: true 57 | resume_state: ~ 58 | 59 | # base_lr(4.5e-6)*bach_size(4) 60 | train: 61 | use_hq_feat_loss: true 62 | feat_loss_weight: 0.5 63 | cross_entropy_loss: true 64 | entropy_loss_weight: 0.1 65 | fidelity_weight: 0 66 | perceptual_loss_weight: 1.0 67 | use_pixel_opt: true 68 | 69 | optim_g: 70 | type: Adam 71 | lr: !!float 1e-4 72 | weight_decay: 0 73 | betas: [0.9, 0.99] 74 | 75 | scheduler: 76 | type: MultiStepLR 77 | milestones: [400000, 450000] 78 | gamma: 0.5 79 | 80 | # scheduler: 81 | # type: CosineAnnealingRestartLR 82 | # periods: [500000] 83 | # restart_weights: [1] 84 | # eta_min: !!float 2e-5 # no lr reduce in official vqgan code 85 | 86 | total_iter: 500000 87 | 88 | warmup_iter: -1 # no warm up 89 | ema_decay: 0.998 90 | 91 | use_adaptive_weight: true 92 | 93 | net_g_start_iter: 0 94 | net_d_iters: 1 95 | net_d_start_iter: 0 96 | manual_seed: 620664 97 | 98 | # perceptual_opt: 99 | # type: PerceptualLoss 100 | # use_input_norm: true 101 | # range_norm: true 102 | # layer_weights: {'relu5_4': 1.} 103 | 104 | # perceptual_opt: 105 | # type: LPIPSLoss 106 | # loss_weight: 1.0 107 | # use_input_norm: true 108 | # range_norm: true 109 | 110 | # psnr_opt: 111 | # type: PSNRLoss 112 | 113 | # ssim_opt: 114 | # type: SSIMLoss 115 | 116 | 117 | # logging settings 118 | logger: 119 | print_freq: 100 120 | save_checkpoint_freq: !!float 1e4 121 | use_tb_logger: true 122 | wandb: 123 | project: ~ 124 | resume_id: ~ 125 | 126 | # dist training settings 127 | dist_params: 128 | backend: nccl 129 | port: 29412 130 | 131 | find_unused_parameters: true 132 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | from .misc import sizeof_fmt 9 | 10 | 11 | def download_file_from_google_drive(file_id, save_path): 12 | """Download files from google drive. 13 | Ref: 14 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 15 | Args: 16 | file_id (str): File id. 17 | save_path (str): Save path. 18 | """ 19 | 20 | session = requests.Session() 21 | URL = 'https://docs.google.com/uc?export=download' 22 | params = {'id': file_id} 23 | 24 | response = session.get(URL, params=params, stream=True) 25 | token = get_confirm_token(response) 26 | if token: 27 | params['confirm'] = token 28 | response = session.get(URL, params=params, stream=True) 29 | 30 | # get file size 31 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 32 | print(response_file_size) 33 | if 'Content-Range' in response_file_size.headers: 34 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 49 | if file_size is not None: 50 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 51 | 52 | readable_file_size = sizeof_fmt(file_size) 53 | else: 54 | pbar = None 55 | 56 | with open(destination, 'wb') as f: 57 | downloaded_size = 0 58 | for chunk in response.iter_content(chunk_size): 59 | downloaded_size += chunk_size 60 | if pbar is not None: 61 | pbar.update(1) 62 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 63 | if chunk: # filter out keep-alive new chunks 64 | f.write(chunk) 65 | if pbar is not None: 66 | pbar.close() 67 | 68 | 69 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 70 | """Load file form http url, will download models if necessary. 71 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 72 | Args: 73 | url (str): URL to be downloaded. 74 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 75 | Default: None. 76 | progress (bool): Whether to show the download progress. Default: True. 77 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 78 | Returns: 79 | str: The path to the downloaded file. 80 | """ 81 | if model_dir is None: # use the pytorch hub_dir 82 | hub_dir = get_dir() 83 | model_dir = os.path.join(hub_dir, 'checkpoints') 84 | 85 | os.makedirs(model_dir, exist_ok=True) 86 | 87 | parts = urlparse(url) 88 | filename = os.path.basename(parts.path) 89 | if file_name is not None: 90 | filename = file_name 91 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 92 | if not os.path.exists(cached_file): 93 | print(f'Downloading: "{url}" to {cached_file}\n') 94 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 95 | return cached_file -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 116 | 117 | def next(self): 118 | torch.cuda.current_stream().wait_stream(self.stream) 119 | batch = self.batch 120 | self.preload() 121 | return batch 122 | 123 | def reset(self): 124 | self.loader = iter(self.ori_loader) 125 | self.preload() 126 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import time 3 | from collections import OrderedDict 4 | from os import path as osp 5 | from utils.misc import get_time_str 6 | 7 | def ordered_yaml(): 8 | """Support OrderedDict for yaml. 9 | 10 | Returns: 11 | yaml Loader and Dumper. 12 | """ 13 | try: 14 | from yaml import CDumper as Dumper 15 | from yaml import CLoader as Loader 16 | except ImportError: 17 | from yaml import Dumper, Loader 18 | 19 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 20 | 21 | def dict_representer(dumper, data): 22 | return dumper.represent_dict(data.items()) 23 | 24 | def dict_constructor(loader, node): 25 | return OrderedDict(loader.construct_pairs(node)) 26 | 27 | Dumper.add_representer(OrderedDict, dict_representer) 28 | Loader.add_constructor(_mapping_tag, dict_constructor) 29 | return Loader, Dumper 30 | 31 | 32 | def parse(opt_path, root_path, is_train=True): 33 | """Parse option file. 34 | 35 | Args: 36 | opt_path (str): Option file path. 37 | is_train (str): Indicate whether in training or not. Default: True. 38 | 39 | Returns: 40 | (dict): Options. 41 | """ 42 | # print('opt_path============================',opt_path) 43 | with open(opt_path, mode='r') as f: 44 | Loader, _ = ordered_yaml() 45 | opt = yaml.load(f, Loader=Loader) 46 | 47 | # print('is_train==============================',is_train,opt) 48 | opt['is_train'] = is_train 49 | 50 | # opt['name'] = f"{get_time_str()}_{opt['name']}" 51 | if opt['path'].get('resume_state', None): # Shangchen added 52 | resume_state_path = opt['path'].get('resume_state') 53 | opt['name'] = resume_state_path.split("/")[-3] 54 | else: 55 | opt['name'] = f"{get_time_str()}_{opt['name']}" 56 | 57 | 58 | # datasets 59 | for phase, dataset in opt['datasets'].items(): 60 | # for several datasets, e.g., test_1, test_2 61 | phase = phase.split('_')[0] 62 | dataset['phase'] = phase 63 | if 'scale' in opt: 64 | dataset['scale'] = opt['scale'] 65 | if dataset.get('dataroot_gt') is not None: 66 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 67 | if dataset.get('dataroot_lq') is not None: 68 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 69 | 70 | # paths 71 | for key, val in opt['path'].items(): 72 | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): 73 | opt['path'][key] = osp.expanduser(val) 74 | 75 | if is_train: 76 | experiments_root = osp.join(root_path, 'experiments', opt['name']) 77 | opt['path']['experiments_root'] = experiments_root 78 | opt['path']['models'] = osp.join(experiments_root, 'models') 79 | opt['path']['training_states'] = osp.join(experiments_root, 'training_states') 80 | opt['path']['log'] = experiments_root 81 | opt['path']['visualization'] = osp.join(experiments_root, 'visualization') 82 | 83 | else: # test 84 | results_root = osp.join(root_path, 'results', opt['name']) 85 | opt['path']['results_root'] = results_root 86 | opt['path']['log'] = results_root 87 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 88 | 89 | return opt 90 | 91 | 92 | def dict2str(opt, indent_level=1): 93 | """dict to string for printing options. 94 | 95 | Args: 96 | opt (dict): Option dict. 97 | indent_level (int): Indent level. Default: 1. 98 | 99 | Return: 100 | (str): Option string for printing. 101 | """ 102 | msg = '\n' 103 | for k, v in opt.items(): 104 | if isinstance(v, dict): 105 | msg += ' ' * (indent_level * 2) + k + ':[' 106 | msg += dict2str(v, indent_level + 1) 107 | msg += ' ' * (indent_level * 2) + ']\n' 108 | else: 109 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 110 | return msg 111 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The mimimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] 97 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from data.prefetch_dataloader import PrefetchDataLoader 11 | from utils import get_root_logger, scandir 12 | from utils.dist_util import get_dist_info 13 | from utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must constain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | print('type========================',dataset_opt['type']) 35 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 36 | logger = get_root_logger() 37 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.') 38 | return dataset 39 | 40 | 41 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 42 | """Build dataloader. 43 | 44 | Args: 45 | dataset (torch.utils.data.Dataset): Dataset. 46 | dataset_opt (dict): Dataset options. It contains the following keys: 47 | phase (str): 'train' or 'val'. 48 | num_worker_per_gpu (int): Number of workers for each GPU. 49 | batch_size_per_gpu (int): Training batch size for each GPU. 50 | num_gpu (int): Number of GPUs. Used only in the train phase. 51 | Default: 1. 52 | dist (bool): Whether in distributed training. Used only in the train 53 | phase. Default: False. 54 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 55 | seed (int | None): Seed. Default: None 56 | """ 57 | phase = dataset_opt['phase'] 58 | rank, _ = get_dist_info() 59 | if phase == 'train': 60 | if dist: # distributed training 61 | batch_size = dataset_opt['batch_size_per_gpu'] 62 | num_workers = dataset_opt['num_worker_per_gpu'] 63 | else: # non-distributed training 64 | multiplier = 1 if num_gpu == 0 else num_gpu 65 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 66 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 67 | dataloader_args = dict( 68 | dataset=dataset, 69 | batch_size=batch_size, 70 | shuffle=False, 71 | num_workers=num_workers, 72 | sampler=sampler, 73 | drop_last=True) 74 | if sampler is None: 75 | dataloader_args['shuffle'] = True 76 | dataloader_args['worker_init_fn'] = partial( 77 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 78 | elif phase in ['val', 'test']: # validation 79 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 80 | else: 81 | raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.") 82 | 83 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 84 | 85 | prefetch_mode = dataset_opt.get('prefetch_mode') 86 | if prefetch_mode == 'cpu': # CPUPrefetcher 87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 88 | logger = get_root_logger() 89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}') 90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 91 | else: 92 | # prefetch_mode=None: Normal dataloader 93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 94 | return torch.utils.data.DataLoader(**dataloader_args) 95 | 96 | 97 | def worker_init_fn(worker_id, num_workers, rank, seed): 98 | # Set the worker seed to num_workers * rank + worker_id + seed 99 | worker_seed = num_workers * rank + worker_id + seed 100 | np.random.seed(worker_seed) 101 | random.seed(worker_seed) 102 | -------------------------------------------------------------------------------- /basicsr/options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import yaml 5 | from utils.util import OrderedYaml 6 | Loader, Dumper = OrderedYaml() 7 | 8 | 9 | def parse(opt_path, is_train=True): 10 | with open(opt_path, mode='r',encoding='utf-8') as f: 11 | opt = yaml.load(f, Loader=Loader) 12 | 13 | 14 | opt['is_train'] = is_train 15 | if opt['distortion'] == 'sr': 16 | scale = opt['scale'] 17 | 18 | # datasets 19 | for phase, dataset in opt['datasets'].items(): 20 | phase = phase.split('_')[0] 21 | dataset['phase'] = phase 22 | if opt['distortion'] == 'sr': 23 | dataset['scale'] = scale 24 | is_lmdb = False 25 | if dataset.get('dataroot_HQ', None) is not None: 26 | dataset['dataroot_HQ'] = osp.expanduser(dataset['dataroot_HQ']) 27 | if dataset['dataroot_HQ'].endswith('lmdb'): 28 | is_lmdb = True 29 | if dataset.get('dataroot_GT_bg', None) is not None: 30 | dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg']) 31 | if dataset.get('dataroot_LQ', None) is not None: 32 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) 33 | if dataset['dataroot_LQ'].endswith('lmdb'): 34 | is_lmdb = True 35 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img' 36 | if dataset['mode'].endswith('mc'): # for memcached 37 | dataset['data_type'] = 'mc' 38 | dataset['mode'] = dataset['mode'].replace('_mc', '') 39 | 40 | # path 41 | for key, path in opt['path'].items(): 42 | if path and key in opt['path'] and key != 'strict_load': 43 | opt['path'][key] = osp.expanduser(path) 44 | 45 | if is_train: 46 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) 47 | opt['path']['experiments_root'] = experiments_root 48 | opt['path']['models'] = osp.join(experiments_root, 'models') 49 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state') 50 | opt['path']['log'] = experiments_root 51 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images') 52 | 53 | # change some options for debug mode 54 | if 'debug' in opt['name']: 55 | opt['train']['val_freq'] = 8 56 | opt['logger']['print_freq'] = 1 57 | opt['logger']['save_checkpoint_freq'] = 8 58 | else: # test 59 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 60 | opt['path']['results_root'] = results_root 61 | opt['path']['log'] = results_root 62 | 63 | 64 | 65 | return opt 66 | 67 | 68 | def dict2str(opt, indent_l=1): 69 | '''dict to string for logger''' 70 | msg = '' 71 | for k, v in opt.items(): 72 | if isinstance(v, dict): 73 | msg += ' ' * (indent_l * 2) + k + ':[\n' 74 | msg += dict2str(v, indent_l + 1) 75 | msg += ' ' * (indent_l * 2) + ']\n' 76 | else: 77 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 78 | return msg 79 | 80 | 81 | class NoneDict(dict): 82 | def __missing__(self, key): 83 | return None 84 | 85 | 86 | # convert to NoneDict, which return None for missing key. 87 | def dict_to_nonedict(opt): 88 | if isinstance(opt, dict): 89 | new_opt = dict() 90 | for key, sub_opt in opt.items(): 91 | new_opt[key] = dict_to_nonedict(sub_opt) 92 | return NoneDict(**new_opt) 93 | elif isinstance(opt, list): 94 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 95 | else: 96 | return opt 97 | 98 | 99 | def check_resume(opt, resume_iter): 100 | '''Check resume states and pretrain_model paths''' 101 | logger = logging.getLogger('base') 102 | if opt['path']['resume_state']: 103 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( 104 | 'pretrain_model_D', None) is not None: 105 | logger.warning('pretrain_model path will be ignored when resuming training.') 106 | 107 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], 108 | '{}_G.pth'.format(resume_iter)) 109 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 110 | if 'gan' in opt['model']: 111 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], 112 | '{}_D.pth'.format(resume_iter)) 113 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) 114 | -------------------------------------------------------------------------------- /basicsr/data/paired_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from torchvision.transforms.functional import normalize 3 | 4 | from data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file 5 | from data.transforms import augment, paired_random_crop 6 | from utils import FileClient, imfrombytes, img2tensor 7 | from utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class PairedImageDataset(data.Dataset): 12 | """Paired image dataset for image restoration. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and 15 | GT image pairs. 16 | 17 | There are three modes: 18 | 1. 'lmdb': Use lmdb files. 19 | If opt['io_backend'] == lmdb. 20 | 2. 'meta_info_file': Use meta information file to generate paths. 21 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 22 | 3. 'folder': Scan folders to generate paths. 23 | The rest. 24 | 25 | Args: 26 | opt (dict): Config for train datasets. It contains the following keys: 27 | dataroot_gt (str): Data root path for gt. 28 | dataroot_lq (str): Data root path for lq. 29 | meta_info_file (str): Path for meta information file. 30 | io_backend (dict): IO backend type and other kwarg. 31 | filename_tmpl (str): Template for each filename. Note that the 32 | template excludes the file extension. Default: '{}'. 33 | gt_size (int): Cropped patched size for gt patches. 34 | use_flip (bool): Use horizontal flips. 35 | use_rot (bool): Use rotation (use vertical flip and transposing h 36 | and w for implementation). 37 | 38 | scale (bool): Scale, which will be added automatically. 39 | phase (str): 'train' or 'val'. 40 | """ 41 | 42 | def __init__(self, opt): 43 | super(PairedImageDataset, self).__init__() 44 | self.opt = opt 45 | # file client (io backend) 46 | self.file_client = None 47 | self.io_backend_opt = opt['io_backend'] 48 | self.mean = opt['mean'] if 'mean' in opt else None 49 | self.std = opt['std'] if 'std' in opt else None 50 | 51 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 52 | if 'filename_tmpl' in opt: 53 | self.filename_tmpl = opt['filename_tmpl'] 54 | else: 55 | self.filename_tmpl = '{}' 56 | 57 | if self.io_backend_opt['type'] == 'lmdb': 58 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 59 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 60 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 61 | elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: 62 | self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], 63 | self.opt['meta_info_file'], self.filename_tmpl) 64 | else: 65 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 66 | 67 | def __getitem__(self, index): 68 | if self.file_client is None: 69 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 70 | 71 | scale = self.opt['scale'] 72 | 73 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 74 | # image range: [0, 1], float32. 75 | gt_path = self.paths[index]['gt_path'] 76 | img_bytes = self.file_client.get(gt_path, 'gt') 77 | img_gt = imfrombytes(img_bytes, float32=True) 78 | lq_path = self.paths[index]['lq_path'] 79 | img_bytes = self.file_client.get(lq_path, 'lq') 80 | img_lq = imfrombytes(img_bytes, float32=True) 81 | 82 | # augmentation for training 83 | if self.opt['phase'] == 'train': 84 | gt_size = self.opt['gt_size'] 85 | # random crop 86 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 87 | # flip, rotation 88 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) 89 | 90 | 91 | # wsy 92 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, 256, scale, gt_path) 93 | 94 | # TODO: color space transform 95 | # BGR to RGB, HWC to CHW, numpy to tensor 96 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 97 | # normalize 98 | if self.mean is not None or self.std is not None: 99 | normalize(img_lq, self.mean, self.std, inplace=True) 100 | normalize(img_gt, self.mean, self.std, inplace=True) 101 | 102 | return {'in': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 103 | 104 | def __len__(self): 105 | return len(self.paths) 106 | -------------------------------------------------------------------------------- /basicsr/metrics/psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from metrics.metric_util import reorder_image, to_y_channel 5 | from utils.registry import METRIC_REGISTRY 6 | 7 | 8 | @METRIC_REGISTRY.register() 9 | def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False): 10 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 11 | 12 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 13 | 14 | Args: 15 | img1 (ndarray): Images with range [0, 255]. 16 | img2 (ndarray): Images with range [0, 255]. 17 | crop_border (int): Cropped pixels in each edge of an image. These 18 | pixels are not involved in the PSNR calculation. 19 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 20 | Default: 'HWC'. 21 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 22 | 23 | Returns: 24 | float: psnr result. 25 | """ 26 | 27 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 28 | if input_order not in ['HWC', 'CHW']: 29 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 30 | img1 = reorder_image(img1, input_order=input_order) 31 | img2 = reorder_image(img2, input_order=input_order) 32 | img1 = img1.astype(np.float64) 33 | img2 = img2.astype(np.float64) 34 | 35 | if crop_border != 0: 36 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 37 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 38 | 39 | if test_y_channel: 40 | img1 = to_y_channel(img1) 41 | img2 = to_y_channel(img2) 42 | 43 | mse = np.mean((img1 - img2)**2) 44 | if mse == 0: 45 | return float('inf') 46 | return 20. * np.log10(255. / np.sqrt(mse)) 47 | 48 | 49 | def _ssim(img1, img2): 50 | """Calculate SSIM (structural similarity) for one channel images. 51 | 52 | It is called by func:`calculate_ssim`. 53 | 54 | Args: 55 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 56 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 57 | 58 | Returns: 59 | float: ssim result. 60 | """ 61 | 62 | C1 = (0.01 * 255)**2 63 | C2 = (0.03 * 255)**2 64 | 65 | img1 = img1.astype(np.float64) 66 | img2 = img2.astype(np.float64) 67 | kernel = cv2.getGaussianKernel(11, 1.5) 68 | window = np.outer(kernel, kernel.transpose()) 69 | 70 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 71 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 72 | mu1_sq = mu1**2 73 | mu2_sq = mu2**2 74 | mu1_mu2 = mu1 * mu2 75 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 76 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 77 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 78 | 79 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 80 | return ssim_map.mean() 81 | 82 | 83 | @METRIC_REGISTRY.register() 84 | def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False): 85 | """Calculate SSIM (structural similarity). 86 | 87 | Ref: 88 | Image quality assessment: From error visibility to structural similarity 89 | 90 | The results are the same as that of the official released MATLAB code in 91 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 92 | 93 | For three-channel images, SSIM is calculated for each channel and then 94 | averaged. 95 | 96 | Args: 97 | img1 (ndarray): Images with range [0, 255]. 98 | img2 (ndarray): Images with range [0, 255]. 99 | crop_border (int): Cropped pixels in each edge of an image. These 100 | pixels are not involved in the SSIM calculation. 101 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 102 | Default: 'HWC'. 103 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 104 | 105 | Returns: 106 | float: ssim result. 107 | """ 108 | 109 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 110 | if input_order not in ['HWC', 'CHW']: 111 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 112 | img1 = reorder_image(img1, input_order=input_order) 113 | img2 = reorder_image(img2, input_order=input_order) 114 | img1 = img1.astype(np.float64) 115 | img2 = img2.astype(np.float64) 116 | 117 | if crop_border != 0: 118 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 119 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 120 | 121 | if test_y_channel: 122 | img1 = to_y_channel(img1) 123 | img2 = to_y_channel(img2) 124 | 125 | ssims = [] 126 | for i in range(img1.shape[2]): 127 | ssims.append(_ssim(img1[..., i], img2[..., i])) 128 | return np.array(ssims).mean() 129 | -------------------------------------------------------------------------------- /basicsr/archs/rrdbnet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from utils.registry import ARCH_REGISTRY 6 | from .arch_util import default_init_weights, make_layer, pixel_unshuffle 7 | 8 | 9 | class ResidualDenseBlock(nn.Module): 10 | """Residual Dense Block. 11 | 12 | Used in RRDB block in ESRGAN. 13 | 14 | Args: 15 | num_feat (int): Channel number of intermediate features. 16 | num_grow_ch (int): Channels for each growth. 17 | """ 18 | 19 | def __init__(self, num_feat=64, num_grow_ch=32): 20 | super(ResidualDenseBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 22 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 23 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 24 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 25 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 26 | 27 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | 29 | # initialization 30 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 31 | 32 | def forward(self, x): 33 | x1 = self.lrelu(self.conv1(x)) 34 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 35 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 36 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 37 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 38 | # Emperically, we use 0.2 to scale the residual for better performance 39 | return x5 * 0.2 + x 40 | 41 | 42 | class RRDB(nn.Module): 43 | """Residual in Residual Dense Block. 44 | 45 | Used in RRDB-Net in ESRGAN. 46 | 47 | Args: 48 | num_feat (int): Channel number of intermediate features. 49 | num_grow_ch (int): Channels for each growth. 50 | """ 51 | 52 | def __init__(self, num_feat, num_grow_ch=32): 53 | super(RRDB, self).__init__() 54 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 55 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 56 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 57 | 58 | def forward(self, x): 59 | out = self.rdb1(x) 60 | out = self.rdb2(out) 61 | out = self.rdb3(out) 62 | # Emperically, we use 0.2 to scale the residual for better performance 63 | return out * 0.2 + x 64 | 65 | 66 | @ARCH_REGISTRY.register() 67 | class RRDBNet(nn.Module): 68 | """Networks consisting of Residual in Residual Dense Block, which is used 69 | in ESRGAN. 70 | 71 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 72 | 73 | We extend ESRGAN for scale x2 and scale x1. 74 | Note: This is one option for scale 1, scale 2 in RRDBNet. 75 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 76 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 77 | 78 | Args: 79 | num_in_ch (int): Channel number of inputs. 80 | num_out_ch (int): Channel number of outputs. 81 | num_feat (int): Channel number of intermediate features. 82 | Default: 64 83 | num_block (int): Block number in the trunk network. Defaults: 23 84 | num_grow_ch (int): Channels for each growth. Default: 32. 85 | """ 86 | 87 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 88 | super(RRDBNet, self).__init__() 89 | self.scale = scale 90 | if scale == 2: 91 | num_in_ch = num_in_ch * 4 92 | elif scale == 1: 93 | num_in_ch = num_in_ch * 16 94 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 95 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 96 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 97 | # upsample 98 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 99 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 100 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 101 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 102 | 103 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 104 | 105 | def forward(self, x): 106 | if self.scale == 2: 107 | feat = pixel_unshuffle(x, scale=2) 108 | elif self.scale == 1: 109 | feat = pixel_unshuffle(x, scale=4) 110 | else: 111 | feat = x 112 | feat = self.conv_first(feat) 113 | body_feat = self.conv_body(self.body(feat)) 114 | feat = feat + body_feat 115 | # upsample 116 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 117 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 118 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 119 | return out -------------------------------------------------------------------------------- /basicsr/utils/video_util.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The code is modified from the Real-ESRGAN: 3 | https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_video.py 4 | 5 | ''' 6 | import cv2 7 | import sys 8 | import numpy as np 9 | 10 | try: 11 | import ffmpeg 12 | except ImportError: 13 | import pip 14 | pip.main(['install', '--user', 'ffmpeg-python']) 15 | import ffmpeg 16 | 17 | def get_video_meta_info(video_path): 18 | ret = {} 19 | probe = ffmpeg.probe(video_path) 20 | video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video'] 21 | has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams']) 22 | ret['width'] = video_streams[0]['width'] 23 | ret['height'] = video_streams[0]['height'] 24 | ret['fps'] = eval(video_streams[0]['avg_frame_rate']) 25 | ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None 26 | ret['nb_frames'] = int(video_streams[0]['nb_frames']) 27 | return ret 28 | 29 | class VideoReader: 30 | def __init__(self, video_path): 31 | self.paths = [] # for image&folder type 32 | self.audio = None 33 | try: 34 | self.stream_reader = ( 35 | ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24', 36 | loglevel='error').run_async( 37 | pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) 38 | except FileNotFoundError: 39 | print('Please install ffmpeg (not ffmpeg-python) by running\n', 40 | '\t$ conda install -c conda-forge ffmpeg') 41 | sys.exit(0) 42 | 43 | meta = get_video_meta_info(video_path) 44 | self.width = meta['width'] 45 | self.height = meta['height'] 46 | self.input_fps = meta['fps'] 47 | self.audio = meta['audio'] 48 | self.nb_frames = meta['nb_frames'] 49 | 50 | self.idx = 0 51 | 52 | def get_resolution(self): 53 | return self.height, self.width 54 | 55 | def get_fps(self): 56 | if self.input_fps is not None: 57 | return self.input_fps 58 | return 24 59 | 60 | def get_audio(self): 61 | return self.audio 62 | 63 | def __len__(self): 64 | return self.nb_frames 65 | 66 | def get_frame_from_stream(self): 67 | img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel 68 | if not img_bytes: 69 | return None 70 | img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3]) 71 | return img 72 | 73 | def get_frame_from_list(self): 74 | if self.idx >= self.nb_frames: 75 | return None 76 | img = cv2.imread(self.paths[self.idx]) 77 | self.idx += 1 78 | return img 79 | 80 | def get_frame(self): 81 | return self.get_frame_from_stream() 82 | 83 | 84 | def close(self): 85 | self.stream_reader.stdin.close() 86 | self.stream_reader.wait() 87 | 88 | 89 | class VideoWriter: 90 | def __init__(self, video_save_path, height, width, fps, audio): 91 | if height > 2160: 92 | print('You are generating video that is larger than 4K, which will be very slow due to IO speed.', 93 | 'We highly recommend to decrease the outscale(aka, -s).') 94 | if audio is not None: 95 | self.stream_writer = ( 96 | ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', 97 | framerate=fps).output( 98 | audio, 99 | video_save_path, 100 | pix_fmt='yuv420p', 101 | vcodec='libx264', 102 | loglevel='error', 103 | acodec='copy').overwrite_output().run_async( 104 | pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) 105 | else: 106 | self.stream_writer = ( 107 | ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', 108 | framerate=fps).output( 109 | video_save_path, pix_fmt='yuv420p', vcodec='libx264', 110 | loglevel='error').overwrite_output().run_async( 111 | pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) 112 | 113 | def write_frame(self, frame): 114 | try: 115 | frame = frame.astype(np.uint8).tobytes() 116 | self.stream_writer.stdin.write(frame) 117 | except BrokenPipeError: 118 | print('Please re-install ffmpeg and libx264 by running\n', 119 | '\t$ conda install -c conda-forge ffmpeg\n', 120 | '\t$ conda install -c conda-forge x264') 121 | sys.exit(0) 122 | 123 | def close(self): 124 | self.stream_writer.stdin.close() 125 | self.stream_writer.wait() -------------------------------------------------------------------------------- /basicsr/options/VarFormer_train_stage2.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: VarFormer2 3 | model_type: VarFormerIdxModel4_2_losszd 4 | num_gpu: 4 5 | manual_seed: 114514 6 | scale: 1 7 | crop_border: 0 8 | weights: [] # the sampling proportion of different datasets, [] means equal proportion 9 | 10 | 11 | datasets: 12 | train: 13 | name: rain_ll_blur_noise_haze 14 | type: LQGTDataset3_weight 15 | data_type: img 16 | data_num: 5 17 | 18 | dataroot_GT0: /home/wangsy/dataset/GoPro/train/groundtruth 19 | dataroot_LQ0: /home/wangsy/dataset/GoPro/train/input 20 | 21 | dataroot_GT1: /home/wangsy/dataset/llight/LOLdataset/our485/high 22 | dataroot_LQ1: /home/wangsy/dataset/llight/LOLdataset/our485/low 23 | 24 | dataroot_GT2: /home/wangsy/dataset/denoise/SIDD/train/target 25 | dataroot_LQ2: /home/wangsy/dataset/denoise/SIDD/train/input 26 | 27 | dataroot_GT3: /home/wangsy/dataset/derain/rain13k/train/target 28 | dataroot_LQ3: /home/wangsy/dataset/derain/rain13k/train/input 29 | 30 | dataroot_GT4: /home/wangsy/dataset/WED_BSD400/gt 31 | dataroot_LQ4: /home/wangsy/dataset/WED_BSD400/g_noise 32 | 33 | haze: true 34 | dataroot_LQ_z: /home/wangsy/dataset/dehaze/RESIDE2/OTS_ALPHA2/haze/OTS 35 | dataroot_GT_z: /home/wangsy/dataset/dehaze/RESIDE2/OTS_ALPHA2/clear/clear_images 36 | 37 | 38 | 39 | need_GT: true 40 | cond_scale: 4 41 | use_shuffle: true 42 | GT_size: 256 # #480 # 160 | 480 43 | use_flip: false 44 | use_rot: false 45 | # data loader 46 | num_worker_per_gpu: 4 47 | batch_size_per_gpu: 16 48 | dataset_enlarge_ratio: 10 49 | prefetch_mode: ~ 50 | 51 | 52 | val: 53 | name: Test100 54 | type: LQGTDataset 55 | dataroot_GT: /home/wangsy/dataset/derain/test/Test100/target 56 | dataroot_LQ: /home/wangsy/dataset/derain/test/Test100/input 57 | GT_size: 256 58 | cond_scale: 4 59 | save_img: false 60 | need_GT: true 61 | data_type: img 62 | 63 | 64 | # network structures 65 | network_g: 66 | type: VarFormer2 # VarFormer5 #VarFormer1 67 | depth: 20 # 16 #24 ## keep the same with var_path 68 | patch_nums: [1, 2, 3, 4, 5, 6, 8, 10, 13, 16] 69 | attn_l2_norm: True 70 | var_force_dpth: 7 71 | n_layers: 4 72 | if_enhance: True 73 | dec_adjust: True 74 | var_cross_c: 3 75 | ch_mult: [1, 1, 2, 2, 4] 76 | num_res_blocks: 2 77 | dropout: 0.0 78 | div_part: 3 79 | 80 | fix_modules: ['encoder','quant_conv','quantize','blocks','decoder','post_quant_conv'] 81 | vqgan_path: ./experiments/pretrained_models/vae_ch160v4096z32.pth # pretrained VQGAN 82 | var_path: ./experiments/pretrained_models/var_d20.pth # var_d16.pth # var_d24.pth # keep the same with depth 83 | network_vqgan: # this config is needed if no pre-calculated latent 84 | type: VarVQAutoEncoder 85 | model_path: ./experiments/pretrained_models/vae_ch160v4096z32.pth 86 | 87 | # path 88 | path: 89 | pretrain_network_g: ~ # pth from stage 1 or train from scratch 90 | param_key_g: params_ema 91 | strict_load_g: false 92 | pretrain_network_d: ~ #./experiments/pretrained_models/vqgan_discriminator.pth 93 | strict_load_d: true 94 | resume_state: ~ 95 | 96 | # base_lr(4.5e-6)*bach_size(4) 97 | train: 98 | use_hq_feat_loss: true 99 | feat_loss_weight: 0.5 100 | cross_entropy_loss: true 101 | entropy_loss_weight: 0.1 102 | fidelity_weight: 0 103 | perceptual_loss_weight: 1.0 104 | use_pixel_opt: true 105 | 106 | optim_g: 107 | type: Adam 108 | lr: !!float 1e-4 109 | weight_decay: 0 110 | betas: [0.9, 0.99] 111 | 112 | scheduler: 113 | type: MultiStepLR 114 | milestones: [400000, 450000] 115 | gamma: 0.5 116 | 117 | # scheduler: 118 | # type: CosineAnnealingRestartLR 119 | # periods: [500000] 120 | # restart_weights: [1] 121 | # eta_min: !!float 2e-5 # no lr reduce in official vqgan code 122 | 123 | total_iter: 700000 124 | 125 | warmup_iter: -1 # no warm up 126 | ema_decay: 0.998 127 | 128 | use_adaptive_weight: true 129 | 130 | net_g_start_iter: 0 131 | net_d_iters: 1 132 | net_d_start_iter: 0 133 | manual_seed: 620664 134 | 135 | perceptual_opt: 136 | type: PerceptualLoss 137 | use_input_norm: true 138 | range_norm: true 139 | layer_weights: {'relu5_4': 1.} 140 | 141 | # perceptual_opt: 142 | # type: LPIPSLoss 143 | # loss_weight: 1.0 144 | # use_input_norm: true 145 | # range_norm: true 146 | 147 | psnr_opt: 148 | type: PSNRLoss 149 | 150 | # ssim_opt: 151 | # type: SSIMLoss 152 | 153 | 154 | # validation settings 155 | val: 156 | val_freq: !!float 5e10 # no validation 157 | save_img: true 158 | 159 | metrics: 160 | psnr: # metric name, can be arbitrary 161 | type: calculate_psnr 162 | crop_border: 4 163 | test_y_channel: false 164 | ssim: # metric name, can be arbitrary 165 | type: calculate_ssim 166 | crop_border: 4 167 | test_y_channel: false 168 | 169 | # logging settings 170 | logger: 171 | print_freq: 100 172 | save_checkpoint_freq: !!float 1e4 173 | use_tb_logger: true 174 | wandb: 175 | project: ~ 176 | resume_id: ~ 177 | 178 | # dist training settings 179 | dist_params: 180 | backend: nccl 181 | port: 29412 182 | 183 | find_unused_parameters: true 184 | -------------------------------------------------------------------------------- /basicsr/losses/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import torch 8 | # from torch.autograd import Variable 9 | 10 | from losses.lpips.trainer import * 11 | from losses.lpips.lpips import * 12 | 13 | def normalize_tensor(in_feat,eps=1e-10): 14 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 15 | return in_feat/(norm_factor+eps) 16 | 17 | def l2(p0, p1, range=255.): 18 | return .5*np.mean((p0 / range - p1 / range)**2) 19 | 20 | def psnr(p0, p1, peak=255.): 21 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 22 | 23 | def dssim(p0, p1, range=255.): 24 | from skimage.measure import compare_ssim 25 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 26 | 27 | def rgb2lab(in_img,mean_cent=False): 28 | from skimage import color 29 | img_lab = color.rgb2lab(in_img) 30 | if(mean_cent): 31 | img_lab[:,:,0] = img_lab[:,:,0]-50 32 | return img_lab 33 | 34 | def tensor2np(tensor_obj): 35 | # change dimension of a tensor object into a numpy array 36 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 37 | 38 | def np2tensor(np_obj): 39 | # change dimenion of np array into tensor array 40 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 41 | 42 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 43 | # image tensor to lab tensor 44 | from skimage import color 45 | 46 | img = tensor2im(image_tensor) 47 | img_lab = color.rgb2lab(img) 48 | if(mc_only): 49 | img_lab[:,:,0] = img_lab[:,:,0]-50 50 | if(to_norm and not mc_only): 51 | img_lab[:,:,0] = img_lab[:,:,0]-50 52 | img_lab = img_lab/100. 53 | 54 | return np2tensor(img_lab) 55 | 56 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 57 | from skimage import color 58 | import warnings 59 | warnings.filterwarnings("ignore") 60 | 61 | lab = tensor2np(lab_tensor)*100. 62 | lab[:,:,0] = lab[:,:,0]+50 63 | 64 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 65 | if(return_inbnd): 66 | # convert back to lab, see if we match 67 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 68 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 69 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 70 | return (im2tensor(rgb_back),mask) 71 | else: 72 | return im2tensor(rgb_back) 73 | 74 | def load_image(path): 75 | if(path[-3:] == 'dng'): 76 | import rawpy 77 | with rawpy.imread(path) as raw: 78 | img = raw.postprocess() 79 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png' or path[-4:]=='jpeg'): 80 | import cv2 81 | return cv2.imread(path)[:,:,::-1] 82 | else: 83 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8') 84 | 85 | return img 86 | 87 | def rgb2lab(input): 88 | from skimage import color 89 | return color.rgb2lab(input / 255.) 90 | 91 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 92 | image_numpy = image_tensor[0].cpu().float().numpy() 93 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 94 | return image_numpy.astype(imtype) 95 | 96 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 97 | return torch.Tensor((image / factor - cent) 98 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 99 | 100 | def tensor2vec(vector_tensor): 101 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 102 | 103 | 104 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 105 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 106 | image_numpy = image_tensor[0].cpu().float().numpy() 107 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 108 | return image_numpy.astype(imtype) 109 | 110 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 111 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 112 | return torch.Tensor((image / factor - cent) 113 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 114 | 115 | 116 | 117 | def voc_ap(rec, prec, use_07_metric=False): 118 | """ ap = voc_ap(rec, prec, [use_07_metric]) 119 | Compute VOC AP given precision and recall. 120 | If use_07_metric is true, uses the 121 | VOC 07 11 point method (default:False). 122 | """ 123 | if use_07_metric: 124 | # 11 point metric 125 | ap = 0. 126 | for t in np.arange(0., 1.1, 0.1): 127 | if np.sum(rec >= t) == 0: 128 | p = 0 129 | else: 130 | p = np.max(prec[rec >= t]) 131 | ap = ap + p / 11. 132 | else: 133 | # correct AP calculation 134 | # first append sentinel values at the end 135 | mrec = np.concatenate(([0.], rec, [1.])) 136 | mpre = np.concatenate(([0.], prec, [0.])) 137 | 138 | # compute the precision envelope 139 | for i in range(mpre.size - 1, 0, -1): 140 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 141 | 142 | # to calculate area under PR curve, look for points 143 | # where X axis (recall) changes value 144 | i = np.where(mrec[1:] != mrec[:-1])[0] 145 | 146 | # and sum (\Delta recall) * prec 147 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 148 | return ap 149 | 150 | -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import random 4 | import time 5 | import torch 6 | import numpy as np 7 | from os import path as osp 8 | 9 | from .dist_util import master_only 10 | from .logger import get_root_logger 11 | 12 | IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ 13 | torch.__version__)[0][:3])] >= [1, 12, 0] 14 | 15 | def gpu_is_available(): 16 | if IS_HIGH_VERSION: 17 | if torch.backends.mps.is_available(): 18 | return True 19 | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False 20 | 21 | def get_device(gpu_id=None): 22 | if gpu_id is None: 23 | gpu_str = '' 24 | elif isinstance(gpu_id, int): 25 | gpu_str = f':{gpu_id}' 26 | else: 27 | raise TypeError('Input should be int value.') 28 | 29 | if IS_HIGH_VERSION: 30 | if torch.backends.mps.is_available(): 31 | return torch.device('mps'+gpu_str) 32 | return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') 33 | 34 | 35 | def set_random_seed(seed): 36 | """Set random seeds.""" 37 | random.seed(seed) 38 | np.random.seed(seed) 39 | torch.manual_seed(seed) 40 | torch.cuda.manual_seed(seed) 41 | torch.cuda.manual_seed_all(seed) 42 | 43 | 44 | def get_time_str(): 45 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 46 | 47 | 48 | def mkdir_and_rename(path): 49 | """mkdirs. If path exists, rename it with timestamp and create a new one. 50 | 51 | Args: 52 | path (str): Folder path. 53 | """ 54 | if osp.exists(path): 55 | new_name = path + '_archived_' + get_time_str() 56 | print(f'Path already exists. Rename it to {new_name}', flush=True) 57 | os.rename(path, new_name) 58 | os.makedirs(path, exist_ok=True) 59 | 60 | 61 | @master_only 62 | def make_exp_dirs(opt): 63 | """Make dirs for experiments.""" 64 | path_opt = opt['path'].copy() 65 | if opt['is_train']: 66 | mkdir_and_rename(path_opt.pop('experiments_root')) 67 | else: 68 | mkdir_and_rename(path_opt.pop('results_root')) 69 | for key, path in path_opt.items(): 70 | if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key): 71 | os.makedirs(path, exist_ok=True) 72 | 73 | 74 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 75 | """Scan a directory to find the interested files. 76 | 77 | Args: 78 | dir_path (str): Path of the directory. 79 | suffix (str | tuple(str), optional): File suffix that we are 80 | interested in. Default: None. 81 | recursive (bool, optional): If set to True, recursively scan the 82 | directory. Default: False. 83 | full_path (bool, optional): If set to True, include the dir_path. 84 | Default: False. 85 | 86 | Returns: 87 | A generator for all the interested files with relative pathes. 88 | """ 89 | 90 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 91 | raise TypeError('"suffix" must be a string or tuple of strings') 92 | 93 | root = dir_path 94 | 95 | def _scandir(dir_path, suffix, recursive): 96 | for entry in os.scandir(dir_path): 97 | if not entry.name.startswith('.') and entry.is_file(): 98 | if full_path: 99 | return_path = entry.path 100 | else: 101 | return_path = osp.relpath(entry.path, root) 102 | 103 | if suffix is None: 104 | yield return_path 105 | elif return_path.endswith(suffix): 106 | yield return_path 107 | else: 108 | if recursive: 109 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 110 | else: 111 | continue 112 | 113 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 114 | 115 | 116 | def check_resume(opt, resume_iter): 117 | """Check resume states and pretrain_network paths. 118 | 119 | Args: 120 | opt (dict): Options. 121 | resume_iter (int): Resume iteration. 122 | """ 123 | logger = get_root_logger() 124 | if opt['path']['resume_state']: 125 | # get all the networks 126 | networks = [key for key in opt.keys() if key.startswith('network_')] 127 | flag_pretrain = False 128 | for network in networks: 129 | if opt['path'].get(f'pretrain_{network}') is not None: 130 | flag_pretrain = True 131 | if flag_pretrain: 132 | logger.warning('pretrain_network path will be ignored during resuming.') 133 | # set pretrained model paths 134 | for network in networks: 135 | name = f'pretrain_{network}' 136 | basename = network.replace('network_', '') 137 | if opt['path'].get('ignore_resume_networks') is None or (basename 138 | not in opt['path']['ignore_resume_networks']): 139 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 140 | logger.info(f"Set {name} to {opt['path'][name]}") 141 | 142 | 143 | def sizeof_fmt(size, suffix='B'): 144 | """Get human readable file size. 145 | 146 | Args: 147 | size (int): File size. 148 | suffix (str): Suffix. Default: 'B'. 149 | 150 | Return: 151 | str: Formated file siz. 152 | """ 153 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 154 | if abs(size) < 1024.0: 155 | return f'{size:3.1f} {unit}{suffix}' 156 | size /= 1024.0 157 | return f'{size:3.1f} Y{suffix}' 158 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.nn import functional as F 6 | 7 | try: 8 | from . import upfirdn2d_ext 9 | except ImportError: 10 | import os 11 | BASICSR_JIT = os.getenv('BASICSR_JIT') 12 | if BASICSR_JIT == 'True': 13 | from torch.utils.cpp_extension import load 14 | module_path = os.path.dirname(__file__) 15 | upfirdn2d_ext = load( 16 | 'upfirdn2d', 17 | sources=[ 18 | os.path.join(module_path, 'src', 'upfirdn2d.cpp'), 19 | os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), 20 | ], 21 | ) 22 | 23 | 24 | class UpFirDn2dBackward(Function): 25 | 26 | @staticmethod 27 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): 28 | 29 | up_x, up_y = up 30 | down_x, down_y = down 31 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 32 | 33 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 34 | 35 | grad_input = upfirdn2d_ext.upfirdn2d( 36 | grad_output, 37 | grad_kernel, 38 | down_x, 39 | down_y, 40 | up_x, 41 | up_y, 42 | g_pad_x0, 43 | g_pad_x1, 44 | g_pad_y0, 45 | g_pad_y1, 46 | ) 47 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 48 | 49 | ctx.save_for_backward(kernel) 50 | 51 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 52 | 53 | ctx.up_x = up_x 54 | ctx.up_y = up_y 55 | ctx.down_x = down_x 56 | ctx.down_y = down_y 57 | ctx.pad_x0 = pad_x0 58 | ctx.pad_x1 = pad_x1 59 | ctx.pad_y0 = pad_y0 60 | ctx.pad_y1 = pad_y1 61 | ctx.in_size = in_size 62 | ctx.out_size = out_size 63 | 64 | return grad_input 65 | 66 | @staticmethod 67 | def backward(ctx, gradgrad_input): 68 | kernel, = ctx.saved_tensors 69 | 70 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 71 | 72 | gradgrad_out = upfirdn2d_ext.upfirdn2d( 73 | gradgrad_input, 74 | kernel, 75 | ctx.up_x, 76 | ctx.up_y, 77 | ctx.down_x, 78 | ctx.down_y, 79 | ctx.pad_x0, 80 | ctx.pad_x1, 81 | ctx.pad_y0, 82 | ctx.pad_y1, 83 | ) 84 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], 85 | # ctx.out_size[1], ctx.in_size[3]) 86 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) 87 | 88 | return gradgrad_out, None, None, None, None, None, None, None, None 89 | 90 | 91 | class UpFirDn2d(Function): 92 | 93 | @staticmethod 94 | def forward(ctx, input, kernel, up, down, pad): 95 | up_x, up_y = up 96 | down_x, down_y = down 97 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 98 | 99 | kernel_h, kernel_w = kernel.shape 100 | batch, channel, in_h, in_w = input.shape 101 | ctx.in_size = input.shape 102 | 103 | input = input.reshape(-1, in_h, in_w, 1) 104 | 105 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 106 | 107 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 108 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 109 | ctx.out_size = (out_h, out_w) 110 | 111 | ctx.up = (up_x, up_y) 112 | ctx.down = (down_x, down_y) 113 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 114 | 115 | g_pad_x0 = kernel_w - pad_x0 - 1 116 | g_pad_y0 = kernel_h - pad_y0 - 1 117 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 118 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 119 | 120 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 121 | 122 | out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) 123 | # out = out.view(major, out_h, out_w, minor) 124 | out = out.view(-1, channel, out_h, out_w) 125 | 126 | return out 127 | 128 | @staticmethod 129 | def backward(ctx, grad_output): 130 | kernel, grad_kernel = ctx.saved_tensors 131 | 132 | grad_input = UpFirDn2dBackward.apply( 133 | grad_output, 134 | kernel, 135 | grad_kernel, 136 | ctx.up, 137 | ctx.down, 138 | ctx.pad, 139 | ctx.g_pad, 140 | ctx.in_size, 141 | ctx.out_size, 142 | ) 143 | 144 | return grad_input, None, None, None, None 145 | 146 | 147 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 148 | if input.device.type == 'cpu': 149 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) 150 | else: 151 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) 152 | 153 | return out 154 | 155 | 156 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): 157 | _, channel, in_h, in_w = input.shape 158 | input = input.reshape(-1, in_h, in_w, 1) 159 | 160 | _, in_h, in_w, minor = input.shape 161 | kernel_h, kernel_w = kernel.shape 162 | 163 | out = input.view(-1, in_h, 1, in_w, 1, minor) 164 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 165 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 166 | 167 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 168 | out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] 169 | 170 | out = out.permute(0, 3, 1, 2) 171 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 172 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 173 | out = F.conv2d(out, w) 174 | out = out.reshape( 175 | -1, 176 | minor, 177 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 178 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 179 | ) 180 | out = out.permute(0, 2, 3, 1) 181 | out = out[:, ::down_y, ::down_x, :] 182 | 183 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 184 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 185 | 186 | return out.view(-1, channel, out_h, out_w) 187 | -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError('Please install memcached to enable MemcachedBackend.') 40 | 41 | self.server_list_cfg = server_list_cfg 42 | self.client_cfg = client_cfg 43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 44 | # mc.pyvector servers as a point which points to a memory cache 45 | self._mc_buffer = mc.pyvector() 46 | 47 | def get(self, filepath): 48 | filepath = str(filepath) 49 | import mc 50 | self._client.Get(filepath, self._mc_buffer) 51 | value_buf = mc.ConvertBuffer(self._mc_buffer) 52 | return value_buf 53 | 54 | def get_text(self, filepath): 55 | raise NotImplementedError 56 | 57 | 58 | class HardDiskBackend(BaseStorageBackend): 59 | """Raw hard disks storage backend.""" 60 | 61 | def get(self, filepath): 62 | filepath = str(filepath) 63 | with open(filepath, 'rb') as f: 64 | value_buf = f.read() 65 | return value_buf 66 | 67 | def get_text(self, filepath): 68 | filepath = str(filepath) 69 | with open(filepath, 'r') as f: 70 | value_buf = f.read() 71 | return value_buf 72 | 73 | 74 | class LmdbBackend(BaseStorageBackend): 75 | """Lmdb storage backend. 76 | 77 | Args: 78 | db_paths (str | list[str]): Lmdb database paths. 79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 80 | readonly (bool, optional): Lmdb environment parameter. If True, 81 | disallow any write operations. Default: True. 82 | lock (bool, optional): Lmdb environment parameter. If False, when 83 | concurrent access occurs, do not lock the database. Default: False. 84 | readahead (bool, optional): Lmdb environment parameter. If False, 85 | disable the OS filesystem readahead mechanism, which may improve 86 | random read performance when a database is larger than RAM. 87 | Default: False. 88 | 89 | Attributes: 90 | db_paths (list): Lmdb database path. 91 | _client (list): A list of several lmdb envs. 92 | """ 93 | 94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): 95 | try: 96 | import lmdb 97 | except ImportError: 98 | raise ImportError('Please install lmdb to enable LmdbBackend.') 99 | 100 | if isinstance(client_keys, str): 101 | client_keys = [client_keys] 102 | 103 | if isinstance(db_paths, list): 104 | self.db_paths = [str(v) for v in db_paths] 105 | elif isinstance(db_paths, str): 106 | self.db_paths = [str(db_paths)] 107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' 108 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 109 | 110 | self._client = {} 111 | for client, path in zip(client_keys, self.db_paths): 112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 113 | 114 | def get(self, filepath, client_key): 115 | """Get values according to the filepath from one lmdb named client_key. 116 | 117 | Args: 118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 119 | client_key (str): Used for distinguishing differnet lmdb envs. 120 | """ 121 | filepath = str(filepath) 122 | assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') 123 | client = self._client[client_key] 124 | with client.begin(write=False) as txn: 125 | value_buf = txn.get(filepath.encode('ascii')) 126 | return value_buf 127 | 128 | def get_text(self, filepath): 129 | raise NotImplementedError 130 | 131 | 132 | class FileClient(object): 133 | """A general file client to access files in different backend. 134 | 135 | The client loads a file or text in a specified backend from its path 136 | and return it as a binary file. it can also register other backend 137 | accessor with a given name and backend class. 138 | 139 | Attributes: 140 | backend (str): The storage backend type. Options are "disk", 141 | "memcached" and "lmdb". 142 | client (:obj:`BaseStorageBackend`): The backend object. 143 | """ 144 | 145 | _backends = { 146 | 'disk': HardDiskBackend, 147 | 'memcached': MemcachedBackend, 148 | 'lmdb': LmdbBackend, 149 | } 150 | 151 | def __init__(self, backend='disk', **kwargs): 152 | if backend not in self._backends: 153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones' 154 | f' are {list(self._backends.keys())}') 155 | self.backend = backend 156 | self.client = self._backends[backend](**kwargs) 157 | 158 | def get(self, filepath, client_key='default'): 159 | # client_key is used only for lmdb, where different fileclients have 160 | # different lmdb environments. 161 | if self.backend == 'lmdb': 162 | return self.client.get(filepath, client_key) 163 | else: 164 | return self.client.get(filepath) 165 | 166 | def get_text(self, filepath): 167 | return self.client.get_text(filepath) 168 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | if img.dtype == 'float64': 25 | img = img.astype('float32') 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | img = torch.from_numpy(img.transpose(2, 0, 1)) 28 | if float32: 29 | img = img.float() 30 | return img 31 | 32 | if isinstance(imgs, list): 33 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 34 | else: 35 | return _totensor(imgs, bgr2rgb, float32) 36 | 37 | 38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 39 | """Convert torch Tensors into image numpy arrays. 40 | 41 | After clamping to [min, max], values will be normalized to [0, 1]. 42 | 43 | Args: 44 | tensor (Tensor or list[Tensor]): Accept shapes: 45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 46 | 2) 3D Tensor of shape (3/1 x H x W); 47 | 3) 2D Tensor of shape (H x W). 48 | Tensor channel should be in RGB order. 49 | rgb2bgr (bool): Whether to change rgb to bgr. 50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 51 | to uint8 type with range [0, 255]; otherwise, float type with 52 | range [0, 1]. Default: ``np.uint8``. 53 | min_max (tuple[int]): min and max values for clamp. 54 | 55 | Returns: 56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 57 | shape (H x W). The channel order is BGR. 58 | """ 59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if torch.is_tensor(tensor): 63 | tensor = [tensor] 64 | result = [] 65 | for _tensor in tensor: 66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 68 | 69 | n_dim = _tensor.dim() 70 | if n_dim == 4: 71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 72 | img_np = img_np.transpose(1, 2, 0) 73 | if rgb2bgr: 74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 75 | elif n_dim == 3: 76 | img_np = _tensor.numpy() 77 | img_np = img_np.transpose(1, 2, 0) 78 | if img_np.shape[2] == 1: # gray image 79 | img_np = np.squeeze(img_np, axis=2) 80 | else: 81 | if rgb2bgr: 82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 83 | elif n_dim == 2: 84 | img_np = _tensor.numpy() 85 | else: 86 | raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}') 87 | if out_type == np.uint8: 88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 89 | img_np = (img_np * 255.0).round() 90 | img_np = img_np.astype(out_type) 91 | result.append(img_np) 92 | if len(result) == 1: 93 | result = result[0] 94 | return result 95 | 96 | 97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): 98 | """This implementation is slightly faster than tensor2img. 99 | It now only supports torch tensor with shape (1, c, h, w). 100 | 101 | Args: 102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w). 103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True. 104 | min_max (tuple[int]): min and max values for clamp. 105 | """ 106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) 107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 108 | output = output.type(torch.uint8).cpu().numpy() 109 | if rgb2bgr: 110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 111 | return output 112 | 113 | 114 | def imfrombytes(content, flag='color', float32=False): 115 | """Read an image from bytes. 116 | 117 | Args: 118 | content (bytes): Image bytes got from files or other streams. 119 | flag (str): Flags specifying the color type of a loaded image, 120 | candidates are `color`, `grayscale` and `unchanged`. 121 | float32 (bool): Whether to change to float32., If True, will also norm 122 | to [0, 1]. Default: False. 123 | 124 | Returns: 125 | ndarray: Loaded image array. 126 | """ 127 | img_np = np.frombuffer(content, np.uint8) 128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} 129 | img = cv2.imdecode(img_np, imread_flags[flag]) 130 | if float32: 131 | img = img.astype(np.float32) / 255. 132 | return img 133 | 134 | 135 | def imwrite(img, file_path, params=None, auto_mkdir=True): 136 | """Write image to file. 137 | 138 | Args: 139 | img (ndarray): Image array to be written. 140 | file_path (str): Image file path. 141 | params (None or list): Same as opencv's :func:`imwrite` interface. 142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 143 | whether to create it automatically. 144 | 145 | Returns: 146 | bool: Successful or not. 147 | """ 148 | if auto_mkdir: 149 | dir_name = os.path.abspath(os.path.dirname(file_path)) 150 | os.makedirs(dir_name, exist_ok=True) 151 | return cv2.imwrite(file_path, img, params) 152 | 153 | 154 | def crop_border(imgs, crop_border): 155 | """Crop borders of images. 156 | 157 | Args: 158 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 159 | crop_border (int): Crop border for each end of height and weight. 160 | 161 | Returns: 162 | list[ndarray]: Cropped images. 163 | """ 164 | if crop_border == 0: 165 | return imgs 166 | else: 167 | if isinstance(imgs, list): 168 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] 169 | else: 170 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] 171 | -------------------------------------------------------------------------------- /basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | initialized_logger = {} 8 | 9 | 10 | class MessageLogger(): 11 | """Message logger for printing. 12 | Args: 13 | opt (dict): Config. It contains the following keys: 14 | name (str): Exp name. 15 | logger (dict): Contains 'print_freq' (str) for logger interval. 16 | train (dict): Contains 'total_iter' (int) for total iters. 17 | use_tb_logger (bool): Use tensorboard logger. 18 | start_iter (int): Start iter. Default: 1. 19 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 20 | """ 21 | 22 | def __init__(self, opt, start_iter=1, tb_logger=None): 23 | self.exp_name = opt['name'] 24 | self.interval = opt['logger']['print_freq'] 25 | self.start_iter = start_iter 26 | self.max_iters = opt['train']['total_iter'] 27 | self.use_tb_logger = opt['logger']['use_tb_logger'] 28 | self.tb_logger = tb_logger 29 | self.start_time = time.time() 30 | self.logger = get_root_logger() 31 | 32 | @master_only 33 | def __call__(self, log_vars): 34 | """Format logging message. 35 | Args: 36 | log_vars (dict): It contains the following keys: 37 | epoch (int): Epoch number. 38 | iter (int): Current iter. 39 | lrs (list): List for learning rates. 40 | time (float): Iter time. 41 | data_time (float): Data time for each iter. 42 | """ 43 | # epoch, iter, learning rates 44 | epoch = log_vars.pop('epoch') 45 | current_iter = log_vars.pop('iter') 46 | lrs = log_vars.pop('lrs') 47 | 48 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(') 49 | for v in lrs: 50 | message += f'{v:.3e},' 51 | message += ')] ' 52 | 53 | # time and estimated time 54 | if 'time' in log_vars.keys(): 55 | iter_time = log_vars.pop('time') 56 | data_time = log_vars.pop('data_time') 57 | 58 | total_time = time.time() - self.start_time 59 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 60 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 61 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 62 | message += f'[eta: {eta_str}, ' 63 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 64 | 65 | # other items, especially losses 66 | for k, v in log_vars.items(): 67 | message += f'{k}: {v:.4e} ' 68 | # tensorboard logger 69 | if self.use_tb_logger: 70 | # if k.startswith('l_'): 71 | # self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 72 | # else: 73 | self.tb_logger.add_scalar(k, v, current_iter) 74 | self.logger.info(message) 75 | 76 | 77 | @master_only 78 | def init_tb_logger(log_dir): 79 | from torch.utils.tensorboard import SummaryWriter 80 | tb_logger = SummaryWriter(log_dir=log_dir) 81 | return tb_logger 82 | 83 | 84 | @master_only 85 | def init_wandb_logger(opt): 86 | """We now only use wandb to sync tensorboard log.""" 87 | import wandb 88 | logger = logging.getLogger('basicsr') 89 | 90 | project = opt['logger']['wandb']['project'] 91 | resume_id = opt['logger']['wandb'].get('resume_id') 92 | if resume_id: 93 | wandb_id = resume_id 94 | resume = 'allow' 95 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 96 | else: 97 | wandb_id = wandb.util.generate_id() 98 | resume = 'never' 99 | 100 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) 101 | 102 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 103 | 104 | 105 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 106 | """Get the root logger. 107 | The logger will be initialized if it has not been initialized. By default a 108 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 109 | also be added. 110 | Args: 111 | logger_name (str): root logger name. Default: 'basicsr'. 112 | log_file (str | None): The log filename. If specified, a FileHandler 113 | will be added to the root logger. 114 | log_level (int): The root logger level. Note that only the process of 115 | rank 0 is affected, while other processes will set the level to 116 | "Error" and be silent most of the time. 117 | Returns: 118 | logging.Logger: The root logger. 119 | """ 120 | logger = logging.getLogger(logger_name) 121 | # if the logger has been initialized, just return it 122 | if logger_name in initialized_logger: 123 | return logger 124 | 125 | format_str = '%(asctime)s %(levelname)s: %(message)s' 126 | stream_handler = logging.StreamHandler() 127 | stream_handler.setFormatter(logging.Formatter(format_str)) 128 | logger.addHandler(stream_handler) 129 | logger.propagate = False 130 | rank, _ = get_dist_info() 131 | if rank != 0: 132 | logger.setLevel('ERROR') 133 | elif log_file is not None: 134 | logger.setLevel(log_level) 135 | # add file handler 136 | # file_handler = logging.FileHandler(log_file, 'w') 137 | file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log 138 | file_handler.setFormatter(logging.Formatter(format_str)) 139 | file_handler.setLevel(log_level) 140 | logger.addHandler(file_handler) 141 | initialized_logger[logger_name] = True 142 | return logger 143 | 144 | 145 | def get_env_info(): 146 | """Get environment information. 147 | Currently, only log the software version. 148 | """ 149 | import torch 150 | import torchvision 151 | 152 | # from basicsr.version import __version__ 153 | msg = r""" 154 | ____ _ _____ ____ 155 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 156 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 157 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 158 | /_____/ \__,_//____//_/ \___//____//_/ |_| 159 | ______ __ __ __ __ 160 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 161 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 162 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 163 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 164 | """ 165 | # msg += ('\nVersion Information: ' 166 | # f'\n\tBasicSR: {__version__}' 167 | # f'\n\tPyTorch: {torch.__version__}' 168 | # f'\n\tTorchVision: {torchvision.__version__}') 169 | return msg -------------------------------------------------------------------------------- /basicsr/losses/lpips/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | 5 | class squeezenet(torch.nn.Module): 6 | def __init__(self, requires_grad=False, pretrained=True): 7 | super(squeezenet, self).__init__() 8 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 9 | self.slice1 = torch.nn.Sequential() 10 | self.slice2 = torch.nn.Sequential() 11 | self.slice3 = torch.nn.Sequential() 12 | self.slice4 = torch.nn.Sequential() 13 | self.slice5 = torch.nn.Sequential() 14 | self.slice6 = torch.nn.Sequential() 15 | self.slice7 = torch.nn.Sequential() 16 | self.N_slices = 7 17 | for x in range(2): 18 | self.slice1.add_module(str(x), pretrained_features[x]) 19 | for x in range(2,5): 20 | self.slice2.add_module(str(x), pretrained_features[x]) 21 | for x in range(5, 8): 22 | self.slice3.add_module(str(x), pretrained_features[x]) 23 | for x in range(8, 10): 24 | self.slice4.add_module(str(x), pretrained_features[x]) 25 | for x in range(10, 11): 26 | self.slice5.add_module(str(x), pretrained_features[x]) 27 | for x in range(11, 12): 28 | self.slice6.add_module(str(x), pretrained_features[x]) 29 | for x in range(12, 13): 30 | self.slice7.add_module(str(x), pretrained_features[x]) 31 | if not requires_grad: 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | def forward(self, X): 36 | h = self.slice1(X) 37 | h_relu1 = h 38 | h = self.slice2(h) 39 | h_relu2 = h 40 | h = self.slice3(h) 41 | h_relu3 = h 42 | h = self.slice4(h) 43 | h_relu4 = h 44 | h = self.slice5(h) 45 | h_relu5 = h 46 | h = self.slice6(h) 47 | h_relu6 = h 48 | h = self.slice7(h) 49 | h_relu7 = h 50 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 51 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 52 | 53 | return out 54 | 55 | 56 | class alexnet(torch.nn.Module): 57 | def __init__(self, requires_grad=False, pretrained=True): 58 | super(alexnet, self).__init__() 59 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 60 | self.slice1 = torch.nn.Sequential() 61 | self.slice2 = torch.nn.Sequential() 62 | self.slice3 = torch.nn.Sequential() 63 | self.slice4 = torch.nn.Sequential() 64 | self.slice5 = torch.nn.Sequential() 65 | self.N_slices = 5 66 | for x in range(2): 67 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 68 | for x in range(2, 5): 69 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 70 | for x in range(5, 8): 71 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 72 | for x in range(8, 10): 73 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 74 | for x in range(10, 12): 75 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 76 | if not requires_grad: 77 | for param in self.parameters(): 78 | param.requires_grad = False 79 | 80 | def forward(self, X): 81 | h = self.slice1(X) 82 | h_relu1 = h 83 | h = self.slice2(h) 84 | h_relu2 = h 85 | h = self.slice3(h) 86 | h_relu3 = h 87 | h = self.slice4(h) 88 | h_relu4 = h 89 | h = self.slice5(h) 90 | h_relu5 = h 91 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 92 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 93 | 94 | return out 95 | 96 | class vgg16(torch.nn.Module): 97 | def __init__(self, requires_grad=False, pretrained=True): 98 | super(vgg16, self).__init__() 99 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 100 | self.slice1 = torch.nn.Sequential() 101 | self.slice2 = torch.nn.Sequential() 102 | self.slice3 = torch.nn.Sequential() 103 | self.slice4 = torch.nn.Sequential() 104 | self.slice5 = torch.nn.Sequential() 105 | self.N_slices = 5 106 | for x in range(4): 107 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 108 | for x in range(4, 9): 109 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(9, 16): 111 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 112 | for x in range(16, 23): 113 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 114 | for x in range(23, 30): 115 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 116 | if not requires_grad: 117 | for param in self.parameters(): 118 | param.requires_grad = False 119 | 120 | def forward(self, X): 121 | h = self.slice1(X) 122 | h_relu1_2 = h 123 | h = self.slice2(h) 124 | h_relu2_2 = h 125 | h = self.slice3(h) 126 | h_relu3_3 = h 127 | h = self.slice4(h) 128 | h_relu4_3 = h 129 | h = self.slice5(h) 130 | h_relu5_3 = h 131 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 132 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 133 | 134 | return out 135 | 136 | 137 | 138 | class resnet(torch.nn.Module): 139 | def __init__(self, requires_grad=False, pretrained=True, num=18): 140 | super(resnet, self).__init__() 141 | if(num==18): 142 | self.net = tv.resnet18(pretrained=pretrained) 143 | elif(num==34): 144 | self.net = tv.resnet34(pretrained=pretrained) 145 | elif(num==50): 146 | self.net = tv.resnet50(pretrained=pretrained) 147 | elif(num==101): 148 | self.net = tv.resnet101(pretrained=pretrained) 149 | elif(num==152): 150 | self.net = tv.resnet152(pretrained=pretrained) 151 | self.N_slices = 5 152 | 153 | self.conv1 = self.net.conv1 154 | self.bn1 = self.net.bn1 155 | self.relu = self.net.relu 156 | self.maxpool = self.net.maxpool 157 | self.layer1 = self.net.layer1 158 | self.layer2 = self.net.layer2 159 | self.layer3 = self.net.layer3 160 | self.layer4 = self.net.layer4 161 | 162 | def forward(self, X): 163 | h = self.conv1(X) 164 | h = self.bn1(h) 165 | h = self.relu(h) 166 | h_relu1 = h 167 | h = self.maxpool(h) 168 | h = self.layer1(h) 169 | h_conv2 = h 170 | h = self.layer2(h) 171 | h_conv3 = h 172 | h = self.layer3(h) 173 | h_conv4 = h 174 | h = self.layer4(h) 175 | h_conv5 = h 176 | 177 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 178 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 179 | 180 | return out 181 | -------------------------------------------------------------------------------- /basicsr/utils/lmdb_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import lmdb 3 | import sys 4 | from multiprocessing import Pool 5 | from os import path as osp 6 | from tqdm import tqdm 7 | 8 | 9 | def make_lmdb_from_imgs(data_path, 10 | lmdb_path, 11 | img_path_list, 12 | keys, 13 | batch=5000, 14 | compress_level=1, 15 | multiprocessing_read=False, 16 | n_thread=40, 17 | map_size=None): 18 | """Make lmdb from images. 19 | 20 | Contents of lmdb. The file structure is: 21 | example.lmdb 22 | ├── data.mdb 23 | ├── lock.mdb 24 | ├── meta_info.txt 25 | 26 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 27 | https://lmdb.readthedocs.io/en/release/ for more details. 28 | 29 | The meta_info.txt is a specified txt file to record the meta information 30 | of our datasets. It will be automatically created when preparing 31 | datasets by our provided dataset tools. 32 | Each line in the txt file records 1)image name (with extension), 33 | 2)image shape, and 3)compression level, separated by a white space. 34 | 35 | For example, the meta information could be: 36 | `000_00000000.png (720,1280,3) 1`, which means: 37 | 1) image name (with extension): 000_00000000.png; 38 | 2) image shape: (720,1280,3); 39 | 3) compression level: 1 40 | 41 | We use the image name without extension as the lmdb key. 42 | 43 | If `multiprocessing_read` is True, it will read all the images to memory 44 | using multiprocessing. Thus, your server needs to have enough memory. 45 | 46 | Args: 47 | data_path (str): Data path for reading images. 48 | lmdb_path (str): Lmdb save path. 49 | img_path_list (str): Image path list. 50 | keys (str): Used for lmdb keys. 51 | batch (int): After processing batch images, lmdb commits. 52 | Default: 5000. 53 | compress_level (int): Compress level when encoding images. Default: 1. 54 | multiprocessing_read (bool): Whether use multiprocessing to read all 55 | the images to memory. Default: False. 56 | n_thread (int): For multiprocessing. 57 | map_size (int | None): Map size for lmdb env. If None, use the 58 | estimated size from images. Default: None 59 | """ 60 | 61 | assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' 62 | f'but got {len(img_path_list)} and {len(keys)}') 63 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...') 64 | print(f'Totoal images: {len(img_path_list)}') 65 | if not lmdb_path.endswith('.lmdb'): 66 | raise ValueError("lmdb_path must end with '.lmdb'.") 67 | if osp.exists(lmdb_path): 68 | print(f'Folder {lmdb_path} already exists. Exit.') 69 | sys.exit(1) 70 | 71 | if multiprocessing_read: 72 | # read all the images to memory (multiprocessing) 73 | dataset = {} # use dict to keep the order for multiprocessing 74 | shapes = {} 75 | print(f'Read images with multiprocessing, #thread: {n_thread} ...') 76 | pbar = tqdm(total=len(img_path_list), unit='image') 77 | 78 | def callback(arg): 79 | """get the image data and update pbar.""" 80 | key, dataset[key], shapes[key] = arg 81 | pbar.update(1) 82 | pbar.set_description(f'Read {key}') 83 | 84 | pool = Pool(n_thread) 85 | for path, key in zip(img_path_list, keys): 86 | pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) 87 | pool.close() 88 | pool.join() 89 | pbar.close() 90 | print(f'Finish reading {len(img_path_list)} images.') 91 | 92 | # create lmdb environment 93 | if map_size is None: 94 | # obtain data size for one image 95 | img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) 96 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 97 | data_size_per_img = img_byte.nbytes 98 | print('Data size per image is: ', data_size_per_img) 99 | data_size = data_size_per_img * len(img_path_list) 100 | map_size = data_size * 10 101 | 102 | env = lmdb.open(lmdb_path, map_size=map_size) 103 | 104 | # write data to lmdb 105 | pbar = tqdm(total=len(img_path_list), unit='chunk') 106 | txn = env.begin(write=True) 107 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 108 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 109 | pbar.update(1) 110 | pbar.set_description(f'Write {key}') 111 | key_byte = key.encode('ascii') 112 | if multiprocessing_read: 113 | img_byte = dataset[key] 114 | h, w, c = shapes[key] 115 | else: 116 | _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) 117 | h, w, c = img_shape 118 | 119 | txn.put(key_byte, img_byte) 120 | # write meta information 121 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') 122 | if idx % batch == 0: 123 | txn.commit() 124 | txn = env.begin(write=True) 125 | pbar.close() 126 | txn.commit() 127 | env.close() 128 | txt_file.close() 129 | print('\nFinish writing lmdb.') 130 | 131 | 132 | def read_img_worker(path, key, compress_level): 133 | """Read image worker. 134 | 135 | Args: 136 | path (str): Image path. 137 | key (str): Image key. 138 | compress_level (int): Compress level when encoding images. 139 | 140 | Returns: 141 | str: Image key. 142 | byte: Image byte. 143 | tuple[int]: Image shape. 144 | """ 145 | 146 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 147 | if img.ndim == 2: 148 | h, w = img.shape 149 | c = 1 150 | else: 151 | h, w, c = img.shape 152 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 153 | return (key, img_byte, (h, w, c)) 154 | 155 | 156 | class LmdbMaker(): 157 | """LMDB Maker. 158 | 159 | Args: 160 | lmdb_path (str): Lmdb save path. 161 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. 162 | batch (int): After processing batch images, lmdb commits. 163 | Default: 5000. 164 | compress_level (int): Compress level when encoding images. Default: 1. 165 | """ 166 | 167 | def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): 168 | if not lmdb_path.endswith('.lmdb'): 169 | raise ValueError("lmdb_path must end with '.lmdb'.") 170 | if osp.exists(lmdb_path): 171 | print(f'Folder {lmdb_path} already exists. Exit.') 172 | sys.exit(1) 173 | 174 | self.lmdb_path = lmdb_path 175 | self.batch = batch 176 | self.compress_level = compress_level 177 | self.env = lmdb.open(lmdb_path, map_size=map_size) 178 | self.txn = self.env.begin(write=True) 179 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 180 | self.counter = 0 181 | 182 | def put(self, img_byte, key, img_shape): 183 | self.counter += 1 184 | key_byte = key.encode('ascii') 185 | self.txn.put(key_byte, img_byte) 186 | # write meta information 187 | h, w, c = img_shape 188 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') 189 | if self.counter % self.batch == 0: 190 | self.txn.commit() 191 | self.txn = self.env.begin(write=True) 192 | 193 | def close(self): 194 | self.txn.commit() 195 | self.env.close() 196 | self.txt_file.close() 197 | -------------------------------------------------------------------------------- /basicsr/archs/AIEM.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torchvision 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | class LayerNormFunction(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, x, weight, bias, eps): 11 | ctx.eps = eps 12 | N, C, H, W = x.size() 13 | mu = x.mean(1, keepdim=True) 14 | var = (x - mu).pow(2).mean(1, keepdim=True) 15 | 16 | y = (x - mu) / (var + eps).sqrt() 17 | weight, bias, y = weight.contiguous(), bias.contiguous(), y.contiguous() # avoid cuda error 18 | ctx.save_for_backward(y, var, weight) 19 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) 20 | return y 21 | 22 | @staticmethod 23 | def backward(ctx, grad_output): 24 | eps = ctx.eps 25 | 26 | N, C, H, W = grad_output.size() 27 | y, var, weight = ctx.saved_tensors 28 | g = grad_output * weight.view(1, C, 1, 1) 29 | mean_g = g.mean(dim=1, keepdim=True) 30 | 31 | mean_gy = (g * y).mean(dim=1, keepdim=True) 32 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) 33 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( 34 | dim=0), None 35 | 36 | 37 | class LayerNorm2d(nn.Module): 38 | def __init__(self, channels, eps=1e-6, requires_grad=True): 39 | super(LayerNorm2d, self).__init__() 40 | self.register_parameter('weight', nn.Parameter(torch.ones(channels), requires_grad=requires_grad)) 41 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels), requires_grad=requires_grad)) 42 | self.eps = eps 43 | 44 | def forward(self, x): 45 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) 46 | 47 | 48 | class SimpleGate(nn.Module): 49 | def forward(self, x): 50 | x1, x2 = x.chunk(2, dim=1) 51 | return x1 * x2 52 | 53 | class LKA(nn.Module): 54 | def __init__(self, inp_dim, out_dim): 55 | super().__init__() 56 | self.conv0 = nn.Conv2d(inp_dim, inp_dim, 5, padding=2, groups=inp_dim) 57 | self.conv_spatial = nn.Conv2d(inp_dim, inp_dim, 7, stride=1, padding=9, groups=inp_dim, dilation=3) 58 | self.conv1 = nn.Conv2d(inp_dim, out_dim, 1) 59 | 60 | def forward(self, x): 61 | attn = self.conv0(x) 62 | attn = self.conv_spatial(attn) 63 | attn = self.conv1(attn) 64 | 65 | return attn 66 | 67 | 68 | class IMAConv(nn.Module): 69 | ''' Mutual Affine Convolution (MAConv) layer ''' 70 | def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=True, split=4, n_curve=3): 71 | super(IMAConv, self).__init__() 72 | assert split >= 2, 'Num of splits should be larger than one' 73 | 74 | self.num_split = split 75 | splits = [1 / split] * split 76 | self.in_split, self.in_split_rest, self.out_split = [], [], [] 77 | self.n_curve = n_curve 78 | self.relu = nn.ReLU(inplace=False) 79 | 80 | for i in range(self.num_split): 81 | in_split = round(in_channel * splits[i]) if i < self.num_split - 1 else in_channel - sum(self.in_split) 82 | in_split_rest = in_channel - in_split 83 | out_split = round(out_channel * splits[i]) if i < self.num_split - 1 else in_channel - sum(self.out_split) 84 | 85 | self.in_split.append(in_split) 86 | self.in_split_rest.append(in_split_rest) 87 | self.out_split.append(out_split) 88 | 89 | setattr(self, 'predictA{}'.format(i), nn.Sequential(*[ 90 | nn.Conv2d(in_split_rest, in_split, 5, stride=1, padding=2),nn.ReLU(inplace=True), 91 | nn.Conv2d(in_split, in_split, 3, stride=1, padding=1),nn.ReLU(inplace=True), 92 | nn.Conv2d(in_split, n_curve, 1, stride=1, padding=0), 93 | nn.Sigmoid() 94 | ])) 95 | setattr(self, 'conv{}'.format(i), nn.Conv2d(in_channels=in_split, out_channels=out_split, 96 | kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) 97 | 98 | def forward(self, input): 99 | input = torch.split(input, self.in_split, dim=1) 100 | output = [] 101 | 102 | for i in range(self.num_split): 103 | a = getattr(self, 'predictA{}'.format(i))(torch.cat(input[:i] + input[i + 1:], 1)) 104 | x = self.relu(input[i]) - self.relu(input[i]-1) 105 | for j in range(self.n_curve): 106 | x = x + a[:,j:j+1]*x*(1-x) 107 | output.append(getattr(self, 'conv{}'.format(i))(x)) 108 | 109 | return torch.cat(output, 1) 110 | 111 | class AIEM(nn.Module): 112 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0., split_group=4, n_curve=3): 113 | super().__init__() 114 | dw_channel = c * DW_Expand 115 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, 116 | bias=True) 117 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, 118 | groups=dw_channel, 119 | bias=True) 120 | self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 121 | groups=1, bias=True) 122 | 123 | # Simplified Channel Attention 124 | self.sca1 = LKA(dw_channel, dw_channel//2) 125 | self.sca2 = nn.Sequential( 126 | nn.AdaptiveAvgPool2d(1), 127 | nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, 128 | groups=1, bias=True), 129 | ) 130 | 131 | # SimpleGate 132 | self.sg = SimpleGate() 133 | 134 | ffn_channel = FFN_Expand * c 135 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 136 | self.IMAC = IMAConv(in_channel=ffn_channel // 2, out_channel=ffn_channel // 2, kernel_size=3, stride=1, padding=1, bias=True, split=split_group, n_curve=n_curve) 137 | self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 138 | 139 | self.norm1 = LayerNorm2d(c) 140 | self.norm2 = LayerNorm2d(c) 141 | 142 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 143 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 144 | 145 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 146 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 147 | 148 | def forward(self, inp): 149 | x = inp 150 | 151 | x = self.norm1(x) 152 | 153 | x = self.conv1(x) 154 | x = self.conv2(x) 155 | # x = self.sg(x) 156 | x = (self.sca2(x) * self.sca1(x)) * self.sg(x) 157 | x = self.conv3(x) 158 | 159 | x = self.dropout1(x) 160 | 161 | y = inp + x * self.beta 162 | x = self.conv4(self.norm2(y)) 163 | x = self.sg(x) 164 | x = self.IMAC(x) 165 | x = self.conv5(x) 166 | 167 | 168 | x = self.dropout2(x) 169 | 170 | return y + x * self.gamma 171 | 172 | class EnhanceLayers(nn.Module): 173 | def __init__(self, embed_dim=256, n_layers=4, split_group=4, n_curve=3): 174 | super().__init__() 175 | self.blks = nn.ModuleList() 176 | for i in range(n_layers): 177 | layer = AIEM(embed_dim, DW_Expand=2, FFN_Expand=2, drop_out_rate=0., split_group=split_group, n_curve=n_curve) 178 | self.blks.append(layer) 179 | 180 | def forward(self, x): 181 | for m in self.blks: 182 | x = m(x) 183 | return x 184 | -------------------------------------------------------------------------------- /basicsr/archs/DRSW_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pdb import set_trace as stx 5 | import numbers 6 | 7 | from einops import rearrange 8 | 9 | def to_3d(x): 10 | return rearrange(x, 'b c h w -> b (h w) c') 11 | 12 | def to_4d(x, h, w): 13 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 14 | 15 | class BiasFree_LayerNorm(nn.Module): 16 | def __init__(self, normalized_shape): 17 | super(BiasFree_LayerNorm, self).__init__() 18 | if isinstance(normalized_shape, numbers.Integral): 19 | normalized_shape = (normalized_shape,) 20 | normalized_shape = torch.Size(normalized_shape) 21 | 22 | assert len(normalized_shape) == 1 23 | 24 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 25 | self.normalized_shape = normalized_shape 26 | 27 | def forward(self, x): 28 | sigma = x.var(-1, keepdim=True, unbiased=False) 29 | return x / torch.sqrt(sigma + 1e-5) * self.weight 30 | 31 | class WithBias_LayerNorm(nn.Module): 32 | def __init__(self, normalized_shape): 33 | super(WithBias_LayerNorm, self).__init__() 34 | if isinstance(normalized_shape, numbers.Integral): 35 | normalized_shape = (normalized_shape,) 36 | normalized_shape = torch.Size(normalized_shape) 37 | 38 | assert len(normalized_shape) == 1 39 | 40 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 41 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 42 | self.normalized_shape = normalized_shape 43 | 44 | def forward(self, x): 45 | mu = x.mean(-1, keepdim=True) 46 | sigma = x.var(-1, keepdim=True, unbiased=False) 47 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias 48 | 49 | class LayerNorm(nn.Module): 50 | def __init__(self, dim, LayerNorm_type): 51 | super(LayerNorm, self).__init__() 52 | if LayerNorm_type == 'BiasFree': 53 | self.body = BiasFree_LayerNorm(dim) 54 | else: 55 | self.body = WithBias_LayerNorm(dim) 56 | 57 | def forward(self, x): 58 | h, w = x.shape[-2:] 59 | return to_4d(self.body(to_3d(x)), h, w) 60 | 61 | ## Mixed-Scale Feed-forward Network (MSFN) 62 | class FeedForward(nn.Module): 63 | def __init__(self, dim, ffn_expansion_factor, bias): 64 | super(FeedForward, self).__init__() 65 | 66 | hidden_features = int(dim * ffn_expansion_factor) 67 | 68 | self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) 69 | 70 | self.dwconv3x3 = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, groups=hidden_features * 2, bias=bias) 71 | self.dwconv5x5 = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=5, stride=1, padding=2, groups=hidden_features * 2, bias=bias) 72 | self.relu3 = nn.ReLU() 73 | self.relu5 = nn.ReLU() 74 | 75 | self.dwconv3x3_1 = nn.Conv2d(hidden_features * 2, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features , bias=bias) 76 | self.dwconv5x5_1 = nn.Conv2d(hidden_features * 2, hidden_features, kernel_size=5, stride=1, padding=2, groups=hidden_features , bias=bias) 77 | 78 | self.relu3_1 = nn.ReLU() 79 | self.relu5_1 = nn.ReLU() 80 | 81 | self.project_out = nn.Conv2d(hidden_features * 2, dim, kernel_size=1, bias=bias) 82 | 83 | def forward(self, x): 84 | x = self.project_in(x) 85 | x1_3, x2_3 = self.relu3(self.dwconv3x3(x)).chunk(2, dim=1) 86 | x1_5, x2_5 = self.relu5(self.dwconv5x5(x)).chunk(2, dim=1) 87 | 88 | x1 = torch.cat([x1_3, x1_5], dim=1) 89 | x2 = torch.cat([x2_3, x2_5], dim=1) 90 | 91 | x1 = self.relu3_1(self.dwconv3x3_1(x1)) 92 | x2 = self.relu5_1(self.dwconv5x5_1(x2)) 93 | 94 | x = torch.cat([x1, x2], dim=1) 95 | 96 | x = self.project_out(x) 97 | 98 | return x 99 | 100 | ## Top-K Sparse Attention (TKSA) 101 | class Attention(nn.Module): 102 | def __init__(self, dim, num_heads, bias): 103 | super(Attention, self).__init__() 104 | self.num_heads = num_heads 105 | 106 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 107 | 108 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) 109 | self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) 110 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) 111 | self.attn_drop = nn.Dropout(0.) 112 | 113 | self.attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) 114 | self.attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) 115 | self.attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) 116 | self.attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) 117 | 118 | def forward(self, x): 119 | b, c, h, w = x.shape 120 | 121 | qkv = self.qkv_dwconv(self.qkv(x)) 122 | q, k, v = qkv.chunk(3, dim=1) 123 | 124 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 125 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 126 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 127 | 128 | q = torch.nn.functional.normalize(q, dim=-1) 129 | k = torch.nn.functional.normalize(k, dim=-1) 130 | 131 | _, _, C, _ = q.shape 132 | 133 | mask1 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False) 134 | mask2 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False) 135 | mask3 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False) 136 | mask4 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False) 137 | 138 | attn = (q @ k.transpose(-2, -1)) * self.temperature 139 | 140 | index = torch.topk(attn, k=int(C/2), dim=-1, largest=True)[1] 141 | mask1.scatter_(-1, index, 1.) 142 | attn1 = torch.where(mask1 > 0, attn, torch.full_like(attn, float('-inf'))) 143 | 144 | index = torch.topk(attn, k=int(C*2/3), dim=-1, largest=True)[1] 145 | mask2.scatter_(-1, index, 1.) 146 | attn2 = torch.where(mask2 > 0, attn, torch.full_like(attn, float('-inf'))) 147 | 148 | index = torch.topk(attn, k=int(C*3/4), dim=-1, largest=True)[1] 149 | mask3.scatter_(-1, index, 1.) 150 | attn3 = torch.where(mask3 > 0, attn, torch.full_like(attn, float('-inf'))) 151 | 152 | index = torch.topk(attn, k=int(C*4/5), dim=-1, largest=True)[1] 153 | mask4.scatter_(-1, index, 1.) 154 | attn4 = torch.where(mask4 > 0, attn, torch.full_like(attn, float('-inf'))) 155 | 156 | attn1 = attn1.softmax(dim=-1) 157 | attn2 = attn2.softmax(dim=-1) 158 | attn3 = attn3.softmax(dim=-1) 159 | attn4 = attn4.softmax(dim=-1) 160 | 161 | out1 = (attn1 @ v) 162 | out2 = (attn2 @ v) 163 | out3 = (attn3 @ v) 164 | out4 = (attn4 @ v) 165 | 166 | out = out1 * self.attn1 + out2 * self.attn2 + out3 * self.attn3 + out4 * self.attn4 167 | 168 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 169 | 170 | out = self.project_out(out) 171 | return out 172 | 173 | ## Sparse Transformer Block (STB) 174 | class TransformerBlock(nn.Module): 175 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): 176 | super(TransformerBlock, self).__init__() 177 | 178 | self.norm1 = LayerNorm(dim, LayerNorm_type) 179 | self.attn = Attention(dim, num_heads, bias) 180 | self.norm2 = LayerNorm(dim, LayerNorm_type) 181 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 182 | 183 | def forward(self, x): 184 | x = x + self.attn(self.norm1(x)) 185 | x = x + self.ffn(self.norm2(x)) 186 | 187 | return x 188 | -------------------------------------------------------------------------------- /basicsr/ops/dcn/src/deform_conv_ext.cpp: -------------------------------------------------------------------------------- 1 | // modify from 2 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #define WITH_CUDA // always use cuda 11 | #ifdef WITH_CUDA 12 | int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, 13 | at::Tensor offset, at::Tensor output, 14 | at::Tensor columns, at::Tensor ones, int kW, 15 | int kH, int dW, int dH, int padW, int padH, 16 | int dilationW, int dilationH, int group, 17 | int deformable_group, int im2col_step); 18 | 19 | int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, 20 | at::Tensor gradOutput, at::Tensor gradInput, 21 | at::Tensor gradOffset, at::Tensor weight, 22 | at::Tensor columns, int kW, int kH, int dW, 23 | int dH, int padW, int padH, int dilationW, 24 | int dilationH, int group, 25 | int deformable_group, int im2col_step); 26 | 27 | int deform_conv_backward_parameters_cuda( 28 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput, 29 | at::Tensor gradWeight, // at::Tensor gradBias, 30 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, 31 | int padW, int padH, int dilationW, int dilationH, int group, 32 | int deformable_group, float scale, int im2col_step); 33 | 34 | void modulated_deform_conv_cuda_forward( 35 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 36 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, 37 | int kernel_h, int kernel_w, const int stride_h, const int stride_w, 38 | const int pad_h, const int pad_w, const int dilation_h, 39 | const int dilation_w, const int group, const int deformable_group, 40 | const bool with_bias); 41 | 42 | void modulated_deform_conv_cuda_backward( 43 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 44 | at::Tensor offset, at::Tensor mask, at::Tensor columns, 45 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, 46 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, 47 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, 48 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, 49 | const bool with_bias); 50 | #endif 51 | 52 | int deform_conv_forward(at::Tensor input, at::Tensor weight, 53 | at::Tensor offset, at::Tensor output, 54 | at::Tensor columns, at::Tensor ones, int kW, 55 | int kH, int dW, int dH, int padW, int padH, 56 | int dilationW, int dilationH, int group, 57 | int deformable_group, int im2col_step) { 58 | if (input.device().is_cuda()) { 59 | #ifdef WITH_CUDA 60 | return deform_conv_forward_cuda(input, weight, offset, output, columns, 61 | ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, 62 | deformable_group, im2col_step); 63 | #else 64 | AT_ERROR("deform conv is not compiled with GPU support"); 65 | #endif 66 | } 67 | AT_ERROR("deform conv is not implemented on CPU"); 68 | } 69 | 70 | int deform_conv_backward_input(at::Tensor input, at::Tensor offset, 71 | at::Tensor gradOutput, at::Tensor gradInput, 72 | at::Tensor gradOffset, at::Tensor weight, 73 | at::Tensor columns, int kW, int kH, int dW, 74 | int dH, int padW, int padH, int dilationW, 75 | int dilationH, int group, 76 | int deformable_group, int im2col_step) { 77 | if (input.device().is_cuda()) { 78 | #ifdef WITH_CUDA 79 | return deform_conv_backward_input_cuda(input, offset, gradOutput, 80 | gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, 81 | dilationW, dilationH, group, deformable_group, im2col_step); 82 | #else 83 | AT_ERROR("deform conv is not compiled with GPU support"); 84 | #endif 85 | } 86 | AT_ERROR("deform conv is not implemented on CPU"); 87 | } 88 | 89 | int deform_conv_backward_parameters( 90 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput, 91 | at::Tensor gradWeight, // at::Tensor gradBias, 92 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, 93 | int padW, int padH, int dilationW, int dilationH, int group, 94 | int deformable_group, float scale, int im2col_step) { 95 | if (input.device().is_cuda()) { 96 | #ifdef WITH_CUDA 97 | return deform_conv_backward_parameters_cuda(input, offset, gradOutput, 98 | gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, 99 | dilationH, group, deformable_group, scale, im2col_step); 100 | #else 101 | AT_ERROR("deform conv is not compiled with GPU support"); 102 | #endif 103 | } 104 | AT_ERROR("deform conv is not implemented on CPU"); 105 | } 106 | 107 | void modulated_deform_conv_forward( 108 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 109 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, 110 | int kernel_h, int kernel_w, const int stride_h, const int stride_w, 111 | const int pad_h, const int pad_w, const int dilation_h, 112 | const int dilation_w, const int group, const int deformable_group, 113 | const bool with_bias) { 114 | if (input.device().is_cuda()) { 115 | #ifdef WITH_CUDA 116 | return modulated_deform_conv_cuda_forward(input, weight, bias, ones, 117 | offset, mask, output, columns, kernel_h, kernel_w, stride_h, 118 | stride_w, pad_h, pad_w, dilation_h, dilation_w, group, 119 | deformable_group, with_bias); 120 | #else 121 | AT_ERROR("modulated deform conv is not compiled with GPU support"); 122 | #endif 123 | } 124 | AT_ERROR("modulated deform conv is not implemented on CPU"); 125 | } 126 | 127 | void modulated_deform_conv_backward( 128 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 129 | at::Tensor offset, at::Tensor mask, at::Tensor columns, 130 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, 131 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, 132 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, 133 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, 134 | const bool with_bias) { 135 | if (input.device().is_cuda()) { 136 | #ifdef WITH_CUDA 137 | return modulated_deform_conv_cuda_backward(input, weight, bias, ones, 138 | offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, 139 | grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, 140 | pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, 141 | with_bias); 142 | #else 143 | AT_ERROR("modulated deform conv is not compiled with GPU support"); 144 | #endif 145 | } 146 | AT_ERROR("modulated deform conv is not implemented on CPU"); 147 | } 148 | 149 | 150 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 151 | m.def("deform_conv_forward", &deform_conv_forward, 152 | "deform forward"); 153 | m.def("deform_conv_backward_input", &deform_conv_backward_input, 154 | "deform_conv_backward_input"); 155 | m.def("deform_conv_backward_parameters", 156 | &deform_conv_backward_parameters, 157 | "deform_conv_backward_parameters"); 158 | m.def("modulated_deform_conv_forward", 159 | &modulated_deform_conv_forward, 160 | "modulated deform conv forward"); 161 | m.def("modulated_deform_conv_backward", 162 | &modulated_deform_conv_backward, 163 | "modulated deform conv backward"); 164 | } 165 | -------------------------------------------------------------------------------- /basicsr/data/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | 4 | 5 | def mod_crop(img, scale): 6 | """Mod crop images, used during testing. 7 | 8 | Args: 9 | img (ndarray): Input image. 10 | scale (int): Scale factor. 11 | 12 | Returns: 13 | ndarray: Result image. 14 | """ 15 | img = img.copy() 16 | if img.ndim in (2, 3): 17 | h, w = img.shape[0], img.shape[1] 18 | h_remainder, w_remainder = h % scale, w % scale 19 | img = img[:h - h_remainder, :w - w_remainder, ...] 20 | else: 21 | raise ValueError(f'Wrong img ndim: {img.ndim}.') 22 | return img 23 | 24 | 25 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): 26 | """Paired random crop. 27 | 28 | It crops lists of lq and gt images with corresponding locations. 29 | 30 | Args: 31 | img_gts (list[ndarray] | ndarray): GT images. Note that all images 32 | should have the same shape. If the input is an ndarray, it will 33 | be transformed to a list containing itself. 34 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 35 | should have the same shape. If the input is an ndarray, it will 36 | be transformed to a list containing itself. 37 | gt_patch_size (int): GT patch size. 38 | scale (int): Scale factor. 39 | gt_path (str): Path to ground-truth. 40 | 41 | Returns: 42 | list[ndarray] | ndarray: GT images and LQ images. If returned results 43 | only have one element, just return ndarray. 44 | """ 45 | 46 | if not isinstance(img_gts, list): 47 | img_gts = [img_gts] 48 | if not isinstance(img_lqs, list): 49 | img_lqs = [img_lqs] 50 | 51 | h_lq, w_lq, _ = img_lqs[0].shape 52 | h_gt, w_gt, _ = img_gts[0].shape 53 | lq_patch_size = gt_patch_size // scale 54 | 55 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 56 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 57 | f'multiplication of LQ ({h_lq}, {w_lq}).') 58 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 59 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 60 | f'({lq_patch_size}, {lq_patch_size}). ' 61 | f'Please remove {gt_path}.') 62 | 63 | # randomly choose top and left coordinates for lq patch 64 | top = random.randint(0, h_lq - lq_patch_size) 65 | left = random.randint(0, w_lq - lq_patch_size) 66 | 67 | # crop lq patch 68 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] 69 | 70 | # crop corresponding gt patch 71 | top_gt, left_gt = int(top * scale), int(left * scale) 72 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] 73 | if len(img_gts) == 1: 74 | img_gts = img_gts[0] 75 | if len(img_lqs) == 1: 76 | img_lqs = img_lqs[0] 77 | return img_gts, img_lqs 78 | 79 | 80 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): 81 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 82 | 83 | We use vertical flip and transpose for rotation implementation. 84 | All the images in the list use the same augmentation. 85 | 86 | Args: 87 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 88 | is an ndarray, it will be transformed to a list. 89 | hflip (bool): Horizontal flip. Default: True. 90 | rotation (bool): Ratotation. Default: True. 91 | flows (list[ndarray]: Flows to be augmented. If the input is an 92 | ndarray, it will be transformed to a list. 93 | Dimension is (h, w, 2). Default: None. 94 | return_status (bool): Return the status of flip and rotation. 95 | Default: False. 96 | 97 | Returns: 98 | list[ndarray] | ndarray: Augmented images and flows. If returned 99 | results only have one element, just return ndarray. 100 | 101 | """ 102 | hflip = hflip and random.random() < 0.5 103 | vflip = rotation and random.random() < 0.5 104 | rot90 = rotation and random.random() < 0.5 105 | 106 | def _augment(img): 107 | if hflip: # horizontal 108 | cv2.flip(img, 1, img) 109 | if vflip: # vertical 110 | cv2.flip(img, 0, img) 111 | if rot90: 112 | img = img.transpose(1, 0, 2) 113 | return img 114 | 115 | def _augment_flow(flow): 116 | if hflip: # horizontal 117 | cv2.flip(flow, 1, flow) 118 | flow[:, :, 0] *= -1 119 | if vflip: # vertical 120 | cv2.flip(flow, 0, flow) 121 | flow[:, :, 1] *= -1 122 | if rot90: 123 | flow = flow.transpose(1, 0, 2) 124 | flow = flow[:, :, [1, 0]] 125 | return flow 126 | 127 | if not isinstance(imgs, list): 128 | imgs = [imgs] 129 | imgs = [_augment(img) for img in imgs] 130 | if len(imgs) == 1: 131 | imgs = imgs[0] 132 | 133 | if flows is not None: 134 | if not isinstance(flows, list): 135 | flows = [flows] 136 | flows = [_augment_flow(flow) for flow in flows] 137 | if len(flows) == 1: 138 | flows = flows[0] 139 | return imgs, flows 140 | else: 141 | if return_status: 142 | return imgs, (hflip, vflip, rot90) 143 | else: 144 | return imgs 145 | 146 | 147 | 148 | 149 | def augment2(imgs1, imgs2, hflip=True, rotation=True, flows=None, return_status=False): 150 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 151 | 152 | We use vertical flip and transpose for rotation implementation. 153 | All the images in the list use the same augmentation. 154 | 155 | Args: 156 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 157 | is an ndarray, it will be transformed to a list. 158 | hflip (bool): Horizontal flip. Default: True. 159 | rotation (bool): Ratotation. Default: True. 160 | flows (list[ndarray]: Flows to be augmented. If the input is an 161 | ndarray, it will be transformed to a list. 162 | Dimension is (h, w, 2). Default: None. 163 | return_status (bool): Return the status of flip and rotation. 164 | Default: False. 165 | 166 | Returns: 167 | list[ndarray] | ndarray: Augmented images and flows. If returned 168 | results only have one element, just return ndarray. 169 | 170 | """ 171 | hflip = hflip and random.random() < 0.5 172 | vflip = rotation and random.random() < 0.5 173 | rot90 = rotation and random.random() < 0.5 174 | 175 | def _augment(img): 176 | if hflip: # horizontal 177 | cv2.flip(img, 1, img) 178 | if vflip: # vertical 179 | cv2.flip(img, 0, img) 180 | if rot90: 181 | img = img.transpose(1, 0, 2) 182 | return img 183 | 184 | def _augment_flow(flow): 185 | if hflip: # horizontal 186 | cv2.flip(flow, 1, flow) 187 | flow[:, :, 0] *= -1 188 | if vflip: # vertical 189 | cv2.flip(flow, 0, flow) 190 | flow[:, :, 1] *= -1 191 | if rot90: 192 | flow = flow.transpose(1, 0, 2) 193 | flow = flow[:, :, [1, 0]] 194 | return flow 195 | 196 | if not isinstance(imgs1, list): 197 | imgs1 = [imgs1] 198 | imgs1 = [_augment(img) for img in imgs1] 199 | if len(imgs1) == 1: 200 | imgs1 = imgs1[0] 201 | 202 | if not isinstance(imgs2, list): 203 | imgs2 = [imgs2] 204 | imgs2 = [_augment(img) for img in imgs2] 205 | if len(imgs2) == 1: 206 | imgs2 = imgs2[0] 207 | 208 | if return_status: 209 | return imgs1, imgs2, (hflip, vflip, rot90) 210 | else: 211 | return imgs1, imgs2 212 | 213 | 214 | 215 | def img_rotate(img, angle, center=None, scale=1.0): 216 | """Rotate image. 217 | 218 | Args: 219 | img (ndarray): Image to be rotated. 220 | angle (float): Rotation angle in degrees. Positive values mean 221 | counter-clockwise rotation. 222 | center (tuple[int]): Rotation center. If the center is None, 223 | initialize it as the center of the image. Default: None. 224 | scale (float): Isotropic scale factor. Default: 1.0. 225 | """ 226 | (h, w) = img.shape[:2] 227 | 228 | if center is None: 229 | center = (w // 2, h // 2) 230 | 231 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 232 | rotated_img = cv2.warpAffine(img, matrix, (w, h)) 233 | return rotated_img 234 | -------------------------------------------------------------------------------- /basicsr/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | from datetime import datetime 6 | import random 7 | import logging 8 | from collections import OrderedDict 9 | import numpy as np 10 | import cv2 11 | import torch 12 | from torchvision.utils import make_grid 13 | from shutil import get_terminal_size 14 | 15 | import yaml 16 | try: 17 | from yaml import CLoader as Loader, CDumper as Dumper 18 | except ImportError: 19 | from yaml import Loader, Dumper 20 | 21 | 22 | def OrderedYaml(): 23 | '''yaml orderedDict support''' 24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 25 | 26 | def dict_representer(dumper, data): 27 | return dumper.represent_dict(data.items()) 28 | 29 | def dict_constructor(loader, node): 30 | return OrderedDict(loader.construct_pairs(node)) 31 | 32 | Dumper.add_representer(OrderedDict, dict_representer) 33 | Loader.add_constructor(_mapping_tag, dict_constructor) 34 | return Loader, Dumper 35 | 36 | 37 | #################### 38 | # miscellaneous 39 | #################### 40 | 41 | 42 | def get_timestamp(): 43 | return datetime.now().strftime('%y%m%d-%H%M%S') 44 | 45 | 46 | def mkdir(path): 47 | if not os.path.exists(path): 48 | os.makedirs(path) 49 | 50 | 51 | def mkdirs(paths): 52 | if isinstance(paths, str): 53 | mkdir(paths) 54 | else: 55 | for path in paths: 56 | mkdir(path) 57 | 58 | 59 | def mkdir_and_rename(path): 60 | if os.path.exists(path): 61 | new_name = path + '_archived_' + get_timestamp() 62 | print('Path already exists. Rename it to [{:s}]'.format(new_name)) 63 | logger = logging.getLogger('base') 64 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) 65 | os.rename(path, new_name) 66 | os.makedirs(path) 67 | 68 | 69 | def set_random_seed(seed): 70 | random.seed(seed) 71 | np.random.seed(seed) 72 | torch.manual_seed(seed) 73 | torch.cuda.manual_seed_all(seed) 74 | 75 | 76 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): 77 | '''set up logger''' 78 | lg = logging.getLogger(logger_name) 79 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 80 | datefmt='%y-%m-%d %H:%M:%S') 81 | lg.setLevel(level) 82 | if tofile: 83 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) 84 | fh = logging.FileHandler(log_file, mode='w') 85 | fh.setFormatter(formatter) 86 | lg.addHandler(fh) 87 | if screen: 88 | sh = logging.StreamHandler() 89 | sh.setFormatter(formatter) 90 | lg.addHandler(sh) 91 | 92 | 93 | #################### 94 | # image convert 95 | #################### 96 | 97 | 98 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 99 | ''' 100 | Converts a torch Tensor into an image Numpy array 101 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 102 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 103 | ''' 104 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 105 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 106 | n_dim = tensor.dim() 107 | if n_dim == 4: 108 | n_img = len(tensor) 109 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 110 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 111 | elif n_dim == 3: 112 | img_np = tensor.numpy() 113 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 114 | elif n_dim == 2: 115 | img_np = tensor.numpy() 116 | else: 117 | raise TypeError( 118 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 119 | if out_type == np.uint8: 120 | img_np = (img_np * 255.0).round() 121 | img_np = np.clip(img_np, 0, 255) 122 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 123 | return img_np.astype(out_type) 124 | 125 | 126 | def save_img(img, img_path, mode='RGB'): 127 | cv2.imwrite(img_path, img) 128 | 129 | 130 | #################### 131 | # metric 132 | #################### 133 | 134 | 135 | def calculate_psnr(img1, img2): 136 | # img1 and img2 have range [0, 255] 137 | img1 = img1.astype(np.float64) 138 | img2 = img2.astype(np.float64) 139 | mse = np.mean((img1 - img2)**2) 140 | if mse == 0: 141 | return float('inf') 142 | return 20 * math.log10(255.0 / math.sqrt(mse)) 143 | 144 | 145 | def ssim(img1, img2): 146 | C1 = (0.01 * 255)**2 147 | C2 = (0.03 * 255)**2 148 | 149 | img1 = img1.astype(np.float64) 150 | img2 = img2.astype(np.float64) 151 | kernel = cv2.getGaussianKernel(11, 1.5) 152 | window = np.outer(kernel, kernel.transpose()) 153 | 154 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 155 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 156 | mu1_sq = mu1**2 157 | mu2_sq = mu2**2 158 | mu1_mu2 = mu1 * mu2 159 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 160 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 161 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 162 | 163 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 164 | (sigma1_sq + sigma2_sq + C2)) 165 | return ssim_map.mean() 166 | 167 | 168 | def calculate_ssim(img1, img2): 169 | '''calculate SSIM 170 | the same outputs as MATLAB's 171 | img1, img2: [0, 255] 172 | ''' 173 | if not img1.shape == img2.shape: 174 | raise ValueError('Input images must have the same dimensions.') 175 | if img1.ndim == 2: 176 | return ssim(img1, img2) 177 | elif img1.ndim == 3: 178 | if img1.shape[2] == 3: 179 | ssims = [] 180 | for i in range(3): 181 | ssims.append(ssim(img1, img2)) 182 | return np.array(ssims).mean() 183 | elif img1.shape[2] == 1: 184 | return ssim(np.squeeze(img1), np.squeeze(img2)) 185 | else: 186 | raise ValueError('Wrong input image dimensions.') 187 | 188 | 189 | class ProgressBar(object): 190 | '''A progress bar which can print the progress 191 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py 192 | ''' 193 | 194 | def __init__(self, task_num=0, bar_width=50, start=True): 195 | self.task_num = task_num 196 | max_bar_width = self._get_max_bar_width() 197 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) 198 | self.completed = 0 199 | if start: 200 | self.start() 201 | 202 | def _get_max_bar_width(self): 203 | terminal_width, _ = get_terminal_size() 204 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) 205 | if max_bar_width < 10: 206 | print('terminal width is too small ({}), please consider widen the terminal for better ' 207 | 'progressbar visualization'.format(terminal_width)) 208 | max_bar_width = 10 209 | return max_bar_width 210 | 211 | def start(self): 212 | if self.task_num > 0: 213 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format( 214 | ' ' * self.bar_width, self.task_num, 'Start...')) 215 | else: 216 | sys.stdout.write('completed: 0, elapsed: 0s') 217 | sys.stdout.flush() 218 | self.start_time = time.time() 219 | 220 | def update(self, msg='In progress...'): 221 | self.completed += 1 222 | elapsed = time.time() - self.start_time 223 | fps = self.completed / elapsed 224 | if self.task_num > 0: 225 | percentage = self.completed / float(self.task_num) 226 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 227 | mark_width = int(self.bar_width * percentage) 228 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) 229 | sys.stdout.write('\033[2F') # cursor up 2 lines 230 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) 231 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format( 232 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg)) 233 | else: 234 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( 235 | self.completed, int(elapsed + 0.5), fps)) 236 | sys.stdout.flush() 237 | -------------------------------------------------------------------------------- /basicsr/pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import logging 4 | import math 5 | import random 6 | import time 7 | import torch 8 | from os import path as osp 9 | import numpy as np 10 | from data import build_dataloader, build_dataset 11 | from data.data_sampler import EnlargedSampler 12 | from data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher 13 | from models import build_model 14 | from utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger, 15 | init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed) 16 | from utils.dist_util import get_dist_info, init_dist 17 | from utils.options import dict2str, parse 18 | 19 | import warnings 20 | warnings.filterwarnings("ignore", category=UserWarning) 21 | 22 | def parse_options(root_path, is_train=True): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') 25 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') 26 | parser.add_argument('--local-rank', type=int, default=0) 27 | args = parser.parse_args() 28 | opt = parse(args.opt, root_path, is_train=is_train) 29 | 30 | # distributed settings 31 | if args.launcher == 'none': 32 | opt['dist'] = False 33 | print('Disable distributed.', flush=True) 34 | else: 35 | opt['dist'] = True 36 | if args.launcher == 'slurm' and 'dist_params' in opt: 37 | init_dist(args.launcher, **opt['dist_params']) 38 | else: 39 | init_dist(args.launcher) 40 | 41 | opt['rank'], opt['world_size'] = get_dist_info() 42 | 43 | # random seed 44 | seed = opt.get('manual_seed') 45 | if seed is None: 46 | seed = 3407 47 | opt['manual_seed'] = seed 48 | set_random_seed(seed + opt['rank']) 49 | # seed 50 | torch.manual_seed(seed) 51 | random.seed(seed) 52 | np.random.seed(seed) 53 | torch.backends.cudnn.deterministic = True 54 | torch.backends.cudnn.benchmark = False 55 | 56 | return opt 57 | 58 | 59 | def init_loggers(opt): 60 | log_file = osp.join(opt['path']['log'], f"pretrain_{opt['name']}.log") 61 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 62 | logger.info(get_env_info()) 63 | logger.info(dict2str(opt)) 64 | 65 | # initialize wandb logger before tensorboard logger to allow proper sync: 66 | if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None): 67 | assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb') 68 | init_wandb_logger(opt) 69 | tb_logger = None 70 | if opt['logger'].get('use_tb_logger'): 71 | tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) 72 | return logger, tb_logger 73 | 74 | 75 | def create_pretrain_dataloader(opt, logger): 76 | pretrain_loader = None 77 | for phase, dataset_opt in opt['datasets'].items(): 78 | dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) 79 | train_set = build_dataset(dataset_opt) 80 | pretrain_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio) 81 | pretrain_loader = build_dataloader( 82 | train_set, 83 | dataset_opt, 84 | num_gpu=opt['num_gpu'], 85 | dist=opt['dist'], 86 | sampler=pretrain_sampler, 87 | seed=opt['manual_seed']) 88 | 89 | num_iter_per_epoch = math.ceil( 90 | len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) 91 | total_iters = int(opt['train']['total_iter']) 92 | total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) 93 | logger.info('Training statistics:' 94 | f'\n\tNumber of train images: {len(train_set)}' 95 | f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' 96 | f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' 97 | f'\n\tWorld size (gpu number): {opt["world_size"]}' 98 | f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' 99 | f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') 100 | 101 | return pretrain_loader, pretrain_sampler, total_epochs, total_iters 102 | 103 | 104 | def train_pipeline(root_path): 105 | opt = parse_options(root_path, is_train=True) 106 | 107 | torch.backends.cudnn.benchmark = True 108 | 109 | if opt['path'].get('resume_state'): 110 | device_id = torch.cuda.current_device() 111 | resume_state = torch.load( 112 | opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) 113 | else: 114 | resume_state = None 115 | 116 | # mkdir for experiments and logger 117 | if resume_state is None: 118 | make_exp_dirs(opt) 119 | if opt['logger'].get('use_tb_logger') and opt['rank'] == 0: 120 | mkdir_and_rename(osp.join('tb_logger', opt['name'])) 121 | 122 | # initialize loggers 123 | logger, tb_logger = init_loggers(opt) 124 | 125 | # create train and validation dataloaders 126 | result = create_pretrain_dataloader(opt, logger) 127 | pretrain_loader, pretrain_sampler, total_epochs, total_iters = result 128 | 129 | # create model 130 | if resume_state: # resume training 131 | check_resume(opt, resume_state['iter']) 132 | model = build_model(opt) 133 | model.resume_training(resume_state) # handle optimizers and schedulers 134 | logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") 135 | start_epoch = resume_state['epoch'] 136 | current_iter = resume_state['iter'] 137 | else: 138 | model = build_model(opt) 139 | start_epoch = 0 140 | current_iter = 0 141 | 142 | # create message logger (formatted outputs) 143 | msg_logger = MessageLogger(opt, current_iter, tb_logger) 144 | 145 | # dataloader prefetcher 146 | prefetch_mode = opt['datasets']['pretrain'].get('prefetch_mode') 147 | if prefetch_mode is None or prefetch_mode == 'cpu': 148 | prefetcher = CPUPrefetcher(pretrain_loader) 149 | elif prefetch_mode == 'cuda': 150 | prefetcher = CUDAPrefetcher(pretrain_loader, opt) 151 | logger.info(f'Use {prefetch_mode} prefetch dataloader') 152 | if opt['datasets']['pretrain'].get('pin_memory') is not True: 153 | raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') 154 | else: 155 | raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.") 156 | 157 | # training 158 | logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}') 159 | data_time, iter_time = time.time(), time.time() 160 | start_time = time.time() 161 | 162 | for epoch in range(start_epoch, total_epochs + 1): 163 | # break 164 | pretrain_sampler.set_epoch(epoch) 165 | prefetcher.reset() 166 | pretrain_data = prefetcher.next() 167 | while pretrain_data is not None: 168 | data_time = time.time() - data_time 169 | current_iter += 1 170 | if current_iter > total_iters: 171 | break 172 | # update learning rate 173 | model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) 174 | # training 175 | model.feed_data(pretrain_data) 176 | model.optimize_parameters(current_iter) 177 | iter_time = time.time() - iter_time 178 | # log 179 | if current_iter % opt['logger']['print_freq'] == 0: 180 | log_vars = {'epoch': epoch, 'iter': current_iter} 181 | log_vars.update({'lrs': model.get_current_learning_rate()}) 182 | log_vars.update({'time': iter_time, 'data_time': data_time}) 183 | log_vars.update(model.get_current_log()) 184 | msg_logger(log_vars) 185 | # save models and training states 186 | if current_iter % opt['logger']['save_checkpoint_freq'] == 0: 187 | logger.info('Saving models and training states.') 188 | model.save(epoch, current_iter) 189 | 190 | data_time = time.time() 191 | iter_time = time.time() 192 | pretrain_data = prefetcher.next() 193 | 194 | 195 | consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) 196 | logger.info(f'End of training. Time consumed: {consumed_time}') 197 | logger.info('Save the latest model.') 198 | model.save(epoch=-1, current_iter=-1) 199 | if tb_logger: 200 | tb_logger.close() 201 | 202 | 203 | if __name__ == '__main__': 204 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 205 | train_pipeline(root_path) 206 | 207 | -------------------------------------------------------------------------------- /basicsr/archs/arcface_arch.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from utils.registry import ARCH_REGISTRY 3 | 4 | 5 | def conv3x3(inplanes, outplanes, stride=1): 6 | """A simple wrapper for 3x3 convolution with padding. 7 | 8 | Args: 9 | inplanes (int): Channel number of inputs. 10 | outplanes (int): Channel number of outputs. 11 | stride (int): Stride in convolution. Default: 1. 12 | """ 13 | return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | """Basic residual block used in the ResNetArcFace architecture. 18 | 19 | Args: 20 | inplanes (int): Channel number of inputs. 21 | planes (int): Channel number of outputs. 22 | stride (int): Stride in convolution. Default: 1. 23 | downsample (nn.Module): The downsample module. Default: None. 24 | """ 25 | expansion = 1 # output channel expansion ratio 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class IRBlock(nn.Module): 57 | """Improved residual block (IR Block) used in the ResNetArcFace architecture. 58 | 59 | Args: 60 | inplanes (int): Channel number of inputs. 61 | planes (int): Channel number of outputs. 62 | stride (int): Stride in convolution. Default: 1. 63 | downsample (nn.Module): The downsample module. Default: None. 64 | use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. 65 | """ 66 | expansion = 1 # output channel expansion ratio 67 | 68 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): 69 | super(IRBlock, self).__init__() 70 | self.bn0 = nn.BatchNorm2d(inplanes) 71 | self.conv1 = conv3x3(inplanes, inplanes) 72 | self.bn1 = nn.BatchNorm2d(inplanes) 73 | self.prelu = nn.PReLU() 74 | self.conv2 = conv3x3(inplanes, planes, stride) 75 | self.bn2 = nn.BatchNorm2d(planes) 76 | self.downsample = downsample 77 | self.stride = stride 78 | self.use_se = use_se 79 | if self.use_se: 80 | self.se = SEBlock(planes) 81 | 82 | def forward(self, x): 83 | residual = x 84 | out = self.bn0(x) 85 | out = self.conv1(out) 86 | out = self.bn1(out) 87 | out = self.prelu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | if self.use_se: 92 | out = self.se(out) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = self.prelu(out) 99 | 100 | return out 101 | 102 | 103 | class Bottleneck(nn.Module): 104 | """Bottleneck block used in the ResNetArcFace architecture. 105 | 106 | Args: 107 | inplanes (int): Channel number of inputs. 108 | planes (int): Channel number of outputs. 109 | stride (int): Stride in convolution. Default: 1. 110 | downsample (nn.Module): The downsample module. Default: None. 111 | """ 112 | expansion = 4 # output channel expansion ratio 113 | 114 | def __init__(self, inplanes, planes, stride=1, downsample=None): 115 | super(Bottleneck, self).__init__() 116 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 117 | self.bn1 = nn.BatchNorm2d(planes) 118 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 119 | self.bn2 = nn.BatchNorm2d(planes) 120 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 121 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.downsample = downsample 124 | self.stride = stride 125 | 126 | def forward(self, x): 127 | residual = x 128 | 129 | out = self.conv1(x) 130 | out = self.bn1(out) 131 | out = self.relu(out) 132 | 133 | out = self.conv2(out) 134 | out = self.bn2(out) 135 | out = self.relu(out) 136 | 137 | out = self.conv3(out) 138 | out = self.bn3(out) 139 | 140 | if self.downsample is not None: 141 | residual = self.downsample(x) 142 | 143 | out += residual 144 | out = self.relu(out) 145 | 146 | return out 147 | 148 | 149 | class SEBlock(nn.Module): 150 | """The squeeze-and-excitation block (SEBlock) used in the IRBlock. 151 | 152 | Args: 153 | channel (int): Channel number of inputs. 154 | reduction (int): Channel reduction ration. Default: 16. 155 | """ 156 | 157 | def __init__(self, channel, reduction=16): 158 | super(SEBlock, self).__init__() 159 | self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information 160 | self.fc = nn.Sequential( 161 | nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), 162 | nn.Sigmoid()) 163 | 164 | def forward(self, x): 165 | b, c, _, _ = x.size() 166 | y = self.avg_pool(x).view(b, c) 167 | y = self.fc(y).view(b, c, 1, 1) 168 | return x * y 169 | 170 | 171 | @ARCH_REGISTRY.register() 172 | class ResNetArcFace(nn.Module): 173 | """ArcFace with ResNet architectures. 174 | 175 | Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition. 176 | 177 | Args: 178 | block (str): Block used in the ArcFace architecture. 179 | layers (tuple(int)): Block numbers in each layer. 180 | use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. 181 | """ 182 | 183 | def __init__(self, block, layers, use_se=True): 184 | if block == 'IRBlock': 185 | block = IRBlock 186 | self.inplanes = 64 187 | self.use_se = use_se 188 | super(ResNetArcFace, self).__init__() 189 | 190 | self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) 191 | self.bn1 = nn.BatchNorm2d(64) 192 | self.prelu = nn.PReLU() 193 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 194 | self.layer1 = self._make_layer(block, 64, layers[0]) 195 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 196 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 197 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 198 | self.bn4 = nn.BatchNorm2d(512) 199 | self.dropout = nn.Dropout() 200 | self.fc5 = nn.Linear(512 * 8 * 8, 512) 201 | self.bn5 = nn.BatchNorm1d(512) 202 | 203 | # initialization 204 | for m in self.modules(): 205 | if isinstance(m, nn.Conv2d): 206 | nn.init.xavier_normal_(m.weight) 207 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 208 | nn.init.constant_(m.weight, 1) 209 | nn.init.constant_(m.bias, 0) 210 | elif isinstance(m, nn.Linear): 211 | nn.init.xavier_normal_(m.weight) 212 | nn.init.constant_(m.bias, 0) 213 | 214 | def _make_layer(self, block, planes, num_blocks, stride=1): 215 | downsample = None 216 | if stride != 1 or self.inplanes != planes * block.expansion: 217 | downsample = nn.Sequential( 218 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 219 | nn.BatchNorm2d(planes * block.expansion), 220 | ) 221 | layers = [] 222 | layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) 223 | self.inplanes = planes 224 | for _ in range(1, num_blocks): 225 | layers.append(block(self.inplanes, planes, use_se=self.use_se)) 226 | 227 | return nn.Sequential(*layers) 228 | 229 | def forward(self, x): 230 | x = self.conv1(x) 231 | x = self.bn1(x) 232 | x = self.prelu(x) 233 | x = self.maxpool(x) 234 | 235 | x = self.layer1(x) 236 | x = self.layer2(x) 237 | x = self.layer3(x) 238 | x = self.layer4(x) 239 | x = self.bn4(x) 240 | x = self.dropout(x) 241 | x = x.view(x.size(0), -1) 242 | x = self.fc5(x) 243 | x = self.bn5(x) 244 | 245 | return x -------------------------------------------------------------------------------- /basicsr/train.py: -------------------------------------------------------------------------------- 1 | # os.environ['PYTORCH_CUDA_MAX_SPLIT_SIZE_MB'] = '128' 2 | import argparse 3 | import datetime 4 | import logging 5 | import math 6 | import copy 7 | import random 8 | import time 9 | import torch 10 | from os import path as osp 11 | import numpy as np 12 | from data import build_dataloader, build_dataset 13 | from data.data_sampler import EnlargedSampler 14 | from data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher 15 | from models import build_model 16 | from utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger, 17 | init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed) 18 | from utils.dist_util import get_dist_info, init_dist 19 | from utils.options import dict2str, parse 20 | 21 | import warnings 22 | # ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. 23 | warnings.filterwarnings("ignore", category=UserWarning) 24 | 25 | def parse_options(root_path, is_train=True): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') 28 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') 29 | parser.add_argument('--local_rank', type=int, default=0) 30 | args = parser.parse_args() 31 | opt = parse(args.opt, root_path, is_train=is_train) 32 | 33 | # distributed settings 34 | if args.launcher == 'none': 35 | opt['dist'] = False 36 | print('Disable distributed.', flush=True) 37 | else: 38 | opt['dist'] = True 39 | if args.launcher == 'slurm' and 'dist_params' in opt: 40 | init_dist(args.launcher, **opt['dist_params']) 41 | else: 42 | init_dist(args.launcher) 43 | 44 | opt['rank'], opt['world_size'] = get_dist_info() 45 | 46 | # random seed 47 | seed = opt.get('manual_seed') 48 | if seed is None: 49 | seed = random.randint(1, 10000) 50 | opt['manual_seed'] = seed 51 | set_random_seed(seed + opt['rank']) 52 | # seed 53 | torch.manual_seed(seed) 54 | random.seed(seed) 55 | np.random.seed(seed) 56 | torch.backends.cudnn.deterministic = True 57 | torch.backends.cudnn.benchmark = False 58 | 59 | return opt 60 | 61 | 62 | def init_loggers(opt): 63 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") 64 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 65 | logger.info(get_env_info()) 66 | logger.info(dict2str(opt)) 67 | 68 | # initialize wandb logger before tensorboard logger to allow proper sync: 69 | if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None): 70 | assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb') 71 | init_wandb_logger(opt) 72 | tb_logger = None 73 | if opt['logger'].get('use_tb_logger'): 74 | tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) 75 | return logger, tb_logger 76 | 77 | 78 | def create_train_val_dataloader(opt, logger): 79 | # create train and val dataloaders 80 | train_loader, val_loader = None, None 81 | for phase, dataset_opt in opt['datasets'].items(): 82 | if phase == 'train': 83 | dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) 84 | train_set = build_dataset(dataset_opt) 85 | train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio) 86 | train_loader = build_dataloader( 87 | train_set, 88 | dataset_opt, 89 | num_gpu=opt['num_gpu'], 90 | dist=opt['dist'], 91 | sampler=train_sampler, 92 | seed=opt['manual_seed']) 93 | 94 | num_iter_per_epoch = math.ceil( 95 | len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) 96 | total_iters = int(opt['train']['total_iter']) 97 | total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) 98 | logger.info('Training statistics:' 99 | f'\n\tNumber of train images: {len(train_set)}' 100 | f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' 101 | f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' 102 | f'\n\tWorld size (gpu number): {opt["world_size"]}' 103 | f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' 104 | f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') 105 | 106 | elif phase == 'val': 107 | val_set = build_dataset(dataset_opt) 108 | val_loader = build_dataloader( 109 | val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) 110 | logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}') 111 | else: 112 | raise ValueError(f'Dataset phase {phase} is not recognized.') 113 | 114 | return train_loader, train_sampler, val_loader, total_epochs, total_iters 115 | 116 | 117 | def train_pipeline(root_path): 118 | # parse options, set distributed setting, set ramdom seed 119 | opt = parse_options(root_path, is_train=True) 120 | 121 | torch.backends.cudnn.benchmark = True 122 | # torch.backends.cudnn.deterministic = True 123 | 124 | # load resume states if necessary 125 | if opt['path'].get('resume_state'): 126 | device_id = torch.cuda.current_device() 127 | resume_state = torch.load( 128 | opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) 129 | else: 130 | resume_state = None 131 | 132 | # mkdir for experiments and logger 133 | if resume_state is None: 134 | make_exp_dirs(opt) 135 | if opt['logger'].get('use_tb_logger') and opt['rank'] == 0: 136 | mkdir_and_rename(osp.join('tb_logger', opt['name'])) 137 | 138 | # initialize loggers 139 | logger, tb_logger = init_loggers(opt) 140 | 141 | # create train and validation dataloaders 142 | result = create_train_val_dataloader(opt, logger) 143 | train_loader, train_sampler, val_loader, total_epochs, total_iters = result 144 | 145 | # create model 146 | if resume_state: # resume training 147 | check_resume(opt, resume_state['iter']) 148 | model = build_model(opt) 149 | model.resume_training(resume_state) # handle optimizers and schedulers 150 | logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") 151 | start_epoch = resume_state['epoch'] 152 | current_iter = resume_state['iter'] 153 | else: 154 | model = build_model(opt) 155 | start_epoch = 0 156 | current_iter = 0 157 | 158 | # create message logger (formatted outputs) 159 | msg_logger = MessageLogger(opt, current_iter, tb_logger) 160 | 161 | # dataloader prefetcher 162 | prefetch_mode = opt['datasets']['train'].get('prefetch_mode') 163 | if prefetch_mode is None or prefetch_mode == 'cpu': 164 | prefetcher = CPUPrefetcher(train_loader) 165 | elif prefetch_mode == 'cuda': 166 | prefetcher = CUDAPrefetcher(train_loader, opt) 167 | logger.info(f'Use {prefetch_mode} prefetch dataloader') 168 | if opt['datasets']['train'].get('pin_memory') is not True: 169 | raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') 170 | else: 171 | raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.") 172 | 173 | # training 174 | logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}') 175 | data_time, iter_time = time.time(), time.time() 176 | start_time = time.time() 177 | 178 | for epoch in range(start_epoch, total_epochs + 1): 179 | # break 180 | train_sampler.set_epoch(epoch) 181 | prefetcher.reset() 182 | train_data = prefetcher.next() 183 | # validation 184 | if opt.get('val') is not None and opt['datasets'].get('val') is not None: 185 | model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) 186 | while train_data is not None: 187 | data_time = time.time() - data_time 188 | 189 | current_iter += 1 190 | if current_iter > total_iters: 191 | break 192 | # update learning rate 193 | model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) 194 | # training 195 | model.feed_data(train_data) 196 | model.optimize_parameters(current_iter) 197 | iter_time = time.time() - iter_time 198 | # log 199 | if current_iter % opt['logger']['print_freq'] == 0: 200 | log_vars = {'epoch': epoch, 'iter': current_iter} 201 | log_vars.update({'lrs': model.get_current_learning_rate()}) 202 | log_vars.update({'time': iter_time, 'data_time': data_time}) 203 | log_vars.update(model.get_current_log()) 204 | msg_logger(log_vars) 205 | 206 | # save models and training states 207 | if current_iter % opt['logger']['save_checkpoint_freq'] == 0: 208 | logger.info('Saving models and training states.') 209 | model.save(epoch, current_iter) 210 | 211 | # validation 212 | if opt.get('val') is not None and opt['datasets'].get('val') is not None: 213 | model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) 214 | 215 | data_time = time.time() 216 | iter_time = time.time() 217 | train_data = prefetcher.next() 218 | # end of iter 219 | 220 | # end of epoch 221 | 222 | consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) 223 | logger.info(f'End of training. Time consumed: {consumed_time}') 224 | logger.info('Save the latest model.') 225 | model.save(epoch=-1, current_iter=-1) # -1 stands for the latest 226 | if opt.get('val') is not None and opt['datasets'].get('val'): 227 | model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) 228 | if tb_logger: 229 | tb_logger.close() 230 | 231 | 232 | if __name__ == '__main__': 233 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 234 | train_pipeline(root_path) 235 | 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /basicsr/losses/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | from torch.autograd import Variable 8 | import numpy as np 9 | from . import pretrained_networks as pn 10 | import torch.nn 11 | 12 | import losses.lpips as lpips 13 | 14 | def spatial_average(in_tens, keepdim=True): 15 | return in_tens.mean([2,3],keepdim=keepdim) 16 | 17 | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W 18 | in_H, in_W = in_tens.shape[2], in_tens.shape[3] 19 | return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) 20 | 21 | # Learned perceptual metric 22 | class LPIPS(nn.Module): 23 | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, 24 | pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True): 25 | """ Initializes a perceptual loss torch.nn.Module 26 | 27 | Parameters (default listed first) 28 | --------------------------------- 29 | lpips : bool 30 | [True] use linear layers on top of base/trunk network 31 | [False] means no linear layers; each layer is averaged together 32 | pretrained : bool 33 | This flag controls the linear layers, which are only in effect when lpips=True above 34 | [True] means linear layers are calibrated with human perceptual judgments 35 | [False] means linear layers are randomly initialized 36 | pnet_rand : bool 37 | [False] means trunk loaded with ImageNet classification weights 38 | [True] means randomly initialized trunk 39 | net : str 40 | ['alex','vgg','squeeze'] are the base/trunk networks available 41 | version : str 42 | ['v0.1'] is the default and latest 43 | ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1) 44 | model_path : 'str' 45 | [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1 46 | 47 | The following parameters should only be changed if training the network 48 | 49 | eval_mode : bool 50 | [True] is for test mode (default) 51 | [False] is for training mode 52 | pnet_tune 53 | [False] tune the base/trunk network 54 | [True] keep base/trunk frozen 55 | use_dropout : bool 56 | [True] to use dropout when training linear layers 57 | [False] for no dropout when training linear layers 58 | """ 59 | 60 | super(LPIPS, self).__init__() 61 | if(verbose): 62 | print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'% 63 | ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) 64 | 65 | self.pnet_type = net 66 | self.pnet_tune = pnet_tune 67 | self.pnet_rand = pnet_rand 68 | self.spatial = spatial 69 | self.lpips = lpips # false means baseline of just averaging all layers 70 | self.version = version 71 | self.scaling_layer = ScalingLayer() 72 | 73 | if(self.pnet_type in ['vgg','vgg16']): 74 | net_type = pn.vgg16 75 | self.chns = [64,128,256,512,512] 76 | elif(self.pnet_type=='alex'): 77 | net_type = pn.alexnet 78 | self.chns = [64,192,384,256,256] 79 | elif(self.pnet_type=='squeeze'): 80 | net_type = pn.squeezenet 81 | self.chns = [64,128,256,384,384,512,512] 82 | self.L = len(self.chns) 83 | 84 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 85 | 86 | if(lpips): 87 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 88 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 89 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 90 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 91 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 92 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 93 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 94 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 95 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 96 | self.lins+=[self.lin5,self.lin6] 97 | self.lins = nn.ModuleList(self.lins) 98 | 99 | if(pretrained): 100 | if(model_path is None): 101 | import inspect 102 | import os 103 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net))) 104 | 105 | if(verbose): 106 | print('Loading model from: %s'%model_path) 107 | self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) 108 | 109 | if(eval_mode): 110 | self.eval() 111 | 112 | def forward(self, in0, in1, retPerLayer=False, normalize=False): 113 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] 114 | in0 = 2 * in0 - 1 115 | in1 = 2 * in1 - 1 116 | 117 | # v0.0 - original release had a bug, where input was not scaled 118 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 119 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 120 | feats0, feats1, diffs = {}, {}, {} 121 | 122 | for kk in range(self.L): 123 | feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk]) 124 | diffs[kk] = (feats0[kk]-feats1[kk])**2 125 | 126 | if(self.lpips): 127 | if(self.spatial): 128 | res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] 129 | else: 130 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] 131 | else: 132 | if(self.spatial): 133 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] 134 | else: 135 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] 136 | 137 | val = 0 138 | for l in range(self.L): 139 | val += res[l] 140 | 141 | if(retPerLayer): 142 | return (val, res) 143 | else: 144 | return val 145 | 146 | 147 | class ScalingLayer(nn.Module): 148 | def __init__(self): 149 | super(ScalingLayer, self).__init__() 150 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) 151 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) 152 | 153 | def forward(self, inp): 154 | return (inp - self.shift) / self.scale 155 | 156 | 157 | class NetLinLayer(nn.Module): 158 | ''' A single linear layer which does a 1x1 conv ''' 159 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 160 | super(NetLinLayer, self).__init__() 161 | 162 | layers = [nn.Dropout(),] if(use_dropout) else [] 163 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 164 | self.model = nn.Sequential(*layers) 165 | 166 | def forward(self, x): 167 | return self.model(x) 168 | 169 | class Dist2LogitLayer(nn.Module): 170 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 171 | def __init__(self, chn_mid=32, use_sigmoid=True): 172 | super(Dist2LogitLayer, self).__init__() 173 | 174 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 175 | layers += [nn.LeakyReLU(0.2,True),] 176 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 177 | layers += [nn.LeakyReLU(0.2,True),] 178 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 179 | if(use_sigmoid): 180 | layers += [nn.Sigmoid(),] 181 | self.model = nn.Sequential(*layers) 182 | 183 | def forward(self,d0,d1,eps=0.1): 184 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 185 | 186 | class BCERankingLoss(nn.Module): 187 | def __init__(self, chn_mid=32): 188 | super(BCERankingLoss, self).__init__() 189 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 190 | # self.parameters = list(self.net.parameters()) 191 | self.loss = torch.nn.BCELoss() 192 | 193 | def forward(self, d0, d1, judge): 194 | per = (judge+1.)/2. 195 | self.logit = self.net.forward(d0,d1) 196 | return self.loss(self.logit, per) 197 | 198 | # L2, DSSIM metrics 199 | class FakeNet(nn.Module): 200 | def __init__(self, use_gpu=True, colorspace='Lab'): 201 | super(FakeNet, self).__init__() 202 | self.use_gpu = use_gpu 203 | self.colorspace = colorspace 204 | 205 | class L2(FakeNet): 206 | def forward(self, in0, in1, retPerLayer=None): 207 | assert(in0.size()[0]==1) # currently only supports batchSize 1 208 | 209 | if(self.colorspace=='RGB'): 210 | (N,C,X,Y) = in0.size() 211 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 212 | return value 213 | elif(self.colorspace=='Lab'): 214 | value = lpips.l2(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)), 215 | lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 216 | ret_var = Variable( torch.Tensor((value,) ) ) 217 | if(self.use_gpu): 218 | ret_var = ret_var.cuda() 219 | return ret_var 220 | 221 | class DSSIM(FakeNet): 222 | 223 | def forward(self, in0, in1, retPerLayer=None): 224 | assert(in0.size()[0]==1) # currently only supports batchSize 1 225 | 226 | if(self.colorspace=='RGB'): 227 | value = lpips.dssim(1.*lpips.tensor2im(in0.data), 1.*lpips.tensor2im(in1.data), range=255.).astype('float') 228 | elif(self.colorspace=='Lab'): 229 | value = lpips.dssim(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)), 230 | lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 231 | ret_var = Variable( torch.Tensor((value,) ) ) 232 | if(self.use_gpu): 233 | ret_var = ret_var.cuda() 234 | return ret_var 235 | 236 | def print_network(net): 237 | num_params = 0 238 | for param in net.parameters(): 239 | num_params += param.numel() 240 | print('Network',net) 241 | print('Total number of parameters: %d' % num_params) 242 | --------------------------------------------------------------------------------