├── VERSION ├── basicsr ├── ops │ ├── __init__.py │ ├── upfirdn2d │ │ ├── __init__.py │ │ ├── src │ │ │ └── upfirdn2d.cpp │ │ └── upfirdn2d.py │ ├── fused_act │ │ ├── __init__.py │ │ ├── src │ │ │ ├── fused_bias_act.cpp │ │ │ └── fused_bias_act_kernel.cu │ │ └── fused_act.py │ └── dcn │ │ ├── __init__.py │ │ └── src │ │ └── deform_conv_ext.cpp ├── __init__.py ├── metrics │ ├── __init__.py │ ├── metric_util.py │ └── psnr_ssim.py ├── losses │ ├── __init__.py │ └── loss_util.py ├── utils │ ├── __init__.py │ ├── download_util.py │ ├── registry.py │ ├── img_process_util.py │ ├── dist_util.py │ ├── misc.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_util.py │ ├── options-DESKTOP-S7HM52K.py │ ├── options.py │ ├── lmdb_util.py │ ├── logger.py │ └── face_util.py ├── models │ ├── __init__.py │ └── lr_scheduler.py ├── archs │ ├── __init__.py │ ├── vgg_arch.py │ └── mask_guided_jdnsr_arch.py ├── test.py └── data │ ├── data_sampler.py │ ├── prefetch_dataloader.py │ └── __init__.py ├── scripts ├── __init__.py ├── metrics │ └── __init__.py ├── data_preparation │ ├── __init__.py │ ├── generate_avg_ct_for_maskguided2022.py │ ├── generate_lr_images.py │ └── create_lmdb.py ├── matlab_scripts │ ├── back_projection │ │ ├── backprojection.m │ │ ├── main_bp.m │ │ └── main_reverse_filter.m │ ├── generate_LR_Vimeo90K.m │ └── generate_bicubic_img.m ├── model_conversion │ ├── convert_ridnet.py │ ├── convert_dfdnet.py │ └── convert_stylegan.py ├── generate_bicubic_sr_img.py └── test_dual_guided_jdnsr_2022.py ├── Low-dose CT image super-resolution network.pdf ├── requirements.txt ├── setup.cfg ├── LICENSE ├── LICENSE-stylegan2-pytorch ├── README.md └── LICENSE-NVIDIA ├── options ├── test │ ├── test_mask_guided_jdnsr_x2_3dircadb.yml │ ├── test_mask_guided_jdnsr_x4_3dircadb.yml │ ├── test_mask_guided_jdnsr_x4_pancreas.yml │ └── test_mask_guided_jdnsr_x2_pancreas.yml └── train │ ├── train_mask_guided_jdnsr_x2_3dircadb.yml │ ├── train_mask_guided_jdnsr_x2_pancreas.yml │ ├── train_mask_guided_jdnsr_x4_3dircadb.yml │ └── train_mask_guided_jdnsr_x4_pancreas.yml ├── .pre-commit-config.yaml ├── .gitignore ├── README.md └── setup.py /VERSION: -------------------------------------------------------------------------------- 1 | 1.3.4.4 2 | -------------------------------------------------------------------------------- /basicsr/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/data_preparation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ['upfirdn2d'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] 4 | -------------------------------------------------------------------------------- /Low-dose CT image super-resolution network.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neu-szy/dual-guidance_LDCT_SR/HEAD/Low-dose CT image super-resolution network.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict 2 | future 3 | lmdb 4 | numpy 5 | matplotlib 6 | opencv-python 7 | Pillow 8 | pyyaml 9 | requests 10 | scikit-image 11 | scipy 12 | tb-nightly 13 | torch>=1.7 14 | torchvision 15 | tqdm 16 | yapf 17 | easydict 18 | icecream 19 | einops 20 | timm 21 | -------------------------------------------------------------------------------- /basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, 2 | modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .archs import * 4 | from .data import * 5 | from .losses import * 6 | from .metrics import * 7 | from .models import * 8 | from .ops import * 9 | from .test import * 10 | from .train import * 11 | from .utils import * 12 | from .version import __gitsha__, __version__ 13 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils.registry import METRIC_REGISTRY 4 | from .psnr_ssim import calculate_psnr, calculate_ssim 5 | 6 | __all__ = ['calculate_psnr', 'calculate_ssim'] 7 | 8 | 9 | def calculate_metric(data, opt): 10 | """Calculate metric from data and options. 11 | 12 | Args: 13 | opt (dict): Configuration. It must contain: 14 | type (str): Model type. 15 | """ 16 | opt = deepcopy(opt) 17 | metric_type = opt.pop('type') # calculate_psnr and so on 18 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 19 | return metric 20 | -------------------------------------------------------------------------------- /scripts/matlab_scripts/back_projection/backprojection.m: -------------------------------------------------------------------------------- 1 | function [im_h] = backprojection(im_h, im_l, maxIter) 2 | 3 | [row_l, col_l,~] = size(im_l); 4 | [row_h, col_h,~] = size(im_h); 5 | 6 | p = fspecial('gaussian', 5, 1); 7 | p = p.^2; 8 | p = p./sum(p(:)); 9 | 10 | im_l = double(im_l); 11 | im_h = double(im_h); 12 | 13 | for ii = 1:maxIter 14 | im_l_s = imresize(im_h, [row_l, col_l], 'bicubic'); 15 | im_diff = im_l - im_l_s; 16 | im_diff = imresize(im_diff, [row_h, col_h], 'bicubic'); 17 | im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same'); 18 | im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same'); 19 | im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same'); 20 | end 21 | -------------------------------------------------------------------------------- /scripts/matlab_scripts/back_projection/main_bp.m: -------------------------------------------------------------------------------- 1 | clear; close all; clc; 2 | 3 | LR_folder = './LR'; % LR 4 | preout_folder = './results'; % pre output 5 | save_folder = './results_20bp'; 6 | filepaths = dir(fullfile(preout_folder, '*.png')); 7 | max_iter = 20; 8 | 9 | if ~ exist(save_folder, 'dir') 10 | mkdir(save_folder); 11 | end 12 | 13 | for idx_im = 1:length(filepaths) 14 | fprintf([num2str(idx_im) '\n']); 15 | im_name = filepaths(idx_im).name; 16 | im_LR = im2double(imread(fullfile(LR_folder, im_name))); 17 | im_out = im2double(imread(fullfile(preout_folder, im_name))); 18 | %tic 19 | im_out = backprojection(im_out, im_LR, max_iter); 20 | %toc 21 | imwrite(im_out, fullfile(save_folder, im_name)); 22 | end 23 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # line break before binary operator (W503) 4 | W503, 5 | # line break after binary operator (W504) 6 | W504, 7 | max-line-length=120 8 | 9 | [yapf] 10 | based_on_style = pep8 11 | column_limit = 120 12 | blank_line_before_nested_class_or_def = true 13 | split_before_expression_after_opening_paren = true 14 | 15 | [isort] 16 | line_length = 120 17 | multi_line_output = 0 18 | known_standard_library = pkg_resources,setuptools 19 | known_first_party = basicsr 20 | known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml 21 | no_lines_before = STDLIB,LOCALFOLDER 22 | default_section = THIRDPARTY 23 | 24 | [codespell] 25 | skip = .git,./docs/build,*.cfg 26 | count = 27 | quiet-level = 3 28 | ignore-words-list = gool 29 | -------------------------------------------------------------------------------- /scripts/model_conversion/convert_ridnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | from basicsr.archs.ridnet_arch import RIDNet 5 | 6 | if __name__ == '__main__': 7 | ori_net_checkpoint = torch.load( 8 | 'experiments/pretrained_models/RIDNet/RIDNet_official_original.pt', map_location=lambda storage, loc: storage) 9 | rid_net = RIDNet(3, 64, 3) 10 | new_ridnet_dict = OrderedDict() 11 | 12 | rid_net_namelist = [] 13 | for name, param in rid_net.named_parameters(): 14 | rid_net_namelist.append(name) 15 | 16 | count = 0 17 | for name, param in ori_net_checkpoint.items(): 18 | new_ridnet_dict[rid_net_namelist[count]] = param 19 | count += 1 20 | 21 | rid_net.load_state_dict(new_ridnet_dict) 22 | torch.save(rid_net.state_dict(), 'experiments/pretrained_models/RIDNet/RIDNet.pth') 23 | -------------------------------------------------------------------------------- /scripts/matlab_scripts/back_projection/main_reverse_filter.m: -------------------------------------------------------------------------------- 1 | clear; close all; clc; 2 | 3 | LR_folder = './LR'; % LR 4 | preout_folder = './results'; % pre output 5 | save_folder = './results_20if'; 6 | filepaths = dir(fullfile(preout_folder, '*.png')); 7 | max_iter = 20; 8 | 9 | if ~ exist(save_folder, 'dir') 10 | mkdir(save_folder); 11 | end 12 | 13 | for idx_im = 1:length(filepaths) 14 | fprintf([num2str(idx_im) '\n']); 15 | im_name = filepaths(idx_im).name; 16 | im_LR = im2double(imread(fullfile(LR_folder, im_name))); 17 | im_out = im2double(imread(fullfile(preout_folder, im_name))); 18 | J = imresize(im_LR,4,'bicubic'); 19 | %tic 20 | for m = 1:max_iter 21 | im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic')); 22 | end 23 | %toc 24 | imwrite(im_out, fullfile(save_folder, im_name)); 25 | end 26 | -------------------------------------------------------------------------------- /scripts/data_preparation/generate_avg_ct_for_maskguided2022.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import glob 4 | 5 | import numpy as np 6 | 7 | 8 | def main(lq_dir, save_dir, scale): 9 | os.makedirs(save_dir, exist_ok=True) 10 | img_names = os.listdir(lq_dir) 11 | num1s = [os.path.splitext(i)[0].split("_")[0] for i in img_names] 12 | num1s = list(set(num1s)) 13 | for num1 in num1s: 14 | img_total = np.zeros((512 // scale, 512 // scale), dtype=np.float64) 15 | n = 0 16 | img_num1_paths = glob.glob(os.path.join(lq_dir, f"{num1}_*.*")) 17 | for img_num1_path in img_num1_paths: 18 | img = cv2.imread(img_num1_path, flags=cv2.IMREAD_GRAYSCALE) 19 | img_total += img 20 | n += 1 21 | img_avg = img_total / n 22 | cv2.imwrite(os.path.join(save_dir, f"{num1}.png"), np.uint8(img_avg)) 23 | 24 | if __name__ == '__main__': 25 | for d in ["3dircadb", "pancreas"]: 26 | for s in [2, 4]: 27 | for p in ["train", "val", "test"]: 28 | main(f"/home/zhiyi/data/{d}/img/lr_nd/x{s}/{p}", f"/home/zhiyi/data/{d}/img/lr_nd/x{s}/{p}_avg", s) -------------------------------------------------------------------------------- /basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import importlib 3 | from copy import deepcopy 4 | from os import path as osp 5 | 6 | from basicsr.utils import get_root_logger, scandir 7 | from basicsr.utils.registry import LOSS_REGISTRY 8 | 9 | __all__ = ['build_loss'] 10 | 11 | # automatically scan and import loss modules for registry 12 | # scan all the files under the 'losses' folder and collect files ending with '_loss.py' 13 | loss_folder = osp.dirname(osp.abspath(__file__)) 14 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] 15 | # import all the loss modules 16 | _model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames] 17 | 18 | def build_loss(opt): 19 | """Build loss from options. 20 | Args: 21 | opt (dict): Configuration. It must contain: 22 | type (str): Model type. 23 | """ 24 | opt = deepcopy(opt) 25 | loss_type = opt.pop('type') 26 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 27 | logger = get_root_logger() 28 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 29 | return loss -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffjpeg import DiffJPEG 2 | from .file_client import FileClient 3 | from .img_process_util import USMSharp, usm_sharp 4 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 5 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 6 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 7 | 8 | __all__ = [ 9 | # file_client.py 10 | 'FileClient', 11 | # img_util.py 12 | 'img2tensor', 13 | 'tensor2img', 14 | 'imfrombytes', 15 | 'imwrite', 16 | 'crop_border', 17 | # logger.py 18 | 'MessageLogger', 19 | 'AvgTimer', 20 | 'init_tb_logger', 21 | 'init_wandb_logger', 22 | 'get_root_logger', 23 | 'get_env_info', 24 | # misc.py 25 | 'set_random_seed', 26 | 'get_time_str', 27 | 'mkdir_and_rename', 28 | 'make_exp_dirs', 29 | 'scandir', 30 | 'check_resume', 31 | 'sizeof_fmt', 32 | # diffjpeg 33 | 'DiffJPEG', 34 | # img_process_util 35 | 'USMSharp', 36 | 'usm_sharp' 37 | ] 38 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with 12 | # '_model.py' 13 | model_folder = osp.dirname(osp.abspath(__file__)) 14 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 15 | # import all the model modules 16 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 17 | 18 | 19 | def build_model(opt): 20 | """Build model from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must contain: 24 | model_type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 28 | 29 | 30 | logger = get_root_logger() 31 | logger.info(f'Model [{model.__class__.__name__}] is created.') 32 | return model 33 | -------------------------------------------------------------------------------- /LICENSE/LICENSE-stylegan2-pytorch: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/src/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | #include 3 | 4 | 5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 6 | int up_x, int up_y, int down_x, int down_y, 7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 14 | int up_x, int up_y, int down_x, int down_y, 15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(kernel); 18 | 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | #include 3 | 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, 6 | const torch::Tensor& bias, 7 | const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, 15 | const torch::Tensor& bias, 16 | const torch::Tensor& refer, 17 | int act, int grad, float alpha, float scale) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(bias); 20 | 21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 26 | } 27 | -------------------------------------------------------------------------------- /basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with 12 | # '_arch.py' 13 | arch_folder = osp.dirname(osp.abspath(__file__)) 14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 15 | # import all the arch modules] 16 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 17 | 18 | 19 | def build_network(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop('type') 22 | net = ARCH_REGISTRY.get(network_type)(**opt) 23 | logger = get_root_logger() 24 | logger.info(f'Network [{net.__class__.__name__}] is created.') 25 | return net 26 | 27 | if __name__ == "__main__": 28 | arch_folder = osp.dirname(osp.abspath(__file__)) 29 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 30 | print(arch_filenames) -------------------------------------------------------------------------------- /options/test/test_mask_guided_jdnsr_x2_3dircadb.yml: -------------------------------------------------------------------------------- 1 | name: test_mask_guided_jdnsr_x2_3dircadb 2 | suffix: ~ # add suffix to saved images 3 | model_type: Mask_Guided_Model 4 | scale: 2 5 | crop_border: ~ # crop border when evaluation. If None, crop the scale pixels 6 | num_gpu: 1 # set num_gpu: 0 for cpu mode 7 | manual_seed: 0 8 | 9 | datasets: 10 | test_1: 11 | name: 3dircadb_test 12 | type: PairedMASKDataset 13 | dataroot_gt: /home/zhiyi/data/3dircadb/img/hr_nd/test 14 | dataroot_gt_lr: /home/zhiyi/data/3dircadb/img/lr_nd/x2/test 15 | dataroot_lq: /home/zhiyi/data/3dircadb/img/lr_ld/x2/test 16 | dataroot_mask: /home/zhiyi/data/3dircadb/mask/x2/test 17 | dataroot_avg_ct: /home/zhiyi/data/3dircadb/img/lr_ld/x2/test_avg 18 | io_backend: 19 | type: disk 20 | 21 | # network structures 22 | network_g: 23 | type: MaskGuidedJDNSR 24 | scale: 2 25 | 26 | # path 27 | path: 28 | pretrain_network_g: ~ # your model path 29 | strict_load_g: true 30 | 31 | # validation settings 32 | val: 33 | save_img: True 34 | suffix: ~ # add suffix to saved images, if None, use exp name 35 | 36 | metrics: 37 | psnr: # metric name, can be arbitrary 38 | type: calculate_psnr 39 | crop_border: 2 40 | test_y_channel: true 41 | ssim: 42 | type: calculate_ssim 43 | crop_border: 2 44 | test_y_channel: true -------------------------------------------------------------------------------- /options/test/test_mask_guided_jdnsr_x4_3dircadb.yml: -------------------------------------------------------------------------------- 1 | name: test_mask_guided_jdnsr_x4_3dircadb 2 | suffix: ~ # add suffix to saved images 3 | model_type: Mask_Guided_Model 4 | scale: 4 5 | crop_border: ~ # crop border when evaluation. If None, crop the scale pixels 6 | num_gpu: 1 # set num_gpu: 0 for cpu mode 7 | manual_seed: 0 8 | 9 | datasets: 10 | test_1: 11 | name: 3dircadb_test 12 | type: PairedMASKDataset 13 | dataroot_gt: /home/zhiyi/data/3dircadb/img/hr_nd/test 14 | dataroot_gt_lr: /home/zhiyi/data/3dircadb/img/lr_nd/x4/test 15 | dataroot_lq: /home/zhiyi/data/3dircadb/img/lr_ld/x4/test 16 | dataroot_mask: /home/zhiyi/data/3dircadb/mask/x4/test 17 | dataroot_avg_ct: /home/zhiyi/data/3dircadb/img/lr_ld/x4/test_avg 18 | io_backend: 19 | type: disk 20 | 21 | # network structures 22 | network_g: 23 | type: MaskGuidedJDNSR 24 | scale: 4 25 | 26 | # path 27 | path: 28 | pretrain_network_g: ~ # your model path 29 | strict_load_g: true 30 | 31 | # validation settings 32 | val: 33 | save_img: True 34 | suffix: ~ # add suffix to saved images, if None, use exp name 35 | 36 | metrics: 37 | psnr: # metric name, can be arbitrary 38 | type: calculate_psnr 39 | crop_border: 4 40 | test_y_channel: true 41 | ssim: 42 | type: calculate_ssim 43 | crop_border: 4 44 | test_y_channel: true -------------------------------------------------------------------------------- /options/test/test_mask_guided_jdnsr_x4_pancreas.yml: -------------------------------------------------------------------------------- 1 | name: test_mask_guided_jdnsr_x4_pancreas 2 | suffix: ~ # add suffix to saved images 3 | model_type: Mask_Guided_Model 4 | scale: 4 5 | crop_border: ~ # crop border when evaluation. If None, crop the scale pixels 6 | num_gpu: 1 # set num_gpu: 0 for cpu mode 7 | manual_seed: 0 8 | 9 | datasets: 10 | test_1: 11 | name: pancreas_test 12 | type: PairedMASKDataset 13 | dataroot_gt: /home/zhiyi/data/pancreas/img/hr_nd/test 14 | dataroot_gt_lr: /home/zhiyi/data/pancreas/img/lr_nd/x4/test 15 | dataroot_lq: /home/zhiyi/data/pancreas/img/lr_ld/x4/test 16 | dataroot_mask: /home/zhiyi/data/pancreas/mask/x4/test 17 | dataroot_avg_ct: /home/zhiyi/data/pancreas/img/lr_ld/x4/test_avg 18 | io_backend: 19 | type: disk 20 | 21 | # network structures 22 | network_g: 23 | type: MaskGuidedJDNSR 24 | scale: 4 25 | 26 | # path 27 | path: 28 | pretrain_network_g: ~ # your model path 29 | strict_load_g: true 30 | 31 | # validation settings 32 | val: 33 | save_img: True 34 | suffix: ~ # add suffix to saved images, if None, use exp name 35 | 36 | metrics: 37 | psnr: # metric name, can be arbitrary 38 | type: calculate_psnr 39 | crop_border: 4 40 | test_y_channel: true 41 | ssim: 42 | type: calculate_ssim 43 | crop_border: 4 44 | test_y_channel: true -------------------------------------------------------------------------------- /options/test/test_mask_guided_jdnsr_x2_pancreas.yml: -------------------------------------------------------------------------------- 1 | name: test_mask_guided_jdnsr_x2_pancreas 2 | suffix: ~ # add suffix to saved images 3 | model_type: Mask_Guided_Model 4 | scale: 2 5 | crop_border: ~ # crop border when evaluation. If None, crop the scale pixels 6 | num_gpu: 1 # set num_gpu: 0 for cpu mode 7 | manual_seed: 0 8 | 9 | datasets: 10 | test_1: 11 | name: pancreas_test 12 | type: PairedMASKDataset 13 | dataroot_gt: /home/zhiyi/data/pancreas/img/hr_nd/test 14 | dataroot_gt_lr: /home/zhiyi/data/pancreas/img/lr_nd/x2/test 15 | dataroot_lq: /home/zhiyi/data/pancreas/img/lr_ld/x2/test 16 | dataroot_mask: /home/zhiyi/data/pancreas/mask/x2/test 17 | dataroot_avg_ct: /home/zhiyi/data/pancreas/img/lr_ld/x2/test_avg 18 | io_backend: 19 | type: disk 20 | 21 | # network structures 22 | network_g: 23 | type: MaskGuidedJDNSR 24 | scale: 2 25 | 26 | 27 | # path 28 | path: 29 | pretrain_network_g: ~ # your model path 30 | strict_load_g: true 31 | 32 | # validation settings 33 | val: 34 | save_img: True 35 | suffix: ~ # add suffix to saved images, if None, use exp name 36 | 37 | metrics: 38 | psnr: # metric name, can be arbitrary 39 | type: calculate_psnr 40 | crop_border: 2 41 | test_y_channel: true 42 | ssim: 43 | type: calculate_ssim 44 | crop_border: 2 45 | test_y_channel: true -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /scripts/matlab_scripts/generate_LR_Vimeo90K.m: -------------------------------------------------------------------------------- 1 | function generate_LR_Vimeo90K() 2 | %% matlab code to genetate bicubic-downsampled for Vimeo90K dataset 3 | 4 | up_scale = 4; 5 | mod_scale = 4; 6 | idx = 0; 7 | filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png'); 8 | for i = 1 : length(filepaths) 9 | [~,imname,ext] = fileparts(filepaths(i).name); 10 | folder_path = filepaths(i).folder; 11 | save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4'); 12 | if ~exist(save_LR_folder, 'dir') 13 | mkdir(save_LR_folder); 14 | end 15 | if isempty(imname) 16 | disp('Ignore . folder.'); 17 | elseif strcmp(imname, '.') 18 | disp('Ignore .. folder.'); 19 | else 20 | idx = idx + 1; 21 | str_result = sprintf('%d\t%s.\n', idx, imname); 22 | fprintf(str_result); 23 | % read image 24 | img = imread(fullfile(folder_path, [imname, ext])); 25 | img = im2double(img); 26 | % modcrop 27 | img = modcrop(img, mod_scale); 28 | % LR 29 | im_LR = imresize(img, 1/up_scale, 'bicubic'); 30 | if exist('save_LR_folder', 'var') 31 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); 32 | end 33 | end 34 | end 35 | end 36 | 37 | %% modcrop 38 | function img = modcrop(img, modulo) 39 | if size(img,3) == 1 40 | sz = size(img); 41 | sz = sz - mod(sz, modulo); 42 | img = img(1:sz(1), 1:sz(2)); 43 | else 44 | tmpsz = size(img); 45 | sz = tmpsz(1:2); 46 | sz = sz - mod(sz, modulo); 47 | img = img(1:sz(1), 1:sz(2),:); 48 | end 49 | end 50 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # flake8 3 | - repo: https://github.com/PyCQA/flake8 4 | rev: 3.8.3 5 | hooks: 6 | - id: flake8 7 | args: ["--config=setup.cfg", "--ignore=W504, W503"] 8 | 9 | # modify known_third_party 10 | - repo: https://github.com/asottile/seed-isort-config 11 | rev: v2.2.0 12 | hooks: 13 | - id: seed-isort-config 14 | 15 | # isort 16 | - repo: https://github.com/timothycrosley/isort 17 | rev: 5.2.2 18 | hooks: 19 | - id: isort 20 | 21 | # yapf 22 | - repo: https://github.com/pre-commit/mirrors-yapf 23 | rev: v0.30.0 24 | hooks: 25 | - id: yapf 26 | 27 | # codespell 28 | - repo: https://github.com/codespell-project/codespell 29 | rev: v2.1.0 30 | hooks: 31 | - id: codespell 32 | 33 | # pre-commit-hooks 34 | - repo: https://github.com/pre-commit/pre-commit-hooks 35 | rev: v3.2.0 36 | hooks: 37 | - id: trailing-whitespace # Trim trailing whitespace 38 | - id: check-yaml # Attempt to load all yaml files to verify syntax 39 | - id: check-merge-conflict # Check for files that contain merge conflict strings 40 | - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings 41 | - id: end-of-file-fixer # Make sure files end in a newline and only a newline 42 | - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0 43 | - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*- 44 | args: ["--remove"] 45 | - id: mixed-line-ending # Replace or check mixed line ending 46 | args: ["--fix=lf"] 47 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import torch 5 | from os import path as osp 6 | 7 | from basicsr.data import build_dataloader, build_dataset 8 | from basicsr.models import build_model 9 | from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs 10 | from basicsr.utils.options import dict2str, parse_options 11 | 12 | 13 | def test_pipeline(root_path): 14 | # parse options, set distributed setting, set ramdom seed 15 | opt, _ = parse_options(root_path, is_train=False) 16 | torch.backends.cudnn.benchmark = True 17 | # torch.backends.cudnn.deterministic = True 18 | 19 | # mkdir and initialize loggers 20 | make_exp_dirs(opt) 21 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") 22 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 23 | logger.info(get_env_info()) 24 | logger.info(dict2str(opt)) 25 | 26 | # create test dataset and dataloader 27 | test_loaders = [] 28 | for _, dataset_opt in sorted(opt['datasets'].items()): 29 | test_set = build_dataset(dataset_opt) 30 | test_loader = build_dataloader( 31 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) 32 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 33 | test_loaders.append(test_loader) 34 | 35 | # create model 36 | model = build_model(opt) 37 | 38 | for test_loader in test_loaders: 39 | test_set_name = test_loader.dataset.opt['name'] 40 | logger.info(f'Testing {test_set_name}...') 41 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) 42 | 43 | 44 | if __name__ == '__main__': 45 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 46 | test_pipeline(root_path) 47 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas # 1 24 | self.rank = rank # 0 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) # dataset的len在类里定义了,就是文件夹内图片的数量 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() # 用于生成随机数 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() # 返回一个[0,total_size]的打乱的列表 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank: self.total_size: self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /scripts/generate_bicubic_sr_img.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | from PIL import Image 4 | 5 | def main(dir, save_dir, scale): 6 | img_list = os.listdir(dir) 7 | if not os.path.isdir(save_dir): 8 | os.makedirs(save_dir) 9 | for img_name in tqdm(img_list): 10 | img_path = os.path.join(dir, img_name) 11 | img = Image.open(img_path) 12 | w, h = img.size 13 | img = img.resize((w*scale, h*scale), Image.BICUBIC) 14 | save_path = os.path.join(save_dir, img_name) 15 | img.save(save_path) 16 | 17 | # for name in ["belly", "lung"]: 18 | # for scale in [2, 4, 8]: 19 | # main(f"/home/zhiyi/data/medical/{name}/x{scale}_valid", 20 | # f"/home/zhiyi/data/medical/{name}/x{scale}_valid_bicubic_sr", 21 | # scale) 22 | 23 | # for a, b in zip( 24 | # ["x4_coronal", "x4_sagittal", "x4_val_artifact1e3", "x4_val_cubic", "x4_val_linear", "x4_val_nearest", "x4_val_pepper", "x4_val_gaussian", "x4_multiHU/hu_-40_110", "x4_multiHU/hu_-100_900", "x4_multiHU/hu_-926_26"][3:6], 25 | # ["x4_coronal_bicubic_sr", "x4_sagittal_bicubic_sr", "x4_val_artifact1e3_bicubic_sr", "x4_val_cubic_bicubic_sr", "x4_val_linear_bicubic_sr", "x4_val_nearest_bicubic_sr", "x4_val_pepper_bicubic_sr", "x4_val_gaussian_bicubic_sr", "x4_multiHU/hu_-40_110_bicubic_sr", "x4_multiHU/hu_-100_900_bicubic_sr", "x4_multiHU/hu_-926_26_bicubic_sr"][3:6] 26 | # ): 27 | # main(f"/home/zhiyi/data/medical/belly/zt_extra/{a}", f"/home/zhiyi/data/medical/belly/zt_extra/{b}", 4) 28 | 29 | # for a, b in zip( 30 | # ["x2_val_cubic", "x2_val_linear", "x2_val_nearest"], 31 | # ["x2_val_cubic_bicubic_sr", "x2_val_linear_bicubic_sr", "x2_val_nearest_bicubic_sr"] 32 | # ): 33 | # main(f"/home/zhiyi/data/medical/belly/zt_extra/{a}", f"/home/zhiyi/data/medical/belly/zt_extra/{b}", 2) 34 | 35 | 36 | if __name__ == '__main__': 37 | for data in ["3dircadb", "pancreas"]: 38 | for s in [2, 4]: 39 | for p in ["val", "test"]: 40 | main(f"/home/zhiyi/data/{data}/img/lr_ld/x{s}/{p}", f"/home/zhiyi/data/{data}/img/lr_ld_bicubic/x{s}/{p}", s) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ignored folders 2 | datasets/* 3 | experiments/* 4 | results/* 5 | tb_logger/* 6 | wandb/* 7 | tmp/* 8 | 9 | *.DS_Store 10 | .idea 11 | 12 | # ignored files 13 | version.py 14 | 15 | # ignored files with suffix 16 | *.html 17 | *.png 18 | *.jpeg 19 | *.jpg 20 | *.gif 21 | *.pth 22 | *.zip 23 | 24 | # template 25 | 26 | # Byte-compiled / optimized / DLL files 27 | __pycache__/ 28 | *.py[cod] 29 | *$py.class 30 | 31 | # C extensions 32 | *.so 33 | 34 | # Distribution / packaging 35 | .Python 36 | build/ 37 | develop-eggs/ 38 | dist/ 39 | downloads/ 40 | eggs/ 41 | .eggs/ 42 | lib/ 43 | lib64/ 44 | parts/ 45 | sdist/ 46 | var/ 47 | wheels/ 48 | *.egg-info/ 49 | .installed.cfg 50 | *.egg 51 | MANIFEST 52 | 53 | # PyInstaller 54 | # Usually these files are written by a python script from a template 55 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 56 | *.manifest 57 | *.spec 58 | 59 | # Installer logs 60 | pip-log.txt 61 | pip-delete-this-directory.txt 62 | 63 | # Unit test / coverage reports 64 | htmlcov/ 65 | .tox/ 66 | .coverage 67 | .coverage.* 68 | .cache 69 | nosetests.xml 70 | coverage.xml 71 | *.cover 72 | .hypothesis/ 73 | .pytest_cache/ 74 | 75 | # Translations 76 | *.mo 77 | *.pot 78 | 79 | # Django stuff: 80 | *.log 81 | local_settings.py 82 | db.sqlite3 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # pyenv 101 | .python-version 102 | 103 | # celery beat schedule file 104 | celerybeat-schedule 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | 131 | 132 | .vscode/ 133 | .idea/ 134 | -------------------------------------------------------------------------------- /LICENSE/README.md: -------------------------------------------------------------------------------- 1 | # License and Acknowledgement 2 | 3 | This BasicSR project is released under the Apache 2.0 license. 4 | 5 | - StyleGAN2 6 | - The codes are modified from the repository [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch). Many thanks to the author - [Kim Seonghyeon](https://rosinality.github.io/) :blush: for translating from the official TensorFlow codes to PyTorch ones. Here is the [license](LICENSE-stylegan2-pytorch) of stylegan2-pytorch. 7 | - The official repository is https://github.com/NVlabs/stylegan2, and here is the [NVIDIA license](./LICENSE-NVIDIA). 8 | - DFDNet 9 | - The codes are largely modified from the repository [DFDNet](https://github.com/csxmli2016/DFDNet). Their license is [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by-nc-sa/4.0/). 10 | - DiffJPEG 11 | - Modified from https://github.com/mlomnitz/DiffJPEG. 12 | - [pytorch-image-models](https://github.com/rwightman/pytorch-image-models/) 13 | - We use the implementation of `DropPath` and `trunc_normal_` from [pytorch-image-models](https://github.com/rwightman/pytorch-image-models/). The LICENSE is included as [LICENSE_pytorch-image-models](LICENSE/LICENSE_pytorch-image-models). 14 | - [SwinIR](https://github.com/JingyunLiang/SwinIR) 15 | - The arch implementation of SwinIR is from [SwinIR](https://github.com/JingyunLiang/SwinIR). The LICENSE is included as [LICENSE_SwinIR](LICENSE/LICENSE_SwinIR). 16 | - [ECBSR](https://github.com/xindongzhang/ECBSR) 17 | - The arch implementation of ECBSR is from [ECBSR](https://github.com/xindongzhang/ECBSR). The LICENSE of ECBSR is [Apache License 2.0](https://github.com/xindongzhang/ECBSR/blob/main/LICENSE) 18 | 19 | ## References 20 | 21 | 1. NIQE metric: the codes are translated from the [official MATLAB codes](http://live.ece.utexas.edu/research/quality/niqe_release.zip) 22 | 23 | > A. Mittal, R. Soundararajan and A. C. Bovik, "Making a Completely Blind Image Quality Analyzer", IEEE Signal Processing Letters, 2012. 24 | 25 | 1. FID metric: the codes are modified from [pytorch-fid](https://github.com/mseitzer/pytorch-fid) and [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch). 26 | -------------------------------------------------------------------------------- /scripts/test_dual_guided_jdnsr_2022.py: -------------------------------------------------------------------------------- 1 | from scripts.metrics.calculate_psnr_ssim import main3 2 | import os 3 | import numpy as np 4 | from multiprocessing import Process 5 | 6 | 7 | def sort_imgs(name): 8 | if "." in name: 9 | basename = os.path.splitext(name)[0] 10 | num1, num2 = basename.split("_") 11 | return int(num1) * 10000 + int(num2) 12 | 13 | 14 | def _get_psnr_and_ssim_for_bicubic(gts_dir, restored_dir, scale): 15 | gts = os.listdir(gts_dir) 16 | restoreds = os.listdir(restored_dir) 17 | gts = sorted(gts, key=sort_imgs) 18 | restoreds = sorted(restoreds, key=sort_imgs) 19 | gts = [os.path.join(gts_dir, i) for i in gts] 20 | restoreds = [os.path.join(restored_dir, j) for j in restoreds] 21 | psnr, ssim = main3(gts, restoreds, test_y_channel=True, crop_border=scale) 22 | print(f"{restored_dir} x{scale} \n PSNR AVG: {np.mean(psnr)} PSNR STD: {np.std(psnr)} \n SSIM AVG: {np.mean(ssim)} SSIM STD: {np.std(ssim)}") 23 | 24 | def get_psnr_and_ssim_for_bicubic(): 25 | # 在172.17.27.170运行 26 | 27 | bicubic_root_dir_3dircadb_x2 = "/home/zhiyi/data/3dircadb/img/lr_ld_bicubic/x2/test" 28 | bicubic_root_dir_3dircadb_x4 = "/home/zhiyi/data/3dircadb/img/lr_ld_bicubic/x4/test" 29 | bicubic_root_dir_pancreas_x2 = "/home/zhiyi/data/pancreas/img/lr_ld_bicubic/x2/test" 30 | bicubic_root_dir_pancreas_x4 = "/home/zhiyi/data/pancreas/img/lr_ld_bicubic/x4/test" 31 | 32 | gt_dir_3dircadb = "/home/zhiyi/data/3dircadb/img/hr_nd/test" 33 | gt_dir_pancreas = "/home/zhiyi/data/pancreas/img/hr_nd/test" 34 | 35 | tasks = [] 36 | tasks.append( 37 | Process(target=_get_psnr_and_ssim_for_bicubic, args=(gt_dir_3dircadb, bicubic_root_dir_3dircadb_x2, 2)) 38 | ) 39 | tasks.append( 40 | Process(target=_get_psnr_and_ssim_for_bicubic, args=(gt_dir_3dircadb, bicubic_root_dir_3dircadb_x4, 4)) 41 | ) 42 | tasks.append( 43 | Process(target=_get_psnr_and_ssim_for_bicubic, args=(gt_dir_pancreas, bicubic_root_dir_pancreas_x2, 2)) 44 | ) 45 | tasks.append( 46 | Process(target=_get_psnr_and_ssim_for_bicubic, args=(gt_dir_pancreas, bicubic_root_dir_pancreas_x4, 4)) 47 | ) 48 | 49 | for task in tasks: 50 | task.start() 51 | 52 | get_psnr_and_ssim_for_bicubic() -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import requests 3 | from tqdm import tqdm 4 | 5 | from .misc import sizeof_fmt 6 | 7 | 8 | def download_file_from_google_drive(file_id, save_path): 9 | """Download files from google drive. 10 | 11 | Ref: 12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 13 | 14 | Args: 15 | file_id (str): File id. 16 | save_path (str): Save path. 17 | """ 18 | 19 | session = requests.Session() 20 | URL = 'https://docs.google.com/uc?export=download' 21 | params = {'id': file_id} 22 | 23 | response = session.get(URL, params=params, stream=True) 24 | token = get_confirm_token(response) 25 | if token: 26 | params['confirm'] = token 27 | response = session.get(URL, params=params, stream=True) 28 | 29 | # get file size 30 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 31 | if 'Content-Range' in response_file_size.headers: 32 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 33 | else: 34 | file_size = None 35 | 36 | save_response_content(response, save_path, file_size) 37 | 38 | 39 | def get_confirm_token(response): 40 | for key, value in response.cookies.items(): 41 | if key.startswith('download_warning'): 42 | return value 43 | return None 44 | 45 | 46 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 47 | if file_size is not None: 48 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 49 | 50 | readable_file_size = sizeof_fmt(file_size) 51 | else: 52 | pbar = None 53 | 54 | with open(destination, 'wb') as f: 55 | downloaded_size = 0 56 | for chunk in response.iter_content(chunk_size): 57 | downloaded_size += chunk_size 58 | if pbar is not None: 59 | pbar.update(1) 60 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 61 | if chunk: # filter out keep-alive new chunks 62 | f.write(chunk) 63 | if pbar is not None: 64 | pbar.close() 65 | -------------------------------------------------------------------------------- /basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry: 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj): 39 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 40 | f"in '{self._name}' registry!") 41 | self._obj_map[name] = obj 42 | 43 | 44 | def register(self, obj=None): 45 | """ 46 | Register the given object under the the name `obj.__name__`. 47 | Can be used as either a decorator or not. 48 | See docstring of this class for usage. 49 | """ 50 | if obj is None: 51 | # 如果是空的,register等价于deco(是一个装饰器,用来登记修饰的函数) 52 | # used as a decorator 53 | def deco(func_or_class): 54 | name = func_or_class.__name__ 55 | self._do_register(name, func_or_class) # 使self._obj_map[name] = func_or_class 如果已经有了会报错 56 | return func_or_class 57 | 58 | return deco 59 | 60 | # used as a function call 61 | name = obj.__name__ 62 | self._do_register(name, obj) 63 | 64 | def get(self, name): 65 | ret = self._obj_map.get(name) 66 | if ret is None: 67 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 68 | return ret 69 | 70 | def __contains__(self, name): 71 | return name in self._obj_map 72 | 73 | def __iter__(self): 74 | return iter(self._obj_map.items()) 75 | 76 | def keys(self): 77 | return self._obj_map.keys() 78 | 79 | 80 | DATASET_REGISTRY = Registry('dataset') 81 | ARCH_REGISTRY = Registry('arch') 82 | MODEL_REGISTRY = Registry('model') 83 | LOSS_REGISTRY = Registry('loss') 84 | METRIC_REGISTRY = Registry('metric') 85 | -------------------------------------------------------------------------------- /scripts/matlab_scripts/generate_bicubic_img.m: -------------------------------------------------------------------------------- 1 | function generate_bicubic_img() 2 | %% matlab code to genetate mod images, bicubic-downsampled images and 3 | %% bicubic_upsampled images 4 | 5 | %% set configurations 6 | % comment the unnecessary lines 7 | input_folder = '../../datasets/Set5/original'; 8 | save_mod_folder = '../../datasets/Set5/GTmod12'; 9 | save_lr_folder = '../../datasets/Set5/LRbicx2'; 10 | % save_bic_folder = ''; 11 | 12 | mod_scale = 12; 13 | up_scale = 2; 14 | 15 | if exist('save_mod_folder', 'var') 16 | if exist(save_mod_folder, 'dir') 17 | disp(['It will cover ', save_mod_folder]); 18 | else 19 | mkdir(save_mod_folder); 20 | end 21 | end 22 | if exist('save_lr_folder', 'var') 23 | if exist(save_lr_folder, 'dir') 24 | disp(['It will cover ', save_lr_folder]); 25 | else 26 | mkdir(save_lr_folder); 27 | end 28 | end 29 | if exist('save_bic_folder', 'var') 30 | if exist(save_bic_folder, 'dir') 31 | disp(['It will cover ', save_bic_folder]); 32 | else 33 | mkdir(save_bic_folder); 34 | end 35 | end 36 | 37 | idx = 0; 38 | filepaths = dir(fullfile(input_folder,'*.*')); 39 | for i = 1 : length(filepaths) 40 | [paths, img_name, ext] = fileparts(filepaths(i).name); 41 | if isempty(img_name) 42 | disp('Ignore . folder.'); 43 | elseif strcmp(img_name, '.') 44 | disp('Ignore .. folder.'); 45 | else 46 | idx = idx + 1; 47 | str_result = sprintf('%d\t%s.\n', idx, img_name); 48 | fprintf(str_result); 49 | 50 | % read image 51 | img = imread(fullfile(input_folder, [img_name, ext])); 52 | img = im2double(img); 53 | 54 | % modcrop 55 | img = modcrop(img, mod_scale); 56 | if exist('save_mod_folder', 'var') 57 | imwrite(img, fullfile(save_mod_folder, [img_name, '.png'])); 58 | end 59 | 60 | % LR 61 | im_lr = imresize(img, 1/up_scale, 'bicubic'); 62 | if exist('save_lr_folder', 'var') 63 | imwrite(im_lr, fullfile(save_lr_folder, [img_name, '.png'])); 64 | end 65 | 66 | % Bicubic 67 | if exist('save_bic_folder', 'var') 68 | im_bicubic = imresize(im_lr, up_scale, 'bicubic'); 69 | imwrite(im_bicubic, fullfile(save_bic_folder, [img_name, '.png'])); 70 | end 71 | end 72 | end 73 | end 74 | 75 | %% modcrop 76 | function img = modcrop(img, modulo) 77 | if size(img,3) == 1 78 | sz = size(img); 79 | sz = sz - mod(sz, modulo); 80 | img = img(1:sz(1), 1:sz(2)); 81 | else 82 | tmpsz = size(img); 83 | sz = tmpsz(1:2); 84 | sz = sz - mod(sz, modulo); 85 | img = img(1:sz(1), 1:sz(2),:); 86 | end 87 | end 88 | -------------------------------------------------------------------------------- /basicsr/utils/img_process_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def filter2D(img, kernel): 8 | """PyTorch version of cv2.filter2D 9 | 10 | Args: 11 | img (Tensor): (b, c, h, w) 12 | kernel (Tensor): (b, k, k) 13 | """ 14 | k = kernel.size(-1) 15 | b, c, h, w = img.size() 16 | if k % 2 == 1: 17 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') 18 | else: 19 | raise ValueError('Wrong kernel size') 20 | 21 | ph, pw = img.size()[-2:] 22 | 23 | if kernel.size(0) == 1: 24 | # apply the same kernel to all batch images 25 | img = img.view(b * c, 1, ph, pw) 26 | kernel = kernel.view(1, 1, k, k) 27 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w) 28 | else: 29 | img = img.view(1, b * c, ph, pw) 30 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) 31 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) 32 | 33 | 34 | def usm_sharp(img, weight=0.5, radius=50, threshold=10): 35 | """USM sharpening. 36 | 37 | Input image: I; Blurry image: B. 38 | 1. sharp = I + weight * (I - B) 39 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 40 | 3. Blur mask: 41 | 4. Out = Mask * sharp + (1 - Mask) * I 42 | 43 | 44 | Args: 45 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 46 | weight (float): Sharp weight. Default: 1. 47 | radius (float): Kernel size of Gaussian blur. Default: 50. 48 | threshold (int): 49 | """ 50 | if radius % 2 == 0: 51 | radius += 1 52 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 53 | residual = img - blur 54 | mask = np.abs(residual) * 255 > threshold 55 | mask = mask.astype('float32') 56 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 57 | 58 | sharp = img + weight * residual 59 | sharp = np.clip(sharp, 0, 1) 60 | return soft_mask * sharp + (1 - soft_mask) * img 61 | 62 | 63 | class USMSharp(torch.nn.Module): 64 | 65 | def __init__(self, radius=50, sigma=0): 66 | super(USMSharp, self).__init__() 67 | if radius % 2 == 0: 68 | radius += 1 69 | self.radius = radius 70 | kernel = cv2.getGaussianKernel(radius, sigma) 71 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) 72 | self.register_buffer('kernel', kernel) 73 | 74 | def forward(self, img, weight=0.5, threshold=10): 75 | blur = filter2D(img, self.kernel) 76 | residual = img - blur 77 | 78 | mask = torch.abs(residual) * 255 > threshold 79 | mask = mask.float() 80 | soft_mask = filter2D(mask, self.kernel) 81 | sharp = img + weight * residual 82 | sharp = torch.clip(sharp, 0, 1) 83 | return soft_mask * sharp + (1 - soft_mask) * img 84 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README 2 | **[MICCAI2023]** The official implementation of **Low-dose CT image super-resolution network with dual-guidance feature distillation and dual-path content communication**. 3 | 4 | This repository is modified from [BasicSR](https://github.com/XPixelGroup/BasicSR). Thanks for the open source code of [BasicSR](https://github.com/XPixelGroup/BasicSR). 5 | ## Installation 6 | ```bash 7 | conda create -n new_env python=3.9.7 -y 8 | conda activate new_env 9 | pip install -r requirements.txt 10 | pip install -e . 11 | ``` 12 | More details could be found in [the installation ducoment of BasicSR](https://github.com/XPixelGroup/BasicSR/blob/master/docs/INSTALL.md). 13 | ## Data preparation 14 | You should prepare your data in this way: 15 | ``` 16 | data_rootdir 17 | - dataset_name 18 | - img 19 | - hr_nd 20 | - train 21 | - val 22 | - test 23 | - lr_ld 24 | - x2 25 | - train 26 | - train_avg 27 | - val 28 | - val_avg 29 | - test 30 | - test_avg 31 | - x4 32 | - train 33 | - train_avg 34 | - val 35 | - val_avg 36 | - test 37 | - test_avg 38 | - lr_nd 39 | - x2 40 | - train 41 | - val 42 | - test 43 | - x4 44 | - train 45 | - val 46 | - test 47 | -mask 48 | - hr 49 | - train 50 | - val 51 | - test 52 | - x2 53 | - train 54 | - val 55 | - test 56 | - x4 57 | - train 58 | - val 59 | - test 60 | ``` 61 | And you should modify the path in configuration files in "opations/train/\*.yml" or "opations/test/\*.yml". 62 | ## Training 63 | Run: 64 | ```bash 65 | python basicsr/train.py --opt options/train/your_config_file.yml 66 | ``` 67 | The model files will be saved in "experiments" folder. 68 | ## Testing 69 | Firstly, you should modify the model paths in "opations/test/\*.yml". 70 | Then, run: 71 | ```bash 72 | python basicsr/test.py --opt options/test/your_config_file.yml 73 | ``` 74 | The results will be saved in "results" folder. 75 | 76 | An example, including models and dataset, could be found in [BaiduDisk:z3gy](https://pan.baidu.com/s/1l7lXLCOJWeQVZOt_ldtGbg). 77 | 78 | ## Cite 79 | ``` 80 | @inproceedings{chi2023low, 81 | title={Low-Dose CT Image Super-Resolution Network with Dual-Guidance Feature Distillation and Dual-Path Content Communication}, 82 | author={Chi, Jianning and Sun, Zhiyi and Zhao, Tianli and Wang, Huan and Yu, Xiaosheng and Wu, Chengdong}, 83 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 84 | pages={98--108}, 85 | year={2023}, 86 | organization={Springer} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /options/train/train_mask_guided_jdnsr_x2_3dircadb.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_mask_guided_jdnsr_x2_3dircadb 3 | model_type: Mask_Guided_Model 4 | num_gpu: 1 # set num_gpu: 0 for cpu mode 5 | manual_seed: 10 6 | scale: 2 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: 3dircadb_train 12 | type: PairedMASKDataset 13 | dataroot_gt: /home/zhiyi/data/3dircadb/img/hr_nd/train 14 | dataroot_gt_lr: /home/zhiyi/data/3dircadb/img/lr_nd/x2/train 15 | dataroot_lq: /home/zhiyi/data/3dircadb/img/lr_ld/x2/train 16 | dataroot_mask: /home/zhiyi/data/3dircadb/mask/x2/train 17 | dataroot_avg_ct: /home/zhiyi/data/3dircadb/img/lr_ld/x2/train_avg 18 | 19 | filename_tmpl: '{}' 20 | io_backend: 21 | type: disk 22 | 23 | gt_size: 128 24 | use_flip: true 25 | use_rot: true 26 | 27 | # data loader 28 | use_shuffle: true 29 | num_worker_per_gpu: 6 30 | batch_size_per_gpu: 16 31 | dataset_enlarge_ratio: 100 32 | prefetch_mode: ~ 33 | 34 | val: 35 | name: 3dircadb_val 36 | type: PairedMASKDataset 37 | dataroot_gt: /home/zhiyi/data/3dircadb/img/hr_nd/val 38 | dataroot_gt_lr: /home/zhiyi/data/3dircadb/img/lr_nd/x2/val 39 | dataroot_lq: /home/zhiyi/data/3dircadb/img/lr_ld/x2/val 40 | dataroot_mask: /home/zhiyi/data/3dircadb/mask/x2/val 41 | dataroot_avg_ct: /home/zhiyi/data/3dircadb/img/lr_ld/x2/val_avg 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: MaskGuidedJDNSR 48 | scale: 2 49 | num_feat: 64 50 | mode: jdnsr 51 | num_block: 10 52 | 53 | # path 54 | path: 55 | pretrain_network_g: ~ 56 | strict_load_g: true 57 | resume_state: ~ 58 | 59 | # training settings 60 | train: 61 | ema_decay: 0 62 | optim_g: 63 | type: Adam 64 | lr: !!float 1e-4 65 | weight_decay: 0 66 | betas: [0.9, 0.99] 67 | 68 | scheduler: 69 | type: MultiStepLR 70 | milestones: [200000, 400000] 71 | gamma: 0.5 72 | 73 | total_iter: 250000 74 | warmup_iter: -1 # no warm up 75 | 76 | # losses 77 | hrLD_pixel_opt: 78 | type: L1Loss 79 | loss_weight: 0.2 80 | reduction: mean 81 | 82 | LRnd_pixel_opt: 83 | type: L1Loss 84 | loss_weight: 0.2 85 | reduction: mean 86 | 87 | hrnd_pixel_opt: 88 | type: L1Loss 89 | loss_weight: 1.0 90 | 91 | # tv_opt: 92 | # type: WeightedTVLoss 93 | # loss_weight: 0.2 94 | 95 | # validation settings 96 | val: 97 | val_freq: !!float 1e4 98 | save_img: false 99 | 100 | metrics: 101 | psnr: # metric name, can be arbitrary 102 | type: calculate_psnr 103 | crop_border: 2 104 | test_y_channel: true 105 | 106 | ssim: # metric name, can be arbitrary 107 | type: calculate_ssim 108 | crop_border: 2 109 | test_y_channel: true 110 | 111 | # logging settings 112 | logger: 113 | print_freq: 100 114 | save_checkpoint_freq: !!float 1e4 115 | use_tb_logger: true 116 | wandb: 117 | project: ~ 118 | resume_id: ~ 119 | 120 | # dist training settings 121 | dist_params: 122 | backend: nccl 123 | port: 29500 124 | -------------------------------------------------------------------------------- /options/train/train_mask_guided_jdnsr_x2_pancreas.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_mask_guided_jdnsr_x2_pancreas 3 | model_type: Mask_Guided_Model 4 | num_gpu: 1 # set num_gpu: 0 for cpu mode 5 | manual_seed: 10 6 | scale: 2 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: pancreas_train 12 | type: PairedMASKDataset 13 | dataroot_gt: /home/zhiyi/data/pancreas/img/hr_nd/train 14 | dataroot_gt_lr: /home/zhiyi/data/pancreas/img/lr_nd/x2/train 15 | dataroot_lq: /home/zhiyi/data/pancreas/img/lr_ld/x2/train 16 | dataroot_mask: /home/zhiyi/data/pancreas/mask/x2/train 17 | dataroot_avg_ct: /home/zhiyi/data/pancreas/img/lr_ld/x2/train_avg 18 | 19 | filename_tmpl: '{}' 20 | io_backend: 21 | type: disk 22 | 23 | gt_size: 128 24 | use_flip: true 25 | use_rot: true 26 | 27 | # data loader 28 | use_shuffle: true 29 | num_worker_per_gpu: 6 30 | batch_size_per_gpu: 16 31 | dataset_enlarge_ratio: 100 32 | prefetch_mode: ~ 33 | 34 | val: 35 | name: pancreas_val 36 | type: PairedMASKDataset 37 | dataroot_gt: /home/zhiyi/data/pancreas/img/hr_nd/val 38 | dataroot_gt_lr: /home/zhiyi/data/pancreas/img/lr_nd/x2/val 39 | dataroot_lq: /home/zhiyi/data/pancreas/img/lr_ld/x2/val 40 | dataroot_mask: /home/zhiyi/data/pancreas/mask/x2/val 41 | dataroot_avg_ct: /home/zhiyi/data/pancreas/img/lr_ld/x2/val_avg 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: MaskGuidedJDNSR 48 | scale: 2 49 | num_feat: 64 50 | mode: jdnsr 51 | num_block: 10 52 | 53 | # path 54 | path: 55 | pretrain_network_g: ~ 56 | strict_load_g: true 57 | resume_state: ~ 58 | 59 | # training settings 60 | train: 61 | ema_decay: 0 62 | optim_g: 63 | type: Adam 64 | lr: !!float 1e-4 65 | weight_decay: 0 66 | betas: [0.9, 0.99] 67 | 68 | scheduler: 69 | type: MultiStepLR 70 | milestones: [200000, 400000] 71 | gamma: 0.5 72 | 73 | total_iter: 250000 74 | warmup_iter: -1 # no warm up 75 | 76 | # losses 77 | hrLD_pixel_opt: 78 | type: L1Loss 79 | loss_weight: 0.2 80 | reduction: mean 81 | 82 | LRnd_pixel_opt: 83 | type: L1Loss 84 | loss_weight: 0.2 85 | reduction: mean 86 | 87 | hrnd_pixel_opt: 88 | type: L1Loss 89 | loss_weight: 1.0 90 | 91 | # tv_opt: 92 | # type: WeightedTVLoss 93 | # loss_weight: 0.2 94 | 95 | # validation settings 96 | val: 97 | val_freq: !!float 1e4 98 | save_img: false 99 | 100 | metrics: 101 | psnr: # metric name, can be arbitrary 102 | type: calculate_psnr 103 | crop_border: 2 104 | test_y_channel: true 105 | 106 | ssim: # metric name, can be arbitrary 107 | type: calculate_ssim 108 | crop_border: 2 109 | test_y_channel: true 110 | 111 | # logging settings 112 | logger: 113 | print_freq: 100 114 | save_checkpoint_freq: !!float 1e4 115 | use_tb_logger: true 116 | wandb: 117 | project: ~ 118 | resume_id: ~ 119 | 120 | # dist training settings 121 | dist_params: 122 | backend: nccl 123 | port: 29500 124 | -------------------------------------------------------------------------------- /options/train/train_mask_guided_jdnsr_x4_3dircadb.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_mask_guided_jdnsr_x4_3dircadb 3 | model_type: Mask_Guided_Model 4 | num_gpu: 1 # set num_gpu: 0 for cpu mode 5 | manual_seed: 10 6 | scale: 4 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: 3dircadb_train 12 | type: PairedMASKDataset 13 | dataroot_gt: /home/zhiyi/data/3dircadb/img/hr_nd/train 14 | dataroot_gt_lr: /home/zhiyi/data/3dircadb/img/lr_nd/x4/train 15 | dataroot_lq: /home/zhiyi/data/3dircadb/img/lr_ld/x4/train 16 | dataroot_mask: /home/zhiyi/data/3dircadb/mask/x4/train 17 | dataroot_avg_ct: /home/zhiyi/data/3dircadb/img/lr_ld/x4/train_avg 18 | 19 | filename_tmpl: '{}' 20 | io_backend: 21 | type: disk 22 | 23 | gt_size: 128 24 | use_flip: true 25 | use_rot: true 26 | 27 | # data loader 28 | use_shuffle: true 29 | num_worker_per_gpu: 6 30 | batch_size_per_gpu: 16 31 | dataset_enlarge_ratio: 100 32 | prefetch_mode: ~ 33 | 34 | val: 35 | name: 3dircadb_val 36 | type: PairedMASKDataset 37 | dataroot_gt: /home/zhiyi/data/3dircadb/img/hr_nd/val 38 | dataroot_gt_lr: /home/zhiyi/data/3dircadb/img/lr_nd/x4/val 39 | dataroot_lq: /home/zhiyi/data/3dircadb/img/lr_ld/x4/val 40 | dataroot_mask: /home/zhiyi/data/3dircadb/mask/x4/val 41 | dataroot_avg_ct: /home/zhiyi/data/3dircadb/img/lr_ld/x4/val_avg 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: MaskGuidedJDNSR 48 | scale: 4 49 | num_feat: 64 50 | mode: jdnsr 51 | num_block: 10 52 | 53 | # path 54 | path: 55 | pretrain_network_g: ~ 56 | strict_load_g: true 57 | resume_state: ~ 58 | 59 | # training settings 60 | train: 61 | ema_decay: 0 62 | optim_g: 63 | type: Adam 64 | lr: !!float 1e-4 65 | weight_decay: 0 66 | betas: [0.9, 0.99] 67 | 68 | scheduler: 69 | type: MultiStepLR 70 | milestones: [200000, 400000] 71 | gamma: 0.5 72 | 73 | total_iter: 250000 74 | warmup_iter: -1 # no warm up 75 | 76 | # losses 77 | hrLD_pixel_opt: 78 | type: L1Loss 79 | loss_weight: 0.2 80 | reduction: mean 81 | 82 | LRnd_pixel_opt: 83 | type: L1Loss 84 | loss_weight: 0.2 85 | reduction: mean 86 | 87 | hrnd_pixel_opt: 88 | type: L1Loss 89 | loss_weight: 1.0 90 | 91 | # tv_opt: 92 | # type: WeightedTVLoss 93 | # loss_weight: 0.2 94 | 95 | # validation settings 96 | val: 97 | val_freq: !!float 1e4 98 | save_img: false 99 | 100 | metrics: 101 | psnr: # metric name, can be arbitrary 102 | type: calculate_psnr 103 | crop_border: 4 104 | test_y_channel: true 105 | 106 | ssim: # metric name, can be arbitrary 107 | type: calculate_ssim 108 | crop_border: 4 109 | test_y_channel: true 110 | 111 | # logging settings 112 | logger: 113 | print_freq: 100 114 | save_checkpoint_freq: !!float 1e4 115 | use_tb_logger: true 116 | wandb: 117 | project: ~ 118 | resume_id: ~ 119 | 120 | # dist training settings 121 | dist_params: 122 | backend: nccl 123 | port: 29500 124 | -------------------------------------------------------------------------------- /options/train/train_mask_guided_jdnsr_x4_pancreas.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_mask_guided_jdnsr_x4_pancreas 3 | model_type: Mask_Guided_Model 4 | num_gpu: 1 # set num_gpu: 0 for cpu mode 5 | manual_seed: 10 6 | scale: 4 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: pancreas_train 12 | type: PairedMASKDataset 13 | dataroot_gt: /home/zhiyi/data/pancreas/img/hr_nd/train 14 | dataroot_gt_lr: /home/zhiyi/data/pancreas/img/lr_nd/x4/train 15 | dataroot_lq: /home/zhiyi/data/pancreas/img/lr_ld/x4/train 16 | dataroot_mask: /home/zhiyi/data/pancreas/mask/x4/train 17 | dataroot_avg_ct: /home/zhiyi/data/pancreas/img/lr_ld/x4/train_avg 18 | 19 | filename_tmpl: '{}' 20 | io_backend: 21 | type: disk 22 | 23 | gt_size: 128 24 | use_flip: true 25 | use_rot: true 26 | 27 | # data loader 28 | use_shuffle: true 29 | num_worker_per_gpu: 6 30 | batch_size_per_gpu: 16 31 | dataset_enlarge_ratio: 100 32 | prefetch_mode: ~ 33 | 34 | val: 35 | name: pancreas_val 36 | type: PairedMASKDataset 37 | dataroot_gt: /home/zhiyi/data/pancreas/img/hr_nd/val 38 | dataroot_gt_lr: /home/zhiyi/data/pancreas/img/lr_nd/x4/val 39 | dataroot_lq: /home/zhiyi/data/pancreas/img/lr_ld/x4/val 40 | dataroot_mask: /home/zhiyi/data/pancreas/mask/x4/val 41 | dataroot_avg_ct: /home/zhiyi/data/pancreas/img/lr_ld/x4/val_avg 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: MaskGuidedJDNSR 48 | scale: 4 49 | num_feat: 64 50 | mode: jdnsr 51 | num_block: 10 52 | 53 | # path 54 | path: 55 | pretrain_network_g: ~ 56 | strict_load_g: true 57 | resume_state: ~ 58 | 59 | # training settings 60 | train: 61 | ema_decay: 0 62 | optim_g: 63 | type: Adam 64 | lr: !!float 1e-4 65 | weight_decay: 0 66 | betas: [0.9, 0.99] 67 | 68 | scheduler: 69 | type: MultiStepLR 70 | milestones: [200000, 400000] 71 | gamma: 0.5 72 | 73 | total_iter: 250000 74 | warmup_iter: -1 # no warm up 75 | 76 | # losses 77 | hrLD_pixel_opt: 78 | type: L1Loss 79 | loss_weight: 0.2 80 | reduction: mean 81 | 82 | LRnd_pixel_opt: 83 | type: L1Loss 84 | loss_weight: 0.2 85 | reduction: mean 86 | 87 | hrnd_pixel_opt: 88 | type: L1Loss 89 | loss_weight: 1.0 90 | 91 | # tv_opt: 92 | # type: WeightedTVLoss 93 | # loss_weight: 0.2 94 | 95 | # validation settings 96 | val: 97 | val_freq: !!float 1e4 98 | save_img: false 99 | 100 | metrics: 101 | psnr: # metric name, can be arbitrary 102 | type: calculate_psnr 103 | crop_border: 4 104 | test_y_channel: true 105 | 106 | ssim: # metric name, can be arbitrary 107 | type: calculate_ssim 108 | crop_border: 4 109 | test_y_channel: true 110 | 111 | # logging settings 112 | logger: 113 | print_freq: 100 114 | save_checkpoint_freq: !!float 1e4 115 | use_tb_logger: true 116 | wandb: 117 | project: ~ 118 | resume_id: ~ 119 | 120 | # dist training settings 121 | dist_params: 122 | backend: nccl 123 | port: 29500 124 | -------------------------------------------------------------------------------- /scripts/model_conversion/convert_dfdnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from basicsr.archs.dfdnet_arch import DFDNet 4 | from basicsr.archs.vgg_arch import NAMES 5 | 6 | 7 | def convert_net(ori_net, crt_net): 8 | 9 | for crt_k, _ in crt_net.items(): 10 | # vgg feature extractor 11 | if 'vgg_extractor' in crt_k: 12 | ori_k = crt_k.replace('vgg_extractor', 'VggExtract').replace('vgg_net', 'model') 13 | if 'mean' in crt_k: 14 | ori_k = ori_k.replace('mean', 'RGB_mean') 15 | elif 'std' in crt_k: 16 | ori_k = ori_k.replace('std', 'RGB_std') 17 | else: 18 | idx = NAMES['vgg19'].index(crt_k.split('.')[2]) 19 | if 'weight' in crt_k: 20 | ori_k = f'VggExtract.model.features.{idx}.weight' 21 | else: 22 | ori_k = f'VggExtract.model.features.{idx}.bias' 23 | elif 'attn_blocks' in crt_k: 24 | if 'left_eye' in crt_k: 25 | ori_k = crt_k.replace('attn_blocks.left_eye', 'le') 26 | elif 'right_eye' in crt_k: 27 | ori_k = crt_k.replace('attn_blocks.right_eye', 're') 28 | elif 'mouth' in crt_k: 29 | ori_k = crt_k.replace('attn_blocks.mouth', 'mo') 30 | elif 'nose' in crt_k: 31 | ori_k = crt_k.replace('attn_blocks.nose', 'no') 32 | else: 33 | raise ValueError('Wrong!') 34 | elif 'multi_scale_dilation' in crt_k: 35 | if 'conv_blocks' in crt_k: 36 | _, _, c, d, e = crt_k.split('.') 37 | ori_k = f'MSDilate.conv{int(c)+1}.{d}.{e}' 38 | else: 39 | ori_k = crt_k.replace('multi_scale_dilation.conv_fusion', 'MSDilate.convi') 40 | 41 | elif crt_k.startswith('upsample'): 42 | ori_k = crt_k.replace('upsample', 'up') 43 | if 'scale_block' in crt_k: 44 | ori_k = ori_k.replace('scale_block', 'ScaleModel1') 45 | elif 'shift_block' in crt_k: 46 | ori_k = ori_k.replace('shift_block', 'ShiftModel1') 47 | 48 | elif 'upsample4' in crt_k and 'body' in crt_k: 49 | ori_k = ori_k.replace('body', 'Model') 50 | 51 | else: 52 | print('unprocess key: ', crt_k) 53 | 54 | # replace 55 | if crt_net[crt_k].size() != ori_net[ori_k].size(): 56 | raise ValueError('Wrong tensor size: \n' 57 | f'crt_net: {crt_net[crt_k].size()}\n' 58 | f'ori_net: {ori_net[ori_k].size()}') 59 | else: 60 | crt_net[crt_k] = ori_net[ori_k] 61 | 62 | return crt_net 63 | 64 | 65 | if __name__ == '__main__': 66 | ori_net = torch.load('experiments/pretrained_models/DFDNet/DFDNet_official_original.pth') 67 | dfd_net = DFDNet(64, dict_path='experiments/pretrained_models/DFDNet/DFDNet_dict_512.pth') 68 | crt_net = dfd_net.state_dict() 69 | crt_net_params = convert_net(ori_net, crt_net) 70 | 71 | torch.save( 72 | dict(params=crt_net_params), 73 | 'experiments/pretrained_models/DFDNet/DFDNet_official.pth', 74 | _use_new_zipfile_serialization=False) 75 | -------------------------------------------------------------------------------- /basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Function 7 | 8 | BASICSR_JIT = os.getenv('BASICSR_JIT') 9 | if BASICSR_JIT == 'True': 10 | from torch.utils.cpp_extension import load 11 | module_path = os.path.dirname(__file__) 12 | fused_act_ext = load( 13 | 'fused', 14 | sources=[ 15 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'), 16 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), 17 | ], 18 | ) 19 | else: 20 | try: 21 | from . import fused_act_ext 22 | except ImportError: 23 | pass 24 | # avoid annoying print output 25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' 26 | # '1. compile with BASICSR_EXT=True. or\n ' 27 | # '2. set BASICSR_JIT=True during running') 28 | 29 | 30 | class FusedLeakyReLUFunctionBackward(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, grad_output, out, negative_slope, scale): 34 | ctx.save_for_backward(out) 35 | ctx.negative_slope = negative_slope 36 | ctx.scale = scale 37 | 38 | empty = grad_output.new_empty(0) 39 | 40 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 41 | 42 | dim = [0] 43 | 44 | if grad_input.ndim > 2: 45 | dim += list(range(2, grad_input.ndim)) 46 | 47 | grad_bias = grad_input.sum(dim).detach() 48 | 49 | return grad_input, grad_bias 50 | 51 | @staticmethod 52 | def backward(ctx, gradgrad_input, gradgrad_bias): 53 | out, = ctx.saved_tensors 54 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, 55 | ctx.scale) 56 | 57 | return gradgrad_out, None, None, None 58 | 59 | 60 | class FusedLeakyReLUFunction(Function): 61 | 62 | @staticmethod 63 | def forward(ctx, input, bias, negative_slope, scale): 64 | empty = input.new_empty(0) 65 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 66 | ctx.save_for_backward(out) 67 | ctx.negative_slope = negative_slope 68 | ctx.scale = scale 69 | 70 | return out 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | out, = ctx.saved_tensors 75 | 76 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 77 | 78 | return grad_input, grad_bias, None, None 79 | 80 | 81 | class FusedLeakyReLU(nn.Module): 82 | 83 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 84 | super().__init__() 85 | 86 | self.bias = nn.Parameter(torch.zeros(channel)) 87 | self.negative_slope = negative_slope 88 | self.scale = scale 89 | 90 | def forward(self, input): 91 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 92 | 93 | 94 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 95 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 96 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu 2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 3 | // 4 | // This work is made available under the Nvidia Source Code License-NC. 5 | // To view a copy of this license, visit 6 | // https://nvlabs.github.io/stylegan2/license.html 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | 19 | template 20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 23 | 24 | scalar_t zero = 0.0; 25 | 26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 27 | scalar_t x = p_x[xi]; 28 | 29 | if (use_bias) { 30 | x += p_b[(xi / step_b) % size_b]; 31 | } 32 | 33 | scalar_t ref = use_ref ? p_ref[xi] : zero; 34 | 35 | scalar_t y; 36 | 37 | switch (act * 10 + grad) { 38 | default: 39 | case 10: y = x; break; 40 | case 11: y = x; break; 41 | case 12: y = 0.0; break; 42 | 43 | case 30: y = (x > 0.0) ? x : x * alpha; break; 44 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 45 | case 32: y = 0.0; break; 46 | } 47 | 48 | out[xi] = y * scale; 49 | } 50 | } 51 | 52 | 53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 54 | int act, int grad, float alpha, float scale) { 55 | int curDevice = -1; 56 | cudaGetDevice(&curDevice); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 58 | 59 | auto x = input.contiguous(); 60 | auto b = bias.contiguous(); 61 | auto ref = refer.contiguous(); 62 | 63 | int use_bias = b.numel() ? 1 : 0; 64 | int use_ref = ref.numel() ? 1 : 0; 65 | 66 | int size_x = x.numel(); 67 | int size_b = b.numel(); 68 | int step_b = 1; 69 | 70 | for (int i = 1 + 1; i < x.dim(); i++) { 71 | step_b *= x.size(i); 72 | } 73 | 74 | int loop_x = 4; 75 | int block_size = 4 * 32; 76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 77 | 78 | auto y = torch::empty_like(x); 79 | 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 81 | fused_bias_act_kernel<<>>( 82 | y.data_ptr(), 83 | x.data_ptr(), 84 | b.data_ptr(), 85 | ref.data_ptr(), 86 | act, 87 | grad, 88 | alpha, 89 | scale, 90 | loop_x, 91 | size_x, 92 | step_b, 93 | size_b, 94 | use_bias, 95 | use_ref 96 | ); 97 | }); 98 | 99 | return y; 100 | } 101 | -------------------------------------------------------------------------------- /scripts/data_preparation/generate_lr_images.py: -------------------------------------------------------------------------------- 1 | # 生成bicubic下采样的图片 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | import tqdm 6 | 7 | 8 | def main(data_dir="/home/zhiyi/data/medical/belly", hr_name="hr", lr_name="x2", scale=2, lr_name_template=None): 9 | # Scale factor 10 | hr_dir = os.path.join(data_dir, hr_name) 11 | assert os.path.isdir(hr_dir), "hr路径不存在" 12 | lr_dir = os.path.join(data_dir, lr_name) 13 | if not os.path.isdir(lr_dir): 14 | os.makedirs(lr_dir) 15 | hr_list = os.listdir(hr_dir) 16 | 17 | for i in tqdm.tqdm(range(len(hr_list))): 18 | # print(f"第{i}张") 19 | img_hr = os.path.join(hr_dir, hr_list[i]) 20 | img_hr = Image.open(img_hr) 21 | dsize = (img_hr.size[0] // scale, img_hr.size[1] // scale) 22 | img_hr = np.array(img_hr) 23 | if len(img_hr.shape) == 3: 24 | img_hr = img_hr[:img_hr.shape[0]//scale*scale, :img_hr.shape[1]//scale*scale, :] 25 | else: 26 | img_hr = img_hr[:img_hr.shape[0] // scale * scale, :img_hr.shape[1] // scale * scale] 27 | img_hr = Image.fromarray(img_hr) 28 | img_lr = img_hr.resize(dsize, Image.BICUBIC) 29 | if lr_name_template is None: 30 | lr_path = os.path.join(lr_dir, hr_list[i]) 31 | else: 32 | b, e = os.path.splitext(hr_list[i]) 33 | b = b.split("_") 34 | t = [] 35 | for i in b: 36 | if i.isdigit(): 37 | t.append(i) 38 | lr_name = lr_name_template.format(*t) 39 | lr_path = os.path.join(lr_dir, lr_name) 40 | img_lr.save(lr_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | # d = ["B100", "manga109", "urban100"] 45 | # for i in [8]: 46 | # main(f"/home/zhiyi/data/Set14", i) 47 | # for s in [2, 3, 4, 8]: 48 | # main(r"/data/zy/DIV2K/val", s) 49 | 50 | # for s in [2, 4, 8]: 51 | # main("/home/zhiyi/data/medical/belly", "hr_mask", f"x{s}_mask", s, "image_{}_{}.jpg") 52 | # main("/home/zhiyi/data/medical/belly", "hr_mask_valid", f"x{s}_mask_valid", s, "image_{}_{}.jpg") 53 | 54 | # img_names = os.listdir("/home/zhiyi/data/medical/belly/hr_mask_origin") 55 | # save_dir = "/home/zhiyi/data/medical/belly/hr_mask" 56 | # if not os.path.isdir(save_dir): 57 | # os.makedirs(save_dir) 58 | # for img_name in img_names: 59 | # b, e = os.path.splitext(img_name) 60 | # b = b.split("_") 61 | # t = [] 62 | # for i in b: 63 | # if i.isdigit(): 64 | # t.append(i) 65 | # img = Image.open("/home/zhiyi/data/medical/belly/hr_mask_origin/" + img_name) 66 | # save_path = os.path.join(save_dir, "image_{}_{}.jpg".format(*t)) 67 | # img.save(save_path) 68 | # 69 | # img_names = os.listdir("/home/zhiyi/data/medical/belly/hr_mask_valid_origin") 70 | # save_dir = "/home/zhiyi/data/medical/belly/hr_mask_valid" 71 | # if not os.path.isdir(save_dir): 72 | # os.makedirs(save_dir) 73 | # for img_name in img_names: 74 | # b, e = os.path.splitext(img_name) 75 | # b = b.split("_") 76 | # t = [] 77 | # for i in b: 78 | # if i.isdigit(): 79 | # t.append(i) 80 | # img = Image.open("/home/zhiyi/data/medical/belly/hr_mask_valid_origin/" + img_name) 81 | # save_path = os.path.join(save_dir, "image_{}_{}.jpg".format(*t)) 82 | # img.save(save_path) 83 | 84 | for s in [2, 4]: 85 | for n in ["train", "val", "test"]: 86 | # main("/home/zhiyi/data/3dircadb/img", f"hr_ld/{n}", f"lr_ld/x{s}/{n}", s) 87 | main("/home/zhiyi/data/pancreas/mask_old", f"hr/{n}", f"x{s}/{n}", s) -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | 81 | def reset(self): 82 | self.loader = iter(self.ori_loader) 83 | 84 | 85 | class CUDAPrefetcher(): 86 | """CUDA prefetcher. 87 | 88 | Ref: 89 | https://github.com/NVIDIA/apex/issues/304# 90 | 91 | It may consums more GPU memory. 92 | 93 | Args: 94 | loader: Dataloader. 95 | opt (dict): Options. 96 | """ 97 | 98 | def __init__(self, loader, opt): 99 | self.ori_loader = loader 100 | self.loader = iter(loader) 101 | self.opt = opt 102 | self.stream = torch.cuda.Stream() 103 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 104 | self.preload() 105 | 106 | def preload(self): 107 | try: 108 | self.batch = next(self.loader) # self.batch is a dict 109 | except StopIteration: 110 | self.batch = None 111 | return None 112 | # put tensors to gpu 113 | with torch.cuda.stream(self.stream): 114 | for k, v in self.batch.items(): 115 | if torch.is_tensor(v): 116 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 117 | 118 | def next(self): 119 | torch.cuda.current_stream().wait_stream(self.stream) 120 | batch = self.batch 121 | self.preload() 122 | return batch 123 | 124 | def reset(self): 125 | self.loader = iter(self.ori_loader) 126 | self.preload() 127 | -------------------------------------------------------------------------------- /scripts/model_conversion/convert_stylegan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator, StyleGAN2Generator 4 | 5 | 6 | def convert_net_g(ori_net, crt_net): 7 | """Convert network generator.""" 8 | 9 | for crt_k, crt_v in crt_net.items(): 10 | if 'style_mlp' in crt_k: 11 | ori_k = crt_k.replace('style_mlp', 'style') 12 | elif 'constant_input.weight' in crt_k: 13 | ori_k = crt_k.replace('constant_input.weight', 'input.input') 14 | # style conv1 15 | elif 'style_conv1.modulated_conv' in crt_k: 16 | ori_k = crt_k.replace('style_conv1.modulated_conv', 'conv1.conv') 17 | elif 'style_conv1' in crt_k: 18 | if crt_v.shape == torch.Size([1]): 19 | ori_k = crt_k.replace('style_conv1', 'conv1.noise') 20 | else: 21 | ori_k = crt_k.replace('style_conv1', 'conv1') 22 | # style conv 23 | elif 'style_convs' in crt_k: 24 | ori_k = crt_k.replace('style_convs', 'convs').replace('modulated_conv', 'conv') 25 | if crt_v.shape == torch.Size([1]): 26 | ori_k = ori_k.replace('.weight', '.noise.weight') 27 | # to_rgb1 28 | elif 'to_rgb1.modulated_conv' in crt_k: 29 | ori_k = crt_k.replace('to_rgb1.modulated_conv', 'to_rgb1.conv') 30 | # to_rgbs 31 | elif 'to_rgbs' in crt_k: 32 | ori_k = crt_k.replace('modulated_conv', 'conv') 33 | elif 'noises' in crt_k: 34 | ori_k = crt_k.replace('.noise', '.noise_') 35 | else: 36 | ori_k = crt_k 37 | 38 | # replace 39 | if crt_net[crt_k].size() != ori_net[ori_k].size(): 40 | raise ValueError('Wrong tensor size: \n' 41 | f'crt_net: {crt_net[crt_k].size()}\n' 42 | f'ori_net: {ori_net[ori_k].size()}') 43 | else: 44 | crt_net[crt_k] = ori_net[ori_k] 45 | 46 | return crt_net 47 | 48 | 49 | def convert_net_d(ori_net, crt_net): 50 | """Convert network discriminator.""" 51 | 52 | for crt_k, _ in crt_net.items(): 53 | if 'conv_body' in crt_k: 54 | ori_k = crt_k.replace('conv_body', 'convs') 55 | else: 56 | ori_k = crt_k 57 | 58 | # replace 59 | if crt_net[crt_k].size() != ori_net[ori_k].size(): 60 | raise ValueError('Wrong tensor size: \n' 61 | f'crt_net: {crt_net[crt_k].size()}\n' 62 | f'ori_net: {ori_net[ori_k].size()}') 63 | else: 64 | crt_net[crt_k] = ori_net[ori_k] 65 | return crt_net 66 | 67 | 68 | if __name__ == '__main__': 69 | """Convert official stylegan2 weights from stylegan2-pytorch.""" 70 | 71 | # configuration 72 | ori_net = torch.load('experiments/pretrained_models/stylegan2-ffhq.pth') 73 | save_path_g = 'experiments/pretrained_models/stylegan2_ffhq_config_f_1024_official.pth' # noqa: E501 74 | save_path_d = 'experiments/pretrained_models/stylegan2_ffhq_config_f_1024_discriminator_official.pth' # noqa: E501 75 | out_size = 1024 76 | channel_multiplier = 1 77 | 78 | # convert generator 79 | crt_net = StyleGAN2Generator(out_size, num_style_feat=512, num_mlp=8, channel_multiplier=channel_multiplier) 80 | crt_net = crt_net.state_dict() 81 | 82 | crt_net_params_ema = convert_net_g(ori_net['g_ema'], crt_net) 83 | torch.save(dict(params_ema=crt_net_params_ema, latent_avg=ori_net['latent_avg']), save_path_g) 84 | 85 | # convert discriminator 86 | crt_net = StyleGAN2Discriminator(out_size, channel_multiplier=channel_multiplier) 87 | crt_net = crt_net.state_dict() 88 | 89 | crt_net_params = convert_net_d(ori_net['d'], crt_net) 90 | torch.save(dict(params=crt_net_params), save_path_d) 91 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The minimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] 97 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 11 | from basicsr.utils import get_root_logger, scandir 12 | from basicsr.utils.dist_util import get_dist_info 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must contain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) # 得到一个类的实例,例如basicsr.data.paired_image_dataset.PairedImageDataset 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] # train or test 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) # 创建一个字典 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) 84 | 85 | prefetch_mode = dataset_opt.get('prefetch_mode') 86 | if prefetch_mode == 'cpu': # CPUPrefetcher 87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 88 | logger = get_root_logger() 89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') 90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 91 | else: 92 | # prefetch_mode=None: Normal dataloader 93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 94 | return torch.utils.data.DataLoader(**dataloader_args) 95 | 96 | 97 | def worker_init_fn(worker_id, num_workers, rank, seed): 98 | # Set the worker seed to num_workers * rank + worker_id + seed 99 | worker_seed = num_workers * rank + worker_id + seed 100 | np.random.seed(worker_seed) 101 | random.seed(worker_seed) 102 | -------------------------------------------------------------------------------- /LICENSE/LICENSE-NVIDIA: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | Nvidia Source Code License-NC 5 | 6 | ======================================================================= 7 | 8 | 1. Definitions 9 | 10 | "Licensor" means any person or entity that distributes its Work. 11 | 12 | "Software" means the original work of authorship made available under 13 | this License. 14 | 15 | "Work" means the Software and any additions to or derivative works of 16 | the Software that are made available under this License. 17 | 18 | "Nvidia Processors" means any central processing unit (CPU), graphics 19 | processing unit (GPU), field-programmable gate array (FPGA), 20 | application-specific integrated circuit (ASIC) or any combination 21 | thereof designed, made, sold, or provided by Nvidia or its affiliates. 22 | 23 | The terms "reproduce," "reproduction," "derivative works," and 24 | "distribution" have the meaning as provided under U.S. copyright law; 25 | provided, however, that for the purposes of this License, derivative 26 | works shall not include works that remain separable from, or merely 27 | link (or bind by name) to the interfaces of, the Work. 28 | 29 | Works, including the Software, are "made available" under this License 30 | by including in or with the Work either (a) a copyright notice 31 | referencing the applicability of this License to the Work, or (b) a 32 | copy of this License. 33 | 34 | 2. License Grants 35 | 36 | 2.1 Copyright Grant. Subject to the terms and conditions of this 37 | License, each Licensor grants to you a perpetual, worldwide, 38 | non-exclusive, royalty-free, copyright license to reproduce, 39 | prepare derivative works of, publicly display, publicly perform, 40 | sublicense and distribute its Work and any resulting derivative 41 | works in any form. 42 | 43 | 3. Limitations 44 | 45 | 3.1 Redistribution. You may reproduce or distribute the Work only 46 | if (a) you do so under this License, (b) you include a complete 47 | copy of this License with your distribution, and (c) you retain 48 | without modification any copyright, patent, trademark, or 49 | attribution notices that are present in the Work. 50 | 51 | 3.2 Derivative Works. You may specify that additional or different 52 | terms apply to the use, reproduction, and distribution of your 53 | derivative works of the Work ("Your Terms") only if (a) Your Terms 54 | provide that the use limitation in Section 3.3 applies to your 55 | derivative works, and (b) you identify the specific derivative 56 | works that are subject to Your Terms. Notwithstanding Your Terms, 57 | this License (including the redistribution requirements in Section 58 | 3.1) will continue to apply to the Work itself. 59 | 60 | 3.3 Use Limitation. The Work and any derivative works thereof only 61 | may be used or intended for use non-commercially. The Work or 62 | derivative works thereof may be used or intended for use by Nvidia 63 | or its affiliates commercially or non-commercially. As used herein, 64 | "non-commercially" means for research or evaluation purposes only. 65 | 66 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 67 | against any Licensor (including any claim, cross-claim or 68 | counterclaim in a lawsuit) to enforce any patents that you allege 69 | are infringed by any Work, then your rights under this License from 70 | such Licensor (including the grants in Sections 2.1 and 2.2) will 71 | terminate immediately. 72 | 73 | 3.5 Trademarks. This License does not grant any rights to use any 74 | Licensor's or its affiliates' names, logos, or trademarks, except 75 | as necessary to reproduce the notices described in this License. 76 | 77 | 3.6 Termination. If you violate any term of this License, then your 78 | rights under this License (including the grants in Sections 2.1 and 79 | 2.2) will terminate immediately. 80 | 81 | 4. Disclaimer of Warranty. 82 | 83 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 84 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 85 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 86 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 87 | THIS LICENSE. 88 | 89 | 5. Limitation of Liability. 90 | 91 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 92 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 93 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 94 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 95 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 96 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 97 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 98 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 99 | THE POSSIBILITY OF SUCH DAMAGES. 100 | 101 | ======================================================================= 102 | -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | 10 | 11 | def set_random_seed(seed): 12 | """Set random seeds.""" 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def get_time_str(): 21 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 22 | 23 | 24 | def mkdir_and_rename(path): 25 | """mkdirs. If path exists, rename it with timestamp and create a new one. 26 | 27 | Args: 28 | path (str): Folder path. 29 | """ 30 | if osp.exists(path): 31 | new_name = path + '_archived_' + get_time_str() 32 | print(f'Path already exists. Rename it to {new_name}', flush=True) 33 | os.rename(path, new_name) 34 | os.makedirs(path, exist_ok=True) 35 | 36 | 37 | @master_only 38 | def make_exp_dirs(opt): 39 | """Make dirs for experiments.""" 40 | path_opt = opt['path'].copy() 41 | if opt['is_train']: 42 | mkdir_and_rename(path_opt.pop('experiments_root')) 43 | else: 44 | mkdir_and_rename(path_opt.pop('results_root')) 45 | for key, path in path_opt.items(): 46 | if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): 47 | continue 48 | else: 49 | os.makedirs(path, exist_ok=True) 50 | 51 | 52 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 53 | """Scan a directory to find the interested files. 54 | 55 | Args: 56 | dir_path (str): Path of the directory. 57 | suffix (str | tuple(str), optional): File suffix that we are 58 | interested in. Default: None. 59 | recursive (bool, optional): If set to True, recursively scan the 60 | directory. Default: False. 61 | full_path (bool, optional): If set to True, include the dir_path. 62 | Default: False. 63 | 64 | Returns: 65 | A generator for all the interested files with relative paths. 66 | """ 67 | 68 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 69 | raise TypeError('"suffix" must be a string or tuple of strings') 70 | 71 | root = dir_path 72 | 73 | def _scandir(dir_path, suffix, recursive): 74 | for entry in os.scandir(dir_path): 75 | if not entry.name.startswith('.') and entry.is_file(): 76 | if full_path: 77 | return_path = entry.path 78 | else: 79 | return_path = osp.relpath(entry.path, root) 80 | 81 | if suffix is None: 82 | yield return_path 83 | elif return_path.endswith(suffix): 84 | yield return_path 85 | else: 86 | if recursive: 87 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 88 | else: 89 | continue 90 | 91 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 92 | 93 | 94 | def check_resume(opt, resume_iter): 95 | """Check resume states and pretrain_network paths. 96 | 97 | Args: 98 | opt (dict): Options. 99 | resume_iter (int): Resume iteration. 100 | """ 101 | if opt['path']['resume_state']: 102 | # get all the networks 103 | networks = [key for key in opt.keys() if key.startswith('network_')] # ["network_g] 104 | flag_pretrain = False 105 | for network in networks: 106 | if opt['path'].get(f'pretrain_{network}') is not None: # 如果设置了预训练模型的路径 107 | flag_pretrain = True 108 | if flag_pretrain: 109 | print('pretrain_network path will be ignored during resuming.') 110 | # set pretrained model paths 111 | for network in networks: 112 | name = f'pretrain_{network}' 113 | basename = network.replace('network_', '') # g 114 | if opt['path'].get('ignore_resume_networks') is None or (network 115 | not in opt['path']['ignore_resume_networks']): 116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 117 | print(f"Set {name} to {opt['path'][name]}") 118 | 119 | # change param_key to params in resume 120 | param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] 121 | for param_key in param_keys: 122 | if opt['path'][param_key] == 'params_ema': 123 | opt['path'][param_key] = 'params' 124 | print(f'Set {param_key} to params') 125 | 126 | 127 | def sizeof_fmt(size, suffix='B'): 128 | """Get human readable file size. 129 | 130 | Args: 131 | size (int): File size. 132 | suffix (str): Suffix. Default: 'B'. 133 | 134 | Return: 135 | str: Formatted file siz. 136 | """ 137 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 138 | if abs(size) < 1024.0: 139 | return f'{size:3.1f} {unit}{suffix}' 140 | size /= 1024.0 141 | return f'{size:3.1f} Y{suffix}' 142 | -------------------------------------------------------------------------------- /basicsr/metrics/psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import math 4 | 5 | from basicsr.metrics.metric_util import reorder_image, to_y_channel 6 | from basicsr.utils.registry import METRIC_REGISTRY 7 | 8 | 9 | @METRIC_REGISTRY.register() # 修饰器其实是deco() 10 | def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): 11 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 12 | 13 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 14 | 15 | Args: 16 | img (ndarray): Images with range [0, 255]. 17 | img2 (ndarray): Images with range [0, 255]. 18 | crop_border (int): Cropped pixels in each edge of an image. These 19 | pixels are not involved in the PSNR calculation. 20 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 21 | Default: 'HWC'. 22 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 23 | 24 | Returns: 25 | float: psnr result. 26 | """ 27 | 28 | assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') 29 | if input_order not in ['HWC', 'CHW']: 30 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 31 | img = reorder_image(img, input_order=input_order) 32 | img2 = reorder_image(img2, input_order=input_order) 33 | img = img.astype(np.float64) 34 | img2 = img2.astype(np.float64) 35 | 36 | if crop_border != 0: 37 | img = img[crop_border:-crop_border, crop_border:-crop_border, ...] 38 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 39 | 40 | if test_y_channel: 41 | img = to_y_channel(img) 42 | img2 = to_y_channel(img2) 43 | 44 | mse = np.mean((img - img2)**2) 45 | if mse == 0: 46 | # return float('inf') 47 | if np.sum(img) == 0: 48 | return None 49 | else: 50 | return 100 51 | 52 | return 20. * np.log10(255. / np.sqrt(mse)) 53 | 54 | def _ssim(img, img2): 55 | """Calculate SSIM (structural similarity) for one channel images. 56 | 57 | It is called by func:`calculate_ssim`. 58 | 59 | Args: 60 | img (ndarray): Images with range [0, 255] with order 'HWC'. 61 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 62 | 63 | Returns: 64 | float: ssim result. 65 | """ 66 | 67 | c1 = (0.01 * 255)**2 68 | c2 = (0.03 * 255)**2 69 | 70 | img = img.astype(np.float64) 71 | img2 = img2.astype(np.float64) 72 | kernel = cv2.getGaussianKernel(11, 1.5) 73 | window = np.outer(kernel, kernel.transpose()) 74 | 75 | mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] 76 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 77 | mu1_sq = mu1**2 78 | mu2_sq = mu2**2 79 | mu1_mu2 = mu1 * mu2 80 | sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq 81 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 82 | sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 83 | 84 | ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) 85 | return ssim_map.mean() 86 | 87 | 88 | @METRIC_REGISTRY.register() 89 | def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): 90 | """Calculate SSIM (structural similarity). 91 | 92 | Ref: 93 | Image quality assessment: From error visibility to structural similarity 94 | 95 | The results are the same as that of the official released MATLAB code in 96 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 97 | 98 | For three-channel images, SSIM is calculated for each channel and then 99 | averaged. 100 | 101 | Args: 102 | img (ndarray): Images with range [0, 255]. 103 | img2 (ndarray): Images with range [0, 255]. 104 | crop_border (int): Cropped pixels in each edge of an image. These 105 | pixels are not involved in the SSIM calculation. 106 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 107 | Default: 'HWC'. 108 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 109 | 110 | Returns: 111 | float: ssim result. 112 | """ 113 | 114 | assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') 115 | if input_order not in ['HWC', 'CHW']: 116 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 117 | img = reorder_image(img, input_order=input_order) 118 | img2 = reorder_image(img2, input_order=input_order) 119 | img = img.astype(np.float64) 120 | img2 = img2.astype(np.float64) 121 | 122 | if crop_border != 0: 123 | img = img[crop_border:-crop_border, crop_border:-crop_border, ...] 124 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 125 | 126 | if test_y_channel: 127 | img = to_y_channel(img) 128 | img2 = to_y_channel(img2) 129 | 130 | ssims = [] 131 | for i in range(img.shape[2]): 132 | ssims.append(_ssim(img[..., i], img2[..., i])) 133 | return np.array(ssims).mean() 134 | 135 | if __name__ == "__main__": 136 | img = img2 = np.random.random((128, 128)) 137 | calculate_psnr(img, img2, 4) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import time 8 | import torch 9 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 10 | 11 | version_file = 'basicsr/version.py' 12 | 13 | 14 | def readme(): 15 | with open('README.md', encoding='utf-8') as f: 16 | content = f.read() 17 | return content 18 | 19 | 20 | def get_git_hash(): 21 | 22 | def _minimal_ext_cmd(cmd): 23 | # construct minimal environment 24 | env = {} 25 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 26 | v = os.environ.get(k) 27 | if v is not None: 28 | env[k] = v 29 | # LANGUAGE is used on win32 30 | env['LANGUAGE'] = 'C' 31 | env['LANG'] = 'C' 32 | env['LC_ALL'] = 'C' 33 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 34 | return out 35 | 36 | try: 37 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 38 | sha = out.strip().decode('ascii') 39 | except OSError: 40 | sha = 'unknown' 41 | 42 | return sha 43 | 44 | 45 | def get_hash(): 46 | if os.path.exists('.git'): 47 | sha = get_git_hash()[:7] 48 | elif os.path.exists(version_file): 49 | try: 50 | from basicsr.version import __version__ 51 | sha = __version__.split('+')[-1] 52 | except ImportError: 53 | raise ImportError('Unable to get git version') 54 | else: 55 | sha = 'unknown' 56 | 57 | return sha 58 | 59 | 60 | def write_version_py(): 61 | content = """# GENERATED VERSION FILE 62 | # TIME: {} 63 | __version__ = '{}' 64 | __gitsha__ = '{}' 65 | version_info = ({}) 66 | """ 67 | sha = get_hash() 68 | with open('VERSION', 'r') as f: 69 | SHORT_VERSION = f.read().strip() 70 | VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 71 | 72 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) 73 | with open(version_file, 'w') as f: 74 | f.write(version_file_str) 75 | 76 | 77 | def get_version(): 78 | with open(version_file, 'r') as f: 79 | exec(compile(f.read(), version_file, 'exec')) 80 | return locals()['__version__'] 81 | 82 | 83 | def make_cuda_ext(name, module, sources, sources_cuda=None): 84 | if sources_cuda is None: 85 | sources_cuda = [] 86 | define_macros = [] 87 | extra_compile_args = {'cxx': []} 88 | 89 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 90 | define_macros += [('WITH_CUDA', None)] 91 | extension = CUDAExtension 92 | extra_compile_args['nvcc'] = [ 93 | '-D__CUDA_NO_HALF_OPERATORS__', 94 | '-D__CUDA_NO_HALF_CONVERSIONS__', 95 | '-D__CUDA_NO_HALF2_OPERATORS__', 96 | ] 97 | sources += sources_cuda 98 | else: 99 | print(f'Compiling {name} without CUDA') 100 | extension = CppExtension 101 | 102 | return extension( 103 | name=f'{module}.{name}', 104 | sources=[os.path.join(*module.split('.'), p) for p in sources], 105 | define_macros=define_macros, 106 | extra_compile_args=extra_compile_args) 107 | 108 | 109 | def get_requirements(filename='requirements.txt'): 110 | here = os.path.dirname(os.path.realpath(__file__)) 111 | with open(os.path.join(here, filename), 'r') as f: 112 | requires = [line.replace('\n', '') for line in f.readlines()] 113 | return requires 114 | 115 | 116 | if __name__ == '__main__': 117 | cuda_ext = os.getenv('BASICSR_EXT') # whether compile cuda ext 118 | if cuda_ext == 'True': 119 | ext_modules = [ 120 | make_cuda_ext( 121 | name='deform_conv_ext', 122 | module='basicsr.ops.dcn', 123 | sources=['src/deform_conv_ext.cpp'], 124 | sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']), 125 | make_cuda_ext( 126 | name='fused_act_ext', 127 | module='basicsr.ops.fused_act', 128 | sources=['src/fused_bias_act.cpp'], 129 | sources_cuda=['src/fused_bias_act_kernel.cu']), 130 | make_cuda_ext( 131 | name='upfirdn2d_ext', 132 | module='basicsr.ops.upfirdn2d', 133 | sources=['src/upfirdn2d.cpp'], 134 | sources_cuda=['src/upfirdn2d_kernel.cu']), 135 | ] 136 | else: 137 | ext_modules = [] 138 | 139 | write_version_py() 140 | setup( 141 | name='basicsr', 142 | version=get_version(), 143 | description='Open Source Image and Video Super-Resolution Toolbox', 144 | long_description=readme(), 145 | long_description_content_type='text/markdown', 146 | author='Xintao Wang', 147 | author_email='xintao.wang@outlook.com', 148 | keywords='computer vision, restoration, super resolution', 149 | url='https://github.com/xinntao/BasicSR', 150 | include_package_data=True, 151 | packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), 152 | classifiers=[ 153 | 'Development Status :: 4 - Beta', 154 | 'License :: OSI Approved :: Apache Software License', 155 | 'Operating System :: OS Independent', 156 | 'Programming Language :: Python :: 3', 157 | 'Programming Language :: Python :: 3.7', 158 | 'Programming Language :: Python :: 3.8', 159 | ], 160 | license='Apache License 2.0', 161 | setup_requires=['cython', 'numpy'], 162 | install_requires=get_requirements(), 163 | ext_modules=ext_modules, 164 | cmdclass={'build_ext': BuildExtension}, 165 | zip_safe=False) 166 | -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError('Please install memcached to enable MemcachedBackend.') 40 | 41 | self.server_list_cfg = server_list_cfg 42 | self.client_cfg = client_cfg 43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 44 | # mc.pyvector servers as a point which points to a memory cache 45 | self._mc_buffer = mc.pyvector() 46 | 47 | def get(self, filepath): 48 | filepath = str(filepath) 49 | import mc 50 | self._client.Get(filepath, self._mc_buffer) 51 | value_buf = mc.ConvertBuffer(self._mc_buffer) 52 | return value_buf 53 | 54 | def get_text(self, filepath): 55 | raise NotImplementedError 56 | 57 | 58 | class HardDiskBackend(BaseStorageBackend): 59 | """Raw hard disks storage backend.""" 60 | 61 | def get(self, filepath): 62 | filepath = str(filepath) 63 | with open(filepath, 'rb') as f: 64 | value_buf = f.read() 65 | return value_buf 66 | 67 | def get_text(self, filepath): 68 | filepath = str(filepath) 69 | with open(filepath, 'r') as f: 70 | value_buf = f.read() 71 | return value_buf 72 | 73 | 74 | class LmdbBackend(BaseStorageBackend): 75 | """Lmdb storage backend. 76 | 77 | Args: 78 | db_paths (str | list[str]): Lmdb database paths. 79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 80 | readonly (bool, optional): Lmdb environment parameter. If True, 81 | disallow any write operations. Default: True. 82 | lock (bool, optional): Lmdb environment parameter. If False, when 83 | concurrent access occurs, do not lock the database. Default: False. 84 | readahead (bool, optional): Lmdb environment parameter. If False, 85 | disable the OS filesystem readahead mechanism, which may improve 86 | random read performance when a database is larger than RAM. 87 | Default: False. 88 | 89 | Attributes: 90 | db_paths (list): Lmdb database path. 91 | _client (list): A list of several lmdb envs. 92 | """ 93 | 94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): 95 | try: 96 | import lmdb 97 | except ImportError: 98 | raise ImportError('Please install lmdb to enable LmdbBackend.') 99 | 100 | if isinstance(client_keys, str): 101 | client_keys = [client_keys] 102 | 103 | if isinstance(db_paths, list): 104 | self.db_paths = [str(v) for v in db_paths] 105 | elif isinstance(db_paths, str): 106 | self.db_paths = [str(db_paths)] 107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' 108 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 109 | 110 | self._client = {} 111 | for client, path in zip(client_keys, self.db_paths): 112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 113 | 114 | def get(self, filepath, client_key): 115 | """Get values according to the filepath from one lmdb named client_key. 116 | 117 | Args: 118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 119 | client_key (str): Used for distinguishing different lmdb envs. 120 | """ 121 | filepath = str(filepath) 122 | assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') 123 | client = self._client[client_key] 124 | with client.begin(write=False) as txn: 125 | value_buf = txn.get(filepath.encode('ascii')) 126 | return value_buf 127 | 128 | def get_text(self, filepath): 129 | raise NotImplementedError 130 | 131 | 132 | class FileClient(object): 133 | """A general file client to access files in different backend. 134 | 135 | The client loads a file or text in a specified backend from its path 136 | and return it as a binary file. it can also register other backend 137 | accessor with a given name and backend class. 138 | 139 | Attributes: 140 | backend (str): The storage backend type. Options are "disk", 141 | "memcached" and "lmdb". 142 | client (:obj:`BaseStorageBackend`): The backend object. 143 | """ 144 | # 里面是三个类 145 | _backends = { 146 | 'disk': HardDiskBackend, 147 | 'memcached': MemcachedBackend, 148 | 'lmdb': LmdbBackend, 149 | } 150 | 151 | def __init__(self, backend='disk', **kwargs): 152 | if backend not in self._backends: 153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones' 154 | f' are {list(self._backends.keys())}') 155 | self.backend = backend 156 | self.client = self._backends[backend](**kwargs) 157 | 158 | def get(self, filepath, client_key='default'): 159 | # client_key is used only for lmdb, where different fileclients have 160 | # different lmdb environments. 161 | if self.backend == 'lmdb': 162 | return self.client.get(filepath, client_key) 163 | else: 164 | return self.client.get(filepath) 165 | 166 | def get_text(self, filepath): 167 | return self.client.get_text(filepath) 168 | -------------------------------------------------------------------------------- /basicsr/archs/vgg_arch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from torch import nn as nn 5 | from torchvision.models import vgg as vgg 6 | 7 | from basicsr.utils.registry import ARCH_REGISTRY 8 | 9 | VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' 10 | NAMES = { 11 | 'vgg11': [ 12 | 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 13 | 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 14 | 'pool5' 15 | ], 16 | 'vgg13': [ 17 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 18 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 19 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' 20 | ], 21 | 'vgg16': [ 22 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 23 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 24 | 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 25 | 'pool5' 26 | ], 27 | 'vgg19': [ 28 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 29 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', 30 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 31 | 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' 32 | ] 33 | } 34 | 35 | 36 | def insert_bn(names): 37 | """Insert bn layer after each conv. 38 | 39 | Args: 40 | names (list): The list of layer names. 41 | 42 | Returns: 43 | list: The list of layer names with bn layers. 44 | """ 45 | names_bn = [] 46 | for name in names: 47 | names_bn.append(name) 48 | if 'conv' in name: 49 | position = name.replace('conv', '') 50 | names_bn.append('bn' + position) 51 | return names_bn 52 | 53 | 54 | @ARCH_REGISTRY.register() 55 | class VGGFeatureExtractor(nn.Module): 56 | """VGG network for feature extraction. 57 | 58 | In this implementation, we allow users to choose whether use normalization 59 | in the input feature and the type of vgg network. Note that the pretrained 60 | path must fit the vgg type. 61 | 62 | Args: 63 | layer_name_list (list[str]): Forward function returns the corresponding 64 | features according to the layer_name_list. 65 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}. 66 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'. 67 | use_input_norm (bool): If True, normalize the input image. Importantly, 68 | the input feature must in the range [0, 1]. Default: True. 69 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. 70 | Default: False. 71 | requires_grad (bool): If true, the parameters of VGG network will be 72 | optimized. Default: False. 73 | remove_pooling (bool): If true, the max pooling operations in VGG net 74 | will be removed. Default: False. 75 | pooling_stride (int): The stride of max pooling operation. Default: 2. 76 | """ 77 | 78 | def __init__(self, 79 | layer_name_list, 80 | vgg_type='vgg19', 81 | use_input_norm=True, 82 | range_norm=False, 83 | requires_grad=False, 84 | remove_pooling=False, 85 | pooling_stride=2): 86 | super(VGGFeatureExtractor, self).__init__() 87 | 88 | self.layer_name_list = layer_name_list 89 | self.use_input_norm = use_input_norm 90 | self.range_norm = range_norm 91 | 92 | self.names = NAMES[vgg_type.replace('_bn', '')] 93 | if 'bn' in vgg_type: 94 | self.names = insert_bn(self.names) 95 | 96 | # only borrow layers that will be used to avoid unused params 97 | max_idx = 0 98 | for v in layer_name_list: 99 | idx = self.names.index(v) 100 | if idx > max_idx: 101 | max_idx = idx 102 | 103 | if os.path.exists(VGG_PRETRAIN_PATH): 104 | vgg_net = getattr(vgg, vgg_type)(pretrained=False) 105 | state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) 106 | vgg_net.load_state_dict(state_dict) 107 | else: 108 | vgg_net = getattr(vgg, vgg_type)(pretrained=True) 109 | 110 | features = vgg_net.features[:max_idx + 1] 111 | 112 | modified_net = OrderedDict() 113 | for k, v in zip(self.names, features): 114 | if 'pool' in k: 115 | # if remove_pooling is true, pooling operation will be removed 116 | if remove_pooling: 117 | continue 118 | else: 119 | # in some cases, we may want to change the default stride 120 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) 121 | else: 122 | modified_net[k] = v 123 | 124 | self.vgg_net = nn.Sequential(modified_net) 125 | 126 | if not requires_grad: 127 | self.vgg_net.eval() 128 | for param in self.parameters(): 129 | param.requires_grad = False 130 | else: 131 | self.vgg_net.train() 132 | for param in self.parameters(): 133 | param.requires_grad = True 134 | 135 | if self.use_input_norm: 136 | # the mean is for image with range [0, 1] 137 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 138 | # the std is for image with range [0, 1] 139 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 140 | 141 | def forward(self, x): 142 | """Forward function. 143 | 144 | Args: 145 | x (Tensor): Input tensor with shape (n, c, h, w). 146 | 147 | Returns: 148 | Tensor: Forward results. 149 | """ 150 | if self.range_norm: 151 | x = (x + 1) / 2 152 | if self.use_input_norm: 153 | x = (x - self.mean) / self.std 154 | 155 | output = {} 156 | for key, layer in self.vgg_net._modules.items(): 157 | x = layer(x) 158 | if key in self.layer_name_list: 159 | output[key] = x.clone() 160 | 161 | return output 162 | -------------------------------------------------------------------------------- /scripts/data_preparation/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs 6 | 7 | 8 | def create_lmdb_for_div2k(): 9 | """Create lmdb files for DIV2K dataset. 10 | 11 | Usage: 12 | Before run this script, please run `extract_subimages.py`. 13 | Typically, there are four folders to be processed for DIV2K dataset. 14 | DIV2K_train_HR_sub 15 | DIV2K_train_LR_bicubic/X2_sub 16 | DIV2K_train_LR_bicubic/X3_sub 17 | DIV2K_train_LR_bicubic/X4_sub 18 | Remember to modify opt configurations according to your settings. 19 | """ 20 | # HR images 21 | folder_path = 'datasets/DIV2K/DIV2K_train_HR_sub' 22 | lmdb_path = 'datasets/DIV2K/DIV2K_train_HR_sub.lmdb' 23 | img_path_list, keys = prepare_keys_div2k(folder_path) 24 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 25 | 26 | # LRx2 images 27 | folder_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X2_sub' 28 | lmdb_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X2_sub.lmdb' 29 | img_path_list, keys = prepare_keys_div2k(folder_path) 30 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 31 | 32 | # LRx3 images 33 | folder_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X3_sub' 34 | lmdb_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X3_sub.lmdb' 35 | img_path_list, keys = prepare_keys_div2k(folder_path) 36 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 37 | 38 | # LRx4 images 39 | folder_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub' 40 | lmdb_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb' 41 | img_path_list, keys = prepare_keys_div2k(folder_path) 42 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 43 | 44 | 45 | def prepare_keys_div2k(folder_path): 46 | """Prepare image path list and keys for DIV2K dataset. 47 | 48 | Args: 49 | folder_path (str): Folder path. 50 | 51 | Returns: 52 | list[str]: Image path list. 53 | list[str]: Key list. 54 | """ 55 | print('Reading image path list ...') 56 | img_path_list = sorted(list(scandir(folder_path, suffix='png', recursive=False))) 57 | keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)] 58 | 59 | return img_path_list, keys 60 | 61 | 62 | def create_lmdb_for_reds(): 63 | """Create lmdb files for REDS dataset. 64 | 65 | Usage: 66 | Before run this script, please run `merge_reds_train_val.py`. 67 | We take two folders for example: 68 | train_sharp 69 | train_sharp_bicubic 70 | Remember to modify opt configurations according to your settings. 71 | """ 72 | # train_sharp 73 | folder_path = 'datasets/REDS/train_sharp' 74 | lmdb_path = 'datasets/REDS/train_sharp_with_val.lmdb' 75 | img_path_list, keys = prepare_keys_reds(folder_path) 76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True) 77 | 78 | # train_sharp_bicubic 79 | folder_path = 'datasets/REDS/train_sharp_bicubic' 80 | lmdb_path = 'datasets/REDS/train_sharp_bicubic_with_val.lmdb' 81 | img_path_list, keys = prepare_keys_reds(folder_path) 82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True) 83 | 84 | 85 | def prepare_keys_reds(folder_path): 86 | """Prepare image path list and keys for REDS dataset. 87 | 88 | Args: 89 | folder_path (str): Folder path. 90 | 91 | Returns: 92 | list[str]: Image path list. 93 | list[str]: Key list. 94 | """ 95 | print('Reading image path list ...') 96 | img_path_list = sorted(list(scandir(folder_path, suffix='png', recursive=True))) 97 | keys = [v.split('.png')[0] for v in img_path_list] # example: 000/00000000 98 | 99 | return img_path_list, keys 100 | 101 | 102 | def create_lmdb_for_vimeo90k(): 103 | """Create lmdb files for Vimeo90K dataset. 104 | 105 | Usage: 106 | Remember to modify opt configurations according to your settings. 107 | """ 108 | # GT 109 | folder_path = 'datasets/vimeo90k/vimeo_septuplet/sequences' 110 | lmdb_path = 'datasets/vimeo90k/vimeo90k_train_GT_only4th.lmdb' 111 | train_list_path = 'datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' 112 | img_path_list, keys = prepare_keys_vimeo90k(folder_path, train_list_path, 'gt') 113 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True) 114 | 115 | # LQ 116 | folder_path = 'datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences' 117 | lmdb_path = 'datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb' 118 | train_list_path = 'datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' 119 | img_path_list, keys = prepare_keys_vimeo90k(folder_path, train_list_path, 'lq') 120 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True) 121 | 122 | 123 | def prepare_keys_vimeo90k(folder_path, train_list_path, mode): 124 | """Prepare image path list and keys for Vimeo90K dataset. 125 | 126 | Args: 127 | folder_path (str): Folder path. 128 | train_list_path (str): Path to the official train list. 129 | mode (str): One of 'gt' or 'lq'. 130 | 131 | Returns: 132 | list[str]: Image path list. 133 | list[str]: Key list. 134 | """ 135 | print('Reading image path list ...') 136 | with open(train_list_path, 'r') as fin: 137 | train_list = [line.strip() for line in fin] 138 | 139 | img_path_list = [] 140 | keys = [] 141 | for line in train_list: 142 | folder, sub_folder = line.split('/') 143 | img_path_list.extend([osp.join(folder, sub_folder, f'im{j + 1}.png') for j in range(7)]) 144 | keys.extend([f'{folder}/{sub_folder}/im{j + 1}' for j in range(7)]) 145 | 146 | if mode == 'gt': 147 | print('Only keep the 4th frame for the gt mode.') 148 | img_path_list = [v for v in img_path_list if v.endswith('im4.png')] 149 | keys = [v for v in keys if v.endswith('/im4')] 150 | 151 | return img_path_list, keys 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser() 156 | 157 | parser.add_argument( 158 | '--dataset', 159 | type=str, 160 | help=("Options: 'DIV2K', 'REDS', 'Vimeo90K' " 161 | 'You may need to modify the corresponding configurations in codes.')) 162 | args = parser.parse_args() 163 | dataset = args.dataset.lower() 164 | if dataset == 'div2k': 165 | create_lmdb_for_div2k() 166 | elif dataset == 'reds': 167 | create_lmdb_for_reds() 168 | elif dataset == 'vimeo90k': 169 | create_lmdb_for_vimeo90k() 170 | else: 171 | raise ValueError('Wrong dataset.') 172 | -------------------------------------------------------------------------------- /basicsr/utils/flow_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 2 | import cv2 3 | import numpy as np 4 | import os 5 | 6 | 7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): 8 | """Read an optical flow map. 9 | 10 | Args: 11 | flow_path (ndarray or str): Flow path. 12 | quantize (bool): whether to read quantized pair, if set to True, 13 | remaining args will be passed to :func:`dequantize_flow`. 14 | concat_axis (int): The axis that dx and dy are concatenated, 15 | can be either 0 or 1. Ignored if quantize is False. 16 | 17 | Returns: 18 | ndarray: Optical flow represented as a (h, w, 2) numpy array 19 | """ 20 | if quantize: 21 | assert concat_axis in [0, 1] 22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) 23 | if cat_flow.ndim != 2: 24 | raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') 25 | assert cat_flow.shape[concat_axis] % 2 == 0 26 | dx, dy = np.split(cat_flow, 2, axis=concat_axis) 27 | flow = dequantize_flow(dx, dy, *args, **kwargs) 28 | else: 29 | with open(flow_path, 'rb') as f: 30 | try: 31 | header = f.read(4).decode('utf-8') 32 | except Exception: 33 | raise IOError(f'Invalid flow file: {flow_path}') 34 | else: 35 | if header != 'PIEH': 36 | raise IOError(f'Invalid flow file: {flow_path}, ' 'header does not contain PIEH') 37 | 38 | w = np.fromfile(f, np.int32, 1).squeeze() 39 | h = np.fromfile(f, np.int32, 1).squeeze() 40 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) 41 | 42 | return flow.astype(np.float32) 43 | 44 | 45 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): 46 | """Write optical flow to file. 47 | 48 | If the flow is not quantized, it will be saved as a .flo file losslessly, 49 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy 50 | will be concatenated horizontally into a single image if quantize is True.) 51 | 52 | Args: 53 | flow (ndarray): (h, w, 2) array of optical flow. 54 | filename (str): Output filepath. 55 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg 56 | images. If set to True, remaining args will be passed to 57 | :func:`quantize_flow`. 58 | concat_axis (int): The axis that dx and dy are concatenated, 59 | can be either 0 or 1. Ignored if quantize is False. 60 | """ 61 | if not quantize: 62 | with open(filename, 'wb') as f: 63 | f.write('PIEH'.encode('utf-8')) 64 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 65 | flow = flow.astype(np.float32) 66 | flow.tofile(f) 67 | f.flush() 68 | else: 69 | assert concat_axis in [0, 1] 70 | dx, dy = quantize_flow(flow, *args, **kwargs) 71 | dxdy = np.concatenate((dx, dy), axis=concat_axis) 72 | os.makedirs(os.path.dirname(filename), exist_ok=True) 73 | cv2.imwrite(filename, dxdy) 74 | 75 | 76 | def quantize_flow(flow, max_val=0.02, norm=True): 77 | """Quantize flow to [0, 255]. 78 | 79 | After this step, the size of flow will be much smaller, and can be 80 | dumped as jpeg images. 81 | 82 | Args: 83 | flow (ndarray): (h, w, 2) array of optical flow. 84 | max_val (float): Maximum value of flow, values beyond 85 | [-max_val, max_val] will be truncated. 86 | norm (bool): Whether to divide flow values by image width/height. 87 | 88 | Returns: 89 | tuple[ndarray]: Quantized dx and dy. 90 | """ 91 | h, w, _ = flow.shape 92 | dx = flow[..., 0] 93 | dy = flow[..., 1] 94 | if norm: 95 | dx = dx / w # avoid inplace operations 96 | dy = dy / h 97 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. 98 | flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] 99 | return tuple(flow_comps) 100 | 101 | 102 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True): 103 | """Recover from quantized flow. 104 | 105 | Args: 106 | dx (ndarray): Quantized dx. 107 | dy (ndarray): Quantized dy. 108 | max_val (float): Maximum value used when quantizing. 109 | denorm (bool): Whether to multiply flow values with width/height. 110 | 111 | Returns: 112 | ndarray: Dequantized flow. 113 | """ 114 | assert dx.shape == dy.shape 115 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) 116 | 117 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] 118 | 119 | if denorm: 120 | dx *= dx.shape[1] 121 | dy *= dx.shape[0] 122 | flow = np.dstack((dx, dy)) 123 | return flow 124 | 125 | 126 | def quantize(arr, min_val, max_val, levels, dtype=np.int64): 127 | """Quantize an array of (-inf, inf) to [0, levels-1]. 128 | 129 | Args: 130 | arr (ndarray): Input array. 131 | min_val (scalar): Minimum value to be clipped. 132 | max_val (scalar): Maximum value to be clipped. 133 | levels (int): Quantization levels. 134 | dtype (np.type): The type of the quantized array. 135 | 136 | Returns: 137 | tuple: Quantized array. 138 | """ 139 | if not (isinstance(levels, int) and levels > 1): 140 | raise ValueError(f'levels must be a positive integer, but got {levels}') 141 | if min_val >= max_val: 142 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 143 | 144 | arr = np.clip(arr, min_val, max_val) - min_val 145 | quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) 146 | 147 | return quantized_arr 148 | 149 | 150 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64): 151 | """Dequantize an array. 152 | 153 | Args: 154 | arr (ndarray): Input array. 155 | min_val (scalar): Minimum value to be clipped. 156 | max_val (scalar): Maximum value to be clipped. 157 | levels (int): Quantization levels. 158 | dtype (np.type): The type of the dequantized array. 159 | 160 | Returns: 161 | tuple: Dequantized array. 162 | """ 163 | if not (isinstance(levels, int) and levels > 1): 164 | raise ValueError(f'levels must be a positive integer, but got {levels}') 165 | if min_val >= max_val: 166 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 167 | 168 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val 169 | 170 | return dequantized_arr 171 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | if img.dtype == 'float64': 25 | img = img.astype('float32') 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | img = torch.from_numpy(img.transpose(2, 0, 1)) 28 | if float32: 29 | img = img.float() 30 | return img 31 | 32 | if isinstance(imgs, list): 33 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 34 | else: 35 | return _totensor(imgs, bgr2rgb, float32) 36 | 37 | 38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 39 | """Convert torch Tensors into image numpy arrays. 40 | 41 | After clamping to [min, max], values will be normalized to [0, 1]. 42 | 43 | Args: 44 | tensor (Tensor or list[Tensor]): Accept shapes: 45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 46 | 2) 3D Tensor of shape (3/1 x H x W); 47 | 3) 2D Tensor of shape (H x W). 48 | Tensor channel should be in RGB order. 49 | rgb2bgr (bool): Whether to change rgb to bgr. 50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 51 | to uint8 type with range [0, 255]; otherwise, float type with 52 | range [0, 1]. Default: ``np.uint8``. 53 | min_max (tuple[int]): min and max values for clamp. 54 | 55 | Returns: 56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 57 | shape (H x W). The channel order is BGR. 58 | """ 59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if torch.is_tensor(tensor): 63 | tensor = [tensor] 64 | result = [] 65 | for _tensor in tensor: 66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 68 | 69 | n_dim = _tensor.dim() 70 | if n_dim == 4: 71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 72 | img_np = img_np.transpose(1, 2, 0) 73 | if rgb2bgr: 74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 75 | elif n_dim == 3: 76 | img_np = _tensor.numpy() 77 | img_np = img_np.transpose(1, 2, 0) 78 | if img_np.shape[2] == 1: # gray image 79 | img_np = np.squeeze(img_np, axis=2) 80 | else: 81 | if rgb2bgr: 82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 83 | elif n_dim == 2: 84 | img_np = _tensor.numpy() 85 | else: 86 | raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') 87 | if out_type == np.uint8: 88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 89 | img_np = (img_np * 255.0).round() 90 | img_np = img_np.astype(out_type) 91 | result.append(img_np) 92 | if len(result) == 1: 93 | result = result[0] 94 | return result 95 | 96 | 97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): 98 | """This implementation is slightly faster than tensor2img. 99 | It now only supports torch tensor with shape (1, c, h, w). 100 | 101 | Args: 102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w). 103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True. 104 | min_max (tuple[int]): min and max values for clamp. 105 | """ 106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) 107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 108 | output = output.type(torch.uint8).cpu().numpy() 109 | if rgb2bgr: 110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 111 | return output 112 | 113 | 114 | def imfrombytes(content, flag='color', float32=False): 115 | """Read an image from bytes. 116 | 117 | Args: 118 | content (bytes): Image bytes got from files or other streams. 119 | flag (str): Flags specifying the color type of a loaded image, 120 | candidates are `color`, `grayscale` and `unchanged`. 121 | float32 (bool): Whether to change to float32., If True, will also norm 122 | to [0, 1]. Default: False. 123 | 124 | Returns: 125 | ndarray: Loaded image array. 126 | """ 127 | img_np = np.frombuffer(content, np.uint8) 128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} 129 | img = cv2.imdecode(img_np, imread_flags[flag]) 130 | if float32: 131 | img = img.astype(np.float32) / 255. 132 | return img 133 | 134 | 135 | def imwrite(img, file_path, params=None, auto_mkdir=True): 136 | """Write image to file. 137 | 138 | Args: 139 | img (ndarray): Image array to be written. 140 | file_path (str): Image file path. 141 | params (None or list): Same as opencv's :func:`imwrite` interface. 142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 143 | whether to create it automatically. 144 | 145 | Returns: 146 | bool: Successful or not. 147 | """ 148 | if auto_mkdir: 149 | dir_name = os.path.abspath(os.path.dirname(file_path)) 150 | os.makedirs(dir_name, exist_ok=True) 151 | ok = cv2.imwrite(file_path, img, params) 152 | if not ok: 153 | raise IOError('Failed in writing images.') 154 | 155 | 156 | def crop_border(imgs, crop_border): 157 | """Crop borders of images. 158 | 159 | Args: 160 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 161 | crop_border (int): Crop border for each end of height and weight. 162 | 163 | Returns: 164 | list[ndarray]: Cropped images. 165 | """ 166 | if crop_border == 0: 167 | return imgs 168 | else: 169 | if isinstance(imgs, list): 170 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] 171 | else: 172 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] 173 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 2 | 3 | import os 4 | import torch 5 | from torch.autograd import Function 6 | from torch.nn import functional as F 7 | 8 | BASICSR_JIT = os.getenv('BASICSR_JIT') 9 | if BASICSR_JIT == 'True': 10 | from torch.utils.cpp_extension import load 11 | module_path = os.path.dirname(__file__) 12 | upfirdn2d_ext = load( 13 | 'upfirdn2d', 14 | sources=[ 15 | os.path.join(module_path, 'src', 'upfirdn2d.cpp'), 16 | os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), 17 | ], 18 | ) 19 | else: 20 | try: 21 | from . import upfirdn2d_ext 22 | except ImportError: 23 | pass 24 | # avoid annoying print output 25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' 26 | # '1. compile with BASICSR_EXT=True. or\n ' 27 | # '2. set BASICSR_JIT=True during running') 28 | 29 | 30 | class UpFirDn2dBackward(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): 34 | 35 | up_x, up_y = up 36 | down_x, down_y = down 37 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 38 | 39 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 40 | 41 | grad_input = upfirdn2d_ext.upfirdn2d( 42 | grad_output, 43 | grad_kernel, 44 | down_x, 45 | down_y, 46 | up_x, 47 | up_y, 48 | g_pad_x0, 49 | g_pad_x1, 50 | g_pad_y0, 51 | g_pad_y1, 52 | ) 53 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 54 | 55 | ctx.save_for_backward(kernel) 56 | 57 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 58 | 59 | ctx.up_x = up_x 60 | ctx.up_y = up_y 61 | ctx.down_x = down_x 62 | ctx.down_y = down_y 63 | ctx.pad_x0 = pad_x0 64 | ctx.pad_x1 = pad_x1 65 | ctx.pad_y0 = pad_y0 66 | ctx.pad_y1 = pad_y1 67 | ctx.in_size = in_size 68 | ctx.out_size = out_size 69 | 70 | return grad_input 71 | 72 | @staticmethod 73 | def backward(ctx, gradgrad_input): 74 | kernel, = ctx.saved_tensors 75 | 76 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 77 | 78 | gradgrad_out = upfirdn2d_ext.upfirdn2d( 79 | gradgrad_input, 80 | kernel, 81 | ctx.up_x, 82 | ctx.up_y, 83 | ctx.down_x, 84 | ctx.down_y, 85 | ctx.pad_x0, 86 | ctx.pad_x1, 87 | ctx.pad_y0, 88 | ctx.pad_y1, 89 | ) 90 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], 91 | # ctx.out_size[1], ctx.in_size[3]) 92 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) 93 | 94 | return gradgrad_out, None, None, None, None, None, None, None, None 95 | 96 | 97 | class UpFirDn2d(Function): 98 | 99 | @staticmethod 100 | def forward(ctx, input, kernel, up, down, pad): 101 | up_x, up_y = up 102 | down_x, down_y = down 103 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 104 | 105 | kernel_h, kernel_w = kernel.shape 106 | _, channel, in_h, in_w = input.shape 107 | ctx.in_size = input.shape 108 | 109 | input = input.reshape(-1, in_h, in_w, 1) 110 | 111 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 112 | 113 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 114 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 115 | ctx.out_size = (out_h, out_w) 116 | 117 | ctx.up = (up_x, up_y) 118 | ctx.down = (down_x, down_y) 119 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 120 | 121 | g_pad_x0 = kernel_w - pad_x0 - 1 122 | g_pad_y0 = kernel_h - pad_y0 - 1 123 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 124 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 125 | 126 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 127 | 128 | out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) 129 | # out = out.view(major, out_h, out_w, minor) 130 | out = out.view(-1, channel, out_h, out_w) 131 | 132 | return out 133 | 134 | @staticmethod 135 | def backward(ctx, grad_output): 136 | kernel, grad_kernel = ctx.saved_tensors 137 | 138 | grad_input = UpFirDn2dBackward.apply( 139 | grad_output, 140 | kernel, 141 | grad_kernel, 142 | ctx.up, 143 | ctx.down, 144 | ctx.pad, 145 | ctx.g_pad, 146 | ctx.in_size, 147 | ctx.out_size, 148 | ) 149 | 150 | return grad_input, None, None, None, None 151 | 152 | 153 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 154 | if input.device.type == 'cpu': 155 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) 156 | else: 157 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) 158 | 159 | return out 160 | 161 | 162 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): 163 | _, channel, in_h, in_w = input.shape 164 | input = input.reshape(-1, in_h, in_w, 1) 165 | 166 | _, in_h, in_w, minor = input.shape 167 | kernel_h, kernel_w = kernel.shape 168 | 169 | out = input.view(-1, in_h, 1, in_w, 1, minor) 170 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 171 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 172 | 173 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 174 | out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] 175 | 176 | out = out.permute(0, 3, 1, 2) 177 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 178 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 179 | out = F.conv2d(out, w) 180 | out = out.reshape( 181 | -1, 182 | minor, 183 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 184 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 185 | ) 186 | out = out.permute(0, 2, 3, 1) 187 | out = out[:, ::down_y, ::down_x, :] 188 | 189 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 190 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 191 | 192 | return out.view(-1, channel, out_h, out_w) 193 | -------------------------------------------------------------------------------- /basicsr/utils/options-DESKTOP-S7HM52K.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import torch 4 | import yaml 5 | from collections import OrderedDict 6 | from os import path as osp 7 | 8 | from basicsr.utils import set_random_seed 9 | from basicsr.utils.dist_util import get_dist_info, init_dist, master_only 10 | 11 | 12 | def ordered_yaml(): 13 | """Support OrderedDict for yaml. 14 | 15 | Returns: 16 | yaml Loader and Dumper. 17 | """ 18 | try: 19 | from yaml import CDumper as Dumper 20 | from yaml import CLoader as Loader 21 | except ImportError: 22 | from yaml import Dumper, Loader 23 | 24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 25 | 26 | def dict_representer(dumper, data): 27 | return dumper.represent_dict(data.items()) 28 | 29 | def dict_constructor(loader, node): 30 | return OrderedDict(loader.construct_pairs(node)) 31 | 32 | Dumper.add_representer(OrderedDict, dict_representer) 33 | Loader.add_constructor(_mapping_tag, dict_constructor) 34 | return Loader, Dumper 35 | 36 | 37 | def dict2str(opt, indent_level=1): 38 | """dict to string for printing options. 39 | 40 | Args: 41 | opt (dict): Option dict. 42 | indent_level (int): Indent level. Default: 1. 43 | 44 | Return: 45 | (str): Option string for printing. 46 | """ 47 | msg = '\n' 48 | for k, v in opt.items(): 49 | if isinstance(v, dict): 50 | msg += ' ' * (indent_level * 2) + k + ':[' 51 | msg += dict2str(v, indent_level + 1) 52 | msg += ' ' * (indent_level * 2) + ']\n' 53 | else: 54 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 55 | return msg 56 | 57 | 58 | def _postprocess_yml_value(value): 59 | # None 60 | if value == '~' or value.lower() == 'none': 61 | return None 62 | # bool 63 | if value.lower() == 'true': 64 | return True 65 | elif value.lower() == 'false': 66 | return False 67 | # !!float number 68 | if value.startswith('!!float'): 69 | return float(value.replace('!!float', '')) 70 | # number 71 | if value.isdigit(): 72 | return int(value) 73 | elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: 74 | return float(value) 75 | # list 76 | if value.startswith('['): 77 | return eval(value) 78 | # str 79 | return value 80 | 81 | 82 | def parse_options(root_path, is_train=True): 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--opt', default="/home/zhiyi/pycharm-tmp/basicsr/options/test/RCAN/train_RCAN_x4.yml", type=str, required=False, help='Path to option YAML file.') 85 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') 86 | parser.add_argument('--auto_resume', action='store_true') 87 | parser.add_argument('--debug', action='store_true') 88 | parser.add_argument('--local_rank', type=int, default=0) 89 | parser.add_argument( 90 | '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') 91 | args = parser.parse_args() 92 | 93 | # parse yml to dict 94 | 95 | with open(args.opt, mode='r') as f: 96 | opt = yaml.load(f, Loader=ordered_yaml()[0]) 97 | 98 | # distributed settings 99 | if args.launcher == 'none': 100 | opt['dist'] = False 101 | print('Disable distributed.', flush=True) 102 | else: 103 | opt['dist'] = True 104 | if args.launcher == 'slurm' and 'dist_params' in opt: 105 | init_dist(args.launcher, **opt['dist_params']) 106 | else: 107 | init_dist(args.launcher) 108 | opt['rank'], opt['world_size'] = get_dist_info() 109 | 110 | # random seed 111 | seed = opt.get('manual_seed') 112 | if seed is None: 113 | seed = random.randint(1, 10000) 114 | opt['manual_seed'] = seed 115 | set_random_seed(seed + opt['rank']) 116 | 117 | # force to update yml options 118 | if args.force_yml is not None: 119 | for entry in args.force_yml: 120 | # now do not support creating new keys 121 | keys, value = entry.split('=') 122 | keys, value = keys.strip(), value.strip() 123 | value = _postprocess_yml_value(value) 124 | eval_str = 'opt' 125 | for key in keys.split(':'): 126 | eval_str += f'["{key}"]' 127 | eval_str += '=value' 128 | # using exec function 129 | exec(eval_str) 130 | 131 | opt['auto_resume'] = args.auto_resume 132 | opt['is_train'] = is_train 133 | 134 | # debug setting 135 | if args.debug and not opt['name'].startswith('debug'): 136 | opt['name'] = 'debug_' + opt['name'] 137 | 138 | if opt['num_gpu'] == 'auto': 139 | opt['num_gpu'] = torch.cuda.device_count() 140 | 141 | # datasets 142 | for phase, dataset in opt['datasets'].items(): 143 | # for multiple datasets, e.g., val_1, val_2; test_1, test_2 144 | phase = phase.split('_')[0] 145 | dataset['phase'] = phase 146 | if 'scale' in opt: 147 | dataset['scale'] = opt['scale'] 148 | if dataset.get('dataroot_gt') is not None: 149 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 150 | if dataset.get('dataroot_lq') is not None: 151 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 152 | 153 | # paths 154 | for key, val in opt['path'].items(): 155 | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): 156 | opt['path'][key] = osp.expanduser(val) 157 | 158 | if is_train: 159 | experiments_root = osp.join(root_path, 'experiments', opt['name']) 160 | opt['path']['experiments_root'] = experiments_root 161 | opt['path']['models'] = osp.join(experiments_root, 'models') 162 | opt['path']['training_states'] = osp.join(experiments_root, 'training_states') 163 | opt['path']['log'] = experiments_root 164 | opt['path']['visualization'] = osp.join(experiments_root, 'visualization') 165 | 166 | # change some options for debug mode 167 | if 'debug' in opt['name']: 168 | if 'val' in opt: 169 | opt['val']['val_freq'] = 8 170 | opt['logger']['print_freq'] = 1 171 | opt['logger']['save_checkpoint_freq'] = 8 172 | else: # test 173 | results_root = osp.join(root_path, 'results', opt['name']) 174 | opt['path']['results_root'] = results_root 175 | opt['path']['log'] = results_root 176 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 177 | 178 | return opt, args 179 | 180 | 181 | @master_only 182 | def copy_opt_file(opt_file, experiments_root): 183 | # copy the yml file to the experiment root 184 | import sys 185 | import time 186 | from shutil import copyfile 187 | cmd = ' '.join(sys.argv) 188 | filename = osp.join(experiments_root, osp.basename(opt_file)) 189 | copyfile(opt_file, filename) 190 | 191 | with open(filename, 'r+') as f: 192 | lines = f.readlines() 193 | lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') 194 | f.seek(0) 195 | f.writelines(lines) 196 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import torch 4 | import yaml 5 | from collections import OrderedDict 6 | from os import path as osp 7 | 8 | from basicsr.utils import set_random_seed 9 | from basicsr.utils.dist_util import get_dist_info, init_dist, master_only 10 | 11 | 12 | def ordered_yaml(): 13 | """Support OrderedDict for yaml. 14 | 15 | Returns: 16 | yaml Loader and Dumper. 17 | """ 18 | try: 19 | from yaml import CDumper as Dumper 20 | from yaml import CLoader as Loader 21 | except ImportError: 22 | from yaml import Dumper, Loader 23 | 24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 25 | 26 | def dict_representer(dumper, data): 27 | return dumper.represent_dict(data.items()) 28 | 29 | def dict_constructor(loader, node): 30 | return OrderedDict(loader.construct_pairs(node)) 31 | 32 | Dumper.add_representer(OrderedDict, dict_representer) 33 | Loader.add_constructor(_mapping_tag, dict_constructor) 34 | return Loader, Dumper 35 | 36 | 37 | def dict2str(opt, indent_level=1): 38 | """dict to string for printing options. 39 | 40 | Args: 41 | opt (dict): Option dict. 42 | indent_level (int): Indent level. Default: 1. 43 | 44 | Return: 45 | (str): Option string for printing. 46 | """ 47 | msg = '\n' 48 | for k, v in opt.items(): 49 | if isinstance(v, dict): 50 | msg += ' ' * (indent_level * 2) + k + ':[' 51 | msg += dict2str(v, indent_level + 1) 52 | msg += ' ' * (indent_level * 2) + ']\n' 53 | else: 54 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 55 | return msg 56 | 57 | 58 | def _postprocess_yml_value(value): 59 | # None 60 | if value == '~' or value.lower() == 'none': 61 | return None 62 | # bool 63 | if value.lower() == 'true': 64 | return True 65 | elif value.lower() == 'false': 66 | return False 67 | # !!float number 68 | if value.startswith('!!float'): 69 | return float(value.replace('!!float', '')) 70 | # number 71 | if value.isdigit(): 72 | return int(value) 73 | elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: 74 | return float(value) 75 | # list 76 | if value.startswith('['): 77 | return eval(value) 78 | # str 79 | return value 80 | 81 | 82 | def parse_options(root_path, is_train=True): 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--opt', default="/home/zhiyi/pycharm-tmp/basicsr/options/train/RCAN/train_RCAN_x3.yml", type=str, required=False, help='Path to option YAML file.') 85 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') 86 | parser.add_argument('--auto_resume', action='store_true') 87 | parser.add_argument('--debug', action='store_true') 88 | parser.add_argument('--local_rank', type=int, default=0) 89 | parser.add_argument( 90 | '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') 91 | args = parser.parse_args() 92 | 93 | # parse yml to dict 94 | 95 | with open(args.opt, mode='r') as f: 96 | opt = yaml.load(f, Loader=ordered_yaml()[0]) 97 | 98 | # distributed settings 99 | if args.launcher == 'none': 100 | opt['dist'] = False 101 | print('Disable distributed.', flush=True) 102 | else: 103 | opt['dist'] = True 104 | if args.launcher == 'slurm' and 'dist_params' in opt: 105 | init_dist(args.launcher, **opt['dist_params']) 106 | else: 107 | init_dist(args.launcher) 108 | opt['rank'], opt['world_size'] = get_dist_info() # 与分布式训练有关的参数 109 | 110 | # random seed 111 | seed = opt.get('manual_seed') 112 | if seed is None: 113 | seed = random.randint(1, 10000) 114 | opt['manual_seed'] = seed 115 | set_random_seed(seed + opt['rank']) 116 | 117 | # force to update yml options 118 | if args.force_yml is not None: 119 | for entry in args.force_yml: 120 | # now do not support creating new keys 121 | keys, value = entry.split('=') 122 | keys, value = keys.strip(), value.strip() 123 | value = _postprocess_yml_value(value) 124 | eval_str = 'opt' 125 | for key in keys.split(':'): 126 | eval_str += f'["{key}"]' 127 | eval_str += '=value' 128 | # using exec function 129 | exec(eval_str) 130 | 131 | opt['auto_resume'] = args.auto_resume 132 | opt['is_train'] = is_train 133 | 134 | # debug setting 135 | if args.debug and not opt['name'].startswith('debug'): 136 | opt['name'] = 'debug_' + opt['name'] 137 | 138 | if opt['num_gpu'] == 'auto': 139 | opt['num_gpu'] = torch.cuda.device_count() 140 | 141 | # datasets 142 | for phase, dataset in opt['datasets'].items(): 143 | # for multiple datasets, e.g., val_1, val_2; test_1, test_2 144 | phase = phase.split('_')[0] 145 | dataset['phase'] = phase 146 | if 'scale' in opt: 147 | dataset['scale'] = opt['scale'] 148 | if dataset.get('dataroot_gt') is not None: 149 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 150 | if dataset.get('dataroot_lq') is not None: 151 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 152 | 153 | # paths 154 | for key, val in opt['path'].items(): 155 | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): 156 | opt['path'][key] = osp.expanduser(val) 157 | 158 | if is_train: 159 | experiments_root = osp.join(root_path, 'experiments', opt['name']) 160 | opt['path']['experiments_root'] = experiments_root 161 | opt['path']['models'] = osp.join(experiments_root, 'models') 162 | opt['path']['training_states'] = osp.join(experiments_root, 'training_states') 163 | opt['path']['log'] = experiments_root 164 | opt['path']['visualization'] = osp.join(experiments_root, 'visualization') 165 | 166 | # change some options for debug mode 167 | if 'debug' in opt['name']: 168 | if 'val' in opt: 169 | opt['val']['val_freq'] = 8 170 | opt['logger']['print_freq'] = 1 171 | opt['logger']['save_checkpoint_freq'] = 8 172 | else: # test 173 | results_root = osp.join(root_path, 'results', opt['name']) 174 | opt['path']['results_root'] = results_root 175 | opt['path']['log'] = results_root 176 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 177 | 178 | return opt, args 179 | 180 | 181 | @master_only 182 | def copy_opt_file(opt_file, experiments_root): 183 | # copy the yml file to the experiment root 184 | import sys 185 | import time 186 | from shutil import copyfile 187 | cmd = ' '.join(sys.argv) 188 | filename = osp.join(experiments_root, osp.basename(opt_file)) 189 | copyfile(opt_file, filename) 190 | 191 | with open(filename, 'r+') as f: 192 | lines = f.readlines() 193 | lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') 194 | f.seek(0) 195 | f.writelines(lines) 196 | -------------------------------------------------------------------------------- /basicsr/utils/lmdb_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import lmdb 3 | import sys 4 | from multiprocessing import Pool 5 | from os import path as osp 6 | from tqdm import tqdm 7 | 8 | 9 | def make_lmdb_from_imgs(data_path, 10 | lmdb_path, 11 | img_path_list, 12 | keys, 13 | batch=5000, 14 | compress_level=1, 15 | multiprocessing_read=False, 16 | n_thread=40, 17 | map_size=None): 18 | """Make lmdb from images. 19 | 20 | Contents of lmdb. The file structure is: 21 | example.lmdb 22 | ├── data.mdb 23 | ├── lock.mdb 24 | ├── meta_info.txt 25 | 26 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 27 | https://lmdb.readthedocs.io/en/release/ for more details. 28 | 29 | The meta_info.txt is a specified txt file to record the meta information 30 | of our datasets. It will be automatically created when preparing 31 | datasets by our provided dataset tools. 32 | Each line in the txt file records 1)image name (with extension), 33 | 2)image shape, and 3)compression level, separated by a white space. 34 | 35 | For example, the meta information could be: 36 | `000_00000000.png (720,1280,3) 1`, which means: 37 | 1) image name (with extension): 000_00000000.png; 38 | 2) image shape: (720,1280,3); 39 | 3) compression level: 1 40 | 41 | We use the image name without extension as the lmdb key. 42 | 43 | If `multiprocessing_read` is True, it will read all the images to memory 44 | using multiprocessing. Thus, your server needs to have enough memory. 45 | 46 | Args: 47 | data_path (str): Data path for reading images. 48 | lmdb_path (str): Lmdb save path. 49 | img_path_list (str): Image path list. 50 | keys (str): Used for lmdb keys. 51 | batch (int): After processing batch images, lmdb commits. 52 | Default: 5000. 53 | compress_level (int): Compress level when encoding images. Default: 1. 54 | multiprocessing_read (bool): Whether use multiprocessing to read all 55 | the images to memory. Default: False. 56 | n_thread (int): For multiprocessing. 57 | map_size (int | None): Map size for lmdb env. If None, use the 58 | estimated size from images. Default: None 59 | """ 60 | 61 | assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' 62 | f'but got {len(img_path_list)} and {len(keys)}') 63 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...') 64 | print(f'Totoal images: {len(img_path_list)}') 65 | if not lmdb_path.endswith('.lmdb'): 66 | raise ValueError("lmdb_path must end with '.lmdb'.") 67 | if osp.exists(lmdb_path): 68 | print(f'Folder {lmdb_path} already exists. Exit.') 69 | sys.exit(1) 70 | 71 | if multiprocessing_read: 72 | # read all the images to memory (multiprocessing) 73 | dataset = {} # use dict to keep the order for multiprocessing 74 | shapes = {} 75 | print(f'Read images with multiprocessing, #thread: {n_thread} ...') 76 | pbar = tqdm(total=len(img_path_list), unit='image') 77 | 78 | def callback(arg): 79 | """get the image data and update pbar.""" 80 | key, dataset[key], shapes[key] = arg 81 | pbar.update(1) 82 | pbar.set_description(f'Read {key}') 83 | 84 | pool = Pool(n_thread) 85 | for path, key in zip(img_path_list, keys): 86 | pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) 87 | pool.close() 88 | pool.join() 89 | pbar.close() 90 | print(f'Finish reading {len(img_path_list)} images.') 91 | 92 | # create lmdb environment 93 | if map_size is None: 94 | # obtain data size for one image 95 | img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) 96 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 97 | data_size_per_img = img_byte.nbytes 98 | print('Data size per image is: ', data_size_per_img) 99 | data_size = data_size_per_img * len(img_path_list) 100 | map_size = data_size * 10 101 | 102 | env = lmdb.open(lmdb_path, map_size=map_size) 103 | 104 | # write data to lmdb 105 | pbar = tqdm(total=len(img_path_list), unit='chunk') 106 | txn = env.begin(write=True) 107 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 108 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 109 | pbar.update(1) 110 | pbar.set_description(f'Write {key}') 111 | key_byte = key.encode('ascii') 112 | if multiprocessing_read: 113 | img_byte = dataset[key] 114 | h, w, c = shapes[key] 115 | else: 116 | _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) 117 | h, w, c = img_shape 118 | 119 | txn.put(key_byte, img_byte) 120 | # write meta information 121 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') 122 | if idx % batch == 0: 123 | txn.commit() 124 | txn = env.begin(write=True) 125 | pbar.close() 126 | txn.commit() 127 | env.close() 128 | txt_file.close() 129 | print('\nFinish writing lmdb.') 130 | 131 | 132 | def read_img_worker(path, key, compress_level): 133 | """Read image worker. 134 | 135 | Args: 136 | path (str): Image path. 137 | key (str): Image key. 138 | compress_level (int): Compress level when encoding images. 139 | 140 | Returns: 141 | str: Image key. 142 | byte: Image byte. 143 | tuple[int]: Image shape. 144 | """ 145 | 146 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 147 | if img.ndim == 2: 148 | h, w = img.shape 149 | c = 1 150 | else: 151 | h, w, c = img.shape 152 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 153 | return (key, img_byte, (h, w, c)) 154 | 155 | 156 | class LmdbMaker(): 157 | """LMDB Maker. 158 | 159 | Args: 160 | lmdb_path (str): Lmdb save path. 161 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. 162 | batch (int): After processing batch images, lmdb commits. 163 | Default: 5000. 164 | compress_level (int): Compress level when encoding images. Default: 1. 165 | """ 166 | 167 | def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): 168 | if not lmdb_path.endswith('.lmdb'): 169 | raise ValueError("lmdb_path must end with '.lmdb'.") 170 | if osp.exists(lmdb_path): 171 | print(f'Folder {lmdb_path} already exists. Exit.') 172 | sys.exit(1) 173 | 174 | self.lmdb_path = lmdb_path 175 | self.batch = batch 176 | self.compress_level = compress_level 177 | self.env = lmdb.open(lmdb_path, map_size=map_size) 178 | self.txn = self.env.begin(write=True) 179 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 180 | self.counter = 0 181 | 182 | def put(self, img_byte, key, img_shape): 183 | self.counter += 1 184 | key_byte = key.encode('ascii') 185 | self.txn.put(key_byte, img_byte) 186 | # write meta information 187 | h, w, c = img_shape 188 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') 189 | if self.counter % self.batch == 0: 190 | self.txn.commit() 191 | self.txn = self.env.begin(write=True) 192 | 193 | def close(self): 194 | self.txn.commit() 195 | self.env.close() 196 | self.txt_file.close() 197 | -------------------------------------------------------------------------------- /basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | initialized_logger = {} 8 | 9 | 10 | class AvgTimer(): 11 | 12 | def __init__(self, window=200): 13 | self.window = window # average window 14 | self.current_time = 0 15 | self.total_time = 0 16 | self.count = 0 17 | self.avg_time = 0 18 | self.start() 19 | 20 | def start(self): 21 | self.start_time = time.time() 22 | 23 | def record(self): 24 | self.count += 1 25 | self.current_time = time.time() - self.start_time 26 | self.total_time += self.current_time 27 | # calculate average time 28 | self.avg_time = self.total_time / self.count 29 | # reset 30 | if self.count > self.window: 31 | self.count = 0 32 | self.total_time = 0 33 | 34 | def get_current_time(self): 35 | return self.current_time 36 | 37 | def get_avg_time(self): 38 | return self.avg_time 39 | 40 | 41 | class MessageLogger(): 42 | """Message logger for printing. 43 | 44 | Args: 45 | opt (dict): Config. It contains the following keys: 46 | name (str): Exp name. 47 | logger (dict): Contains 'print_freq' (str) for logger interval. 48 | train (dict): Contains 'total_iter' (int) for total iters. 49 | use_tb_logger (bool): Use tensorboard logger. 50 | start_iter (int): Start iter. Default: 1. 51 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 52 | """ 53 | 54 | def __init__(self, opt, start_iter=1, tb_logger=None): 55 | self.exp_name = opt['name'] 56 | self.interval = opt['logger']['print_freq'] 57 | self.start_iter = start_iter 58 | self.max_iters = opt['train']['total_iter'] 59 | self.use_tb_logger = opt['logger']['use_tb_logger'] 60 | self.tb_logger = tb_logger 61 | self.start_time = time.time() 62 | self.logger = get_root_logger() 63 | 64 | def reset_start_time(self): 65 | self.start_time = time.time() 66 | 67 | @master_only 68 | def __call__(self, log_vars): 69 | """Format logging message. 70 | 71 | Args: 72 | log_vars (dict): It contains the following keys: 73 | epoch (int): Epoch number. 74 | iter (int): Current iter. 75 | lrs (list): List for learning rates. 76 | 77 | time (float): Iter time. 78 | data_time (float): Data time for each iter. 79 | """ 80 | # epoch, iter, learning rates 81 | epoch = log_vars.pop('epoch') 82 | current_iter = log_vars.pop('iter') 83 | lrs = log_vars.pop('lrs') 84 | 85 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(') 86 | for v in lrs: 87 | message += f'{v:.3e},' 88 | message += ')] ' 89 | 90 | # time and estimated time 91 | if 'time' in log_vars.keys(): 92 | iter_time = log_vars.pop('time') 93 | data_time = log_vars.pop('data_time') 94 | 95 | total_time = time.time() - self.start_time 96 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 97 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 98 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 99 | message += f'[eta: {eta_str}, ' 100 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 101 | 102 | # other items, especially losses 103 | for k, v in log_vars.items(): 104 | message += f'{k}: {v:.4e} ' 105 | # tensorboard logger 106 | if self.use_tb_logger and 'debug' not in self.exp_name: 107 | if k.startswith('l_'): 108 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 109 | else: 110 | self.tb_logger.add_scalar(k, v, current_iter) 111 | self.logger.info(message) 112 | 113 | 114 | @master_only 115 | def init_tb_logger(log_dir): 116 | from torch.utils.tensorboard import SummaryWriter 117 | tb_logger = SummaryWriter(log_dir=log_dir) 118 | return tb_logger 119 | 120 | 121 | @master_only 122 | def init_wandb_logger(opt): 123 | """We now only use wandb to sync tensorboard log.""" 124 | import wandb 125 | logger = get_root_logger() 126 | 127 | project = opt['logger']['wandb']['project'] 128 | resume_id = opt['logger']['wandb'].get('resume_id') 129 | if resume_id: 130 | wandb_id = resume_id 131 | resume = 'allow' 132 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 133 | else: 134 | wandb_id = wandb.util.generate_id() 135 | resume = 'never' 136 | 137 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) 138 | 139 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 140 | 141 | 142 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 143 | """Get the root logger. 144 | 145 | The logger will be initialized if it has not been initialized. By default a 146 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 147 | also be added. 148 | 149 | Args: 150 | logger_name (str): root logger name. Default: 'basicsr'. 151 | log_file (str | None): The log filename. If specified, a FileHandler 152 | will be added to the root logger. 153 | log_level (int): The root logger level. Note that only the process of 154 | rank 0 is affected, while other processes will set the level to 155 | "Error" and be silent most of the time. 156 | 157 | Returns: 158 | logging.Logger: The root logger. 159 | """ 160 | logger = logging.getLogger(logger_name) 161 | # if the logger has been initialized, just return it 162 | if logger_name in initialized_logger: 163 | return logger 164 | 165 | format_str = '%(asctime)s %(levelname)s: %(message)s' 166 | stream_handler = logging.StreamHandler() 167 | stream_handler.setFormatter(logging.Formatter(format_str)) 168 | logger.addHandler(stream_handler) 169 | logger.propagate = False 170 | rank, _ = get_dist_info() 171 | if rank != 0: 172 | logger.setLevel('ERROR') 173 | elif log_file is not None: 174 | logger.setLevel(log_level) 175 | # add file handler 176 | file_handler = logging.FileHandler(log_file, 'w') 177 | file_handler.setFormatter(logging.Formatter(format_str)) 178 | file_handler.setLevel(log_level) 179 | logger.addHandler(file_handler) 180 | initialized_logger[logger_name] = True 181 | return logger 182 | 183 | 184 | def get_env_info(): 185 | """Get environment information. 186 | 187 | Currently, only log the software version. 188 | """ 189 | import torch 190 | import torchvision 191 | 192 | from basicsr.version import __version__ 193 | msg = r""" 194 | ____ _ _____ ____ 195 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 196 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 197 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 198 | /_____/ \__,_//____//_/ \___//____//_/ |_| 199 | ______ __ __ __ __ 200 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 201 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 202 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 203 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 204 | """ 205 | msg += ('\nVersion Information: ' 206 | f'\n\tBasicSR: {__version__}' 207 | f'\n\tPyTorch: {torch.__version__}' 208 | f'\n\tTorchVision: {torchvision.__version__}') 209 | return msg 210 | -------------------------------------------------------------------------------- /basicsr/ops/dcn/src/deform_conv_ext.cpp: -------------------------------------------------------------------------------- 1 | // modify from 2 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #define WITH_CUDA // always use cuda 11 | #ifdef WITH_CUDA 12 | int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, 13 | at::Tensor offset, at::Tensor output, 14 | at::Tensor columns, at::Tensor ones, int kW, 15 | int kH, int dW, int dH, int padW, int padH, 16 | int dilationW, int dilationH, int group, 17 | int deformable_group, int im2col_step); 18 | 19 | int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, 20 | at::Tensor gradOutput, at::Tensor gradInput, 21 | at::Tensor gradOffset, at::Tensor weight, 22 | at::Tensor columns, int kW, int kH, int dW, 23 | int dH, int padW, int padH, int dilationW, 24 | int dilationH, int group, 25 | int deformable_group, int im2col_step); 26 | 27 | int deform_conv_backward_parameters_cuda( 28 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput, 29 | at::Tensor gradWeight, // at::Tensor gradBias, 30 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, 31 | int padW, int padH, int dilationW, int dilationH, int group, 32 | int deformable_group, float scale, int im2col_step); 33 | 34 | void modulated_deform_conv_cuda_forward( 35 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 36 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, 37 | int kernel_h, int kernel_w, const int stride_h, const int stride_w, 38 | const int pad_h, const int pad_w, const int dilation_h, 39 | const int dilation_w, const int group, const int deformable_group, 40 | const bool with_bias); 41 | 42 | void modulated_deform_conv_cuda_backward( 43 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 44 | at::Tensor offset, at::Tensor mask, at::Tensor columns, 45 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, 46 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, 47 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, 48 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, 49 | const bool with_bias); 50 | #endif 51 | 52 | int deform_conv_forward(at::Tensor input, at::Tensor weight, 53 | at::Tensor offset, at::Tensor output, 54 | at::Tensor columns, at::Tensor ones, int kW, 55 | int kH, int dW, int dH, int padW, int padH, 56 | int dilationW, int dilationH, int group, 57 | int deformable_group, int im2col_step) { 58 | if (input.device().is_cuda()) { 59 | #ifdef WITH_CUDA 60 | return deform_conv_forward_cuda(input, weight, offset, output, columns, 61 | ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, 62 | deformable_group, im2col_step); 63 | #else 64 | AT_ERROR("deform conv is not compiled with GPU support"); 65 | #endif 66 | } 67 | AT_ERROR("deform conv is not implemented on CPU"); 68 | } 69 | 70 | int deform_conv_backward_input(at::Tensor input, at::Tensor offset, 71 | at::Tensor gradOutput, at::Tensor gradInput, 72 | at::Tensor gradOffset, at::Tensor weight, 73 | at::Tensor columns, int kW, int kH, int dW, 74 | int dH, int padW, int padH, int dilationW, 75 | int dilationH, int group, 76 | int deformable_group, int im2col_step) { 77 | if (input.device().is_cuda()) { 78 | #ifdef WITH_CUDA 79 | return deform_conv_backward_input_cuda(input, offset, gradOutput, 80 | gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, 81 | dilationW, dilationH, group, deformable_group, im2col_step); 82 | #else 83 | AT_ERROR("deform conv is not compiled with GPU support"); 84 | #endif 85 | } 86 | AT_ERROR("deform conv is not implemented on CPU"); 87 | } 88 | 89 | int deform_conv_backward_parameters( 90 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput, 91 | at::Tensor gradWeight, // at::Tensor gradBias, 92 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, 93 | int padW, int padH, int dilationW, int dilationH, int group, 94 | int deformable_group, float scale, int im2col_step) { 95 | if (input.device().is_cuda()) { 96 | #ifdef WITH_CUDA 97 | return deform_conv_backward_parameters_cuda(input, offset, gradOutput, 98 | gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, 99 | dilationH, group, deformable_group, scale, im2col_step); 100 | #else 101 | AT_ERROR("deform conv is not compiled with GPU support"); 102 | #endif 103 | } 104 | AT_ERROR("deform conv is not implemented on CPU"); 105 | } 106 | 107 | void modulated_deform_conv_forward( 108 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 109 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, 110 | int kernel_h, int kernel_w, const int stride_h, const int stride_w, 111 | const int pad_h, const int pad_w, const int dilation_h, 112 | const int dilation_w, const int group, const int deformable_group, 113 | const bool with_bias) { 114 | if (input.device().is_cuda()) { 115 | #ifdef WITH_CUDA 116 | return modulated_deform_conv_cuda_forward(input, weight, bias, ones, 117 | offset, mask, output, columns, kernel_h, kernel_w, stride_h, 118 | stride_w, pad_h, pad_w, dilation_h, dilation_w, group, 119 | deformable_group, with_bias); 120 | #else 121 | AT_ERROR("modulated deform conv is not compiled with GPU support"); 122 | #endif 123 | } 124 | AT_ERROR("modulated deform conv is not implemented on CPU"); 125 | } 126 | 127 | void modulated_deform_conv_backward( 128 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 129 | at::Tensor offset, at::Tensor mask, at::Tensor columns, 130 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, 131 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, 132 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, 133 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, 134 | const bool with_bias) { 135 | if (input.device().is_cuda()) { 136 | #ifdef WITH_CUDA 137 | return modulated_deform_conv_cuda_backward(input, weight, bias, ones, 138 | offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, 139 | grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, 140 | pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, 141 | with_bias); 142 | #else 143 | AT_ERROR("modulated deform conv is not compiled with GPU support"); 144 | #endif 145 | } 146 | AT_ERROR("modulated deform conv is not implemented on CPU"); 147 | } 148 | 149 | 150 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 151 | m.def("deform_conv_forward", &deform_conv_forward, 152 | "deform forward"); 153 | m.def("deform_conv_backward_input", &deform_conv_backward_input, 154 | "deform_conv_backward_input"); 155 | m.def("deform_conv_backward_parameters", 156 | &deform_conv_backward_parameters, 157 | "deform_conv_backward_parameters"); 158 | m.def("modulated_deform_conv_forward", 159 | &modulated_deform_conv_forward, 160 | "modulated deform conv forward"); 161 | m.def("modulated_deform_conv_backward", 162 | &modulated_deform_conv_backward, 163 | "modulated deform conv backward"); 164 | } 165 | -------------------------------------------------------------------------------- /basicsr/archs/mask_guided_jdnsr_arch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | import math 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | 9 | 10 | class DepthwiseSeparableConv(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size=3 16 | ): 17 | super(DepthwiseSeparableConv, self).__init__() 18 | self.depth_conv = nn.Conv2d(in_channels, out_channels, kernel_size, 1, kernel_size // 2, groups=in_channels) 19 | self.point_conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0) 20 | 21 | def forward(self, x): 22 | x = self.depth_conv(x) 23 | x = self.point_conv(x) 24 | return x 25 | 26 | 27 | class Fusion(nn.Module): 28 | def __init__( 29 | self, 30 | num_feat 31 | ): 32 | super(Fusion, self).__init__() 33 | self.Q = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 34 | self.K = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 35 | self.V = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 36 | 37 | 38 | def forward(self, x, avg_ct, mask): 39 | b, c, w, h = x.shape 40 | 41 | q = self.Q(mask) 42 | k = self.K(x) 43 | v = self.V(avg_ct) 44 | 45 | qk = torch.matmul(q.view(b, c, w*h), k.view(b, w*h, c)) / math.sqrt(w*h) 46 | qk = torch.softmax(qk.view(b, c*c), dim=1).view(b, c, c) 47 | return torch.matmul(qk, v.view(b, c, w*h)).view(b, c, w, h) + x 48 | 49 | 50 | class SAM(nn.Module): 51 | def __init__( 52 | self, 53 | num_feat 54 | ): 55 | super(SAM, self).__init__() 56 | self.conv_prelu_deconv = nn.Sequential( 57 | nn.Conv2d(num_feat, num_feat, 3, 1, 1, dilation=2), 58 | nn.PReLU(), 59 | nn.ConvTranspose2d(num_feat, num_feat, 3, 1, 1, dilation=2) 60 | ) 61 | self.deconv_prelu_conv = nn.Sequential( 62 | nn.ConvTranspose2d(num_feat, num_feat, 3, 1, 1, dilation=2), 63 | nn.PReLU(), 64 | nn.Conv2d(num_feat, num_feat, 3, 1, 1, dilation=2) 65 | ) 66 | 67 | def forward(self, x): 68 | return self.conv_prelu_deconv(x) + self.deconv_prelu_conv(x) + x 69 | 70 | 71 | class CAM(nn.Module): 72 | def __init__( 73 | self, 74 | num_feat, 75 | squeeze_factor=16 76 | ): 77 | super(CAM, self).__init__() 78 | self.avg_pooling = nn.AdaptiveAvgPool2d((1, 1)) 79 | self.mlp = nn.Sequential( 80 | nn.Linear(num_feat, num_feat // squeeze_factor), 81 | nn.PReLU(), 82 | nn.Linear(num_feat // squeeze_factor, num_feat) 83 | ) 84 | self.sigmoid = nn.Sigmoid() 85 | 86 | def forward(self, x): 87 | b, c, w, h = x.shape 88 | shortcut = x 89 | x = self.avg_pooling(x).view(b, c) 90 | x = self.mlp(x) 91 | x = self.sigmoid(x).view(b, c, 1, 1) 92 | x = x * shortcut 93 | return x + shortcut 94 | 95 | 96 | class CALayer(nn.Module): 97 | def __init__(self, num_feat, reduction=16): 98 | super(CALayer, self).__init__() 99 | self.body = nn.Sequential( 100 | nn.Conv2d(num_feat, num_feat // reduction, 1, 1, 0), 101 | nn.PReLU(), 102 | nn.Conv2d(num_feat // reduction, num_feat, 1, 1, 0), 103 | nn.Sigmoid(), 104 | ) 105 | self.avg = nn.AdaptiveAvgPool2d(1) 106 | 107 | def forward(self, x): 108 | y = self.avg(x) 109 | y = self.body(y) 110 | return torch.mul(x, y) 111 | 112 | 113 | class ConvPReluCAM(nn.Module): 114 | def __init__( 115 | self, 116 | num_feat 117 | ): 118 | super(ConvPReluCAM, self).__init__() 119 | self.f = nn.Sequential( 120 | nn.Conv2d(num_feat, num_feat, 3, 1, 1), 121 | nn.PReLU(), 122 | CAM(num_feat) 123 | ) 124 | 125 | def forward(self, x): 126 | return self.f(x) 127 | 128 | 129 | class Block(nn.Module): 130 | def __init__( 131 | self, 132 | num_feat 133 | ): 134 | super(Block, self).__init__() 135 | self.block1 = ConvPReluCAM(num_feat) 136 | self.block2 = ConvPReluCAM(num_feat) 137 | self.block3 = ConvPReluCAM(num_feat) 138 | self.block4 = ConvPReluCAM(num_feat) 139 | self.sam1 = SAM(num_feat) 140 | self.sam2 = SAM(num_feat) 141 | 142 | def forward(self, x): 143 | x0 = x 144 | sam1 = self.sam1(x0) 145 | x1 = self.block1(x0) 146 | sam2 = self.sam2(x1) 147 | x2 = self.block2(x1) + sam1 148 | x3 = self.block3(x2) + sam2 149 | x4 = self.block4(x3) 150 | return x4 151 | 152 | 153 | class Upsample(nn.Sequential): 154 | """Upsample module. 155 | 156 | Args: 157 | scale (int): Scale factor. Supported scales: 2^n and 3. 158 | num_feat (int): Channel number of intermediate features. 159 | """ 160 | 161 | def __init__(self, scale, num_feat): 162 | m = [] 163 | if (scale & (scale - 1)) == 0: # scale = 2^n 164 | for _ in range(int(math.log(scale, 2))): 165 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 166 | m.append(nn.PixelShuffle(2)) 167 | elif scale == 3: 168 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 169 | m.append(nn.PixelShuffle(3)) 170 | else: 171 | raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') 172 | super(Upsample, self).__init__(*m) 173 | 174 | 175 | class SRHead(nn.Module): 176 | def __init__( 177 | self, 178 | scale, 179 | out_channels, 180 | num_feat 181 | ): 182 | super(SRHead, self).__init__() 183 | self.conv_before_upsample = nn.Sequential( 184 | nn.Conv2d(num_feat, num_feat, 3, 1, 1), 185 | nn.PReLU() 186 | ) 187 | self.upsample = Upsample(scale, num_feat) 188 | self.conv_last = nn.Conv2d(num_feat, out_channels, 3, 1, 1) 189 | 190 | def forward(self, x): 191 | x = self.conv_before_upsample(x) 192 | features = self.upsample(x) 193 | sr = self.conv_last(features) 194 | return sr, features 195 | 196 | 197 | class DenoiseHead(nn.Module): 198 | def __init__( 199 | self, 200 | out_channels, 201 | num_feat 202 | ): 203 | super(DenoiseHead, self).__init__() 204 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 205 | self.conv2 = nn.Conv2d(num_feat, out_channels, 3, 1, 1) 206 | 207 | def forward(self, x): 208 | features = self.conv1(x) 209 | denoise = self.conv2(features) 210 | return denoise, features 211 | 212 | 213 | @ARCH_REGISTRY.register() 214 | class MaskGuidedJDNSR(nn.Module): 215 | def __init__( 216 | self, 217 | in_channels=1, 218 | out_channels=1, 219 | scale=2, 220 | num_feat=64, 221 | num_block=6, 222 | mode="jdnsr", 223 | loop=None 224 | ): 225 | super(MaskGuidedJDNSR, self).__init__() 226 | # assert mode in ["jdnsr", "segmentation"], f"Support jdnsr and segmentatio mode only, do not support {mode}" 227 | self.mode = mode 228 | self.conv_first = nn.Conv2d(in_channels, num_feat, 3, 1, 1, bias=True) 229 | backbone = [] 230 | if mode == "jdnsr": 231 | # super resolution and denoisng 232 | for _ in range(num_block): 233 | backbone.append( 234 | nn.ModuleList( 235 | [ 236 | Fusion(num_feat), 237 | Block(num_feat) 238 | ] 239 | ) 240 | ) 241 | self.backbone = nn.ModuleList(backbone) 242 | self.sr_head = SRHead(scale=scale, out_channels=out_channels, num_feat=num_feat) 243 | self.denoise_head = DenoiseHead(out_channels=out_channels, num_feat=num_feat) 244 | 245 | self.bicubic_up = nn.Upsample(scale_factor=scale, mode="bicubic") 246 | 247 | 248 | else: 249 | raise NotImplemented(f"do not support the mode {mode}") 250 | 251 | def forward(self, x, avg_ct=None, mask=None): 252 | if self.mode == "jdnsr": 253 | x_origin = x 254 | x_origin_bicubic = self.bicubic_up(x_origin) 255 | x = self.conv_first(x) 256 | avg_ct = self.conv_first(avg_ct) 257 | mask = self.conv_first(mask) 258 | for (f, b) in self.backbone: 259 | x = f(x, avg_ct, mask) 260 | x = b(x) 261 | hrLD, hrLD_features = self.sr_head(x) 262 | LRnd, LRnd_features = self.denoise_head(x) 263 | hrnd1, _ = self.denoise_head(hrLD_features) 264 | hrnd2, _ = self.sr_head(LRnd_features) 265 | 266 | hrLD = hrLD + x_origin_bicubic 267 | LRnd = LRnd + x_origin 268 | hrnd1 = hrnd1 + x_origin_bicubic 269 | hrnd2 = hrnd2 + x_origin_bicubic 270 | 271 | return hrLD, LRnd, (hrnd1 + hrnd2) / 2 272 | 273 | -------------------------------------------------------------------------------- /basicsr/utils/face_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import torch 5 | from skimage import transform as trans 6 | 7 | from basicsr.utils import imwrite 8 | 9 | try: 10 | import dlib 11 | except ImportError: 12 | print('Please install dlib before testing face restoration.' 'Reference: https://github.com/davisking/dlib') 13 | 14 | 15 | class FaceRestorationHelper(object): 16 | """Helper for the face restoration pipeline.""" 17 | 18 | def __init__(self, upscale_factor, face_size=512): 19 | self.upscale_factor = upscale_factor 20 | self.face_size = (face_size, face_size) 21 | 22 | # standard 5 landmarks for FFHQ faces with 1024 x 1024 23 | self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941], 24 | [337.91089109, 488.38613861], [437.95049505, 493.51485149], 25 | [513.58415842, 678.5049505]]) 26 | self.face_template = self.face_template / (1024 // face_size) 27 | # for estimation the 2D similarity transformation 28 | self.similarity_trans = trans.SimilarityTransform() 29 | 30 | self.all_landmarks_5 = [] 31 | self.all_landmarks_68 = [] 32 | self.affine_matrices = [] 33 | self.inverse_affine_matrices = [] 34 | self.cropped_faces = [] 35 | self.restored_faces = [] 36 | self.save_png = True 37 | 38 | def init_dlib(self, detection_path, landmark5_path, landmark68_path): 39 | """Initialize the dlib detectors and predictors.""" 40 | self.face_detector = dlib.cnn_face_detection_model_v1(detection_path) 41 | self.shape_predictor_5 = dlib.shape_predictor(landmark5_path) 42 | self.shape_predictor_68 = dlib.shape_predictor(landmark68_path) 43 | 44 | def free_dlib_gpu_memory(self): 45 | del self.face_detector 46 | del self.shape_predictor_5 47 | del self.shape_predictor_68 48 | 49 | def read_input_image(self, img_path): 50 | # self.input_img is Numpy array, (h, w, c) with RGB order 51 | self.input_img = dlib.load_rgb_image(img_path) 52 | 53 | def detect_faces(self, img_path, upsample_num_times=1, only_keep_largest=False): 54 | """ 55 | Args: 56 | img_path (str): Image path. 57 | upsample_num_times (int): Upsamples the image before running the 58 | face detector 59 | 60 | Returns: 61 | int: Number of detected faces. 62 | """ 63 | self.read_input_image(img_path) 64 | det_faces = self.face_detector(self.input_img, upsample_num_times) 65 | if len(det_faces) == 0: 66 | print('No face detected. Try to increase upsample_num_times.') 67 | else: 68 | if only_keep_largest: 69 | print('Detect several faces and only keep the largest.') 70 | face_areas = [] 71 | for i in range(len(det_faces)): 72 | face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * ( 73 | det_faces[i].rect.bottom() - det_faces[i].rect.top()) 74 | face_areas.append(face_area) 75 | largest_idx = face_areas.index(max(face_areas)) 76 | self.det_faces = [det_faces[largest_idx]] 77 | else: 78 | self.det_faces = det_faces 79 | return len(self.det_faces) 80 | 81 | def get_face_landmarks_5(self): 82 | for face in self.det_faces: 83 | shape = self.shape_predictor_5(self.input_img, face.rect) 84 | landmark = np.array([[part.x, part.y] for part in shape.parts()]) 85 | self.all_landmarks_5.append(landmark) 86 | return len(self.all_landmarks_5) 87 | 88 | def get_face_landmarks_68(self): 89 | """Get 68 densemarks for cropped images. 90 | 91 | Should only have one face at most in the cropped image. 92 | """ 93 | num_detected_face = 0 94 | for idx, face in enumerate(self.cropped_faces): 95 | # face detection 96 | det_face = self.face_detector(face, 1) # TODO: can we remove it? 97 | if len(det_face) == 0: 98 | print(f'Cannot find faces in cropped image with index {idx}.') 99 | self.all_landmarks_68.append(None) 100 | else: 101 | if len(det_face) > 1: 102 | print('Detect several faces in the cropped face. Use the ' 103 | ' largest one. Note that it will also cause overlap ' 104 | 'during paste_faces_to_input_image.') 105 | face_areas = [] 106 | for i in range(len(det_face)): 107 | face_area = (det_face[i].rect.right() - det_face[i].rect.left()) * ( 108 | det_face[i].rect.bottom() - det_face[i].rect.top()) 109 | face_areas.append(face_area) 110 | largest_idx = face_areas.index(max(face_areas)) 111 | face_rect = det_face[largest_idx].rect 112 | else: 113 | face_rect = det_face[0].rect 114 | shape = self.shape_predictor_68(face, face_rect) 115 | landmark = np.array([[part.x, part.y] for part in shape.parts()]) 116 | self.all_landmarks_68.append(landmark) 117 | num_detected_face += 1 118 | 119 | return num_detected_face 120 | 121 | def warp_crop_faces(self, save_cropped_path=None, save_inverse_affine_path=None): 122 | """Get affine matrix, warp and cropped faces. 123 | 124 | Also get inverse affine matrix for post-processing. 125 | """ 126 | for idx, landmark in enumerate(self.all_landmarks_5): 127 | # use 5 landmarks to get affine matrix 128 | self.similarity_trans.estimate(landmark, self.face_template) 129 | affine_matrix = self.similarity_trans.params[0:2, :] 130 | self.affine_matrices.append(affine_matrix) 131 | # warp and crop faces 132 | cropped_face = cv2.warpAffine(self.input_img, affine_matrix, self.face_size) 133 | self.cropped_faces.append(cropped_face) 134 | # save the cropped face 135 | if save_cropped_path is not None: 136 | path, ext = os.path.splitext(save_cropped_path) 137 | if self.save_png: 138 | save_path = f'{path}_{idx:02d}.png' 139 | else: 140 | save_path = f'{path}_{idx:02d}{ext}' 141 | 142 | imwrite(cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path) 143 | 144 | # get inverse affine matrix 145 | self.similarity_trans.estimate(self.face_template, landmark * self.upscale_factor) 146 | inverse_affine = self.similarity_trans.params[0:2, :] 147 | self.inverse_affine_matrices.append(inverse_affine) 148 | # save inverse affine matrices 149 | if save_inverse_affine_path is not None: 150 | path, _ = os.path.splitext(save_inverse_affine_path) 151 | save_path = f'{path}_{idx:02d}.pth' 152 | torch.save(inverse_affine, save_path) 153 | 154 | def add_restored_face(self, face): 155 | self.restored_faces.append(face) 156 | 157 | def paste_faces_to_input_image(self, save_path): 158 | # operate in the BGR order 159 | input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR) 160 | h, w, _ = input_img.shape 161 | h_up, w_up = h * self.upscale_factor, w * self.upscale_factor 162 | # simply resize the background 163 | upsample_img = cv2.resize(input_img, (w_up, h_up)) 164 | assert len(self.restored_faces) == len( 165 | self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.') 166 | for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices): 167 | inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up)) 168 | mask = np.ones((*self.face_size, 3), dtype=np.float32) 169 | inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) 170 | # remove the black borders 171 | inv_mask_erosion = cv2.erode(inv_mask, np.ones((2 * self.upscale_factor, 2 * self.upscale_factor), 172 | np.uint8)) 173 | inv_restored_remove_border = inv_mask_erosion * inv_restored 174 | total_face_area = np.sum(inv_mask_erosion) // 3 175 | # compute the fusion edge based on the area of face 176 | w_edge = int(total_face_area**0.5) // 20 177 | erosion_radius = w_edge * 2 178 | inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) 179 | blur_size = w_edge * 2 180 | inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) 181 | upsample_img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * upsample_img 182 | if self.save_png: 183 | save_path = save_path.replace('.jpg', '.png').replace('.jpeg', '.png') 184 | imwrite(upsample_img.astype(np.uint8), save_path) 185 | 186 | def clean_all(self): 187 | self.all_landmarks_5 = [] 188 | self.all_landmarks_68 = [] 189 | self.restored_faces = [] 190 | self.affine_matrices = [] 191 | self.cropped_faces = [] 192 | self.inverse_affine_matrices = [] 193 | --------------------------------------------------------------------------------